Skip to content

Conversation

@ErezYosef
Copy link
Contributor

A proposal addressing Issue #1489: Optimizer should track parameter names and not id.

(also mentioned in here: [RFC] Introducing FQNs/clarity eyeglasses to optim state_dict

Summary

This PR introduces a backward-compatible enhancement where optimizers track parameter names instead of just their id.
Optimizers can be initialized with named_parameters() as:

optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)

This allows for greater clarity and ease when handling optimizers, as the parameters' names are preserved within the optimizer’s state_dict as:

state_dict = 
{
    'state': {
    0: {'momentum_buffer': tensor(...), ...},
    1: {'momentum_buffer': tensor(...), ...},
    },
    'param_groups': [
        {
        'lr': 0.01,
        'weight_decay': 0,
        ...
        'params': [0,1]
        'param_names' ['layer.weight', 'layer.bias']  (optional)
        }
    ]
}

Loading state_dict is not changed (backward-compatible) and the param_names key will be ignored.

Key Features

Named Parameters in Optimizer Initialization:

Optimizers can accept the output of model.named_parameters() during initialization, allowing them to store parameter names directly.

Parameter Names in state_dict:

The parameter names are saved as a list in the optimizer’s state_dict with key param_names, alongside the params indices, ensuring seamless tracking of both names and parameters.

Backward Compatibility

No Breaking Changes:

This change is fully backward-compatible. The added param_names key in the optimizer's state_dict is ignored when loading a state to the optimizer.

Customization with Hooks:

For more control, the loaded state_dict can be modified using a custom register_load_state_dict_pre_hook, providing flexibility for different design needs.

Documentation Updates

Please refer to the documentation changes for more details on how this feature is implemented and how it can be used effectively.

Solution Example:

A suggested solution to the problem mentioned in #1489, for the same parameters but in a different order.
The following register_load_state_dict_pre_hook should be added to the optimizer before loading to enable loading the state dict :

def adapt_state_dict_ids(optimizer, state_dict):
    # assuming a single param group. 
    current_state_group = optimizer.state_dict()['param_groups'][0]
    loaded_state_group = state_dict['param_groups'][0]

    # same number of params, same names, only different ordering
    current_state_name_to_id_mapping = {}  # mapping --  param_name: id
    for i, name in enumerate(current_state_group['param_names']):
        current_state_name_to_id_mapping[name] = current_state_group['params'][i]

    # changing the ids of the loaded state dict to match the order of the given state dict.
    for i, name in enumerate(current_state_group['param_names']):
        loaded_state_group['params'][i] = current_state_name_to_id_mapping[name]

    return state_dict

In this code, the loaded state_dict ids are adapted to match the order of the current optimizer state_dict.
Both the previous and the current optimizers are required to be initiated with named_parameters() to have the 'param_names' key in the dict.

Note

This is my first contribution to PyTorch, and I wish to receive feedback or suggestions for improvement.

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134107

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8f3560a with merge base de4c2a3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Aug 21, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

@vadimkantorov
Copy link
Contributor

@ErezYosef
Copy link
Contributor Author

Is this related?

Issue #72146 mentions this problem of missing names, but suggests changing the state_dict['param_groups']['params'] from list to dict. It is an option but backward compatibility and the ordering of the parameters should be checked.

Also this problem is mentioned in Issue 71683 comment 1019456452 as said: "provide param_names list inside the param group" , as suggested here.

@janeyx99
Copy link
Contributor

hi @ErezYosef! Thanks for the well-written proposal and implementation attempt. I agree this is a good direction to go towards, where we gently enable accepting named_parameters without breaking BC.

Before I review your PR in full, it'd be essential to add testing and docs:

  • Add tests in test/test_optim.py to verify behavior. For example, I would expect that an optimizer given the named_parameters would be saved and checkpointed properly. We'd also want to make sure that nothing semantically changes for all our implementations when parameters are named_parameters().
  • Add documentation for each optimizer in their respective files (though should mostly be modularized in optimizer.py).
  • Add docs to optim.rst to show the use case people would use this for

@ErezYosef
Copy link
Contributor Author

Thanks @janeyx99 .
I pushed the changes you requested.
Please review the test I added, I hope it is good.

Erez.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Aug 23, 2024

Maybe also this is related:

could be an alternative, e.g. when .parameters() is called it would e.g. it could return param.with_name(name) which could then be used

With named_parameters() support of optimizers, maybe it's less pressing

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 23, 2024
@ErezYosef
Copy link
Contributor Author

Maybe also this is related:

could be an alternative, e.g. when .parameters() is called it would e.g. it could return param.with_name(name) which could then be used

With named_parameters() support of optimizers, maybe it's less pressing

It looks like a very useful feature that could enhance debugging and the overall experience.
Once it is fully characterized and implemented the optimizers can also support it (according to the specific design choices that need to be made).

However, the current PR contributes to the existing code as it is, and enables more flexibility (without having major changes).
Therefore, I believe it can be merged independently of any potential future features.

Erez.

@ErezYosef
Copy link
Contributor Author

Hi @janeyx99

I added a few more tests + explanations and an example to present how to use the new feature.
Please check these tests and verify their correctness.

Thanks,
Erez.

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this review took me so long--I just got back from hyperfocusing on the conference + traveling to the conference and am prioritizing responding to GH notifications now! I would like to get your PR in by the end of this week if possible.

Thanks for your work and for adding tests--I've added some more highlevel comments throughout. Mainly, I'm curious how far we want to support:

  1. Do we want to support mixed names + no names within a param group? (I'm thinking no)
  2. Do we want to support param groups with names + param groups without names within one optimizer? (Maybe, but if there's no real use case for this, I'd also argue to stay as simple as we can).
  3. What would you guess is the main use case for using named parameters? Is it for easier filtering later? For debugging? For sharding the statedict? Whatever that is, I'd like to showcase an example in the docs.
  4. What is the expected behavior of:
optim = AdamW(model.named_parameters())
named_sd = optim.state_dict()

optim2 = AdamW(model.parameters())
optim2.load_state_dict(named_sd)

and

optim = AdamW(model.parameters())
unnamed_sd = optim.state_dict()

optim2 = AdamW(model.named_parameters())
optim2.load_state_dict(unnamed_sd)
  1. For each of the above, I'd love to get a simple test case going exhibiting the behavior. For example, if we want to error, we should add combinations of mixed names/no names in get_error_inputs_for_all_optims.

@ErezYosef
Copy link
Contributor Author

Thanks so much for your detailed review, @janeyx99, and no worries about the timing!

Regarding your questions:

  1. I completely agree with you. I’ll ensure that we don’t mix parameters with and without names within a parameter group.
  2. Keeping things simple sounds like a good approach, and I can enforce consistent behavior across all groups. However, this would introduce a constraint that may not be strictly necessary.
  3. Named parameters primarily offer flexibility for various use cases, including easier filtering, debugging, and managing the state_dict, as you mentioned. I’ll add examples to showcase how this can be useful.
    The example I added in optim.rst (involving fc1 and fc2) is similar to the scenario that led me to open this PR (training of multi-experts).
    It can also be used to apply something like "strict=False" strict keyword doesn't exist in the optimizer's load_state_dict #3852 (see added example)
  4. The behavior in these cases would ignore the names in the state dict (for backward compatibility), treating them as the optimizer was initialized with model.parameters(). I’ll add tests to cover this as you suggested.

I’ll make sure to include test cases to capture these behaviors, including scenarios where we expect an error if mixed names are used.

Thanks again,
Erez.

@janeyx99
Copy link
Contributor

janeyx99 commented Oct 2, 2024

Ah thank you. Let me know when this PR is ready for another review--I've kicked off the CI for now.

@ErezYosef
Copy link
Contributor Author

Ah thank you. Let me know when this PR is ready for another review--I've kicked off the CI for now.

Thanks,
I assume these commits address all your points.

Addressing the points:

  1. I added the raise:
        if len(extracted_param_names) != 0:
            if len(extracted_param_names) == len(extracted_param_tensors):
                param_group['param_names'] = extracted_param_names
            else:
                raise ValueError("all optimizer params should be with/without names. Some param names are missing")
  1. I added:
                ....
                raise ValueError("all optimizer param groups should be with/without names. "
                                 f'cannot add param group {current_group_txt} to the optimizer')
  1. Examples were added to the doc
  2. I add a test for the behavior:
            # Make sure that param_names are preserved when provided to at least one of the optimizers
            if is_named_optim0 or is_named_optim1:
                self.assertEqual(optimizer2.state_dict()['param_groups'][0]['param_names'],
                                 ['0.weight', '0.bias', '1.weight', '1.bias'])

Feel free to let me know if there's anything else to add.
Erez.

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really good! CI looks like a mess though--could you rebase and let's see if these tests all pass.

Comment on lines 16 to 17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parameters (all should be :class:`~torch.nn.Parameter` s) to optimize or named parameters
(tuples of (str, :class:`~torch.nn.Parameter`)). Then,
parameters (all should be :class:`~torch.nn.Parameter` s) or named parameters
(tuples of (str, :class:`~torch.nn.Parameter`)) to optimize. Then,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
--------------------------------------------
------------------------------------------------------------

this needs to match the line above

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If ``param_names`` exist in loaded state dict ``param_groups`` they will be saved or will override
If ``param_names`` exist in loaded state dict ``param_groups`` they will be saved and override

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea i agree with you this might be an unnecessary constraint but if someone wants this feature later it shouldn't be awful to add it. thanks for the error here!

@ErezYosef
Copy link
Contributor Author

Thanks.
I am not sure how the CI works. Hope it will work easily after a rebase.

Do you want me to rebase on upstream viable/strict? or which branch?
should it be forced?

Do you know if these commands will work?

git pull --rebase upstream viable/strict
git push -f

Erez.

@janeyx99
Copy link
Contributor

janeyx99 commented Oct 6, 2024

Rebasing onto viable/strict or main would both work. the commands you wrote should be fine—i typically go for a git fetch followed by a git rebase.

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approval contingent on green CI :D (and that you address the nits on wording)

and optimizer states from ``model`` into both ``fc1`` and ``fc2`` of ``model2``
(and adjust them accordingly)::
Let's say that ``model`` implements an expert (MoE), and we want to duplicate it and resume training
for two experts, both initialized the same way as the ``fc`` layer. For the following ``model2`` wecreate two layers identical to ``fc`` and resume training by loading the model weights and optimizer states from ``model`` into both ``fc1`` and ``fc2`` of ``model2`` (and adjust them accordingly)::
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for two experts, both initialized the same way as the ``fc`` layer. For the following ``model2`` wecreate two layers identical to ``fc`` and resume training by loading the model weights and optimizer states from ``model`` into both ``fc1`` and ``fc2`` of ``model2`` (and adjust them accordingly)::
for two experts, both initialized the same way as the ``fc`` layer. For the following ``model2`` we create two layers identical to ``fc`` and resume training by loading the model weights and optimizer states from ``model`` into both ``fc1`` and ``fc2`` of ``model2`` (and adjust them accordingly)::

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit 9d7a208

@janeyx99
Copy link
Contributor

janeyx99 commented Oct 9, 2024

CI failures are real! hud.pytorch.org/pr/134107. Please verify the tests pass locally: python test/test_optim.py -k <test_name> before committing. lmk if you're stuck on any errors, the SparseAdam ones you can follow the example from the other tests to make the grads sparse 😛

@ErezYosef
Copy link
Contributor Author

Thansk.

  1. I can see how to make grad sparse. but it is a kind of copy-paste of test_can_load_older_state_dict so why is it works on the original test?

  2. Not sure I understand this error:

X linux-jammy-py3.9-gcc11 / test (docs_test, 1, 1, linux.2xlarge)
    Queued: 0.0s, Duration: 10.0m, Test insights, Show artifacts, Raw logs

▼ /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/jit/_trace.py:791: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:

/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/jit/_trace.py:791: TracerWarning: 
Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:

2024-10-09T16:01:47.8645453Z Mismatched elements: 4 / 12 (33.3%)
2024-10-09T16:01:47.8646178Z Greatest absolute difference: 0.39807265996932983 at index (0, 2) (up to 1e-05 allowed)
2024-10-09T16:01:47.8647146Z Greatest relative difference: 1.0307998656385509 at index (0, 1) (up to 1e-05 allowed)
2024-10-09T16:01:47.8647768Z   _check_trace(
2024-10-09T16:01:47.8679892Z /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/jit/_trace.py:791: TracerWarning: Trace had nondeterministic nodes. Did you forget call .eval() on your model? Nodes:
2024-10-09T16:01:47.8681483Z 	%14 : Float(1, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::rand(%9, %10, %11, %12, %13) # <doctest default[0]>:2:0
2024-10-09T16:01:47.8682670Z This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
2024-10-09T16:01:47.8683459Z   _check_trace(
2024-10-09T16:01:47.8685149Z /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/jit/_trace.py:791: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
2024-10-09T16:01:47.8686536Z Tensor-likes are not close!

What is its source and maybe you have a clue how to solve it?

Thanks,
Erez.

@janeyx99
Copy link
Contributor

janeyx99 commented Oct 9, 2024

Ah, it's skipped in our skip list: https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/common_optimizers.py#L2163-L2169 I think skipping it is fine as well, though I'd prefer it be in the test case, e.g.,

@optims([o for o in optims_db if not o.only_supports_sparse_grads])
def test...

@janeyx99
Copy link
Contributor

janeyx99 commented Oct 9, 2024

For the docs test, the error in jit is a red herring. The more relevant line is

2024-10-09T16:18:53.1562928Z /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/optim/optimizer.py:docstring of torch.optim.optimizer.Optimizer.load_state_dict:7: WARNING: Block quote ends without a blank line; unexpected unindent.

@janeyx99
Copy link
Contributor

For lint, I'd recommend pip install lintrunner and then running lintrunner --init; lintrunner -a to get all of lint to be fixed for ci.

@ErezYosef
Copy link
Contributor Author

For lint, I'd recommend pip install lintrunner and then running lintrunner --init; lintrunner -a to get all of lint to be fixed for ci.

Thanks, @janeyx99. However, I'm running into some issues with this command, possibly related to its dependencies.

@ErezYosef
Copy link
Contributor Author

Looks like we have green CI. 💯
What should I do next?

@janeyx99
Copy link
Contributor

Wonderful!

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 14, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

How to utilize named parameters to load optimizer state dict
------------------------------------------------------------

The function :func:`~Optimizer.load_state_dict` stores the optional ``param_names``content from the
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably missing whitespace after param_names

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if u want to open another pr to address, i'd approve it

pytorchmergebot pushed a commit that referenced this pull request Oct 18, 2024
### Description:

This PR addresses a minor [formatting issue identified in a previous contribution to the Optimizer documentation](#134107 (comment)).

Specifically, it fixes the missing whitespace after `param_names` in the section on utilizing named parameters to load the optimizer state dict.

You can find the related docs here:
[Optimizer Documentation](https://pytorch.org/docs/main/optim.html#how-to-utilize-named-parameters-to-load-optimizer-state-dict).

@janeyx99

Pull Request resolved: #138321
Approved by: https://github.com/janeyx99
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants