-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Description
Hello,
I encountered issue that loops generated from reduction domain where inefficient. Halide properly detected stop condition for x rdom variable loop and limited iteration in this dimension however for y rdom variable not only it did not detect it properly as loop limit but also put if branch inside inner loop.
I encountered this issue in v14.0.0 and also on release/15.x branch.
I was able to reduce code to this reproduction case (probably it could be reduced further as at some point I just complicated things to try to trigger it, however I think main point is complicating totalEntries expression):
#include "Halide.h"
using namespace Halide;
class BugGenerator : public Generator<BugGenerator>
{
public:
static constexpr int MAX_ENTRIES = 512;
Input<Buffer<std::uint32_t, 2>> input {"input"};
Input<Buffer<bool, 3>> valid {"valid"};
Output<Buffer<std::uint32_t, 2>> result {"result"};
Func internal {"internal"};
Var x, y, index;
void generate()
{
const Expr entry1 = input(x / 2, y / 2);
const Expr entry2 = input(x / 2, y / 2);
const auto dim0 = input.dim(0);
const auto dim1 = input.dim(1);
const Expr entry3 = input(
clamp(cast<int>(entry1), dim0.min(), dim0.min() + dim0.extent()),
clamp(cast<int>(entry2), dim1.min(), dim1.min() + dim1.extent()));
const Expr totalEntries = 1 + 2 * entry1 + entry3;
Func entries;
entries(x, y, index) = Tuple(cast<int>(input(x, y) % 10), y, index);
RDom rdom(0, MAX_ENTRIES, 0, MAX_ENTRIES, "internalRDOM");
rdom.where(rdom.y < totalEntries);
rdom.where(rdom.x < totalEntries);
const Tuple entry = entries(x, y, rdom.y);
rdom.where(valid(entry.as_vector()));
internal(x, y, index) = cast<std::uint32_t>(0);
const auto currentIndex = clamp(totalEntries - rdom.x - 1, 0, MAX_ENTRIES);
const Expr currentValue = internal(x, y, currentIndex);
rdom.where(currentValue < 1024);
internal(x, y, currentIndex) = select(currentValue < 512,
currentValue + cast<std::uint32_t>(rdom.y),
currentValue / Expr(2u) + Expr(2u) * cast<std::uint32_t>(rdom.y));
RDom sumDom(0, MAX_ENTRIES, "sumRDOM");
sumDom.where(sumDom.x < totalEntries);
result(x, y) = sum(sumDom, internal(x, y, sumDom.x));
}
void schedule()
{
internal.compute_root().parallel(y);
internal.update(0).parallel(y);
result.parallel(y);
}
};
HALIDE_REGISTER_GENERATOR(BugGenerator, bug_generator);and this interesting fragment of lowered stmt (I added comments to mark problem points):
// Unbound loop iteration
for (internal.s1.internalRDOM$y, 0, 512) {
let internal.s1.internalRDOM$x.new_max.s = let t414 = input[t387] in (let t415 = int32(t414) in (int32(((input[(max(min(t415, t373), input.min.1)*input.stride.1) + (max(min(t415, t372), input.min.0) - t385)] + (t414*(uint32)2)) + (uint32)1)) + -1))
let t398 = uint32(internal.s1.internalRDOM$y)
let t394 = max(min(internal.s1.internalRDOM$x.new_max.s, 511), -1)
let t395 = internal.s1.internalRDOM$y*valid.stride.2
for (internal.s1.internalRDOM$x, 0, t394 + 1) {
// both rdom checked in inner loop (for x it is second check as far as I understand)
if (let t416 = max(internal.s1.internalRDOM$x, internal.s1.internalRDOM$y) in (let t417 = input[t388] in (let t418 = int32(t417) in (let t419 = input[t387] in (let t420 = int32(t419) in (((t416 < int32(((input[(max(min(t418, t373), input.min.1)*input.stride.1) + (max(min(t418, t372), input.min.0) - t385)] + (t417*(uint32)2)) + (uint32)1))) && uint1(valid[(t389 + int32((input[t390] % (uint32)10))) + t395])) && (t416 < int32(((input[(max(min(t420, t373), input.min.1)*input.stride.1) + (max(min(t420, t372), input.min.0) - t385)] + (t419*(uint32)2)) + (uint32)1))))))))) {
if (let t421 = input[t387] in (let t422 = int32(t421) in (max(internal.s1.internalRDOM$x, internal.s1.internalRDOM$y) < int32(((input[(max(min(t422, t373), input.min.1)*input.stride.1) + (max(min(t422, t372), input.min.0) - t385)] + (t421*(uint32)2)) + (uint32)1))))) {
if (let t423 = input[t387] in (let t424 = int32(t423) in (internal[(max(min(int32(((input[(max(min(t424, t373), input.min.1)*input.stride.1) + (max(min(t424, t372), input.min.0) - t385)] + (t423*(uint32)2)) + (uint32)1)) - internal.s1.internalRDOM$x, 513), 1)*t383) + t391] < (uint32)1024))) {
let t332 = input[t387]
let t333 = int32(t332)
let t334.s = int32(((input[(max(min(t333, t373), input.min.1)*input.stride.1) + (max(min(t333, t372), input.min.0) - t385)] + (t332*(uint32)2)) + (uint32)1))
let t337 = internal[(max(min(t334.s - internal.s1.internalRDOM$x, 513), 1)*t383) + t391]
internal[((max(min(t334.s - internal.s1.internalRDOM$x, 513), 1) + -1)*t383) + t392] = select(t337 < (uint32)512, t337 + t398, (t337/(uint32)2) + (t398*(uint32)2))
}
}
}
}
}In my original case there is slight difference that my bounds check does not need memory access only calculations however it is emmitted as second condition and first one does memory access.
Not sure if it is possible to force different statement generation or if this is some bug or limitation?
Best regards,
Adrian