Skip to content

Conversation

@laithsakka
Copy link
Contributor

@laithsakka laithsakka commented Jul 26, 2025

Stack from ghstack (oldest at bottom):

This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context

we defined in this #157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.

when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()

one issue not handled well was this path

c10::SymBool TensorImpl::sym_is_contiguous_custom(
    at::MemoryFormat memory_format) const {
  if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
    return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
        this, memory_format);
  }

  return sym_is_contiguous_default(memory_format);
}

namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);

This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.

so I had to define it for pyinterpreter, and then I had to override it for nested tensors.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159197

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3c49eb1 with merge base 8ca8b60 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

laithsakka added a commit that referenced this pull request Jul 26, 2025
@laithsakka laithsakka marked this pull request as draft July 26, 2025 05:23
[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Jul 27, 2025
@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Jul 27, 2025
@laithsakka laithsakka added the keep-going Don't stop on first failure, keep running tests until the end label Jul 28, 2025
This might cause some new DDEs on call sites that do not use is_contiguous_or_false
but want to find those call sites to handle this properly by calling  is_contiguous_or_false and not is_contiguous

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Jul 29, 2025
laithsakka added a commit that referenced this pull request Jul 29, 2025
[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Jul 29, 2025
@laithsakka laithsakka changed the title remove default gso from normal contiguity checks remove default guard size oblivuous from normal contiguity checks and add aten.sym_is_contiguous. Jul 29, 2025
[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Jul 29, 2025
[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Jul 29, 2025
[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Jul 29, 2025
[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Jul 29, 2025
@laithsakka laithsakka changed the title remove default guard size oblivuous from normal contiguity checks and add aten.sym_is_contiguous. remove default guard size oblivuous from normal contiguity python checks and add aten.sym_is_contiguous. Jul 29, 2025
@laithsakka laithsakka changed the title remove default guard size oblivuous from normal contiguity python checks and add aten.sym_is_contiguous. Remove default guard size oblivious from normal contiguity python checks and add aten.sym_is_contiguous. Jul 29, 2025
@laithsakka
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@laithsakka
Copy link
Contributor Author

@pytorchbot revert

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 18, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -m/--message, -c/--classification

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Try @pytorchbot --help for more info.

@laithsakka
Copy link
Contributor Author

@pytorchbot revert -m "internal build faluires"

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 18, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -c/--classification

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Try @pytorchbot --help for more info.

@laithsakka
Copy link
Contributor Author

@pytorchbot revert -m "internal build failures" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Aug 18, 2025
@pytorchmergebot
Copy link
Collaborator

@laithsakka your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Aug 18, 2025
pytorch-bot bot pushed a commit that referenced this pull request Aug 18, 2025
… add aten.sym_is_contiguous. (#159197) (#159197)

Summary:
This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling  is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context

we defined in this #157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.

when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()

one issue not handled well was this path
```
c10::SymBool TensorImpl::sym_is_contiguous_custom(
    at::MemoryFormat memory_format) const {
  if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
    return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
        this, memory_format);
  }

  return sym_is_contiguous_default(memory_format);
}
```
namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);

This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.

so I had to define it for pyinterpreter, and then I had to override it for nested tensors.

Pull Request resolved: #159197
Approved by: https://github.com/ezyang

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/e444cd24d48b3a46f067974f2cc157f5ed27709f

Rollback Plan:

Differential Revision: D80435179
@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@sajjadghodsizad1-lab
Copy link

**
**

@sajjadghodsizad1-lab
Copy link

[

](url)

can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
… add aten.sym_is_contiguous. (pytorch#159197)

This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling  is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context

we defined in this pytorch#157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.

when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()

one issue not handled well was this path
```
c10::SymBool TensorImpl::sym_is_contiguous_custom(
    at::MemoryFormat memory_format) const {
  if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
    return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
        this, memory_format);
  }

  return sym_is_contiguous_default(memory_format);
}
```
namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);

This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.

so I had to define it for pyinterpreter, and then I had to override it for nested tensors.

Pull Request resolved: pytorch#159197
Approved by: https://github.com/ezyang
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
@laithsakka
Copy link
Contributor Author

published another identical PR

@laithsakka laithsakka closed this Sep 1, 2025
pytorch-bot bot pushed a commit that referenced this pull request Sep 2, 2025
… add aten.sym_is_contiguous. (#159197) (#160869)

Summary:

This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling  is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context

we defined in this #157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.

when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()

one issue not handled well was this path
```
c10::SymBool TensorImpl::sym_is_contiguous_custom(
    at::MemoryFormat memory_format) const {
  if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
    return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
        this, memory_format);
  }

  return sym_is_contiguous_default(memory_format);
}
```
namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);

This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.

so I had to define it for pyinterpreter, and then I had to override it for nested tensors.

Approved by: https://github.com/ezyang

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/e444cd24d48b3a46f067974f2cc157f5ed27709f

Rollback Plan:

Reviewed By: ezyang

Differential Revision: D80435179
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
… add aten.sym_is_contiguous. (pytorch#159197)

This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling  is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context

we defined in this pytorch#157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.

when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()

one issue not handled well was this path
```
c10::SymBool TensorImpl::sym_is_contiguous_custom(
    at::MemoryFormat memory_format) const {
  if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
    return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
        this, memory_format);
  }

  return sym_is_contiguous_default(memory_format);
}
```
namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);

This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.

so I had to define it for pyinterpreter, and then I had to override it for nested tensors.

Pull Request resolved: pytorch#159197
Approved by: https://github.com/ezyang
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
@github-actions github-actions bot deleted the gh/laithsakka/249/head branch October 2, 2025 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged module: dynamo release notes: fx release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants