Skip to content

Commit 93db2b8

Browse files
suofacebook-github-bot
authored andcommitted
Fix type sharing on loaded ScriptModules (#29826)
Summary: Pull Request resolved: #29826 After save/load, we lose concrete type information. So if you tried to script something that contained a loaded ScriptModule as a submodule, the following sequence happened: 1. During ConcreteType inference, the loaded submodule got a new inferred type. 2. But it already has a type! So there was a type mismatch. To fix this, we should generate a ConcreteType directly from the loaded submodule type (similar to what we do for interfaces). This makes sense too--the ConcreteModuleType should be empty, since all the "sugaredness" was stripped out during the save/load process. Test Plan: Imported from OSS Differential Revision: D18575009 Pulled By: suo fbshipit-source-id: 4d329b7e9b7e7624f459e50092e35ab0ab813791
1 parent 558a777 commit 93db2b8

File tree

5 files changed

+34
-8
lines changed

5 files changed

+34
-8
lines changed

test/jit/test_recursive_script.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,3 +629,24 @@ def forward(self, x):
629629
dummies = nn.ModuleList([dummy])
630630
model = Model(dummies)
631631
self.checkModule(model, (torch.rand(5, 5), ))
632+
633+
def test_script_loaded_module(self):
634+
"""
635+
Test that we can hold a loaded ScriptModule as a submodule.
636+
"""
637+
class Dummy(nn.Module):
638+
def forward(self, x):
639+
return x
640+
641+
dummy = torch.jit.script(Dummy())
642+
dummy = self.getExportImportCopy(dummy)
643+
644+
class ContainsLoaded(torch.nn.Module):
645+
def __init__(self):
646+
super(ContainsLoaded, self).__init__()
647+
self.encoder = dummy
648+
649+
def forward(self, input):
650+
return self.encoder(input)
651+
652+
self.checkModule(ContainsLoaded(), (torch.rand(2, 3), ))

torch/csrc/jit/script/concrete_module_type.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,16 @@ ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const {
3535
return cls;
3636
}
3737

38-
std::shared_ptr<ConcreteModuleType> ConcreteModuleType::fromInterface(
39-
InterfaceTypePtr interface) {
40-
TORCH_INTERNAL_ASSERT(interface->is_module());
38+
std::shared_ptr<ConcreteModuleType> ConcreteModuleType::fromJitType(
39+
TypePtr type) {
40+
// `type` should either be a module interface or a class type
41+
if (auto interface = type->cast<InterfaceType>()){
42+
TORCH_INTERNAL_ASSERT(interface->is_module());
43+
} else {
44+
TORCH_INTERNAL_ASSERT(type->cast<ClassType>());
45+
}
4146
auto ret = std::shared_ptr<ConcreteModuleType>(new ConcreteModuleType());
42-
ret->jitType_ = std::move(interface);
47+
ret->jitType_ = std::move(type);
4348
return ret;
4449
}
4550

torch/csrc/jit/script/concrete_module_type.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ class VISIBILITY_HIDDEN ConcreteModuleType {
178178
public:
179179
explicit ConcreteModuleType(ConcreteModuleTypeBuilder data);
180180

181-
static std::shared_ptr<ConcreteModuleType> fromInterface(
182-
InterfaceTypePtr interface);
181+
static std::shared_ptr<ConcreteModuleType> fromJitType(TypePtr type);
183182

184183
TypePtr getJitType() const;
185184
py::object getPyClass() const;

torch/csrc/jit/script/init.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1082,7 +1082,7 @@ void initJitScriptBindings(PyObject* module) {
10821082
m, "ConcreteModuleType")
10831083
.def_property_readonly("py_class", &ConcreteModuleType::getPyClass)
10841084
.def_property_readonly("jit_type", &ConcreteModuleType::getJitType)
1085-
.def_static("from_interface", &ConcreteModuleType::fromInterface)
1085+
.def_static("from_jit_type", &ConcreteModuleType::fromJitType)
10861086
.def("get_constants", &ConcreteModuleType::getConstantsPy)
10871087
.def("get_attributes", &ConcreteModuleType::getAttributesPy)
10881088
.def("get_modules", &ConcreteModuleType::getModulesPy)

torch/jit/_recursive.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def infer_type(name, item):
115115
attr_type = infer_type(name, item)
116116
if attr_type is not None:
117117
# if the type can be inferred, it should be a module interface type
118-
sub_concrete_type = torch._C.ConcreteModuleType.from_interface(attr_type)
118+
sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(attr_type)
119119
else:
120120
# otherwise we get the concrete module type for item and add it to concrete_type
121121
sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
@@ -561,6 +561,7 @@ def wrap_cpp_module(cpp_module):
561561
def init_fn(script_module):
562562
for name, cpp_module in torch._C.ModuleDict(script_module._c).items():
563563
setattr(script_module, name, wrap_cpp_module(cpp_module))
564+
script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type(script_module._c._type())
564565
return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
565566

566567
def compile_unbound_method(concrete_type, fn):

0 commit comments

Comments
 (0)