reimplement __torch_function__ overrides for torch.functional using inline logic#32194
Closed
ngoldbaum wants to merge 2 commits intopytorch:masterfrom
Closed
reimplement __torch_function__ overrides for torch.functional using inline logic#32194ngoldbaum wants to merge 2 commits intopytorch:masterfrom
ngoldbaum wants to merge 2 commits intopytorch:masterfrom
Conversation
Member
💊 CircleCI build failures summary and remediationsAs of commit c904a23: Commit c904a23 was recently pushed. Waiting for builds... This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 1 time. |
ezyang
approved these changes
Jan 15, 2020
Contributor
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Contributor
wuhuikx
pushed a commit
to wuhuikx/pytorch
that referenced
this pull request
Jan 30, 2020
…nline logic (pytorch#32194) Summary: Fixes pytorch#30831. This improves the performance of operators in the `torch.functional` namespace that are overridable by `__torch_function__` implementations when supplied with `Tensor` operands. Running the split benchmark in various configurations produces the following timings: <details> <summary>Expand for timings on <code>master</code> </summary> ``` # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : short # Benchmarking PyTorch: split # Mode: Eager # Name: split_M8_N8_parts2_cpu # Input: M: 8, N: 8, parts: 2, device: cpu Forward Execution Time (us) : 3.340 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M8_N8_parts2_cuda # Input: M: 8, N: 8, parts: 2, device: cuda Forward Execution Time (us) : 3.333 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M256_N512_parts2_cpu # Input: M: 256, N: 512, parts: 2, device: cpu Forward Execution Time (us) : 3.366 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M256_N512_parts2_cuda # Input: M: 256, N: 512, parts: 2, device: cuda Forward Execution Time (us) : 3.385 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M512_N512_parts2_cpu # Input: M: 512, N: 512, parts: 2, device: cpu Forward Execution Time (us) : 3.468 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M512_N512_parts2_cuda # Input: M: 512, N: 512, parts: 2, device: cuda Forward Execution Time (us) : 3.416 ``` </details> <details> <summary>Expand for timings with this pull request applied</summary> ``` # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : short # Benchmarking PyTorch: split # Mode: Eager # Name: split_M8_N8_parts2_cpu # Input: M: 8, N: 8, parts: 2, device: cpu Forward Execution Time (us) : 2.261 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M8_N8_parts2_cuda # Input: M: 8, N: 8, parts: 2, device: cuda Forward Execution Time (us) : 2.223 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M256_N512_parts2_cpu # Input: M: 256, N: 512, parts: 2, device: cpu Forward Execution Time (us) : 2.237 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M256_N512_parts2_cuda # Input: M: 256, N: 512, parts: 2, device: cuda Forward Execution Time (us) : 2.218 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M512_N512_parts2_cpu # Input: M: 512, N: 512, parts: 2, device: cpu Forward Execution Time (us) : 2.259 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M512_N512_parts2_cuda # Input: M: 512, N: 512, parts: 2, device: cuda Forward Execution Time (us) : 2.234 ``` </details> <details> <summary>Expand for timings on <code>master</code> with <code>__torch_function__</code> dispatch disabled </summary> ``` # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : short # Benchmarking PyTorch: split # Mode: Eager # Name: split_M8_N8_parts2_cpu # Input: M: 8, N: 8, parts: 2, device: cpu Forward Execution Time (us) : 2.180 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M8_N8_parts2_cuda # Input: M: 8, N: 8, parts: 2, device: cuda Forward Execution Time (us) : 2.172 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M256_N512_parts2_cpu # Input: M: 256, N: 512, parts: 2, device: cpu Forward Execution Time (us) : 2.171 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M256_N512_parts2_cuda # Input: M: 256, N: 512, parts: 2, device: cuda Forward Execution Time (us) : 2.146 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M512_N512_parts2_cpu # Input: M: 512, N: 512, parts: 2, device: cpu Forward Execution Time (us) : 2.175 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M512_N512_parts2_cuda # Input: M: 512, N: 512, parts: 2, device: cuda Forward Execution Time (us) : 2.152 ``` </details> So at least on the machine I'm testing on, this brings the overhead down to less than 100 ns. For comparison, the overhead for `__array_function__` in NumPy is about 850 ns on the same machine. <details> <summary>Expand for timings for NumPy <code>__array_function__</code> dispatch </summary> ``` In [1]: import numpy as np In [2]: %timeit np.mean([1]) 8.89 µs ± 17.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) In [3]: %timeit np.mean._implementation([1]) 8.04 µs ± 28.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) ``` See [the implementation in NumPy](https://github.com/numpy/numpy/blob/master/numpy/core/overrides.py#L195) for why this measures `__array_function__` overhead. </details> Pull Request resolved: pytorch#32194 Differential Revision: D19410396 Pulled By: ezyang fbshipit-source-id: ada788a4399c81cd7eb2d548aa04a2459e96634a
facebook-github-bot
pushed a commit
that referenced
this pull request
Feb 21, 2020
…onal (#32799) Summary: This adds `__torch_function__` support for all functions in `torch.functional` and `torch.nn.functional`. The changes to C++ code and codegen scripts are to facilitate adding `__torch_function__` support for the native functions in `torch._C._nn`. Note that I moved the `handle_torch_function` C++ function to a header that both `python_torch_functions.cpp` and `python_nn_functions.cpp` include. The changes to `python_nn_functions.cpp` mirror the changes I made to `python_torch_functions.cpp` when `__torch_function__` support was first added in #27064. Due to the somewhat different way the `torch._C` and `torch._C._nn` namespaces are initialized I needed to create a new static reference to the `torch._C._nn` namespace (`THPNNVariableFunctions`). I'm not sure if that is the best way to do this. In principle I could import these namespaces in each kernel and avoid the global variable but that would have a runtime cost. I added `__torch_function__` support to the Python functions in `torch.nn.functional` following the approach in #32194. I re-enabled the test that checks if all functions in the `torch` namespace are explicitly tested for `__torch_function__` support. I also generalized the check to work for `torch.functional` and `torch.nn.functional` as well. This test was explicitly disabled in #30730 and I'm happy to disable it again if you think that's appropriate. I figured now was as good a time as any to try to re-enable it. Finally I adjusted the existing torch API tests to suppress deprecation warnings and add keyword arguments used by some of the code in `torch.nn.functional` that were missed when I originally added the tests in #27064. Pull Request resolved: #32799 Differential Revision: D19956809 Pulled By: ezyang fbshipit-source-id: 40d34e0109cc4b9f3ef62f409d2d35a1d84e3d22
ttumiel
pushed a commit
to ttumiel/pytorch
that referenced
this pull request
Mar 4, 2020
…nline logic (pytorch#32194) Summary: Fixes pytorch#30831. This improves the performance of operators in the `torch.functional` namespace that are overridable by `__torch_function__` implementations when supplied with `Tensor` operands. Running the split benchmark in various configurations produces the following timings: <details> <summary>Expand for timings on <code>master</code> </summary> ``` # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : short # Benchmarking PyTorch: split # Mode: Eager # Name: split_M8_N8_parts2_cpu # Input: M: 8, N: 8, parts: 2, device: cpu Forward Execution Time (us) : 3.340 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M8_N8_parts2_cuda # Input: M: 8, N: 8, parts: 2, device: cuda Forward Execution Time (us) : 3.333 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M256_N512_parts2_cpu # Input: M: 256, N: 512, parts: 2, device: cpu Forward Execution Time (us) : 3.366 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M256_N512_parts2_cuda # Input: M: 256, N: 512, parts: 2, device: cuda Forward Execution Time (us) : 3.385 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M512_N512_parts2_cpu # Input: M: 512, N: 512, parts: 2, device: cpu Forward Execution Time (us) : 3.468 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M512_N512_parts2_cuda # Input: M: 512, N: 512, parts: 2, device: cuda Forward Execution Time (us) : 3.416 ``` </details> <details> <summary>Expand for timings with this pull request applied</summary> ``` # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : short # Benchmarking PyTorch: split # Mode: Eager # Name: split_M8_N8_parts2_cpu # Input: M: 8, N: 8, parts: 2, device: cpu Forward Execution Time (us) : 2.261 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M8_N8_parts2_cuda # Input: M: 8, N: 8, parts: 2, device: cuda Forward Execution Time (us) : 2.223 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M256_N512_parts2_cpu # Input: M: 256, N: 512, parts: 2, device: cpu Forward Execution Time (us) : 2.237 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M256_N512_parts2_cuda # Input: M: 256, N: 512, parts: 2, device: cuda Forward Execution Time (us) : 2.218 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M512_N512_parts2_cpu # Input: M: 512, N: 512, parts: 2, device: cpu Forward Execution Time (us) : 2.259 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M512_N512_parts2_cuda # Input: M: 512, N: 512, parts: 2, device: cuda Forward Execution Time (us) : 2.234 ``` </details> <details> <summary>Expand for timings on <code>master</code> with <code>__torch_function__</code> dispatch disabled </summary> ``` # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : short # Benchmarking PyTorch: split # Mode: Eager # Name: split_M8_N8_parts2_cpu # Input: M: 8, N: 8, parts: 2, device: cpu Forward Execution Time (us) : 2.180 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M8_N8_parts2_cuda # Input: M: 8, N: 8, parts: 2, device: cuda Forward Execution Time (us) : 2.172 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M256_N512_parts2_cpu # Input: M: 256, N: 512, parts: 2, device: cpu Forward Execution Time (us) : 2.171 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M256_N512_parts2_cuda # Input: M: 256, N: 512, parts: 2, device: cuda Forward Execution Time (us) : 2.146 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M512_N512_parts2_cpu # Input: M: 512, N: 512, parts: 2, device: cpu Forward Execution Time (us) : 2.175 # Benchmarking PyTorch: split # Mode: Eager # Name: split_M512_N512_parts2_cuda # Input: M: 512, N: 512, parts: 2, device: cuda Forward Execution Time (us) : 2.152 ``` </details> So at least on the machine I'm testing on, this brings the overhead down to less than 100 ns. For comparison, the overhead for `__array_function__` in NumPy is about 850 ns on the same machine. <details> <summary>Expand for timings for NumPy <code>__array_function__</code> dispatch </summary> ``` In [1]: import numpy as np In [2]: %timeit np.mean([1]) 8.89 µs ± 17.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) In [3]: %timeit np.mean._implementation([1]) 8.04 µs ± 28.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) ``` See [the implementation in NumPy](https://github.com/numpy/numpy/blob/master/numpy/core/overrides.py#L195) for why this measures `__array_function__` overhead. </details> Pull Request resolved: pytorch#32194 Differential Revision: D19410396 Pulled By: ezyang fbshipit-source-id: ada788a4399c81cd7eb2d548aa04a2459e96634a
ttumiel
pushed a commit
to ttumiel/pytorch
that referenced
this pull request
Mar 4, 2020
…onal (pytorch#32799) Summary: This adds `__torch_function__` support for all functions in `torch.functional` and `torch.nn.functional`. The changes to C++ code and codegen scripts are to facilitate adding `__torch_function__` support for the native functions in `torch._C._nn`. Note that I moved the `handle_torch_function` C++ function to a header that both `python_torch_functions.cpp` and `python_nn_functions.cpp` include. The changes to `python_nn_functions.cpp` mirror the changes I made to `python_torch_functions.cpp` when `__torch_function__` support was first added in pytorch#27064. Due to the somewhat different way the `torch._C` and `torch._C._nn` namespaces are initialized I needed to create a new static reference to the `torch._C._nn` namespace (`THPNNVariableFunctions`). I'm not sure if that is the best way to do this. In principle I could import these namespaces in each kernel and avoid the global variable but that would have a runtime cost. I added `__torch_function__` support to the Python functions in `torch.nn.functional` following the approach in pytorch#32194. I re-enabled the test that checks if all functions in the `torch` namespace are explicitly tested for `__torch_function__` support. I also generalized the check to work for `torch.functional` and `torch.nn.functional` as well. This test was explicitly disabled in pytorch#30730 and I'm happy to disable it again if you think that's appropriate. I figured now was as good a time as any to try to re-enable it. Finally I adjusted the existing torch API tests to suppress deprecation warnings and add keyword arguments used by some of the code in `torch.nn.functional` that were missed when I originally added the tests in pytorch#27064. Pull Request resolved: pytorch#32799 Differential Revision: D19956809 Pulled By: ezyang fbshipit-source-id: 40d34e0109cc4b9f3ef62f409d2d35a1d84e3d22
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #30831.
This improves the performance of operators in the
torch.functionalnamespace that are overridable by__torch_function__implementations when supplied withTensoroperands.Running the split benchmark in various configurations produces the following timings:
Expand for timings on
masterExpand for timings with this pull request applied
Expand for timings on
masterwith__torch_function__dispatch disabledSo at least on the machine I'm testing on, this brings the overhead down to less than 100 ns. For comparison, the overhead for
__array_function__in NumPy is about 850 ns on the same machine.Expand for timings for NumPy
__array_function__dispatchSee the implementation in NumPy for why this measures
__array_function__overhead.