Skip to content

[Extensions] ensure extension default values from extension toml are used, not base class #1244

Merged
lessw2020 merged 3 commits intomainfrom
lessw2020/fix_jobconfig_extension_processing
May 30, 2025
Merged

[Extensions] ensure extension default values from extension toml are used, not base class #1244
lessw2020 merged 3 commits intomainfrom
lessw2020/fix_jobconfig_extension_processing

Conversation

@lessw2020
Copy link
Contributor

@lessw2020 lessw2020 commented May 29, 2025

This PR:
Ensures that JobConfig extension values are included in the final merged JobConfig, rather than just the custom fields and their values all (incorrectly) set to zero.
Fix has gone from closure to a one line adjustment (credit @tianyu-l )
Unit testing has been increased to verify and catch any future breaks like this (credit @jaysonfrancis )

Testing - verified all unit tests passing (with added assert for custom field update) and expert parallel extension use case working.
Example:

@dataclass
class Parallelism:
    expert_parallel_degree: int = 2
    """ degree to parallelize experts """

results in

parallelism=MergedParallelism(... expert_parallel_degree=0)

which is not what is desired.

With PR:

parallelism=MergedParallelism(... expert_parallel_degree=2)

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 29, 2025
@lessw2020 lessw2020 requested review from fegin, kwen2501 and tianyu-l May 29, 2025 22:58
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the original approach not work? Can you update the docstring to describe why field(default_factor=) not works?

@lessw2020
Copy link
Contributor Author

Why does the original approach not work?

When merging nested dataclasses (like Parallelism), the orig code created a new merged type... but didn't preserve the default values from the custom config.

It creates a new merged dataclass type correctly but ...when creating the default_factory for fields, it just used field(default_factory=m_type).
This creates a new instance of the merged type with base defaults, not the correct/custom ones

So when I defined expert_parallel_degree = 2 for example, that value was lost during the merge process, and you get the default value of 0 instead.

The fix uses a closure to capture both the merged type and custom field, ensuring the custom defaults are preserved when creating instances of the merged type.

@lessw2020
Copy link
Contributor Author

lessw2020 commented May 30, 2025

re: Can you update the docstring to describe why field(default_factor=) not works?

I can but imo, I don't think it's beneficial to list code history/explanation changes in a doc string - doc strings should be focused on current code usage for the user.
Github maintains the history of code changes if anyone wants to see/review changes.

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think relying on github summary is the correct way to maintain the code readability. It's easy to get lost where is the code from after many refactors. It is reasonable to add a comment, not necessary docstring, like "Using field() cannot support the use cases for ....,".

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @jaysonfrancis for a review on this PR

@lessw2020
Copy link
Contributor Author

I don't think relying on github summary is the correct way to maintain the code readability. It's easy to get lost where is the code from after many refactors. It is reasonable to add a comment, not necessary docstring, like "Using field() cannot support the use cases for ....,".

I see - for code readability, I will add an explanatory comment above the closure to clarify why it's needed. I think that makes more sense that putting it in the docstring.

# We do this by using a closure to capture both the merged type and custom fields,
# ensuring the custom defaults are preserved when creating instances of the merged type.
# Previously, we were using a default_factory that would create an instance of the merged type,
# but this would not preserve the custom defaults as it would use base class defaults.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm a bit confused by what is the "base class defaults" and where it comes from.
If we already recursively use the default from c_map[name], how could it not use the default in c_map? This is assuming we already modify the line below
from

result.append((name, f.type, field(default_factory=f.type)))

to

result.append((name, f.type, f))

@kwen2501
Copy link
Contributor

If it makes things easier, why don't we just add an expert_parallel_degree field in the "upstream" JobConfig?

@jaysonfrancis
Copy link
Contributor

jaysonfrancis commented May 30, 2025

Ah sorry I missed this!

Looks like your change to result.append((name, f.type, f)) fixes this.

@dataclass
class Training:
steps: int = 99
my_custom_steps: int = 32

Forgot to add this assert here to check:

++ assert config.training.my_custom_steps == 32

https://github.com/pytorch/torchtitan/blob/main/tests/unit_tests/test_job_config.py#L283

@lessw2020
Copy link
Contributor Author

ok let me update to the simpler fix, verify it works for expert parallel case, and add in the assert in the testing to close this out.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, glad that we have a simple fix!

@lessw2020
Copy link
Contributor Author

I'll update above in the main PR comments but for closure:
the simpler fix works for both the unit testing and for my expert parallel use case. (thanks @tianyu-l for the improved solution).
updated the unit testing (thanks @jaysonfrancis for the pointer where to add) and verified all tests passing.

@lessw2020 lessw2020 changed the title [Extensions] add closure to ensure extension default values from extension toml are used, not base class [Extensions] ensure extension default values from extension toml are used, not base class May 30, 2025
@lessw2020
Copy link
Contributor Author

CI gpu failure is not related

@lessw2020 lessw2020 merged commit 29fb3f9 into main May 30, 2025
5 of 6 checks passed
@lessw2020 lessw2020 deleted the lessw2020/fix_jobconfig_extension_processing branch May 30, 2025 20:31
lessw2020 added a commit that referenced this pull request May 30, 2025
This PR implements a core 'real' training loop in that it runs
deepseekv2 model using a number of Titan components to train on real
(C4) data with adamW and displays initial training loop metrics.

There is a lot more to be done but the goal here is to get a true
training loop going from which additional PRs will then improve upon it.

<img width="1192" alt="Screenshot 2025-05-29 at 7 41 01 PM"
src="https://github.com/user-attachments/assets/36ae2ff1-aa99-42c9-8b97-1e0a1ef8376e"
/>

A couple key highlights:
a - the model is now controllable via toml or cmd line just like Titan
main. Note that the expert parallel control is waiting for PR
#1244 to land...atm it just
manually puts ep to 2.
b - we use the HF deepseek tokenizer and as a result I had to make a
wrapper to deal with the bos and eos params passed by Titan.
c - loss metrics, tps, etc are displaying but MFU and tflops need to be
updated.

A lot more improvements will come shortly but for now want to land this
to ensure our base deepseek training loop is available to iterate on.
wwwjn pushed a commit to wwwjn/torchtitan that referenced this pull request Jun 2, 2025
…used, not base class (pytorch#1244)

This PR:
Ensures that JobConfig extension values are included in the final merged
JobConfig, rather than just the custom fields and their values all
(incorrectly) set to zero.
Fix has gone from closure to a one line adjustment (credit @tianyu-l )
Unit testing has been increased to verify and catch any future breaks
like this (credit @jaysonfrancis )

Testing - verified all unit tests passing (with added assert for custom
field update) and expert parallel extension use case working.
Example:
~~~
@DataClass
class Parallelism:
    expert_parallel_degree: int = 2
    """ degree to parallelize experts """
~~~
results in 
~~~
parallelism=MergedParallelism(... expert_parallel_degree=0)
~~~
 which is not what is desired. 
 
 With PR:
 ~~~
parallelism=MergedParallelism(... expert_parallel_degree=2)
~~~
wwwjn pushed a commit to wwwjn/torchtitan that referenced this pull request Jun 2, 2025
This PR implements a core 'real' training loop in that it runs
deepseekv2 model using a number of Titan components to train on real
(C4) data with adamW and displays initial training loop metrics.

There is a lot more to be done but the goal here is to get a true
training loop going from which additional PRs will then improve upon it.

<img width="1192" alt="Screenshot 2025-05-29 at 7 41 01 PM"
src="https://github.com/user-attachments/assets/36ae2ff1-aa99-42c9-8b97-1e0a1ef8376e"
/>

A couple key highlights:
a - the model is now controllable via toml or cmd line just like Titan
main. Note that the expert parallel control is waiting for PR
pytorch#1244 to land...atm it just
manually puts ep to 2.
b - we use the HF deepseek tokenizer and as a result I had to make a
wrapper to deal with the bos and eos params passed by Titan.
c - loss metrics, tps, etc are displaying but MFU and tflops need to be
updated.

A lot more improvements will come shortly but for now want to land this
to ensure our base deepseek training loop is available to iterate on.
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
…used, not base class (pytorch#1244)

This PR:
Ensures that JobConfig extension values are included in the final merged
JobConfig, rather than just the custom fields and their values all
(incorrectly) set to zero.
Fix has gone from closure to a one line adjustment (credit @tianyu-l )
Unit testing has been increased to verify and catch any future breaks
like this (credit @jaysonfrancis )

Testing - verified all unit tests passing (with added assert for custom
field update) and expert parallel extension use case working.
Example:
~~~
@DataClass
class Parallelism:
    expert_parallel_degree: int = 2
    """ degree to parallelize experts """
~~~
results in 
~~~
parallelism=MergedParallelism(... expert_parallel_degree=0)
~~~
 which is not what is desired. 
 
 With PR:
 ~~~
parallelism=MergedParallelism(... expert_parallel_degree=2)
~~~
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
This PR implements a core 'real' training loop in that it runs
deepseekv2 model using a number of Titan components to train on real
(C4) data with adamW and displays initial training loop metrics.

There is a lot more to be done but the goal here is to get a true
training loop going from which additional PRs will then improve upon it.

<img width="1192" alt="Screenshot 2025-05-29 at 7 41 01 PM"
src="https://github.com/user-attachments/assets/36ae2ff1-aa99-42c9-8b97-1e0a1ef8376e"
/>

A couple key highlights:
a - the model is now controllable via toml or cmd line just like Titan
main. Note that the expert parallel control is waiting for PR
pytorch#1244 to land...atm it just
manually puts ep to 2.
b - we use the HF deepseek tokenizer and as a result I had to make a
wrapper to deal with the bos and eos params passed by Titan.
c - loss metrics, tps, etc are displaying but MFU and tflops need to be
updated.

A lot more improvements will come shortly but for now want to land this
to ensure our base deepseek training loop is available to iterate on.
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
…used, not base class (pytorch#1244)

This PR:
Ensures that JobConfig extension values are included in the final merged
JobConfig, rather than just the custom fields and their values all
(incorrectly) set to zero.
Fix has gone from closure to a one line adjustment (credit @tianyu-l )
Unit testing has been increased to verify and catch any future breaks
like this (credit @jaysonfrancis )

Testing - verified all unit tests passing (with added assert for custom
field update) and expert parallel extension use case working.
Example:
~~~
@DataClass
class Parallelism:
    expert_parallel_degree: int = 2
    """ degree to parallelize experts """
~~~
results in 
~~~
parallelism=MergedParallelism(... expert_parallel_degree=0)
~~~
 which is not what is desired. 
 
 With PR:
 ~~~
parallelism=MergedParallelism(... expert_parallel_degree=2)
~~~
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
This PR implements a core 'real' training loop in that it runs
deepseekv2 model using a number of Titan components to train on real
(C4) data with adamW and displays initial training loop metrics.

There is a lot more to be done but the goal here is to get a true
training loop going from which additional PRs will then improve upon it.

<img width="1192" alt="Screenshot 2025-05-29 at 7 41 01 PM"
src="https://github.com/user-attachments/assets/36ae2ff1-aa99-42c9-8b97-1e0a1ef8376e"
/>

A couple key highlights:
a - the model is now controllable via toml or cmd line just like Titan
main. Note that the expert parallel control is waiting for PR
pytorch#1244 to land...atm it just
manually puts ep to 2.
b - we use the HF deepseek tokenizer and as a result I had to make a
wrapper to deal with the bos and eos params passed by Titan.
c - loss metrics, tps, etc are displaying but MFU and tflops need to be
updated.

A lot more improvements will come shortly but for now want to land this
to ensure our base deepseek training loop is available to iterate on.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Extensions] JobConfig extensions processing does not propagate actual values, sets to zero

6 participants