mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Comparison of cumsum performance before and after Metal implementaton: Previous performance (using torch==2.7.1): ```[------------------------------- -------------------------------] | eager | compile 1 threads: ------------------------------------------------------- cumsum-dim0-32x32 (torch.float16) | 131.0 | 136.9 cumsum-dim0-128x128 (torch.float16) | 116.9 | 121.2 cumsum-dim0-512x512 (torch.float16) | 132.5 | 151.9 cumsum-dim0-1024x1024 (torch.float16) | 150.0 | 163.0 cumsum-dim1-32x32 (torch.float16) | 125.9 | 140.9 cumsum-dim1-128x128 (torch.float16) | 116.4 | 129.4 cumsum-dim1-512x512 (torch.float16) | 135.9 | 150.1 cumsum-dim1-1024x1024 (torch.float16) | 139.5 | 154.2 cumsum-1d-100 (torch.float16) | 119.5 | 127.1 cumsum-1d-10000 (torch.float16) | 128.9 | 142.5 cumsum-1d-1000000 (torch.float16) | 140.6 | 145.6 cumsum-dim0-32x32 (torch.float32) | 115.7 | 132.5 cumsum-dim0-128x128 (torch.float32) | 118.0 | 131.5 cumsum-dim0-512x512 (torch.float32) | 138.8 | 151.6 cumsum-dim0-1024x1024 (torch.float32) | 155.5 | 164.2 cumsum-dim1-32x32 (torch.float32) | 127.2 | 141.7 cumsum-dim1-128x128 (torch.float32) | 117.7 | 130.5 cumsum-dim1-512x512 (torch.float32) | 138.2 | 152.3 cumsum-dim1-1024x1024 (torch.float32) | 144.4 | 158.6 cumsum-1d-100 (torch.float32) | 118.6 | 128.0 cumsum-1d-10000 (torch.float32) | 125.5 | 141.5 cumsum-1d-1000000 (torch.float32) | 143.9 | 158.4 cumsum-dim0-32x32 (torch.bfloat16) | 106.6 | 137.6 cumsum-dim0-128x128 (torch.bfloat16) | 118.1 | 131.0 cumsum-dim0-512x512 (torch.bfloat16) | 140.0 | 154.3 cumsum-dim0-1024x1024 (torch.bfloat16) | 153.2 | 164.4 cumsum-dim1-32x32 (torch.bfloat16) | 127.9 | 132.6 cumsum-dim1-128x128 (torch.bfloat16) | 116.5 | 129.6 cumsum-dim1-512x512 (torch.bfloat16) | 136.5 | 151.2 cumsum-dim1-1024x1024 (torch.bfloat16) | 139.8 | 144.8 cumsum-1d-100 (torch.bfloat16) | 115.7 | 129.4 cumsum-1d-10000 (torch.bfloat16) | 125.0 | 143.3 cumsum-1d-1000000 (torch.bfloat16) | 127.8 | 143.4 Times are in microseconds (us). ``` Current performance: ``` [-------------------------------- --------------------------------] | eager | compile 1 threads: --------------------------------------------------------- cumsum-dim0-32x32 (torch.float16) | 107.4 | 123.8 cumsum-dim0-128x128 (torch.float16) | 134.2 | 145.8 cumsum-dim0-512x512 (torch.float16) | 207.3 | 231.6 cumsum-dim0-1024x1024 (torch.float16) | 318.9 | 355.3 cumsum-dim1-32x32 (torch.float16) | 98.0 | 114.3 cumsum-dim1-128x128 (torch.float16) | 110.8 | 121.6 cumsum-dim1-512x512 (torch.float16) | 193.0 | 209.1 cumsum-dim1-1024x1024 (torch.float16) | 844.7 | 870.8 cumsum-1d-100 (torch.float16) | 108.4 | 125.0 cumsum-1d-10000 (torch.float16) | 784.7 | 852.3 cumsum-1d-1000000 (torch.float16) | 65855.2 | 66725.9 cumsum-dim0-32x32 (torch.float32) | 114.7 | 115.7 cumsum-dim0-128x128 (torch.float32) | 139.0 | 151.6 cumsum-dim0-512x512 (torch.float32) | 197.3 | 208.0 cumsum-dim0-1024x1024 (torch.float32) | 312.7 | 332.9 cumsum-dim1-32x32 (torch.float32) | 92.0 | 110.8 cumsum-dim1-128x128 (torch.float32) | 114.2 | 125.0 cumsum-dim1-512x512 (torch.float32) | 186.2 | 196.1 cumsum-dim1-1024x1024 (torch.float32) | 752.0 | 825.0 cumsum-1d-100 (torch.float32) | 112.4 | 122.0 cumsum-1d-10000 (torch.float32) | 793.5 | 863.5 cumsum-1d-1000000 (torch.float32) | 66431.8 | 66040.0 cumsum-dim0-32x32 (torch.bfloat16) | 111.6 | 121.6 cumsum-dim0-128x128 (torch.bfloat16) | 139.0 | 138.4 cumsum-dim0-512x512 (torch.bfloat16) | 217.6 | 230.1 cumsum-dim0-1024x1024 (torch.bfloat16) | 305.2 | 325.6 cumsum-dim1-32x32 (torch.bfloat16) | 100.5 | 110.9 cumsum-dim1-128x128 (torch.bfloat16) | 112.8 | 125.0 cumsum-dim1-512x512 (torch.bfloat16) | 187.8 | 208.9 cumsum-dim1-1024x1024 (torch.bfloat16) | 790.9 | 864.7 cumsum-1d-100 (torch.bfloat16) | 111.6 | 124.6 cumsum-1d-10000 (torch.bfloat16) | 778.1 | 844.9 cumsum-1d-1000000 (torch.bfloat16) | 64654.3 | 64082.5 Times are in microseconds (us). ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156241 Approved by: https://github.com/malfet
186 lines
6.2 KiB
Python
186 lines
6.2 KiB
Python
# Owner(s): ["module: mps"]
|
|
# Collection of op level benchmarks for MPS
|
|
# Useful as reference tool when migrating ops from MPS to Metal
|
|
import itertools
|
|
import timeit
|
|
import warnings
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch.utils.benchmark import Compare, Measurement, Timer
|
|
|
|
|
|
def bench_unary_op(func, x, label) -> Measurement:
|
|
sync_cmd = "torch.mps.synchronize()" if "mps" in str(x.device) else ""
|
|
t = Timer(
|
|
stmt=f"f(x);{sync_cmd}",
|
|
globals={"f": func, "x": x},
|
|
language="python",
|
|
timer=timeit.default_timer,
|
|
sub_label=f"{func.__name__} ({str(x.dtype)})",
|
|
description=label,
|
|
env=torch.__version__,
|
|
)
|
|
return t.blocked_autorange()
|
|
|
|
|
|
def bench_binary_op(func, x, y, label) -> Measurement:
|
|
sync_cmd = "torch.mps.synchronize()" if "mps" in str(x.device) else ""
|
|
t = Timer(
|
|
stmt=f"f(x, y);{sync_cmd}",
|
|
globals={"f": func, "x": x, "y": y},
|
|
language="python",
|
|
timer=timeit.default_timer,
|
|
sub_label=f"{func.__name__} ({str(x.dtype)}, {str(y.dtype)})",
|
|
description=label,
|
|
env=torch.__version__,
|
|
)
|
|
return t.blocked_autorange()
|
|
|
|
|
|
def bench_unary(
|
|
unary_func, device: str = "mps", dtype: torch.dtype = torch.float32
|
|
) -> list[Measurement]:
|
|
x = torch.testing.make_tensor(1024, 1024, device=device, dtype=dtype)
|
|
x_s = torch.testing.make_tensor(1024, 2048, device=device, dtype=dtype)[::, ::2]
|
|
rc = []
|
|
rc.append(bench_unary_op(unary_func, x, "dense"))
|
|
rc.append(bench_unary_op(unary_func, x.t(), "transposed"))
|
|
rc.append(bench_unary_op(unary_func, x_s, "strided"))
|
|
rc.append(bench_unary_op(unary_func, x_s.t(), "strided + transposed"))
|
|
return rc
|
|
|
|
|
|
def bench_binary(
|
|
binary_func,
|
|
device: str = "mps",
|
|
dt_a: torch.dtype = torch.float32,
|
|
dt_b: Optional[torch.dtype] = None,
|
|
) -> list[Measurement]:
|
|
dt_b = dt_b if dt_b is not None else dt_a
|
|
x = torch.testing.make_tensor(1024, 1024, device=device, dtype=dt_a)
|
|
y = torch.testing.make_tensor(1024, 1024, device=device, dtype=dt_b)
|
|
s = torch.testing.make_tensor((), device=device, dtype=dt_b)
|
|
rc = []
|
|
rc.append(bench_binary_op(binary_func, x, y, "dense-dense"))
|
|
rc.append(bench_binary_op(binary_func, x.t(), y.t(), "transp-transp"))
|
|
rc.append(bench_binary_op(binary_func, x, y.t(), "dense-transp"))
|
|
rc.append(bench_binary_op(binary_func, x.t(), y, "transp-dense"))
|
|
rc.append(bench_binary_op(binary_func, x, s, "dense-scalar"))
|
|
rc.append(bench_binary_op(binary_func, x, y[0], "dense-bcast"))
|
|
return rc
|
|
|
|
|
|
def bench_reduction(
|
|
reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32
|
|
) -> list[Measurement]:
|
|
rc = []
|
|
|
|
# Bench 2D with reduction over dim=0
|
|
def f(t):
|
|
return reduction_func(t, dim=0)
|
|
|
|
f.__name__ = reduction_func.__name__
|
|
f_c = torch.compile(f, dynamic=False)
|
|
|
|
for size in (512, 1024, 2048, 4096):
|
|
x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
|
|
rc_c, rc_e = f(x), f_c(x)
|
|
rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e)
|
|
if not torch.allclose(rc_c, rc_e):
|
|
mdiff = (rc_c - rc_e).abs().max()
|
|
warnings.warn(
|
|
f"Eager and compile reduction do not match for {reduction_func.__name__} and {dtype} max_diff={mdiff}",
|
|
stacklevel=2,
|
|
)
|
|
rc.append(bench_unary_op(f, x, f"eager-{size}x{size}"))
|
|
rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}"))
|
|
return rc
|
|
|
|
|
|
def bench_scan(
|
|
scan_func, device: str = "mps", dtype: torch.dtype = torch.float32
|
|
) -> list[Measurement]:
|
|
rc = []
|
|
|
|
# Bench cumsum along different dimensions
|
|
for dim in [0, 1]:
|
|
|
|
def f(t):
|
|
return scan_func(t, dim=dim)
|
|
|
|
f_c = torch.compile(f, dynamic=False)
|
|
|
|
for size in (32, 128, 512, 1024):
|
|
f.__name__ = f"{scan_func.__name__}-dim{dim}-{size}x{size}"
|
|
f_c.__name__ = f.__name__
|
|
x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
|
|
rc_c, rc_e = f(x), f_c(x)
|
|
if not torch.allclose(rc_c, rc_e):
|
|
mdiff = (rc_c - rc_e).abs().max()
|
|
warnings.warn(
|
|
f"Eager and compile scan do not match for {scan_func.__name__} dim={dim} and {dtype} max_diff={mdiff}",
|
|
stacklevel=2,
|
|
)
|
|
rc.append(bench_unary_op(f, x, "eager"))
|
|
rc.append(bench_unary_op(f_c, x, "compile"))
|
|
|
|
# Bench 1D cumsum for different sizes
|
|
def f_1d(t):
|
|
return scan_func(t, dim=0)
|
|
|
|
f_1d_c = torch.compile(f_1d, dynamic=False)
|
|
|
|
for size in (100, 10000, 1000000):
|
|
f_1d.__name__ = f"{scan_func.__name__}-1d-{size}"
|
|
f_1d_c.__name__ = f_1d.__name__
|
|
x = torch.testing.make_tensor(size, device=device, dtype=dtype)
|
|
rc_c, rc_e = f_1d(x), f_1d_c(x)
|
|
if not torch.allclose(rc_c, rc_e):
|
|
mdiff = (rc_c - rc_e).abs().max()
|
|
warnings.warn(
|
|
f"Eager and compile 1D scan do not match for {scan_func.__name__} and {dtype} max_diff={mdiff}",
|
|
stacklevel=2,
|
|
)
|
|
rc.append(bench_unary_op(f_1d, x, "eager"))
|
|
rc.append(bench_unary_op(f_1d_c, x, "compile"))
|
|
|
|
return rc
|
|
|
|
|
|
def main() -> None:
|
|
dtypes = [torch.float16, torch.float32]
|
|
if torch.backends.mps.is_macos_or_newer(14, 0):
|
|
dtypes.append(torch.bfloat16)
|
|
|
|
# Profile unary ops
|
|
rc = []
|
|
for op, dtype in itertools.product([torch.sqrt, torch.sin], dtypes):
|
|
rc.extend(bench_unary(op, dtype=dtype))
|
|
Compare(rc).print()
|
|
|
|
# Profile reduction ops
|
|
rc = []
|
|
for op in [torch.sum, torch.max]:
|
|
rc.extend(bench_reduction(op))
|
|
Compare(rc).print()
|
|
|
|
# Profile scan ops (cumsum)
|
|
rc = []
|
|
for dtype in dtypes:
|
|
rc.extend(bench_scan(torch.cumsum, dtype=dtype))
|
|
Compare(rc).print()
|
|
|
|
# Profile binary ops
|
|
rc = []
|
|
ops = [torch.fmax, torch.add]
|
|
for op, dtype in itertools.product(ops, dtypes):
|
|
rc.extend(bench_binary(op, dt_a=dtype))
|
|
if dtype == torch.float32:
|
|
rc.extend(bench_binary(op, dt_b=torch.float16))
|
|
Compare(rc).print()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|