mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 03:04:55 +08:00
Revert "[Inductor][CPP] Support more than one LocalBuffer (#129121)"
This reverts commit f794cf59bd0891ff4a4337e0d919ee68ba1f0472. Reverted https://github.com/pytorch/pytorch/pull/129121 on behalf of https://github.com/leslie-fang-intel due to Broken trunk and need rebase ([comment](https://github.com/pytorch/pytorch/pull/129121#issuecomment-2212337590))
This commit is contained in:
@ -306,11 +306,11 @@ def value_to_cpp(value, cpp_type):
|
||||
def rewrite_index_for_function(
|
||||
localize_buffer_handler: "LocalizeBufferHandler",
|
||||
index: sympy.Expr,
|
||||
global_buf_name: str,
|
||||
):
|
||||
# Local buffer at the inner dimensions
|
||||
snode = V.graph.scheduler.name_to_node.get(global_buf_name)
|
||||
local_buf = localize_buffer_handler.global_to_local[global_buf_name]
|
||||
snode = V.graph.scheduler.name_to_node.get(
|
||||
localize_buffer_handler.global_buf.get_name()
|
||||
)
|
||||
assert snode is not None
|
||||
scheduler_nodes = snode.get_nodes()
|
||||
_, (group, reduction_group) = max(
|
||||
@ -319,7 +319,7 @@ def rewrite_index_for_function(
|
||||
call_ranges = tuple(group) + tuple(reduction_group)
|
||||
indices_to_keep = [
|
||||
f"x{len(call_ranges) - (idx + 1)}"
|
||||
for idx in range(len(local_buf.get_layout().size))
|
||||
for idx in range(len(localize_buffer_handler.local_buf.get_layout().size))
|
||||
]
|
||||
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined]
|
||||
replacements = {}
|
||||
@ -334,15 +334,13 @@ def rewrite_index_for_function(
|
||||
def rewrite_index_for_nodes(
|
||||
localize_buffer_handler: "LocalizeBufferHandler",
|
||||
index: sympy.Expr,
|
||||
global_buf_name: str,
|
||||
):
|
||||
used_vars = {s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)}
|
||||
index_vars = []
|
||||
local_buf = localize_buffer_handler.global_to_local[global_buf_name]
|
||||
for i in range(len(local_buf.get_size())):
|
||||
for i in range(len(localize_buffer_handler.local_buf.get_size())):
|
||||
var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
|
||||
index_vars.append(var if var in used_vars else 0)
|
||||
index = local_buf.layout.make_indexer()(index_vars)
|
||||
index = localize_buffer_handler.local_buf.layout.make_indexer()(index_vars)
|
||||
return index
|
||||
|
||||
|
||||
@ -350,18 +348,20 @@ class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
|
||||
def __init__(
|
||||
self,
|
||||
inner,
|
||||
global_to_local: Dict[str, ir.Buffer],
|
||||
rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr],
|
||||
global_buf: ir.Buffer,
|
||||
local_buf: ir.Buffer,
|
||||
rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr], sympy.Expr],
|
||||
):
|
||||
super().__init__(inner)
|
||||
self.global_to_local = global_to_local
|
||||
self.global_buf = global_buf
|
||||
self.local_buf = local_buf
|
||||
self.rewrite_index = rewrite_index
|
||||
|
||||
def localize(self, name: str, index: sympy.Expr):
|
||||
if self.global_to_local and name in self.global_to_local:
|
||||
if self.global_buf and name == self.global_buf.get_name():
|
||||
assert self.rewrite_index is not None
|
||||
index = self.rewrite_index(self, index, name)
|
||||
name = self.global_to_local[name].get_name()
|
||||
name = self.local_buf.get_name()
|
||||
index = self.rewrite_index(self, index)
|
||||
return name, index
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
@ -371,8 +371,8 @@ class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
|
||||
local_buffer_name, local_buffer_index = self.localize(name, index)
|
||||
res = self._inner.store(local_buffer_name, local_buffer_index, value, mode)
|
||||
if (
|
||||
self.global_to_local
|
||||
and name in self.global_to_local
|
||||
self.global_buf
|
||||
and name == self.global_buf.get_name()
|
||||
and isinstance(V.kernel, Kernel)
|
||||
):
|
||||
# Remove name of local buffer from Kernel.store_buffer_names
|
||||
@ -397,12 +397,10 @@ class LocalBufferContext:
|
||||
def __init__(self, kernel_args: KernelArgs):
|
||||
self.kernel_args = kernel_args
|
||||
self.exit_stack = contextlib.ExitStack()
|
||||
# map local buffer name to local buffer
|
||||
# Map Local Buffer name to Local Buffer
|
||||
self.local_buffers: Dict[str, ir.Buffer] = {}
|
||||
# map global buffer name to global buffer
|
||||
self.global_buffers: Dict[str, ir.Buffer] = {}
|
||||
# map global buffer name to local buffer
|
||||
self.global_to_local: Dict[str, ir.Buffer] = {}
|
||||
# Map Local Buffer name to Global Buffer
|
||||
self.local_to_global: Dict[str, ir.Buffer] = {}
|
||||
|
||||
def __enter__(self):
|
||||
self.exit_stack.__enter__()
|
||||
@ -443,33 +441,32 @@ class LocalBufferContext:
|
||||
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def add_local_buffer(
|
||||
self, local_buffer: ir.Buffer, global_buffers: Optional[List[ir.Buffer]] = None
|
||||
self, local_buffer: ir.Buffer, global_buffer: Optional[ir.Buffer] = None
|
||||
):
|
||||
assert local_buffer.get_name() not in self.local_buffers
|
||||
self.local_buffers[local_buffer.get_name()] = local_buffer
|
||||
if global_buffers:
|
||||
for global_buffer in global_buffers:
|
||||
global_buffer_name = global_buffer.get_name()
|
||||
assert (
|
||||
global_buffer_name not in self.global_buffers
|
||||
and global_buffer_name not in self.global_to_local
|
||||
)
|
||||
self.global_buffers[global_buffer_name] = global_buffer
|
||||
self.global_to_local[global_buffer_name] = local_buffer
|
||||
V.graph.removed_buffers.add(global_buffer_name)
|
||||
if global_buffer:
|
||||
self.local_to_global[local_buffer.get_name()] = global_buffer
|
||||
V.graph.removed_buffers.add(global_buffer.get_name())
|
||||
|
||||
def localize_function(
|
||||
self,
|
||||
fn: Callable[..., Any],
|
||||
rewrite_index: Callable[
|
||||
["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
|
||||
["LocalizeBufferHandler", sympy.Expr], sympy.Expr
|
||||
] = rewrite_index_for_function,
|
||||
):
|
||||
local_buffers = list(self.local_buffers.values())
|
||||
global_buffers = list(self.local_to_global.values())
|
||||
local_buf = local_buffers[0]
|
||||
global_buf = global_buffers[0]
|
||||
|
||||
def inner(node, *index_vars):
|
||||
with V.set_ops_handler(
|
||||
LocalizeBufferHandler(
|
||||
V.get_ops_handler(),
|
||||
global_to_local=self.global_to_local,
|
||||
global_buf=global_buf,
|
||||
local_buf=local_buf,
|
||||
rewrite_index=rewrite_index,
|
||||
)
|
||||
):
|
||||
@ -481,7 +478,7 @@ class LocalBufferContext:
|
||||
self,
|
||||
nodes: List[ir.IRNode],
|
||||
rewrite_index: Callable[
|
||||
["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
|
||||
["LocalizeBufferHandler", sympy.Expr], sympy.Expr
|
||||
] = rewrite_index_for_nodes,
|
||||
) -> List[ir.IRNode]:
|
||||
"""
|
||||
@ -495,6 +492,9 @@ class LocalBufferContext:
|
||||
The the data access of `local_buf` is assumed to be contiguous with the
|
||||
same order as the `global_buf`.
|
||||
"""
|
||||
local_buffers = list(self.local_buffers.values())
|
||||
global_buffers = list(self.local_to_global.values())
|
||||
assert len(global_buffers[0].get_size()) == len(local_buffers[0].get_size())
|
||||
assert len(nodes) > 0
|
||||
|
||||
def wrap_inner_fn_for_node(node: ir.IRNode):
|
||||
|
||||
Reference in New Issue
Block a user