mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[5/x][AMD][Lowering Enablement] Hipifying aoti code_wrapper (#124241)
Summary: as title Test Plan: CI & unit test patch on top of https://www.internalfb.com/phabricator/paste/view/P1214895953 to test Differential Revision: D56223917 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124241 Approved by: https://github.com/jansel, https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
25c65d6642
commit
b0d83726bd
20
torch/_inductor/codegen/aoti_hipify_utils.py
Normal file
20
torch/_inductor/codegen/aoti_hipify_utils.py
Normal file
@ -0,0 +1,20 @@
|
||||
import torch
|
||||
|
||||
from torch.utils.hipify.hipify_python import PYTORCH_MAP, RE_PYTORCH_PREPROCESSOR
|
||||
|
||||
# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like:
|
||||
# "...
|
||||
# from ..codecache import CudaKernelParamCache
|
||||
# ..."
|
||||
# In such cases, we do not need to hipify_torch the orignial class/file name in codegen/codecache
|
||||
|
||||
|
||||
def maybe_hipify_code_wrapper(source_codes: str) -> str:
|
||||
if torch.version.hip is None:
|
||||
return source_codes
|
||||
|
||||
def c2_repl(m):
|
||||
return PYTORCH_MAP[m.group(0)]
|
||||
|
||||
source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes)
|
||||
return source_codes
|
88
torch/_inductor/codegen/codegen_device_driver.py
Normal file
88
torch/_inductor/codegen/codegen_device_driver.py
Normal file
@ -0,0 +1,88 @@
|
||||
import torch
|
||||
|
||||
# Provide aoti module launch hip/cuda drivers. This file is also used for unit testing purpose
|
||||
|
||||
|
||||
def cuda_kernel_driver() -> str:
|
||||
source_codes = """
|
||||
#define CUDA_DRIVER_CHECK(EXPR) \\
|
||||
do { \\
|
||||
CUresult code = EXPR; \\
|
||||
const char *msg; \\
|
||||
cuGetErrorString(code, &msg); \\
|
||||
if (code != CUDA_SUCCESS) { \\
|
||||
throw std::runtime_error( \\
|
||||
std::string("CUDA driver error: ") + \\
|
||||
std::string(msg)); \\
|
||||
} \\
|
||||
} while (0);
|
||||
|
||||
namespace {
|
||||
|
||||
struct Grid {
|
||||
Grid(uint32_t x, uint32_t y, uint32_t z)
|
||||
: grid_x(x), grid_y(y), grid_z(z) {}
|
||||
uint32_t grid_x;
|
||||
uint32_t grid_y;
|
||||
uint32_t grid_z;
|
||||
|
||||
bool is_non_zero() {
|
||||
return grid_x > 0 && grid_y > 0 && grid_z > 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
static inline CUfunction loadKernel(
|
||||
std::string filePath,
|
||||
const std::string &funcName,
|
||||
uint32_t sharedMemBytes,
|
||||
const std::optional<std::string> &cubinDir = std::nullopt) {
|
||||
if (cubinDir) {
|
||||
std::filesystem::path p1{*cubinDir};
|
||||
std::filesystem::path p2{filePath};
|
||||
filePath = (p1 / p2.filename()).string();
|
||||
}
|
||||
|
||||
CUmodule mod;
|
||||
CUfunction func;
|
||||
CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
|
||||
CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
|
||||
if (sharedMemBytes > 0) {
|
||||
CUDA_DRIVER_CHECK(cuFuncSetAttribute(
|
||||
func,
|
||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
sharedMemBytes
|
||||
))
|
||||
}
|
||||
return func;
|
||||
}
|
||||
|
||||
static inline void launchKernel(
|
||||
CUfunction func,
|
||||
uint32_t gridX,
|
||||
uint32_t gridY,
|
||||
uint32_t gridZ,
|
||||
uint32_t numWarps,
|
||||
uint32_t sharedMemBytes,
|
||||
void* args[],
|
||||
cudaStream_t stream) {
|
||||
CUDA_DRIVER_CHECK(cuLaunchKernel(
|
||||
func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
|
||||
));
|
||||
}
|
||||
"""
|
||||
if torch.version.hip is not None:
|
||||
# Replace the warp size from 32 (cuLaunchKernel) to 64 (hipModuleLaunchKernel)
|
||||
# The warp size on NV GPU is 32, while the wavefront size on AMD GPU is 64
|
||||
source_codes = source_codes.replace("32*numWarps", "64*numWarps")
|
||||
return source_codes
|
||||
|
||||
|
||||
def cuda_kernel_header() -> str:
|
||||
source_codes = """
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <ATen/cuda/EmptyTensor.h>
|
||||
"""
|
||||
return source_codes
|
@ -15,6 +15,7 @@ from .. import config, ir
|
||||
from ..codecache import CudaKernelParamCache
|
||||
from ..utils import cache_on_self, sympy_product
|
||||
from ..virtualized import V
|
||||
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
||||
from .common import IndentedBuffer
|
||||
from .wrapper import EnterSubgraphLine, ExitSubgraphLine, WrapperCodeGen
|
||||
|
||||
@ -665,7 +666,9 @@ class CppWrapperCpu(WrapperCodeGen):
|
||||
V.graph.const_module.wrapper_code.src_to_kernel.values()
|
||||
)
|
||||
for kernel in sorted(declare_kernel):
|
||||
self.prefix.writeline(f" CUfunction {kernel}{{nullptr}};")
|
||||
self.prefix.writeline(
|
||||
maybe_hipify_code_wrapper(f" CUfunction {kernel}{{nullptr}};")
|
||||
)
|
||||
self.prefix.writeline("};")
|
||||
self.prefix.writeline("} // namespace")
|
||||
|
||||
|
@ -11,6 +11,8 @@ from .. import config
|
||||
from ..codecache import CudaKernelParamCache
|
||||
from ..triton_heuristics import grid as default_grid
|
||||
from ..virtualized import V
|
||||
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
||||
from .codegen_device_driver import cuda_kernel_driver, cuda_kernel_header
|
||||
from .cpp_wrapper_cpu import CppWrapperCpu
|
||||
from .wrapper import SymbolicCallArg
|
||||
|
||||
@ -64,88 +66,12 @@ class CppWrapperCuda(CppWrapperCpu):
|
||||
"#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
|
||||
)
|
||||
else:
|
||||
self.header.splice(
|
||||
"""
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <ATen/cuda/EmptyTensor.h>
|
||||
"""
|
||||
)
|
||||
|
||||
self.header.splice(
|
||||
"""
|
||||
#define CUDA_DRIVER_CHECK(EXPR) \\
|
||||
do { \\
|
||||
CUresult code = EXPR; \\
|
||||
const char *msg; \\
|
||||
cuGetErrorString(code, &msg); \\
|
||||
if (code != CUDA_SUCCESS) { \\
|
||||
throw std::runtime_error( \\
|
||||
std::string("CUDA driver error: ") + \\
|
||||
std::string(msg)); \\
|
||||
} \\
|
||||
} while (0);
|
||||
|
||||
namespace {
|
||||
|
||||
struct Grid {
|
||||
Grid(uint32_t x, uint32_t y, uint32_t z)
|
||||
: grid_x(x), grid_y(y), grid_z(z) {}
|
||||
uint32_t grid_x;
|
||||
uint32_t grid_y;
|
||||
uint32_t grid_z;
|
||||
|
||||
bool is_non_zero() {
|
||||
return grid_x > 0 && grid_y > 0 && grid_z > 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
static inline CUfunction loadKernel(
|
||||
std::string filePath,
|
||||
const std::string &funcName,
|
||||
uint32_t sharedMemBytes,
|
||||
const std::optional<std::string> &cubinDir = std::nullopt) {
|
||||
if (cubinDir) {
|
||||
std::filesystem::path p1{*cubinDir};
|
||||
std::filesystem::path p2{filePath};
|
||||
filePath = (p1 / p2.filename()).string();
|
||||
}
|
||||
|
||||
CUmodule mod;
|
||||
CUfunction func;
|
||||
CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
|
||||
CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
|
||||
if (sharedMemBytes > 0) {
|
||||
CUDA_DRIVER_CHECK(cuFuncSetAttribute(
|
||||
func,
|
||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
sharedMemBytes
|
||||
))
|
||||
}
|
||||
return func;
|
||||
}
|
||||
|
||||
static inline void launchKernel(
|
||||
CUfunction func,
|
||||
uint32_t gridX,
|
||||
uint32_t gridY,
|
||||
uint32_t gridZ,
|
||||
uint32_t numWarps,
|
||||
uint32_t sharedMemBytes,
|
||||
void* args[],
|
||||
cudaStream_t stream) {
|
||||
CUDA_DRIVER_CHECK(cuLaunchKernel(
|
||||
func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
|
||||
));
|
||||
}
|
||||
"""
|
||||
)
|
||||
self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_header()))
|
||||
self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_driver()))
|
||||
|
||||
def write_get_raw_stream(self, index, graph=None):
|
||||
name = f"stream{index}"
|
||||
self.writeline(f"cudaStream_t {name};")
|
||||
self.writeline(maybe_hipify_code_wrapper(f"cudaStream_t {name};"))
|
||||
self.writeline(
|
||||
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream({index}, (void**)&{name}));"
|
||||
)
|
||||
@ -164,7 +90,9 @@ class CppWrapperCuda(CppWrapperCpu):
|
||||
sorted(self.src_to_kernel.values()),
|
||||
sorted([entry[0] for entry in self.user_defined_kernel_cache.values()]),
|
||||
):
|
||||
self.prefix.writeline(f"static CUfunction {kernel} = nullptr;")
|
||||
self.prefix.writeline(
|
||||
maybe_hipify_code_wrapper(f"static CUfunction {kernel} = nullptr;")
|
||||
)
|
||||
self.prefix.writeline("\n")
|
||||
return super().generate(is_inference)
|
||||
|
||||
@ -214,13 +142,17 @@ class CppWrapperCuda(CppWrapperCpu):
|
||||
self.writeline(f"auto {var_name} = c10::nullopt;")
|
||||
else:
|
||||
if config.abi_compatible:
|
||||
self.writeline(f"CUdeviceptr {var_name};")
|
||||
self.writeline(
|
||||
maybe_hipify_code_wrapper(f"CUdeviceptr {var_name};")
|
||||
)
|
||||
self.writeline(
|
||||
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));"
|
||||
)
|
||||
else:
|
||||
self.writeline(
|
||||
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
|
||||
maybe_hipify_code_wrapper(
|
||||
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
|
||||
)
|
||||
)
|
||||
new_args.append(f"&{var_name}")
|
||||
|
||||
|
@ -42,6 +42,7 @@ from ..utils import (
|
||||
sympy_str,
|
||||
)
|
||||
from ..virtualized import V
|
||||
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
||||
from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter
|
||||
from .triton_utils import config_of, signature_to_meta
|
||||
|
||||
@ -264,8 +265,10 @@ class EnterDeviceContextManagerLine(WrapperLine):
|
||||
)
|
||||
else:
|
||||
code.writeline(
|
||||
"at::cuda::CUDAStreamGuard stream_guard("
|
||||
+ "at::cuda::getStreamFromExternal(stream, this->device_idx_));"
|
||||
maybe_hipify_code_wrapper(
|
||||
"at::cuda::CUDAStreamGuard stream_guard("
|
||||
+ "at::cuda::getStreamFromExternal(stream, this->device_idx_));"
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
@ -276,7 +279,9 @@ class EnterDeviceContextManagerLine(WrapperLine):
|
||||
code.writeline(
|
||||
f"AOTICudaGuard device_guard({self.device_idx});"
|
||||
if config.abi_compatible
|
||||
else f"at::cuda::CUDAGuard device_guard({self.device_idx});"
|
||||
else maybe_hipify_code_wrapper(
|
||||
f"at::cuda::CUDAGuard device_guard({self.device_idx});"
|
||||
)
|
||||
)
|
||||
else:
|
||||
code.writeline(f"device_guard.set_index({self.device_idx});")
|
||||
|
Reference in New Issue
Block a user