Skip to content

__torch_function__ overrides for torch.functional and torch.nn.functional#32799

Closed
ngoldbaum wants to merge 20 commits intopytorch:masterfrom
ngoldbaum:more-functional-overrides
Closed

__torch_function__ overrides for torch.functional and torch.nn.functional#32799
ngoldbaum wants to merge 20 commits intopytorch:masterfrom
ngoldbaum:more-functional-overrides

Conversation

@ngoldbaum
Copy link
Copy Markdown
Contributor

This adds __torch_function__ support for all functions in torch.functional and torch.nn.functional.

The changes to C++ code and codegen scripts are to facilitate adding __torch_function__ support for the native functions in torch._C._nn. Note that I moved the handle_torch_function C++ function to a header that both python_torch_functions.cpp and python_nn_functions.cpp include. The changes to python_nn_functions.cpp mirror the changes I made to python_torch_functions.cpp when __torch_function__ support was first added in #27064. Due to the somewhat different way the torch._C and torch._C._nn namespaces are initialized I needed to create a new static reference to the torch._C._nn namespace (THPNNVariableFunctions). I'm not sure if that is the best way to do this. In principle I could import these namespaces in each kernel and avoid the global variable but that would have a runtime cost.

I added __torch_function__ support to the Python functions in torch.nn.functional following the approach in #32194.

I re-enabled the test that checks if all functions in the torch namespace are explicitly tested for __torch_function__ support. I also generalized the check to work for torch.functional and torch.nn.functional as well. This test was explicitly disabled in #30730 and I'm happy to disable it again if you think that's appropriate. I figured now was as good a time as any to try to re-enable it.

Finally I adjusted the existing torch API tests to suppress deprecation warnings and add keyword arguments used by some of the code in torch.nn.functional that were missed when I originally added the tests in #27064.

@ngoldbaum ngoldbaum requested a review from ezyang January 29, 2020 23:42
@ngoldbaum ngoldbaum requested a review from apaszke as a code owner January 29, 2020 23:42
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the API of this function changed after I moved it to this file. I'm sorry that makes this a little difficult to review!

I'm now passing in torch_api as a PyObject*. I'm also passing in the module name to facilitate the code that generates the error message below. Both of these changes facilitate using this function from both python_torch_functions.cpp and python_nn_functions.cpp.

There should be no other changes to this function.

@kostmo
Copy link
Copy Markdown
Member

kostmo commented Jan 30, 2020

💊 CircleCI build failures summary and remediations

As of commit 39b204b:

  • 5/8 failures introduced in this PR
  • 3/8 recognized as flaky ❄️
    • Re-run these jobs?

Detailed failure analysis

One may explore the probable reasons each build failed interactively on the Dr. CI website.

🕵️ 5 new failures recognized by patterns

The following build failures do not appear to be due to upstream breakage:

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_test (1/5)

Step: "Test" (full log | pattern match details)

Feb 18 22:16:50 RuntimeError: test_overrides failed!
Feb 18 22:16:50   File "test_overrides.py", line 909, in <module> 
Feb 18 22:16:50     generate_tensor_like_torch_implementations() 
Feb 18 22:16:50   File "test_overrides.py", line 904, in generate_tensor_like_torch_implementations 
Feb 18 22:16:50     assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs)) 
Feb 18 22:16:50 AssertionError: ['torch.quantized_batch_norm'] is in IGNORED_TORCH_FUNCTIONS but still has an explicit override 
Feb 18 22:16:50 Traceback (most recent call last): 
Feb 18 22:16:50   File "test/run_test.py", line 486, in <module> 
Feb 18 22:16:50     main() 
Feb 18 22:16:50   File "test/run_test.py", line 479, in main 
Feb 18 22:16:50     raise RuntimeError(message) 
Feb 18 22:16:50 RuntimeError: test_overrides failed! 
Feb 18 22:16:50 + cleanup 
Feb 18 22:16:50 + retcode=1 
Feb 18 22:16:50 + set +x 
Feb 18 22:16:50 =================== sccache compilation log =================== 
Feb 18 22:16:50 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/tmp/torch_extensions/test_compilation_error_formatting/main.cpp: In function \'int main()\':\n/tmp/torch_extensions/test_compilation_error_formatting/main.cpp:2:23: error: expected \';\' before \'}\' token\n int main() { return 0 }\n                       ^\n" } 
Feb 18 22:16:50  
Feb 18 22:16:50 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Feb 18 22:16:50 Compile requests                 61 
Feb 18 22:16:50 Compile requests executed        28 
Feb 18 22:16:50 Cache hits                       21 

See CircleCI build pytorch_linux_xenial_cuda10_1_cudnn7_py3_slow_test (2/5)

Step: "Test" (full log | pattern match details)

Feb 18 22:23:32 RuntimeError: test_overrides failed!
Feb 18 22:23:32   File "test_overrides.py", line 909, in <module> 
Feb 18 22:23:32     generate_tensor_like_torch_implementations() 
Feb 18 22:23:32   File "test_overrides.py", line 904, in generate_tensor_like_torch_implementations 
Feb 18 22:23:32     assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs)) 
Feb 18 22:23:32 AssertionError: ['torch.quantized_batch_norm'] is in IGNORED_TORCH_FUNCTIONS but still has an explicit override 
Feb 18 22:23:32 Traceback (most recent call last): 
Feb 18 22:23:32   File "test/run_test.py", line 486, in <module> 
Feb 18 22:23:32     main() 
Feb 18 22:23:32   File "test/run_test.py", line 479, in main 
Feb 18 22:23:32     raise RuntimeError(message) 
Feb 18 22:23:32 RuntimeError: test_overrides failed! 
Feb 18 22:23:33 + cleanup 
Feb 18 22:23:33 + retcode=1 
Feb 18 22:23:33 + set +x 
Feb 18 22:23:33 =================== sccache compilation log =================== 
Feb 18 22:23:33 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/tmp/torch_extensions/test_compilation_error_formatting/main.cpp: In function \'int main()\':\n/tmp/torch_extensions/test_compilation_error_formatting/main.cpp:2:23: error: expected \';\' before \'}\' token\n int main() { return 0 }\n                       ^\n" } 
Feb 18 22:23:33  
Feb 18 22:23:33 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Feb 18 22:23:33 Compile requests               161 
Feb 18 22:23:33 Compile requests executed       57 
Feb 18 22:23:33 Cache hits                      44 

See CircleCI build pytorch_macos_10_13_py3_test (3/5)

Step: "Test" (full log | pattern match details)

Feb 18 14:28:22 RuntimeError: test_overrides failed!
Feb 18 14:28:22   File "test_overrides.py", line 909, in <module> 
Feb 18 14:28:22     generate_tensor_like_torch_implementations() 
Feb 18 14:28:22   File "test_overrides.py", line 904, in generate_tensor_like_torch_implementations 
Feb 18 14:28:22     assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs)) 
Feb 18 14:28:22 AssertionError: ['torch.quantized_batch_norm'] is in IGNORED_TORCH_FUNCTIONS but still has an explicit override 
Feb 18 14:28:22 Traceback (most recent call last): 
Feb 18 14:28:22   File "test/run_test.py", line 486, in <module> 
Feb 18 14:28:22     main() 
Feb 18 14:28:22   File "test/run_test.py", line 479, in main 
Feb 18 14:28:22     raise RuntimeError(message) 
Feb 18 14:28:22 RuntimeError: test_overrides failed! 
Feb 18 14:28:22 + cleanup 
Feb 18 14:28:22 + retcode=1 
Feb 18 14:28:22 + set +x 

See CircleCI build pytorch_linux_xenial_cuda10_1_cudnn7_py3_nogpu_test (4/5)

Step: "Test" (full log | pattern match details)

Feb 18 22:33:50 RuntimeError: test_overrides failed!
Feb 18 22:33:50   File "test_overrides.py", line 909, in <module> 
Feb 18 22:33:50     generate_tensor_like_torch_implementations() 
Feb 18 22:33:50   File "test_overrides.py", line 904, in generate_tensor_like_torch_implementations 
Feb 18 22:33:50     assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs)) 
Feb 18 22:33:50 AssertionError: ['torch.quantized_batch_norm'] is in IGNORED_TORCH_FUNCTIONS but still has an explicit override 
Feb 18 22:33:50 Traceback (most recent call last): 
Feb 18 22:33:50   File "test/run_test.py", line 486, in <module> 
Feb 18 22:33:50     main() 
Feb 18 22:33:50   File "test/run_test.py", line 479, in main 
Feb 18 22:33:50     raise RuntimeError(message) 
Feb 18 22:33:50 RuntimeError: test_overrides failed! 
Feb 18 22:33:50 + cleanup 
Feb 18 22:33:50 + retcode=1 
Feb 18 22:33:50 + set +x 
Feb 18 22:33:50 =================== sccache compilation log =================== 
Feb 18 22:33:50 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/tmp/torch_extensions/test_compilation_error_formatting/main.cpp: In function \'int main()\':\n/tmp/torch_extensions/test_compilation_error_formatting/main.cpp:2:23: error: expected \';\' before \'}\' token\n int main() { return 0 }\n                       ^\n" } 
Feb 18 22:33:50  
Feb 18 22:33:50 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Feb 18 22:33:50 Compile requests                 61 
Feb 18 22:33:50 Compile requests executed        28 
Feb 18 22:33:50 Cache hits                       21 

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test2 (5/5)

Step: "Test" (full log | pattern match details)

RuntimeError: test_overrides failed!
  File "test_overrides.py", line 909, in <module> 
    generate_tensor_like_torch_implementations() 
  File "test_overrides.py", line 904, in generate_tensor_like_torch_implementations 
    assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs)) 
AssertionError: ['torch.quantized_batch_norm'] is in IGNORED_TORCH_FUNCTIONS but still has an explicit override 
Traceback (most recent call last): 
  File "run_test.py", line 486, in <module> 
    main() 
  File "run_test.py", line 479, in main 
    raise RuntimeError(message) 
RuntimeError: test_overrides failed! 
 
(base) circleci@PACKER-5E29F737 C:\Users\circleci\project\test>if ERRORLEVEL 1 exit /b 1  
+ cleanup
+ retcode=1
+ set +x

❄️ 3 failures recognized as flaky

The following build failures have been detected as flaky and may not be your fault:

See CircleCI build pytorch_linux_xenial_cuda10_1_cudnn7_py3_gcc7_test (1/3)

Step: "Test" (full log | pattern match details) ❄️

Feb 18 22:50:37 ConnectionResetError: [Errno 104] Connection reset by peer
Feb 18 22:50:37   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 455, in accept 
Feb 18 22:50:37     deliver_challenge(c, self._authkey) 
Feb 18 22:50:37   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 722, in deliver_challenge 
Feb 18 22:50:37     response = connection.recv_bytes(256)        # reject large message 
Feb 18 22:50:37   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes 
Feb 18 22:50:37     buf = self._recv_bytes(maxlength) 
Feb 18 22:50:37   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes 
Feb 18 22:50:37     buf = self._recv(4) 
Feb 18 22:50:37   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 379, in _recv 
Feb 18 22:50:37     chunk = read(handle, remaining) 
Feb 18 22:50:37 ConnectionResetError: [Errno 104] Connection reset by peer 
Feb 18 22:50:37 /opt/conda/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 14 leaked semaphores to clean up at shutdown 
Feb 18 22:50:37   len(cache)) 
Feb 18 22:50:40 Process ErrorTrackingProcess-122: 
Feb 18 22:50:40 Traceback (most recent call last): 
Feb 18 22:50:40   File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap 
Feb 18 22:50:40     self.run() 
Feb 18 22:50:40   File "/var/lib/jenkins/workspace/test/test_dataloader.py", line 333, in run 
Feb 18 22:50:40     super(ErrorTrackingProcess, self).run() 
Feb 18 22:50:40   File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 93, in run 
Feb 18 22:50:40     self._target(*self._args, **self._kwargs) 

See CircleCI build pytorch_linux_xenial_cuda10_1_cudnn7_py3_NO_AVX_NO_AVX2_test (2/3)

Step: "Test" (full log | pattern match details) ❄️

Feb 18 22:51:14 ConnectionResetError: [Errno 104] Connection reset by peer
Feb 18 22:51:14   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 455, in accept 
Feb 18 22:51:14     deliver_challenge(c, self._authkey) 
Feb 18 22:51:14   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 722, in deliver_challenge 
Feb 18 22:51:14     response = connection.recv_bytes(256)        # reject large message 
Feb 18 22:51:14   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes 
Feb 18 22:51:14     buf = self._recv_bytes(maxlength) 
Feb 18 22:51:14   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes 
Feb 18 22:51:14     buf = self._recv(4) 
Feb 18 22:51:14   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 379, in _recv 
Feb 18 22:51:14     chunk = read(handle, remaining) 
Feb 18 22:51:14 ConnectionResetError: [Errno 104] Connection reset by peer 
Feb 18 22:51:14 /opt/conda/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 14 leaked semaphores to clean up at shutdown 
Feb 18 22:51:14   len(cache)) 
Feb 18 22:51:17 Process ErrorTrackingProcess-126: 
Feb 18 22:51:17 Traceback (most recent call last): 
Feb 18 22:51:17   File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap 
Feb 18 22:51:17     self.run() 
Feb 18 22:51:17   File "/var/lib/jenkins/workspace/test/test_dataloader.py", line 333, in run 
Feb 18 22:51:17     super(ErrorTrackingProcess, self).run() 
Feb 18 22:51:17   File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 93, in run 
Feb 18 22:51:17     self._target(*self._args, **self._kwargs) 

See CircleCI build pytorch_linux_xenial_cuda10_1_cudnn7_py3_NO_AVX2_test (3/3)

Step: "Test" (full log | pattern match details) ❄️

Feb 18 22:51:25 ConnectionResetError: [Errno 104] Connection reset by peer
Feb 18 22:51:25   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 455, in accept 
Feb 18 22:51:25     deliver_challenge(c, self._authkey) 
Feb 18 22:51:25   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 722, in deliver_challenge 
Feb 18 22:51:25     response = connection.recv_bytes(256)        # reject large message 
Feb 18 22:51:25   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes 
Feb 18 22:51:25     buf = self._recv_bytes(maxlength) 
Feb 18 22:51:25   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes 
Feb 18 22:51:25     buf = self._recv(4) 
Feb 18 22:51:25   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 379, in _recv 
Feb 18 22:51:25     chunk = read(handle, remaining) 
Feb 18 22:51:25 ConnectionResetError: [Errno 104] Connection reset by peer 
Feb 18 22:51:25 /opt/conda/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 14 leaked semaphores to clean up at shutdown 
Feb 18 22:51:25   len(cache)) 
Feb 18 22:51:28 Process ErrorTrackingProcess-122: 
Feb 18 22:51:28 Traceback (most recent call last): 
Feb 18 22:51:28   File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap 
Feb 18 22:51:28     self.run() 
Feb 18 22:51:28   File "/var/lib/jenkins/workspace/test/test_dataloader.py", line 333, in run 
Feb 18 22:51:28     super(ErrorTrackingProcess, self).run() 
Feb 18 22:51:28   File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 93, in run 
Feb 18 22:51:28     self._target(*self._args, **self._kwargs) 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 38 times.

@ngoldbaum ngoldbaum changed the title __torch_function__ overrides for torch.functional and torch.nn.functional [WIP] __torch_function__ overrides for torch.functional and torch.nn.functional Jan 30, 2020
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Jan 30, 2020

Nifty! Does this still count as WIP? It looks like some merging is in order.

@ngoldbaum
Copy link
Copy Markdown
Contributor Author

I marked it as WIP last night after seeing the test failures which I’ll work through today. Please feel free to review, I doubt fixing the test failures will change the structure of this code too much.

@ezyang ezyang requested a review from bhosmer January 30, 2020 20:51
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Jan 30, 2020

Adding @bhosmer for codegen changes

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oooof lol

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my curiosity; how come in some of the adjusted code, we need to define tens_ops, and in other cases, we don't? Is this just a code reduction measure?

Copy link
Copy Markdown
Contributor Author

@ngoldbaum ngoldbaum Jan 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's multiple tensor operands I create the tuple before calling type because I need to create the tuple no matter what. When there's only one tensor operand I don't create the tuple because it's only needed if type(input) ends up not being Tensor, so we don't need to pay the cost of creating the tuple unless we're passed non-tensor operands. I figure optimizing for tensor operands is what you'd prefer :)

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Jan 30, 2020

Can you post the modified generated code, before and after (a diff works good too)

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Jan 30, 2020

This all looks very reasonable. It'll need a rebase and some testfixing though.

Copy link
Copy Markdown

@bhosmer bhosmer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codegen changes LGTM, mod echoing @ezyang's codegen diff request :)

Also, apologies for the rebase in your future - refacttor of gen_python_functions.py just landed. Rebase should be conceptually simple but nontrivial. Definitely hit me up here if anything is unclear.

BTW, I notice that the existing setup doesn't generate the check into no-args bindings, was that intentional?

@ngoldbaum
Copy link
Copy Markdown
Contributor Author

I pinged @ezyang about this in the quansight channel on the pytorch slack but I figure someone else might chime in here. It looks like most of the test failures are because torchscript doesn't like the code I've added to check for Tensor operands. For example:

Jan 30 01:27:22     RuntimeError: 
Jan 30 01:27:22     builtin cannot be used as a value:
Jan 30 01:27:22       File "/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py", line 1028
Jan 30 01:27:22         :class:`~torch.nn.ReLU` for more details.
Jan 30 01:27:22         """
Jan 30 01:27:22         if type(input) is not Tensor and has_torch_function((input,)):
Jan 30 01:27:22                               ~~~~~~ <--- HERE
Jan 30 01:27:22             return handle_torch_function(relu, (input,), input, inplace=inplace)
Jan 30 01:27:22         if inplace:
Jan 30 01:27:22     'relu' is being compiled since it was called from 'MyModule.forward'
Jan 30 01:27:22         def forward(self, input):
Jan 30 01:27:22           input = F.relu(self.conv1(input))
Jan 30 01:27:22           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
Jan 30 01:27:22           input = F.relu(self.conv2(input))
Jan 30 01:27:22           return input

I've been playing around with rephrasing that code in a way that appeases torchscript but I don't see it. Is there a way I can disable torchscript for a block of code? Or does this mean I'll need to touch the jit code?

@bhosmer
Copy link
Copy Markdown

bhosmer commented Jan 30, 2020

@ngoldbaum tangentially related to your last comment, worth noting that the JIT doesn't see the generated interception logic at all - the bindings generated by gen_python_functions are eager mode only.

Approach and timeline for teaching the JIT about the overrides is TBD, but in the meantime, thought it was worth noting the disparity.

@eellison
Copy link
Copy Markdown
Contributor

@ngoldbaum if if type(input) is not Tensor and has_torch_function((input,)): could be rewritten as

if not isinstance(input, Tensor):
   if has_torch_function((input,)): 

that would could compile in TorchScript. But I guess that probably doesn't work because the inputs that will have torch_functions probably inherit from Tensor.
Otherwise you can try:

@torch.jit.ignore
def has_torch_function(Tensor):
     ....

@torch.jit.ignore
def handle_torch_function(....):
     return handle_torch_function(relu, (input,), input, inplace=inplace)


if not torch.jit.is_scripting():
    if has_torch_function():
        return handle_torch_function(relu, (input,), input, inplace=inplace)

@ngoldbaum
Copy link
Copy Markdown
Contributor Author

Yeah, we want to check if it's specifically not Tensor, subclasses might have __torch_function__ defined.

Thank you for the suggestions!

@ngoldbaum
Copy link
Copy Markdown
Contributor Author

Hmm, that doesn't seem to work. It's still trying to compile the generator expression in torch.nn.functional.linear, for example:

E           torch.jit.frontend.UnsupportedNodeError: GeneratorExp aren't supported:
E             File "/home/goldbaum/pytorch/torch/nn/functional.py", line 1541
E               if not torch.jit.is_scripting():
E                   tens_ops = (input, weight)
E                   if any(type(t) is not Tensor for t in tens_ops) and has_torch_function(tens_ops):
E                          ~ <--- HERE
E                       return handle_torch_function(linear, tens_ops, input, weight, bias=bias)
E               if input.dim() == 2 and bias is not None:

@eellison
Copy link
Copy Markdown
Contributor

@ngoldbaum yea in that case you would have to factor out the any call into an torch.jit.ignore function. however there may be an easier solution, i'll get back to you

@eellison
Copy link
Copy Markdown
Contributor

@ngoldbaum when #32871 lands you will be able to just put the block you don't want compiled under if not torch.jit.is_scripting() and not have to use torch.jit.ignore. The ^^ example should work:

if not torch.jit.is_scripting():
    tens_ops = (input, weight)
    if any(type(t) is not Tensor for t in tens_ops) and has_torch_function(tens_ops):
           ~ <--- HERE
        return handle_torch_function(linear, tens_ops, input, weight, bias=bias)

@ngoldbaum
Copy link
Copy Markdown
Contributor Author

Excellent, thank you for the quick fix on that :)

@ngoldbaum
Copy link
Copy Markdown
Contributor Author

ngoldbaum commented Feb 12, 2020

I've rebased and am on a branch based on a version of master after #32871, which I will push right now.

@ngoldbaum ngoldbaum force-pushed the more-functional-overrides branch from 631e4bf to 3aaf9b5 Compare February 12, 2020 17:29
@ngoldbaum
Copy link
Copy Markdown
Contributor Author

ngoldbaum commented Feb 12, 2020

And it looks like I can trigger this in the test @eellison added with the following addition:

diff --git a/test/test_jit.py b/test/test_jit.py
index 89d4b9434c..37165c2051 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -15668,6 +15668,7 @@ a")
             if torch.jit.is_scripting():
                 return 1
             else:
+                any(a for a in 'hello')
                 print("hello") + 2

         self.assertEqual(foo(), 1)
Details

self = <test_jit.TestScript testMethod=test_is_scripting_metacompile>

    def test_is_scripting_metacompile(self):
>       @torch.jit.script
        def foo():

test/test_jit.py:15666:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
torch/jit/__init__.py:1304: in script
    ast = get_jit_def(obj)
torch/jit/frontend.py:171: in get_jit_def
    return build_def(ctx, py_ast.body[0], type_line, self_name)
torch/jit/frontend.py:212: in build_def
    build_stmts(ctx, body))
torch/jit/frontend.py:127: in build_stmts
    stmts = [build_stmt(ctx, s) for s in stmts]
torch/jit/frontend.py:127: in <listcomp>
    stmts = [build_stmt(ctx, s) for s in stmts]
torch/jit/frontend.py:187: in __call__
    return method(ctx, node)
torch/jit/frontend.py:368: in build_If
    build_stmts(ctx, stmt.orelse))
torch/jit/frontend.py:127: in build_stmts
    stmts = [build_stmt(ctx, s) for s in stmts]
torch/jit/frontend.py:127: in <listcomp>
    stmts = [build_stmt(ctx, s) for s in stmts]
torch/jit/frontend.py:187: in __call__
    return method(ctx, node)
torch/jit/frontend.py:290: in build_Expr
    return ExprStmt(build_expr(ctx, value))
torch/jit/frontend.py:187: in __call__
    return method(ctx, node)
torch/jit/frontend.py:457: in build_Call
    args = [build_expr(ctx, py_arg) for py_arg in expr.args]
torch/jit/frontend.py:457: in <listcomp>
    args = [build_expr(ctx, py_arg) for py_arg in expr.args]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <torch.jit.frontend.ExprBuilder object at 0x7f64d1ff67d0>
ctx = <torch.jit.frontend.SourceContext object at 0x7f63f42a2fb0>, node = <_ast.GeneratorExp object at 0x7f63f42a3f10>

    def __call__(self, ctx, node):
        method = getattr(self, 'build_' + node.__class__.__name__, None)
        if method is None:
>           raise UnsupportedNodeError(ctx, node)
E           torch.jit.frontend.UnsupportedNodeError: GeneratorExp aren't supported:
E             File "/home/goldbaum/pytorch/test/test_jit.py", line 15671
E                           return 1
E                       else:
E                           any(a for a in 'hello')
E                               ~ <--- HERE
E                           print("hello") + 2

torch/jit/frontend.py:186: UnsupportedNodeError

So it's still trying to compile code in the block even though it doesn't get evaluated at runtime.

@eellison
Copy link
Copy Markdown
Contributor

And it looks like I can trigger this in the test @eellison added with the following addition:

diff --git a/test/test_jit.py b/test/test_jit.py
index 89d4b9434c..37165c2051 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -15668,6 +15668,7 @@ a")
             if torch.jit.is_scripting():
                 return 1
             else:
+                any(a for a in 'hello')
                 print("hello") + 2

         self.assertEqual(foo(), 1)

So it's still trying to compile code in the block even though it doesn't get evaluated at runtime.

It's still trying to parse the code and create the AST, which is expected. Looks like you ran into a AST node that we don't parse yet. What is the node that isn't supported exactly ?

@ngoldbaum
Copy link
Copy Markdown
Contributor Author

In this case it's a generator expression I'd like to use any with. Thank you for the hint, I'll see if I can rephrase it in a way that's friendlier to the JIT.

@ngoldbaum ngoldbaum force-pushed the more-functional-overrides branch from a5f7b18 to 44e6cb3 Compare February 12, 2020 23:34
@ngoldbaum ngoldbaum changed the title [WIP] __torch_function__ overrides for torch.functional and torch.nn.functional __torch_function__ overrides for torch.functional and torch.nn.functional Feb 13, 2020
@ngoldbaum ngoldbaum requested review from bhosmer and ezyang February 13, 2020 02:57
@ngoldbaum
Copy link
Copy Markdown
Contributor Author

@ezyang this should be good for re-review. Here's a gist with the diffs you asked for: https://gist.github.com/ngoldbaum/ebe6a0ad1fb91a0c701dc43ffbd65aee

// PyType_GenericNew returns a new reference
THPVariableFunctionsModule = PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None);
// PyModule_AddObject steals a reference
if (PyModule_AddObject(module, "_VariableFunctions", THPVariableFunctionsModule) < 0) {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to slightly adjust how _VariableFunctions is set up to make it match how _nn is set up in python_nn_functions.cpp. I think this is safe and should behave identically but it may be a risky change because I don't know why this namespace was originally set up as a type instead of an instance.

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Feb 18, 2020

Sorry this has taken a while, hoping this will get reviewed today

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngoldbaum
Copy link
Copy Markdown
Contributor Author

Looks like the internal tests are failing.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngoldbaum
Copy link
Copy Markdown
Contributor Author

yay all the tests are passing :)

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ezyang merged this pull request in fa80299.

ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
…onal (pytorch#32799)

Summary:
This adds `__torch_function__` support for all functions in `torch.functional` and `torch.nn.functional`.

The changes to C++ code and codegen scripts are to facilitate adding `__torch_function__` support for the native functions in `torch._C._nn`. Note that I moved the `handle_torch_function` C++ function to a header that both `python_torch_functions.cpp` and `python_nn_functions.cpp` include. The changes to `python_nn_functions.cpp` mirror the changes I made to `python_torch_functions.cpp` when `__torch_function__` support was first added in pytorch#27064. Due to the somewhat different way the `torch._C` and `torch._C._nn` namespaces are initialized I needed to create a new static reference to the `torch._C._nn` namespace (`THPNNVariableFunctions`). I'm not sure if that is the best way to do this. In principle I could import these namespaces in each kernel and avoid the global variable but that would have a runtime cost.

I added `__torch_function__` support to the Python functions in `torch.nn.functional` following the approach in pytorch#32194.

I re-enabled the test that checks if all functions in the `torch` namespace are explicitly tested for `__torch_function__` support. I also generalized the check to work for `torch.functional` and `torch.nn.functional` as well. This test was explicitly disabled in pytorch#30730 and I'm happy to disable it again if you think that's appropriate. I figured now was as good a time as any to try to re-enable it.

Finally I adjusted the existing torch API tests to suppress deprecation warnings and add keyword arguments used by some of the code in `torch.nn.functional` that were missed when I originally added the tests in pytorch#27064.
Pull Request resolved: pytorch#32799

Differential Revision: D19956809

Pulled By: ezyang

fbshipit-source-id: 40d34e0109cc4b9f3ef62f409d2d35a1d84e3d22
torch.zeros,
torch.nn.functional.assert_int_or_pair,
torch.nn.functional.boolean_dispatch,
torch.nn.functional.division,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why this is here? AFAICT, the only division in torch.nn.functional comes from from __future__ import division, which is probably not what you want.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a blacklist: IGNORED_TORCH_FUNCTIONS. The test loops over dir of the module and somehow division was being exported lol

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants