cpp_wrapper: fix inductor triton tests (#146109)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146109
Approved by: https://github.com/desertfire
This commit is contained in:
Benjamin Glass
2025-02-25 15:46:09 +00:00
committed by PyTorch MergeBot
parent 9740d69e78
commit 46d1422afd
5 changed files with 77 additions and 44 deletions

View File

@ -65,6 +65,9 @@ def mock_triton_hash_with_backend(*args, **kwargs):
@unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now")
@test_torchinductor.skip_if_cpp_wrapper(
"Not possible to fix until CppWrapperCpu supports triton for CPU"
)
class TritonExtensionBackendTests(BaseExtensionBackendTests):
"""
Test creating a backend for inductor with Triton scheduling.

View File

@ -4,17 +4,17 @@
# Skip do not assign a lambda expression, use a def
import functools
import logging
from unittest.mock import patch
import torch
import torch._dynamo.testing
import torch._inductor.test_case
from torch._dynamo import config as dynamo_config
from torch._higher_order_ops.triton_kernel_wrap import (
generate_ttir,
triton_kernel_wrapper_functional,
triton_kernel_wrapper_mutation,
)
from torch._inductor import metrics
from torch._inductor import config as inductor_config, metrics
from torch._inductor.utils import run_and_get_code, triton_version_uses_attrs_dict
from torch._library import capture_triton
from torch.testing import FileCheck
@ -72,6 +72,11 @@ if HAS_GPU:
class KernelTests(torch._inductor.test_case.TestCase):
def _kernel_launched_in_code(self, kernel_name: str, code: str) -> bool:
if inductor_config.cpp_wrapper:
return f"launchKernel({kernel_name}" in code
return f"{kernel_name}.run(" in code
@requires_gpu
def test_triton_kernel_with_kernel_param(self):
@triton.jit
@ -344,8 +349,13 @@ def forward(self, x_1, output_1):
output_code = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip()
self.assertTrue(len(output_code) > 0, msg="output code is not empty")
self.assertEqual(output_code.count('float("nan")'), 0)
self.assertEqual(output_code.count("float('nan')"), 0)
if inductor_config.cpp_wrapper:
self.assertEqual(
output_code.count("std::numeric_limits<double>::quiet_NaN()"), 0
)
else:
self.assertEqual(output_code.count('float("nan")'), 0)
self.assertEqual(output_code.count("float('nan')"), 0)
@requires_gpu
@common_utils.parametrize("grad_fn", [torch.no_grad, torch.enable_grad])
@ -397,7 +407,7 @@ def forward(self, x_1, output_1):
@requires_gpu
@common_utils.parametrize("grad", [False, True])
@common_utils.parametrize("dynamic", [False, True])
@patch.object(torch._inductor.config, "implicit_fallbacks", False)
@inductor_config.patch("implicit_fallbacks", False)
def test_triton_kernel_no_clones(self, grad, dynamic):
from torch._inductor.utils import run_and_get_code
@ -419,7 +429,7 @@ def forward(self, x_1, output_1):
torch_add = call_triton(t1, t2, o1)
metrics.reset()
o2 = torch.zeros_like(t1, requires_grad=grad)
test, codes = run_and_get_code(
test, (code,) = run_and_get_code(
torch.compile(call_triton, dynamic=dynamic), t1, t2, o2
)
if not grad:
@ -427,14 +437,27 @@ def forward(self, x_1, output_1):
self.assertEqual(torch_add, test)
# These two asserts are not optimal since it requires original aten
# to be in the metadata, so there might be false negatives
self.assertTrue("aten.copy" not in codes[0])
self.assertTrue("aten.clone" not in codes[0])
self.assertNotIn(
"aoti_torch_copy_" if inductor_config.cpp_wrapper else "aten.copy", code
)
self.assertNotIn(
"aoti_torch_clone" if inductor_config.cpp_wrapper else "aten.clone", code
)
# The following checks that there are only the tensor output is in
# the compiled graph
if dynamic and grad:
self.assertTrue("return (buf0, s0, )" in codes[0])
if inductor_config.cpp_wrapper:
self.assertIn("output_handles[0] = ", code)
self.assertIn("output_handles[1] = ", code)
else:
self.assertIn("return (buf0, s0, )", code)
else:
self.assertTrue("return (buf0, )" in codes[0])
self.assertIn(
"output_handles[0] = "
if inductor_config.cpp_wrapper
else "return (buf0, )",
code,
)
@requires_gpu
def test_triton_kernel_caching(self):
@ -511,8 +534,8 @@ def forward(self, x_1, output_1):
t = torch.ones(5, device=GPU_TYPE)
test, (code,) = run_and_get_code(torch.compile(call_triton), t)
# Make sure we emitted two kernels here
self.assertTrue("pass_kernel_0.run" in code)
self.assertTrue("pass_kernel_1.run" in code)
self.assertTrue(self._kernel_launched_in_code("pass_kernel_0", code))
self.assertTrue(self._kernel_launched_in_code("pass_kernel_1", code))
@requires_gpu
def test_triton_kernel_various_args(self):
@ -754,9 +777,7 @@ def forward(self, x_1, output_1):
@requires_gpu
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@patch.object(
torch._inductor.config, "unsafe_ignore_unsupported_triton_autotune_args", True
)
@inductor_config.patch("unsafe_ignore_unsupported_triton_autotune_args", True)
def test_triton_kernel_autotune_with_unsupported_args(self, backend):
def call_triton(x: torch.Tensor, y: torch.Tensor):
output = torch.zeros_like(x)
@ -882,7 +903,7 @@ def forward(self, x_1, output_1):
@common_utils.parametrize("grad", [False, True])
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@patch.object(torch._inductor.config, "implicit_fallbacks", False)
@inductor_config.patch("implicit_fallbacks", False)
def test_triton_kernel_native(self, grad, dynamic, backend):
def call_triton_add(
x: torch.Tensor,
@ -955,7 +976,7 @@ def forward(self, x_1, output_1):
out.sum().backward()
@requires_gpu
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
@inductor_config.patch("allow_buffer_reuse", True)
def test_triton_kernel_inputs_buffer_reuse(self):
def _mul2(x):
y = torch.empty_like(x)
@ -981,13 +1002,18 @@ def forward(self, x_1, output_1):
self.assertEqual(compiled_out, eager_out)
# Check that we're allocating the minimal # of buffers.
code_string = f"empty_strided_{GPU_TYPE}((10, ), (1, ), torch.float32)"
code_string = (
"aoti_torch_empty_strided("
if inductor_config.cpp_wrapper
else f"empty_strided_{GPU_TYPE}((10, ), (1, ), torch.float32)"
)
num_bufs_allocated = code.count(code_string)
self.assertEqual(num_bufs_allocated, 2)
# Check we're re-using buffers if not allocating.
num_bufs_reused = code.count("# reuse")
num_bufs_reused = code.count(
"// reuse" if inductor_config.cpp_wrapper else "# reuse"
)
self.assertEqual(num_bufs_reused, 3)
@requires_gpu
@ -1036,7 +1062,7 @@ def forward(self, x_1, output_1):
compiled_out = torch.compile(f)(inp)
self.assertEqual(compiled_out, eager_out)
@torch._inductor.config.patch(
@inductor_config.patch(
triton_kernel_default_layout_constraint="needs_fixed_stride_order"
)
@requires_gpu
@ -1197,8 +1223,8 @@ def forward(self, x_1, output_1):
self.assertEqual(compiled_out, eager_out)
@requires_gpu
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
@dynamo_config.patch(capture_dynamic_output_shape_ops=True)
@dynamo_config.patch(capture_scalar_outputs=True)
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_triton_kernel_unbacked_shape_tensor(self, backend):
@triton.jit
@ -1403,13 +1429,13 @@ def forward(self, x_1, output_1):
)
if size == 4 and not dynamic:
# Produce 2 kernels due to divisibility
self.assertTrue("add_kernel_0.run" in code)
self.assertTrue("add_kernel_1.run" in code)
self.assertTrue(self._kernel_launched_in_code("add_kernel_0", code))
self.assertTrue(self._kernel_launched_in_code("add_kernel_1", code))
else:
# size == 16 or dynamic
# Only one kernel
self.assertTrue("add_kernel_0.run" in code)
self.assertTrue("add_kernel_1.run" not in code)
self.assertTrue(self._kernel_launched_in_code("add_kernel_0", code))
self.assertFalse(self._kernel_launched_in_code("add_kernel_1", code))
self.assertEqual(compiled_out, eager_out)
@ -1446,10 +1472,10 @@ def forward(self, x_1, output_1):
x = torch.randn(4, device=GPU_TYPE)
y = torch.randn(4, device=GPU_TYPE)
args_list = (
[x, y, torch.float32, tl.float32],
[x, y, torch.bfloat16, tl.bfloat16],
)
args_list = [(x, y, torch.float32, tl.float32)]
if torch.cuda.is_bf16_supported(including_emulation=False):
args_list.append((x, y, torch.bfloat16, tl.bfloat16))
for args in args_list:
eager_out = f(*args)
compiled_out = torch.compile(
@ -2012,7 +2038,7 @@ def forward(self, arg0_1, arg1_1):
x = torch.rand(4, device=GPU_TYPE)
prev = x.clone()
with torch._inductor.config.patch(
with inductor_config.patch(
{"triton.autotune_at_compile_time": autotune_at_compile_time}
):
f(x)
@ -2042,7 +2068,7 @@ def forward(self, arg0_1, arg1_1):
elif cfg == "cpp_wrapper":
config_kwargs = {"cpp_wrapper": True}
with torch._inductor.config.patch(**config_kwargs):
with inductor_config.patch(**config_kwargs):
@triton.jit
def _triton_kernel(out_ptr, numel, BLOCK_SIZE: tl.constexpr):
@ -2113,7 +2139,7 @@ def forward(self, arg0_1, arg1_1):
torch._dynamo.mark_dynamic(inp, 0)
fn_c = torch.compile(fn, fullgraph=True)
with torch._dynamo.config.patch(capture_scalar_outputs=True):
with dynamo_config.patch(capture_scalar_outputs=True):
res = fn_c(inp)
self.assertTrue(((res < 2) & (res >= 0)).all().item())
@ -3230,7 +3256,7 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
self.assertNotIn(opname, code)
@requires_gpu
@patch.object(torch._dynamo.config, "recompile_limit", 1)
@dynamo_config.patch("recompile_limit", 1)
def test_triton_dynamic_grid_no_recompile(self):
libname = "my_cool_namespace"
opname = "my_triton_operator"
@ -3410,8 +3436,8 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
@skipIfWindows(msg="AOTI/Cpp_Wrapper have not enabled on Windows")
@requires_gpu
@patch.object(torch._inductor.config, "cpp_wrapper", True)
@patch.object(torch._inductor.config, "triton.autotune_at_compile_time", True)
@inductor_config.patch("cpp_wrapper", True)
@inductor_config.patch("triton.autotune_at_compile_time", True)
def test_autotune_unbacked(self):
import triton
import triton.language as tl
@ -3644,7 +3670,7 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
increment = torch.rand(4, device=GPU_TYPE)
# during autotuning, x should not change in value
with torch._inductor.config.patch(
with inductor_config.patch(
{"triton.autotune_at_compile_time": autotune_at_compile_time}
):
# we will add rand a single time to x
@ -3957,7 +3983,7 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
src = torch.empty(N, device=GPU_TYPE)
dst = torch.zeros(N, device=GPU_TYPE)
with torch._inductor.config.patch(
with inductor_config.patch(
{"triton.autotune_at_compile_time": autotune_at_compile_time}
):
compiled_f(dst, src, N=N)

View File

@ -16,7 +16,11 @@ class TestTritonSyntacticallyValid(TestCase):
def newtonschulz5(G, steps: int, eps=1e-7):
assert len(G.shape) == 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
X = G.to(
torch.bfloat16
if torch.cuda.is_bf16_supported(including_emulation=False)
else torch.float16
)
X /= X.norm() + eps # ensure top singular value <= 1
if G.size(0) > G.size(1):
X = X.T

View File

@ -441,7 +441,6 @@ inline void {{kernel_name}}_kernel(
int64_t ldc
) {
using Vectorized = at::vec::Vectorized<{{compute_t}}>;
using VectorizedIn = at::vec::Vectorized<{{input_t}}>;
constexpr auto VLEN = Vectorized::size();
constexpr auto ROWS = BLOCK_M;
constexpr auto COLS = BLOCK_N / VLEN;
@ -475,6 +474,7 @@ inline void {{kernel_name}}_kernel(
if constexpr (row == 0) {
{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
using VectorizedIn = at::vec::Vectorized<{{input_t}}>;
auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN);
vb[col] = at::vec::convert<{{compute_t}}>(b);
{%- elif input2_dtype == torch.int8 %}

View File

@ -2007,11 +2007,11 @@ if (custom_op_wrapper.get() == NULL) {
def generate_float_value(self, val):
assert isinstance(val, float)
if val == float("inf"):
return "std::numeric_limits<float>::infinity()"
return "std::numeric_limits<double>::infinity()"
elif val == float("-inf"):
return "-std::numeric_limits<float>::infinity()"
return "-std::numeric_limits<double>::infinity()"
elif math.isnan(val):
return "std::numeric_limits<float>::quiet_NaN()"
return "std::numeric_limits<double>::quiet_NaN()"
else:
return f"{val}"