[aot_inductor] move CudaWrapperCodeGen into a separate file (#119870)

This reverts commit 3ab08946d5052eaeda11d683d6a58e801a032755.

Differential Revision: [D53817852](https://our.internmc.facebook.com/intern/diff/D53817852)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119870
Approved by: https://github.com/khabinov
This commit is contained in:
Yang Chen
2024-02-15 14:43:36 -08:00
committed by PyTorch MergeBot
parent 8f9f12c068
commit 78c9b2948a
4 changed files with 306 additions and 290 deletions

View File

@ -200,7 +200,10 @@ hipify_python.hipify(
output_directory=out_dir,
includes=includes,
ignores=ignores,
extra_files=["torch/_inductor/codegen/wrapper.py"],
extra_files=[
"torch/_inductor/codegen/cpp_wrapper_cuda.py",
"torch/_inductor/codegen/wrapper.py",
],
out_of_place_only=args.out_of_place_only,
hip_clang_launch=is_hip_clang(),
)

View File

@ -0,0 +1,299 @@
import functools
import os
from itertools import chain, count
from typing import Any, List, Optional
import sympy
from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
from .. import config
from ..codecache import CudaKernelParamCache
from ..triton_heuristics import grid as default_grid
from ..virtualized import V
from .wrapper import CppWrapperCodeGen, SymbolicCallArg
def is_int(s: str) -> bool:
# Cpp code gen adds L at the end of ints
# Lets remove it for checking whether we have an int or not
if s and s[-1] == "L":
s = s[:-1]
try:
int(s)
except ValueError:
return False
except TypeError:
return False
return True
def is_float(s: str) -> bool:
try:
float(s)
except ValueError:
return False
return True
class CudaWrapperCodeGen(CppWrapperCodeGen):
"""
Generates cpp wrapper for running on GPU and calls CUDA kernels
"""
def __init__(self):
super().__init__()
self.grid_id = count()
self.cuda = True
def write_header(self):
if V.graph.is_const_graph:
# We do not write header for constant graph, it will be written by main module.
return
super().write_header()
self.header.splice("#include <filesystem>")
if config.abi_compatible:
self.header.splice(
"#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
)
else:
self.header.splice(
"""
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
"""
)
self.header.splice(
"""
#define CUDA_DRIVER_CHECK(EXPR) \\
do { \\
CUresult code = EXPR; \\
const char *msg; \\
cuGetErrorString(code, &msg); \\
if (code != CUDA_SUCCESS) { \\
throw std::runtime_error( \\
std::string("CUDA driver error: ") + \\
std::string(msg)); \\
} \\
} while (0);
namespace {
struct Grid {
Grid(uint32_t x, uint32_t y, uint32_t z)
: grid_x(x), grid_y(y), grid_z(z) {}
uint32_t grid_x;
uint32_t grid_y;
uint32_t grid_z;
bool is_non_zero() {
return grid_x > 0 && grid_y > 0 && grid_z > 0;
}
};
} // anonymous namespace
static inline CUfunction loadKernel(
std::string filePath,
const std::string &funcName,
uint32_t sharedMemBytes,
const std::optional<std::string> &cubinDir = std::nullopt) {
if (cubinDir) {
std::filesystem::path p1{*cubinDir};
std::filesystem::path p2{filePath};
filePath = (p1 / p2.filename()).string();
}
CUmodule mod;
CUfunction func;
CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
if (sharedMemBytes > 0) {
CUDA_DRIVER_CHECK(cuFuncSetAttribute(
func,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
sharedMemBytes
))
}
return func;
}
static inline void launchKernel(
CUfunction func,
uint32_t gridX,
uint32_t gridY,
uint32_t gridZ,
uint32_t numWarps,
uint32_t sharedMemBytes,
void* args[],
cudaStream_t stream) {
CUDA_DRIVER_CHECK(cuLaunchKernel(
func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
));
}
"""
)
def write_get_raw_stream(self, index):
name = f"stream{index}"
self.writeline(
f"cudaStream_t {name} = at::cuda::getCurrentCUDAStream({index});"
)
return name
def define_kernel(
self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True
):
if not cuda:
return super().define_kernel(name, kernel, metadata, cuda)
def generate(self, is_inference):
self.prefix.writeline("\n")
if not V.graph.aot_mode:
for kernel in chain(
self.src_to_kernel.values(), self.user_defined_kernel_cache.values()
):
self.prefix.writeline(f"static CUfunction {kernel} = nullptr;")
self.prefix.writeline("\n")
return super().generate(is_inference)
@functools.lru_cache(None)
def generate_load_kernel_once(
self, name: str, mangled_name: str, cubin_path: str, shared_mem: int
):
if V.graph.aot_mode:
self.writeline(f"if (kernels.{name} == nullptr) {{")
self.writeline(
f""" kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);"""
)
self.writeline("}")
else:
self.writeline(f"if ({name} == nullptr) {{")
self.writeline(
f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});"""
)
self.writeline("}")
def generate_args_decl(self, call_args):
dynamic_symbols = V.graph.sizevars.free_symbols()
# TODO: only works for constant now, need type info
new_args = []
for arg in call_args:
var_name = f"var_{next(self.arg_var_id)}"
if isinstance(arg, (sympy.Integer, sympy.Symbol, SymbolicCallArg)):
self.writeline(f"auto {var_name} = {arg};")
elif isinstance(arg, sympy.Expr):
self.writeline(f"auto {var_name} = {self.expr_printer(arg)};")
elif is_int(arg):
self.writeline(f"int {var_name} = {arg};")
elif is_float(arg):
self.writeline(f"float {var_name} = {arg};")
elif any(str(arg) == s.name for s in dynamic_symbols):
self.writeline(f"auto {var_name} = {arg};")
elif arg == "nullptr":
self.writeline(f"auto {var_name} = nullptr;")
elif arg == "c10::nullopt":
self.writeline(f"auto {var_name} = c10::nullopt;")
else:
if config.abi_compatible:
self.writeline(f"CUdeviceptr {var_name};")
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));"
)
else:
self.writeline(
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
)
new_args.append(f"&{var_name}")
return ", ".join(new_args)
def generate_default_grid(self, name: str, grid: List[Any], cuda: bool = True):
"""
Generate grid configs for launching a CUDA kernel using the grid
function from triton_heuristics.
"""
if not cuda:
return grid
assert isinstance(grid, list), f"expected {grid=} to be a list"
grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid]
grid_fn = default_grid(*grid)
params = CudaKernelParamCache.get(name)
assert (
params is not None
), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}"
block_cfg = {
"XBLOCK": params["x_block"],
"YBLOCK": params["y_block"],
"ZBLOCK": params["z_block"],
}
return grid_fn(block_cfg)
def generate_kernel_call(
self,
name,
call_args,
grid=None,
device_index=None,
cuda=True,
triton=True,
arg_types=None,
grid_fn: str = "grid",
):
if not cuda:
# Even in CudaWrapperCodeGen, we may see cpp kernels
return super().generate_kernel_call(
name, call_args, grid, device_index, cuda, triton, arg_types
)
params = CudaKernelParamCache.get(name)
assert (
params is not None
), f"cuda kernel parameters for {name} should already exist at this moment"
mangled_name = params.get("mangled_name", None)
assert mangled_name is not None, "missing mangled_name"
cubin_path = params.get(get_cpp_wrapper_cubin_path_name(), None)
assert cubin_path is not None and os.path.exists(
cubin_path
), f"cubin file should already exist at this moment: {cubin_path}"
shared_mem = params.get("shared_mem", 0)
self.generate_load_kernel_once(name, mangled_name, cubin_path, shared_mem)
call_args = self.generate_args_decl(call_args)
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};")
stream = (
"stream" if V.graph.aot_mode else self.write_get_raw_stream(device_index)
)
grid_name = f"{name}_grid_{next(self.grid_id)}"
assert isinstance(
grid, (list, tuple)
), f"expected grid to be a list or tuple but got: {grid=}"
grid = [V.graph.sizevars.simplify(item) for item in grid]
grid_uses_symbolic_shapes = any(item.free_symbols for item in grid)
grid_args = [self.grid_expr_printer(item) for item in grid]
grid_args_str = ", ".join(grid_args)
self.writeline(f"Grid {grid_name} = Grid({grid_args_str});")
if grid_uses_symbolic_shapes:
self.writeline(f"if ({grid_name}.is_non_zero()) {{")
kernel_var_name = f"kernels.{name}" if V.graph.aot_mode else name
self.writeline(
"launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format(
kernel_var_name,
f"{grid_name}.grid_x",
f"{grid_name}.grid_y",
f"{grid_name}.grid_z",
params["num_warps"],
params["shared_mem"],
kernel_args_var,
stream,
)
)
if grid_uses_symbolic_shapes:
self.writeline("}")

View File

@ -7,7 +7,7 @@ import operator
import os
import re
import sys
from itertools import chain, count
from itertools import count
from typing import (
Any,
Callable,
@ -27,7 +27,6 @@ from sympy import Expr
import torch
import torch._ops
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
from torch._inductor.codegen.multi_kernel import MultiKernelState
from torch.fx.experimental.symbolic_shapes import SymTypes
@ -37,7 +36,6 @@ from torch.utils._sympy.singleton_int import SingletonInt
from .. import codecache, config, ir
from ..codecache import CudaKernelParamCache
from ..ir import ReinterpretView
from ..triton_heuristics import grid as default_grid
from ..utils import (
cache_on_self,
get_benchmark_name,
@ -70,28 +68,6 @@ def buffer_reuse_key(node: ir.Buffer) -> ReuseKey:
)
def is_int(s: str) -> bool:
# Cpp code gen adds L at the end of ints
# Lets remove it for checking whether we have an int or not
if s and s[-1] == "L":
s = s[:-1]
try:
int(s)
except ValueError:
return False
except TypeError:
return False
return True
def is_float(s: str) -> bool:
try:
float(s)
except ValueError:
return False
return True
def convert_arg_type(arg: torch.Argument) -> str:
from .cpp import CONTAINER_PYTHON_TO_CPP, PYTHON_TO_CPP
@ -3094,266 +3070,3 @@ class CppWrapperCodeGen(WrapperCodeGen):
return result
else:
return repr(val)
class CudaWrapperCodeGen(CppWrapperCodeGen):
"""
Generates cpp wrapper for running on GPU and calls CUDA kernels
"""
def __init__(self):
super().__init__()
self.grid_id = count()
self.cuda = True
def write_header(self):
if V.graph.is_const_graph:
# We do not write header for constant graph, it will be written by main module.
return
super().write_header()
self.header.splice("#include <filesystem>")
if config.abi_compatible:
self.header.splice(
"#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
)
else:
self.header.splice(
"""
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
"""
)
self.header.splice(
"""
#define CUDA_DRIVER_CHECK(EXPR) \\
do { \\
CUresult code = EXPR; \\
const char *msg; \\
cuGetErrorString(code, &msg); \\
if (code != CUDA_SUCCESS) { \\
throw std::runtime_error( \\
std::string("CUDA driver error: ") + \\
std::string(msg)); \\
} \\
} while (0);
namespace {
struct Grid {
Grid(uint32_t x, uint32_t y, uint32_t z)
: grid_x(x), grid_y(y), grid_z(z) {}
uint32_t grid_x;
uint32_t grid_y;
uint32_t grid_z;
bool is_non_zero() {
return grid_x > 0 && grid_y > 0 && grid_z > 0;
}
};
} // anonymous namespace
static inline CUfunction loadKernel(
std::string filePath,
const std::string &funcName,
uint32_t sharedMemBytes,
const std::optional<std::string> &cubinDir = std::nullopt) {
if (cubinDir) {
std::filesystem::path p1{*cubinDir};
std::filesystem::path p2{filePath};
filePath = (p1 / p2.filename()).string();
}
CUmodule mod;
CUfunction func;
CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
if (sharedMemBytes > 0) {
CUDA_DRIVER_CHECK(cuFuncSetAttribute(
func,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
sharedMemBytes
))
}
return func;
}
static inline void launchKernel(
CUfunction func,
uint32_t gridX,
uint32_t gridY,
uint32_t gridZ,
uint32_t numWarps,
uint32_t sharedMemBytes,
void* args[],
cudaStream_t stream) {
CUDA_DRIVER_CHECK(cuLaunchKernel(
func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
));
}
"""
)
def write_get_raw_stream(self, index):
name = f"stream{index}"
self.writeline(
f"cudaStream_t {name} = at::cuda::getCurrentCUDAStream({index});"
)
return name
def define_kernel(
self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True
):
if not cuda:
return super().define_kernel(name, kernel, metadata, cuda)
def generate(self, is_inference):
self.prefix.writeline("\n")
if not V.graph.aot_mode:
for kernel in chain(
self.src_to_kernel.values(), self.user_defined_kernel_cache.values()
):
self.prefix.writeline(f"static CUfunction {kernel} = nullptr;")
self.prefix.writeline("\n")
return super().generate(is_inference)
@functools.lru_cache(None)
def generate_load_kernel_once(
self, name: str, mangled_name: str, cubin_path: str, shared_mem: int
):
if V.graph.aot_mode:
self.writeline(f"if (kernels.{name} == nullptr) {{")
self.writeline(
f""" kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);"""
)
self.writeline("}")
else:
self.writeline(f"if ({name} == nullptr) {{")
self.writeline(
f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});"""
)
self.writeline("}")
def generate_args_decl(self, call_args):
dynamic_symbols = V.graph.sizevars.free_symbols()
# TODO: only works for constant now, need type info
new_args = []
for arg in call_args:
var_name = f"var_{next(self.arg_var_id)}"
if isinstance(arg, (sympy.Integer, sympy.Symbol, SymbolicCallArg)):
self.writeline(f"auto {var_name} = {arg};")
elif isinstance(arg, sympy.Expr):
self.writeline(f"auto {var_name} = {self.expr_printer(arg)};")
elif is_int(arg):
self.writeline(f"int {var_name} = {arg};")
elif is_float(arg):
self.writeline(f"float {var_name} = {arg};")
elif any(str(arg) == s.name for s in dynamic_symbols):
self.writeline(f"auto {var_name} = {arg};")
elif arg == "nullptr":
self.writeline(f"auto {var_name} = nullptr;")
elif arg == "c10::nullopt":
self.writeline(f"auto {var_name} = c10::nullopt;")
else:
if config.abi_compatible:
self.writeline(f"CUdeviceptr {var_name};")
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));"
)
else:
self.writeline(
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
)
new_args.append(f"&{var_name}")
return ", ".join(new_args)
def generate_default_grid(self, name: str, grid: List[Any], cuda: bool = True):
"""
Generate grid configs for launching a CUDA kernel using the grid
function from triton_heuristics.
"""
if not cuda:
return grid
assert isinstance(grid, list), f"expected {grid=} to be a list"
grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid]
grid_fn = default_grid(*grid)
params = CudaKernelParamCache.get(name)
assert (
params is not None
), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}"
block_cfg = {
"XBLOCK": params["x_block"],
"YBLOCK": params["y_block"],
"ZBLOCK": params["z_block"],
}
return grid_fn(block_cfg)
def generate_kernel_call(
self,
name,
call_args,
grid=None,
device_index=None,
cuda=True,
triton=True,
arg_types=None,
grid_fn: str = "grid",
):
if not cuda:
# Even in CudaWrapperCodeGen, we may see cpp kernels
return super().generate_kernel_call(
name, call_args, grid, device_index, cuda, triton, arg_types
)
params = CudaKernelParamCache.get(name)
assert (
params is not None
), f"cuda kernel parameters for {name} should already exist at this moment"
mangled_name = params.get("mangled_name", None)
assert mangled_name is not None, "missing mangled_name"
cubin_path = params.get(get_cpp_wrapper_cubin_path_name(), None)
assert cubin_path is not None and os.path.exists(
cubin_path
), f"cubin file should already exist at this moment: {cubin_path}"
shared_mem = params.get("shared_mem", 0)
self.generate_load_kernel_once(name, mangled_name, cubin_path, shared_mem)
call_args = self.generate_args_decl(call_args)
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};")
stream = (
"stream" if V.graph.aot_mode else self.write_get_raw_stream(device_index)
)
grid_name = f"{name}_grid_{next(self.grid_id)}"
assert isinstance(
grid, (list, tuple)
), f"expected grid to be a list or tuple but got: {grid=}"
grid = [V.graph.sizevars.simplify(item) for item in grid]
grid_uses_symbolic_shapes = any(item.free_symbols for item in grid)
grid_args = [self.grid_expr_printer(item) for item in grid]
grid_args_str = ", ".join(grid_args)
self.writeline(f"Grid {grid_name} = Grid({grid_args_str});")
if grid_uses_symbolic_shapes:
self.writeline(f"if ({grid_name}.is_non_zero()) {{")
kernel_var_name = f"kernels.{name}" if V.graph.aot_mode else name
self.writeline(
"launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format(
kernel_var_name,
f"{grid_name}.grid_x",
f"{grid_name}.grid_y",
f"{grid_name}.grid_z",
params["num_warps"],
params["shared_mem"],
kernel_args_var,
stream,
)
)
if grid_uses_symbolic_shapes:
self.writeline("}")

View File

@ -30,7 +30,8 @@ from .codegen.common import (
get_wrapper_codegen_for_device,
register_backend_for_device,
)
from .codegen.wrapper import CppWrapperCodeGen, CudaWrapperCodeGen, WrapperCodeGen
from .codegen.cpp_wrapper_cuda import CudaWrapperCodeGen
from .codegen.wrapper import CppWrapperCodeGen, WrapperCodeGen
from .exc import (
CppWrapperCodeGenError,
LoweringException,