Skip to content

Commit 24ff92f

Browse files
BowenBaofacebook-github-bot
authored andcommitted
[ONNX] Redesign inplace conversion (#55033) (#56173)
Summary: Pull Request resolved: #56173 * Create `InplaceConverter` and `ValueTracker` to keep track of aliases of values throughout the graph. For a given value, a new alias is created every time when there is an inplace operation, SetAttr, or through nested blocks owned by If/Loop nodes. * Fix bug where controlflow node output types are not set, when the complete node is unable to run ONNX shape inference due to containing non-onnx node. * Add symbolic for `__not__` ~~and `prim_min`~~(update: moved to a separate PR), and update `index_put` opset9 to support case of assignment without providing indices. * Bump ORT version in CI test. Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D27866138 Pulled By: SplitInfinity fbshipit-source-id: ab5c9188740c50f783ceba4d54fda43c26e2fde7
1 parent 818ce1d commit 24ff92f

File tree

6 files changed

+874
-592
lines changed

6 files changed

+874
-592
lines changed

.jenkins/caffe2/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
170170
# JIT C++ extensions require ninja, so put it into PATH.
171171
export PATH="/var/lib/jenkins/.local/bin:$PATH"
172172
if [[ "$BUILD_ENVIRONMENT" == *py3* ]]; then
173-
pip install -q --user onnxruntime==1.6.0
173+
pip install -q --user onnxruntime==1.7.0
174174
fi
175175
"$ROOT_DIR/scripts/onnx/test.sh"
176176
fi

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 238 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4569,6 +4569,24 @@ def forward(self, x, y):
45694569
y = torch.randn(4, 5)
45704570
self.run_test(model, (x, y))
45714571

4572+
@skipIfUnsupportedMinOpsetVersion(14) # Need onnx::identity of sequence in opset 14
4573+
def test_list_append_nested_2(self):
4574+
class ListModel(torch.nn.Module):
4575+
def forward(self, x):
4576+
res = []
4577+
res_replicate = []
4578+
for i in range(x.size(0)):
4579+
if len(res) > 2:
4580+
for j in range(x.size(1)):
4581+
res.append(x[i][j])
4582+
res_replicate.append(res[-1])
4583+
res.append(res_replicate[-1])
4584+
return res, res_replicate
4585+
4586+
model = torch.jit.script(ListModel())
4587+
x = torch.randn(4, 4, 3, 4)
4588+
self.run_test(model, (x, ))
4589+
45724590
@skipIfUnsupportedMinOpsetVersion(11)
45734591
def test_list_pop(self):
45744592
class ListModel(torch.nn.Module):
@@ -4651,6 +4669,36 @@ def forward(self, x, y):
46514669
y = torch.randn(4, 5)
46524670
self.run_test(model, (x, y))
46534671

4672+
@skipIfUnsupportedMinOpsetVersion(11)
4673+
def test_list_set(self):
4674+
class ListModel(torch.nn.Module):
4675+
def forward(self, x, y):
4676+
res = []
4677+
for i in range(x.size(0)):
4678+
res.append(x[i])
4679+
res[y] = x[y]
4680+
return res
4681+
4682+
model = torch.jit.script(ListModel())
4683+
x = torch.randn(12, 4)
4684+
y = torch.tensor(2, dtype=torch.long)
4685+
self.run_test(model, (x, y))
4686+
4687+
@skipIfUnsupportedMinOpsetVersion(13)
4688+
def test_list_idx_sum(self):
4689+
class ListModel(torch.nn.Module):
4690+
def forward(self, x, y):
4691+
indices = torch.arange(x.size(0))
4692+
res = []
4693+
for i in range(x.size(0)):
4694+
res.append(x[i])
4695+
return res[torch.sum(indices[:y])]
4696+
4697+
model = torch.jit.script(ListModel())
4698+
x = torch.randn(12, 4)
4699+
y = torch.tensor(2, dtype=torch.long)
4700+
self.run_test(model, (x, y))
4701+
46544702
@skipIfUnsupportedMinOpsetVersion(9)
46554703
def test_tensor_factories(self):
46564704
class TensorFactory(torch.nn.Module):
@@ -4830,6 +4878,125 @@ def forward(self, x, y):
48304878
self.run_test(InplaceAddModel(), (x, y), rtol=1e-2, atol=1e-2)
48314879
self.run_test(InplaceMulModel(), (x, y), rtol=1e-2, atol=1e-2)
48324880

4881+
@skipIfUnsupportedMinOpsetVersion(9)
4882+
def test_inplace_with_loop(self):
4883+
class M(torch.nn.Module):
4884+
def forward(self, x):
4885+
a = torch.ones(12,)
4886+
for i in range(10):
4887+
a.add_(torch.ones(12,))
4888+
return a + x
4889+
4890+
m = M()
4891+
x = torch.randn(12,)
4892+
self.run_test(torch.jit.script(M()), (x))
4893+
4894+
@skipIfUnsupportedMinOpsetVersion(9)
4895+
def test_inplace_with_loop_2(self):
4896+
class M(torch.nn.Module):
4897+
def forward(self, x):
4898+
_bias = torch.ones(12,)
4899+
a = torch.ones(12,) # used in loop, altered.
4900+
a_ref = a # not used in loop, should be altered.
4901+
b = x.clone() # used in loop, not be altered.
4902+
b_ref = b # not used in loop, should not be altered.
4903+
for i in range(10):
4904+
if i == 3:
4905+
for j in range(5):
4906+
a += _bias
4907+
_bias.add_(torch.ones(12,))
4908+
b = b + torch.ones(12,)
4909+
4910+
_bias.add_(torch.ones(12,))
4911+
a += _bias
4912+
# TODO: value for a_ref is incorrect.
4913+
# a_ref += torch.ones(12,)
4914+
b_ref += torch.ones(12,)
4915+
return _bias + x, a, b, b_ref
4916+
4917+
m = M()
4918+
x = torch.zeros(12,)
4919+
self.run_test(torch.jit.script(M()), (x))
4920+
4921+
@skipIfUnsupportedMinOpsetVersion(11)
4922+
def test_inplace_attr_with_loop(self):
4923+
class M(torch.nn.Module):
4924+
def __init__(self):
4925+
super().__init__()
4926+
self._bias = torch.arange(12,)
4927+
4928+
def forward(self, x):
4929+
self._bias = torch.arange(12,)
4930+
for i in range(10):
4931+
if i == 3:
4932+
for j in range(5):
4933+
self._bias += torch.arange(12,)
4934+
return self._bias + x
4935+
4936+
m = M()
4937+
x = torch.zeros(12,)
4938+
self.run_test(torch.jit.script(M()), (x))
4939+
4940+
@skipIfUnsupportedMinOpsetVersion(11)
4941+
def test_inplace_attr_copy_with_loop(self):
4942+
class M(torch.nn.Module):
4943+
def __init__(self):
4944+
super().__init__()
4945+
self._bias = torch.arange(12,)
4946+
4947+
def forward(self, x):
4948+
self._bias = torch.arange(12,)
4949+
for i in range(10):
4950+
if i == 3:
4951+
for j in range(5):
4952+
self._bias.copy_(torch.arange(12,))
4953+
self._bias.copy_(self._bias + torch.arange(12,))
4954+
4955+
self._bias.copy_(self._bias + torch.arange(12,))
4956+
return self._bias + x
4957+
4958+
m = M()
4959+
x = torch.zeros(12,)
4960+
self.run_test(torch.jit.script(M()), (x))
4961+
4962+
@skipIfUnsupportedMinOpsetVersion(14) # Need onnx::identity of sequence in opset 14
4963+
def test_inplace_sequence_with_loop(self):
4964+
class M(torch.nn.Module):
4965+
def process(self, beam_hyps: List[torch.Tensor], done: torch.Tensor, x):
4966+
batch_size = x.shape[0]
4967+
for i in range(batch_size):
4968+
if done[i]:
4969+
continue
4970+
4971+
beam_idx = 0
4972+
for _, token in enumerate(x[i]):
4973+
beam_hyps.append(token)
4974+
beam_idx += 1
4975+
4976+
if beam_idx == 6:
4977+
break
4978+
4979+
done[i] = len(beam_hyps) > 4
4980+
4981+
return beam_hyps, done
4982+
4983+
def forward(self, x):
4984+
beam_hyps: List[torch.Tensor] = []
4985+
batch_size = x.shape[0]
4986+
cur_len = 0
4987+
max_len = x.shape[1]
4988+
done = torch.zeros(batch_size, dtype=torch.bool)
4989+
while cur_len < max_len:
4990+
beam_hyps, done = self.process(beam_hyps, done, x[:, 0, :])
4991+
cur_len = cur_len + 1
4992+
4993+
return beam_hyps
4994+
4995+
m = torch.jit.script(M())
4996+
x = torch.randn(8, 4, 3)
4997+
self.run_test(torch.jit.script(M()), (x))
4998+
4999+
48335000
@disableScriptTest() # Sort with dynamic dim not supported in ONNX
48345001
def test_sort(self):
48355002
class SortModel(torch.nn.Module):
@@ -7601,6 +7768,37 @@ def forward(self, feature_maps, anchors) -> Tuple[torch.Tensor, torch.Tensor]:
76017768
anchors = torch.ones(3, 10, 3)
76027769
self.run_test(model, (x, anchors))
76037770

7771+
@skipIfUnsupportedMinOpsetVersion(11)
7772+
def test_set_attr_5(self):
7773+
class MyModule(torch.nn.Module):
7774+
def __init__(self):
7775+
super(MyModule, self).__init__()
7776+
self.conv = torch.nn.Conv1d(10, 3, 3)
7777+
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
7778+
7779+
def set_cell_anchors(self, anchors):
7780+
self.conv.weight = torch.arange(10)
7781+
for i in range(10):
7782+
if i == 3:
7783+
for j in range(10):
7784+
w = self.conv.weight
7785+
self.conv.weight = torch.arange(10) + w
7786+
7787+
self.conv.weight = self.conv.weight + torch.arange(10)
7788+
# NOTE: `is not None` and `assert` is for passing torchscript.
7789+
if self.conv.bias is not None:
7790+
a = self.conv.bias
7791+
assert a is not None
7792+
self.conv.bias = anchors + a
7793+
7794+
def forward(self, anchors):
7795+
self.set_cell_anchors(anchors)
7796+
return self.conv.weight, self.conv.bias
7797+
7798+
model = torch.jit.script(MyModule())
7799+
anchors = torch.ones(3, 10, 3)
7800+
self.run_test(model, (anchors))
7801+
76047802
@skipIfUnsupportedMinOpsetVersion(11)
76057803
def test_set_attr_in_loop(self):
76067804
class MyModule(torch.nn.Module):
@@ -7698,7 +7896,11 @@ def forward(self, input_data, prev_state):
76987896
model = Example(10)
76997897
random_data = torch.rand((1, 5, 30, 30))
77007898
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
7701-
self.run_test(model, (random_data, empty_tensor))
7899+
random_state = torch.rand((1, 1, 10, 30, 30))
7900+
self.run_test(model, (random_data, empty_tensor),
7901+
input_names=['data', 'state'],
7902+
dynamic_axes={'state': [0, 1, 2, 3, 4]},
7903+
test_with_inputs=[(random_data, random_state)])
77027904

77037905
@skipIfUnsupportedMinOpsetVersion(11)
77047906
def test_index_put_if_3(self):
@@ -7768,6 +7970,41 @@ def forward(self, input_data, prev_state):
77687970
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
77697971
self.run_test(model, (random_data, empty_tensor))
77707972

7973+
7974+
@skipIfUnsupportedMinOpsetVersion(11)
7975+
def test_index_put_if_5(self):
7976+
@torch.jit.script
7977+
def check_init(input_data, hidden_size, prev_state):
7978+
# type: (torch.Tensor, int, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
7979+
batch_size = input_data.size(0)
7980+
spatial_size_0 = input_data.size(2)
7981+
spatial_size_1 = input_data.size(3)
7982+
# generate empty prev_state, if None is provided
7983+
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
7984+
state = torch.zeros(state_size, device=input_data.device)
7985+
state_ref = state
7986+
if prev_state.size(0) == 0:
7987+
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 3
7988+
state = state + 3
7989+
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 4
7990+
else:
7991+
state = state + 2
7992+
return state, state_ref
7993+
7994+
class Example(torch.nn.Module):
7995+
def __init__(self, hidden_size):
7996+
super().__init__()
7997+
self.hidden_size = hidden_size
7998+
7999+
def forward(self, input_data, prev_state):
8000+
prev_state, state_ref = check_init(input_data, self.hidden_size, prev_state)
8001+
return prev_state, state_ref
8002+
8003+
model = Example(4)
8004+
random_data = torch.rand((1, 5, 4, 4))
8005+
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
8006+
self.run_test(model, (random_data, empty_tensor))
8007+
77718008
@skipIfUnsupportedMinOpsetVersion(11)
77728009
def test_list_append_in_block(self):
77738010
class ListModel(torch.nn.Module):

torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@ std::vector<Value*> FixupONNXLoopNode(Node* node, int opset_version) {
236236
// NOTE: the output order is deliberately changed to match expected order
237237
// since onnx loop requires scan outputs to be the last outputs.
238238
auto new_outputs = ConvertSequenceDependencies(node, opset_version);
239+
240+
// Copy type of block output to node output.
241+
for (size_t i = 0; i < node->outputs().size(); ++i) {
242+
node->output(i)->setType(node->blocks().at(0)->outputs().at(i + 1)->type());
243+
}
239244
TORCH_INTERNAL_ASSERT(output_size == new_outputs.size());
240245
return new_outputs;
241246
}
@@ -375,6 +380,11 @@ std::vector<Value*> FixupONNXIfNode(Node* node, int opset_version) {
375380
auto* graph = if_node->owningGraph();
376381
FixupONNXSubblockOutputs(node);
377382
ONNXFixupUninitializedOutput(if_node);
383+
// Copy type of block output to node output.
384+
for (size_t i = 0; i < node->outputs().size(); ++i) {
385+
node->output(i)->setType(node->blocks().at(0)->outputs().at(i)->type());
386+
}
387+
378388
GRAPH_DUMP("Graph after fixing controlflow: ", node->owningGraph());
379389
return if_node->outputs().vec();
380390
}

0 commit comments

Comments
 (0)