mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	The online_softmax_reduce runtime helper previously assumes the input tl.Tensor's are 2d tensors. But with tiled reduction, they can be 3d (y, x, r). Pull Request resolved: https://github.com/pytorch/pytorch/pull/162341 Approved by: https://github.com/jansel, https://github.com/eellison ghstack dependencies: #162311
		
			
				
	
	
		
			315 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			315 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Owner(s): ["module: inductor"]
 | 
						|
 | 
						|
import math
 | 
						|
import os
 | 
						|
 | 
						|
import torch
 | 
						|
import torch._inductor.config as inductor_config
 | 
						|
import torch.nn.functional as F
 | 
						|
from torch._dynamo.utils import rmse, same
 | 
						|
from torch._inductor.test_case import run_tests, TestCase
 | 
						|
from torch._inductor.utils import run_and_get_code
 | 
						|
from torch.testing._internal.common_utils import (
 | 
						|
    instantiate_parametrized_tests,
 | 
						|
    IS_LINUX,
 | 
						|
    parametrize,
 | 
						|
)
 | 
						|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON
 | 
						|
 | 
						|
 | 
						|
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
 | 
						|
USE_LARGE_INPUT = os.environ.get("USE_LARGE_INPUT") == "1" or DO_PERF_TEST
 | 
						|
 | 
						|
 | 
						|
def _prepare_softmax(x, dim):
 | 
						|
    xmax = x.amax(dim=dim, keepdim=True)
 | 
						|
    xsum = (x - xmax).exp().sum(dim=dim, keepdim=True)
 | 
						|
    return xmax, xsum
 | 
						|
 | 
						|
 | 
						|
class TestOnlineSoftmax(TestCase):
 | 
						|
    def do_test_acc_and_perf(self, op):
 | 
						|
        if DO_PERF_TEST:
 | 
						|
            N = 32 * 1024
 | 
						|
            V = 50304  # padded version for gpt2
 | 
						|
        else:
 | 
						|
            N, V = 1024, 2048  # small value to avoid OOM in CI
 | 
						|
 | 
						|
        def f(x):
 | 
						|
            return op(x, dim=-1)
 | 
						|
 | 
						|
        x = torch.randn(N, V, dtype=torch.bfloat16, device=GPU_TYPE)
 | 
						|
        opt_f = torch.compile(f)
 | 
						|
        expected = f(x)
 | 
						|
        actual = opt_f(x)
 | 
						|
 | 
						|
        self.assertTrue(same(expected, actual, tol=1e-2))
 | 
						|
 | 
						|
        if DO_PERF_TEST:
 | 
						|
            from triton.testing import do_bench
 | 
						|
 | 
						|
            eager_ms = do_bench(lambda: f(x))
 | 
						|
            opt_ms = do_bench(lambda: opt_f(x))
 | 
						|
            print(f"{eager_ms=}")
 | 
						|
            print(f"{opt_ms=}")
 | 
						|
 | 
						|
    def test_softmax(self):
 | 
						|
        self.do_test_acc_and_perf(torch.softmax)
 | 
						|
 | 
						|
    def test_log_softmax(self):
 | 
						|
        self.do_test_acc_and_perf(torch.log_softmax)
 | 
						|
 | 
						|
    @inductor_config.patch(use_fast_math=True)
 | 
						|
    def test_prepare_softmax_perf(self):
 | 
						|
        self.do_test_acc_and_perf(_prepare_softmax)
 | 
						|
 | 
						|
    def get_softmax_wrapper(self, V=50304, use_log_softmax=False, device=GPU_TYPE):
 | 
						|
        N = 32 * 1024
 | 
						|
 | 
						|
        @torch.compile
 | 
						|
        def f(x):
 | 
						|
            if use_log_softmax:
 | 
						|
                return torch.log_softmax(x, dim=-1)
 | 
						|
            else:
 | 
						|
                return torch.softmax(x, dim=-1)
 | 
						|
 | 
						|
        x = torch.randn(N, V, dtype=torch.bfloat16, device=device)
 | 
						|
        out, source_codes = run_and_get_code(f, x)
 | 
						|
        return source_codes[0]
 | 
						|
 | 
						|
    def test_codegen_3pass_softmax_due_to_disable(self):
 | 
						|
        with inductor_config.patch(online_softmax=False):
 | 
						|
            wrapper_code = self.get_softmax_wrapper()
 | 
						|
 | 
						|
        self.assertEqual(wrapper_code.count("for r0_offset in"), 3)
 | 
						|
 | 
						|
    @parametrize("V", [2048, 50304])
 | 
						|
    @parametrize("use_log_softmax", [False, True])
 | 
						|
    def test_codegen_online_softmax(self, use_log_softmax, V):
 | 
						|
        wrapper_code = self.get_softmax_wrapper(use_log_softmax=use_log_softmax, V=V)
 | 
						|
 | 
						|
        self.assertEqual(wrapper_code.count("for r0_offset in"), 2)
 | 
						|
 | 
						|
    def test_no_online_softmax_for_cpu(self):
 | 
						|
        code = self.get_softmax_wrapper(V=2048, device="cpu")
 | 
						|
 | 
						|
        # CPU need an explicit loop across different rows.
 | 
						|
        # For GPU, this is parallelized by the hardware.
 | 
						|
        self.assertEqual(code.count("for(int64_t"), 4)
 | 
						|
 | 
						|
    def test_codegen_softmax_persistent_reduction(self):
 | 
						|
        """
 | 
						|
        Persistent reduction has no for loops.
 | 
						|
        """
 | 
						|
        wrapper_code = self.get_softmax_wrapper(1024)
 | 
						|
        self.assertEqual(wrapper_code.count("for r0_offset in"), 0)
 | 
						|
 | 
						|
    @inductor_config.patch("triton.persistent_reductions", False)
 | 
						|
    def test_sdpa(self):
 | 
						|
        """
 | 
						|
        Make sure online softmax here does not conflict with the sdpa
 | 
						|
        patterns.
 | 
						|
        """
 | 
						|
        q, k, v = (
 | 
						|
            torch.randn((4, 2, 16, 32), device=GPU_TYPE, dtype=torch.bfloat16)
 | 
						|
            for _ in range(3)
 | 
						|
        )
 | 
						|
 | 
						|
        def f(q, k, v):
 | 
						|
            return (
 | 
						|
                torch.matmul(q, k.transpose(-2, -1))
 | 
						|
                .div(math.sqrt(k.shape[-1]))
 | 
						|
                .softmax(dim=-1)
 | 
						|
                .matmul(v)
 | 
						|
            )
 | 
						|
 | 
						|
        opt_f = torch.compile(f)
 | 
						|
        ref = f(q, k, v)
 | 
						|
        act, (code,) = run_and_get_code(opt_f, q, k, v)
 | 
						|
        self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2))
 | 
						|
        self.assertTrue("aten._scaled_dot_product_" in code)
 | 
						|
 | 
						|
    @parametrize("nrow", [2, 2048])
 | 
						|
    @parametrize("dim", [-1, 0, 1])
 | 
						|
    def test_prepare_softmax(self, dim, nrow):
 | 
						|
        x = torch.randn(nrow, 2048, dtype=torch.bfloat16, device=GPU_TYPE)
 | 
						|
        act, (code,) = run_and_get_code(torch.compile(_prepare_softmax), x, dim)
 | 
						|
        ref = _prepare_softmax(x, dim)
 | 
						|
        self.assertTrue(same(ref, act, tol=1e-2))
 | 
						|
 | 
						|
        if nrow == 2048 and dim == 0:
 | 
						|
            # split reduction is triggered. We have multiple kernels
 | 
						|
            self.assertTrue(code.count("def triton") >= 2)
 | 
						|
        else:
 | 
						|
            if nrow == 2 and dim == 0:
 | 
						|
                # persistent reduction triggered
 | 
						|
                expected_num_loop = 0
 | 
						|
            else:
 | 
						|
                # A single loop due to online softmax
 | 
						|
                expected_num_loop = 1
 | 
						|
            self.assertEqual(code.count("for r0_offset in"), expected_num_loop)
 | 
						|
 | 
						|
    def test_split_reduction(self):
 | 
						|
        """
 | 
						|
        We don't split online_softmax_reduce for now. Check
 | 
						|
        'Split online_softmax_reduce' note in the code.
 | 
						|
 | 
						|
        When a split is promsing, we fallback for now.
 | 
						|
 | 
						|
        This is just a manual example rather than something we
 | 
						|
        see in practice.
 | 
						|
        """
 | 
						|
        # tensor shape to trigger split reduction
 | 
						|
        x = torch.randn(1, 2**20, dtype=torch.bfloat16, device=GPU_TYPE)
 | 
						|
        ref = torch.softmax(x, dim=-1)
 | 
						|
        act, (code,) = run_and_get_code(torch.compile(torch.softmax), x, dim=-1)
 | 
						|
        self.assertTrue(torch.allclose(ref, act, atol=1e-3, rtol=1e-3))
 | 
						|
        self.assertTrue(code.count("def triton") >= 2)
 | 
						|
        self.assertTrue("online_softmax_reduce" not in code)
 | 
						|
 | 
						|
    @parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
 | 
						|
    def test_prepare_softmax_acc_with_fp64(self, dtype):
 | 
						|
        if USE_LARGE_INPUT:
 | 
						|
            M, N = 32768, 50257
 | 
						|
        else:
 | 
						|
            M, N = 1024, 2048
 | 
						|
 | 
						|
        x = torch.randn(M, N, device=GPU_TYPE, dtype=dtype)
 | 
						|
 | 
						|
        ref_fp64 = _prepare_softmax(x.to(dtype=torch.float64), dim=-1)
 | 
						|
        ref = _prepare_softmax(x, dim=-1)
 | 
						|
        res, (code,) = run_and_get_code(torch.compile(_prepare_softmax), x, dim=-1)
 | 
						|
        self.assertTrue("online_softmax_reduce" in code)
 | 
						|
 | 
						|
        # Max should be exactly equal
 | 
						|
        self.assertEqual(ref[0], res[0])
 | 
						|
        self.assertEqual(ref[0].to(dtype=torch.float64), ref_fp64[0])
 | 
						|
 | 
						|
        ref_error = rmse(ref_fp64[1], ref[1]).item()
 | 
						|
        res_error = rmse(ref_fp64[1], res[1]).item()
 | 
						|
 | 
						|
        # My local tests even shows a smaller res_error:
 | 
						|
        #   ref_error=2.1065, res_error=2.1028
 | 
						|
        # for bf16
 | 
						|
        #   ref_error=0.2611, res_error=0.2609
 | 
						|
        # for fp16
 | 
						|
        #   ref_error=0.0001, res_error=0.0001
 | 
						|
        # for fp32
 | 
						|
        print(f"{ref_error=:.4f}, {res_error=:.4f}")
 | 
						|
 | 
						|
        self.assertTrue(
 | 
						|
            res_error < ref_error + 0.1
 | 
						|
        )  # Is this good enough to make CI stable
 | 
						|
 | 
						|
    @parametrize("fn", [torch.log_softmax, torch.softmax])
 | 
						|
    @parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
 | 
						|
    def test_softmax_acc_with_fp64(self, dtype, fn):
 | 
						|
        if USE_LARGE_INPUT:
 | 
						|
            M, N = 32768, 50257
 | 
						|
        else:
 | 
						|
            M, N = 1024, 2048
 | 
						|
 | 
						|
        x = torch.randn(M, N, device=GPU_TYPE, dtype=dtype)
 | 
						|
 | 
						|
        ref_fp64 = fn(x.to(dtype=torch.float64), dim=-1)
 | 
						|
        ref = fn(x, dim=-1)
 | 
						|
        res, (code,) = run_and_get_code(torch.compile(fn), x, dim=-1)
 | 
						|
        self.assertTrue("online_softmax_reduce" in code)
 | 
						|
 | 
						|
        ref_error = rmse(ref_fp64, ref).item()
 | 
						|
        res_error = rmse(ref_fp64, res).item()
 | 
						|
 | 
						|
        # For torch.softmax,
 | 
						|
        # I get almost 0 for ref_error/res_error for all 3 dtypes. It's because
 | 
						|
        # each value is very small since each row add up to 1.0
 | 
						|
        #
 | 
						|
        # For torch.log_softmax
 | 
						|
        #   ref_error=0.0180399032, res_error=0.0180399031
 | 
						|
        # for bf16
 | 
						|
        #   ref_error=0.0022548872, res_error=0.0022548872
 | 
						|
        # for fp16
 | 
						|
        #   ref_error=0.0000003744, res_error=0.0000003748
 | 
						|
        # for fp32
 | 
						|
        print(f"{ref_error=:.10f}, {res_error=:.10f}")
 | 
						|
 | 
						|
        self.assertTrue(
 | 
						|
            res_error < ref_error + 0.1
 | 
						|
        )  # Is this good enough to make CI stable
 | 
						|
 | 
						|
    def test_softmin(self):
 | 
						|
        """
 | 
						|
        The rnumel==1 kind of reduction should be unrolled.
 | 
						|
        """
 | 
						|
 | 
						|
        def f(x):
 | 
						|
            return F.softmin(x, dim=0)
 | 
						|
 | 
						|
        x = torch.randn(1, device=GPU_TYPE)
 | 
						|
        ref = f(x)
 | 
						|
        act, (code,) = run_and_get_code(torch.compile(f), x)
 | 
						|
        self.assertTrue(torch.allclose(ref, act))
 | 
						|
        self.assertTrue("online_softmax_reduce" not in code)
 | 
						|
 | 
						|
    def test_causal_mask(self):
 | 
						|
        def f(x):
 | 
						|
            return x.softmax(dim=-1)
 | 
						|
 | 
						|
        x = torch.randn(2048, 2048, device=GPU_TYPE)
 | 
						|
        mask = torch.tril(torch.ones(2048, 2048, device=GPU_TYPE))
 | 
						|
        x.masked_fill_(mask == 0, float("-inf"))
 | 
						|
 | 
						|
        ref = f(x)
 | 
						|
        act = torch.compile(f)(x)
 | 
						|
        self.assertTrue(not ref.isnan().any())
 | 
						|
        self.assertTrue(not act.isnan().any())
 | 
						|
        self.assertTrue(torch.allclose(ref, act))
 | 
						|
 | 
						|
    def test_tb_speech_transformer_attn(self):
 | 
						|
        """
 | 
						|
        This is an example extracted from speech_transformer.
 | 
						|
        Since online softmax use the max from partial elements of an entire
 | 
						|
        row, if the input contains '-inf', it's possible that the
 | 
						|
        max of those partial elements is '-inf' even if the entire row
 | 
						|
        has non '-inf' value. In this cause, online softmax will need
 | 
						|
        do things like 'float(-inf) - float(-inf)' which becomes 'nan'.
 | 
						|
        We fixed this by interpreting 'float(-inf) - float(-inf)' as 0
 | 
						|
        if we found both operands are 'float(-inf)'.
 | 
						|
        """
 | 
						|
        torch.manual_seed(1337)
 | 
						|
 | 
						|
        def f(x, mask):
 | 
						|
            x = torch.where(mask, float("-inf"), x)
 | 
						|
            xmax = x.amax(dim=-1, keepdim=True)
 | 
						|
            xsum = (x - xmax).exp().sum(dim=-1, keepdim=True)
 | 
						|
            return xsum
 | 
						|
 | 
						|
        x = torch.randn(8, 10, 22, 204, device=GPU_TYPE)
 | 
						|
        mask = torch.randint(0, 2, (10, 204), device=GPU_TYPE) == 0
 | 
						|
        mask = mask.view(1, 10, 1, 204)
 | 
						|
 | 
						|
        ref = f(x, mask)
 | 
						|
        act = torch.compile(f)(x, mask)
 | 
						|
        self.assertTrue(not ref.isnan().any())
 | 
						|
        self.assertTrue(not act.isnan().any())
 | 
						|
        self.assertTrue(torch.allclose(ref, act))
 | 
						|
 | 
						|
    @inductor_config.patch(split_reductions=False)
 | 
						|
    def test_3d_tiled_online_softmax(self):
 | 
						|
        def f(x, y):
 | 
						|
            return (x * y).softmax(dim=-1)
 | 
						|
 | 
						|
        M, N, K = 32, 8, 1024
 | 
						|
 | 
						|
        x = torch.randn(K, N, M, device=GPU_TYPE).permute(2, 1, 0)
 | 
						|
        y = torch.randn(K, M, N, device=GPU_TYPE).permute(1, 2, 0)
 | 
						|
 | 
						|
        opt_f = torch.compile(f)
 | 
						|
        torch.testing.assert_close(f(x, y), opt_f(x, y), atol=1e-3, rtol=1e-3)
 | 
						|
 | 
						|
 | 
						|
instantiate_parametrized_tests(TestOnlineSoftmax)
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    if IS_LINUX and HAS_CUDA_AND_TRITON:
 | 
						|
        run_tests()
 |