mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor][CPP] Enable Epilogue Fusion for Grouped GEMM Template (#143897)
**Summary** In this PR, we enable the epilogues fusion and code generation for Grouped GEMM. Here are the high-level description of how we implement it. **Fusion** - The Grouped GEMM Template produces a `Template Buffer` with a `MultiOutputLayout` and a set of `MultiOutput Buffers`, where each buffer corresponds to a specific GEMM. - During the initial round of fusion, the `Template Buffer` and all associated `MultiOutput Buffers` are fused into a `FusedSchedulerNode` by extending the existing fusion design. - In subsequent fusion rounds, this `FusedSchedulerNode` can further fuse with its epilogues, following the original fusion design principles. **Code Gen** We maintain a list of epilogues and codegen it one by one. - If any of the GEMM has bias, we create a extra `bias_add` epilogue and prepend it at first of the epilogue list. - If any of the GEMM has no epilogue, we create a `to_bf16` copy epilogue and append it at last of the epilogue list. **TestPlan** ``` python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_epilogue ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/143897 Approved by: https://github.com/jansel, https://github.com/jgong5 ghstack dependencies: #143796
This commit is contained in:
committed by
PyTorch MergeBot
parent
25de671ea8
commit
9d98b66e7b
@ -1822,6 +1822,31 @@ def pass_execution_and_save(func, gm, inp, msg):
|
||||
)
|
||||
|
||||
|
||||
def is_multi_outputs_template(input_buf) -> bool:
|
||||
"""
|
||||
Check if input buffer is a multi-outputs template buffer
|
||||
"""
|
||||
from . import ir
|
||||
|
||||
return isinstance(input_buf, ir.CppTemplateBuffer) and isinstance(
|
||||
input_buf.layout, ir.MultiOutputLayout
|
||||
)
|
||||
|
||||
|
||||
def is_output_of_multi_outputs_template(input_buf) -> bool:
|
||||
"""
|
||||
Check if input buffer is a output of multi-outputs template buffer
|
||||
"""
|
||||
from . import ir
|
||||
|
||||
return (
|
||||
isinstance(input_buf, ir.MultiOutput)
|
||||
and len(input_buf.inputs) == 1
|
||||
and isinstance(input_buf.inputs[0], ir.CppTemplateBuffer)
|
||||
and isinstance(input_buf.inputs[0].layout, ir.MultiOutputLayout)
|
||||
)
|
||||
|
||||
|
||||
def is_collective(node, op=None):
|
||||
from . import ir
|
||||
|
||||
|
||||
Reference in New Issue
Block a user