-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Update functionalization metadata more eagerly. *_scatter ops should preserve input stride/storage_offset #88198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
This PR needs a labelIf your changes are user facing and intended to be a part of release notes, please use a label starting with If not, please add the For more information, see https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work. |
[ghstack-poisoned]
|
Darn, I'm getting the cursed error: Maybe, somehow when we regenerate the views for a mutated view, that regenerate tensor isn't getting proxies attached to its sizes/strides |
…ops should preserve input stride/storage_offset" Two major changes in this PR: (1) outputs of `*_scatter` ops will now reflect the same storage size, stride and storage_offset of their inputs. See more details at https://github.com/pytorch/pytorch/pull/87610/files#r1007264456. This fixes some silent correctness issues with functionalization advertising incorrect strides/storage_offsets on some tensors. (2) That actually isn't enough: We need to ensure that any time someone calls `.stride()` or `.storage_offset()` on a tensor, its metadata is not stale. To fix this, I updated all of the custom metadata calls on `FunctionalTensorWrapper` to perform a sync first if metadata is out of date. As a motivating example, consider this code: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) a_diag.add_(1) ``` Functionalization will temporarily turn this code into: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) # a_diag_updated has incorrect metadata! The output of add() is always a contiguous, densely packed tensor. # But a_diag_updated's metadata should advertise as (properly) not being contiguous! it has different strides a_diag_updated = a_diag.add(1) ``` As mentioned in the above comment, this isn't 100% correct - `a_diag_updated` has different metadata than `a_diag`. If user code queries the metadata on `a_diag` at this point, we'll return incorrect metadata. The fix: By eagerly running syncing logic before any metadata on `a_diag` is accessed, we'll do the following: ``` # reapply the mutation to the base, a a_updated = torch.diagonal_scatter(a, a_diag.add(1)) # regenerate a_diag with proper strides, from a_updated a_diag_updated = torch.diagonal(a_updated, 0, 1) ``` This ensures that any user code that calls `a_diag.stride()` or `a_diag.storage_offset()` sees accurate metadata [ghstack-poisoned]
…ops should preserve input stride/storage_offset" Two major changes in this PR: (1) outputs of `*_scatter` ops will now reflect the same storage size, stride and storage_offset of their inputs. See more details at https://github.com/pytorch/pytorch/pull/87610/files#r1007264456. This fixes some silent correctness issues with functionalization advertising incorrect strides/storage_offsets on some tensors. (2) That actually isn't enough: We need to ensure that any time someone calls `.stride()` or `.storage_offset()` on a tensor, its metadata is not stale. To fix this, I updated all of the custom metadata calls on `FunctionalTensorWrapper` to perform a sync first if metadata is out of date. As a motivating example, consider this code: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) a_diag.add_(1) ``` Functionalization will temporarily turn this code into: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) # a_diag_updated has incorrect metadata! The output of add() is always a contiguous, densely packed tensor. # But a_diag_updated's metadata should advertise as (properly) not being contiguous! it has different strides a_diag_updated = a_diag.add(1) ``` As mentioned in the above comment, this isn't 100% correct - `a_diag_updated` has different metadata than `a_diag`. If user code queries the metadata on `a_diag` at this point, we'll return incorrect metadata. The fix: By eagerly running syncing logic before any metadata on `a_diag` is accessed, we'll do the following: ``` # reapply the mutation to the base, a a_updated = torch.diagonal_scatter(a, a_diag.add(1)) # regenerate a_diag with proper strides, from a_updated a_diag_updated = torch.diagonal(a_updated, 0, 1) ``` This ensures that any user code that calls `a_diag.stride()` or `a_diag.storage_offset()` sees accurate metadata [ghstack-poisoned]
|
So I am actually not sure why you need to sync_ the metadata on each pass now. Don't you just need to update the inplace tensor metadata; the other tensors not changing their sizes and strides is fine, because whatever they used to have, should still be valid! |
|
Yeah, I think you're right. At first I was worried about other views of the base that also came from slice/select calls, but their metadata should be accurate to begin with - the only time the metadata becomes "wrong" is when we convert an inplace op into an out-of-place. |
…ops should preserve input stride/storage_offset" Two major changes in this PR: (1) outputs of `*_scatter` ops will now reflect the same storage size, stride and storage_offset of their inputs. See more details at https://github.com/pytorch/pytorch/pull/87610/files#r1007264456. This fixes some silent correctness issues with functionalization advertising incorrect strides/storage_offsets on some tensors. (2) That actually isn't enough: We need to ensure that any time someone calls `.stride()` or `.storage_offset()` on a tensor, its metadata is not stale. To fix this, I updated all of the custom metadata calls on `FunctionalTensorWrapper` to perform a sync first if metadata is out of date. As a motivating example, consider this code: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) a_diag.add_(1) ``` Functionalization will temporarily turn this code into: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) # a_diag_updated has incorrect metadata! The output of add() is always a contiguous, densely packed tensor. # But a_diag_updated's metadata should advertise as (properly) not being contiguous! it has different strides a_diag_updated = a_diag.add(1) ``` As mentioned in the above comment, this isn't 100% correct - `a_diag_updated` has different metadata than `a_diag`. If user code queries the metadata on `a_diag` at this point, we'll return incorrect metadata. The fix: By eagerly running syncing logic before any metadata on `a_diag` is accessed, we'll do the following: ``` # reapply the mutation to the base, a a_updated = torch.diagonal_scatter(a, a_diag.add(1)) # regenerate a_diag with proper strides, from a_updated a_diag_updated = torch.diagonal(a_updated, 0, 1) ``` This ensures that any user code that calls `a_diag.stride()` or `a_diag.storage_offset()` sees accurate metadata [ghstack-poisoned]
…ops should preserve input stride/storage_offset" Two major changes in this PR: (1) outputs of `*_scatter` ops will now reflect the same storage size, stride and storage_offset of their inputs. See more details at https://github.com/pytorch/pytorch/pull/87610/files#r1007264456. This fixes some silent correctness issues with functionalization advertising incorrect strides/storage_offsets on some tensors. (2) That actually isn't enough: We need to ensure that any time someone calls `.stride()` or `.storage_offset()` on a tensor, its metadata is not stale. To fix this, I updated all of the custom metadata calls on `FunctionalTensorWrapper` to perform a sync first if metadata is out of date. As a motivating example, consider this code: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) a_diag.add_(1) ``` Functionalization will temporarily turn this code into: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) # a_diag_updated has incorrect metadata! The output of add() is always a contiguous, densely packed tensor. # But a_diag_updated's metadata should advertise as (properly) not being contiguous! it has different strides a_diag_updated = a_diag.add(1) ``` As mentioned in the above comment, this isn't 100% correct - `a_diag_updated` has different metadata than `a_diag`. If user code queries the metadata on `a_diag` at this point, we'll return incorrect metadata. The fix: By eagerly running syncing logic before any metadata on `a_diag` is accessed, we'll do the following: ``` # reapply the mutation to the base, a a_updated = torch.diagonal_scatter(a, a_diag.add(1)) # regenerate a_diag with proper strides, from a_updated a_diag_updated = torch.diagonal(a_updated, 0, 1) ``` This ensures that any user code that calls `a_diag.stride()` or `a_diag.storage_offset()` sees accurate metadata [ghstack-poisoned]
…ops should preserve input stride/storage_offset" Two major changes in this PR: (1) outputs of `*_scatter` ops will now reflect the same storage size, stride and storage_offset of their inputs. See more details at https://github.com/pytorch/pytorch/pull/87610/files#r1007264456. This fixes some silent correctness issues with functionalization advertising incorrect strides/storage_offsets on some tensors. (2) That actually isn't enough: We need to ensure that any time someone calls `.stride()` or `.storage_offset()` on a tensor, its metadata is not stale. To fix this, I updated all of the custom metadata calls on `FunctionalTensorWrapper` to perform a sync first if metadata is out of date. As a motivating example, consider this code: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) a_diag.add_(1) ``` Functionalization will temporarily turn this code into: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) # a_diag_updated has incorrect metadata! The output of add() is always a contiguous, densely packed tensor. # But a_diag_updated's metadata should advertise as (properly) not being contiguous! it has different strides a_diag_updated = a_diag.add(1) ``` As mentioned in the above comment, this isn't 100% correct - `a_diag_updated` has different metadata than `a_diag`. If user code queries the metadata on `a_diag` at this point, we'll return incorrect metadata. The fix: By eagerly running syncing logic before any metadata on `a_diag` is accessed, we'll do the following: ``` # reapply the mutation to the base, a a_updated = torch.diagonal_scatter(a, a_diag.add(1)) # regenerate a_diag with proper strides, from a_updated a_diag_updated = torch.diagonal(a_updated, 0, 1) ``` This ensures that any user code that calls `a_diag.stride()` or `a_diag.storage_offset()` sees accurate metadata [ghstack-poisoned]
…ops should preserve input stride/storage_offset" Two major changes in this PR: (1) outputs of `*_scatter` ops will now reflect the same storage size, stride and storage_offset of their inputs. See more details at https://github.com/pytorch/pytorch/pull/87610/files#r1007264456. This fixes some silent correctness issues with functionalization advertising incorrect strides/storage_offsets on some tensors. (2) That actually isn't enough: We need to ensure that any time someone calls `.stride()` or `.storage_offset()` on a tensor, its metadata is not stale. To fix this, I updated all of the custom metadata calls on `FunctionalTensorWrapper` to perform a sync first if metadata is out of date. As a motivating example, consider this code: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) a_diag.add_(1) ``` Functionalization will temporarily turn this code into: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) # a_diag_updated has incorrect metadata! The output of add() is always a contiguous, densely packed tensor. # But a_diag_updated's metadata should advertise as (properly) not being contiguous! it has different strides a_diag_updated = a_diag.add(1) ``` As mentioned in the above comment, this isn't 100% correct - `a_diag_updated` has different metadata than `a_diag`. If user code queries the metadata on `a_diag` at this point, we'll return incorrect metadata. The fix: By eagerly running syncing logic before any metadata on `a_diag` is accessed, we'll do the following: ``` # reapply the mutation to the base, a a_updated = torch.diagonal_scatter(a, a_diag.add(1)) # regenerate a_diag with proper strides, from a_updated a_diag_updated = torch.diagonal(a_updated, 0, 1) ``` This ensures that any user code that calls `a_diag.stride()` or `a_diag.storage_offset()` sees accurate metadata [ghstack-poisoned]
…ops should preserve input stride/storage_offset" Two major changes in this PR: (1) outputs of `*_scatter` ops will now reflect the same storage size, stride and storage_offset of their inputs. See more details at https://github.com/pytorch/pytorch/pull/87610/files#r1007264456. This fixes some silent correctness issues with functionalization advertising incorrect strides/storage_offsets on some tensors. (2) That actually isn't enough: We need to ensure that any time someone calls `.stride()` or `.storage_offset()` on a tensor, its metadata is not stale. To fix this, I updated all of the custom metadata calls on `FunctionalTensorWrapper` to perform a sync first if metadata is out of date. As a motivating example, consider this code: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) a_diag.add_(1) ``` Functionalization will temporarily turn this code into: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) # a_diag_updated has incorrect metadata! The output of add() is always a contiguous, densely packed tensor. # But a_diag_updated's metadata should advertise as (properly) not being contiguous! it has different strides a_diag_updated = a_diag.add(1) ``` As mentioned in the above comment, this isn't 100% correct - `a_diag_updated` has different metadata than `a_diag`. If user code queries the metadata on `a_diag` at this point, we'll return incorrect metadata. The fix: By eagerly running syncing logic before any metadata on `a_diag` is accessed, we'll do the following: ``` # reapply the mutation to the base, a a_updated = torch.diagonal_scatter(a, a_diag.add(1)) # regenerate a_diag with proper strides, from a_updated a_diag_updated = torch.diagonal(a_updated, 0, 1) ``` This ensures that any user code that calls `a_diag.stride()` or `a_diag.storage_offset()` sees accurate metadata [ghstack-poisoned]
…ops should preserve input stride/storage_offset" Two major changes in this PR: (1) outputs of `*_scatter` ops will now reflect the same storage size, stride and storage_offset of their inputs. See more details at https://github.com/pytorch/pytorch/pull/87610/files#r1007264456. This fixes some silent correctness issues with functionalization advertising incorrect strides/storage_offsets on some tensors. (2) That actually isn't enough: We need to ensure that any time someone calls `.stride()` or `.storage_offset()` on a tensor, its metadata is not stale. To fix this, I updated all of the custom metadata calls on `FunctionalTensorWrapper` to perform a sync first if metadata is out of date. As a motivating example, consider this code: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) a_diag.add_(1) ``` Functionalization will temporarily turn this code into: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) # a_diag_updated has incorrect metadata! The output of add() is always a contiguous, densely packed tensor. # But a_diag_updated's metadata should advertise as (properly) not being contiguous! it has different strides a_diag_updated = a_diag.add(1) ``` As mentioned in the above comment, this isn't 100% correct - `a_diag_updated` has different metadata than `a_diag`. If user code queries the metadata on `a_diag` at this point, we'll return incorrect metadata. The fix: By eagerly running syncing logic before any metadata on `a_diag` is accessed, we'll do the following: ``` # reapply the mutation to the base, a a_updated = torch.diagonal_scatter(a, a_diag.add(1)) # regenerate a_diag with proper strides, from a_updated a_diag_updated = torch.diagonal(a_updated, 0, 1) ``` This ensures that any user code that calls `a_diag.stride()` or `a_diag.storage_offset()` sees accurate metadata [ghstack-poisoned]
…ops should preserve input stride/storage_offset" Two major changes in this PR: (1) outputs of `*_scatter` ops will now reflect the same storage size, stride and storage_offset of their inputs. See more details at https://github.com/pytorch/pytorch/pull/87610/files#r1007264456. This fixes some silent correctness issues with functionalization advertising incorrect strides/storage_offsets on some tensors. (2) That actually isn't enough: We need to ensure that any time someone calls `.stride()` or `.storage_offset()` on a tensor, its metadata is not stale. To fix this, I updated all of the custom metadata calls on `FunctionalTensorWrapper` to perform a sync first if metadata is out of date. As a motivating example, consider this code: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) a_diag.add_(1) ``` Functionalization will temporarily turn this code into: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) # a_diag_updated has incorrect metadata! The output of add() is always a contiguous, densely packed tensor. # But a_diag_updated's metadata should advertise as (properly) not being contiguous! it has different strides a_diag_updated = a_diag.add(1) ``` As mentioned in the above comment, this isn't 100% correct - `a_diag_updated` has different metadata than `a_diag`. If user code queries the metadata on `a_diag` at this point, we'll return incorrect metadata. The fix: By eagerly running syncing logic before any metadata on `a_diag` is accessed, we'll do the following: ``` # reapply the mutation to the base, a a_updated = torch.diagonal_scatter(a, a_diag.add(1)) # regenerate a_diag with proper strides, from a_updated a_diag_updated = torch.diagonal(a_updated, 0, 1) ``` This ensures that any user code that calls `a_diag.stride()` or `a_diag.storage_offset()` sees accurate metadata [ghstack-poisoned]
…ops should preserve input stride/storage_offset" Two major changes in this PR: (1) outputs of `*_scatter` ops will now reflect the same storage size, stride and storage_offset of their inputs. See more details at https://github.com/pytorch/pytorch/pull/87610/files#r1007264456. This fixes some silent correctness issues with functionalization advertising incorrect strides/storage_offsets on some tensors. (2) That actually isn't enough: We need to ensure that any time someone calls `.stride()` or `.storage_offset()` on a tensor, its metadata is not stale. To fix this, I updated all of the custom metadata calls on `FunctionalTensorWrapper` to perform a sync first if metadata is out of date. As a motivating example, consider this code: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) a_diag.add_(1) ``` Functionalization will temporarily turn this code into: ``` a = torch.ones(2, 2) a_diag = torch.diagonal(a, 0, 1) # a_diag_updated has incorrect metadata! The output of add() is always a contiguous, densely packed tensor. # But a_diag_updated's metadata should advertise as (properly) not being contiguous! it has different strides a_diag_updated = a_diag.add(1) ``` As mentioned in the above comment, this isn't 100% correct - `a_diag_updated` has different metadata than `a_diag`. If user code queries the metadata on `a_diag` at this point, we'll return incorrect metadata. The fix: By eagerly running syncing logic before any metadata on `a_diag` is accessed, we'll do the following: ``` # reapply the mutation to the base, a a_updated = torch.diagonal_scatter(a, a_diag.add(1)) # regenerate a_diag with proper strides, from a_updated a_diag_updated = torch.diagonal(a_updated, 0, 1) ``` This ensures that any user code that calls `a_diag.stride()` or `a_diag.storage_offset()` sees accurate metadata [ghstack-poisoned]
This is a carve out of #88198 to diagnose CI failures Signed-off-by: Edward Z. Yang <[email protected]> [ghstack-poisoned]
|
There are a bunch of things going on here, but when I ablate everything except the clone preserving strides changes (#89474) it still fails |
|
So, I'm a little confused about whether or not this PR is still necessary. The stated justification was to make "do not use unsafe restriding for subclasses" PR be able to go in, c.f. https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#diff-c4e35e76279419d9edda979a995412ee494fad3a97d3fb59486a1f0e326d48be But as best as I can tell, this already has gone in to master without trouble. So is this just for a hypothetical situation where functionalization can produce incorrect strides? |
|
we think this is obsolete |
Two major changes in this PR:
(1) outputs of
*_scatterops will now reflect the same storage size, stride and storage_offset of their inputs. See more details at https://github.com/pytorch/pytorch/pull/87610/files#r1007264456. This fixes some silent correctness issues with functionalization advertising incorrect strides/storage_offsets on some tensors.(2) That actually isn't enough: We need to ensure that any time someone calls
.stride()or.storage_offset()on a tensor, its metadata is not stale. To fix this, I updated all of the custom metadata calls onFunctionalTensorWrapperto perform a sync first if metadata is out of date.As a motivating example, consider this code:
Functionalization will temporarily turn this code into:
As mentioned in the above comment, this isn't 100% correct -
a_diag_updatedhas different metadata thana_diag. If user code queries the metadata ona_diagat this point, we'll return incorrect metadata.The fix: By eagerly running syncing logic before any metadata on
a_diagis accessed, we'll do the following:This ensures that any user code that calls
a_diag.stride()ora_diag.storage_offset()sees accurate metadataStack from ghstack (oldest at bottom):