@@ -1120,12 +1120,130 @@ void testSimplifyDiv() {
11201120
11211121 IS_VAR_WITH_NAME (simplified.node (), " x" );
11221122 }
1123+ }
1124+
1125+ void testSimplifyMod () {
1126+ KernelScope kernel_scope;
1127+ VarHandle x (" x" , kInt );
1128+ VarHandle y (" y" , kInt );
1129+ VarHandle z (" z" , kInt );
1130+
1131+ {
1132+ // Constant folding works.
1133+ ExprHandle body = ExprHandle (10 ) % 8 ;
1134+ ExprHandle simplified = IRSimplifier::simplify (body);
1135+ IS_IMM_WITH_VAL (Int, simplified.node (), 2 );
1136+ }
11231137
11241138 {
1125- ExprHandle body = x / x;
1139+ // x % x => 0
1140+ ExprHandle body = x % x;
11261141 ExprHandle simplified = IRSimplifier::simplify (body);
1142+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1143+ }
11271144
1128- IS_IMM_WITH_VAL (Int, simplified.node (), 1 );
1145+ {
1146+ // 0 % x => 0
1147+ ExprHandle body = ExprHandle (0 ) % x;
1148+ ExprHandle simplified = IRSimplifier::simplify (body);
1149+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1150+ }
1151+
1152+ {
1153+ // x % 1 => 0
1154+ ExprHandle body = x % 1 ;
1155+ ExprHandle simplified = IRSimplifier::simplify (body);
1156+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1157+ }
1158+
1159+ {
1160+ // Doesn't change unknown mods.
1161+ // x % y => x % y
1162+ ExprHandle body = x % y;
1163+ ExprHandle simplified = IRSimplifier::simplify (body);
1164+ IS_NODE_WITH_NAME (Mod, simplified.node (), mod);
1165+ IS_VAR_WITH_NAME (mod->lhs (), " x" );
1166+ IS_VAR_WITH_NAME (mod->rhs (), " y" );
1167+ }
1168+
1169+ {
1170+ // don't touch if RHS is unknown.
1171+ // 4 % x => 4 % x
1172+ ExprHandle body = ExprHandle (4 ) % x;
1173+ ExprHandle simplified = IRSimplifier::simplify (body);
1174+ IS_NODE_WITH_NAME (Mod, simplified.node (), mod);
1175+ IS_IMM_WITH_VAL (Int, mod->lhs (), 4 );
1176+ IS_VAR_WITH_NAME (mod->rhs (), " x" );
1177+ }
1178+
1179+ {
1180+ // don't touch if LHS is unknown.
1181+ // x % 4 => x % 4
1182+ ExprHandle body = x % 4 ;
1183+ ExprHandle simplified = IRSimplifier::simplify (body);
1184+ IS_NODE_WITH_NAME (Mod, simplified.node (), mod);
1185+ IS_VAR_WITH_NAME (mod->lhs (), " x" );
1186+ IS_IMM_WITH_VAL (Int, mod->rhs (), 4 );
1187+ }
1188+
1189+ {
1190+ // if LHS is a multiple of RHS, mod is zero.
1191+ // 2 * x % x => 0
1192+ ExprHandle body = (x * 2 ) % x;
1193+ ExprHandle simplified = IRSimplifier::simplify (body);
1194+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1195+ }
1196+
1197+ {
1198+ // true even if the multiple is not constant.
1199+ // x * y % x => 0
1200+ ExprHandle body = (x * y) % x;
1201+ ExprHandle simplified = IRSimplifier::simplify (body);
1202+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1203+ }
1204+
1205+ {
1206+ // true with multiple unknown values in LHS.
1207+ // x * y * z % x => 0
1208+ ExprHandle body = (x * y * z) % x;
1209+ ExprHandle simplified = IRSimplifier::simplify (body);
1210+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1211+ }
1212+
1213+ {
1214+ // true if the denom is compound.
1215+ // x * y * z % y * z => 0
1216+ ExprHandle body = (x * y * z) % (y * z);
1217+ ExprHandle simplified = IRSimplifier::simplify (body);
1218+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1219+ }
1220+
1221+ {
1222+ // Sanity check true with scalars that are multiples.
1223+ // 12 * x % 4 => 0
1224+ ExprHandle body = (x * 12 ) % 4 ;
1225+ ExprHandle simplified = IRSimplifier::simplify (body);
1226+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1227+ }
1228+
1229+ {
1230+ // Sanity check not true if the smaller scalar is on LHS.
1231+ // 4 * x % 12 => 4 * x % 12
1232+ ExprHandle body = (x * 4 ) % 12 ;
1233+ ExprHandle simplified = IRSimplifier::simplify (body);
1234+ IS_NODE_WITH_NAME (Mod, simplified.node (), mod);
1235+ IS_NODE_WITH_NAME (Mul, mod->lhs (), mul);
1236+ IS_IMM_WITH_VAL (Int, mul->lhs (), 4 );
1237+ IS_VAR_WITH_NAME (mul->rhs (), " x" );
1238+ IS_IMM_WITH_VAL (Int, mod->rhs (), 12 );
1239+ }
1240+
1241+ {
1242+ // Both scalar and symbolic in multiple.
1243+ // (6 * x * y) % (3 * x * y) => 0
1244+ ExprHandle body = (ExprHandle (6 ) * x * y) % (x * y * 3 );
1245+ ExprHandle simplified = IRSimplifier::simplify (body);
1246+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
11291247 }
11301248}
11311249
@@ -2807,6 +2925,189 @@ void testSimplifyEliminateEmptyCond() {
28072925 }
28082926}
28092927
2928+ void testSimplifyConstantComparisons () {
2929+ KernelScope kernel_scope;
2930+
2931+ auto ComparisonTest =
2932+ [](ExprHandle a, ExprHandle b, CompareSelectOperation op, int result) {
2933+ ExprHandle body = CompareSelect::make (a, b, op);
2934+ ExprHandle simplified = IRSimplifier::simplify (body);
2935+ IS_IMM_WITH_VAL (Int, simplified.node (), result);
2936+ };
2937+
2938+ // Equals.
2939+ ComparisonTest (2 , 2 , kEQ , 1 );
2940+ ComparisonTest (1 , 2 , kEQ , 0 );
2941+ ComparisonTest (2 , 1 , kEQ , 0 );
2942+
2943+ // Greater than.
2944+ ComparisonTest (2 , 2 , kGT , 0 );
2945+ ComparisonTest (1 , 2 , kGT , 0 );
2946+ ComparisonTest (2 , 1 , kGT , 1 );
2947+
2948+ // Greater or Equal.
2949+ ComparisonTest (2 , 2 , kGE , 1 );
2950+ ComparisonTest (1 , 2 , kGE , 0 );
2951+ ComparisonTest (2 , 1 , kGE , 1 );
2952+
2953+ // Less Than.
2954+ ComparisonTest (2 , 2 , kLT , 0 );
2955+ ComparisonTest (1 , 2 , kLT , 1 );
2956+ ComparisonTest (2 , 1 , kLT , 0 );
2957+
2958+ // Less or Equal.
2959+ ComparisonTest (2 , 2 , kLE , 1 );
2960+ ComparisonTest (1 , 2 , kLE , 1 );
2961+ ComparisonTest (2 , 1 , kLE , 0 );
2962+
2963+ // Not equal.
2964+ ComparisonTest (2 , 2 , kNE , 0 );
2965+ ComparisonTest (1 , 2 , kNE , 1 );
2966+ ComparisonTest (2 , 1 , kNE , 1 );
2967+
2968+ // With specified results:
2969+ ExprHandle body = CompareSelect::make (2 , 2 , 5 , 42 , kNE );
2970+ ExprHandle simplified = IRSimplifier::simplify (body);
2971+ IS_IMM_WITH_VAL (Int, simplified.node (), 42 );
2972+ }
2973+
2974+ void testSimplifySymbolicComparisons () {
2975+ KernelScope kernel_scope;
2976+ VarHandle x (" x" , kInt );
2977+ VarHandle y (" y" , kInt );
2978+
2979+ auto TookTrueBranch = [](ExprHandle a) { IS_IMM_WITH_VAL (Int, a.node (), 1 ); };
2980+ auto TookFalseBranch = [](ExprHandle a) {
2981+ IS_IMM_WITH_VAL (Int, a.node (), 0 );
2982+ };
2983+
2984+ // EQ
2985+
2986+ // x == x => 1
2987+ ExprHandle body = CompareSelect::make (x, x, kEQ );
2988+ TookTrueBranch (IRSimplifier::simplify (body));
2989+
2990+ // x == x+1 => 0
2991+ body = CompareSelect::make (x, x + 1 , kEQ );
2992+ TookFalseBranch (IRSimplifier::simplify (body));
2993+
2994+ // x == x * 2 cannot simplify since we don't know x is nonzero.
2995+ body = CompareSelect::make (x, x * 2 , kEQ );
2996+ IS_NODE (CompareSelect, IRSimplifier::simplify (body).node ());
2997+
2998+ // x == x * 1 => 1
2999+ body = CompareSelect::make (x, x * 1 , kEQ );
3000+ TookTrueBranch (IRSimplifier::simplify (body));
3001+
3002+ {
3003+ // x == y => x == y
3004+ body = CompareSelect::make (x, y, kEQ );
3005+ ExprHandle simplified = IRSimplifier::simplify (body);
3006+ IS_NODE_WITH_NAME (CompareSelect, simplified.node (), cmp);
3007+ ASSERT_EQ (cmp->compare_select_op (), kEQ );
3008+ IS_VAR_WITH_NAME (cmp->lhs (), " x" );
3009+ IS_VAR_WITH_NAME (cmp->rhs (), " y" );
3010+ }
3011+
3012+ {
3013+ // x == 5 => x == 5
3014+ body = CompareSelect::make (x, 5 , kEQ );
3015+ ExprHandle simplified = IRSimplifier::simplify (body);
3016+ IS_NODE_WITH_NAME (CompareSelect, simplified.node (), cmp);
3017+ ASSERT_EQ (cmp->compare_select_op (), kEQ );
3018+ IS_VAR_WITH_NAME (cmp->lhs (), " x" );
3019+ IS_IMM_WITH_VAL (Int, cmp->rhs (), 5 );
3020+ }
3021+
3022+ // GT
3023+
3024+ // x+1 > x => 1
3025+ body = CompareSelect::make (x + 1 , x, kGT );
3026+ TookTrueBranch (IRSimplifier::simplify (body));
3027+
3028+ // x > x + 1 => 0
3029+ body = CompareSelect::make (x, x + 1 , kGT );
3030+ TookFalseBranch (IRSimplifier::simplify (body));
3031+
3032+ // x > x - 1 => 1
3033+ body = CompareSelect::make (x, x - 1 , kGT );
3034+ TookTrueBranch (IRSimplifier::simplify (body));
3035+
3036+ // x - 1 > x => 0
3037+ body = CompareSelect::make (x - 1 , x, kGT );
3038+ TookFalseBranch (IRSimplifier::simplify (body));
3039+
3040+ // x > x => 0
3041+ body = CompareSelect::make (x, x, kGT );
3042+ TookFalseBranch (IRSimplifier::simplify (body));
3043+
3044+ // x * 2 > x => x * 2 > x
3045+ // since we don't know the sign of x.
3046+ body = CompareSelect::make (x * 2 , x, kGT );
3047+ IS_NODE (CompareSelect, IRSimplifier::simplify (body).node ());
3048+
3049+ // GE
3050+
3051+ // x+1 >= x => 1
3052+ body = CompareSelect::make (x + 1 , x, kGE );
3053+ TookTrueBranch (IRSimplifier::simplify (body));
3054+
3055+ // x >= x + 1 => 0
3056+ body = CompareSelect::make (x, x + 1 , kGE );
3057+ TookFalseBranch (IRSimplifier::simplify (body));
3058+
3059+ // x >= x => 1
3060+ body = CompareSelect::make (x, x, kGE );
3061+ TookTrueBranch (IRSimplifier::simplify (body));
3062+
3063+ // x * 2 >= x => x * 2 >= x
3064+ // since we don't know the sign of x.
3065+ body = CompareSelect::make (x * 2 , x, kGE );
3066+ IS_NODE (CompareSelect, IRSimplifier::simplify (body).node ());
3067+
3068+ // LT
3069+
3070+ // x+1 < x => 0
3071+ body = CompareSelect::make (x + 1 , x, kLT );
3072+ TookFalseBranch (IRSimplifier::simplify (body));
3073+
3074+ // x < x + 1 => 1
3075+ body = CompareSelect::make (x, x + 1 , kLT );
3076+ TookTrueBranch (IRSimplifier::simplify (body));
3077+
3078+ // x < x => 0
3079+ body = CompareSelect::make (x, x, kLT );
3080+ TookFalseBranch (IRSimplifier::simplify (body));
3081+
3082+ // LE
3083+
3084+ // x+1 <= x => 0
3085+ body = CompareSelect::make (x + 1 , x, kLE );
3086+ TookFalseBranch (IRSimplifier::simplify (body));
3087+
3088+ // x <= x + 1 => 1
3089+ body = CompareSelect::make (x, x + 1 , kLE );
3090+ TookTrueBranch (IRSimplifier::simplify (body));
3091+
3092+ // x <= x => 1
3093+ body = CompareSelect::make (x, x, kLE );
3094+ TookTrueBranch (IRSimplifier::simplify (body));
3095+
3096+ // NE
3097+
3098+ // x+1 != x => 1
3099+ body = CompareSelect::make (x + 1 , x, kNE );
3100+ TookTrueBranch (IRSimplifier::simplify (body));
3101+
3102+ // x != x + 1 => 1
3103+ body = CompareSelect::make (x, x + 1 , kNE );
3104+ TookTrueBranch (IRSimplifier::simplify (body));
3105+
3106+ // x != x => 0
3107+ body = CompareSelect::make (x, x, kNE );
3108+ TookFalseBranch (IRSimplifier::simplify (body));
3109+ }
3110+
28103111void testSimplifyEliminateZeroLengthFor () {
28113112 KernelScope kernel_scope;
28123113
0 commit comments