Skip to content

padding + cast can not be compiled for all dtypes on CPU #122606

@malfet

Description

@malfet

🐛 Describe the bug

Following example:

import torch

@torch.compile
def cast_and_pad(x):
    return torch.nn.functional.pad(x.to(torch.float32), (0, 3, 0, 0))

x=torch.ones(1, 1, 13, dtype=torch.int64)
print(cast_and_pad(x))

will fail(see Colab ) with

/tmp/torchinductor_root/qf/cqffzgc7mvvjhlx2uqho42flqfmxpnu4g7tu2mltyq57j7thf4jq.cpp: In lambda function:
/tmp/torchinductor_root/qf/cqffzgc7mvvjhlx2uqho42flqfmxpnu4g7tu2mltyq57j7thf4jq.cpp:16:40: error: no matching function for call to ‘masked_load(const long int*, at::vec::CPU_CAPABILITY::Vectorized<float>)’
   16 |                 auto tmp6 = masked_load(in_ptr0 + static_cast<long>(x0), to_float_mask(tmp4));

Because indeed masked_load is only implemented for floating types

Fixed by af9e30f

Discovered while debugging test_comprehensive_fft_ihfft2_cpu_int64 failures on ARM/Intel CPUs without AVX512 support

Versions

CI

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang

Metadata

Metadata

Assignees

Labels

oncall: cpu inductorCPU Inductor issues for Intel team to triageoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions