-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
The is_same_shape call doesn't use size oblivious / expect true so it will choke if there are unbacked SymInts involved.
You can work around it with something like this:
diff --git a/fbcode/caffe2/torch/csrc/autograd/engine.cpp b/fbcode/caffe2/torch/csrc/autograd/engine.cpp
--- a/fbcode/caffe2/torch/csrc/autograd/engine.cpp
+++ b/fbcode/caffe2/torch/csrc/autograd/engine.cpp
@@ -847,14 +847,7 @@
continue;
}
- if (!metadata.is_same_shape(grad)) {
- if (metadata.is_expandable_to_shape(grad)) {
- grad = metadata.reduce_grad(grad);
- } else {
- const auto message = metadata.incompatible_shape_error_message(i, grad);
- TORCH_CHECK(false, format_error(message.str()));
- }
- }
+ grad = metadata.reduce_grad(grad);
bool input_is_complex =
isComplexType(c10::typeMetaToScalarType(metadata.options().dtype()));
Error looks something like:
It appears that you're trying to get a value out of symbolic int/float whose value is data-dependent (and thus we do not know the true value.) The expression we were trying to evaluate is Eq(i2 + i3, i4 + i5) (unhinted: Eq(i2 + i3, i4 + i5)). Scroll up to see where each of these data-dependent accesses originally occurred.
Note that it looks like maybe i4 and i5 are reallocations of i2 and i3, but these reallocations were actually done in user code (not in the framework) because they are repeatedly tolist()'ing the same tensor.
Versions
main
cc @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @msaroufim @bdhirsh @anijain2305
Metadata
Metadata
Assignees
Labels
module: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module