-
Notifications
You must be signed in to change notification settings - Fork 32.7k
RFC: Integrating bitsandbytes 8-bit optimizer / adding Embedding Norm #14819
Description
🚀 Feature request
-
BNB AdamW Optimizer: https://github.com/facebookresearch/bitsandbytes created by @TimDettmers uses 8-bit quantization technique, which allows to reduce memory usage for the AdamW optimizer from 8 bytes to 2 bytes, which is a huge memory saving and I think our users will benefit a lot from it.
-
Additionally, we discovered that one of BNB's components, Embedding Norm, on its own made a huge improvement to the training stability of large models @bigscience.
Therefore this is a 2-features in one request.
Performance
We did experiments at BigScience for 104B model and while we didn't have a chance to run it through a full training to the end, BNB was performing on par with the normal AdamW quality-wise.
I'm currently also running a full 1.3B model training with embed norm to compare scaling laws with the same training w/o embed norm. Should be finished in a few days.
Tech
This technology comes in 2 components.
- 8-bit quantization optimizer
- required Embedding Norm
The optimizer itself is a drop-in replacement for Adam:
import bitsandbytes as bnb
optim = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=8)
but there is an important requirement of using Embed norm, which is needed to ensure training stability, which we currently don't have.
In fact for BigScience we discovered that adding Embed norm on its own and w/o BNB made a huge difference to training stability and we are most likely going to enable it in the 200B gpt model training, as the current 104B gpt model results are the best when embed norm is enabled. So once we release the 200B model most likely we want the Embed norm in transformers for the custom architecture of that model.
Embedding norm currently appears to be a new default for google and openai models according to Tim.
BNB comes with StableEmbedding which replaces nn.Embedding
So the only integration that is needed on the HF side (other than adding --optim=adamw_bnb to HF Trainer) is to add an embed norm and config option to have it enabled or not. It also wants xavier_uniform init, but that's a minor detail.
Finetuning
For existing pre-trained transformers models one could use them as is and use 8-bit optimizers for all weights, but 32-bit optimizers for the embedding layer. This will improve stability for fine-tuning. Tim shared that for GLUE fine-tuning, it is fine to have 8-bit optimizers for the embedding layer, but in general 32-bit should be more stable.
Pretraining
For pretraining it would make sense to implement the full stable embedding layer. i.e. add a configurable embed norm at the end of Embedding.forward. Here we would want to implement it ourselves rather than re-use StableEmbedding from BNB, so that we can easily load any model from the hub without depending on BNB, after it was trained with BNB.
We obviously can't make this a default for all our models, but perhaps we can consider starting enabling this for some models where we know it makes a huge difference - or at least to recommend to.
@TimDettmers, please let me know if I missed something or you'd like to add anything to my summary here. Thank you!
Comments are welcome.