Skip to content

Commit 3f0abc3

Browse files
committed
Adding BC tests: same results with named_parameters, and satet_dict save-load equality
1 parent 18afbe2 commit 3f0abc3

File tree

1 file changed

+76
-57
lines changed

1 file changed

+76
-57
lines changed

test/test_optim.py

Lines changed: 76 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,6 +1356,16 @@ def test_state_dict_deterministic(self, device, dtype, optim_info):
13561356
input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
13571357
params = [weight, bias]
13581358

1359+
def make_param_and_named_param(param):
1360+
named_param = [(f'name{i}', p) for i, p in enumerate(param)]
1361+
return [param, named_param]
1362+
1363+
def without_param_names(state_dict):
1364+
new_state_dict = deepcopy(state_dict)
1365+
for pg in new_state_dict['param_groups']:
1366+
pg.pop('param_names', None)
1367+
return new_state_dict
1368+
13591369
def fwd_bwd(optim, w, b, i):
13601370
optim.zero_grad()
13611371
loss = (w.mv(i) + b).pow(2).sum()
@@ -1368,51 +1378,55 @@ def fwd_bwd(optim, w, b, i):
13681378
return loss
13691379

13701380
for optim_input in all_optim_inputs:
1371-
optimizer = optim_cls(params, **optim_input.kwargs)
1372-
closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
1373-
1374-
# Prime the optimizer
1375-
for _ in range(10):
1376-
if optim_info.step_requires_closure:
1377-
optimizer.step(closure)
1378-
else:
1379-
closure()
1380-
optimizer.step()
1381+
for param_in in make_param_and_named_param(params):
1382+
optimizer = optim_cls(param_in, **optim_input.kwargs)
1383+
closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
13811384

1382-
# Clone the weights and construct a new optimizer for them
1383-
with torch.no_grad():
1384-
weight_c = Parameter(weight.clone())
1385-
bias_c = Parameter(bias.clone())
1386-
1387-
optimizer_c = optim_cls([weight_c, bias_c], **optim_input.kwargs)
1388-
closure_c = functools.partial(fwd_bwd, optimizer_c, weight_c, bias_c, input)
1389-
1390-
# Load the state dict from the original optimizer into the new one
1391-
optimizer_c.load_state_dict(deepcopy(optimizer.state_dict()))
1385+
# Prime the optimizer
1386+
for _ in range(10):
1387+
if optim_info.step_requires_closure:
1388+
optimizer.step(closure)
1389+
else:
1390+
closure()
1391+
optimizer.step()
13921392

1393-
# Run both optimizers in parallel
1394-
for _ in range(10):
1395-
if optim_info.step_requires_closure:
1396-
optimizer.step(closure)
1397-
optimizer_c.step(closure_c)
1398-
else:
1399-
closure()
1400-
closure_c()
1401-
optimizer.step()
1402-
optimizer_c.step()
1393+
for param_c_index in range(2):
1394+
# Clone the weights and construct a new optimizer for them
1395+
with torch.no_grad():
1396+
weight_c = Parameter(weight.clone())
1397+
bias_c = Parameter(bias.clone())
1398+
param_c = make_param_and_named_param([weight_c, bias_c])[param_c_index]
1399+
optimizer_c = optim_cls(param_c, **optim_input.kwargs)
1400+
closure_c = functools.partial(fwd_bwd, optimizer_c, weight_c, bias_c, input)
1401+
1402+
# Load the state dict from the original optimizer into the new one
1403+
optimizer_c.load_state_dict(deepcopy(optimizer.state_dict()))
1404+
1405+
# Run both optimizers in parallel
1406+
for _ in range(10):
1407+
if optim_info.step_requires_closure:
1408+
optimizer.step(closure)
1409+
optimizer_c.step(closure_c)
1410+
else:
1411+
closure()
1412+
closure_c()
1413+
optimizer.step()
1414+
optimizer_c.step()
14031415

1404-
self.assertEqual(weight, weight_c)
1405-
self.assertEqual(bias, bias_c)
1416+
self.assertEqual(weight, weight_c)
1417+
self.assertEqual(bias, bias_c)
14061418

1407-
# Make sure state dict is deterministic with equal (not identical) parameters
1408-
self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
1419+
# Make sure state dict is deterministic with equal (not identical) parameters
1420+
# Param names are optional and not needed to be the consistent.
1421+
self.assertEqual(without_param_names(optimizer.state_dict()),
1422+
without_param_names(optimizer_c.state_dict()))
14091423

1410-
# Make sure repeated parameters have identical representation (see #36831)
1411-
optimizer_c.param_groups.extend(optimizer_c.param_groups)
1412-
self.assertEqual(
1413-
optimizer.state_dict()["param_groups"][-1],
1414-
optimizer_c.state_dict()["param_groups"][-1],
1415-
)
1424+
# Make sure repeated parameters have identical representation (see #36831)
1425+
optimizer_c.param_groups.extend(optimizer_c.param_groups)
1426+
self.assertEqual(
1427+
without_param_names(optimizer.state_dict())["param_groups"][-1],
1428+
without_param_names(optimizer_c.state_dict())["param_groups"][-1],
1429+
)
14161430

14171431
@optims(optim_db, dtypes=[torch.float32])
14181432
def test_can_load_older_state_dict(self, device, dtype, optim_info):
@@ -1538,6 +1552,10 @@ def test_save_load_equality_with_weights_only(self, device, dtype, optim_info):
15381552
input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
15391553
params = [weight, bias]
15401554

1555+
def make_param_and_named_param(param):
1556+
named_param = [(f'name{i}', p) for i, p in enumerate(param)]
1557+
return [param, named_param]
1558+
15411559
def fwd_bwd(optim, w, b, i):
15421560
optim.zero_grad()
15431561
loss = (w.mv(i) + b).pow(2).sum()
@@ -1548,25 +1566,26 @@ def fwd_bwd(optim, w, b, i):
15481566
return loss
15491567

15501568
for optim_input in all_optim_inputs:
1551-
optimizer = optim_cls(params, **optim_input.kwargs)
1552-
closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
1569+
for params_in in make_param_and_named_param(params):
1570+
optimizer = optim_cls(params_in, **optim_input.kwargs)
1571+
closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
15531572

1554-
# Prime the optimizer
1555-
for _ in range(3):
1556-
optimizer.step(closure)
1573+
# Prime the optimizer
1574+
for _ in range(3):
1575+
optimizer.step(closure)
15571576

1558-
sd = optimizer.state_dict()
1559-
1560-
# === Check saved/loaded state_dict are the same (including weights_only load). ===
1561-
with tempfile.TemporaryFile() as f:
1562-
torch.save(sd, f)
1563-
f.seek(0)
1564-
sd_copy = torch.load(f)
1565-
self.assertEqual(sd_copy, sd)
1566-
del sd_copy
1567-
f.seek(0)
1568-
sd_copy_wo = torch.load(f, weights_only=True)
1569-
self.assertEqual(sd_copy_wo, sd)
1577+
sd = optimizer.state_dict()
1578+
1579+
# === Check saved/loaded state_dict are the same (including weights_only load). ===
1580+
with tempfile.TemporaryFile() as f:
1581+
torch.save(sd, f)
1582+
f.seek(0)
1583+
sd_copy = torch.load(f)
1584+
self.assertEqual(sd_copy, sd)
1585+
del sd_copy
1586+
f.seek(0)
1587+
sd_copy_wo = torch.load(f, weights_only=True)
1588+
self.assertEqual(sd_copy_wo, sd)
15701589

15711590
@optims(optim_db, dtypes=[torch.float32])
15721591
def test_load_nontensor_step(self, device, dtype, optim_info):

0 commit comments

Comments
 (0)