Skip to content

Commit 4f9fcd7

Browse files
jiashenCpytorchmergebot
authored andcommitted
Handle unpacking during TorchScript to ExportedProgram conversion (#127419)
Pull Request resolved: #127419 Approved by: https://github.com/angelayi
1 parent 9f2c4b9 commit 4f9fcd7

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

test/export/test_converter.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,23 @@ def forward(
247247
inp = (torch.randn(10, 10), torch.rand(10, 10))
248248
self._check_equal_ts_ep_converter(Module(), inp)
249249

250+
def test_ts2ep_converter_unpack(self):
251+
class MUnpackList(torch.nn.Module):
252+
def forward(self, x):
253+
x, y = torch.split(x, 2)
254+
return x + y
255+
256+
class MUnpackTuple(torch.nn.Module):
257+
def forward(self, x_tuple: Tuple[torch.Tensor, torch.Tensor]):
258+
x, y = x_tuple
259+
x = x.cos()
260+
return x + y
261+
262+
inp = torch.ones(1, 4)
263+
self._check_equal_ts_ep_converter(MUnpackList(), inp)
264+
inp = ((torch.zeros(1, 4), torch.ones(1, 4)),)
265+
self._check_equal_ts_ep_converter(MUnpackTuple(), inp)
266+
250267

251268
if __name__ == "__main__":
252269
run_tests()

torch/_export/converter.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,20 @@ def convert_prim_DictConstruct(self, node: torch._C.Node):
319319
output_name = node.output().debugName()
320320
self.name_to_node[output_name] = output_dict
321321

322+
def convert_prim_ListUnpack(self, node: torch._C.Node):
323+
self._convert_prim_unpack_iterator(node)
324+
325+
def convert_prim_TupleUnpack(self, node: torch._C.Node):
326+
self._convert_prim_unpack_iterator(node)
327+
328+
def _convert_prim_unpack_iterator(self, node: torch._C.Node):
329+
# Single input and multiple outputs for unpacking.
330+
for i, outp in enumerate(node.outputs()):
331+
outp_name = outp.debugName()
332+
inp = self.get_fx_value(node.input())
333+
fx_node = self.fx_graph.call_function(operator.getitem, (inp, i))
334+
self.name_to_node[outp_name] = fx_node
335+
322336
def convert_aten_Int(self, node: torch._C.Node):
323337
# converts aten::Int as aten._to_copy + aten::_local_scalar_dense
324338
target = torch.ops.aten._to_copy.default

0 commit comments

Comments
 (0)