[5/x][AMD][Lowering Enablement] Hipifying aoti code_wrapper (#124241)

Summary: as title

Test Plan:
CI & unit test

patch on top of https://www.internalfb.com/phabricator/paste/view/P1214895953 to test

Differential Revision: D56223917

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124241
Approved by: https://github.com/jansel, https://github.com/desertfire
This commit is contained in:
Zhuoran Zhao
2024-04-19 18:57:38 +00:00
committed by PyTorch MergeBot
parent 25c65d6642
commit b0d83726bd
5 changed files with 134 additions and 86 deletions

View File

@ -0,0 +1,20 @@
import torch
from torch.utils.hipify.hipify_python import PYTORCH_MAP, RE_PYTORCH_PREPROCESSOR
# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like:
# "...
# from ..codecache import CudaKernelParamCache
# ..."
# In such cases, we do not need to hipify_torch the orignial class/file name in codegen/codecache
def maybe_hipify_code_wrapper(source_codes: str) -> str:
if torch.version.hip is None:
return source_codes
def c2_repl(m):
return PYTORCH_MAP[m.group(0)]
source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes)
return source_codes

View File

@ -0,0 +1,88 @@
import torch
# Provide aoti module launch hip/cuda drivers. This file is also used for unit testing purpose
def cuda_kernel_driver() -> str:
source_codes = """
#define CUDA_DRIVER_CHECK(EXPR) \\
do { \\
CUresult code = EXPR; \\
const char *msg; \\
cuGetErrorString(code, &msg); \\
if (code != CUDA_SUCCESS) { \\
throw std::runtime_error( \\
std::string("CUDA driver error: ") + \\
std::string(msg)); \\
} \\
} while (0);
namespace {
struct Grid {
Grid(uint32_t x, uint32_t y, uint32_t z)
: grid_x(x), grid_y(y), grid_z(z) {}
uint32_t grid_x;
uint32_t grid_y;
uint32_t grid_z;
bool is_non_zero() {
return grid_x > 0 && grid_y > 0 && grid_z > 0;
}
};
} // anonymous namespace
static inline CUfunction loadKernel(
std::string filePath,
const std::string &funcName,
uint32_t sharedMemBytes,
const std::optional<std::string> &cubinDir = std::nullopt) {
if (cubinDir) {
std::filesystem::path p1{*cubinDir};
std::filesystem::path p2{filePath};
filePath = (p1 / p2.filename()).string();
}
CUmodule mod;
CUfunction func;
CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
if (sharedMemBytes > 0) {
CUDA_DRIVER_CHECK(cuFuncSetAttribute(
func,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
sharedMemBytes
))
}
return func;
}
static inline void launchKernel(
CUfunction func,
uint32_t gridX,
uint32_t gridY,
uint32_t gridZ,
uint32_t numWarps,
uint32_t sharedMemBytes,
void* args[],
cudaStream_t stream) {
CUDA_DRIVER_CHECK(cuLaunchKernel(
func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
));
}
"""
if torch.version.hip is not None:
# Replace the warp size from 32 (cuLaunchKernel) to 64 (hipModuleLaunchKernel)
# The warp size on NV GPU is 32, while the wavefront size on AMD GPU is 64
source_codes = source_codes.replace("32*numWarps", "64*numWarps")
return source_codes
def cuda_kernel_header() -> str:
source_codes = """
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/EmptyTensor.h>
"""
return source_codes

View File

@ -15,6 +15,7 @@ from .. import config, ir
from ..codecache import CudaKernelParamCache
from ..utils import cache_on_self, sympy_product
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import IndentedBuffer
from .wrapper import EnterSubgraphLine, ExitSubgraphLine, WrapperCodeGen
@ -665,7 +666,9 @@ class CppWrapperCpu(WrapperCodeGen):
V.graph.const_module.wrapper_code.src_to_kernel.values()
)
for kernel in sorted(declare_kernel):
self.prefix.writeline(f" CUfunction {kernel}{{nullptr}};")
self.prefix.writeline(
maybe_hipify_code_wrapper(f" CUfunction {kernel}{{nullptr}};")
)
self.prefix.writeline("};")
self.prefix.writeline("} // namespace")

View File

@ -11,6 +11,8 @@ from .. import config
from ..codecache import CudaKernelParamCache
from ..triton_heuristics import grid as default_grid
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .codegen_device_driver import cuda_kernel_driver, cuda_kernel_header
from .cpp_wrapper_cpu import CppWrapperCpu
from .wrapper import SymbolicCallArg
@ -64,88 +66,12 @@ class CppWrapperCuda(CppWrapperCpu):
"#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
)
else:
self.header.splice(
"""
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/EmptyTensor.h>
"""
)
self.header.splice(
"""
#define CUDA_DRIVER_CHECK(EXPR) \\
do { \\
CUresult code = EXPR; \\
const char *msg; \\
cuGetErrorString(code, &msg); \\
if (code != CUDA_SUCCESS) { \\
throw std::runtime_error( \\
std::string("CUDA driver error: ") + \\
std::string(msg)); \\
} \\
} while (0);
namespace {
struct Grid {
Grid(uint32_t x, uint32_t y, uint32_t z)
: grid_x(x), grid_y(y), grid_z(z) {}
uint32_t grid_x;
uint32_t grid_y;
uint32_t grid_z;
bool is_non_zero() {
return grid_x > 0 && grid_y > 0 && grid_z > 0;
}
};
} // anonymous namespace
static inline CUfunction loadKernel(
std::string filePath,
const std::string &funcName,
uint32_t sharedMemBytes,
const std::optional<std::string> &cubinDir = std::nullopt) {
if (cubinDir) {
std::filesystem::path p1{*cubinDir};
std::filesystem::path p2{filePath};
filePath = (p1 / p2.filename()).string();
}
CUmodule mod;
CUfunction func;
CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
if (sharedMemBytes > 0) {
CUDA_DRIVER_CHECK(cuFuncSetAttribute(
func,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
sharedMemBytes
))
}
return func;
}
static inline void launchKernel(
CUfunction func,
uint32_t gridX,
uint32_t gridY,
uint32_t gridZ,
uint32_t numWarps,
uint32_t sharedMemBytes,
void* args[],
cudaStream_t stream) {
CUDA_DRIVER_CHECK(cuLaunchKernel(
func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
));
}
"""
)
self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_header()))
self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_driver()))
def write_get_raw_stream(self, index, graph=None):
name = f"stream{index}"
self.writeline(f"cudaStream_t {name};")
self.writeline(maybe_hipify_code_wrapper(f"cudaStream_t {name};"))
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream({index}, (void**)&{name}));"
)
@ -164,7 +90,9 @@ class CppWrapperCuda(CppWrapperCpu):
sorted(self.src_to_kernel.values()),
sorted([entry[0] for entry in self.user_defined_kernel_cache.values()]),
):
self.prefix.writeline(f"static CUfunction {kernel} = nullptr;")
self.prefix.writeline(
maybe_hipify_code_wrapper(f"static CUfunction {kernel} = nullptr;")
)
self.prefix.writeline("\n")
return super().generate(is_inference)
@ -214,14 +142,18 @@ class CppWrapperCuda(CppWrapperCpu):
self.writeline(f"auto {var_name} = c10::nullopt;")
else:
if config.abi_compatible:
self.writeline(f"CUdeviceptr {var_name};")
self.writeline(
maybe_hipify_code_wrapper(f"CUdeviceptr {var_name};")
)
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));"
)
else:
self.writeline(
maybe_hipify_code_wrapper(
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
)
)
new_args.append(f"&{var_name}")
return ", ".join(new_args)

View File

@ -42,6 +42,7 @@ from ..utils import (
sympy_str,
)
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter
from .triton_utils import config_of, signature_to_meta
@ -264,9 +265,11 @@ class EnterDeviceContextManagerLine(WrapperLine):
)
else:
code.writeline(
maybe_hipify_code_wrapper(
"at::cuda::CUDAStreamGuard stream_guard("
+ "at::cuda::getStreamFromExternal(stream, this->device_idx_));"
)
)
else:
assert (
self.last_seen_device_guard_index == self.device_idx
@ -276,7 +279,9 @@ class EnterDeviceContextManagerLine(WrapperLine):
code.writeline(
f"AOTICudaGuard device_guard({self.device_idx});"
if config.abi_compatible
else f"at::cuda::CUDAGuard device_guard({self.device_idx});"
else maybe_hipify_code_wrapper(
f"at::cuda::CUDAGuard device_guard({self.device_idx});"
)
)
else:
code.writeline(f"device_guard.set_index({self.device_idx});")