@@ -2617,13 +2617,13 @@ struct to_ir {
26172617 loc, *graph, aten::slice, c10::nullopt , args, {step_nv}, true );
26182618 }
26192619
2620- Value* emitUnsqueeze (const SourceRange& loc, Value* input, int64_t dim ) {
2620+ Value* emitUnsqueeze (const SourceRange& loc, Value* input, Value* dim_val ) {
26212621 return emitBuiltinCall (
26222622 loc,
26232623 *graph,
26242624 aten::unsqueeze,
26252625 c10::nullopt ,
2626- {input, graph-> insertConstant (dim, nullptr , loc) },
2626+ {input, dim_val },
26272627 {},
26282628 true );
26292629 }
@@ -2653,71 +2653,119 @@ struct to_ir {
26532653 const SourceRange& loc,
26542654 Value* sliceable,
26552655 const List<Expr>& subscript_exprs) {
2656- std::vector<Value*> tensor_indices;
2657- size_t dim = 0 ;
2656+ // Overall, to handle indexing (other than Tensors), we need to handle a couple different things.
2657+ // For example, for x[1:3, None, 4], each of these different index types
2658+ // (slice, None, and integer) result in different number of dimensions.
2659+ // Slicing doesn't change the number of dimensions, None adds a dimension,
2660+ // and integer removes a dimension. As these indexing operations are applied
2661+ // left to right, the actual index that it's being applied to depends on the
2662+ // previous operations.
2663+ // Ellipses indexing throws another wrinkle. Ellipses selects any remaining
2664+ // unspecified dimensions. Thus, for indexes following an ellipses, the
2665+ // actual index an indexing operation is being applied to depends on the
2666+ // operations to the right.
2667+ // Thus, we do two passes, one from left to right up until the ellipses, and
2668+ // one from right to left.
26582669
2659- auto handle_tensor = [&](Value* tensor) {
2660- // NB: tensor_indices can have None holes because of how at::index works.
2661- tensor_indices.resize (dim + 1 );
2662- tensor_indices[dim] = tensor;
2663- dim++;
2664- };
2670+ std::vector<Value*> tensor_indices;
26652671
2666- // before ellipsis, dimension index should be `dim`
2667- // after ellipsis, dimension index should be `-offset`
2668- int offset = 0 ;
2669- size_t ellipsis_dim = 0 ;
26702672 auto insert_value_for_dim = [&](int64_t dim) {
2671- return (offset == 0 )
2672- ? graph->insertConstant (dim, nullptr , loc)
2673- :
2674- // NB: offset is incremented to move to the next dimension index
2675- graph->insertConstant (offset++, nullptr , loc);
2673+ return graph->insertConstant (dim, nullptr , loc);
26762674 };
2677-
2678- for (const auto & subscript_expr : subscript_exprs) {
2679- // NB: ellipsis_dim is **always** incremented
2680- // (comparing to dim) in order to compute
2681- // the correct offsets for the remaining
2682- // dimension indices following an ellipsis "..."
2683- // token
2684- ellipsis_dim++;
2685- if (subscript_expr.kind () == TK_DOTS) {
2686- offset = -(subscript_exprs.size () - ellipsis_dim);
2687- ++dim;
2688- continue ;
2689- }
2675+ std::vector<int64_t > dims (subscript_exprs.size ());
2676+ std::vector<c10::optional<Value*>> exprs (
2677+ subscript_exprs.size (), c10::nullopt );
2678+
2679+ auto handle_indexing = [&](const Expr& subscript_expr,
2680+ int expr_idx,
2681+ int64_t dim,
2682+ bool is_reverse = false ) {
2683+ dims[expr_idx] = dim;
26902684 if (subscript_expr.kind () == TK_SLICE_EXPR) {
2691- auto dim_val = insert_value_for_dim (dim);
2692- sliceable =
2693- emitSlice (loc, sliceable, dim_val, SliceExpr (subscript_expr));
2694- ++ dim;
2695- continue ;
2685+ if (is_reverse) {
2686+ return dim - 1 ;
2687+ } else {
2688+ return dim + 1 ;
2689+ }
26962690 }
26972691 TypePtr type_hint = OptionalType::ofTensor ();
26982692 if (subscript_expr.kind () == TK_NONE) {
26992693 type_hint = NoneType::get ();
27002694 }
27012695 auto index = emitExpr (subscript_expr, type_hint);
2702- if (index->type () == IntType::get ()) {
2703- // NB: note, select squeezes out a dimension,
2704- // so dim is **not** incremented
2705- auto dim_val = insert_value_for_dim (dim);
2706- sliceable = emitSelect (loc, sliceable, dim_val, index);
2707- continue ;
2708- } else if (index->type ()->isSubtypeOf (NoneType::get ())) {
2709- sliceable = emitUnsqueeze (loc, sliceable, dim);
2710- dim++;
2711- continue ;
2696+ exprs[expr_idx] = index;
2697+ if (index->type ()->isSubtypeOf (NoneType::get ())) {
2698+ if (is_reverse) {
2699+ return dim;
2700+ } else {
2701+ return dim + 1 ;
2702+ }
2703+ } else if (index->type () == IntType::get ()) {
2704+ if (is_reverse) {
2705+ return dim - 1 ;
2706+ } else {
2707+ return dim;
2708+ }
27122709 } else if (index->type ()->isSubtypeOf (OptionalType::ofTensor ())) {
2713- // NB:index type can either be a Tensor or : (None of Optional Tensor)
2714- handle_tensor (index);
2710+ if (is_reverse) {
2711+ throw ErrorReport (loc)
2712+ << " Ellipses followed by tensor indexing is currently not supported" ;
2713+ } else {
2714+ return dim + 1 ;
2715+ }
2716+ } else {
2717+ throw ErrorReport (loc)
2718+ << " Unsupported operation: indexing tensor with unsupported index type '"
2719+ << index->type ()->python_str ()
2720+ << " '. Only ints, slices, and tensors are supported" ;
2721+ }
2722+ };
2723+
2724+ size_t idx = 0 ;
2725+ int64_t dim = 0 ;
2726+ for (; idx < subscript_exprs.size (); idx++) {
2727+ auto subscript_expr = subscript_exprs[idx];
2728+ if (subscript_expr.kind () == TK_DOTS) {
2729+ break ;
2730+ }
2731+ dim = handle_indexing (subscript_expr, idx, dim, /* is_reverse=*/ false );
2732+ }
2733+ int64_t rdim = -1 ;
2734+ for (size_t rev_idx = subscript_exprs.size () - 1 ; rev_idx > idx;
2735+ rev_idx--) {
2736+ auto subscript_expr = subscript_exprs[rev_idx];
2737+ if (subscript_expr.kind () == TK_DOTS) {
2738+ throw ErrorReport (loc)
2739+ << " An index can only have a single ellipsis ('...')" ;
2740+ }
2741+ rdim =
2742+ handle_indexing (subscript_expr, rev_idx, rdim, /* is_reverse=*/ true );
2743+ }
2744+ for (size_t i = 0 ; i < exprs.size (); i++) {
2745+ if (!exprs[i].has_value ()) {
2746+ if (subscript_exprs[i].kind () == TK_SLICE_EXPR) {
2747+ sliceable = emitSlice (
2748+ loc,
2749+ sliceable,
2750+ insert_value_for_dim (dims[i]),
2751+ SliceExpr (subscript_exprs[i]));
2752+ }
27152753 continue ;
27162754 }
2717- throw ErrorReport (loc)
2718- << " Unsupported operation: indexing tensor with unsupported index type '"
2719- << index->type ()->python_str ()
2720- << " '. Only ints, slices, and tensors are supported" ;
2755+ auto expr = exprs[i].value ();
2756+ if (expr->type ()->isSubtypeOf (NoneType::get ())) {
2757+ sliceable =
2758+ emitUnsqueeze (loc, sliceable, insert_value_for_dim (dims[i]));
2759+ } else if (expr->type () == IntType::get ()) {
2760+ sliceable =
2761+ emitSelect (loc, sliceable, insert_value_for_dim (dims[i]), expr);
2762+ } else if (expr->type ()->isSubtypeOf (OptionalType::ofTensor ())) {
2763+ tensor_indices.resize (dims[i] + 1 );
2764+ tensor_indices[dims[i]] = expr;
2765+ } else {
2766+ TORCH_INTERNAL_ASSERT (
2767+ " Trying to process index type that we don't support." );
2768+ }
27212769 }
27222770 // at::index takes in a List[Optional[Tensor]] where some dims can be None.
27232771 // create None node with optional tensor output type and pass to at::index.
0 commit comments