Skip to content

Conversation

@malfet
Copy link
Contributor

@malfet malfet commented Jan 24, 2025

Stack from ghstack (oldest at bottom):

In preparation for more metal shaders to come

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Jan 24, 2025
@malfet malfet requested a review from kulinseth as a code owner January 24, 2025 02:03
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145578

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 12 Pending

As of commit 34cf566 with merge base 66bf7da (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@malfet malfet added the topic: not user facing topic category label Jan 24, 2025
@malfet
Copy link
Contributor Author

malfet commented Jan 24, 2025

@pytorchbot merge -f "Lint is green, otherwise it's a no-op"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jan 25, 2025
That significantly improves performance and addresses correctness problem(to an extend permitted by reducing precision of scale factor computation to float32). uint8 scaling algorithm mimics CPU/Pillow implementation
https://github.com/python-pillow/Pillow/blob/569b785371aa717a004adb0166feb565bbb01b7b/src/libImaging/Resample.c#L306-L309
I.e. using fixed precision integral arithmetic and rounding results of horizontal interpolation back to integers before performing vertical one, which results in technically less accurate results.

But even with those changes, `atol`, `rtol` must be tweaked to `1, 0` when scale factor is `1/3` or `2/3` because of the difference of representation  of those values as floats and doubles.

Changes in the performance could be measured using the following script
```python
import torch
import time
import subprocess

def benchmark(device, dtype):
    # Create example inputs
    x = torch.testing.make_tensor(1, 1, 2048, 2048, device=device, dtype=dtype)
    sf = .5

    # Check output
    y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="bilinear")
    z = torch.nn.functional.interpolate(x.cpu(), scale_factor=sf, mode="bilinear")
    outputs_match = torch.allclose(y.cpu(), z)
    if not outputs_match:
       atol = (y.cpu() - z).abs().max()
       rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max()
       print(f"atol={atol} rtol={rtol}")

    # Measure time manually
    start_time = time.time() * 1000
    for _ in range(1000):
        y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="bilinear")
    torch.mps.synchronize
    end_time = time.time() * 1000
    manual_delta = (end_time - start_time)
    average_time = f"{manual_delta:6.1f}"

    return "True " if outputs_match else "False", average_time

outputs_match_list = []
average_time_list = []
for device in ["mps", "cpu"]:
    for dtype in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]:
        outputs_match, average_time = benchmark(device, dtype)
        outputs_match_list.append(str(outputs_match))
        average_time_list.append(average_time)

brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip()
print(f"\nBenchmarking Results (collected on {brand_string}):")
print("-"*40)
print("Device            :                MPS                 |               CPU")
print("Dtype             :   FP32  |  FP16  |  BF16  |   U8   |  FP32  |  FP16  |  BF16  |  U8")
print(f"Outputs Match     :  ", " |  ".join(outputs_match_list))
print(f"Average Time (us) :", "  |".join(average_time_list))
```

Benchmark results before
```
Benchmarking Results (collected on Apple M4 Pro):
----------------------------------------
Device            :                MPS                 |               CPU
Dtype             :   FP32  |  FP16  |  BF16  |   U8   |  FP32  |  FP16  |  BF16  |  U8
Outputs Match     :   True  |  True  |  True  |  False |  True  |  True  |  True  |  True
Average Time (us) :  277.3  | 197.2  | 188.0  | 163.5  | 302.8  | 248.1  | 308.7  | 650.9
```
After(almost **100x** perf gain):
```
Benchmarking Results (collected on Apple M4 Pro):
----------------------------------------
Device            :                MPS                 |               CPU
Dtype             :   FP32  |  FP16  |  BF16  |   U8   |  FP32  |  FP16  |  BF16  |  U8
Outputs Match     :   True  |  True  |  True  |  True  |  True  |  True  |  True  |  True
Average Time (us) :    1.7  |   1.5  |   1.7  |   1.5  | 296.5  | 236.0  | 310.8  | 642.6
```
Pull Request resolved: #145581
Approved by: https://github.com/Skylion007
ghstack dependencies: #145578
nWEIdia pushed a commit to nWEIdia/pytorch that referenced this pull request Jan 27, 2025
That significantly improves performance and addresses correctness problem(to an extend permitted by reducing precision of scale factor computation to float32). uint8 scaling algorithm mimics CPU/Pillow implementation
https://github.com/python-pillow/Pillow/blob/569b785371aa717a004adb0166feb565bbb01b7b/src/libImaging/Resample.c#L306-L309
I.e. using fixed precision integral arithmetic and rounding results of horizontal interpolation back to integers before performing vertical one, which results in technically less accurate results.

But even with those changes, `atol`, `rtol` must be tweaked to `1, 0` when scale factor is `1/3` or `2/3` because of the difference of representation  of those values as floats and doubles.

Changes in the performance could be measured using the following script
```python
import torch
import time
import subprocess

def benchmark(device, dtype):
    # Create example inputs
    x = torch.testing.make_tensor(1, 1, 2048, 2048, device=device, dtype=dtype)
    sf = .5

    # Check output
    y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="bilinear")
    z = torch.nn.functional.interpolate(x.cpu(), scale_factor=sf, mode="bilinear")
    outputs_match = torch.allclose(y.cpu(), z)
    if not outputs_match:
       atol = (y.cpu() - z).abs().max()
       rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max()
       print(f"atol={atol} rtol={rtol}")

    # Measure time manually
    start_time = time.time() * 1000
    for _ in range(1000):
        y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="bilinear")
    torch.mps.synchronize
    end_time = time.time() * 1000
    manual_delta = (end_time - start_time)
    average_time = f"{manual_delta:6.1f}"

    return "True " if outputs_match else "False", average_time

outputs_match_list = []
average_time_list = []
for device in ["mps", "cpu"]:
    for dtype in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]:
        outputs_match, average_time = benchmark(device, dtype)
        outputs_match_list.append(str(outputs_match))
        average_time_list.append(average_time)

brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip()
print(f"\nBenchmarking Results (collected on {brand_string}):")
print("-"*40)
print("Device            :                MPS                 |               CPU")
print("Dtype             :   FP32  |  FP16  |  BF16  |   U8   |  FP32  |  FP16  |  BF16  |  U8")
print(f"Outputs Match     :  ", " |  ".join(outputs_match_list))
print(f"Average Time (us) :", "  |".join(average_time_list))
```

Benchmark results before
```
Benchmarking Results (collected on Apple M4 Pro):
----------------------------------------
Device            :                MPS                 |               CPU
Dtype             :   FP32  |  FP16  |  BF16  |   U8   |  FP32  |  FP16  |  BF16  |  U8
Outputs Match     :   True  |  True  |  True  |  False |  True  |  True  |  True  |  True
Average Time (us) :  277.3  | 197.2  | 188.0  | 163.5  | 302.8  | 248.1  | 308.7  | 650.9
```
After(almost **100x** perf gain):
```
Benchmarking Results (collected on Apple M4 Pro):
----------------------------------------
Device            :                MPS                 |               CPU
Dtype             :   FP32  |  FP16  |  BF16  |   U8   |  FP32  |  FP16  |  BF16  |  U8
Outputs Match     :   True  |  True  |  True  |  True  |  True  |  True  |  True  |  True
Average Time (us) :    1.7  |   1.5  |   1.7  |   1.5  | 296.5  | 236.0  | 310.8  | 642.6
```
Pull Request resolved: pytorch#145581
Approved by: https://github.com/Skylion007
ghstack dependencies: pytorch#145578
@github-actions github-actions bot deleted the gh/malfet/146/head branch February 24, 2025 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) Merged release notes: mps Release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants