mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Revert "[Inductor][CPP] Enable Local Buffer for Outer loop fusion (#126967)"
This reverts commit 98929ceae3873f18f4747b88cdff708fde107aa7. Reverted https://github.com/pytorch/pytorch/pull/126967 on behalf of https://github.com/leslie-fang-intel due to Broken trunk and need rebase ([comment](https://github.com/pytorch/pytorch/pull/126967#issuecomment-2212337926))
This commit is contained in:
		| @ -2556,7 +2556,6 @@ class CPUReproTests(TestCase): | ||||
|                 self.common(fn, (x,)) | ||||
|                 assert metrics.generated_cpp_vec_kernel_count == 0 | ||||
|  | ||||
|     @config.patch(fx_graph_cache=False) | ||||
|     def test_outer_loop_fusion(self): | ||||
|         def fn(x): | ||||
|             max = torch.amax(x, dim=-1, keepdim=True) | ||||
| @ -2568,47 +2567,8 @@ class CPUReproTests(TestCase): | ||||
|             torch._dynamo.reset() | ||||
|             metrics.reset() | ||||
|             self.common(fn, (x,)) | ||||
|             self.assertEqual( | ||||
|                 len(metrics.cpp_outer_loop_fused_inner_counts), | ||||
|                 1, | ||||
|             ) | ||||
|             self.assertEqual( | ||||
|                 metrics.cpp_outer_loop_fused_inner_counts[0].inner_kernel_number, | ||||
|                 2, | ||||
|             ) | ||||
|  | ||||
|     @config.patch(fx_graph_cache=False) | ||||
|     def test_local_buffer_in_outer_loop_fusion(self): | ||||
|         def fn(x): | ||||
|             max = torch.nn.functional.softmax(x, dim=-1) | ||||
|             return x - max | ||||
|  | ||||
|         x = torch.randn(4, 12, 1023, 1022) | ||||
|  | ||||
|         with config.patch({"cpp.simdlen": None}): | ||||
|             torch._dynamo.reset() | ||||
|             metrics.reset() | ||||
|             self.common(fn, (x,)) | ||||
|             self.assertEqual( | ||||
|                 len(metrics.cpp_outer_loop_fused_inner_counts), | ||||
|                 1, | ||||
|             ) | ||||
|             self.assertEqual( | ||||
|                 metrics.cpp_outer_loop_fused_inner_counts[0].inner_kernel_number, | ||||
|                 3, | ||||
|             ) | ||||
|             self.assertEqual( | ||||
|                 metrics.cpp_outer_loop_fused_inner_counts[0].local_buffer_number, | ||||
|                 1, | ||||
|             ) | ||||
|             # Check the number of global buffer allocation | ||||
|             torch._dynamo.reset() | ||||
|             metrics.reset() | ||||
|             _, code = run_and_get_cpp_code( | ||||
|                 torch._dynamo.optimize("inductor")(fn), | ||||
|                 x, | ||||
|             ) | ||||
|             self.assertEqual(code.count("empty_strided_cpu("), 3) | ||||
|             assert len(metrics.cpp_outer_loop_fused_inner_counts) == 1 | ||||
|             assert metrics.cpp_outer_loop_fused_inner_counts[0] == 2 | ||||
|  | ||||
|     def test_argmin(self): | ||||
|         def fn(x): | ||||
|  | ||||
| @ -7,7 +7,6 @@ import logging | ||||
| import math | ||||
| import re | ||||
| import sys | ||||
| from collections import namedtuple | ||||
| from copy import copy, deepcopy | ||||
| from enum import Enum | ||||
| from typing import Any, cast, Dict, List, Optional, Sequence, Set, Tuple, Union | ||||
| @ -70,7 +69,6 @@ from .cpp_utils import ( | ||||
|     cexpr_index, | ||||
|     DTYPE_TO_CPP, | ||||
|     INDEX_TYPE, | ||||
|     LocalBufferContext, | ||||
|     unify_mask_base_type, | ||||
|     value_to_cpp, | ||||
| ) | ||||
| @ -437,6 +435,8 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode): | ||||
|         loop_nest_list: List[LoopNestWithSplit] = [ | ||||
|             kernel.loop_nest for kernel in cpp_kernel_proxy_list | ||||
|         ] | ||||
|         metrics.cpp_outer_loop_fused_inner_counts.append(len(loop_nest_list)) | ||||
|  | ||||
|         kernel_group = cpp_kernel_proxy_list[0].kernel_group | ||||
|  | ||||
|         def _merge_outer_fusion_loop_levels( | ||||
| @ -1915,10 +1915,7 @@ class CppKernel(Kernel): | ||||
|         threads = parallel_num_threads() | ||||
|         assert self.call_ranges is not None | ||||
|         kernels = loop_nest.get_kernels() | ||||
|         has_outer_loop_kernel = any( | ||||
|             isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels | ||||
|         ) | ||||
|         if has_outer_loop_kernel: | ||||
|         if any(isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels): | ||||
|             assert len(kernels) == 1 | ||||
|             assert isinstance(kernels[0], OuterLoopFusedKernel) | ||||
|             par_depth = kernels[0].decide_parallel_depth( | ||||
| @ -2048,31 +2045,6 @@ class CppKernel(Kernel): | ||||
|  | ||||
|             stack.enter_context(code.indent()) | ||||
|             if loop_nest.root: | ||||
|                 if ( | ||||
|                     has_outer_loop_kernel | ||||
|                     and isinstance(V.local_buffer_context, LocalBufferContext) | ||||
|                     and V.local_buffer_context.local_buffers | ||||
|                 ): | ||||
|                     # Allocate local buffer | ||||
|                     local_buffers = V.local_buffer_context.local_buffers | ||||
|                     assert len(local_buffers.items()) == 1 | ||||
|                     local_buffer = next(iter(local_buffers.items()))[1] | ||||
|                     # For dynamic size, rename s to ks | ||||
|                     local_buf_size = sympy_product( | ||||
|                         [ | ||||
|                             self.rename_indexing(size_val) | ||||
|                             for size_val in local_buffer.get_layout().size | ||||
|                         ] | ||||
|                     ) | ||||
|                     local_buf_dtype = DTYPE_TO_CPP[local_buffer.get_layout().dtype] | ||||
|                     allocate = f"std::make_unique<{local_buf_dtype} []>({cexpr(local_buf_size)})" | ||||
|                     code.splice( | ||||
|                         f"std::unique_ptr<{local_buf_dtype} []> local_buffer = {allocate};" | ||||
|                     ) | ||||
|                     local_buffer_name = local_buffer.get_name() | ||||
|                     code.splice( | ||||
|                         f"{local_buf_dtype}* {local_buffer_name} = local_buffer.get();" | ||||
|                     ) | ||||
|                 gen_loops(loop_nest.root) | ||||
|             else: | ||||
|                 gen_kernel(loop_nest.kernel) | ||||
| @ -3528,18 +3500,6 @@ class CppKernelProxy(CppKernel): | ||||
|                 return node.codegen(index_vars) | ||||
|  | ||||
|         fn_list = [functools.partial(fn, node) for node in nodes] | ||||
|  | ||||
|         if ( | ||||
|             isinstance(V.local_buffer_context, LocalBufferContext) | ||||
|             and V.local_buffer_context.local_buffers | ||||
|         ): | ||||
|             fn_list = [ | ||||
|                 V.local_buffer_context.localize_function( | ||||
|                     fn, | ||||
|                 ) | ||||
|                 for fn in fn_list | ||||
|             ] | ||||
|  | ||||
|         var_sizes_list = [node.group[1] for node in nodes] | ||||
|         self.codegen_functions(fn_list, var_sizes_list, vec_dtype) | ||||
|  | ||||
| @ -3847,159 +3807,6 @@ class CppScheduling(BaseScheduling): | ||||
|             self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() | ||||
|         ) or self.can_fuse_vertical_outer_loop(node1, node2) | ||||
|  | ||||
|     def codegen_outer_loop_node( | ||||
|         self, | ||||
|         node: OuterLoopFusedSchedulerNode, | ||||
|     ): | ||||
|         """ | ||||
|         Generate the code for the outer loop fused scheduler node. | ||||
|         1. Codegen with fused outer loop: depends on the analysis of | ||||
|             the outer loop fused scheduler node, with or without the local buffer. | ||||
|         2. If failed, fallback to standard codegen. | ||||
|         """ | ||||
|         kernel_group = self.kernel_group | ||||
|         generated_cpp_vec_kernel_count = metrics.generated_cpp_vec_kernel_count | ||||
|         cpp_kernel_proxy_list: List[CppKernelProxy] = [] | ||||
|         nodes_list: List[List[SchedulerNode]] = [] | ||||
|         assert isinstance(node, OuterLoopFusedSchedulerNode) | ||||
|  | ||||
|         def try_outer_loop_fusion_with_local_buf(node: OuterLoopFusedSchedulerNode): | ||||
|             """ | ||||
|             Codegen code with fused outer loop and local Buffer. | ||||
|             """ | ||||
|             assert isinstance(node, OuterLoopFusedSchedulerNode) | ||||
|             cpp_kernel_proxy_list.clear() | ||||
|             nodes_list.clear() | ||||
|  | ||||
|             def get_call_ranges(node: BaseSchedulerNode): | ||||
|                 assert isinstance(node, (SchedulerNode, FusedSchedulerNode)) | ||||
|                 nodes: List[SchedulerNode] = node.get_nodes()  # type: ignore[assignment] | ||||
|                 _, (group, reduction_group) = max( | ||||
|                     nodes, key=lambda x: int(x.is_reduction()) | ||||
|                 ).group | ||||
|                 call_ranges = tuple(group) + tuple(reduction_group) | ||||
|                 return call_ranges | ||||
|  | ||||
|             LocalBuffer = namedtuple("LocalBuffer", ["local_buf", "global_buf"]) | ||||
|             local_buffers: List[LocalBuffer] = [] | ||||
|             if all( | ||||
|                 len(get_call_ranges(_node)) == node.outer_loop_fusion_depth + 1 | ||||
|                 for _node in node.get_outer_nodes() | ||||
|             ): | ||||
|                 # Ref to the typical case of local buffer | ||||
|                 # in https://github.com/pytorch/pytorch/blob/ | ||||
|                 # 1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159 | ||||
|                 # where the buffer is with size of last dim and contiguous. | ||||
|                 # Only support this typical case at first. | ||||
|                 for scheduler_node in node.get_nodes(): | ||||
|                     # all users inside same OuterLoopFusedSchedulerNode | ||||
|                     if not scheduler_node.is_reduction() and all( | ||||
|                         user.node in node.get_nodes() for user in scheduler_node.users | ||||
|                     ): | ||||
|                         global_buffer = scheduler_node.node | ||||
|                         assert isinstance(global_buffer, ir.ComputedBuffer) | ||||
|                         global_buffer_layout = global_buffer.get_layout() | ||||
|                         size_offset = node.outer_loop_fusion_depth - len( | ||||
|                             get_call_ranges(scheduler_node) | ||||
|                         ) | ||||
|  | ||||
|                         def is_all_write_read_contiguous(scheduler_node): | ||||
|                             contiguous_index_expr = 0 | ||||
|                             stride = 1 | ||||
|                             for var, range in reversed( | ||||
|                                 scheduler_node._body.var_ranges.items() | ||||
|                             ): | ||||
|                                 contiguous_index_expr += stride * var | ||||
|                                 stride *= range | ||||
|                             write_index_expr = scheduler_node._body.writes_name2expr[ | ||||
|                                 scheduler_node.get_name() | ||||
|                             ] | ||||
|  | ||||
|                             def is_contiguous_index(x): | ||||
|                                 return x == contiguous_index_expr | ||||
|  | ||||
|                             return is_contiguous_index(write_index_expr) and all( | ||||
|                                 is_contiguous_index( | ||||
|                                     user.node._body.reads_name2expr[ | ||||
|                                         scheduler_node.get_name() | ||||
|                                     ], | ||||
|                                 ) | ||||
|                                 for user in scheduler_node.users | ||||
|                             ) | ||||
|  | ||||
|                         if not ( | ||||
|                             global_buffer_layout.is_contiguous() | ||||
|                             and not scheduler_node.is_reduction() | ||||
|                             and is_all_write_read_contiguous(scheduler_node) | ||||
|                         ): | ||||
|                             continue | ||||
|                         # Local Buffer is a view of global buffer | ||||
|                         local_buffer_layout = ir.FixedLayout( | ||||
|                             global_buffer_layout.device, | ||||
|                             global_buffer_layout.dtype, | ||||
|                             global_buffer_layout.size[size_offset:], | ||||
|                             global_buffer_layout.stride[size_offset:], | ||||
|                         ) | ||||
|                         local_buffers.append( | ||||
|                             LocalBuffer( | ||||
|                                 local_buf=ir.Buffer( | ||||
|                                     "local_buffer_data", local_buffer_layout | ||||
|                                 ), | ||||
|                                 global_buf=global_buffer, | ||||
|                             ) | ||||
|                         ) | ||||
|                         # At most 1 node with local buf for each OuterLoopFusedSchedulerNode | ||||
|                         break | ||||
|             assert len(local_buffers) in [0, 1] | ||||
|  | ||||
|             with LocalBufferContext(kernel_group.args) as scope: | ||||
|                 if len(local_buffers) > 0: | ||||
|                     scope.add_local_buffer( | ||||
|                         local_buffers[0].local_buf, local_buffers[0].global_buf | ||||
|                     ) | ||||
|                 for _node in node.get_outer_nodes(): | ||||
|                     assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) | ||||
|                     cpp_kernel_proxy = CppKernelProxy(kernel_group) | ||||
|                     cpp_kernel_proxy.codegen_nodes(_node.get_nodes())  # type: ignore[arg-type] | ||||
|                     cpp_kernel_proxy_list.append(cpp_kernel_proxy) | ||||
|                     nodes_list.append(_node.get_nodes())  # type: ignore[arg-type] | ||||
|  | ||||
|                 if not node.check_outer_fusion_loop_level_attr( | ||||
|                     cpp_kernel_proxy_list, node.outer_loop_fusion_depth | ||||
|                 ): | ||||
|                     return False | ||||
|                 metrics.cpp_outer_loop_fused_inner_counts.append( | ||||
|                     metrics.CppOuterLoopFusedCount( | ||||
|                         len(cpp_kernel_proxy_list), | ||||
|                         local_buffer_number=len(local_buffers), | ||||
|                     ) | ||||
|                 ) | ||||
|                 outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels( | ||||
|                     cpp_kernel_proxy_list, | ||||
|                 ) | ||||
|                 kernel_group.finalize_kernel( | ||||
|                     outer_fusion_cpp_kernel_proxy, | ||||
|                     [_node for _nodes in nodes_list for _node in _nodes], | ||||
|                 ) | ||||
|  | ||||
|             return True | ||||
|  | ||||
|         if not try_outer_loop_fusion_with_local_buf(node): | ||||
|             # Reset generated_cpp_vec_kernel_count to codegen again | ||||
|             metrics.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count | ||||
|             cpp_kernel_proxy_list.clear() | ||||
|             nodes_list.clear() | ||||
|             # Similar as comment in | ||||
|             # https://github.com/pytorch/pytorch/blob/469383755fe416eb1c41fa724762ad3eaecdff07/torch/_inductor/codegen/cpp.py#L3269-L3272 | ||||
|             # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. | ||||
|             with torch._inductor.config.patch(inplace_buffers=False): | ||||
|                 for _node in node.get_outer_nodes(): | ||||
|                     assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) | ||||
|                     _nodes: List[SchedulerNode] = _node.get_nodes()  # type: ignore[assignment] | ||||
|                     cpp_kernel_proxy = CppKernelProxy(kernel_group) | ||||
|                     cpp_kernel_proxy.codegen_nodes(_nodes) | ||||
|                     kernel_group.finalize_kernel(cpp_kernel_proxy, _nodes) | ||||
|  | ||||
|     def codegen_node( | ||||
|         self, | ||||
|         node: Union[OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode], | ||||
| @ -4010,7 +3817,38 @@ class CppScheduling(BaseScheduling): | ||||
|         kernel_group = self.kernel_group | ||||
|  | ||||
|         if isinstance(node, OuterLoopFusedSchedulerNode): | ||||
|             self.codegen_outer_loop_node(node) | ||||
|             cpp_kernel_proxy_list: List[CppKernelProxy] = [] | ||||
|             nodes_list: List[List[SchedulerNode]] = [] | ||||
|  | ||||
|             for _node in node.get_outer_nodes(): | ||||
|                 assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) | ||||
|                 _nodes: List[SchedulerNode] = _node.get_nodes()  # type: ignore[assignment] | ||||
|                 cpp_kernel_proxy = CppKernelProxy(kernel_group) | ||||
|                 cpp_kernel_proxy.codegen_nodes(_nodes) | ||||
|  | ||||
|                 cpp_kernel_proxy_list.append(cpp_kernel_proxy) | ||||
|                 nodes_list.append(_nodes) | ||||
|  | ||||
|             # Note that, in the future, when every kernel can be vectorized, | ||||
|             # the function select_tiling will be much easier, and we'll be able to lift | ||||
|             # check_outer_fusion_loop_level_attr to the fusion phase, | ||||
|             # avoiding grouping kernels at fusion time that "look like we'll be able to fuse them" | ||||
|             # but then we actually won't. | ||||
|             if node.check_outer_fusion_loop_level_attr( | ||||
|                 cpp_kernel_proxy_list, node.outer_loop_fusion_depth | ||||
|             ): | ||||
|                 # Merge the cpp_kernel_proxy_list into cpp_kernel_proxy | ||||
|                 outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels( | ||||
|                     cpp_kernel_proxy_list, | ||||
|                 ) | ||||
|                 kernel_group.finalize_kernel( | ||||
|                     outer_fusion_cpp_kernel_proxy, | ||||
|                     [_node for _nodes in nodes_list for _node in _nodes], | ||||
|                 ) | ||||
|             else: | ||||
|                 # Fall back to standard loop codegen | ||||
|                 for _kernel_proxy, _nodes in zip(cpp_kernel_proxy_list, nodes_list): | ||||
|                     kernel_group.finalize_kernel(_kernel_proxy, _nodes) | ||||
|         else: | ||||
|             nodes: List[SchedulerNode] = node.get_nodes()  # type: ignore[assignment] | ||||
|             cpp_kernel_proxy = CppKernelProxy(kernel_group) | ||||
|  | ||||
| @ -14,7 +14,7 @@ from ..select_algorithm import PartialRender | ||||
| from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix | ||||
| from ..virtualized import V | ||||
| from .cpp import CppKernel, CppKernelProxy, KernelGroup | ||||
| from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext | ||||
| from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferScope | ||||
|  | ||||
|  | ||||
| def parse_expr_with_index_symbols(expr): | ||||
| @ -270,11 +270,13 @@ class CppTemplateKernel(CppKernel): | ||||
|         if offsets: | ||||
|             offsets = parse_expr_with_index_symbols(offsets) | ||||
|         if epilogue_nodes: | ||||
|             with LocalBufferContext(self.args) as scope: | ||||
|             with LocalBufferScope(self) as scope: | ||||
|                 assert orig_src is not None | ||||
|                 if orig_src.get_name() != src.get_name(): | ||||
|                     scope.add_local_buffer(src, orig_src) | ||||
|                     epilogue_nodes = scope.localize_nodes(epilogue_nodes) | ||||
|                     scope.add_local_buffer(src) | ||||
|                     epilogue_nodes = scope.localize_buffer( | ||||
|                         orig_src, src, epilogue_nodes | ||||
|                     ) | ||||
|                 return self.store_pointwise_nodes( | ||||
|                     dst, epilogue_nodes, offsets, reindexers  # type: ignore[arg-type] | ||||
|                 ) | ||||
| @ -282,7 +284,7 @@ class CppTemplateKernel(CppKernel): | ||||
|             if dst.get_name() != src.get_name(): | ||||
|                 # src is local | ||||
|                 copy = L.copy(dst, src).data.data | ||||
|                 with LocalBufferContext(self.args) as scope: | ||||
|                 with LocalBufferScope(self) as scope: | ||||
|                     scope.add_local_buffer(src) | ||||
|                     return self.store_pointwise_nodes(dst, [copy]) | ||||
|             else: | ||||
|  | ||||
| @ -4,7 +4,7 @@ import copy | ||||
| import math | ||||
|  | ||||
| from collections import namedtuple | ||||
| from typing import Any, Callable, Dict, List, Optional, Tuple | ||||
| from typing import Dict, List, Tuple | ||||
| from unittest.mock import patch | ||||
|  | ||||
| import sympy | ||||
| @ -12,10 +12,11 @@ import sympy | ||||
| import torch | ||||
| from torch.utils._sympy.symbol import symbol_is_type, SymT | ||||
| from .. import ir | ||||
| from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs | ||||
| from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix | ||||
| from ..virtualized import V | ||||
|  | ||||
| from .common import CSEVariable, ExprPrinter, Kernel, KernelArgs | ||||
| from .common import CSEVariable, ExprPrinter, Kernel | ||||
|  | ||||
|  | ||||
| DTYPE_TO_CPP = { | ||||
|     torch.float32: "float", | ||||
| @ -303,88 +304,7 @@ def value_to_cpp(value, cpp_type): | ||||
|         return f"static_cast<{cpp_type}>({repr(value)})" | ||||
|  | ||||
|  | ||||
| def rewrite_index_for_function( | ||||
|     localize_buffer_handler: "LocalizeBufferHandler", | ||||
|     index: sympy.Expr, | ||||
| ): | ||||
|     # Local buffer at the inner dimensions | ||||
|     snode = V.graph.scheduler.name_to_node.get( | ||||
|         localize_buffer_handler.global_buf.get_name() | ||||
|     ) | ||||
|     assert snode is not None | ||||
|     scheduler_nodes = snode.get_nodes() | ||||
|     _, (group, reduction_group) = max( | ||||
|         scheduler_nodes, key=lambda x: int(x.is_reduction()) | ||||
|     ).group | ||||
|     call_ranges = tuple(group) + tuple(reduction_group) | ||||
|     indices_to_keep = [ | ||||
|         f"x{len(call_ranges) - (idx + 1)}" | ||||
|         for idx in range(len(localize_buffer_handler.local_buf.get_layout().size)) | ||||
|     ] | ||||
|     sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)  # type: ignore[attr-defined] | ||||
|     replacements = {} | ||||
|     for x in sorted_symbols: | ||||
|         if x.name.startswith("x") and x.name not in indices_to_keep:  # type: ignore[attr-defined] | ||||
|             # Only keep index used by local buffer | ||||
|             replacements[x] = sympy.core.numbers.Zero() | ||||
|     index = sympy_subs(index, replacements)  # type: ignore[arg-type] | ||||
|     return index | ||||
|  | ||||
|  | ||||
| def rewrite_index_for_nodes( | ||||
|     localize_buffer_handler: "LocalizeBufferHandler", | ||||
|     index: sympy.Expr, | ||||
| ): | ||||
|     used_vars = {s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)} | ||||
|     index_vars = [] | ||||
|     for i in range(len(localize_buffer_handler.local_buf.get_size())): | ||||
|         var = sympy_index_symbol_with_prefix(SymT.INDEX, i) | ||||
|         index_vars.append(var if var in used_vars else 0) | ||||
|     index = localize_buffer_handler.local_buf.layout.make_indexer()(index_vars) | ||||
|     return index | ||||
|  | ||||
|  | ||||
| class LocalizeBufferHandler(V.WrapperHandler):  # type: ignore[name-defined] | ||||
|     def __init__( | ||||
|         self, | ||||
|         inner, | ||||
|         global_buf: ir.Buffer, | ||||
|         local_buf: ir.Buffer, | ||||
|         rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr], sympy.Expr], | ||||
|     ): | ||||
|         super().__init__(inner) | ||||
|         self.global_buf = global_buf | ||||
|         self.local_buf = local_buf | ||||
|         self.rewrite_index = rewrite_index | ||||
|  | ||||
|     def localize(self, name: str, index: sympy.Expr): | ||||
|         if self.global_buf and name == self.global_buf.get_name(): | ||||
|             assert self.rewrite_index is not None | ||||
|             name = self.local_buf.get_name() | ||||
|             index = self.rewrite_index(self, index) | ||||
|         return name, index | ||||
|  | ||||
|     def load(self, name: str, index: sympy.Expr): | ||||
|         return self._inner.load(*self.localize(name, index)) | ||||
|  | ||||
|     def store(self, name, index, value, mode=None): | ||||
|         local_buffer_name, local_buffer_index = self.localize(name, index) | ||||
|         res = self._inner.store(local_buffer_name, local_buffer_index, value, mode) | ||||
|         if ( | ||||
|             self.global_buf | ||||
|             and name == self.global_buf.get_name() | ||||
|             and isinstance(V.kernel, Kernel) | ||||
|         ): | ||||
|             # Remove name of local buffer from Kernel.store_buffer_names | ||||
|             # local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store. | ||||
|             V.kernel.store_buffer_names.discard(local_buffer_name) | ||||
|         return res | ||||
|  | ||||
|     def store_reduction(self, name, index, value): | ||||
|         return self._inner.store_reduction(*self.localize(name, index), value) | ||||
|  | ||||
|  | ||||
| class LocalBufferContext: | ||||
| class LocalBufferScope: | ||||
|     """ | ||||
|     This class creates a context that helps to generate code involving Inductor IR with | ||||
|     function local buffers. These buffers are constructed during the codegen process and | ||||
| @ -394,13 +314,10 @@ class LocalBufferContext: | ||||
|     these buffers without exposure to the outside world. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, kernel_args: KernelArgs): | ||||
|         self.kernel_args = kernel_args | ||||
|     def __init__(self, kernel: Kernel): | ||||
|         self.kernel = kernel | ||||
|         self.exit_stack = contextlib.ExitStack() | ||||
|         # Map Local Buffer name to Local Buffer | ||||
|         self.local_buffers: Dict[str, ir.Buffer] = {} | ||||
|         # Map Local Buffer name to Global Buffer | ||||
|         self.local_to_global: Dict[str, ir.Buffer] = {} | ||||
|  | ||||
|     def __enter__(self): | ||||
|         self.exit_stack.__enter__() | ||||
| @ -413,26 +330,23 @@ class LocalBufferContext: | ||||
|  | ||||
|         self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype)) | ||||
|  | ||||
|         original_input = self.kernel_args.input | ||||
|         original_input = self.kernel.args.input | ||||
|  | ||||
|         def input(name): | ||||
|             if name in self.local_buffers: | ||||
|                 return name | ||||
|             return original_input(name) | ||||
|  | ||||
|         self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input)) | ||||
|         self.exit_stack.enter_context(patch.object(self.kernel.args, "input", input)) | ||||
|  | ||||
|         original_output = self.kernel_args.output | ||||
|         original_output = self.kernel.args.output | ||||
|  | ||||
|         def output(name): | ||||
|             if name in self.local_buffers: | ||||
|                 return name | ||||
|             return original_output(name) | ||||
|  | ||||
|         self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output)) | ||||
|  | ||||
|         # Set current LocalBufferContext into V | ||||
|         self.exit_stack.enter_context(V.set_local_buffer_context(self)) | ||||
|         self.exit_stack.enter_context(patch.object(self.kernel.args, "output", output)) | ||||
|  | ||||
|         return self | ||||
|  | ||||
| @ -440,64 +354,53 @@ class LocalBufferContext: | ||||
|         self.local_buffers.clear() | ||||
|         self.exit_stack.__exit__(exc_type, exc_val, exc_tb) | ||||
|  | ||||
|     def add_local_buffer( | ||||
|         self, local_buffer: ir.Buffer, global_buffer: Optional[ir.Buffer] = None | ||||
|     ): | ||||
|         assert local_buffer.get_name() not in self.local_buffers | ||||
|         self.local_buffers[local_buffer.get_name()] = local_buffer | ||||
|         if global_buffer: | ||||
|             self.local_to_global[local_buffer.get_name()] = global_buffer | ||||
|             V.graph.removed_buffers.add(global_buffer.get_name()) | ||||
|     def add_local_buffer(self, buffer: ir.Buffer): | ||||
|         assert buffer.get_name() not in self.local_buffers | ||||
|         self.local_buffers[buffer.get_name()] = buffer | ||||
|  | ||||
|     def localize_function( | ||||
|         self, | ||||
|         fn: Callable[..., Any], | ||||
|         rewrite_index: Callable[ | ||||
|             ["LocalizeBufferHandler", sympy.Expr], sympy.Expr | ||||
|         ] = rewrite_index_for_function, | ||||
|     ): | ||||
|         local_buffers = list(self.local_buffers.values()) | ||||
|         global_buffers = list(self.local_to_global.values()) | ||||
|         local_buf = local_buffers[0] | ||||
|         global_buf = global_buffers[0] | ||||
|  | ||||
|         def inner(node, *index_vars): | ||||
|             with V.set_ops_handler( | ||||
|                 LocalizeBufferHandler( | ||||
|                     V.get_ops_handler(), | ||||
|                     global_buf=global_buf, | ||||
|                     local_buf=local_buf, | ||||
|                     rewrite_index=rewrite_index, | ||||
|                 ) | ||||
|             ): | ||||
|                 return fn(node, *index_vars) | ||||
|  | ||||
|         return inner | ||||
|  | ||||
|     def localize_nodes( | ||||
|         self, | ||||
|         nodes: List[ir.IRNode], | ||||
|         rewrite_index: Callable[ | ||||
|             ["LocalizeBufferHandler", sympy.Expr], sympy.Expr | ||||
|         ] = rewrite_index_for_nodes, | ||||
|     def localize_buffer( | ||||
|         self, global_buf: ir.Buffer, local_buf: ir.Buffer, nodes: List[ir.IRNode] | ||||
|     ) -> List[ir.IRNode]: | ||||
|         """ | ||||
|         Given `local_buf` and `global_buf` registered in current `LocalBufferContext` | ||||
|         though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf` | ||||
|         for the given `nodes` and returns a new list of IR nodes that work on `local_buf` | ||||
|         instead of `global_buf`, i.e., all the loads and stores are redirected to | ||||
|         `local_buf`. This helps the fused loops to work on smaller-sized local buffers | ||||
|         for better data locality. | ||||
|         Localizes the buffer `global_buf` to `local_buf` in the given `nodes` and returns | ||||
|         a new list of IR nodes that work on `local_buf` instead of `global_buf`, i.e., all | ||||
|         the loads and stores are redirected to `local_buf`. This helps the fused loops to | ||||
|         work on smaller-sized local buffers for better data locality. | ||||
|  | ||||
|         The the data access of `local_buf` is assumed to be contiguous with the | ||||
|         same order as the `global_buf`. | ||||
|         The `local_buf` should already be registered in the local scope and the data access | ||||
|         is assumed to be contiguous with the same order as the `global_buf`. | ||||
|         """ | ||||
|         local_buffers = list(self.local_buffers.values()) | ||||
|         global_buffers = list(self.local_to_global.values()) | ||||
|         assert len(global_buffers[0].get_size()) == len(local_buffers[0].get_size()) | ||||
|         assert local_buf.get_name() in self.local_buffers | ||||
|         assert len(global_buf.get_size()) == len(local_buf.get_size()) | ||||
|         assert len(nodes) > 0 | ||||
|  | ||||
|         def wrap_inner_fn_for_node(node: ir.IRNode): | ||||
|         class LocalizeBufferHandler(V.WrapperHandler):  # type: ignore[name-defined] | ||||
|             def __init__(self, inner): | ||||
|                 super().__init__(inner) | ||||
|  | ||||
|             def localize(self, name: str, index: sympy.Expr): | ||||
|                 if name == global_buf.get_name(): | ||||
|                     name = local_buf.get_name() | ||||
|                     used_vars = { | ||||
|                         s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX) | ||||
|                     } | ||||
|                     index_vars = [] | ||||
|                     for i in range(len(local_buf.get_size())): | ||||
|                         var = sympy_index_symbol_with_prefix(SymT.INDEX, i) | ||||
|                         index_vars.append(var if var in used_vars else 0) | ||||
|                     index = local_buf.layout.make_indexer()(index_vars) | ||||
|                 return name, index | ||||
|  | ||||
|             def load(self, name: str, index: sympy.Expr): | ||||
|                 return self._inner.load(*self.localize(name, index)) | ||||
|  | ||||
|             def store(self, name, index, value, mode=None): | ||||
|                 return self._inner.store(*self.localize(name, index), value, mode) | ||||
|  | ||||
|             def store_reduction(self, name, index, value): | ||||
|                 return self._inner.store_reduction(*self.localize(name, index), value) | ||||
|  | ||||
|         def wrap_inner_fn_for_node(node: ir.IRNode, inner_fn_wrapper): | ||||
|             loops = node.data if isinstance(node, ir.ComputedBuffer) else node | ||||
|             assert isinstance(loops, ir.Loops) | ||||
|             new_loops = copy.copy(loops) | ||||
| @ -508,13 +411,17 @@ class LocalBufferContext: | ||||
|             else: | ||||
|                 new_node = new_loops  # type: ignore[assignment] | ||||
|  | ||||
|             new_loops.inner_fn = self.localize_function( | ||||
|                 new_loops.inner_fn, | ||||
|                 rewrite_index, | ||||
|             ) | ||||
|             new_loops.inner_fn = inner_fn_wrapper(new_loops.inner_fn) | ||||
|             return new_node | ||||
|  | ||||
|         return [wrap_inner_fn_for_node(node) for node in nodes] | ||||
|         def inner_fn_wrapper(inner_fn): | ||||
|             def inner(index): | ||||
|                 with V.set_ops_handler(LocalizeBufferHandler(V.get_ops_handler())): | ||||
|                     return inner_fn(index) | ||||
|  | ||||
|             return inner | ||||
|  | ||||
|         return [wrap_inner_fn_for_node(node, inner_fn_wrapper) for node in nodes] | ||||
|  | ||||
|  | ||||
| def unify_mask_base_type( | ||||
|  | ||||
| @ -41,15 +41,9 @@ ir_nodes_pre_fusion = 0 | ||||
| # counters for tracking to_dtype inserted | ||||
| cpp_to_dtype_count = 0 | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass | ||||
| class CppOuterLoopFusedCount: | ||||
|     inner_kernel_number: int | ||||
|     local_buffer_number: int = 0 | ||||
|  | ||||
|  | ||||
| # The length counts the number of outer loop fusions. | ||||
| cpp_outer_loop_fused_inner_counts: List[CppOuterLoopFusedCount] = [] | ||||
| # Each element counts the number of inner kernels in each outer loop fusion. | ||||
| cpp_outer_loop_fused_inner_counts: List[int] = [] | ||||
|  | ||||
| num_comprehensive_padding = 0 | ||||
| num_matches_for_scatter_upon_const_tensor = 0 | ||||
|  | ||||
| @ -72,7 +72,6 @@ from .ops_handler import (  # noqa: F401 | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     import torch | ||||
|     from torch._inductor.codegen.cpp_utils import LocalBufferContext | ||||
|     from torch._inductor.debug import DebugContext | ||||
|     from torch._inductor.graph import GraphLowering | ||||
|     from torch._inductor.ir import InterpreterShim | ||||
| @ -163,9 +162,6 @@ _debug: Virtualized[DebugContext] = Virtualized("debug", NullHandler) | ||||
| _interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler) | ||||
| _aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler) | ||||
| _current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler) | ||||
| _local_buffer_context: Virtualized[LocalBufferContext] = Virtualized( | ||||
|     "local_buffer_context", NullHandler | ||||
| ) | ||||
|  | ||||
|  | ||||
| class OpsValue: | ||||
| @ -310,8 +306,6 @@ class _V: | ||||
|     get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler | ||||
|     set_current_node: Callable[[Any], Any] = _current_node._set_handler | ||||
|     get_current_node: Callable[[], Any] = _current_node._get_handler | ||||
|     set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler | ||||
|     get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler | ||||
|  | ||||
|     @property | ||||
|     def ops(self) -> OpsHandler[Any]: | ||||
| @ -354,9 +348,5 @@ class _V: | ||||
|     def current_node(self): | ||||
|         return _current_node._get_handler() | ||||
|  | ||||
|     @property | ||||
|     def local_buffer_context(self): | ||||
|         return _local_buffer_context._get_handler() | ||||
|  | ||||
|  | ||||
| V = _V() | ||||
|  | ||||
		Reference in New Issue
	
	Block a user