Skip to content

torch.jit.trace returns unwrapped C type #20017

@VitalyFedyunin

Description

@VitalyFedyunin

🐛 Bug

With pytorch 1.1.0

import torch
def f(x):
  return x*2
z = torch.jit.trace(f, (torch.zeros(10),))
print(type(z))

Returns torch._C.Function and cannot be saved:

Traceback (most recent call last): File "1.py", line 7, in <module> torch.jit.save(z, "filename") File "/home/vitalyf/local/miniconda/envs/wp_1/lib/python3.7/site-packages/torch/jit/__init__.py", line 198, in save m.save(f, _extra_files=_extra_files) AttributeError: 'torch._C.Function' object has no attribute 'save'

Expected as in pytorch 1.0.1

import torch
def f(x):
  return x*2
z = torch.jit.trace(f, (torch.zeros(10),))
print(type(z))

Returns <class 'torch.jit.TopLevelTracedModule'>

Metadata

Metadata

Assignees

Labels

module: regressionIt used to work, and now it doesn'toncall: jitAdd this issue/PR to JIT oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions