@@ -4569,6 +4569,24 @@ def forward(self, x, y):
45694569 y = torch .randn (4 , 5 )
45704570 self .run_test (model , (x , y ))
45714571
4572+ @skipIfUnsupportedMinOpsetVersion (14 ) # Need onnx::identity of sequence in opset 14
4573+ def test_list_append_nested_2 (self ):
4574+ class ListModel (torch .nn .Module ):
4575+ def forward (self , x ):
4576+ res = []
4577+ res_replicate = []
4578+ for i in range (x .size (0 )):
4579+ if len (res ) > 2 :
4580+ for j in range (x .size (1 )):
4581+ res .append (x [i ][j ])
4582+ res_replicate .append (res [- 1 ])
4583+ res .append (res_replicate [- 1 ])
4584+ return res , res_replicate
4585+
4586+ model = torch .jit .script (ListModel ())
4587+ x = torch .randn (4 , 4 , 3 , 4 )
4588+ self .run_test (model , (x , ))
4589+
45724590 @skipIfUnsupportedMinOpsetVersion (11 )
45734591 def test_list_pop (self ):
45744592 class ListModel (torch .nn .Module ):
@@ -4651,6 +4669,36 @@ def forward(self, x, y):
46514669 y = torch .randn (4 , 5 )
46524670 self .run_test (model , (x , y ))
46534671
4672+ @skipIfUnsupportedMinOpsetVersion (11 )
4673+ def test_list_set (self ):
4674+ class ListModel (torch .nn .Module ):
4675+ def forward (self , x , y ):
4676+ res = []
4677+ for i in range (x .size (0 )):
4678+ res .append (x [i ])
4679+ res [y ] = x [y ]
4680+ return res
4681+
4682+ model = torch .jit .script (ListModel ())
4683+ x = torch .randn (12 , 4 )
4684+ y = torch .tensor (2 , dtype = torch .long )
4685+ self .run_test (model , (x , y ))
4686+
4687+ @skipIfUnsupportedMinOpsetVersion (13 )
4688+ def test_list_idx_sum (self ):
4689+ class ListModel (torch .nn .Module ):
4690+ def forward (self , x , y ):
4691+ indices = torch .arange (x .size (0 ))
4692+ res = []
4693+ for i in range (x .size (0 )):
4694+ res .append (x [i ])
4695+ return res [torch .sum (indices [:y ])]
4696+
4697+ model = torch .jit .script (ListModel ())
4698+ x = torch .randn (12 , 4 )
4699+ y = torch .tensor (2 , dtype = torch .long )
4700+ self .run_test (model , (x , y ))
4701+
46544702 @skipIfUnsupportedMinOpsetVersion (9 )
46554703 def test_tensor_factories (self ):
46564704 class TensorFactory (torch .nn .Module ):
@@ -4830,6 +4878,125 @@ def forward(self, x, y):
48304878 self .run_test (InplaceAddModel (), (x , y ), rtol = 1e-2 , atol = 1e-2 )
48314879 self .run_test (InplaceMulModel (), (x , y ), rtol = 1e-2 , atol = 1e-2 )
48324880
4881+ @skipIfUnsupportedMinOpsetVersion (9 )
4882+ def test_inplace_with_loop (self ):
4883+ class M (torch .nn .Module ):
4884+ def forward (self , x ):
4885+ a = torch .ones (12 ,)
4886+ for i in range (10 ):
4887+ a .add_ (torch .ones (12 ,))
4888+ return a + x
4889+
4890+ m = M ()
4891+ x = torch .randn (12 ,)
4892+ self .run_test (torch .jit .script (M ()), (x ))
4893+
4894+ @skipIfUnsupportedMinOpsetVersion (9 )
4895+ def test_inplace_with_loop_2 (self ):
4896+ class M (torch .nn .Module ):
4897+ def forward (self , x ):
4898+ _bias = torch .ones (12 ,)
4899+ a = torch .ones (12 ,) # used in loop, altered.
4900+ a_ref = a # not used in loop, should be altered.
4901+ b = x .clone () # used in loop, not be altered.
4902+ b_ref = b # not used in loop, should not be altered.
4903+ for i in range (10 ):
4904+ if i == 3 :
4905+ for j in range (5 ):
4906+ a += _bias
4907+ _bias .add_ (torch .ones (12 ,))
4908+ b = b + torch .ones (12 ,)
4909+
4910+ _bias .add_ (torch .ones (12 ,))
4911+ a += _bias
4912+ # TODO: value for a_ref is incorrect.
4913+ # a_ref += torch.ones(12,)
4914+ b_ref += torch .ones (12 ,)
4915+ return _bias + x , a , b , b_ref
4916+
4917+ m = M ()
4918+ x = torch .zeros (12 ,)
4919+ self .run_test (torch .jit .script (M ()), (x ))
4920+
4921+ @skipIfUnsupportedMinOpsetVersion (11 )
4922+ def test_inplace_attr_with_loop (self ):
4923+ class M (torch .nn .Module ):
4924+ def __init__ (self ):
4925+ super ().__init__ ()
4926+ self ._bias = torch .arange (12 ,)
4927+
4928+ def forward (self , x ):
4929+ self ._bias = torch .arange (12 ,)
4930+ for i in range (10 ):
4931+ if i == 3 :
4932+ for j in range (5 ):
4933+ self ._bias += torch .arange (12 ,)
4934+ return self ._bias + x
4935+
4936+ m = M ()
4937+ x = torch .zeros (12 ,)
4938+ self .run_test (torch .jit .script (M ()), (x ))
4939+
4940+ @skipIfUnsupportedMinOpsetVersion (11 )
4941+ def test_inplace_attr_copy_with_loop (self ):
4942+ class M (torch .nn .Module ):
4943+ def __init__ (self ):
4944+ super ().__init__ ()
4945+ self ._bias = torch .arange (12 ,)
4946+
4947+ def forward (self , x ):
4948+ self ._bias = torch .arange (12 ,)
4949+ for i in range (10 ):
4950+ if i == 3 :
4951+ for j in range (5 ):
4952+ self ._bias .copy_ (torch .arange (12 ,))
4953+ self ._bias .copy_ (self ._bias + torch .arange (12 ,))
4954+
4955+ self ._bias .copy_ (self ._bias + torch .arange (12 ,))
4956+ return self ._bias + x
4957+
4958+ m = M ()
4959+ x = torch .zeros (12 ,)
4960+ self .run_test (torch .jit .script (M ()), (x ))
4961+
4962+ @skipIfUnsupportedMinOpsetVersion (14 ) # Need onnx::identity of sequence in opset 14
4963+ def test_inplace_sequence_with_loop (self ):
4964+ class M (torch .nn .Module ):
4965+ def process (self , beam_hyps : List [torch .Tensor ], done : torch .Tensor , x ):
4966+ batch_size = x .shape [0 ]
4967+ for i in range (batch_size ):
4968+ if done [i ]:
4969+ continue
4970+
4971+ beam_idx = 0
4972+ for _ , token in enumerate (x [i ]):
4973+ beam_hyps .append (token )
4974+ beam_idx += 1
4975+
4976+ if beam_idx == 6 :
4977+ break
4978+
4979+ done [i ] = len (beam_hyps ) > 4
4980+
4981+ return beam_hyps , done
4982+
4983+ def forward (self , x ):
4984+ beam_hyps : List [torch .Tensor ] = []
4985+ batch_size = x .shape [0 ]
4986+ cur_len = 0
4987+ max_len = x .shape [1 ]
4988+ done = torch .zeros (batch_size , dtype = torch .bool )
4989+ while cur_len < max_len :
4990+ beam_hyps , done = self .process (beam_hyps , done , x [:, 0 , :])
4991+ cur_len = cur_len + 1
4992+
4993+ return beam_hyps
4994+
4995+ m = torch .jit .script (M ())
4996+ x = torch .randn (8 , 4 , 3 )
4997+ self .run_test (torch .jit .script (M ()), (x ))
4998+
4999+
48335000 @disableScriptTest () # Sort with dynamic dim not supported in ONNX
48345001 def test_sort (self ):
48355002 class SortModel (torch .nn .Module ):
@@ -7601,6 +7768,37 @@ def forward(self, feature_maps, anchors) -> Tuple[torch.Tensor, torch.Tensor]:
76017768 anchors = torch .ones (3 , 10 , 3 )
76027769 self .run_test (model , (x , anchors ))
76037770
7771+ @skipIfUnsupportedMinOpsetVersion (11 )
7772+ def test_set_attr_5 (self ):
7773+ class MyModule (torch .nn .Module ):
7774+ def __init__ (self ):
7775+ super (MyModule , self ).__init__ ()
7776+ self .conv = torch .nn .Conv1d (10 , 3 , 3 )
7777+ self .conv .bias = torch .nn .Parameter (torch .zeros (3 , 10 , 3 ))
7778+
7779+ def set_cell_anchors (self , anchors ):
7780+ self .conv .weight = torch .arange (10 )
7781+ for i in range (10 ):
7782+ if i == 3 :
7783+ for j in range (10 ):
7784+ w = self .conv .weight
7785+ self .conv .weight = torch .arange (10 ) + w
7786+
7787+ self .conv .weight = self .conv .weight + torch .arange (10 )
7788+ # NOTE: `is not None` and `assert` is for passing torchscript.
7789+ if self .conv .bias is not None :
7790+ a = self .conv .bias
7791+ assert a is not None
7792+ self .conv .bias = anchors + a
7793+
7794+ def forward (self , anchors ):
7795+ self .set_cell_anchors (anchors )
7796+ return self .conv .weight , self .conv .bias
7797+
7798+ model = torch .jit .script (MyModule ())
7799+ anchors = torch .ones (3 , 10 , 3 )
7800+ self .run_test (model , (anchors ))
7801+
76047802 @skipIfUnsupportedMinOpsetVersion (11 )
76057803 def test_set_attr_in_loop (self ):
76067804 class MyModule (torch .nn .Module ):
@@ -7698,7 +7896,11 @@ def forward(self, input_data, prev_state):
76987896 model = Example (10 )
76997897 random_data = torch .rand ((1 , 5 , 30 , 30 ))
77007898 empty_tensor = torch .tensor ([], dtype = torch .float ).view (0 , 0 , 0 , 0 , 0 )
7701- self .run_test (model , (random_data , empty_tensor ))
7899+ random_state = torch .rand ((1 , 1 , 10 , 30 , 30 ))
7900+ self .run_test (model , (random_data , empty_tensor ),
7901+ input_names = ['data' , 'state' ],
7902+ dynamic_axes = {'state' : [0 , 1 , 2 , 3 , 4 ]},
7903+ test_with_inputs = [(random_data , random_state )])
77027904
77037905 @skipIfUnsupportedMinOpsetVersion (11 )
77047906 def test_index_put_if_3 (self ):
@@ -7768,6 +7970,41 @@ def forward(self, input_data, prev_state):
77687970 empty_tensor = torch .tensor ([], dtype = torch .float ).view (0 , 0 , 0 , 0 , 0 )
77697971 self .run_test (model , (random_data , empty_tensor ))
77707972
7973+
7974+ @skipIfUnsupportedMinOpsetVersion (11 )
7975+ def test_index_put_if_5 (self ):
7976+ @torch .jit .script
7977+ def check_init (input_data , hidden_size , prev_state ):
7978+ # type: (torch.Tensor, int, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
7979+ batch_size = input_data .size (0 )
7980+ spatial_size_0 = input_data .size (2 )
7981+ spatial_size_1 = input_data .size (3 )
7982+ # generate empty prev_state, if None is provided
7983+ state_size = (2 , batch_size , hidden_size , spatial_size_0 , spatial_size_1 )
7984+ state = torch .zeros (state_size , device = input_data .device )
7985+ state_ref = state
7986+ if prev_state .size (0 ) == 0 :
7987+ state [:] = torch .ones (batch_size , hidden_size , spatial_size_0 , spatial_size_1 ) * 3
7988+ state = state + 3
7989+ state [:] = torch .ones (batch_size , hidden_size , spatial_size_0 , spatial_size_1 ) * 4
7990+ else :
7991+ state = state + 2
7992+ return state , state_ref
7993+
7994+ class Example (torch .nn .Module ):
7995+ def __init__ (self , hidden_size ):
7996+ super ().__init__ ()
7997+ self .hidden_size = hidden_size
7998+
7999+ def forward (self , input_data , prev_state ):
8000+ prev_state , state_ref = check_init (input_data , self .hidden_size , prev_state )
8001+ return prev_state , state_ref
8002+
8003+ model = Example (4 )
8004+ random_data = torch .rand ((1 , 5 , 4 , 4 ))
8005+ empty_tensor = torch .tensor ([], dtype = torch .float ).view (0 , 0 , 0 , 0 , 0 )
8006+ self .run_test (model , (random_data , empty_tensor ))
8007+
77718008 @skipIfUnsupportedMinOpsetVersion (11 )
77728009 def test_list_append_in_block (self ):
77738010 class ListModel (torch .nn .Module ):
0 commit comments