mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 19:24:55 +08:00
Revert "[Inductor][CPP] Enable Local Buffer for Outer loop fusion (#126967)"
This reverts commit 98929ceae3873f18f4747b88cdff708fde107aa7. Reverted https://github.com/pytorch/pytorch/pull/126967 on behalf of https://github.com/leslie-fang-intel due to Broken trunk and need rebase ([comment](https://github.com/pytorch/pytorch/pull/126967#issuecomment-2212337926))
This commit is contained in:
@ -4,7 +4,7 @@ import copy
|
||||
import math
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
@ -12,10 +12,11 @@ import sympy
|
||||
import torch
|
||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
from .. import ir
|
||||
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
|
||||
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix
|
||||
from ..virtualized import V
|
||||
|
||||
from .common import CSEVariable, ExprPrinter, Kernel, KernelArgs
|
||||
from .common import CSEVariable, ExprPrinter, Kernel
|
||||
|
||||
|
||||
DTYPE_TO_CPP = {
|
||||
torch.float32: "float",
|
||||
@ -303,88 +304,7 @@ def value_to_cpp(value, cpp_type):
|
||||
return f"static_cast<{cpp_type}>({repr(value)})"
|
||||
|
||||
|
||||
def rewrite_index_for_function(
|
||||
localize_buffer_handler: "LocalizeBufferHandler",
|
||||
index: sympy.Expr,
|
||||
):
|
||||
# Local buffer at the inner dimensions
|
||||
snode = V.graph.scheduler.name_to_node.get(
|
||||
localize_buffer_handler.global_buf.get_name()
|
||||
)
|
||||
assert snode is not None
|
||||
scheduler_nodes = snode.get_nodes()
|
||||
_, (group, reduction_group) = max(
|
||||
scheduler_nodes, key=lambda x: int(x.is_reduction())
|
||||
).group
|
||||
call_ranges = tuple(group) + tuple(reduction_group)
|
||||
indices_to_keep = [
|
||||
f"x{len(call_ranges) - (idx + 1)}"
|
||||
for idx in range(len(localize_buffer_handler.local_buf.get_layout().size))
|
||||
]
|
||||
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined]
|
||||
replacements = {}
|
||||
for x in sorted_symbols:
|
||||
if x.name.startswith("x") and x.name not in indices_to_keep: # type: ignore[attr-defined]
|
||||
# Only keep index used by local buffer
|
||||
replacements[x] = sympy.core.numbers.Zero()
|
||||
index = sympy_subs(index, replacements) # type: ignore[arg-type]
|
||||
return index
|
||||
|
||||
|
||||
def rewrite_index_for_nodes(
|
||||
localize_buffer_handler: "LocalizeBufferHandler",
|
||||
index: sympy.Expr,
|
||||
):
|
||||
used_vars = {s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)}
|
||||
index_vars = []
|
||||
for i in range(len(localize_buffer_handler.local_buf.get_size())):
|
||||
var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
|
||||
index_vars.append(var if var in used_vars else 0)
|
||||
index = localize_buffer_handler.local_buf.layout.make_indexer()(index_vars)
|
||||
return index
|
||||
|
||||
|
||||
class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
|
||||
def __init__(
|
||||
self,
|
||||
inner,
|
||||
global_buf: ir.Buffer,
|
||||
local_buf: ir.Buffer,
|
||||
rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr], sympy.Expr],
|
||||
):
|
||||
super().__init__(inner)
|
||||
self.global_buf = global_buf
|
||||
self.local_buf = local_buf
|
||||
self.rewrite_index = rewrite_index
|
||||
|
||||
def localize(self, name: str, index: sympy.Expr):
|
||||
if self.global_buf and name == self.global_buf.get_name():
|
||||
assert self.rewrite_index is not None
|
||||
name = self.local_buf.get_name()
|
||||
index = self.rewrite_index(self, index)
|
||||
return name, index
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
return self._inner.load(*self.localize(name, index))
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
local_buffer_name, local_buffer_index = self.localize(name, index)
|
||||
res = self._inner.store(local_buffer_name, local_buffer_index, value, mode)
|
||||
if (
|
||||
self.global_buf
|
||||
and name == self.global_buf.get_name()
|
||||
and isinstance(V.kernel, Kernel)
|
||||
):
|
||||
# Remove name of local buffer from Kernel.store_buffer_names
|
||||
# local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store.
|
||||
V.kernel.store_buffer_names.discard(local_buffer_name)
|
||||
return res
|
||||
|
||||
def store_reduction(self, name, index, value):
|
||||
return self._inner.store_reduction(*self.localize(name, index), value)
|
||||
|
||||
|
||||
class LocalBufferContext:
|
||||
class LocalBufferScope:
|
||||
"""
|
||||
This class creates a context that helps to generate code involving Inductor IR with
|
||||
function local buffers. These buffers are constructed during the codegen process and
|
||||
@ -394,13 +314,10 @@ class LocalBufferContext:
|
||||
these buffers without exposure to the outside world.
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_args: KernelArgs):
|
||||
self.kernel_args = kernel_args
|
||||
def __init__(self, kernel: Kernel):
|
||||
self.kernel = kernel
|
||||
self.exit_stack = contextlib.ExitStack()
|
||||
# Map Local Buffer name to Local Buffer
|
||||
self.local_buffers: Dict[str, ir.Buffer] = {}
|
||||
# Map Local Buffer name to Global Buffer
|
||||
self.local_to_global: Dict[str, ir.Buffer] = {}
|
||||
|
||||
def __enter__(self):
|
||||
self.exit_stack.__enter__()
|
||||
@ -413,26 +330,23 @@ class LocalBufferContext:
|
||||
|
||||
self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype))
|
||||
|
||||
original_input = self.kernel_args.input
|
||||
original_input = self.kernel.args.input
|
||||
|
||||
def input(name):
|
||||
if name in self.local_buffers:
|
||||
return name
|
||||
return original_input(name)
|
||||
|
||||
self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input))
|
||||
self.exit_stack.enter_context(patch.object(self.kernel.args, "input", input))
|
||||
|
||||
original_output = self.kernel_args.output
|
||||
original_output = self.kernel.args.output
|
||||
|
||||
def output(name):
|
||||
if name in self.local_buffers:
|
||||
return name
|
||||
return original_output(name)
|
||||
|
||||
self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output))
|
||||
|
||||
# Set current LocalBufferContext into V
|
||||
self.exit_stack.enter_context(V.set_local_buffer_context(self))
|
||||
self.exit_stack.enter_context(patch.object(self.kernel.args, "output", output))
|
||||
|
||||
return self
|
||||
|
||||
@ -440,64 +354,53 @@ class LocalBufferContext:
|
||||
self.local_buffers.clear()
|
||||
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def add_local_buffer(
|
||||
self, local_buffer: ir.Buffer, global_buffer: Optional[ir.Buffer] = None
|
||||
):
|
||||
assert local_buffer.get_name() not in self.local_buffers
|
||||
self.local_buffers[local_buffer.get_name()] = local_buffer
|
||||
if global_buffer:
|
||||
self.local_to_global[local_buffer.get_name()] = global_buffer
|
||||
V.graph.removed_buffers.add(global_buffer.get_name())
|
||||
def add_local_buffer(self, buffer: ir.Buffer):
|
||||
assert buffer.get_name() not in self.local_buffers
|
||||
self.local_buffers[buffer.get_name()] = buffer
|
||||
|
||||
def localize_function(
|
||||
self,
|
||||
fn: Callable[..., Any],
|
||||
rewrite_index: Callable[
|
||||
["LocalizeBufferHandler", sympy.Expr], sympy.Expr
|
||||
] = rewrite_index_for_function,
|
||||
):
|
||||
local_buffers = list(self.local_buffers.values())
|
||||
global_buffers = list(self.local_to_global.values())
|
||||
local_buf = local_buffers[0]
|
||||
global_buf = global_buffers[0]
|
||||
|
||||
def inner(node, *index_vars):
|
||||
with V.set_ops_handler(
|
||||
LocalizeBufferHandler(
|
||||
V.get_ops_handler(),
|
||||
global_buf=global_buf,
|
||||
local_buf=local_buf,
|
||||
rewrite_index=rewrite_index,
|
||||
)
|
||||
):
|
||||
return fn(node, *index_vars)
|
||||
|
||||
return inner
|
||||
|
||||
def localize_nodes(
|
||||
self,
|
||||
nodes: List[ir.IRNode],
|
||||
rewrite_index: Callable[
|
||||
["LocalizeBufferHandler", sympy.Expr], sympy.Expr
|
||||
] = rewrite_index_for_nodes,
|
||||
def localize_buffer(
|
||||
self, global_buf: ir.Buffer, local_buf: ir.Buffer, nodes: List[ir.IRNode]
|
||||
) -> List[ir.IRNode]:
|
||||
"""
|
||||
Given `local_buf` and `global_buf` registered in current `LocalBufferContext`
|
||||
though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf`
|
||||
for the given `nodes` and returns a new list of IR nodes that work on `local_buf`
|
||||
instead of `global_buf`, i.e., all the loads and stores are redirected to
|
||||
`local_buf`. This helps the fused loops to work on smaller-sized local buffers
|
||||
for better data locality.
|
||||
Localizes the buffer `global_buf` to `local_buf` in the given `nodes` and returns
|
||||
a new list of IR nodes that work on `local_buf` instead of `global_buf`, i.e., all
|
||||
the loads and stores are redirected to `local_buf`. This helps the fused loops to
|
||||
work on smaller-sized local buffers for better data locality.
|
||||
|
||||
The the data access of `local_buf` is assumed to be contiguous with the
|
||||
same order as the `global_buf`.
|
||||
The `local_buf` should already be registered in the local scope and the data access
|
||||
is assumed to be contiguous with the same order as the `global_buf`.
|
||||
"""
|
||||
local_buffers = list(self.local_buffers.values())
|
||||
global_buffers = list(self.local_to_global.values())
|
||||
assert len(global_buffers[0].get_size()) == len(local_buffers[0].get_size())
|
||||
assert local_buf.get_name() in self.local_buffers
|
||||
assert len(global_buf.get_size()) == len(local_buf.get_size())
|
||||
assert len(nodes) > 0
|
||||
|
||||
def wrap_inner_fn_for_node(node: ir.IRNode):
|
||||
class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
|
||||
def __init__(self, inner):
|
||||
super().__init__(inner)
|
||||
|
||||
def localize(self, name: str, index: sympy.Expr):
|
||||
if name == global_buf.get_name():
|
||||
name = local_buf.get_name()
|
||||
used_vars = {
|
||||
s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)
|
||||
}
|
||||
index_vars = []
|
||||
for i in range(len(local_buf.get_size())):
|
||||
var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
|
||||
index_vars.append(var if var in used_vars else 0)
|
||||
index = local_buf.layout.make_indexer()(index_vars)
|
||||
return name, index
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
return self._inner.load(*self.localize(name, index))
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
return self._inner.store(*self.localize(name, index), value, mode)
|
||||
|
||||
def store_reduction(self, name, index, value):
|
||||
return self._inner.store_reduction(*self.localize(name, index), value)
|
||||
|
||||
def wrap_inner_fn_for_node(node: ir.IRNode, inner_fn_wrapper):
|
||||
loops = node.data if isinstance(node, ir.ComputedBuffer) else node
|
||||
assert isinstance(loops, ir.Loops)
|
||||
new_loops = copy.copy(loops)
|
||||
@ -508,13 +411,17 @@ class LocalBufferContext:
|
||||
else:
|
||||
new_node = new_loops # type: ignore[assignment]
|
||||
|
||||
new_loops.inner_fn = self.localize_function(
|
||||
new_loops.inner_fn,
|
||||
rewrite_index,
|
||||
)
|
||||
new_loops.inner_fn = inner_fn_wrapper(new_loops.inner_fn)
|
||||
return new_node
|
||||
|
||||
return [wrap_inner_fn_for_node(node) for node in nodes]
|
||||
def inner_fn_wrapper(inner_fn):
|
||||
def inner(index):
|
||||
with V.set_ops_handler(LocalizeBufferHandler(V.get_ops_handler())):
|
||||
return inner_fn(index)
|
||||
|
||||
return inner
|
||||
|
||||
return [wrap_inner_fn_for_node(node, inner_fn_wrapper) for node in nodes]
|
||||
|
||||
|
||||
def unify_mask_base_type(
|
||||
|
||||
Reference in New Issue
Block a user