From 46d1422afd3ba8393fc7560c9277327161e4fa4a Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Tue, 25 Feb 2025 15:46:09 +0000 Subject: [PATCH] cpp_wrapper: fix inductor triton tests (#146109) Pull Request resolved: https://github.com/pytorch/pytorch/pull/146109 Approved by: https://github.com/desertfire --- .../inductor/test_triton_extension_backend.py | 3 + test/inductor/test_triton_kernels.py | 104 +++++++++++------- test/inductor/test_triton_syntax.py | 6 +- torch/_inductor/codegen/cpp_micro_gemm.py | 2 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 6 +- 5 files changed, 77 insertions(+), 44 deletions(-) diff --git a/test/inductor/test_triton_extension_backend.py b/test/inductor/test_triton_extension_backend.py index c2a0a8cdea7f..37b32404508b 100644 --- a/test/inductor/test_triton_extension_backend.py +++ b/test/inductor/test_triton_extension_backend.py @@ -65,6 +65,9 @@ def mock_triton_hash_with_backend(*args, **kwargs): @unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now") +@test_torchinductor.skip_if_cpp_wrapper( + "Not possible to fix until CppWrapperCpu supports triton for CPU" +) class TritonExtensionBackendTests(BaseExtensionBackendTests): """ Test creating a backend for inductor with Triton scheduling. diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 0bd492aa3289..b95a65a1b6d0 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -4,17 +4,17 @@ # Skip do not assign a lambda expression, use a def import functools import logging -from unittest.mock import patch import torch import torch._dynamo.testing import torch._inductor.test_case +from torch._dynamo import config as dynamo_config from torch._higher_order_ops.triton_kernel_wrap import ( generate_ttir, triton_kernel_wrapper_functional, triton_kernel_wrapper_mutation, ) -from torch._inductor import metrics +from torch._inductor import config as inductor_config, metrics from torch._inductor.utils import run_and_get_code, triton_version_uses_attrs_dict from torch._library import capture_triton from torch.testing import FileCheck @@ -72,6 +72,11 @@ if HAS_GPU: class KernelTests(torch._inductor.test_case.TestCase): + def _kernel_launched_in_code(self, kernel_name: str, code: str) -> bool: + if inductor_config.cpp_wrapper: + return f"launchKernel({kernel_name}" in code + return f"{kernel_name}.run(" in code + @requires_gpu def test_triton_kernel_with_kernel_param(self): @triton.jit @@ -344,8 +349,13 @@ def forward(self, x_1, output_1): output_code = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip() self.assertTrue(len(output_code) > 0, msg="output code is not empty") - self.assertEqual(output_code.count('float("nan")'), 0) - self.assertEqual(output_code.count("float('nan')"), 0) + if inductor_config.cpp_wrapper: + self.assertEqual( + output_code.count("std::numeric_limits::quiet_NaN()"), 0 + ) + else: + self.assertEqual(output_code.count('float("nan")'), 0) + self.assertEqual(output_code.count("float('nan')"), 0) @requires_gpu @common_utils.parametrize("grad_fn", [torch.no_grad, torch.enable_grad]) @@ -397,7 +407,7 @@ def forward(self, x_1, output_1): @requires_gpu @common_utils.parametrize("grad", [False, True]) @common_utils.parametrize("dynamic", [False, True]) - @patch.object(torch._inductor.config, "implicit_fallbacks", False) + @inductor_config.patch("implicit_fallbacks", False) def test_triton_kernel_no_clones(self, grad, dynamic): from torch._inductor.utils import run_and_get_code @@ -419,7 +429,7 @@ def forward(self, x_1, output_1): torch_add = call_triton(t1, t2, o1) metrics.reset() o2 = torch.zeros_like(t1, requires_grad=grad) - test, codes = run_and_get_code( + test, (code,) = run_and_get_code( torch.compile(call_triton, dynamic=dynamic), t1, t2, o2 ) if not grad: @@ -427,14 +437,27 @@ def forward(self, x_1, output_1): self.assertEqual(torch_add, test) # These two asserts are not optimal since it requires original aten # to be in the metadata, so there might be false negatives - self.assertTrue("aten.copy" not in codes[0]) - self.assertTrue("aten.clone" not in codes[0]) + self.assertNotIn( + "aoti_torch_copy_" if inductor_config.cpp_wrapper else "aten.copy", code + ) + self.assertNotIn( + "aoti_torch_clone" if inductor_config.cpp_wrapper else "aten.clone", code + ) # The following checks that there are only the tensor output is in # the compiled graph if dynamic and grad: - self.assertTrue("return (buf0, s0, )" in codes[0]) + if inductor_config.cpp_wrapper: + self.assertIn("output_handles[0] = ", code) + self.assertIn("output_handles[1] = ", code) + else: + self.assertIn("return (buf0, s0, )", code) else: - self.assertTrue("return (buf0, )" in codes[0]) + self.assertIn( + "output_handles[0] = " + if inductor_config.cpp_wrapper + else "return (buf0, )", + code, + ) @requires_gpu def test_triton_kernel_caching(self): @@ -511,8 +534,8 @@ def forward(self, x_1, output_1): t = torch.ones(5, device=GPU_TYPE) test, (code,) = run_and_get_code(torch.compile(call_triton), t) # Make sure we emitted two kernels here - self.assertTrue("pass_kernel_0.run" in code) - self.assertTrue("pass_kernel_1.run" in code) + self.assertTrue(self._kernel_launched_in_code("pass_kernel_0", code)) + self.assertTrue(self._kernel_launched_in_code("pass_kernel_1", code)) @requires_gpu def test_triton_kernel_various_args(self): @@ -754,9 +777,7 @@ def forward(self, x_1, output_1): @requires_gpu @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) - @patch.object( - torch._inductor.config, "unsafe_ignore_unsupported_triton_autotune_args", True - ) + @inductor_config.patch("unsafe_ignore_unsupported_triton_autotune_args", True) def test_triton_kernel_autotune_with_unsupported_args(self, backend): def call_triton(x: torch.Tensor, y: torch.Tensor): output = torch.zeros_like(x) @@ -882,7 +903,7 @@ def forward(self, x_1, output_1): @common_utils.parametrize("grad", [False, True]) @common_utils.parametrize("dynamic", [False, True]) @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) - @patch.object(torch._inductor.config, "implicit_fallbacks", False) + @inductor_config.patch("implicit_fallbacks", False) def test_triton_kernel_native(self, grad, dynamic, backend): def call_triton_add( x: torch.Tensor, @@ -955,7 +976,7 @@ def forward(self, x_1, output_1): out.sum().backward() @requires_gpu - @patch.object(torch._inductor.config, "allow_buffer_reuse", True) + @inductor_config.patch("allow_buffer_reuse", True) def test_triton_kernel_inputs_buffer_reuse(self): def _mul2(x): y = torch.empty_like(x) @@ -981,13 +1002,18 @@ def forward(self, x_1, output_1): self.assertEqual(compiled_out, eager_out) # Check that we're allocating the minimal # of buffers. - code_string = f"empty_strided_{GPU_TYPE}((10, ), (1, ), torch.float32)" - + code_string = ( + "aoti_torch_empty_strided(" + if inductor_config.cpp_wrapper + else f"empty_strided_{GPU_TYPE}((10, ), (1, ), torch.float32)" + ) num_bufs_allocated = code.count(code_string) self.assertEqual(num_bufs_allocated, 2) # Check we're re-using buffers if not allocating. - num_bufs_reused = code.count("# reuse") + num_bufs_reused = code.count( + "// reuse" if inductor_config.cpp_wrapper else "# reuse" + ) self.assertEqual(num_bufs_reused, 3) @requires_gpu @@ -1036,7 +1062,7 @@ def forward(self, x_1, output_1): compiled_out = torch.compile(f)(inp) self.assertEqual(compiled_out, eager_out) - @torch._inductor.config.patch( + @inductor_config.patch( triton_kernel_default_layout_constraint="needs_fixed_stride_order" ) @requires_gpu @@ -1197,8 +1223,8 @@ def forward(self, x_1, output_1): self.assertEqual(compiled_out, eager_out) @requires_gpu - @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) - @torch._dynamo.config.patch(capture_scalar_outputs=True) + @dynamo_config.patch(capture_dynamic_output_shape_ops=True) + @dynamo_config.patch(capture_scalar_outputs=True) @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) def test_triton_kernel_unbacked_shape_tensor(self, backend): @triton.jit @@ -1403,13 +1429,13 @@ def forward(self, x_1, output_1): ) if size == 4 and not dynamic: # Produce 2 kernels due to divisibility - self.assertTrue("add_kernel_0.run" in code) - self.assertTrue("add_kernel_1.run" in code) + self.assertTrue(self._kernel_launched_in_code("add_kernel_0", code)) + self.assertTrue(self._kernel_launched_in_code("add_kernel_1", code)) else: # size == 16 or dynamic # Only one kernel - self.assertTrue("add_kernel_0.run" in code) - self.assertTrue("add_kernel_1.run" not in code) + self.assertTrue(self._kernel_launched_in_code("add_kernel_0", code)) + self.assertFalse(self._kernel_launched_in_code("add_kernel_1", code)) self.assertEqual(compiled_out, eager_out) @@ -1446,10 +1472,10 @@ def forward(self, x_1, output_1): x = torch.randn(4, device=GPU_TYPE) y = torch.randn(4, device=GPU_TYPE) - args_list = ( - [x, y, torch.float32, tl.float32], - [x, y, torch.bfloat16, tl.bfloat16], - ) + args_list = [(x, y, torch.float32, tl.float32)] + if torch.cuda.is_bf16_supported(including_emulation=False): + args_list.append((x, y, torch.bfloat16, tl.bfloat16)) + for args in args_list: eager_out = f(*args) compiled_out = torch.compile( @@ -2012,7 +2038,7 @@ def forward(self, arg0_1, arg1_1): x = torch.rand(4, device=GPU_TYPE) prev = x.clone() - with torch._inductor.config.patch( + with inductor_config.patch( {"triton.autotune_at_compile_time": autotune_at_compile_time} ): f(x) @@ -2042,7 +2068,7 @@ def forward(self, arg0_1, arg1_1): elif cfg == "cpp_wrapper": config_kwargs = {"cpp_wrapper": True} - with torch._inductor.config.patch(**config_kwargs): + with inductor_config.patch(**config_kwargs): @triton.jit def _triton_kernel(out_ptr, numel, BLOCK_SIZE: tl.constexpr): @@ -2113,7 +2139,7 @@ def forward(self, arg0_1, arg1_1): torch._dynamo.mark_dynamic(inp, 0) fn_c = torch.compile(fn, fullgraph=True) - with torch._dynamo.config.patch(capture_scalar_outputs=True): + with dynamo_config.patch(capture_scalar_outputs=True): res = fn_c(inp) self.assertTrue(((res < 2) & (res >= 0)).all().item()) @@ -3230,7 +3256,7 @@ class CustomOpTests(torch._inductor.test_case.TestCase): self.assertNotIn(opname, code) @requires_gpu - @patch.object(torch._dynamo.config, "recompile_limit", 1) + @dynamo_config.patch("recompile_limit", 1) def test_triton_dynamic_grid_no_recompile(self): libname = "my_cool_namespace" opname = "my_triton_operator" @@ -3410,8 +3436,8 @@ class CustomOpTests(torch._inductor.test_case.TestCase): @skipIfWindows(msg="AOTI/Cpp_Wrapper have not enabled on Windows") @requires_gpu - @patch.object(torch._inductor.config, "cpp_wrapper", True) - @patch.object(torch._inductor.config, "triton.autotune_at_compile_time", True) + @inductor_config.patch("cpp_wrapper", True) + @inductor_config.patch("triton.autotune_at_compile_time", True) def test_autotune_unbacked(self): import triton import triton.language as tl @@ -3644,7 +3670,7 @@ class CustomOpTests(torch._inductor.test_case.TestCase): increment = torch.rand(4, device=GPU_TYPE) # during autotuning, x should not change in value - with torch._inductor.config.patch( + with inductor_config.patch( {"triton.autotune_at_compile_time": autotune_at_compile_time} ): # we will add rand a single time to x @@ -3957,7 +3983,7 @@ class CustomOpTests(torch._inductor.test_case.TestCase): src = torch.empty(N, device=GPU_TYPE) dst = torch.zeros(N, device=GPU_TYPE) - with torch._inductor.config.patch( + with inductor_config.patch( {"triton.autotune_at_compile_time": autotune_at_compile_time} ): compiled_f(dst, src, N=N) diff --git a/test/inductor/test_triton_syntax.py b/test/inductor/test_triton_syntax.py index 988b0b9f4c24..49f0bd06a8bc 100644 --- a/test/inductor/test_triton_syntax.py +++ b/test/inductor/test_triton_syntax.py @@ -16,7 +16,11 @@ class TestTritonSyntacticallyValid(TestCase): def newtonschulz5(G, steps: int, eps=1e-7): assert len(G.shape) == 2 a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() + X = G.to( + torch.bfloat16 + if torch.cuda.is_bf16_supported(including_emulation=False) + else torch.float16 + ) X /= X.norm() + eps # ensure top singular value <= 1 if G.size(0) > G.size(1): X = X.T diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 5c872462e950..e02171fd325b 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -441,7 +441,6 @@ inline void {{kernel_name}}_kernel( int64_t ldc ) { using Vectorized = at::vec::Vectorized<{{compute_t}}>; - using VectorizedIn = at::vec::Vectorized<{{input_t}}>; constexpr auto VLEN = Vectorized::size(); constexpr auto ROWS = BLOCK_M; constexpr auto COLS = BLOCK_N / VLEN; @@ -475,6 +474,7 @@ inline void {{kernel_name}}_kernel( if constexpr (row == 0) { {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + using VectorizedIn = at::vec::Vectorized<{{input_t}}>; auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN); vb[col] = at::vec::convert<{{compute_t}}>(b); {%- elif input2_dtype == torch.int8 %} diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 9d4982084b47..0988bc5b2005 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -2007,11 +2007,11 @@ if (custom_op_wrapper.get() == NULL) { def generate_float_value(self, val): assert isinstance(val, float) if val == float("inf"): - return "std::numeric_limits::infinity()" + return "std::numeric_limits::infinity()" elif val == float("-inf"): - return "-std::numeric_limits::infinity()" + return "-std::numeric_limits::infinity()" elif math.isnan(val): - return "std::numeric_limits::quiet_NaN()" + return "std::numeric_limits::quiet_NaN()" else: return f"{val}"