mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] FX backend via Wrapper IR (#146942)
# Sub-PRs These PRs contain refactors from the main one. They should be reviewed and merged first. - https://github.com/pytorch/pytorch/pull/150458 - https://github.com/pytorch/pytorch/pull/152391 - https://github.com/pytorch/pytorch/pull/152587 # Feature The goals of this PR are twofold. ## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen. In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components. This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR. ## Goal 2: Convert Wrapper IR into FX IR. One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc. It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes. The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source. Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa. # Current status Things that seem to work: - Converted a lot of the most common Python codegen lines to Wrapper IR lines. - Handled the following cases, in addition to what was already in the Memory Planning IR: - Comments - Triton kernels - Extern/fallback kernels - Freeing tensors (`del buf0`) - MultiOutput - Graph outputs - ReinterpretView / StorageBox, for both call args and outputs. - FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code. - Prototype FX converter which can handle some of the most common use cases. - Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565). - Calling wrapped Triton kernels. - Calling extern kernels and certain types of fallback kernels. - Support both `extern_kernels.*` and `aten.*`. - Support multi-output kernels like `torch.topk`. - Graphs with multiple inputs/outputs. - Training i.e. calling `Tensor.backward()` in a compiled function. - Graph breaks (training). - Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen. Things that don't work: - Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections. - Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX. - Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR. - CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this. # Out-of-tree compilers With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below. ``` from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen class MyCustomBackend(WrapperFxCodegen): def compile_graph(self, gm): # Add 1 to the graph's outputs def compiled_fn(*args): return [x + 1 for x in gm.graph.forward(*args)] return compiled_fn ``` # Example FX graphs This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`. Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}}) return (buf0,) ``` Here's a more complicated graph that calls a `torch.addmm` extern kernel. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=2] = placeholder[target=arg1_1] %buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}}) %buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0}) %addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2}) %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {}) return (buf2,) ``` Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {}) %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {}) %buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {}) %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}}) %triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}}) return (buf1, buf2) ``` Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}}) %buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {}) return (buf0_view_buf0_0,) ``` Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`. ``` graph(): %s6 : [num_users=0] = placeholder[target=s6] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}}) return buf0 ``` Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions. ``` graph(): %s10 : [num_users=0] = placeholder[target=s10] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}}) return buf0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/146942 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
fdadda21b6
commit
a7691140a0
417
test/inductor/test_fxir_backend.py
Normal file
417
test/inductor/test_fxir_backend.py
Normal file
@ -0,0 +1,417 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
"""
|
||||
Test the FX IR backend.
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import operator
|
||||
import unittest
|
||||
from typing import Callable, Optional
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch._inductor.codegen.common as common
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.exc import BackendCompilerFailed
|
||||
from torch._dynamo.utils import same
|
||||
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.common import register_backend_for_device
|
||||
from torch._inductor.codegen.cpp import CppScheduling
|
||||
from torch._inductor.codegen.triton import TritonScheduling
|
||||
from torch._inductor.codegen.wrapper_fxir import FxConverter, WrapperFxCodegen
|
||||
from torch._inductor.select_algorithm import extern_kernels
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
GPU_TYPE,
|
||||
HAS_GPU,
|
||||
requires_gpu,
|
||||
TRITON_HAS_CPU,
|
||||
)
|
||||
|
||||
|
||||
@requires_gpu()
|
||||
@config.patch(
|
||||
compile_threads=1,
|
||||
alignment_asserts=False,
|
||||
size_asserts=False,
|
||||
scalar_asserts=False,
|
||||
nan_asserts=False,
|
||||
)
|
||||
class FxirTestCase(InductorTestCase):
|
||||
device = GPU_TYPE
|
||||
|
||||
def _count_ops(self, gm: torch.fx.GraphModule, target: Callable) -> int:
|
||||
return len(gm.graph.find_nodes(op="call_function", target=target))
|
||||
|
||||
def _run_and_capture_graphs(self, opt, args) -> torch.fx.GraphModule:
|
||||
gms = []
|
||||
|
||||
orig_generate = FxConverter.generate
|
||||
|
||||
def generate(self) -> torch.fx.GraphModule:
|
||||
nonlocal gms
|
||||
gm = orig_generate(self)
|
||||
gms.append(gm)
|
||||
return gm
|
||||
|
||||
with unittest.mock.patch.object(
|
||||
torch._inductor.codegen.wrapper_fxir.FxConverter, "generate", generate
|
||||
):
|
||||
opt(*args)
|
||||
|
||||
return gms
|
||||
|
||||
def _compile_and_check(
|
||||
self,
|
||||
func,
|
||||
args,
|
||||
expected_num_triton_kernels: int = 1,
|
||||
metadata_only: bool = False,
|
||||
compile_kwargs: Optional[dict] = None,
|
||||
):
|
||||
if compile_kwargs is None:
|
||||
compile_kwargs = {}
|
||||
|
||||
opt = torch.compile(func, **compile_kwargs)
|
||||
|
||||
# Get the FX graph from the backend.
|
||||
gms = self._run_and_capture_graphs(opt, args)
|
||||
|
||||
# Check the code for triton kernels.
|
||||
num_kernels = sum(
|
||||
self._count_ops(gm, triton_kernel_wrapper_mutation) for gm in gms
|
||||
)
|
||||
self.assertEqual(num_kernels, expected_num_triton_kernels)
|
||||
|
||||
# Check accuracy.
|
||||
result = opt(*args)
|
||||
ref = func(*args)
|
||||
if metadata_only:
|
||||
# When we only want to check metadata, fill in zeros for tensor data.
|
||||
ref, result = tuple(
|
||||
pytree.tree_map(torch.zeros_like, x) for x in (ref, result)
|
||||
)
|
||||
|
||||
self.assertTrue(same(ref, result))
|
||||
|
||||
return gms
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
# Register the FX backend.
|
||||
register_backend_for_device(cls.device, TritonScheduling, WrapperFxCodegen)
|
||||
|
||||
def test_basic(self):
|
||||
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
||||
self._compile_and_check(torch.add, args)
|
||||
|
||||
def test_multiple_kernels(self):
|
||||
def foo(x, y):
|
||||
return x.sum() + y.sum()
|
||||
|
||||
args = [torch.randn(length, device=self.device) for length in [517, 1029]]
|
||||
self._compile_and_check(foo, args, expected_num_triton_kernels=2)
|
||||
|
||||
def test_free(self):
|
||||
"""
|
||||
Test a program that frees a buffer which is no longer in use.
|
||||
"""
|
||||
|
||||
def foo(x, y, z):
|
||||
w = x.sum() + y
|
||||
return z.sum() + w.sum()
|
||||
|
||||
args = [torch.randn(length, device=self.device) for length in [517, 1029, 123]]
|
||||
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=3)
|
||||
|
||||
# Check the generated code for frees.
|
||||
num_frees = gm.code.count("= None")
|
||||
self.assertGreater(num_frees, 0)
|
||||
|
||||
def test_extern(self):
|
||||
"""
|
||||
Test a program that calls an extern kernel.
|
||||
"""
|
||||
|
||||
def foo(x, y):
|
||||
return x @ y + y.sum()
|
||||
|
||||
args = [
|
||||
torch.randn(size, device=self.device) for size in [(129, 129), (129, 1)]
|
||||
]
|
||||
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
||||
|
||||
# Check for the extern kernel
|
||||
num_extern = self._count_ops(gm, extern_kernels.addmm)
|
||||
self.assertEqual(num_extern, 1)
|
||||
|
||||
def test_fallback(self):
|
||||
"""
|
||||
Test a program that calls an aten fallback.
|
||||
"""
|
||||
|
||||
length = 8
|
||||
|
||||
def foo(x):
|
||||
return x + torch.randn(1, device=self.device)
|
||||
|
||||
args = (torch.randn(length, device=self.device),)
|
||||
|
||||
# Since the program has a random output, just check metadata.
|
||||
# Don't check for an exact value.
|
||||
(gm,) = self._compile_and_check(
|
||||
foo, args, expected_num_triton_kernels=2, metadata_only=True
|
||||
)
|
||||
|
||||
# Check for the fallback kernel.
|
||||
num_fallback = self._count_ops(gm, torch.ops.aten.randint.low_out)
|
||||
self.assertEqual(num_fallback, 1)
|
||||
|
||||
def test_cat_inputs(self):
|
||||
"""
|
||||
Test concatenation of graph inputs.
|
||||
"""
|
||||
|
||||
def foo(x, y):
|
||||
return torch.cat((x, y)) + 1
|
||||
|
||||
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
||||
self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
||||
|
||||
def test_cat_to_alloc(self):
|
||||
"""
|
||||
Test concatenation that's optimized out to an allocation.
|
||||
"""
|
||||
length = 8
|
||||
|
||||
def foo(x):
|
||||
y, z = tuple(
|
||||
torch.arange(length // 2, device=self.device) for _ in range(2)
|
||||
)
|
||||
return x + torch.cat((y, z))
|
||||
|
||||
args = [torch.randn(length, device=self.device)]
|
||||
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
||||
|
||||
# Expect a single allocation, even though eager mode would use 2.
|
||||
num_allocs = self._count_ops(gm, torch.empty_strided)
|
||||
self.assertEqual(num_allocs, 1)
|
||||
|
||||
def test_cat_reinterpret_view(self):
|
||||
"""
|
||||
Test torch.cat using ReinterpretView.
|
||||
"""
|
||||
length = 8
|
||||
|
||||
def foo(x):
|
||||
y, z = tuple(torch.randn(length // 2, device=self.device) for _ in range(2))
|
||||
return x + torch.cat((y, z))
|
||||
|
||||
args = [torch.randn(length, device=self.device)]
|
||||
|
||||
# Since this test generates random numbers, check metadata only.
|
||||
(gm,) = self._compile_and_check(
|
||||
foo, args, expected_num_triton_kernels=3, metadata_only=True
|
||||
)
|
||||
|
||||
# Check for as_strided. We map ReinterpretView to this.
|
||||
num_as_strided = self._count_ops(gm, torch.as_strided)
|
||||
self.assertEqual(num_as_strided, 2)
|
||||
|
||||
def test_reshape_output(self):
|
||||
"""
|
||||
Test reshaping the output, which maps to a ReinterpretView.
|
||||
"""
|
||||
|
||||
def foo(x, y):
|
||||
return torch.reshape(x + y, (8,))
|
||||
|
||||
args = [torch.randn((2, 4), device=self.device) for _ in range(2)]
|
||||
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
||||
|
||||
# Check for as_strided. We map ReinterpretView to this.
|
||||
num_as_strided = self._count_ops(gm, torch.as_strided)
|
||||
self.assertEqual(num_as_strided, 1)
|
||||
|
||||
def test_extern_multi_output(self):
|
||||
"""
|
||||
Test an extern kernel with multiple outputs.
|
||||
Also test a graph with multiple outputs.
|
||||
"""
|
||||
|
||||
def foo(x):
|
||||
top, idx = torch.topk(x, 2)
|
||||
return top + 1, idx * 2
|
||||
|
||||
args = [torch.randn(8, device=self.device)]
|
||||
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=2)
|
||||
|
||||
# Check for multiple kernel outputs via getitems.
|
||||
num_getitems = self._count_ops(gm, operator.getitem)
|
||||
self.assertEqual(num_getitems, 2)
|
||||
|
||||
# Check for multiple graph outputs.
|
||||
output_node = gm.graph.find_nodes(op="output")[0]
|
||||
self.assertEqual(len(output_node.args[0]), 2)
|
||||
|
||||
def test_duplicate_input(self):
|
||||
"""
|
||||
Test duplicated inputs. This will collapse into a single input in the GM.
|
||||
"""
|
||||
|
||||
args = [torch.randn(4, device=self.device)] * 2
|
||||
(gm,) = self._compile_and_check(torch.add, args, expected_num_triton_kernels=1)
|
||||
|
||||
num_placeholders = len(gm.graph.find_nodes(op="placeholder"))
|
||||
self.assertEqual(num_placeholders, 1)
|
||||
|
||||
def test_backward(self):
|
||||
"""
|
||||
Test a program with a backward pass.
|
||||
"""
|
||||
|
||||
x = torch.ones(5, device=self.device) # input tensor
|
||||
y = torch.zeros(3, device=self.device) # expected output
|
||||
w = torch.randn(5, 3, requires_grad=True, device=self.device)
|
||||
b = torch.randn(3, requires_grad=True, device=self.device)
|
||||
|
||||
def foo(x, y):
|
||||
z = torch.matmul(x, w) + b
|
||||
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
|
||||
loss.backward()
|
||||
return w.grad, b.grad
|
||||
|
||||
# Expect separate forward and backward graphs.
|
||||
(forward_gm, backward_gm) = self._compile_and_check(
|
||||
foo, (x, y), expected_num_triton_kernels=3
|
||||
)
|
||||
|
||||
def test_custom_compiler(self):
|
||||
"""
|
||||
Test a derived backend with a custom compiler.
|
||||
"""
|
||||
offset = 1
|
||||
|
||||
class CustomWrapperCodegen(WrapperFxCodegen):
|
||||
def compile_graph(self, gm):
|
||||
def compiled_fn(*args):
|
||||
# Adds an offset to the program's outputs.
|
||||
outputs = gm(*args)
|
||||
return pytree.tree_map(lambda x: x + 1, outputs)
|
||||
|
||||
return compiled_fn
|
||||
|
||||
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
||||
custom_backend = common.DeviceCodegen(
|
||||
TritonScheduling, CustomWrapperCodegen, None
|
||||
)
|
||||
with unittest.mock.patch.dict(
|
||||
common.device_codegens, {self.device: custom_backend}
|
||||
):
|
||||
func = torch.add
|
||||
opt = torch.compile(func)
|
||||
result = opt(*args)
|
||||
|
||||
# Check the output is offset from eager mode.
|
||||
ref = func(*args)
|
||||
self.assertFalse(same(result, ref))
|
||||
self.assertNotEqual(offset, 0)
|
||||
self.assertTrue(same(result - offset, ref))
|
||||
|
||||
def test_dynamic_shapes_and_strides(self):
|
||||
"""
|
||||
Test a graph with dynamic shapes and strides.
|
||||
"""
|
||||
|
||||
static_dims = (8, 8)
|
||||
|
||||
def get_input():
|
||||
full_size = (16, 8)
|
||||
full = torch.randn(full_size, device=self.device)
|
||||
view = torch.as_strided(full, static_dims, full.stride())
|
||||
return view
|
||||
|
||||
func = torch.add
|
||||
args = [get_input() for _ in range(2)]
|
||||
(gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True})
|
||||
|
||||
# Check for a symbolic output shape.
|
||||
(empty_strided,) = gm.graph.find_nodes(
|
||||
op="call_function", target=torch.empty_strided
|
||||
)
|
||||
example_tensor = empty_strided.meta["val"]
|
||||
symbolic_dims = example_tensor.shape
|
||||
self.assertEqual(len(symbolic_dims), len(static_dims))
|
||||
|
||||
# Check for symbolic output strides.
|
||||
(stride, one) = example_tensor.stride()
|
||||
self.assertEqual(one, sympy.S.One)
|
||||
|
||||
# Find the size symbols, and check for a corresponding placeholders defining them.
|
||||
for symbol in itertools.chain(symbolic_dims, [stride]):
|
||||
self.assertTrue(isinstance(symbol, torch.SymInt))
|
||||
(placeholder,) = [
|
||||
node
|
||||
for node in gm.graph.find_nodes(op="placeholder")
|
||||
if node.name == str(symbol)
|
||||
]
|
||||
self.assertEqual(placeholder.meta["val"], symbol)
|
||||
|
||||
@config.patch({"trace.enabled": True})
|
||||
@unittest.mock.patch("torch._inductor.debug.DebugFormatter.output_code")
|
||||
def test_debug(self, mock_output_code):
|
||||
# Compile in debug mode.
|
||||
args = [torch.randn(11, device=self.device) for _ in range(2)]
|
||||
self._compile_and_check(torch.sub, args)
|
||||
|
||||
# Check the output code for a Triton kernel call.
|
||||
mock_output_code.assert_called_once()
|
||||
(output_filename,) = mock_output_code.call_args.args
|
||||
with open(output_filename) as f:
|
||||
output_code = f.read()
|
||||
self.assertIn("triton_kernel_wrapper_mutation", output_code)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_subgraph_raises(self):
|
||||
"""
|
||||
Test a model with subgraphs. This is not yet supported, so check that we get the
|
||||
expected exception.
|
||||
"""
|
||||
|
||||
def foo(cond, x):
|
||||
return torch.cond(cond, torch.cos, torch.sin, [x])
|
||||
|
||||
cond = torch.tensor([True], device=self.device)
|
||||
x = torch.ones([2, 3], device=self.device)
|
||||
|
||||
with self.assertRaisesRegex(BackendCompilerFailed, "Subgraph"):
|
||||
self._compile_and_check(foo, [cond, x])
|
||||
|
||||
def test_cpp_raises(self):
|
||||
"""
|
||||
Test the C++ CPU backend. C++ kernels are not yet supported, so for now check
|
||||
that we get the expected exception.
|
||||
"""
|
||||
|
||||
def foo(x, y):
|
||||
return x + y * 5
|
||||
|
||||
device = torch.device("cpu")
|
||||
args = [torch.randn(5, device=device) for _ in range(2)]
|
||||
|
||||
cpp_backend = common.DeviceCodegen(CppScheduling, WrapperFxCodegen, None)
|
||||
with unittest.mock.patch.dict(
|
||||
common.device_codegens, {device.type: cpp_backend}
|
||||
), self.assertRaisesRegex(BackendCompilerFailed, "Triton"):
|
||||
self._compile_and_check(foo, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
if HAS_GPU or TRITON_HAS_CPU:
|
||||
run_tests(needs="filelock")
|
@ -1750,6 +1750,27 @@ class TracingTritonHOPifier(TritonHOPifier):
|
||||
# normalize to tuple
|
||||
return tuple(grid)
|
||||
|
||||
def store_non_graphable_args(
|
||||
self,
|
||||
combined_args: dict[str, Any],
|
||||
) -> tuple[dict, int]:
|
||||
"""
|
||||
Some args cannot be stored in the FX graph.
|
||||
Put them in the side table.
|
||||
"""
|
||||
|
||||
def is_graphable(val: Any) -> bool:
|
||||
return isinstance(val, (fx.node.base_types, fx.Node))
|
||||
|
||||
non_graphable_args = {
|
||||
k: v for k, v in combined_args.items() if not is_graphable(v)
|
||||
}
|
||||
graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)}
|
||||
|
||||
constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args)
|
||||
|
||||
return graphable_args, constant_args_idx
|
||||
|
||||
def call_HOP(
|
||||
self,
|
||||
variable: "TraceableTritonKernelWrapper",
|
||||
@ -1760,15 +1781,8 @@ class TracingTritonHOPifier(TritonHOPifier):
|
||||
assert tx is None
|
||||
assert isinstance(variable, TraceableTritonKernelWrapper)
|
||||
|
||||
def is_graphable(val: Any) -> bool:
|
||||
return isinstance(val, fx.node.base_types)
|
||||
graphable_args, constant_args_idx = self.store_non_graphable_args(combined_args)
|
||||
|
||||
non_graphable_args = {
|
||||
k: v for k, v in combined_args.items() if not is_graphable(v)
|
||||
}
|
||||
graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)}
|
||||
|
||||
constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args)
|
||||
assert isinstance(variable.kernel_idx, int)
|
||||
return triton_kernel_wrapper_mutation(
|
||||
kernel_idx=variable.kernel_idx,
|
||||
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import enum
|
||||
@ -8,8 +9,11 @@ import itertools
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import auto, Enum
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
@ -60,6 +64,8 @@ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, MutableMapping, Sequence
|
||||
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
|
||||
from ..loop_body import LoopBody
|
||||
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
|
||||
@ -83,6 +89,38 @@ def data_type_logger(msg: str) -> None:
|
||||
schedule_log.debug("Data type propagation: %s", msg)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FileBackedGraphModule:
|
||||
"""
|
||||
Output of FX wrapper codegen. Exposes the same methods as ModuleType, but these
|
||||
map back to a GraphModule instead of Python source.
|
||||
"""
|
||||
|
||||
gm: GraphModule
|
||||
compiled_fn: Callable[..., Any]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Write the code to a file for compatibility with debugging utilities.
|
||||
# The file is deleted upon program termination.
|
||||
self.tempfile = tempfile.NamedTemporaryFile(
|
||||
mode="w+", suffix=".py", delete=False
|
||||
)
|
||||
atexit.register(os.remove, self.tempfile.name)
|
||||
with self.tempfile as f:
|
||||
f.write(self.value)
|
||||
|
||||
@property
|
||||
def __file__(self) -> str:
|
||||
return self.tempfile.name
|
||||
|
||||
def call(self, args: list[Any]) -> Any:
|
||||
return self.compiled_fn(*args)
|
||||
|
||||
@property
|
||||
def value(self) -> str:
|
||||
return self.gm.code
|
||||
|
||||
|
||||
class WorkspaceZeroMode(enum.Enum):
|
||||
UNINITIALIZED = 0
|
||||
ZERO_ON_CALL = 1 # kernel may leave workspace dirty
|
||||
@ -103,8 +141,22 @@ class WorkspaceZeroMode(enum.Enum):
|
||||
return WorkspaceZeroMode.UNINITIALIZED
|
||||
|
||||
|
||||
class CodegenSymbol(ABC):
|
||||
"""
|
||||
An IR object possibly corresponding to a variable in the wrapper code.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
|
||||
pass
|
||||
|
||||
|
||||
@ir_dataclass(frozen=True)
|
||||
class WorkspaceArg:
|
||||
class WorkspaceArg(CodegenSymbol):
|
||||
"""A temporary buffer used for a single kernel, then discarded.
|
||||
|
||||
Not registered as a traditional buffer since there are no users,
|
||||
@ -167,6 +219,9 @@ class WorkspaceArg:
|
||||
def get_dtype(self) -> torch.dtype:
|
||||
return self.dtype
|
||||
|
||||
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
|
||||
return self.get_layout().get_example()
|
||||
|
||||
def get_layout(self) -> FixedLayout:
|
||||
from ..ir import FixedLayout
|
||||
|
||||
@ -185,6 +240,9 @@ class WorkspaceArg:
|
||||
maybe_get_output_spec = get_layout
|
||||
maybe_get_layout = get_layout
|
||||
|
||||
def get_offset(self) -> sympy.Expr:
|
||||
return sympy.S.Zero
|
||||
|
||||
def get_size(self) -> list[sympy.Expr]:
|
||||
return [self.count]
|
||||
|
||||
|
@ -74,6 +74,7 @@ if TYPE_CHECKING:
|
||||
import triton
|
||||
|
||||
from ..graph import GraphLowering
|
||||
from .wrapper_fxir import FxConverter
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -83,6 +84,7 @@ pexpr = PythonPrinter().doprint
|
||||
|
||||
ReuseKey = tuple[torch.device, torch.dtype, str, bool]
|
||||
BufferLike = Union[ir.Buffer, WorkspaceArg]
|
||||
FxConversionFunc = Callable[["WrapperLine"], None]
|
||||
|
||||
|
||||
def buffer_reuse_key(node: BufferLike) -> ReuseKey:
|
||||
@ -349,7 +351,8 @@ class MemoryPlanningState:
|
||||
|
||||
|
||||
class WrapperLine:
|
||||
pass
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
raise NotImplementedError("FX codegen not yet supported for type {type(self)}")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -364,6 +367,9 @@ class EnterSubgraphLine(WrapperLine):
|
||||
self.wrapper.push_codegened_graph(self.graph)
|
||||
code.do_indent()
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_enter_subgraph
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CommentLine(WrapperLine):
|
||||
@ -372,6 +378,10 @@ class CommentLine(WrapperLine):
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
code.writeline(self.line)
|
||||
|
||||
@staticmethod
|
||||
def codegen_fx(converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_comment
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExitSubgraphLine(WrapperLine):
|
||||
@ -384,6 +394,9 @@ class ExitSubgraphLine(WrapperLine):
|
||||
self.wrapper.pop_codegened_graph()
|
||||
code.do_unindent()
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_exit_subgraph
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class EnterDeviceContextManagerLine(WrapperLine):
|
||||
@ -419,12 +432,18 @@ class EnterDeviceContextManagerLine(WrapperLine):
|
||||
code.do_indent()
|
||||
code.writeline(V.graph.device_ops.set_device(self.device_idx))
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_enter_device_context_manager
|
||||
|
||||
|
||||
class ExitDeviceContextManagerLine(WrapperLine):
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
if not V.graph.cpp_wrapper:
|
||||
code.do_unindent()
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_exit_device_context_manager
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExternKernelAllocLine(WrapperLine):
|
||||
@ -436,6 +455,9 @@ class ExternKernelAllocLine(WrapperLine):
|
||||
args = [*node.codegen_args(), *node.codegen_kwargs()]
|
||||
self.wrapper._generate_extern_kernel_alloc_helper(self.node, args)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_extern_kernel_alloc
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExternKernelOutLine(WrapperLine):
|
||||
@ -466,6 +488,9 @@ class ExternKernelOutLine(WrapperLine):
|
||||
device,
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_extern_kernel_out
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FreeLine(WrapperLine):
|
||||
@ -476,6 +501,9 @@ class FreeLine(WrapperLine):
|
||||
assert self.node.get_name() not in V.graph.removed_buffers
|
||||
code.writeline(self.wrapper.make_buffer_free(self.node))
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_free
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KernelCallLine(WrapperLine):
|
||||
@ -505,6 +533,9 @@ class KernelCallLine(WrapperLine):
|
||||
original_fxnode_name=self.original_fxnode_name,
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_kernel_call
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KernelDefinitionLine(WrapperLine):
|
||||
@ -524,6 +555,9 @@ class KernelDefinitionLine(WrapperLine):
|
||||
cpp_definition=self.cpp_definition,
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_kernel_definition
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MemoryPlanningLine(WrapperLine):
|
||||
@ -580,6 +614,9 @@ class AllocateLine(MemoryPlanningLine):
|
||||
line = self.wrapper.make_buffer_allocation(self.node)
|
||||
code.writeline(line)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_allocate
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FreeIfNotReusedLine(MemoryPlanningLine):
|
||||
@ -603,6 +640,9 @@ class FreeIfNotReusedLine(MemoryPlanningLine):
|
||||
if not self.is_reused:
|
||||
code.writeline(self.wrapper.make_buffer_free(self.node))
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_free_if_not_reused
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReinterpretLine(MemoryPlanningLine):
|
||||
@ -620,6 +660,9 @@ class ReinterpretLine(MemoryPlanningLine):
|
||||
self.reused_as.get_name(), self.layout.view
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_reinterpret
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReuseLine(MemoryPlanningLine):
|
||||
@ -641,9 +684,13 @@ class ReuseLine(MemoryPlanningLine):
|
||||
self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old)
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_reuse
|
||||
|
||||
|
||||
class NullLine(MemoryPlanningLine):
|
||||
pass
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_null
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -717,6 +764,9 @@ class CommBufferAllocateLine(CommBufferLine):
|
||||
f"Unsupported comm buffer type: {comm_buffer_type}"
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_comm_buffer_allocate
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CommBufferFreeLine(CommBufferLine):
|
||||
@ -724,6 +774,9 @@ class CommBufferFreeLine(CommBufferLine):
|
||||
line = self.wrapper.make_buffer_free(self.node)
|
||||
code.writeline(f"{line} # {self.comm_buffer_type.value} buffer free")
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_comm_buffer_free
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MultiOutputLine(WrapperLine):
|
||||
@ -760,6 +813,22 @@ class MultiOutputLine(WrapperLine):
|
||||
f"{self.wrapper.declare}{self.result_name} = {value}{self.wrapper.ending}"
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_multi_output
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SymbolicCallArgLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
arg: SymbolicCallArg
|
||||
graph: GraphLowering
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
self.wrapper._generate_symbolic_call_arg_helper(self.arg, self.graph)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_symbolic_call_arg
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SymbolicCallArgLine(WrapperLine):
|
||||
|
596
torch/_inductor/codegen/wrapper_fxir.py
Normal file
596
torch/_inductor/codegen/wrapper_fxir.py
Normal file
@ -0,0 +1,596 @@
|
||||
import dataclasses
|
||||
import operator
|
||||
import textwrap
|
||||
from collections import Counter
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.triton_kernel_wrap import (
|
||||
TraceableTritonKernelWrapper,
|
||||
tracing_triton_hopifier_singleton,
|
||||
triton_kernel_wrapper_mutation,
|
||||
)
|
||||
from torch._inductor.codecache import PyCodeCache
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
from torch._inductor.select_algorithm import extern_kernels # noqa: F401
|
||||
from torch._inductor.virtualized import V
|
||||
from torch._library.triton import wrap_triton
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from .. import ir
|
||||
from ..utils import convert_shape_to_symint, convert_to_symint, LineContext
|
||||
from .common import (
|
||||
CodegenSymbol,
|
||||
FileBackedGraphModule,
|
||||
WorkspaceArg,
|
||||
WorkspaceZeroMode,
|
||||
)
|
||||
from .wrapper import (
|
||||
AllocateLine,
|
||||
BufferLike,
|
||||
CommBufferAllocateLine,
|
||||
CommBufferFreeLine,
|
||||
CommentLine,
|
||||
EnterDeviceContextManagerLine,
|
||||
EnterSubgraphLine,
|
||||
ExitDeviceContextManagerLine,
|
||||
ExitSubgraphLine,
|
||||
ExternKernelAllocLine,
|
||||
ExternKernelOutLine,
|
||||
FreeIfNotReusedLine,
|
||||
FreeLine,
|
||||
KernelCallLine,
|
||||
KernelDefinitionLine,
|
||||
Line,
|
||||
MultiOutputLine,
|
||||
NullLine,
|
||||
PythonWrapperCodegen,
|
||||
ReinterpretLine,
|
||||
ReuseLine,
|
||||
SymbolicCallArg,
|
||||
SymbolicCallArgLine,
|
||||
WrapperLine,
|
||||
)
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SymbolBuffer(CodegenSymbol):
|
||||
"""
|
||||
Represents a sympy.Symbol graph input.
|
||||
"""
|
||||
|
||||
symbol: sympy.Symbol
|
||||
|
||||
def get_name(self) -> str:
|
||||
return str(self.symbol)
|
||||
|
||||
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
|
||||
return self.symbol
|
||||
|
||||
|
||||
CodegenBuffer = Union[BufferLike, SymbolBuffer]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TritonKernel:
|
||||
"""
|
||||
Stores metadata about Triton kernels for use in FX.
|
||||
"""
|
||||
|
||||
tuner: CachingAutotuner
|
||||
wrapped: TraceableTritonKernelWrapper
|
||||
|
||||
|
||||
class WrapperFxCodegen(PythonWrapperCodegen):
|
||||
"""
|
||||
Backend to generate wrapper code as an FX IR graph.
|
||||
"""
|
||||
|
||||
supports_caching = False
|
||||
|
||||
def _generate(self, is_inference: bool) -> tuple[FileBackedGraphModule, None]:
|
||||
self.run_wrapper_ir_passes(is_inference)
|
||||
|
||||
prologue = "\n".join(
|
||||
[
|
||||
self.imports.getvalue(),
|
||||
self.header.getvalue(),
|
||||
]
|
||||
)
|
||||
gm = FxConverter(lines=self.lines, prologue=prologue).generate()
|
||||
compiled_fn = self.compile_graph(gm)
|
||||
|
||||
return FileBackedGraphModule(gm, compiled_fn), None
|
||||
|
||||
def compile_graph(self, gm: GraphModule) -> Callable[..., Any]:
|
||||
"""
|
||||
Converts the graph module into a runnable function. The default implementation
|
||||
is simply an interpreter calling kernels in eager mode. Derived backends can
|
||||
override this to do further compilation.
|
||||
"""
|
||||
return gm.forward
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
is_subgraph: bool,
|
||||
subgraph_name: Optional[str],
|
||||
parent_wrapper: Optional[PythonWrapperCodegen],
|
||||
partition_signatures: Optional[ir.GraphPartitionSignature] = None,
|
||||
) -> "WrapperFxCodegen":
|
||||
if is_subgraph:
|
||||
raise NotImplementedError(
|
||||
"Subgraphs are not yet supported by FX conversion"
|
||||
)
|
||||
|
||||
# For derived backends, this could be a subclass.
|
||||
return cls()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FxConverter:
|
||||
"""
|
||||
Generates FX IR from Wrapper IR. As each instance is only meant to be used once, the
|
||||
input and output code are stored as attributes.
|
||||
"""
|
||||
|
||||
lines: list[Line]
|
||||
prologue: str = ""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
graph = torch.fx.Graph()
|
||||
self.gm = GraphModule({}, graph) # Wrapper FX IR.
|
||||
self.buffer_to_node: dict[
|
||||
Optional[str], torch.fx.Node
|
||||
] = {} # Symbol table for codegen.
|
||||
self.kernels: dict[str, TritonKernel] = {} # Table to store Triton kernels.
|
||||
self._unique_symbol_ids: Counter[str] = Counter()
|
||||
|
||||
def _import_kernel(self, code: str, kernel_name: str) -> CachingAutotuner:
|
||||
"""
|
||||
Imports a kernel from source, possibly autotuning block parameters.
|
||||
"""
|
||||
module_code = "\n".join([self.prologue, code])
|
||||
mod = PyCodeCache.load(module_code)
|
||||
kernel = getattr(mod, kernel_name)
|
||||
|
||||
if not isinstance(kernel, CachingAutotuner):
|
||||
raise NotImplementedError(
|
||||
textwrap.dedent(f"""
|
||||
Unsupported type for kernel {kernel_name}: {type(kernel)}.
|
||||
FX conversion only supports Triton kernels.
|
||||
""")
|
||||
)
|
||||
|
||||
return kernel
|
||||
|
||||
def _fake_tensor(
|
||||
self,
|
||||
size: tuple[Any, ...],
|
||||
stride: tuple[Any, ...],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
with V.fake_mode:
|
||||
return torch.empty_strided(
|
||||
convert_shape_to_symint(size),
|
||||
convert_shape_to_symint(stride),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def _create_meta_from_buffer(
|
||||
self, node: torch.fx.Node, buffer: CodegenBuffer
|
||||
) -> None:
|
||||
name = buffer.get_name()
|
||||
assert name
|
||||
node.name = name
|
||||
node.meta["val"] = buffer.get_example()
|
||||
|
||||
def _record_allocation(self, buffer: CodegenBuffer, node: torch.fx.Node) -> None:
|
||||
"""
|
||||
Updates the symbol table to record that an Inductor buffer maps to the result of
|
||||
an FX node.
|
||||
"""
|
||||
assert node not in self.buffer_to_node
|
||||
self.buffer_to_node[buffer.get_name()] = node
|
||||
|
||||
def _free(self, buffer: Union[CodegenBuffer, ir.TorchBindObject]) -> None:
|
||||
"""
|
||||
Removes the buffer from the symbol table.
|
||||
"""
|
||||
name = buffer.get_name()
|
||||
del self.buffer_to_node[name]
|
||||
|
||||
def _lookup_args(self, args: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
"""
|
||||
Maps call args back to FX nodes.
|
||||
"""
|
||||
return tuple(
|
||||
self.buffer_to_node[arg]
|
||||
if isinstance(arg, str)
|
||||
else arg.inner_expr
|
||||
if isinstance(arg, SymbolicCallArg)
|
||||
else arg
|
||||
for arg in args
|
||||
)
|
||||
|
||||
def _get_buffer(self, node: ir.IRNode) -> CodegenBuffer:
|
||||
"""
|
||||
Extract buffer data from an IR node.
|
||||
"""
|
||||
if isinstance(node, (ir.Buffer, WorkspaceArg)):
|
||||
return node
|
||||
elif isinstance(node, (ir.BaseView, ir.MutableBox)):
|
||||
return self._get_buffer(node.data)
|
||||
elif isinstance(node, sympy.Symbol):
|
||||
return SymbolBuffer(node)
|
||||
else:
|
||||
raise NotImplementedError(f"Unable to extract buffer from node: {node}")
|
||||
|
||||
def _generate_graph_inputs(self) -> None:
|
||||
"""
|
||||
Converts graph inputs to FX placeholders.
|
||||
"""
|
||||
for ir_node in V.graph.graph_inputs.values():
|
||||
buffer = self._get_buffer(ir_node)
|
||||
node = self.gm.graph.placeholder(buffer.get_name())
|
||||
self._create_meta_from_buffer(node, buffer)
|
||||
self._record_allocation(buffer, node)
|
||||
|
||||
def _generate_buffer(self, node: ir.IRNode) -> Optional[torch.fx.Node]:
|
||||
"""
|
||||
Generates FX IR for transformations on a buffer, such as ReinterpretView.
|
||||
Does nothing if no such transformations are present.
|
||||
"""
|
||||
|
||||
def generate_to_buffer(node: ir.IRNode) -> Optional[BufferLike]:
|
||||
if isinstance(node, (ir.Buffer, WorkspaceArg)):
|
||||
return node
|
||||
elif isinstance(node, ir.NoneAsConstantBuffer):
|
||||
return None
|
||||
elif isinstance(node, ir.StorageBox):
|
||||
return generate_to_buffer(node.data)
|
||||
elif isinstance(node, ir.ReinterpretView):
|
||||
# We need to introduce a new symbol if the output is a ReinterpretView.
|
||||
# Use a WorkspaceArg for this.
|
||||
buffer = self._get_buffer(node.data)
|
||||
assert isinstance(buffer, (ir.Buffer, WorkspaceArg))
|
||||
unique_name = self.gm.graph._graph_namespace.create_name(
|
||||
f"{buffer.get_name()}_view", None
|
||||
)
|
||||
device = buffer.get_device()
|
||||
assert device
|
||||
reused_as = WorkspaceArg(
|
||||
count=buffer.get_size(),
|
||||
zero_mode=WorkspaceZeroMode.UNINITIALIZED,
|
||||
device=device,
|
||||
outer_name=unique_name,
|
||||
dtype=buffer.get_dtype(),
|
||||
)
|
||||
|
||||
# Generate FX IR for the view.
|
||||
self._generate_reinterpret_helper(buffer, reused_as, node.layout)
|
||||
|
||||
return reused_as
|
||||
else:
|
||||
raise NotImplementedError(f"Unrecognized buffer/view node: {node}")
|
||||
|
||||
buffer = generate_to_buffer(node)
|
||||
return self.buffer_to_node[buffer.get_name()] if buffer is not None else None
|
||||
|
||||
def _generate_output(self) -> None:
|
||||
"""
|
||||
Generate FX IR for graph outputs.
|
||||
"""
|
||||
output_nodes = [
|
||||
self._generate_buffer(node)
|
||||
for idx, node in enumerate(V.graph.graph_outputs)
|
||||
]
|
||||
|
||||
# Single return elements don't use a tuple.
|
||||
output_value = output_nodes[0] if len(output_nodes) == 1 else output_nodes
|
||||
|
||||
self.gm.graph.output(output_value)
|
||||
|
||||
def generate(self) -> torch.fx.GraphModule:
|
||||
"""
|
||||
Main entrypoint for FX codegen.
|
||||
"""
|
||||
self._generate_graph_inputs()
|
||||
|
||||
# Generate FX IR from Wrapper IR lines.
|
||||
for line in self.lines:
|
||||
if isinstance(line, WrapperLine):
|
||||
line.codegen_fx(self)(line)
|
||||
elif isinstance(line, LineContext):
|
||||
# Ignore line context in FX IR.
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
textwrap.dedent(
|
||||
f"""
|
||||
Found line of unrecognized type '{type(line)}':
|
||||
'{line}'
|
||||
|
||||
FX conversion only supports Wrapper IR lines.
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
self._generate_output()
|
||||
self.gm.recompile()
|
||||
return self.gm
|
||||
|
||||
def _generate_allocate(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, AllocateLine)
|
||||
buffer = line.node
|
||||
name = buffer.get_name()
|
||||
assert name not in V.graph.removed_buffers
|
||||
|
||||
device = buffer.get_device()
|
||||
dtype = buffer.get_dtype()
|
||||
shape = convert_shape_to_symint(buffer.get_size())
|
||||
stride = convert_shape_to_symint(buffer.get_stride())
|
||||
|
||||
node = self.gm.graph.call_function(
|
||||
torch.empty_strided,
|
||||
args=(shape, stride),
|
||||
kwargs={"dtype": dtype, "device": device},
|
||||
)
|
||||
assert name
|
||||
node.name = name
|
||||
self._create_meta_from_buffer(node, buffer)
|
||||
self._record_allocation(buffer, node)
|
||||
|
||||
def _generate_comment(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, CommentLine)
|
||||
# We ignore comments in FX IR.
|
||||
|
||||
def _generate_enter_device_context_manager(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, EnterDeviceContextManagerLine)
|
||||
# We ignore the device context in FX IR.
|
||||
|
||||
def _generate_exit_device_context_manager(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, ExitDeviceContextManagerLine)
|
||||
# We ignore the device context in FX IR.
|
||||
|
||||
def _generate_enter_subgraph(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, EnterSubgraphLine)
|
||||
raise NotImplementedError("Subgraphs are not yet supported by FX conversion")
|
||||
|
||||
def _generate_exit_subgraph(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, ExitSubgraphLine)
|
||||
raise NotImplementedError("Subgraphs are not yet supported by FX conversion")
|
||||
|
||||
def _generate_free(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, FreeLine)
|
||||
|
||||
buf = line.node
|
||||
|
||||
# No need to free placeholders.
|
||||
if self.buffer_to_node[buf.get_name()].op == "placeholder":
|
||||
return
|
||||
|
||||
self._free(buf)
|
||||
|
||||
def _generate_free_if_not_reused(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, FreeIfNotReusedLine)
|
||||
buf = line.node
|
||||
assert buf.get_name() not in V.graph.removed_buffers
|
||||
if not line.is_reused:
|
||||
self._free(buf)
|
||||
|
||||
def _generate_line_context(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, LineContext)
|
||||
# We ignore line context in FX IR.
|
||||
|
||||
def _generate_reinterpret(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, ReinterpretLine)
|
||||
self._generate_reinterpret_helper(line.node, line.reused_as, line.layout)
|
||||
|
||||
def _generate_reinterpret_helper(
|
||||
self, input_buffer: BufferLike, result_buffer: BufferLike, layout: ir.Layout
|
||||
) -> None:
|
||||
input_node = self.buffer_to_node[input_buffer.get_name()]
|
||||
|
||||
# Look up output metadata.
|
||||
name = result_buffer.get_name()
|
||||
assert name
|
||||
size = tuple(layout.size)
|
||||
stride = tuple(layout.stride)
|
||||
offset = input_buffer.get_offset() + layout.offset
|
||||
|
||||
# Map ReinterpretView to as_strided.
|
||||
result_node = self.gm.graph.call_function(
|
||||
torch.as_strided, args=(input_node, size, stride, offset)
|
||||
)
|
||||
result_node.name = name
|
||||
result_node.meta["val"] = layout.get_example()
|
||||
self._record_allocation(result_buffer, result_node)
|
||||
|
||||
def _generate_reuse(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, ReuseLine)
|
||||
old = line.node
|
||||
new = line.reused_as
|
||||
assert not any(buf.get_name() in V.graph.removed_buffers for buf in (old, new))
|
||||
assert old.get_dtype() == new.get_dtype()
|
||||
|
||||
old_node = self.buffer_to_node[old.get_name()]
|
||||
result_node = old_node
|
||||
|
||||
# Change shape and stride.
|
||||
size = new.get_size()
|
||||
stride = new.get_stride()
|
||||
offset = new.get_offset()
|
||||
if (
|
||||
old.get_size() != size
|
||||
or old.get_stride() != stride
|
||||
or old.get_offset() != offset
|
||||
):
|
||||
result_node = self.gm.graph.call_function(
|
||||
torch.as_strided, args=(old_node, size, stride, offset)
|
||||
)
|
||||
self._create_meta_from_buffer(result_node, new)
|
||||
|
||||
self._record_allocation(new, result_node)
|
||||
|
||||
# Free the old buffer, if we allocated a new tensor.
|
||||
if (
|
||||
old.get_name() not in V.graph.get_output_names()
|
||||
and line.delete_old
|
||||
and result_node is not old_node
|
||||
):
|
||||
self._free(old)
|
||||
|
||||
def _generate_multi_output(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, MultiOutputLine)
|
||||
|
||||
# Extract the index for tuple access.
|
||||
inds = line.indices[0][1:]
|
||||
assert len(inds) == 1, f"Cannot convert {inds} to an index."
|
||||
idx = inds[0]
|
||||
|
||||
arg_node = self.buffer_to_node[line.arg_name]
|
||||
node = self.gm.graph.call_function(operator.getitem, args=(arg_node, idx))
|
||||
node.meta["val"] = arg_node.meta["val"][idx]
|
||||
node.name = line.result_name
|
||||
self.buffer_to_node[line.result_name] = node
|
||||
|
||||
def _generate_null(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, NullLine)
|
||||
# Does nothing.
|
||||
|
||||
def _generate_comm_buffer_allocate(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, CommBufferAllocateLine)
|
||||
raise NotImplementedError("Comm buffer allocation is not yet supported")
|
||||
|
||||
def _generate_comm_buffer_free(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, CommBufferFreeLine)
|
||||
self._free(line.node)
|
||||
|
||||
def _generate_triton_call(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, KernelCallLine)
|
||||
|
||||
# Collect all kwargs, including autotuned block sizes.
|
||||
call_args = self._lookup_args(line.call_args)
|
||||
kernel = self.kernels[line.kernel_name]
|
||||
tuner = kernel.tuner
|
||||
config = tuner.compile_results[0].config
|
||||
call_args, grid = tuner._interpret_args_grid(call_args, config)
|
||||
call_kwargs = dict(zip(tuner.triton_meta["signature"], call_args))
|
||||
call_kwargs.update(config.kwargs)
|
||||
|
||||
# Convert sympy expressions to symints.
|
||||
for name, val in call_kwargs.items():
|
||||
if isinstance(val, sympy.Expr):
|
||||
call_kwargs[name] = convert_to_symint(val)
|
||||
|
||||
# Store non-graphable kwargs in the side table.
|
||||
(
|
||||
call_kwargs,
|
||||
constant_args_idx,
|
||||
) = tracing_triton_hopifier_singleton.store_non_graphable_args(call_kwargs)
|
||||
|
||||
self.gm.graph.call_function(
|
||||
triton_kernel_wrapper_mutation,
|
||||
kwargs={
|
||||
"kernel_idx": kernel.wrapped.kernel_idx,
|
||||
"constant_args_idx": constant_args_idx,
|
||||
"grid": [convert_shape_to_symint(grid)],
|
||||
"tma_descriptor_metadata": {},
|
||||
"kwargs": call_kwargs,
|
||||
},
|
||||
)
|
||||
|
||||
def _generate_extern_kernel_alloc(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, ExternKernelAllocLine)
|
||||
node = line.node
|
||||
self._generate_extern_kernel_common(node, node)
|
||||
|
||||
def _generate_extern_kernel_out(
|
||||
self,
|
||||
line: WrapperLine,
|
||||
) -> None:
|
||||
assert isinstance(line, ExternKernelOutLine)
|
||||
node = line.node
|
||||
out_node = node.output_view if node.output_view else node
|
||||
self._generate_extern_kernel_common(node, out_node)
|
||||
|
||||
def _generate_extern_kernel_common(
|
||||
self, kernel: ir.ExternKernel, out_ir_node: ir.IRNode
|
||||
) -> None:
|
||||
"""
|
||||
Generates FX IR from either ExternKernelAlloc or ExternKernelOut.
|
||||
"""
|
||||
|
||||
# Get FX nodes corresponding to the call args.
|
||||
tensor_nodes = tuple(self._generate_buffer(arg) for arg in kernel.inputs)
|
||||
args = tensor_nodes + tuple(kernel.constant_args)
|
||||
|
||||
# Get the result buffer.
|
||||
# Some kernels write to a pre-existing output tensor via the "out" kwarg.
|
||||
kwargs = kernel.kwargs.copy()
|
||||
result_buffer: Optional[str] = None
|
||||
if isinstance(kernel, ir.ExternKernelOut):
|
||||
kwargs["out"] = self.buffer_to_node[out_ir_node.codegen_reference()]
|
||||
elif isinstance(kernel.layout, (ir.Layout, ir.MultiOutputLayout)):
|
||||
result_buffer = kernel.get_name()
|
||||
elif isinstance(kernel.layout, ir.NoneLayout):
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(f"Unrecognized output layout: {kernel.layout}")
|
||||
|
||||
# Look up the kernel function from its name.
|
||||
kernel_name = kernel.get_kernel_name()
|
||||
module_name, kernel_name = kernel_name.split(".", 1)
|
||||
op = globals()[module_name] # E.g. extern_kernels, aten, etc.
|
||||
for subname in kernel_name.split("."):
|
||||
op = getattr(op, subname) # E.g. extern_kernels.addmm
|
||||
|
||||
fx_node = self.gm.graph.call_function(op, args=args, kwargs=kwargs)
|
||||
|
||||
# Assign the result to the given name.
|
||||
if result_buffer:
|
||||
assert "out" not in kwargs, (
|
||||
f"Extern kernel '{kernel}' has both result and out kwarg. Expected only one."
|
||||
)
|
||||
fx_node.name = result_buffer
|
||||
self.buffer_to_node[result_buffer] = fx_node
|
||||
|
||||
arg_tensors = [
|
||||
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
|
||||
for arg in args
|
||||
]
|
||||
|
||||
# Run the operation to propagate metadata.
|
||||
fx_node.meta["val"] = op(*arg_tensors, **kwargs)
|
||||
|
||||
def _generate_kernel_call(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, KernelCallLine)
|
||||
if not line.triton:
|
||||
raise NotImplementedError("FX conversion only supports Triton kernels.")
|
||||
|
||||
self._generate_triton_call(line)
|
||||
|
||||
def _generate_kernel_definition(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, KernelDefinitionLine)
|
||||
|
||||
# Generate code for the kernel.
|
||||
kernel_code = PythonWrapperCodegen._format_kernel_definition(
|
||||
line.kernel_name, line.kernel_body, metadata=line.metadata
|
||||
)
|
||||
|
||||
# Import the module and store the JIT kernel.
|
||||
tuner = self._import_kernel(kernel_code, line.kernel_name)
|
||||
wrapped = wrap_triton(tuner.fn)
|
||||
self.kernels[line.kernel_name] = TritonKernel(tuner, wrapped)
|
||||
|
||||
def _generate_symbolic_call_arg(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, SymbolicCallArgLine)
|
||||
# No need for an FX node, as we will pass the arg to kernels via a SymInt.
|
@ -50,6 +50,7 @@ from . import config, ir, metrics
|
||||
from .codegen.common import (
|
||||
BackendFeature,
|
||||
DeviceOpOverrides,
|
||||
FileBackedGraphModule,
|
||||
get_backend_features,
|
||||
get_device_op_overrides,
|
||||
get_wrapper_codegen_for_device,
|
||||
@ -115,9 +116,12 @@ if TYPE_CHECKING:
|
||||
from torch._higher_order_ops.effects import _EffectType
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.graph import Graph
|
||||
|
||||
from .codegen.wrapper import PythonWrapperCodegen
|
||||
from .scheduler import BaseSchedulerNode
|
||||
|
||||
CompiledModule = Union[ModuleType, FileBackedGraphModule]
|
||||
|
||||
from torch._inductor.codecache import output_code_log
|
||||
|
||||
|
||||
@ -2224,7 +2228,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
# No-op to be patched for unit tests
|
||||
save_output_code: Optional[Callable[[str], None]] = None
|
||||
|
||||
def compile_to_module(self) -> ModuleType:
|
||||
def compile_to_module(self) -> CompiledModule:
|
||||
with dynamo_timed(
|
||||
"GraphLowering.compile_to_module",
|
||||
phase_name="code_gen",
|
||||
@ -2233,14 +2237,41 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
):
|
||||
return self._compile_to_module()
|
||||
|
||||
def _compile_to_module(self) -> ModuleType:
|
||||
from .codecache import PyCodeCache
|
||||
|
||||
def _compile_to_module(self) -> CompiledModule:
|
||||
# Currently, if we're here, we don't have to worry about the kernel code, which
|
||||
# is only available in AOTInductor mode.
|
||||
wrapper_code, _ = (
|
||||
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
|
||||
)
|
||||
|
||||
if isinstance(wrapper_code, ValueWithLineMap):
|
||||
mod = self._compile_to_module_lines(wrapper_code)
|
||||
elif isinstance(wrapper_code, FileBackedGraphModule):
|
||||
mod = wrapper_code
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unrecognized wrapper code type: {type(wrapper_code)}"
|
||||
)
|
||||
|
||||
# Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
|
||||
# TODO. Revisit this once the logging API is more mature
|
||||
assert mod.__file__ is not None
|
||||
|
||||
log_module_code(mod.__file__)
|
||||
log.debug("Output code written to: %s", mod.__file__)
|
||||
output_code_log.info("Output code written to: %s", mod.__file__)
|
||||
if config.benchmark_kernel:
|
||||
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
|
||||
V.debug.output_code(mod.__file__)
|
||||
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
|
||||
|
||||
return mod
|
||||
|
||||
def _compile_to_module_lines(
|
||||
self, wrapper_code: ValueWithLineMap
|
||||
) -> CompiledModule:
|
||||
from .codecache import PyCodeCache
|
||||
|
||||
if config.triton.autotune_at_compile_time:
|
||||
tuning_code = (
|
||||
'"""\n'
|
||||
@ -2291,17 +2322,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
if config.benchmark_harness and config.profile_bandwidth_output:
|
||||
# run the inputs code gen to get the bandwidth info
|
||||
mod.benchmark_compiled_module(times=1, repeat=1)
|
||||
# Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
|
||||
# TODO. Revisit this once the logging API is more mature
|
||||
assert mod.__file__ is not None
|
||||
|
||||
log_module_code(mod.__file__)
|
||||
log.debug("Output code written to: %s", mod.__file__)
|
||||
output_code_log.info("Output code written to: %s", mod.__file__)
|
||||
if config.benchmark_kernel:
|
||||
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
|
||||
V.debug.output_code(mod.__file__)
|
||||
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
|
||||
return mod
|
||||
|
||||
def get_output_names(self) -> list[str]:
|
||||
|
@ -65,6 +65,7 @@ from torch.utils._sympy.symbol import SymT
|
||||
from . import config, dependencies
|
||||
from .codegen.common import (
|
||||
BackendFeature,
|
||||
CodegenSymbol,
|
||||
get_scheduling_for_device,
|
||||
index_prevent_reordering,
|
||||
)
|
||||
@ -3423,6 +3424,15 @@ class Layout(OutputSpec):
|
||||
def get_device(self) -> torch.device:
|
||||
return self.device
|
||||
|
||||
def get_example(self) -> torch.Tensor:
|
||||
with V.fake_mode:
|
||||
return torch.empty_strided(
|
||||
convert_shape_to_symint(self.size),
|
||||
convert_shape_to_symint(self.stride),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def is_contiguous(self) -> bool:
|
||||
return is_contiguous_strides_for_shape(self.stride, self.size)
|
||||
|
||||
@ -3926,7 +3936,7 @@ class MutationLayoutSHOULDREMOVE(Layout):
|
||||
|
||||
|
||||
@ir_dataclass(frozen=False)
|
||||
class Buffer(IRNode):
|
||||
class Buffer(IRNode, CodegenSymbol):
|
||||
# Name is sometimes None; e.g., ForceInPlace, where there isn't
|
||||
# a meaningful name
|
||||
name: Optional[str]
|
||||
@ -3946,6 +3956,11 @@ class Buffer(IRNode):
|
||||
assert self.name, self
|
||||
return self.name
|
||||
|
||||
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
|
||||
if isinstance(self.layout, Layout):
|
||||
return self.layout.get_example()
|
||||
raise NotImplementedError(type(self.layout).__name__)
|
||||
|
||||
def get_device(self) -> Optional[torch.device]:
|
||||
return self.get_output_spec().get_device()
|
||||
|
||||
|
@ -85,7 +85,7 @@ class NoTritonConfigsError(RuntimeError):
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Container, Hashable, Sequence
|
||||
from collections.abc import Container, Hashable
|
||||
|
||||
from torch._guards import CompileId
|
||||
|
||||
@ -2564,13 +2564,15 @@ class GridExpr:
|
||||
|
||||
inductor_meta: dict[str, Any]
|
||||
mode: Literal["python", "cpp"] = "python"
|
||||
prefix: Sequence[str] = ()
|
||||
prefix: list[str] = dataclasses.field(default_factory=list)
|
||||
x_grid: Union[str, int] = 1
|
||||
y_grid: Union[str, int] = 1
|
||||
z_grid: Union[str, int] = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.mode in ("python", "cpp")
|
||||
if self.mode == "python":
|
||||
self.prefix.append("from torch.utils._sympy.functions import FloorDiv")
|
||||
|
||||
def generate(self, meta: dict[str, int]) -> None:
|
||||
raise NotImplementedError
|
||||
@ -2583,7 +2585,9 @@ class GridExpr:
|
||||
if isinstance(numel, int) and isinstance(block, int):
|
||||
return ceildiv(numel, block) # constant fold
|
||||
if self.mode == "python":
|
||||
return f"-(({numel}) // -({block}))"
|
||||
# Use FloorDiv instead of // so we can get better sympy expressions for
|
||||
# dynamic shapes.
|
||||
return f"-FloorDiv(({numel}), -({block}))"
|
||||
# trick above doesn't work in C++ due to rounding differences
|
||||
return f"(({numel} + ({block} - 1)) / ({block}))"
|
||||
|
||||
@ -2666,12 +2670,16 @@ class Grid3D(GridExpr):
|
||||
class Grid2DWithYZOverflow(GridExpr):
|
||||
def generate(self, meta: dict[str, int]) -> None:
|
||||
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
||||
self.prefix = [
|
||||
self.assign_tmp("y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK"))),
|
||||
self.assign_tmp(
|
||||
"y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid())
|
||||
),
|
||||
]
|
||||
self.prefix.extend(
|
||||
[
|
||||
self.assign_tmp(
|
||||
"y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK"))
|
||||
),
|
||||
self.assign_tmp(
|
||||
"y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid())
|
||||
),
|
||||
]
|
||||
)
|
||||
self.y_grid = self.ceildiv("y_grid_raw_", "y_grid_div_")
|
||||
self.z_grid = "y_grid_div_"
|
||||
|
||||
|
@ -436,6 +436,23 @@ def convert_shape_to_inductor(
|
||||
return [sympy.sympify(i) for i in lst]
|
||||
|
||||
|
||||
def convert_to_symint(i: Union[int, sympy.Expr]) -> Union[int, torch.SymInt]:
|
||||
"""
|
||||
Like convert_shape_to_symint, but operates on a single expression.
|
||||
"""
|
||||
from .virtualized import V
|
||||
|
||||
return (
|
||||
i
|
||||
if isinstance(i, int)
|
||||
else (
|
||||
int(i)
|
||||
if isinstance(i, sympy.Integer)
|
||||
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def convert_shape_to_symint(
|
||||
lst: Iterable[Union[int, sympy.Expr]],
|
||||
) -> list[Union[int, torch.SymInt]]:
|
||||
@ -443,20 +460,7 @@ def convert_shape_to_symint(
|
||||
Takes a list of shapes from Inductor and converts them into symints (or just
|
||||
ints if all shapes are static).
|
||||
"""
|
||||
from .virtualized import V
|
||||
|
||||
return [
|
||||
(
|
||||
i
|
||||
if isinstance(i, int)
|
||||
else (
|
||||
int(i)
|
||||
if isinstance(i, sympy.Integer)
|
||||
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
|
||||
)
|
||||
)
|
||||
for i in lst
|
||||
]
|
||||
return [convert_to_symint(i) for i in lst]
|
||||
|
||||
|
||||
def is_view(op: torch._ops.OpOverload) -> bool:
|
||||
|
Reference in New Issue
Block a user