[Inductor][CPP] Enable Epilogue Fusion for Grouped GEMM Template (#143897)

**Summary**
In this PR, we enable the epilogues fusion and code generation for Grouped GEMM. Here are the high-level description of how we implement it.

**Fusion**

- The Grouped GEMM Template produces a `Template Buffer` with a `MultiOutputLayout` and a set of `MultiOutput Buffers`, where each buffer corresponds to a specific GEMM.
- During the initial round of fusion, the `Template Buffer` and all associated `MultiOutput Buffers` are fused into a `FusedSchedulerNode` by extending the existing fusion design.
- In subsequent fusion rounds, this `FusedSchedulerNode` can further fuse with its epilogues, following the original fusion design principles.

**Code Gen**
We maintain a list of epilogues and codegen it one by one.

- If any of the GEMM has bias, we create  a extra `bias_add` epilogue and prepend it at first of the epilogue list.
- If any of the GEMM has no epilogue, we create a `to_bf16` copy epilogue and append it at last of the epilogue list.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_epilogue
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143897
Approved by: https://github.com/jansel, https://github.com/jgong5
ghstack dependencies: #143796
This commit is contained in:
leslie-fang-intel
2025-01-13 00:03:45 -08:00
committed by PyTorch MergeBot
parent 25de671ea8
commit 9d98b66e7b
7 changed files with 349 additions and 79 deletions

View File

@ -1822,6 +1822,31 @@ def pass_execution_and_save(func, gm, inp, msg):
)
def is_multi_outputs_template(input_buf) -> bool:
"""
Check if input buffer is a multi-outputs template buffer
"""
from . import ir
return isinstance(input_buf, ir.CppTemplateBuffer) and isinstance(
input_buf.layout, ir.MultiOutputLayout
)
def is_output_of_multi_outputs_template(input_buf) -> bool:
"""
Check if input buffer is a output of multi-outputs template buffer
"""
from . import ir
return (
isinstance(input_buf, ir.MultiOutput)
and len(input_buf.inputs) == 1
and isinstance(input_buf.inputs[0], ir.CppTemplateBuffer)
and isinstance(input_buf.inputs[0].layout, ir.MultiOutputLayout)
)
def is_collective(node, op=None):
from . import ir