Revert "[Inductor][CPP] Enable Local Buffer for Outer loop fusion (#126967)"

This reverts commit 98929ceae3873f18f4747b88cdff708fde107aa7.

Reverted https://github.com/pytorch/pytorch/pull/126967 on behalf of https://github.com/leslie-fang-intel due to Broken trunk and need rebase ([comment](https://github.com/pytorch/pytorch/pull/126967#issuecomment-2212337926))
This commit is contained in:
PyTorch MergeBot
2024-07-07 06:16:32 +00:00
parent 1b57dce35f
commit e423224546
6 changed files with 106 additions and 415 deletions

View File

@ -2556,7 +2556,6 @@ class CPUReproTests(TestCase):
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 0
@config.patch(fx_graph_cache=False)
def test_outer_loop_fusion(self):
def fn(x):
max = torch.amax(x, dim=-1, keepdim=True)
@ -2568,47 +2567,8 @@ class CPUReproTests(TestCase):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x,))
self.assertEqual(
len(metrics.cpp_outer_loop_fused_inner_counts),
1,
)
self.assertEqual(
metrics.cpp_outer_loop_fused_inner_counts[0].inner_kernel_number,
2,
)
@config.patch(fx_graph_cache=False)
def test_local_buffer_in_outer_loop_fusion(self):
def fn(x):
max = torch.nn.functional.softmax(x, dim=-1)
return x - max
x = torch.randn(4, 12, 1023, 1022)
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x,))
self.assertEqual(
len(metrics.cpp_outer_loop_fused_inner_counts),
1,
)
self.assertEqual(
metrics.cpp_outer_loop_fused_inner_counts[0].inner_kernel_number,
3,
)
self.assertEqual(
metrics.cpp_outer_loop_fused_inner_counts[0].local_buffer_number,
1,
)
# Check the number of global buffer allocation
torch._dynamo.reset()
metrics.reset()
_, code = run_and_get_cpp_code(
torch._dynamo.optimize("inductor")(fn),
x,
)
self.assertEqual(code.count("empty_strided_cpu("), 3)
assert len(metrics.cpp_outer_loop_fused_inner_counts) == 1
assert metrics.cpp_outer_loop_fused_inner_counts[0] == 2
def test_argmin(self):
def fn(x):

View File

@ -7,7 +7,6 @@ import logging
import math
import re
import sys
from collections import namedtuple
from copy import copy, deepcopy
from enum import Enum
from typing import Any, cast, Dict, List, Optional, Sequence, Set, Tuple, Union
@ -70,7 +69,6 @@ from .cpp_utils import (
cexpr_index,
DTYPE_TO_CPP,
INDEX_TYPE,
LocalBufferContext,
unify_mask_base_type,
value_to_cpp,
)
@ -437,6 +435,8 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode):
loop_nest_list: List[LoopNestWithSplit] = [
kernel.loop_nest for kernel in cpp_kernel_proxy_list
]
metrics.cpp_outer_loop_fused_inner_counts.append(len(loop_nest_list))
kernel_group = cpp_kernel_proxy_list[0].kernel_group
def _merge_outer_fusion_loop_levels(
@ -1915,10 +1915,7 @@ class CppKernel(Kernel):
threads = parallel_num_threads()
assert self.call_ranges is not None
kernels = loop_nest.get_kernels()
has_outer_loop_kernel = any(
isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels
)
if has_outer_loop_kernel:
if any(isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels):
assert len(kernels) == 1
assert isinstance(kernels[0], OuterLoopFusedKernel)
par_depth = kernels[0].decide_parallel_depth(
@ -2048,31 +2045,6 @@ class CppKernel(Kernel):
stack.enter_context(code.indent())
if loop_nest.root:
if (
has_outer_loop_kernel
and isinstance(V.local_buffer_context, LocalBufferContext)
and V.local_buffer_context.local_buffers
):
# Allocate local buffer
local_buffers = V.local_buffer_context.local_buffers
assert len(local_buffers.items()) == 1
local_buffer = next(iter(local_buffers.items()))[1]
# For dynamic size, rename s to ks
local_buf_size = sympy_product(
[
self.rename_indexing(size_val)
for size_val in local_buffer.get_layout().size
]
)
local_buf_dtype = DTYPE_TO_CPP[local_buffer.get_layout().dtype]
allocate = f"std::make_unique<{local_buf_dtype} []>({cexpr(local_buf_size)})"
code.splice(
f"std::unique_ptr<{local_buf_dtype} []> local_buffer = {allocate};"
)
local_buffer_name = local_buffer.get_name()
code.splice(
f"{local_buf_dtype}* {local_buffer_name} = local_buffer.get();"
)
gen_loops(loop_nest.root)
else:
gen_kernel(loop_nest.kernel)
@ -3528,18 +3500,6 @@ class CppKernelProxy(CppKernel):
return node.codegen(index_vars)
fn_list = [functools.partial(fn, node) for node in nodes]
if (
isinstance(V.local_buffer_context, LocalBufferContext)
and V.local_buffer_context.local_buffers
):
fn_list = [
V.local_buffer_context.localize_function(
fn,
)
for fn in fn_list
]
var_sizes_list = [node.group[1] for node in nodes]
self.codegen_functions(fn_list, var_sizes_list, vec_dtype)
@ -3847,159 +3807,6 @@ class CppScheduling(BaseScheduling):
self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction()
) or self.can_fuse_vertical_outer_loop(node1, node2)
def codegen_outer_loop_node(
self,
node: OuterLoopFusedSchedulerNode,
):
"""
Generate the code for the outer loop fused scheduler node.
1. Codegen with fused outer loop: depends on the analysis of
the outer loop fused scheduler node, with or without the local buffer.
2. If failed, fallback to standard codegen.
"""
kernel_group = self.kernel_group
generated_cpp_vec_kernel_count = metrics.generated_cpp_vec_kernel_count
cpp_kernel_proxy_list: List[CppKernelProxy] = []
nodes_list: List[List[SchedulerNode]] = []
assert isinstance(node, OuterLoopFusedSchedulerNode)
def try_outer_loop_fusion_with_local_buf(node: OuterLoopFusedSchedulerNode):
"""
Codegen code with fused outer loop and local Buffer.
"""
assert isinstance(node, OuterLoopFusedSchedulerNode)
cpp_kernel_proxy_list.clear()
nodes_list.clear()
def get_call_ranges(node: BaseSchedulerNode):
assert isinstance(node, (SchedulerNode, FusedSchedulerNode))
nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment]
_, (group, reduction_group) = max(
nodes, key=lambda x: int(x.is_reduction())
).group
call_ranges = tuple(group) + tuple(reduction_group)
return call_ranges
LocalBuffer = namedtuple("LocalBuffer", ["local_buf", "global_buf"])
local_buffers: List[LocalBuffer] = []
if all(
len(get_call_ranges(_node)) == node.outer_loop_fusion_depth + 1
for _node in node.get_outer_nodes()
):
# Ref to the typical case of local buffer
# in https://github.com/pytorch/pytorch/blob/
# 1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159
# where the buffer is with size of last dim and contiguous.
# Only support this typical case at first.
for scheduler_node in node.get_nodes():
# all users inside same OuterLoopFusedSchedulerNode
if not scheduler_node.is_reduction() and all(
user.node in node.get_nodes() for user in scheduler_node.users
):
global_buffer = scheduler_node.node
assert isinstance(global_buffer, ir.ComputedBuffer)
global_buffer_layout = global_buffer.get_layout()
size_offset = node.outer_loop_fusion_depth - len(
get_call_ranges(scheduler_node)
)
def is_all_write_read_contiguous(scheduler_node):
contiguous_index_expr = 0
stride = 1
for var, range in reversed(
scheduler_node._body.var_ranges.items()
):
contiguous_index_expr += stride * var
stride *= range
write_index_expr = scheduler_node._body.writes_name2expr[
scheduler_node.get_name()
]
def is_contiguous_index(x):
return x == contiguous_index_expr
return is_contiguous_index(write_index_expr) and all(
is_contiguous_index(
user.node._body.reads_name2expr[
scheduler_node.get_name()
],
)
for user in scheduler_node.users
)
if not (
global_buffer_layout.is_contiguous()
and not scheduler_node.is_reduction()
and is_all_write_read_contiguous(scheduler_node)
):
continue
# Local Buffer is a view of global buffer
local_buffer_layout = ir.FixedLayout(
global_buffer_layout.device,
global_buffer_layout.dtype,
global_buffer_layout.size[size_offset:],
global_buffer_layout.stride[size_offset:],
)
local_buffers.append(
LocalBuffer(
local_buf=ir.Buffer(
"local_buffer_data", local_buffer_layout
),
global_buf=global_buffer,
)
)
# At most 1 node with local buf for each OuterLoopFusedSchedulerNode
break
assert len(local_buffers) in [0, 1]
with LocalBufferContext(kernel_group.args) as scope:
if len(local_buffers) > 0:
scope.add_local_buffer(
local_buffers[0].local_buf, local_buffers[0].global_buf
)
for _node in node.get_outer_nodes():
assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
cpp_kernel_proxy = CppKernelProxy(kernel_group)
cpp_kernel_proxy.codegen_nodes(_node.get_nodes()) # type: ignore[arg-type]
cpp_kernel_proxy_list.append(cpp_kernel_proxy)
nodes_list.append(_node.get_nodes()) # type: ignore[arg-type]
if not node.check_outer_fusion_loop_level_attr(
cpp_kernel_proxy_list, node.outer_loop_fusion_depth
):
return False
metrics.cpp_outer_loop_fused_inner_counts.append(
metrics.CppOuterLoopFusedCount(
len(cpp_kernel_proxy_list),
local_buffer_number=len(local_buffers),
)
)
outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels(
cpp_kernel_proxy_list,
)
kernel_group.finalize_kernel(
outer_fusion_cpp_kernel_proxy,
[_node for _nodes in nodes_list for _node in _nodes],
)
return True
if not try_outer_loop_fusion_with_local_buf(node):
# Reset generated_cpp_vec_kernel_count to codegen again
metrics.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count
cpp_kernel_proxy_list.clear()
nodes_list.clear()
# Similar as comment in
# https://github.com/pytorch/pytorch/blob/469383755fe416eb1c41fa724762ad3eaecdff07/torch/_inductor/codegen/cpp.py#L3269-L3272
# Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args.
with torch._inductor.config.patch(inplace_buffers=False):
for _node in node.get_outer_nodes():
assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
_nodes: List[SchedulerNode] = _node.get_nodes() # type: ignore[assignment]
cpp_kernel_proxy = CppKernelProxy(kernel_group)
cpp_kernel_proxy.codegen_nodes(_nodes)
kernel_group.finalize_kernel(cpp_kernel_proxy, _nodes)
def codegen_node(
self,
node: Union[OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode],
@ -4010,7 +3817,38 @@ class CppScheduling(BaseScheduling):
kernel_group = self.kernel_group
if isinstance(node, OuterLoopFusedSchedulerNode):
self.codegen_outer_loop_node(node)
cpp_kernel_proxy_list: List[CppKernelProxy] = []
nodes_list: List[List[SchedulerNode]] = []
for _node in node.get_outer_nodes():
assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
_nodes: List[SchedulerNode] = _node.get_nodes() # type: ignore[assignment]
cpp_kernel_proxy = CppKernelProxy(kernel_group)
cpp_kernel_proxy.codegen_nodes(_nodes)
cpp_kernel_proxy_list.append(cpp_kernel_proxy)
nodes_list.append(_nodes)
# Note that, in the future, when every kernel can be vectorized,
# the function select_tiling will be much easier, and we'll be able to lift
# check_outer_fusion_loop_level_attr to the fusion phase,
# avoiding grouping kernels at fusion time that "look like we'll be able to fuse them"
# but then we actually won't.
if node.check_outer_fusion_loop_level_attr(
cpp_kernel_proxy_list, node.outer_loop_fusion_depth
):
# Merge the cpp_kernel_proxy_list into cpp_kernel_proxy
outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels(
cpp_kernel_proxy_list,
)
kernel_group.finalize_kernel(
outer_fusion_cpp_kernel_proxy,
[_node for _nodes in nodes_list for _node in _nodes],
)
else:
# Fall back to standard loop codegen
for _kernel_proxy, _nodes in zip(cpp_kernel_proxy_list, nodes_list):
kernel_group.finalize_kernel(_kernel_proxy, _nodes)
else:
nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment]
cpp_kernel_proxy = CppKernelProxy(kernel_group)

View File

@ -14,7 +14,7 @@ from ..select_algorithm import PartialRender
from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix
from ..virtualized import V
from .cpp import CppKernel, CppKernelProxy, KernelGroup
from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext
from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferScope
def parse_expr_with_index_symbols(expr):
@ -270,11 +270,13 @@ class CppTemplateKernel(CppKernel):
if offsets:
offsets = parse_expr_with_index_symbols(offsets)
if epilogue_nodes:
with LocalBufferContext(self.args) as scope:
with LocalBufferScope(self) as scope:
assert orig_src is not None
if orig_src.get_name() != src.get_name():
scope.add_local_buffer(src, orig_src)
epilogue_nodes = scope.localize_nodes(epilogue_nodes)
scope.add_local_buffer(src)
epilogue_nodes = scope.localize_buffer(
orig_src, src, epilogue_nodes
)
return self.store_pointwise_nodes(
dst, epilogue_nodes, offsets, reindexers # type: ignore[arg-type]
)
@ -282,7 +284,7 @@ class CppTemplateKernel(CppKernel):
if dst.get_name() != src.get_name():
# src is local
copy = L.copy(dst, src).data.data
with LocalBufferContext(self.args) as scope:
with LocalBufferScope(self) as scope:
scope.add_local_buffer(src)
return self.store_pointwise_nodes(dst, [copy])
else:

View File

@ -4,7 +4,7 @@ import copy
import math
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Dict, List, Tuple
from unittest.mock import patch
import sympy
@ -12,10 +12,11 @@ import sympy
import torch
from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import ir
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix
from ..virtualized import V
from .common import CSEVariable, ExprPrinter, Kernel, KernelArgs
from .common import CSEVariable, ExprPrinter, Kernel
DTYPE_TO_CPP = {
torch.float32: "float",
@ -303,88 +304,7 @@ def value_to_cpp(value, cpp_type):
return f"static_cast<{cpp_type}>({repr(value)})"
def rewrite_index_for_function(
localize_buffer_handler: "LocalizeBufferHandler",
index: sympy.Expr,
):
# Local buffer at the inner dimensions
snode = V.graph.scheduler.name_to_node.get(
localize_buffer_handler.global_buf.get_name()
)
assert snode is not None
scheduler_nodes = snode.get_nodes()
_, (group, reduction_group) = max(
scheduler_nodes, key=lambda x: int(x.is_reduction())
).group
call_ranges = tuple(group) + tuple(reduction_group)
indices_to_keep = [
f"x{len(call_ranges) - (idx + 1)}"
for idx in range(len(localize_buffer_handler.local_buf.get_layout().size))
]
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined]
replacements = {}
for x in sorted_symbols:
if x.name.startswith("x") and x.name not in indices_to_keep: # type: ignore[attr-defined]
# Only keep index used by local buffer
replacements[x] = sympy.core.numbers.Zero()
index = sympy_subs(index, replacements) # type: ignore[arg-type]
return index
def rewrite_index_for_nodes(
localize_buffer_handler: "LocalizeBufferHandler",
index: sympy.Expr,
):
used_vars = {s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)}
index_vars = []
for i in range(len(localize_buffer_handler.local_buf.get_size())):
var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
index_vars.append(var if var in used_vars else 0)
index = localize_buffer_handler.local_buf.layout.make_indexer()(index_vars)
return index
class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
def __init__(
self,
inner,
global_buf: ir.Buffer,
local_buf: ir.Buffer,
rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr], sympy.Expr],
):
super().__init__(inner)
self.global_buf = global_buf
self.local_buf = local_buf
self.rewrite_index = rewrite_index
def localize(self, name: str, index: sympy.Expr):
if self.global_buf and name == self.global_buf.get_name():
assert self.rewrite_index is not None
name = self.local_buf.get_name()
index = self.rewrite_index(self, index)
return name, index
def load(self, name: str, index: sympy.Expr):
return self._inner.load(*self.localize(name, index))
def store(self, name, index, value, mode=None):
local_buffer_name, local_buffer_index = self.localize(name, index)
res = self._inner.store(local_buffer_name, local_buffer_index, value, mode)
if (
self.global_buf
and name == self.global_buf.get_name()
and isinstance(V.kernel, Kernel)
):
# Remove name of local buffer from Kernel.store_buffer_names
# local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store.
V.kernel.store_buffer_names.discard(local_buffer_name)
return res
def store_reduction(self, name, index, value):
return self._inner.store_reduction(*self.localize(name, index), value)
class LocalBufferContext:
class LocalBufferScope:
"""
This class creates a context that helps to generate code involving Inductor IR with
function local buffers. These buffers are constructed during the codegen process and
@ -394,13 +314,10 @@ class LocalBufferContext:
these buffers without exposure to the outside world.
"""
def __init__(self, kernel_args: KernelArgs):
self.kernel_args = kernel_args
def __init__(self, kernel: Kernel):
self.kernel = kernel
self.exit_stack = contextlib.ExitStack()
# Map Local Buffer name to Local Buffer
self.local_buffers: Dict[str, ir.Buffer] = {}
# Map Local Buffer name to Global Buffer
self.local_to_global: Dict[str, ir.Buffer] = {}
def __enter__(self):
self.exit_stack.__enter__()
@ -413,26 +330,23 @@ class LocalBufferContext:
self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype))
original_input = self.kernel_args.input
original_input = self.kernel.args.input
def input(name):
if name in self.local_buffers:
return name
return original_input(name)
self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input))
self.exit_stack.enter_context(patch.object(self.kernel.args, "input", input))
original_output = self.kernel_args.output
original_output = self.kernel.args.output
def output(name):
if name in self.local_buffers:
return name
return original_output(name)
self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output))
# Set current LocalBufferContext into V
self.exit_stack.enter_context(V.set_local_buffer_context(self))
self.exit_stack.enter_context(patch.object(self.kernel.args, "output", output))
return self
@ -440,64 +354,53 @@ class LocalBufferContext:
self.local_buffers.clear()
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
def add_local_buffer(
self, local_buffer: ir.Buffer, global_buffer: Optional[ir.Buffer] = None
):
assert local_buffer.get_name() not in self.local_buffers
self.local_buffers[local_buffer.get_name()] = local_buffer
if global_buffer:
self.local_to_global[local_buffer.get_name()] = global_buffer
V.graph.removed_buffers.add(global_buffer.get_name())
def add_local_buffer(self, buffer: ir.Buffer):
assert buffer.get_name() not in self.local_buffers
self.local_buffers[buffer.get_name()] = buffer
def localize_function(
self,
fn: Callable[..., Any],
rewrite_index: Callable[
["LocalizeBufferHandler", sympy.Expr], sympy.Expr
] = rewrite_index_for_function,
):
local_buffers = list(self.local_buffers.values())
global_buffers = list(self.local_to_global.values())
local_buf = local_buffers[0]
global_buf = global_buffers[0]
def inner(node, *index_vars):
with V.set_ops_handler(
LocalizeBufferHandler(
V.get_ops_handler(),
global_buf=global_buf,
local_buf=local_buf,
rewrite_index=rewrite_index,
)
):
return fn(node, *index_vars)
return inner
def localize_nodes(
self,
nodes: List[ir.IRNode],
rewrite_index: Callable[
["LocalizeBufferHandler", sympy.Expr], sympy.Expr
] = rewrite_index_for_nodes,
def localize_buffer(
self, global_buf: ir.Buffer, local_buf: ir.Buffer, nodes: List[ir.IRNode]
) -> List[ir.IRNode]:
"""
Given `local_buf` and `global_buf` registered in current `LocalBufferContext`
though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf`
for the given `nodes` and returns a new list of IR nodes that work on `local_buf`
instead of `global_buf`, i.e., all the loads and stores are redirected to
`local_buf`. This helps the fused loops to work on smaller-sized local buffers
for better data locality.
Localizes the buffer `global_buf` to `local_buf` in the given `nodes` and returns
a new list of IR nodes that work on `local_buf` instead of `global_buf`, i.e., all
the loads and stores are redirected to `local_buf`. This helps the fused loops to
work on smaller-sized local buffers for better data locality.
The the data access of `local_buf` is assumed to be contiguous with the
same order as the `global_buf`.
The `local_buf` should already be registered in the local scope and the data access
is assumed to be contiguous with the same order as the `global_buf`.
"""
local_buffers = list(self.local_buffers.values())
global_buffers = list(self.local_to_global.values())
assert len(global_buffers[0].get_size()) == len(local_buffers[0].get_size())
assert local_buf.get_name() in self.local_buffers
assert len(global_buf.get_size()) == len(local_buf.get_size())
assert len(nodes) > 0
def wrap_inner_fn_for_node(node: ir.IRNode):
class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
def __init__(self, inner):
super().__init__(inner)
def localize(self, name: str, index: sympy.Expr):
if name == global_buf.get_name():
name = local_buf.get_name()
used_vars = {
s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)
}
index_vars = []
for i in range(len(local_buf.get_size())):
var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
index_vars.append(var if var in used_vars else 0)
index = local_buf.layout.make_indexer()(index_vars)
return name, index
def load(self, name: str, index: sympy.Expr):
return self._inner.load(*self.localize(name, index))
def store(self, name, index, value, mode=None):
return self._inner.store(*self.localize(name, index), value, mode)
def store_reduction(self, name, index, value):
return self._inner.store_reduction(*self.localize(name, index), value)
def wrap_inner_fn_for_node(node: ir.IRNode, inner_fn_wrapper):
loops = node.data if isinstance(node, ir.ComputedBuffer) else node
assert isinstance(loops, ir.Loops)
new_loops = copy.copy(loops)
@ -508,13 +411,17 @@ class LocalBufferContext:
else:
new_node = new_loops # type: ignore[assignment]
new_loops.inner_fn = self.localize_function(
new_loops.inner_fn,
rewrite_index,
)
new_loops.inner_fn = inner_fn_wrapper(new_loops.inner_fn)
return new_node
return [wrap_inner_fn_for_node(node) for node in nodes]
def inner_fn_wrapper(inner_fn):
def inner(index):
with V.set_ops_handler(LocalizeBufferHandler(V.get_ops_handler())):
return inner_fn(index)
return inner
return [wrap_inner_fn_for_node(node, inner_fn_wrapper) for node in nodes]
def unify_mask_base_type(

View File

@ -41,15 +41,9 @@ ir_nodes_pre_fusion = 0
# counters for tracking to_dtype inserted
cpp_to_dtype_count = 0
@dataclasses.dataclass
class CppOuterLoopFusedCount:
inner_kernel_number: int
local_buffer_number: int = 0
# The length counts the number of outer loop fusions.
cpp_outer_loop_fused_inner_counts: List[CppOuterLoopFusedCount] = []
# Each element counts the number of inner kernels in each outer loop fusion.
cpp_outer_loop_fused_inner_counts: List[int] = []
num_comprehensive_padding = 0
num_matches_for_scatter_upon_const_tensor = 0

View File

@ -72,7 +72,6 @@ from .ops_handler import ( # noqa: F401
if TYPE_CHECKING:
import torch
from torch._inductor.codegen.cpp_utils import LocalBufferContext
from torch._inductor.debug import DebugContext
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import InterpreterShim
@ -163,9 +162,6 @@ _debug: Virtualized[DebugContext] = Virtualized("debug", NullHandler)
_interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler)
_aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler)
_current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler)
_local_buffer_context: Virtualized[LocalBufferContext] = Virtualized(
"local_buffer_context", NullHandler
)
class OpsValue:
@ -310,8 +306,6 @@ class _V:
get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler
set_current_node: Callable[[Any], Any] = _current_node._set_handler
get_current_node: Callable[[], Any] = _current_node._get_handler
set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler
get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler
@property
def ops(self) -> OpsHandler[Any]:
@ -354,9 +348,5 @@ class _V:
def current_node(self):
return _current_node._get_handler()
@property
def local_buffer_context(self):
return _local_buffer_context._get_handler()
V = _V()