-
Notifications
You must be signed in to change notification settings - Fork 563
[SPMD] Preserve parameter sharding with output data sharding #4721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
54c51f1 to
420d701
Compare
3eac5e6 to
f26b305
Compare
5c3e631 to
0ddee73
Compare
a90760e to
e45ab94
Compare
|
|
e45ab94 to
5ba829f
Compare
5ba829f to
26279e3
Compare
Yea, we need at least 2 devices to create Hlo sharding. Added the safeguard. |
26279e3 to
8d83ef4
Compare
JackCaoG
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
jonb377
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
| } | ||
|
|
||
| void Assign(const Data& data) override { | ||
| XLA_ERROR() << __FUNCTION__ << " not supported."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! We can retry the simple MpDeviceLoader hack for SPMD once this lands, this was the blocker.
[SPMD] Persist tensor sharding with XLA sharding propagation
…#4721) [SPMD] Persist tensor sharding with XLA sharding propagation
…#4721) [SPMD] Persist tensor sharding with XLA sharding propagation
This addresses the same problem as in #4696 with an alternative solution. We shard the replicated output while handling the computation results. This avoids post traversal pass to replace original data node with a sharded one, thus more efficient. Key changes include:
ShardingUtil::OutputHandlerXLAShardingTest.OutputHandlertest for unit testing,test_optimizer_step_with_shardingchecks the validity of the change with a simple e2e example already.std::optional<xla::Shape>toShardingSpecstd::optional<xla::OpSharding>toPjRtShardedDatastd::vector<XLATensor::ShardingSpecPtr>param toXLAGraphExecutor::ScheduleSyncTensorsGraph, since the async function now callsShardingUtil::OutputHandlerXLAGraphExecutor::CollectShardingSpecsbefore callingScheduleSyncTensorsGraphWrapDataShardsandGetDataShardingAPIs inComputationClient.