Skip to content

Commit 309b28e

Browse files
James Reedfacebook-github-bot
authored andcommitted
Trace module calls
Summary: Pull Request resolved: #29261 Test Plan: Imported from OSS Differential Revision: D18343363 Pulled By: jamesr66a fbshipit-source-id: 0c6394205e2c0ea8708028d20df83fe17b466ff4
1 parent 0f4b226 commit 309b28e

21 files changed

+771
-248
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ namespace c10 {
9494
_(prim, range) \
9595
_(prim, rangelist) \
9696
_(prim, isinstance) \
97-
_(prim, unchecked_cast) \
97+
_(prim, unchecked_cast) \
9898
_(aten, _grad_sum_to_size) \
9999
_(aten, _size_if_not_equal) \
100100
_(aten, _ncf_unsqueeze) \
@@ -118,7 +118,10 @@ namespace c10 {
118118
_(prim, CallFunction) \
119119
_(prim, CallMethod) \
120120
_(prim, LoopContinuation) \
121-
_(prim, annotate) \
121+
_(prim, annotate) \
122+
_(prim, TracedModuleForward) \
123+
_(prim, TracedFork) \
124+
_(prim, TracedAttr) \
122125
_(aten, append) \
123126
_(aten, item) \
124127
_(aten, format) \
@@ -226,7 +229,8 @@ namespace c10 {
226229
_(attr, split) \
227230
_(attr, slot) \
228231
_(attr, kinds) \
229-
_(attr, types)
232+
_(attr, types) \
233+
_(attr, scope)
230234
#else
231235
#define FORALL_NS_SYMBOLS(_) \
232236
_(namespaces, prim) \

caffe2/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
400400
${TORCH_SRC_DIR}/csrc/jit/passes/decompose_ops.cpp
401401
${TORCH_SRC_DIR}/csrc/jit/passes/canonicalize_ops.cpp
402402
${TORCH_SRC_DIR}/csrc/jit/passes/erase_number_types.cpp
403+
${TORCH_SRC_DIR}/csrc/jit/passes/fixup_trace_scope_blocks.cpp
403404
${TORCH_SRC_DIR}/csrc/jit/passes/inline_fork_wait.cpp
404405
${TORCH_SRC_DIR}/csrc/jit/passes/graph_fuser.cpp
405406
${TORCH_SRC_DIR}/csrc/jit/passes/guard_elimination.cpp

test/cpp/jit/test_alias_analysis.cpp

Lines changed: 0 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -391,94 +391,6 @@ void testAliasAnalysis() {
391391
AliasDb aliasDb(graph);
392392
ASSERT_FALSE(aliasDb.moveBeforeTopologicallyValid(c->node(), if_));
393393
}
394-
{
395-
// test fork/wait
396-
397-
// a = rand(1)
398-
// fut = fork(a)
399-
// Subgraph is: return a.add_(1)
400-
// ... some unrelated code
401-
// c = wait(b)
402-
// d = a + a
403-
404-
auto graph = std::make_shared<Graph>();
405-
auto constant = graph->insertConstant(1);
406-
auto a = graph->insert(aten::rand, {constant});
407-
408-
auto forkNode = graph->insertNode(graph->create(prim::fork));
409-
auto forkBlock = forkNode->addBlock();
410-
{
411-
WithInsertPoint g(forkBlock);
412-
auto aMut = graph->insert(aten::add_, {a, constant});
413-
forkBlock->registerOutput(aMut);
414-
forkNode->output()->setType(FutureType::create(aMut->type()));
415-
}
416-
script::lambdaLiftFork(forkNode);
417-
418-
auto fut = forkNode->output();
419-
auto wait = graph->insert(aten::wait, {fut})->node();
420-
auto d = graph->insert(aten::add, {a, a});
421-
422-
graph->lint();
423-
424-
// Should not be able to move `d` before the wait call
425-
AliasDb aliasDb(graph);
426-
ASSERT_FALSE(aliasDb.moveBeforeTopologicallyValid(d->node(), wait));
427-
}
428-
{
429-
// test fork/wait in an if statement
430-
431-
// a = rand(1)
432-
// if 1:
433-
// fut = fork(a)
434-
// Subgraph is: return a.add_(1)
435-
// else:
436-
// fut = fork(a)
437-
// Subgraph is: return a.sub_(1)
438-
// c = wait(b)
439-
// d = a + a
440-
441-
auto graph = std::make_shared<Graph>();
442-
auto constant = graph->insertConstant(1);
443-
auto a = graph->insert(aten::rand, {constant});
444-
auto if_ = insertIf(
445-
*graph,
446-
constant,
447-
[&]() -> std::vector<Value*> {
448-
auto forkNode = graph->insertNode(graph->create(prim::fork));
449-
auto forkBlock = forkNode->addBlock();
450-
{
451-
WithInsertPoint g(forkBlock);
452-
auto aMut = graph->insert(aten::add_, {a, constant});
453-
forkBlock->registerOutput(aMut);
454-
forkNode->output()->setType(FutureType::create(aMut->type()));
455-
}
456-
script::lambdaLiftFork(forkNode);
457-
return {forkNode->output()};
458-
},
459-
[&]() -> std::vector<Value*> {
460-
auto forkNode = graph->insertNode(graph->create(prim::fork));
461-
auto forkBlock = forkNode->addBlock();
462-
{
463-
WithInsertPoint g(forkBlock);
464-
auto aMut = graph->insert(aten::sub_, {a, constant});
465-
forkBlock->registerOutput(aMut);
466-
forkNode->output()->setType(FutureType::create(aMut->type()));
467-
}
468-
script::lambdaLiftFork(forkNode);
469-
return {forkNode->output()};
470-
});
471-
472-
auto fut = if_->output();
473-
auto wait = graph->insert(aten::wait, {fut})->node();
474-
auto d = graph->insert(aten::add, {a, a});
475-
476-
graph->lint();
477-
478-
// Should not be able to move `d` before the wait call
479-
AliasDb aliasDb(graph);
480-
ASSERT_FALSE(aliasDb.moveBeforeTopologicallyValid(d->node(), wait));
481-
}
482394

483395
// test none value does not have writers
484396
{

test/test_jit.py

Lines changed: 84 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -804,61 +804,6 @@ def run(f):
804804
traced_fn = torch.jit.trace(fn, torch.ones(1))
805805
self.assertEqual(run(fn), run(traced_fn))
806806

807-
def test_scopes(self):
808-
x = torch.tensor([0.4], requires_grad=True)
809-
y = torch.tensor([0.7], requires_grad=True)
810-
811-
def f(x, y):
812-
out = x + y
813-
with torch.jit.scope('Foo'):
814-
out = x * out
815-
with torch.jit.scope('Bar'):
816-
out = torch.tanh(out)
817-
out = torch.sigmoid(out)
818-
return out
819-
820-
self.checkTrace(f, (x, y))
821-
822-
def test_scopes_intermediate_node(self):
823-
class Net(nn.Module):
824-
def forward(self, x):
825-
return F.log_softmax(x, dim=0)
826-
827-
net = Net()
828-
t = torch.ones(2, requires_grad=True)
829-
g, outputs, inputs = torch.jit._get_trace_graph(net, (t,), return_inputs=True)
830-
self.assertEqual(outputs, self.createFunctionFromGraph(g)(*inputs))
831-
self.assertExportImport(g, (t,))
832-
g = torch.onnx._optimize_trace(g, operator_export_type=OperatorExportTypes.ONNX)
833-
FileCheck().check("onnx::LogSoftmax").check("scope: Net").run(str(g))
834-
835-
def test_scopes_identity_node(self):
836-
837-
class Net(nn.Module):
838-
839-
def __init__(self):
840-
super(Net, self).__init__()
841-
self.features = nn.Sequential(
842-
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
843-
nn.ReLU(inplace=True),
844-
nn.MaxPool2d(kernel_size=3, stride=2),
845-
)
846-
847-
def forward(self, x):
848-
x = self.features(x)
849-
return x
850-
851-
model = Net()
852-
853-
t = torch.ones(1, 3, 227, 227, requires_grad=True)
854-
855-
with torch.onnx.set_training(model, False):
856-
g, _ = torch.jit._get_trace_graph(model, (t,))
857-
858-
self.assertExportImport(g, (t,) + tuple(model.parameters()))
859-
g = torch.onnx._optimize_trace(g, operator_export_type=OperatorExportTypes.ONNX)
860-
FileCheck().check("Net/Sequential[features]/Conv2d[0]").check("ReLU").check("MaxPool").run(str(g))
861-
862807
def test_canonicalize_tensor_iterator(self):
863808
x = torch.randn(4, 4)
864809

@@ -11949,9 +11894,12 @@ def forward(self, x):
1194911894

1195011895
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
1195111896

11952-
FileCheck().check_not("value=<Tensor>").check_count("aten::mm", 2).check("aten::add") \
11897+
FileCheck().check_not("value=<Tensor>").check("aten::mm")\
11898+
.check("prim::CallMethod[name=\"forward\"]").check("aten::add") \
1195311899
.run(str(tm.graph))
1195411900

11901+
FileCheck().check("aten::mm").run(str(tm.mod.graph))
11902+
1195511903
@_tmp_donotuse_dont_inline_everything
1195611904
def test_call_traced_fn_from_traced_module(self):
1195711905
@_trace(torch.rand(3, 4))
@@ -11969,9 +11917,9 @@ def forward(self, x):
1196911917
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
1197011918

1197111919
# Note: neg op from the traced function should be properly inlined
11972-
FileCheck().check("aten::mm").check_same("scope: TracedModule") \
11920+
FileCheck().check("aten::mm") \
1197311921
.check('name="traced_fn"') \
11974-
.check_next("prim::CallFunction").check("scope: TracedModule/traced_fn") \
11922+
.check_next("prim::CallFunction") \
1197511923
.run(str(tm.graph))
1197611924

1197711925
def test_trace_hierarchy(self):
@@ -15015,6 +14963,84 @@ def forward(self, q, k, v):
1501514963
model.mod.out_proj.bias)[0]
1501614964
self.assertTrue(torch.allclose(jit_out, py_out, atol=5e-4, rtol=1e-4))
1501714965

14966+
def test_trace_modulelist(self):
14967+
class MySubmod(torch.nn.Module):
14968+
def __init__(self):
14969+
super(MySubmod, self).__init__()
14970+
self.relu = torch.nn.ReLU()
14971+
14972+
def forward(self, x):
14973+
return self.relu(x)
14974+
14975+
class MyMod(torch.nn.Module):
14976+
def __init__(self):
14977+
super(MyMod, self).__init__()
14978+
self.ml = torch.nn.ModuleList([
14979+
MySubmod(),
14980+
MySubmod()
14981+
])
14982+
14983+
def forward(self, x):
14984+
for mod in self.ml:
14985+
x = mod(x)
14986+
return x
14987+
14988+
traced = torch.jit.trace(MyMod(), (torch.rand(3, 4),))
14989+
14990+
def test_trace_fork_join_and_module(self):
14991+
class MySubmod(torch.nn.Module):
14992+
def __init__(self):
14993+
super(MySubmod, self).__init__()
14994+
self.relu = torch.nn.ReLU()
14995+
14996+
def forward(self, x):
14997+
return self.relu(x), torch.neg(x)
14998+
14999+
class Mod(torch.nn.Module):
15000+
def __init__(self):
15001+
super(Mod, self).__init__()
15002+
self.ml = torch.nn.ModuleList([
15003+
MySubmod() for i in range(2)
15004+
])
15005+
15006+
def forward(self, x):
15007+
futs = []
15008+
for i in range(2):
15009+
futs.append(torch.jit._fork(self.ml[i], x))
15010+
15011+
results = []
15012+
for i in range(2):
15013+
results.append(torch.jit._wait(futs[i])[0])
15014+
15015+
return torch.stack(results)
15016+
15017+
m = Mod()
15018+
traced = torch.jit.trace(m, torch.rand(3, 4))
15019+
15020+
def test_trace_invert_module_hierarchy(self):
15021+
class MySubmod(torch.nn.Module):
15022+
def __init__(self):
15023+
super(MySubmod, self).__init__()
15024+
self.relu = torch.nn.ReLU()
15025+
15026+
def forward(self, x):
15027+
return self.relu(x), torch.neg(x)
15028+
15029+
class MyFunctionalMod(torch.nn.Module):
15030+
def forward(self, x, submod):
15031+
return submod(x)
15032+
15033+
class Mod(torch.nn.Module):
15034+
def __init__(self):
15035+
super(Mod, self).__init__()
15036+
self.sm = MySubmod()
15037+
self.fm = MyFunctionalMod()
15038+
15039+
def forward(self, x):
15040+
return self.fm(x, self.sm)
15041+
15042+
torch.jit.trace(Mod(), (torch.rand(3, 4),))
15043+
1501815044
@unittest.skipIf(not RUN_CUDA, "no CUDA")
1501915045
def test_scriptmodule_transformer_cuda(self):
1502015046

tools/build_variables.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
"torch/csrc/jit/passes/create_autodiff_subgraphs.cpp",
116116
"torch/csrc/jit/passes/dead_code_elimination.cpp",
117117
"torch/csrc/jit/passes/erase_number_types.cpp",
118+
"torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp",
118119
"torch/csrc/jit/passes/graph_fuser.cpp",
119120
"torch/csrc/jit/passes/guard_elimination.cpp",
120121
"torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp",

torch/csrc/jit/init.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ void initJITBindings(PyObject* module) {
535535

536536
if (jit::tracer::isTracing()) {
537537
auto graph = jit::tracer::getTracingState()->graph;
538-
auto fork_node = graph->insertNode(graph->create(prim::fork, 1));
538+
auto fork_node = graph->insertNode(graph->create(prim::TracedFork, 1));
539539
auto body_block = fork_node->addBlock();
540540

541541
Value* node_output;
@@ -557,9 +557,6 @@ void initJITBindings(PyObject* module) {
557557
body_block->registerOutput(out_val);
558558
node_output =
559559
fork_node->output()->setType(FutureType::create(out_val->type()));
560-
561-
// Lambda lift into a Subgraph attribute
562-
torch::jit::script::lambdaLiftFork(fork_node);
563560
}
564561

565562
auto retval =

torch/csrc/jit/ir.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,22 @@ void Graph::dump() const {
585585
std::cout << *this << "\n";
586586
}
587587

588-
void LintGraph(std::shared_ptr<Graph>& graph) {
588+
void Graph::push_scope(const std::string& scope_name) {
589+
current_scope_ = current_scope_->push(Symbol::scope(scope_name));
590+
Node* block_node = insertNode(create(prim::TracedModuleForward, 0));
591+
block_node->s_(attr::scope, scope_name);
592+
Block* b = block_node->addBlock();
593+
setInsertPoint(b);
594+
}
595+
void Graph::pop_scope() {
596+
current_scope_ = current_scope_->parent();
597+
if (insertPoint()->owningBlock()->owningNode()->kind() ==
598+
prim::TracedModuleForward) {
599+
setInsertPoint(insertPoint()->owningBlock()->owningNode()->next());
600+
}
601+
}
602+
603+
void LintGraph(const std::shared_ptr<Graph>& graph) {
589604
graph->lint();
590605
}
591606

0 commit comments

Comments
 (0)