diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index a362092712e7..ffad29577276 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -200,7 +200,10 @@ hipify_python.hipify( output_directory=out_dir, includes=includes, ignores=ignores, - extra_files=["torch/_inductor/codegen/wrapper.py"], + extra_files=[ + "torch/_inductor/codegen/cpp_wrapper_cuda.py", + "torch/_inductor/codegen/wrapper.py", + ], out_of_place_only=args.out_of_place_only, hip_clang_launch=is_hip_clang(), ) diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_cuda.py new file mode 100644 index 000000000000..a35b3ee2da09 --- /dev/null +++ b/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -0,0 +1,299 @@ +import functools +import os +from itertools import chain, count +from typing import Any, List, Optional + +import sympy + +from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name + +from .. import config +from ..codecache import CudaKernelParamCache +from ..triton_heuristics import grid as default_grid +from ..virtualized import V +from .wrapper import CppWrapperCodeGen, SymbolicCallArg + + +def is_int(s: str) -> bool: + # Cpp code gen adds L at the end of ints + # Lets remove it for checking whether we have an int or not + if s and s[-1] == "L": + s = s[:-1] + try: + int(s) + except ValueError: + return False + except TypeError: + return False + return True + + +def is_float(s: str) -> bool: + try: + float(s) + except ValueError: + return False + return True + + +class CudaWrapperCodeGen(CppWrapperCodeGen): + """ + Generates cpp wrapper for running on GPU and calls CUDA kernels + """ + + def __init__(self): + super().__init__() + self.grid_id = count() + self.cuda = True + + def write_header(self): + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + super().write_header() + + self.header.splice("#include ") + if config.abi_compatible: + self.header.splice( + "#include " + ) + else: + self.header.splice( + """ + #include + #include + """ + ) + + self.header.splice( + """ + #define CUDA_DRIVER_CHECK(EXPR) \\ + do { \\ + CUresult code = EXPR; \\ + const char *msg; \\ + cuGetErrorString(code, &msg); \\ + if (code != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string(msg)); \\ + } \\ + } while (0); + + namespace { + + struct Grid { + Grid(uint32_t x, uint32_t y, uint32_t z) + : grid_x(x), grid_y(y), grid_z(z) {} + uint32_t grid_x; + uint32_t grid_y; + uint32_t grid_z; + + bool is_non_zero() { + return grid_x > 0 && grid_y > 0 && grid_z > 0; + } + }; + + } // anonymous namespace + + static inline CUfunction loadKernel( + std::string filePath, + const std::string &funcName, + uint32_t sharedMemBytes, + const std::optional &cubinDir = std::nullopt) { + if (cubinDir) { + std::filesystem::path p1{*cubinDir}; + std::filesystem::path p2{filePath}; + filePath = (p1 / p2.filename()).string(); + } + + CUmodule mod; + CUfunction func; + CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str())); + CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )) + } + return func; + } + + static inline void launchKernel( + CUfunction func, + uint32_t gridX, + uint32_t gridY, + uint32_t gridZ, + uint32_t numWarps, + uint32_t sharedMemBytes, + void* args[], + cudaStream_t stream) { + CUDA_DRIVER_CHECK(cuLaunchKernel( + func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr + )); + } + """ + ) + + def write_get_raw_stream(self, index): + name = f"stream{index}" + self.writeline( + f"cudaStream_t {name} = at::cuda::getCurrentCUDAStream({index});" + ) + return name + + def define_kernel( + self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True + ): + if not cuda: + return super().define_kernel(name, kernel, metadata, cuda) + + def generate(self, is_inference): + self.prefix.writeline("\n") + if not V.graph.aot_mode: + for kernel in chain( + self.src_to_kernel.values(), self.user_defined_kernel_cache.values() + ): + self.prefix.writeline(f"static CUfunction {kernel} = nullptr;") + self.prefix.writeline("\n") + return super().generate(is_inference) + + @functools.lru_cache(None) + def generate_load_kernel_once( + self, name: str, mangled_name: str, cubin_path: str, shared_mem: int + ): + if V.graph.aot_mode: + self.writeline(f"if (kernels.{name} == nullptr) {{") + self.writeline( + f""" kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);""" + ) + self.writeline("}") + else: + self.writeline(f"if ({name} == nullptr) {{") + self.writeline( + f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});""" + ) + self.writeline("}") + + def generate_args_decl(self, call_args): + dynamic_symbols = V.graph.sizevars.free_symbols() + # TODO: only works for constant now, need type info + new_args = [] + for arg in call_args: + var_name = f"var_{next(self.arg_var_id)}" + if isinstance(arg, (sympy.Integer, sympy.Symbol, SymbolicCallArg)): + self.writeline(f"auto {var_name} = {arg};") + elif isinstance(arg, sympy.Expr): + self.writeline(f"auto {var_name} = {self.expr_printer(arg)};") + elif is_int(arg): + self.writeline(f"int {var_name} = {arg};") + elif is_float(arg): + self.writeline(f"float {var_name} = {arg};") + elif any(str(arg) == s.name for s in dynamic_symbols): + self.writeline(f"auto {var_name} = {arg};") + elif arg == "nullptr": + self.writeline(f"auto {var_name} = nullptr;") + elif arg == "c10::nullopt": + self.writeline(f"auto {var_name} = c10::nullopt;") + else: + if config.abi_compatible: + self.writeline(f"CUdeviceptr {var_name};") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast(&{var_name})));" + ) + else: + self.writeline( + f"CUdeviceptr {var_name} = reinterpret_cast({arg}.data_ptr());" + ) + new_args.append(f"&{var_name}") + + return ", ".join(new_args) + + def generate_default_grid(self, name: str, grid: List[Any], cuda: bool = True): + """ + Generate grid configs for launching a CUDA kernel using the grid + function from triton_heuristics. + """ + if not cuda: + return grid + assert isinstance(grid, list), f"expected {grid=} to be a list" + grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid] + grid_fn = default_grid(*grid) + params = CudaKernelParamCache.get(name) + assert ( + params is not None + ), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}" + block_cfg = { + "XBLOCK": params["x_block"], + "YBLOCK": params["y_block"], + "ZBLOCK": params["z_block"], + } + return grid_fn(block_cfg) + + def generate_kernel_call( + self, + name, + call_args, + grid=None, + device_index=None, + cuda=True, + triton=True, + arg_types=None, + grid_fn: str = "grid", + ): + if not cuda: + # Even in CudaWrapperCodeGen, we may see cpp kernels + return super().generate_kernel_call( + name, call_args, grid, device_index, cuda, triton, arg_types + ) + + params = CudaKernelParamCache.get(name) + assert ( + params is not None + ), f"cuda kernel parameters for {name} should already exist at this moment" + mangled_name = params.get("mangled_name", None) + assert mangled_name is not None, "missing mangled_name" + cubin_path = params.get(get_cpp_wrapper_cubin_path_name(), None) + assert cubin_path is not None and os.path.exists( + cubin_path + ), f"cubin file should already exist at this moment: {cubin_path}" + shared_mem = params.get("shared_mem", 0) + + self.generate_load_kernel_once(name, mangled_name, cubin_path, shared_mem) + + call_args = self.generate_args_decl(call_args) + kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}" + self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};") + stream = ( + "stream" if V.graph.aot_mode else self.write_get_raw_stream(device_index) + ) + grid_name = f"{name}_grid_{next(self.grid_id)}" + assert isinstance( + grid, (list, tuple) + ), f"expected grid to be a list or tuple but got: {grid=}" + + grid = [V.graph.sizevars.simplify(item) for item in grid] + grid_uses_symbolic_shapes = any(item.free_symbols for item in grid) + grid_args = [self.grid_expr_printer(item) for item in grid] + grid_args_str = ", ".join(grid_args) + self.writeline(f"Grid {grid_name} = Grid({grid_args_str});") + + if grid_uses_symbolic_shapes: + self.writeline(f"if ({grid_name}.is_non_zero()) {{") + kernel_var_name = f"kernels.{name}" if V.graph.aot_mode else name + self.writeline( + "launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format( + kernel_var_name, + f"{grid_name}.grid_x", + f"{grid_name}.grid_y", + f"{grid_name}.grid_z", + params["num_warps"], + params["shared_mem"], + kernel_args_var, + stream, + ) + ) + if grid_uses_symbolic_shapes: + self.writeline("}") diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 9f5c6db3ed37..18dfc71151a9 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -7,7 +7,7 @@ import operator import os import re import sys -from itertools import chain, count +from itertools import count from typing import ( Any, Callable, @@ -27,7 +27,6 @@ from sympy import Expr import torch import torch._ops from torch._dynamo.utils import counters, dynamo_timed -from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name from torch._inductor.codegen.multi_kernel import MultiKernelState from torch.fx.experimental.symbolic_shapes import SymTypes @@ -37,7 +36,6 @@ from torch.utils._sympy.singleton_int import SingletonInt from .. import codecache, config, ir from ..codecache import CudaKernelParamCache from ..ir import ReinterpretView -from ..triton_heuristics import grid as default_grid from ..utils import ( cache_on_self, get_benchmark_name, @@ -70,28 +68,6 @@ def buffer_reuse_key(node: ir.Buffer) -> ReuseKey: ) -def is_int(s: str) -> bool: - # Cpp code gen adds L at the end of ints - # Lets remove it for checking whether we have an int or not - if s and s[-1] == "L": - s = s[:-1] - try: - int(s) - except ValueError: - return False - except TypeError: - return False - return True - - -def is_float(s: str) -> bool: - try: - float(s) - except ValueError: - return False - return True - - def convert_arg_type(arg: torch.Argument) -> str: from .cpp import CONTAINER_PYTHON_TO_CPP, PYTHON_TO_CPP @@ -3094,266 +3070,3 @@ class CppWrapperCodeGen(WrapperCodeGen): return result else: return repr(val) - - -class CudaWrapperCodeGen(CppWrapperCodeGen): - """ - Generates cpp wrapper for running on GPU and calls CUDA kernels - """ - - def __init__(self): - super().__init__() - self.grid_id = count() - self.cuda = True - - def write_header(self): - if V.graph.is_const_graph: - # We do not write header for constant graph, it will be written by main module. - return - - super().write_header() - - self.header.splice("#include ") - if config.abi_compatible: - self.header.splice( - "#include " - ) - else: - self.header.splice( - """ - #include - #include - """ - ) - - self.header.splice( - """ - #define CUDA_DRIVER_CHECK(EXPR) \\ - do { \\ - CUresult code = EXPR; \\ - const char *msg; \\ - cuGetErrorString(code, &msg); \\ - if (code != CUDA_SUCCESS) { \\ - throw std::runtime_error( \\ - std::string("CUDA driver error: ") + \\ - std::string(msg)); \\ - } \\ - } while (0); - - namespace { - - struct Grid { - Grid(uint32_t x, uint32_t y, uint32_t z) - : grid_x(x), grid_y(y), grid_z(z) {} - uint32_t grid_x; - uint32_t grid_y; - uint32_t grid_z; - - bool is_non_zero() { - return grid_x > 0 && grid_y > 0 && grid_z > 0; - } - }; - - } // anonymous namespace - - static inline CUfunction loadKernel( - std::string filePath, - const std::string &funcName, - uint32_t sharedMemBytes, - const std::optional &cubinDir = std::nullopt) { - if (cubinDir) { - std::filesystem::path p1{*cubinDir}; - std::filesystem::path p2{filePath}; - filePath = (p1 / p2.filename()).string(); - } - - CUmodule mod; - CUfunction func; - CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str())); - CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); - if (sharedMemBytes > 0) { - CUDA_DRIVER_CHECK(cuFuncSetAttribute( - func, - CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - sharedMemBytes - )) - } - return func; - } - - static inline void launchKernel( - CUfunction func, - uint32_t gridX, - uint32_t gridY, - uint32_t gridZ, - uint32_t numWarps, - uint32_t sharedMemBytes, - void* args[], - cudaStream_t stream) { - CUDA_DRIVER_CHECK(cuLaunchKernel( - func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr - )); - } - """ - ) - - def write_get_raw_stream(self, index): - name = f"stream{index}" - self.writeline( - f"cudaStream_t {name} = at::cuda::getCurrentCUDAStream({index});" - ) - return name - - def define_kernel( - self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True - ): - if not cuda: - return super().define_kernel(name, kernel, metadata, cuda) - - def generate(self, is_inference): - self.prefix.writeline("\n") - if not V.graph.aot_mode: - for kernel in chain( - self.src_to_kernel.values(), self.user_defined_kernel_cache.values() - ): - self.prefix.writeline(f"static CUfunction {kernel} = nullptr;") - self.prefix.writeline("\n") - return super().generate(is_inference) - - @functools.lru_cache(None) - def generate_load_kernel_once( - self, name: str, mangled_name: str, cubin_path: str, shared_mem: int - ): - if V.graph.aot_mode: - self.writeline(f"if (kernels.{name} == nullptr) {{") - self.writeline( - f""" kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);""" - ) - self.writeline("}") - else: - self.writeline(f"if ({name} == nullptr) {{") - self.writeline( - f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});""" - ) - self.writeline("}") - - def generate_args_decl(self, call_args): - dynamic_symbols = V.graph.sizevars.free_symbols() - # TODO: only works for constant now, need type info - new_args = [] - for arg in call_args: - var_name = f"var_{next(self.arg_var_id)}" - if isinstance(arg, (sympy.Integer, sympy.Symbol, SymbolicCallArg)): - self.writeline(f"auto {var_name} = {arg};") - elif isinstance(arg, sympy.Expr): - self.writeline(f"auto {var_name} = {self.expr_printer(arg)};") - elif is_int(arg): - self.writeline(f"int {var_name} = {arg};") - elif is_float(arg): - self.writeline(f"float {var_name} = {arg};") - elif any(str(arg) == s.name for s in dynamic_symbols): - self.writeline(f"auto {var_name} = {arg};") - elif arg == "nullptr": - self.writeline(f"auto {var_name} = nullptr;") - elif arg == "c10::nullopt": - self.writeline(f"auto {var_name} = c10::nullopt;") - else: - if config.abi_compatible: - self.writeline(f"CUdeviceptr {var_name};") - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast(&{var_name})));" - ) - else: - self.writeline( - f"CUdeviceptr {var_name} = reinterpret_cast({arg}.data_ptr());" - ) - new_args.append(f"&{var_name}") - - return ", ".join(new_args) - - def generate_default_grid(self, name: str, grid: List[Any], cuda: bool = True): - """ - Generate grid configs for launching a CUDA kernel using the grid - function from triton_heuristics. - """ - if not cuda: - return grid - assert isinstance(grid, list), f"expected {grid=} to be a list" - grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid] - grid_fn = default_grid(*grid) - params = CudaKernelParamCache.get(name) - assert ( - params is not None - ), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}" - block_cfg = { - "XBLOCK": params["x_block"], - "YBLOCK": params["y_block"], - "ZBLOCK": params["z_block"], - } - return grid_fn(block_cfg) - - def generate_kernel_call( - self, - name, - call_args, - grid=None, - device_index=None, - cuda=True, - triton=True, - arg_types=None, - grid_fn: str = "grid", - ): - if not cuda: - # Even in CudaWrapperCodeGen, we may see cpp kernels - return super().generate_kernel_call( - name, call_args, grid, device_index, cuda, triton, arg_types - ) - - params = CudaKernelParamCache.get(name) - assert ( - params is not None - ), f"cuda kernel parameters for {name} should already exist at this moment" - mangled_name = params.get("mangled_name", None) - assert mangled_name is not None, "missing mangled_name" - cubin_path = params.get(get_cpp_wrapper_cubin_path_name(), None) - assert cubin_path is not None and os.path.exists( - cubin_path - ), f"cubin file should already exist at this moment: {cubin_path}" - shared_mem = params.get("shared_mem", 0) - - self.generate_load_kernel_once(name, mangled_name, cubin_path, shared_mem) - - call_args = self.generate_args_decl(call_args) - kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}" - self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};") - stream = ( - "stream" if V.graph.aot_mode else self.write_get_raw_stream(device_index) - ) - grid_name = f"{name}_grid_{next(self.grid_id)}" - assert isinstance( - grid, (list, tuple) - ), f"expected grid to be a list or tuple but got: {grid=}" - - grid = [V.graph.sizevars.simplify(item) for item in grid] - grid_uses_symbolic_shapes = any(item.free_symbols for item in grid) - grid_args = [self.grid_expr_printer(item) for item in grid] - grid_args_str = ", ".join(grid_args) - self.writeline(f"Grid {grid_name} = Grid({grid_args_str});") - - if grid_uses_symbolic_shapes: - self.writeline(f"if ({grid_name}.is_non_zero()) {{") - kernel_var_name = f"kernels.{name}" if V.graph.aot_mode else name - self.writeline( - "launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format( - kernel_var_name, - f"{grid_name}.grid_x", - f"{grid_name}.grid_y", - f"{grid_name}.grid_z", - params["num_warps"], - params["shared_mem"], - kernel_args_var, - stream, - ) - ) - if grid_uses_symbolic_shapes: - self.writeline("}") diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 5ed55310b4d3..e2b5ded299b3 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -30,7 +30,8 @@ from .codegen.common import ( get_wrapper_codegen_for_device, register_backend_for_device, ) -from .codegen.wrapper import CppWrapperCodeGen, CudaWrapperCodeGen, WrapperCodeGen +from .codegen.cpp_wrapper_cuda import CudaWrapperCodeGen +from .codegen.wrapper import CppWrapperCodeGen, WrapperCodeGen from .exc import ( CppWrapperCodeGenError, LoweringException,