Skip to content

Enable AMP for MPS devices #88415

@justusschock

Description

@justusschock

🚀 The feature, motivation and pitch

#78168 States that fp16 support for mps devices in general should be possible, but autocas only works with cpu and cuda device types. When enabling it manually, on mps it does not show any additional conversions:

>>> torch.set_autocast_enabled(True)
>>> with capture_logs(is_mode=True) as logs, LoggingTensorMode():
...     a = torch.rand(10, 10, dtype=torch.float, device='mps')
...     b = torch.rand(10, 10, dtype=torch.float, device='mps')
...     c = torch.addmm(a, a, b)
>>> for l in logs:
...     print(l)

prints the following:

$0 = torch._ops.aten.rand.default([10, 10], dtype=torch.float32, device=device(type='mps'), pin_memory=False)
$1 = torch._ops.aten.rand.default([10, 10], dtype=torch.float32, device=device(type='mps'), pin_memory=False)
$2 = torch._ops.aten.addmm.default($0, $0, $1)

given the fp16 support, it would be nice to have autocast and amp in general working on MPS devices as well

Alternatives

No response

Additional context

cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev as per our discussion on slack.

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: amp (automated mixed precision)autocastmodule: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions