[aot_inductor] move CudaWrapperCodeGen into a separate file (#119870)

This reverts commit 3ab08946d5052eaeda11d683d6a58e801a032755.

Differential Revision: [D53817852](https://our.internmc.facebook.com/intern/diff/D53817852)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119870
Approved by: https://github.com/khabinov
This commit is contained in:
Yang Chen
2024-02-15 14:43:36 -08:00
committed by PyTorch MergeBot
parent 8f9f12c068
commit 78c9b2948a
4 changed files with 306 additions and 290 deletions

View File

@ -200,7 +200,10 @@ hipify_python.hipify(
output_directory=out_dir, output_directory=out_dir,
includes=includes, includes=includes,
ignores=ignores, ignores=ignores,
extra_files=["torch/_inductor/codegen/wrapper.py"], extra_files=[
"torch/_inductor/codegen/cpp_wrapper_cuda.py",
"torch/_inductor/codegen/wrapper.py",
],
out_of_place_only=args.out_of_place_only, out_of_place_only=args.out_of_place_only,
hip_clang_launch=is_hip_clang(), hip_clang_launch=is_hip_clang(),
) )

View File

@ -0,0 +1,299 @@
import functools
import os
from itertools import chain, count
from typing import Any, List, Optional
import sympy
from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
from .. import config
from ..codecache import CudaKernelParamCache
from ..triton_heuristics import grid as default_grid
from ..virtualized import V
from .wrapper import CppWrapperCodeGen, SymbolicCallArg
def is_int(s: str) -> bool:
# Cpp code gen adds L at the end of ints
# Lets remove it for checking whether we have an int or not
if s and s[-1] == "L":
s = s[:-1]
try:
int(s)
except ValueError:
return False
except TypeError:
return False
return True
def is_float(s: str) -> bool:
try:
float(s)
except ValueError:
return False
return True
class CudaWrapperCodeGen(CppWrapperCodeGen):
"""
Generates cpp wrapper for running on GPU and calls CUDA kernels
"""
def __init__(self):
super().__init__()
self.grid_id = count()
self.cuda = True
def write_header(self):
if V.graph.is_const_graph:
# We do not write header for constant graph, it will be written by main module.
return
super().write_header()
self.header.splice("#include <filesystem>")
if config.abi_compatible:
self.header.splice(
"#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
)
else:
self.header.splice(
"""
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
"""
)
self.header.splice(
"""
#define CUDA_DRIVER_CHECK(EXPR) \\
do { \\
CUresult code = EXPR; \\
const char *msg; \\
cuGetErrorString(code, &msg); \\
if (code != CUDA_SUCCESS) { \\
throw std::runtime_error( \\
std::string("CUDA driver error: ") + \\
std::string(msg)); \\
} \\
} while (0);
namespace {
struct Grid {
Grid(uint32_t x, uint32_t y, uint32_t z)
: grid_x(x), grid_y(y), grid_z(z) {}
uint32_t grid_x;
uint32_t grid_y;
uint32_t grid_z;
bool is_non_zero() {
return grid_x > 0 && grid_y > 0 && grid_z > 0;
}
};
} // anonymous namespace
static inline CUfunction loadKernel(
std::string filePath,
const std::string &funcName,
uint32_t sharedMemBytes,
const std::optional<std::string> &cubinDir = std::nullopt) {
if (cubinDir) {
std::filesystem::path p1{*cubinDir};
std::filesystem::path p2{filePath};
filePath = (p1 / p2.filename()).string();
}
CUmodule mod;
CUfunction func;
CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
if (sharedMemBytes > 0) {
CUDA_DRIVER_CHECK(cuFuncSetAttribute(
func,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
sharedMemBytes
))
}
return func;
}
static inline void launchKernel(
CUfunction func,
uint32_t gridX,
uint32_t gridY,
uint32_t gridZ,
uint32_t numWarps,
uint32_t sharedMemBytes,
void* args[],
cudaStream_t stream) {
CUDA_DRIVER_CHECK(cuLaunchKernel(
func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
));
}
"""
)
def write_get_raw_stream(self, index):
name = f"stream{index}"
self.writeline(
f"cudaStream_t {name} = at::cuda::getCurrentCUDAStream({index});"
)
return name
def define_kernel(
self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True
):
if not cuda:
return super().define_kernel(name, kernel, metadata, cuda)
def generate(self, is_inference):
self.prefix.writeline("\n")
if not V.graph.aot_mode:
for kernel in chain(
self.src_to_kernel.values(), self.user_defined_kernel_cache.values()
):
self.prefix.writeline(f"static CUfunction {kernel} = nullptr;")
self.prefix.writeline("\n")
return super().generate(is_inference)
@functools.lru_cache(None)
def generate_load_kernel_once(
self, name: str, mangled_name: str, cubin_path: str, shared_mem: int
):
if V.graph.aot_mode:
self.writeline(f"if (kernels.{name} == nullptr) {{")
self.writeline(
f""" kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);"""
)
self.writeline("}")
else:
self.writeline(f"if ({name} == nullptr) {{")
self.writeline(
f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});"""
)
self.writeline("}")
def generate_args_decl(self, call_args):
dynamic_symbols = V.graph.sizevars.free_symbols()
# TODO: only works for constant now, need type info
new_args = []
for arg in call_args:
var_name = f"var_{next(self.arg_var_id)}"
if isinstance(arg, (sympy.Integer, sympy.Symbol, SymbolicCallArg)):
self.writeline(f"auto {var_name} = {arg};")
elif isinstance(arg, sympy.Expr):
self.writeline(f"auto {var_name} = {self.expr_printer(arg)};")
elif is_int(arg):
self.writeline(f"int {var_name} = {arg};")
elif is_float(arg):
self.writeline(f"float {var_name} = {arg};")
elif any(str(arg) == s.name for s in dynamic_symbols):
self.writeline(f"auto {var_name} = {arg};")
elif arg == "nullptr":
self.writeline(f"auto {var_name} = nullptr;")
elif arg == "c10::nullopt":
self.writeline(f"auto {var_name} = c10::nullopt;")
else:
if config.abi_compatible:
self.writeline(f"CUdeviceptr {var_name};")
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));"
)
else:
self.writeline(
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
)
new_args.append(f"&{var_name}")
return ", ".join(new_args)
def generate_default_grid(self, name: str, grid: List[Any], cuda: bool = True):
"""
Generate grid configs for launching a CUDA kernel using the grid
function from triton_heuristics.
"""
if not cuda:
return grid
assert isinstance(grid, list), f"expected {grid=} to be a list"
grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid]
grid_fn = default_grid(*grid)
params = CudaKernelParamCache.get(name)
assert (
params is not None
), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}"
block_cfg = {
"XBLOCK": params["x_block"],
"YBLOCK": params["y_block"],
"ZBLOCK": params["z_block"],
}
return grid_fn(block_cfg)
def generate_kernel_call(
self,
name,
call_args,
grid=None,
device_index=None,
cuda=True,
triton=True,
arg_types=None,
grid_fn: str = "grid",
):
if not cuda:
# Even in CudaWrapperCodeGen, we may see cpp kernels
return super().generate_kernel_call(
name, call_args, grid, device_index, cuda, triton, arg_types
)
params = CudaKernelParamCache.get(name)
assert (
params is not None
), f"cuda kernel parameters for {name} should already exist at this moment"
mangled_name = params.get("mangled_name", None)
assert mangled_name is not None, "missing mangled_name"
cubin_path = params.get(get_cpp_wrapper_cubin_path_name(), None)
assert cubin_path is not None and os.path.exists(
cubin_path
), f"cubin file should already exist at this moment: {cubin_path}"
shared_mem = params.get("shared_mem", 0)
self.generate_load_kernel_once(name, mangled_name, cubin_path, shared_mem)
call_args = self.generate_args_decl(call_args)
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};")
stream = (
"stream" if V.graph.aot_mode else self.write_get_raw_stream(device_index)
)
grid_name = f"{name}_grid_{next(self.grid_id)}"
assert isinstance(
grid, (list, tuple)
), f"expected grid to be a list or tuple but got: {grid=}"
grid = [V.graph.sizevars.simplify(item) for item in grid]
grid_uses_symbolic_shapes = any(item.free_symbols for item in grid)
grid_args = [self.grid_expr_printer(item) for item in grid]
grid_args_str = ", ".join(grid_args)
self.writeline(f"Grid {grid_name} = Grid({grid_args_str});")
if grid_uses_symbolic_shapes:
self.writeline(f"if ({grid_name}.is_non_zero()) {{")
kernel_var_name = f"kernels.{name}" if V.graph.aot_mode else name
self.writeline(
"launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format(
kernel_var_name,
f"{grid_name}.grid_x",
f"{grid_name}.grid_y",
f"{grid_name}.grid_z",
params["num_warps"],
params["shared_mem"],
kernel_args_var,
stream,
)
)
if grid_uses_symbolic_shapes:
self.writeline("}")

View File

@ -7,7 +7,7 @@ import operator
import os import os
import re import re
import sys import sys
from itertools import chain, count from itertools import count
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -27,7 +27,6 @@ from sympy import Expr
import torch import torch
import torch._ops import torch._ops
from torch._dynamo.utils import counters, dynamo_timed from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
from torch._inductor.codegen.multi_kernel import MultiKernelState from torch._inductor.codegen.multi_kernel import MultiKernelState
from torch.fx.experimental.symbolic_shapes import SymTypes from torch.fx.experimental.symbolic_shapes import SymTypes
@ -37,7 +36,6 @@ from torch.utils._sympy.singleton_int import SingletonInt
from .. import codecache, config, ir from .. import codecache, config, ir
from ..codecache import CudaKernelParamCache from ..codecache import CudaKernelParamCache
from ..ir import ReinterpretView from ..ir import ReinterpretView
from ..triton_heuristics import grid as default_grid
from ..utils import ( from ..utils import (
cache_on_self, cache_on_self,
get_benchmark_name, get_benchmark_name,
@ -70,28 +68,6 @@ def buffer_reuse_key(node: ir.Buffer) -> ReuseKey:
) )
def is_int(s: str) -> bool:
# Cpp code gen adds L at the end of ints
# Lets remove it for checking whether we have an int or not
if s and s[-1] == "L":
s = s[:-1]
try:
int(s)
except ValueError:
return False
except TypeError:
return False
return True
def is_float(s: str) -> bool:
try:
float(s)
except ValueError:
return False
return True
def convert_arg_type(arg: torch.Argument) -> str: def convert_arg_type(arg: torch.Argument) -> str:
from .cpp import CONTAINER_PYTHON_TO_CPP, PYTHON_TO_CPP from .cpp import CONTAINER_PYTHON_TO_CPP, PYTHON_TO_CPP
@ -3094,266 +3070,3 @@ class CppWrapperCodeGen(WrapperCodeGen):
return result return result
else: else:
return repr(val) return repr(val)
class CudaWrapperCodeGen(CppWrapperCodeGen):
"""
Generates cpp wrapper for running on GPU and calls CUDA kernels
"""
def __init__(self):
super().__init__()
self.grid_id = count()
self.cuda = True
def write_header(self):
if V.graph.is_const_graph:
# We do not write header for constant graph, it will be written by main module.
return
super().write_header()
self.header.splice("#include <filesystem>")
if config.abi_compatible:
self.header.splice(
"#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
)
else:
self.header.splice(
"""
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
"""
)
self.header.splice(
"""
#define CUDA_DRIVER_CHECK(EXPR) \\
do { \\
CUresult code = EXPR; \\
const char *msg; \\
cuGetErrorString(code, &msg); \\
if (code != CUDA_SUCCESS) { \\
throw std::runtime_error( \\
std::string("CUDA driver error: ") + \\
std::string(msg)); \\
} \\
} while (0);
namespace {
struct Grid {
Grid(uint32_t x, uint32_t y, uint32_t z)
: grid_x(x), grid_y(y), grid_z(z) {}
uint32_t grid_x;
uint32_t grid_y;
uint32_t grid_z;
bool is_non_zero() {
return grid_x > 0 && grid_y > 0 && grid_z > 0;
}
};
} // anonymous namespace
static inline CUfunction loadKernel(
std::string filePath,
const std::string &funcName,
uint32_t sharedMemBytes,
const std::optional<std::string> &cubinDir = std::nullopt) {
if (cubinDir) {
std::filesystem::path p1{*cubinDir};
std::filesystem::path p2{filePath};
filePath = (p1 / p2.filename()).string();
}
CUmodule mod;
CUfunction func;
CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
if (sharedMemBytes > 0) {
CUDA_DRIVER_CHECK(cuFuncSetAttribute(
func,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
sharedMemBytes
))
}
return func;
}
static inline void launchKernel(
CUfunction func,
uint32_t gridX,
uint32_t gridY,
uint32_t gridZ,
uint32_t numWarps,
uint32_t sharedMemBytes,
void* args[],
cudaStream_t stream) {
CUDA_DRIVER_CHECK(cuLaunchKernel(
func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
));
}
"""
)
def write_get_raw_stream(self, index):
name = f"stream{index}"
self.writeline(
f"cudaStream_t {name} = at::cuda::getCurrentCUDAStream({index});"
)
return name
def define_kernel(
self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True
):
if not cuda:
return super().define_kernel(name, kernel, metadata, cuda)
def generate(self, is_inference):
self.prefix.writeline("\n")
if not V.graph.aot_mode:
for kernel in chain(
self.src_to_kernel.values(), self.user_defined_kernel_cache.values()
):
self.prefix.writeline(f"static CUfunction {kernel} = nullptr;")
self.prefix.writeline("\n")
return super().generate(is_inference)
@functools.lru_cache(None)
def generate_load_kernel_once(
self, name: str, mangled_name: str, cubin_path: str, shared_mem: int
):
if V.graph.aot_mode:
self.writeline(f"if (kernels.{name} == nullptr) {{")
self.writeline(
f""" kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);"""
)
self.writeline("}")
else:
self.writeline(f"if ({name} == nullptr) {{")
self.writeline(
f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});"""
)
self.writeline("}")
def generate_args_decl(self, call_args):
dynamic_symbols = V.graph.sizevars.free_symbols()
# TODO: only works for constant now, need type info
new_args = []
for arg in call_args:
var_name = f"var_{next(self.arg_var_id)}"
if isinstance(arg, (sympy.Integer, sympy.Symbol, SymbolicCallArg)):
self.writeline(f"auto {var_name} = {arg};")
elif isinstance(arg, sympy.Expr):
self.writeline(f"auto {var_name} = {self.expr_printer(arg)};")
elif is_int(arg):
self.writeline(f"int {var_name} = {arg};")
elif is_float(arg):
self.writeline(f"float {var_name} = {arg};")
elif any(str(arg) == s.name for s in dynamic_symbols):
self.writeline(f"auto {var_name} = {arg};")
elif arg == "nullptr":
self.writeline(f"auto {var_name} = nullptr;")
elif arg == "c10::nullopt":
self.writeline(f"auto {var_name} = c10::nullopt;")
else:
if config.abi_compatible:
self.writeline(f"CUdeviceptr {var_name};")
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));"
)
else:
self.writeline(
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
)
new_args.append(f"&{var_name}")
return ", ".join(new_args)
def generate_default_grid(self, name: str, grid: List[Any], cuda: bool = True):
"""
Generate grid configs for launching a CUDA kernel using the grid
function from triton_heuristics.
"""
if not cuda:
return grid
assert isinstance(grid, list), f"expected {grid=} to be a list"
grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid]
grid_fn = default_grid(*grid)
params = CudaKernelParamCache.get(name)
assert (
params is not None
), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}"
block_cfg = {
"XBLOCK": params["x_block"],
"YBLOCK": params["y_block"],
"ZBLOCK": params["z_block"],
}
return grid_fn(block_cfg)
def generate_kernel_call(
self,
name,
call_args,
grid=None,
device_index=None,
cuda=True,
triton=True,
arg_types=None,
grid_fn: str = "grid",
):
if not cuda:
# Even in CudaWrapperCodeGen, we may see cpp kernels
return super().generate_kernel_call(
name, call_args, grid, device_index, cuda, triton, arg_types
)
params = CudaKernelParamCache.get(name)
assert (
params is not None
), f"cuda kernel parameters for {name} should already exist at this moment"
mangled_name = params.get("mangled_name", None)
assert mangled_name is not None, "missing mangled_name"
cubin_path = params.get(get_cpp_wrapper_cubin_path_name(), None)
assert cubin_path is not None and os.path.exists(
cubin_path
), f"cubin file should already exist at this moment: {cubin_path}"
shared_mem = params.get("shared_mem", 0)
self.generate_load_kernel_once(name, mangled_name, cubin_path, shared_mem)
call_args = self.generate_args_decl(call_args)
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};")
stream = (
"stream" if V.graph.aot_mode else self.write_get_raw_stream(device_index)
)
grid_name = f"{name}_grid_{next(self.grid_id)}"
assert isinstance(
grid, (list, tuple)
), f"expected grid to be a list or tuple but got: {grid=}"
grid = [V.graph.sizevars.simplify(item) for item in grid]
grid_uses_symbolic_shapes = any(item.free_symbols for item in grid)
grid_args = [self.grid_expr_printer(item) for item in grid]
grid_args_str = ", ".join(grid_args)
self.writeline(f"Grid {grid_name} = Grid({grid_args_str});")
if grid_uses_symbolic_shapes:
self.writeline(f"if ({grid_name}.is_non_zero()) {{")
kernel_var_name = f"kernels.{name}" if V.graph.aot_mode else name
self.writeline(
"launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format(
kernel_var_name,
f"{grid_name}.grid_x",
f"{grid_name}.grid_y",
f"{grid_name}.grid_z",
params["num_warps"],
params["shared_mem"],
kernel_args_var,
stream,
)
)
if grid_uses_symbolic_shapes:
self.writeline("}")

View File

@ -30,7 +30,8 @@ from .codegen.common import (
get_wrapper_codegen_for_device, get_wrapper_codegen_for_device,
register_backend_for_device, register_backend_for_device,
) )
from .codegen.wrapper import CppWrapperCodeGen, CudaWrapperCodeGen, WrapperCodeGen from .codegen.cpp_wrapper_cuda import CudaWrapperCodeGen
from .codegen.wrapper import CppWrapperCodeGen, WrapperCodeGen
from .exc import ( from .exc import (
CppWrapperCodeGenError, CppWrapperCodeGenError,
LoweringException, LoweringException,