mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor][CPP] Reuse the pre-existing kernel for the same kernels (#158404)
Reuse the pre-existing kernel to avoid defining redundant kernels. Inductor CPP will generate same kernels. For example: ``` # Example class Model(torch.nn.Module): def __init__(self, K, N): super().__init__() self.linear0 = torch.nn.Linear(K, N) self.linear1 = torch.nn.Linear(N, K) self.linear2 = torch.nn.Linear(K, N) def forward(self, input): out = self.linear0(input) out = self.linear1(out) out = self.linear2(out) return out ``` For the above example, linear2 is same as linear0, and Inductor CPP generates 2 same kernels: cpp_fused_addmm_0 and cpp_fused_addmm_2. ``` # Generated code: ... cpp_fused_addmm_0 = async_compile.cpp_pybinding(['const at::BFloat16*', 'const at::BFloat16*', 'const at::BFloat16*', 'at::BFloat16*'], ''' ... extern "C" void kernel(const at::BFloat16* X, const at::BFloat16* W, const at::BFloat16* inp, at::BFloat16* Y) { constexpr int64_t num_threads = 32; constexpr int64_t N = 1024; constexpr int64_t K = 2048; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; ... cpp_fused_addmm_1 = async_compile.cpp_pybinding(['const at::BFloat16*', 'const at::BFloat16*', 'const at::BFloat16*', 'at::BFloat16*'], ''' ... extern "C" void kernel(const at::BFloat16* X, const at::BFloat16* W, const at::BFloat16* inp, at::BFloat16* Y) { constexpr int64_t num_threads = 32; constexpr int64_t N = 2048; constexpr int64_t K = 1024; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; ... cpp_fused_addmm_2 = async_compile.cpp_pybinding(['const at::BFloat16*', 'const at::BFloat16*', 'const at::BFloat16*', 'at::BFloat16*'], ''' extern "C" void kernel(const at::BFloat16* X, const at::BFloat16* W, const at::BFloat16* inp, at::BFloat16* Y) { constexpr int64_t num_threads = 32; constexpr int64_t N = 1024; constexpr int64_t K = 2048; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; ... def call(self, args): arg6_1, = args args.clear() buf0 = empty_strided_cpu((1024, 1024), (1024, 1), torch.bfloat16) cpp_fused_addmm_0(arg6_1, constant6, _frozen_param6, buf0) del arg6_1 buf1 = empty_strided_cpu((1024, 2048), (2048, 1), torch.bfloat16) cpp_fused_addmm_1(buf0, constant6_0, _frozen_param8, buf1) buf2 = buf0; del buf0 # reuse cpp_fused_addmm_2(buf1, constant6_1, _frozen_param10, buf2) return (buf2, ) ``` After reusing the pre-existing kernel, Inductor CPP will reuse cpp_fused_addmm_0. ``` cpp_fused_addmm_0 = async_compile.cpp_pybinding(['const at::BFloat16*', 'const at::BFloat16*', 'const at::BFloat16*', 'at::BFloat16*'], ''' ... extern "C" void kernel(const at::BFloat16* X, const at::BFloat16* W, const at::BFloat16* inp, at::BFloat16* Y) { constexpr int64_t num_threads = 32; constexpr int64_t N = 1024; constexpr int64_t K = 2048; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; ... cpp_fused_addmm_1 = async_compile.cpp_pybinding(['const at::BFloat16*', 'const at::BFloat16*', 'const at::BFloat16*', 'at::BFloat16*'], ''' ... extern "C" void kernel(const at::BFloat16* X, const at::BFloat16* W, const at::BFloat16* inp, at::BFloat16* Y) { constexpr int64_t num_threads = 32; constexpr int64_t N = 2048; constexpr int64_t K = 1024; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; ... def call(self, args): arg6_1, = args args.clear() buf0 = empty_strided_cpu((1024, 1024), (1024, 1), torch.bfloat16) cpp_fused_addmm_0(arg6_1, constant6, _frozen_param6, buf0) del arg6_1 buf1 = empty_strided_cpu((1024, 2048), (2048, 1), torch.bfloat16) cpp_fused_addmm_1(buf0, constant6_0, _frozen_param8, buf1) buf2 = buf0; del buf0 # reuse cpp_fused_addmm_0(buf1, constant6_1, _frozen_param10, buf2) return (buf2, ) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158404 Approved by: https://github.com/jansel, https://github.com/leslie-fang-intel
This commit is contained in:
@ -2910,6 +2910,37 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
||||
with verify(u.dtype) as (atol, rtol):
|
||||
self.common(mod, (u, v))
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
|
||||
)
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
@torch.no_grad
|
||||
@parametrize("batch_size", (1024,))
|
||||
@parametrize("in_features", (1024,))
|
||||
@parametrize("out_features", (2048,))
|
||||
@dtypes(torch.bfloat16)
|
||||
def test_linear_reuse_kernels(self, batch_size, in_features, out_features, dtype):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear_x = torch.nn.Linear(in_features, out_features)
|
||||
self.linear_y = torch.nn.Linear(out_features, in_features)
|
||||
self.linear_z = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.linear_x(x)
|
||||
out = self.linear_y(out)
|
||||
out = self.linear_z(out)
|
||||
return out
|
||||
|
||||
x = torch.randn(batch_size, in_features).to(dtype=dtype)
|
||||
mod = M().to(dtype=dtype).eval()
|
||||
self.common(mod, (x))
|
||||
_, code = run_and_get_cpp_code(mod, x)
|
||||
# Check that only 2 kernels are in the generated code
|
||||
assert code.count("AMXState amx_state") == 2
|
||||
|
||||
|
||||
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
|
||||
class _DynamicShapesTestBase(BaseTestSelectAlgorithm):
|
||||
|
@ -5390,42 +5390,48 @@ class CppScheduling(BaseScheduling):
|
||||
|
||||
def define_kernel(self, src_code, nodes, kernel_args=None):
|
||||
wrapper = V.graph.wrapper_code
|
||||
fused_name = (
|
||||
get_fused_kernel_name(nodes, config.cpp.descriptive_names)
|
||||
if config.cpp.descriptive_names
|
||||
else ""
|
||||
)
|
||||
kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
|
||||
kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"
|
||||
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name)
|
||||
src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name)
|
||||
# TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
|
||||
# not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
|
||||
src_code = src_code.replace("#pragma CMT", "//")
|
||||
if src_code in wrapper.src_to_kernel:
|
||||
kernel_name = wrapper.src_to_kernel[src_code]
|
||||
else:
|
||||
fused_name = (
|
||||
get_fused_kernel_name(nodes, config.cpp.descriptive_names)
|
||||
if config.cpp.descriptive_names
|
||||
else ""
|
||||
)
|
||||
kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
|
||||
wrapper.src_to_kernel[src_code] = kernel_name
|
||||
kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"
|
||||
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name)
|
||||
src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name)
|
||||
# TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
|
||||
# not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
|
||||
src_code = src_code.replace("#pragma CMT", "//")
|
||||
|
||||
# Get the lines in the source code representing the function definition,
|
||||
# excluding the the first line including cpp_prefix.h.
|
||||
first_char = src_code.rfind('extern "C"')
|
||||
last_char = src_code.find(")", first_char)
|
||||
if _IS_WINDOWS:
|
||||
# get_export_declaration introduced one more ')' in Windows
|
||||
last_char = src_code.find(")", last_char + 1)
|
||||
kernel_definition = f"{src_code[first_char : last_char + 1]};\n"
|
||||
# Get the lines in the source code representing the function definition,
|
||||
# excluding the the first line including cpp_prefix.h.
|
||||
first_char = src_code.rfind('extern "C"')
|
||||
last_char = src_code.find(")", first_char)
|
||||
if _IS_WINDOWS:
|
||||
# get_export_declaration introduced one more ')' in Windows
|
||||
last_char = src_code.find(")", last_char + 1)
|
||||
kernel_definition = f"{src_code[first_char : last_char + 1]};\n"
|
||||
|
||||
compile_wrapper = IndentedBuffer()
|
||||
args = self.kernel_group.args if kernel_args is None else kernel_args
|
||||
_, _, arg_types = args.cpp_argdefs()
|
||||
if not V.graph.cpp_wrapper:
|
||||
compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''")
|
||||
compile_wrapper.splice(src_code, strip=True)
|
||||
if not V.graph.cpp_wrapper:
|
||||
compile_wrapper.writeline("''')")
|
||||
wrapper.define_kernel(
|
||||
kernel_name,
|
||||
compile_wrapper.getvalue(),
|
||||
gpu=False,
|
||||
cpp_definition=kernel_definition,
|
||||
)
|
||||
compile_wrapper = IndentedBuffer()
|
||||
args = self.kernel_group.args if kernel_args is None else kernel_args
|
||||
_, _, arg_types = args.cpp_argdefs()
|
||||
if not V.graph.cpp_wrapper:
|
||||
compile_wrapper.writeline(
|
||||
f"async_compile.cpp_pybinding({arg_types!r}, '''"
|
||||
)
|
||||
compile_wrapper.splice(src_code, strip=True)
|
||||
if not V.graph.cpp_wrapper:
|
||||
compile_wrapper.writeline("''')")
|
||||
wrapper.define_kernel(
|
||||
kernel_name,
|
||||
compile_wrapper.getvalue(),
|
||||
gpu=False,
|
||||
cpp_definition=kernel_definition,
|
||||
)
|
||||
return kernel_name
|
||||
|
||||
def flush(self):
|
||||
|
@ -14,6 +14,9 @@ class CpuDeviceOpOverrides(DeviceOpOverrides):
|
||||
"""
|
||||
)
|
||||
|
||||
def cpp_kernel_type(self) -> str:
|
||||
return "void*"
|
||||
|
||||
def set_device(self, device_idx: int) -> str:
|
||||
return "pass"
|
||||
|
||||
|
Reference in New Issue
Block a user