[Inductor][CPP] Reuse the pre-existing kernel for the same kernels (#158404)

Reuse the pre-existing kernel to avoid defining redundant kernels.
Inductor CPP will generate same kernels. For example:
```
# Example
class Model(torch.nn.Module):
    def __init__(self, K, N):
        super().__init__()
        self.linear0 = torch.nn.Linear(K, N)
        self.linear1 = torch.nn.Linear(N, K)
        self.linear2 = torch.nn.Linear(K, N)

    def forward(self, input):
        out = self.linear0(input)
        out = self.linear1(out)
        out = self.linear2(out)
        return out
```

For the above example, linear2 is same as linear0, and Inductor CPP generates 2 same kernels: cpp_fused_addmm_0  and cpp_fused_addmm_2.

```
# Generated code:
...
cpp_fused_addmm_0 = async_compile.cpp_pybinding(['const at::BFloat16*', 'const at::BFloat16*', 'const at::BFloat16*', 'at::BFloat16*'], '''
...
extern "C"
void kernel(const at::BFloat16* X, const at::BFloat16* W, const at::BFloat16* inp, at::BFloat16* Y)
{

    constexpr int64_t num_threads = 32;
    constexpr int64_t N = 1024;
    constexpr int64_t K = 2048;
    constexpr int64_t Mr = 32;
    constexpr int64_t Nr = 32;
    constexpr int64_t Kr = 32;
...
cpp_fused_addmm_1 = async_compile.cpp_pybinding(['const at::BFloat16*', 'const at::BFloat16*', 'const at::BFloat16*', 'at::BFloat16*'], '''
...
extern "C"
void kernel(const at::BFloat16* X, const at::BFloat16* W, const at::BFloat16* inp, at::BFloat16* Y)
{

    constexpr int64_t num_threads = 32;
    constexpr int64_t N = 2048;
    constexpr int64_t K = 1024;
    constexpr int64_t Mr = 32;
    constexpr int64_t Nr = 32;
    constexpr int64_t Kr = 32;
...
cpp_fused_addmm_2 = async_compile.cpp_pybinding(['const at::BFloat16*', 'const at::BFloat16*', 'const at::BFloat16*', 'at::BFloat16*'], '''
extern "C"
void kernel(const at::BFloat16* X, const at::BFloat16* W, const at::BFloat16* inp, at::BFloat16* Y)
{

    constexpr int64_t num_threads = 32;
    constexpr int64_t N = 1024;
    constexpr int64_t K = 2048;
    constexpr int64_t Mr = 32;
    constexpr int64_t Nr = 32;
    constexpr int64_t Kr = 32;
...
    def call(self, args):
        arg6_1, = args
        args.clear()
        buf0 = empty_strided_cpu((1024, 1024), (1024, 1), torch.bfloat16)
        cpp_fused_addmm_0(arg6_1, constant6, _frozen_param6, buf0)
        del arg6_1
        buf1 = empty_strided_cpu((1024, 2048), (2048, 1), torch.bfloat16)
        cpp_fused_addmm_1(buf0, constant6_0, _frozen_param8, buf1)
        buf2 = buf0; del buf0  # reuse
        cpp_fused_addmm_2(buf1, constant6_1, _frozen_param10, buf2)
        return (buf2, )
```

After reusing the pre-existing kernel, Inductor CPP will reuse cpp_fused_addmm_0.
```
cpp_fused_addmm_0 = async_compile.cpp_pybinding(['const at::BFloat16*', 'const at::BFloat16*', 'const at::BFloat16*', 'at::BFloat16*'], '''
...
extern "C"
void kernel(const at::BFloat16* X, const at::BFloat16* W, const at::BFloat16* inp, at::BFloat16* Y)
{

    constexpr int64_t num_threads = 32;
    constexpr int64_t N = 1024;
    constexpr int64_t K = 2048;
    constexpr int64_t Mr = 32;
    constexpr int64_t Nr = 32;
    constexpr int64_t Kr = 32;
...
cpp_fused_addmm_1 = async_compile.cpp_pybinding(['const at::BFloat16*', 'const at::BFloat16*', 'const at::BFloat16*', 'at::BFloat16*'], '''
...
extern "C"
void kernel(const at::BFloat16* X, const at::BFloat16* W, const at::BFloat16* inp, at::BFloat16* Y)
{

    constexpr int64_t num_threads = 32;
    constexpr int64_t N = 2048;
    constexpr int64_t K = 1024;
    constexpr int64_t Mr = 32;
    constexpr int64_t Nr = 32;
    constexpr int64_t Kr = 32;
...
    def call(self, args):
        arg6_1, = args
        args.clear()
        buf0 = empty_strided_cpu((1024, 1024), (1024, 1), torch.bfloat16)
        cpp_fused_addmm_0(arg6_1, constant6, _frozen_param6, buf0)
        del arg6_1
        buf1 = empty_strided_cpu((1024, 2048), (2048, 1), torch.bfloat16)
        cpp_fused_addmm_1(buf0, constant6_0, _frozen_param8, buf1)
        buf2 = buf0; del buf0  # reuse
        cpp_fused_addmm_0(buf1, constant6_1, _frozen_param10, buf2)
        return (buf2, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158404
Approved by: https://github.com/jansel, https://github.com/leslie-fang-intel
This commit is contained in:
CaoE
2025-09-16 01:54:21 +00:00
committed by PyTorch MergeBot
parent 61be0f1c11
commit 1aa41eccc2
3 changed files with 74 additions and 34 deletions

View File

@ -2910,6 +2910,37 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
with verify(u.dtype) as (atol, rtol):
self.common(mod, (u, v))
@unittest.skipIf(
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@parametrize("batch_size", (1024,))
@parametrize("in_features", (1024,))
@parametrize("out_features", (2048,))
@dtypes(torch.bfloat16)
def test_linear_reuse_kernels(self, batch_size, in_features, out_features, dtype):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear_x = torch.nn.Linear(in_features, out_features)
self.linear_y = torch.nn.Linear(out_features, in_features)
self.linear_z = torch.nn.Linear(in_features, out_features)
def forward(self, x):
out = self.linear_x(x)
out = self.linear_y(out)
out = self.linear_z(out)
return out
x = torch.randn(batch_size, in_features).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
self.common(mod, (x))
_, code = run_and_get_cpp_code(mod, x)
# Check that only 2 kernels are in the generated code
assert code.count("AMXState amx_state") == 2
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
class _DynamicShapesTestBase(BaseTestSelectAlgorithm):

View File

@ -5390,42 +5390,48 @@ class CppScheduling(BaseScheduling):
def define_kernel(self, src_code, nodes, kernel_args=None):
wrapper = V.graph.wrapper_code
fused_name = (
get_fused_kernel_name(nodes, config.cpp.descriptive_names)
if config.cpp.descriptive_names
else ""
)
kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name)
src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name)
# TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
# not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
src_code = src_code.replace("#pragma CMT", "//")
if src_code in wrapper.src_to_kernel:
kernel_name = wrapper.src_to_kernel[src_code]
else:
fused_name = (
get_fused_kernel_name(nodes, config.cpp.descriptive_names)
if config.cpp.descriptive_names
else ""
)
kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
wrapper.src_to_kernel[src_code] = kernel_name
kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name)
src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name)
# TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
# not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
src_code = src_code.replace("#pragma CMT", "//")
# Get the lines in the source code representing the function definition,
# excluding the the first line including cpp_prefix.h.
first_char = src_code.rfind('extern "C"')
last_char = src_code.find(")", first_char)
if _IS_WINDOWS:
# get_export_declaration introduced one more ')' in Windows
last_char = src_code.find(")", last_char + 1)
kernel_definition = f"{src_code[first_char : last_char + 1]};\n"
# Get the lines in the source code representing the function definition,
# excluding the the first line including cpp_prefix.h.
first_char = src_code.rfind('extern "C"')
last_char = src_code.find(")", first_char)
if _IS_WINDOWS:
# get_export_declaration introduced one more ')' in Windows
last_char = src_code.find(")", last_char + 1)
kernel_definition = f"{src_code[first_char : last_char + 1]};\n"
compile_wrapper = IndentedBuffer()
args = self.kernel_group.args if kernel_args is None else kernel_args
_, _, arg_types = args.cpp_argdefs()
if not V.graph.cpp_wrapper:
compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''")
compile_wrapper.splice(src_code, strip=True)
if not V.graph.cpp_wrapper:
compile_wrapper.writeline("''')")
wrapper.define_kernel(
kernel_name,
compile_wrapper.getvalue(),
gpu=False,
cpp_definition=kernel_definition,
)
compile_wrapper = IndentedBuffer()
args = self.kernel_group.args if kernel_args is None else kernel_args
_, _, arg_types = args.cpp_argdefs()
if not V.graph.cpp_wrapper:
compile_wrapper.writeline(
f"async_compile.cpp_pybinding({arg_types!r}, '''"
)
compile_wrapper.splice(src_code, strip=True)
if not V.graph.cpp_wrapper:
compile_wrapper.writeline("''')")
wrapper.define_kernel(
kernel_name,
compile_wrapper.getvalue(),
gpu=False,
cpp_definition=kernel_definition,
)
return kernel_name
def flush(self):

View File

@ -14,6 +14,9 @@ class CpuDeviceOpOverrides(DeviceOpOverrides):
"""
)
def cpp_kernel_type(self) -> str:
return "void*"
def set_device(self, device_idx: int) -> str:
return "pass"