import timeit from functools import partial import numpy as np import pandas as pd import torch from functorch.compile import pointwise_operator WRITE_CSV = False CUDA = False SIZES = [1, 512, 8192] NUMBER = [100, 10, 1, 1] REPEAT = 20 @pointwise_operator def nnc_add(a, b): return a + b @pointwise_operator def nnc_addnorm(a, b, mean, std): return (a + b - mean) / std def eager_addnorm(a, b, mean, std): return (a + b - mean) / std def inplace_addnorm(a, b, mean, std, out): out = torch.add(a, b, out=out) torch.sub(out, mean, out=out) torch.div(out, std, out=out) return out ts_addnorm = torch.jit.script(eager_addnorm) ts_ip_addnorm = torch.jit.script(inplace_addnorm) def maybe_synced(fn): if CUDA: synchronize = torch.cuda.synchronize synchronize() # warmup def _fn(): result = fn() synchronize() return result return _fn return fn def benchmark_loop(setup): result = np.zeros((REPEAT, len(SIZES), 2), dtype=np.float64) for s, n in enumerate(SIZES): nnc, aten = setup(n) nnc = maybe_synced(nnc) aten = maybe_synced(aten) for r in range(result.shape[0]): result[r, s, 0] = timeit.timeit(nnc, number=NUMBER[s]) result[r, s, 1] = timeit.timeit(aten, number=NUMBER[s]) result = np.median(result, axis=0) assert result.shape == (len(SIZES), 2) result = result[:, 1] / result[:, 0] print(result) return result def test(make_args, nnc=nnc_add, aten=torch.add): def setup(n): args = make_args(n) result_aten = aten(*args) result_nnc = nnc(*args) assert result_nnc.dtype == result_aten.dtype assert result_nnc.size() == result_aten.size() assert result_nnc.stride() == result_aten.stride() torch.testing.assert_close(result_aten, result_nnc) return (lambda: nnc(*args), lambda: aten(*args)) return benchmark_loop(setup) def test_inplace(make_args, nnc=nnc_add, aten=torch.add): def inplace_setup(n): a, b = make_args(n) result_aten = torch.clone(a) result_nnc = torch.clone(a) nnc(result_nnc, b, out=result_nnc) aten(result_aten, b, out=result_aten) torch.testing.assert_close(result_aten, result_nnc) return (lambda: nnc(a, b, out=a), lambda: aten(a, b, out=a)) return benchmark_loop(inplace_setup) def test_out(make_args, out, nnc=nnc_add, aten=torch.add): def out_setup(n): args = make_args(n) result_aten = out(n) result_nnc = out(n) aten(*args, out=result_aten) nnc(*args, out=result_nnc) torch.testing.assert_close(result_aten, result_nnc) result = out(n) return (lambda: nnc(*args, out=result), lambda: aten(*args, out=result)) return benchmark_loop(out_setup) def test_backwards(make_args, nnc=nnc_add, aten=torch.add): def backwards_setup(n): args = make_args(n) (grad_var,) = (a for a in args if a.requires_grad) aten(*args).sum().backward() correct = grad_var.grad.clone() grad_var.grad.zero_() nnc(*args).sum().backward() torch.testing.assert_close(correct, grad_var.grad) return ( lambda: nnc(*args).sum().backward(), lambda: aten(*args).sum().backward(), ) return benchmark_loop(backwards_setup) def main(): torch.set_num_threads(1) # TODO(jansel): add parallel support torch._C._jit_override_can_fuse_on_cpu(True) device = "cuda" if CUDA else "cpu" I = partial(torch.randint, 0, 100, device=device) R = partial(torch.randn, device=device) results = [ ("add", test(lambda n: (R(n, n), R(n, n)))), ("broadcast1", test(lambda n: (R(n, n), R(1)))), ("broadcast2", test(lambda n: (R(n, n), R(n, 1)))), ("broadcast3", test(lambda n: (R(n, 1), R(1, n)))), ("inplace", test_inplace(lambda n: (R(n, n), R(n, 1)))), ("out=", test_out(lambda n: (R(n, n), R(n, n)), out=lambda n: R(n, n))), ("transposed1", test(lambda n: (R(n, n), R(n, n).transpose(0, 1)))), ( "transposed2", test(lambda n: (R(n, n).transpose(0, 1), R(n, n).transpose(0, 1))), ), ("slice1", test(lambda n: (R(n + 1, n + 1, 2)[:n, :n, 0], R(n, n)))), ("slice2", test(lambda n: (R(n, n, 2)[:, :, 0], R(n, n, 2)[:, :, 0]))), ( "strided out", test_out( lambda n: (R(n, n), R(n, n)), out=lambda n: R(n + 1, n + 1, 2)[:n, :n, 0], ), ), ( "out convert", test_out( lambda n: (R(n, n), R(n, n)), out=lambda n: R(n, n, dtype=torch.float64) ), ), ("issue #57611 (n,32,32,2)", test(lambda n: (R(1, 32, 32, 2), R(n, 1, 1, 2)))), ("float+double", test(lambda n: (R(n, n), R(n, n, dtype=torch.float64)))), ( "int+long", test( lambda n: (I([n, n], dtype=torch.int32), I([n, n], dtype=torch.int64)) ), ), ( "int+short", test( lambda n: (I([n, n], dtype=torch.int32), I([n, n], dtype=torch.int16)) ), ), ( "float+int", test( lambda n: (R([n, n], dtype=torch.float32), I([n, n], dtype=torch.int32)) ), ), ( "double+long", test( lambda n: (R([n, n], dtype=torch.float64), I([n, n], dtype=torch.int64)) ), ), ( "fused addnorm", test( lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)), nnc=nnc_addnorm, aten=eager_addnorm, ), ), ( "fused addnorm (vs TS)", test( lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)), nnc=nnc_addnorm, aten=ts_addnorm, ), ), ( "fused addnorm out=", test_out( lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)), nnc=nnc_addnorm, aten=inplace_addnorm, out=lambda n: R(n, n), ), ), ( "fused addnorm out= (vs TS)", test_out( lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)), nnc=nnc_addnorm, aten=ts_ip_addnorm, out=lambda n: R(n, n), ), ), ( "fused addnorm backward", test_backwards( lambda n: (R(n, n), R(n, n, requires_grad=True), R(n, n), R(n, n)), nnc=nnc_addnorm, aten=eager_addnorm, ), ), ( "fused addnorm backward (vs TS)", test_backwards( lambda n: (R(n, n), R(n, n, requires_grad=True), R(n, n), R(n, n)), nnc=nnc_addnorm, aten=ts_addnorm, ), ), ] df = pd.DataFrame( np.stack([r for n, r in results]), columns=[f"{n}x{n}".rjust(9) for n in SIZES], index=[n for n, r in results], ) if WRITE_CSV: df.to_csv("../operator_authoring_results.csv") print("wrote ../operator_authoring_results.csv") print() print("Speedups over aten") pd.options.display.float_format = "{:.2f}x".format print(df) if __name__ == "__main__": main()