Skip to content

Commit 0021165

Browse files
authored
Make random faster by putting the innermost var last (#6504)
* Make random 2x faster by putting the innermost var last * Improve period of low bits of random noise * Add new rewrite rules for quadratics By pulling constant additions outside of quadratics, we can shave off a few add instructions in the inner loop for random number generation, which uses a quadratic modulo 2^32 I also removed the !overflows predicates, because rules already fail to match if a fold overflows. New rules formally verified. * Make expensive_zero actually always zero
1 parent f11d820 commit 0021165

File tree

4 files changed

+19
-9
lines changed

4 files changed

+19
-9
lines changed

src/Random.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ Expr random_int(const vector<Expr> &e) {
8282
rng32(Variable::make(UInt(32), name)));
8383
}
8484
}
85+
// The low bytes of this have a poor period, so mix in the high bytes for
86+
// two additional instructions.
87+
result = result ^ (result >> 16);
88+
8589
return result;
8690
}
8791

@@ -101,7 +105,9 @@ class LowerRandom : public IRMutator {
101105
Expr visit(const Call *op) override {
102106
if (op->is_intrinsic(Call::random)) {
103107
vector<Expr> args = op->args;
104-
args.insert(args.end(), extra_args.begin(), extra_args.end());
108+
// Insert the free vars in reverse, so innermost vars typically end
109+
// up last.
110+
args.insert(args.end(), extra_args.rbegin(), extra_args.rend());
105111
if (op->type == Float(32)) {
106112
return random_float(args);
107113
} else if (op->type == Int(32)) {
@@ -121,14 +127,14 @@ class LowerRandom : public IRMutator {
121127

122128
public:
123129
LowerRandom(const vector<VarOrRVar> &free_vars, int tag) {
124-
extra_args.emplace_back(tag);
125130
for (const VarOrRVar &v : free_vars) {
126131
if (v.is_rvar) {
127132
extra_args.push_back(v.rvar);
128133
} else {
129134
extra_args.push_back(v.var);
130135
}
131136
}
137+
extra_args.emplace_back(tag);
132138
}
133139
};
134140

src/Simplify_Mul.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,16 @@ Expr Simplify::visit(const Mul *op, ExprInfo *bounds) {
6969
}
7070

7171
if (rewrite(c0 * c1, fold(c0 * c1)) ||
72-
rewrite((x + c0) * c1, x * c1 + fold(c0 * c1), !overflows(c0 * c1)) ||
73-
rewrite((c0 - x) * c1, x * fold(-c1) + fold(c0 * c1), !overflows(c0 * c1)) ||
72+
rewrite((x + c0) * (x + c1), x * (x + fold(c0 + c1)) + fold(c0 * c1)) ||
73+
rewrite((x * c0 + c1) * (x + c2), x * (x * c0 + fold(c1 + c0 * c2)) + fold(c1 * c2)) ||
74+
rewrite((x + c2) * (x * c0 + c1), x * (x * c0 + fold(c1 + c0 * c2)) + fold(c1 * c2)) ||
75+
rewrite((x * c0 + c1) * (x * c2 + c3), x * (x * fold(c0 * c2) + fold(c0 * c3 + c1 * c2)) + fold(c1 * c3)) ||
76+
rewrite((x + c0) * c1, x * c1 + fold(c0 * c1)) ||
77+
rewrite((c0 - x) * c1, x * fold(-c1) + fold(c0 * c1)) ||
7478
rewrite((0 - x) * y, 0 - x * y) ||
7579
rewrite(x * (0 - y), 0 - x * y) ||
7680
rewrite((x - y) * c0, (y - x) * fold(-c0), c0 < 0 && -c0 > 0) ||
77-
rewrite((x * c0) * c1, x * fold(c0 * c1), !overflows(c0 * c1)) ||
81+
rewrite((x * c0) * c1, x * fold(c0 * c1)) ||
7882
rewrite((x * c0) * y, (x * y) * c0, !is_const(y)) ||
7983
rewrite(x * (y * c0), (x * y) * c0) ||
8084
rewrite(max(x, y) * min(x, y), x * y) ||

test/correctness/async_device_copy.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ Expr expensive_zero(Expr x, Expr y, Expr t, int n) {
99
RDom r(0, n);
1010
Func a, b, c;
1111
Var z;
12-
a(x, y, t, z) = random_int() % 1024;
13-
b(x, y, t, z) = random_int() % 1024;
14-
c(x, y, t, z) = random_int() % 1024;
12+
a(x, y, t, z) = random_int() % 1024 + 5;
13+
b(x, y, t, z) = random_int() % 1024 + 5;
14+
c(x, y, t, z) = random_int() % 1024 + 5;
1515
return sum(select(pow(a(x, y, t, r), 3) + pow(b(x, y, t, r), 3) == pow(c(x, y, t, r), 3), 1, 0));
1616
}
1717

test/correctness/simplify.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ void check_casts() {
130130

131131
check(cast(Float(64), 0.5f), Expr(0.5));
132132
check((x - cast(Float(64), 0.5f)) * (x - cast(Float(64), 0.5f)),
133-
(x + Expr(-0.5)) * (x + Expr(-0.5)));
133+
(cast(Float(64), x) + Expr(-1.0)) * cast(Float(64), x) + Expr(0.25));
134134

135135
check(cast(Int(64, 3), ramp(5.5f, 2.0f, 3)),
136136
cast(Int(64, 3), ramp(5.5f, 2.0f, 3)));

0 commit comments

Comments
 (0)