-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
In #106211 we added a translation layer that can take Numpy programs and translate them into corresponding PyTorch programs, so that we could transparently rewrite Numpy programs into accelerated PyTorch code.
We need a similar translation layer for programs that operate on native Python ints/floats/bools/etc. This is necessary to properly exercise on the agreed upon plan in https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit#heading=h.aytveihghlgn where we want to translate SymFloat operations into tensor operations. Dynamo sort of implements this today, but on a very ad hoc basis.
Naively, you might think this is pretty straightforward (a multiply on ints is just a multiple on 0-d int tensors), but doing this in complete generality has more edge cases than you might think:
- Sometimes the obvious translation doesn't actually match semantics
- Anywhere an operator accepts a Scalar, you need to translate it into an appropriate variant that accepts a Tensor. Sometimes there is no Tensor-ified variant and you have to get creative (with a decomp or something)
- Python scalars have different type promotion rules than tensors
- Python scalars always work with tensors that live on other devices. However, if you are compiling a call
cuda_tensor.item(), you would prefer NOT to immediately move the corresponding tensor to CPU in case you can avoid the DtoH sync entirely (maybe Inductor can take care of this optimization directly) - The precision of operations on Tensors may differ from Python precision. This may be fundamentally unsolvable, but it can problems when hard thresholds are applied against the results of floating point computation; e.g., see for example Add non-TS'able _resize_image_and_masks variant with less tensor ops vision#7592
- Some APIs only take scalars and aren't even operators e.g. https://gist.github.com/ezyang/7ebe5e607451cdd2cc067c9eeb1de3b4
Versions
main
cc @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @wconstab