mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor][ROCm][CK] Add standalone runner (#139441)
Generate standalone executable to debug and profile CK gemm instances Pull Request resolved: https://github.com/pytorch/pytorch/pull/139441 Approved by: https://github.com/ColinPeppler
This commit is contained in:
committed by
PyTorch MergeBot
parent
d36fdaf157
commit
ca30704f0b
@ -266,12 +266,19 @@ class AsyncCompile:
|
||||
|
||||
return self.submit(task)
|
||||
|
||||
def rocm(self, source_code, dst_file_ext, aot_compile=False):
|
||||
def rocm(
|
||||
self,
|
||||
source_code,
|
||||
dst_file_ext,
|
||||
aot_compile=False,
|
||||
):
|
||||
kernel_code_log.info("ROCm Kernel:\n%s", source_code)
|
||||
|
||||
def task():
|
||||
if aot_compile:
|
||||
_ = ROCmCodeCache.compile(source_code, dst_file_ext="o")
|
||||
if config.rocm.generate_test_runner:
|
||||
_ = ROCmCodeCache.compile(source_code, dst_file_ext="exe")
|
||||
return ROCmCodeCache.load(source_code, dst_file_ext)[0]
|
||||
|
||||
return self.submit(task)
|
||||
|
@ -3486,8 +3486,9 @@ class ROCmCodeCache:
|
||||
log.info(log_duration_msg)
|
||||
else:
|
||||
log.debug(
|
||||
"Compilation skipped: %s since output already exists",
|
||||
"Skip compiling %s: output %s already exists",
|
||||
input_path,
|
||||
output_path,
|
||||
)
|
||||
cls.cache[key] = ROCmCodeCache.CacheEntry(input_path, output_path)
|
||||
|
||||
|
@ -8,7 +8,9 @@ import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.cpp_utils import DTYPE_TO_CPP
|
||||
from torch._inductor.codegen.rocm.ck_template import CKTemplate
|
||||
from torch._inductor.codegen.rocm.compile_command import rocm_compile_command
|
||||
from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel
|
||||
from torch._inductor.ir import Buffer, Layout
|
||||
|
||||
@ -73,12 +75,164 @@ class CKGemmTemplate(CKTemplate):
|
||||
return 0;
|
||||
}
|
||||
// run the kernel
|
||||
float elapsed_time = invoker.Run(argument, StreamConfig{stream, /* time kernel */ false, /* log level */ kDEBUG_LOG});
|
||||
#ifdef GENERATE_CK_STANDALONE_RUNNER
|
||||
const auto stream_config = StreamConfig{
|
||||
stream,
|
||||
/* time kernel */ 1,
|
||||
/* log level */ 1,
|
||||
/* n_cold_iter */ 100,
|
||||
/* n_hot_iter */ 100,
|
||||
/* flush_l2_cache */ 1,
|
||||
/* rotate_count */ 5};
|
||||
#else
|
||||
const auto stream_config = StreamConfig{stream, /* time kernel */ false, /* log level */ 0};
|
||||
#endif
|
||||
|
||||
const float elapsed_time = invoker.Run(argument, stream_config);
|
||||
|
||||
#ifdef GENERATE_CK_STANDALONE_RUNNER
|
||||
std::cout << "elapsed time: " << elapsed_time << " ms" << std::endl;
|
||||
#else
|
||||
(void)elapsed_time;
|
||||
#endif
|
||||
return 0;
|
||||
} // kernel definition
|
||||
} // extern C
|
||||
"""
|
||||
|
||||
standalone_runner_template = r"""
|
||||
#ifdef GENERATE_CK_STANDALONE_RUNNER
|
||||
// standalone runner for the generated CK GEMM kernel
|
||||
|
||||
{{inline_utils}}
|
||||
|
||||
extern "C" {
|
||||
int run_main(int argc, char** argv) {
|
||||
const int32_t M = {{M}};
|
||||
const int32_t N = {{N}};
|
||||
const int32_t K = {{K}};
|
||||
const int32_t LDA = {{LDA}};
|
||||
const int32_t LDB = {{LDB}};
|
||||
const int32_t LDC = {{LDC}};
|
||||
const int32_t LDD = {{LDD}};
|
||||
|
||||
using AElementType = {{a_ck_dtype}};
|
||||
using BElementType = {{b_ck_dtype}};
|
||||
using CElementType = {{c_ck_dtype}};
|
||||
{% if has_bias %}
|
||||
using BiasElementType = {{bias_ck_dtype}};
|
||||
{% endif %}
|
||||
{% if has_scale %}
|
||||
using ScaleAElementType = {{scale_a_ck_dtype}};
|
||||
using ScaleBElementType = {{scale_b_ck_dtype}};
|
||||
{% endif %}
|
||||
|
||||
using AArgType = {{a_torch_dtype}};
|
||||
using BArgType = {{b_torch_dtype}};
|
||||
using CArgType = {{c_torch_dtype}};
|
||||
{% if has_bias %}
|
||||
using BiasArgType = {{bias_torch_dtype}};
|
||||
{% endif %}
|
||||
{% if has_scale %}
|
||||
using ScaleAArgType = {{scale_a_torch_dtype}};
|
||||
using ScaleBArgType = {{scale_b_torch_dtype}};
|
||||
{% endif %}
|
||||
|
||||
using ALayout = {{a_layout}};
|
||||
using BLayout = {{b_layout}};
|
||||
using CLayout = {{c_layout}};
|
||||
{% if has_bias %}
|
||||
using BiasLayout = {{bias_layout}};
|
||||
{% endif %}
|
||||
|
||||
using strides_t = std::array<int32_t, 2>;
|
||||
|
||||
auto get_strides = [](int32_t leading_dimension, auto layout) constexpr -> strides_t {
|
||||
if constexpr (std::is_same_v<decltype(layout), Row>) {
|
||||
return {leading_dimension, 1};
|
||||
}
|
||||
return {1, leading_dimension};
|
||||
};
|
||||
|
||||
Tensor<AElementType> a_m_k ( HostTensorDescriptor ( strides_t{M, K}, get_strides(LDA, ALayout{}) ) );
|
||||
Tensor<BElementType> b_k_n ( HostTensorDescriptor ( strides_t{N, K}, get_strides(LDB, BLayout{}) ) );
|
||||
{% if has_bias %}
|
||||
Tensor<BiasElementType> d_m_n ( HostTensorDescriptor ( strides_t{M, N}, get_strides(LDD, BiasLayout{}) ) );
|
||||
{% endif %}
|
||||
{% if has_scale %}
|
||||
// NB: these are hardcoded
|
||||
Tensor<ScaleAElementType> s_a_m_n ( HostTensorDescriptor ( strides_t{M, N}, get_strides(0, Row{}) ));
|
||||
Tensor<ScaleAElementType> s_b_m_n ( HostTensorDescriptor ( strides_t{M, N}, get_strides(0, Col{}) ));
|
||||
{% endif %}
|
||||
|
||||
Tensor<CElementType> c_m_n_host ( HostTensorDescriptor ( strides_t{M, N}, get_strides(LDC, CLayout{}) ) );
|
||||
Tensor<CElementType> c_m_n_device ( HostTensorDescriptor ( strides_t{M, N}, get_strides(LDC, CLayout{}) ) );
|
||||
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<AElementType>());
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BElementType>());
|
||||
{% if has_bias %}
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_2<BiasElementType>());
|
||||
{% endif %}
|
||||
{% if has_scale %}
|
||||
s_a_m_n.GenerateTensorValue(GeneratorTensor_2<ScaleAElementType>());
|
||||
s_b_m_n.GenerateTensorValue(GeneratorTensor_2<ScaleBElementType>());
|
||||
{% endif %}
|
||||
DeviceMem a_m_k_device_buf(sizeof(AElementType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BElementType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
{% if has_bias %}
|
||||
DeviceMem d_m_n_device_buf(sizeof(BiasElementType) * d_m_n.mDesc.GetElementSpaceSize());
|
||||
{% endif %}
|
||||
{% if has_scale %}
|
||||
DeviceMem s_a_m_n_device_buf(sizeof(ScaleAElementType) * s_a_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem s_b_m_n_device_buf(sizeof(ScaleBElementType) * s_b_m_n.mDesc.GetElementSpaceSize());
|
||||
{% endif %}
|
||||
DeviceMem c_m_n_device_buf(sizeof(CElementType) * c_m_n_device.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
{% if has_bias %}
|
||||
d_m_n_device_buf.ToDevice(d_m_n.mData.data());
|
||||
{% endif %}
|
||||
{% if has_scale %}
|
||||
s_a_m_n_device_buf.ToDevice(s_a_m_n.mData.data());
|
||||
s_b_m_n_device_buf.ToDevice(s_b_m_n.mData.data());
|
||||
{% endif %}
|
||||
|
||||
{{kernel_name}}(
|
||||
static_cast<const AArgType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<const BArgType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
{% if has_bias %}
|
||||
static_cast<const BiasArgType*>(d_m_n_device_buf.GetDeviceBuffer()),
|
||||
{% endif %}
|
||||
{% if has_scale %}
|
||||
static_cast<const ScaleAArgType*>(s_a_m_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<const ScaleBArgType*>(s_b_m_n_device_buf.GetDeviceBuffer()),
|
||||
{% endif %}
|
||||
static_cast<CArgType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
LDA,
|
||||
LDB,
|
||||
LDC,
|
||||
LDD,
|
||||
nullptr, // workspace_size
|
||||
nullptr, // workspace
|
||||
nullptr); // stream
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
return 0;
|
||||
} // run_main
|
||||
} // extern C
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
return run_main(argc, argv);
|
||||
}
|
||||
// compile with: {{compile_cmd}}
|
||||
#endif // GENERATE_CK_STANDALONE_RUNNER
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_nodes: List[Buffer],
|
||||
@ -123,6 +277,16 @@ class CKGemmTemplate(CKTemplate):
|
||||
)
|
||||
return res
|
||||
|
||||
def inline_utils(self):
|
||||
res = IndentedBuffer()
|
||||
res.splice(
|
||||
"""
|
||||
#include "host_tensor.cpp"
|
||||
#include "device_memory.cpp"
|
||||
"""
|
||||
)
|
||||
return res
|
||||
|
||||
def filter_op(self, op: "CKGemmOperation"):
|
||||
"""
|
||||
Determines whether a given op definition is suitable for the current
|
||||
@ -333,7 +497,8 @@ class CKGemmTemplate(CKTemplate):
|
||||
|
||||
assert epilogue is not None, "CK GEMM epilogue is not set"
|
||||
|
||||
return self._template_from_string(self.gemm_template).render(
|
||||
res = self._template_from_string(self.gemm_template).render(
|
||||
inline_utils=self.inline_utils(),
|
||||
headers=self.header().getvalue(),
|
||||
globals=self.globals().getvalue(),
|
||||
instance_definition=instance_definition,
|
||||
@ -380,6 +545,67 @@ class CKGemmTemplate(CKTemplate):
|
||||
version_comment=version_comment,
|
||||
)
|
||||
|
||||
if config.rocm.generate_test_runner:
|
||||
is_static_problem = all(is_static_int(arg) for arg in self.size_args())
|
||||
M, N, K, LDA, LDB, LDC, LDD = (
|
||||
self.size_args()
|
||||
if is_static_problem
|
||||
else (
|
||||
f"std::stoi(argv[{k}])" for k, _ in enumerate(self.size_args(), 1)
|
||||
)
|
||||
)
|
||||
has_bias = Bias is not None
|
||||
has_scale = scale_x is not None and scale_w is not None
|
||||
runner_code = self._template_from_string(
|
||||
self.standalone_runner_template
|
||||
).render(
|
||||
inline_utils=self.inline_utils().getvalue(),
|
||||
kernel_name=kernel.kernel_name,
|
||||
M=M,
|
||||
N=N,
|
||||
K=K,
|
||||
LDA=LDA,
|
||||
LDB=LDB,
|
||||
LDC=LDC,
|
||||
LDD=LDD,
|
||||
has_bias=has_bias,
|
||||
has_scale=has_scale,
|
||||
a_ck_dtype=op.a_element_dtype,
|
||||
b_ck_dtype=op.b_element_dtype,
|
||||
c_ck_dtype=op.c_element_dtype,
|
||||
bias_ck_dtype=op.ds_element_dtypes[0] if has_bias else "",
|
||||
scale_a_ck_dtype=op.ds_element_dtypes[0]
|
||||
if has_scale and 2 == len(op.ds_element_dtypes)
|
||||
else "BF16",
|
||||
scale_b_ck_dtype=op.ds_element_dtypes[1]
|
||||
if has_scale and 2 == len(op.ds_element_dtypes)
|
||||
else "BF16",
|
||||
a_torch_dtype=DTYPE_TO_CPP[X.get_layout().dtype],
|
||||
b_torch_dtype=DTYPE_TO_CPP[W.get_layout().dtype],
|
||||
c_torch_dtype=DTYPE_TO_CPP[Y.get_layout().dtype],
|
||||
bias_torch_dtype=DTYPE_TO_CPP[Bias.get_layout().dtype]
|
||||
if Bias is not None
|
||||
else "",
|
||||
scale_a_torch_dtype=DTYPE_TO_CPP[scale_x.get_layout().dtype]
|
||||
if scale_x is not None
|
||||
else "",
|
||||
scale_b_torch_dtype=DTYPE_TO_CPP[scale_w.get_layout().dtype]
|
||||
if scale_w is not None
|
||||
else "",
|
||||
a_layout=torch_layout_to_ck_layout(X.get_layout()),
|
||||
b_layout=torch_layout_to_ck_layout(W.get_layout()),
|
||||
c_layout=torch_layout_to_ck_layout(Y.get_layout()),
|
||||
bias_layout=torch_layout_to_ck_layout(Bias.get_layout())
|
||||
if Bias is not None
|
||||
else "",
|
||||
compile_cmd=rocm_compile_command(
|
||||
["<source_file_name>"], "<executable_name>", "exe"
|
||||
),
|
||||
)
|
||||
res += runner_code
|
||||
|
||||
return res
|
||||
|
||||
def _is_rcr_f16(self):
|
||||
X_meta, W_meta, Y_meta = (
|
||||
T.get_layout() for T in [*self.input_nodes, self.output_node]
|
||||
|
@ -10,7 +10,7 @@ from torch._inductor.utils import is_linux
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _rocm_include_paths() -> List[str]:
|
||||
def _rocm_include_paths(dst_file_ext: str) -> List[str]:
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
rocm_include = (
|
||||
@ -38,10 +38,13 @@ def _rocm_include_paths() -> List[str]:
|
||||
paths = [
|
||||
os.path.realpath(p) for p in (ck_include, ck_library_include, rocm_include)
|
||||
]
|
||||
if dst_file_ext == "exe":
|
||||
ck_utility_include = os.path.join(ck_path, "library", "src", "utility")
|
||||
paths.append(os.path.realpath(ck_utility_include))
|
||||
return paths
|
||||
|
||||
|
||||
def _rocm_lib_options() -> List[str]:
|
||||
def _rocm_lib_options(dst_file_ext: str) -> List[str]:
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
rocm_lib_dir = (
|
||||
@ -55,12 +58,15 @@ def _rocm_lib_options() -> List[str]:
|
||||
else cpp_extension._join_rocm_home("hip", "lib")
|
||||
)
|
||||
|
||||
return [
|
||||
opts = [
|
||||
"-include __clang_hip_runtime_wrapper.h",
|
||||
f"-L{os.path.realpath(rocm_lib_dir)}",
|
||||
f"-L{os.path.realpath(hip_lib_dir)}",
|
||||
"-lamdhip64",
|
||||
]
|
||||
if dst_file_ext == "exe":
|
||||
opts += ["-lpthread", "-lstdc++"]
|
||||
return opts
|
||||
|
||||
|
||||
def _rocm_compiler_options() -> List[str]:
|
||||
@ -118,25 +124,24 @@ def rocm_compile_command(
|
||||
dst_file_ext: str,
|
||||
extra_args: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
include_paths = _rocm_include_paths()
|
||||
lib_options = _rocm_lib_options()
|
||||
include_paths = _rocm_include_paths(dst_file_ext)
|
||||
lib_options = _rocm_lib_options(dst_file_ext)
|
||||
compiler_options = _rocm_compiler_options()
|
||||
compiler = rocm_compiler()
|
||||
options = (
|
||||
compiler_options
|
||||
+ (extra_args if extra_args else [])
|
||||
+ ["-I" + path for path in include_paths]
|
||||
+ (extra_args or [])
|
||||
+ [f"-I{path}" for path in include_paths]
|
||||
+ lib_options
|
||||
)
|
||||
src_file = " ".join(src_files)
|
||||
res = ""
|
||||
# supported extensions: .o, .so, .exe
|
||||
if dst_file_ext == "o":
|
||||
res = f"{compiler} {' '.join(options)} -c -o {dst_file} {src_file}"
|
||||
options.append("-c")
|
||||
elif dst_file_ext == "so":
|
||||
options.append("-shared")
|
||||
res = f"{compiler} {' '.join(options)} -o {dst_file} {src_file}"
|
||||
elif dst_file_ext == "exe":
|
||||
res = f"{compiler} {' '.join(options)} -o {dst_file} {src_file}"
|
||||
options.append("-DGENERATE_CK_STANDALONE_RUNNER")
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!")
|
||||
return res
|
||||
return f"{compiler} {' '.join(options)} -o {dst_file} {src_file}"
|
||||
|
@ -7,6 +7,7 @@ from ctypes import byref, c_int, c_size_t, c_void_p
|
||||
from typing import Any, Callable, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.autotune_process import (
|
||||
BenchmarkRequest,
|
||||
GPUDeviceBenchmarkMixin,
|
||||
@ -45,6 +46,8 @@ class ROCmBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
|
||||
# may happen in separate Threadpool
|
||||
log.debug("Precompiling %s", self)
|
||||
ROCmCodeCache.compile(self.source_code, "so")
|
||||
if config.rocm.generate_test_runner:
|
||||
ROCmCodeCache.compile(self.source_code, "exe")
|
||||
log.debug("Done precompiling %s", self)
|
||||
|
||||
def make_run_fn(
|
||||
|
@ -1193,6 +1193,11 @@ class rocm:
|
||||
# Install with `pip install git+https://github.com/rocm/composable_kernel@develop`.
|
||||
ck_dir = os.environ.get("TORCHINDUCTOR_CK_DIR")
|
||||
|
||||
# generate standalone executables for instances generated with the CK backend
|
||||
generate_test_runner: bool = (
|
||||
os.environ.get("INDUCTOR_CK_BACKEND_GENERATE_TEST_RUNNER_CODE", "0") == "1"
|
||||
)
|
||||
|
||||
# Number of op instance choices to trade off between runtime perf and compilation time
|
||||
n_max_profiling_configs: Optional[int] = None
|
||||
|
||||
|
Reference in New Issue
Block a user