Skip to content

[JIT] make is_scripting a condvalue#32871

Closed
eellison wants to merge 1 commit intopytorch:masterfrom
eellison:is_scripting_metacompile
Closed

[JIT] make is_scripting a condvalue#32871
eellison wants to merge 1 commit intopytorch:masterfrom
eellison:is_scripting_metacompile

Conversation

@eellison
Copy link
Copy Markdown
Contributor

Add torch.jit.is_scripting to the list of CondValues, or values that if they are an input to a if statement we only compile one side of the if. I'm not sure if we actually want this PR.

Pros:

  • Makes it easier to add features that are not yet supported in TorchScript (like has_torch_function)
  • The current idiom of writing torch.jit.is_scripting and factoring out the block to a function annotated with torch.jit.ignore is functionally equivalent and much more cumbersome

Cons:

  • Makes it easier to add features that are not yet supported in TorchScript
  • Perhaps is confusing as a reader what is being compiled. Potentially could give all caps name or otherwise change name to make it more visually stand out.

@eellison eellison requested a review from apaszke as a code owner January 31, 2020 17:57
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jan 31, 2020
Copy link
Copy Markdown
Contributor

@driazati driazati left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accepting since this doesn’t really add any more behavior and looks like it makes more sense than is_scripting and an @ignore

Copy link
Copy Markdown
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is an improvement on what was there before. Please don't expand its use just because it better though, it is still a code smell whenever it appears since it breaks the script/python is the same contract.

}
return CondValue(
emitToBool(emitExpr(expr)), RefinementSet({}), c10::nullopt);
auto expr_out = emitToBool(emitExpr(expr));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For robustness, I'd prefer this looked at the node before emitExpr to determine whether it is an is_scripting call. Doing it afterward is subject to arbitrary peephole optimizations that code emission might do and doesn't match the pattern in the rest of this function.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AST looks like:

    (apply
      (.
        (.
          (variable (ident torch))
          (ident jit))
        (ident is_scripting))
      (list)
      (list)))

How should I pattern match then? Are we taking the stance that only torch.jit.is_scripting works, and not jit.is_scripting, or is_scripting, or

import torch.jit.is_scripting as COMPILE_BLOCK
if not COMPILE_BLOCK()
...

i'm fine with whatever. up to this point we have only used condvalues with builtin reserved keywords, so this is a different case.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@Godricly
Copy link
Copy Markdown

Is this is_scrpting stuff working when tracing a module? I tried will some simle demo but not working. Only false is returned.

import torch
from torch import nn
import numpy as np

class Dummy:
    def __init__(self, x):
        self.x = x

def nemo(x):
    return x+1


class Demo(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        print(torch.jit.is_scripting())
        if torch.jit.is_scripting():
            return x
        else:
            return nemo(x) # Dummy(x)




if __name__ == "__main__":
    input_tensor = torch.range(20, 80)
    demo = Demo()
    out = demo(input_tensor)
    torch.onnx.export(demo, input_tensor, "debug.onnx", verbose=True,
                        input_names=['data'],
                        # opset_version=11,
                        do_constant_folding=True,
                        dynamic_axes={'data':{0:'batch'}})

My env is:

PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.15.3

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: GeForce GTX 1080
Nvidia driver version: 430.50
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.3

Versions of relevant libraries:
[pip3] numpy==1.17.2
[pip3] torch==1.4.0
[pip3] torchvision==0.5.0
[conda] Could not collect

@eellison
Copy link
Copy Markdown
Contributor Author

Hi, this doesn't work with tracing. it should just return false when tracing. you can use is_tracing for that.

ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
Summary:
Add `torch.jit.is_scripting` to the list of CondValues, or values that if they are an input to a if statement we only compile one side of the if. I'm not sure if we actually want this PR.

Pros:
- Makes it easier to add features that are not yet supported in TorchScript (like has_torch_function)
- The current idiom of writing `torch.jit.is_scripting` and factoring out the block to a function annotated with `torch.jit.ignore` is functionally equivalent and much more cumbersome

Cons:
- Makes it easier to add features that are not yet supported in TorchScript
- Perhaps is confusing as a reader what is being compiled. Potentially could give all caps name or otherwise change name to make it more visually stand out.
Pull Request resolved: pytorch#32871

Differential Revision: D19670383

Pulled By: eellison

fbshipit-source-id: 5257b0bd23c66f199d59a7f2c911e948301e5588
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants