mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Reland] First version of statically compiled launcher for triton compiled CUDA kernels (#149238)
This is a new version of https://github.com/pytorch/pytorch/pull/148561 fixing the ROCM test failure Putting this up for a first pass review, though I will likely make a bunch of changes before landing to add more features, etc. This diff implements a first version of a static CUDA kernel launcher in `torch._C`. The goal here is to take a cubin file and some metadata from a CompiledKernel from `triton`, and launch the cubin file directly. Background doc: https://docs.google.com/document/d/1rjRcHl6MfauHG30nCoQX-9UKvKyIs4WWMy_GsGyqb9g/edit?tab=t.0#heading=h.ut5lf39lzq66 Normally, using triton's CompiledKernel.make_launcher(), we would pay the cost of codegenning C++ and running it at compile time. With this new approach, we can use one statically compiled library to launch the kernel. The tradeoff here is that this new kernel launcher will not be able to use codegen to deal with different lengths/types of arguments. So we use templating to handle up to 10 arguments for now. We also allocate 8 bytes on the stack per argument no matter the argument type, which can take more memory than codegenning. On the other hand, we improve compile time on cold and warm start by not having to call the C++ compiler at all. This diff does not add the launcher to torch, but introduces a basic test suite. A list of TODOs that are not yet complete: - Handle `nvTmaDesc` and `cuTensorMap`, which triton handles - Embed the grid logic instead of passing in gridX,Y,Z - Handle launch_enter and exit hooks? (Not sure if inductor has these) - Benchmarking to see if there's runtime performance loss - Probably lots of features of the triton C++ generated code that I haven't handled yet. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149238 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
c83c711da8
commit
a9c55277d7
@ -156,6 +156,7 @@ NVRTC_STUB2(nvrtcGetProgramLogSize,nvrtcProgram, size_t*)
|
||||
NVRTC_STUB2(nvrtcGetProgramLog, nvrtcProgram, char *)
|
||||
NVRTC_STUB3(nvrtcGetLoweredName, nvrtcProgram, const char *, const char **)
|
||||
|
||||
CUDA_STUB2(cuModuleLoad, CUmodule*, const char*)
|
||||
CUDA_STUB2(cuModuleLoadData, CUmodule *, const void *)
|
||||
CUDA_STUB3(cuModuleGetFunction, CUfunction *, CUmodule, const char *)
|
||||
CUDA_STUB4(cuOccupancyMaxActiveBlocksPerMultiprocessor, int *, CUfunction, int, size_t)
|
||||
@ -169,6 +170,8 @@ CUDA_STUB4(cuLinkCreate, unsigned int, CUjit_option *, void **, CUlinkState *)
|
||||
CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *)
|
||||
CUDA_STUB3(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int)
|
||||
CUDA_STUB3(cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction)
|
||||
CUDA_STUB3(cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr)
|
||||
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
|
||||
CUresult CUDAAPI
|
||||
|
@ -43,6 +43,7 @@ namespace at::cuda {
|
||||
_(nvrtcGetProgramLogSize) \
|
||||
_(nvrtcGetProgramLog) \
|
||||
_(nvrtcGetLoweredName) \
|
||||
_(cuModuleLoad) \
|
||||
_(cuModuleLoadData) \
|
||||
_(cuModuleLoadDataEx) \
|
||||
_(cuModuleGetFunction) \
|
||||
@ -60,6 +61,7 @@ namespace at::cuda {
|
||||
_(cuLinkComplete) \
|
||||
_(cuFuncSetAttribute) \
|
||||
_(cuFuncGetAttribute) \
|
||||
_(cuPointerGetAttribute) \
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
|
||||
#define AT_FORALL_NVRTC_EXTENDED(_) \
|
||||
|
@ -859,6 +859,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/inductor/aoti_eager/kernel_holder.cpp",
|
||||
"torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp",
|
||||
"torch/csrc/inductor/resize_storage_bytes.cpp",
|
||||
"torch/csrc/inductor/static_cuda_launcher.cpp",
|
||||
"torch/csrc/jit/backends/backend_init.cpp",
|
||||
"torch/csrc/jit/python/init.cpp",
|
||||
"torch/csrc/jit/passes/onnx.cpp",
|
||||
|
319
test/inductor/test_static_cuda_launcher.py
Normal file
319
test/inductor/test_static_cuda_launcher.py
Normal file
@ -0,0 +1,319 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._inductor.runtime import triton_helpers
|
||||
from torch._inductor.runtime.static_cuda_launcher import StaticallyLaunchedCudaKernel
|
||||
from torch._inductor.runtime.triton_compat import tl, triton
|
||||
from torch._inductor.runtime.triton_helpers import libdevice
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.common_utils import skipIfRocm
|
||||
from torch.testing._internal.triton_utils import requires_cuda
|
||||
|
||||
|
||||
@requires_cuda
|
||||
class TestStaticCudaLauncher(TestCase):
|
||||
def setUp(self):
|
||||
# Create a temporary file to store the cubin.
|
||||
# We set delete=False so that the file persists after closing.
|
||||
self.tmp_file = tempfile.NamedTemporaryFile(mode="wb")
|
||||
self.tmp_file.close() # Close now; we'll open it for writing later.
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# Delete the temporary cubin file.
|
||||
try:
|
||||
os.remove(self.tmp_file.name)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def _make_launcher(
|
||||
self,
|
||||
kernel: Callable,
|
||||
args: tuple[Any, ...],
|
||||
grid: tuple[Any, ...] = (1,),
|
||||
) -> StaticallyLaunchedCudaKernel:
|
||||
"""
|
||||
Compiles a Triton kernel with the provided *args,
|
||||
writes its cubin to the temporary file, and returns the file path.
|
||||
"""
|
||||
fn = triton.jit(kernel)
|
||||
# Launch the kernel to trigger compilation.
|
||||
compiled_kernel = fn[grid](*args)
|
||||
result = StaticallyLaunchedCudaKernel(compiled_kernel)
|
||||
result.write_cubin_to_file(self.tmp_file.name)
|
||||
result.load_kernel()
|
||||
return result
|
||||
|
||||
@skipIfRocm
|
||||
def test_basic(self):
|
||||
def simple_kernel(arg0, arg1):
|
||||
x = tl.load(arg0)
|
||||
y = arg1
|
||||
tl.store(arg0, x + y)
|
||||
|
||||
arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
|
||||
arg1 = 5
|
||||
args = (arg0, arg1)
|
||||
|
||||
launcher = self._make_launcher(simple_kernel, args, (1,))
|
||||
self.assertEqual(arg0, torch.tensor([5], dtype=torch.int32, device="cuda"))
|
||||
self.assertEqual(launcher.arg_tys, "Oi")
|
||||
new_arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
|
||||
device_interface = get_interface_for_device("cuda")
|
||||
stream = device_interface.get_raw_stream(device_interface.current_device())
|
||||
|
||||
launcher.run((1,), stream, new_arg0, arg1)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
# I wish I could macro all int types this into a single unit test on a loop, but
|
||||
# 1. variables aren't allowed as type annotations in python
|
||||
# 2. triton relies on inspect.get_source to get the type annotations
|
||||
# so I can't even use exec() to generate the test cases.
|
||||
# So we'll just make a few kernels by hand
|
||||
@skipIfRocm
|
||||
def test_unsigned_integers(self):
|
||||
def unsigned_integers(
|
||||
arg0, arg1: tl.uint8, arg2: tl.uint16, arg3: tl.uint32, arg4: tl.uint64
|
||||
):
|
||||
x = tl.load(arg0)
|
||||
y = arg1 + arg2 + arg3 + arg4
|
||||
tl.store(arg0, x + y)
|
||||
|
||||
arg0 = torch.zeros(1, dtype=torch.uint64, device="cuda")
|
||||
# Using small numbers creates a Literal type which triton treats as a constant
|
||||
args = (arg0, 50, 50, 50, 50)
|
||||
|
||||
launcher = self._make_launcher(unsigned_integers, args, (1,))
|
||||
self.assertEqual(arg0, torch.tensor([200], dtype=torch.uint64, device="cuda"))
|
||||
self.assertEqual(launcher.arg_tys, "OBHIK")
|
||||
new_arg0 = torch.zeros(1, dtype=torch.uint64, device="cuda")
|
||||
device_interface = get_interface_for_device("cuda")
|
||||
stream = device_interface.get_raw_stream(device_interface.current_device())
|
||||
launcher.run((1,), stream, new_arg0, 50, 50, 50, 50)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_signed_integers(self):
|
||||
def signed_integers(
|
||||
arg0, arg1: tl.int8, arg2: tl.int16, arg3: tl.int32, arg4: tl.int64
|
||||
):
|
||||
x = tl.load(arg0)
|
||||
y = arg1 + arg2 + arg3 + arg4
|
||||
tl.store(arg0, x + y)
|
||||
|
||||
arg0 = torch.zeros(1, dtype=torch.int64, device="cuda")
|
||||
# Using small numbers creates a Literal type which triton treats as a constant
|
||||
args = (arg0, 50, 50, 50, 50)
|
||||
|
||||
launcher = self._make_launcher(signed_integers, args, (1,))
|
||||
self.assertEqual(arg0, torch.tensor([200], dtype=torch.int64, device="cuda"))
|
||||
self.assertEqual(launcher.arg_tys, "Obhil")
|
||||
new_arg0 = torch.zeros(1, dtype=torch.int64, device="cuda")
|
||||
device_interface = get_interface_for_device("cuda")
|
||||
stream = device_interface.get_raw_stream(device_interface.current_device())
|
||||
launcher.run((1,), stream, new_arg0, 50, 50, 50, 50)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
# TODO: floats don't work properly, triton seems to think they're all tl.float32
|
||||
# despite type annotations.
|
||||
# There's also not really a good way for me to make a float16 in python...
|
||||
@skipIfRocm
|
||||
def test_floats(self):
|
||||
def floats(arg0, arg1: tl.float16, arg2: tl.float32, arg3: tl.float64):
|
||||
x = tl.load(arg0)
|
||||
y = arg1 + arg2 + arg3
|
||||
tl.store(arg0, x + y)
|
||||
|
||||
arg0 = torch.zeros(1, dtype=torch.float64, device="cuda")
|
||||
|
||||
args = (arg0, 1.0, 1.0, 1.0)
|
||||
|
||||
launcher = self._make_launcher(floats, args, (1,))
|
||||
# TODO: in Pytorch's pinned version of triton, arg3 is typed as regular float
|
||||
# but in triton 3.3.0, this is fixed and it's 0ffd. We'll need to update later.
|
||||
self.assertEqual(launcher.arg_tys, "Offf")
|
||||
self.assertEqual(arg0, torch.tensor([3.0], dtype=torch.float64, device="cuda"))
|
||||
new_arg0 = torch.zeros(1, dtype=torch.float64, device="cuda")
|
||||
device_interface = get_interface_for_device("cuda")
|
||||
stream = device_interface.get_raw_stream(device_interface.current_device())
|
||||
launcher.run((1,), stream, new_arg0, 1.0, 1.0, 1.0)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_basic_1arg(self):
|
||||
def simple_kernel_1_arg(arg0):
|
||||
x = tl.load(arg0)
|
||||
tl.store(arg0, x + 1)
|
||||
|
||||
arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
|
||||
launcher = self._make_launcher(simple_kernel_1_arg, (arg0,), (1,))
|
||||
self.assertEqual(arg0, torch.tensor([1], dtype=torch.int32, device="cuda"))
|
||||
self.assertEqual(launcher.arg_tys, "O")
|
||||
new_arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
|
||||
device_interface = get_interface_for_device("cuda")
|
||||
stream = device_interface.get_raw_stream(device_interface.current_device())
|
||||
|
||||
launcher.run(
|
||||
(1,),
|
||||
stream,
|
||||
new_arg0,
|
||||
)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_constexpr(self):
|
||||
# Constexprs are compiled directly into the cubin file,
|
||||
# so we never need to pass it to StaticCudaLauncher.
|
||||
|
||||
@triton.jit
|
||||
def kernel_constexpr(arg0, CONSTANT: tl.constexpr):
|
||||
x = tl.load(arg0)
|
||||
tl.store(arg0, x + CONSTANT)
|
||||
|
||||
# Can't use make_launcher because constexpr needs to be constant
|
||||
arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
|
||||
compiled_kernel = kernel_constexpr[(1,)](arg0, CONSTANT=5)
|
||||
launcher = StaticallyLaunchedCudaKernel(compiled_kernel)
|
||||
launcher.write_cubin_to_file(self.tmp_file.name)
|
||||
launcher.load_kernel()
|
||||
|
||||
self.assertEqual(arg0, torch.tensor([5], dtype=torch.int32, device="cuda"))
|
||||
self.assertEqual(launcher.arg_tys, "O")
|
||||
new_arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
|
||||
device_interface = get_interface_for_device("cuda")
|
||||
stream = device_interface.get_raw_stream(device_interface.current_device())
|
||||
launcher.run(
|
||||
(1,),
|
||||
stream,
|
||||
new_arg0,
|
||||
)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_implied_constant(self):
|
||||
"""xnumel is unused in this kernel, but isn't explicitly marked as a constexpr"""
|
||||
|
||||
# This kernel was generated by inductor so it has a bunch of unused arguments. We don't change it
|
||||
@triton.jit
|
||||
def triton_red_fused_any_isinf_0(
|
||||
in_ptr0,
|
||||
out_ptr0,
|
||||
xnumel, # noqa: F841
|
||||
r0_numel,
|
||||
XBLOCK: tl.constexpr,
|
||||
R0_BLOCK: tl.constexpr,
|
||||
):
|
||||
xnumel = 1 # noqa: F841
|
||||
rnumel = r0_numel # noqa: F841
|
||||
RBLOCK: tl.constexpr = R0_BLOCK # noqa: F841
|
||||
xoffset = tl.program_id(0) * XBLOCK
|
||||
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] # noqa: F841
|
||||
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) # noqa: F841
|
||||
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
||||
rbase = r0_base # noqa: F841
|
||||
_tmp3 = tl.full([XBLOCK, R0_BLOCK], False, tl.int1)
|
||||
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
||||
r0_index = r0_offset + r0_base
|
||||
r0_mask = r0_index < r0_numel
|
||||
roffset = r0_offset # noqa: F841
|
||||
rindex = r0_index # noqa: F841
|
||||
r0_0 = r0_index
|
||||
tmp0 = tl.load(
|
||||
in_ptr0 + (r0_0), r0_mask, eviction_policy="evict_first", other=0.0
|
||||
)
|
||||
tmp1 = libdevice.isinf(tmp0).to(tl.int1)
|
||||
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
|
||||
tmp4 = _tmp3 | tmp2
|
||||
_tmp3 = tl.where(r0_mask, tmp4, _tmp3)
|
||||
tmp3 = triton_helpers.any(_tmp3.to(tl.int8), 1)[:, None].to(tl.int1)
|
||||
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp3, None)
|
||||
|
||||
arg0 = torch.tensor([0.0, 0.5, float("inf"), 5], device="cuda")
|
||||
arg1 = torch.tensor([False], device="cuda")
|
||||
arg2 = torch.tensor([False], device="cuda")
|
||||
compiled_kernel = triton_red_fused_any_isinf_0[1,](
|
||||
arg0, arg1, 1, 128, XBLOCK=1, R0_BLOCK=1
|
||||
)
|
||||
|
||||
launcher = StaticallyLaunchedCudaKernel(compiled_kernel)
|
||||
launcher.write_cubin_to_file(self.tmp_file.name)
|
||||
launcher.load_kernel()
|
||||
|
||||
device_interface = get_interface_for_device("cuda")
|
||||
stream = device_interface.get_raw_stream(device_interface.current_device())
|
||||
launcher.run((1,), stream, arg0, arg2, 1, 128)
|
||||
self.assertEqual(arg1, arg2)
|
||||
|
||||
@skipIfRocm
|
||||
def test_kernel_empty_tensor(self):
|
||||
# Triton kernel generated by torch.compile of the following:
|
||||
# @torch.compile()
|
||||
# def foo(x, y):
|
||||
# return torch.cat(((x * 4), y + 10))
|
||||
|
||||
# Running with example input:
|
||||
# torch._dynamo.decorators.mark_unbacked(t, 0)
|
||||
# x = torch.rand(0, device="cuda")
|
||||
# y = torch.rand(20, device="cuda")
|
||||
|
||||
@triton.jit
|
||||
def triton_poi_fused_cat_0(
|
||||
in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK: tl.constexpr
|
||||
):
|
||||
xoffset = tl.program_id(0).to(tl.int64) * XBLOCK
|
||||
xindex = xoffset + tl.arange(0, XBLOCK)[:].to(tl.int64)
|
||||
xmask = xindex < xnumel
|
||||
x0 = xindex
|
||||
tmp0 = x0
|
||||
tmp3 = ks0
|
||||
tmp4 = tmp0 < tmp3
|
||||
tmp5 = tl.load(
|
||||
in_ptr0 + (x0), xmask & tmp4, eviction_policy="evict_last", other=0.0
|
||||
)
|
||||
tmp6 = 4.0
|
||||
tmp7 = tmp5 * tmp6
|
||||
tmp8 = tl.full(tmp7.shape, 0.0, tmp7.dtype)
|
||||
tmp9 = tl.where(tmp4, tmp7, tmp8)
|
||||
tmp10 = tmp0 >= tmp3
|
||||
tmp13 = tl.load(
|
||||
in_ptr1 + (x0 + ((-1) * ks0)),
|
||||
xmask & tmp10,
|
||||
eviction_policy="evict_last",
|
||||
other=0.0,
|
||||
)
|
||||
tmp14 = 10.0
|
||||
tmp15 = tmp13 + tmp14
|
||||
tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
|
||||
tmp17 = tl.where(tmp10, tmp15, tmp16)
|
||||
tmp18 = tl.where(tmp4, tmp9, tmp17)
|
||||
tl.store(out_ptr0 + (x0), tmp18, xmask)
|
||||
|
||||
arg0 = 0
|
||||
arg1 = torch.randn(0, device="cuda")
|
||||
arg2 = torch.randn(20, device="cuda")
|
||||
buf0 = torch.empty(20, device="cuda")
|
||||
buf1 = torch.empty(20, device="cuda")
|
||||
xnumel = 20 + arg0
|
||||
compiled_kernel = triton_poi_fused_cat_0[(1,)](
|
||||
arg1, arg2, buf0, arg0, xnumel, XBLOCK=32
|
||||
)
|
||||
launcher = StaticallyLaunchedCudaKernel(compiled_kernel)
|
||||
|
||||
launcher.write_cubin_to_file(self.tmp_file.name)
|
||||
launcher.load_kernel()
|
||||
device_interface = get_interface_for_device("cuda")
|
||||
stream = device_interface.get_raw_stream(device_interface.current_device())
|
||||
|
||||
launcher.run((1, 1, 1), stream, arg1, arg2, buf1, arg0, xnumel)
|
||||
self.assertEqual(buf0, buf1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -2545,3 +2545,28 @@ class _NodeIter(Iterator):
|
||||
def __init__(self, root: FxNode, reversed: _bool) -> None: ...
|
||||
def __iter__(self) -> Iterator[FxNode]: ...
|
||||
def __next__(self) -> FxNode: ...
|
||||
|
||||
|
||||
# Defined in torch/csrc/inductor/static_cuda_launcher.cpp
|
||||
class _StaticCudaLauncher:
|
||||
@staticmethod
|
||||
def _load_kernel(
|
||||
cubin_file: str,
|
||||
func_name: str,
|
||||
shared_mem_bytes: _int,
|
||||
) -> Tuple[_int, _int, _int]:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def _launch_kernel(
|
||||
func: _int,
|
||||
grid_x: _int,
|
||||
grid_y: _int,
|
||||
grid_z: _int,
|
||||
num_warps: _int,
|
||||
shared_mem_bytes: _int,
|
||||
arg_types: str,
|
||||
args: Tuple[Any, ...],
|
||||
stream: _int,
|
||||
) -> None:
|
||||
...
|
||||
|
225
torch/_inductor/runtime/static_cuda_launcher.py
Normal file
225
torch/_inductor/runtime/static_cuda_launcher.py
Normal file
@ -0,0 +1,225 @@
|
||||
import functools
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .triton_compat import ASTSource, CompiledKernel
|
||||
|
||||
|
||||
MAX_SHARED_MEMORY = 49152
|
||||
MAX_ARGS = 50
|
||||
|
||||
|
||||
class StaticallyLaunchedCudaKernel:
|
||||
"""
|
||||
Parses the metadata of a CompiledKernel from Triton into a structure that can
|
||||
launch the cuda kernel directly. Only works for triton kernels compiled to cubin.
|
||||
|
||||
Doing this avoids C++ codegen and compilation during compile, since we can use a
|
||||
statically compiled library to launch the kernel. To avoid mallocing for the arguments,
|
||||
we have a launcher for different numbers of arguments up to a max. StaticCudaLauncher
|
||||
only supports # of arguments up until 10 for now.
|
||||
|
||||
Workflow:
|
||||
Compile time:
|
||||
1. Compile a kernel with triton and get a CompiledKernel
|
||||
2. Instantiate kernel = StaticallyLaunchedCudaKernel(triton_kernel)
|
||||
3. Write to a cubin file: kernel.write_cubin_to_file(filepath)
|
||||
4. Call kernel.load_kernel() (CUDA should be initialized by this point) to load the cubin
|
||||
Runtime:
|
||||
5. Call kernel.run(grid, stream, args) to launch the kernel
|
||||
|
||||
Note that after step 3, StaticallyLaunchedCudaKernel is fully pickleable/serializable.
|
||||
This allows it to be cached by FXGraphCache/TritonBundler, as well as sent from the worker
|
||||
to the parent process in inductor.
|
||||
"""
|
||||
|
||||
def __init__(self, kernel: CompiledKernel) -> None:
|
||||
# To be used later when hooking up with torch.compile:
|
||||
# inductor knows where the cubin file should be from triton,
|
||||
# so won't need to write to a tmp file directly.
|
||||
if hasattr(kernel, "_cubin_path"):
|
||||
self.cubin_path = kernel._cubin_path
|
||||
else:
|
||||
self.cubin = kernel.asm["cubin"]
|
||||
|
||||
# TODO: is this right?
|
||||
self.name = kernel.src.fn.__name__
|
||||
self.hash = kernel.hash
|
||||
if (
|
||||
kernel.__class__.launch_enter_hook is not None
|
||||
or kernel.__class__.launch_exit_hook is not None
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"We don't support launch enter or launch exit hooks"
|
||||
)
|
||||
self.num_warps = kernel.metadata.num_warps
|
||||
self.shared = (
|
||||
kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared
|
||||
)
|
||||
# When shared memory > 48 KB, triton allocates CUDA memory via both static and dynamic
|
||||
# memory allocation, which gets really complicated. We'll handle it later.
|
||||
# See triton/third-party/nvidia/driver.c in loadBinary
|
||||
if self.shared > MAX_SHARED_MEMORY:
|
||||
raise NotImplementedError(
|
||||
"Shared memory size > 48KB requires special triton handling"
|
||||
)
|
||||
|
||||
# Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel.
|
||||
# Inductor never uses this field or enables it, but we still have to pass an extra None
|
||||
# into the set of params if its enabled
|
||||
if hasattr(kernel.metadata, "global_scratch_size"):
|
||||
if kernel.metadata.global_scratch_size > 0:
|
||||
raise NotImplementedError("Global scratch not yet supported")
|
||||
else:
|
||||
self.has_global_scratch = True
|
||||
else:
|
||||
self.has_global_scratch = False
|
||||
|
||||
self.arg_tys, self.constant_idxs = self.arg_ty_from_signature(kernel.src)
|
||||
self.function: Optional[int] = (
|
||||
None # Loaded by load_kernel(on the parent process)
|
||||
)
|
||||
num_args = len(self.arg_tys)
|
||||
num_ctas = 1
|
||||
if hasattr(kernel, "num_ctas"):
|
||||
num_ctas = kernel.num_ctas
|
||||
elif hasattr(kernel, "metadata"):
|
||||
num_ctas = kernel.metadata.num_ctas
|
||||
|
||||
if num_ctas != 1:
|
||||
raise NotImplementedError(
|
||||
"Static cuda launcher only supports num_ctas == 1"
|
||||
)
|
||||
|
||||
if num_args > MAX_ARGS or num_args == 0:
|
||||
raise NotImplementedError(
|
||||
"No static cuda launcher available for %d arguments", num_args
|
||||
)
|
||||
|
||||
def load_kernel(self) -> None:
|
||||
from torch._C import _StaticCudaLauncher
|
||||
|
||||
assert hasattr(self, "cubin_path")
|
||||
if self.function is not None:
|
||||
return
|
||||
(self.function, self.n_regs, self.n_spills) = _StaticCudaLauncher._load_kernel(
|
||||
self.cubin_path, self.name, self.shared
|
||||
)
|
||||
|
||||
def write_cubin_to_file(self, filepath: str) -> None:
|
||||
"""
|
||||
Only used for tests where we don't have a cubin path.
|
||||
"""
|
||||
if hasattr(self, "cubin_path"):
|
||||
return
|
||||
# Just used by tests for now.
|
||||
# TODO: derive cubin_path from wherever triton stores the cubin file on disk.
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(self.cubin)
|
||||
del self.cubin
|
||||
self.cubin_path = filepath
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache
|
||||
def type_mappings() -> dict[str, str]:
|
||||
return {
|
||||
"i1": "i",
|
||||
"i8": "b",
|
||||
"i16": "h",
|
||||
"i32": "i",
|
||||
"i64": "l",
|
||||
"u1": "I",
|
||||
"u8": "B",
|
||||
"u16": "H",
|
||||
"u32": "I",
|
||||
"u64": "K",
|
||||
"fp16": "f",
|
||||
"bf16": "f",
|
||||
"fp32": "f",
|
||||
"f32": "f",
|
||||
"fp64": "d",
|
||||
# TODO handle nvTmaDesc/CUtensormap
|
||||
}
|
||||
|
||||
def extract_type(self, ty: str) -> str:
|
||||
"""
|
||||
Takes a triton type from CompiledKernel.signature and
|
||||
converts it into a single char encoding. _StaticCudaLauncher
|
||||
will switch on this char to figure out what type the underlying
|
||||
value should be passed to the triton kernel as.
|
||||
"""
|
||||
if ty[0] == "*":
|
||||
return "O"
|
||||
elif ty == "nvTmaDesc":
|
||||
raise NotImplementedError("nvTmaDesc kernels are not yet supported")
|
||||
return StaticallyLaunchedCudaKernel.type_mappings()[ty]
|
||||
|
||||
def arg_ty_from_signature(self, src: ASTSource) -> tuple[str, OrderedSet[int]]:
|
||||
def index_key(i: Any) -> int:
|
||||
return src.fn.arg_names.index(i) if isinstance(i, str) else i
|
||||
|
||||
signature = {index_key(key): value for key, value in src.signature.items()}
|
||||
constants = [index_key(key) for key in getattr(src, "constants", dict())]
|
||||
# Despite requiring them to be passed in, the triton CUDA launcher
|
||||
# completely ignores the constexprs passed into it when generating code.
|
||||
# So we can ignore them here too
|
||||
params = []
|
||||
|
||||
constant_idxs: OrderedSet[int] = OrderedSet()
|
||||
for i in sorted(signature.keys()):
|
||||
ty = signature[i]
|
||||
# In newer triton versions, constants are passed in to signature with type `constexpr`
|
||||
# In older triton versions, there can be constants in src.constants that are not `constexpr` in signature
|
||||
# so we check both here
|
||||
if ty == "constexpr" or i in constants:
|
||||
constant_idxs.add(i)
|
||||
else:
|
||||
params.append(self.extract_type(ty))
|
||||
return "".join(params), constant_idxs
|
||||
|
||||
def run(
|
||||
self, grid: tuple[int, ...], stream: int, *args: Unpack[tuple[object, ...]]
|
||||
) -> None:
|
||||
"""Actually run the kernel at runtime. This function is the hot codepath."""
|
||||
from torch._C import _StaticCudaLauncher
|
||||
|
||||
# Assert load_kernel() has been called and args match
|
||||
assert self.function is not None
|
||||
|
||||
# TODO: actually, if the args *don't* match, we probably should
|
||||
# throw an exception. But if inductor is the only one calling this
|
||||
# thing, it should always match.
|
||||
# Get rid of constants before passing to cubin launcher
|
||||
|
||||
# TODO: is this (and the check below) slow to do at runtime? The thing is,
|
||||
# we already spend the time in CachingAutotuner.launch() to massage the arguments
|
||||
# properly anyways so this isn't exactly slower than that...
|
||||
args = tuple(args[i] for i in range(len(args)) if i not in self.constant_idxs)
|
||||
|
||||
# Add a None if triton wants an extra parameter to the cubin
|
||||
if self.has_global_scratch:
|
||||
arg_tys = self.arg_tys + "O"
|
||||
args = (*args, None)
|
||||
else:
|
||||
arg_tys = self.arg_tys
|
||||
|
||||
assert len(args) == len(arg_tys)
|
||||
|
||||
# TODO: can handle grid functions here or in C++, so
|
||||
# that we don't need the grid handler above.
|
||||
grid_x = grid[0]
|
||||
grid_y = grid[1] if len(grid) > 1 else 1
|
||||
grid_z = grid[2] if len(grid) > 2 else 1
|
||||
_StaticCudaLauncher._launch_kernel(
|
||||
self.function,
|
||||
grid_x,
|
||||
grid_y,
|
||||
grid_z,
|
||||
self.num_warps,
|
||||
self.shared,
|
||||
arg_tys,
|
||||
args,
|
||||
stream,
|
||||
)
|
@ -112,6 +112,7 @@
|
||||
#include <ATen/ROCmFABackend.h>
|
||||
#include <ATen/cuda/CUDAConfig.h>
|
||||
#include <ATen/native/transformers/cuda/sdp_utils.h>
|
||||
#include <torch/csrc/inductor/static_cuda_launcher.h>
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#include <ATen/native/cudnn/hip/BatchNorm.h>
|
||||
#else
|
||||
@ -1882,6 +1883,9 @@ PyObject* initModule() {
|
||||
#ifdef USE_CUDA
|
||||
torch::cuda::initModule(module);
|
||||
#endif
|
||||
#if defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
ASSERT_TRUE(StaticCudaLauncher_init(module));
|
||||
#endif
|
||||
#ifdef USE_MPS
|
||||
torch::mps::initModule(module);
|
||||
#endif
|
||||
|
418
torch/csrc/inductor/static_cuda_launcher.cpp
Normal file
418
torch/csrc/inductor/static_cuda_launcher.cpp
Normal file
@ -0,0 +1,418 @@
|
||||
#if defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
// We disable this file from being hipified because there are CUDA drivers hip
|
||||
// has not implemented yet. Also, we're passing in a cubin file directly, so it
|
||||
// would take more work to support ROCM anyway.
|
||||
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/inductor/static_cuda_launcher.h>
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
#include <filesystem>
|
||||
#include <optional>
|
||||
/**
|
||||
Implements a static launcher for triton compiled CUDA kernels.
|
||||
Given a path to a cubin file, a function name, and some metadata,
|
||||
this class loads and launches the cubin.
|
||||
|
||||
Doing this avoids C++ codegen and compilation during compile, since we can
|
||||
use a statically compiled library to launch the kernel. To avoid mallocing
|
||||
for the arguments, we have a launcher for different numbers of arguments up
|
||||
to a max. StaticCudaLauncher only supports # of arguments up until 10 for
|
||||
now.
|
||||
|
||||
Note that we allocate 8 bytes per argument, no matter the types of each
|
||||
argument, since we don't know ahead of time what the types of each argument
|
||||
passed to the triton kernel are. This may take slightly more memory on the
|
||||
stack, and will require some benchmarking. However, since the vast majority
|
||||
of triton kernels have less than 10 args, this seems unlikely to be
|
||||
expensive.
|
||||
|
||||
This launcher is paired with StaticallyLaunchedCudaKernel in
|
||||
triton_heuristics.py.
|
||||
|
||||
TODO:
|
||||
- Handle CutensorMap, NvtmDesc
|
||||
- Handle launch_enter and launch_exit hooks (in python maybe?)
|
||||
*/
|
||||
|
||||
// Use ATen/NVRTC.h to gain access to the CUDA driver API.
|
||||
// This function is only called when CUDA is enabled, and only called to load
|
||||
// and launch triton compiled CUDA kernels, so CUDA should always be
|
||||
// initialized.
|
||||
namespace {
|
||||
const at::cuda::NVRTC& nvrtc() {
|
||||
return at::globalContext().getNVRTC();
|
||||
}
|
||||
|
||||
#define MAX_ARGS 50
|
||||
|
||||
CUdeviceptr getPointer(PyObject* obj) {
|
||||
CUdeviceptr data_ptr = 0;
|
||||
if (THPUtils_checkLong(obj)) {
|
||||
data_ptr = THPUtils_unpackUInt64(obj);
|
||||
return data_ptr;
|
||||
}
|
||||
if (obj == Py_None) {
|
||||
// valid nullptr
|
||||
return data_ptr;
|
||||
}
|
||||
auto ptr = THPObjectPtr{PyObject_GetAttrString(obj, "data_ptr")};
|
||||
TORCH_CHECK(
|
||||
ptr != nullptr,
|
||||
"Pointer argument must be either uint64 or have data_ptr method")
|
||||
auto empty_tuple = THPObjectPtr{PyTuple_New(0)};
|
||||
auto ret = THPObjectPtr{PyObject_Call(ptr, empty_tuple, nullptr)};
|
||||
TORCH_CHECK(
|
||||
THPUtils_checkLong(ret),
|
||||
"data_ptr method of Pointer object must return 64-bit int");
|
||||
data_ptr = THPUtils_unpackUInt64(ret);
|
||||
if (!data_ptr)
|
||||
return data_ptr;
|
||||
|
||||
CUdeviceptr dev_ptr = 0;
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuPointerGetAttribute(
|
||||
&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr));
|
||||
return dev_ptr;
|
||||
}
|
||||
|
||||
CUfunction loadKernel(
|
||||
std::string filePath,
|
||||
const std::string& funcName,
|
||||
uint32_t sharedMemBytes,
|
||||
const std::optional<std::string>& cubinDir = std::nullopt) {
|
||||
if (cubinDir) {
|
||||
std::filesystem::path p1{*cubinDir};
|
||||
std::filesystem::path p2{filePath};
|
||||
filePath = (p1 / p2.filename()).string();
|
||||
}
|
||||
|
||||
CUmodule mod = nullptr;
|
||||
CUfunction func = nullptr;
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoad(&mod, filePath.c_str()));
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
nvrtc().cuModuleGetFunction(&func, mod, funcName.c_str()));
|
||||
if (sharedMemBytes > 0) {
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncSetAttribute(
|
||||
func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, sharedMemBytes));
|
||||
}
|
||||
return func;
|
||||
}
|
||||
|
||||
template <size_t NUM_ARGS>
|
||||
inline void launchKernel(
|
||||
CUfunction func,
|
||||
uint32_t gridX,
|
||||
uint32_t gridY,
|
||||
uint32_t gridZ,
|
||||
uint32_t numWarps,
|
||||
uint32_t sharedMemBytes,
|
||||
std::array<void*, NUM_ARGS>& args,
|
||||
cudaStream_t stream) {
|
||||
// cta_args is always 1 for inductor generated triton kernels,
|
||||
// so we don't need to figure out grid dimension here
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
|
||||
func,
|
||||
gridX,
|
||||
gridY,
|
||||
gridZ,
|
||||
32 * numWarps, // blockDim.x
|
||||
1, // blockDim.y
|
||||
1, // blockDim.z
|
||||
sharedMemBytes,
|
||||
stream,
|
||||
args.data(),
|
||||
nullptr));
|
||||
}
|
||||
|
||||
template <typename FINAL, typename F>
|
||||
void convertType(F converter, const char* name, void* slot, PyObject* item) {
|
||||
auto temp = converter(item);
|
||||
if (PyErr_Occurred()) {
|
||||
std::string msg = "Failed to convert argument to ";
|
||||
msg += name;
|
||||
TORCH_CHECK(false, msg);
|
||||
}
|
||||
*reinterpret_cast<FINAL*>(slot) = static_cast<FINAL>(temp);
|
||||
}
|
||||
|
||||
/**
|
||||
Given a list of args and their types (in a string), along with two stack
|
||||
allocated arrays, puts each argument arg_{i} into argStorage[i], and a
|
||||
pointer to the argument in kernelArgs[i]. We then can pass `kernelArgs`
|
||||
directly to launchKernel. Note that some args can be less than 8 bytes, but
|
||||
we'll still allocate 8 bytes on the stack for them.
|
||||
|
||||
* TODO: Need to handle NvtmDesc here.
|
||||
*/
|
||||
void parseKernelArgs(
|
||||
PyObject* varArgs,
|
||||
const char* argTypes,
|
||||
uint64_t* argStorage,
|
||||
void** kernelArgs) {
|
||||
int numKernelArgs = static_cast<int>(std::strlen(argTypes));
|
||||
TORCH_CHECK(
|
||||
PyTuple_Check(varArgs), "Kernel arguments must be provided as a tuple");
|
||||
TORCH_CHECK(
|
||||
PyTuple_Size(varArgs) == static_cast<Py_ssize_t>(numKernelArgs),
|
||||
"Mismatch between number of argument types and provided arguments");
|
||||
|
||||
for (int i = 0; i < numKernelArgs; ++i) {
|
||||
// Get pointer to the ith 8-byte slot.
|
||||
void* slot = static_cast<void*>(&argStorage[i]);
|
||||
PyObject* item = PyTuple_GetItem(varArgs, i);
|
||||
char typeChar = argTypes[i];
|
||||
switch (typeChar) {
|
||||
case 'b':
|
||||
convertType<int8_t>(THPUtils_unpackInt, "int8", slot, item);
|
||||
break;
|
||||
case 'h':
|
||||
convertType<int16_t>(THPUtils_unpackInt, "int16", slot, item);
|
||||
break;
|
||||
case 'i':
|
||||
convertType<int32_t>(THPUtils_unpackLong, "int32", slot, item);
|
||||
break;
|
||||
case 'l':
|
||||
convertType<int64_t>(THPUtils_unpackLong, "int64", slot, item);
|
||||
break;
|
||||
case 'B':
|
||||
convertType<uint8_t>(THPUtils_unpackUInt32, "uint8", slot, item);
|
||||
break;
|
||||
case 'H':
|
||||
convertType<uint16_t>(THPUtils_unpackUInt32, "uint16", slot, item);
|
||||
break;
|
||||
case 'I':
|
||||
convertType<uint32_t>(THPUtils_unpackUInt32, "uint32", slot, item);
|
||||
break;
|
||||
case 'K':
|
||||
convertType<uint64_t>(THPUtils_unpackUInt64, "uint64", slot, item);
|
||||
break;
|
||||
case 'f':
|
||||
convertType<float>(THPUtils_unpackDouble, "float", slot, item);
|
||||
break;
|
||||
case 'd':
|
||||
convertType<double>(THPUtils_unpackDouble, "double", slot, item);
|
||||
break;
|
||||
case 'O': { // pointer; using helper getPointer() (which may call
|
||||
// data_ptr() if needed)
|
||||
CUdeviceptr ptr = getPointer(item);
|
||||
*reinterpret_cast<CUdeviceptr*>(slot) = ptr;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(false, "Unknown type passed in: ", typeChar);
|
||||
}
|
||||
// Save the pointer to this slot.
|
||||
kernelArgs[i] = slot;
|
||||
}
|
||||
}
|
||||
|
||||
/* Load the CUDA kernel into memory (called during torch.compile), and
|
||||
return a pointer to it (along with nregs and nspills).
|
||||
Called in python as:
|
||||
(function, n_regs, n_spills) = load_kernel(cubin_path, func_name,
|
||||
sharedMemBytes)
|
||||
*/
|
||||
PyObject* load_kernel(PyObject* self, PyObject* args) {
|
||||
const char* filePath = nullptr;
|
||||
const char* funcName = nullptr;
|
||||
int sharedMemBytes = 0;
|
||||
int n_regs = 0;
|
||||
int n_spills = 0;
|
||||
if (!PyArg_ParseTuple(args, "ssi", &filePath, &funcName, &sharedMemBytes)) {
|
||||
return nullptr;
|
||||
}
|
||||
CUfunction func = nullptr;
|
||||
func = loadKernel(filePath, funcName, sharedMemBytes);
|
||||
// Taken from triton/nvidia/backend/driver.c
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
nvrtc().cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, func));
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncGetAttribute(
|
||||
&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func));
|
||||
n_spills /= 4;
|
||||
// Return a tuple of CUFunction, n_regs, n_spills
|
||||
return Py_BuildValue(
|
||||
"(Kii)", reinterpret_cast<uint64_t>(func), n_regs, n_spills);
|
||||
}
|
||||
|
||||
PyObject* launch_kernel_inner(
|
||||
CUfunction func,
|
||||
int gridX,
|
||||
int gridY,
|
||||
int gridZ,
|
||||
int numWarps,
|
||||
int sharedMemBytes,
|
||||
const char* argTypes,
|
||||
PyObject* varArgs,
|
||||
cudaStream_t cudaStream) {
|
||||
// Launch the kernel
|
||||
// Prepare the arguments for the kernel
|
||||
// We allocate 8 bytes per argument on the stack. We then allocate 8 more
|
||||
// bytes to point to each 8 byte slot in argStorage, and pass that array of
|
||||
// pointers to launchKernel.
|
||||
std::array<uint64_t, MAX_ARGS> argStorage = {};
|
||||
std::array<void*, MAX_ARGS> kernelArgs = {};
|
||||
parseKernelArgs(varArgs, argTypes, argStorage.data(), kernelArgs.data());
|
||||
|
||||
launchKernel(
|
||||
func,
|
||||
gridX,
|
||||
gridY,
|
||||
gridZ,
|
||||
numWarps,
|
||||
sharedMemBytes,
|
||||
kernelArgs,
|
||||
cudaStream);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Main entrypoint function called at runtime; called like this in python land:
|
||||
launcher(
|
||||
function, # CUfunction returned by load_kernel()
|
||||
grid_x,
|
||||
grid_y,
|
||||
grid_z,
|
||||
num_warps,
|
||||
shared,
|
||||
arg_tys, # e.g. "bO" for (int8_t, uint64_t)
|
||||
args, # tuple of arguments passed to the kernel
|
||||
stream,
|
||||
)
|
||||
*
|
||||
*/
|
||||
PyObject* launch_kernel(PyObject* self, PyObject* args) {
|
||||
// Pointer to CUfunction generated by load_kernel()
|
||||
uint64_t func_ptr = 0;
|
||||
int gridX = 0, gridY = 0, gridZ = 0, numWarps = 0, sharedMemBytes = 0;
|
||||
// stream here should be the raw stream gotten from
|
||||
// device_interface.get_raw_stream()
|
||||
uint64_t stream = 0;
|
||||
const char* argTypes = nullptr;
|
||||
PyObject* varArgs = nullptr;
|
||||
// Parse the fixed arguments and the format string
|
||||
if (!PyArg_ParseTuple(
|
||||
args,
|
||||
"KiiiiisOl",
|
||||
&func_ptr,
|
||||
&gridX,
|
||||
&gridY,
|
||||
&gridZ,
|
||||
&numWarps,
|
||||
&sharedMemBytes,
|
||||
&argTypes,
|
||||
&varArgs,
|
||||
&stream)) {
|
||||
return nullptr;
|
||||
}
|
||||
CUfunction func = reinterpret_cast<CUfunction>(func_ptr); // NOLINT
|
||||
cudaStream_t cudaStream = reinterpret_cast<cudaStream_t>(stream); // NOLINT
|
||||
auto num_args = std::strlen(argTypes);
|
||||
TORCH_CHECK(
|
||||
num_args <= MAX_ARGS,
|
||||
"Static Cuda Launcher only supports up to 50 arguments");
|
||||
return launch_kernel_inner(
|
||||
func,
|
||||
gridX,
|
||||
gridY,
|
||||
gridZ,
|
||||
numWarps,
|
||||
sharedMemBytes,
|
||||
argTypes,
|
||||
varArgs,
|
||||
cudaStream);
|
||||
}
|
||||
|
||||
std::array<PyMethodDef, 2> StaticCudaLauncherMethods = {
|
||||
PyMethodDef{
|
||||
"_launch_kernel",
|
||||
(PyCFunction)launch_kernel,
|
||||
METH_VARARGS,
|
||||
"Cuda Launcher with up to 50 args"},
|
||||
PyMethodDef{
|
||||
"_load_kernel",
|
||||
(PyCFunction)load_kernel,
|
||||
METH_VARARGS,
|
||||
"Load CUDA kernel from cubin file"}};
|
||||
|
||||
// Define a minimal type for StaticCudaLauncher.
|
||||
// We don't implement __new__ or __init__ because we're using it only as a
|
||||
// container for static methods.
|
||||
PyTypeObject StaticCudaLauncherType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
"torch._C._StaticCudaLauncher", // tp_name
|
||||
sizeof(PyObject), // tp_basicsize
|
||||
0, // tp_itemsize
|
||||
nullptr, // tp_dealloc
|
||||
0, // tp_print (deprecated)
|
||||
nullptr, // tp_getattr
|
||||
nullptr, // tp_setattr
|
||||
nullptr, // tp_reserved
|
||||
nullptr, // tp_repr
|
||||
nullptr, // tp_as_number
|
||||
nullptr, // tp_as_sequence
|
||||
nullptr, // tp_as_mapping
|
||||
nullptr, // tp_hash
|
||||
nullptr, // tp_call
|
||||
nullptr, // tp_str
|
||||
nullptr, // tp_getattro
|
||||
nullptr, // tp_setattro
|
||||
nullptr, // tp_as_buffer
|
||||
Py_TPFLAGS_DEFAULT,
|
||||
"Statically defined launchers for triton compiled CUDA kernels", // tp_doc
|
||||
nullptr, // tp_traverse
|
||||
nullptr, // tp_clear
|
||||
nullptr, // tp_richcompare
|
||||
0, // tp_weaklistoffset
|
||||
nullptr, // tp_iter
|
||||
nullptr, // tp_iternext
|
||||
nullptr, // tp_methods
|
||||
nullptr, // tp_members
|
||||
nullptr, // tp_getset
|
||||
nullptr, // tp_base
|
||||
nullptr, // tp_dict (automatically allocated)
|
||||
nullptr, // tp_descr_get
|
||||
nullptr, // tp_descr_set
|
||||
0, // tp_dictoffset
|
||||
nullptr, // tp_init
|
||||
nullptr, // tp_alloc
|
||||
nullptr, // tp_new
|
||||
};
|
||||
} // anonymous namespace
|
||||
// Module initialization: add StaticCudaLauncher to the module with our static
|
||||
// methods.
|
||||
bool StaticCudaLauncher_init(PyObject* module) {
|
||||
if (PyType_Ready(&StaticCudaLauncherType) < 0) {
|
||||
return false;
|
||||
}
|
||||
// Add our static methods to the type's dictionary.
|
||||
PyObject* dict = StaticCudaLauncherType.tp_dict;
|
||||
for (auto& def : StaticCudaLauncherMethods) {
|
||||
PyObject* func = PyCFunction_New(&def, nullptr);
|
||||
if (!func) {
|
||||
return false;
|
||||
}
|
||||
PyObject* static_method = PyStaticMethod_New(func);
|
||||
Py_DECREF(func);
|
||||
if (PyDict_SetItemString(dict, def.ml_name, static_method) < 0) {
|
||||
Py_DECREF(static_method);
|
||||
return false;
|
||||
}
|
||||
Py_DECREF(static_method);
|
||||
}
|
||||
Py_INCREF(&StaticCudaLauncherType);
|
||||
if (PyModule_AddObject(
|
||||
module, "_StaticCudaLauncher", (PyObject*)&StaticCudaLauncherType) <
|
||||
0) {
|
||||
Py_DECREF(&StaticCudaLauncherType);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif
|
7
torch/csrc/inductor/static_cuda_launcher.h
Normal file
7
torch/csrc/inductor/static_cuda_launcher.h
Normal file
@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
#if defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
#include <torch/csrc/inductor/cpp_wrapper/device_internal/cuda.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
bool StaticCudaLauncher_init(PyObject* module);
|
||||
#endif
|
Reference in New Issue
Block a user