|
80 | 80 | FALLBACK_ALLOW_LIST, |
81 | 81 | fallback_handler, |
82 | 82 | fallback_node_due_to_unsupported_type, |
| 83 | + get_layout_constraint_tag, |
83 | 84 | lowerings, |
84 | 85 | make_fallback, |
85 | 86 | maybe_layout_constraints, |
86 | 87 | needs_realized_inputs, |
87 | 88 | require_contiguous, |
| 89 | + tag_to_layout_constraint, |
88 | 90 | unsupported_output_tensor, |
89 | 91 | ) |
90 | 92 | from .runtime import autotune_cache |
@@ -1149,34 +1151,26 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> |
1149 | 1151 | error.operator_str(target, args, kwargs), |
1150 | 1152 | ) |
1151 | 1153 |
|
1152 | | - # use contiguous unless the (custom) op asks something else |
1153 | | - # explicitly |
1154 | | - if torch._library.utils.needs_exact_strides(target): |
1155 | | - decided_constraint = constrain_to_fake_tensors |
1156 | | - elif torch._C.Tag.needs_fixed_stride_order in target.tags: |
1157 | | - decided_constraint = constrain_to_fx_strides # type: ignore[assignment] |
1158 | | - elif torch._C.Tag.flexible_layout in target.tags: |
1159 | | - decided_constraint = None # type: ignore[assignment] |
1160 | | - else: |
1161 | | - # If there are no tags, we do different things depending on |
1162 | | - # if it's a builtin ATen/prim ops or custom ops. |
1163 | | - # For ATen ops, we require_contiguous to fix https://github.com/pytorch/pytorch/issues/140452 |
1164 | | - # For custom ops, we constrain_to_fx_strides to maintain the |
1165 | | - # behavior of PyTorch 2.5: https://github.com/pytorch/pytorch/issues/148356 |
| 1154 | + tag = get_layout_constraint_tag(target, with_default=False) |
| 1155 | + if ( |
| 1156 | + tag is None |
| 1157 | + and torch._library.utils.is_builtin(target) |
| 1158 | + and self.is_backward |
| 1159 | + ): |
| 1160 | + # for implicit fallback ATen ops during backward, if there |
| 1161 | + # is no layout constraint tag, we conservatively require contiguous |
| 1162 | + # input since some eager kernels do not |
| 1163 | + # support non-contiguous inputs. Otherwise they may silently cause |
| 1164 | + # accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452 |
| 1165 | + # We only do this For ATen ops and for backward. |
1166 | 1166 | # |
1167 | | - # For ATen ops, only apply the constraint for backward |
1168 | | - # ops since fwd ops should work for any strides. |
1169 | | - if torch._library.utils.is_builtin(target) and self.is_backward: |
1170 | | - decided_constraint = require_contiguous # type: ignore[assignment] |
1171 | | - else: |
1172 | | - # maybe_layout_constraints will decide the layout constraint for the custom op |
1173 | | - # lazily |
1174 | | - decided_constraint = None # type: ignore[assignment] |
1175 | | - |
1176 | | - # for implicitly fallback ops, we conservatively requires |
1177 | | - # contiguous input since some eager kernels does not |
1178 | | - # support non-contiguous inputs. They may silently cause |
1179 | | - # accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452 |
| 1167 | + # TODO: should really switch to "needs_fixed_stride" constraint on these |
| 1168 | + # and identify them one by one. |
| 1169 | + decided_constraint = require_contiguous # type: ignore[assignment] |
| 1170 | + else: |
| 1171 | + tag = get_layout_constraint_tag(target, with_default=True) |
| 1172 | + decided_constraint = tag_to_layout_constraint(tag) |
| 1173 | + |
1180 | 1174 | make_fallback(target, layout_constraint=decided_constraint) |
1181 | 1175 |
|
1182 | 1176 | elif get_decompositions([target]): |
|
0 commit comments