[Inductor] FX backend via Wrapper IR (#146942)

# Sub-PRs

These PRs contain refactors from the main one. They should be reviewed and merged first.

- https://github.com/pytorch/pytorch/pull/150458
- https://github.com/pytorch/pytorch/pull/152391
- https://github.com/pytorch/pytorch/pull/152587

# Feature

The goals of this PR are twofold.

## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen.

In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components.

This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR.

## Goal 2: Convert Wrapper IR into FX IR.

One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc.

It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes.

The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source.

Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa.

# Current status

Things that seem to work:

- Converted a lot of the most common Python codegen lines to Wrapper IR lines.
     - Handled the following cases, in addition to what was already in the Memory Planning IR:
         - Comments
         - Triton kernels
         - Extern/fallback kernels
         - Freeing tensors (`del buf0`)
         - MultiOutput
         - Graph outputs
         - ReinterpretView / StorageBox, for both call args and outputs.
     - FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code.
- Prototype FX converter which can handle some of the most common use cases.
   - Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565).
   - Calling wrapped Triton kernels.
   - Calling extern kernels and certain types of fallback kernels.
       - Support both `extern_kernels.*` and `aten.*`.
       - Support multi-output kernels like `torch.topk`.
   - Graphs with multiple inputs/outputs.
   - Training i.e. calling `Tensor.backward()` in a compiled function.
   - Graph breaks (training).
- Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen.

Things that don't work:
- Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections.
         - Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX.
         - Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR.
         - CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this.

# Out-of-tree compilers

With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below.

```
from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen

class MyCustomBackend(WrapperFxCodegen):
     def compile_graph(self, gm):
         # Add 1 to the graph's outputs
         def compiled_fn(*args):
             return [x + 1 for x in gm.graph.forward(*args)]
         return compiled_fn
```

# Example FX graphs

This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`.

Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table.
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
    return (buf0,)
```

Here's a more complicated graph that calls a `torch.addmm` extern kernel.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=2] = placeholder[target=arg1_1]
    %buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}})
    %buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
    %addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2})
    %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
    return (buf2,)
```

Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {})
    %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {})
    %buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {})
    %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}})
    %triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}})
    return (buf1, buf2)
```

Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
    %buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {})
    return (buf0_view_buf0_0,)
```

Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`.
```
graph():
    %s6 : [num_users=0] = placeholder[target=s6]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}})
    return buf0
```

Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions.
```
graph():
    %s10 : [num_users=0] = placeholder[target=s10]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}})
    return buf0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146942
Approved by: https://github.com/jansel
This commit is contained in:
Blaine Burton Rister
2025-05-05 19:34:49 +00:00
committed by PyTorch MergeBot
parent fdadda21b6
commit a7691140a0
9 changed files with 1251 additions and 49 deletions

View File

@ -0,0 +1,417 @@
# Owner(s): ["module: inductor"]
"""
Test the FX IR backend.
"""
import itertools
import operator
import unittest
from typing import Callable, Optional
import sympy
import torch
import torch._inductor.codegen.common as common
import torch.utils._pytree as pytree
from torch._dynamo.exc import BackendCompilerFailed
from torch._dynamo.utils import same
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
from torch._inductor import config
from torch._inductor.codegen.common import register_backend_for_device
from torch._inductor.codegen.cpp import CppScheduling
from torch._inductor.codegen.triton import TritonScheduling
from torch._inductor.codegen.wrapper_fxir import FxConverter, WrapperFxCodegen
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_GPU,
requires_gpu,
TRITON_HAS_CPU,
)
@requires_gpu()
@config.patch(
compile_threads=1,
alignment_asserts=False,
size_asserts=False,
scalar_asserts=False,
nan_asserts=False,
)
class FxirTestCase(InductorTestCase):
device = GPU_TYPE
def _count_ops(self, gm: torch.fx.GraphModule, target: Callable) -> int:
return len(gm.graph.find_nodes(op="call_function", target=target))
def _run_and_capture_graphs(self, opt, args) -> torch.fx.GraphModule:
gms = []
orig_generate = FxConverter.generate
def generate(self) -> torch.fx.GraphModule:
nonlocal gms
gm = orig_generate(self)
gms.append(gm)
return gm
with unittest.mock.patch.object(
torch._inductor.codegen.wrapper_fxir.FxConverter, "generate", generate
):
opt(*args)
return gms
def _compile_and_check(
self,
func,
args,
expected_num_triton_kernels: int = 1,
metadata_only: bool = False,
compile_kwargs: Optional[dict] = None,
):
if compile_kwargs is None:
compile_kwargs = {}
opt = torch.compile(func, **compile_kwargs)
# Get the FX graph from the backend.
gms = self._run_and_capture_graphs(opt, args)
# Check the code for triton kernels.
num_kernels = sum(
self._count_ops(gm, triton_kernel_wrapper_mutation) for gm in gms
)
self.assertEqual(num_kernels, expected_num_triton_kernels)
# Check accuracy.
result = opt(*args)
ref = func(*args)
if metadata_only:
# When we only want to check metadata, fill in zeros for tensor data.
ref, result = tuple(
pytree.tree_map(torch.zeros_like, x) for x in (ref, result)
)
self.assertTrue(same(ref, result))
return gms
@classmethod
def setUpClass(cls):
super().setUpClass()
# Register the FX backend.
register_backend_for_device(cls.device, TritonScheduling, WrapperFxCodegen)
def test_basic(self):
args = [torch.randn(8, device=self.device) for _ in range(2)]
self._compile_and_check(torch.add, args)
def test_multiple_kernels(self):
def foo(x, y):
return x.sum() + y.sum()
args = [torch.randn(length, device=self.device) for length in [517, 1029]]
self._compile_and_check(foo, args, expected_num_triton_kernels=2)
def test_free(self):
"""
Test a program that frees a buffer which is no longer in use.
"""
def foo(x, y, z):
w = x.sum() + y
return z.sum() + w.sum()
args = [torch.randn(length, device=self.device) for length in [517, 1029, 123]]
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=3)
# Check the generated code for frees.
num_frees = gm.code.count("= None")
self.assertGreater(num_frees, 0)
def test_extern(self):
"""
Test a program that calls an extern kernel.
"""
def foo(x, y):
return x @ y + y.sum()
args = [
torch.randn(size, device=self.device) for size in [(129, 129), (129, 1)]
]
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
# Check for the extern kernel
num_extern = self._count_ops(gm, extern_kernels.addmm)
self.assertEqual(num_extern, 1)
def test_fallback(self):
"""
Test a program that calls an aten fallback.
"""
length = 8
def foo(x):
return x + torch.randn(1, device=self.device)
args = (torch.randn(length, device=self.device),)
# Since the program has a random output, just check metadata.
# Don't check for an exact value.
(gm,) = self._compile_and_check(
foo, args, expected_num_triton_kernels=2, metadata_only=True
)
# Check for the fallback kernel.
num_fallback = self._count_ops(gm, torch.ops.aten.randint.low_out)
self.assertEqual(num_fallback, 1)
def test_cat_inputs(self):
"""
Test concatenation of graph inputs.
"""
def foo(x, y):
return torch.cat((x, y)) + 1
args = [torch.randn(8, device=self.device) for _ in range(2)]
self._compile_and_check(foo, args, expected_num_triton_kernels=1)
def test_cat_to_alloc(self):
"""
Test concatenation that's optimized out to an allocation.
"""
length = 8
def foo(x):
y, z = tuple(
torch.arange(length // 2, device=self.device) for _ in range(2)
)
return x + torch.cat((y, z))
args = [torch.randn(length, device=self.device)]
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
# Expect a single allocation, even though eager mode would use 2.
num_allocs = self._count_ops(gm, torch.empty_strided)
self.assertEqual(num_allocs, 1)
def test_cat_reinterpret_view(self):
"""
Test torch.cat using ReinterpretView.
"""
length = 8
def foo(x):
y, z = tuple(torch.randn(length // 2, device=self.device) for _ in range(2))
return x + torch.cat((y, z))
args = [torch.randn(length, device=self.device)]
# Since this test generates random numbers, check metadata only.
(gm,) = self._compile_and_check(
foo, args, expected_num_triton_kernels=3, metadata_only=True
)
# Check for as_strided. We map ReinterpretView to this.
num_as_strided = self._count_ops(gm, torch.as_strided)
self.assertEqual(num_as_strided, 2)
def test_reshape_output(self):
"""
Test reshaping the output, which maps to a ReinterpretView.
"""
def foo(x, y):
return torch.reshape(x + y, (8,))
args = [torch.randn((2, 4), device=self.device) for _ in range(2)]
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
# Check for as_strided. We map ReinterpretView to this.
num_as_strided = self._count_ops(gm, torch.as_strided)
self.assertEqual(num_as_strided, 1)
def test_extern_multi_output(self):
"""
Test an extern kernel with multiple outputs.
Also test a graph with multiple outputs.
"""
def foo(x):
top, idx = torch.topk(x, 2)
return top + 1, idx * 2
args = [torch.randn(8, device=self.device)]
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=2)
# Check for multiple kernel outputs via getitems.
num_getitems = self._count_ops(gm, operator.getitem)
self.assertEqual(num_getitems, 2)
# Check for multiple graph outputs.
output_node = gm.graph.find_nodes(op="output")[0]
self.assertEqual(len(output_node.args[0]), 2)
def test_duplicate_input(self):
"""
Test duplicated inputs. This will collapse into a single input in the GM.
"""
args = [torch.randn(4, device=self.device)] * 2
(gm,) = self._compile_and_check(torch.add, args, expected_num_triton_kernels=1)
num_placeholders = len(gm.graph.find_nodes(op="placeholder"))
self.assertEqual(num_placeholders, 1)
def test_backward(self):
"""
Test a program with a backward pass.
"""
x = torch.ones(5, device=self.device) # input tensor
y = torch.zeros(3, device=self.device) # expected output
w = torch.randn(5, 3, requires_grad=True, device=self.device)
b = torch.randn(3, requires_grad=True, device=self.device)
def foo(x, y):
z = torch.matmul(x, w) + b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss.backward()
return w.grad, b.grad
# Expect separate forward and backward graphs.
(forward_gm, backward_gm) = self._compile_and_check(
foo, (x, y), expected_num_triton_kernels=3
)
def test_custom_compiler(self):
"""
Test a derived backend with a custom compiler.
"""
offset = 1
class CustomWrapperCodegen(WrapperFxCodegen):
def compile_graph(self, gm):
def compiled_fn(*args):
# Adds an offset to the program's outputs.
outputs = gm(*args)
return pytree.tree_map(lambda x: x + 1, outputs)
return compiled_fn
args = [torch.randn(8, device=self.device) for _ in range(2)]
custom_backend = common.DeviceCodegen(
TritonScheduling, CustomWrapperCodegen, None
)
with unittest.mock.patch.dict(
common.device_codegens, {self.device: custom_backend}
):
func = torch.add
opt = torch.compile(func)
result = opt(*args)
# Check the output is offset from eager mode.
ref = func(*args)
self.assertFalse(same(result, ref))
self.assertNotEqual(offset, 0)
self.assertTrue(same(result - offset, ref))
def test_dynamic_shapes_and_strides(self):
"""
Test a graph with dynamic shapes and strides.
"""
static_dims = (8, 8)
def get_input():
full_size = (16, 8)
full = torch.randn(full_size, device=self.device)
view = torch.as_strided(full, static_dims, full.stride())
return view
func = torch.add
args = [get_input() for _ in range(2)]
(gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True})
# Check for a symbolic output shape.
(empty_strided,) = gm.graph.find_nodes(
op="call_function", target=torch.empty_strided
)
example_tensor = empty_strided.meta["val"]
symbolic_dims = example_tensor.shape
self.assertEqual(len(symbolic_dims), len(static_dims))
# Check for symbolic output strides.
(stride, one) = example_tensor.stride()
self.assertEqual(one, sympy.S.One)
# Find the size symbols, and check for a corresponding placeholders defining them.
for symbol in itertools.chain(symbolic_dims, [stride]):
self.assertTrue(isinstance(symbol, torch.SymInt))
(placeholder,) = [
node
for node in gm.graph.find_nodes(op="placeholder")
if node.name == str(symbol)
]
self.assertEqual(placeholder.meta["val"], symbol)
@config.patch({"trace.enabled": True})
@unittest.mock.patch("torch._inductor.debug.DebugFormatter.output_code")
def test_debug(self, mock_output_code):
# Compile in debug mode.
args = [torch.randn(11, device=self.device) for _ in range(2)]
self._compile_and_check(torch.sub, args)
# Check the output code for a Triton kernel call.
mock_output_code.assert_called_once()
(output_filename,) = mock_output_code.call_args.args
with open(output_filename) as f:
output_code = f.read()
self.assertIn("triton_kernel_wrapper_mutation", output_code)
@torch._inductor.config.patch("graph_partition", True)
def test_subgraph_raises(self):
"""
Test a model with subgraphs. This is not yet supported, so check that we get the
expected exception.
"""
def foo(cond, x):
return torch.cond(cond, torch.cos, torch.sin, [x])
cond = torch.tensor([True], device=self.device)
x = torch.ones([2, 3], device=self.device)
with self.assertRaisesRegex(BackendCompilerFailed, "Subgraph"):
self._compile_and_check(foo, [cond, x])
def test_cpp_raises(self):
"""
Test the C++ CPU backend. C++ kernels are not yet supported, so for now check
that we get the expected exception.
"""
def foo(x, y):
return x + y * 5
device = torch.device("cpu")
args = [torch.randn(5, device=device) for _ in range(2)]
cpp_backend = common.DeviceCodegen(CppScheduling, WrapperFxCodegen, None)
with unittest.mock.patch.dict(
common.device_codegens, {device.type: cpp_backend}
), self.assertRaisesRegex(BackendCompilerFailed, "Triton"):
self._compile_and_check(foo, args)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_GPU or TRITON_HAS_CPU:
run_tests(needs="filelock")

View File

@ -1750,6 +1750,27 @@ class TracingTritonHOPifier(TritonHOPifier):
# normalize to tuple
return tuple(grid)
def store_non_graphable_args(
self,
combined_args: dict[str, Any],
) -> tuple[dict, int]:
"""
Some args cannot be stored in the FX graph.
Put them in the side table.
"""
def is_graphable(val: Any) -> bool:
return isinstance(val, (fx.node.base_types, fx.Node))
non_graphable_args = {
k: v for k, v in combined_args.items() if not is_graphable(v)
}
graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)}
constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args)
return graphable_args, constant_args_idx
def call_HOP(
self,
variable: "TraceableTritonKernelWrapper",
@ -1760,15 +1781,8 @@ class TracingTritonHOPifier(TritonHOPifier):
assert tx is None
assert isinstance(variable, TraceableTritonKernelWrapper)
def is_graphable(val: Any) -> bool:
return isinstance(val, fx.node.base_types)
graphable_args, constant_args_idx = self.store_non_graphable_args(combined_args)
non_graphable_args = {
k: v for k, v in combined_args.items() if not is_graphable(v)
}
graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)}
constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args)
assert isinstance(variable.kernel_idx, int)
return triton_kernel_wrapper_mutation(
kernel_idx=variable.kernel_idx,

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import atexit
import contextlib
import dataclasses
import enum
@ -8,8 +9,11 @@ import itertools
import logging
import math
import operator
import os
import re
import tempfile
import typing
from abc import ABC, abstractmethod
from enum import auto, Enum
from itertools import chain
from typing import (
@ -60,6 +64,8 @@ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
if TYPE_CHECKING:
from collections.abc import Iterator, MutableMapping, Sequence
from torch.fx import GraphModule
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
from ..loop_body import LoopBody
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
@ -83,6 +89,38 @@ def data_type_logger(msg: str) -> None:
schedule_log.debug("Data type propagation: %s", msg)
@dataclasses.dataclass
class FileBackedGraphModule:
"""
Output of FX wrapper codegen. Exposes the same methods as ModuleType, but these
map back to a GraphModule instead of Python source.
"""
gm: GraphModule
compiled_fn: Callable[..., Any]
def __post_init__(self) -> None:
# Write the code to a file for compatibility with debugging utilities.
# The file is deleted upon program termination.
self.tempfile = tempfile.NamedTemporaryFile(
mode="w+", suffix=".py", delete=False
)
atexit.register(os.remove, self.tempfile.name)
with self.tempfile as f:
f.write(self.value)
@property
def __file__(self) -> str:
return self.tempfile.name
def call(self, args: list[Any]) -> Any:
return self.compiled_fn(*args)
@property
def value(self) -> str:
return self.gm.code
class WorkspaceZeroMode(enum.Enum):
UNINITIALIZED = 0
ZERO_ON_CALL = 1 # kernel may leave workspace dirty
@ -103,8 +141,22 @@ class WorkspaceZeroMode(enum.Enum):
return WorkspaceZeroMode.UNINITIALIZED
class CodegenSymbol(ABC):
"""
An IR object possibly corresponding to a variable in the wrapper code.
"""
@abstractmethod
def get_name(self) -> str:
pass
@abstractmethod
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
pass
@ir_dataclass(frozen=True)
class WorkspaceArg:
class WorkspaceArg(CodegenSymbol):
"""A temporary buffer used for a single kernel, then discarded.
Not registered as a traditional buffer since there are no users,
@ -167,6 +219,9 @@ class WorkspaceArg:
def get_dtype(self) -> torch.dtype:
return self.dtype
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
return self.get_layout().get_example()
def get_layout(self) -> FixedLayout:
from ..ir import FixedLayout
@ -185,6 +240,9 @@ class WorkspaceArg:
maybe_get_output_spec = get_layout
maybe_get_layout = get_layout
def get_offset(self) -> sympy.Expr:
return sympy.S.Zero
def get_size(self) -> list[sympy.Expr]:
return [self.count]

View File

@ -74,6 +74,7 @@ if TYPE_CHECKING:
import triton
from ..graph import GraphLowering
from .wrapper_fxir import FxConverter
log = logging.getLogger(__name__)
@ -83,6 +84,7 @@ pexpr = PythonPrinter().doprint
ReuseKey = tuple[torch.device, torch.dtype, str, bool]
BufferLike = Union[ir.Buffer, WorkspaceArg]
FxConversionFunc = Callable[["WrapperLine"], None]
def buffer_reuse_key(node: BufferLike) -> ReuseKey:
@ -349,7 +351,8 @@ class MemoryPlanningState:
class WrapperLine:
pass
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
raise NotImplementedError("FX codegen not yet supported for type {type(self)}")
@dataclasses.dataclass
@ -364,6 +367,9 @@ class EnterSubgraphLine(WrapperLine):
self.wrapper.push_codegened_graph(self.graph)
code.do_indent()
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_enter_subgraph
@dataclasses.dataclass
class CommentLine(WrapperLine):
@ -372,6 +378,10 @@ class CommentLine(WrapperLine):
def codegen(self, code: IndentedBuffer) -> None:
code.writeline(self.line)
@staticmethod
def codegen_fx(converter: FxConverter) -> FxConversionFunc:
return converter._generate_comment
@dataclasses.dataclass
class ExitSubgraphLine(WrapperLine):
@ -384,6 +394,9 @@ class ExitSubgraphLine(WrapperLine):
self.wrapper.pop_codegened_graph()
code.do_unindent()
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_exit_subgraph
@dataclasses.dataclass
class EnterDeviceContextManagerLine(WrapperLine):
@ -419,12 +432,18 @@ class EnterDeviceContextManagerLine(WrapperLine):
code.do_indent()
code.writeline(V.graph.device_ops.set_device(self.device_idx))
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_enter_device_context_manager
class ExitDeviceContextManagerLine(WrapperLine):
def codegen(self, code: IndentedBuffer) -> None:
if not V.graph.cpp_wrapper:
code.do_unindent()
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_exit_device_context_manager
@dataclasses.dataclass
class ExternKernelAllocLine(WrapperLine):
@ -436,6 +455,9 @@ class ExternKernelAllocLine(WrapperLine):
args = [*node.codegen_args(), *node.codegen_kwargs()]
self.wrapper._generate_extern_kernel_alloc_helper(self.node, args)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_extern_kernel_alloc
@dataclasses.dataclass
class ExternKernelOutLine(WrapperLine):
@ -466,6 +488,9 @@ class ExternKernelOutLine(WrapperLine):
device,
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_extern_kernel_out
@dataclasses.dataclass
class FreeLine(WrapperLine):
@ -476,6 +501,9 @@ class FreeLine(WrapperLine):
assert self.node.get_name() not in V.graph.removed_buffers
code.writeline(self.wrapper.make_buffer_free(self.node))
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_free
@dataclasses.dataclass
class KernelCallLine(WrapperLine):
@ -505,6 +533,9 @@ class KernelCallLine(WrapperLine):
original_fxnode_name=self.original_fxnode_name,
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_kernel_call
@dataclasses.dataclass
class KernelDefinitionLine(WrapperLine):
@ -524,6 +555,9 @@ class KernelDefinitionLine(WrapperLine):
cpp_definition=self.cpp_definition,
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_kernel_definition
@dataclasses.dataclass
class MemoryPlanningLine(WrapperLine):
@ -580,6 +614,9 @@ class AllocateLine(MemoryPlanningLine):
line = self.wrapper.make_buffer_allocation(self.node)
code.writeline(line)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_allocate
@dataclasses.dataclass
class FreeIfNotReusedLine(MemoryPlanningLine):
@ -603,6 +640,9 @@ class FreeIfNotReusedLine(MemoryPlanningLine):
if not self.is_reused:
code.writeline(self.wrapper.make_buffer_free(self.node))
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_free_if_not_reused
@dataclasses.dataclass
class ReinterpretLine(MemoryPlanningLine):
@ -620,6 +660,9 @@ class ReinterpretLine(MemoryPlanningLine):
self.reused_as.get_name(), self.layout.view
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_reinterpret
@dataclasses.dataclass
class ReuseLine(MemoryPlanningLine):
@ -641,9 +684,13 @@ class ReuseLine(MemoryPlanningLine):
self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old)
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_reuse
class NullLine(MemoryPlanningLine):
pass
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_null
@dataclasses.dataclass
@ -717,6 +764,9 @@ class CommBufferAllocateLine(CommBufferLine):
f"Unsupported comm buffer type: {comm_buffer_type}"
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_comm_buffer_allocate
@dataclasses.dataclass
class CommBufferFreeLine(CommBufferLine):
@ -724,6 +774,9 @@ class CommBufferFreeLine(CommBufferLine):
line = self.wrapper.make_buffer_free(self.node)
code.writeline(f"{line} # {self.comm_buffer_type.value} buffer free")
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_comm_buffer_free
@dataclasses.dataclass
class MultiOutputLine(WrapperLine):
@ -760,6 +813,22 @@ class MultiOutputLine(WrapperLine):
f"{self.wrapper.declare}{self.result_name} = {value}{self.wrapper.ending}"
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_multi_output
@dataclasses.dataclass
class SymbolicCallArgLine(WrapperLine):
wrapper: PythonWrapperCodegen
arg: SymbolicCallArg
graph: GraphLowering
def codegen(self, code: IndentedBuffer) -> None:
self.wrapper._generate_symbolic_call_arg_helper(self.arg, self.graph)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_symbolic_call_arg
@dataclasses.dataclass
class SymbolicCallArgLine(WrapperLine):

View File

@ -0,0 +1,596 @@
import dataclasses
import operator
import textwrap
from collections import Counter
from typing import Any, Callable, Optional, Union
import sympy
import torch
from torch._higher_order_ops.triton_kernel_wrap import (
TraceableTritonKernelWrapper,
tracing_triton_hopifier_singleton,
triton_kernel_wrapper_mutation,
)
from torch._inductor.codecache import PyCodeCache
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from torch._inductor.select_algorithm import extern_kernels # noqa: F401
from torch._inductor.virtualized import V
from torch._library.triton import wrap_triton
from torch.fx import GraphModule
from .. import ir
from ..utils import convert_shape_to_symint, convert_to_symint, LineContext
from .common import (
CodegenSymbol,
FileBackedGraphModule,
WorkspaceArg,
WorkspaceZeroMode,
)
from .wrapper import (
AllocateLine,
BufferLike,
CommBufferAllocateLine,
CommBufferFreeLine,
CommentLine,
EnterDeviceContextManagerLine,
EnterSubgraphLine,
ExitDeviceContextManagerLine,
ExitSubgraphLine,
ExternKernelAllocLine,
ExternKernelOutLine,
FreeIfNotReusedLine,
FreeLine,
KernelCallLine,
KernelDefinitionLine,
Line,
MultiOutputLine,
NullLine,
PythonWrapperCodegen,
ReinterpretLine,
ReuseLine,
SymbolicCallArg,
SymbolicCallArgLine,
WrapperLine,
)
aten = torch.ops.aten
@dataclasses.dataclass
class SymbolBuffer(CodegenSymbol):
"""
Represents a sympy.Symbol graph input.
"""
symbol: sympy.Symbol
def get_name(self) -> str:
return str(self.symbol)
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
return self.symbol
CodegenBuffer = Union[BufferLike, SymbolBuffer]
@dataclasses.dataclass
class TritonKernel:
"""
Stores metadata about Triton kernels for use in FX.
"""
tuner: CachingAutotuner
wrapped: TraceableTritonKernelWrapper
class WrapperFxCodegen(PythonWrapperCodegen):
"""
Backend to generate wrapper code as an FX IR graph.
"""
supports_caching = False
def _generate(self, is_inference: bool) -> tuple[FileBackedGraphModule, None]:
self.run_wrapper_ir_passes(is_inference)
prologue = "\n".join(
[
self.imports.getvalue(),
self.header.getvalue(),
]
)
gm = FxConverter(lines=self.lines, prologue=prologue).generate()
compiled_fn = self.compile_graph(gm)
return FileBackedGraphModule(gm, compiled_fn), None
def compile_graph(self, gm: GraphModule) -> Callable[..., Any]:
"""
Converts the graph module into a runnable function. The default implementation
is simply an interpreter calling kernels in eager mode. Derived backends can
override this to do further compilation.
"""
return gm.forward
@classmethod
def create(
cls,
is_subgraph: bool,
subgraph_name: Optional[str],
parent_wrapper: Optional[PythonWrapperCodegen],
partition_signatures: Optional[ir.GraphPartitionSignature] = None,
) -> "WrapperFxCodegen":
if is_subgraph:
raise NotImplementedError(
"Subgraphs are not yet supported by FX conversion"
)
# For derived backends, this could be a subclass.
return cls()
@dataclasses.dataclass
class FxConverter:
"""
Generates FX IR from Wrapper IR. As each instance is only meant to be used once, the
input and output code are stored as attributes.
"""
lines: list[Line]
prologue: str = ""
def __post_init__(self) -> None:
graph = torch.fx.Graph()
self.gm = GraphModule({}, graph) # Wrapper FX IR.
self.buffer_to_node: dict[
Optional[str], torch.fx.Node
] = {} # Symbol table for codegen.
self.kernels: dict[str, TritonKernel] = {} # Table to store Triton kernels.
self._unique_symbol_ids: Counter[str] = Counter()
def _import_kernel(self, code: str, kernel_name: str) -> CachingAutotuner:
"""
Imports a kernel from source, possibly autotuning block parameters.
"""
module_code = "\n".join([self.prologue, code])
mod = PyCodeCache.load(module_code)
kernel = getattr(mod, kernel_name)
if not isinstance(kernel, CachingAutotuner):
raise NotImplementedError(
textwrap.dedent(f"""
Unsupported type for kernel {kernel_name}: {type(kernel)}.
FX conversion only supports Triton kernels.
""")
)
return kernel
def _fake_tensor(
self,
size: tuple[Any, ...],
stride: tuple[Any, ...],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
with V.fake_mode:
return torch.empty_strided(
convert_shape_to_symint(size),
convert_shape_to_symint(stride),
dtype=dtype,
device=device,
)
def _create_meta_from_buffer(
self, node: torch.fx.Node, buffer: CodegenBuffer
) -> None:
name = buffer.get_name()
assert name
node.name = name
node.meta["val"] = buffer.get_example()
def _record_allocation(self, buffer: CodegenBuffer, node: torch.fx.Node) -> None:
"""
Updates the symbol table to record that an Inductor buffer maps to the result of
an FX node.
"""
assert node not in self.buffer_to_node
self.buffer_to_node[buffer.get_name()] = node
def _free(self, buffer: Union[CodegenBuffer, ir.TorchBindObject]) -> None:
"""
Removes the buffer from the symbol table.
"""
name = buffer.get_name()
del self.buffer_to_node[name]
def _lookup_args(self, args: tuple[Any, ...]) -> tuple[Any, ...]:
"""
Maps call args back to FX nodes.
"""
return tuple(
self.buffer_to_node[arg]
if isinstance(arg, str)
else arg.inner_expr
if isinstance(arg, SymbolicCallArg)
else arg
for arg in args
)
def _get_buffer(self, node: ir.IRNode) -> CodegenBuffer:
"""
Extract buffer data from an IR node.
"""
if isinstance(node, (ir.Buffer, WorkspaceArg)):
return node
elif isinstance(node, (ir.BaseView, ir.MutableBox)):
return self._get_buffer(node.data)
elif isinstance(node, sympy.Symbol):
return SymbolBuffer(node)
else:
raise NotImplementedError(f"Unable to extract buffer from node: {node}")
def _generate_graph_inputs(self) -> None:
"""
Converts graph inputs to FX placeholders.
"""
for ir_node in V.graph.graph_inputs.values():
buffer = self._get_buffer(ir_node)
node = self.gm.graph.placeholder(buffer.get_name())
self._create_meta_from_buffer(node, buffer)
self._record_allocation(buffer, node)
def _generate_buffer(self, node: ir.IRNode) -> Optional[torch.fx.Node]:
"""
Generates FX IR for transformations on a buffer, such as ReinterpretView.
Does nothing if no such transformations are present.
"""
def generate_to_buffer(node: ir.IRNode) -> Optional[BufferLike]:
if isinstance(node, (ir.Buffer, WorkspaceArg)):
return node
elif isinstance(node, ir.NoneAsConstantBuffer):
return None
elif isinstance(node, ir.StorageBox):
return generate_to_buffer(node.data)
elif isinstance(node, ir.ReinterpretView):
# We need to introduce a new symbol if the output is a ReinterpretView.
# Use a WorkspaceArg for this.
buffer = self._get_buffer(node.data)
assert isinstance(buffer, (ir.Buffer, WorkspaceArg))
unique_name = self.gm.graph._graph_namespace.create_name(
f"{buffer.get_name()}_view", None
)
device = buffer.get_device()
assert device
reused_as = WorkspaceArg(
count=buffer.get_size(),
zero_mode=WorkspaceZeroMode.UNINITIALIZED,
device=device,
outer_name=unique_name,
dtype=buffer.get_dtype(),
)
# Generate FX IR for the view.
self._generate_reinterpret_helper(buffer, reused_as, node.layout)
return reused_as
else:
raise NotImplementedError(f"Unrecognized buffer/view node: {node}")
buffer = generate_to_buffer(node)
return self.buffer_to_node[buffer.get_name()] if buffer is not None else None
def _generate_output(self) -> None:
"""
Generate FX IR for graph outputs.
"""
output_nodes = [
self._generate_buffer(node)
for idx, node in enumerate(V.graph.graph_outputs)
]
# Single return elements don't use a tuple.
output_value = output_nodes[0] if len(output_nodes) == 1 else output_nodes
self.gm.graph.output(output_value)
def generate(self) -> torch.fx.GraphModule:
"""
Main entrypoint for FX codegen.
"""
self._generate_graph_inputs()
# Generate FX IR from Wrapper IR lines.
for line in self.lines:
if isinstance(line, WrapperLine):
line.codegen_fx(self)(line)
elif isinstance(line, LineContext):
# Ignore line context in FX IR.
pass
else:
raise NotImplementedError(
textwrap.dedent(
f"""
Found line of unrecognized type '{type(line)}':
'{line}'
FX conversion only supports Wrapper IR lines.
"""
)
)
self._generate_output()
self.gm.recompile()
return self.gm
def _generate_allocate(self, line: WrapperLine) -> None:
assert isinstance(line, AllocateLine)
buffer = line.node
name = buffer.get_name()
assert name not in V.graph.removed_buffers
device = buffer.get_device()
dtype = buffer.get_dtype()
shape = convert_shape_to_symint(buffer.get_size())
stride = convert_shape_to_symint(buffer.get_stride())
node = self.gm.graph.call_function(
torch.empty_strided,
args=(shape, stride),
kwargs={"dtype": dtype, "device": device},
)
assert name
node.name = name
self._create_meta_from_buffer(node, buffer)
self._record_allocation(buffer, node)
def _generate_comment(self, line: WrapperLine) -> None:
assert isinstance(line, CommentLine)
# We ignore comments in FX IR.
def _generate_enter_device_context_manager(self, line: WrapperLine) -> None:
assert isinstance(line, EnterDeviceContextManagerLine)
# We ignore the device context in FX IR.
def _generate_exit_device_context_manager(self, line: WrapperLine) -> None:
assert isinstance(line, ExitDeviceContextManagerLine)
# We ignore the device context in FX IR.
def _generate_enter_subgraph(self, line: WrapperLine) -> None:
assert isinstance(line, EnterSubgraphLine)
raise NotImplementedError("Subgraphs are not yet supported by FX conversion")
def _generate_exit_subgraph(self, line: WrapperLine) -> None:
assert isinstance(line, ExitSubgraphLine)
raise NotImplementedError("Subgraphs are not yet supported by FX conversion")
def _generate_free(self, line: WrapperLine) -> None:
assert isinstance(line, FreeLine)
buf = line.node
# No need to free placeholders.
if self.buffer_to_node[buf.get_name()].op == "placeholder":
return
self._free(buf)
def _generate_free_if_not_reused(self, line: WrapperLine) -> None:
assert isinstance(line, FreeIfNotReusedLine)
buf = line.node
assert buf.get_name() not in V.graph.removed_buffers
if not line.is_reused:
self._free(buf)
def _generate_line_context(self, line: WrapperLine) -> None:
assert isinstance(line, LineContext)
# We ignore line context in FX IR.
def _generate_reinterpret(self, line: WrapperLine) -> None:
assert isinstance(line, ReinterpretLine)
self._generate_reinterpret_helper(line.node, line.reused_as, line.layout)
def _generate_reinterpret_helper(
self, input_buffer: BufferLike, result_buffer: BufferLike, layout: ir.Layout
) -> None:
input_node = self.buffer_to_node[input_buffer.get_name()]
# Look up output metadata.
name = result_buffer.get_name()
assert name
size = tuple(layout.size)
stride = tuple(layout.stride)
offset = input_buffer.get_offset() + layout.offset
# Map ReinterpretView to as_strided.
result_node = self.gm.graph.call_function(
torch.as_strided, args=(input_node, size, stride, offset)
)
result_node.name = name
result_node.meta["val"] = layout.get_example()
self._record_allocation(result_buffer, result_node)
def _generate_reuse(self, line: WrapperLine) -> None:
assert isinstance(line, ReuseLine)
old = line.node
new = line.reused_as
assert not any(buf.get_name() in V.graph.removed_buffers for buf in (old, new))
assert old.get_dtype() == new.get_dtype()
old_node = self.buffer_to_node[old.get_name()]
result_node = old_node
# Change shape and stride.
size = new.get_size()
stride = new.get_stride()
offset = new.get_offset()
if (
old.get_size() != size
or old.get_stride() != stride
or old.get_offset() != offset
):
result_node = self.gm.graph.call_function(
torch.as_strided, args=(old_node, size, stride, offset)
)
self._create_meta_from_buffer(result_node, new)
self._record_allocation(new, result_node)
# Free the old buffer, if we allocated a new tensor.
if (
old.get_name() not in V.graph.get_output_names()
and line.delete_old
and result_node is not old_node
):
self._free(old)
def _generate_multi_output(self, line: WrapperLine) -> None:
assert isinstance(line, MultiOutputLine)
# Extract the index for tuple access.
inds = line.indices[0][1:]
assert len(inds) == 1, f"Cannot convert {inds} to an index."
idx = inds[0]
arg_node = self.buffer_to_node[line.arg_name]
node = self.gm.graph.call_function(operator.getitem, args=(arg_node, idx))
node.meta["val"] = arg_node.meta["val"][idx]
node.name = line.result_name
self.buffer_to_node[line.result_name] = node
def _generate_null(self, line: WrapperLine) -> None:
assert isinstance(line, NullLine)
# Does nothing.
def _generate_comm_buffer_allocate(self, line: WrapperLine) -> None:
assert isinstance(line, CommBufferAllocateLine)
raise NotImplementedError("Comm buffer allocation is not yet supported")
def _generate_comm_buffer_free(self, line: WrapperLine) -> None:
assert isinstance(line, CommBufferFreeLine)
self._free(line.node)
def _generate_triton_call(self, line: WrapperLine) -> None:
assert isinstance(line, KernelCallLine)
# Collect all kwargs, including autotuned block sizes.
call_args = self._lookup_args(line.call_args)
kernel = self.kernels[line.kernel_name]
tuner = kernel.tuner
config = tuner.compile_results[0].config
call_args, grid = tuner._interpret_args_grid(call_args, config)
call_kwargs = dict(zip(tuner.triton_meta["signature"], call_args))
call_kwargs.update(config.kwargs)
# Convert sympy expressions to symints.
for name, val in call_kwargs.items():
if isinstance(val, sympy.Expr):
call_kwargs[name] = convert_to_symint(val)
# Store non-graphable kwargs in the side table.
(
call_kwargs,
constant_args_idx,
) = tracing_triton_hopifier_singleton.store_non_graphable_args(call_kwargs)
self.gm.graph.call_function(
triton_kernel_wrapper_mutation,
kwargs={
"kernel_idx": kernel.wrapped.kernel_idx,
"constant_args_idx": constant_args_idx,
"grid": [convert_shape_to_symint(grid)],
"tma_descriptor_metadata": {},
"kwargs": call_kwargs,
},
)
def _generate_extern_kernel_alloc(self, line: WrapperLine) -> None:
assert isinstance(line, ExternKernelAllocLine)
node = line.node
self._generate_extern_kernel_common(node, node)
def _generate_extern_kernel_out(
self,
line: WrapperLine,
) -> None:
assert isinstance(line, ExternKernelOutLine)
node = line.node
out_node = node.output_view if node.output_view else node
self._generate_extern_kernel_common(node, out_node)
def _generate_extern_kernel_common(
self, kernel: ir.ExternKernel, out_ir_node: ir.IRNode
) -> None:
"""
Generates FX IR from either ExternKernelAlloc or ExternKernelOut.
"""
# Get FX nodes corresponding to the call args.
tensor_nodes = tuple(self._generate_buffer(arg) for arg in kernel.inputs)
args = tensor_nodes + tuple(kernel.constant_args)
# Get the result buffer.
# Some kernels write to a pre-existing output tensor via the "out" kwarg.
kwargs = kernel.kwargs.copy()
result_buffer: Optional[str] = None
if isinstance(kernel, ir.ExternKernelOut):
kwargs["out"] = self.buffer_to_node[out_ir_node.codegen_reference()]
elif isinstance(kernel.layout, (ir.Layout, ir.MultiOutputLayout)):
result_buffer = kernel.get_name()
elif isinstance(kernel.layout, ir.NoneLayout):
pass
else:
raise NotImplementedError(f"Unrecognized output layout: {kernel.layout}")
# Look up the kernel function from its name.
kernel_name = kernel.get_kernel_name()
module_name, kernel_name = kernel_name.split(".", 1)
op = globals()[module_name] # E.g. extern_kernels, aten, etc.
for subname in kernel_name.split("."):
op = getattr(op, subname) # E.g. extern_kernels.addmm
fx_node = self.gm.graph.call_function(op, args=args, kwargs=kwargs)
# Assign the result to the given name.
if result_buffer:
assert "out" not in kwargs, (
f"Extern kernel '{kernel}' has both result and out kwarg. Expected only one."
)
fx_node.name = result_buffer
self.buffer_to_node[result_buffer] = fx_node
arg_tensors = [
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
for arg in args
]
# Run the operation to propagate metadata.
fx_node.meta["val"] = op(*arg_tensors, **kwargs)
def _generate_kernel_call(self, line: WrapperLine) -> None:
assert isinstance(line, KernelCallLine)
if not line.triton:
raise NotImplementedError("FX conversion only supports Triton kernels.")
self._generate_triton_call(line)
def _generate_kernel_definition(self, line: WrapperLine) -> None:
assert isinstance(line, KernelDefinitionLine)
# Generate code for the kernel.
kernel_code = PythonWrapperCodegen._format_kernel_definition(
line.kernel_name, line.kernel_body, metadata=line.metadata
)
# Import the module and store the JIT kernel.
tuner = self._import_kernel(kernel_code, line.kernel_name)
wrapped = wrap_triton(tuner.fn)
self.kernels[line.kernel_name] = TritonKernel(tuner, wrapped)
def _generate_symbolic_call_arg(self, line: WrapperLine) -> None:
assert isinstance(line, SymbolicCallArgLine)
# No need for an FX node, as we will pass the arg to kernels via a SymInt.

View File

@ -50,6 +50,7 @@ from . import config, ir, metrics
from .codegen.common import (
BackendFeature,
DeviceOpOverrides,
FileBackedGraphModule,
get_backend_features,
get_device_op_overrides,
get_wrapper_codegen_for_device,
@ -115,9 +116,12 @@ if TYPE_CHECKING:
from torch._higher_order_ops.effects import _EffectType
from torch.fx import GraphModule
from torch.fx.graph import Graph
from .codegen.wrapper import PythonWrapperCodegen
from .scheduler import BaseSchedulerNode
CompiledModule = Union[ModuleType, FileBackedGraphModule]
from torch._inductor.codecache import output_code_log
@ -2224,7 +2228,7 @@ class GraphLowering(torch.fx.Interpreter):
# No-op to be patched for unit tests
save_output_code: Optional[Callable[[str], None]] = None
def compile_to_module(self) -> ModuleType:
def compile_to_module(self) -> CompiledModule:
with dynamo_timed(
"GraphLowering.compile_to_module",
phase_name="code_gen",
@ -2233,14 +2237,41 @@ class GraphLowering(torch.fx.Interpreter):
):
return self._compile_to_module()
def _compile_to_module(self) -> ModuleType:
from .codecache import PyCodeCache
def _compile_to_module(self) -> CompiledModule:
# Currently, if we're here, we don't have to worry about the kernel code, which
# is only available in AOTInductor mode.
wrapper_code, _ = (
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
)
if isinstance(wrapper_code, ValueWithLineMap):
mod = self._compile_to_module_lines(wrapper_code)
elif isinstance(wrapper_code, FileBackedGraphModule):
mod = wrapper_code
else:
raise NotImplementedError(
f"Unrecognized wrapper code type: {type(wrapper_code)}"
)
# Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
# TODO. Revisit this once the logging API is more mature
assert mod.__file__ is not None
log_module_code(mod.__file__)
log.debug("Output code written to: %s", mod.__file__)
output_code_log.info("Output code written to: %s", mod.__file__)
if config.benchmark_kernel:
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
V.debug.output_code(mod.__file__)
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
return mod
def _compile_to_module_lines(
self, wrapper_code: ValueWithLineMap
) -> CompiledModule:
from .codecache import PyCodeCache
if config.triton.autotune_at_compile_time:
tuning_code = (
'"""\n'
@ -2291,17 +2322,7 @@ class GraphLowering(torch.fx.Interpreter):
if config.benchmark_harness and config.profile_bandwidth_output:
# run the inputs code gen to get the bandwidth info
mod.benchmark_compiled_module(times=1, repeat=1)
# Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
# TODO. Revisit this once the logging API is more mature
assert mod.__file__ is not None
log_module_code(mod.__file__)
log.debug("Output code written to: %s", mod.__file__)
output_code_log.info("Output code written to: %s", mod.__file__)
if config.benchmark_kernel:
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
V.debug.output_code(mod.__file__)
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
return mod
def get_output_names(self) -> list[str]:

View File

@ -65,6 +65,7 @@ from torch.utils._sympy.symbol import SymT
from . import config, dependencies
from .codegen.common import (
BackendFeature,
CodegenSymbol,
get_scheduling_for_device,
index_prevent_reordering,
)
@ -3423,6 +3424,15 @@ class Layout(OutputSpec):
def get_device(self) -> torch.device:
return self.device
def get_example(self) -> torch.Tensor:
with V.fake_mode:
return torch.empty_strided(
convert_shape_to_symint(self.size),
convert_shape_to_symint(self.stride),
dtype=self.dtype,
device=self.device,
)
def is_contiguous(self) -> bool:
return is_contiguous_strides_for_shape(self.stride, self.size)
@ -3926,7 +3936,7 @@ class MutationLayoutSHOULDREMOVE(Layout):
@ir_dataclass(frozen=False)
class Buffer(IRNode):
class Buffer(IRNode, CodegenSymbol):
# Name is sometimes None; e.g., ForceInPlace, where there isn't
# a meaningful name
name: Optional[str]
@ -3946,6 +3956,11 @@ class Buffer(IRNode):
assert self.name, self
return self.name
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
if isinstance(self.layout, Layout):
return self.layout.get_example()
raise NotImplementedError(type(self.layout).__name__)
def get_device(self) -> Optional[torch.device]:
return self.get_output_spec().get_device()

View File

@ -85,7 +85,7 @@ class NoTritonConfigsError(RuntimeError):
if TYPE_CHECKING:
from collections.abc import Container, Hashable, Sequence
from collections.abc import Container, Hashable
from torch._guards import CompileId
@ -2564,13 +2564,15 @@ class GridExpr:
inductor_meta: dict[str, Any]
mode: Literal["python", "cpp"] = "python"
prefix: Sequence[str] = ()
prefix: list[str] = dataclasses.field(default_factory=list)
x_grid: Union[str, int] = 1
y_grid: Union[str, int] = 1
z_grid: Union[str, int] = 1
def __post_init__(self) -> None:
assert self.mode in ("python", "cpp")
if self.mode == "python":
self.prefix.append("from torch.utils._sympy.functions import FloorDiv")
def generate(self, meta: dict[str, int]) -> None:
raise NotImplementedError
@ -2583,7 +2585,9 @@ class GridExpr:
if isinstance(numel, int) and isinstance(block, int):
return ceildiv(numel, block) # constant fold
if self.mode == "python":
return f"-(({numel}) // -({block}))"
# Use FloorDiv instead of // so we can get better sympy expressions for
# dynamic shapes.
return f"-FloorDiv(({numel}), -({block}))"
# trick above doesn't work in C++ due to rounding differences
return f"(({numel} + ({block} - 1)) / ({block}))"
@ -2666,12 +2670,16 @@ class Grid3D(GridExpr):
class Grid2DWithYZOverflow(GridExpr):
def generate(self, meta: dict[str, int]) -> None:
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
self.prefix = [
self.assign_tmp("y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK"))),
self.assign_tmp(
"y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid())
),
]
self.prefix.extend(
[
self.assign_tmp(
"y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK"))
),
self.assign_tmp(
"y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid())
),
]
)
self.y_grid = self.ceildiv("y_grid_raw_", "y_grid_div_")
self.z_grid = "y_grid_div_"

View File

@ -436,6 +436,23 @@ def convert_shape_to_inductor(
return [sympy.sympify(i) for i in lst]
def convert_to_symint(i: Union[int, sympy.Expr]) -> Union[int, torch.SymInt]:
"""
Like convert_shape_to_symint, but operates on a single expression.
"""
from .virtualized import V
return (
i
if isinstance(i, int)
else (
int(i)
if isinstance(i, sympy.Integer)
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
)
)
def convert_shape_to_symint(
lst: Iterable[Union[int, sympy.Expr]],
) -> list[Union[int, torch.SymInt]]:
@ -443,20 +460,7 @@ def convert_shape_to_symint(
Takes a list of shapes from Inductor and converts them into symints (or just
ints if all shapes are static).
"""
from .virtualized import V
return [
(
i
if isinstance(i, int)
else (
int(i)
if isinstance(i, sympy.Integer)
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
)
)
for i in lst
]
return [convert_to_symint(i) for i in lst]
def is_view(op: torch._ops.OpOverload) -> bool: