@@ -1356,6 +1356,16 @@ def test_state_dict_deterministic(self, device, dtype, optim_info):
13561356 input = torch .randn (3 , requires_grad = True , device = device , dtype = dtype )
13571357 params = [weight , bias ]
13581358
1359+ def make_param_and_named_param (param ):
1360+ named_param = [(f'name{ i } ' , p ) for i , p in enumerate (param )]
1361+ return [param , named_param ]
1362+
1363+ def without_param_names (state_dict ):
1364+ new_state_dict = deepcopy (state_dict )
1365+ for pg in new_state_dict ['param_groups' ]:
1366+ pg .pop ('param_names' , None )
1367+ return new_state_dict
1368+
13591369 def fwd_bwd (optim , w , b , i ):
13601370 optim .zero_grad ()
13611371 loss = (w .mv (i ) + b ).pow (2 ).sum ()
@@ -1368,51 +1378,55 @@ def fwd_bwd(optim, w, b, i):
13681378 return loss
13691379
13701380 for optim_input in all_optim_inputs :
1371- optimizer = optim_cls (params , ** optim_input .kwargs )
1372- closure = functools .partial (fwd_bwd , optimizer , weight , bias , input )
1373-
1374- # Prime the optimizer
1375- for _ in range (10 ):
1376- if optim_info .step_requires_closure :
1377- optimizer .step (closure )
1378- else :
1379- closure ()
1380- optimizer .step ()
1381+ for param_in in make_param_and_named_param (params ):
1382+ optimizer = optim_cls (param_in , ** optim_input .kwargs )
1383+ closure = functools .partial (fwd_bwd , optimizer , weight , bias , input )
13811384
1382- # Clone the weights and construct a new optimizer for them
1383- with torch .no_grad ():
1384- weight_c = Parameter (weight .clone ())
1385- bias_c = Parameter (bias .clone ())
1386-
1387- optimizer_c = optim_cls ([weight_c , bias_c ], ** optim_input .kwargs )
1388- closure_c = functools .partial (fwd_bwd , optimizer_c , weight_c , bias_c , input )
1389-
1390- # Load the state dict from the original optimizer into the new one
1391- optimizer_c .load_state_dict (deepcopy (optimizer .state_dict ()))
1385+ # Prime the optimizer
1386+ for _ in range (10 ):
1387+ if optim_info .step_requires_closure :
1388+ optimizer .step (closure )
1389+ else :
1390+ closure ()
1391+ optimizer .step ()
13921392
1393- # Run both optimizers in parallel
1394- for _ in range (10 ):
1395- if optim_info .step_requires_closure :
1396- optimizer .step (closure )
1397- optimizer_c .step (closure_c )
1398- else :
1399- closure ()
1400- closure_c ()
1401- optimizer .step ()
1402- optimizer_c .step ()
1393+ for param_c_index in range (2 ):
1394+ # Clone the weights and construct a new optimizer for them
1395+ with torch .no_grad ():
1396+ weight_c = Parameter (weight .clone ())
1397+ bias_c = Parameter (bias .clone ())
1398+ param_c = make_param_and_named_param ([weight_c , bias_c ])[param_c_index ]
1399+ optimizer_c = optim_cls (param_c , ** optim_input .kwargs )
1400+ closure_c = functools .partial (fwd_bwd , optimizer_c , weight_c , bias_c , input )
1401+
1402+ # Load the state dict from the original optimizer into the new one
1403+ optimizer_c .load_state_dict (deepcopy (optimizer .state_dict ()))
1404+
1405+ # Run both optimizers in parallel
1406+ for _ in range (10 ):
1407+ if optim_info .step_requires_closure :
1408+ optimizer .step (closure )
1409+ optimizer_c .step (closure_c )
1410+ else :
1411+ closure ()
1412+ closure_c ()
1413+ optimizer .step ()
1414+ optimizer_c .step ()
14031415
1404- self .assertEqual (weight , weight_c )
1405- self .assertEqual (bias , bias_c )
1416+ self .assertEqual (weight , weight_c )
1417+ self .assertEqual (bias , bias_c )
14061418
1407- # Make sure state dict is deterministic with equal (not identical) parameters
1408- self .assertEqual (optimizer .state_dict (), optimizer_c .state_dict ())
1419+ # Make sure state dict is deterministic with equal (not identical) parameters
1420+ # Param names are optional and not needed to be the consistent.
1421+ self .assertEqual (without_param_names (optimizer .state_dict ()),
1422+ without_param_names (optimizer_c .state_dict ()))
14091423
1410- # Make sure repeated parameters have identical representation (see #36831)
1411- optimizer_c .param_groups .extend (optimizer_c .param_groups )
1412- self .assertEqual (
1413- optimizer .state_dict ()["param_groups" ][- 1 ],
1414- optimizer_c .state_dict ()["param_groups" ][- 1 ],
1415- )
1424+ # Make sure repeated parameters have identical representation (see #36831)
1425+ optimizer_c .param_groups .extend (optimizer_c .param_groups )
1426+ self .assertEqual (
1427+ without_param_names ( optimizer .state_dict () )["param_groups" ][- 1 ],
1428+ without_param_names ( optimizer_c .state_dict () )["param_groups" ][- 1 ],
1429+ )
14161430
14171431 @optims (optim_db , dtypes = [torch .float32 ])
14181432 def test_can_load_older_state_dict (self , device , dtype , optim_info ):
@@ -1538,6 +1552,10 @@ def test_save_load_equality_with_weights_only(self, device, dtype, optim_info):
15381552 input = torch .randn (3 , requires_grad = True , device = device , dtype = dtype )
15391553 params = [weight , bias ]
15401554
1555+ def make_param_and_named_param (param ):
1556+ named_param = [(f'name{ i } ' , p ) for i , p in enumerate (param )]
1557+ return [param , named_param ]
1558+
15411559 def fwd_bwd (optim , w , b , i ):
15421560 optim .zero_grad ()
15431561 loss = (w .mv (i ) + b ).pow (2 ).sum ()
@@ -1548,25 +1566,26 @@ def fwd_bwd(optim, w, b, i):
15481566 return loss
15491567
15501568 for optim_input in all_optim_inputs :
1551- optimizer = optim_cls (params , ** optim_input .kwargs )
1552- closure = functools .partial (fwd_bwd , optimizer , weight , bias , input )
1569+ for params_in in make_param_and_named_param (params ):
1570+ optimizer = optim_cls (params_in , ** optim_input .kwargs )
1571+ closure = functools .partial (fwd_bwd , optimizer , weight , bias , input )
15531572
1554- # Prime the optimizer
1555- for _ in range (3 ):
1556- optimizer .step (closure )
1573+ # Prime the optimizer
1574+ for _ in range (3 ):
1575+ optimizer .step (closure )
15571576
1558- sd = optimizer .state_dict ()
1559-
1560- # === Check saved/loaded state_dict are the same (including weights_only load). ===
1561- with tempfile .TemporaryFile () as f :
1562- torch .save (sd , f )
1563- f .seek (0 )
1564- sd_copy = torch .load (f )
1565- self .assertEqual (sd_copy , sd )
1566- del sd_copy
1567- f .seek (0 )
1568- sd_copy_wo = torch .load (f , weights_only = True )
1569- self .assertEqual (sd_copy_wo , sd )
1577+ sd = optimizer .state_dict ()
1578+
1579+ # === Check saved/loaded state_dict are the same (including weights_only load). ===
1580+ with tempfile .TemporaryFile () as f :
1581+ torch .save (sd , f )
1582+ f .seek (0 )
1583+ sd_copy = torch .load (f )
1584+ self .assertEqual (sd_copy , sd )
1585+ del sd_copy
1586+ f .seek (0 )
1587+ sd_copy_wo = torch .load (f , weights_only = True )
1588+ self .assertEqual (sd_copy_wo , sd )
15701589
15711590 @optims (optim_db , dtypes = [torch .float32 ])
15721591 def test_load_nontensor_step (self , device , dtype , optim_info ):
0 commit comments