Skip to content

FSDP models are not being correctly converted to another data-type after initialization. #137522

@ysiraichi

Description

@ysiraichi

🐛 Describe the bug

The test test_to_float64_after_init starts failing after introducing a dtype mismatch check for lerp.

def test_to_float64_after_init(self):

Context: lerp eager implementation actually errors on mismatched data-types for its inputs. However, its decomposition implements type-promotion, instead. After fixing this issue (#136909), the above-mentioned test starts failing.

TORCH_META_FUNC(lerp_Tensor)(
const Tensor& self, const Tensor& end, const Tensor& weight) {
TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(),
" for `end` but got dtype ", end.dtype());
TORCH_CHECK(self.dtype() == weight.dtype(), "expected dtype ", self.dtype(),
" for `weight` but got dtype ", weight.dtype());

Upon closer inspection, I observed that: even though the DTensor had the same dtype for both start and end, when it got to the decomposition (without #136909), the dtypes were different. Moving the dtype conversion statement in the test model.to(dtype) before calling fully_sharded on the model fixes the issue (defeats the purpose of the test, though).

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @albanD

Versions

PyTorch version: 2.6.0a0+git85eed60
Is debug build: True
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

Metadata

Metadata

Assignees

No one assigned

    Labels

    high priorityoncall: distributedAdd this issue/PR to distributed oncall triage queuetriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions