diff --git a/test/inductor/test_fxir_backend.py b/test/inductor/test_fxir_backend.py new file mode 100644 index 000000000000..cfd34fad50d7 --- /dev/null +++ b/test/inductor/test_fxir_backend.py @@ -0,0 +1,417 @@ +# Owner(s): ["module: inductor"] +""" +Test the FX IR backend. +""" + +import itertools +import operator +import unittest +from typing import Callable, Optional + +import sympy + +import torch +import torch._inductor.codegen.common as common +import torch.utils._pytree as pytree +from torch._dynamo.exc import BackendCompilerFailed +from torch._dynamo.utils import same +from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation +from torch._inductor import config +from torch._inductor.codegen.common import register_backend_for_device +from torch._inductor.codegen.cpp import CppScheduling +from torch._inductor.codegen.triton import TritonScheduling +from torch._inductor.codegen.wrapper_fxir import FxConverter, WrapperFxCodegen +from torch._inductor.select_algorithm import extern_kernels +from torch._inductor.test_case import TestCase as InductorTestCase +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_GPU, + requires_gpu, + TRITON_HAS_CPU, +) + + +@requires_gpu() +@config.patch( + compile_threads=1, + alignment_asserts=False, + size_asserts=False, + scalar_asserts=False, + nan_asserts=False, +) +class FxirTestCase(InductorTestCase): + device = GPU_TYPE + + def _count_ops(self, gm: torch.fx.GraphModule, target: Callable) -> int: + return len(gm.graph.find_nodes(op="call_function", target=target)) + + def _run_and_capture_graphs(self, opt, args) -> torch.fx.GraphModule: + gms = [] + + orig_generate = FxConverter.generate + + def generate(self) -> torch.fx.GraphModule: + nonlocal gms + gm = orig_generate(self) + gms.append(gm) + return gm + + with unittest.mock.patch.object( + torch._inductor.codegen.wrapper_fxir.FxConverter, "generate", generate + ): + opt(*args) + + return gms + + def _compile_and_check( + self, + func, + args, + expected_num_triton_kernels: int = 1, + metadata_only: bool = False, + compile_kwargs: Optional[dict] = None, + ): + if compile_kwargs is None: + compile_kwargs = {} + + opt = torch.compile(func, **compile_kwargs) + + # Get the FX graph from the backend. + gms = self._run_and_capture_graphs(opt, args) + + # Check the code for triton kernels. + num_kernels = sum( + self._count_ops(gm, triton_kernel_wrapper_mutation) for gm in gms + ) + self.assertEqual(num_kernels, expected_num_triton_kernels) + + # Check accuracy. + result = opt(*args) + ref = func(*args) + if metadata_only: + # When we only want to check metadata, fill in zeros for tensor data. + ref, result = tuple( + pytree.tree_map(torch.zeros_like, x) for x in (ref, result) + ) + + self.assertTrue(same(ref, result)) + + return gms + + @classmethod + def setUpClass(cls): + super().setUpClass() + + # Register the FX backend. + register_backend_for_device(cls.device, TritonScheduling, WrapperFxCodegen) + + def test_basic(self): + args = [torch.randn(8, device=self.device) for _ in range(2)] + self._compile_and_check(torch.add, args) + + def test_multiple_kernels(self): + def foo(x, y): + return x.sum() + y.sum() + + args = [torch.randn(length, device=self.device) for length in [517, 1029]] + self._compile_and_check(foo, args, expected_num_triton_kernels=2) + + def test_free(self): + """ + Test a program that frees a buffer which is no longer in use. + """ + + def foo(x, y, z): + w = x.sum() + y + return z.sum() + w.sum() + + args = [torch.randn(length, device=self.device) for length in [517, 1029, 123]] + (gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=3) + + # Check the generated code for frees. + num_frees = gm.code.count("= None") + self.assertGreater(num_frees, 0) + + def test_extern(self): + """ + Test a program that calls an extern kernel. + """ + + def foo(x, y): + return x @ y + y.sum() + + args = [ + torch.randn(size, device=self.device) for size in [(129, 129), (129, 1)] + ] + (gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1) + + # Check for the extern kernel + num_extern = self._count_ops(gm, extern_kernels.addmm) + self.assertEqual(num_extern, 1) + + def test_fallback(self): + """ + Test a program that calls an aten fallback. + """ + + length = 8 + + def foo(x): + return x + torch.randn(1, device=self.device) + + args = (torch.randn(length, device=self.device),) + + # Since the program has a random output, just check metadata. + # Don't check for an exact value. + (gm,) = self._compile_and_check( + foo, args, expected_num_triton_kernels=2, metadata_only=True + ) + + # Check for the fallback kernel. + num_fallback = self._count_ops(gm, torch.ops.aten.randint.low_out) + self.assertEqual(num_fallback, 1) + + def test_cat_inputs(self): + """ + Test concatenation of graph inputs. + """ + + def foo(x, y): + return torch.cat((x, y)) + 1 + + args = [torch.randn(8, device=self.device) for _ in range(2)] + self._compile_and_check(foo, args, expected_num_triton_kernels=1) + + def test_cat_to_alloc(self): + """ + Test concatenation that's optimized out to an allocation. + """ + length = 8 + + def foo(x): + y, z = tuple( + torch.arange(length // 2, device=self.device) for _ in range(2) + ) + return x + torch.cat((y, z)) + + args = [torch.randn(length, device=self.device)] + (gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1) + + # Expect a single allocation, even though eager mode would use 2. + num_allocs = self._count_ops(gm, torch.empty_strided) + self.assertEqual(num_allocs, 1) + + def test_cat_reinterpret_view(self): + """ + Test torch.cat using ReinterpretView. + """ + length = 8 + + def foo(x): + y, z = tuple(torch.randn(length // 2, device=self.device) for _ in range(2)) + return x + torch.cat((y, z)) + + args = [torch.randn(length, device=self.device)] + + # Since this test generates random numbers, check metadata only. + (gm,) = self._compile_and_check( + foo, args, expected_num_triton_kernels=3, metadata_only=True + ) + + # Check for as_strided. We map ReinterpretView to this. + num_as_strided = self._count_ops(gm, torch.as_strided) + self.assertEqual(num_as_strided, 2) + + def test_reshape_output(self): + """ + Test reshaping the output, which maps to a ReinterpretView. + """ + + def foo(x, y): + return torch.reshape(x + y, (8,)) + + args = [torch.randn((2, 4), device=self.device) for _ in range(2)] + (gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1) + + # Check for as_strided. We map ReinterpretView to this. + num_as_strided = self._count_ops(gm, torch.as_strided) + self.assertEqual(num_as_strided, 1) + + def test_extern_multi_output(self): + """ + Test an extern kernel with multiple outputs. + Also test a graph with multiple outputs. + """ + + def foo(x): + top, idx = torch.topk(x, 2) + return top + 1, idx * 2 + + args = [torch.randn(8, device=self.device)] + (gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=2) + + # Check for multiple kernel outputs via getitems. + num_getitems = self._count_ops(gm, operator.getitem) + self.assertEqual(num_getitems, 2) + + # Check for multiple graph outputs. + output_node = gm.graph.find_nodes(op="output")[0] + self.assertEqual(len(output_node.args[0]), 2) + + def test_duplicate_input(self): + """ + Test duplicated inputs. This will collapse into a single input in the GM. + """ + + args = [torch.randn(4, device=self.device)] * 2 + (gm,) = self._compile_and_check(torch.add, args, expected_num_triton_kernels=1) + + num_placeholders = len(gm.graph.find_nodes(op="placeholder")) + self.assertEqual(num_placeholders, 1) + + def test_backward(self): + """ + Test a program with a backward pass. + """ + + x = torch.ones(5, device=self.device) # input tensor + y = torch.zeros(3, device=self.device) # expected output + w = torch.randn(5, 3, requires_grad=True, device=self.device) + b = torch.randn(3, requires_grad=True, device=self.device) + + def foo(x, y): + z = torch.matmul(x, w) + b + loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y) + loss.backward() + return w.grad, b.grad + + # Expect separate forward and backward graphs. + (forward_gm, backward_gm) = self._compile_and_check( + foo, (x, y), expected_num_triton_kernels=3 + ) + + def test_custom_compiler(self): + """ + Test a derived backend with a custom compiler. + """ + offset = 1 + + class CustomWrapperCodegen(WrapperFxCodegen): + def compile_graph(self, gm): + def compiled_fn(*args): + # Adds an offset to the program's outputs. + outputs = gm(*args) + return pytree.tree_map(lambda x: x + 1, outputs) + + return compiled_fn + + args = [torch.randn(8, device=self.device) for _ in range(2)] + custom_backend = common.DeviceCodegen( + TritonScheduling, CustomWrapperCodegen, None + ) + with unittest.mock.patch.dict( + common.device_codegens, {self.device: custom_backend} + ): + func = torch.add + opt = torch.compile(func) + result = opt(*args) + + # Check the output is offset from eager mode. + ref = func(*args) + self.assertFalse(same(result, ref)) + self.assertNotEqual(offset, 0) + self.assertTrue(same(result - offset, ref)) + + def test_dynamic_shapes_and_strides(self): + """ + Test a graph with dynamic shapes and strides. + """ + + static_dims = (8, 8) + + def get_input(): + full_size = (16, 8) + full = torch.randn(full_size, device=self.device) + view = torch.as_strided(full, static_dims, full.stride()) + return view + + func = torch.add + args = [get_input() for _ in range(2)] + (gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True}) + + # Check for a symbolic output shape. + (empty_strided,) = gm.graph.find_nodes( + op="call_function", target=torch.empty_strided + ) + example_tensor = empty_strided.meta["val"] + symbolic_dims = example_tensor.shape + self.assertEqual(len(symbolic_dims), len(static_dims)) + + # Check for symbolic output strides. + (stride, one) = example_tensor.stride() + self.assertEqual(one, sympy.S.One) + + # Find the size symbols, and check for a corresponding placeholders defining them. + for symbol in itertools.chain(symbolic_dims, [stride]): + self.assertTrue(isinstance(symbol, torch.SymInt)) + (placeholder,) = [ + node + for node in gm.graph.find_nodes(op="placeholder") + if node.name == str(symbol) + ] + self.assertEqual(placeholder.meta["val"], symbol) + + @config.patch({"trace.enabled": True}) + @unittest.mock.patch("torch._inductor.debug.DebugFormatter.output_code") + def test_debug(self, mock_output_code): + # Compile in debug mode. + args = [torch.randn(11, device=self.device) for _ in range(2)] + self._compile_and_check(torch.sub, args) + + # Check the output code for a Triton kernel call. + mock_output_code.assert_called_once() + (output_filename,) = mock_output_code.call_args.args + with open(output_filename) as f: + output_code = f.read() + self.assertIn("triton_kernel_wrapper_mutation", output_code) + + @torch._inductor.config.patch("graph_partition", True) + def test_subgraph_raises(self): + """ + Test a model with subgraphs. This is not yet supported, so check that we get the + expected exception. + """ + + def foo(cond, x): + return torch.cond(cond, torch.cos, torch.sin, [x]) + + cond = torch.tensor([True], device=self.device) + x = torch.ones([2, 3], device=self.device) + + with self.assertRaisesRegex(BackendCompilerFailed, "Subgraph"): + self._compile_and_check(foo, [cond, x]) + + def test_cpp_raises(self): + """ + Test the C++ CPU backend. C++ kernels are not yet supported, so for now check + that we get the expected exception. + """ + + def foo(x, y): + return x + y * 5 + + device = torch.device("cpu") + args = [torch.randn(5, device=device) for _ in range(2)] + + cpp_backend = common.DeviceCodegen(CppScheduling, WrapperFxCodegen, None) + with unittest.mock.patch.dict( + common.device_codegens, {device.type: cpp_backend} + ), self.assertRaisesRegex(BackendCompilerFailed, "Triton"): + self._compile_and_check(foo, args) + + +if __name__ == "__main__": + from torch._inductor.test_case import run_tests + + if HAS_GPU or TRITON_HAS_CPU: + run_tests(needs="filelock") diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 0fa3d20214f4..4703b79af93c 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -1750,6 +1750,27 @@ class TracingTritonHOPifier(TritonHOPifier): # normalize to tuple return tuple(grid) + def store_non_graphable_args( + self, + combined_args: dict[str, Any], + ) -> tuple[dict, int]: + """ + Some args cannot be stored in the FX graph. + Put them in the side table. + """ + + def is_graphable(val: Any) -> bool: + return isinstance(val, (fx.node.base_types, fx.Node)) + + non_graphable_args = { + k: v for k, v in combined_args.items() if not is_graphable(v) + } + graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)} + + constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args) + + return graphable_args, constant_args_idx + def call_HOP( self, variable: "TraceableTritonKernelWrapper", @@ -1760,15 +1781,8 @@ class TracingTritonHOPifier(TritonHOPifier): assert tx is None assert isinstance(variable, TraceableTritonKernelWrapper) - def is_graphable(val: Any) -> bool: - return isinstance(val, fx.node.base_types) + graphable_args, constant_args_idx = self.store_non_graphable_args(combined_args) - non_graphable_args = { - k: v for k, v in combined_args.items() if not is_graphable(v) - } - graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)} - - constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args) assert isinstance(variable.kernel_idx, int) return triton_kernel_wrapper_mutation( kernel_idx=variable.kernel_idx, diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index b85cfa778a8c..9236ebe7bb98 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1,5 +1,6 @@ from __future__ import annotations +import atexit import contextlib import dataclasses import enum @@ -8,8 +9,11 @@ import itertools import logging import math import operator +import os import re +import tempfile import typing +from abc import ABC, abstractmethod from enum import auto, Enum from itertools import chain from typing import ( @@ -60,6 +64,8 @@ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V if TYPE_CHECKING: from collections.abc import Iterator, MutableMapping, Sequence + from torch.fx import GraphModule + from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode from ..loop_body import LoopBody from ..scheduler import BaseScheduling, Scheduler, SchedulerNode @@ -83,6 +89,38 @@ def data_type_logger(msg: str) -> None: schedule_log.debug("Data type propagation: %s", msg) +@dataclasses.dataclass +class FileBackedGraphModule: + """ + Output of FX wrapper codegen. Exposes the same methods as ModuleType, but these + map back to a GraphModule instead of Python source. + """ + + gm: GraphModule + compiled_fn: Callable[..., Any] + + def __post_init__(self) -> None: + # Write the code to a file for compatibility with debugging utilities. + # The file is deleted upon program termination. + self.tempfile = tempfile.NamedTemporaryFile( + mode="w+", suffix=".py", delete=False + ) + atexit.register(os.remove, self.tempfile.name) + with self.tempfile as f: + f.write(self.value) + + @property + def __file__(self) -> str: + return self.tempfile.name + + def call(self, args: list[Any]) -> Any: + return self.compiled_fn(*args) + + @property + def value(self) -> str: + return self.gm.code + + class WorkspaceZeroMode(enum.Enum): UNINITIALIZED = 0 ZERO_ON_CALL = 1 # kernel may leave workspace dirty @@ -103,8 +141,22 @@ class WorkspaceZeroMode(enum.Enum): return WorkspaceZeroMode.UNINITIALIZED +class CodegenSymbol(ABC): + """ + An IR object possibly corresponding to a variable in the wrapper code. + """ + + @abstractmethod + def get_name(self) -> str: + pass + + @abstractmethod + def get_example(self) -> Union[torch.Tensor, sympy.Symbol]: + pass + + @ir_dataclass(frozen=True) -class WorkspaceArg: +class WorkspaceArg(CodegenSymbol): """A temporary buffer used for a single kernel, then discarded. Not registered as a traditional buffer since there are no users, @@ -167,6 +219,9 @@ class WorkspaceArg: def get_dtype(self) -> torch.dtype: return self.dtype + def get_example(self) -> Union[torch.Tensor, sympy.Symbol]: + return self.get_layout().get_example() + def get_layout(self) -> FixedLayout: from ..ir import FixedLayout @@ -185,6 +240,9 @@ class WorkspaceArg: maybe_get_output_spec = get_layout maybe_get_layout = get_layout + def get_offset(self) -> sympy.Expr: + return sympy.S.Zero + def get_size(self) -> list[sympy.Expr]: return [self.count] diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index c94233bd728d..45e1c591a63f 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -74,6 +74,7 @@ if TYPE_CHECKING: import triton from ..graph import GraphLowering + from .wrapper_fxir import FxConverter log = logging.getLogger(__name__) @@ -83,6 +84,7 @@ pexpr = PythonPrinter().doprint ReuseKey = tuple[torch.device, torch.dtype, str, bool] BufferLike = Union[ir.Buffer, WorkspaceArg] +FxConversionFunc = Callable[["WrapperLine"], None] def buffer_reuse_key(node: BufferLike) -> ReuseKey: @@ -349,7 +351,8 @@ class MemoryPlanningState: class WrapperLine: - pass + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + raise NotImplementedError("FX codegen not yet supported for type {type(self)}") @dataclasses.dataclass @@ -364,6 +367,9 @@ class EnterSubgraphLine(WrapperLine): self.wrapper.push_codegened_graph(self.graph) code.do_indent() + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_enter_subgraph + @dataclasses.dataclass class CommentLine(WrapperLine): @@ -372,6 +378,10 @@ class CommentLine(WrapperLine): def codegen(self, code: IndentedBuffer) -> None: code.writeline(self.line) + @staticmethod + def codegen_fx(converter: FxConverter) -> FxConversionFunc: + return converter._generate_comment + @dataclasses.dataclass class ExitSubgraphLine(WrapperLine): @@ -384,6 +394,9 @@ class ExitSubgraphLine(WrapperLine): self.wrapper.pop_codegened_graph() code.do_unindent() + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_exit_subgraph + @dataclasses.dataclass class EnterDeviceContextManagerLine(WrapperLine): @@ -419,12 +432,18 @@ class EnterDeviceContextManagerLine(WrapperLine): code.do_indent() code.writeline(V.graph.device_ops.set_device(self.device_idx)) + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_enter_device_context_manager + class ExitDeviceContextManagerLine(WrapperLine): def codegen(self, code: IndentedBuffer) -> None: if not V.graph.cpp_wrapper: code.do_unindent() + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_exit_device_context_manager + @dataclasses.dataclass class ExternKernelAllocLine(WrapperLine): @@ -436,6 +455,9 @@ class ExternKernelAllocLine(WrapperLine): args = [*node.codegen_args(), *node.codegen_kwargs()] self.wrapper._generate_extern_kernel_alloc_helper(self.node, args) + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_extern_kernel_alloc + @dataclasses.dataclass class ExternKernelOutLine(WrapperLine): @@ -466,6 +488,9 @@ class ExternKernelOutLine(WrapperLine): device, ) + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_extern_kernel_out + @dataclasses.dataclass class FreeLine(WrapperLine): @@ -476,6 +501,9 @@ class FreeLine(WrapperLine): assert self.node.get_name() not in V.graph.removed_buffers code.writeline(self.wrapper.make_buffer_free(self.node)) + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_free + @dataclasses.dataclass class KernelCallLine(WrapperLine): @@ -505,6 +533,9 @@ class KernelCallLine(WrapperLine): original_fxnode_name=self.original_fxnode_name, ) + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_kernel_call + @dataclasses.dataclass class KernelDefinitionLine(WrapperLine): @@ -524,6 +555,9 @@ class KernelDefinitionLine(WrapperLine): cpp_definition=self.cpp_definition, ) + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_kernel_definition + @dataclasses.dataclass class MemoryPlanningLine(WrapperLine): @@ -580,6 +614,9 @@ class AllocateLine(MemoryPlanningLine): line = self.wrapper.make_buffer_allocation(self.node) code.writeline(line) + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_allocate + @dataclasses.dataclass class FreeIfNotReusedLine(MemoryPlanningLine): @@ -603,6 +640,9 @@ class FreeIfNotReusedLine(MemoryPlanningLine): if not self.is_reused: code.writeline(self.wrapper.make_buffer_free(self.node)) + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_free_if_not_reused + @dataclasses.dataclass class ReinterpretLine(MemoryPlanningLine): @@ -620,6 +660,9 @@ class ReinterpretLine(MemoryPlanningLine): self.reused_as.get_name(), self.layout.view ) + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_reinterpret + @dataclasses.dataclass class ReuseLine(MemoryPlanningLine): @@ -641,9 +684,13 @@ class ReuseLine(MemoryPlanningLine): self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old) ) + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_reuse + class NullLine(MemoryPlanningLine): - pass + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_null @dataclasses.dataclass @@ -717,6 +764,9 @@ class CommBufferAllocateLine(CommBufferLine): f"Unsupported comm buffer type: {comm_buffer_type}" ) + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_comm_buffer_allocate + @dataclasses.dataclass class CommBufferFreeLine(CommBufferLine): @@ -724,6 +774,9 @@ class CommBufferFreeLine(CommBufferLine): line = self.wrapper.make_buffer_free(self.node) code.writeline(f"{line} # {self.comm_buffer_type.value} buffer free") + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_comm_buffer_free + @dataclasses.dataclass class MultiOutputLine(WrapperLine): @@ -760,6 +813,22 @@ class MultiOutputLine(WrapperLine): f"{self.wrapper.declare}{self.result_name} = {value}{self.wrapper.ending}" ) + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_multi_output + + +@dataclasses.dataclass +class SymbolicCallArgLine(WrapperLine): + wrapper: PythonWrapperCodegen + arg: SymbolicCallArg + graph: GraphLowering + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper._generate_symbolic_call_arg_helper(self.arg, self.graph) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_symbolic_call_arg + @dataclasses.dataclass class SymbolicCallArgLine(WrapperLine): diff --git a/torch/_inductor/codegen/wrapper_fxir.py b/torch/_inductor/codegen/wrapper_fxir.py new file mode 100644 index 000000000000..42cdc5d2e2f8 --- /dev/null +++ b/torch/_inductor/codegen/wrapper_fxir.py @@ -0,0 +1,596 @@ +import dataclasses +import operator +import textwrap +from collections import Counter +from typing import Any, Callable, Optional, Union + +import sympy + +import torch +from torch._higher_order_ops.triton_kernel_wrap import ( + TraceableTritonKernelWrapper, + tracing_triton_hopifier_singleton, + triton_kernel_wrapper_mutation, +) +from torch._inductor.codecache import PyCodeCache +from torch._inductor.runtime.triton_heuristics import CachingAutotuner +from torch._inductor.select_algorithm import extern_kernels # noqa: F401 +from torch._inductor.virtualized import V +from torch._library.triton import wrap_triton +from torch.fx import GraphModule + +from .. import ir +from ..utils import convert_shape_to_symint, convert_to_symint, LineContext +from .common import ( + CodegenSymbol, + FileBackedGraphModule, + WorkspaceArg, + WorkspaceZeroMode, +) +from .wrapper import ( + AllocateLine, + BufferLike, + CommBufferAllocateLine, + CommBufferFreeLine, + CommentLine, + EnterDeviceContextManagerLine, + EnterSubgraphLine, + ExitDeviceContextManagerLine, + ExitSubgraphLine, + ExternKernelAllocLine, + ExternKernelOutLine, + FreeIfNotReusedLine, + FreeLine, + KernelCallLine, + KernelDefinitionLine, + Line, + MultiOutputLine, + NullLine, + PythonWrapperCodegen, + ReinterpretLine, + ReuseLine, + SymbolicCallArg, + SymbolicCallArgLine, + WrapperLine, +) + + +aten = torch.ops.aten + + +@dataclasses.dataclass +class SymbolBuffer(CodegenSymbol): + """ + Represents a sympy.Symbol graph input. + """ + + symbol: sympy.Symbol + + def get_name(self) -> str: + return str(self.symbol) + + def get_example(self) -> Union[torch.Tensor, sympy.Symbol]: + return self.symbol + + +CodegenBuffer = Union[BufferLike, SymbolBuffer] + + +@dataclasses.dataclass +class TritonKernel: + """ + Stores metadata about Triton kernels for use in FX. + """ + + tuner: CachingAutotuner + wrapped: TraceableTritonKernelWrapper + + +class WrapperFxCodegen(PythonWrapperCodegen): + """ + Backend to generate wrapper code as an FX IR graph. + """ + + supports_caching = False + + def _generate(self, is_inference: bool) -> tuple[FileBackedGraphModule, None]: + self.run_wrapper_ir_passes(is_inference) + + prologue = "\n".join( + [ + self.imports.getvalue(), + self.header.getvalue(), + ] + ) + gm = FxConverter(lines=self.lines, prologue=prologue).generate() + compiled_fn = self.compile_graph(gm) + + return FileBackedGraphModule(gm, compiled_fn), None + + def compile_graph(self, gm: GraphModule) -> Callable[..., Any]: + """ + Converts the graph module into a runnable function. The default implementation + is simply an interpreter calling kernels in eager mode. Derived backends can + override this to do further compilation. + """ + return gm.forward + + @classmethod + def create( + cls, + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ) -> "WrapperFxCodegen": + if is_subgraph: + raise NotImplementedError( + "Subgraphs are not yet supported by FX conversion" + ) + + # For derived backends, this could be a subclass. + return cls() + + +@dataclasses.dataclass +class FxConverter: + """ + Generates FX IR from Wrapper IR. As each instance is only meant to be used once, the + input and output code are stored as attributes. + """ + + lines: list[Line] + prologue: str = "" + + def __post_init__(self) -> None: + graph = torch.fx.Graph() + self.gm = GraphModule({}, graph) # Wrapper FX IR. + self.buffer_to_node: dict[ + Optional[str], torch.fx.Node + ] = {} # Symbol table for codegen. + self.kernels: dict[str, TritonKernel] = {} # Table to store Triton kernels. + self._unique_symbol_ids: Counter[str] = Counter() + + def _import_kernel(self, code: str, kernel_name: str) -> CachingAutotuner: + """ + Imports a kernel from source, possibly autotuning block parameters. + """ + module_code = "\n".join([self.prologue, code]) + mod = PyCodeCache.load(module_code) + kernel = getattr(mod, kernel_name) + + if not isinstance(kernel, CachingAutotuner): + raise NotImplementedError( + textwrap.dedent(f""" + Unsupported type for kernel {kernel_name}: {type(kernel)}. + FX conversion only supports Triton kernels. + """) + ) + + return kernel + + def _fake_tensor( + self, + size: tuple[Any, ...], + stride: tuple[Any, ...], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + with V.fake_mode: + return torch.empty_strided( + convert_shape_to_symint(size), + convert_shape_to_symint(stride), + dtype=dtype, + device=device, + ) + + def _create_meta_from_buffer( + self, node: torch.fx.Node, buffer: CodegenBuffer + ) -> None: + name = buffer.get_name() + assert name + node.name = name + node.meta["val"] = buffer.get_example() + + def _record_allocation(self, buffer: CodegenBuffer, node: torch.fx.Node) -> None: + """ + Updates the symbol table to record that an Inductor buffer maps to the result of + an FX node. + """ + assert node not in self.buffer_to_node + self.buffer_to_node[buffer.get_name()] = node + + def _free(self, buffer: Union[CodegenBuffer, ir.TorchBindObject]) -> None: + """ + Removes the buffer from the symbol table. + """ + name = buffer.get_name() + del self.buffer_to_node[name] + + def _lookup_args(self, args: tuple[Any, ...]) -> tuple[Any, ...]: + """ + Maps call args back to FX nodes. + """ + return tuple( + self.buffer_to_node[arg] + if isinstance(arg, str) + else arg.inner_expr + if isinstance(arg, SymbolicCallArg) + else arg + for arg in args + ) + + def _get_buffer(self, node: ir.IRNode) -> CodegenBuffer: + """ + Extract buffer data from an IR node. + """ + if isinstance(node, (ir.Buffer, WorkspaceArg)): + return node + elif isinstance(node, (ir.BaseView, ir.MutableBox)): + return self._get_buffer(node.data) + elif isinstance(node, sympy.Symbol): + return SymbolBuffer(node) + else: + raise NotImplementedError(f"Unable to extract buffer from node: {node}") + + def _generate_graph_inputs(self) -> None: + """ + Converts graph inputs to FX placeholders. + """ + for ir_node in V.graph.graph_inputs.values(): + buffer = self._get_buffer(ir_node) + node = self.gm.graph.placeholder(buffer.get_name()) + self._create_meta_from_buffer(node, buffer) + self._record_allocation(buffer, node) + + def _generate_buffer(self, node: ir.IRNode) -> Optional[torch.fx.Node]: + """ + Generates FX IR for transformations on a buffer, such as ReinterpretView. + Does nothing if no such transformations are present. + """ + + def generate_to_buffer(node: ir.IRNode) -> Optional[BufferLike]: + if isinstance(node, (ir.Buffer, WorkspaceArg)): + return node + elif isinstance(node, ir.NoneAsConstantBuffer): + return None + elif isinstance(node, ir.StorageBox): + return generate_to_buffer(node.data) + elif isinstance(node, ir.ReinterpretView): + # We need to introduce a new symbol if the output is a ReinterpretView. + # Use a WorkspaceArg for this. + buffer = self._get_buffer(node.data) + assert isinstance(buffer, (ir.Buffer, WorkspaceArg)) + unique_name = self.gm.graph._graph_namespace.create_name( + f"{buffer.get_name()}_view", None + ) + device = buffer.get_device() + assert device + reused_as = WorkspaceArg( + count=buffer.get_size(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + device=device, + outer_name=unique_name, + dtype=buffer.get_dtype(), + ) + + # Generate FX IR for the view. + self._generate_reinterpret_helper(buffer, reused_as, node.layout) + + return reused_as + else: + raise NotImplementedError(f"Unrecognized buffer/view node: {node}") + + buffer = generate_to_buffer(node) + return self.buffer_to_node[buffer.get_name()] if buffer is not None else None + + def _generate_output(self) -> None: + """ + Generate FX IR for graph outputs. + """ + output_nodes = [ + self._generate_buffer(node) + for idx, node in enumerate(V.graph.graph_outputs) + ] + + # Single return elements don't use a tuple. + output_value = output_nodes[0] if len(output_nodes) == 1 else output_nodes + + self.gm.graph.output(output_value) + + def generate(self) -> torch.fx.GraphModule: + """ + Main entrypoint for FX codegen. + """ + self._generate_graph_inputs() + + # Generate FX IR from Wrapper IR lines. + for line in self.lines: + if isinstance(line, WrapperLine): + line.codegen_fx(self)(line) + elif isinstance(line, LineContext): + # Ignore line context in FX IR. + pass + else: + raise NotImplementedError( + textwrap.dedent( + f""" + Found line of unrecognized type '{type(line)}': + '{line}' + + FX conversion only supports Wrapper IR lines. + """ + ) + ) + + self._generate_output() + self.gm.recompile() + return self.gm + + def _generate_allocate(self, line: WrapperLine) -> None: + assert isinstance(line, AllocateLine) + buffer = line.node + name = buffer.get_name() + assert name not in V.graph.removed_buffers + + device = buffer.get_device() + dtype = buffer.get_dtype() + shape = convert_shape_to_symint(buffer.get_size()) + stride = convert_shape_to_symint(buffer.get_stride()) + + node = self.gm.graph.call_function( + torch.empty_strided, + args=(shape, stride), + kwargs={"dtype": dtype, "device": device}, + ) + assert name + node.name = name + self._create_meta_from_buffer(node, buffer) + self._record_allocation(buffer, node) + + def _generate_comment(self, line: WrapperLine) -> None: + assert isinstance(line, CommentLine) + # We ignore comments in FX IR. + + def _generate_enter_device_context_manager(self, line: WrapperLine) -> None: + assert isinstance(line, EnterDeviceContextManagerLine) + # We ignore the device context in FX IR. + + def _generate_exit_device_context_manager(self, line: WrapperLine) -> None: + assert isinstance(line, ExitDeviceContextManagerLine) + # We ignore the device context in FX IR. + + def _generate_enter_subgraph(self, line: WrapperLine) -> None: + assert isinstance(line, EnterSubgraphLine) + raise NotImplementedError("Subgraphs are not yet supported by FX conversion") + + def _generate_exit_subgraph(self, line: WrapperLine) -> None: + assert isinstance(line, ExitSubgraphLine) + raise NotImplementedError("Subgraphs are not yet supported by FX conversion") + + def _generate_free(self, line: WrapperLine) -> None: + assert isinstance(line, FreeLine) + + buf = line.node + + # No need to free placeholders. + if self.buffer_to_node[buf.get_name()].op == "placeholder": + return + + self._free(buf) + + def _generate_free_if_not_reused(self, line: WrapperLine) -> None: + assert isinstance(line, FreeIfNotReusedLine) + buf = line.node + assert buf.get_name() not in V.graph.removed_buffers + if not line.is_reused: + self._free(buf) + + def _generate_line_context(self, line: WrapperLine) -> None: + assert isinstance(line, LineContext) + # We ignore line context in FX IR. + + def _generate_reinterpret(self, line: WrapperLine) -> None: + assert isinstance(line, ReinterpretLine) + self._generate_reinterpret_helper(line.node, line.reused_as, line.layout) + + def _generate_reinterpret_helper( + self, input_buffer: BufferLike, result_buffer: BufferLike, layout: ir.Layout + ) -> None: + input_node = self.buffer_to_node[input_buffer.get_name()] + + # Look up output metadata. + name = result_buffer.get_name() + assert name + size = tuple(layout.size) + stride = tuple(layout.stride) + offset = input_buffer.get_offset() + layout.offset + + # Map ReinterpretView to as_strided. + result_node = self.gm.graph.call_function( + torch.as_strided, args=(input_node, size, stride, offset) + ) + result_node.name = name + result_node.meta["val"] = layout.get_example() + self._record_allocation(result_buffer, result_node) + + def _generate_reuse(self, line: WrapperLine) -> None: + assert isinstance(line, ReuseLine) + old = line.node + new = line.reused_as + assert not any(buf.get_name() in V.graph.removed_buffers for buf in (old, new)) + assert old.get_dtype() == new.get_dtype() + + old_node = self.buffer_to_node[old.get_name()] + result_node = old_node + + # Change shape and stride. + size = new.get_size() + stride = new.get_stride() + offset = new.get_offset() + if ( + old.get_size() != size + or old.get_stride() != stride + or old.get_offset() != offset + ): + result_node = self.gm.graph.call_function( + torch.as_strided, args=(old_node, size, stride, offset) + ) + self._create_meta_from_buffer(result_node, new) + + self._record_allocation(new, result_node) + + # Free the old buffer, if we allocated a new tensor. + if ( + old.get_name() not in V.graph.get_output_names() + and line.delete_old + and result_node is not old_node + ): + self._free(old) + + def _generate_multi_output(self, line: WrapperLine) -> None: + assert isinstance(line, MultiOutputLine) + + # Extract the index for tuple access. + inds = line.indices[0][1:] + assert len(inds) == 1, f"Cannot convert {inds} to an index." + idx = inds[0] + + arg_node = self.buffer_to_node[line.arg_name] + node = self.gm.graph.call_function(operator.getitem, args=(arg_node, idx)) + node.meta["val"] = arg_node.meta["val"][idx] + node.name = line.result_name + self.buffer_to_node[line.result_name] = node + + def _generate_null(self, line: WrapperLine) -> None: + assert isinstance(line, NullLine) + # Does nothing. + + def _generate_comm_buffer_allocate(self, line: WrapperLine) -> None: + assert isinstance(line, CommBufferAllocateLine) + raise NotImplementedError("Comm buffer allocation is not yet supported") + + def _generate_comm_buffer_free(self, line: WrapperLine) -> None: + assert isinstance(line, CommBufferFreeLine) + self._free(line.node) + + def _generate_triton_call(self, line: WrapperLine) -> None: + assert isinstance(line, KernelCallLine) + + # Collect all kwargs, including autotuned block sizes. + call_args = self._lookup_args(line.call_args) + kernel = self.kernels[line.kernel_name] + tuner = kernel.tuner + config = tuner.compile_results[0].config + call_args, grid = tuner._interpret_args_grid(call_args, config) + call_kwargs = dict(zip(tuner.triton_meta["signature"], call_args)) + call_kwargs.update(config.kwargs) + + # Convert sympy expressions to symints. + for name, val in call_kwargs.items(): + if isinstance(val, sympy.Expr): + call_kwargs[name] = convert_to_symint(val) + + # Store non-graphable kwargs in the side table. + ( + call_kwargs, + constant_args_idx, + ) = tracing_triton_hopifier_singleton.store_non_graphable_args(call_kwargs) + + self.gm.graph.call_function( + triton_kernel_wrapper_mutation, + kwargs={ + "kernel_idx": kernel.wrapped.kernel_idx, + "constant_args_idx": constant_args_idx, + "grid": [convert_shape_to_symint(grid)], + "tma_descriptor_metadata": {}, + "kwargs": call_kwargs, + }, + ) + + def _generate_extern_kernel_alloc(self, line: WrapperLine) -> None: + assert isinstance(line, ExternKernelAllocLine) + node = line.node + self._generate_extern_kernel_common(node, node) + + def _generate_extern_kernel_out( + self, + line: WrapperLine, + ) -> None: + assert isinstance(line, ExternKernelOutLine) + node = line.node + out_node = node.output_view if node.output_view else node + self._generate_extern_kernel_common(node, out_node) + + def _generate_extern_kernel_common( + self, kernel: ir.ExternKernel, out_ir_node: ir.IRNode + ) -> None: + """ + Generates FX IR from either ExternKernelAlloc or ExternKernelOut. + """ + + # Get FX nodes corresponding to the call args. + tensor_nodes = tuple(self._generate_buffer(arg) for arg in kernel.inputs) + args = tensor_nodes + tuple(kernel.constant_args) + + # Get the result buffer. + # Some kernels write to a pre-existing output tensor via the "out" kwarg. + kwargs = kernel.kwargs.copy() + result_buffer: Optional[str] = None + if isinstance(kernel, ir.ExternKernelOut): + kwargs["out"] = self.buffer_to_node[out_ir_node.codegen_reference()] + elif isinstance(kernel.layout, (ir.Layout, ir.MultiOutputLayout)): + result_buffer = kernel.get_name() + elif isinstance(kernel.layout, ir.NoneLayout): + pass + else: + raise NotImplementedError(f"Unrecognized output layout: {kernel.layout}") + + # Look up the kernel function from its name. + kernel_name = kernel.get_kernel_name() + module_name, kernel_name = kernel_name.split(".", 1) + op = globals()[module_name] # E.g. extern_kernels, aten, etc. + for subname in kernel_name.split("."): + op = getattr(op, subname) # E.g. extern_kernels.addmm + + fx_node = self.gm.graph.call_function(op, args=args, kwargs=kwargs) + + # Assign the result to the given name. + if result_buffer: + assert "out" not in kwargs, ( + f"Extern kernel '{kernel}' has both result and out kwarg. Expected only one." + ) + fx_node.name = result_buffer + self.buffer_to_node[result_buffer] = fx_node + + arg_tensors = [ + arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg + for arg in args + ] + + # Run the operation to propagate metadata. + fx_node.meta["val"] = op(*arg_tensors, **kwargs) + + def _generate_kernel_call(self, line: WrapperLine) -> None: + assert isinstance(line, KernelCallLine) + if not line.triton: + raise NotImplementedError("FX conversion only supports Triton kernels.") + + self._generate_triton_call(line) + + def _generate_kernel_definition(self, line: WrapperLine) -> None: + assert isinstance(line, KernelDefinitionLine) + + # Generate code for the kernel. + kernel_code = PythonWrapperCodegen._format_kernel_definition( + line.kernel_name, line.kernel_body, metadata=line.metadata + ) + + # Import the module and store the JIT kernel. + tuner = self._import_kernel(kernel_code, line.kernel_name) + wrapped = wrap_triton(tuner.fn) + self.kernels[line.kernel_name] = TritonKernel(tuner, wrapped) + + def _generate_symbolic_call_arg(self, line: WrapperLine) -> None: + assert isinstance(line, SymbolicCallArgLine) + # No need for an FX node, as we will pass the arg to kernels via a SymInt. diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 892336a3f7e3..b8d4ba329472 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -50,6 +50,7 @@ from . import config, ir, metrics from .codegen.common import ( BackendFeature, DeviceOpOverrides, + FileBackedGraphModule, get_backend_features, get_device_op_overrides, get_wrapper_codegen_for_device, @@ -115,9 +116,12 @@ if TYPE_CHECKING: from torch._higher_order_ops.effects import _EffectType from torch.fx import GraphModule from torch.fx.graph import Graph + from .codegen.wrapper import PythonWrapperCodegen from .scheduler import BaseSchedulerNode + CompiledModule = Union[ModuleType, FileBackedGraphModule] + from torch._inductor.codecache import output_code_log @@ -2224,7 +2228,7 @@ class GraphLowering(torch.fx.Interpreter): # No-op to be patched for unit tests save_output_code: Optional[Callable[[str], None]] = None - def compile_to_module(self) -> ModuleType: + def compile_to_module(self) -> CompiledModule: with dynamo_timed( "GraphLowering.compile_to_module", phase_name="code_gen", @@ -2233,14 +2237,41 @@ class GraphLowering(torch.fx.Interpreter): ): return self._compile_to_module() - def _compile_to_module(self) -> ModuleType: - from .codecache import PyCodeCache - + def _compile_to_module(self) -> CompiledModule: # Currently, if we're here, we don't have to worry about the kernel code, which # is only available in AOTInductor mode. wrapper_code, _ = ( self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() ) + + if isinstance(wrapper_code, ValueWithLineMap): + mod = self._compile_to_module_lines(wrapper_code) + elif isinstance(wrapper_code, FileBackedGraphModule): + mod = wrapper_code + else: + raise NotImplementedError( + f"Unrecognized wrapper code type: {type(wrapper_code)}" + ) + + # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029 + # TODO. Revisit this once the logging API is more mature + assert mod.__file__ is not None + + log_module_code(mod.__file__) + log.debug("Output code written to: %s", mod.__file__) + output_code_log.info("Output code written to: %s", mod.__file__) + if config.benchmark_kernel: + print(f"Compiled module path: {mod.__file__}", file=sys.stderr) + V.debug.output_code(mod.__file__) + V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug") + + return mod + + def _compile_to_module_lines( + self, wrapper_code: ValueWithLineMap + ) -> CompiledModule: + from .codecache import PyCodeCache + if config.triton.autotune_at_compile_time: tuning_code = ( '"""\n' @@ -2291,17 +2322,7 @@ class GraphLowering(torch.fx.Interpreter): if config.benchmark_harness and config.profile_bandwidth_output: # run the inputs code gen to get the bandwidth info mod.benchmark_compiled_module(times=1, repeat=1) - # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029 - # TODO. Revisit this once the logging API is more mature - assert mod.__file__ is not None - log_module_code(mod.__file__) - log.debug("Output code written to: %s", mod.__file__) - output_code_log.info("Output code written to: %s", mod.__file__) - if config.benchmark_kernel: - print(f"Compiled module path: {mod.__file__}", file=sys.stderr) - V.debug.output_code(mod.__file__) - V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug") return mod def get_output_names(self) -> list[str]: diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 54f451ad5843..76fdb73e231b 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -65,6 +65,7 @@ from torch.utils._sympy.symbol import SymT from . import config, dependencies from .codegen.common import ( BackendFeature, + CodegenSymbol, get_scheduling_for_device, index_prevent_reordering, ) @@ -3423,6 +3424,15 @@ class Layout(OutputSpec): def get_device(self) -> torch.device: return self.device + def get_example(self) -> torch.Tensor: + with V.fake_mode: + return torch.empty_strided( + convert_shape_to_symint(self.size), + convert_shape_to_symint(self.stride), + dtype=self.dtype, + device=self.device, + ) + def is_contiguous(self) -> bool: return is_contiguous_strides_for_shape(self.stride, self.size) @@ -3926,7 +3936,7 @@ class MutationLayoutSHOULDREMOVE(Layout): @ir_dataclass(frozen=False) -class Buffer(IRNode): +class Buffer(IRNode, CodegenSymbol): # Name is sometimes None; e.g., ForceInPlace, where there isn't # a meaningful name name: Optional[str] @@ -3946,6 +3956,11 @@ class Buffer(IRNode): assert self.name, self return self.name + def get_example(self) -> Union[torch.Tensor, sympy.Symbol]: + if isinstance(self.layout, Layout): + return self.layout.get_example() + raise NotImplementedError(type(self.layout).__name__) + def get_device(self) -> Optional[torch.device]: return self.get_output_spec().get_device() diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index a2625f3cfb61..d71c7324ddb8 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -85,7 +85,7 @@ class NoTritonConfigsError(RuntimeError): if TYPE_CHECKING: - from collections.abc import Container, Hashable, Sequence + from collections.abc import Container, Hashable from torch._guards import CompileId @@ -2564,13 +2564,15 @@ class GridExpr: inductor_meta: dict[str, Any] mode: Literal["python", "cpp"] = "python" - prefix: Sequence[str] = () + prefix: list[str] = dataclasses.field(default_factory=list) x_grid: Union[str, int] = 1 y_grid: Union[str, int] = 1 z_grid: Union[str, int] = 1 def __post_init__(self) -> None: assert self.mode in ("python", "cpp") + if self.mode == "python": + self.prefix.append("from torch.utils._sympy.functions import FloorDiv") def generate(self, meta: dict[str, int]) -> None: raise NotImplementedError @@ -2583,7 +2585,9 @@ class GridExpr: if isinstance(numel, int) and isinstance(block, int): return ceildiv(numel, block) # constant fold if self.mode == "python": - return f"-(({numel}) // -({block}))" + # Use FloorDiv instead of // so we can get better sympy expressions for + # dynamic shapes. + return f"-FloorDiv(({numel}), -({block}))" # trick above doesn't work in C++ due to rounding differences return f"(({numel} + ({block} - 1)) / ({block}))" @@ -2666,12 +2670,16 @@ class Grid3D(GridExpr): class Grid2DWithYZOverflow(GridExpr): def generate(self, meta: dict[str, int]) -> None: self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) - self.prefix = [ - self.assign_tmp("y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK"))), - self.assign_tmp( - "y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid()) - ), - ] + self.prefix.extend( + [ + self.assign_tmp( + "y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK")) + ), + self.assign_tmp( + "y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid()) + ), + ] + ) self.y_grid = self.ceildiv("y_grid_raw_", "y_grid_div_") self.z_grid = "y_grid_div_" diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 051d60a02019..59777b2dc001 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -436,6 +436,23 @@ def convert_shape_to_inductor( return [sympy.sympify(i) for i in lst] +def convert_to_symint(i: Union[int, sympy.Expr]) -> Union[int, torch.SymInt]: + """ + Like convert_shape_to_symint, but operates on a single expression. + """ + from .virtualized import V + + return ( + i + if isinstance(i, int) + else ( + int(i) + if isinstance(i, sympy.Integer) + else V.graph.sizevars.shape_env.create_symintnode(i, hint=None) + ) + ) + + def convert_shape_to_symint( lst: Iterable[Union[int, sympy.Expr]], ) -> list[Union[int, torch.SymInt]]: @@ -443,20 +460,7 @@ def convert_shape_to_symint( Takes a list of shapes from Inductor and converts them into symints (or just ints if all shapes are static). """ - from .virtualized import V - - return [ - ( - i - if isinstance(i, int) - else ( - int(i) - if isinstance(i, sympy.Integer) - else V.graph.sizevars.shape_env.create_symintnode(i, hint=None) - ) - ) - for i in lst - ] + return [convert_to_symint(i) for i in lst] def is_view(op: torch._ops.OpOverload) -> bool: