Skip to content

Implement Maximal Update Parametrization (muP) #16157

@thegregyang

Description

@thegregyang

🚀 Feature request

This request is to open up a discussion on 1) whether it makes sense to implement Maximal Update Parametrization (abbreviated muP) in Huggingface, 2) if so, how to do it.

Motivation

Hi,

I'm a maintainer for the mup package (paper). This repo allows one to implement in their models a special parametrization called maximal update parametrization, or muP, that has the special property that narrow and wide networks share the same optimal hyperparameters (like learning rate, initialization, etc). This is demonstrated below on a Transformer trained with adam, where on the left we have the pytorch default parametrization and the right we have muP.
image
Most strikingly, this property can be used to tune hyperparameters for extremely large neural networks like GPT-3 that is too expensive to train more than once, by just tuning a tiny version of it. But even for "regular joe" users, muP can alleviate a lot of the pain when transitioning from exploration to scaling up and finding performance suffer for mysterious reasons. Transformers in particular is somewhat infamous for problems like training instability. So having muP integrated natively into Huggingface can benefit a lot of users at once.

muP can be implemented in a backward compatible way, as shown below, so users do not need to worry about it breaking existing codebases.

See this twitter thread for more (but brief) information about how this works, and this blog post for less brief overview.

Your contribution

Now let's return to the two questions at the beginning: 1) whether it makes sense to implement Maximal Update Parametrization (abbreviated muP) in Huggingface, 2) if so, how to do it.

For 1), the popularity (or not) of this issue should serve as an indicator of community interest, and the above makes the case for the utility of this integration.

For 2), we have examples of how to integrate muP with some common (PyTorch) Huggingface transformers in our mutransformers repo.

Current Example Implementation

In summary, to modify an existing Huggingface transformer to implement muP, one needs to

  1. Switch any readout layer (dimensions: width -> number of labels) from nn.Linear to mup.MuReadout.
  2. Modify the _init_weights method to use mup.init.* methods instead of nn.init.* methods (or equivalent).
  3. Scale the attention logits like 1/d instead of 1/sqrt(d)
  4. Use mup.AdamW instead of the pytorch or Huggingface version.

In addition, when using a mutransformer, one needs to provide a "base shape file" that lets the model know how to properly scale the learning rate and attention with width. This is designed so that if the model parameter shapes are the same as the "base shapes", then the model is in the original parametrization, i.e. backward compatible.

from mutransformers import BertConfig, BertForMaskedLM
# instantiate model
model = BertForMaskedLM(config=BertConfig(...))
# set base shapes
set_base_shapes(model, path_to_base_shape_file)
# re-initialize
model.apply(model._init_weights)

More Seamless Integration

Now, the mutransformers repo is primarily designed to serve as examples of how to implement muP into existing transformers. So all of the above can be streamlined if we really want seamless integration into Huggingface.

For example, the user interface for instantiating a model could just be the same as it is now, but we just have an additional flag mup=True in BertConfig that says to switch on mup. BertConfig itself may carry a default set of base shapes for use in this scenario, which the user can also modify if necessary.

# the model automatically sets base shapes based on defaults in BertConfig
# no need to re-initialize either
model = BertForMaskedLM(config=BertConfig(mup=True,...))
# use model immediately, e.g., train

In addition, mup.MuAdamW can be incorporated natively into Huggingface as well, so that there is no dependency on the mup package at all.

muP for All Transformers?

As, currently, there is no automatic way of backfitting existing transformers, it could be quite a task to add muP to all of the transformers in Huggingface. So a good practical compromise is to just implement muP for the most commonly used models in Huggingface.

In the interim, research can be done on a method of such automatic backfitting. This could even involve a pull request into PyTorch core.

Conclusion

Again, this issue is intended to start the discussion of whether and how to make muP available to Huggingface users natively. It could be that the best course forward is to have users implement muP transformers themselves as in mutransformers, or even to build mutransformers into such a repo of muP transformers. And even if we do decide to integrate muP into Huggingface, there could be many ways to do it.

I hope discussion here could elucidate the right course of action.

Metadata

Metadata

Assignees

No one assigned

    Labels

    WIPLabel your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions