Skip to content

Commit 43a0918

Browse files
manuelcandalespytorchmergebot
authored andcommitted
[MPS] Add benchmark for scan with indices (#156860)
Baseline performance on M4 Max 64GB (macOS 15.5): ``` [-------------------------------- --------------------------------] | eager | compile 1 threads: --------------------------------------------------------- cummin-dim0-32x32 (torch.float16) | 102.5 | 115.0 cummin-dim0-128x128 (torch.float16) | 133.6 | 147.8 cummin-dim0-512x512 (torch.float16) | 233.1 | 243.1 cummin-dim0-1024x1024 (torch.float16) | 364.2 | 385.2 cummin-dim1-32x32 (torch.float16) | 94.4 | 109.8 cummin-dim1-128x128 (torch.float16) | 109.9 | 122.5 cummin-dim1-512x512 (torch.float16) | 227.0 | 233.8 cummin-dim1-1024x1024 (torch.float16) | 985.1 | 1010.5 cummin-1d-100 (torch.float16) | 100.7 | 114.3 cummin-1d-10000 (torch.float16) | 805.0 | 879.1 cummin-1d-1000000 (torch.float16) | 70545.6 | 71310.3 cummin-dim0-32x32 (torch.float32) | 102.7 | 115.5 cummin-dim0-128x128 (torch.float32) | 137.2 | 143.8 cummin-dim0-512x512 (torch.float32) | 209.7 | 222.0 cummin-dim0-1024x1024 (torch.float32) | 340.1 | 389.9 cummin-dim1-32x32 (torch.float32) | 99.2 | 107.8 cummin-dim1-128x128 (torch.float32) | 111.9 | 119.3 cummin-dim1-512x512 (torch.float32) | 250.7 | 255.1 cummin-dim1-1024x1024 (torch.float32) | 987.9 | 1013.2 cummin-1d-100 (torch.float32) | 100.6 | 114.6 cummin-1d-10000 (torch.float32) | 794.7 | 862.2 cummin-1d-1000000 (torch.float32) | 71995.3 | 71963.5 cummin-dim0-32x32 (torch.bfloat16) | 105.9 | 113.9 cummin-dim0-128x128 (torch.bfloat16) | 135.7 | 147.9 cummin-dim0-512x512 (torch.bfloat16) | 231.9 | 240.7 cummin-dim0-1024x1024 (torch.bfloat16) | 327.7 | 366.9 cummin-dim1-32x32 (torch.bfloat16) | 91.3 | 103.3 cummin-dim1-128x128 (torch.bfloat16) | 108.5 | 117.4 cummin-dim1-512x512 (torch.bfloat16) | 222.0 | 233.6 cummin-dim1-1024x1024 (torch.bfloat16) | 936.9 | 982.5 cummin-1d-100 (torch.bfloat16) | 106.6 | 112.4 cummin-1d-10000 (torch.bfloat16) | 795.8 | 819.6 cummin-1d-1000000 (torch.bfloat16) | 68667.4 | 68557.9 Times are in microseconds (us). ``` Pull Request resolved: #156860 Approved by: https://github.com/malfet
1 parent 9fe2d15 commit 43a0918

File tree

1 file changed

+30
-19
lines changed

1 file changed

+30
-19
lines changed

test/bench_mps_ops.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ def bench_binary(
7171
return rc
7272

7373

74+
def check_eager_vs_compile(rc_c, rc_e, func, dtype):
75+
if not torch.allclose(rc_c, rc_e):
76+
mdiff = (rc_c - rc_e).abs().max()
77+
warnings.warn(
78+
f"Eager and compile reduction do not match for {func.__name__} and {dtype} max_diff={mdiff}",
79+
stacklevel=2,
80+
)
81+
82+
7483
def bench_reduction(
7584
reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32
7685
) -> list[Measurement]:
@@ -87,19 +96,17 @@ def f(t):
8796
x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
8897
rc_c, rc_e = f(x), f_c(x)
8998
rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e)
90-
if not torch.allclose(rc_c, rc_e):
91-
mdiff = (rc_c - rc_e).abs().max()
92-
warnings.warn(
93-
f"Eager and compile reduction do not match for {reduction_func.__name__} and {dtype} max_diff={mdiff}",
94-
stacklevel=2,
95-
)
99+
check_eager_vs_compile(rc_c, rc_e, reduction_func, dtype)
96100
rc.append(bench_unary_op(f, x, f"eager-{size}x{size}"))
97101
rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}"))
98102
return rc
99103

100104

101105
def bench_scan(
102-
scan_func, device: str = "mps", dtype: torch.dtype = torch.float32
106+
scan_func,
107+
device: str = "mps",
108+
dtype: torch.dtype = torch.float32,
109+
with_indices: bool = False,
103110
) -> list[Measurement]:
104111
rc = []
105112

@@ -116,12 +123,11 @@ def f(t):
116123
f_c.__name__ = f.__name__
117124
x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
118125
rc_c, rc_e = f(x), f_c(x)
119-
if not torch.allclose(rc_c, rc_e):
120-
mdiff = (rc_c - rc_e).abs().max()
121-
warnings.warn(
122-
f"Eager and compile scan do not match for {scan_func.__name__} dim={dim} and {dtype} max_diff={mdiff}",
123-
stacklevel=2,
124-
)
126+
if with_indices:
127+
check_eager_vs_compile(rc_c[0], rc_e[0], scan_func, dtype)
128+
check_eager_vs_compile(rc_c[1], rc_e[1], scan_func, dtype)
129+
else:
130+
check_eager_vs_compile(rc_c, rc_e, scan_func, dtype)
125131
rc.append(bench_unary_op(f, x, "eager"))
126132
rc.append(bench_unary_op(f_c, x, "compile"))
127133

@@ -136,12 +142,11 @@ def f_1d(t):
136142
f_1d_c.__name__ = f_1d.__name__
137143
x = torch.testing.make_tensor(size, device=device, dtype=dtype)
138144
rc_c, rc_e = f_1d(x), f_1d_c(x)
139-
if not torch.allclose(rc_c, rc_e):
140-
mdiff = (rc_c - rc_e).abs().max()
141-
warnings.warn(
142-
f"Eager and compile 1D scan do not match for {scan_func.__name__} and {dtype} max_diff={mdiff}",
143-
stacklevel=2,
144-
)
145+
if with_indices:
146+
check_eager_vs_compile(rc_c[0], rc_e[0], scan_func, dtype)
147+
check_eager_vs_compile(rc_c[1], rc_e[1], scan_func, dtype)
148+
else:
149+
check_eager_vs_compile(rc_c, rc_e, scan_func, dtype)
145150
rc.append(bench_unary_op(f_1d, x, "eager"))
146151
rc.append(bench_unary_op(f_1d_c, x, "compile"))
147152

@@ -171,6 +176,12 @@ def main() -> None:
171176
rc.extend(bench_scan(torch.cumsum, dtype=dtype))
172177
Compare(rc).print()
173178

179+
# Profile scan with indices ops (cummin)
180+
rc = []
181+
for dtype in dtypes:
182+
rc.extend(bench_scan(torch.cummin, dtype=dtype, with_indices=True))
183+
Compare(rc).print()
184+
174185
# Profile binary ops
175186
rc = []
176187
ops = [torch.fmax, torch.add]

0 commit comments

Comments
 (0)