-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[MPS][BE] Implement bilineard2d as shader #145581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145581
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 37 PendingAs of commit 49a94e7 with merge base 66bf7da ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot merge -f "Lint + MPS are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
That significantly improves performance and addresses correctness problem(to an extend permitted by reducing precision of scale factor computation to float32). uint8 scaling algorithm mimics CPU/Pillow implementation https://github.com/python-pillow/Pillow/blob/569b785371aa717a004adb0166feb565bbb01b7b/src/libImaging/Resample.c#L306-L309 I.e. using fixed precision integral arithmetic and rounding results of horizontal interpolation back to integers before performing vertical one, which results in technically less accurate results. But even with those changes, `atol`, `rtol` must be tweaked to `1, 0` when scale factor is `1/3` or `2/3` because of the difference of representation of those values as floats and doubles. Changes in the performance could be measured using the following script ```python import torch import time import subprocess def benchmark(device, dtype): # Create example inputs x = torch.testing.make_tensor(1, 1, 2048, 2048, device=device, dtype=dtype) sf = .5 # Check output y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="bilinear") z = torch.nn.functional.interpolate(x.cpu(), scale_factor=sf, mode="bilinear") outputs_match = torch.allclose(y.cpu(), z) if not outputs_match: atol = (y.cpu() - z).abs().max() rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max() print(f"atol={atol} rtol={rtol}") # Measure time manually start_time = time.time() * 1000 for _ in range(1000): y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="bilinear") torch.mps.synchronize end_time = time.time() * 1000 manual_delta = (end_time - start_time) average_time = f"{manual_delta:6.1f}" return "True " if outputs_match else "False", average_time outputs_match_list = [] average_time_list = [] for device in ["mps", "cpu"]: for dtype in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]: outputs_match, average_time = benchmark(device, dtype) outputs_match_list.append(str(outputs_match)) average_time_list.append(average_time) brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip() print(f"\nBenchmarking Results (collected on {brand_string}):") print("-"*40) print("Device : MPS | CPU") print("Dtype : FP32 | FP16 | BF16 | U8 | FP32 | FP16 | BF16 | U8") print(f"Outputs Match : ", " | ".join(outputs_match_list)) print(f"Average Time (us) :", " |".join(average_time_list)) ``` Benchmark results before ``` Benchmarking Results (collected on Apple M4 Pro): ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | U8 | FP32 | FP16 | BF16 | U8 Outputs Match : True | True | True | False | True | True | True | True Average Time (us) : 277.3 | 197.2 | 188.0 | 163.5 | 302.8 | 248.1 | 308.7 | 650.9 ``` After(almost **100x** perf gain): ``` Benchmarking Results (collected on Apple M4 Pro): ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | U8 | FP32 | FP16 | BF16 | U8 Outputs Match : True | True | True | True | True | True | True | True Average Time (us) : 1.7 | 1.5 | 1.7 | 1.5 | 296.5 | 236.0 | 310.8 | 642.6 ``` Pull Request resolved: pytorch#145581 Approved by: https://github.com/Skylion007 ghstack dependencies: pytorch#145578
First of all, perf claims made in #145581 and #148154 are too good to be true (due to the bug in the script that did not call `torch.mps.synchronize` at the end of the benchmark script, but still slightly better than MPS, probably due to the launch overhead. And while measure performance correctly, I've noticed that a lot of time is spent on 64-bit integral division of thread_index to get spatial coordinates. Simply downcasting divisior to 32-bit integer (which is also the thread index) speeds it up almost 2x for bilinear and bicubic as could be demonstrated by running following script ```python import torch import time import subprocess import itertools def benchmark(device, dtype, mode="bilinear", antialias=False, sf=.5): # Create example inputs x = torch.testing.make_tensor(1, 1, 2048, 2048, device=device, dtype=dtype) # define kwargs kwargs = {"antialias": antialias, "mode": mode, "scale_factor": sf} # Skip for unimplemented flavors if antialias and mode == "bicubic" and device == "mps": return None, "Skip" elif antialias and dtype != torch.float32: if device == "cpu": return None, "Skip" outputs_match = None else: # Check output y = torch.nn.functional.interpolate(x, **kwargs) z = torch.nn.functional.interpolate(x.cpu(), **kwargs) outputs_match = torch.allclose(y.cpu(), z) if not outputs_match: atol = (y.cpu() - z).abs().max() rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max() print(f"atol={atol} rtol={rtol}") # Measure time manually start_time = time.time() * 1000 for _ in range(1000): y = torch.nn.functional.interpolate(x, **kwargs) torch.mps.synchronize() end_time = time.time() * 1000 manual_delta = (end_time - start_time) average_time = f"{manual_delta:6.1f}" return "True " if outputs_match else "False", average_time brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip() for mode,antialias in itertools.product(["bilinear", "bicubic"], [False, True]): outputs_match_list = [] average_time_list = [] for device in ["mps", "cpu"]: for dtype in [torch.float32, torch.float16, torch.bfloat16]: outputs_match, average_time = benchmark(device, dtype, mode=mode, antialias=antialias) outputs_match_list.append(str(outputs_match)) average_time_list.append(average_time) print(f"\nBenchmarking Results (collected on {brand_string}) for {mode} interpolation {'with antialias' if antialias else ''}:") print("-"*40) print("Device : MPS | CPU") print("Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16") print(f"Outputs Match : ", " | ".join(outputs_match_list)) print(f"Average Time (us) :", " |".join(average_time_list)) ``` Before ``` Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : True | True | True | True | True | True Average Time (us) : 292.0 | 264.7 | 267.9 | 289.1 | 230.9 | 309.1 atol=1.430511474609375e-06 rtol=0.11363636702299118 Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation with antialias: ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | False | False | True | None | None Average Time (us) : 698.3 | 684.2 | 683.8 | 851.0 |Skip |Skip atol=2.086162567138672e-06 rtol=0.019750799983739853 Benchmarking Results (collected on Apple M4 Pro) for bicubic interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | True | True | True | True | True Average Time (us) : 314.3 | 301.0 | 298.8 | 681.5 | 616.7 | 833.7 ``` After ``` Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : True | True | True | True | True | True Average Time (us) : 119.9 | 98.9 | 98.6 | 289.8 | 231.9 | 308.5 atol=1.430511474609375e-06 rtol=0.05681818351149559 Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation with antialias: ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | False | False | True | None | None Average Time (us) : 541.9 | 531.1 | 531.0 | 846.8 |Skip |Skip atol=2.0265579223632812e-06 rtol=0.008604463189840317 Benchmarking Results (collected on Apple M4 Pro) for bicubic interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | True | True | True | True | True Average Time (us) : 314.3 | 301.0 | 298.8 | 681.5 | 616.7 | 833.7 ``` TODO: - Figure out if this ops make more sense as 3D jobs with n and c channels dispatch as one more dimension Pull Request resolved: #148277 Approved by: https://github.com/Skylion007
First of all, perf claims made in pytorch/pytorch#145581 and pytorch/pytorch#148154 are too good to be true (due to the bug in the script that did not call `torch.mps.synchronize` at the end of the benchmark script, but still slightly better than MPS, probably due to the launch overhead. And while measure performance correctly, I've noticed that a lot of time is spent on 64-bit integral division of thread_index to get spatial coordinates. Simply downcasting divisior to 32-bit integer (which is also the thread index) speeds it up almost 2x for bilinear and bicubic as could be demonstrated by running following script ``` import torch import time import subprocess import itertools def benchmark(device, dtype, mode="bilinear", antialias=False, sf=.5): # Create example inputs x = torch.testing.make_tensor(1, 1, 2048, 2048, device=device, dtype=dtype) # define kwargs kwargs = {"antialias": antialias, "mode": mode, "scale_factor": sf} # Skip for unimplemented flavors if antialias and mode == "bicubic" and device == "mps": return None, "Skip" elif antialias and dtype != torch.float32: if device == "cpu": return None, "Skip" outputs_match = None else: # Check output y = torch.nn.functional.interpolate(x, **kwargs) z = torch.nn.functional.interpolate(x.cpu(), **kwargs) outputs_match = torch.allclose(y.cpu(), z) if not outputs_match: atol = (y.cpu() - z).abs().max() rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max() print(f"atol={atol} rtol={rtol}") # Measure time manually start_time = time.time() * 1000 for _ in range(1000): y = torch.nn.functional.interpolate(x, **kwargs) torch.mps.synchronize() end_time = time.time() * 1000 manual_delta = (end_time - start_time) average_time = f"{manual_delta:6.1f}" return "True " if outputs_match else "False", average_time brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip() for mode,antialias in itertools.product(["bilinear", "bicubic"], [False, True]): outputs_match_list = [] average_time_list = [] for device in ["mps", "cpu"]: for dtype in [torch.float32, torch.float16, torch.bfloat16]: outputs_match, average_time = benchmark(device, dtype, mode=mode, antialias=antialias) outputs_match_list.append(str(outputs_match)) average_time_list.append(average_time) print(f"\nBenchmarking Results (collected on {brand_string}) for {mode} interpolation {'with antialias' if antialias else ''}:") print("-"*40) print("Device : MPS | CPU") print("Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16") print(f"Outputs Match : ", " | ".join(outputs_match_list)) print(f"Average Time (us) :", " |".join(average_time_list)) ``` Before ``` Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : True | True | True | True | True | True Average Time (us) : 292.0 | 264.7 | 267.9 | 289.1 | 230.9 | 309.1 atol=1.430511474609375e-06 rtol=0.11363636702299118 Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation with antialias: ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | False | False | True | None | None Average Time (us) : 698.3 | 684.2 | 683.8 | 851.0 |Skip |Skip atol=2.086162567138672e-06 rtol=0.019750799983739853 Benchmarking Results (collected on Apple M4 Pro) for bicubic interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | True | True | True | True | True Average Time (us) : 314.3 | 301.0 | 298.8 | 681.5 | 616.7 | 833.7 ``` After ``` Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : True | True | True | True | True | True Average Time (us) : 119.9 | 98.9 | 98.6 | 289.8 | 231.9 | 308.5 atol=1.430511474609375e-06 rtol=0.05681818351149559 Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation with antialias: ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | False | False | True | None | None Average Time (us) : 541.9 | 531.1 | 531.0 | 846.8 |Skip |Skip atol=2.0265579223632812e-06 rtol=0.008604463189840317 Benchmarking Results (collected on Apple M4 Pro) for bicubic interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | True | True | True | True | True Average Time (us) : 314.3 | 301.0 | 298.8 | 681.5 | 616.7 | 833.7 ``` ghstack-source-id: c622e55 Pull Request resolved: pytorch/pytorch#148277
Stack from ghstack (oldest at bottom):
That significantly improves performance and addresses correctness problem(to an extend permitted by reducing precision of scale factor computation to float32). uint8 scaling algorithm mimics CPU/Pillow implementation
https://github.com/python-pillow/Pillow/blob/569b785371aa717a004adb0166feb565bbb01b7b/src/libImaging/Resample.c#L306-L309
I.e. using fixed precision integral arithmetic and rounding results of horizontal interpolation back to integers before performing vertical one, which results in technically less accurate results.
But even with those changes,
atol,rtolmust be tweaked to1, 0when scale factor is1/3or2/3because of the difference of representation of those values as floats and doubles.Changes in the performance could be measured using the following script
Benchmark results before
After(almost 100x perf gain):