Skip to content

Conversation

@mlaves
Copy link

@mlaves mlaves commented Jul 29, 2025

Metal implementation of grid_sampler_3d, including backward pass.

Runtime benchmarks vs. CPU (M2 Max)

Batch size:  1, CPU time: 7.164317 ms, MPS time: 0.817350 ms, speedup: 8.77x
Batch size:  2, CPU time: 14.793888 ms, MPS time: 1.111496 ms, speedup: 13.31x
Batch size:  4, CPU time: 30.779025 ms, MPS time: 1.712583 ms, speedup: 17.97x
Batch size:  8, CPU time: 68.820854 ms, MPS time: 3.391506 ms, speedup: 20.29x
Batch size: 16, CPU time: 142.618127 ms, MPS time: 5.831585 ms, speedup: 24.46x

Output accuracy test vs. CPU

Click to expand
Grid Sampler 3D Output Accuracy Test (MPS vs CPU)
input_shape           grid_shape         interp   padding   align_corners   max_diff      mean_diff     sum_diff
------------------------------------------------------------------------------------------------------------------------
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   0        0         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   0        0         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   0        1         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   0        1         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   0        2         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   0        2         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   1        0         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   1        0         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   1        1         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   1        1         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   1        2         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   1        2         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      0        0         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      0        0         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      0        1         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      0        1         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      0        2         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      0        2         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      1        0         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      1        0         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      1        1         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      1        1         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      1        2         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      1        2         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 8, 8, 8, 3)      0        0         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 8, 8, 8, 3)      0        0         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 8, 8, 8, 3)      0        1         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 8, 8, 8, 3)      0        1         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 8, 8, 8, 3)      0        2         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 8, 8, 8, 3)      0        2         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 8, 8, 8, 3)      1        0         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 8, 8, 8, 3)      1        0         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 8, 8, 8, 3)      1        1         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 8, 8, 8, 3)      1        1         1   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 8, 8, 8, 3)      1        2         0   0.00000000   0.00000000   0.00000000
(1, 3, 128, 128, 128) (1, 8, 8, 8, 3)      1        2         1   0.00000000   0.00000000   0.00000000
(2, 1, 32, 32, 32)    (2, 16, 16, 16, 3)   0        0         0   0.00000000   0.00000000   0.00000000
(2, 1, 32, 32, 32)    (2, 16, 16, 16, 3)   0        0         1   0.00000000   0.00000000   0.00000000
(2, 1, 32, 32, 32)    (2, 16, 16, 16, 3)   0        1         0   0.00000000   0.00000000   0.00000000
(2, 1, 32, 32, 32)    (2, 16, 16, 16, 3)   0        1         1   0.00000000   0.00000000   0.00000000
(2, 1, 32, 32, 32)    (2, 16, 16, 16, 3)   0        2         0   0.00000000   0.00000000   0.00000000
(2, 1, 32, 32, 32)    (2, 16, 16, 16, 3)   0        2         1   0.00000000   0.00000000   0.00000000
(2, 1, 32, 32, 32)    (2, 16, 16, 16, 3)   1        0         0   0.00000000   0.00000000   0.00000000
(2, 1, 32, 32, 32)    (2, 16, 16, 16, 3)   1        0         1   0.00000000   0.00000000   0.00000000
(2, 1, 32, 32, 32)    (2, 16, 16, 16, 3)   1        1         0   0.00000000   0.00000000   0.00000000
(2, 1, 32, 32, 32)    (2, 16, 16, 16, 3)   1        1         1   0.00000000   0.00000000   0.00000000
(2, 1, 32, 32, 32)    (2, 16, 16, 16, 3)   1        2         0   0.00000000   0.00000000   0.00000000
(2, 1, 32, 32, 32)    (2, 16, 16, 16, 3)   1        2         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   0        0         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   0        0         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   0        1         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   0        1         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   0        2         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   0        2         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   1        0         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   1        0         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   1        1         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   1        1         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   1        2         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   1        2         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      0        0         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      0        0         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      0        1         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      0        1         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      0        2         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      0        2         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      1        0         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      1        0         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      1        1         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      1        1         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      1        2         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      1        2         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 8, 8, 8, 3)      0        0         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 8, 8, 8, 3)      0        0         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 8, 8, 8, 3)      0        1         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 8, 8, 8, 3)      0        1         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 8, 8, 8, 3)      0        2         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 8, 8, 8, 3)      0        2         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 8, 8, 8, 3)      1        0         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 8, 8, 8, 3)      1        0         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 8, 8, 8, 3)      1        1         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 8, 8, 8, 3)      1        1         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 8, 8, 8, 3)      1        2         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 1, 1)       (1, 8, 8, 8, 3)      1        2         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 64, 64, 64, 3)   0        0         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 64, 64, 64, 3)   0        0         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 64, 64, 64, 3)   0        1         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 64, 64, 64, 3)   0        1         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 64, 64, 64, 3)   0        2         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 64, 64, 64, 3)   0        2         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 64, 64, 64, 3)   1        0         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 64, 64, 64, 3)   1        0         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 64, 64, 64, 3)   1        1         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 64, 64, 64, 3)   1        1         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 64, 64, 64, 3)   1        2         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 64, 64, 64, 3)   1        2         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 1, 1, 1, 3)      0        0         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 1, 1, 1, 3)      0        0         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 1, 1, 1, 3)      0        1         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 1, 1, 1, 3)      0        1         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 1, 1, 1, 3)      0        2         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 1, 1, 1, 3)      0        2         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 1, 1, 1, 3)      1        0         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 1, 1, 1, 3)      1        0         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 1, 1, 1, 3)      1        1         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 1, 1, 1, 3)      1        1         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 1, 1, 1, 3)      1        2         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 1, 1, 1, 3)      1        2         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 8, 8, 8, 3)      0        0         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 8, 8, 8, 3)      0        0         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 8, 8, 8, 3)      0        1         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 8, 8, 8, 3)      0        1         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 8, 8, 8, 3)      0        2         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 8, 8, 8, 3)      0        2         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 8, 8, 8, 3)      1        0         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 8, 8, 8, 3)      1        0         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 8, 8, 8, 3)      1        1         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 8, 8, 8, 3)      1        1         1   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 8, 8, 8, 3)      1        2         0   0.00000000   0.00000000   0.00000000
(1, 4, 16, 16, 16)    (1, 8, 8, 8, 3)      1        2         1   0.00000000   0.00000000   0.00000000
------------------------------------------------------------------------------------------------------------------------
Summary:
Total tests: 120
Failed tests: 0
Max difference overall: 0.00000000
Mean difference overall: 0.00000000
Sum difference overall: 0.00000000

Edge Cases Test
input_shape           grid_shape         interp   padding   align_corners   max_diff      mean_diff     sum_diff
------------------------------------------------------------------------------------------------------------------------
(1, 1, 2, 2, 2)       (1, 1, 1, 1, 3)      0        0         0   0.00000000   0.00000000   0.00000000
(1, 1, 2, 2, 2)       (1, 1, 1, 1, 3)      0        0         1   0.00000000   0.00000000   0.00000000
(1, 1, 2, 2, 2)       (1, 1, 1, 1, 3)      0        1         0   0.00000000   0.00000000   0.00000000
(1, 1, 2, 2, 2)       (1, 1, 1, 1, 3)      0        1         1   0.00000000   0.00000000   0.00000000
(1, 1, 2, 2, 2)       (1, 1, 1, 1, 3)      1        0         0   0.00000000   0.00000000   0.00000000
(1, 1, 2, 2, 2)       (1, 1, 1, 1, 3)      1        0         1   0.00000000   0.00000000   0.00000000
(1, 1, 2, 2, 2)       (1, 1, 1, 1, 3)      1        1         0   0.00000000   0.00000000   0.00000000
(1, 1, 2, 2, 2)       (1, 1, 1, 1, 3)      1        1         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 10, 10)     (1, 1, 5, 5, 3)      0        0         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 10, 10)     (1, 1, 5, 5, 3)      0        0         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 10, 10)     (1, 1, 5, 5, 3)      0        1         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 10, 10)     (1, 1, 5, 5, 3)      0        1         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 10, 10)     (1, 1, 5, 5, 3)      1        0         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 10, 10)     (1, 1, 5, 5, 3)      1        0         1   0.00000000   0.00000000   0.00000000
(1, 1, 1, 10, 10)     (1, 1, 5, 5, 3)      1        1         0   0.00000000   0.00000000   0.00000000
(1, 1, 1, 10, 10)     (1, 1, 5, 5, 3)      1        1         1   0.00000000   0.00000000   0.00000000
(1, 16, 8, 8, 8)      (1, 4, 4, 4, 3)      0        0         0   0.00000000   0.00000000   0.00000000
(1, 16, 8, 8, 8)      (1, 4, 4, 4, 3)      0        0         1   0.00000000   0.00000000   0.00000000
(1, 16, 8, 8, 8)      (1, 4, 4, 4, 3)      0        1         0   0.00000000   0.00000000   0.00000000
(1, 16, 8, 8, 8)      (1, 4, 4, 4, 3)      0        1         1   0.00000000   0.00000000   0.00000000
(1, 16, 8, 8, 8)      (1, 4, 4, 4, 3)      1        0         0   0.00000000   0.00000000   0.00000000
(1, 16, 8, 8, 8)      (1, 4, 4, 4, 3)      1        0         1   0.00000000   0.00000000   0.00000000
(1, 16, 8, 8, 8)      (1, 4, 4, 4, 3)      1        1         0   0.00000000   0.00000000   0.00000000
(1, 16, 8, 8, 8)      (1, 4, 4, 4, 3)      1        1         1   0.00000000   0.00000000   0.00000000

All tests passed!

Gradient accuracy test vs. CPU

Click to expand
grad wrt. input
input_shape           grid_shape         interp   padding   align_corners   diff
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   0        0         0    6.324674117763607e-10
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   0        0         1    7.693850001544433e-10
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   0        1         0    6.303384481043395e-10
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   0        1         1    7.976201921167103e-10
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   1        0         0                      0.0
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   1        0         1                      0.0
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   1        1         0                      0.0
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   1        1         1                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      0        0         0                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      0        0         1                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      0        1         0                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      0        1         1                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      1        0         0                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      1        0         1                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      1        1         0                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      1        1         1                      0.0
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   0        0         0   1.1622905731201172e-06
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   0        0         1                      0.0
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   0        1         0                      0.0
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   0        1         1                      0.0
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   1        0         0                      0.0
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   1        0         1                      0.0
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   1        1         0                      0.0
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   1        1         1                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      0        0         0                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      0        0         1                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      0        1         0                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      0        1         1                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      1        0         0                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      1        0         1                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      1        1         0                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      1        1         1                      0.0
max diff for input gradients:  1.1622905731201172e-06
grad wrt. grid
input_shape           grid_shape         interp   padding   align_corners   diff
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   0        0         0                      0.0
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   0        0         1                      0.0
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   0        1         0                      0.0
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   0        1         1                      0.0
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   1        0         0                      0.0
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   1        0         1                      0.0
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   1        1         0                      0.0
(1, 3, 128, 128, 128) (1, 64, 64, 64, 3)   1        1         1                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      0        0         0                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      0        0         1                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      0        1         0                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      0        1         1                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      1        0         0                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      1        0         1                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      1        1         0                      0.0
(1, 3, 128, 128, 128) (1, 1, 1, 1, 3)      1        1         1                      0.0
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   0        0         0                      0.0
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   0        0         1                      0.0
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   0        1         0                      0.0
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   0        1         1                      0.0
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   1        0         0                      0.0
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   1        0         1                      0.0
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   1        1         0                      0.0
(1, 1, 1, 1, 1)       (1, 64, 64, 64, 3)   1        1         1                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      0        0         0                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      0        0         1                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      0        1         0                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      0        1         1                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      1        0         0                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      1        0         1                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      1        1         0                      0.0
(1, 1, 1, 1, 1)       (1, 1, 1, 1, 3)      1        1         1                      0.0
max diff for grid gradients:  0.0

@mlaves mlaves requested review from kulinseth and malfet as code owners July 29, 2025 21:24
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159421

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 2 Unrelated Failures

As of commit 36d9498 with merge base c665594 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: mps Release notes category label Jul 29, 2025
@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@mlaves mlaves requested a review from Skylion007 August 3, 2025 20:58
@mlaves
Copy link
Author

mlaves commented Aug 22, 2025

Any updates on this? @malfet

@mlaves
Copy link
Author

mlaves commented Aug 27, 2025

I see that there was another PR #160541 that was opened after this one, but already merged. It looks like that the merged PR does not implement the backward pass or nearest interpolation. Therefore, my PR might still be worth merging. @malfet @kurtamohler

@kurtamohler
Copy link
Collaborator

kurtamohler commented Aug 27, 2025

Hey @mlaves, I'm sorry I had not noticed that you opened this PR. I have started to work on the backward part as well, but I could stop if you'd like to update your PR to only add the backward function.

You'll need to enable tests for it by removing this line:

"grid_sampler_3d": None,

You can run the tests on your machine with python test/test_mps.py -k grid_sampler_3d

@malfet , does that sound good to you?

@malfet
Copy link
Contributor

malfet commented Aug 27, 2025

@mlaves sorry I've missed your change earlier (I usually go over all the changes that claim to Fix some of the MPS issues or has ciflow/mps in the title.

Please rebase your change and I'll have a look at it soon

@mlaves
Copy link
Author

mlaves commented Aug 28, 2025

@kurtamohler No worries! I also missed your PR and saw it just now. It's rather remarkable that this has been missing for years, only to be implemented by two individuals within a couple of days lol.

@malfet I will rebase my changes on current main as soon as possible.

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Oct 27, 2025
@github-actions github-actions bot closed this Nov 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants