[Reland] First version of statically compiled launcher for triton compiled CUDA kernels (#149238)

This is a new version of https://github.com/pytorch/pytorch/pull/148561 fixing the ROCM test failure

Putting this up for a first pass review, though I will likely make a bunch of changes before landing to add more features, etc.

This diff implements a first version of a static CUDA kernel launcher in `torch._C`. The goal here is to take a cubin file and some metadata from a CompiledKernel from `triton`, and launch the cubin file directly.

Background doc: https://docs.google.com/document/d/1rjRcHl6MfauHG30nCoQX-9UKvKyIs4WWMy_GsGyqb9g/edit?tab=t.0#heading=h.ut5lf39lzq66

Normally, using triton's CompiledKernel.make_launcher(), we would pay the cost of codegenning C++ and running it at compile time. With this new approach, we can use one statically compiled library to launch the kernel.

The tradeoff here is that this new kernel launcher will not be able to use codegen to deal with different lengths/types of arguments. So we use templating to handle up to 10 arguments for now. We also allocate 8 bytes on the stack per argument no matter the argument type, which can take more memory than codegenning. On the other hand, we improve compile time on cold and warm start by not having to call the C++ compiler at all.

This diff does not add the launcher to torch, but introduces a basic test suite.

A list of TODOs that are not yet complete:
- Handle `nvTmaDesc` and `cuTensorMap`, which triton handles
- Embed the grid logic instead of passing in gridX,Y,Z
- Handle launch_enter and exit hooks? (Not sure if inductor has these)
- Benchmarking to see if there's runtime performance loss
- Probably lots of features of the triton C++ generated code that I haven't handled yet.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149238
Approved by: https://github.com/oulgen
This commit is contained in:
James Wu
2025-03-14 17:25:48 -07:00
committed by PyTorch MergeBot
parent c83c711da8
commit a9c55277d7
9 changed files with 1004 additions and 0 deletions

View File

@ -156,6 +156,7 @@ NVRTC_STUB2(nvrtcGetProgramLogSize,nvrtcProgram, size_t*)
NVRTC_STUB2(nvrtcGetProgramLog, nvrtcProgram, char *)
NVRTC_STUB3(nvrtcGetLoweredName, nvrtcProgram, const char *, const char **)
CUDA_STUB2(cuModuleLoad, CUmodule*, const char*)
CUDA_STUB2(cuModuleLoadData, CUmodule *, const void *)
CUDA_STUB3(cuModuleGetFunction, CUfunction *, CUmodule, const char *)
CUDA_STUB4(cuOccupancyMaxActiveBlocksPerMultiprocessor, int *, CUfunction, int, size_t)
@ -169,6 +170,8 @@ CUDA_STUB4(cuLinkCreate, unsigned int, CUjit_option *, void **, CUlinkState *)
CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *)
CUDA_STUB3(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int)
CUDA_STUB3(cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction)
CUDA_STUB3(cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr)
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
CUresult CUDAAPI

View File

@ -43,6 +43,7 @@ namespace at::cuda {
_(nvrtcGetProgramLogSize) \
_(nvrtcGetProgramLog) \
_(nvrtcGetLoweredName) \
_(cuModuleLoad) \
_(cuModuleLoadData) \
_(cuModuleLoadDataEx) \
_(cuModuleGetFunction) \
@ -60,6 +61,7 @@ namespace at::cuda {
_(cuLinkComplete) \
_(cuFuncSetAttribute) \
_(cuFuncGetAttribute) \
_(cuPointerGetAttribute) \
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
#define AT_FORALL_NVRTC_EXTENDED(_) \

View File

@ -859,6 +859,7 @@ libtorch_python_core_sources = [
"torch/csrc/inductor/aoti_eager/kernel_holder.cpp",
"torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp",
"torch/csrc/inductor/resize_storage_bytes.cpp",
"torch/csrc/inductor/static_cuda_launcher.cpp",
"torch/csrc/jit/backends/backend_init.cpp",
"torch/csrc/jit/python/init.cpp",
"torch/csrc/jit/passes/onnx.cpp",

View File

@ -0,0 +1,319 @@
# Owner(s): ["module: inductor"]
import os
import tempfile
from typing import Any, Callable
import torch
from torch._dynamo.device_interface import get_interface_for_device
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.static_cuda_launcher import StaticallyLaunchedCudaKernel
from torch._inductor.runtime.triton_compat import tl, triton
from torch._inductor.runtime.triton_helpers import libdevice
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.triton_utils import requires_cuda
@requires_cuda
class TestStaticCudaLauncher(TestCase):
def setUp(self):
# Create a temporary file to store the cubin.
# We set delete=False so that the file persists after closing.
self.tmp_file = tempfile.NamedTemporaryFile(mode="wb")
self.tmp_file.close() # Close now; we'll open it for writing later.
super().setUp()
def tearDown(self):
super().tearDown()
# Delete the temporary cubin file.
try:
os.remove(self.tmp_file.name)
except FileNotFoundError:
pass
def _make_launcher(
self,
kernel: Callable,
args: tuple[Any, ...],
grid: tuple[Any, ...] = (1,),
) -> StaticallyLaunchedCudaKernel:
"""
Compiles a Triton kernel with the provided *args,
writes its cubin to the temporary file, and returns the file path.
"""
fn = triton.jit(kernel)
# Launch the kernel to trigger compilation.
compiled_kernel = fn[grid](*args)
result = StaticallyLaunchedCudaKernel(compiled_kernel)
result.write_cubin_to_file(self.tmp_file.name)
result.load_kernel()
return result
@skipIfRocm
def test_basic(self):
def simple_kernel(arg0, arg1):
x = tl.load(arg0)
y = arg1
tl.store(arg0, x + y)
arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
arg1 = 5
args = (arg0, arg1)
launcher = self._make_launcher(simple_kernel, args, (1,))
self.assertEqual(arg0, torch.tensor([5], dtype=torch.int32, device="cuda"))
self.assertEqual(launcher.arg_tys, "Oi")
new_arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
device_interface = get_interface_for_device("cuda")
stream = device_interface.get_raw_stream(device_interface.current_device())
launcher.run((1,), stream, new_arg0, arg1)
self.assertEqual(new_arg0, arg0)
# I wish I could macro all int types this into a single unit test on a loop, but
# 1. variables aren't allowed as type annotations in python
# 2. triton relies on inspect.get_source to get the type annotations
# so I can't even use exec() to generate the test cases.
# So we'll just make a few kernels by hand
@skipIfRocm
def test_unsigned_integers(self):
def unsigned_integers(
arg0, arg1: tl.uint8, arg2: tl.uint16, arg3: tl.uint32, arg4: tl.uint64
):
x = tl.load(arg0)
y = arg1 + arg2 + arg3 + arg4
tl.store(arg0, x + y)
arg0 = torch.zeros(1, dtype=torch.uint64, device="cuda")
# Using small numbers creates a Literal type which triton treats as a constant
args = (arg0, 50, 50, 50, 50)
launcher = self._make_launcher(unsigned_integers, args, (1,))
self.assertEqual(arg0, torch.tensor([200], dtype=torch.uint64, device="cuda"))
self.assertEqual(launcher.arg_tys, "OBHIK")
new_arg0 = torch.zeros(1, dtype=torch.uint64, device="cuda")
device_interface = get_interface_for_device("cuda")
stream = device_interface.get_raw_stream(device_interface.current_device())
launcher.run((1,), stream, new_arg0, 50, 50, 50, 50)
self.assertEqual(new_arg0, arg0)
@skipIfRocm
def test_signed_integers(self):
def signed_integers(
arg0, arg1: tl.int8, arg2: tl.int16, arg3: tl.int32, arg4: tl.int64
):
x = tl.load(arg0)
y = arg1 + arg2 + arg3 + arg4
tl.store(arg0, x + y)
arg0 = torch.zeros(1, dtype=torch.int64, device="cuda")
# Using small numbers creates a Literal type which triton treats as a constant
args = (arg0, 50, 50, 50, 50)
launcher = self._make_launcher(signed_integers, args, (1,))
self.assertEqual(arg0, torch.tensor([200], dtype=torch.int64, device="cuda"))
self.assertEqual(launcher.arg_tys, "Obhil")
new_arg0 = torch.zeros(1, dtype=torch.int64, device="cuda")
device_interface = get_interface_for_device("cuda")
stream = device_interface.get_raw_stream(device_interface.current_device())
launcher.run((1,), stream, new_arg0, 50, 50, 50, 50)
self.assertEqual(new_arg0, arg0)
# TODO: floats don't work properly, triton seems to think they're all tl.float32
# despite type annotations.
# There's also not really a good way for me to make a float16 in python...
@skipIfRocm
def test_floats(self):
def floats(arg0, arg1: tl.float16, arg2: tl.float32, arg3: tl.float64):
x = tl.load(arg0)
y = arg1 + arg2 + arg3
tl.store(arg0, x + y)
arg0 = torch.zeros(1, dtype=torch.float64, device="cuda")
args = (arg0, 1.0, 1.0, 1.0)
launcher = self._make_launcher(floats, args, (1,))
# TODO: in Pytorch's pinned version of triton, arg3 is typed as regular float
# but in triton 3.3.0, this is fixed and it's 0ffd. We'll need to update later.
self.assertEqual(launcher.arg_tys, "Offf")
self.assertEqual(arg0, torch.tensor([3.0], dtype=torch.float64, device="cuda"))
new_arg0 = torch.zeros(1, dtype=torch.float64, device="cuda")
device_interface = get_interface_for_device("cuda")
stream = device_interface.get_raw_stream(device_interface.current_device())
launcher.run((1,), stream, new_arg0, 1.0, 1.0, 1.0)
self.assertEqual(new_arg0, arg0)
@skipIfRocm
def test_basic_1arg(self):
def simple_kernel_1_arg(arg0):
x = tl.load(arg0)
tl.store(arg0, x + 1)
arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
launcher = self._make_launcher(simple_kernel_1_arg, (arg0,), (1,))
self.assertEqual(arg0, torch.tensor([1], dtype=torch.int32, device="cuda"))
self.assertEqual(launcher.arg_tys, "O")
new_arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
device_interface = get_interface_for_device("cuda")
stream = device_interface.get_raw_stream(device_interface.current_device())
launcher.run(
(1,),
stream,
new_arg0,
)
self.assertEqual(new_arg0, arg0)
@skipIfRocm
def test_constexpr(self):
# Constexprs are compiled directly into the cubin file,
# so we never need to pass it to StaticCudaLauncher.
@triton.jit
def kernel_constexpr(arg0, CONSTANT: tl.constexpr):
x = tl.load(arg0)
tl.store(arg0, x + CONSTANT)
# Can't use make_launcher because constexpr needs to be constant
arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
compiled_kernel = kernel_constexpr[(1,)](arg0, CONSTANT=5)
launcher = StaticallyLaunchedCudaKernel(compiled_kernel)
launcher.write_cubin_to_file(self.tmp_file.name)
launcher.load_kernel()
self.assertEqual(arg0, torch.tensor([5], dtype=torch.int32, device="cuda"))
self.assertEqual(launcher.arg_tys, "O")
new_arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
device_interface = get_interface_for_device("cuda")
stream = device_interface.get_raw_stream(device_interface.current_device())
launcher.run(
(1,),
stream,
new_arg0,
)
self.assertEqual(new_arg0, arg0)
@skipIfRocm
def test_implied_constant(self):
"""xnumel is unused in this kernel, but isn't explicitly marked as a constexpr"""
# This kernel was generated by inductor so it has a bunch of unused arguments. We don't change it
@triton.jit
def triton_red_fused_any_isinf_0(
in_ptr0,
out_ptr0,
xnumel, # noqa: F841
r0_numel,
XBLOCK: tl.constexpr,
R0_BLOCK: tl.constexpr,
):
xnumel = 1 # noqa: F841
rnumel = r0_numel # noqa: F841
RBLOCK: tl.constexpr = R0_BLOCK # noqa: F841
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] # noqa: F841
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) # noqa: F841
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base # noqa: F841
_tmp3 = tl.full([XBLOCK, R0_BLOCK], False, tl.int1)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset # noqa: F841
rindex = r0_index # noqa: F841
r0_0 = r0_index
tmp0 = tl.load(
in_ptr0 + (r0_0), r0_mask, eviction_policy="evict_first", other=0.0
)
tmp1 = libdevice.isinf(tmp0).to(tl.int1)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
tmp4 = _tmp3 | tmp2
_tmp3 = tl.where(r0_mask, tmp4, _tmp3)
tmp3 = triton_helpers.any(_tmp3.to(tl.int8), 1)[:, None].to(tl.int1)
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp3, None)
arg0 = torch.tensor([0.0, 0.5, float("inf"), 5], device="cuda")
arg1 = torch.tensor([False], device="cuda")
arg2 = torch.tensor([False], device="cuda")
compiled_kernel = triton_red_fused_any_isinf_0[1,](
arg0, arg1, 1, 128, XBLOCK=1, R0_BLOCK=1
)
launcher = StaticallyLaunchedCudaKernel(compiled_kernel)
launcher.write_cubin_to_file(self.tmp_file.name)
launcher.load_kernel()
device_interface = get_interface_for_device("cuda")
stream = device_interface.get_raw_stream(device_interface.current_device())
launcher.run((1,), stream, arg0, arg2, 1, 128)
self.assertEqual(arg1, arg2)
@skipIfRocm
def test_kernel_empty_tensor(self):
# Triton kernel generated by torch.compile of the following:
# @torch.compile()
# def foo(x, y):
# return torch.cat(((x * 4), y + 10))
# Running with example input:
# torch._dynamo.decorators.mark_unbacked(t, 0)
# x = torch.rand(0, device="cuda")
# y = torch.rand(20, device="cuda")
@triton.jit
def triton_poi_fused_cat_0(
in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK: tl.constexpr
):
xoffset = tl.program_id(0).to(tl.int64) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:].to(tl.int64)
xmask = xindex < xnumel
x0 = xindex
tmp0 = x0
tmp3 = ks0
tmp4 = tmp0 < tmp3
tmp5 = tl.load(
in_ptr0 + (x0), xmask & tmp4, eviction_policy="evict_last", other=0.0
)
tmp6 = 4.0
tmp7 = tmp5 * tmp6
tmp8 = tl.full(tmp7.shape, 0.0, tmp7.dtype)
tmp9 = tl.where(tmp4, tmp7, tmp8)
tmp10 = tmp0 >= tmp3
tmp13 = tl.load(
in_ptr1 + (x0 + ((-1) * ks0)),
xmask & tmp10,
eviction_policy="evict_last",
other=0.0,
)
tmp14 = 10.0
tmp15 = tmp13 + tmp14
tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
tmp17 = tl.where(tmp10, tmp15, tmp16)
tmp18 = tl.where(tmp4, tmp9, tmp17)
tl.store(out_ptr0 + (x0), tmp18, xmask)
arg0 = 0
arg1 = torch.randn(0, device="cuda")
arg2 = torch.randn(20, device="cuda")
buf0 = torch.empty(20, device="cuda")
buf1 = torch.empty(20, device="cuda")
xnumel = 20 + arg0
compiled_kernel = triton_poi_fused_cat_0[(1,)](
arg1, arg2, buf0, arg0, xnumel, XBLOCK=32
)
launcher = StaticallyLaunchedCudaKernel(compiled_kernel)
launcher.write_cubin_to_file(self.tmp_file.name)
launcher.load_kernel()
device_interface = get_interface_for_device("cuda")
stream = device_interface.get_raw_stream(device_interface.current_device())
launcher.run((1, 1, 1), stream, arg1, arg2, buf1, arg0, xnumel)
self.assertEqual(buf0, buf1)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
run_tests()

View File

@ -2545,3 +2545,28 @@ class _NodeIter(Iterator):
def __init__(self, root: FxNode, reversed: _bool) -> None: ...
def __iter__(self) -> Iterator[FxNode]: ...
def __next__(self) -> FxNode: ...
# Defined in torch/csrc/inductor/static_cuda_launcher.cpp
class _StaticCudaLauncher:
@staticmethod
def _load_kernel(
cubin_file: str,
func_name: str,
shared_mem_bytes: _int,
) -> Tuple[_int, _int, _int]:
...
@staticmethod
def _launch_kernel(
func: _int,
grid_x: _int,
grid_y: _int,
grid_z: _int,
num_warps: _int,
shared_mem_bytes: _int,
arg_types: str,
args: Tuple[Any, ...],
stream: _int,
) -> None:
...

View File

@ -0,0 +1,225 @@
import functools
from typing import Any, Optional
from typing_extensions import Unpack
from torch.utils._ordered_set import OrderedSet
from .triton_compat import ASTSource, CompiledKernel
MAX_SHARED_MEMORY = 49152
MAX_ARGS = 50
class StaticallyLaunchedCudaKernel:
"""
Parses the metadata of a CompiledKernel from Triton into a structure that can
launch the cuda kernel directly. Only works for triton kernels compiled to cubin.
Doing this avoids C++ codegen and compilation during compile, since we can use a
statically compiled library to launch the kernel. To avoid mallocing for the arguments,
we have a launcher for different numbers of arguments up to a max. StaticCudaLauncher
only supports # of arguments up until 10 for now.
Workflow:
Compile time:
1. Compile a kernel with triton and get a CompiledKernel
2. Instantiate kernel = StaticallyLaunchedCudaKernel(triton_kernel)
3. Write to a cubin file: kernel.write_cubin_to_file(filepath)
4. Call kernel.load_kernel() (CUDA should be initialized by this point) to load the cubin
Runtime:
5. Call kernel.run(grid, stream, args) to launch the kernel
Note that after step 3, StaticallyLaunchedCudaKernel is fully pickleable/serializable.
This allows it to be cached by FXGraphCache/TritonBundler, as well as sent from the worker
to the parent process in inductor.
"""
def __init__(self, kernel: CompiledKernel) -> None:
# To be used later when hooking up with torch.compile:
# inductor knows where the cubin file should be from triton,
# so won't need to write to a tmp file directly.
if hasattr(kernel, "_cubin_path"):
self.cubin_path = kernel._cubin_path
else:
self.cubin = kernel.asm["cubin"]
# TODO: is this right?
self.name = kernel.src.fn.__name__
self.hash = kernel.hash
if (
kernel.__class__.launch_enter_hook is not None
or kernel.__class__.launch_exit_hook is not None
):
raise NotImplementedError(
"We don't support launch enter or launch exit hooks"
)
self.num_warps = kernel.metadata.num_warps
self.shared = (
kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared
)
# When shared memory > 48 KB, triton allocates CUDA memory via both static and dynamic
# memory allocation, which gets really complicated. We'll handle it later.
# See triton/third-party/nvidia/driver.c in loadBinary
if self.shared > MAX_SHARED_MEMORY:
raise NotImplementedError(
"Shared memory size > 48KB requires special triton handling"
)
# Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel.
# Inductor never uses this field or enables it, but we still have to pass an extra None
# into the set of params if its enabled
if hasattr(kernel.metadata, "global_scratch_size"):
if kernel.metadata.global_scratch_size > 0:
raise NotImplementedError("Global scratch not yet supported")
else:
self.has_global_scratch = True
else:
self.has_global_scratch = False
self.arg_tys, self.constant_idxs = self.arg_ty_from_signature(kernel.src)
self.function: Optional[int] = (
None # Loaded by load_kernel(on the parent process)
)
num_args = len(self.arg_tys)
num_ctas = 1
if hasattr(kernel, "num_ctas"):
num_ctas = kernel.num_ctas
elif hasattr(kernel, "metadata"):
num_ctas = kernel.metadata.num_ctas
if num_ctas != 1:
raise NotImplementedError(
"Static cuda launcher only supports num_ctas == 1"
)
if num_args > MAX_ARGS or num_args == 0:
raise NotImplementedError(
"No static cuda launcher available for %d arguments", num_args
)
def load_kernel(self) -> None:
from torch._C import _StaticCudaLauncher
assert hasattr(self, "cubin_path")
if self.function is not None:
return
(self.function, self.n_regs, self.n_spills) = _StaticCudaLauncher._load_kernel(
self.cubin_path, self.name, self.shared
)
def write_cubin_to_file(self, filepath: str) -> None:
"""
Only used for tests where we don't have a cubin path.
"""
if hasattr(self, "cubin_path"):
return
# Just used by tests for now.
# TODO: derive cubin_path from wherever triton stores the cubin file on disk.
with open(filepath, "wb") as f:
f.write(self.cubin)
del self.cubin
self.cubin_path = filepath
@staticmethod
@functools.lru_cache
def type_mappings() -> dict[str, str]:
return {
"i1": "i",
"i8": "b",
"i16": "h",
"i32": "i",
"i64": "l",
"u1": "I",
"u8": "B",
"u16": "H",
"u32": "I",
"u64": "K",
"fp16": "f",
"bf16": "f",
"fp32": "f",
"f32": "f",
"fp64": "d",
# TODO handle nvTmaDesc/CUtensormap
}
def extract_type(self, ty: str) -> str:
"""
Takes a triton type from CompiledKernel.signature and
converts it into a single char encoding. _StaticCudaLauncher
will switch on this char to figure out what type the underlying
value should be passed to the triton kernel as.
"""
if ty[0] == "*":
return "O"
elif ty == "nvTmaDesc":
raise NotImplementedError("nvTmaDesc kernels are not yet supported")
return StaticallyLaunchedCudaKernel.type_mappings()[ty]
def arg_ty_from_signature(self, src: ASTSource) -> tuple[str, OrderedSet[int]]:
def index_key(i: Any) -> int:
return src.fn.arg_names.index(i) if isinstance(i, str) else i
signature = {index_key(key): value for key, value in src.signature.items()}
constants = [index_key(key) for key in getattr(src, "constants", dict())]
# Despite requiring them to be passed in, the triton CUDA launcher
# completely ignores the constexprs passed into it when generating code.
# So we can ignore them here too
params = []
constant_idxs: OrderedSet[int] = OrderedSet()
for i in sorted(signature.keys()):
ty = signature[i]
# In newer triton versions, constants are passed in to signature with type `constexpr`
# In older triton versions, there can be constants in src.constants that are not `constexpr` in signature
# so we check both here
if ty == "constexpr" or i in constants:
constant_idxs.add(i)
else:
params.append(self.extract_type(ty))
return "".join(params), constant_idxs
def run(
self, grid: tuple[int, ...], stream: int, *args: Unpack[tuple[object, ...]]
) -> None:
"""Actually run the kernel at runtime. This function is the hot codepath."""
from torch._C import _StaticCudaLauncher
# Assert load_kernel() has been called and args match
assert self.function is not None
# TODO: actually, if the args *don't* match, we probably should
# throw an exception. But if inductor is the only one calling this
# thing, it should always match.
# Get rid of constants before passing to cubin launcher
# TODO: is this (and the check below) slow to do at runtime? The thing is,
# we already spend the time in CachingAutotuner.launch() to massage the arguments
# properly anyways so this isn't exactly slower than that...
args = tuple(args[i] for i in range(len(args)) if i not in self.constant_idxs)
# Add a None if triton wants an extra parameter to the cubin
if self.has_global_scratch:
arg_tys = self.arg_tys + "O"
args = (*args, None)
else:
arg_tys = self.arg_tys
assert len(args) == len(arg_tys)
# TODO: can handle grid functions here or in C++, so
# that we don't need the grid handler above.
grid_x = grid[0]
grid_y = grid[1] if len(grid) > 1 else 1
grid_z = grid[2] if len(grid) > 2 else 1
_StaticCudaLauncher._launch_kernel(
self.function,
grid_x,
grid_y,
grid_z,
self.num_warps,
self.shared,
arg_tys,
args,
stream,
)

View File

@ -112,6 +112,7 @@
#include <ATen/ROCmFABackend.h>
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/native/transformers/cuda/sdp_utils.h>
#include <torch/csrc/inductor/static_cuda_launcher.h>
#ifdef __HIP_PLATFORM_AMD__
#include <ATen/native/cudnn/hip/BatchNorm.h>
#else
@ -1882,6 +1883,9 @@ PyObject* initModule() {
#ifdef USE_CUDA
torch::cuda::initModule(module);
#endif
#if defined(USE_CUDA) && !defined(USE_ROCM)
ASSERT_TRUE(StaticCudaLauncher_init(module));
#endif
#ifdef USE_MPS
torch::mps::initModule(module);
#endif

View File

@ -0,0 +1,418 @@
#if defined(USE_CUDA) && !defined(USE_ROCM)
// We disable this file from being hipified because there are CUDA drivers hip
// has not implemented yet. Also, we're passing in a cubin file directly, so it
// would take more work to support ROCM anyway.
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <ATen/Context.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/inductor/static_cuda_launcher.h>
#include <cstdint>
#include <stdexcept>
#include <torch/csrc/utils/python_numbers.h>
#include <filesystem>
#include <optional>
/**
Implements a static launcher for triton compiled CUDA kernels.
Given a path to a cubin file, a function name, and some metadata,
this class loads and launches the cubin.
Doing this avoids C++ codegen and compilation during compile, since we can
use a statically compiled library to launch the kernel. To avoid mallocing
for the arguments, we have a launcher for different numbers of arguments up
to a max. StaticCudaLauncher only supports # of arguments up until 10 for
now.
Note that we allocate 8 bytes per argument, no matter the types of each
argument, since we don't know ahead of time what the types of each argument
passed to the triton kernel are. This may take slightly more memory on the
stack, and will require some benchmarking. However, since the vast majority
of triton kernels have less than 10 args, this seems unlikely to be
expensive.
This launcher is paired with StaticallyLaunchedCudaKernel in
triton_heuristics.py.
TODO:
- Handle CutensorMap, NvtmDesc
- Handle launch_enter and launch_exit hooks (in python maybe?)
*/
// Use ATen/NVRTC.h to gain access to the CUDA driver API.
// This function is only called when CUDA is enabled, and only called to load
// and launch triton compiled CUDA kernels, so CUDA should always be
// initialized.
namespace {
const at::cuda::NVRTC& nvrtc() {
return at::globalContext().getNVRTC();
}
#define MAX_ARGS 50
CUdeviceptr getPointer(PyObject* obj) {
CUdeviceptr data_ptr = 0;
if (THPUtils_checkLong(obj)) {
data_ptr = THPUtils_unpackUInt64(obj);
return data_ptr;
}
if (obj == Py_None) {
// valid nullptr
return data_ptr;
}
auto ptr = THPObjectPtr{PyObject_GetAttrString(obj, "data_ptr")};
TORCH_CHECK(
ptr != nullptr,
"Pointer argument must be either uint64 or have data_ptr method")
auto empty_tuple = THPObjectPtr{PyTuple_New(0)};
auto ret = THPObjectPtr{PyObject_Call(ptr, empty_tuple, nullptr)};
TORCH_CHECK(
THPUtils_checkLong(ret),
"data_ptr method of Pointer object must return 64-bit int");
data_ptr = THPUtils_unpackUInt64(ret);
if (!data_ptr)
return data_ptr;
CUdeviceptr dev_ptr = 0;
AT_CUDA_DRIVER_CHECK(nvrtc().cuPointerGetAttribute(
&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr));
return dev_ptr;
}
CUfunction loadKernel(
std::string filePath,
const std::string& funcName,
uint32_t sharedMemBytes,
const std::optional<std::string>& cubinDir = std::nullopt) {
if (cubinDir) {
std::filesystem::path p1{*cubinDir};
std::filesystem::path p2{filePath};
filePath = (p1 / p2.filename()).string();
}
CUmodule mod = nullptr;
CUfunction func = nullptr;
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoad(&mod, filePath.c_str()));
AT_CUDA_DRIVER_CHECK(
nvrtc().cuModuleGetFunction(&func, mod, funcName.c_str()));
if (sharedMemBytes > 0) {
AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncSetAttribute(
func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, sharedMemBytes));
}
return func;
}
template <size_t NUM_ARGS>
inline void launchKernel(
CUfunction func,
uint32_t gridX,
uint32_t gridY,
uint32_t gridZ,
uint32_t numWarps,
uint32_t sharedMemBytes,
std::array<void*, NUM_ARGS>& args,
cudaStream_t stream) {
// cta_args is always 1 for inductor generated triton kernels,
// so we don't need to figure out grid dimension here
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
func,
gridX,
gridY,
gridZ,
32 * numWarps, // blockDim.x
1, // blockDim.y
1, // blockDim.z
sharedMemBytes,
stream,
args.data(),
nullptr));
}
template <typename FINAL, typename F>
void convertType(F converter, const char* name, void* slot, PyObject* item) {
auto temp = converter(item);
if (PyErr_Occurred()) {
std::string msg = "Failed to convert argument to ";
msg += name;
TORCH_CHECK(false, msg);
}
*reinterpret_cast<FINAL*>(slot) = static_cast<FINAL>(temp);
}
/**
Given a list of args and their types (in a string), along with two stack
allocated arrays, puts each argument arg_{i} into argStorage[i], and a
pointer to the argument in kernelArgs[i]. We then can pass `kernelArgs`
directly to launchKernel. Note that some args can be less than 8 bytes, but
we'll still allocate 8 bytes on the stack for them.
* TODO: Need to handle NvtmDesc here.
*/
void parseKernelArgs(
PyObject* varArgs,
const char* argTypes,
uint64_t* argStorage,
void** kernelArgs) {
int numKernelArgs = static_cast<int>(std::strlen(argTypes));
TORCH_CHECK(
PyTuple_Check(varArgs), "Kernel arguments must be provided as a tuple");
TORCH_CHECK(
PyTuple_Size(varArgs) == static_cast<Py_ssize_t>(numKernelArgs),
"Mismatch between number of argument types and provided arguments");
for (int i = 0; i < numKernelArgs; ++i) {
// Get pointer to the ith 8-byte slot.
void* slot = static_cast<void*>(&argStorage[i]);
PyObject* item = PyTuple_GetItem(varArgs, i);
char typeChar = argTypes[i];
switch (typeChar) {
case 'b':
convertType<int8_t>(THPUtils_unpackInt, "int8", slot, item);
break;
case 'h':
convertType<int16_t>(THPUtils_unpackInt, "int16", slot, item);
break;
case 'i':
convertType<int32_t>(THPUtils_unpackLong, "int32", slot, item);
break;
case 'l':
convertType<int64_t>(THPUtils_unpackLong, "int64", slot, item);
break;
case 'B':
convertType<uint8_t>(THPUtils_unpackUInt32, "uint8", slot, item);
break;
case 'H':
convertType<uint16_t>(THPUtils_unpackUInt32, "uint16", slot, item);
break;
case 'I':
convertType<uint32_t>(THPUtils_unpackUInt32, "uint32", slot, item);
break;
case 'K':
convertType<uint64_t>(THPUtils_unpackUInt64, "uint64", slot, item);
break;
case 'f':
convertType<float>(THPUtils_unpackDouble, "float", slot, item);
break;
case 'd':
convertType<double>(THPUtils_unpackDouble, "double", slot, item);
break;
case 'O': { // pointer; using helper getPointer() (which may call
// data_ptr() if needed)
CUdeviceptr ptr = getPointer(item);
*reinterpret_cast<CUdeviceptr*>(slot) = ptr;
break;
}
default:
TORCH_CHECK(false, "Unknown type passed in: ", typeChar);
}
// Save the pointer to this slot.
kernelArgs[i] = slot;
}
}
/* Load the CUDA kernel into memory (called during torch.compile), and
return a pointer to it (along with nregs and nspills).
Called in python as:
(function, n_regs, n_spills) = load_kernel(cubin_path, func_name,
sharedMemBytes)
*/
PyObject* load_kernel(PyObject* self, PyObject* args) {
const char* filePath = nullptr;
const char* funcName = nullptr;
int sharedMemBytes = 0;
int n_regs = 0;
int n_spills = 0;
if (!PyArg_ParseTuple(args, "ssi", &filePath, &funcName, &sharedMemBytes)) {
return nullptr;
}
CUfunction func = nullptr;
func = loadKernel(filePath, funcName, sharedMemBytes);
// Taken from triton/nvidia/backend/driver.c
AT_CUDA_DRIVER_CHECK(
nvrtc().cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, func));
AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncGetAttribute(
&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func));
n_spills /= 4;
// Return a tuple of CUFunction, n_regs, n_spills
return Py_BuildValue(
"(Kii)", reinterpret_cast<uint64_t>(func), n_regs, n_spills);
}
PyObject* launch_kernel_inner(
CUfunction func,
int gridX,
int gridY,
int gridZ,
int numWarps,
int sharedMemBytes,
const char* argTypes,
PyObject* varArgs,
cudaStream_t cudaStream) {
// Launch the kernel
// Prepare the arguments for the kernel
// We allocate 8 bytes per argument on the stack. We then allocate 8 more
// bytes to point to each 8 byte slot in argStorage, and pass that array of
// pointers to launchKernel.
std::array<uint64_t, MAX_ARGS> argStorage = {};
std::array<void*, MAX_ARGS> kernelArgs = {};
parseKernelArgs(varArgs, argTypes, argStorage.data(), kernelArgs.data());
launchKernel(
func,
gridX,
gridY,
gridZ,
numWarps,
sharedMemBytes,
kernelArgs,
cudaStream);
Py_RETURN_NONE;
}
/**
* Main entrypoint function called at runtime; called like this in python land:
launcher(
function, # CUfunction returned by load_kernel()
grid_x,
grid_y,
grid_z,
num_warps,
shared,
arg_tys, # e.g. "bO" for (int8_t, uint64_t)
args, # tuple of arguments passed to the kernel
stream,
)
*
*/
PyObject* launch_kernel(PyObject* self, PyObject* args) {
// Pointer to CUfunction generated by load_kernel()
uint64_t func_ptr = 0;
int gridX = 0, gridY = 0, gridZ = 0, numWarps = 0, sharedMemBytes = 0;
// stream here should be the raw stream gotten from
// device_interface.get_raw_stream()
uint64_t stream = 0;
const char* argTypes = nullptr;
PyObject* varArgs = nullptr;
// Parse the fixed arguments and the format string
if (!PyArg_ParseTuple(
args,
"KiiiiisOl",
&func_ptr,
&gridX,
&gridY,
&gridZ,
&numWarps,
&sharedMemBytes,
&argTypes,
&varArgs,
&stream)) {
return nullptr;
}
CUfunction func = reinterpret_cast<CUfunction>(func_ptr); // NOLINT
cudaStream_t cudaStream = reinterpret_cast<cudaStream_t>(stream); // NOLINT
auto num_args = std::strlen(argTypes);
TORCH_CHECK(
num_args <= MAX_ARGS,
"Static Cuda Launcher only supports up to 50 arguments");
return launch_kernel_inner(
func,
gridX,
gridY,
gridZ,
numWarps,
sharedMemBytes,
argTypes,
varArgs,
cudaStream);
}
std::array<PyMethodDef, 2> StaticCudaLauncherMethods = {
PyMethodDef{
"_launch_kernel",
(PyCFunction)launch_kernel,
METH_VARARGS,
"Cuda Launcher with up to 50 args"},
PyMethodDef{
"_load_kernel",
(PyCFunction)load_kernel,
METH_VARARGS,
"Load CUDA kernel from cubin file"}};
// Define a minimal type for StaticCudaLauncher.
// We don't implement __new__ or __init__ because we're using it only as a
// container for static methods.
PyTypeObject StaticCudaLauncherType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._StaticCudaLauncher", // tp_name
sizeof(PyObject), // tp_basicsize
0, // tp_itemsize
nullptr, // tp_dealloc
0, // tp_print (deprecated)
nullptr, // tp_getattr
nullptr, // tp_setattr
nullptr, // tp_reserved
nullptr, // tp_repr
nullptr, // tp_as_number
nullptr, // tp_as_sequence
nullptr, // tp_as_mapping
nullptr, // tp_hash
nullptr, // tp_call
nullptr, // tp_str
nullptr, // tp_getattro
nullptr, // tp_setattro
nullptr, // tp_as_buffer
Py_TPFLAGS_DEFAULT,
"Statically defined launchers for triton compiled CUDA kernels", // tp_doc
nullptr, // tp_traverse
nullptr, // tp_clear
nullptr, // tp_richcompare
0, // tp_weaklistoffset
nullptr, // tp_iter
nullptr, // tp_iternext
nullptr, // tp_methods
nullptr, // tp_members
nullptr, // tp_getset
nullptr, // tp_base
nullptr, // tp_dict (automatically allocated)
nullptr, // tp_descr_get
nullptr, // tp_descr_set
0, // tp_dictoffset
nullptr, // tp_init
nullptr, // tp_alloc
nullptr, // tp_new
};
} // anonymous namespace
// Module initialization: add StaticCudaLauncher to the module with our static
// methods.
bool StaticCudaLauncher_init(PyObject* module) {
if (PyType_Ready(&StaticCudaLauncherType) < 0) {
return false;
}
// Add our static methods to the type's dictionary.
PyObject* dict = StaticCudaLauncherType.tp_dict;
for (auto& def : StaticCudaLauncherMethods) {
PyObject* func = PyCFunction_New(&def, nullptr);
if (!func) {
return false;
}
PyObject* static_method = PyStaticMethod_New(func);
Py_DECREF(func);
if (PyDict_SetItemString(dict, def.ml_name, static_method) < 0) {
Py_DECREF(static_method);
return false;
}
Py_DECREF(static_method);
}
Py_INCREF(&StaticCudaLauncherType);
if (PyModule_AddObject(
module, "_StaticCudaLauncher", (PyObject*)&StaticCudaLauncherType) <
0) {
Py_DECREF(&StaticCudaLauncherType);
return false;
}
return true;
}
#endif

View File

@ -0,0 +1,7 @@
#pragma once
#if defined(USE_CUDA) && !defined(USE_ROCM)
#include <torch/csrc/inductor/cpp_wrapper/device_internal/cuda.h>
#include <torch/csrc/python_headers.h>
bool StaticCudaLauncher_init(PyObject* module);
#endif