[inductor] [cpp] add index check when fusing epilogue with GEMM template (#135661)

## Description
Fixes the accuracy failure of FP32 `jx_nest_base` of max-autotune.

The current epilogue fusion implementation in GEMM template assumes that the read of template buffer and the write of epilogue output in the epilogue node have the same index (the layout could be different but the index should be the same).

If the condition is not satisfied, the computation is wrong, leading to correctness issue for FP32 `jx_nest_base`.

This PR disabled the epilogue fusion with GEMM template when the above condition is not satisfied.

### Unsupported epilogue:
`buf1` is the template buffer and `buf2` is the epilogue output buffer.
The store of `buf2`:
401408 * d0 + 100352 * d1 + **7168 * d2** + **1792 * d3** + 128 * d4 + d5

The load of `buf1` in the epilogue node:
401408 * d0 + 100352 * d1 + **1792 * d2** + **25088 * d3** + 128 * d4 + d5

The above two indexes are different.

```
CppTemplateBuffer(name='buf1', layout=FixedLayout('cpu', torch.float32, size=[25088, 128], stride=[128, 1]))
ComputedBuffer(name='buf2', layout=FixedLayout('cpu', torch.float32, size=[8, 4, 14, 4, 14, 128], stride=[401408, 100352, 7168, 1792, 128, 1]), data=Pointwise(
  'cpu',
  torch.float32,
  def inner_fn(index):
      i0, i1, i2, i3, i4, i5 = index
      tmp0 = ops.load(arg5_1, i5 + 128 * i4 + 1792 * i2 + 25088 * i3 + 100352 * i1 + 401408 * i0)
      tmp1 = ops.load(buf0, i5 + 128 * i4 + 1792 * i2 + 25088 * i3 + 100352 * i1 + 401408 * i0)
      tmp2 = tmp0 + tmp1
      tmp3 = ops.load(buf1, i5 + 128 * i4 + 1792 * i2 + 25088 * i3 + 100352 * i1 + 401408 * i0)
      tmp4 = tmp2 + tmp3
      return tmp4
  ,
  ranges=[8, 4, 14, 4, 14, 128],
  origin_node=clone,
  origins=OrderedSet([clone])
))
```

### Supported epilogue:
`buf1` is the template buffer and `buf2` is the epilogue output buffer.
The store of `buf2`:
d0 + 576 * d1 + 32 * d2

The load of `buf1` in the epilogue node:
d0 + 576 * d1 + 32 * d2

The above two indexes are the same.

The layout of `buf2` and `buf1` are different though which is handled by the reindexer:
`buf1`: `size=[324, 32], stride=[32, 1]`
`buf2`: `size=[1, 32, 18, 18], stride=[10368, 1, 576, 32]`

```
CppTemplateBuffer(name='buf1', layout=FixedLayout('cpu', torch.bfloat16, size=[324, 32], stride=[32, 1]))
ComputedBuffer(name='buf2', layout=FixedLayout('cpu', torch.bfloat16, size=[1, 32, 18, 18], stride=[10368, 1, 576, 32]), data=Pointwise(
  'cpu',
  torch.bfloat16,
  def inner_fn(index):
      _, i1, i2, i3 = index
      tmp0 = ops.load(buf1, i1 + 32 * i3 + 576 * i2)
      tmp1 = ops.to_dtype(tmp0, torch.float32, src_dtype=torch.bfloat16)
      tmp2 = ops.load(_frozen_param4, i1)
      tmp3 = tmp1 * tmp2
      tmp4 = ops.load(arg7_1, i1 + 32 * i3 + 576 * i2)
      tmp5 = tmp3 + tmp4
      tmp6 = ops.to_dtype(tmp5, torch.bfloat16, src_dtype=torch.float32)
      return tmp6
  ,
  ranges=[1, 32, 18, 18],
  origin_node=convert_element_type_4,
  origins=OrderedSet([add, mul, convert_element_type_4])
))
```

## TODO
Add the support for fusions when the indexes are different in a follow-up PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135661
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5
This commit is contained in:
Wu, Chunyuan
2024-09-24 10:05:20 +00:00
committed by PyTorch MergeBot
parent 7283530db2
commit 44c871c34b
3 changed files with 210 additions and 2 deletions

View File

@ -5,18 +5,21 @@ import functools
import math
import sys
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from unittest.mock import patch
import sympy
import torch
from torch._prims_common import is_integer_dtype
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.symbol import symbol_is_type, SymT
from torch.utils._sympy.value_ranges import ValueRanges
from .. import ir
from ..dependencies import Dep
from ..loop_body import LoopBody
from ..scheduler import BaseSchedulerNode, SchedulerBuffer
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
from ..virtualized import ops, OpsValue, V
from .common import (
@ -916,3 +919,72 @@ def _get_dtype_from_loopbodies(loop_bodies):
continue
dtypes.add(node.meta[OptimizationContext.key].dtype)
return dtypes
def template_fusion_with_epilogues_supported(
template: BaseSchedulerNode, epilogues: List[BaseSchedulerNode]
) -> Tuple[bool, bool]:
def _get_indexes_of_template_buf_read(
epilogue_node: ir.Operation, template_buf_names: List[str]
) -> List[sympy.Expr]:
return [
read.index
for read in epilogue_node.get_reads()
if read.name in template_buf_names
]
def _check_supported_and_same_indexes(
index_of_template_buf_read: sympy.Expr, epilogue_writes: OrderedSet[Dep]
) -> Tuple[bool, bool]:
num_indexes = len(set(index_of_template_buf_read))
if num_indexes > 1:
same_index = False
supported = False # Different read indexes not supported
elif num_indexes == 0:
same_index = True
supported = True # No reads, automatically supported
elif num_indexes == 1:
index_of_template_buf_read = index_of_template_buf_read[0]
same_index = all(
write.index == index_of_template_buf_read for write in epilogue_writes
)
# TODO: Add support of fusion when the read of template buffer and the write of epilogue output
# in the epilogue node don't have the same index and change supported to True
supported = same_index
else:
raise AssertionError("Should not reach here")
return supported, same_index
def _template_fusion_supported(
template_outputs: Sequence[SchedulerBuffer], epilogue_nodes: List[ir.Operation]
) -> Tuple[bool, bool]:
template_buf_names = [x.get_name() for x in template_outputs]
indexes_of_template_buf_reads = [
_get_indexes_of_template_buf_read(epilogue_node, template_buf_names)
for epilogue_node in epilogue_nodes
]
epilogue_nodes_writes = [
epilogue_node.get_read_writes().writes for epilogue_node in epilogue_nodes
]
results = [
_check_supported_and_same_indexes(reads, writes)
for reads, writes in zip(
indexes_of_template_buf_reads, epilogue_nodes_writes
)
]
supported, same_indexes = zip(*results)
return all(supported), all(same_indexes)
assert template.is_template()
template_outputs = template.get_outputs()
epilogue_nodes = [
n.node
for epilogue in epilogues
for n in epilogue.get_nodes()
if n.node is not None
]
return _template_fusion_supported(template_outputs, epilogue_nodes)