Skip to content

Commit a24f6c1

Browse files
Chilleefacebook-github-bot
authored andcommitted
Fix broken indexing when using None and ellipses indexing together (#22905)
Summary: #20153 I believe you need 2 passes for this. Take this example ```python torch.jit.script def f(): x = torch.ones(10, 9, 8, 7, 6) return x[..., None, None].shape ``` which results in `[10, 9, 8, 7, 6, 1, 1]` vs ``` torch.jit.script def f(): x = torch.ones(10, 9, 8, 7, 6) return x[..., None, None, :].shape ``` which results in `[10, 9, 8, 7, 1, 1, 6]` After only processing `x[..., None, None` we don't know whether we should be creating a new dimension at the end of the dimension list or somewhere in the middle. What we do depends on the elements to the right of it. Thus, I do 2 passes - one to collect all the dimensions that the index operations operate on, and another that executes the index operations. This still doesn't work for an ellipse index followed by a tensor index, but it wasn't working previously either. Pull Request resolved: #22905 Differential Revision: D16433558 Pulled By: Chillee fbshipit-source-id: c1b303cb97b1af8b6e405bad33495ef3b4c27c4a
1 parent 648f10b commit a24f6c1

File tree

2 files changed

+120
-53
lines changed

2 files changed

+120
-53
lines changed

test/test_jit.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3761,6 +3761,25 @@ def func(x, value1, value2):
37613761
check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
37623762
check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)
37633763

3764+
def test_index_ellipses(self):
3765+
vals = [":", 1, None]
3766+
for _ in range(100):
3767+
indices = [random.choice(vals) for _ in range(4)]
3768+
indices[random.randint(0, len(indices) - 1)] = "..."
3769+
test_str = dedent("""
3770+
def f():
3771+
x = torch.ones(10, 9, 8, 7, 6)
3772+
return x{indices}.shape
3773+
""".format(indices=indices))
3774+
test_str = test_str.replace(r"'", r'')
3775+
scope = {}
3776+
execWrapper(test_str, globals(), scope)
3777+
cu = torch.jit.CompilationUnit(test_str)
3778+
res1 = cu.f()
3779+
res2 = scope['f']()
3780+
self.assertEqual(res1, res2)
3781+
3782+
37643783
def test_tensor_item(self):
37653784
def test_scalar_cast(x):
37663785
scalar = x.item()

torch/csrc/jit/script/compiler.cpp

Lines changed: 101 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2617,13 +2617,13 @@ struct to_ir {
26172617
loc, *graph, aten::slice, c10::nullopt, args, {step_nv}, true);
26182618
}
26192619

2620-
Value* emitUnsqueeze(const SourceRange& loc, Value* input, int64_t dim) {
2620+
Value* emitUnsqueeze(const SourceRange& loc, Value* input, Value* dim_val) {
26212621
return emitBuiltinCall(
26222622
loc,
26232623
*graph,
26242624
aten::unsqueeze,
26252625
c10::nullopt,
2626-
{input, graph->insertConstant(dim, nullptr, loc)},
2626+
{input, dim_val},
26272627
{},
26282628
true);
26292629
}
@@ -2653,71 +2653,119 @@ struct to_ir {
26532653
const SourceRange& loc,
26542654
Value* sliceable,
26552655
const List<Expr>& subscript_exprs) {
2656-
std::vector<Value*> tensor_indices;
2657-
size_t dim = 0;
2656+
// Overall, to handle indexing (other than Tensors), we need to handle a couple different things.
2657+
// For example, for x[1:3, None, 4], each of these different index types
2658+
// (slice, None, and integer) result in different number of dimensions.
2659+
// Slicing doesn't change the number of dimensions, None adds a dimension,
2660+
// and integer removes a dimension. As these indexing operations are applied
2661+
// left to right, the actual index that it's being applied to depends on the
2662+
// previous operations.
2663+
// Ellipses indexing throws another wrinkle. Ellipses selects any remaining
2664+
// unspecified dimensions. Thus, for indexes following an ellipses, the
2665+
// actual index an indexing operation is being applied to depends on the
2666+
// operations to the right.
2667+
// Thus, we do two passes, one from left to right up until the ellipses, and
2668+
// one from right to left.
26582669

2659-
auto handle_tensor = [&](Value* tensor) {
2660-
// NB: tensor_indices can have None holes because of how at::index works.
2661-
tensor_indices.resize(dim + 1);
2662-
tensor_indices[dim] = tensor;
2663-
dim++;
2664-
};
2670+
std::vector<Value*> tensor_indices;
26652671

2666-
// before ellipsis, dimension index should be `dim`
2667-
// after ellipsis, dimension index should be `-offset`
2668-
int offset = 0;
2669-
size_t ellipsis_dim = 0;
26702672
auto insert_value_for_dim = [&](int64_t dim) {
2671-
return (offset == 0)
2672-
? graph->insertConstant(dim, nullptr, loc)
2673-
:
2674-
// NB: offset is incremented to move to the next dimension index
2675-
graph->insertConstant(offset++, nullptr, loc);
2673+
return graph->insertConstant(dim, nullptr, loc);
26762674
};
2677-
2678-
for (const auto& subscript_expr : subscript_exprs) {
2679-
// NB: ellipsis_dim is **always** incremented
2680-
// (comparing to dim) in order to compute
2681-
// the correct offsets for the remaining
2682-
// dimension indices following an ellipsis "..."
2683-
// token
2684-
ellipsis_dim++;
2685-
if (subscript_expr.kind() == TK_DOTS) {
2686-
offset = -(subscript_exprs.size() - ellipsis_dim);
2687-
++dim;
2688-
continue;
2689-
}
2675+
std::vector<int64_t> dims(subscript_exprs.size());
2676+
std::vector<c10::optional<Value*>> exprs(
2677+
subscript_exprs.size(), c10::nullopt);
2678+
2679+
auto handle_indexing = [&](const Expr& subscript_expr,
2680+
int expr_idx,
2681+
int64_t dim,
2682+
bool is_reverse = false) {
2683+
dims[expr_idx] = dim;
26902684
if (subscript_expr.kind() == TK_SLICE_EXPR) {
2691-
auto dim_val = insert_value_for_dim(dim);
2692-
sliceable =
2693-
emitSlice(loc, sliceable, dim_val, SliceExpr(subscript_expr));
2694-
++dim;
2695-
continue;
2685+
if (is_reverse) {
2686+
return dim - 1;
2687+
} else {
2688+
return dim + 1;
2689+
}
26962690
}
26972691
TypePtr type_hint = OptionalType::ofTensor();
26982692
if (subscript_expr.kind() == TK_NONE) {
26992693
type_hint = NoneType::get();
27002694
}
27012695
auto index = emitExpr(subscript_expr, type_hint);
2702-
if (index->type() == IntType::get()) {
2703-
// NB: note, select squeezes out a dimension,
2704-
// so dim is **not** incremented
2705-
auto dim_val = insert_value_for_dim(dim);
2706-
sliceable = emitSelect(loc, sliceable, dim_val, index);
2707-
continue;
2708-
} else if (index->type()->isSubtypeOf(NoneType::get())) {
2709-
sliceable = emitUnsqueeze(loc, sliceable, dim);
2710-
dim++;
2711-
continue;
2696+
exprs[expr_idx] = index;
2697+
if (index->type()->isSubtypeOf(NoneType::get())) {
2698+
if (is_reverse) {
2699+
return dim;
2700+
} else {
2701+
return dim + 1;
2702+
}
2703+
} else if (index->type() == IntType::get()) {
2704+
if (is_reverse) {
2705+
return dim - 1;
2706+
} else {
2707+
return dim;
2708+
}
27122709
} else if (index->type()->isSubtypeOf(OptionalType::ofTensor())) {
2713-
// NB:index type can either be a Tensor or : (None of Optional Tensor)
2714-
handle_tensor(index);
2710+
if (is_reverse) {
2711+
throw ErrorReport(loc)
2712+
<< "Ellipses followed by tensor indexing is currently not supported";
2713+
} else {
2714+
return dim + 1;
2715+
}
2716+
} else {
2717+
throw ErrorReport(loc)
2718+
<< "Unsupported operation: indexing tensor with unsupported index type '"
2719+
<< index->type()->python_str()
2720+
<< "'. Only ints, slices, and tensors are supported";
2721+
}
2722+
};
2723+
2724+
size_t idx = 0;
2725+
int64_t dim = 0;
2726+
for (; idx < subscript_exprs.size(); idx++) {
2727+
auto subscript_expr = subscript_exprs[idx];
2728+
if (subscript_expr.kind() == TK_DOTS) {
2729+
break;
2730+
}
2731+
dim = handle_indexing(subscript_expr, idx, dim, /*is_reverse=*/false);
2732+
}
2733+
int64_t rdim = -1;
2734+
for (size_t rev_idx = subscript_exprs.size() - 1; rev_idx > idx;
2735+
rev_idx--) {
2736+
auto subscript_expr = subscript_exprs[rev_idx];
2737+
if (subscript_expr.kind() == TK_DOTS) {
2738+
throw ErrorReport(loc)
2739+
<< "An index can only have a single ellipsis ('...')";
2740+
}
2741+
rdim =
2742+
handle_indexing(subscript_expr, rev_idx, rdim, /*is_reverse=*/true);
2743+
}
2744+
for (size_t i = 0; i < exprs.size(); i++) {
2745+
if (!exprs[i].has_value()) {
2746+
if (subscript_exprs[i].kind() == TK_SLICE_EXPR) {
2747+
sliceable = emitSlice(
2748+
loc,
2749+
sliceable,
2750+
insert_value_for_dim(dims[i]),
2751+
SliceExpr(subscript_exprs[i]));
2752+
}
27152753
continue;
27162754
}
2717-
throw ErrorReport(loc)
2718-
<< "Unsupported operation: indexing tensor with unsupported index type '"
2719-
<< index->type()->python_str()
2720-
<< "'. Only ints, slices, and tensors are supported";
2755+
auto expr = exprs[i].value();
2756+
if (expr->type()->isSubtypeOf(NoneType::get())) {
2757+
sliceable =
2758+
emitUnsqueeze(loc, sliceable, insert_value_for_dim(dims[i]));
2759+
} else if (expr->type() == IntType::get()) {
2760+
sliceable =
2761+
emitSelect(loc, sliceable, insert_value_for_dim(dims[i]), expr);
2762+
} else if (expr->type()->isSubtypeOf(OptionalType::ofTensor())) {
2763+
tensor_indices.resize(dims[i] + 1);
2764+
tensor_indices[dims[i]] = expr;
2765+
} else {
2766+
TORCH_INTERNAL_ASSERT(
2767+
"Trying to process index type that we don't support.");
2768+
}
27212769
}
27222770
// at::index takes in a List[Optional[Tensor]] where some dims can be None.
27232771
// create None node with optional tensor output type and pass to at::index.

0 commit comments

Comments
 (0)