-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add new_empty (with dtype argument only) to torch::stable #159508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add new_empty (with dtype argument only) to torch::stable #159508
Conversation
…(pending header-onlyness) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159508
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 188f656 with merge base a44a0d3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…le::Tensor (pending header-onlyness)" needs ScalarType.h, which needs to be header-only [ghstack-poisoned]
…le::Tensor (pending header-onlyness)" needs ScalarType.h, which needs to be header-only [ghstack-poisoned]
| if not has_function_variant: | ||
| # Functions with both function and method variants can use the at::{*}_symint version | ||
| # (e.g., narrow -> at::narrow_symint), BUT | ||
| # Method-only functions with symint parameters should use at::symint:: namespace | ||
| # Remove the _symint suffix since at::symint:: namespace uses the base name | ||
| # (e.g., new_empty -> at::symint::new_empty<c10::SymInt>) | ||
| base_name = cpp_sig.name() | ||
| base_name = base_name.removesuffix("_symint") # Remove "_symint" suffix | ||
| return f"at::symint::{base_name}<c10::SymInt>" | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the relevant part here
Lines 717 to 739 in 6d91d6d
| if Variant.function in f.variants: | |
| result += f""" | |
| // aten::{f.func} | |
| inline {sig.decl()} {{ | |
| return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); | |
| }}""" | |
| # The template function can be used from template situations | |
| # where you want to switch between the symint or not version | |
| # depending on a template argument | |
| # | |
| # NB: we ALWAYS generate this even for methods. But we put it in | |
| # this header so it can take advantage of per-op headers | |
| if has_symint: | |
| result += f""" | |
| namespace symint {{ | |
| template <typename T, typename = std::enable_if_t<std::is_same_v<T, {intlike_t}>>> | |
| {sig.decl(suppress_symint_suffix=True)} {{ | |
| return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); | |
| }} | |
| }} | |
| """ | |
| return result |
needs ScalarType.h, which needs to be header-only [ghstack-poisoned]
torch/csrc/stable/ops.h
Outdated
| // Handle dtype - use input tensor's dtype if not specified | ||
| int32_t target_dtype; | ||
| if (dtype.has_value()) { | ||
| target_dtype = static_cast<int32_t>(dtype.value()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm lol is this actually stable...this looks to expose a detail on how dtype is encapsulated. Should we/I write that translation layer from the headeronly ScalarType to their corresponding shim aoti_torch_get_dtype... maybe there's no other way to hide this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it would be nice if you could write that!
Can you describe what you mean in more detail of what the translation layer would do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like a pair of utils that take in and return ScalarType::BFloat`6 -> aoti_torch_dtype_bfloat16 and vice versa. I'll prioritize writing this (i think it;d be helpful for my Tensor scalar_type PR too)
| cpp_sig = gen_static_dispatch_backend_call_signature(sig, f) | ||
|
|
||
| if backend_index is None: | ||
| # Check if this is a symint function and if the function only has method variants |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any other ops we expect to go through this branch? Maybe list new_empty as an example here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm it's listed below (when all the if checks pass) on 412
janeyx99
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine to me, this is making me think i need to land my scalartype liaisoning asap so we aren't exposing weird details by accident
test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py
Outdated
Show resolved
Hide resolved
test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py
Outdated
Show resolved
Hide resolved
test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp
Show resolved
Hide resolved
[ghstack-poisoned]
[ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…9508) Pull Request resolved: pytorch#159508 Approved by: https://github.com/janeyx99 ghstack dependencies: pytorch#160557
Stack from ghstack (oldest at bottom):