Skip to content

[JIT] JIT frontend uses len(op_token) before checking for None #25360

@HapeMask

Description

@HapeMask

🐛 Bug

If you try to script a module with an unsupported unary op (in this case, Invert applied to a tensor like ~X), build_UnaryOp() tries to use len(op_token) before checking if the token is None. Instead of printing the actual error message, it then fails with a ValueError:

...

", line 504, in build_UnaryOp                                                                                     
    r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(op_token))                             
TypeError: object of type 'NoneType' has no len()

If I add a check for None and use 0 instead of len(op_token) in that case, I get the real error which says that Invert is not supported :torch.jit.frontend.NotSupportedError: unsupported unary operator: Invert

Followup question: is that expected? If I replace ~ with th.bitwise_not() then it works just fine.

To Reproduce

import torch as th                                                                                                
                                                                                                                  
class TestMod(th.nn.Module):                                                                                      
    def __init__(self):                                                                                           
        super().__init__()                                                                                        
        self.register_buffer("bool_tensor", th.zeros(3,).bool())                                                  
                                                                                                                  
    def forward(self, x):                                                                                         
        return x[~self.bool_tensor]                                                                               
                                                                                                                  
mod = TestMod()                                                                                                   
smod = th.jit.script(mod)

Run the above script.

Traceback (most recent call last):
  File "test.py", line 12, in <module>
    smod = th.jit.script(mod)
  File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/__init__.py", line 1161, in script
    return torch.jit.torch.jit._recursive.recursive_script(obj)                                                  
  File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/_recursive.py", line 133, in recursive_script
    stubs = list(map(make_stub, methods))
  File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/_recursive.py", line 131, in make_stub
    return torch.jit.script_method(func, _jit_internal.createResolutionCallbackFromClosure(func))                
  File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/__init__.py", line 1226, in script_method
    ast = get_jit_def(fn, self_name="ScriptModule")
  File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/frontend.py", line 166, in get_jit_def
    return build_def(ctx, py_ast.body[0], type_line, self_name)                                                  
  File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/frontend.py", line 206, in build_def
    build_stmts(ctx, body))
  File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/frontend.py", line 122, in build_stmts
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/frontend.py", line 122, in <listcomp>
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/frontend.py", line 182, in __call__
    return method(ctx, node)
  File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/frontend.py", line 298, in build_Return
    return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value))                                
  File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/frontend.py", line 182, in __call__
    return method(ctx, node)
  File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/frontend.py", line 596, in build_Subscript
    return Subscript(base, [build_expr(ctx, expr.slice.value)])                                                  
  File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/frontend.py", line 182, in __call__
    return method(ctx, node)
  File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/frontend.py", line 503, in build_UnaryOp
    r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(op_token))                            
TypeError: object of type 'NoneType' has no len()

Expected behavior

", line 507, in build_UnaryOp                  
    raise NotSupportedError(err_range, "unsupported unary operator: " + op.__name__)                              
torch.jit.frontend.NotSupportedError: unsupported unary operator: Invert
:                           
at test.py:9:18                                                                                                   
    def forward(self, x):   
        return x[~self.bool_tensor]                                              
                 ~~~~~~~~~~ <--- HERE

Is what I get if I do the hack I described above, this seems correct.

Environment

  • PyTorch Version (e.g., 1.0): '1.3.0.dev20190816'
  • OS (e.g., Linux): linux
  • How you installed PyTorch (conda, pip, source): conda (nightly)
  • Python version: 3.7.1
  • CUDA/cuDNN version: 10 / 7.6
  • GPU models and configuration: Titan RTX

cc @suo

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: jitAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions