【Flexcheckpoint】add_get_var_mapping_chain_macro#76013
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
6b57f39 to
efcdf60
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #76013 +/- ##
==========================================
Coverage ? 95.40%
==========================================
Files ? 4
Lines ? 87
Branches ? 0
==========================================
Hits ? 83
Misses ? 4
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
|
||
| For example: | ||
| - reverse=False: temp_var -> dst_key | ||
| - reverse=True: temp_var -> src_key |
There was a problem hiding this comment.
reverse=True: temp_var -> src_key 如何理解
There was a problem hiding this comment.
这个意思是,记录temp_var是从哪个src_key转换过来了,记录它的来源信息,则当以下情景时:
aoa_statements = [
"layers.0.gate_up_fused_proj.weight^T -> temp_var \n",
"temp_var -> new_name_layers.0.gate_up_fused_proj.weight,fused_ffn\n",
]
我们调用fused_ffn,而此时箭头左侧需要找到temp_var是从哪个src_key转换过来的,从而找到src_key对应的来源信息
| mapping_dict = self.left_var_to_right_var_mapping | ||
|
|
||
| while current_key in mapping_dict: | ||
| if current_key in visited: |
There was a problem hiding this comment.
这里是不是遇到了环,为什么可以直接return
There was a problem hiding this comment.
这里想的是,正确的aoa配置一般不会出现环的情况,出现的情景一般如下:
aoa_statements = [
"src_key -> dist_key \n",
]
或
aoa_statements = [
"src_key^T -> A\n",
"A -> A \n",
]
即,src_key 和dist_key同名,或用户设置的中间变量和dst同名(src同理),则会出现dst_key:dst_key(或src_key:src_key)的映射,即传入的本身就是dst_key或src_key,则需要中断防止一直陷入循环,并返回传入的这个key。
| visited.add(current_key) | ||
|
|
||
| mapped_vars = mapping_dict[current_key] | ||
| if mapped_vars and len(mapped_vars) > 0: |
There was a problem hiding this comment.
这里如果是 a-> b,c或者 a,b->c这种场景,切分信息还能往下传递吗
There was a problem hiding this comment.
可以的,因为这里保存的value是列表,例如:
a ->b,c
则对于left_var_to_right_var_mapping保存的是:
{
a:[b,c]
}
而对于right_var_from_left_var_mapping保存的是:
{
b:[a]
c:[a]
}
若a是src_key,b,c是中间变量,需要找到b或c来源的src_key的切分信息时,访问right_var_from_left_var_mapping即可映射到;同理若a是中间变量,b,c是dst_key时,直接访问left_var_to_right_var_mapping映射,得到[b,c]列表即可,并且访问列表的第[0]个元素即可,因为二者作为dst_key在统一操作下,则他们在dst中的切分信息也应该是相同的。
|
|
||
|
|
||
| @macro(name='get_var_mapping_chain_macro', priority=3) | ||
| def get_var_mapping_chain_macro(tokens, expression, context): |
There was a problem hiding this comment.
所有macro看似被这个macro分成了两类,在这个macro之前是不是所有匹配src 或 dst中的key的macros都要展开?如果是这样是不是在代码上加一下限制比较好,不然容易出错
There was a problem hiding this comment.
对的,这里就是通过priority来控制的,让此macro在所有展开操作的macro之后做,priority应该是4,这里当时修改掉了,后续会改正
|
/re-run all-failed |
|
LGTM |
f390065 to
96db255
Compare
96db255 to
234ac3e
Compare
|
/re-run all-failed |
1 similar comment
|
/re-run all-failed |
|
/re-run all-failed |
1 similar comment
|
/re-run all-failed |
|
LGTM |
|
/re-run all-failed |
1 similar comment
|
/re-run all-failed |
|
/re-run all-failed |
1 similar comment
|
/re-run all-failed |
fe96e7a to
2293d95
Compare
|
/re-run all-failed |
1 similar comment
|
/re-run all-failed |
* add_get_var_mapping_chain_macro * add note * fix the bug input_vars and resolve_mapping_chain * fix the code style * fit the dtype assert bug * fix the bug * fix the merge_sharded_state_dict bug
* add_get_var_mapping_chain_macro * add note * fix the bug input_vars and resolve_mapping_chain * fix the code style * fit the dtype assert bug * fix the bug * fix the merge_sharded_state_dict bug
…#76252) * 【FlexCheckpoint】fix_the_layer_id_macro (#75556) * fix_the_layer_id_macro * fix the ctest * add expert_id_macro * fix the assert bug * fix the code style * Pr support load hf checkpoint (#75928) * support hf checkpoint fix support cast add id macro fix * add test and fix some bug * fix full param bug * add full param cast test --------- Co-authored-by: xingmingyyj <[email protected]> * 【Flexcheckpoint】add_get_var_mapping_chain_macro (#76013) * add_get_var_mapping_chain_macro * add note * fix the bug input_vars and resolve_mapping_chain * fix the code style * fit the dtype assert bug * fix the bug * fix the merge_sharded_state_dict bug * fix aoa transpose corner case (#76234) --------- Co-authored-by: Tianyu Zheng <[email protected]>
….2 (#76249) * 【FlexCP】merge_sharded_state_dict support distribute merge (#75005) * fix data is nullptr * add dist merge * change test * change test * 【FlexCP】add Skip param param for merge_shard_state_dict (#75061) * fix data is nullptr * add dist merge * change test * change test * add skip optimizer param * [Flex CP]Fix merge_sharded_state_dict with aoa and offload (#75062) * fix merge_state_dict with aoa and offload * add tests * refine * fix * fix * add log * fix * fix * 【FlexCheckpoint】Upgrade some macros and optimize load_state_dict communication (#75282) * upgrad macros and load_state_dict comm task fix fix support 0-d tensor fix balance save and fix * fix test * Add the test about the sharded_state_dict of optimizer (#75067) * fix the share_weight_bug * add note * add the unit test * set the timeout * add more test * Trigger CI rebuild * fix the CmakeLists * handle_missing_edge_cases_in_fc (#75413) * up_grade fc (#75613) fix and add test fix fix fix fix cmakelists add notion * 【FlexCheckpoint】fix_the_layer_id_macro (#75556) * fix_the_layer_id_macro * fix the ctest * add expert_id_macro * fix the assert bug * fix the code style * Pr support load hf checkpoint (#75928) * support hf checkpoint fix support cast add id macro fix * add test and fix some bug * fix full param bug * add full param cast test --------- Co-authored-by: xingmingyyj <[email protected]> * 【Flexcheckpoint】add_get_var_mapping_chain_macro (#76013) * add_get_var_mapping_chain_macro * add note * fix the bug input_vars and resolve_mapping_chain * fix the code style * fit the dtype assert bug * fix the bug * fix the merge_sharded_state_dict bug * fix aoa transpose corner case (#76234) --------- Co-authored-by: xiaoguoguo626807 <[email protected]> Co-authored-by: Chen Zhiyang <[email protected]> Co-authored-by: Tianyu Zheng <[email protected]>
PR Category
User Experience
PR Types
New features
Description