mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
**Problem:** Fusion can accumulate large amount of reads, which leads to significant increase in peak memory utilization. Imagine we have the following code snippet ``` total = torch.rand(N, N) for _ in range(r): x = torch.rand(N, N) total = total + x ``` The default execution is memory efficient as only two tensors of size N-by-N is in memory at any given time. However, with fusion, the additions are fused into a single operation and the execution becomes something like: ``` x_1 = torch.rand(N, N) x_2 = torch.rand(N, N) ... x_r = torch.rand(N, N) total = x_1 + x_2 + ... + x_r ``` Though this is run-time efficient, in the case of large `N` and/or large `r`, this is not memory efficient. [internal only] see [post](https://fb.workplace.com/groups/1075192433118967/permalink/1703374333634104/) for additional details **Solution:** Our proposed solution is to ban fusions in case where a large amount of reads are accumulated. This is in addition to some existing logics during torch compile. * During lowering (i.e., `ir.py`), the config `realize_acc_reads_threshold`, which is default to be 8, controls _the number of_ buffers can be accumulated for a single operator. However, this is oblivious to the size of the buffers. Hence, we additionally introduce a config `realize_acc_reads_size_threshold` to control _the amount of buffers_ in size that can be accumulated. * During scheduling (i.e., `scheduler.py`), additional fusion will be performed and thus we also need to capture such pattern there. The decisions are implemented under `choices.py`. **Results:** For a small example similar to be one in the test case (but with larger `N` and higher number of loop repeats), the memory snapshot before and after are shown below. Note the snapshot on the right is zoomed out so that the y-axis of the two snapshots match. <img width="1328" alt="image" src="https://github.com/user-attachments/assets/670b5961-8454-4379-ae0f-62d4e7946c64" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/157563 Approved by: https://github.com/jansel, https://github.com/mlazos
308 lines
10 KiB
Python
308 lines
10 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import math
|
|
import os
|
|
|
|
import torch
|
|
import torch._inductor.config as inductor_config
|
|
import torch.nn.functional as F
|
|
from torch._dynamo.utils import rmse, same
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch._inductor.utils import run_and_get_code
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
IS_LINUX,
|
|
parametrize,
|
|
serialTest,
|
|
)
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA
|
|
|
|
|
|
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
|
|
USE_LARGE_INPUT = os.environ.get("USE_LARGE_INPUT") == "1" or DO_PERF_TEST
|
|
|
|
|
|
def _prepare_softmax(x, dim):
|
|
xmax = x.amax(dim=dim, keepdim=True)
|
|
xsum = (x - xmax).exp().sum(dim=dim, keepdim=True)
|
|
return xmax, xsum
|
|
|
|
|
|
class TestOnlineSoftmax(TestCase):
|
|
def do_test_acc_and_perf(self, op):
|
|
if DO_PERF_TEST:
|
|
N = 32 * 1024
|
|
V = 50304 # padded version for gpt2
|
|
else:
|
|
N, V = 1024, 2048 # small value to avoid OOM in CI
|
|
|
|
def f(x):
|
|
return op(x, dim=-1)
|
|
|
|
x = torch.randn(N, V, dtype=torch.bfloat16, device=GPU_TYPE)
|
|
opt_f = torch.compile(f)
|
|
expected = f(x)
|
|
actual = opt_f(x)
|
|
|
|
self.assertTrue(same(expected, actual, tol=1e-2))
|
|
|
|
if DO_PERF_TEST:
|
|
from triton.testing import do_bench
|
|
|
|
eager_ms = do_bench(lambda: f(x))
|
|
opt_ms = do_bench(lambda: opt_f(x))
|
|
print(f"{eager_ms=}")
|
|
print(f"{opt_ms=}")
|
|
|
|
def test_softmax(self):
|
|
self.do_test_acc_and_perf(torch.softmax)
|
|
|
|
def test_log_softmax(self):
|
|
self.do_test_acc_and_perf(torch.log_softmax)
|
|
|
|
@inductor_config.patch(use_fast_math=True)
|
|
def test_prepare_softmax_perf(self):
|
|
self.do_test_acc_and_perf(_prepare_softmax)
|
|
|
|
def get_softmax_wrapper(self, V=50304, use_log_softmax=False, device=GPU_TYPE):
|
|
N = 32 * 1024
|
|
|
|
@torch.compile
|
|
def f(x):
|
|
if use_log_softmax:
|
|
return torch.log_softmax(x, dim=-1)
|
|
else:
|
|
return torch.softmax(x, dim=-1)
|
|
|
|
x = torch.randn(N, V, dtype=torch.bfloat16, device=device)
|
|
out, source_codes = run_and_get_code(f, x)
|
|
return source_codes[0]
|
|
|
|
@serialTest()
|
|
def test_codegen_3pass_softmax_due_to_disable(self):
|
|
with inductor_config.patch(
|
|
online_softmax=False,
|
|
realize_acc_reads_size_threshold=float("inf"),
|
|
):
|
|
wrapper_code = self.get_softmax_wrapper()
|
|
|
|
self.assertEqual(wrapper_code.count("for r0_offset in"), 3)
|
|
|
|
@serialTest()
|
|
@parametrize("V", [2048, 50304])
|
|
@parametrize("use_log_softmax", [False, True])
|
|
def test_codegen_online_softmax(self, use_log_softmax, V):
|
|
wrapper_code = self.get_softmax_wrapper(use_log_softmax=use_log_softmax, V=V)
|
|
|
|
self.assertEqual(wrapper_code.count("for r0_offset in"), 2)
|
|
|
|
def test_no_online_softmax_for_cpu(self):
|
|
code = self.get_softmax_wrapper(V=2048, device="cpu")
|
|
|
|
# CPU need an explicit loop across different rows.
|
|
# For GPU, this is parallelized by the hardware.
|
|
self.assertEqual(code.count("for(int64_t"), 4)
|
|
|
|
def test_codegen_softmax_persistent_reduction(self):
|
|
"""
|
|
Persistent reduction has no for loops.
|
|
"""
|
|
wrapper_code = self.get_softmax_wrapper(1024)
|
|
self.assertEqual(wrapper_code.count("for r0_offset in"), 0)
|
|
|
|
@inductor_config.patch("triton.persistent_reductions", False)
|
|
def test_sdpa(self):
|
|
"""
|
|
Make sure online softmax here does not conflict with the sdpa
|
|
patterns.
|
|
"""
|
|
q, k, v = (
|
|
torch.randn((4, 2, 16, 32), device=GPU_TYPE, dtype=torch.bfloat16)
|
|
for _ in range(3)
|
|
)
|
|
|
|
def f(q, k, v):
|
|
return (
|
|
torch.matmul(q, k.transpose(-2, -1))
|
|
.div(math.sqrt(k.shape[-1]))
|
|
.softmax(dim=-1)
|
|
.matmul(v)
|
|
)
|
|
|
|
opt_f = torch.compile(f)
|
|
ref = f(q, k, v)
|
|
act, (code,) = run_and_get_code(opt_f, q, k, v)
|
|
self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2))
|
|
self.assertTrue("aten._scaled_dot_product_" in code)
|
|
|
|
@parametrize("nrow", [2, 2048])
|
|
@parametrize("dim", [-1, 0, 1])
|
|
def test_prepare_softmax(self, dim, nrow):
|
|
x = torch.randn(nrow, 2048, dtype=torch.bfloat16, device=GPU_TYPE)
|
|
act, (code,) = run_and_get_code(torch.compile(_prepare_softmax), x, dim)
|
|
ref = _prepare_softmax(x, dim)
|
|
self.assertTrue(same(ref, act, tol=1e-2))
|
|
|
|
if nrow == 2048 and dim == 0:
|
|
# split reduction is triggered. We have multiple kernels
|
|
self.assertTrue(code.count("def triton") >= 2)
|
|
else:
|
|
if nrow == 2 and dim == 0:
|
|
# persistent reduction triggered
|
|
expected_num_loop = 0
|
|
else:
|
|
# A single loop due to online softmax
|
|
expected_num_loop = 1
|
|
self.assertEqual(code.count("for r0_offset in"), expected_num_loop)
|
|
|
|
def test_split_reduction(self):
|
|
"""
|
|
We don't split online_softmax_reduce for now. Check
|
|
'Split online_softmax_reduce' note in the code.
|
|
|
|
When a split is promsing, we fallback for now.
|
|
|
|
This is just a manual example rather than something we
|
|
see in practice.
|
|
"""
|
|
# tensor shape to trigger split reduction
|
|
x = torch.randn(1, 2**20, dtype=torch.bfloat16, device=GPU_TYPE)
|
|
ref = torch.softmax(x, dim=-1)
|
|
act, (code,) = run_and_get_code(torch.compile(torch.softmax), x, dim=-1)
|
|
self.assertTrue(torch.allclose(ref, act, atol=1e-3, rtol=1e-3))
|
|
self.assertTrue(code.count("def triton") >= 2)
|
|
self.assertTrue("online_softmax_reduce" not in code)
|
|
|
|
@parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
|
|
def test_prepare_softmax_acc_with_fp64(self, dtype):
|
|
if USE_LARGE_INPUT:
|
|
M, N = 32768, 50257
|
|
else:
|
|
M, N = 1024, 2048
|
|
|
|
x = torch.randn(M, N, device=GPU_TYPE, dtype=dtype)
|
|
|
|
ref_fp64 = _prepare_softmax(x.to(dtype=torch.float64), dim=-1)
|
|
ref = _prepare_softmax(x, dim=-1)
|
|
res, (code,) = run_and_get_code(torch.compile(_prepare_softmax), x, dim=-1)
|
|
self.assertTrue("online_softmax_reduce" in code)
|
|
|
|
# Max should be exactly equal
|
|
self.assertEqual(ref[0], res[0])
|
|
self.assertEqual(ref[0].to(dtype=torch.float64), ref_fp64[0])
|
|
|
|
ref_error = rmse(ref_fp64[1], ref[1]).item()
|
|
res_error = rmse(ref_fp64[1], res[1]).item()
|
|
|
|
# My local tests even shows a smaller res_error:
|
|
# ref_error=2.1065, res_error=2.1028
|
|
# for bf16
|
|
# ref_error=0.2611, res_error=0.2609
|
|
# for fp16
|
|
# ref_error=0.0001, res_error=0.0001
|
|
# for fp32
|
|
print(f"{ref_error=:.4f}, {res_error=:.4f}")
|
|
|
|
self.assertTrue(
|
|
res_error < ref_error + 0.1
|
|
) # Is this good enough to make CI stable
|
|
|
|
@parametrize("fn", [torch.log_softmax, torch.softmax])
|
|
@parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
|
|
def test_softmax_acc_with_fp64(self, dtype, fn):
|
|
if USE_LARGE_INPUT:
|
|
M, N = 32768, 50257
|
|
else:
|
|
M, N = 1024, 2048
|
|
|
|
x = torch.randn(M, N, device=GPU_TYPE, dtype=dtype)
|
|
|
|
ref_fp64 = fn(x.to(dtype=torch.float64), dim=-1)
|
|
ref = fn(x, dim=-1)
|
|
res, (code,) = run_and_get_code(torch.compile(fn), x, dim=-1)
|
|
self.assertTrue("online_softmax_reduce" in code)
|
|
|
|
ref_error = rmse(ref_fp64, ref).item()
|
|
res_error = rmse(ref_fp64, res).item()
|
|
|
|
# For torch.softmax,
|
|
# I get almost 0 for ref_error/res_error for all 3 dtypes. It's because
|
|
# each value is very small since each row add up to 1.0
|
|
#
|
|
# For torch.log_softmax
|
|
# ref_error=0.0180399032, res_error=0.0180399031
|
|
# for bf16
|
|
# ref_error=0.0022548872, res_error=0.0022548872
|
|
# for fp16
|
|
# ref_error=0.0000003744, res_error=0.0000003748
|
|
# for fp32
|
|
print(f"{ref_error=:.10f}, {res_error=:.10f}")
|
|
|
|
self.assertTrue(
|
|
res_error < ref_error + 0.1
|
|
) # Is this good enough to make CI stable
|
|
|
|
def test_softmin(self):
|
|
"""
|
|
The rnumel==1 kind of reduction should be unrolled.
|
|
"""
|
|
|
|
def f(x):
|
|
return F.softmin(x, dim=0)
|
|
|
|
x = torch.randn(1, device=GPU_TYPE)
|
|
ref = f(x)
|
|
act, (code,) = run_and_get_code(torch.compile(f), x)
|
|
self.assertTrue(torch.allclose(ref, act))
|
|
self.assertTrue("online_softmax_reduce" not in code)
|
|
|
|
def test_causal_mask(self):
|
|
def f(x):
|
|
return x.softmax(dim=-1)
|
|
|
|
x = torch.randn(2048, 2048, device=GPU_TYPE)
|
|
mask = torch.tril(torch.ones(2048, 2048, device=GPU_TYPE))
|
|
x.masked_fill_(mask == 0, float("-inf"))
|
|
|
|
ref = f(x)
|
|
act = torch.compile(f)(x)
|
|
self.assertTrue(not ref.isnan().any())
|
|
self.assertTrue(not act.isnan().any())
|
|
self.assertTrue(torch.allclose(ref, act))
|
|
|
|
def test_tb_speech_transformer_attn(self):
|
|
"""
|
|
This is an example extracted from speech_transformer.
|
|
Since online softmax use the max from partial elements of an entire
|
|
row, if the input contains '-inf', it's possible that the
|
|
max of those partial elements is '-inf' even if the entire row
|
|
has non '-inf' value. In this cause, online softmax will need
|
|
do things like 'float(-inf) - float(-inf)' which becomes 'nan'.
|
|
We fixed this by interpreting 'float(-inf) - float(-inf)' as 0
|
|
if we found both operands are 'float(-inf)'.
|
|
"""
|
|
torch.manual_seed(1337)
|
|
|
|
def f(x, mask):
|
|
x = torch.where(mask, float("-inf"), x)
|
|
xmax = x.amax(dim=-1, keepdim=True)
|
|
xsum = (x - xmax).exp().sum(dim=-1, keepdim=True)
|
|
return xsum
|
|
|
|
x = torch.randn(8, 10, 22, 204, device=GPU_TYPE)
|
|
mask = torch.randint(0, 2, (10, 204), device=GPU_TYPE) == 0
|
|
mask = mask.view(1, 10, 1, 204)
|
|
|
|
ref = f(x, mask)
|
|
act = torch.compile(f)(x, mask)
|
|
self.assertTrue(not ref.isnan().any())
|
|
self.assertTrue(not act.isnan().any())
|
|
self.assertTrue(torch.allclose(ref, act))
|
|
|
|
|
|
instantiate_parametrized_tests(TestOnlineSoftmax)
|
|
|
|
if __name__ == "__main__":
|
|
if IS_LINUX and HAS_CUDA:
|
|
run_tests()
|