mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-30 03:34:56 +08:00
[inductor] [cpp] add index check when fusing epilogue with GEMM template (#135661)
## Description
Fixes the accuracy failure of FP32 `jx_nest_base` of max-autotune.
The current epilogue fusion implementation in GEMM template assumes that the read of template buffer and the write of epilogue output in the epilogue node have the same index (the layout could be different but the index should be the same).
If the condition is not satisfied, the computation is wrong, leading to correctness issue for FP32 `jx_nest_base`.
This PR disabled the epilogue fusion with GEMM template when the above condition is not satisfied.
### Unsupported epilogue:
`buf1` is the template buffer and `buf2` is the epilogue output buffer.
The store of `buf2`:
401408 * d0 + 100352 * d1 + **7168 * d2** + **1792 * d3** + 128 * d4 + d5
The load of `buf1` in the epilogue node:
401408 * d0 + 100352 * d1 + **1792 * d2** + **25088 * d3** + 128 * d4 + d5
The above two indexes are different.
```
CppTemplateBuffer(name='buf1', layout=FixedLayout('cpu', torch.float32, size=[25088, 128], stride=[128, 1]))
ComputedBuffer(name='buf2', layout=FixedLayout('cpu', torch.float32, size=[8, 4, 14, 4, 14, 128], stride=[401408, 100352, 7168, 1792, 128, 1]), data=Pointwise(
'cpu',
torch.float32,
def inner_fn(index):
i0, i1, i2, i3, i4, i5 = index
tmp0 = ops.load(arg5_1, i5 + 128 * i4 + 1792 * i2 + 25088 * i3 + 100352 * i1 + 401408 * i0)
tmp1 = ops.load(buf0, i5 + 128 * i4 + 1792 * i2 + 25088 * i3 + 100352 * i1 + 401408 * i0)
tmp2 = tmp0 + tmp1
tmp3 = ops.load(buf1, i5 + 128 * i4 + 1792 * i2 + 25088 * i3 + 100352 * i1 + 401408 * i0)
tmp4 = tmp2 + tmp3
return tmp4
,
ranges=[8, 4, 14, 4, 14, 128],
origin_node=clone,
origins=OrderedSet([clone])
))
```
### Supported epilogue:
`buf1` is the template buffer and `buf2` is the epilogue output buffer.
The store of `buf2`:
d0 + 576 * d1 + 32 * d2
The load of `buf1` in the epilogue node:
d0 + 576 * d1 + 32 * d2
The above two indexes are the same.
The layout of `buf2` and `buf1` are different though which is handled by the reindexer:
`buf1`: `size=[324, 32], stride=[32, 1]`
`buf2`: `size=[1, 32, 18, 18], stride=[10368, 1, 576, 32]`
```
CppTemplateBuffer(name='buf1', layout=FixedLayout('cpu', torch.bfloat16, size=[324, 32], stride=[32, 1]))
ComputedBuffer(name='buf2', layout=FixedLayout('cpu', torch.bfloat16, size=[1, 32, 18, 18], stride=[10368, 1, 576, 32]), data=Pointwise(
'cpu',
torch.bfloat16,
def inner_fn(index):
_, i1, i2, i3 = index
tmp0 = ops.load(buf1, i1 + 32 * i3 + 576 * i2)
tmp1 = ops.to_dtype(tmp0, torch.float32, src_dtype=torch.bfloat16)
tmp2 = ops.load(_frozen_param4, i1)
tmp3 = tmp1 * tmp2
tmp4 = ops.load(arg7_1, i1 + 32 * i3 + 576 * i2)
tmp5 = tmp3 + tmp4
tmp6 = ops.to_dtype(tmp5, torch.bfloat16, src_dtype=torch.float32)
return tmp6
,
ranges=[1, 32, 18, 18],
origin_node=convert_element_type_4,
origins=OrderedSet([add, mul, convert_element_type_4])
))
```
## TODO
Add the support for fusions when the indexes are different in a follow-up PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135661
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5
This commit is contained in:
committed by
PyTorch MergeBot
parent
7283530db2
commit
44c871c34b
@ -5,18 +5,21 @@ import functools
|
||||
import math
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._prims_common import is_integer_dtype
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
from .. import ir
|
||||
from ..dependencies import Dep
|
||||
from ..loop_body import LoopBody
|
||||
from ..scheduler import BaseSchedulerNode, SchedulerBuffer
|
||||
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
|
||||
from ..virtualized import ops, OpsValue, V
|
||||
from .common import (
|
||||
@ -916,3 +919,72 @@ def _get_dtype_from_loopbodies(loop_bodies):
|
||||
continue
|
||||
dtypes.add(node.meta[OptimizationContext.key].dtype)
|
||||
return dtypes
|
||||
|
||||
|
||||
def template_fusion_with_epilogues_supported(
|
||||
template: BaseSchedulerNode, epilogues: List[BaseSchedulerNode]
|
||||
) -> Tuple[bool, bool]:
|
||||
def _get_indexes_of_template_buf_read(
|
||||
epilogue_node: ir.Operation, template_buf_names: List[str]
|
||||
) -> List[sympy.Expr]:
|
||||
return [
|
||||
read.index
|
||||
for read in epilogue_node.get_reads()
|
||||
if read.name in template_buf_names
|
||||
]
|
||||
|
||||
def _check_supported_and_same_indexes(
|
||||
index_of_template_buf_read: sympy.Expr, epilogue_writes: OrderedSet[Dep]
|
||||
) -> Tuple[bool, bool]:
|
||||
num_indexes = len(set(index_of_template_buf_read))
|
||||
|
||||
if num_indexes > 1:
|
||||
same_index = False
|
||||
supported = False # Different read indexes not supported
|
||||
elif num_indexes == 0:
|
||||
same_index = True
|
||||
supported = True # No reads, automatically supported
|
||||
elif num_indexes == 1:
|
||||
index_of_template_buf_read = index_of_template_buf_read[0]
|
||||
same_index = all(
|
||||
write.index == index_of_template_buf_read for write in epilogue_writes
|
||||
)
|
||||
# TODO: Add support of fusion when the read of template buffer and the write of epilogue output
|
||||
# in the epilogue node don't have the same index and change supported to True
|
||||
supported = same_index
|
||||
else:
|
||||
raise AssertionError("Should not reach here")
|
||||
|
||||
return supported, same_index
|
||||
|
||||
def _template_fusion_supported(
|
||||
template_outputs: Sequence[SchedulerBuffer], epilogue_nodes: List[ir.Operation]
|
||||
) -> Tuple[bool, bool]:
|
||||
template_buf_names = [x.get_name() for x in template_outputs]
|
||||
indexes_of_template_buf_reads = [
|
||||
_get_indexes_of_template_buf_read(epilogue_node, template_buf_names)
|
||||
for epilogue_node in epilogue_nodes
|
||||
]
|
||||
epilogue_nodes_writes = [
|
||||
epilogue_node.get_read_writes().writes for epilogue_node in epilogue_nodes
|
||||
]
|
||||
|
||||
results = [
|
||||
_check_supported_and_same_indexes(reads, writes)
|
||||
for reads, writes in zip(
|
||||
indexes_of_template_buf_reads, epilogue_nodes_writes
|
||||
)
|
||||
]
|
||||
supported, same_indexes = zip(*results)
|
||||
return all(supported), all(same_indexes)
|
||||
|
||||
assert template.is_template()
|
||||
template_outputs = template.get_outputs()
|
||||
|
||||
epilogue_nodes = [
|
||||
n.node
|
||||
for epilogue in epilogues
|
||||
for n in epilogue.get_nodes()
|
||||
if n.node is not None
|
||||
]
|
||||
return _template_fusion_supported(template_outputs, epilogue_nodes)
|
||||
|
||||
Reference in New Issue
Block a user