Revert "[Inductor][CPP] Enable Local Buffer for Outer loop fusion (#126967)"

This reverts commit 98929ceae3873f18f4747b88cdff708fde107aa7.

Reverted https://github.com/pytorch/pytorch/pull/126967 on behalf of https://github.com/leslie-fang-intel due to Broken trunk and need rebase ([comment](https://github.com/pytorch/pytorch/pull/126967#issuecomment-2212337926))
This commit is contained in:
PyTorch MergeBot
2024-07-07 06:16:32 +00:00
parent 1b57dce35f
commit e423224546
6 changed files with 106 additions and 415 deletions

View File

@ -4,7 +4,7 @@ import copy
import math
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Dict, List, Tuple
from unittest.mock import patch
import sympy
@ -12,10 +12,11 @@ import sympy
import torch
from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import ir
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix
from ..virtualized import V
from .common import CSEVariable, ExprPrinter, Kernel, KernelArgs
from .common import CSEVariable, ExprPrinter, Kernel
DTYPE_TO_CPP = {
torch.float32: "float",
@ -303,88 +304,7 @@ def value_to_cpp(value, cpp_type):
return f"static_cast<{cpp_type}>({repr(value)})"
def rewrite_index_for_function(
localize_buffer_handler: "LocalizeBufferHandler",
index: sympy.Expr,
):
# Local buffer at the inner dimensions
snode = V.graph.scheduler.name_to_node.get(
localize_buffer_handler.global_buf.get_name()
)
assert snode is not None
scheduler_nodes = snode.get_nodes()
_, (group, reduction_group) = max(
scheduler_nodes, key=lambda x: int(x.is_reduction())
).group
call_ranges = tuple(group) + tuple(reduction_group)
indices_to_keep = [
f"x{len(call_ranges) - (idx + 1)}"
for idx in range(len(localize_buffer_handler.local_buf.get_layout().size))
]
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined]
replacements = {}
for x in sorted_symbols:
if x.name.startswith("x") and x.name not in indices_to_keep: # type: ignore[attr-defined]
# Only keep index used by local buffer
replacements[x] = sympy.core.numbers.Zero()
index = sympy_subs(index, replacements) # type: ignore[arg-type]
return index
def rewrite_index_for_nodes(
localize_buffer_handler: "LocalizeBufferHandler",
index: sympy.Expr,
):
used_vars = {s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)}
index_vars = []
for i in range(len(localize_buffer_handler.local_buf.get_size())):
var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
index_vars.append(var if var in used_vars else 0)
index = localize_buffer_handler.local_buf.layout.make_indexer()(index_vars)
return index
class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
def __init__(
self,
inner,
global_buf: ir.Buffer,
local_buf: ir.Buffer,
rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr], sympy.Expr],
):
super().__init__(inner)
self.global_buf = global_buf
self.local_buf = local_buf
self.rewrite_index = rewrite_index
def localize(self, name: str, index: sympy.Expr):
if self.global_buf and name == self.global_buf.get_name():
assert self.rewrite_index is not None
name = self.local_buf.get_name()
index = self.rewrite_index(self, index)
return name, index
def load(self, name: str, index: sympy.Expr):
return self._inner.load(*self.localize(name, index))
def store(self, name, index, value, mode=None):
local_buffer_name, local_buffer_index = self.localize(name, index)
res = self._inner.store(local_buffer_name, local_buffer_index, value, mode)
if (
self.global_buf
and name == self.global_buf.get_name()
and isinstance(V.kernel, Kernel)
):
# Remove name of local buffer from Kernel.store_buffer_names
# local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store.
V.kernel.store_buffer_names.discard(local_buffer_name)
return res
def store_reduction(self, name, index, value):
return self._inner.store_reduction(*self.localize(name, index), value)
class LocalBufferContext:
class LocalBufferScope:
"""
This class creates a context that helps to generate code involving Inductor IR with
function local buffers. These buffers are constructed during the codegen process and
@ -394,13 +314,10 @@ class LocalBufferContext:
these buffers without exposure to the outside world.
"""
def __init__(self, kernel_args: KernelArgs):
self.kernel_args = kernel_args
def __init__(self, kernel: Kernel):
self.kernel = kernel
self.exit_stack = contextlib.ExitStack()
# Map Local Buffer name to Local Buffer
self.local_buffers: Dict[str, ir.Buffer] = {}
# Map Local Buffer name to Global Buffer
self.local_to_global: Dict[str, ir.Buffer] = {}
def __enter__(self):
self.exit_stack.__enter__()
@ -413,26 +330,23 @@ class LocalBufferContext:
self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype))
original_input = self.kernel_args.input
original_input = self.kernel.args.input
def input(name):
if name in self.local_buffers:
return name
return original_input(name)
self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input))
self.exit_stack.enter_context(patch.object(self.kernel.args, "input", input))
original_output = self.kernel_args.output
original_output = self.kernel.args.output
def output(name):
if name in self.local_buffers:
return name
return original_output(name)
self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output))
# Set current LocalBufferContext into V
self.exit_stack.enter_context(V.set_local_buffer_context(self))
self.exit_stack.enter_context(patch.object(self.kernel.args, "output", output))
return self
@ -440,64 +354,53 @@ class LocalBufferContext:
self.local_buffers.clear()
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
def add_local_buffer(
self, local_buffer: ir.Buffer, global_buffer: Optional[ir.Buffer] = None
):
assert local_buffer.get_name() not in self.local_buffers
self.local_buffers[local_buffer.get_name()] = local_buffer
if global_buffer:
self.local_to_global[local_buffer.get_name()] = global_buffer
V.graph.removed_buffers.add(global_buffer.get_name())
def add_local_buffer(self, buffer: ir.Buffer):
assert buffer.get_name() not in self.local_buffers
self.local_buffers[buffer.get_name()] = buffer
def localize_function(
self,
fn: Callable[..., Any],
rewrite_index: Callable[
["LocalizeBufferHandler", sympy.Expr], sympy.Expr
] = rewrite_index_for_function,
):
local_buffers = list(self.local_buffers.values())
global_buffers = list(self.local_to_global.values())
local_buf = local_buffers[0]
global_buf = global_buffers[0]
def inner(node, *index_vars):
with V.set_ops_handler(
LocalizeBufferHandler(
V.get_ops_handler(),
global_buf=global_buf,
local_buf=local_buf,
rewrite_index=rewrite_index,
)
):
return fn(node, *index_vars)
return inner
def localize_nodes(
self,
nodes: List[ir.IRNode],
rewrite_index: Callable[
["LocalizeBufferHandler", sympy.Expr], sympy.Expr
] = rewrite_index_for_nodes,
def localize_buffer(
self, global_buf: ir.Buffer, local_buf: ir.Buffer, nodes: List[ir.IRNode]
) -> List[ir.IRNode]:
"""
Given `local_buf` and `global_buf` registered in current `LocalBufferContext`
though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf`
for the given `nodes` and returns a new list of IR nodes that work on `local_buf`
instead of `global_buf`, i.e., all the loads and stores are redirected to
`local_buf`. This helps the fused loops to work on smaller-sized local buffers
for better data locality.
Localizes the buffer `global_buf` to `local_buf` in the given `nodes` and returns
a new list of IR nodes that work on `local_buf` instead of `global_buf`, i.e., all
the loads and stores are redirected to `local_buf`. This helps the fused loops to
work on smaller-sized local buffers for better data locality.
The the data access of `local_buf` is assumed to be contiguous with the
same order as the `global_buf`.
The `local_buf` should already be registered in the local scope and the data access
is assumed to be contiguous with the same order as the `global_buf`.
"""
local_buffers = list(self.local_buffers.values())
global_buffers = list(self.local_to_global.values())
assert len(global_buffers[0].get_size()) == len(local_buffers[0].get_size())
assert local_buf.get_name() in self.local_buffers
assert len(global_buf.get_size()) == len(local_buf.get_size())
assert len(nodes) > 0
def wrap_inner_fn_for_node(node: ir.IRNode):
class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
def __init__(self, inner):
super().__init__(inner)
def localize(self, name: str, index: sympy.Expr):
if name == global_buf.get_name():
name = local_buf.get_name()
used_vars = {
s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)
}
index_vars = []
for i in range(len(local_buf.get_size())):
var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
index_vars.append(var if var in used_vars else 0)
index = local_buf.layout.make_indexer()(index_vars)
return name, index
def load(self, name: str, index: sympy.Expr):
return self._inner.load(*self.localize(name, index))
def store(self, name, index, value, mode=None):
return self._inner.store(*self.localize(name, index), value, mode)
def store_reduction(self, name, index, value):
return self._inner.store_reduction(*self.localize(name, index), value)
def wrap_inner_fn_for_node(node: ir.IRNode, inner_fn_wrapper):
loops = node.data if isinstance(node, ir.ComputedBuffer) else node
assert isinstance(loops, ir.Loops)
new_loops = copy.copy(loops)
@ -508,13 +411,17 @@ class LocalBufferContext:
else:
new_node = new_loops # type: ignore[assignment]
new_loops.inner_fn = self.localize_function(
new_loops.inner_fn,
rewrite_index,
)
new_loops.inner_fn = inner_fn_wrapper(new_loops.inner_fn)
return new_node
return [wrap_inner_fn_for_node(node) for node in nodes]
def inner_fn_wrapper(inner_fn):
def inner(index):
with V.set_ops_handler(LocalizeBufferHandler(V.get_ops_handler())):
return inner_fn(index)
return inner
return [wrap_inner_fn_for_node(node, inner_fn_wrapper) for node in nodes]
def unify_mask_base_type(