Skip to content

[deepseek] update deepseek to real training loop, part 1#1233

Merged
lessw2020 merged 14 commits intomainfrom
lessw2020/ds_training_1_of_4
May 30, 2025
Merged

[deepseek] update deepseek to real training loop, part 1#1233
lessw2020 merged 14 commits intomainfrom
lessw2020/ds_training_1_of_4

Conversation

@lessw2020
Copy link
Contributor

@lessw2020 lessw2020 commented May 28, 2025

This PR implements a core 'real' training loop in that it runs deepseekv2 model using a number of Titan components to train on real (C4) data with adamW and displays initial training loop metrics.

There is a lot more to be done but the goal here is to get a true training loop going from which additional PRs will then improve upon it.

Screenshot 2025-05-29 at 7 41 01 PM

A couple key highlights:
a - the model is now controllable via toml or cmd line just like Titan main. Note that the expert parallel control is waiting for PR #1244 to land...atm it just manually puts ep to 2.
b - we use the HF deepseek tokenizer and as a result I had to make a wrapper to deal with the bos and eos params passed by Titan.
c - loss metrics, tps, etc are displaying but MFU and tflops need to be updated.

A lot more improvements will come shortly but for now want to land this to ensure our base deepseek training loop is available to iterate on.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 28, 2025
@lessw2020 lessw2020 requested a review from kwen2501 May 30, 2025 04:38
@lessw2020 lessw2020 changed the title [WIP][deepseek] update deepseek to real training loop, part 1 [deepseek] update deepseek to real training loop, part 1 May 30, 2025
Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the long haul! Nice demo!

Comment on lines +20 to +21
@dataclass
class TransformerModelArgs(BaseModelArgs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What motivates this new class than the existing one(s) in model_config.py?

)

# Synthetic setting
microbatches = pp_size * 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does JobConfig have a field for microbatches?

TrainSpec(
name="deepseek3",
cls=DeepseekForCausalLM,
config=deepseek_configs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not a problem of this PR, but I think LHS config should point to a single model config, rather than a dictionary of model configs, e.g.

config=deepseek_debug_config,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps no need to put it under an infra folder? I don't see how they are related.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll move it - I was mostly just trying to keep some similarity with the llama4 layout.

Comment on lines 41 to 42
# Use DeepSeek-V2-Lite as a proxy
model_id = "deepseek-ai/DeepSeek-V2-Lite"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe model_id should come from JobConfig?

Comment on lines 143 to 151
proxy_parallel_dims = ParallelDims(
dp_replicate=ep_size,
dp_shard=fsdp_dim,
pp=pp_size,
cp=1,
tp=1,
world_size=world_mesh.size(),
enable_loss_parallel=False,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a duplicate of information of what DeviceMesh or config.parallelism would carry.

Comment on lines 115 to 119
pp_size,
pp_rank,
pp_mesh,
ep_size,
ep_rank,
Copy link
Contributor

@kwen2501 kwen2501 May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logically, aren't these pre-known, or known easier via:

device_mesh.get_rank(dim="pp")

?

Comment on lines +49 to +53
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=get_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to make more sense to directly invoke these "build_..." functions in train.py. JobConfig should be more for int, str, float, etc.
If changed, the imports of these functions at top can be moved away too. (would make __init__.py much cleaner imo)

@lessw2020
Copy link
Contributor Author

There are additional layout related feedback issues above, but I'm planning to address these remaining ones in part 2 and land part 1 now, so that pending users can actively start using the training loop while I update the other items which are not functionally related.

@lessw2020
Copy link
Contributor Author

GPU CI failure is not related (hit exact same error on earlier PR, which also was unrelated).

@lessw2020 lessw2020 merged commit d0ed9b4 into main May 30, 2025
6 of 7 checks passed
@lessw2020 lessw2020 deleted the lessw2020/ds_training_1_of_4 branch May 30, 2025 23:21
wwwjn pushed a commit to wwwjn/torchtitan that referenced this pull request Jun 2, 2025
This PR implements a core 'real' training loop in that it runs
deepseekv2 model using a number of Titan components to train on real
(C4) data with adamW and displays initial training loop metrics.

There is a lot more to be done but the goal here is to get a true
training loop going from which additional PRs will then improve upon it.

<img width="1192" alt="Screenshot 2025-05-29 at 7 41 01 PM"
src="https://github.com/user-attachments/assets/36ae2ff1-aa99-42c9-8b97-1e0a1ef8376e"
/>

A couple key highlights:
a - the model is now controllable via toml or cmd line just like Titan
main. Note that the expert parallel control is waiting for PR
pytorch#1244 to land...atm it just
manually puts ep to 2.
b - we use the HF deepseek tokenizer and as a result I had to make a
wrapper to deal with the bos and eos params passed by Titan.
c - loss metrics, tps, etc are displaying but MFU and tflops need to be
updated.

A lot more improvements will come shortly but for now want to land this
to ensure our base deepseek training loop is available to iterate on.
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
This PR implements a core 'real' training loop in that it runs
deepseekv2 model using a number of Titan components to train on real
(C4) data with adamW and displays initial training loop metrics.

There is a lot more to be done but the goal here is to get a true
training loop going from which additional PRs will then improve upon it.

<img width="1192" alt="Screenshot 2025-05-29 at 7 41 01 PM"
src="https://github.com/user-attachments/assets/36ae2ff1-aa99-42c9-8b97-1e0a1ef8376e"
/>

A couple key highlights:
a - the model is now controllable via toml or cmd line just like Titan
main. Note that the expert parallel control is waiting for PR
pytorch#1244 to land...atm it just
manually puts ep to 2.
b - we use the HF deepseek tokenizer and as a result I had to make a
wrapper to deal with the bos and eos params passed by Titan.
c - loss metrics, tps, etc are displaying but MFU and tflops need to be
updated.

A lot more improvements will come shortly but for now want to land this
to ensure our base deepseek training loop is available to iterate on.
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
This PR implements a core 'real' training loop in that it runs
deepseekv2 model using a number of Titan components to train on real
(C4) data with adamW and displays initial training loop metrics.

There is a lot more to be done but the goal here is to get a true
training loop going from which additional PRs will then improve upon it.

<img width="1192" alt="Screenshot 2025-05-29 at 7 41 01 PM"
src="https://github.com/user-attachments/assets/36ae2ff1-aa99-42c9-8b97-1e0a1ef8376e"
/>

A couple key highlights:
a - the model is now controllable via toml or cmd line just like Titan
main. Note that the expert parallel control is waiting for PR
pytorch#1244 to land...atm it just
manually puts ep to 2.
b - we use the HF deepseek tokenizer and as a result I had to make a
wrapper to deal with the bos and eos params passed by Titan.
c - loss metrics, tps, etc are displaying but MFU and tflops need to be
updated.

A lot more improvements will come shortly but for now want to land this
to ensure our base deepseek training loop is available to iterate on.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants