mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[aot_inductor] move CudaWrapperCodeGen into a separate file (#119870)
This reverts commit 3ab08946d5052eaeda11d683d6a58e801a032755. Differential Revision: [D53817852](https://our.internmc.facebook.com/intern/diff/D53817852) Pull Request resolved: https://github.com/pytorch/pytorch/pull/119870 Approved by: https://github.com/khabinov
This commit is contained in:
committed by
PyTorch MergeBot
parent
8f9f12c068
commit
78c9b2948a
@ -200,7 +200,10 @@ hipify_python.hipify(
|
|||||||
output_directory=out_dir,
|
output_directory=out_dir,
|
||||||
includes=includes,
|
includes=includes,
|
||||||
ignores=ignores,
|
ignores=ignores,
|
||||||
extra_files=["torch/_inductor/codegen/wrapper.py"],
|
extra_files=[
|
||||||
|
"torch/_inductor/codegen/cpp_wrapper_cuda.py",
|
||||||
|
"torch/_inductor/codegen/wrapper.py",
|
||||||
|
],
|
||||||
out_of_place_only=args.out_of_place_only,
|
out_of_place_only=args.out_of_place_only,
|
||||||
hip_clang_launch=is_hip_clang(),
|
hip_clang_launch=is_hip_clang(),
|
||||||
)
|
)
|
||||||
|
299
torch/_inductor/codegen/cpp_wrapper_cuda.py
Normal file
299
torch/_inductor/codegen/cpp_wrapper_cuda.py
Normal file
@ -0,0 +1,299 @@
|
|||||||
|
import functools
|
||||||
|
import os
|
||||||
|
from itertools import chain, count
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
import sympy
|
||||||
|
|
||||||
|
from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
|
||||||
|
|
||||||
|
from .. import config
|
||||||
|
from ..codecache import CudaKernelParamCache
|
||||||
|
from ..triton_heuristics import grid as default_grid
|
||||||
|
from ..virtualized import V
|
||||||
|
from .wrapper import CppWrapperCodeGen, SymbolicCallArg
|
||||||
|
|
||||||
|
|
||||||
|
def is_int(s: str) -> bool:
|
||||||
|
# Cpp code gen adds L at the end of ints
|
||||||
|
# Lets remove it for checking whether we have an int or not
|
||||||
|
if s and s[-1] == "L":
|
||||||
|
s = s[:-1]
|
||||||
|
try:
|
||||||
|
int(s)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
except TypeError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_float(s: str) -> bool:
|
||||||
|
try:
|
||||||
|
float(s)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class CudaWrapperCodeGen(CppWrapperCodeGen):
|
||||||
|
"""
|
||||||
|
Generates cpp wrapper for running on GPU and calls CUDA kernels
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.grid_id = count()
|
||||||
|
self.cuda = True
|
||||||
|
|
||||||
|
def write_header(self):
|
||||||
|
if V.graph.is_const_graph:
|
||||||
|
# We do not write header for constant graph, it will be written by main module.
|
||||||
|
return
|
||||||
|
|
||||||
|
super().write_header()
|
||||||
|
|
||||||
|
self.header.splice("#include <filesystem>")
|
||||||
|
if config.abi_compatible:
|
||||||
|
self.header.splice(
|
||||||
|
"#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.header.splice(
|
||||||
|
"""
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <c10/cuda/CUDAStream.h>
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self.header.splice(
|
||||||
|
"""
|
||||||
|
#define CUDA_DRIVER_CHECK(EXPR) \\
|
||||||
|
do { \\
|
||||||
|
CUresult code = EXPR; \\
|
||||||
|
const char *msg; \\
|
||||||
|
cuGetErrorString(code, &msg); \\
|
||||||
|
if (code != CUDA_SUCCESS) { \\
|
||||||
|
throw std::runtime_error( \\
|
||||||
|
std::string("CUDA driver error: ") + \\
|
||||||
|
std::string(msg)); \\
|
||||||
|
} \\
|
||||||
|
} while (0);
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct Grid {
|
||||||
|
Grid(uint32_t x, uint32_t y, uint32_t z)
|
||||||
|
: grid_x(x), grid_y(y), grid_z(z) {}
|
||||||
|
uint32_t grid_x;
|
||||||
|
uint32_t grid_y;
|
||||||
|
uint32_t grid_z;
|
||||||
|
|
||||||
|
bool is_non_zero() {
|
||||||
|
return grid_x > 0 && grid_y > 0 && grid_z > 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
static inline CUfunction loadKernel(
|
||||||
|
std::string filePath,
|
||||||
|
const std::string &funcName,
|
||||||
|
uint32_t sharedMemBytes,
|
||||||
|
const std::optional<std::string> &cubinDir = std::nullopt) {
|
||||||
|
if (cubinDir) {
|
||||||
|
std::filesystem::path p1{*cubinDir};
|
||||||
|
std::filesystem::path p2{filePath};
|
||||||
|
filePath = (p1 / p2.filename()).string();
|
||||||
|
}
|
||||||
|
|
||||||
|
CUmodule mod;
|
||||||
|
CUfunction func;
|
||||||
|
CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
|
||||||
|
CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
|
||||||
|
if (sharedMemBytes > 0) {
|
||||||
|
CUDA_DRIVER_CHECK(cuFuncSetAttribute(
|
||||||
|
func,
|
||||||
|
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||||
|
sharedMemBytes
|
||||||
|
))
|
||||||
|
}
|
||||||
|
return func;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void launchKernel(
|
||||||
|
CUfunction func,
|
||||||
|
uint32_t gridX,
|
||||||
|
uint32_t gridY,
|
||||||
|
uint32_t gridZ,
|
||||||
|
uint32_t numWarps,
|
||||||
|
uint32_t sharedMemBytes,
|
||||||
|
void* args[],
|
||||||
|
cudaStream_t stream) {
|
||||||
|
CUDA_DRIVER_CHECK(cuLaunchKernel(
|
||||||
|
func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
|
||||||
|
));
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
def write_get_raw_stream(self, index):
|
||||||
|
name = f"stream{index}"
|
||||||
|
self.writeline(
|
||||||
|
f"cudaStream_t {name} = at::cuda::getCurrentCUDAStream({index});"
|
||||||
|
)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def define_kernel(
|
||||||
|
self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True
|
||||||
|
):
|
||||||
|
if not cuda:
|
||||||
|
return super().define_kernel(name, kernel, metadata, cuda)
|
||||||
|
|
||||||
|
def generate(self, is_inference):
|
||||||
|
self.prefix.writeline("\n")
|
||||||
|
if not V.graph.aot_mode:
|
||||||
|
for kernel in chain(
|
||||||
|
self.src_to_kernel.values(), self.user_defined_kernel_cache.values()
|
||||||
|
):
|
||||||
|
self.prefix.writeline(f"static CUfunction {kernel} = nullptr;")
|
||||||
|
self.prefix.writeline("\n")
|
||||||
|
return super().generate(is_inference)
|
||||||
|
|
||||||
|
@functools.lru_cache(None)
|
||||||
|
def generate_load_kernel_once(
|
||||||
|
self, name: str, mangled_name: str, cubin_path: str, shared_mem: int
|
||||||
|
):
|
||||||
|
if V.graph.aot_mode:
|
||||||
|
self.writeline(f"if (kernels.{name} == nullptr) {{")
|
||||||
|
self.writeline(
|
||||||
|
f""" kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);"""
|
||||||
|
)
|
||||||
|
self.writeline("}")
|
||||||
|
else:
|
||||||
|
self.writeline(f"if ({name} == nullptr) {{")
|
||||||
|
self.writeline(
|
||||||
|
f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});"""
|
||||||
|
)
|
||||||
|
self.writeline("}")
|
||||||
|
|
||||||
|
def generate_args_decl(self, call_args):
|
||||||
|
dynamic_symbols = V.graph.sizevars.free_symbols()
|
||||||
|
# TODO: only works for constant now, need type info
|
||||||
|
new_args = []
|
||||||
|
for arg in call_args:
|
||||||
|
var_name = f"var_{next(self.arg_var_id)}"
|
||||||
|
if isinstance(arg, (sympy.Integer, sympy.Symbol, SymbolicCallArg)):
|
||||||
|
self.writeline(f"auto {var_name} = {arg};")
|
||||||
|
elif isinstance(arg, sympy.Expr):
|
||||||
|
self.writeline(f"auto {var_name} = {self.expr_printer(arg)};")
|
||||||
|
elif is_int(arg):
|
||||||
|
self.writeline(f"int {var_name} = {arg};")
|
||||||
|
elif is_float(arg):
|
||||||
|
self.writeline(f"float {var_name} = {arg};")
|
||||||
|
elif any(str(arg) == s.name for s in dynamic_symbols):
|
||||||
|
self.writeline(f"auto {var_name} = {arg};")
|
||||||
|
elif arg == "nullptr":
|
||||||
|
self.writeline(f"auto {var_name} = nullptr;")
|
||||||
|
elif arg == "c10::nullopt":
|
||||||
|
self.writeline(f"auto {var_name} = c10::nullopt;")
|
||||||
|
else:
|
||||||
|
if config.abi_compatible:
|
||||||
|
self.writeline(f"CUdeviceptr {var_name};")
|
||||||
|
self.writeline(
|
||||||
|
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.writeline(
|
||||||
|
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
|
||||||
|
)
|
||||||
|
new_args.append(f"&{var_name}")
|
||||||
|
|
||||||
|
return ", ".join(new_args)
|
||||||
|
|
||||||
|
def generate_default_grid(self, name: str, grid: List[Any], cuda: bool = True):
|
||||||
|
"""
|
||||||
|
Generate grid configs for launching a CUDA kernel using the grid
|
||||||
|
function from triton_heuristics.
|
||||||
|
"""
|
||||||
|
if not cuda:
|
||||||
|
return grid
|
||||||
|
assert isinstance(grid, list), f"expected {grid=} to be a list"
|
||||||
|
grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid]
|
||||||
|
grid_fn = default_grid(*grid)
|
||||||
|
params = CudaKernelParamCache.get(name)
|
||||||
|
assert (
|
||||||
|
params is not None
|
||||||
|
), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}"
|
||||||
|
block_cfg = {
|
||||||
|
"XBLOCK": params["x_block"],
|
||||||
|
"YBLOCK": params["y_block"],
|
||||||
|
"ZBLOCK": params["z_block"],
|
||||||
|
}
|
||||||
|
return grid_fn(block_cfg)
|
||||||
|
|
||||||
|
def generate_kernel_call(
|
||||||
|
self,
|
||||||
|
name,
|
||||||
|
call_args,
|
||||||
|
grid=None,
|
||||||
|
device_index=None,
|
||||||
|
cuda=True,
|
||||||
|
triton=True,
|
||||||
|
arg_types=None,
|
||||||
|
grid_fn: str = "grid",
|
||||||
|
):
|
||||||
|
if not cuda:
|
||||||
|
# Even in CudaWrapperCodeGen, we may see cpp kernels
|
||||||
|
return super().generate_kernel_call(
|
||||||
|
name, call_args, grid, device_index, cuda, triton, arg_types
|
||||||
|
)
|
||||||
|
|
||||||
|
params = CudaKernelParamCache.get(name)
|
||||||
|
assert (
|
||||||
|
params is not None
|
||||||
|
), f"cuda kernel parameters for {name} should already exist at this moment"
|
||||||
|
mangled_name = params.get("mangled_name", None)
|
||||||
|
assert mangled_name is not None, "missing mangled_name"
|
||||||
|
cubin_path = params.get(get_cpp_wrapper_cubin_path_name(), None)
|
||||||
|
assert cubin_path is not None and os.path.exists(
|
||||||
|
cubin_path
|
||||||
|
), f"cubin file should already exist at this moment: {cubin_path}"
|
||||||
|
shared_mem = params.get("shared_mem", 0)
|
||||||
|
|
||||||
|
self.generate_load_kernel_once(name, mangled_name, cubin_path, shared_mem)
|
||||||
|
|
||||||
|
call_args = self.generate_args_decl(call_args)
|
||||||
|
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
|
||||||
|
self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};")
|
||||||
|
stream = (
|
||||||
|
"stream" if V.graph.aot_mode else self.write_get_raw_stream(device_index)
|
||||||
|
)
|
||||||
|
grid_name = f"{name}_grid_{next(self.grid_id)}"
|
||||||
|
assert isinstance(
|
||||||
|
grid, (list, tuple)
|
||||||
|
), f"expected grid to be a list or tuple but got: {grid=}"
|
||||||
|
|
||||||
|
grid = [V.graph.sizevars.simplify(item) for item in grid]
|
||||||
|
grid_uses_symbolic_shapes = any(item.free_symbols for item in grid)
|
||||||
|
grid_args = [self.grid_expr_printer(item) for item in grid]
|
||||||
|
grid_args_str = ", ".join(grid_args)
|
||||||
|
self.writeline(f"Grid {grid_name} = Grid({grid_args_str});")
|
||||||
|
|
||||||
|
if grid_uses_symbolic_shapes:
|
||||||
|
self.writeline(f"if ({grid_name}.is_non_zero()) {{")
|
||||||
|
kernel_var_name = f"kernels.{name}" if V.graph.aot_mode else name
|
||||||
|
self.writeline(
|
||||||
|
"launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format(
|
||||||
|
kernel_var_name,
|
||||||
|
f"{grid_name}.grid_x",
|
||||||
|
f"{grid_name}.grid_y",
|
||||||
|
f"{grid_name}.grid_z",
|
||||||
|
params["num_warps"],
|
||||||
|
params["shared_mem"],
|
||||||
|
kernel_args_var,
|
||||||
|
stream,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if grid_uses_symbolic_shapes:
|
||||||
|
self.writeline("}")
|
@ -7,7 +7,7 @@ import operator
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from itertools import chain, count
|
from itertools import count
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
@ -27,7 +27,6 @@ from sympy import Expr
|
|||||||
import torch
|
import torch
|
||||||
import torch._ops
|
import torch._ops
|
||||||
from torch._dynamo.utils import counters, dynamo_timed
|
from torch._dynamo.utils import counters, dynamo_timed
|
||||||
from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
|
|
||||||
|
|
||||||
from torch._inductor.codegen.multi_kernel import MultiKernelState
|
from torch._inductor.codegen.multi_kernel import MultiKernelState
|
||||||
from torch.fx.experimental.symbolic_shapes import SymTypes
|
from torch.fx.experimental.symbolic_shapes import SymTypes
|
||||||
@ -37,7 +36,6 @@ from torch.utils._sympy.singleton_int import SingletonInt
|
|||||||
from .. import codecache, config, ir
|
from .. import codecache, config, ir
|
||||||
from ..codecache import CudaKernelParamCache
|
from ..codecache import CudaKernelParamCache
|
||||||
from ..ir import ReinterpretView
|
from ..ir import ReinterpretView
|
||||||
from ..triton_heuristics import grid as default_grid
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
cache_on_self,
|
cache_on_self,
|
||||||
get_benchmark_name,
|
get_benchmark_name,
|
||||||
@ -70,28 +68,6 @@ def buffer_reuse_key(node: ir.Buffer) -> ReuseKey:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_int(s: str) -> bool:
|
|
||||||
# Cpp code gen adds L at the end of ints
|
|
||||||
# Lets remove it for checking whether we have an int or not
|
|
||||||
if s and s[-1] == "L":
|
|
||||||
s = s[:-1]
|
|
||||||
try:
|
|
||||||
int(s)
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
except TypeError:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def is_float(s: str) -> bool:
|
|
||||||
try:
|
|
||||||
float(s)
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def convert_arg_type(arg: torch.Argument) -> str:
|
def convert_arg_type(arg: torch.Argument) -> str:
|
||||||
from .cpp import CONTAINER_PYTHON_TO_CPP, PYTHON_TO_CPP
|
from .cpp import CONTAINER_PYTHON_TO_CPP, PYTHON_TO_CPP
|
||||||
|
|
||||||
@ -3094,266 +3070,3 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
return repr(val)
|
return repr(val)
|
||||||
|
|
||||||
|
|
||||||
class CudaWrapperCodeGen(CppWrapperCodeGen):
|
|
||||||
"""
|
|
||||||
Generates cpp wrapper for running on GPU and calls CUDA kernels
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.grid_id = count()
|
|
||||||
self.cuda = True
|
|
||||||
|
|
||||||
def write_header(self):
|
|
||||||
if V.graph.is_const_graph:
|
|
||||||
# We do not write header for constant graph, it will be written by main module.
|
|
||||||
return
|
|
||||||
|
|
||||||
super().write_header()
|
|
||||||
|
|
||||||
self.header.splice("#include <filesystem>")
|
|
||||||
if config.abi_compatible:
|
|
||||||
self.header.splice(
|
|
||||||
"#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.header.splice(
|
|
||||||
"""
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
#include <c10/cuda/CUDAStream.h>
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self.header.splice(
|
|
||||||
"""
|
|
||||||
#define CUDA_DRIVER_CHECK(EXPR) \\
|
|
||||||
do { \\
|
|
||||||
CUresult code = EXPR; \\
|
|
||||||
const char *msg; \\
|
|
||||||
cuGetErrorString(code, &msg); \\
|
|
||||||
if (code != CUDA_SUCCESS) { \\
|
|
||||||
throw std::runtime_error( \\
|
|
||||||
std::string("CUDA driver error: ") + \\
|
|
||||||
std::string(msg)); \\
|
|
||||||
} \\
|
|
||||||
} while (0);
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
struct Grid {
|
|
||||||
Grid(uint32_t x, uint32_t y, uint32_t z)
|
|
||||||
: grid_x(x), grid_y(y), grid_z(z) {}
|
|
||||||
uint32_t grid_x;
|
|
||||||
uint32_t grid_y;
|
|
||||||
uint32_t grid_z;
|
|
||||||
|
|
||||||
bool is_non_zero() {
|
|
||||||
return grid_x > 0 && grid_y > 0 && grid_z > 0;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
static inline CUfunction loadKernel(
|
|
||||||
std::string filePath,
|
|
||||||
const std::string &funcName,
|
|
||||||
uint32_t sharedMemBytes,
|
|
||||||
const std::optional<std::string> &cubinDir = std::nullopt) {
|
|
||||||
if (cubinDir) {
|
|
||||||
std::filesystem::path p1{*cubinDir};
|
|
||||||
std::filesystem::path p2{filePath};
|
|
||||||
filePath = (p1 / p2.filename()).string();
|
|
||||||
}
|
|
||||||
|
|
||||||
CUmodule mod;
|
|
||||||
CUfunction func;
|
|
||||||
CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
|
|
||||||
CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
|
|
||||||
if (sharedMemBytes > 0) {
|
|
||||||
CUDA_DRIVER_CHECK(cuFuncSetAttribute(
|
|
||||||
func,
|
|
||||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
|
||||||
sharedMemBytes
|
|
||||||
))
|
|
||||||
}
|
|
||||||
return func;
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline void launchKernel(
|
|
||||||
CUfunction func,
|
|
||||||
uint32_t gridX,
|
|
||||||
uint32_t gridY,
|
|
||||||
uint32_t gridZ,
|
|
||||||
uint32_t numWarps,
|
|
||||||
uint32_t sharedMemBytes,
|
|
||||||
void* args[],
|
|
||||||
cudaStream_t stream) {
|
|
||||||
CUDA_DRIVER_CHECK(cuLaunchKernel(
|
|
||||||
func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
|
|
||||||
));
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
def write_get_raw_stream(self, index):
|
|
||||||
name = f"stream{index}"
|
|
||||||
self.writeline(
|
|
||||||
f"cudaStream_t {name} = at::cuda::getCurrentCUDAStream({index});"
|
|
||||||
)
|
|
||||||
return name
|
|
||||||
|
|
||||||
def define_kernel(
|
|
||||||
self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True
|
|
||||||
):
|
|
||||||
if not cuda:
|
|
||||||
return super().define_kernel(name, kernel, metadata, cuda)
|
|
||||||
|
|
||||||
def generate(self, is_inference):
|
|
||||||
self.prefix.writeline("\n")
|
|
||||||
if not V.graph.aot_mode:
|
|
||||||
for kernel in chain(
|
|
||||||
self.src_to_kernel.values(), self.user_defined_kernel_cache.values()
|
|
||||||
):
|
|
||||||
self.prefix.writeline(f"static CUfunction {kernel} = nullptr;")
|
|
||||||
self.prefix.writeline("\n")
|
|
||||||
return super().generate(is_inference)
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
|
||||||
def generate_load_kernel_once(
|
|
||||||
self, name: str, mangled_name: str, cubin_path: str, shared_mem: int
|
|
||||||
):
|
|
||||||
if V.graph.aot_mode:
|
|
||||||
self.writeline(f"if (kernels.{name} == nullptr) {{")
|
|
||||||
self.writeline(
|
|
||||||
f""" kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);"""
|
|
||||||
)
|
|
||||||
self.writeline("}")
|
|
||||||
else:
|
|
||||||
self.writeline(f"if ({name} == nullptr) {{")
|
|
||||||
self.writeline(
|
|
||||||
f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});"""
|
|
||||||
)
|
|
||||||
self.writeline("}")
|
|
||||||
|
|
||||||
def generate_args_decl(self, call_args):
|
|
||||||
dynamic_symbols = V.graph.sizevars.free_symbols()
|
|
||||||
# TODO: only works for constant now, need type info
|
|
||||||
new_args = []
|
|
||||||
for arg in call_args:
|
|
||||||
var_name = f"var_{next(self.arg_var_id)}"
|
|
||||||
if isinstance(arg, (sympy.Integer, sympy.Symbol, SymbolicCallArg)):
|
|
||||||
self.writeline(f"auto {var_name} = {arg};")
|
|
||||||
elif isinstance(arg, sympy.Expr):
|
|
||||||
self.writeline(f"auto {var_name} = {self.expr_printer(arg)};")
|
|
||||||
elif is_int(arg):
|
|
||||||
self.writeline(f"int {var_name} = {arg};")
|
|
||||||
elif is_float(arg):
|
|
||||||
self.writeline(f"float {var_name} = {arg};")
|
|
||||||
elif any(str(arg) == s.name for s in dynamic_symbols):
|
|
||||||
self.writeline(f"auto {var_name} = {arg};")
|
|
||||||
elif arg == "nullptr":
|
|
||||||
self.writeline(f"auto {var_name} = nullptr;")
|
|
||||||
elif arg == "c10::nullopt":
|
|
||||||
self.writeline(f"auto {var_name} = c10::nullopt;")
|
|
||||||
else:
|
|
||||||
if config.abi_compatible:
|
|
||||||
self.writeline(f"CUdeviceptr {var_name};")
|
|
||||||
self.writeline(
|
|
||||||
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.writeline(
|
|
||||||
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
|
|
||||||
)
|
|
||||||
new_args.append(f"&{var_name}")
|
|
||||||
|
|
||||||
return ", ".join(new_args)
|
|
||||||
|
|
||||||
def generate_default_grid(self, name: str, grid: List[Any], cuda: bool = True):
|
|
||||||
"""
|
|
||||||
Generate grid configs for launching a CUDA kernel using the grid
|
|
||||||
function from triton_heuristics.
|
|
||||||
"""
|
|
||||||
if not cuda:
|
|
||||||
return grid
|
|
||||||
assert isinstance(grid, list), f"expected {grid=} to be a list"
|
|
||||||
grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid]
|
|
||||||
grid_fn = default_grid(*grid)
|
|
||||||
params = CudaKernelParamCache.get(name)
|
|
||||||
assert (
|
|
||||||
params is not None
|
|
||||||
), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}"
|
|
||||||
block_cfg = {
|
|
||||||
"XBLOCK": params["x_block"],
|
|
||||||
"YBLOCK": params["y_block"],
|
|
||||||
"ZBLOCK": params["z_block"],
|
|
||||||
}
|
|
||||||
return grid_fn(block_cfg)
|
|
||||||
|
|
||||||
def generate_kernel_call(
|
|
||||||
self,
|
|
||||||
name,
|
|
||||||
call_args,
|
|
||||||
grid=None,
|
|
||||||
device_index=None,
|
|
||||||
cuda=True,
|
|
||||||
triton=True,
|
|
||||||
arg_types=None,
|
|
||||||
grid_fn: str = "grid",
|
|
||||||
):
|
|
||||||
if not cuda:
|
|
||||||
# Even in CudaWrapperCodeGen, we may see cpp kernels
|
|
||||||
return super().generate_kernel_call(
|
|
||||||
name, call_args, grid, device_index, cuda, triton, arg_types
|
|
||||||
)
|
|
||||||
|
|
||||||
params = CudaKernelParamCache.get(name)
|
|
||||||
assert (
|
|
||||||
params is not None
|
|
||||||
), f"cuda kernel parameters for {name} should already exist at this moment"
|
|
||||||
mangled_name = params.get("mangled_name", None)
|
|
||||||
assert mangled_name is not None, "missing mangled_name"
|
|
||||||
cubin_path = params.get(get_cpp_wrapper_cubin_path_name(), None)
|
|
||||||
assert cubin_path is not None and os.path.exists(
|
|
||||||
cubin_path
|
|
||||||
), f"cubin file should already exist at this moment: {cubin_path}"
|
|
||||||
shared_mem = params.get("shared_mem", 0)
|
|
||||||
|
|
||||||
self.generate_load_kernel_once(name, mangled_name, cubin_path, shared_mem)
|
|
||||||
|
|
||||||
call_args = self.generate_args_decl(call_args)
|
|
||||||
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
|
|
||||||
self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};")
|
|
||||||
stream = (
|
|
||||||
"stream" if V.graph.aot_mode else self.write_get_raw_stream(device_index)
|
|
||||||
)
|
|
||||||
grid_name = f"{name}_grid_{next(self.grid_id)}"
|
|
||||||
assert isinstance(
|
|
||||||
grid, (list, tuple)
|
|
||||||
), f"expected grid to be a list or tuple but got: {grid=}"
|
|
||||||
|
|
||||||
grid = [V.graph.sizevars.simplify(item) for item in grid]
|
|
||||||
grid_uses_symbolic_shapes = any(item.free_symbols for item in grid)
|
|
||||||
grid_args = [self.grid_expr_printer(item) for item in grid]
|
|
||||||
grid_args_str = ", ".join(grid_args)
|
|
||||||
self.writeline(f"Grid {grid_name} = Grid({grid_args_str});")
|
|
||||||
|
|
||||||
if grid_uses_symbolic_shapes:
|
|
||||||
self.writeline(f"if ({grid_name}.is_non_zero()) {{")
|
|
||||||
kernel_var_name = f"kernels.{name}" if V.graph.aot_mode else name
|
|
||||||
self.writeline(
|
|
||||||
"launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format(
|
|
||||||
kernel_var_name,
|
|
||||||
f"{grid_name}.grid_x",
|
|
||||||
f"{grid_name}.grid_y",
|
|
||||||
f"{grid_name}.grid_z",
|
|
||||||
params["num_warps"],
|
|
||||||
params["shared_mem"],
|
|
||||||
kernel_args_var,
|
|
||||||
stream,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if grid_uses_symbolic_shapes:
|
|
||||||
self.writeline("}")
|
|
||||||
|
@ -30,7 +30,8 @@ from .codegen.common import (
|
|||||||
get_wrapper_codegen_for_device,
|
get_wrapper_codegen_for_device,
|
||||||
register_backend_for_device,
|
register_backend_for_device,
|
||||||
)
|
)
|
||||||
from .codegen.wrapper import CppWrapperCodeGen, CudaWrapperCodeGen, WrapperCodeGen
|
from .codegen.cpp_wrapper_cuda import CudaWrapperCodeGen
|
||||||
|
from .codegen.wrapper import CppWrapperCodeGen, WrapperCodeGen
|
||||||
from .exc import (
|
from .exc import (
|
||||||
CppWrapperCodeGenError,
|
CppWrapperCodeGenError,
|
||||||
LoweringException,
|
LoweringException,
|
||||||
|
Reference in New Issue
Block a user