Skip to content

Conversation

@toothacher17
Copy link

An proof of concept for implementing the Distributed Muon as described in: https://github.com/MoonshotAI/Moonlight

  • Example script: see examples/muon/training.sh

  • Tested with TP=2, PP=2, DP=2 and compared with AdamW, and no TP/PP

  • Used the data from bigscience and the provided example script

img_v3_02jq_52105121-679a-4744-9b77-02645613951g

Copy link

@mactavish91 mactavish91 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
@toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

@toothacher17
Copy link
Author

toothacher17 commented Feb 25, 2025

image @toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

hi, @mactavish91

Thanks a lot for trying out! I actually probably know the reason:

  1. The first question is that you are reporting training loss or validation loss? It's better to observe validation loss rather than training loss.

  2. My next questions is that how many tokens did you train with Muon for your 20k steps? It is likely your trained tokens is already in the over-train setting. Your curve looks like a typical case that Muon trained model is not well weight decayed.

  3. For an over-train setting, as our paper mentioned (https://arxiv.org/pdf/2502.16982), it is important to do weight decay on all parameters, even including the RMSNorm gamma, see Part 2.1 and Appendix D.

  4. However, the default setting for megatron is to set no weight decay for RMSNorm Gamma
    https://github.com/NVIDIA/Megatron-LM/pull/1428/files#diff-b5fac51ecd0148c2f4f8f2f1e64535089e90be87606c1f9357778d05af823220R100

A simple way to hack is to add one line to force lr_mult = 1.0 and wd_mult = 1.0 for all parameters after line 114
image

@toothacher17
Copy link
Author

image @toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

Your exp looks very much like this:

image

Let us know if adding weight decay to all params helps!

@mactavish91
Copy link

image @toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

Your exp looks very much like this:

image Let us know if adding weight decay to all params helps!

Thank you for your kindly help. I used train loss, since the dataset is of a pretrain-level size, it can be approximated as val loss. After 20k steps, approximately 50B tokens have been trained. I tried applying weight decay to all parameters, but it doesn't seem to help much.

@toothacher17
Copy link
Author

image @toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

Your exp looks very much like this:
image
Let us know if adding weight decay to all params helps!

Thank you for your kindly help. I used train loss, since the dataset is of a pretrain-level size, it can be approximated as val loss. After 20k steps, approximately 50B tokens have been trained. I tried applying weight decay to all parameters, but it doesn't seem to help much.

Hi, thanks for sharing. Yeah all your settings looked fine and reasonable. Since it is pretrain level size, reporting pretrain loss is also reasonable. I am not sure what is the root cause that Muon's advantages diminishes, we found muon performed well as long as we adding correct weight decay and adjusting the update rms for matrix's shape.

If it is ok, do you mind sharing your model arch details, and we can try in our code base and data, and see what's going on

@toothacher17
Copy link
Author

image @toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

Your exp looks very much like this:
image
Let us know if adding weight decay to all params helps!

Thank you for your kindly help. I used train loss, since the dataset is of a pretrain-level size, it can be approximated as val loss. After 20k steps, approximately 50B tokens have been trained. I tried applying weight decay to all parameters, but it doesn't seem to help much.

Another thing to debug is to observe your weight rms, max logit, output rms per layer, and update rms during the training and see if there is anything weird that is happening

@hjlee1371
Copy link

Hi, thank you for open-sourcing great job! I have some questions:

  1. It seems that this implementation treats weight matrices for QKV projection as one parameter (linear_qkv of attention), rather than splitting them as in here. Is this intended?
  2. Can you share some intuitions about the effect of splitting WQ, WK, and WV matrices (e.g., treating each head as a separate parameter, as mentioned in the paper)?

@SeunghyunSEO
Copy link

SeunghyunSEO commented Feb 26, 2025

Hi, thank you for open-sourcing great job! I have some questions:

  1. It seems that this implementation treats weight matrices for QKV projection as one parameter (linear_qkv of attention), rather than splitting them as in here. Is this intended?
  2. Can you share some intuitions about the effect of splitting WQ, WK, and WV matrices (e.g., treating each head as a separate parameter, as mentioned in the paper)?

I think the reason megatron uses fused QKV is to reduce memory access and allow the GPU to perform larger matrix multiplications. This seems more like megatron's built-in optimization rather than the moonlight author's intention.
In the original codebase, it looks like they fused the operation as well, but in a slightly different way. but I'm not sure how muon processes this 3D qkv_w tensor.

well, but... idk why using separate weights is better.
If muon is just an approximation of a second-order optimizer like shampoo, shouldn't it perform better when it considers more correlations between the matrices?

@toothacher17
Copy link
Author

toothacher17 commented Feb 26, 2025

Hi, thank you for open-sourcing great job! I have some questions:

  1. It seems that this implementation treats weight matrices for QKV projection as one parameter (linear_qkv of attention), rather than splitting them as in here. Is this intended?
  2. Can you share some intuitions about the effect of splitting WQ, WK, and WV matrices (e.g., treating each head as a separate parameter, as mentioned in the paper)?

Very good questions!

  1. For the moonlight and moonlight-a model, we used MLA, so Q K V are naturally split. Besides, following Keller's blog, it seems that splitting performs better. We recommending split Q K V into three matrices and update them respectively for now;

  2. For splitting Q K V into multiple heads and updating them separately, I think https://leloykun.github.io/ has some experiments. For now we do not split them and update the Q K V all heads together. But in the sec 3.1 of the paper, you can see that Query projection matrix performed very differently comparing to the MLP matrix. While the Update Norm method is strictly controlled RMS to match AdamW, the Adjusted LR method we used here is not. I think there are some room here to further improve it

In general, the concept of 'matrix' might not be well defined in Muon, and for now we relied on empirical results to decide the matrix split

@toothacher17
Copy link
Author

Hi, thank you for open-sourcing great job! I have some questions:

  1. It seems that this implementation treats weight matrices for QKV projection as one parameter (linear_qkv of attention), rather than splitting them as in here. Is this intended?
  2. Can you share some intuitions about the effect of splitting WQ, WK, and WV matrices (e.g., treating each head as a separate parameter, as mentioned in the paper)?

I think the reason megatron uses fused QKV is to reduce memory access and allow the GPU to perform larger matrix multiplications. This seems more like megatron's built-in optimization rather than the moonlight author's intention. In the original codebase, it looks like they fused the operation as well, but in a slightly different way. but I'm not sure how muon processes this 3D qkv_w tensor.

well, but... idk why using separate weights is better. If muon is just an approximation of a second-order optimizer like shampoo, shouldn't it perform better when it considers more correlations between the matrices?

Yeah splitting them into three matrices performed better empirically so we followed. For moonlight, it uses MLA so it is naturally split.

@toothacher17
Copy link
Author

Hi, thank you for open-sourcing great job! I have some questions:

  1. It seems that this implementation treats weight matrices for QKV projection as one parameter (linear_qkv of attention), rather than splitting them as in here. Is this intended?
  2. Can you share some intuitions about the effect of splitting WQ, WK, and WV matrices (e.g., treating each head as a separate parameter, as mentioned in the paper)?

I think the reason megatron uses fused QKV is to reduce memory access and allow the GPU to perform larger matrix multiplications. This seems more like megatron's built-in optimization rather than the moonlight author's intention. In the original codebase, it looks like they fused the operation as well, but in a slightly different way. but I'm not sure how muon processes this 3D qkv_w tensor.

well, but... idk why using separate weights is better. If muon is just an approximation of a second-order optimizer like shampoo, shouldn't it perform better when it considers more correlations between the matrices?

Besides the larger matrix multiplications, another advantage of using QKV fused is that you only need to gather the input between TP group once (if TP and SP are enabled) and used them for projection

@mactavish91
Copy link

image @toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

Your exp looks very much like this:
image
Let us know if adding weight decay to all params helps!

Thank you for your kindly help. I used train loss, since the dataset is of a pretrain-level size, it can be approximated as val loss. After 20k steps, approximately 50B tokens have been trained. I tried applying weight decay to all parameters, but it doesn't seem to help much.

Hi, thanks for sharing. Yeah all your settings looked fine and reasonable. Since it is pretrain level size, reporting pretrain loss is also reasonable. I am not sure what is the root cause that Muon's advantages diminishes, we found muon performed well as long as we adding correct weight decay and adjusting the update rms for matrix's shape.

If it is ok, do you mind sharing your model arch details, and we can try in our code base and data, and see what's going on

The following are the settings I used in the experiment

#!/bin/bash

TEXT_DATA_PATH=""

NAME="1b-1e-3-qknorm-factor2.45-muon"
CHECKPOINT_PATH="checkpoints/${NAME}"
TENSORBOARD_PATH="runs/research/${NAME}"
KEEP_LATEST_CKPT=3  

MICRO_BATCH_SIZE=1
GLOBAL_BATCH_SIZE=1152

TP_SIZE=1
PP_SIZE=1
EP_SIZE=1

MOE_ROUTED_EXPERTS=64
MOE_ACTIVE_ROUTED_EXPERTS=6
MOE_SHARED_EXPERTS=2

NHIDDEN=728
MOE_FFN_HIDDEN=408
MOE_SHARED_EXPERT_INTERMEDIATE_SIZE=$(($MOE_FFN_HIDDEN * $MOE_SHARED_EXPERTS))
FFN_HIDDEN=2176
NLAYERS=18
NHEADS=8

SEQ_LEN=2048

SAVE_INTERVAL=50000

TRAIN_TOKENS=100000000000 # 100B tokens
TRAIN_SAMPLES=$((TRAIN_TOKENS / SEQ_LEN))
LR_DECAY_SAMPLES=$((TRAIN_SAMPLES * 98 / 100))
LR_WARMUP_SAMPLES=$((TRAIN_SAMPLES * 1 / 100))

NCCL_IB_QPS_PER_CONNECTION=2

script_path="pretrain_gpt.py"

OPTIMIZER_ARGS="
    --optimizer muon
    --muon-matched-adamw-rms 0.2
    --adam-beta1 0.9
    --adam-beta2 0.95
    --adam-eps 1e-8
    --lr 1e-3
    --min-lr 1e-4
    --lr-decay-style cosine
    --lr-decay-samples $LR_DECAY_SAMPLES
    --lr-warmup-samples $LR_WARMUP_SAMPLES
    --clip-grad 1.0
    --weight-decay 1e-1
    --hidden-dropout 0.0
    --attention-dropout 0.0
    --initial-loss-scale 65536
"

MOE_ARGS="
    --num-experts $MOE_ROUTED_EXPERTS
    --moe-shared-expert-intermediate-size $MOE_SHARED_EXPERT_INTERMEDIATE_SIZE
    --moe-shared-expert-overlap
    --moe-router-topk $MOE_ACTIVE_ROUTED_EXPERTS
    --moe-grouped-gemm
    --moe-num-first-dense-layers 2
    --moe-ffn-hidden-size $MOE_FFN_HIDDEN
    --expert-model-parallel-size $EP_SIZE
    --moe-permute-fusion
    --moe-router-enable-expert-bias
    --moe-router-bias-update-rate 1e-3
    --expert-balance-factor 0
    --device-balance-factor 0
    --moe-global-batch-balance
    --moe-router-activation-type softmax
    --moe-routed-scaling-factor 2.45
"

MODEL_ARGS="
    --bf16
    --num-layers $NLAYERS
    --hidden-size $NHIDDEN
    --ffn-hidden-size $FFN_HIDDEN
    --seq-length $SEQ_LEN
    --no-interleaved-qkv
    --max-position-embeddings $SEQ_LEN
    --num-attention-heads $NHEADS
    --disable-bias-linear
    --add-qkv-bias
    --rotary-percent 0.5
    --swiglu
    --use-flash-attn
    --transformer-impl transformer_engine
    --untie-embeddings-and-output-weights
    --position-embedding-type rope
    --no-position-embedding
    --normalization RMSNorm
    --use-mcore-models
    --manual-gc
    --kv-channels 128
    --qk-layernorm
"

TRAINING_ARGS="
    --micro-batch-size $MICRO_BATCH_SIZE
    --global-batch-size $GLOBAL_BATCH_SIZE
    --train-samples $TRAIN_SAMPLES
    --tensor-model-parallel-size $TP_SIZE
    --pipeline-model-parallel-size $PP_SIZE
    --use-distributed-optimizer
    --overlap-grad-reduce
"

DATA_ARGS="
    --num-workers 1
    --train-data-path $TEXT_DATA_PATH
"

OUTPUT_ARGS="
    --log-throughput \
    --log-interval 1 \
    --eval-interval 0 \
    --timing-log-level 0 \
    --save-interval $SAVE_INTERVAL \
    --tensorboard-dir $TENSORBOARD_PATH/tensorboard \
    --wandb-save-dir $CHECKPOINT_PATH \
    --wandb-exp-name $NAME \
"

gpt_options="
    $MODEL_ARGS
    $MOE_ARGS
    $TRAINING_ARGS
    $OPTIMIZER_ARGS
    $DATA_ARGS
    $OUTPUT_ARGS
    --distributed-timeout-minutes 20
    --init-method-std 0.006
    --save $CHECKPOINT_PATH
    --load $CHECKPOINT_PATH
    --save-async-fast-checkpoint
"

@toothacher17
Copy link
Author

toothacher17 commented Feb 26, 2025

image @toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

Your exp looks very much like this:
image
Let us know if adding weight decay to all params helps!

Thank you for your kindly help. I used train loss, since the dataset is of a pretrain-level size, it can be approximated as val loss. After 20k steps, approximately 50B tokens have been trained. I tried applying weight decay to all parameters, but it doesn't seem to help much.

Hi, thanks for sharing. Yeah all your settings looked fine and reasonable. Since it is pretrain level size, reporting pretrain loss is also reasonable. I am not sure what is the root cause that Muon's advantages diminishes, we found muon performed well as long as we adding correct weight decay and adjusting the update rms for matrix's shape.
If it is ok, do you mind sharing your model arch details, and we can try in our code base and data, and see what's going on

The following are the settings I used in the experiment

#!/bin/bash

TEXT_DATA_PATH=""

NAME="1b-1e-3-qknorm-factor2.45-muon"
CHECKPOINT_PATH="checkpoints/${NAME}"
TENSORBOARD_PATH="runs/research/${NAME}"
KEEP_LATEST_CKPT=3  

MICRO_BATCH_SIZE=1
GLOBAL_BATCH_SIZE=1152

TP_SIZE=1
PP_SIZE=1
EP_SIZE=1

MOE_ROUTED_EXPERTS=64
MOE_ACTIVE_ROUTED_EXPERTS=6
MOE_SHARED_EXPERTS=2

NHIDDEN=728
MOE_FFN_HIDDEN=408
MOE_SHARED_EXPERT_INTERMEDIATE_SIZE=$(($MOE_FFN_HIDDEN * $MOE_SHARED_EXPERTS))
FFN_HIDDEN=2176
NLAYERS=18
NHEADS=8

SEQ_LEN=2048

SAVE_INTERVAL=50000

TRAIN_TOKENS=100000000000 # 100B tokens
TRAIN_SAMPLES=$((TRAIN_TOKENS / SEQ_LEN))
LR_DECAY_SAMPLES=$((TRAIN_SAMPLES * 98 / 100))
LR_WARMUP_SAMPLES=$((TRAIN_SAMPLES * 1 / 100))

NCCL_IB_QPS_PER_CONNECTION=2

script_path="pretrain_gpt.py"

OPTIMIZER_ARGS="
    --optimizer muon
    --muon-matched-adamw-rms 0.2
    --adam-beta1 0.9
    --adam-beta2 0.95
    --adam-eps 1e-8
    --lr 1e-3
    --min-lr 1e-4
    --lr-decay-style cosine
    --lr-decay-samples $LR_DECAY_SAMPLES
    --lr-warmup-samples $LR_WARMUP_SAMPLES
    --clip-grad 1.0
    --weight-decay 1e-1
    --hidden-dropout 0.0
    --attention-dropout 0.0
    --initial-loss-scale 65536
"

MOE_ARGS="
    --num-experts $MOE_ROUTED_EXPERTS
    --moe-shared-expert-intermediate-size $MOE_SHARED_EXPERT_INTERMEDIATE_SIZE
    --moe-shared-expert-overlap
    --moe-router-topk $MOE_ACTIVE_ROUTED_EXPERTS
    --moe-grouped-gemm
    --moe-num-first-dense-layers 2
    --moe-ffn-hidden-size $MOE_FFN_HIDDEN
    --expert-model-parallel-size $EP_SIZE
    --moe-permute-fusion
    --moe-router-enable-expert-bias
    --moe-router-bias-update-rate 1e-3
    --expert-balance-factor 0
    --device-balance-factor 0
    --moe-global-batch-balance
    --moe-router-activation-type softmax
    --moe-routed-scaling-factor 2.45
"

MODEL_ARGS="
    --bf16
    --num-layers $NLAYERS
    --hidden-size $NHIDDEN
    --ffn-hidden-size $FFN_HIDDEN
    --seq-length $SEQ_LEN
    --no-interleaved-qkv
    --max-position-embeddings $SEQ_LEN
    --num-attention-heads $NHEADS
    --disable-bias-linear
    --add-qkv-bias
    --rotary-percent 0.5
    --swiglu
    --use-flash-attn
    --transformer-impl transformer_engine
    --untie-embeddings-and-output-weights
    --position-embedding-type rope
    --no-position-embedding
    --normalization RMSNorm
    --use-mcore-models
    --manual-gc
    --kv-channels 128
    --qk-layernorm
"

TRAINING_ARGS="
    --micro-batch-size $MICRO_BATCH_SIZE
    --global-batch-size $GLOBAL_BATCH_SIZE
    --train-samples $TRAIN_SAMPLES
    --tensor-model-parallel-size $TP_SIZE
    --pipeline-model-parallel-size $PP_SIZE
    --use-distributed-optimizer
    --overlap-grad-reduce
"

DATA_ARGS="
    --num-workers 1
    --train-data-path $TEXT_DATA_PATH
"

OUTPUT_ARGS="
    --log-throughput \
    --log-interval 1 \
    --eval-interval 0 \
    --timing-log-level 0 \
    --save-interval $SAVE_INTERVAL \
    --tensorboard-dir $TENSORBOARD_PATH/tensorboard \
    --wandb-save-dir $CHECKPOINT_PATH \
    --wandb-exp-name $NAME \
"

gpt_options="
    $MODEL_ARGS
    $MOE_ARGS
    $TRAINING_ARGS
    $OPTIMIZER_ARGS
    $DATA_ARGS
    $OUTPUT_ARGS
    --distributed-timeout-minutes 20
    --init-method-std 0.006
    --save $CHECKPOINT_PATH
    --load $CHECKPOINT_PATH
    --save-async-fast-checkpoint
"

Hi, @mactavish91 your model arch looks reasonable.

For the purpose of debugging, I'll need more monitoring that current open source megatron-lm does not have. So I'll run in our internal infra, with some slight changes:

  1. We will use our own data so the seq-len will be changed from 2048 to 8192. Correspondingly, the bsz will be changed from 1152 to 288
  2. we will not add the attention qk bias
  3. we will update q k v three matrices separately
  4. since we are using moe with auxfree bias and a scaling factor of 2.45, I'll use the sigmoid gate, rather than the softmax gate

Other settings will remain the same as you posted. We'll keep you posted about our findings

@toothacher17
Copy link
Author

toothacher17 commented Feb 26, 2025

@mactavish91

I am running on your two configs right now and not sure about the results yet. But I did some math and probably found out the problem: the model might be too small comparing to its embedding. We have a 160K vocab size, (I am not sure about yours, do you mind sharing it?), so the parameters became:

Total Params:
total = 1,351,721,016
embedding = 768 X 163840 X 2 = 251,658,240
total excluding embedding = 1,351,721,016 - 251,658,240 = 1,100,062,776

Activated Params:
not activated = 18 X 408 X 728 X 3 X 58 = 930,279,168
total activated = 1,351,721,016 - 930,279,168 = 421,441,848
total activated excluding embedding: 1,100,062,776 - 930,279,168 = 169,783,608

So you can see, the model has ~170M non-embedding activated params, about 1.1B non-embedding total and ~252M word embeddings or LM heads. Because the word embeddings and LM heads are updated by the AdamW, so maybe in the long run, there are not too much differences.

I would recommend to try on a larger model as well, for example the 822M one as listed below. We ran on this model with AdamW and Muon for 100B tokens and still see big differences:

image

@mactavish91
Copy link

@mactavish91

I am running on your two configs right now and not sure about the results yet. But I did some math and probably found out the problem: the model might be too small comparing to its embedding. We have a 160K vocab size, (I am not sure about yours, do you mind sharing it?), so the parameters became:

Total Params: total = 1,351,721,016 embedding = 768 X 163840 X 2 = 251,658,240 total excluding embedding = 1,351,721,016 - 251,658,240 = 1,100,062,776

Activated Params: not activated = 18 X 408 X 728 X 3 X 58 = 930,279,168 total activated = 1,351,721,016 - 930,279,168 = 421,441,848 total activated excluding embedding: 1,100,062,776 - 930,279,168 = 169,783,608

So you can see, the model has ~170M non-embedding activated params, about 1.1B non-embedding total and ~252M word embeddings or LM heads. Because the word embeddings and LM heads are updated by the AdamW, so maybe in the long run, there are not too much differences.

I would recommend to try on a larger model as well, for example the 822M one as listed below. We ran on this model with AdamW and Muon for 100B tokens and still see big differences:

image

Our tokenizer size is 150k, and it is very likely the reason behind the issue. I will switch to a 60k tokenizer and increase the hidden size and the number of layers for a new experiment.

@toothacher17
Copy link
Author

toothacher17 commented Feb 26, 2025

@mactavish91
I am running on your two configs right now and not sure about the results yet. But I did some math and probably found out the problem: the model might be too small comparing to its embedding. We have a 160K vocab size, (I am not sure about yours, do you mind sharing it?), so the parameters became:
Total Params: total = 1,351,721,016 embedding = 768 X 163840 X 2 = 251,658,240 total excluding embedding = 1,351,721,016 - 251,658,240 = 1,100,062,776
Activated Params: not activated = 18 X 408 X 728 X 3 X 58 = 930,279,168 total activated = 1,351,721,016 - 930,279,168 = 421,441,848 total activated excluding embedding: 1,100,062,776 - 930,279,168 = 169,783,608
So you can see, the model has ~170M non-embedding activated params, about 1.1B non-embedding total and ~252M word embeddings or LM heads. Because the word embeddings and LM heads are updated by the AdamW, so maybe in the long run, there are not too much differences.
I would recommend to try on a larger model as well, for example the 822M one as listed below. We ran on this model with AdamW and Muon for 100B tokens and still see big differences:
image

Our tokenizer size is 150k, and it is very likely the reason behind the issue. I will switch to a 60k tokenizer and increase the hidden size and the number of layers for a new experiment.

Yeah, that would be better to get rid of the impacts of large embeddings. I am still running the two comparing jobs based on your previous smaller model setting in progress.

Besides increasing the, another thing worth mentioning is to use/report the OOD validation data rather than in domain validation data for a more accurate eval of the model.

@toothacher17
Copy link
Author

toothacher17 commented Feb 27, 2025

hi, @mactavish91 We ran your settings for about ~17K steps by now and for about ~40+B tokens (You mentioned before that ~20K steps, the advantages diminish. Even though with the big embedding issue, I actually think the result is promising. We plot the figure as shown below:

  1. With proper smoothing, we can see the training loss gap of muon is not diminishing
  2. We define a new metric, Muon Leading Steps, to understand how many extra steps that AdamW needs to match Muon's performances
  3. Besides, we can use a simple ratio metric Muon_Leading_Steps/Muon_Trained_Steps to help understand that if Muon is consistently leading

image

@toothacher17
Copy link
Author

toothacher17 commented Feb 27, 2025

For the purpose of reproducing, we provide the script to generate these figures. @mactavish91 Can you help to try on such figures based on your previous small run data as well?

if "validation" or 'training' in tag:
    # smooth the data by emw
    ewm_alpha = 0.005
    muon_data = muon_data.ewm(alpha=ewm_alpha).mean()
    adam_data = adam_data.ewm(alpha=ewm_alpha).mean()

    # Subsample both dataframes to ~1000 rows for cleaner plotting
    num_samples = 1000
    stride = max(1, len(muon_data) // num_samples)
    muon_data = muon_data.iloc[::stride]
    adam_data = adam_data.iloc[::stride]

# columns = wall_time, step, value

# assuming steps are sorted
# while losses are not sorted, but it is generally decreasing, and the lower the loss the better
# for each step of muon loss, find the smallest step of adam loss that is smaller than the muon loss
# plot the step difference between the two steps, that is how much steps does adam need to take to match the muon loss
import matplotlib.pyplot as plt


def find_matching_step(target_value: float, reference_df: pd.DataFrame) -> int:
    """Find the first step where reference loss is lower than target value.

    Args:
        target_value: Loss value to match or beat
        reference_df: DataFrame containing step and value columns

    Returns:
        Step number where reference loss first beats target, or max step if never beats
    """
    mask = reference_df["value"] <= target_value
    if not mask.any():
        return None
    return reference_df.loc[mask, "step"].iloc[0]


step_differences = []
for _, muon_row in muon_data.iterrows():
    muon_step = muon_row["step"]
    muon_loss = muon_row["value"]
    adam_matching_step = find_matching_step(muon_loss, adam_data)
    if adam_matching_step is None:
        break
    step_diff = adam_matching_step - muon_step
    step_differences.append((muon_step, step_diff))

step_diff_df = pd.DataFrame(step_differences, columns=["muon_step", "step_difference"])
num_plot_rows = 3
fig, (ax1, ax2, ax3) = plt.subplots(num_plot_rows, 1, figsize=(10, 6 * num_plot_rows))

# Plot losse
muon_run_name = muon_run.split("/")[0]
adam_run_name = adam_run.split("/")[0]
ax1.plot(muon_data["step"], muon_data["value"], label=f"muon: {muon_run_name}")
ax1.plot(adam_data["step"], adam_data["value"], label=f"adam: {adam_run_name}")
ax1.set_xlabel("Training Steps")
#ax1.set_ylabel("Validation Loss Value")
ax1.set_ylabel("Training Loss Value")
ax1.set_title(
    f"Training Loss Comparison\ncollection={collection_id}\ntag={tag}\nsmoothing: {ewm_alpha}, subsampling: {num_samples}"
)

ylim_min = min(muon_data["value"].min(), adam_data["value"].min()) - 0.02
ylim_max = ylim_min + 0.5
ax1.set_ylim(ylim_min, ylim_max)

ax1.grid(True)
ax1.legend()

# Plot step differences
ax2.plot(step_diff_df["muon_step"], step_diff_df["step_difference"])
ax2.set_xlabel("Muon Training Steps")
ax2.set_ylabel("Additional Steps Needed by AdamW")
ax2.set_title("Steps AdamW Needs to Match Muon Performance")
ax2.grid(True)

# plot step diff/step
leading_step_ratio = step_diff_df["step_difference"] / step_diff_df["muon_step"]
ax3.plot(step_diff_df["muon_step"], leading_step_ratio)
ax3.set_xlabel("Muon Training Steps")
ax3.set_ylabel("Muon leading step ratio")
ax3.set_title("Muon leading step ratio")
ax3.set_ylim(0, leading_step_ratio.max() * 1.1)
ax3.grid(True)

plt.tight_layout()

Here adam_data or muon_data is the run we fetched from the TB, tag is simply the 'lm-loss-training/lm loss'

image

@toothacher17
Copy link
Author

@mactavish91 Besides, we also evaluated on OOD lm validation loss data, and it showed pretty good results

image

@toothacher17
Copy link
Author

@mactavish91 We'll wait for you visualization results and see how it goes! Thanks!

@SeunghyunSEO
Copy link

hi guys, let me share my vibe check results.
i tested small scale proxy model with 64*4096=262k batch tokens and 40k horizon, so they consumed 10.5B tokens.
my model config is like standard parameterization (SP) with 0.2 std, GQA, not separated QKV, lr 0.00195, weight decay 0.1 (didn't decay rmsnorm gamma) for 12 layers with 2 width (hidden size), 1024 and 4096.
the larger one (4096) is approximately 3.5B.
and here are my results.

for smaller one, muon outperforms adamw.

Screenshot 2025-02-27 at 2 44 29 PM

and for larger model, muon looks promising too but there is some issue that adamw diverges.

Screenshot 2025-02-27 at 2 44 42 PM

however the real problem is that the throughput of muon is bad at multi node setup.
i used 4 node A100 for larger model, and it's throughput seems it needs to be optimized.
i think gradient all-gather for NS should be overlapped or grad bucketing should be carefully designed.

Screenshot 2025-02-27 at 2 44 54 PM

@toothacher17
Copy link
Author

@SeunghyunSEO Thanks for sharing! I have some comments regarding your runs:

  1. For muon's performances, this is what we actually expected to see! Thanks for sharing! It would be better stability if you norm your weight decay gamma (as mentioned in the appendix) and maybe your adamw lr is too big so it actually does not converge

  2. For the throughput issue, this is probably because the distributed optimizer implementation changes. We noticed this gap when porting our internal impl to the open sourced one. Previously in Megatron-LM, the distributed optimizer states are concat together and flatten into a list (the concatting order is defined by the params init order). Then the list is split into DP parts. So only those params (very few actually if you think about it) that are split in the DP boundary will need the extra gather.

However, the current impl of distributed optimizer is first to group in several params into a bucket. And every params in that bucket will be split into DP parts and needs a gather! Thus, bringing the extra needed all gather to its upperbound, which means every params in every rank needs to a gather. For distributed muon to work efficiently as described in our paper. We need the original way of DP sharding optimizer states, which only requires very limited params to do the extra gathering

@SeunghyunSEO
Copy link

SeunghyunSEO commented Feb 27, 2025

@toothacher17 wow, your response is as fast as the speed of light, lol. I didn’t even know that megatron changed its sharding logic. (I’m also familiar with the sharding strategy you mentioned in point 2.) I’ll dig into the codebase and come back if I find any clues to improve performance.

edited) can you share related PR for refactoring param and grad bucketing? I'm not sure this one is right.

@SeunghyunSEO
Copy link

SeunghyunSEO commented Feb 28, 2025

@toothacher17 ty for sharing!
oh it's just layerwise output norm.
i mean i want to log both per param activation and grad norm, but it's kinda messy.
i know there is more clean way like lingua's probing module, so just want to ask if he use clean and efficient logging module for megatron :)

@spliew
Copy link

spliew commented Mar 5, 2025

Question About Implementation Differences

Hi, thanks for open sourcing this interesting work!

I was comparing the implementations with the original one , and I noticed that your implementation seems to differ. I wanted to understand the reasoning behind your design choices and whether they affect the final results.

Your Implementation (from this PR):

state = self.state[p]
if not "muon_buffer" in state:
    state["muon_buffer"] = torch.zeros_like(g)
buf = state["muon_buffer"]
buf.mul_(momentum).add_(g)

# save to ns input
g = g.add(buf, alpha=momentum) if group['nesterov'] else buf
ns_inputs[p] = g.bfloat16()

Original Implementation:

g = p.grad
assert g is not None
state = self.state[p]
if "momentum_buffer" not in state:
    state["momentum_buffer"] = torch.zeros_like(g)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf

By letting $v$ denote the buffer, and $\gamma$ denote momentum, notice that

  • Your implementation uses buf.mul_(momentum).add_(g), which corresponds to $v \leftarrow \gamma v + g$ .
  • The original implementation updates buf via lerp_(g, 1 - momentum), which is $v \leftarrow \gamma v + (1-\gamma)g$.
  • Your implementation applies g.add(buf, alpha=momentum) if group['nesterov'] is True, which is $g \leftarrow g+\gamma v $.
  • The original implementation applies g.lerp_(buf, group["momentum"]), which is $g \leftarrow (1-\gamma)g+\gamma v$

@github-actions
Copy link
Contributor

github-actions bot commented May 9, 2025

Marking as stale. No activity in 60 days.

@github-actions github-actions bot added the stale No activity in 60 days on issue or PR label May 9, 2025
@toothacher17
Copy link
Author

toothacher17 commented Jun 6, 2025

Question About Implementation Differences

Hi, thanks for open sourcing this interesting work!

I was comparing the implementations with the original one , and I noticed that your implementation seems to differ. I wanted to understand the reasoning behind your design choices and whether they affect the final results.

Your Implementation (from this PR):

state = self.state[p]
if not "muon_buffer" in state:
    state["muon_buffer"] = torch.zeros_like(g)
buf = state["muon_buffer"]
buf.mul_(momentum).add_(g)

# save to ns input
g = g.add(buf, alpha=momentum) if group['nesterov'] else buf
ns_inputs[p] = g.bfloat16()

Original Implementation:

g = p.grad
assert g is not None
state = self.state[p]
if "momentum_buffer" not in state:
    state["momentum_buffer"] = torch.zeros_like(g)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf

By letting v denote the buffer, and γ denote momentum, notice that

Hey, @spliew Thanks for pointing out. I think you are right, they are different implementation, and I think maybe Keller's current impl is closer to outer world's momentum, we should double verify this and run some expertiments. I think current implementation should be copied from an earlier version of open source work and it is working pretty well so we kept it. But maybe using Keller's current equation can yield even better results

To conclude, since nesterov is true by default,

Keller's impl actually gives this (where b is past momentum and g is grad, b' is updated momentumn and u is the update before NS:

b' = m * b + (1-m) * g
u = m^2 * b + (1 - m^2) * g

And our impl actually gives this:

b' = m * b + g
u = m^2 * b + (1 + m) * g

While both equations will lead to EMA on momentum, MS's impl actually gives 'much more weights' on current step gradient, and I do not think it is 'that' stable. Surprisingly, due to the NS, it seems to still work fine

@toothacher17 toothacher17 reopened this Jun 6, 2025
@toothacher17
Copy link
Author

ah, @spliew I know why it looks like this, see Keller's blog (https://kellerjordan.github.io/posts/muon/):

image

So I guess an earlier version is actually what we implemented, where b' = m * b + g. Since the core idea is for the NS to get semi-orthogonization, maybe the momentum is updated in later versions based on Keller's experiments

@toothacher17
Copy link
Author

toothacher17 commented Jun 6, 2025

@spliew thanks for pointing out. This is actually interesting, we should experiment with both I think. The new one seems to be more stable at first glance as it favors historical momentum more

@github-actions github-actions bot removed the stale No activity in 60 days on issue or PR label Jun 6, 2025
@jaesuny
Copy link

jaesuny commented Jun 13, 2025

@toothacher17
Thank you for this great work!

I have a question.

If it’s important to apply weight decay to RMSNorm as well,
I’m curious whether you used zero_centered_gamma, a form of weight decay that regularizes the RMSNorm scaling parameter toward 1.

instance = te.pytorch.RMSNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)

If this wasn’t used in your previous experiments, do you think applying it could help improve stability?

@toothacher17
Copy link
Author

@toothacher17 Thank you for this great work!

I have a question.

If it’s important to apply weight decay to RMSNorm as well, I’m curious whether you used zero_centered_gamma, a form of weight decay that regularizes the RMSNorm scaling parameter toward 1.

instance = te.pytorch.RMSNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)

If this wasn’t used in your previous experiments, do you think applying it could help improve stability?

Yeah, I tried both and both version are stable as long as you apply weight decay. The difference is that with 'zero-centered-gamma', you will have much smaller weight decay in the beginning phase of the training.

Directly applying weight decay on the original version yields slightly better results so I choose to use the original version.

@jaesuny
Copy link

jaesuny commented Jul 15, 2025

@toothacher17

I have also found that the original version performs better.
Thank you very much for your response! It was very helpful.

@switiz
Copy link

switiz commented Jul 23, 2025

It seems that the latest commit to Megatron-LM is less optimized for speed compared to the original Adam implementation.
Have others also mentioned that it's slower?

@skyw
Copy link
Contributor

skyw commented Jul 25, 2025

We are evaluating options to support Muon and other preconditioning based method (Shampoo and Soap for example) at scale.
There a lot of things to consider for algorithmic and training throughput reasons. For example, Muon has those coefficients in Newton Schulz iterations (https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py#L44) that we need to choose. There are also many choices in distributed optimizer and FSDP to support it efficiently. Before this point, optimizer step is point wise can communication in distributing optimizer is uniform.

@toothacher17
Copy link
Author

It seems that the latest commit to Megatron-LM is less optimized for speed compared to the original Adam implementation. Have others also mentioned that it's slower?

There is a discussion here that why current bucketing or fsdp dim-0 sharding zero1 distributed optimizer hurt the performance of this Impl:

https://www.zhihu.com/question/1927140506573435010/answer/1927378524513219780

@toothacher17
Copy link
Author

We are evaluating options to support Muon and other preconditioning based method (Shampoo and Soap for example) at scale. There a lot of things to consider for algorithmic and training throughput reasons. For example, Muon has those coefficients in Newton Schulz iterations (https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py#L44) that we need to choose. There are also many choices in distributed optimizer and FSDP to support it efficiently. Before this point, optimizer step is point wise can communication in distributing optimizer is uniform.

So exicting to see that Megatron-LM is interested in adding Muon!!! Actually PyTorch team is also interested in integrating it into FSDP:
pytorch/pytorch#148819
pytorch/pytorch#159465

In general, it would be nice to allow customized adjust LR method and MSign func, like this PR (pytorch/pytorch#159465), so researchers can customize their methods, like you said, different a,b,c or adjusting lr like Moonshot did for Standardized Parameterization or adjusting lr following a MuP setting.

However, I believe the main challenge is not how to perform the NS after having the full matrix (researchers will have different methods), but how to efficiently obtain the full matrix, which is heavily dependent on the framework itself. After the bucketing or dim-0 sharding of Zero-1 optimizer implementation, current PR's impl might not be efficient anymore, and requires Nvidia team's further support for a better unsharding mechanism.

I am happy to chat more and help contribute to bring Muon into Megatron-LM if needed!

@sbhavani
Copy link
Contributor

sbhavani commented Aug 1, 2025

@toothacher17 we'd be happy to chat more and collaborate!

Comment on lines +354 to +355
tp_split_dim = -1 if getattr(model_param, 'tensor_model_parallel', False) else \
getattr(model_param, 'partition_dim')
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these if and else blocks be swapped?

Suggested change
tp_split_dim = -1 if getattr(model_param, 'tensor_model_parallel', False) else \
getattr(model_param, 'partition_dim')
tp_split_dim = getattr(model_param, 'partition_dim') if getattr(model_param, 'tensor_model_parallel', False) else \
-1

@ftgreat
Copy link

ftgreat commented Sep 16, 2025

This pull request does not seem to include the logic in MuonClip that clips certain parameters.

Could you help confirm this? Thanks. @toothacher17

# init dist group
for i in range(tp_size):
ranks = range(i, world_size, tp_size)
group = dist.new_group(ranks)
Copy link

@JenWei0312 JenWei0312 Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should group = dist.new_group(ranks) be inside the for loop? like

# init dist group
for i in range(tp_size):
    ranks = range(i, world_size, tp_size)
    if rank in ranks:
        group = dist.new_group(ranks) # <-- Only called once per process.
        dist_group = group

# init tp group
for i in range(dp_size):
ranks = range(i * tp_size, (i + 1) * tp_size)
group = dist.new_group(ranks)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should group = dist.new_group(ranks) be inside the for loop?

@BoxiangW
Copy link
Contributor

We have merged 879a7a1 into dev branch if you want to test it out. Emerging-Optimizers package installation is needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.