-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Fix prims.broadcast_in_dim functionalization #163377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163377
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ef87cf7 with merge base 51152ef ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, it seems Codex found an existing wrong functionalization rule for rng in the prims directory. Here's what a real, full-bodied view functionalization rule looks like:
at::Tensor view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) {
at::Tensor self_;
if (at::functionalization::impl::isFunctionalTensor(self)) {
self_ = at::functionalization::impl::from_functional_tensor(self);
} else {
self_ = self;
}
if (!at::functionalization::impl::isFunctionalTensor(self)) {
// functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
at::AutoDispatchSkipFunctionalize guard;
return at::_ops::view::call(self_, size);
}
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
auto inverse_return_mode = (
reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse
: at::functionalization::InverseReturnMode::NeverView
);
auto compute_reference_meta =
self.key_set().has_backend(c10::BackendComponent::XLABit) ||
self.key_set().has_backend(c10::BackendComponent::LazyBit);
at::Tensor reference_tensor_output;
if (compute_reference_meta && !disable_meta_reference()) {
auto self_meta = to_meta(self);
at::AutoDispatchSkipFunctionalize func_guard;
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
reference_tensor_output = at::_ops::view::call(self_meta, size);
}
at::Tensor tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
if (reapply_views) {
tmp_output = at::_ops::view::call(self_, size);
} else {
tmp_output = at::_ops::view_copy::call(self_, size);
}
}
bool has_symbolic_inputs = false;
has_symbolic_inputs = has_symbolic_inputs | (std::any_of(size.begin(), size.end(), [=](auto& arg) { return arg.is_symbolic(); }));
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
[reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor {
if (reapply_views) {
return at::_ops::view::call(base, size);
} else {
return at::_ops::view_copy::call(base, size);
}
},
[inverse_return_mode = inverse_return_mode, size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor {
return at::functionalization::FunctionalInverses::view_inverse(base, mutated_view, inverse_return_mode, size);
},
/*has_symbolic_inputs=*/has_symbolic_inputs,
/*is_multi_output=*/false,
/*is_as_strided=*/false
);
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, view_meta);
// See Note [Propagating strides in the functionalization pass]
if (compute_reference_meta && !disable_meta_reference()) {
at::functionalization::impl::set_sizes_strides_offset(out, reference_tensor_output);
}
return out;
}
for broadcast_in_dim to be implemented correctly it needs to follow this structure more closely
torch/_prims/__init__.py
Outdated
| original_idx = 0 | ||
| for idx in range(len(shape)): | ||
| if idx in broadcast_dims_set: | ||
| size = tensor_sizes[original_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we write this to be unbacked friendly
the unbacked semantics for broadcasting is that if we cant tell if its a broadcast case or not
we would assume no broadcasting and do a torch_check.
example this is the meta version
|
Updated prompt (I reverted the prior version and redid it):
Summary:
|
|
When would this functionalization be needed? this prim should not be used in the usual lowering process, and if someone wants to just call the function that has this behavior, they can write it with |
|
Someone called the op directly. |
|
well they shouldn't |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack (oldest at bottom):
Fixes #163037
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @chenyang78