Files
pytorch/test/bench_mps_ops.py
Manuel Candales 12b02137af [MPS] Add benchmark for scan operations (#156241)
Comparison of cumsum performance before and after Metal implementaton:

Previous performance (using torch==2.7.1):
```[-------------------------------  -------------------------------]
                                              |  eager  |  compile
1 threads: -------------------------------------------------------
      cumsum-dim0-32x32 (torch.float16)       |  131.0  |   136.9
      cumsum-dim0-128x128 (torch.float16)     |  116.9  |   121.2
      cumsum-dim0-512x512 (torch.float16)     |  132.5  |   151.9
      cumsum-dim0-1024x1024 (torch.float16)   |  150.0  |   163.0
      cumsum-dim1-32x32 (torch.float16)       |  125.9  |   140.9
      cumsum-dim1-128x128 (torch.float16)     |  116.4  |   129.4
      cumsum-dim1-512x512 (torch.float16)     |  135.9  |   150.1
      cumsum-dim1-1024x1024 (torch.float16)   |  139.5  |   154.2
      cumsum-1d-100 (torch.float16)           |  119.5  |   127.1
      cumsum-1d-10000 (torch.float16)         |  128.9  |   142.5
      cumsum-1d-1000000 (torch.float16)       |  140.6  |   145.6
      cumsum-dim0-32x32 (torch.float32)       |  115.7  |   132.5
      cumsum-dim0-128x128 (torch.float32)     |  118.0  |   131.5
      cumsum-dim0-512x512 (torch.float32)     |  138.8  |   151.6
      cumsum-dim0-1024x1024 (torch.float32)   |  155.5  |   164.2
      cumsum-dim1-32x32 (torch.float32)       |  127.2  |   141.7
      cumsum-dim1-128x128 (torch.float32)     |  117.7  |   130.5
      cumsum-dim1-512x512 (torch.float32)     |  138.2  |   152.3
      cumsum-dim1-1024x1024 (torch.float32)   |  144.4  |   158.6
      cumsum-1d-100 (torch.float32)           |  118.6  |   128.0
      cumsum-1d-10000 (torch.float32)         |  125.5  |   141.5
      cumsum-1d-1000000 (torch.float32)       |  143.9  |   158.4
      cumsum-dim0-32x32 (torch.bfloat16)      |  106.6  |   137.6
      cumsum-dim0-128x128 (torch.bfloat16)    |  118.1  |   131.0
      cumsum-dim0-512x512 (torch.bfloat16)    |  140.0  |   154.3
      cumsum-dim0-1024x1024 (torch.bfloat16)  |  153.2  |   164.4
      cumsum-dim1-32x32 (torch.bfloat16)      |  127.9  |   132.6
      cumsum-dim1-128x128 (torch.bfloat16)    |  116.5  |   129.6
      cumsum-dim1-512x512 (torch.bfloat16)    |  136.5  |   151.2
      cumsum-dim1-1024x1024 (torch.bfloat16)  |  139.8  |   144.8
      cumsum-1d-100 (torch.bfloat16)          |  115.7  |   129.4
      cumsum-1d-10000 (torch.bfloat16)        |  125.0  |   143.3
      cumsum-1d-1000000 (torch.bfloat16)      |  127.8  |   143.4

Times are in microseconds (us).
```

Current performance:
```
[--------------------------------  --------------------------------]
                                              |   eager   |  compile
1 threads: ---------------------------------------------------------
      cumsum-dim0-32x32 (torch.float16)       |    107.4  |    123.8
      cumsum-dim0-128x128 (torch.float16)     |    134.2  |    145.8
      cumsum-dim0-512x512 (torch.float16)     |    207.3  |    231.6
      cumsum-dim0-1024x1024 (torch.float16)   |    318.9  |    355.3
      cumsum-dim1-32x32 (torch.float16)       |     98.0  |    114.3
      cumsum-dim1-128x128 (torch.float16)     |    110.8  |    121.6
      cumsum-dim1-512x512 (torch.float16)     |    193.0  |    209.1
      cumsum-dim1-1024x1024 (torch.float16)   |    844.7  |    870.8
      cumsum-1d-100 (torch.float16)           |    108.4  |    125.0
      cumsum-1d-10000 (torch.float16)         |    784.7  |    852.3
      cumsum-1d-1000000 (torch.float16)       |  65855.2  |  66725.9
      cumsum-dim0-32x32 (torch.float32)       |    114.7  |    115.7
      cumsum-dim0-128x128 (torch.float32)     |    139.0  |    151.6
      cumsum-dim0-512x512 (torch.float32)     |    197.3  |    208.0
      cumsum-dim0-1024x1024 (torch.float32)   |    312.7  |    332.9
      cumsum-dim1-32x32 (torch.float32)       |     92.0  |    110.8
      cumsum-dim1-128x128 (torch.float32)     |    114.2  |    125.0
      cumsum-dim1-512x512 (torch.float32)     |    186.2  |    196.1
      cumsum-dim1-1024x1024 (torch.float32)   |    752.0  |    825.0
      cumsum-1d-100 (torch.float32)           |    112.4  |    122.0
      cumsum-1d-10000 (torch.float32)         |    793.5  |    863.5
      cumsum-1d-1000000 (torch.float32)       |  66431.8  |  66040.0
      cumsum-dim0-32x32 (torch.bfloat16)      |    111.6  |    121.6
      cumsum-dim0-128x128 (torch.bfloat16)    |    139.0  |    138.4
      cumsum-dim0-512x512 (torch.bfloat16)    |    217.6  |    230.1
      cumsum-dim0-1024x1024 (torch.bfloat16)  |    305.2  |    325.6
      cumsum-dim1-32x32 (torch.bfloat16)      |    100.5  |    110.9
      cumsum-dim1-128x128 (torch.bfloat16)    |    112.8  |    125.0
      cumsum-dim1-512x512 (torch.bfloat16)    |    187.8  |    208.9
      cumsum-dim1-1024x1024 (torch.bfloat16)  |    790.9  |    864.7
      cumsum-1d-100 (torch.bfloat16)          |    111.6  |    124.6
      cumsum-1d-10000 (torch.bfloat16)        |    778.1  |    844.9
      cumsum-1d-1000000 (torch.bfloat16)      |  64654.3  |  64082.5

Times are in microseconds (us).
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156241
Approved by: https://github.com/malfet
2025-06-17 22:30:22 +00:00

186 lines
6.2 KiB
Python

# Owner(s): ["module: mps"]
# Collection of op level benchmarks for MPS
# Useful as reference tool when migrating ops from MPS to Metal
import itertools
import timeit
import warnings
from typing import Optional
import torch
from torch.utils.benchmark import Compare, Measurement, Timer
def bench_unary_op(func, x, label) -> Measurement:
sync_cmd = "torch.mps.synchronize()" if "mps" in str(x.device) else ""
t = Timer(
stmt=f"f(x);{sync_cmd}",
globals={"f": func, "x": x},
language="python",
timer=timeit.default_timer,
sub_label=f"{func.__name__} ({str(x.dtype)})",
description=label,
env=torch.__version__,
)
return t.blocked_autorange()
def bench_binary_op(func, x, y, label) -> Measurement:
sync_cmd = "torch.mps.synchronize()" if "mps" in str(x.device) else ""
t = Timer(
stmt=f"f(x, y);{sync_cmd}",
globals={"f": func, "x": x, "y": y},
language="python",
timer=timeit.default_timer,
sub_label=f"{func.__name__} ({str(x.dtype)}, {str(y.dtype)})",
description=label,
env=torch.__version__,
)
return t.blocked_autorange()
def bench_unary(
unary_func, device: str = "mps", dtype: torch.dtype = torch.float32
) -> list[Measurement]:
x = torch.testing.make_tensor(1024, 1024, device=device, dtype=dtype)
x_s = torch.testing.make_tensor(1024, 2048, device=device, dtype=dtype)[::, ::2]
rc = []
rc.append(bench_unary_op(unary_func, x, "dense"))
rc.append(bench_unary_op(unary_func, x.t(), "transposed"))
rc.append(bench_unary_op(unary_func, x_s, "strided"))
rc.append(bench_unary_op(unary_func, x_s.t(), "strided + transposed"))
return rc
def bench_binary(
binary_func,
device: str = "mps",
dt_a: torch.dtype = torch.float32,
dt_b: Optional[torch.dtype] = None,
) -> list[Measurement]:
dt_b = dt_b if dt_b is not None else dt_a
x = torch.testing.make_tensor(1024, 1024, device=device, dtype=dt_a)
y = torch.testing.make_tensor(1024, 1024, device=device, dtype=dt_b)
s = torch.testing.make_tensor((), device=device, dtype=dt_b)
rc = []
rc.append(bench_binary_op(binary_func, x, y, "dense-dense"))
rc.append(bench_binary_op(binary_func, x.t(), y.t(), "transp-transp"))
rc.append(bench_binary_op(binary_func, x, y.t(), "dense-transp"))
rc.append(bench_binary_op(binary_func, x.t(), y, "transp-dense"))
rc.append(bench_binary_op(binary_func, x, s, "dense-scalar"))
rc.append(bench_binary_op(binary_func, x, y[0], "dense-bcast"))
return rc
def bench_reduction(
reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32
) -> list[Measurement]:
rc = []
# Bench 2D with reduction over dim=0
def f(t):
return reduction_func(t, dim=0)
f.__name__ = reduction_func.__name__
f_c = torch.compile(f, dynamic=False)
for size in (512, 1024, 2048, 4096):
x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
rc_c, rc_e = f(x), f_c(x)
rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e)
if not torch.allclose(rc_c, rc_e):
mdiff = (rc_c - rc_e).abs().max()
warnings.warn(
f"Eager and compile reduction do not match for {reduction_func.__name__} and {dtype} max_diff={mdiff}",
stacklevel=2,
)
rc.append(bench_unary_op(f, x, f"eager-{size}x{size}"))
rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}"))
return rc
def bench_scan(
scan_func, device: str = "mps", dtype: torch.dtype = torch.float32
) -> list[Measurement]:
rc = []
# Bench cumsum along different dimensions
for dim in [0, 1]:
def f(t):
return scan_func(t, dim=dim)
f_c = torch.compile(f, dynamic=False)
for size in (32, 128, 512, 1024):
f.__name__ = f"{scan_func.__name__}-dim{dim}-{size}x{size}"
f_c.__name__ = f.__name__
x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
rc_c, rc_e = f(x), f_c(x)
if not torch.allclose(rc_c, rc_e):
mdiff = (rc_c - rc_e).abs().max()
warnings.warn(
f"Eager and compile scan do not match for {scan_func.__name__} dim={dim} and {dtype} max_diff={mdiff}",
stacklevel=2,
)
rc.append(bench_unary_op(f, x, "eager"))
rc.append(bench_unary_op(f_c, x, "compile"))
# Bench 1D cumsum for different sizes
def f_1d(t):
return scan_func(t, dim=0)
f_1d_c = torch.compile(f_1d, dynamic=False)
for size in (100, 10000, 1000000):
f_1d.__name__ = f"{scan_func.__name__}-1d-{size}"
f_1d_c.__name__ = f_1d.__name__
x = torch.testing.make_tensor(size, device=device, dtype=dtype)
rc_c, rc_e = f_1d(x), f_1d_c(x)
if not torch.allclose(rc_c, rc_e):
mdiff = (rc_c - rc_e).abs().max()
warnings.warn(
f"Eager and compile 1D scan do not match for {scan_func.__name__} and {dtype} max_diff={mdiff}",
stacklevel=2,
)
rc.append(bench_unary_op(f_1d, x, "eager"))
rc.append(bench_unary_op(f_1d_c, x, "compile"))
return rc
def main() -> None:
dtypes = [torch.float16, torch.float32]
if torch.backends.mps.is_macos_or_newer(14, 0):
dtypes.append(torch.bfloat16)
# Profile unary ops
rc = []
for op, dtype in itertools.product([torch.sqrt, torch.sin], dtypes):
rc.extend(bench_unary(op, dtype=dtype))
Compare(rc).print()
# Profile reduction ops
rc = []
for op in [torch.sum, torch.max]:
rc.extend(bench_reduction(op))
Compare(rc).print()
# Profile scan ops (cumsum)
rc = []
for dtype in dtypes:
rc.extend(bench_scan(torch.cumsum, dtype=dtype))
Compare(rc).print()
# Profile binary ops
rc = []
ops = [torch.fmax, torch.add]
for op, dtype in itertools.product(ops, dtypes):
rc.extend(bench_binary(op, dt_a=dtype))
if dtype == torch.float32:
rc.extend(bench_binary(op, dt_b=torch.float16))
Compare(rc).print()
if __name__ == "__main__":
main()