-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[MPS] Implement linear1d as shader #148154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
And get rid of MPS call, as for some reason implementation via MPSGraph
API call is 100x+ times slower that Metal shader, at least according to
the following benchmark
```python
import torch
import time
import subprocess
def benchmark(device, dtype):
# Create example inputs
x = torch.testing.make_tensor(3, 5, 65536, device=device, dtype=dtype)
sf = .5
# Check output
y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear")
z = torch.nn.functional.interpolate(x.cpu(), scale_factor=sf, mode="linear")
outputs_match = torch.allclose(y.cpu(), z)
if not outputs_match:
atol = (y.cpu() - z).abs().max()
rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max()
print(f"atol={atol} rtol={rtol}")
# Measure time manually
start_time = time.time() * 1000
for _ in range(1000):
y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear")
torch.mps.synchronize
end_time = time.time() * 1000
manual_delta = (end_time - start_time)
average_time = f"{manual_delta:6.1f}"
return "True " if outputs_match else "False", average_time
outputs_match_list = []
average_time_list = []
for device in ["mps", "cpu"]:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
outputs_match, average_time = benchmark(device, dtype)
outputs_match_list.append(str(outputs_match))
average_time_list.append(average_time)
brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip()
print(f"\nBenchmarking Results (collected on {brand_string}):")
print("-"*40)
print("Device : MPS | CPU")
print("Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 ")
print(f"Outputs Match : ", " | ".join(outputs_match_list))
print(f"Average Time (us) :", " |".join(average_time_list))
```
Benchmark results after the change
```
Benchmarking Results (collected on Apple M2 Pro):
----------------------------------------
Device : MPS | CPU
Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16
Outputs Match : True | True | True | True | True | True
Average Time (us) : 2.5 | 2.1 | 2.2 | 161.4 | 115.0 | 161.1
```
And before the change
```
Benchmarking Results (collected on Apple M2 Pro):
----------------------------------------
Device : MPS | CPU
Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16
Outputs Match : True | True | True | True | True | True
Average Time (us) : 354.0 | 336.0 | 332.4 | 145.5 | 114.7 | 148.3
```
Fixes #144245
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148154
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 31 PendingAs of commit 5bc969d with merge base 926b7b5 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
And get rid of MPS call, as for some reason implementation via MPSGraph
API call is 100x+ times slower that Metal shader, at least according to
the following benchmark
```python
import torch
import time
import subprocess
def benchmark(device, dtype):
# Create example inputs
x = torch.testing.make_tensor(3, 5, 65536, device=device, dtype=dtype)
sf = .5
# Check output
y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear")
z = torch.nn.functional.interpolate(x.cpu(), scale_factor=sf, mode="linear")
outputs_match = torch.allclose(y.cpu(), z)
if not outputs_match:
atol = (y.cpu() - z).abs().max()
rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max()
print(f"atol={atol} rtol={rtol}")
# Measure time manually
start_time = time.time() * 1000
for _ in range(1000):
y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear")
torch.mps.synchronize
end_time = time.time() * 1000
manual_delta = (end_time - start_time)
average_time = f"{manual_delta:6.1f}"
return "True " if outputs_match else "False", average_time
outputs_match_list = []
average_time_list = []
for device in ["mps", "cpu"]:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
outputs_match, average_time = benchmark(device, dtype)
outputs_match_list.append(str(outputs_match))
average_time_list.append(average_time)
brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip()
print(f"\nBenchmarking Results (collected on {brand_string}):")
print("-"*40)
print("Device : MPS | CPU")
print("Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 ")
print(f"Outputs Match : ", " | ".join(outputs_match_list))
print(f"Average Time (us) :", " |".join(average_time_list))
```
Benchmark results after the change
```
Benchmarking Results (collected on Apple M2 Pro):
----------------------------------------
Device : MPS | CPU
Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16
Outputs Match : True | True | True | True | True | True
Average Time (us) : 2.5 | 2.1 | 2.2 | 161.4 | 115.0 | 161.1
```
And before the change
```
Benchmarking Results (collected on Apple M2 Pro):
----------------------------------------
Device : MPS | CPU
Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16
Outputs Match : True | True | True | True | True | True
Average Time (us) : 354.0 | 336.0 | 332.4 | 145.5 | 114.7 | 148.3
```
Fixes #144245
ghstack-source-id: 08a057f
Pull Request resolved: #148154
dcci
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, thanks.
And get rid of MPS call, as for some reason implementation via MPSGraph
API call is 100x+ times slower that Metal shader, at least according to
the following benchmark
```python
import torch
import time
import subprocess
def benchmark(device, dtype):
# Create example inputs
x = torch.testing.make_tensor(3, 5, 65536, device=device, dtype=dtype)
sf = .5
# Check output
y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear")
z = torch.nn.functional.interpolate(x.cpu(), scale_factor=sf, mode="linear")
outputs_match = torch.allclose(y.cpu(), z)
if not outputs_match:
atol = (y.cpu() - z).abs().max()
rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max()
print(f"atol={atol} rtol={rtol}")
# Measure time manually
start_time = time.time() * 1000
for _ in range(1000):
y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear")
torch.mps.synchronize
end_time = time.time() * 1000
manual_delta = (end_time - start_time)
average_time = f"{manual_delta:6.1f}"
return "True " if outputs_match else "False", average_time
outputs_match_list = []
average_time_list = []
for device in ["mps", "cpu"]:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
outputs_match, average_time = benchmark(device, dtype)
outputs_match_list.append(str(outputs_match))
average_time_list.append(average_time)
brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip()
print(f"\nBenchmarking Results (collected on {brand_string}):")
print("-"*40)
print("Device : MPS | CPU")
print("Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 ")
print(f"Outputs Match : ", " | ".join(outputs_match_list))
print(f"Average Time (us) :", " |".join(average_time_list))
```
Benchmark results after the change
```
Benchmarking Results (collected on Apple M2 Pro):
----------------------------------------
Device : MPS | CPU
Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16
Outputs Match : True | True | True | True | True | True
Average Time (us) : 2.5 | 2.1 | 2.2 | 161.4 | 115.0 | 161.1
```
And before the change
```
Benchmarking Results (collected on Apple M2 Pro):
----------------------------------------
Device : MPS | CPU
Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16
Outputs Match : True | True | True | True | True | True
Average Time (us) : 354.0 | 336.0 | 332.4 | 145.5 | 114.7 | 148.3
```
Fixes #144245
[ghstack-poisoned]
And get rid of MPS call, as for some reason implementation via MPSGraph
API call is 100x+ times slower that Metal shader, at least according to
the following benchmark
```python
import torch
import time
import subprocess
def benchmark(device, dtype):
# Create example inputs
x = torch.testing.make_tensor(3, 5, 65536, device=device, dtype=dtype)
sf = .5
# Check output
y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear")
z = torch.nn.functional.interpolate(x.cpu(), scale_factor=sf, mode="linear")
outputs_match = torch.allclose(y.cpu(), z)
if not outputs_match:
atol = (y.cpu() - z).abs().max()
rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max()
print(f"atol={atol} rtol={rtol}")
# Measure time manually
start_time = time.time() * 1000
for _ in range(1000):
y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear")
torch.mps.synchronize
end_time = time.time() * 1000
manual_delta = (end_time - start_time)
average_time = f"{manual_delta:6.1f}"
return "True " if outputs_match else "False", average_time
outputs_match_list = []
average_time_list = []
for device in ["mps", "cpu"]:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
outputs_match, average_time = benchmark(device, dtype)
outputs_match_list.append(str(outputs_match))
average_time_list.append(average_time)
brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip()
print(f"\nBenchmarking Results (collected on {brand_string}):")
print("-"*40)
print("Device : MPS | CPU")
print("Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 ")
print(f"Outputs Match : ", " | ".join(outputs_match_list))
print(f"Average Time (us) :", " |".join(average_time_list))
```
Benchmark results after the change
```
Benchmarking Results (collected on Apple M2 Pro):
----------------------------------------
Device : MPS | CPU
Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16
Outputs Match : True | True | True | True | True | True
Average Time (us) : 2.5 | 2.1 | 2.2 | 161.4 | 115.0 | 161.1
```
And before the change
```
Benchmarking Results (collected on Apple M2 Pro):
----------------------------------------
Device : MPS | CPU
Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16
Outputs Match : True | True | True | True | True | True
Average Time (us) : 354.0 | 336.0 | 332.4 | 145.5 | 114.7 | 148.3
```
Fixes #144245
[ghstack-poisoned]
|
@pytorchbot merge -f "Lint + MPS are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Refactor `INSTANTIATE_UPSAMPLE_BILINEAR2D(DTYPE)`, `INSTANTIATE_UPSAMPLE_BICUBIC2D(DTYPE)` and `INSTANTIATE_UPSAMPLE_BILINEAR2DAA(DTYPE)` use common `INSTANTIATE_UPSAMPLE2D` Then combine multiple invocations into `INSTANTIATE_UPSAMPLE_ALL` I.e. functionally it's a no-op, but achieves the same with fewer lines of code Pull Request resolved: #148187 Approved by: https://github.com/Skylion007 ghstack dependencies: #148154
- First, by stopp inverting sizes and strides, i.e. passing them as is, but reading them in inverse order in the shader as 1st stride of 4D tensor is one used for batches, 2nd for channels and 3rd and 4th for spatial coordinates - Pass `scales` as float2 even in linear tensor Above allows one to collide two flavors `upsample_kernel_out_template` into one Pull Request resolved: #148211 Approved by: https://github.com/dcci ghstack dependencies: #148154, #148187
Not sure why tolerances were set like that, this logic was added in #104181 without much explanation But if I'm to make a guess, it's likely due to the inaccuracy of bilinear op, that has since been replaced by shader Pull Request resolved: #148224 Approved by: https://github.com/Skylion007, https://github.com/dcci ghstack dependencies: #148154, #148187, #148211
First of all, perf claims made in #145581 and #148154 are too good to be true (due to the bug in the script that did not call `torch.mps.synchronize` at the end of the benchmark script, but still slightly better than MPS, probably due to the launch overhead. And while measure performance correctly, I've noticed that a lot of time is spent on 64-bit integral division of thread_index to get spatial coordinates. Simply downcasting divisior to 32-bit integer (which is also the thread index) speeds it up almost 2x for bilinear and bicubic as could be demonstrated by running following script ```python import torch import time import subprocess import itertools def benchmark(device, dtype, mode="bilinear", antialias=False, sf=.5): # Create example inputs x = torch.testing.make_tensor(1, 1, 2048, 2048, device=device, dtype=dtype) # define kwargs kwargs = {"antialias": antialias, "mode": mode, "scale_factor": sf} # Skip for unimplemented flavors if antialias and mode == "bicubic" and device == "mps": return None, "Skip" elif antialias and dtype != torch.float32: if device == "cpu": return None, "Skip" outputs_match = None else: # Check output y = torch.nn.functional.interpolate(x, **kwargs) z = torch.nn.functional.interpolate(x.cpu(), **kwargs) outputs_match = torch.allclose(y.cpu(), z) if not outputs_match: atol = (y.cpu() - z).abs().max() rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max() print(f"atol={atol} rtol={rtol}") # Measure time manually start_time = time.time() * 1000 for _ in range(1000): y = torch.nn.functional.interpolate(x, **kwargs) torch.mps.synchronize() end_time = time.time() * 1000 manual_delta = (end_time - start_time) average_time = f"{manual_delta:6.1f}" return "True " if outputs_match else "False", average_time brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip() for mode,antialias in itertools.product(["bilinear", "bicubic"], [False, True]): outputs_match_list = [] average_time_list = [] for device in ["mps", "cpu"]: for dtype in [torch.float32, torch.float16, torch.bfloat16]: outputs_match, average_time = benchmark(device, dtype, mode=mode, antialias=antialias) outputs_match_list.append(str(outputs_match)) average_time_list.append(average_time) print(f"\nBenchmarking Results (collected on {brand_string}) for {mode} interpolation {'with antialias' if antialias else ''}:") print("-"*40) print("Device : MPS | CPU") print("Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16") print(f"Outputs Match : ", " | ".join(outputs_match_list)) print(f"Average Time (us) :", " |".join(average_time_list)) ``` Before ``` Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : True | True | True | True | True | True Average Time (us) : 292.0 | 264.7 | 267.9 | 289.1 | 230.9 | 309.1 atol=1.430511474609375e-06 rtol=0.11363636702299118 Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation with antialias: ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | False | False | True | None | None Average Time (us) : 698.3 | 684.2 | 683.8 | 851.0 |Skip |Skip atol=2.086162567138672e-06 rtol=0.019750799983739853 Benchmarking Results (collected on Apple M4 Pro) for bicubic interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | True | True | True | True | True Average Time (us) : 314.3 | 301.0 | 298.8 | 681.5 | 616.7 | 833.7 ``` After ``` Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : True | True | True | True | True | True Average Time (us) : 119.9 | 98.9 | 98.6 | 289.8 | 231.9 | 308.5 atol=1.430511474609375e-06 rtol=0.05681818351149559 Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation with antialias: ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | False | False | True | None | None Average Time (us) : 541.9 | 531.1 | 531.0 | 846.8 |Skip |Skip atol=2.0265579223632812e-06 rtol=0.008604463189840317 Benchmarking Results (collected on Apple M4 Pro) for bicubic interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | True | True | True | True | True Average Time (us) : 314.3 | 301.0 | 298.8 | 681.5 | 616.7 | 833.7 ``` TODO: - Figure out if this ops make more sense as 3D jobs with n and c channels dispatch as one more dimension Pull Request resolved: #148277 Approved by: https://github.com/Skylion007
First of all, perf claims made in pytorch/pytorch#145581 and pytorch/pytorch#148154 are too good to be true (due to the bug in the script that did not call `torch.mps.synchronize` at the end of the benchmark script, but still slightly better than MPS, probably due to the launch overhead. And while measure performance correctly, I've noticed that a lot of time is spent on 64-bit integral division of thread_index to get spatial coordinates. Simply downcasting divisior to 32-bit integer (which is also the thread index) speeds it up almost 2x for bilinear and bicubic as could be demonstrated by running following script ``` import torch import time import subprocess import itertools def benchmark(device, dtype, mode="bilinear", antialias=False, sf=.5): # Create example inputs x = torch.testing.make_tensor(1, 1, 2048, 2048, device=device, dtype=dtype) # define kwargs kwargs = {"antialias": antialias, "mode": mode, "scale_factor": sf} # Skip for unimplemented flavors if antialias and mode == "bicubic" and device == "mps": return None, "Skip" elif antialias and dtype != torch.float32: if device == "cpu": return None, "Skip" outputs_match = None else: # Check output y = torch.nn.functional.interpolate(x, **kwargs) z = torch.nn.functional.interpolate(x.cpu(), **kwargs) outputs_match = torch.allclose(y.cpu(), z) if not outputs_match: atol = (y.cpu() - z).abs().max() rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max() print(f"atol={atol} rtol={rtol}") # Measure time manually start_time = time.time() * 1000 for _ in range(1000): y = torch.nn.functional.interpolate(x, **kwargs) torch.mps.synchronize() end_time = time.time() * 1000 manual_delta = (end_time - start_time) average_time = f"{manual_delta:6.1f}" return "True " if outputs_match else "False", average_time brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip() for mode,antialias in itertools.product(["bilinear", "bicubic"], [False, True]): outputs_match_list = [] average_time_list = [] for device in ["mps", "cpu"]: for dtype in [torch.float32, torch.float16, torch.bfloat16]: outputs_match, average_time = benchmark(device, dtype, mode=mode, antialias=antialias) outputs_match_list.append(str(outputs_match)) average_time_list.append(average_time) print(f"\nBenchmarking Results (collected on {brand_string}) for {mode} interpolation {'with antialias' if antialias else ''}:") print("-"*40) print("Device : MPS | CPU") print("Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16") print(f"Outputs Match : ", " | ".join(outputs_match_list)) print(f"Average Time (us) :", " |".join(average_time_list)) ``` Before ``` Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : True | True | True | True | True | True Average Time (us) : 292.0 | 264.7 | 267.9 | 289.1 | 230.9 | 309.1 atol=1.430511474609375e-06 rtol=0.11363636702299118 Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation with antialias: ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | False | False | True | None | None Average Time (us) : 698.3 | 684.2 | 683.8 | 851.0 |Skip |Skip atol=2.086162567138672e-06 rtol=0.019750799983739853 Benchmarking Results (collected on Apple M4 Pro) for bicubic interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | True | True | True | True | True Average Time (us) : 314.3 | 301.0 | 298.8 | 681.5 | 616.7 | 833.7 ``` After ``` Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : True | True | True | True | True | True Average Time (us) : 119.9 | 98.9 | 98.6 | 289.8 | 231.9 | 308.5 atol=1.430511474609375e-06 rtol=0.05681818351149559 Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation with antialias: ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | False | False | True | None | None Average Time (us) : 541.9 | 531.1 | 531.0 | 846.8 |Skip |Skip atol=2.0265579223632812e-06 rtol=0.008604463189840317 Benchmarking Results (collected on Apple M4 Pro) for bicubic interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | True | True | True | True | True Average Time (us) : 314.3 | 301.0 | 298.8 | 681.5 | 616.7 | 833.7 ``` ghstack-source-id: c622e55 Pull Request resolved: pytorch/pytorch#148277
Stack from ghstack (oldest at bottom):
And get rid of MPS call, as for some reason implementation via MPSGraph
API call is 100x+ times slower that Metal shader, at least according to
the following benchmark
Benchmark results after the change
And before the change
Fixes #144245