mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
cpp_wrapper: fix inductor triton tests (#146109)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146109 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
9740d69e78
commit
46d1422afd
@ -65,6 +65,9 @@ def mock_triton_hash_with_backend(*args, **kwargs):
|
||||
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now")
|
||||
@test_torchinductor.skip_if_cpp_wrapper(
|
||||
"Not possible to fix until CppWrapperCpu supports triton for CPU"
|
||||
)
|
||||
class TritonExtensionBackendTests(BaseExtensionBackendTests):
|
||||
"""
|
||||
Test creating a backend for inductor with Triton scheduling.
|
||||
|
@ -4,17 +4,17 @@
|
||||
# Skip do not assign a lambda expression, use a def
|
||||
import functools
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._dynamo.testing
|
||||
import torch._inductor.test_case
|
||||
from torch._dynamo import config as dynamo_config
|
||||
from torch._higher_order_ops.triton_kernel_wrap import (
|
||||
generate_ttir,
|
||||
triton_kernel_wrapper_functional,
|
||||
triton_kernel_wrapper_mutation,
|
||||
)
|
||||
from torch._inductor import metrics
|
||||
from torch._inductor import config as inductor_config, metrics
|
||||
from torch._inductor.utils import run_and_get_code, triton_version_uses_attrs_dict
|
||||
from torch._library import capture_triton
|
||||
from torch.testing import FileCheck
|
||||
@ -72,6 +72,11 @@ if HAS_GPU:
|
||||
|
||||
|
||||
class KernelTests(torch._inductor.test_case.TestCase):
|
||||
def _kernel_launched_in_code(self, kernel_name: str, code: str) -> bool:
|
||||
if inductor_config.cpp_wrapper:
|
||||
return f"launchKernel({kernel_name}" in code
|
||||
return f"{kernel_name}.run(" in code
|
||||
|
||||
@requires_gpu
|
||||
def test_triton_kernel_with_kernel_param(self):
|
||||
@triton.jit
|
||||
@ -344,8 +349,13 @@ def forward(self, x_1, output_1):
|
||||
|
||||
output_code = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip()
|
||||
self.assertTrue(len(output_code) > 0, msg="output code is not empty")
|
||||
self.assertEqual(output_code.count('float("nan")'), 0)
|
||||
self.assertEqual(output_code.count("float('nan')"), 0)
|
||||
if inductor_config.cpp_wrapper:
|
||||
self.assertEqual(
|
||||
output_code.count("std::numeric_limits<double>::quiet_NaN()"), 0
|
||||
)
|
||||
else:
|
||||
self.assertEqual(output_code.count('float("nan")'), 0)
|
||||
self.assertEqual(output_code.count("float('nan')"), 0)
|
||||
|
||||
@requires_gpu
|
||||
@common_utils.parametrize("grad_fn", [torch.no_grad, torch.enable_grad])
|
||||
@ -397,7 +407,7 @@ def forward(self, x_1, output_1):
|
||||
@requires_gpu
|
||||
@common_utils.parametrize("grad", [False, True])
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
@patch.object(torch._inductor.config, "implicit_fallbacks", False)
|
||||
@inductor_config.patch("implicit_fallbacks", False)
|
||||
def test_triton_kernel_no_clones(self, grad, dynamic):
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
|
||||
@ -419,7 +429,7 @@ def forward(self, x_1, output_1):
|
||||
torch_add = call_triton(t1, t2, o1)
|
||||
metrics.reset()
|
||||
o2 = torch.zeros_like(t1, requires_grad=grad)
|
||||
test, codes = run_and_get_code(
|
||||
test, (code,) = run_and_get_code(
|
||||
torch.compile(call_triton, dynamic=dynamic), t1, t2, o2
|
||||
)
|
||||
if not grad:
|
||||
@ -427,14 +437,27 @@ def forward(self, x_1, output_1):
|
||||
self.assertEqual(torch_add, test)
|
||||
# These two asserts are not optimal since it requires original aten
|
||||
# to be in the metadata, so there might be false negatives
|
||||
self.assertTrue("aten.copy" not in codes[0])
|
||||
self.assertTrue("aten.clone" not in codes[0])
|
||||
self.assertNotIn(
|
||||
"aoti_torch_copy_" if inductor_config.cpp_wrapper else "aten.copy", code
|
||||
)
|
||||
self.assertNotIn(
|
||||
"aoti_torch_clone" if inductor_config.cpp_wrapper else "aten.clone", code
|
||||
)
|
||||
# The following checks that there are only the tensor output is in
|
||||
# the compiled graph
|
||||
if dynamic and grad:
|
||||
self.assertTrue("return (buf0, s0, )" in codes[0])
|
||||
if inductor_config.cpp_wrapper:
|
||||
self.assertIn("output_handles[0] = ", code)
|
||||
self.assertIn("output_handles[1] = ", code)
|
||||
else:
|
||||
self.assertIn("return (buf0, s0, )", code)
|
||||
else:
|
||||
self.assertTrue("return (buf0, )" in codes[0])
|
||||
self.assertIn(
|
||||
"output_handles[0] = "
|
||||
if inductor_config.cpp_wrapper
|
||||
else "return (buf0, )",
|
||||
code,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
def test_triton_kernel_caching(self):
|
||||
@ -511,8 +534,8 @@ def forward(self, x_1, output_1):
|
||||
t = torch.ones(5, device=GPU_TYPE)
|
||||
test, (code,) = run_and_get_code(torch.compile(call_triton), t)
|
||||
# Make sure we emitted two kernels here
|
||||
self.assertTrue("pass_kernel_0.run" in code)
|
||||
self.assertTrue("pass_kernel_1.run" in code)
|
||||
self.assertTrue(self._kernel_launched_in_code("pass_kernel_0", code))
|
||||
self.assertTrue(self._kernel_launched_in_code("pass_kernel_1", code))
|
||||
|
||||
@requires_gpu
|
||||
def test_triton_kernel_various_args(self):
|
||||
@ -754,9 +777,7 @@ def forward(self, x_1, output_1):
|
||||
|
||||
@requires_gpu
|
||||
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
@patch.object(
|
||||
torch._inductor.config, "unsafe_ignore_unsupported_triton_autotune_args", True
|
||||
)
|
||||
@inductor_config.patch("unsafe_ignore_unsupported_triton_autotune_args", True)
|
||||
def test_triton_kernel_autotune_with_unsupported_args(self, backend):
|
||||
def call_triton(x: torch.Tensor, y: torch.Tensor):
|
||||
output = torch.zeros_like(x)
|
||||
@ -882,7 +903,7 @@ def forward(self, x_1, output_1):
|
||||
@common_utils.parametrize("grad", [False, True])
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
@patch.object(torch._inductor.config, "implicit_fallbacks", False)
|
||||
@inductor_config.patch("implicit_fallbacks", False)
|
||||
def test_triton_kernel_native(self, grad, dynamic, backend):
|
||||
def call_triton_add(
|
||||
x: torch.Tensor,
|
||||
@ -955,7 +976,7 @@ def forward(self, x_1, output_1):
|
||||
out.sum().backward()
|
||||
|
||||
@requires_gpu
|
||||
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
|
||||
@inductor_config.patch("allow_buffer_reuse", True)
|
||||
def test_triton_kernel_inputs_buffer_reuse(self):
|
||||
def _mul2(x):
|
||||
y = torch.empty_like(x)
|
||||
@ -981,13 +1002,18 @@ def forward(self, x_1, output_1):
|
||||
self.assertEqual(compiled_out, eager_out)
|
||||
|
||||
# Check that we're allocating the minimal # of buffers.
|
||||
code_string = f"empty_strided_{GPU_TYPE}((10, ), (1, ), torch.float32)"
|
||||
|
||||
code_string = (
|
||||
"aoti_torch_empty_strided("
|
||||
if inductor_config.cpp_wrapper
|
||||
else f"empty_strided_{GPU_TYPE}((10, ), (1, ), torch.float32)"
|
||||
)
|
||||
num_bufs_allocated = code.count(code_string)
|
||||
self.assertEqual(num_bufs_allocated, 2)
|
||||
|
||||
# Check we're re-using buffers if not allocating.
|
||||
num_bufs_reused = code.count("# reuse")
|
||||
num_bufs_reused = code.count(
|
||||
"// reuse" if inductor_config.cpp_wrapper else "# reuse"
|
||||
)
|
||||
self.assertEqual(num_bufs_reused, 3)
|
||||
|
||||
@requires_gpu
|
||||
@ -1036,7 +1062,7 @@ def forward(self, x_1, output_1):
|
||||
compiled_out = torch.compile(f)(inp)
|
||||
self.assertEqual(compiled_out, eager_out)
|
||||
|
||||
@torch._inductor.config.patch(
|
||||
@inductor_config.patch(
|
||||
triton_kernel_default_layout_constraint="needs_fixed_stride_order"
|
||||
)
|
||||
@requires_gpu
|
||||
@ -1197,8 +1223,8 @@ def forward(self, x_1, output_1):
|
||||
self.assertEqual(compiled_out, eager_out)
|
||||
|
||||
@requires_gpu
|
||||
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
@dynamo_config.patch(capture_dynamic_output_shape_ops=True)
|
||||
@dynamo_config.patch(capture_scalar_outputs=True)
|
||||
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_triton_kernel_unbacked_shape_tensor(self, backend):
|
||||
@triton.jit
|
||||
@ -1403,13 +1429,13 @@ def forward(self, x_1, output_1):
|
||||
)
|
||||
if size == 4 and not dynamic:
|
||||
# Produce 2 kernels due to divisibility
|
||||
self.assertTrue("add_kernel_0.run" in code)
|
||||
self.assertTrue("add_kernel_1.run" in code)
|
||||
self.assertTrue(self._kernel_launched_in_code("add_kernel_0", code))
|
||||
self.assertTrue(self._kernel_launched_in_code("add_kernel_1", code))
|
||||
else:
|
||||
# size == 16 or dynamic
|
||||
# Only one kernel
|
||||
self.assertTrue("add_kernel_0.run" in code)
|
||||
self.assertTrue("add_kernel_1.run" not in code)
|
||||
self.assertTrue(self._kernel_launched_in_code("add_kernel_0", code))
|
||||
self.assertFalse(self._kernel_launched_in_code("add_kernel_1", code))
|
||||
|
||||
self.assertEqual(compiled_out, eager_out)
|
||||
|
||||
@ -1446,10 +1472,10 @@ def forward(self, x_1, output_1):
|
||||
|
||||
x = torch.randn(4, device=GPU_TYPE)
|
||||
y = torch.randn(4, device=GPU_TYPE)
|
||||
args_list = (
|
||||
[x, y, torch.float32, tl.float32],
|
||||
[x, y, torch.bfloat16, tl.bfloat16],
|
||||
)
|
||||
args_list = [(x, y, torch.float32, tl.float32)]
|
||||
if torch.cuda.is_bf16_supported(including_emulation=False):
|
||||
args_list.append((x, y, torch.bfloat16, tl.bfloat16))
|
||||
|
||||
for args in args_list:
|
||||
eager_out = f(*args)
|
||||
compiled_out = torch.compile(
|
||||
@ -2012,7 +2038,7 @@ def forward(self, arg0_1, arg1_1):
|
||||
x = torch.rand(4, device=GPU_TYPE)
|
||||
prev = x.clone()
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
with inductor_config.patch(
|
||||
{"triton.autotune_at_compile_time": autotune_at_compile_time}
|
||||
):
|
||||
f(x)
|
||||
@ -2042,7 +2068,7 @@ def forward(self, arg0_1, arg1_1):
|
||||
elif cfg == "cpp_wrapper":
|
||||
config_kwargs = {"cpp_wrapper": True}
|
||||
|
||||
with torch._inductor.config.patch(**config_kwargs):
|
||||
with inductor_config.patch(**config_kwargs):
|
||||
|
||||
@triton.jit
|
||||
def _triton_kernel(out_ptr, numel, BLOCK_SIZE: tl.constexpr):
|
||||
@ -2113,7 +2139,7 @@ def forward(self, arg0_1, arg1_1):
|
||||
torch._dynamo.mark_dynamic(inp, 0)
|
||||
|
||||
fn_c = torch.compile(fn, fullgraph=True)
|
||||
with torch._dynamo.config.patch(capture_scalar_outputs=True):
|
||||
with dynamo_config.patch(capture_scalar_outputs=True):
|
||||
res = fn_c(inp)
|
||||
|
||||
self.assertTrue(((res < 2) & (res >= 0)).all().item())
|
||||
@ -3230,7 +3256,7 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
|
||||
self.assertNotIn(opname, code)
|
||||
|
||||
@requires_gpu
|
||||
@patch.object(torch._dynamo.config, "recompile_limit", 1)
|
||||
@dynamo_config.patch("recompile_limit", 1)
|
||||
def test_triton_dynamic_grid_no_recompile(self):
|
||||
libname = "my_cool_namespace"
|
||||
opname = "my_triton_operator"
|
||||
@ -3410,8 +3436,8 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
|
||||
|
||||
@skipIfWindows(msg="AOTI/Cpp_Wrapper have not enabled on Windows")
|
||||
@requires_gpu
|
||||
@patch.object(torch._inductor.config, "cpp_wrapper", True)
|
||||
@patch.object(torch._inductor.config, "triton.autotune_at_compile_time", True)
|
||||
@inductor_config.patch("cpp_wrapper", True)
|
||||
@inductor_config.patch("triton.autotune_at_compile_time", True)
|
||||
def test_autotune_unbacked(self):
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@ -3644,7 +3670,7 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
|
||||
increment = torch.rand(4, device=GPU_TYPE)
|
||||
|
||||
# during autotuning, x should not change in value
|
||||
with torch._inductor.config.patch(
|
||||
with inductor_config.patch(
|
||||
{"triton.autotune_at_compile_time": autotune_at_compile_time}
|
||||
):
|
||||
# we will add rand a single time to x
|
||||
@ -3957,7 +3983,7 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
|
||||
src = torch.empty(N, device=GPU_TYPE)
|
||||
dst = torch.zeros(N, device=GPU_TYPE)
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
with inductor_config.patch(
|
||||
{"triton.autotune_at_compile_time": autotune_at_compile_time}
|
||||
):
|
||||
compiled_f(dst, src, N=N)
|
||||
|
@ -16,7 +16,11 @@ class TestTritonSyntacticallyValid(TestCase):
|
||||
def newtonschulz5(G, steps: int, eps=1e-7):
|
||||
assert len(G.shape) == 2
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.bfloat16()
|
||||
X = G.to(
|
||||
torch.bfloat16
|
||||
if torch.cuda.is_bf16_supported(including_emulation=False)
|
||||
else torch.float16
|
||||
)
|
||||
X /= X.norm() + eps # ensure top singular value <= 1
|
||||
if G.size(0) > G.size(1):
|
||||
X = X.T
|
||||
|
@ -441,7 +441,6 @@ inline void {{kernel_name}}_kernel(
|
||||
int64_t ldc
|
||||
) {
|
||||
using Vectorized = at::vec::Vectorized<{{compute_t}}>;
|
||||
using VectorizedIn = at::vec::Vectorized<{{input_t}}>;
|
||||
constexpr auto VLEN = Vectorized::size();
|
||||
constexpr auto ROWS = BLOCK_M;
|
||||
constexpr auto COLS = BLOCK_N / VLEN;
|
||||
@ -475,6 +474,7 @@ inline void {{kernel_name}}_kernel(
|
||||
|
||||
if constexpr (row == 0) {
|
||||
{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
|
||||
using VectorizedIn = at::vec::Vectorized<{{input_t}}>;
|
||||
auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN);
|
||||
vb[col] = at::vec::convert<{{compute_t}}>(b);
|
||||
{%- elif input2_dtype == torch.int8 %}
|
||||
|
@ -2007,11 +2007,11 @@ if (custom_op_wrapper.get() == NULL) {
|
||||
def generate_float_value(self, val):
|
||||
assert isinstance(val, float)
|
||||
if val == float("inf"):
|
||||
return "std::numeric_limits<float>::infinity()"
|
||||
return "std::numeric_limits<double>::infinity()"
|
||||
elif val == float("-inf"):
|
||||
return "-std::numeric_limits<float>::infinity()"
|
||||
return "-std::numeric_limits<double>::infinity()"
|
||||
elif math.isnan(val):
|
||||
return "std::numeric_limits<float>::quiet_NaN()"
|
||||
return "std::numeric_limits<double>::quiet_NaN()"
|
||||
else:
|
||||
return f"{val}"
|
||||
|
||||
|
Reference in New Issue
Block a user