Skip to content

Conversation

@jansel
Copy link
Contributor

@jansel jansel commented Sep 19, 2025

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163377

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ef87cf7 with merge base 51152ef (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jansel added a commit that referenced this pull request Sep 19, 2025
Fixes #163037


ghstack-source-id: fc244cf
Pull-Request: #163377
@jansel jansel added the topic: not user facing topic category label Sep 19, 2025
[ghstack-poisoned]
jansel added a commit that referenced this pull request Sep 20, 2025
Fixes #163037


ghstack-source-id: 0c237b0
Pull-Request: #163377
[ghstack-poisoned]
jansel added a commit that referenced this pull request Sep 20, 2025
Fixes #163037


ghstack-source-id: 929c1d6
Pull-Request: #163377
[ghstack-poisoned]
jansel added a commit that referenced this pull request Sep 20, 2025
Fixes #163037

ghstack-source-id: 329d227
Pull-Request: #163377
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@jansel jansel requested a review from bdhirsh September 23, 2025 04:40
[ghstack-poisoned]
@eellison eellison requested review from ezyang and zou3519 September 23, 2025 17:02
[ghstack-poisoned]
jansel added a commit that referenced this pull request Sep 23, 2025
Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, it seems Codex found an existing wrong functionalization rule for rng in the prims directory. Here's what a real, full-bodied view functionalization rule looks like:

    at::Tensor view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) {

      at::Tensor self_;
      if (at::functionalization::impl::isFunctionalTensor(self)) {

        self_ = at::functionalization::impl::from_functional_tensor(self);
      } else {
        self_ = self;
      }
      if (!at::functionalization::impl::isFunctionalTensor(self)) {
        // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
        at::AutoDispatchSkipFunctionalize guard;
        return at::_ops::view::call(self_, size);
      }
      auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
      auto inverse_return_mode = (
          reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse
            : at::functionalization::InverseReturnMode::NeverView
      );
      auto compute_reference_meta =
        self.key_set().has_backend(c10::BackendComponent::XLABit) ||
        self.key_set().has_backend(c10::BackendComponent::LazyBit);
      at::Tensor reference_tensor_output;
      if (compute_reference_meta && !disable_meta_reference()) {
        auto self_meta = to_meta(self);
        at::AutoDispatchSkipFunctionalize func_guard;
        c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
        reference_tensor_output = at::_ops::view::call(self_meta, size);
      }
      at::Tensor tmp_output;
      {
        at::AutoDispatchSkipFunctionalize guard;
        if (reapply_views) {
          tmp_output = at::_ops::view::call(self_, size);
        } else {
          tmp_output = at::_ops::view_copy::call(self_, size);
        }
      }

      bool has_symbolic_inputs = false;
      has_symbolic_inputs = has_symbolic_inputs | (std::any_of(size.begin(), size.end(), [=](auto& arg) { return arg.is_symbolic(); }));
      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor {
          if (reapply_views) {
            return at::_ops::view::call(base, size);
          } else {
            return at::_ops::view_copy::call(base, size);
          }
        },
        [inverse_return_mode = inverse_return_mode, size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::FunctionalInverses::view_inverse(base, mutated_view, inverse_return_mode, size);
        },
        /*has_symbolic_inputs=*/has_symbolic_inputs,
        /*is_multi_output=*/false,
        /*is_as_strided=*/false
      );
      auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      if (compute_reference_meta && !disable_meta_reference()) {
        at::functionalization::impl::set_sizes_strides_offset(out, reference_tensor_output);
      }
      return out;
    }

for broadcast_in_dim to be implemented correctly it needs to follow this structure more closely

[ghstack-poisoned]
jansel added a commit that referenced this pull request Sep 24, 2025
@jansel jansel marked this pull request as draft September 24, 2025 04:09
[ghstack-poisoned]
jansel added a commit that referenced this pull request Sep 24, 2025
original_idx = 0
for idx in range(len(shape)):
if idx in broadcast_dims_set:
size = tensor_sizes[original_idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we write this to be unbacked friendly
the unbacked semantics for broadcasting is that if we cant tell if its a broadcast case or not
we would assume no broadcasting and do a torch_check.
example this is the meta version

https://www.internalfb.com/code/fbsource/[e18e1578407a804d7877bf6be709197b739f6eae]/fbcode/caffe2/torch/_prims/__init__.py?lines=1296

[ghstack-poisoned]
jansel added a commit that referenced this pull request Sep 24, 2025
@jansel
Copy link
Contributor Author

jansel commented Sep 24, 2025

Updated prompt (I reverted the prior version and redid it):

Write a functionalization rule in C++ for broadcast_in_dim to fix failures in repro_broadcast_in_dim.py and the test test_prims_broadcast_in_dim_alias. Try to follow the pattern of other rules like view. Perhaps we could even rewrite broadcast_in_dim to view during functionalization. At the end, please provide a one paragraph summary of your approach for use in a github comment. When you change C++ rebuild with python setup.py develop, you have a working build environment.

Summary:

Implemented a Functionalize kernel for prims::broadcast_in_dim that rewrites the op into the
same unsqueeze+expand view pattern we use in eager, captures the resulting size/stride metadata, and registers a prims-
specific rule so functionalization rewrites it to the appropriate as_strided view; with the new rule in place both the
standalone repro and CPUReproTests.test_prims_broadcast_in_dim_alias now pass

[ghstack-poisoned]
jansel added a commit that referenced this pull request Sep 25, 2025
@jansel jansel marked this pull request as ready for review September 25, 2025 03:25
@jansel jansel requested a review from ezyang September 25, 2025 03:25
[ghstack-poisoned]
jansel added a commit that referenced this pull request Sep 25, 2025
@ngimel
Copy link
Collaborator

ngimel commented Sep 25, 2025

When would this functionalization be needed? this prim should not be used in the usual lowering process, and if someone wants to just call the function that has this behavior, they can write it with unsqueeze and expand and it will be functionalized normally.

@jansel
Copy link
Contributor Author

jansel commented Sep 26, 2025

Someone called the op directly.

@ngimel
Copy link
Collaborator

ngimel commented Sep 26, 2025

well they shouldn't

@eellison eellison removed their request for review October 14, 2025 19:29
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Dec 13, 2025
@jansel jansel closed this Dec 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants