-
Notifications
You must be signed in to change notification settings - Fork 26.3k
CustomOp simple abstract implementation registration #99439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR: - adds a FakeTensor registration API for CustomOp (CustomOp.impl_fake) - adds a FakeTensor registration API for torch.library (impl_fake) The FakeTensor implementation provided by the user: - accepts FakeTensors and returns FakeTensors - If the signature of the operator is (*args, **kwargs), then the signature of the FakeTensor implementation is (ctx, *args, **kwargs). - The `ctx` arg is a new FakeTensorImplCtx object that stores helper methods. So far, the helper methods are (1) constructing a new unbacked symint and (2) constraining the range of the symint. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their FakeTensor implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.new_data_dependent_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99439
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit dce4b8e: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR: - adds a FakeTensor registration API for CustomOp (CustomOp.impl_fake) - adds a FakeTensor registration API for torch.library (impl_fake) The FakeTensor implementation provided by the user: - accepts FakeTensors and returns FakeTensors - If the signature of the operator is (*args, **kwargs), then the signature of the FakeTensor implementation is (ctx, *args, **kwargs). - The `ctx` arg is a new FakeTensorImplCtx object that stores helper methods. So far, the helper methods are (1) constructing a new unbacked symint and (2) constraining the range of the symint. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their FakeTensor implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.new_data_dependent_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. ghstack-source-id: 6752359 Pull Request resolved: #99439
This PR: - adds a FakeTensor registration API for CustomOp (CustomOp.impl_fake) - deletes CustomOp.impl_meta The user story behind this API is that it is the one-stop shop for registering implementations for data-less Tensors, i.e. FakeTensor and Meta tensor. The fake implementation provided by the user: - gets registered as the FakeTensor implementation AND the meta formula - can be written like a regular meta formula. If the user decides that they need something more special (i.e. data-dependent output shape), then they are able to query a current context object (FakeTensorImplCtx) that has methods to construct new unbacked symints. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their FakeTensor implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.new_data_dependent_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Public API: - More docs will come when we actually expose this API to users by putting it in a public namespace, unless you folks want it now. - The APIs mentioned in `__all__` are the ones that are intended to be public. Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. [ghstack-poisoned]
This PR: - adds a FakeTensor registration API for CustomOp (CustomOp.impl_fake) - deletes CustomOp.impl_meta The user story behind this API is that it is the one-stop shop for registering implementations for data-less Tensors, i.e. FakeTensor and Meta tensor. The fake implementation provided by the user: - gets registered as the FakeTensor implementation AND the meta formula - can be written like a regular meta formula. If the user decides that they need something more special (i.e. data-dependent output shape), then they are able to query a current context object (FakeTensorImplCtx) that has methods to construct new unbacked symints. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their FakeTensor implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.new_data_dependent_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Public API: - More docs will come when we actually expose this API to users by putting it in a public namespace, unless you folks want it now. - The APIs mentioned in `__all__` are the ones that are intended to be public. Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. [ghstack-poisoned]
This PR: - adds a FakeTensor registration API for CustomOp (CustomOp.impl_fake) - deletes CustomOp.impl_meta The user story behind this API is that it is the one-stop shop for registering implementations for data-less Tensors, i.e. FakeTensor and Meta tensor. The fake implementation provided by the user: - gets registered as the FakeTensor implementation AND the meta formula - can be written like a regular meta formula. If the user decides that they need something more special (i.e. data-dependent output shape), then they are able to query a current context object (FakeTensorImplCtx) that has methods to construct new unbacked symints. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their FakeTensor implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.new_data_dependent_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Public API: - More docs will come when we actually expose this API to users by putting it in a public namespace, unless you folks want it now. - The APIs mentioned in `__all__` are the ones that are intended to be public. Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. [ghstack-poisoned]
This PR: - adds a FakeTensor registration API for CustomOp (CustomOp.impl_fake) - deletes CustomOp.impl_meta The user story behind this API is that it is the one-stop shop for registering implementations for data-less Tensors, i.e. FakeTensor and Meta tensor. The fake implementation provided by the user: - gets registered as the FakeTensor implementation AND the meta formula - can be written like a regular meta formula. If the user decides that they need something more special (i.e. data-dependent output shape), then they are able to query a current context object (FakeTensorImplCtx) that has methods to construct new unbacked symints. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their FakeTensor implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.new_data_dependent_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Public API: - More docs will come when we actually expose this API to users by putting it in a public namespace, unless you folks want it now. - The APIs mentioned in `__all__` are the ones that are intended to be public. Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. [ghstack-poisoned]
This PR: - adds a FakeTensor registration API for CustomOp (CustomOp.impl_fake) - deletes CustomOp.impl_meta The user story behind this API is that it is the one-stop shop for registering implementations for data-less Tensors, i.e. FakeTensor and Meta tensor. The fake implementation provided by the user: - gets registered as the FakeTensor implementation AND the meta formula - can be written like a regular meta formula. If the user decides that they need something more special (i.e. data-dependent output shape), then they are able to query a current context object (FakeTensorImplCtx) that has methods to construct new unbacked symints. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their FakeTensor implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.new_data_dependent_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Public API: - More docs will come when we actually expose this API to users by putting it in a public namespace, unless you folks want it now. - The APIs mentioned in `__all__` are the ones that are intended to be public. Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. ghstack-source-id: 8fa78ad Pull Request resolved: #99439
This PR: - adds a FakeTensor registration API for CustomOp (CustomOp.impl_fake) - deletes CustomOp.impl_meta The user story behind this API is that it is the one-stop shop for registering implementations for data-less Tensors, i.e. FakeTensor and Meta tensor. The fake implementation provided by the user: - gets registered as the FakeTensor implementation AND the meta formula - can be written like a regular meta formula. If the user decides that they need something more special (i.e. data-dependent output shape), then they are able to query a current context object (FakeTensorImplCtx) that has methods to construct new unbacked symints. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their FakeTensor implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.new_data_dependent_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Public API: - More docs will come when we actually expose this API to users by putting it in a public namespace, unless you folks want it now. - The APIs mentioned in `__all__` are the ones that are intended to be public. Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. [ghstack-poisoned]
This PR: - adds a FakeTensor registration API for CustomOp (CustomOp.impl_fake) - deletes CustomOp.impl_meta The user story behind this API is that it is the one-stop shop for registering implementations for data-less Tensors, i.e. FakeTensor and Meta tensor. The fake implementation provided by the user: - gets registered as the FakeTensor implementation AND the meta formula - can be written like a regular meta formula. If the user decides that they need something more special (i.e. data-dependent output shape), then they are able to query a current context object (FakeTensorImplCtx) that has methods to construct new unbacked symints. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their FakeTensor implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.new_data_dependent_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Public API: - More docs will come when we actually expose this API to users by putting it in a public namespace, unless you folks want it now. - The APIs mentioned in `__all__` are the ones that are intended to be public. Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. ghstack-source-id: 6a155c2 Pull Request resolved: #99439
|
I support pubbing FakeTensor. Similarly, proxy_tensor and symbolic_shapes should graduate from experimental. |
| def numpy_take_fake(x, ind, ind_inv, dim): | ||
| return torch.empty_like(x) | ||
|
|
||
| @custom_op('(Tensor x) -> Tensor', ns='_torch_testing') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm getting some buyer's remorse from removing the deduplication; the main downside is I can no longer grep _torch_testing::numpy_nonzero to find the operator that defines this lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, we can bring the name back and just assert that the name mentioned in the schema matches the function name:
# with schema
@custom_op('_torch_testing::numpy_nonzero', '(Tensor x) -> Tensor')
def numpy_nonzero(x):
...
# without schema (inferring schema from type annotations)
@custom_op('_torch_testing::numpy_nonzero')
def numpy_nonzero(x: Tensor) -> Tensor:
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(will leave for the future)
|
|
||
| # Users can register FakeTensor rules for custom operators | ||
| # Call them if they exist. | ||
| if func.name() in torch._custom_op.global_registry: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason this table isn't indexed on func itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lifetime issues, but I am very close to declaring bankruptcy on lifetime issues anyways.
The problem is that once we create the OpOverload, if the operator gets unregistered and re-registered, invoking the OpOverload leads to segfaults. So as written in this PR, we index the registry on qualname to avoid creating the OpOverload.
This only really matters for testing, but I might (in a future PR) just develop a mechanism for testing to clear custom OpOverloads because testing needs to materialize the OpOverloads anyways when we do things like run make_fx.
This PR: - adds a FakeTensor registration API for CustomOp (CustomOp.impl_fake) - deletes CustomOp.impl_meta The user story behind this API is that it is the one-stop shop for registering implementations for data-less Tensors, i.e. FakeTensor and Meta tensor. The fake implementation provided by the user: - gets registered as the FakeTensor implementation AND the meta formula - can be written like a regular meta formula. If the user decides that they need something more special (i.e. data-dependent output shape), then they are able to query a current context object (FakeTensorImplCtx) that has methods to construct new unbacked symints. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their FakeTensor implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.new_data_dependent_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Public API: - More docs will come when we actually expose this API to users by putting it in a public namespace, unless you folks want it now. - The APIs mentioned in `__all__` are the ones that are intended to be public. Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. ghstack-source-id: ded45f8 Pull Request resolved: #99439
| @numpy_nonzero.impl_fake() | ||
| def numpy_nonzero_fake(x): | ||
| ctx = torch._custom_op.get_ctx() | ||
| i0 = ctx.new_data_dependent_symint() | ||
| shape = [x.dim(), i0] | ||
| result = x.new_empty(shape, dtype=torch.long) | ||
| return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opening up the floor for some bikeshedding. Here's some questions to get started (cc @ezyang
ctx.new_data_dependent_symint()
- ctx.new_symint()? ctx.new_size()? ctx.new_unbacked_symint()?
numpy_nonzero.impl_fake
- Should we call our "fake implementation" "abstract implementation" instead, and rename this to
impl_abstract? "fake implementation" sounds a bit confusing to me because the implementation ends up being registered for both Meta tensors and FakeTensors, but another way to think about this is FakeTensor is a superset of Meta tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re (1), absent any pressure, I would line up the name with ShapeEnv, so create_unbacked_symint
Re (2), I actually kind of think we should lie and call it impl_meta anyway. I think my opinion is a bit different if we're going to support mixed-device custom ops, but if you're registering these implementations for Meta, they can't actually rely on a fake tensor mode being active (and thus always have to use new_empty and similar to make sure they are generated on the correct device.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re (1), absent any pressure, I would line up the name with ShapeEnv, so create_unbacked_symint
sgtm
Re (2), I actually kind of think we should lie and call it impl_meta anyway. I think my opinion is a bit different if we're going to support mixed-device custom ops, but if you're registering these implementations for Meta, they can't actually rely on a fake tensor mode being active (and thus always have to use new_empty and similar to make sure they are generated on the correct device.)
I'm not sure we should lie, it would be bad if the user started creating new meta tensors inside the function (e.g. torch.empty(3, 3, device=meta)) because as it is currently set up: inside the FakeTensor rule, since the FakeTensorMode is active, we would interpret them as a FakeTensor wrapping a meta tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, well, then we shouldn't call it fake or meta lol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not fake? Because it's not always a FakeTensor rule? (e.g. when registered to meta?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re (2), I actually kind of think we should lie and call it impl_meta anyway. I think my opinion is a bit different if we're going to support mixed-device custom ops, but if you're registering these implementations for Meta, they can't actually rely on a fake tensor mode being active (and thus always have to use new_empty and similar to make sure they are generated on the correct device.)
So, one idea for this is that for the meta implementation, we can activate FakeTensor mode and send FakeTensors that wrap meta tensors into the implementation. The benefit of this is that we're no longer lying about this being a fake impl and that the user doesn't need to worry about a third thing (abstract impl vs meta impl vs fake impl).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, one idea for this is that for the meta implementation, we can activate FakeTensor mode and send FakeTensors that wrap meta tensors into the implementation. The benefit of this is that we're no longer lying about this being a fake impl and that the user doesn't need to worry about a third thing (abstract impl vs meta impl vs fake impl).
This doesn't seem possible because the Meta dispatch key is after Python
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to impl_abstract.
| with set_ctx_getter(error_on_ctx): | ||
| return f(*args, **kwargs) | ||
|
|
||
| self._lib.impl(self._opname, f_with_ctx, "Meta") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question to cement my understanding: Is it the case that this registration to the Meta key will basically never be called when using torch.compile, since we're basically guaranteeing that every custom op will have both a Meta impl and a FakeTensor impl, and FakeTensorMode always gives priority to the FakeTensor impl?
I guess the only case where we'd use this meta impl is if there's user code directly using meta tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's correct. The last part of the FakeTensorMode __torch_dispatch__ is to call the meta impl, if it exists.
| assert scores.shape == (N,) | ||
|
|
||
| ctx = torch._custom_op.get_ctx() | ||
| i0 = ctx.new_data_dependent_symint() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very cool :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you mentioned it before, but will it be possible to use this machinery to write FakeTensor rules for existing C++ custom ops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The plan is for users to wrap their existing C++ custom ops in a CustomOp object, and then use this API to define the FakeTensor rule
This PR: - adds an abstract registration API for CustomOp (CustomOp.impl_abstract) that is used for both FakeTensor and meta tensors - deletes CustomOp.impl_meta The user story behind this API is that it is the one-stop shop for registering implementations for data-less Tensors, i.e. FakeTensor and Meta tensor. The abstract implementation provided by the user: - gets registered as the FakeTensor implementation AND the meta formula - can be written like a regular meta formula. If the user decides that they need something more special (i.e. data-dependent output shape), then they are able to query a current context object (FakeTensorImplCtx) that has methods to construct new unbacked symints. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their abstract implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.create_unbacked_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Public API: - More docs will come when we actually expose this API to users by putting it in a public namespace, unless you folks want it now. - The APIs mentioned in `__all__` are the ones that are intended to be public. Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. [ghstack-poisoned]
This PR: - adds an abstract registration API for CustomOp (CustomOp.impl_abstract) that is used for both FakeTensor and meta tensors - deletes CustomOp.impl_meta The user story behind this API is that it is the one-stop shop for registering implementations for data-less Tensors, i.e. FakeTensor and Meta tensor. The abstract implementation provided by the user: - gets registered as the FakeTensor implementation AND the meta formula - can be written like a regular meta formula. If the user decides that they need something more special (i.e. data-dependent output shape), then they are able to query a current context object (FakeTensorImplCtx) that has methods to construct new unbacked symints. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their abstract implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.create_unbacked_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Public API: - More docs will come when we actually expose this API to users by putting it in a public namespace, unless you folks want it now. - The APIs mentioned in `__all__` are the ones that are intended to be public. Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. ghstack-source-id: a087856 Pull Request resolved: #99439
This PR: - adds an abstract registration API for CustomOp (CustomOp.impl_abstract) that is used for both FakeTensor and meta tensors - deletes CustomOp.impl_meta The user story behind this API is that it is the one-stop shop for registering implementations for data-less Tensors, i.e. FakeTensor and Meta tensor. The abstract implementation provided by the user: - gets registered as the FakeTensor implementation AND the meta formula - can be written like a regular meta formula. If the user decides that they need something more special (i.e. data-dependent output shape), then they are able to query a current context object (FakeTensorImplCtx) that has methods to construct new unbacked symints. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their abstract implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.create_unbacked_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Public API: - More docs will come when we actually expose this API to users by putting it in a public namespace, unless you folks want it now. - The APIs mentioned in `__all__` are the ones that are intended to be public. Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. [ghstack-poisoned]
The context behind this change is: - We're adding a new custom operator API to PyTorch (#99439). - This depends on `torchgen`, which is included as a part of PyTorch - `torchgen` depends on the `yaml` project, which is not included as a runtime dependency - #99439 breaks the binary CI tests because it needs yaml There are workarounds I can do, like refactor what I need from torchgen to not import yaml. But the cleaner solution, because we do ship torchgen with PyTorch, is to actually include yaml as a runtime dependency. Putting up this PR for discussion. [ghstack-poisoned]
The context behind this change is: - We're adding a new custom operator API to PyTorch (#99439). - This depends on `torchgen`, which is included as a part of PyTorch - `torchgen` depends on the `yaml` project, which is not included as a runtime dependency - #99439 breaks the binary CI tests because it needs yaml There are workarounds I can do, like refactor what I need from torchgen to not import yaml. But the cleaner solution, because we do ship torchgen with PyTorch, is to actually include yaml as a runtime dependency. Putting up this PR for discussion. [ghstack-poisoned]
The context behind this change is: - We're adding a new custom operator API to PyTorch (#99439). - This depends on `torchgen`, which is included as a part of PyTorch - `torchgen` depends on the `yaml` project, which is not included as a runtime dependency - #99439 breaks the binary CI tests because it needs yaml There are workarounds I can do, like refactor what I need from torchgen to not import yaml. But the cleaner solution, because we do ship torchgen with PyTorch, is to actually include yaml as a runtime dependency. Putting up this PR for discussion. ghstack-source-id: 3bdf12f Pull Request resolved: #100166
This PR: - adds an abstract registration API for CustomOp (CustomOp.impl_abstract) that is used for both FakeTensor and meta tensors - deletes CustomOp.impl_meta The user story behind this API is that it is the one-stop shop for registering implementations for data-less Tensors, i.e. FakeTensor and Meta tensor. The abstract implementation provided by the user: - gets registered as the FakeTensor implementation AND the meta formula - can be written like a regular meta formula. If the user decides that they need something more special (i.e. data-dependent output shape), then they are able to query a current context object (FakeTensorImplCtx) that has methods to construct new unbacked symints. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their abstract implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.create_unbacked_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Public API: - More docs will come when we actually expose this API to users by putting it in a public namespace, unless you folks want it now. - The APIs mentioned in `__all__` are the ones that are intended to be public. Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. [ghstack-poisoned]
This PR: - adds an abstract registration API for CustomOp (CustomOp.impl_abstract) that is used for both FakeTensor and meta tensors - deletes CustomOp.impl_meta The user story behind this API is that it is the one-stop shop for registering implementations for data-less Tensors, i.e. FakeTensor and Meta tensor. The abstract implementation provided by the user: - gets registered as the FakeTensor implementation AND the meta formula - can be written like a regular meta formula. If the user decides that they need something more special (i.e. data-dependent output shape), then they are able to query a current context object (FakeTensorImplCtx) that has methods to construct new unbacked symints. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their abstract implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.create_unbacked_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Public API: - More docs will come when we actually expose this API to users by putting it in a public namespace, unless you folks want it now. - The APIs mentioned in `__all__` are the ones that are intended to be public. Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. [ghstack-poisoned]
This PR: - adds an abstract registration API for CustomOp (CustomOp.impl_abstract) that is used for both FakeTensor and meta tensors - deletes CustomOp.impl_meta The user story behind this API is that it is the one-stop shop for registering implementations for data-less Tensors, i.e. FakeTensor and Meta tensor. The abstract implementation provided by the user: - gets registered as the FakeTensor implementation AND the meta formula - can be written like a regular meta formula. If the user decides that they need something more special (i.e. data-dependent output shape), then they are able to query a current context object (FakeTensorImplCtx) that has methods to construct new unbacked symints. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their abstract implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.create_unbacked_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Public API: - More docs will come when we actually expose this API to users by putting it in a public namespace, unless you folks want it now. - The APIs mentioned in `__all__` are the ones that are intended to be public. Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. [ghstack-poisoned]
This PR: - adds an abstract registration API for CustomOp (CustomOp.impl_abstract) that is used for both FakeTensor and meta tensors - deletes CustomOp.impl_meta The user story behind this API is that it is the one-stop shop for registering implementations for data-less Tensors, i.e. FakeTensor and Meta tensor. The abstract implementation provided by the user: - gets registered as the FakeTensor implementation AND the meta formula - can be written like a regular meta formula. If the user decides that they need something more special (i.e. data-dependent output shape), then they are able to query a current context object (FakeTensorImplCtx) that has methods to construct new unbacked symints. Caveats: - we really need to make FakeTensor/FakeTensorMode public. Otherwise, there isn't a way for the user to interactively test that their abstract implementation is correct without running through large pieces of the PT2 stack (make_fx or torch.compile). - We do not memoize the symints produced by ctx.create_unbacked_symint(). It is possible to do this in the future, but it is difficult to do soundly and I am not convinced of the utility outside of the nonzero() usecase mentioned in #95399 Public API: - More docs will come when we actually expose this API to users by putting it in a public namespace, unless you folks want it now. - The APIs mentioned in `__all__` are the ones that are intended to be public. Test Plan: - Updated existing custom_op_db operators - Added new numpy_nonzero and numpy_nms operations that test operations that have data-dependendent output shape. [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack:
This PR:
that is used for both FakeTensor and meta tensors
The user story behind this API is that it is the one-stop shop for
registering implementations for data-less Tensors, i.e. FakeTensor and
Meta tensor.
The abstract implementation provided by the user:
they need something more special (i.e. data-dependent output shape),
then they are able to query a current context object (FakeTensorImplCtx)
that has methods to construct new unbacked symints.
Caveats:
there isn't a way for the user to interactively test that their abstract
implementation is correct without running through large pieces of the
PT2 stack (make_fx or torch.compile).
ctx.create_unbacked_symint(). It is possible to do this in the
future, but it is difficult to do soundly and I am not convinced of
the utility outside of the nonzero() usecase mentioned in Memoize repeated nonzero calls to the same fake tensor #95399
Public API:
putting it in a public namespace, unless you folks want it now.
__all__are the ones that are intended to bepublic.
Test Plan:
that have data-dependendent output shape.