@@ -804,61 +804,6 @@ def run(f):
804804 traced_fn = torch.jit.trace(fn, torch.ones(1))
805805 self.assertEqual(run(fn), run(traced_fn))
806806
807- def test_scopes(self):
808- x = torch.tensor([0.4], requires_grad=True)
809- y = torch.tensor([0.7], requires_grad=True)
810-
811- def f(x, y):
812- out = x + y
813- with torch.jit.scope('Foo'):
814- out = x * out
815- with torch.jit.scope('Bar'):
816- out = torch.tanh(out)
817- out = torch.sigmoid(out)
818- return out
819-
820- self.checkTrace(f, (x, y))
821-
822- def test_scopes_intermediate_node(self):
823- class Net(nn.Module):
824- def forward(self, x):
825- return F.log_softmax(x, dim=0)
826-
827- net = Net()
828- t = torch.ones(2, requires_grad=True)
829- g, outputs, inputs = torch.jit._get_trace_graph(net, (t,), return_inputs=True)
830- self.assertEqual(outputs, self.createFunctionFromGraph(g)(*inputs))
831- self.assertExportImport(g, (t,))
832- g = torch.onnx._optimize_trace(g, operator_export_type=OperatorExportTypes.ONNX)
833- FileCheck().check("onnx::LogSoftmax").check("scope: Net").run(str(g))
834-
835- def test_scopes_identity_node(self):
836-
837- class Net(nn.Module):
838-
839- def __init__(self):
840- super(Net, self).__init__()
841- self.features = nn.Sequential(
842- nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
843- nn.ReLU(inplace=True),
844- nn.MaxPool2d(kernel_size=3, stride=2),
845- )
846-
847- def forward(self, x):
848- x = self.features(x)
849- return x
850-
851- model = Net()
852-
853- t = torch.ones(1, 3, 227, 227, requires_grad=True)
854-
855- with torch.onnx.set_training(model, False):
856- g, _ = torch.jit._get_trace_graph(model, (t,))
857-
858- self.assertExportImport(g, (t,) + tuple(model.parameters()))
859- g = torch.onnx._optimize_trace(g, operator_export_type=OperatorExportTypes.ONNX)
860- FileCheck().check("Net/Sequential[features]/Conv2d[0]").check("ReLU").check("MaxPool").run(str(g))
861-
862807 def test_canonicalize_tensor_iterator(self):
863808 x = torch.randn(4, 4)
864809
@@ -11949,9 +11894,12 @@ def forward(self, x):
1194911894
1195011895 tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
1195111896
11952- FileCheck().check_not("value=<Tensor>").check_count("aten::mm", 2).check("aten::add") \
11897+ FileCheck().check_not("value=<Tensor>").check("aten::mm")\
11898+ .check("prim::CallMethod[name=\"forward\"]").check("aten::add") \
1195311899 .run(str(tm.graph))
1195411900
11901+ FileCheck().check("aten::mm").run(str(tm.mod.graph))
11902+
1195511903 @_tmp_donotuse_dont_inline_everything
1195611904 def test_call_traced_fn_from_traced_module(self):
1195711905 @_trace(torch.rand(3, 4))
@@ -11969,9 +11917,9 @@ def forward(self, x):
1196911917 tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
1197011918
1197111919 # Note: neg op from the traced function should be properly inlined
11972- FileCheck().check("aten::mm").check_same("scope: TracedModule") \
11920+ FileCheck().check("aten::mm") \
1197311921 .check('name="traced_fn"') \
11974- .check_next("prim::CallFunction").check("scope: TracedModule/traced_fn") \
11922+ .check_next("prim::CallFunction") \
1197511923 .run(str(tm.graph))
1197611924
1197711925 def test_trace_hierarchy(self):
@@ -15015,6 +14963,84 @@ def forward(self, q, k, v):
1501514963 model.mod.out_proj.bias)[0]
1501614964 self.assertTrue(torch.allclose(jit_out, py_out, atol=5e-4, rtol=1e-4))
1501714965
14966+ def test_trace_modulelist(self):
14967+ class MySubmod(torch.nn.Module):
14968+ def __init__(self):
14969+ super(MySubmod, self).__init__()
14970+ self.relu = torch.nn.ReLU()
14971+
14972+ def forward(self, x):
14973+ return self.relu(x)
14974+
14975+ class MyMod(torch.nn.Module):
14976+ def __init__(self):
14977+ super(MyMod, self).__init__()
14978+ self.ml = torch.nn.ModuleList([
14979+ MySubmod(),
14980+ MySubmod()
14981+ ])
14982+
14983+ def forward(self, x):
14984+ for mod in self.ml:
14985+ x = mod(x)
14986+ return x
14987+
14988+ traced = torch.jit.trace(MyMod(), (torch.rand(3, 4),))
14989+
14990+ def test_trace_fork_join_and_module(self):
14991+ class MySubmod(torch.nn.Module):
14992+ def __init__(self):
14993+ super(MySubmod, self).__init__()
14994+ self.relu = torch.nn.ReLU()
14995+
14996+ def forward(self, x):
14997+ return self.relu(x), torch.neg(x)
14998+
14999+ class Mod(torch.nn.Module):
15000+ def __init__(self):
15001+ super(Mod, self).__init__()
15002+ self.ml = torch.nn.ModuleList([
15003+ MySubmod() for i in range(2)
15004+ ])
15005+
15006+ def forward(self, x):
15007+ futs = []
15008+ for i in range(2):
15009+ futs.append(torch.jit._fork(self.ml[i], x))
15010+
15011+ results = []
15012+ for i in range(2):
15013+ results.append(torch.jit._wait(futs[i])[0])
15014+
15015+ return torch.stack(results)
15016+
15017+ m = Mod()
15018+ traced = torch.jit.trace(m, torch.rand(3, 4))
15019+
15020+ def test_trace_invert_module_hierarchy(self):
15021+ class MySubmod(torch.nn.Module):
15022+ def __init__(self):
15023+ super(MySubmod, self).__init__()
15024+ self.relu = torch.nn.ReLU()
15025+
15026+ def forward(self, x):
15027+ return self.relu(x), torch.neg(x)
15028+
15029+ class MyFunctionalMod(torch.nn.Module):
15030+ def forward(self, x, submod):
15031+ return submod(x)
15032+
15033+ class Mod(torch.nn.Module):
15034+ def __init__(self):
15035+ super(Mod, self).__init__()
15036+ self.sm = MySubmod()
15037+ self.fm = MyFunctionalMod()
15038+
15039+ def forward(self, x):
15040+ return self.fm(x, self.sm)
15041+
15042+ torch.jit.trace(Mod(), (torch.rand(3, 4),))
15043+
1501815044 @unittest.skipIf(not RUN_CUDA, "no CUDA")
1501915045 def test_scriptmodule_transformer_cuda(self):
1502015046
0 commit comments