-
Notifications
You must be signed in to change notification settings - Fork 5.9k
【Flexcheckpoint】add_get_var_mapping_chain_macro #76013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Flexcheckpoint】add_get_var_mapping_chain_macro #76013
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
6b57f39 to
efcdf60
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #76013 +/- ##
==========================================
Coverage ? 95.40%
==========================================
Files ? 4
Lines ? 87
Branches ? 0
==========================================
Hits ? 83
Misses ? 4
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| For example: | ||
| - reverse=False: temp_var -> dst_key | ||
| - reverse=True: temp_var -> src_key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reverse=True: temp_var -> src_key 如何理解
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个意思是,记录temp_var是从哪个src_key转换过来了,记录它的来源信息,则当以下情景时:
aoa_statements = [
"layers.0.gate_up_fused_proj.weight^T -> temp_var \n",
"temp_var -> new_name_layers.0.gate_up_fused_proj.weight,fused_ffn\n",
]
我们调用fused_ffn,而此时箭头左侧需要找到temp_var是从哪个src_key转换过来的,从而找到src_key对应的来源信息
| mapping_dict = self.left_var_to_right_var_mapping | ||
|
|
||
| while current_key in mapping_dict: | ||
| if current_key in visited: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是遇到了环,为什么可以直接return
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里想的是,正确的aoa配置一般不会出现环的情况,出现的情景一般如下:
aoa_statements = [
"src_key -> dist_key \n",
]
或
aoa_statements = [
"src_key^T -> A\n",
"A -> A \n",
]
即,src_key 和dist_key同名,或用户设置的中间变量和dst同名(src同理),则会出现dst_key:dst_key(或src_key:src_key)的映射,即传入的本身就是dst_key或src_key,则需要中断防止一直陷入循环,并返回传入的这个key。
| visited.add(current_key) | ||
|
|
||
| mapped_vars = mapping_dict[current_key] | ||
| if mapped_vars and len(mapped_vars) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里如果是 a-> b,c或者 a,b->c这种场景,切分信息还能往下传递吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以的,因为这里保存的value是列表,例如:
a ->b,c
则对于left_var_to_right_var_mapping保存的是:
{
a:[b,c]
}
而对于right_var_from_left_var_mapping保存的是:
{
b:[a]
c:[a]
}
若a是src_key,b,c是中间变量,需要找到b或c来源的src_key的切分信息时,访问right_var_from_left_var_mapping即可映射到;同理若a是中间变量,b,c是dst_key时,直接访问left_var_to_right_var_mapping映射,得到[b,c]列表即可,并且访问列表的第[0]个元素即可,因为二者作为dst_key在统一操作下,则他们在dst中的切分信息也应该是相同的。
|
|
||
|
|
||
| @macro(name='get_var_mapping_chain_macro', priority=3) | ||
| def get_var_mapping_chain_macro(tokens, expression, context): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
所有macro看似被这个macro分成了两类,在这个macro之前是不是所有匹配src 或 dst中的key的macros都要展开?如果是这样是不是在代码上加一下限制比较好,不然容易出错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对的,这里就是通过priority来控制的,让此macro在所有展开操作的macro之后做,priority应该是4,这里当时修改掉了,后续会改正
|
/re-run all-failed |
|
LGTM |
f390065 to
96db255
Compare
96db255 to
234ac3e
Compare
|
/re-run all-failed |
1 similar comment
|
/re-run all-failed |
|
/re-run all-failed |
1 similar comment
|
/re-run all-failed |
|
LGTM |
|
/re-run all-failed |
1 similar comment
|
/re-run all-failed |
|
/re-run all-failed |
1 similar comment
|
/re-run all-failed |
fe96e7a to
2293d95
Compare
|
/re-run all-failed |
1 similar comment
|
/re-run all-failed |
* add_get_var_mapping_chain_macro * add note * fix the bug input_vars and resolve_mapping_chain * fix the code style * fit the dtype assert bug * fix the bug * fix the merge_sharded_state_dict bug
* add_get_var_mapping_chain_macro * add note * fix the bug input_vars and resolve_mapping_chain * fix the code style * fit the dtype assert bug * fix the bug * fix the merge_sharded_state_dict bug
…#76252) * 【FlexCheckpoint】fix_the_layer_id_macro (#75556) * fix_the_layer_id_macro * fix the ctest * add expert_id_macro * fix the assert bug * fix the code style * Pr support load hf checkpoint (#75928) * support hf checkpoint fix support cast add id macro fix * add test and fix some bug * fix full param bug * add full param cast test --------- Co-authored-by: xingmingyyj <[email protected]> * 【Flexcheckpoint】add_get_var_mapping_chain_macro (#76013) * add_get_var_mapping_chain_macro * add note * fix the bug input_vars and resolve_mapping_chain * fix the code style * fit the dtype assert bug * fix the bug * fix the merge_sharded_state_dict bug * fix aoa transpose corner case (#76234) --------- Co-authored-by: Tianyu Zheng <[email protected]>
….2 (#76249) * 【FlexCP】merge_sharded_state_dict support distribute merge (#75005) * fix data is nullptr * add dist merge * change test * change test * 【FlexCP】add Skip param param for merge_shard_state_dict (#75061) * fix data is nullptr * add dist merge * change test * change test * add skip optimizer param * [Flex CP]Fix merge_sharded_state_dict with aoa and offload (#75062) * fix merge_state_dict with aoa and offload * add tests * refine * fix * fix * add log * fix * fix * 【FlexCheckpoint】Upgrade some macros and optimize load_state_dict communication (#75282) * upgrad macros and load_state_dict comm task fix fix support 0-d tensor fix balance save and fix * fix test * Add the test about the sharded_state_dict of optimizer (#75067) * fix the share_weight_bug * add note * add the unit test * set the timeout * add more test * Trigger CI rebuild * fix the CmakeLists * handle_missing_edge_cases_in_fc (#75413) * up_grade fc (#75613) fix and add test fix fix fix fix cmakelists add notion * 【FlexCheckpoint】fix_the_layer_id_macro (#75556) * fix_the_layer_id_macro * fix the ctest * add expert_id_macro * fix the assert bug * fix the code style * Pr support load hf checkpoint (#75928) * support hf checkpoint fix support cast add id macro fix * add test and fix some bug * fix full param bug * add full param cast test --------- Co-authored-by: xingmingyyj <[email protected]> * 【Flexcheckpoint】add_get_var_mapping_chain_macro (#76013) * add_get_var_mapping_chain_macro * add note * fix the bug input_vars and resolve_mapping_chain * fix the code style * fit the dtype assert bug * fix the bug * fix the merge_sharded_state_dict bug * fix aoa transpose corner case (#76234) --------- Co-authored-by: xiaoguoguo626807 <[email protected]> Co-authored-by: Chen Zhiyang <[email protected]> Co-authored-by: Tianyu Zheng <[email protected]>
PR Category
User Experience
PR Types
New features
Description