mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146109 Approved by: https://github.com/desertfire
4055 lines
138 KiB
Python
4055 lines
138 KiB
Python
# Owner(s): ["module: inductor"]
|
|
# ruff: noqa: F841
|
|
# flake8: noqa: E731
|
|
# Skip do not assign a lambda expression, use a def
|
|
import functools
|
|
import logging
|
|
|
|
import torch
|
|
import torch._dynamo.testing
|
|
import torch._inductor.test_case
|
|
from torch._dynamo import config as dynamo_config
|
|
from torch._higher_order_ops.triton_kernel_wrap import (
|
|
generate_ttir,
|
|
triton_kernel_wrapper_functional,
|
|
triton_kernel_wrapper_mutation,
|
|
)
|
|
from torch._inductor import config as inductor_config, metrics
|
|
from torch._inductor.utils import run_and_get_code, triton_version_uses_attrs_dict
|
|
from torch._library import capture_triton
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal import common_utils
|
|
from torch.testing._internal.common_utils import (
|
|
parametrize,
|
|
skipIfRocm,
|
|
skipIfWindows,
|
|
skipIfXpu,
|
|
TEST_WITH_ROCM,
|
|
)
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU
|
|
from torch.testing._internal.logging_utils import log_settings, logs_to_string
|
|
|
|
# Defines all the kernels for tests
|
|
from torch.testing._internal.triton_utils import * # noqa: F403
|
|
from torch.utils._triton import has_triton_package, has_triton_tma
|
|
|
|
|
|
if HAS_GPU:
|
|
import triton
|
|
from triton import language as tl
|
|
|
|
if not TEST_WITH_ROCM:
|
|
if HAS_CUDA:
|
|
try:
|
|
from triton.language.extra.libdevice import ( # @manual
|
|
fast_dividef,
|
|
fast_dividef as my_fast_dividef,
|
|
)
|
|
except ImportError:
|
|
from triton.language.extra.cuda.libdevice import ( # @manual
|
|
fast_dividef,
|
|
fast_dividef as my_fast_dividef,
|
|
)
|
|
elif HAS_XPU:
|
|
from triton.language.extra.intel.libdevice import ( # @manual
|
|
fast_dividef,
|
|
fast_dividef as my_fast_dividef,
|
|
)
|
|
|
|
def _triton_get_ast_equal_to_str(params):
|
|
try:
|
|
from triton.backends.compiler import AttrsDescriptor # noqa: F401
|
|
|
|
return f"'tt.equal_to': {params}"
|
|
except ImportError:
|
|
return f"equal_to_1={params}"
|
|
|
|
# Define shared triton constants here.
|
|
CONSTANT_C: tl.constexpr = 4
|
|
STRING_CONSTANT_C: tl.constexpr = "CONSTANT_C"
|
|
BOOL_CONSTANT_C: tl.constexpr = True
|
|
FLOAT_CONSTANT_C = tl.constexpr(3.14) # intentionally un-annotated
|
|
|
|
|
|
class KernelTests(torch._inductor.test_case.TestCase):
|
|
def _kernel_launched_in_code(self, kernel_name: str, code: str) -> bool:
|
|
if inductor_config.cpp_wrapper:
|
|
return f"launchKernel({kernel_name}" in code
|
|
return f"{kernel_name}.run(" in code
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_with_kernel_param(self):
|
|
@triton.jit
|
|
def pass_kernel(kernel):
|
|
pass
|
|
|
|
@torch.compile(backend="eager")
|
|
def f(x):
|
|
grid = (x.numel(),)
|
|
pass_kernel[grid](kernel=x)
|
|
|
|
t1 = torch.rand(5, device=GPU_TYPE)
|
|
f(t1)
|
|
# No need to assert anything, the goal is to make sure dynamo does
|
|
# not crash
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_higher_order_func(self):
|
|
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
|
|
|
|
add_kernel_id = kernel_side_table.add_kernel(add_kernel)
|
|
|
|
t1 = torch.rand(5, device=GPU_TYPE)
|
|
t2 = torch.rand(5, device=GPU_TYPE)
|
|
|
|
torch_add = t1 + t2
|
|
|
|
# Test higher order function with mutation
|
|
output = torch.zeros_like(t1)
|
|
n_elements = output.numel()
|
|
constant_args_idx = kernel_side_table.add_constant_args(
|
|
{"n_elements": n_elements, "BLOCK_SIZE": 16}
|
|
)
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
triton_kernel_wrapper_mutation(
|
|
kernel_idx=add_kernel_id,
|
|
constant_args_idx=constant_args_idx,
|
|
grid=[grid],
|
|
tma_descriptor_metadata={},
|
|
kwargs={
|
|
"in_ptr0": t1,
|
|
"in_ptr1": t2,
|
|
"out_ptr": output,
|
|
},
|
|
)
|
|
self.assertEqual(output, torch_add)
|
|
# Make sure it is modified
|
|
self.assertNotEqual(output, torch.zeros_like(t1))
|
|
|
|
# Test higher order function without mutation
|
|
output = torch.zeros_like(t1)
|
|
out_dict = triton_kernel_wrapper_functional(
|
|
kernel_idx=add_kernel_id,
|
|
constant_args_idx=constant_args_idx,
|
|
grid=[grid],
|
|
tma_descriptor_metadata={},
|
|
kwargs={
|
|
"in_ptr0": t1,
|
|
"in_ptr1": t2,
|
|
"out_ptr": output,
|
|
},
|
|
tensors_to_clone=["in_ptr0", "in_ptr1", "out_ptr"],
|
|
)
|
|
self.assertEqual(out_dict["out_ptr"], torch_add)
|
|
# Make sure it is NOT modified
|
|
self.assertEqual(output, torch.zeros_like(t1))
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_functionalize(self):
|
|
from functorch import make_fx
|
|
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
|
|
from torch._subclasses.functional_tensor import (
|
|
CppFunctionalizeAPI,
|
|
FunctionalTensorMode,
|
|
PythonFunctionalizeAPI,
|
|
)
|
|
|
|
kernel_side_table.reset_table()
|
|
|
|
def f(x, output):
|
|
out = triton_kernel_wrapper_functional(
|
|
kernel_idx=kernel_side_table.add_kernel(mul2_kernel),
|
|
constant_args_idx=kernel_side_table.add_constant_args(
|
|
{"n_elements": output.numel(), "BLOCK_SIZE": 16}
|
|
),
|
|
grid=[(x.numel(),)],
|
|
tma_descriptor_metadata={},
|
|
kwargs={
|
|
"in_ptr0": x,
|
|
"out_ptr": output,
|
|
},
|
|
tensors_to_clone=["in_ptr0", "out_ptr"],
|
|
)
|
|
return out["out_ptr"]
|
|
|
|
t1 = torch.rand(5, device=GPU_TYPE)
|
|
t2 = torch.rand(5, device=GPU_TYPE)
|
|
with FunctionalTensorMode():
|
|
gm = make_fx(PythonFunctionalizeAPI().functionalize(f))(t1, t2)
|
|
# Make sure t2 was not modified
|
|
self.assertNotEqual(gm(t1, t2), t2)
|
|
|
|
gm = make_fx(CppFunctionalizeAPI().functionalize(f))(t1, t2)
|
|
# Make sure t2 was not modified
|
|
self.assertNotEqual(gm(t1, t2), t2)
|
|
|
|
gm = make_fx(torch.func.functionalize(f))(t1, t2)
|
|
# Make sure t2 was not modified
|
|
self.assertNotEqual(gm(t1, t2), t2)
|
|
|
|
gm = make_fx(f, tracing_mode="fake")(t1, t2)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1, output_1):
|
|
triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 3, grid = [(5,)], tma_descriptor_metadata = {}, kwargs = {'in_ptr0': x_1, 'out_ptr': output_1}, tensors_to_clone = ['in_ptr0', 'out_ptr']); x_1 = output_1 = None
|
|
getitem = triton_kernel_wrapper_functional_proxy['in_ptr0']; getitem = None
|
|
getitem_1 = triton_kernel_wrapper_functional_proxy['out_ptr']; triton_kernel_wrapper_functional_proxy = None
|
|
return getitem_1""",
|
|
)
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_mutation_type(self):
|
|
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch._subclasses.functional_tensor import (
|
|
FunctionalTensor,
|
|
FunctionalTensorMode,
|
|
)
|
|
|
|
def prep():
|
|
x = torch.ones(4, device=GPU_TYPE, requires_grad=True)
|
|
with FunctionalTensorMode():
|
|
x_func = FunctionalTensor.to_functional(x)
|
|
self.assertTrue(torch._is_functional_tensor(x_func.elem))
|
|
return x_func
|
|
|
|
# normal mutation only
|
|
with FakeTensorMode():
|
|
x_func = prep()
|
|
|
|
with FunctionalTensorMode():
|
|
x_func.mul_(2)
|
|
|
|
self.assertFalse(
|
|
torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem)
|
|
)
|
|
|
|
# triton kernel mutation only
|
|
with FakeTensorMode():
|
|
x_func = prep()
|
|
|
|
with FunctionalTensorMode():
|
|
triton_kernel_wrapper_mutation(
|
|
kernel_idx=kernel_side_table.add_kernel(mul2_inplace_kernel),
|
|
constant_args_idx=kernel_side_table.add_constant_args(
|
|
{"n_elements": x_func.numel(), "BLOCK_SIZE": 16}
|
|
),
|
|
grid=[(x_func.numel(),)],
|
|
tma_descriptor_metadata={},
|
|
kwargs={
|
|
"ptr": x_func,
|
|
},
|
|
)
|
|
|
|
self.assertTrue(
|
|
torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem)
|
|
)
|
|
|
|
# normal mutation + triton kernel mutation
|
|
with FakeTensorMode():
|
|
x_func = prep()
|
|
|
|
with FunctionalTensorMode():
|
|
x_func.mul_(2)
|
|
triton_kernel_wrapper_mutation(
|
|
kernel_idx=kernel_side_table.add_kernel(mul2_inplace_kernel),
|
|
constant_args_idx=kernel_side_table.add_constant_args(
|
|
{"n_elements": x_func.numel(), "BLOCK_SIZE": 16}
|
|
),
|
|
grid=[(x_func.numel(),)],
|
|
tma_descriptor_metadata={},
|
|
kwargs={
|
|
"ptr": x_func,
|
|
},
|
|
)
|
|
|
|
self.assertFalse(
|
|
torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem)
|
|
)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
def test_triton_kernel_with_views(self, dynamic, backend):
|
|
def call_triton_take_view(x: torch.Tensor):
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16)
|
|
return output
|
|
|
|
def call_triton_return_view(x: torch.Tensor):
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16)
|
|
return output.view(4, 4)
|
|
|
|
t = torch.rand(4, 4, device=GPU_TYPE)
|
|
t_view = t.view(16)
|
|
|
|
compiled_func = torch.compile(
|
|
call_triton_take_view, backend=backend, fullgraph=True, dynamic=dynamic
|
|
)
|
|
self.assertEqual(2 * t_view, compiled_func(t_view))
|
|
self.assertEqual(2 * t, compiled_func(t_view).view(4, 4))
|
|
|
|
compiled_func = torch.compile(
|
|
call_triton_return_view, backend=backend, fullgraph=True, dynamic=dynamic
|
|
)
|
|
self.assertEqual(2 * t_view, compiled_func(t).view(16))
|
|
self.assertEqual(2 * t, compiled_func(t))
|
|
|
|
@requires_gpu
|
|
def test_no_nan_kernels(self):
|
|
@triton.jit
|
|
def add_one_kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
output = x + 1
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
def add_one(x, out):
|
|
n_elements = x.numel()
|
|
add_one_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
|
|
|
|
class AddOne(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
out = torch.empty_like(x)
|
|
add_one(x, out)
|
|
ctx.save_for_backward(out)
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad):
|
|
(saved,) = ctx.saved_tensors
|
|
out = torch.empty_like(grad)
|
|
add_one(saved, out)
|
|
return out
|
|
|
|
@torch.compile
|
|
def f(x):
|
|
return AddOne.apply(x)
|
|
|
|
log_stream, ctx = logs_to_string("torch._inductor.codecache", "output_code")
|
|
|
|
x = torch.randn(3, requires_grad=True, device=GPU_TYPE)
|
|
with ctx():
|
|
y = f(x)
|
|
|
|
output_code = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip()
|
|
self.assertTrue(len(output_code) > 0, msg="output code is not empty")
|
|
if inductor_config.cpp_wrapper:
|
|
self.assertEqual(
|
|
output_code.count("std::numeric_limits<double>::quiet_NaN()"), 0
|
|
)
|
|
else:
|
|
self.assertEqual(output_code.count('float("nan")'), 0)
|
|
self.assertEqual(output_code.count("float('nan')"), 0)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("grad_fn", [torch.no_grad, torch.enable_grad])
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
def test_triton_kernel_with_grad_option(self, grad_fn, backend):
|
|
def call_triton(x: torch.Tensor):
|
|
with grad_fn():
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16)
|
|
return output
|
|
|
|
t = torch.rand(5, device=GPU_TYPE)
|
|
compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True)
|
|
self.assertEqual(2 * t, compiled_func(t))
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
def test_triton_kernel_inner_triton_function(self, backend):
|
|
def f(x: torch.Tensor):
|
|
@triton.jit
|
|
def pow2_kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
output = x * x
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
pow2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16)
|
|
return output
|
|
|
|
t = torch.rand(5, device=GPU_TYPE)
|
|
|
|
compiled_func = torch.compile(f, backend=backend, fullgraph=True)
|
|
# TODO(oulgen): NYI - Support this
|
|
# self.assertEqual(t * t, compiled_func(t))
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("grad", [False, True])
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@inductor_config.patch("implicit_fallbacks", False)
|
|
def test_triton_kernel_no_clones(self, grad, dynamic):
|
|
from torch._inductor.utils import run_and_get_code
|
|
|
|
def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor):
|
|
n_elements = output.numel()
|
|
|
|
tmp = torch.add(x, 1)
|
|
grid = (x.numel(),)
|
|
add_kernel.run(
|
|
x, y, output, n_elements, warmup=False, grid=grid, BLOCK_SIZE=16
|
|
)
|
|
|
|
return output, tmp
|
|
|
|
t1 = torch.rand(5, device=GPU_TYPE, requires_grad=grad)
|
|
t2 = torch.rand(5, device=GPU_TYPE, requires_grad=grad)
|
|
o1 = torch.zeros_like(t1, requires_grad=grad)
|
|
|
|
torch_add = call_triton(t1, t2, o1)
|
|
metrics.reset()
|
|
o2 = torch.zeros_like(t1, requires_grad=grad)
|
|
test, (code,) = run_and_get_code(
|
|
torch.compile(call_triton, dynamic=dynamic), t1, t2, o2
|
|
)
|
|
if not grad:
|
|
self.assertEqual(metrics.generated_kernel_count, 1)
|
|
self.assertEqual(torch_add, test)
|
|
# These two asserts are not optimal since it requires original aten
|
|
# to be in the metadata, so there might be false negatives
|
|
self.assertNotIn(
|
|
"aoti_torch_copy_" if inductor_config.cpp_wrapper else "aten.copy", code
|
|
)
|
|
self.assertNotIn(
|
|
"aoti_torch_clone" if inductor_config.cpp_wrapper else "aten.clone", code
|
|
)
|
|
# The following checks that there are only the tensor output is in
|
|
# the compiled graph
|
|
if dynamic and grad:
|
|
if inductor_config.cpp_wrapper:
|
|
self.assertIn("output_handles[0] = ", code)
|
|
self.assertIn("output_handles[1] = ", code)
|
|
else:
|
|
self.assertIn("return (buf0, s0, )", code)
|
|
else:
|
|
self.assertIn(
|
|
"output_handles[0] = "
|
|
if inductor_config.cpp_wrapper
|
|
else "return (buf0, )",
|
|
code,
|
|
)
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_caching(self):
|
|
from torch._inductor.utils import run_and_get_code
|
|
|
|
def add_in_loop(
|
|
x: torch.Tensor,
|
|
y: torch.Tensor,
|
|
):
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel_autotuned[grid](x, y, output, n_elements)
|
|
return output
|
|
|
|
def call_triton_add(
|
|
x: torch.Tensor,
|
|
y: torch.Tensor,
|
|
):
|
|
for i in range(4):
|
|
x = add_in_loop(x, y)
|
|
return x
|
|
|
|
t1 = torch.ones(5, device=GPU_TYPE)
|
|
t2 = torch.ones(5, device=GPU_TYPE)
|
|
|
|
test, (code,) = run_and_get_code(torch.compile(call_triton_add), t1, t2)
|
|
self.assertEqual(test, 5 * torch.ones(5, device=GPU_TYPE))
|
|
self.assertTrue("add_kernel_autotuned_1.run" not in code)
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_caching_duplicate(self):
|
|
from torch._inductor.utils import run_and_get_code
|
|
|
|
class C:
|
|
@triton.jit
|
|
def pass_kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
tl.store(out_ptr + offsets, x, mask=mask)
|
|
|
|
class D:
|
|
@triton.jit
|
|
def pass_kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
tl.store(out_ptr + offsets, x, mask=mask)
|
|
|
|
def call_triton(x: torch.Tensor):
|
|
output1 = torch.zeros_like(x)
|
|
output2 = torch.zeros_like(x)
|
|
n_elements = output1.numel()
|
|
grid = (n_elements,)
|
|
C.pass_kernel[grid](x, output1, n_elements, BLOCK_SIZE=16)
|
|
D.pass_kernel[grid](x, output2, n_elements, BLOCK_SIZE=16)
|
|
return output1 + output2
|
|
|
|
t = torch.ones(5, device=GPU_TYPE)
|
|
test, (code,) = run_and_get_code(torch.compile(call_triton), t)
|
|
# Make sure we emitted two kernels here
|
|
self.assertTrue(self._kernel_launched_in_code("pass_kernel_0", code))
|
|
self.assertTrue(self._kernel_launched_in_code("pass_kernel_1", code))
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_various_args(self):
|
|
@triton.autotune(
|
|
configs=[triton.Config({"BLOCK_SIZE": 128})],
|
|
key=[],
|
|
)
|
|
@triton.jit
|
|
def pass_kernel(
|
|
out_ptr,
|
|
n_elements,
|
|
dummy_None,
|
|
dummy_empty,
|
|
dummy_float,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
RANDOM_SIZE: "tl.constexpr",
|
|
):
|
|
pass
|
|
|
|
@torch.compile
|
|
def call_triton(output):
|
|
n_elements = output.numel()
|
|
grid = (n_elements,)
|
|
pass_kernel[grid](
|
|
output,
|
|
n_elements,
|
|
None,
|
|
torch.empty_like(output),
|
|
3.1415926,
|
|
RANDOM_SIZE=0,
|
|
)
|
|
return output
|
|
|
|
output = torch.randn(5, device=GPU_TYPE)
|
|
# Make sure this does not crash
|
|
call_triton(output)
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_dependancies(self):
|
|
def call_triton(
|
|
x: torch.Tensor,
|
|
y: torch.Tensor,
|
|
):
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel_autotuned[grid](x, y, output, n_elements)
|
|
output2 = torch.zeros_like(output)
|
|
add_kernel_autotuned[grid](output, y, output2, n_elements)
|
|
output3 = torch.add(output2, 1)
|
|
return output3
|
|
|
|
t1 = torch.rand(5, device=GPU_TYPE)
|
|
t2 = torch.rand(5, device=GPU_TYPE)
|
|
torch_result = call_triton(t1, t2)
|
|
compiled_result = torch.compile(call_triton)(t1, t2)
|
|
self.assertEqual(torch_result, compiled_result)
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_reinplace_inplaceable_pass(self):
|
|
def call_triton(
|
|
x: torch.Tensor,
|
|
y: torch.Tensor,
|
|
):
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel_autotuned[grid](x, y, output, n_elements)
|
|
add_kernel_autotuned[grid](output, x, output, n_elements)
|
|
return output
|
|
|
|
t1 = torch.rand(5, device=GPU_TYPE)
|
|
t2 = torch.rand(5, device=GPU_TYPE)
|
|
torch_result = call_triton(t1, t2)
|
|
compiled_result = torch.compile(call_triton)(t1, t2)
|
|
self.assertEqual(torch_result, compiled_result)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("grad", [False, True])
|
|
def test_triton_kernel_multi_kernel(self, grad):
|
|
@triton.jit
|
|
def mul2_and_add_and_zero_negatives_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
ACTIVATION: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
indirection_kernel(
|
|
in_ptr0,
|
|
in_ptr0,
|
|
n_elements,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
ACTIVATION="mul2_inplace_kernel",
|
|
)
|
|
indirection_kernel(
|
|
in_ptr1,
|
|
in_ptr1,
|
|
n_elements,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
ACTIVATION="mul2_inplace_kernel",
|
|
)
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
if ACTIVATION == "zero_negs":
|
|
output = zero_negs(output)
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@torch.compile
|
|
def call_triton(
|
|
x: torch.Tensor,
|
|
y: torch.Tensor,
|
|
xi: torch.Tensor,
|
|
yi: torch.Tensor,
|
|
output: torch.Tensor,
|
|
outputi: torch.Tensor,
|
|
):
|
|
n_elements = output.numel()
|
|
|
|
grid = (x.numel(),)
|
|
mul2_and_add_and_zero_negatives_kernel[grid](
|
|
x, y, output, n_elements, BLOCK_SIZE=16, ACTIVATION="zero_negs"
|
|
)
|
|
mul2_and_add_and_zero_negatives_kernel[grid](
|
|
xi, yi, outputi, n_elements, BLOCK_SIZE=16, ACTIVATION=None
|
|
)
|
|
|
|
return (output, outputi)
|
|
|
|
t1 = torch.tensor(
|
|
[-2.0, -1.0, 0.0, 1.0, 2.0], device=GPU_TYPE, requires_grad=grad
|
|
)
|
|
t2 = torch.tensor(
|
|
[-2.0, -1.0, 0.0, 1.0, 2.0], device=GPU_TYPE, requires_grad=grad
|
|
)
|
|
float_result = 2 * t1 + 2 * t2
|
|
float_result = float_result.where(float_result >= 0, 0.0)
|
|
|
|
t1i = torch.randint(-2, 2, (5,), device=GPU_TYPE)
|
|
t2i = torch.randint(-2, 2, (5,), device=GPU_TYPE)
|
|
o = torch.zeros_like(t1, requires_grad=grad)
|
|
oi = torch.zeros_like(t1i)
|
|
int_result = 2 * t1i + 2 * t2i
|
|
|
|
(result, resulti) = call_triton(t1, t2, t1i, t2i, o, oi)
|
|
self.assertEqual(float_result, result)
|
|
self.assertEqual(int_result, resulti)
|
|
|
|
@requires_gpu
|
|
@skipIfXpu
|
|
def test_triton_kernel_constants(self):
|
|
@triton.jit
|
|
def mulC_kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
CONSTANT_NAME: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
if CONSTANT_NAME == STRING_CONSTANT_C:
|
|
output = CONSTANT_C * x
|
|
if BOOL_CONSTANT_C:
|
|
output *= CONSTANT_C
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
def call_triton(
|
|
x: torch.Tensor,
|
|
):
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
|
|
grid = (x.numel(),)
|
|
mulC_kernel[grid](
|
|
x, output, n_elements, BLOCK_SIZE=16, CONSTANT_NAME="CONSTANT_C"
|
|
)
|
|
return output
|
|
|
|
# Triton kernels capture global constants by their parse time value
|
|
# not runtime value
|
|
global CONSTANT_C
|
|
prev_c = CONSTANT_C
|
|
# If the behavior of triton kernels change, this test will fail
|
|
CONSTANT_C = 10
|
|
assert CONSTANT_C != prev_c
|
|
|
|
t = torch.randn(5, device=GPU_TYPE)
|
|
torch_result = call_triton(t)
|
|
compiled_result = torch.compile(call_triton)(t)
|
|
|
|
self.assertEqual(torch_result, compiled_result)
|
|
|
|
# reset back
|
|
CONSTANT_C = prev_c
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("grad", [False, True])
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
@common_utils.parametrize("grid_type", [1, 2, 3])
|
|
def test_triton_kernel_autotune(self, grad, dynamic, backend, grid_type):
|
|
def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor):
|
|
n_elements = output.numel()
|
|
|
|
def grid_fn(meta):
|
|
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
|
|
if grid_type == 1:
|
|
grid = (n_elements,)
|
|
elif grid_type == 2:
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
elif grid_type == 3:
|
|
grid = grid_fn
|
|
|
|
add_kernel_autotuned[grid](x, y, output, n_elements)
|
|
return output
|
|
|
|
t1 = torch.rand(256, device=GPU_TYPE, requires_grad=grad)
|
|
t2 = torch.rand(256, device=GPU_TYPE, requires_grad=grad)
|
|
output = torch.zeros_like(t1, requires_grad=grad)
|
|
|
|
torch_add = call_triton(t1, t2, output)
|
|
compiled_func = torch.compile(
|
|
call_triton, backend=backend, fullgraph=True, dynamic=dynamic
|
|
)
|
|
|
|
output2 = torch.zeros_like(t1, requires_grad=grad)
|
|
self.assertEqual(compiled_func(t1, t2, output2), torch_add)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
@inductor_config.patch("unsafe_ignore_unsupported_triton_autotune_args", True)
|
|
def test_triton_kernel_autotune_with_unsupported_args(self, backend):
|
|
def call_triton(x: torch.Tensor, y: torch.Tensor):
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
add_kernel_autotuned_with_unsupported_args[(n_elements,)](
|
|
x, y, output, n_elements
|
|
)
|
|
return output
|
|
|
|
t1 = torch.rand(256, device=GPU_TYPE)
|
|
t2 = torch.rand(256, device=GPU_TYPE)
|
|
|
|
torch_add = call_triton(t1, t2)
|
|
compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True)
|
|
compiled_add = compiled_func(t1, t2)
|
|
self.assertEqual(compiled_add, torch_add)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("grad", [False, True])
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
@common_utils.parametrize("grid_type", [1, 2, 3])
|
|
def test_triton_kernel_2d_autotune(self, grad, dynamic, backend, grid_type):
|
|
def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor):
|
|
x_elements = output.size()[0]
|
|
y_elements = output.size()[1]
|
|
|
|
def grid_fn(meta):
|
|
return (
|
|
triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
|
|
triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
|
|
)
|
|
|
|
if grid_type == 1:
|
|
grid = (x_elements, y_elements)
|
|
elif grid_type == 2:
|
|
grid = lambda meta: (
|
|
triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
|
|
triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
|
|
)
|
|
elif grid_type == 3:
|
|
grid = grid_fn
|
|
|
|
add_kernel_2d_autotuned[grid](x, y, output, x_elements, y_elements)
|
|
return output
|
|
|
|
t1 = torch.rand((512, 256), device=GPU_TYPE, requires_grad=grad)
|
|
t2 = torch.rand((512, 256), device=GPU_TYPE, requires_grad=grad)
|
|
output = torch.zeros_like(t1, requires_grad=grad)
|
|
|
|
torch_result = call_triton(t1, t2, output)
|
|
compiled_func = torch.compile(
|
|
call_triton, backend=backend, fullgraph=True, dynamic=dynamic
|
|
)
|
|
output2 = torch.zeros_like(t1, requires_grad=grad)
|
|
self.assertEqual(compiled_func(t1, t2, output2), torch_result)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_triton_kernel_tracing(self, dynamic):
|
|
def call_triton_add(
|
|
x: torch.Tensor,
|
|
y: torch.Tensor,
|
|
grid_type: int,
|
|
num=1,
|
|
positional=False,
|
|
autotuned=False,
|
|
):
|
|
output = torch.empty_like(x)
|
|
n_elements = output.numel()
|
|
|
|
def grid_fn(meta):
|
|
return (triton.cdiv(num, meta["BLOCK_SIZE"]),)
|
|
|
|
if grid_type == 0:
|
|
grid = (x.numel(),)
|
|
elif grid_type == 1:
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
elif grid_type == 2:
|
|
grid = grid_fn
|
|
else:
|
|
grid = [x.numel()]
|
|
|
|
if autotuned:
|
|
capture_triton(add_kernel_autotuned)[grid](x, y, output, n_elements)
|
|
else:
|
|
if positional:
|
|
capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
|
|
else:
|
|
capture_triton(add_kernel)[grid](
|
|
x, y, output, n_elements, BLOCK_SIZE=16
|
|
)
|
|
|
|
return output
|
|
|
|
t0 = torch.rand(5, device=GPU_TYPE, requires_grad=True)
|
|
t1 = torch.rand(5, device=GPU_TYPE, requires_grad=True)
|
|
t2 = torch.rand(5, device=GPU_TYPE, requires_grad=True)
|
|
t3 = torch.rand(5, device=GPU_TYPE, requires_grad=True)
|
|
torch_add = t2 + t3
|
|
|
|
tests = [
|
|
functools.partial(call_triton_add, grid_type=0),
|
|
functools.partial(call_triton_add, grid_type=1),
|
|
functools.partial(call_triton_add, grid_type=1, num=1, positional=True),
|
|
functools.partial(call_triton_add, grid_type=2, num=200),
|
|
functools.partial(call_triton_add, grid_type=3),
|
|
functools.partial(call_triton_add, grid_type=0, autotuned=True),
|
|
functools.partial(call_triton_add, grid_type=1, num=1, autotuned=True),
|
|
functools.partial(call_triton_add, grid_type=2, num=200, autotuned=True),
|
|
functools.partial(call_triton_add, grid_type=3, autotuned=True),
|
|
]
|
|
from functorch import make_fx
|
|
|
|
tracing_mode = "symbolic" if dynamic else "fake"
|
|
|
|
for test in tests:
|
|
gm = make_fx(test, tracing_mode=tracing_mode)(t0, t1)
|
|
result = test(t2, t3)
|
|
self.assertEqual(result, torch_add)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("grad", [False, True])
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
@inductor_config.patch("implicit_fallbacks", False)
|
|
def test_triton_kernel_native(self, grad, dynamic, backend):
|
|
def call_triton_add(
|
|
x: torch.Tensor,
|
|
y: torch.Tensor,
|
|
output: torch.Tensor,
|
|
grid_type: int,
|
|
num=1,
|
|
positional=False,
|
|
):
|
|
n_elements = output.numel()
|
|
|
|
def grid_fn(meta):
|
|
return (triton.cdiv(num, meta["BLOCK_SIZE"]),)
|
|
|
|
if grid_type == 0:
|
|
grid = (x.numel(),)
|
|
elif grid_type == 1:
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
else:
|
|
grid = grid_fn
|
|
|
|
if positional:
|
|
add_kernel[grid](x, y, output, n_elements, 16)
|
|
else:
|
|
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
|
|
|
|
return output
|
|
|
|
t1 = torch.rand(5, device=GPU_TYPE, requires_grad=grad)
|
|
t2 = torch.rand(5, device=GPU_TYPE, requires_grad=grad)
|
|
o1 = torch.zeros_like(t1, requires_grad=grad)
|
|
|
|
torch_add = t1 + t2
|
|
|
|
# No Dynamo -- Make sure triton kernel works
|
|
self.assertEqual(call_triton_add(t1, t2, o1, 1), torch_add)
|
|
# No Dynamo -- Make sure triton kernel works (with positional BLOCK_SIZE)
|
|
o2 = torch.zeros_like(t1, requires_grad=grad)
|
|
self.assertEqual(call_triton_add(t1, t2, o2, 1, True), torch_add)
|
|
|
|
# With Dynamo
|
|
compiled_func = torch.compile(
|
|
call_triton_add, backend=backend, fullgraph=True, dynamic=dynamic
|
|
)
|
|
# With simple kernel
|
|
o3 = torch.zeros_like(t1, requires_grad=grad)
|
|
self.assertEqual(compiled_func(t1, t2, o3, 0), torch_add)
|
|
# With lambda kernel
|
|
o4 = torch.zeros_like(t1, requires_grad=grad)
|
|
self.assertEqual(compiled_func(t1, t2, o4, 1), torch_add)
|
|
# With lambda kernel (with positional BLOCK_SIZE)
|
|
o5 = torch.zeros_like(t1, requires_grad=grad)
|
|
self.assertEqual(compiled_func(t1, t2, o5, 1, 1, True), torch_add)
|
|
# With user defined function kernel
|
|
o6 = torch.zeros_like(t1, requires_grad=grad)
|
|
self.assertEqual(compiled_func(t1, t2, o6, 2, 200), torch_add)
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_mutation_not_mark_dirty(self):
|
|
@torch.compile
|
|
def f(x):
|
|
n_elements = x.numel()
|
|
add_kernel[(n_elements,)](x, x, x, n_elements, 16)
|
|
return x
|
|
|
|
x = torch.randn(5, device=GPU_TYPE, requires_grad=True)
|
|
x_cloned = x.clone()
|
|
out = x_cloned.sin()
|
|
f(x_cloned)
|
|
out.sum().backward()
|
|
|
|
@requires_gpu
|
|
@inductor_config.patch("allow_buffer_reuse", True)
|
|
def test_triton_kernel_inputs_buffer_reuse(self):
|
|
def _mul2(x):
|
|
y = torch.empty_like(x)
|
|
mul2_kernel[(10,)](
|
|
in_ptr0=x,
|
|
out_ptr=y,
|
|
n_elements=x.numel(),
|
|
BLOCK_SIZE=1,
|
|
)
|
|
return y
|
|
|
|
@torch.compile
|
|
def f(x):
|
|
for _ in range(4):
|
|
# The output of one kernel is the input to the next kernel, but
|
|
# at some point we should re-use buffers not allocate new ones.
|
|
x = _mul2(x)
|
|
return x + 1
|
|
|
|
x = torch.randn(10, device=GPU_TYPE, dtype=torch.float32)
|
|
eager_out = f(x)
|
|
compiled_out, (code,) = run_and_get_code(torch.compile(f), x)
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
# Check that we're allocating the minimal # of buffers.
|
|
code_string = (
|
|
"aoti_torch_empty_strided("
|
|
if inductor_config.cpp_wrapper
|
|
else f"empty_strided_{GPU_TYPE}((10, ), (1, ), torch.float32)"
|
|
)
|
|
num_bufs_allocated = code.count(code_string)
|
|
self.assertEqual(num_bufs_allocated, 2)
|
|
|
|
# Check we're re-using buffers if not allocating.
|
|
num_bufs_reused = code.count(
|
|
"// reuse" if inductor_config.cpp_wrapper else "# reuse"
|
|
)
|
|
self.assertEqual(num_bufs_reused, 3)
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_matmul_tracking(self):
|
|
@triton.jit
|
|
def ones_kernel(x_ptr, n_elements, BLOCK_SIZE: "tl.constexpr"):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = 1.0
|
|
tl.store(x_ptr + offsets, x, mask=mask)
|
|
|
|
@torch.compile
|
|
def f(x):
|
|
out = torch.zeros_like(x)
|
|
ones_kernel[(4,)](out, 16, BLOCK_SIZE=16)
|
|
return torch.mm(out, x) + 10
|
|
|
|
x = torch.randn(4, 4, device=GPU_TYPE)
|
|
torch_out = f(x)
|
|
python_out = torch.mm(torch.ones(4, 4, device=GPU_TYPE), x) + 10
|
|
self.assertEqual(torch_out, python_out)
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_strided_input(self):
|
|
def f(inp):
|
|
# left has strides [256, 1]
|
|
left, right = torch.split(inp, [128, 128], dim=1)
|
|
out = torch.empty_like(left)
|
|
X_BLOCK_SIZE, Y_BLOCK_SIZE = 32, 16
|
|
grid = (left.size(1) // X_BLOCK_SIZE, left.size(0) // Y_BLOCK_SIZE)
|
|
double_strided_kernel[grid](
|
|
in_ptr=left,
|
|
out_ptr=out,
|
|
in_y_stride=left.stride(0),
|
|
out_y_stride=out.stride(0),
|
|
X_BLOCK_SIZE=X_BLOCK_SIZE,
|
|
Y_BLOCK_SIZE=Y_BLOCK_SIZE,
|
|
)
|
|
return out
|
|
|
|
inp = torch.randn(64, 256, device=GPU_TYPE)
|
|
|
|
eager_out = f(inp)
|
|
compiled_out = torch.compile(f)(inp)
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
@inductor_config.patch(
|
|
triton_kernel_default_layout_constraint="needs_fixed_stride_order"
|
|
)
|
|
@requires_gpu
|
|
def test_layout_constraint_needs_fixed_stride_order(self):
|
|
# Construct a custom op whose output strides are (1, 2)
|
|
@torch.library.custom_op("mylib::weird_op_with_lowering", mutates_args={})
|
|
def weird_op_with_lowering(x: torch.Tensor) -> torch.Tensor:
|
|
return torch.empty_strided((2, 2), (1, 2), dtype=x.dtype, device=x.device)
|
|
|
|
@weird_op_with_lowering.register_fake
|
|
def _(x):
|
|
return torch.empty_strided((2, 2), (1, 2), dtype=x.dtype, device=x.device)
|
|
|
|
# The lowering for the custom op produces output strides (2, 1).
|
|
from torch._inductor.lowering import empty_strided, register_lowering
|
|
|
|
@register_lowering(torch.ops.mylib.weird_op_with_lowering)
|
|
def _(x):
|
|
return empty_strided(
|
|
x.shape, (2, 1), dtype=x.dtype, device=torch.device(GPU_TYPE, 0)
|
|
)
|
|
|
|
# Triton kernel that has different behavior depending on the input strides.
|
|
@triton.jit
|
|
def kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
output = offsets
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
def arange_out(x, out):
|
|
n_elements = x.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
kernel[grid](x, out, n_elements, BLOCK_SIZE=4)
|
|
|
|
def f(x):
|
|
y = weird_op_with_lowering(x)
|
|
# Inductor lowering will decide that y is better having strides (2, 1).
|
|
# This is different from the strides at tracing time (1, 2).
|
|
# Under the "needs_fixed_stride_order" config, inductor will coerce
|
|
# y to have strides (1, 2) before passing it to arange_out.
|
|
# If it doesn't, then the result will be different from eager mode.
|
|
arange_out(x, y)
|
|
return x + y
|
|
|
|
x = torch.randn(2, 2, device=GPU_TYPE)
|
|
eager_out = f(x)
|
|
|
|
compiled_inductor_f = torch.compile(f, backend="inductor", fullgraph=True)
|
|
compiled_inductor_out = compiled_inductor_f(x)
|
|
self.assertEqual(compiled_inductor_out, eager_out)
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_strided_input_nonzero_offset(self):
|
|
def f(inp):
|
|
# right has strides [256, 1] and storage offset 128
|
|
left, right = torch.split(inp, [128, 128], dim=1)
|
|
out = torch.empty_like(right)
|
|
X_BLOCK_SIZE, Y_BLOCK_SIZE = 32, 16
|
|
grid = (right.size(1) // X_BLOCK_SIZE, right.size(0) // Y_BLOCK_SIZE)
|
|
double_strided_kernel[grid](
|
|
in_ptr=right,
|
|
out_ptr=out,
|
|
in_y_stride=right.stride(0),
|
|
out_y_stride=out.stride(0),
|
|
X_BLOCK_SIZE=X_BLOCK_SIZE,
|
|
Y_BLOCK_SIZE=Y_BLOCK_SIZE,
|
|
)
|
|
return out
|
|
|
|
inp = torch.randn(64, 256, device=GPU_TYPE)
|
|
|
|
eager_out = f(inp)
|
|
compiled_out = torch.compile(f)(inp)
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_slice_and_view_input(self):
|
|
def f(inp):
|
|
# left has strides [256, 1]
|
|
left = inp[:, :128]
|
|
left = left.view(64, 4, 32)
|
|
out = torch.empty_like(left)
|
|
X_BLOCK_SIZE, Y_BLOCK_SIZE = 32, 16
|
|
grid = (
|
|
(left.size(1) * left.size(2)) // X_BLOCK_SIZE,
|
|
left.size(0) // Y_BLOCK_SIZE,
|
|
)
|
|
double_strided_kernel[grid](
|
|
in_ptr=left,
|
|
out_ptr=out,
|
|
in_y_stride=left.stride(0),
|
|
out_y_stride=out.stride(0),
|
|
X_BLOCK_SIZE=X_BLOCK_SIZE,
|
|
Y_BLOCK_SIZE=Y_BLOCK_SIZE,
|
|
)
|
|
return out + left
|
|
|
|
inp = torch.randn(64, 256, device=GPU_TYPE)
|
|
|
|
eager_out = f(inp)
|
|
compiled_out = torch.compile(f)(inp)
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_fallback(self):
|
|
def f(x, y):
|
|
out = torch.zeros_like(x)
|
|
out2 = torch.zeros_like(x)
|
|
# torch.mm is ExternKernelOut
|
|
add_kernel[(4,)](x, torch.mm(x, y), out, 4, 16)
|
|
# torch.sort creates fallback kernel and hence MultiOutput
|
|
add_kernel[(4,)](x, torch.sort(y).values, out, 4, 16)
|
|
return out, out2
|
|
|
|
x = torch.randn(4, 4, device=GPU_TYPE)
|
|
y = torch.randn(4, 4, device=GPU_TYPE)
|
|
eager_out = f(x, y)
|
|
compiled_out = torch.compile(f)(x, y)
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_out_of_order(self):
|
|
@triton.jit
|
|
def add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
out_ptr,
|
|
n_elements,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
def f(x, y):
|
|
out = torch.zeros_like(x)
|
|
n_elements = x.numel()
|
|
add_kernel[(n_elements,)](x, y, 4, out, n_elements)
|
|
return out
|
|
|
|
x = torch.randn(4, device=GPU_TYPE)
|
|
y = torch.randn(4, device=GPU_TYPE)
|
|
eager_out = f(x, y)
|
|
compiled_out = torch.compile(f)(x, y)
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
@requires_gpu
|
|
@dynamo_config.patch(capture_dynamic_output_shape_ops=True)
|
|
@dynamo_config.patch(capture_scalar_outputs=True)
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
def test_triton_kernel_unbacked_shape_tensor(self, backend):
|
|
@triton.jit
|
|
def square(
|
|
in_ptr,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr + offsets, mask=mask)
|
|
output = x * x
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
def f(x):
|
|
x = x[x > 2]
|
|
n_elements = x.numel()
|
|
output = torch.zeros_like(x)
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
square[grid](x, output, n_elements, BLOCK_SIZE=16)
|
|
return output
|
|
|
|
x = torch.randn(4, device=GPU_TYPE)
|
|
eager_out = f(x)
|
|
compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x)
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_triton_kernel_equal_to_1_arg(self, dynamic):
|
|
@triton.jit
|
|
def add_kernel_half_n_elements(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
half_n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < half_n_elements * 2
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
def f(x, y):
|
|
out = torch.empty_like(x)
|
|
half_n_elements = x.numel() // 2
|
|
add_kernel_half_n_elements[(half_n_elements,)](
|
|
x, y, out, half_n_elements, BLOCK_SIZE=16
|
|
)
|
|
return out
|
|
|
|
x = torch.randn(2, device=GPU_TYPE)
|
|
y = torch.randn(2, device=GPU_TYPE)
|
|
eager_out = f(x, y)
|
|
compiled_out, sources = run_and_get_code(
|
|
torch.compile(f, dynamic=dynamic), x, y
|
|
)
|
|
|
|
if triton_version_uses_attrs_dict():
|
|
self.assertFalse("equal_to" in sources[0])
|
|
else:
|
|
if dynamic:
|
|
# when half_n_elements passed to the Triton kernel is
|
|
# dynamic, equal_to_1 specializaiton can't be enforced
|
|
|
|
# also, equal_to_1 specialization doesn't occur (or appear in the signature)
|
|
# for newer versions ofo triton (i.e. the ones where triton_version_uses_attrs_dict() == True)
|
|
self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
|
|
else:
|
|
self.assertTrue(_triton_get_ast_equal_to_str((3,)) in sources[0])
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_triton_kernel_equal_to_1_float_arg(self, dynamic):
|
|
def f(x, y):
|
|
out = torch.empty_like(x)
|
|
n_elements = x.numel()
|
|
scaling_factor = (n_elements**0) / 1.0
|
|
add_kernel_with_scaling[(n_elements,)](
|
|
x,
|
|
y,
|
|
out,
|
|
n_elements,
|
|
scaling_factor,
|
|
BLOCK_SIZE=16,
|
|
)
|
|
return out
|
|
|
|
x = torch.randn(2, device=GPU_TYPE)
|
|
y = torch.randn(2, device=GPU_TYPE)
|
|
eager_out = f(x, y)
|
|
compiled_out, sources = run_and_get_code(
|
|
torch.compile(f, dynamic=dynamic), x, y
|
|
)
|
|
|
|
# float 1.0 (both literal or symbolic)
|
|
# should not be added to equal_to_1
|
|
if not triton_version_uses_attrs_dict():
|
|
self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
@requires_gpu
|
|
@skipIfRocm
|
|
def test_triton_kernel_with_imported_symbol(self):
|
|
@triton.jit
|
|
def add_kernel_with_imported_symbol(
|
|
in_ptr,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr + offsets, mask=mask)
|
|
output = fast_dividef(x, 3.14)
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
def f(x):
|
|
out = torch.empty_like(x)
|
|
n_elements = x.numel()
|
|
add_kernel_with_imported_symbol[(n_elements,)](
|
|
x, out, n_elements, BLOCK_SIZE=16
|
|
)
|
|
return out
|
|
|
|
x = torch.randn(4, device=GPU_TYPE)
|
|
eager_out = f(x)
|
|
compiled_out = torch.compile(f)(x)
|
|
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
@requires_gpu
|
|
@skipIfRocm
|
|
def test_triton_kernel_with_imported_symbol_with_custom_name(self):
|
|
@triton.jit
|
|
def add_kernel_with_imported_symbol(
|
|
in_ptr,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr + offsets, mask=mask)
|
|
output = my_fast_dividef(x, 3.14)
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
def f(x):
|
|
out = torch.empty_like(x)
|
|
n_elements = x.numel()
|
|
add_kernel_with_imported_symbol[(n_elements,)](
|
|
x, out, n_elements, BLOCK_SIZE=16
|
|
)
|
|
return out
|
|
|
|
x = torch.randn(4, device=GPU_TYPE)
|
|
eager_out = f(x)
|
|
compiled_out = torch.compile(f)(x)
|
|
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("size", [4, 16])
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_triton_kernel_different_shapes(self, size, dynamic):
|
|
from torch._inductor.utils import run_and_get_code
|
|
|
|
def f(x, y, xx, yy):
|
|
n_elements = x.numel()
|
|
output_1 = torch.zeros_like(x)
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel[grid](x, y, output_1, n_elements, BLOCK_SIZE=4)
|
|
|
|
n_elements = xx.numel()
|
|
output_2 = torch.zeros_like(xx)
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel[grid](xx, yy, output_2, n_elements, BLOCK_SIZE=4)
|
|
|
|
return output_1, output_2
|
|
|
|
x = torch.rand(size, device=GPU_TYPE)
|
|
y = torch.rand(size, device=GPU_TYPE)
|
|
xx = torch.rand(size, size, device=GPU_TYPE)
|
|
yy = torch.rand(size, size, device=GPU_TYPE)
|
|
args = [x, y, xx, yy]
|
|
|
|
eager_out = f(*args)
|
|
compiled_out, (code,) = run_and_get_code(
|
|
torch.compile(f, fullgraph=True, dynamic=dynamic, backend="inductor"), *args
|
|
)
|
|
if size == 4 and not dynamic:
|
|
# Produce 2 kernels due to divisibility
|
|
self.assertTrue(self._kernel_launched_in_code("add_kernel_0", code))
|
|
self.assertTrue(self._kernel_launched_in_code("add_kernel_1", code))
|
|
else:
|
|
# size == 16 or dynamic
|
|
# Only one kernel
|
|
self.assertTrue(self._kernel_launched_in_code("add_kernel_0", code))
|
|
self.assertFalse(self._kernel_launched_in_code("add_kernel_1", code))
|
|
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
def test_triton_kernel_triton_dtype(self, dynamic, backend):
|
|
@triton.jit
|
|
def add_kernel_with_dtype(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
dtype: "tl.constexpr",
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask).to(dtype)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask).to(dtype)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
def f(x, y, dtype_torch, dtype_triton):
|
|
output = torch.zeros_like(x).to(dtype=dtype_torch)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel_with_dtype[grid](
|
|
x, y, output, dtype_triton, n_elements, BLOCK_SIZE=4
|
|
)
|
|
return output
|
|
|
|
x = torch.randn(4, device=GPU_TYPE)
|
|
y = torch.randn(4, device=GPU_TYPE)
|
|
args_list = [(x, y, torch.float32, tl.float32)]
|
|
if torch.cuda.is_bf16_supported(including_emulation=False):
|
|
args_list.append((x, y, torch.bfloat16, tl.bfloat16))
|
|
|
|
for args in args_list:
|
|
eager_out = f(*args)
|
|
compiled_out = torch.compile(
|
|
f, fullgraph=True, backend=backend, dynamic=dynamic
|
|
)(*args)
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
def test_triton_kernel_special_kwargs_with_autotune(self, backend):
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 128}),
|
|
triton.Config({"BLOCK_SIZE": 64}),
|
|
],
|
|
key=["n_elements"],
|
|
)
|
|
@triton.jit
|
|
def add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@torch.compile(fullgraph=True, backend=backend)
|
|
def f(x, y):
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel[grid](
|
|
x,
|
|
y,
|
|
output,
|
|
n_elements,
|
|
num_warps=8,
|
|
num_stages=3,
|
|
)
|
|
return output
|
|
|
|
x = torch.randn(4, device=GPU_TYPE)
|
|
f(x, x)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
def test_triton_kernel_empty_autotune_config_dict(self, backend):
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({}, num_stages=2),
|
|
triton.Config({}, num_stages=3),
|
|
],
|
|
key=["n_elements"],
|
|
)
|
|
@triton.jit
|
|
def add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@torch.compile(fullgraph=True, backend=backend)
|
|
def f(x, y):
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel[grid](
|
|
x,
|
|
y,
|
|
output,
|
|
n_elements,
|
|
BLOCK_SIZE=128,
|
|
)
|
|
return output
|
|
|
|
x = torch.randn(4, device=GPU_TYPE)
|
|
f(x, x)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("autotune", [False, True])
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
def test_triton_kernel_special_params(self, autotune, backend):
|
|
@triton.jit
|
|
def special_params_kernel(
|
|
in_ptr,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
num_warps: "tl.constexpr",
|
|
num_stages: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr + offsets, mask=mask)
|
|
output = x * num_stages + num_warps
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
NUM_WARPS = 4
|
|
NUM_STAGES = 3
|
|
|
|
if autotune:
|
|
special_params_kernel = triton.autotune(
|
|
configs=[
|
|
triton.Config(
|
|
{"BLOCK_SIZE": 128},
|
|
num_stages=NUM_STAGES,
|
|
num_warps=NUM_WARPS,
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE": 64},
|
|
num_stages=NUM_STAGES,
|
|
num_warps=NUM_WARPS,
|
|
),
|
|
],
|
|
key=["n_elements"],
|
|
)(special_params_kernel)
|
|
kwargs = {}
|
|
else:
|
|
kwargs = {
|
|
"BLOCK_SIZE": 128,
|
|
"num_stages": NUM_STAGES,
|
|
"num_warps": NUM_WARPS,
|
|
}
|
|
|
|
def f(x):
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
special_params_kernel[grid](
|
|
x,
|
|
output,
|
|
n_elements,
|
|
**kwargs,
|
|
)
|
|
return output
|
|
|
|
x = torch.randn(4, device=GPU_TYPE)
|
|
eager_out = f(x)
|
|
compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x)
|
|
expected_out = x * NUM_STAGES + NUM_WARPS
|
|
self.assertEqual(eager_out, expected_out)
|
|
self.assertEqual(compiled_out, expected_out)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
def test_triton_kernel_multiple_outputs(self, dynamic, backend):
|
|
@triton.jit
|
|
def add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
out_ptr2,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
tl.store(out_ptr2 + offsets, output + 1, mask=mask)
|
|
|
|
@torch.compile(fullgraph=True, backend=backend, dynamic=dynamic)
|
|
def f(x, y, z):
|
|
output = torch.empty_like(x)
|
|
output2 = torch.empty_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel[grid](x, y, output, output2, n_elements, BLOCK_SIZE=16)
|
|
# The z return is intentional: we're testing training
|
|
return output, output2, z**2
|
|
|
|
x = torch.randn(3, requires_grad=True, device=GPU_TYPE)
|
|
y = torch.randn(3, requires_grad=True, device=GPU_TYPE)
|
|
z = torch.randn(3, requires_grad=True, device=GPU_TYPE)
|
|
out, out2, out3 = f(x, y, z)
|
|
self.assertEqual(out, x + y)
|
|
self.assertEqual(out2, x + y + 1)
|
|
self.assertEqual(out3, z**2)
|
|
|
|
@requires_gpu
|
|
@unittest.skipIf(not has_triton_tma(), "requires Triton TMA support")
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_tma_capture_and_functionalize(self, dynamic):
|
|
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
|
|
|
|
kernel_side_table.reset_table()
|
|
|
|
def f(a, b):
|
|
BLOCK_SIZE = 256
|
|
out = torch.zeros_like(a)
|
|
n_elements = out.numel()
|
|
|
|
desc_a, desc_b, desc_out = (
|
|
triton.tools.experimental_descriptor.create_1d_tma_descriptor(
|
|
t.data_ptr(),
|
|
n_elements,
|
|
BLOCK_SIZE,
|
|
t.element_size(),
|
|
)
|
|
for t in (a, b, out)
|
|
)
|
|
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel_with_tma_1d[grid](
|
|
desc_a,
|
|
desc_b,
|
|
desc_out,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
|
|
return out
|
|
|
|
a = torch.randn(301, device=GPU_TYPE)
|
|
b = torch.randn(301, device=GPU_TYPE)
|
|
|
|
backend = torch._dynamo.testing.AotEagerAndRecordGraphs()
|
|
torch.compile(
|
|
f,
|
|
fullgraph=True,
|
|
backend=backend,
|
|
dynamic=dynamic,
|
|
)(a, b)
|
|
|
|
if dynamic:
|
|
self.assertExpectedInline(
|
|
backend.fw_graphs[0].code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
zeros_like = torch.ops.aten.zeros_like.default(arg1_1, pin_memory = False)
|
|
add_2 = arg0_1 + 256
|
|
sub_1 = add_2 - 1; add_2 = None
|
|
floordiv = sub_1 // 256; sub_1 = None
|
|
triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 0, grid = [(floordiv, 1, 1)], tma_descriptor_metadata = {'in_desc_ptr0': ([arg0_1], [256], 4), 'in_desc_ptr1': ([arg0_1], [256], 4), 'out_desc_ptr': ([arg0_1], [256], 4)}, kwargs = {'in_desc_ptr0': arg1_1, 'in_desc_ptr1': arg2_1, 'out_desc_ptr': zeros_like}, tensors_to_clone = ['out_desc_ptr']); floordiv = arg0_1 = arg1_1 = arg2_1 = zeros_like = None
|
|
getitem = triton_kernel_wrapper_functional_proxy['out_desc_ptr']; triton_kernel_wrapper_functional_proxy = None
|
|
return (getitem,)""",
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
backend.fw_graphs[0].code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1):
|
|
zeros_like = torch.ops.aten.zeros_like.default(arg0_1, pin_memory = False)
|
|
triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 0, grid = [(2, 1, 1)], tma_descriptor_metadata = {'in_desc_ptr0': ([301], [256], 4), 'in_desc_ptr1': ([301], [256], 4), 'out_desc_ptr': ([301], [256], 4)}, kwargs = {'in_desc_ptr0': arg0_1, 'in_desc_ptr1': arg1_1, 'out_desc_ptr': zeros_like}, tensors_to_clone = ['out_desc_ptr']); arg0_1 = arg1_1 = zeros_like = None
|
|
getitem = triton_kernel_wrapper_functional_proxy['out_desc_ptr']; triton_kernel_wrapper_functional_proxy = None
|
|
return (getitem,)""",
|
|
)
|
|
|
|
@requires_gpu
|
|
@unittest.skipIf(not has_triton_tma(), "requires Triton TMA support")
|
|
@common_utils.parametrize("after_data_ptr", [False, True])
|
|
@common_utils.parametrize("after_create_desc", [False, True])
|
|
def test_tma_graph_breaks(self, after_data_ptr, after_create_desc):
|
|
def f(a, b):
|
|
BLOCK_SIZE = 256
|
|
out = torch.zeros_like(a)
|
|
n_elements = out.numel()
|
|
|
|
if after_data_ptr:
|
|
torch._dynamo.graph_break()
|
|
|
|
descs = [
|
|
triton.tools.experimental_descriptor.create_1d_tma_descriptor(
|
|
t.data_ptr(),
|
|
n_elements,
|
|
BLOCK_SIZE,
|
|
t.element_size(),
|
|
)
|
|
for t in (a, b, out)
|
|
]
|
|
|
|
if after_create_desc:
|
|
torch._dynamo.graph_break()
|
|
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel_with_tma_1d[grid](
|
|
*descs,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
|
|
return out
|
|
|
|
a = torch.randn(301, device=GPU_TYPE)
|
|
b = torch.randn(301, device=GPU_TYPE)
|
|
|
|
expected_out = a + b
|
|
eager_out = f(a, b)
|
|
compiled_out = torch.compile(
|
|
f,
|
|
fullgraph=False,
|
|
backend="eager",
|
|
dynamic=False,
|
|
)(a, b)
|
|
|
|
self.assertEqual(eager_out, expected_out)
|
|
self.assertEqual(compiled_out, expected_out)
|
|
|
|
@requires_gpu
|
|
@unittest.skipIf(not has_triton_tma(), "requires Triton TMA support")
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
def test_tma_descriptor_1d(self, dynamic, backend):
|
|
def f(a, b):
|
|
BLOCK_SIZE = 256
|
|
out = torch.zeros_like(a)
|
|
n_elements = out.numel()
|
|
|
|
desc_a, desc_b, desc_out = (
|
|
triton.tools.experimental_descriptor.create_1d_tma_descriptor(
|
|
t.data_ptr(),
|
|
n_elements,
|
|
BLOCK_SIZE,
|
|
t.element_size(),
|
|
)
|
|
for t in (a, b, out)
|
|
)
|
|
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel_with_tma_1d[grid](
|
|
desc_a,
|
|
desc_b,
|
|
desc_out,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
|
|
return out
|
|
|
|
a = torch.randn(301, device=GPU_TYPE)
|
|
b = torch.randn(301, device=GPU_TYPE)
|
|
|
|
expected_out = a + b
|
|
eager_out = f(a, b)
|
|
compiled_out = torch.compile(
|
|
f,
|
|
fullgraph=True,
|
|
backend=backend,
|
|
dynamic=dynamic,
|
|
)(a, b)
|
|
|
|
self.assertEqual(eager_out, expected_out)
|
|
self.assertEqual(compiled_out, expected_out)
|
|
|
|
@requires_gpu
|
|
@unittest.skipIf(not has_triton_tma(), "requires Triton TMA support")
|
|
def test_tma_descriptor_dedup(self):
|
|
def f(a):
|
|
BLOCK_SIZE = 256
|
|
out = torch.zeros_like(a)
|
|
n_elements = out.numel()
|
|
|
|
desc_a, desc_out = (
|
|
triton.tools.experimental_descriptor.create_1d_tma_descriptor(
|
|
t.data_ptr(),
|
|
n_elements,
|
|
BLOCK_SIZE,
|
|
t.element_size(),
|
|
)
|
|
for t in (a, out)
|
|
)
|
|
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel_with_tma_1d[grid](
|
|
desc_a,
|
|
desc_a,
|
|
desc_out,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
|
|
return out
|
|
|
|
a = torch.randn(301, device=GPU_TYPE)
|
|
|
|
expected_out = a + a
|
|
eager_out = f(a)
|
|
compiled_out, (code,) = run_and_get_code(
|
|
torch.compile(
|
|
f,
|
|
fullgraph=True,
|
|
backend="inductor",
|
|
dynamic=True,
|
|
),
|
|
a,
|
|
)
|
|
|
|
self.assertEqual(eager_out, expected_out)
|
|
self.assertEqual(compiled_out, expected_out)
|
|
|
|
# 2 calls: one for two inputs (dedupped), one for the output
|
|
self.assertEqual(code.count("create_1d_tma_descriptor("), 2)
|
|
|
|
@requires_gpu
|
|
@unittest.skipIf(not has_triton_tma(), "requires Triton TMA support")
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager"])
|
|
def test_tma_descriptor_2d(self, dynamic, backend):
|
|
def f(a, b):
|
|
BLOCK_SIZE_X = 16
|
|
BLOCK_SIZE_Y = 32
|
|
out = torch.zeros_like(a)
|
|
x_size, y_size = out.size()
|
|
|
|
desc_a, desc_b, desc_out = (
|
|
triton.tools.experimental_descriptor.create_2d_tma_descriptor(
|
|
t.data_ptr(),
|
|
x_size,
|
|
y_size,
|
|
BLOCK_SIZE_X,
|
|
BLOCK_SIZE_Y,
|
|
t.element_size(),
|
|
)
|
|
for t in (a, b, out)
|
|
)
|
|
|
|
grid = lambda meta: (
|
|
triton.cdiv(x_size, meta["BLOCK_SIZE_X"]),
|
|
triton.cdiv(y_size, meta["BLOCK_SIZE_Y"]),
|
|
)
|
|
add_kernel_with_tma_2d[grid](
|
|
desc_a,
|
|
desc_b,
|
|
desc_out,
|
|
BLOCK_SIZE_X=BLOCK_SIZE_X,
|
|
BLOCK_SIZE_Y=BLOCK_SIZE_Y,
|
|
)
|
|
|
|
return out
|
|
|
|
a = torch.randn((25, 16), device=GPU_TYPE)
|
|
b = torch.randn((25, 16), device=GPU_TYPE)
|
|
|
|
expected_out = a + b
|
|
eager_out = f(a, b)
|
|
compiled_out = torch.compile(
|
|
f,
|
|
fullgraph=True,
|
|
backend=backend,
|
|
dynamic=dynamic,
|
|
)(a, b)
|
|
|
|
self.assertEqual(eager_out, expected_out)
|
|
self.assertEqual(compiled_out, expected_out)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
def test_triton_kernel_num_ctas(self, backend):
|
|
@triton.jit
|
|
def kernel(X):
|
|
return
|
|
|
|
@torch.compile(fullgraph=True, backend=backend)
|
|
def f(x):
|
|
kernel[(1,)](x, num_ctas=1)
|
|
kernel.run(x, num_ctas=1, grid=(1,), warmup=False)
|
|
return x
|
|
|
|
msg = "Passing num_ctas directly to the Triton kernel is not supported. Please use a Config in @triton.autotune instead."
|
|
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
|
x = torch.randn(4, device=GPU_TYPE)
|
|
f(x)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
def test_triton_kernel_special_kwargs_without_autotune(self, backend):
|
|
@triton.jit
|
|
def add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@torch.compile(fullgraph=True, backend=backend)
|
|
def f(x, y):
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel[grid](
|
|
x,
|
|
y,
|
|
output,
|
|
n_elements,
|
|
BLOCK_SIZE=128,
|
|
num_warps=8,
|
|
num_stages=3,
|
|
)
|
|
return output
|
|
|
|
x = torch.randn(4, device=GPU_TYPE)
|
|
f(x, x)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
@common_utils.parametrize("autotune_at_compile_time", [True, False])
|
|
def test_triton_kernel_restore_value(self, backend, autotune_at_compile_time):
|
|
if autotune_at_compile_time and backend != "inductor":
|
|
raise unittest.SkipTest("compile-time autotuning only exists in inductor")
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 16}, num_stages=3, num_warps=8),
|
|
triton.Config({"BLOCK_SIZE": 32}, num_stages=3, num_warps=8),
|
|
],
|
|
key=[],
|
|
restore_value=["in_ptr0"],
|
|
)
|
|
@triton.jit
|
|
def increment_kernel(
|
|
in_ptr0,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
output = x + 1
|
|
tl.store(in_ptr0 + offsets, output, mask=mask)
|
|
|
|
@torch.compile(fullgraph=True, backend=backend)
|
|
def f(x):
|
|
n_elements = x.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
increment_kernel[grid](x, n_elements=n_elements)
|
|
return x
|
|
|
|
x = torch.rand(4, device=GPU_TYPE)
|
|
prev = x.clone()
|
|
|
|
with inductor_config.patch(
|
|
{"triton.autotune_at_compile_time": autotune_at_compile_time}
|
|
):
|
|
f(x)
|
|
|
|
# make sure x was restored after autotuning
|
|
torch.testing.assert_close(x, prev + 1)
|
|
|
|
@requires_gpu
|
|
@parametrize("dtype", (torch.float16, torch.float32, torch.float64))
|
|
def test_triton_kernel_float64_constant(self, dtype):
|
|
def f(x):
|
|
return x * (0.12 * x.shape[0])
|
|
|
|
x = torch.ones(200, device=GPU_TYPE, dtype=dtype)
|
|
|
|
eager_out = f(x)
|
|
compiled_out = torch.compile(f, dynamic=True)(x)
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
# TODO enable this test case on XPU.
|
|
@requires_cuda
|
|
@parametrize("cfg", ["normal", "cpp_wrapper"])
|
|
def test_triton_kernel_dtype_view(self, cfg):
|
|
# https://github.com/pytorch/pytorch/issues/136159
|
|
if cfg == "normal":
|
|
config_kwargs = {"cpp_wrapper": False}
|
|
elif cfg == "cpp_wrapper":
|
|
config_kwargs = {"cpp_wrapper": True}
|
|
|
|
with inductor_config.patch(**config_kwargs):
|
|
|
|
@triton.jit
|
|
def _triton_kernel(out_ptr, numel, BLOCK_SIZE: tl.constexpr):
|
|
pid = tl.program_id(0)
|
|
offsets = BLOCK_SIZE * pid + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < numel
|
|
ones = tl.full((BLOCK_SIZE,), 1, tl.float16)
|
|
tl.store(out_ptr + offsets, ones, mask)
|
|
|
|
def fn(x):
|
|
buf = torch.empty(x.shape, device=x.device, dtype=torch.float16)
|
|
# the buf.view() should be a view sharing the same storage as buf.
|
|
bfloat_buf = buf.view(dtype=torch.bfloat16)
|
|
BLOCK_SIZE = 256
|
|
numel = buf.numel()
|
|
grid = (triton.cdiv(numel, BLOCK_SIZE),)
|
|
_triton_kernel[grid](bfloat_buf, numel, BLOCK_SIZE)
|
|
return buf, bfloat_buf
|
|
|
|
fn_c = torch.compile(fn)
|
|
|
|
x = torch.randn(8, device=GPU_TYPE)
|
|
out_c = fn_c(x)
|
|
out_e = fn(x)
|
|
|
|
# expect view() to be an actual view, sharing the same data as the original buffer
|
|
# verify first that this is true in the eager output
|
|
self.assertEqual(out_e[0].data_ptr(), out_e[1].data_ptr())
|
|
# .. and also in the compiled output
|
|
self.assertEqual(out_c[0].data_ptr(), out_c[1].data_ptr())
|
|
|
|
self.assertEqual(out_e[0], out_c[0])
|
|
self.assertEqual(out_e[1], out_c[1])
|
|
|
|
# TODO enable this test case on XPU.
|
|
@requires_gpu
|
|
def test_i64_input(self):
|
|
# The i64 "seed" input needs to be marked as "i64", not "i32".
|
|
@triton.jit
|
|
def triton_add_noise_(x_ptr, y_ptr, seed, numel, BLOCK_SIZE: tl.constexpr):
|
|
pid = tl.program_id(0)
|
|
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
|
|
x = tl.load(x_ptr + offsets, mask=(offsets < numel))
|
|
rnd = tl.rand(seed, offsets)
|
|
res = x + rnd
|
|
tl.store(y_ptr + offsets, res, mask=(offsets < numel))
|
|
|
|
def add_noise(x, seed):
|
|
y = torch.empty_like(x)
|
|
numel = x.numel()
|
|
BLOCK_SIZE = 256
|
|
|
|
def grid(meta):
|
|
return (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
|
|
|
|
triton_add_noise_[grid](x, y, seed, numel, BLOCK_SIZE)
|
|
return y
|
|
|
|
def fn(x):
|
|
x = x * x
|
|
seed = torch.randint(
|
|
low=2**32, high=2**62, size=(1,), dtype=torch.int64
|
|
).item()
|
|
return add_noise(x, seed)
|
|
|
|
inp = torch.rand(400, device=GPU_TYPE)
|
|
torch._dynamo.mark_dynamic(inp, 0)
|
|
|
|
fn_c = torch.compile(fn, fullgraph=True)
|
|
with dynamo_config.patch(capture_scalar_outputs=True):
|
|
res = fn_c(inp)
|
|
|
|
self.assertTrue(((res < 2) & (res >= 0)).all().item())
|
|
|
|
@requires_gpu
|
|
@parametrize("wrapped", [False, True])
|
|
@parametrize("autotune", [False, True])
|
|
def test_constexpr_dynamic_shapes(self, wrapped, autotune):
|
|
# https://github.com/pytorch/pytorch/issues/136504
|
|
@triton.jit
|
|
def triton_(
|
|
x_ptr,
|
|
y_ptr,
|
|
NUMEL: tl.constexpr,
|
|
IS_ODD: tl.constexpr,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(0)
|
|
offsets = BLOCK_SIZE * pid + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < NUMEL
|
|
|
|
data = tl.load(x_ptr + offsets, mask)
|
|
result = data * data
|
|
if IS_ODD:
|
|
result = result + 1
|
|
|
|
tl.store(y_ptr + offsets, result, mask)
|
|
|
|
if autotune:
|
|
triton_ = triton.autotune(
|
|
[
|
|
triton.Config(kwargs={"BLOCK_SIZE": 128}),
|
|
triton.Config(kwargs={"BLOCK_SIZE": 256}),
|
|
],
|
|
key=[],
|
|
)(triton_)
|
|
|
|
def triton_kernel_impl(x: torch.Tensor) -> torch.Tensor:
|
|
y = torch.empty_like(x)
|
|
numel = x.numel()
|
|
|
|
args = [x, y, numel, numel % 2 == 0]
|
|
if not autotune:
|
|
args.append(256) # BLOCK_SIZE
|
|
|
|
def grid(meta):
|
|
return (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
|
|
|
|
if wrapped:
|
|
capture_triton(triton_)[grid](*args)
|
|
else:
|
|
triton_[grid](*args)
|
|
return y
|
|
|
|
if wrapped:
|
|
triton_kernel = torch.library.triton_op(
|
|
"constexpr_test::square", triton_kernel_impl, mutates_args={}
|
|
)
|
|
else:
|
|
triton_kernel = triton_kernel_impl
|
|
|
|
def fn(x):
|
|
return triton_kernel(x)
|
|
|
|
fn_c = torch.compile(fn, dynamic=True)
|
|
|
|
x = torch.randn(512 + 5, device=GPU_TYPE)
|
|
res = fn_c(x)
|
|
self.assertEqual(x * x, res)
|
|
|
|
x2 = torch.randn(1024 + 5, device=GPU_TYPE)
|
|
res2 = fn_c(x2)
|
|
self.assertEqual(x2 * x2, res2)
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_none_args(self):
|
|
# https://github.com/pytorch/pytorch/issues/115344
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 32}, num_stages=5, num_warps=2),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
|
|
],
|
|
key=["n_elements"],
|
|
)
|
|
@triton.jit
|
|
def sin_kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
if in_ptr0 is not None:
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
else:
|
|
x = 0.0
|
|
output = tl.sin(x)
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
def sin_triton(x, out):
|
|
n_elements = out.numel()
|
|
sin_kernel[(n_elements,)](x, out, n_elements)
|
|
|
|
x = torch.randn(65, device=GPU_TYPE)
|
|
out = torch.empty_like(x)
|
|
out_compiled = torch.empty_like(x)
|
|
sin_triton_compiled = torch.compile(fullgraph=True)(sin_triton)
|
|
|
|
sin_triton(x, out)
|
|
sin_triton_compiled(x, out_compiled)
|
|
self.assertEqual(out, out_compiled)
|
|
|
|
sin_triton(None, out)
|
|
sin_triton_compiled(None, out_compiled)
|
|
self.assertEqual(out, out_compiled)
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_global_constexpr(self):
|
|
@triton.jit
|
|
def triton_(in_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
|
pid = tl.program_id(0)
|
|
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
x = tl.load(in_ptr + offsets)
|
|
output = x + FLOAT_CONSTANT_C
|
|
tl.store(out_ptr + offsets, output)
|
|
|
|
def fn(x):
|
|
y = torch.empty_like(x)
|
|
BLOCK_SIZE = 256
|
|
grid = (triton.cdiv(x.numel(), BLOCK_SIZE),)
|
|
triton_[grid](x, y, BLOCK_SIZE)
|
|
return y
|
|
|
|
# make sure FLOAT_CONSTANT_C is NOT annotated
|
|
self.assertFalse("FLOAT_CONSTANT_C" in globals().get("__annotations__", {}))
|
|
# sanity check: STRING_CONSTANT_C _should_ be annotated
|
|
self.assertTrue("STRING_CONSTANT_C" in globals().get("__annotations__", {}))
|
|
|
|
x = torch.randn(512, device=GPU_TYPE)
|
|
expected = x + 3.14
|
|
actual = torch.compile(fn)(x)
|
|
self.assertEqual(expected, actual)
|
|
|
|
@requires_gpu
|
|
@unittest.skipIf(
|
|
not triton_version_uses_attrs_dict(),
|
|
"Test is only valid for new triton versions where attrs is represented by a raw dict",
|
|
)
|
|
def test_triton_attrs_dict_equal_1_None_format(self):
|
|
@triton.jit
|
|
def triton_(in_ptr, out_ptr, numel, add_amount, BLOCK_SIZE: tl.constexpr):
|
|
offsets = tl.arange(0, BLOCK_SIZE)
|
|
x = tl.load(in_ptr + offsets, mask=(offsets < numel))
|
|
output = x * x
|
|
if add_amount is not None:
|
|
output = output + add_amount
|
|
tl.store(out_ptr + offsets, output, mask=(offsets < numel))
|
|
|
|
def fn(x):
|
|
y = torch.empty_like(x)
|
|
BLOCK_SIZE = 256
|
|
grid = (1,)
|
|
triton_[grid](x, y, x.numel(), None, BLOCK_SIZE)
|
|
return y
|
|
|
|
x = torch.full((1,), 2.5, device=GPU_TYPE)
|
|
expected = fn(x)
|
|
|
|
fn_c = torch.compile(fn)
|
|
res, code = run_and_get_code(fn_c, x)
|
|
self.assertEqual(expected, res)
|
|
|
|
FileCheck().check("triton_meta=").check("'constants':").check("'numel': 1").run(
|
|
code[0]
|
|
)
|
|
FileCheck().check("triton_meta=").check("'constants':").check(
|
|
"'add_amount': None"
|
|
).run(code[0])
|
|
FileCheck().check("triton_meta=").check("'constants':").check(
|
|
"'BLOCK_SIZE': 256"
|
|
).run(code[0])
|
|
|
|
FileCheck().check("triton_meta=").check("'signature':").check(
|
|
"'numel': 'constexpr'"
|
|
).run(code[0])
|
|
FileCheck().check("triton_meta=").check("'signature':").check(
|
|
"'add_amount': 'constexpr'"
|
|
).run(code[0])
|
|
FileCheck().check("triton_meta=").check("'signature':").check(
|
|
"'BLOCK_SIZE': 'constexpr'"
|
|
).run(code[0])
|
|
|
|
|
|
def make_mutation_test(fn):
|
|
@requires_gpu
|
|
def test_fn(self):
|
|
from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors
|
|
|
|
kernel, inputs, outputs = fn()
|
|
self.assertListEqual(
|
|
identify_mutated_tensors(kernel, inputs),
|
|
outputs,
|
|
)
|
|
|
|
return test_fn
|
|
|
|
|
|
# Triton codegen suffers from scoping issues.
|
|
# Define helpers here
|
|
if HAS_GPU:
|
|
|
|
@triton.jit
|
|
def helper_id(p):
|
|
return p
|
|
|
|
@triton.jit
|
|
def helper_add_and_out(x, y, out_ptr):
|
|
return x + y, out_ptr
|
|
|
|
|
|
class MutationTests(torch._inductor.test_case.TestCase):
|
|
# Tests injected below
|
|
|
|
@make_mutation_test
|
|
def test_out_of_order_kernel():
|
|
@triton.jit
|
|
def add_kernel_out_of_order(
|
|
in_ptr0,
|
|
n_elements,
|
|
in_ptr1,
|
|
out_ptr,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
t = torch.randn(4)
|
|
return (
|
|
add_kernel_out_of_order,
|
|
{
|
|
"in_ptr0": t,
|
|
"n_elements": 4,
|
|
"in_ptr1": t,
|
|
"out_ptr": t,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["out_ptr"],
|
|
)
|
|
|
|
@make_mutation_test
|
|
def test_out_of_order_kernel_call():
|
|
@triton.jit
|
|
def add_kernel_out_of_order_fn1(
|
|
in_ptr0,
|
|
n_elements,
|
|
in_ptr1,
|
|
out_ptr,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
add_kernel_out_of_order_fn2(
|
|
in_ptr0, in_ptr1, n_elements, out_ptr, BLOCK_SIZE=BLOCK_SIZE
|
|
)
|
|
|
|
t = torch.randn(4)
|
|
return (
|
|
add_kernel_out_of_order_fn1,
|
|
{
|
|
"in_ptr0": t,
|
|
"n_elements": 4,
|
|
"in_ptr1": t,
|
|
"out_ptr": t,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["out_ptr"],
|
|
)
|
|
|
|
@make_mutation_test
|
|
def test_reduce_sum():
|
|
@triton.jit
|
|
def reduce_sum_kernel(a_ptr, c_ptr, stride_am, stride_an):
|
|
offs_am = tl.arange(0, 4)
|
|
offs_an = tl.arange(0, 4)
|
|
a_ptrs = a_ptr + (
|
|
offs_am[:, None] * stride_am + offs_an[None, :] * stride_an
|
|
)
|
|
a = tl.load(a_ptrs)
|
|
m = tl.sum(a, axis=1)
|
|
tl.store(c_ptr + tl.arange(0, 4), m)
|
|
|
|
t = torch.randn(4)
|
|
kernel = reduce_sum_kernel
|
|
kwargs = {
|
|
"a_ptr": t,
|
|
"c_ptr": t,
|
|
"stride_am": 4,
|
|
"stride_an": 4,
|
|
}
|
|
|
|
# TODO(aakhundov): tt.reduce is now supported, but only
|
|
# in the new MLIR-based Triton analysis pass (not in the
|
|
# old TTIR string parsing-based one). remove this gating
|
|
# and use ["c_ptr"] as `expected` after the new Triton
|
|
# pin lands both in OSS and internally.
|
|
ttir_module, _ = generate_ttir(kernel, kwargs)
|
|
if hasattr(ttir_module, "walk"):
|
|
# with MLIR-based Triton analysis pass
|
|
expected = ["c_ptr"]
|
|
else:
|
|
# with TTIR string parsing-based Triton analysis pass
|
|
expected = ["a_ptr", "c_ptr"]
|
|
|
|
return (
|
|
kernel,
|
|
kwargs,
|
|
expected,
|
|
)
|
|
|
|
@make_mutation_test
|
|
def test_argmax():
|
|
@triton.jit
|
|
def argmax_kernel(a_ptr, c_ptr, stride_am, stride_an):
|
|
offs_am = tl.arange(0, 4)
|
|
offs_an = tl.arange(0, 4)
|
|
a_ptrs = a_ptr + (
|
|
offs_am[:, None] * stride_am + offs_an[None, :] * stride_an
|
|
)
|
|
a = tl.load(a_ptrs)
|
|
m = tl.argmax(a, axis=1)
|
|
tl.store(c_ptr + tl.arange(0, 4), m)
|
|
|
|
t = torch.randn(4)
|
|
kernel = argmax_kernel
|
|
kwargs = {
|
|
"a_ptr": t,
|
|
"c_ptr": t,
|
|
"stride_am": 4,
|
|
"stride_an": 4,
|
|
}
|
|
|
|
# TODO(aakhundov): tt.reduce is now supported, but only
|
|
# in the new MLIR-based Triton analysis pass (not in the
|
|
# old TTIR string parsing-based one). remove this gating
|
|
# and use ["c_ptr"] as `expected` after the new Triton
|
|
# pin lands both in OSS and internally.
|
|
ttir_module, _ = generate_ttir(kernel, kwargs)
|
|
if hasattr(ttir_module, "walk"):
|
|
# with MLIR-based Triton analysis pass
|
|
expected = ["c_ptr"]
|
|
else:
|
|
# with TTIR string parsing-based Triton analysis pass
|
|
expected = ["a_ptr", "c_ptr"]
|
|
|
|
return (
|
|
kernel,
|
|
kwargs,
|
|
expected,
|
|
)
|
|
|
|
@requires_gpu
|
|
def test_triton_kernel_inference_mode(self):
|
|
def f(x, y, out):
|
|
n_elements = x.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=4)
|
|
|
|
with torch.inference_mode():
|
|
x = torch.ones(32, device=GPU_TYPE)
|
|
y = torch.ones(32, device=GPU_TYPE)
|
|
out_ref = torch.zeros_like(x)
|
|
out_test = torch.zeros_like(x)
|
|
f(x, y, out_ref)
|
|
torch.compile(f)(x, y, out_test)
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
@make_mutation_test
|
|
def test_cumsum():
|
|
@triton.jit
|
|
def cumsum_kernel(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
|
|
rindex = tl.arange(0, RBLOCK)[None, :]
|
|
xindex = tl.arange(0, XBLOCK)[:, None]
|
|
data = tl.load(in_ptr + rindex)
|
|
scan = tl.cumsum(data, 1)
|
|
expected_max = tl.sum(data, 1)
|
|
tl.device_assert(scan <= expected_max)
|
|
tl.store(out_ptr + xindex * RBLOCK + rindex, scan)
|
|
|
|
t = torch.randn(4)
|
|
kernel = cumsum_kernel
|
|
kwargs = {
|
|
"in_ptr": t,
|
|
"out_ptr": t,
|
|
"XBLOCK": 4,
|
|
"RBLOCK": 16,
|
|
}
|
|
|
|
# TODO(aakhundov): tt.scan is now supported, but only
|
|
# in the new MLIR-based Triton analysis pass (not in the
|
|
# old TTIR string parsing-based one). remove this gating
|
|
# and use ["out_ptr"] as `expected` after the new Triton
|
|
# pin lands both in OSS and internally.
|
|
ttir_module, _ = generate_ttir(kernel, kwargs)
|
|
if hasattr(ttir_module, "walk"):
|
|
# with MLIR-based Triton analysis pass
|
|
expected = ["out_ptr"]
|
|
else:
|
|
# with TTIR string parsing-based Triton analysis pass
|
|
expected = ["in_ptr", "out_ptr"]
|
|
|
|
return (
|
|
kernel,
|
|
kwargs,
|
|
expected,
|
|
)
|
|
|
|
@make_mutation_test
|
|
def test_fn_call_one_return():
|
|
@triton.jit
|
|
def add_kernel_with_fn_call(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
n_elements,
|
|
out_ptr,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
out = helper_id(out_ptr)
|
|
tl.store(out + offsets, output, mask=mask)
|
|
|
|
t = torch.randn(4)
|
|
return (
|
|
add_kernel_with_fn_call,
|
|
{
|
|
"in_ptr0": t,
|
|
"in_ptr1": t,
|
|
"n_elements": 4,
|
|
"out_ptr": t,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["out_ptr"],
|
|
)
|
|
|
|
@make_mutation_test
|
|
def test_fn_call_multi_return():
|
|
@triton.jit
|
|
def add_kernel_with_fn_call(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
n_elements,
|
|
out_ptr,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output, out = helper_add_and_out(x, y, out_ptr)
|
|
tl.store(out + offsets, output, mask=mask)
|
|
|
|
t = torch.randn(4)
|
|
return (
|
|
add_kernel_with_fn_call,
|
|
{
|
|
"in_ptr0": t,
|
|
"in_ptr1": t,
|
|
"n_elements": 4,
|
|
"out_ptr": t,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["out_ptr"],
|
|
)
|
|
|
|
@make_mutation_test
|
|
def test_nested_cond_op_kernel():
|
|
@triton.jit
|
|
def nested_cond_op_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
if tl.program_id(0) == 0:
|
|
if tl.program_id(1) == 0:
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
else:
|
|
pass
|
|
|
|
t = torch.randn(4)
|
|
return (
|
|
nested_cond_op_kernel,
|
|
{
|
|
"in_ptr0": t,
|
|
"in_ptr1": t,
|
|
"out_ptr": t,
|
|
"n_elements": 4,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["out_ptr"],
|
|
)
|
|
|
|
@make_mutation_test
|
|
def test_add_for_loop():
|
|
@triton.jit
|
|
def add_4_times_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = tl.zeros((n_elements,), dtype=tl.float32)
|
|
for i in range(4):
|
|
output += x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
t = torch.randn(4)
|
|
return (
|
|
add_4_times_kernel,
|
|
{
|
|
"in_ptr0": t,
|
|
"in_ptr1": t,
|
|
"out_ptr": t,
|
|
"n_elements": 4,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["out_ptr"],
|
|
)
|
|
|
|
@make_mutation_test
|
|
def test_add_for_loop2():
|
|
@triton.jit
|
|
def add_1_time_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
for i in range(0, BLOCK_SIZE):
|
|
i = tl.multiple_of(i, 1)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
t = torch.randn(4)
|
|
return (
|
|
add_1_time_kernel,
|
|
{
|
|
"in_ptr0": t,
|
|
"in_ptr1": t,
|
|
"out_ptr": t,
|
|
"n_elements": 4,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["out_ptr"],
|
|
)
|
|
|
|
@make_mutation_test
|
|
def test_add_nested_for_loop():
|
|
@triton.jit
|
|
def add_4_times_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = tl.zeros((n_elements,), dtype=tl.float32)
|
|
for i in range(2):
|
|
for j in range(2):
|
|
output += x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
t = torch.randn(4)
|
|
return (
|
|
add_4_times_kernel,
|
|
{
|
|
"in_ptr0": t,
|
|
"in_ptr1": t,
|
|
"out_ptr": t,
|
|
"n_elements": 4,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["out_ptr"],
|
|
)
|
|
|
|
@make_mutation_test
|
|
def test_add_nested_for_loop_multi_return():
|
|
@triton.jit
|
|
def add_4_times_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output1 = tl.zeros((n_elements,), dtype=tl.float32)
|
|
output2 = tl.zeros((n_elements,), dtype=tl.float32)
|
|
for i in range(2):
|
|
for j in range(2):
|
|
output1 += y
|
|
output2 += x
|
|
output = output1 + output2
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
t = torch.randn(4)
|
|
return (
|
|
add_4_times_kernel,
|
|
{
|
|
"in_ptr0": t,
|
|
"in_ptr1": t,
|
|
"out_ptr": t,
|
|
"n_elements": 4,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["out_ptr"],
|
|
)
|
|
|
|
@make_mutation_test
|
|
def test_labels():
|
|
@triton.jit
|
|
def kernel_with_label(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
if pid > 1:
|
|
return
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
t = torch.randn(4)
|
|
return (
|
|
kernel_with_label,
|
|
{
|
|
"in_ptr0": t,
|
|
"in_ptr1": t,
|
|
"out_ptr": t,
|
|
"n_elements": 4,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["out_ptr"],
|
|
)
|
|
|
|
@make_mutation_test
|
|
def test_for_loop_arg():
|
|
@triton.jit
|
|
def fwd_kernel(
|
|
X_ptr,
|
|
W1_ptr,
|
|
b1_ptr,
|
|
O_ptr,
|
|
M: tl.constexpr,
|
|
C1: tl.constexpr,
|
|
C2: tl.constexpr,
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_C2: tl.constexpr,
|
|
):
|
|
# Get program ids
|
|
pid_m = tl.program_id(0)
|
|
|
|
# Compute offsets
|
|
offs_c1 = tl.arange(0, C1)
|
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
|
|
# Load input data
|
|
x_block_ptr = X_ptr + offs_m[:, None] * C1 + offs_c1[None, :]
|
|
x = tl.load(x_block_ptr)
|
|
|
|
# Compute gating
|
|
for c2 in range(0, tl.cdiv(C2, BLOCK_SIZE_C2)):
|
|
# Compute block pointers
|
|
offs_c2 = c2 * BLOCK_SIZE_C2 + tl.arange(0, BLOCK_SIZE_C2)
|
|
o_block_ptr = O_ptr + offs_m[:, None] * C2 + offs_c2[None, :]
|
|
w1_block_ptr = W1_ptr + offs_c1[:, None] * C2 + offs_c2[None, :]
|
|
b1_block_ptr = b1_ptr + offs_c2
|
|
|
|
# Compute output
|
|
w = tl.load(w1_block_ptr)
|
|
b = tl.load(b1_block_ptr)
|
|
o = tl.dot(x, w, allow_tf32=False)
|
|
o += b[None, :]
|
|
|
|
# Store output
|
|
tl.store(o_block_ptr, o)
|
|
|
|
t = torch.randn(64)
|
|
return (
|
|
fwd_kernel,
|
|
{
|
|
"X_ptr": t,
|
|
"W1_ptr": t,
|
|
"b1_ptr": t,
|
|
"O_ptr": t,
|
|
"M": 64,
|
|
"C1": 64,
|
|
"C2": 64,
|
|
"BLOCK_SIZE_M": 64,
|
|
"BLOCK_SIZE_C2": 64,
|
|
},
|
|
["O_ptr"],
|
|
)
|
|
|
|
@make_mutation_test
|
|
def test_for_loop_arg_2():
|
|
@triton.jit
|
|
def fwd_kernel(
|
|
x_ptr,
|
|
o_ptr,
|
|
M,
|
|
N,
|
|
stride_m,
|
|
stride_n,
|
|
BLOCK_B: tl.constexpr,
|
|
BLOCK_M: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
):
|
|
# Get program ids
|
|
pid_m = tl.program_id(0)
|
|
X_block_ptr = tl.make_block_ptr(
|
|
base=x_ptr,
|
|
shape=(M, N),
|
|
strides=(stride_m, stride_n),
|
|
offsets=(0, 0),
|
|
block_shape=(BLOCK_M, BLOCK_N),
|
|
order=(1, 0),
|
|
)
|
|
O_block_ptr = tl.make_block_ptr(
|
|
base=o_ptr,
|
|
shape=(M, N),
|
|
strides=(stride_m, stride_n),
|
|
offsets=(0, 0),
|
|
block_shape=(BLOCK_M, BLOCK_N),
|
|
order=(1, 0),
|
|
)
|
|
|
|
for _ in range(BLOCK_B):
|
|
x = tl.load(X_block_ptr)
|
|
tl.store(O_block_ptr, x)
|
|
|
|
X_block_ptr = tl.advance(X_block_ptr, (BLOCK_M, 0))
|
|
O_block_ptr = tl.advance(O_block_ptr, (BLOCK_M, 0))
|
|
|
|
t = torch.randn((32, 64, 128))
|
|
o = torch.empty_like(t)
|
|
B, M, N = t.shape
|
|
return (
|
|
fwd_kernel,
|
|
{
|
|
"x_ptr": t,
|
|
"o_ptr": o,
|
|
"M": M,
|
|
"N": N,
|
|
"stride_m": N,
|
|
"stride_n": 1,
|
|
"BLOCK_B": B,
|
|
"BLOCK_M": M,
|
|
"BLOCK_N": N,
|
|
},
|
|
["o_ptr"],
|
|
)
|
|
|
|
@make_mutation_test
|
|
def test_while_loop():
|
|
@triton.jit
|
|
def fwd_kernel(
|
|
x_ptr,
|
|
o_ptr,
|
|
M,
|
|
N,
|
|
stride_m,
|
|
stride_n,
|
|
BLOCK_B: tl.constexpr,
|
|
BLOCK_M: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
):
|
|
# Get program ids
|
|
pid_m = tl.program_id(0)
|
|
X_block_ptr = tl.make_block_ptr(
|
|
base=x_ptr,
|
|
shape=(M, N),
|
|
strides=(stride_m, stride_n),
|
|
offsets=(0, 0),
|
|
block_shape=(BLOCK_M, BLOCK_N),
|
|
order=(1, 0),
|
|
)
|
|
O_block_ptr = tl.make_block_ptr(
|
|
base=o_ptr,
|
|
shape=(M, N),
|
|
strides=(stride_m, stride_n),
|
|
offsets=(0, 0),
|
|
block_shape=(BLOCK_M, BLOCK_N),
|
|
order=(1, 0),
|
|
)
|
|
|
|
i = 0
|
|
while i < BLOCK_B:
|
|
x = tl.load(X_block_ptr)
|
|
tl.store(O_block_ptr, x)
|
|
|
|
X_block_ptr = tl.advance(X_block_ptr, (BLOCK_M, 0))
|
|
O_block_ptr = tl.advance(O_block_ptr, (BLOCK_M, 0))
|
|
i += 1
|
|
|
|
t = torch.randn((32, 64, 128))
|
|
o = torch.empty_like(t)
|
|
B, M, N = t.shape
|
|
return (
|
|
fwd_kernel,
|
|
{
|
|
"x_ptr": t,
|
|
"o_ptr": o,
|
|
"M": M,
|
|
"N": N,
|
|
"stride_m": N,
|
|
"stride_n": 1,
|
|
"BLOCK_B": B,
|
|
"BLOCK_M": M,
|
|
"BLOCK_N": N,
|
|
},
|
|
["o_ptr"],
|
|
)
|
|
|
|
@make_mutation_test
|
|
def test_branch_with_multiple_yield_args():
|
|
@triton.jit
|
|
def branch_with_multiple_yield_args(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
conditional_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
conditional = tl.load(conditional_ptr)
|
|
if conditional:
|
|
in0 = in_ptr0 + 1
|
|
in1 = in_ptr1 + 1
|
|
out = out_ptr + 1
|
|
else:
|
|
in0 = in_ptr0
|
|
in1 = in_ptr1
|
|
out = out_ptr
|
|
x = tl.load(in0 + offsets, mask=mask)
|
|
y = tl.load(in1 + offsets, mask=mask)
|
|
tl.store(out + offsets, x + y, mask=mask)
|
|
|
|
x = torch.randn(15)
|
|
y = torch.randn(15)
|
|
out = torch.zeros(15)
|
|
conditional = torch.tensor(True)
|
|
return (
|
|
branch_with_multiple_yield_args,
|
|
{
|
|
"in_ptr0": x,
|
|
"in_ptr1": y,
|
|
"out_ptr": out,
|
|
"conditional_ptr": conditional,
|
|
"n_elements": 14,
|
|
"BLOCK_SIZE": 16,
|
|
},
|
|
["out_ptr"],
|
|
)
|
|
|
|
|
|
if HAS_GPU:
|
|
t = torch.randn(4)
|
|
tt = torch.randn(4, 1)
|
|
tests = [
|
|
[
|
|
add_kernel,
|
|
{
|
|
"in_ptr0": t,
|
|
"in_ptr1": t,
|
|
"out_ptr": t,
|
|
"n_elements": 4,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["out_ptr"],
|
|
],
|
|
[
|
|
add_kernel_2d_autotuned,
|
|
{
|
|
"in_ptr0": t,
|
|
"in_ptr1": t,
|
|
"out_ptr": t,
|
|
"x_elements": 4,
|
|
"y_elements": 4,
|
|
},
|
|
["out_ptr"],
|
|
],
|
|
[
|
|
indirection_kernel,
|
|
{
|
|
"in_ptr0": t,
|
|
"out_ptr": t,
|
|
"n_elements": 4,
|
|
"BLOCK_SIZE": 4,
|
|
"ACTIVATION": "mul2_inplace_kernel",
|
|
},
|
|
["in_ptr0", "out_ptr"],
|
|
],
|
|
[
|
|
indirection_kernel,
|
|
{
|
|
"in_ptr0": t,
|
|
"out_ptr": t,
|
|
"n_elements": 4,
|
|
"BLOCK_SIZE": 4,
|
|
"ACTIVATION": "add_kernel",
|
|
},
|
|
["out_ptr"],
|
|
],
|
|
[
|
|
mul2_inplace_kernel,
|
|
{"ptr": t, "n_elements": 4, "BLOCK_SIZE": 4},
|
|
["ptr"],
|
|
],
|
|
# Cant optimize since the kernel contains a tl.inline_asm_elementwise
|
|
[
|
|
inline_asm_kernel,
|
|
{"X": t, "Y": t, "Z": t, "n": 4, "BLOCK": 4},
|
|
["X", "Y", "Z"],
|
|
],
|
|
[
|
|
add_kernel_with_block_ptr,
|
|
{
|
|
"x_ptr": t,
|
|
"y_ptr": t,
|
|
"output_ptr": t,
|
|
"n_elements": 4,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["output_ptr"],
|
|
],
|
|
[
|
|
kernel_with_block_ptr_2d,
|
|
{
|
|
"x_ptr": tt,
|
|
"output_ptr": tt,
|
|
"n_elements": 4,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["output_ptr"],
|
|
],
|
|
[
|
|
add_kernel_with_import,
|
|
{
|
|
"in_ptr0": t,
|
|
"in_ptr1": t,
|
|
"out_ptr": t,
|
|
"n_elements": 4,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["out_ptr"],
|
|
],
|
|
[
|
|
atomic_add_kernel,
|
|
{
|
|
"in_ptr0": t,
|
|
"in_ptr1": t,
|
|
"out_ptr": t,
|
|
"n_elements": 4,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["out_ptr"],
|
|
],
|
|
[
|
|
add_4_times_kernel,
|
|
{
|
|
"in_ptr0": t,
|
|
"in_ptr1": t,
|
|
"out_ptr": t,
|
|
"n_elements": 4,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["out_ptr"],
|
|
],
|
|
[
|
|
cond_op_kernel,
|
|
{
|
|
"in_ptr0": t,
|
|
"in_ptr1": t,
|
|
"out_ptr": t,
|
|
"n_elements": 4,
|
|
"BLOCK_SIZE": 4,
|
|
},
|
|
["out_ptr"],
|
|
],
|
|
]
|
|
for kernel, inputs, outputs in tests:
|
|
fn = make_mutation_test(
|
|
# Add default arguments to avoid Python lambda capture pitfall
|
|
# This forces the capture at lambda creation
|
|
lambda kernel=kernel, inputs=inputs, outputs=outputs: (
|
|
kernel,
|
|
inputs,
|
|
outputs,
|
|
)
|
|
)
|
|
name = f"test_mutations_{kernel.fn.__name__}"
|
|
# Poor way to make test names be unique
|
|
while name in MutationTests.__dict__:
|
|
name += "1"
|
|
|
|
setattr(MutationTests, name, fn)
|
|
|
|
|
|
class CustomOpTests(torch._inductor.test_case.TestCase):
|
|
"""Tests for custom ops wrapping triton kernels"""
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("autotuned", [False, True])
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_add_kernel(self, autotuned, dynamic):
|
|
from torch._inductor.utils import run_and_get_code
|
|
|
|
libname = "my_cool_namespace"
|
|
opname = "my_triton_operator"
|
|
|
|
@torch.library.triton_op(f"{libname}::{opname}", mutates_args={})
|
|
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
output = torch.empty_like(x)
|
|
n_elements = output.numel()
|
|
|
|
def grid(meta):
|
|
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
|
|
if autotuned:
|
|
capture_triton(add_kernel_autotuned)[grid](x, y, output, n_elements)
|
|
else:
|
|
capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
|
|
return output
|
|
|
|
def f(x, y):
|
|
return add(x, y)
|
|
|
|
x = torch.randn(3, device=GPU_TYPE)
|
|
y = torch.randn(3, device=GPU_TYPE)
|
|
|
|
out = f(x, y)
|
|
expected = x + y
|
|
self.assertEqual(out, expected)
|
|
out_compiled, codes = run_and_get_code(torch.compile(f, dynamic=dynamic), x, y)
|
|
self.assertEqual(out_compiled, expected)
|
|
self.assertEqual(len(codes), 1)
|
|
|
|
# Check that we decomposed the operator away
|
|
code = "\n".join(codes[0])
|
|
self.assertNotIn(libname, code)
|
|
self.assertNotIn(opname, code)
|
|
|
|
@requires_gpu
|
|
@dynamo_config.patch("recompile_limit", 1)
|
|
def test_triton_dynamic_grid_no_recompile(self):
|
|
libname = "my_cool_namespace"
|
|
opname = "my_triton_operator"
|
|
|
|
@torch.library.triton_op(f"{libname}::{opname}", mutates_args={})
|
|
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
output = torch.empty_like(x)
|
|
n_elements = output.numel()
|
|
capture_triton(add_kernel)[(n_elements,)](x, y, output, n_elements, 16)
|
|
return output
|
|
|
|
@torch.compile(fullgraph=True, dynamic=True)
|
|
def f(x):
|
|
return add(x, x)
|
|
|
|
f(torch.randn(8, device=GPU_TYPE))
|
|
f(torch.randn(16, device=GPU_TYPE))
|
|
|
|
@unittest.skipIf(not has_triton_package(), "requires triton")
|
|
def test_capture_triton_meta(self):
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
@triton.jit
|
|
def add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@torch.library.triton_op("mylib::add", mutates_args=())
|
|
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
output = torch.empty_like(x)
|
|
n_elements = output.numel()
|
|
|
|
def grid(meta):
|
|
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
|
|
capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
|
|
return output
|
|
|
|
def f(x, y):
|
|
return add(x, y)
|
|
|
|
x = torch.randn(3, device="meta")
|
|
y = torch.randn(3, device="meta")
|
|
|
|
out = f(x, y)
|
|
expected = torch.empty_like(x)
|
|
self.assertEqual(out, expected)
|
|
|
|
@requires_gpu
|
|
def test_wrap_triton_disabled_in_triton_op(self):
|
|
import triton # @manual
|
|
import triton.language as tl # @manual
|
|
|
|
@triton.jit
|
|
def add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
add_kernel_decorated = torch.library.wrap_triton(add_kernel)
|
|
|
|
status = []
|
|
|
|
@torch.library.triton_op("mylib::add", mutates_args=())
|
|
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
import torch._higher_order_ops.triton_kernel_wrap
|
|
|
|
status.append(torch._library.triton.is_wrap_triton_enabled())
|
|
|
|
# capture_triton should return the kernel directly if disabled
|
|
result = torch.library.wrap_triton(add_kernel)
|
|
self.assertIs(result, add_kernel)
|
|
|
|
# Smoke test: check that with capture_triton disabled this still does something
|
|
output = torch.empty_like(x)
|
|
output2 = torch.empty_like(x)
|
|
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel_decorated[grid](x, y, output, n_elements, BLOCK_SIZE=16)
|
|
|
|
add_kernel_decorated.run(
|
|
x, y, output2, n_elements, BLOCK_SIZE=16, grid=grid, warmup=False
|
|
)
|
|
|
|
return output + output2
|
|
|
|
x = torch.randn(3, device=GPU_TYPE)
|
|
y = torch.randn(3, device=GPU_TYPE)
|
|
z = add(x, y)
|
|
self.assertEqual(status[-1], False)
|
|
self.assertEqual(z, (x + y) * 2)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("autotune", [False, True])
|
|
def test_capture_triton_special_kwargs(self, dynamic, autotune):
|
|
@triton.jit
|
|
def add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
if autotune:
|
|
add_kernel = triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 128}),
|
|
triton.Config({"BLOCK_SIZE": 64}),
|
|
],
|
|
key=["n_elements"],
|
|
)(add_kernel)
|
|
|
|
def f(x, y):
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
if autotune:
|
|
kwargs = {}
|
|
else:
|
|
kwargs = {"BLOCK_SIZE": 128}
|
|
capture_triton(add_kernel)[grid](
|
|
x,
|
|
y,
|
|
output,
|
|
n_elements,
|
|
num_warps=8,
|
|
num_stages=3,
|
|
**kwargs,
|
|
)
|
|
return output
|
|
|
|
x = torch.randn(4, device=GPU_TYPE)
|
|
tracing_mode = "symbolic" if dynamic else "fake"
|
|
|
|
result = f(x, x)
|
|
self.assertEqual(result, x + x)
|
|
|
|
from functorch import make_fx
|
|
|
|
gm = make_fx(f, tracing_mode=tracing_mode)(x, x)
|
|
self.assertEqual(gm(x, x), x + x)
|
|
|
|
@skipIfWindows(msg="AOTI/Cpp_Wrapper have not enabled on Windows")
|
|
@requires_gpu
|
|
@inductor_config.patch("cpp_wrapper", True)
|
|
@inductor_config.patch("triton.autotune_at_compile_time", True)
|
|
def test_autotune_unbacked(self):
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
def get_op_configs():
|
|
return [
|
|
triton.Config(
|
|
{
|
|
"BLOCK_M": 32,
|
|
"BLOCK_N": 64,
|
|
"BLOCK_K": 32,
|
|
"GROUP_M": 8,
|
|
},
|
|
num_stages=5,
|
|
num_warps=2,
|
|
),
|
|
triton.Config(
|
|
{
|
|
"BLOCK_M": 128,
|
|
"BLOCK_N": 256,
|
|
"BLOCK_K": 64,
|
|
"GROUP_M": 8,
|
|
},
|
|
num_stages=3,
|
|
num_warps=8,
|
|
),
|
|
]
|
|
|
|
@triton.autotune(
|
|
configs=get_op_configs(),
|
|
key=["N", "K"],
|
|
)
|
|
@triton.jit
|
|
def op_zeros(
|
|
x_ptr,
|
|
w_ptr,
|
|
z_ptr,
|
|
M,
|
|
N,
|
|
K,
|
|
stride_xm,
|
|
stride_xk,
|
|
stride_wk,
|
|
stride_wn,
|
|
stride_zm,
|
|
stride_zn,
|
|
BLOCK_M: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
BLOCK_K: tl.constexpr,
|
|
GROUP_M: tl.constexpr,
|
|
ALLOW_TF32: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
num_pid_m = tl.cdiv(M, BLOCK_M)
|
|
num_pid_n = tl.cdiv(N, BLOCK_N)
|
|
num_pid_in_group = GROUP_M * num_pid_n
|
|
group_id = pid // num_pid_in_group
|
|
first_pid_m = group_id * GROUP_M
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
|
|
pid_m = first_pid_m + (pid % group_size_m)
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
offs_m = tl.arange(0, BLOCK_M)
|
|
offs_n = tl.arange(0, BLOCK_N)
|
|
mask_m = (pid_m * BLOCK_M + offs_m)[:, None] < M
|
|
mask_n = (pid_n * BLOCK_N + offs_n)[None, :] < N
|
|
|
|
z_mask = mask_m & mask_n
|
|
z = 0.0
|
|
z_ptr += pid_m.to(tl.int64) * BLOCK_M * stride_zm
|
|
z_ptr += pid_n.to(tl.int64) * BLOCK_N * stride_zn
|
|
z_ptrs = z_ptr + stride_zm * offs_m[:, None] + stride_zn * offs_n[None, :]
|
|
tl.store(z_ptrs, z, mask=z_mask)
|
|
|
|
@torch.compile()
|
|
def foo(x, w):
|
|
M, K = x.shape
|
|
KB, N = w.shape
|
|
assert K == KB, f"incompatible dimensions {K}, {KB}"
|
|
|
|
z = torch.empty((M, N), device=x.device, dtype=x.dtype)
|
|
|
|
def grid(META):
|
|
return (
|
|
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
|
)
|
|
|
|
op_zeros[grid](
|
|
x,
|
|
w,
|
|
z,
|
|
M,
|
|
N,
|
|
K,
|
|
x.stride(0),
|
|
x.stride(1),
|
|
w.stride(0),
|
|
w.stride(1),
|
|
z.stride(0),
|
|
z.stride(1),
|
|
ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32,
|
|
)
|
|
return z
|
|
|
|
M, K, N = 128, 64, 32
|
|
x = torch.randn(M, K, device=GPU_TYPE)
|
|
w = torch.randn(K, N, device=GPU_TYPE)
|
|
|
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
|
|
with log_settings("+output_code"), self.assertLogs(
|
|
logger="torch._inductor", level=logging.DEBUG
|
|
) as log:
|
|
foo(x, w)
|
|
|
|
output = "\n".join(record.getMessage() for record in log.records)
|
|
# correct grid example values updated per block size
|
|
FileCheck().check("Compile-time auto-tuning block:").check(
|
|
"grid_wrapper_for_op_zeros_0"
|
|
).check_next("return (256").check_next("return (64").run(output)
|
|
|
|
# Triton 3.2.0 adds the required flags to the Autotuner object for this test
|
|
# PR: https://github.com/triton-lang/triton/pull/5092
|
|
@requires_gpu
|
|
def test_autotune_no_pre_or_post_hook_user_defined(self):
|
|
from triton.runtime.autotuner import Autotuner
|
|
|
|
def init_to_zero(name):
|
|
return lambda nargs: nargs[name].zero_()
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config(
|
|
{"BLOCK_SIZE": 1024},
|
|
num_warps=4,
|
|
num_stages=2,
|
|
pre_hook=init_to_zero("output_ptr"),
|
|
)
|
|
],
|
|
pre_hook=init_to_zero("output_ptr"),
|
|
post_hook=init_to_zero("output_ptr"),
|
|
key=["n_elements"],
|
|
)
|
|
@triton.jit
|
|
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
|
pid = tl.program_id(axis=0)
|
|
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
|
|
x = tl.load(x_ptr + offsets, mask=mask)
|
|
y = tl.load(y_ptr + offsets, mask=mask)
|
|
output = x + y
|
|
tl.atomic_add(output_ptr + offsets, output, mask=mask)
|
|
|
|
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
output = torch.ones(x.shape, device=x.device, dtype=x.dtype)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel[grid](x, y, output, n_elements)
|
|
return output
|
|
|
|
x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
|
|
y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
|
|
|
|
# should always pass
|
|
assert add(x, y).mean() == 2, "Problem with add kernel"
|
|
|
|
# assert that the user_defined_* flags are properly set on the kernel before compilation
|
|
self.assertEqual(isinstance(add_kernel, Autotuner), True)
|
|
if not hasattr(add_kernel, "user_defined_pre_hook") or not hasattr(
|
|
add_kernel, "user_defined_post_hook"
|
|
):
|
|
raise unittest.SkipTest(
|
|
"test requires Triton version >= 3.2.0 for Autotuner.user_defined* hooks"
|
|
)
|
|
|
|
self.assertEqual(add_kernel.user_defined_pre_hook, True)
|
|
self.assertEqual(add_kernel.user_defined_post_hook, True)
|
|
|
|
# this should cause an exception, since pre_hook is not allowed
|
|
msg = "pre_hook and post_hook are not supported in triton.Autotune or triton.Config"
|
|
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
|
add_compiled = torch.compile(add, mode="reduce-overhead", fullgraph=True)
|
|
add_compiled(x, y).mean()
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
@common_utils.parametrize("autotune_at_compile_time", [True, False])
|
|
def test_triton_kernel_reset_to_zero(self, backend, autotune_at_compile_time):
|
|
if autotune_at_compile_time and backend != "inductor":
|
|
raise unittest.SkipTest("compile-time autotuning only exists in inductor")
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
|
|
triton.Config({"BLOCK_SIZE": 32}, num_stages=3, num_warps=8),
|
|
triton.Config({"BLOCK_SIZE": 16}, num_stages=3, num_warps=8),
|
|
],
|
|
key=[],
|
|
reset_to_zero=["increment_ptr"],
|
|
)
|
|
@triton.jit
|
|
def increment_kernel(
|
|
in_ptr0,
|
|
increment_ptr, # reset this to zero every time
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
|
|
in_ptr_vals = tl.load(in_ptr0 + offsets, mask=mask)
|
|
increment_val = tl.load(increment_ptr + offsets, mask=mask)
|
|
# increment_val should always be zero
|
|
tl.store(in_ptr0 + offsets, in_ptr_vals + increment_val, mask=mask)
|
|
|
|
@torch.compile(fullgraph=True, backend=backend)
|
|
def f(x, increment):
|
|
n_elements = x.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
increment_kernel[grid](x, increment, n_elements=n_elements)
|
|
return x
|
|
|
|
x = torch.rand(4, device=GPU_TYPE)
|
|
y = torch.clone(x)
|
|
increment = torch.rand(4, device=GPU_TYPE)
|
|
|
|
# during autotuning, x should not change in value
|
|
with inductor_config.patch(
|
|
{"triton.autotune_at_compile_time": autotune_at_compile_time}
|
|
):
|
|
# we will add rand a single time to x
|
|
f(x, increment)
|
|
|
|
self.assertEqual(y + increment, x)
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
def test_triton_single_autotune(self, backend):
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config(
|
|
{"BLOCK_SIZE": 4096},
|
|
)
|
|
],
|
|
key=["n_elements"],
|
|
)
|
|
# Currently, this autotuning decorator will never run!
|
|
# We only support having a single autotuning decorator on each Triton kernel
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config(
|
|
{"BLOCK_SIZE": 1024},
|
|
)
|
|
],
|
|
key=["n_elements"],
|
|
)
|
|
@triton.jit
|
|
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
|
pid = tl.program_id(axis=0)
|
|
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
|
|
x = tl.load(x_ptr + offsets, mask=mask)
|
|
y = tl.load(y_ptr + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(output_ptr + offsets, output, mask=mask)
|
|
|
|
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
output = torch.ones(x.shape, device=x.device, dtype=x.dtype)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
add_kernel[grid](x, y, output, n_elements)
|
|
return output
|
|
|
|
x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
|
|
y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
|
|
|
|
# this should cause an exception, since pre_hook is not allowed
|
|
msg = "Passing multiple @triton.autotune decorators is not supported. Please use a single @triton.autotune decorator instead."
|
|
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
|
add_compiled = torch.compile(
|
|
add, mode="reduce-overhead", fullgraph=True, backend=backend
|
|
)
|
|
add_compiled(x, y).mean()
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("non_strict", [True, False])
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
@common_utils.parametrize("with_perf_model", [True, False])
|
|
def test_triton_kernel_prune_configs_by(self, backend, with_perf_model, non_strict):
|
|
# for non-strict mode
|
|
libname = "my_cool_namespace"
|
|
opname = "my_triton_operator"
|
|
|
|
records = {}
|
|
|
|
def early_config_prune(configs, named_args, **kwargs):
|
|
# we need to save the records to the returned config
|
|
records["run_early_config_prune"] = True
|
|
if "N" in kwargs and kwargs["N"] == 1024:
|
|
records["capture_kwargs"] = True
|
|
# named args are: dst, src, add_float
|
|
if "dst" in named_args and "src" in named_args and len(named_args) == 3:
|
|
records["capture_named_args"] = True
|
|
return [configs[0]]
|
|
|
|
def perf_model(*args, **kwargs):
|
|
records["run_perf_model"] = True
|
|
return kwargs["BLOCK_SIZE"] * -1
|
|
|
|
if with_perf_model:
|
|
prune_configs_by = {"perf_model": perf_model, "top_k": 1}
|
|
else:
|
|
prune_configs_by = {"early_config_prune": early_config_prune}
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config(kwargs={"BLOCK_SIZE": 32}),
|
|
triton.Config(kwargs={"BLOCK_SIZE": 128}),
|
|
],
|
|
key=["N"],
|
|
prune_configs_by=prune_configs_by,
|
|
)
|
|
@triton.jit
|
|
def prune_by_kernel(
|
|
dst,
|
|
src,
|
|
add_float,
|
|
N,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
x = tl.load(src + offsets, mask=offsets < N)
|
|
# we only modify dst if our perf_model is applied (and a BLOCK_SIZE of 128 is selected)
|
|
if BLOCK_SIZE == 128:
|
|
x = x + add_float
|
|
tl.store(dst + offsets, x, mask=offsets < N)
|
|
|
|
def f(
|
|
dst: torch.Tensor,
|
|
src: torch.Tensor,
|
|
add_float: float,
|
|
N: int,
|
|
) -> None:
|
|
grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)
|
|
if non_strict:
|
|
torch.library.wrap_triton(prune_by_kernel)[grid](
|
|
dst, src, add_float, N=N
|
|
)
|
|
else:
|
|
prune_by_kernel[grid](dst, src, add_float, N=N)
|
|
|
|
if non_strict:
|
|
decorator = torch.library.triton_op(
|
|
f"{libname}::{opname}", mutates_args={"dst"}
|
|
)(f)
|
|
else:
|
|
# we can just pass the function 'f' for dynamo
|
|
decorator = f
|
|
|
|
compiled_f = torch.compile(decorator, backend=backend)
|
|
N = 1024
|
|
src = torch.randn(N, device=GPU_TYPE)
|
|
dst = torch.empty(N, device=GPU_TYPE)
|
|
compiled_f(dst, src, 1.5, N)
|
|
|
|
if with_perf_model:
|
|
# when applying the perf_model: kwargs["BLOCK_SIZE"] * -1, the largest config (BLOCK_SIZE==128) is selected
|
|
self.assertEqual(len(records), 1)
|
|
self.assertEqual(src + 1.5, dst)
|
|
else:
|
|
# without the perf_model, the BLOCK_SIZE==32, and as a result dst is not modified and remains equal to src
|
|
self.assertEqual(src, dst)
|
|
self.assertEqual(len(records), 3)
|
|
self.assertTrue(records["run_early_config_prune"])
|
|
self.assertTrue(records["capture_kwargs"])
|
|
self.assertTrue(records["capture_named_args"])
|
|
|
|
@requires_gpu
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
@common_utils.parametrize("with_perf_model", [True, False])
|
|
def test_triton_kernel_prune_configs_by_recompile(self, backend, with_perf_model):
|
|
"""
|
|
We want to recompile if anyone changes configs in the autotuner object
|
|
In short if for example the following sequence of events happens:
|
|
1. foo = torch.compile(bar)
|
|
1. call foo
|
|
2. autotuner.configs = [new configs list]
|
|
3. call foo
|
|
|
|
A recompile event should occur, which we check with Dynamo counters
|
|
This tests that we are installing guards on input objects properly
|
|
"""
|
|
|
|
# We don't modify records here because we are testing whether or not
|
|
# recompiles occur/guards are installed
|
|
# If we modified the non-local records dict here, this would trigger
|
|
# recompile events.
|
|
def early_config_prune(configs, named_args, **kwargs):
|
|
return [configs[0]]
|
|
|
|
def perf_model(*args, **kwargs):
|
|
return kwargs["BLOCK_SIZE"] * -1
|
|
|
|
if with_perf_model:
|
|
prune_configs_by = {"perf_model": perf_model, "top_k": 1}
|
|
else:
|
|
prune_configs_by = {"early_config_prune": early_config_prune}
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config(kwargs={"BLOCK_SIZE": 32}),
|
|
triton.Config(kwargs={"BLOCK_SIZE": 128}),
|
|
],
|
|
key=["N"],
|
|
prune_configs_by=prune_configs_by,
|
|
)
|
|
@triton.jit
|
|
def prune_by_kernel(
|
|
dst,
|
|
src,
|
|
add_float,
|
|
N,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
x = tl.load(src + offsets, mask=offsets < N)
|
|
# Let's make sure we always select a block size of 128 based on our perf_model
|
|
if BLOCK_SIZE == 128:
|
|
x = x + add_float
|
|
tl.store(dst + offsets, x, mask=offsets < N)
|
|
|
|
torch._dynamo.reset()
|
|
counter = torch._dynamo.testing.CompileCounterWithBackend(backend=backend)
|
|
|
|
@torch.compile(fullgraph=True, backend=counter)
|
|
def f(dst, src, add_float, N):
|
|
grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)
|
|
prune_by_kernel[grid](dst, src, add_float, N=N)
|
|
|
|
N = 1024
|
|
src = torch.randn(N, device=GPU_TYPE)
|
|
dst = torch.empty(N, device=GPU_TYPE)
|
|
|
|
# first compilation, this prunes the configs
|
|
f(dst, src, 1.5, N)
|
|
|
|
self.assertEqual(counter.op_count, 1)
|
|
|
|
f(dst, src, 1.5, N)
|
|
|
|
# this should not trigger a recompilation
|
|
# this is because we modified the test to not touch the records dict
|
|
# as we do in test_triton_kernel_prune_configs_by. If we kept it, it would trigger a recompile here.
|
|
self.assertEqual(counter.op_count, 1)
|
|
|
|
# Modify the autotuner object
|
|
prune_by_kernel.configs = [triton.Config(kwargs={"BLOCK_SIZE": 64})]
|
|
|
|
# Calling the kernel after modifying the autotuner should
|
|
# trigger a recompile
|
|
f(dst, src, 1.5, N)
|
|
|
|
self.assertEqual(counter.op_count, 2)
|
|
|
|
# there should be no recompile here
|
|
f(dst, src, 1.5, N)
|
|
|
|
self.assertEqual(counter.op_count, 2)
|
|
|
|
# see: https://github.com/triton-lang/triton/blob/67ea999935f4511a535a25bdecb27e79e3c3af41/python/test/unit/language/test_decorator.py#L31
|
|
@requires_gpu
|
|
@common_utils.parametrize("non_strict", [True, False])
|
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
@common_utils.parametrize("autotune_at_compile_time", [True, False])
|
|
def test_triton_kernel_heuristic(
|
|
self, backend, autotune_at_compile_time, non_strict
|
|
):
|
|
# for non-strict mode
|
|
libname = "my_cool_namespace"
|
|
opname = "my_triton_operator"
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config(kwargs={"BLOCK_SIZE": 32}),
|
|
],
|
|
key=["N"],
|
|
)
|
|
# we should be able to modify existing keys in kwargs
|
|
@triton.heuristics({"BLOCK_SIZE": lambda nargs: nargs["BLOCK_SIZE"] * 2})
|
|
# test kwargs
|
|
@triton.heuristics({"EVEN_N": lambda nargs: nargs["N"] + 10})
|
|
@triton.heuristics({"EVEN_N": lambda nargs: nargs["EVEN_N"] * 2})
|
|
# test args
|
|
# There are differences here from OSS Triton because we run these functions in Dynamo
|
|
# We don't have access to the .data_ptr() of TensorVariables
|
|
@triton.heuristics({"NDIM_src": lambda nargs: nargs["src"] is None})
|
|
# test that heuristics are applied in the correct order
|
|
@triton.heuristics({"EVEN_N": lambda nargs: nargs["EVEN_N"] - 10})
|
|
@triton.jit
|
|
def heuristics_kernel(
|
|
dst,
|
|
src,
|
|
N,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
EVEN_N: tl.constexpr,
|
|
NDIM_src: tl.constexpr,
|
|
):
|
|
tl.store(dst, EVEN_N + BLOCK_SIZE)
|
|
tl.store(dst + 1, NDIM_src)
|
|
|
|
grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)
|
|
|
|
def f(
|
|
dst: torch.Tensor,
|
|
src: torch.Tensor,
|
|
N: int,
|
|
) -> None:
|
|
grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)
|
|
if non_strict:
|
|
torch.library.wrap_triton(heuristics_kernel)[grid](dst, src, N=N)
|
|
else:
|
|
heuristics_kernel[grid](dst, src, N=N)
|
|
|
|
if non_strict:
|
|
decorator = torch.library.triton_op(
|
|
f"{libname}::{opname}", mutates_args={"dst"}
|
|
)(f)
|
|
else:
|
|
# we can just pass the function 'f' for dynamo
|
|
decorator = f
|
|
|
|
compiled_f = torch.compile(decorator, backend=backend)
|
|
|
|
N = 1023
|
|
src = torch.empty(N, device=GPU_TYPE)
|
|
dst = torch.zeros(N, device=GPU_TYPE)
|
|
|
|
with inductor_config.patch(
|
|
{"triton.autotune_at_compile_time": autotune_at_compile_time}
|
|
):
|
|
compiled_f(dst, src, N=N)
|
|
|
|
# now let's run without torch.compile to compare
|
|
triton_src = torch.empty(N, device=GPU_TYPE)
|
|
triton_dst = torch.zeros(N, device=GPU_TYPE)
|
|
heuristics_kernel[grid](triton_dst, triton_src, N=N)
|
|
|
|
# triton_dst[0].item() is 2120
|
|
# (1023 + 10) * 2 - 10 + BLOCK_SIZE = 2056 + 64 = 2120
|
|
# this is to test that we apply the heuristics in the correct order
|
|
self.assertEqual(triton_dst[0].item(), 2120)
|
|
self.assertEqual(triton_dst[1].item(), 0.0)
|
|
|
|
# Results should match
|
|
self.assertEqual(dst[0].item(), triton_dst[0].item())
|
|
self.assertEqual(dst[1].item(), triton_dst[1].item())
|
|
|
|
# @triton.heuristics cannot return non-constant values
|
|
# check for the exception
|
|
if not non_strict:
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config(kwargs={"BLOCK_SIZE": 32}),
|
|
],
|
|
key=["N"],
|
|
)
|
|
# torch.randint(...)[0] will produce a non-constant value
|
|
@triton.heuristics({"EVEN_N": lambda nargs: torch.randint(1, (1, 1))[0]})
|
|
@triton.jit
|
|
def heuristics_kernel(
|
|
dst,
|
|
src,
|
|
N,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
EVEN_N: tl.constexpr,
|
|
):
|
|
tl.store(dst, N)
|
|
|
|
grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)
|
|
|
|
def f(
|
|
dst: torch.Tensor,
|
|
src: torch.Tensor,
|
|
N: int,
|
|
) -> None:
|
|
grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)
|
|
heuristics_kernel[grid](dst, src, N=N)
|
|
|
|
compiled_f = torch.compile(f, backend=backend, fullgraph=True)
|
|
N = 1023
|
|
src = torch.empty(N, device=GPU_TYPE)
|
|
dst = torch.zeros(N, device=GPU_TYPE)
|
|
msg = "@triton.heuristics must return constant values because configs can only contain constant values."
|
|
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
|
compiled_f(dst, src, N=N)
|
|
|
|
|
|
common_utils.instantiate_parametrized_tests(KernelTests)
|
|
common_utils.instantiate_parametrized_tests(CustomOpTests)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
run_tests()
|