Revert "[Inductor][CPP] Support more than one LocalBuffer (#129121)"

This reverts commit f794cf59bd0891ff4a4337e0d919ee68ba1f0472.

Reverted https://github.com/pytorch/pytorch/pull/129121 on behalf of https://github.com/leslie-fang-intel due to Broken trunk and need rebase ([comment](https://github.com/pytorch/pytorch/pull/129121#issuecomment-2212337590))
This commit is contained in:
PyTorch MergeBot
2024-07-07 06:13:40 +00:00
parent f794cf59bd
commit 1b57dce35f
4 changed files with 70 additions and 193 deletions

View File

@ -306,11 +306,11 @@ def value_to_cpp(value, cpp_type):
def rewrite_index_for_function(
localize_buffer_handler: "LocalizeBufferHandler",
index: sympy.Expr,
global_buf_name: str,
):
# Local buffer at the inner dimensions
snode = V.graph.scheduler.name_to_node.get(global_buf_name)
local_buf = localize_buffer_handler.global_to_local[global_buf_name]
snode = V.graph.scheduler.name_to_node.get(
localize_buffer_handler.global_buf.get_name()
)
assert snode is not None
scheduler_nodes = snode.get_nodes()
_, (group, reduction_group) = max(
@ -319,7 +319,7 @@ def rewrite_index_for_function(
call_ranges = tuple(group) + tuple(reduction_group)
indices_to_keep = [
f"x{len(call_ranges) - (idx + 1)}"
for idx in range(len(local_buf.get_layout().size))
for idx in range(len(localize_buffer_handler.local_buf.get_layout().size))
]
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined]
replacements = {}
@ -334,15 +334,13 @@ def rewrite_index_for_function(
def rewrite_index_for_nodes(
localize_buffer_handler: "LocalizeBufferHandler",
index: sympy.Expr,
global_buf_name: str,
):
used_vars = {s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)}
index_vars = []
local_buf = localize_buffer_handler.global_to_local[global_buf_name]
for i in range(len(local_buf.get_size())):
for i in range(len(localize_buffer_handler.local_buf.get_size())):
var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
index_vars.append(var if var in used_vars else 0)
index = local_buf.layout.make_indexer()(index_vars)
index = localize_buffer_handler.local_buf.layout.make_indexer()(index_vars)
return index
@ -350,18 +348,20 @@ class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
def __init__(
self,
inner,
global_to_local: Dict[str, ir.Buffer],
rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr],
global_buf: ir.Buffer,
local_buf: ir.Buffer,
rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr], sympy.Expr],
):
super().__init__(inner)
self.global_to_local = global_to_local
self.global_buf = global_buf
self.local_buf = local_buf
self.rewrite_index = rewrite_index
def localize(self, name: str, index: sympy.Expr):
if self.global_to_local and name in self.global_to_local:
if self.global_buf and name == self.global_buf.get_name():
assert self.rewrite_index is not None
index = self.rewrite_index(self, index, name)
name = self.global_to_local[name].get_name()
name = self.local_buf.get_name()
index = self.rewrite_index(self, index)
return name, index
def load(self, name: str, index: sympy.Expr):
@ -371,8 +371,8 @@ class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
local_buffer_name, local_buffer_index = self.localize(name, index)
res = self._inner.store(local_buffer_name, local_buffer_index, value, mode)
if (
self.global_to_local
and name in self.global_to_local
self.global_buf
and name == self.global_buf.get_name()
and isinstance(V.kernel, Kernel)
):
# Remove name of local buffer from Kernel.store_buffer_names
@ -397,12 +397,10 @@ class LocalBufferContext:
def __init__(self, kernel_args: KernelArgs):
self.kernel_args = kernel_args
self.exit_stack = contextlib.ExitStack()
# map local buffer name to local buffer
# Map Local Buffer name to Local Buffer
self.local_buffers: Dict[str, ir.Buffer] = {}
# map global buffer name to global buffer
self.global_buffers: Dict[str, ir.Buffer] = {}
# map global buffer name to local buffer
self.global_to_local: Dict[str, ir.Buffer] = {}
# Map Local Buffer name to Global Buffer
self.local_to_global: Dict[str, ir.Buffer] = {}
def __enter__(self):
self.exit_stack.__enter__()
@ -443,33 +441,32 @@ class LocalBufferContext:
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
def add_local_buffer(
self, local_buffer: ir.Buffer, global_buffers: Optional[List[ir.Buffer]] = None
self, local_buffer: ir.Buffer, global_buffer: Optional[ir.Buffer] = None
):
assert local_buffer.get_name() not in self.local_buffers
self.local_buffers[local_buffer.get_name()] = local_buffer
if global_buffers:
for global_buffer in global_buffers:
global_buffer_name = global_buffer.get_name()
assert (
global_buffer_name not in self.global_buffers
and global_buffer_name not in self.global_to_local
)
self.global_buffers[global_buffer_name] = global_buffer
self.global_to_local[global_buffer_name] = local_buffer
V.graph.removed_buffers.add(global_buffer_name)
if global_buffer:
self.local_to_global[local_buffer.get_name()] = global_buffer
V.graph.removed_buffers.add(global_buffer.get_name())
def localize_function(
self,
fn: Callable[..., Any],
rewrite_index: Callable[
["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
["LocalizeBufferHandler", sympy.Expr], sympy.Expr
] = rewrite_index_for_function,
):
local_buffers = list(self.local_buffers.values())
global_buffers = list(self.local_to_global.values())
local_buf = local_buffers[0]
global_buf = global_buffers[0]
def inner(node, *index_vars):
with V.set_ops_handler(
LocalizeBufferHandler(
V.get_ops_handler(),
global_to_local=self.global_to_local,
global_buf=global_buf,
local_buf=local_buf,
rewrite_index=rewrite_index,
)
):
@ -481,7 +478,7 @@ class LocalBufferContext:
self,
nodes: List[ir.IRNode],
rewrite_index: Callable[
["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
["LocalizeBufferHandler", sympy.Expr], sympy.Expr
] = rewrite_index_for_nodes,
) -> List[ir.IRNode]:
"""
@ -495,6 +492,9 @@ class LocalBufferContext:
The the data access of `local_buf` is assumed to be contiguous with the
same order as the `global_buf`.
"""
local_buffers = list(self.local_buffers.values())
global_buffers = list(self.local_to_global.values())
assert len(global_buffers[0].get_size()) == len(local_buffers[0].get_size())
assert len(nodes) > 0
def wrap_inner_fn_for_node(node: ir.IRNode):