mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Benchmark: NVIDIA GTX 1650 + AMD Ryzen Threadripper 3970X ```python import torch print(torch.__version__) for i in range(1000): torch.randn(1024 * 128, device='cuda') def cuda(e): a = torch.randn(2 ** e, 32, device='cuda') s = torch.randn(32, device='cuda') z = torch.randn(32, device='cuda') torch.cuda.synchronize() %timeit torch.fake_quantize_per_channel_affine(a, s, z, 1, -999, 999); torch.cuda.synchronize() def cpu(e): a = torch.randn(2 ** e, 32, device='cpu') s = torch.randn(32, device='cpu') z = torch.randn(32, device='cpu') %timeit torch.fake_quantize_per_channel_affine(a, s, z, 1, -999, 999); for i in range(10, 24): cuda(i) print() for i in range(10, 32): cpu(i) ``` Before ``` 1.5.0a0+9bc922d 849 µs ± 44.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 817 µs ± 30.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 814 µs ± 2.93 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.11 ms ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.19 ms ± 4.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.6 ms ± 5.58 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 2.44 ms ± 14.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 4.14 ms ± 2.55 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 7.41 ms ± 2.46 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 13.9 ms ± 2.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 26.9 ms ± 254 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 52.6 ms ± 260 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 104 ms ± 176 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 207 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 249 µs ± 158 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 420 µs ± 230 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 766 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.45 ms ± 574 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 2.84 ms ± 34.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 5.69 ms ± 83 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 7.29 ms ± 2.58 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 7.32 ms ± 13.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 17.4 ms ± 38.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 47.5 ms ± 264 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 187 ms ± 1.19 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 379 ms ± 5.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 652 ms ± 11.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.22 s ± 4.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 2.34 s ± 8.77 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 4.56 s ± 7.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 8.97 s ± 33.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 17.8 s ± 32.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 35.2 s ± 167 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` After ``` 1.5.0a0+a7ec8cc 92.5 µs ± 2.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 97.7 µs ± 469 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 109 µs ± 4.73 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 119 µs ± 6.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 146 µs ± 1.84 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 211 µs ± 2.45 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 347 µs ± 4.18 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 624 µs ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.17 ms ± 16.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 2.25 ms ± 48.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 4.43 ms ± 220 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 8.51 ms ± 44.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 16.9 ms ± 30.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 33.7 ms ± 7.64 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 201 µs ± 234 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 285 µs ± 465 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 287 µs ± 214 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 287 µs ± 221 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 287 µs ± 761 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 347 µs ± 399 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 675 µs ± 213 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.34 ms ± 643 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 4.82 ms ± 34.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 10.7 ms ± 88.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 20.3 ms ± 25.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 39.4 ms ± 242 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 78.8 ms ± 2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 153 ms ± 786 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 285 ms ± 911 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 541 ms ± 1.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.03 s ± 1.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.97 s ± 8.59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 3.81 s ± 10.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` Fixes https://github.com/pytorch/pytorch/issues/33647 Pull Request resolved: https://github.com/pytorch/pytorch/pull/33772 Differential Revision: D20112531 Pulled By: ngimel fbshipit-source-id: f90e3ef1b5be8276851637f3e1251cb8f1af411f