mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
# Feature Support `torch.cond` in the FX converter. The generated FX IR is conceptually indentical to what would come from `torch.export`: - Submodules as stored as attributes, and accessed via `getattr`. - The conditional is represented as `torch.ops.higher_order.cond`, which takes in the subgraphs, a predicate and submodule inputs. # Implementation overview The FX backend generates code for subgraphs using the following steps: 1. When `codegen_conditional` is called in `WrapperFxCodegen`, we emit a `ConditionalLine`. a. We also codegen the true/false subgraphs at this time, storing their subgms for later. 2. At the beginning of FX conversion, generate `get_attr` nodes accessing each subgraph. It's important to do this at the start, before registering the node metadata hook. This also matches the convention followed by torch.export. 3. When we see the `ConditionalLine` in the FX converter, we generate a corresponding `torch.ops.higher_order.cond`. # Implementation details This ended up being a substantial change, as wrapper codegen has some special logic for subgraphs. Certain methods of `PythonWrapperCodegen` are overridden by `SubgraphPythonWrapperCodegen`. To apply these overrides, we use multiple inheritance with the registered subclass of `WrapperFxCodegen`. Unlike most other wrapper codegen methods, which map 1:1 to Wrapper IR lines, subgraph codegen generates a number of wrapper lines including `EnterSubgraphLine` and `ExitSubgraphLine`, along with Python or C++ code calling the subgraph as a function. These lines are used for some backends' memory planning. In contrast, FX IR typically represents a subgraph call as a single HOP node, or a `call_module` op. To account for this difference, this PR introduces a new wrapper IR line called `ConditionalLine`, which is only used by the FX backend. We override the `codegen_conditional` method to emit this line. This sidesteps having to port the existing subgraph codegen and associated memory planning to Wrapper IR. (In principle, it seems possible to adapt the existing backends to `ConditionalLine`, but it could be a larger refactor, since we'd also have to update the memory planning.) Some of the lower-level subgraph codegen methods are still shared between the FX and Python backends, such as `generate_subgraph_common`. Those were easier to port to Wrapper IR. This also required generalizing the way the FX converter handles graph inputs and outputs. Previously, it assumed the IO signature was the same as `V.graph.module`, but this is only true for the parent graph, and not subgraphs. Instead, we need to call `get_graph_inputs` and `get_graph_outputs` to populate the inputs and outputs for subgraphs. # Test plan This PR adds a couple of tests using torch.cond. Here's an example graph generated by one of them: ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0] %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%arg0_1, %true_graph_0, %false_graph_0, (%arg1_1,)), kwargs = {}) %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 6, constant_args_idx: 6, grid: [(1, 1, 1)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 6, XBLOCK: 8}}) return buf1 ``` It also removes an existing negative test which checked that a certain error was raised when subgraphs were encountered. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163234 Approved by: https://github.com/angelayi, https://github.com/jansel
999 lines
33 KiB
Python
999 lines
33 KiB
Python
# Owner(s): ["module: inductor"]
|
|
"""
|
|
Test the FX IR backend.
|
|
"""
|
|
|
|
import itertools
|
|
import operator
|
|
import unittest
|
|
from typing import Callable, Optional
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
import torch._inductor.codegen.common as common
|
|
import torch.utils._pytree as pytree
|
|
from torch._dynamo.exc import BackendCompilerFailed
|
|
from torch._dynamo.utils import same
|
|
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
|
|
from torch._inductor import config
|
|
from torch._inductor.codegen.cpp import CppScheduling
|
|
from torch._inductor.codegen.triton import TritonScheduling
|
|
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
|
|
from torch._inductor.codegen.wrapper_fxir import FxConverter, WrapperFxCodegen
|
|
from torch._inductor.test_case import TestCase as InductorTestCase
|
|
from torch.export import Dim
|
|
from torch.testing._internal.common_utils import (
|
|
DeterministicGuard,
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
)
|
|
from torch.testing._internal.inductor_utils import (
|
|
GPU_TYPE,
|
|
HAS_GPU,
|
|
requires_gpu,
|
|
TRITON_HAS_CPU,
|
|
)
|
|
|
|
|
|
if HAS_GPU:
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from torch.testing._internal.triton_utils import add_kernel_2d_autotuned
|
|
|
|
test_config = {
|
|
"compile_threads": 1,
|
|
"alignment_asserts": False,
|
|
"size_asserts": False,
|
|
"scalar_asserts": False,
|
|
"nan_asserts": False,
|
|
}
|
|
|
|
|
|
@requires_gpu()
|
|
@config.patch(test_config)
|
|
@instantiate_parametrized_tests
|
|
class FxirTestCase(InductorTestCase):
|
|
device = GPU_TYPE
|
|
|
|
def _count_ops(self, gm: torch.fx.GraphModule, target: Callable) -> int:
|
|
return len(gm.graph.find_nodes(op="call_function", target=target))
|
|
|
|
def _run_and_capture_graphs(self, opt, args) -> torch.fx.GraphModule:
|
|
gms = []
|
|
|
|
orig_generate = FxConverter.generate
|
|
|
|
def generate(self) -> torch.fx.GraphModule:
|
|
nonlocal gms
|
|
gm = orig_generate(self)
|
|
gms.append(gm)
|
|
return gm
|
|
|
|
with unittest.mock.patch.object(
|
|
torch._inductor.codegen.wrapper_fxir.FxConverter, "generate", generate
|
|
):
|
|
opt(*args)
|
|
|
|
return gms
|
|
|
|
def _compile_and_check(
|
|
self,
|
|
func,
|
|
args,
|
|
expected_num_triton_kernels: int = 1,
|
|
metadata_only: bool = False,
|
|
compile_kwargs: Optional[dict] = None,
|
|
):
|
|
if compile_kwargs is None:
|
|
compile_kwargs = {}
|
|
|
|
opt = torch.compile(func, **compile_kwargs)
|
|
|
|
# Get the FX graph from the backend.
|
|
gms = self._run_and_capture_graphs(opt, args)
|
|
|
|
# Check the code for triton kernels.
|
|
num_kernels = sum(
|
|
self._count_ops(gm, triton_kernel_wrapper_mutation) for gm in gms
|
|
)
|
|
self.assertEqual(num_kernels, expected_num_triton_kernels)
|
|
|
|
# Check accuracy.
|
|
result = opt(*args)
|
|
ref = func(*args)
|
|
if metadata_only:
|
|
# When we only want to check metadata, fill in zeros for tensor data.
|
|
ref, result = tuple(
|
|
pytree.tree_map(torch.zeros_like, x) for x in (ref, result)
|
|
)
|
|
|
|
self.assertTrue(same(ref, result))
|
|
|
|
return gms
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
|
|
# Register the FX backend, storing the default for later.
|
|
common.init_backend_registration()
|
|
cls._default_backend = common.device_codegens[cls.device]
|
|
common.register_backend_for_device(
|
|
cls.device, TritonScheduling, WrapperFxCodegen
|
|
)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
super().tearDownClass()
|
|
|
|
# Restore the default backend.
|
|
common.device_codegens[cls.device] = cls._default_backend
|
|
|
|
def test_basic(self):
|
|
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
|
self._compile_and_check(torch.add, args)
|
|
|
|
def test_multiple_kernels(self):
|
|
def foo(x, y):
|
|
return x.sum() + y.sum()
|
|
|
|
args = [torch.randn(length, device=self.device) for length in [517, 1029]]
|
|
self._compile_and_check(foo, args, expected_num_triton_kernels=2)
|
|
|
|
def test_free(self):
|
|
"""
|
|
Test a program that frees a buffer which is no longer in use.
|
|
"""
|
|
|
|
def foo(x, y, z):
|
|
w = x.sum() + y
|
|
return z.sum() + w.sum()
|
|
|
|
args = [torch.randn(length, device=self.device) for length in [517, 1029, 123]]
|
|
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=3)
|
|
|
|
# Check the generated code for frees.
|
|
num_frees = gm.code.count("= None")
|
|
self.assertGreater(num_frees, 0)
|
|
|
|
def test_extern(self):
|
|
"""
|
|
Test a program that calls an extern kernel.
|
|
"""
|
|
|
|
def foo(x, y):
|
|
return x @ y + y.sum()
|
|
|
|
args = [
|
|
torch.randn(size, device=self.device) for size in [(129, 129), (129, 1)]
|
|
]
|
|
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
|
|
|
# Check for the extern kernel
|
|
num_extern = self._count_ops(gm, torch.ops.aten.addmm.out)
|
|
self.assertEqual(num_extern, 1)
|
|
|
|
def test_fallback(self):
|
|
"""
|
|
Test a program that calls aten fallbacks.
|
|
"""
|
|
|
|
def foo(x):
|
|
batch1 = torch.randn(2, 3, 5, device=self.device)
|
|
batch2 = torch.randn(2, 5, 4, device=self.device)
|
|
return torch.addbmm(x, batch1, batch2)
|
|
|
|
args = (torch.randn(3, 4, device=self.device),)
|
|
|
|
# Since the program has a random output, just check metadata.
|
|
# Don't check for an exact value.
|
|
(gm,) = self._compile_and_check(
|
|
foo, args, expected_num_triton_kernels=2, metadata_only=True
|
|
)
|
|
|
|
# Check for the fallback kernel.
|
|
num_fallback = self._count_ops(
|
|
gm, torch.ops.aten.randint.low_out
|
|
) + self._count_ops(gm, torch.ops.aten.addbmm.default)
|
|
self.assertEqual(num_fallback, 2)
|
|
|
|
def test_cat_inputs(self):
|
|
"""
|
|
Test concatenation of graph inputs.
|
|
"""
|
|
|
|
def foo(x, y):
|
|
return torch.cat((x, y)) + 1
|
|
|
|
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
|
self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
|
|
|
def test_cat_views(self):
|
|
"""
|
|
Test concatenation with multiple kernels writing to the same buffer.
|
|
"""
|
|
|
|
def foo(x, y):
|
|
a = x - 2
|
|
b = y.sum(0, keepdim=True)
|
|
c = torch.cat((a, b)).clone()
|
|
return a, b, c
|
|
|
|
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
|
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=2)
|
|
|
|
def get_offset(node: torch.fx.Node) -> int:
|
|
(input_, shape, stride, offset) = node.args
|
|
assert isinstance(offset, int)
|
|
return offset
|
|
|
|
# Check for 2 views, one of which is offset.
|
|
as_strided_nodes = list(
|
|
gm.graph.find_nodes(op="call_function", target=torch.as_strided)
|
|
)
|
|
self.assertEqual(len(as_strided_nodes), 2)
|
|
num_offset_views = sum(get_offset(node) > 0 for node in as_strided_nodes)
|
|
self.assertEqual(num_offset_views, 1)
|
|
|
|
def test_cat_to_alloc(self):
|
|
"""
|
|
Test concatenation that's optimized out to an allocation.
|
|
"""
|
|
length = 8
|
|
|
|
def foo(x):
|
|
y, z = tuple(
|
|
torch.arange(length // 2, device=self.device) for _ in range(2)
|
|
)
|
|
return x + torch.cat((y, z))
|
|
|
|
args = [torch.randn(length, device=self.device)]
|
|
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
|
|
|
# Expect a single allocation, even though eager mode would use 2.
|
|
num_allocs = self._count_ops(gm, torch.empty_strided)
|
|
self.assertEqual(num_allocs, 1)
|
|
|
|
def test_cat_reinterpret_view(self):
|
|
"""
|
|
Test torch.cat using ReinterpretView.
|
|
"""
|
|
length = 8
|
|
|
|
def foo(x):
|
|
y, z = tuple(torch.randn(length // 2, device=self.device) for _ in range(2))
|
|
return x + torch.cat((y, z))
|
|
|
|
args = [torch.randn(length, device=self.device)]
|
|
|
|
# Since this test generates random numbers, check metadata only.
|
|
(gm,) = self._compile_and_check(
|
|
foo, args, expected_num_triton_kernels=3, metadata_only=True
|
|
)
|
|
|
|
# Check for as_strided. We map ReinterpretView to this.
|
|
num_as_strided = self._count_ops(gm, torch.as_strided)
|
|
self.assertEqual(num_as_strided, 2)
|
|
|
|
def test_reshape_output(self):
|
|
"""
|
|
Test reshaping the output, which maps to a ReinterpretView.
|
|
"""
|
|
|
|
def foo(x, y):
|
|
return torch.reshape(x + y, (8,))
|
|
|
|
args = [torch.randn((2, 4), device=self.device) for _ in range(2)]
|
|
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
|
|
|
# Check for as_strided. We map ReinterpretView to this.
|
|
num_as_strided = self._count_ops(gm, torch.as_strided)
|
|
self.assertEqual(num_as_strided, 1)
|
|
|
|
def test_extern_multi_output(self):
|
|
"""
|
|
Test an extern kernel with multiple outputs.
|
|
Also test a graph with multiple outputs.
|
|
"""
|
|
|
|
def foo(x):
|
|
top, idx = torch.topk(x, 2)
|
|
return top + 1, idx * 2
|
|
|
|
args = [torch.randn(8, device=self.device)]
|
|
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=2)
|
|
|
|
# Check for multiple kernel outputs via getitems.
|
|
num_getitems = self._count_ops(gm, operator.getitem)
|
|
self.assertEqual(num_getitems, 2)
|
|
|
|
# Check for multiple graph outputs.
|
|
output_node = gm.graph.find_nodes(op="output")[0]
|
|
self.assertEqual(len(output_node.args[0]), 2)
|
|
|
|
def test_duplicate_input(self):
|
|
"""
|
|
Test duplicated inputs. This will collapse into a single input in the GM.
|
|
"""
|
|
|
|
args = [torch.randn(4, device=self.device)] * 2
|
|
(gm,) = self._compile_and_check(torch.add, args, expected_num_triton_kernels=1)
|
|
|
|
num_placeholders = len(gm.graph.find_nodes(op="placeholder"))
|
|
self.assertEqual(num_placeholders, 1)
|
|
|
|
def test_backward(self):
|
|
"""
|
|
Test a program with a backward pass.
|
|
"""
|
|
|
|
x = torch.ones(5, device=self.device) # input tensor
|
|
y = torch.zeros(3, device=self.device) # expected output
|
|
w = torch.randn(5, 3, requires_grad=True, device=self.device)
|
|
b = torch.randn(3, requires_grad=True, device=self.device)
|
|
|
|
def foo(x, y):
|
|
z = torch.matmul(x, w) + b
|
|
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
|
|
loss.backward()
|
|
return w.grad, b.grad
|
|
|
|
# Expect separate forward and backward graphs.
|
|
(forward_gm, backward_gm) = self._compile_and_check(
|
|
foo, (x, y), expected_num_triton_kernels=3
|
|
)
|
|
|
|
def test_custom_compiler(self):
|
|
"""
|
|
Test a derived backend with a custom compiler.
|
|
"""
|
|
offset = 1
|
|
|
|
class CustomWrapperCodegen(WrapperFxCodegen):
|
|
def compile_graph(self, gm):
|
|
def compiled_fn(*args):
|
|
# Adds an offset to the program's outputs.
|
|
outputs = gm(*args)
|
|
return pytree.tree_map(lambda x: x + 1, outputs)
|
|
|
|
return compiled_fn
|
|
|
|
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
|
custom_backend = common.DeviceCodegen(
|
|
TritonScheduling, CustomWrapperCodegen, None
|
|
)
|
|
with unittest.mock.patch.dict(
|
|
common.device_codegens, {self.device: custom_backend}
|
|
):
|
|
func = torch.add
|
|
opt = torch.compile(func)
|
|
result = opt(*args)
|
|
|
|
# Check the output is offset from eager mode.
|
|
ref = func(*args)
|
|
self.assertFalse(same(result, ref))
|
|
self.assertNotEqual(offset, 0)
|
|
self.assertTrue(same(result - offset, ref))
|
|
|
|
def test_dynamic_shapes_and_strides(self):
|
|
"""
|
|
Test a graph with dynamic shapes and strides.
|
|
"""
|
|
|
|
static_dims = (8, 8)
|
|
|
|
def get_input():
|
|
full_size = (16, 8)
|
|
full = torch.randn(full_size, device=self.device)
|
|
view = torch.as_strided(full, static_dims, full.stride())
|
|
return view
|
|
|
|
func = torch.add
|
|
args = [get_input() for _ in range(2)]
|
|
(gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True})
|
|
|
|
# Check for a symbolic output shape.
|
|
(empty_strided,) = gm.graph.find_nodes(
|
|
op="call_function", target=torch.empty_strided
|
|
)
|
|
example_tensor = empty_strided.meta["val"]
|
|
symbolic_dims = example_tensor.shape
|
|
self.assertEqual(len(symbolic_dims), len(static_dims))
|
|
|
|
# Check for symbolic output strides.
|
|
(stride, one) = example_tensor.stride()
|
|
self.assertEqual(one, sympy.S.One)
|
|
|
|
# Find the size symbols, and check for a corresponding placeholders defining them.
|
|
for symbol in itertools.chain(symbolic_dims, [stride]):
|
|
self.assertTrue(isinstance(symbol, torch.SymInt))
|
|
(placeholder,) = [
|
|
node
|
|
for node in gm.graph.find_nodes(op="placeholder")
|
|
if node.name == str(symbol)
|
|
]
|
|
self.assertEqual(placeholder.meta["val"], symbol)
|
|
|
|
@parametrize(
|
|
"shape",
|
|
[
|
|
(20,),
|
|
(50, 30),
|
|
(50, 30, 40),
|
|
],
|
|
)
|
|
@torch._inductor.config.patch(
|
|
{
|
|
"pad_dynamic_shapes": True,
|
|
"comprehensive_padding": True,
|
|
"padding_alignment_bytes": 32,
|
|
"pad_outputs": True,
|
|
}
|
|
)
|
|
def test_dynamic_shapes_with_padding(self, shape):
|
|
"""
|
|
Test a graph with dynamic shapes with padding.
|
|
"""
|
|
|
|
def get_input(shape):
|
|
pad_size = list(shape)
|
|
pad_size[-1] = ((shape[-1] + 7) // 8) * 8
|
|
pad = torch.randn(pad_size, dtype=torch.float32, device=self.device)
|
|
view = torch.as_strided(pad, shape, pad.stride())
|
|
return view
|
|
|
|
args = [get_input(shape) for _ in range(2)]
|
|
(gm,) = self._compile_and_check(
|
|
torch.add, args, compile_kwargs={"dynamic": True}
|
|
)
|
|
|
|
# Check for a symbolic output shape.
|
|
(empty_strided,) = gm.graph.find_nodes(
|
|
op="call_function", target=torch.empty_strided
|
|
)
|
|
example_tensor = empty_strided.meta["val"]
|
|
symbolic_dims = example_tensor.shape
|
|
symbolic_strides = example_tensor.stride()
|
|
|
|
align_elems = 32 // args[0].dtype.itemsize
|
|
expected_strides = [1 for _ in range(len(shape))]
|
|
for i in range(len(shape) - 1, 0, -1):
|
|
expected_strides[i - 1] = align_elems * (
|
|
((expected_strides[i] * symbolic_dims[i]) + align_elems - 1)
|
|
// align_elems
|
|
)
|
|
for i, j in zip(symbolic_strides, expected_strides):
|
|
self.assertEqual(i, j)
|
|
|
|
def test_dynamic_shapes_precomputed_size(self):
|
|
"""
|
|
Test dynamic shapes where a kernel's size arg is precomputed.
|
|
"""
|
|
func = torch.add
|
|
args = [
|
|
torch.randn(shape, device=self.device) for shape in [(7, 12, 9), (7, 1, 1)]
|
|
]
|
|
(gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True})
|
|
|
|
# Check for the precomputed size arg.
|
|
(triton_node,) = gm.graph.find_nodes(
|
|
op="call_function", target=triton_kernel_wrapper_mutation
|
|
)
|
|
self.assertIn("ks0", triton_node.kwargs["kwargs"])
|
|
|
|
def test_dynamic_launch_grid_calc_python(self):
|
|
"""
|
|
Test the dyanmic launch grid calculation for Triton kernel wrapper using python mode
|
|
"""
|
|
func = torch.add
|
|
args = [torch.randn(shape, device=self.device) for shape in [(7, 12), (7, 1)]]
|
|
(gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True})
|
|
|
|
# Check for the precomputed size arg.
|
|
(triton_node,) = gm.graph.find_nodes(
|
|
op="call_function", target=triton_kernel_wrapper_mutation
|
|
)
|
|
self.assertIn("grid", triton_node.kwargs)
|
|
self.assertIn("xnumel", triton_node.kwargs["kwargs"])
|
|
self.assertIn("XBLOCK", triton_node.kwargs["kwargs"])
|
|
grid = triton_node.kwargs["grid"][0]
|
|
xnumel = triton_node.kwargs["kwargs"]["xnumel"].meta["val"]
|
|
xblock = triton_node.kwargs["kwargs"]["XBLOCK"]
|
|
self.assertEqual(grid[0].meta["val"], -(-xnumel // xblock))
|
|
self.assertEqual(grid[1], 1)
|
|
self.assertEqual(grid[2], 1)
|
|
|
|
def test_dynamic_launch_grid_calc_python_slow(self):
|
|
"""
|
|
Test the dyanmic launch grid calculation for Triton kernel wrapper using python_slow mode
|
|
"""
|
|
from torch._inductor.runtime.triton_heuristics import GridExpr
|
|
|
|
# Mock GridExpr.from_meta to use "python_slow" mode explicitly
|
|
original_from_meta = GridExpr.from_meta
|
|
|
|
def mocked_from_meta(inductor_meta, cfg, mode="python"):
|
|
return original_from_meta(inductor_meta, cfg, mode="python_slow")
|
|
|
|
with unittest.mock.patch.object(GridExpr, "from_meta", mocked_from_meta):
|
|
func = torch.add
|
|
args = [
|
|
torch.randn(shape, device=self.device) for shape in [(7, 12), (7, 1)]
|
|
]
|
|
(gm,) = self._compile_and_check(
|
|
func, args, compile_kwargs={"dynamic": True}
|
|
)
|
|
|
|
# Check for the precomputed size arg.
|
|
(triton_node,) = gm.graph.find_nodes(
|
|
op="call_function", target=triton_kernel_wrapper_mutation
|
|
)
|
|
self.assertIn("grid", triton_node.kwargs)
|
|
self.assertIn("xnumel", triton_node.kwargs["kwargs"])
|
|
self.assertIn("XBLOCK", triton_node.kwargs["kwargs"])
|
|
grid = triton_node.kwargs["grid"][0]
|
|
xnumel = triton_node.kwargs["kwargs"]["xnumel"].meta["val"]
|
|
xblock = triton_node.kwargs["kwargs"]["XBLOCK"]
|
|
self.assertEqual(grid[0].meta["val"], ((xnumel + xblock - 1) // xblock))
|
|
self.assertEqual(grid[1], 1)
|
|
self.assertEqual(grid[2], 1)
|
|
|
|
@config.patch({"trace.enabled": True})
|
|
@unittest.mock.patch("torch._inductor.debug.DebugFormatter.output_code")
|
|
def test_debug(self, mock_output_code):
|
|
# Compile in debug mode.
|
|
args = [torch.randn(11, device=self.device) for _ in range(2)]
|
|
self._compile_and_check(torch.sub, args)
|
|
|
|
# Check the output code for a Triton kernel call.
|
|
mock_output_code.assert_called_once()
|
|
(output_filename,) = mock_output_code.call_args.args
|
|
with open(output_filename) as f:
|
|
output_code = f.read()
|
|
self.assertIn("triton_kernel_wrapper_mutation", output_code)
|
|
|
|
@parametrize(
|
|
"const",
|
|
(1, 1.5),
|
|
)
|
|
def test_export_const_placeholder(self, const):
|
|
"""
|
|
Test that we can compile a graph coming from torch.export with a constant input.
|
|
"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x - y
|
|
|
|
args = (torch.randn(8, device=self.device), const)
|
|
mod = TestModule()
|
|
export_gm = torch.export.export(mod, args).module()
|
|
|
|
def compile_module(*inps):
|
|
torch._inductor.compile(export_gm, inps)
|
|
|
|
(inductor_gm,) = self._run_and_capture_graphs(compile_module, args)
|
|
result = inductor_gm(*args)
|
|
ref = mod(*args)
|
|
|
|
self.assertTrue(same(ref, result))
|
|
|
|
def test_scatter_fallback_scalar_src(self):
|
|
"""
|
|
Test a special case where ScatterFallback takes a scalar 'src' argument.
|
|
"""
|
|
|
|
def foo(input_):
|
|
dim = 0
|
|
src = 1.5
|
|
return torch.ops.aten.scatter(input_, dim, index, src)
|
|
|
|
length = 8
|
|
index = torch.randint(length, (length,), device=self.device)
|
|
input_ = torch.randn(length, device=self.device)
|
|
with DeterministicGuard(True):
|
|
(gm,) = self._compile_and_check(
|
|
foo,
|
|
(input_,),
|
|
)
|
|
|
|
# Check for the fallback op.
|
|
num_fallback = self._count_ops(gm, torch.ops.aten.scatter_.value)
|
|
self.assertEqual(num_fallback, 1)
|
|
|
|
def test_index_put_fallback(self):
|
|
"""
|
|
Test the deterministic fallback for index_put.
|
|
"""
|
|
length = 8
|
|
out, values = [torch.randn(length, device=self.device) for _ in range(2)]
|
|
indices = (torch.randint(length, (length,), device=self.device),)
|
|
accumulate = True
|
|
with DeterministicGuard(True):
|
|
(gm,) = self._compile_and_check(
|
|
torch.index_put,
|
|
(out, indices, values, accumulate),
|
|
expected_num_triton_kernels=1,
|
|
)
|
|
|
|
# Check for the fallback op.
|
|
self.assertEqual(self._count_ops(gm, torch.ops.aten.index_put_.default), 1)
|
|
|
|
def test_scatter_reduce_fallback(self):
|
|
"""
|
|
Test the customized wrapper codegen for ScatterFallback ops.
|
|
"""
|
|
fallback_op = torch.ops.aten.scatter_reduce_.two
|
|
|
|
def foo(out, index, src):
|
|
dim = 0
|
|
out = fallback_op(out, dim, index, src, reduce="amax", include_self=False)
|
|
return out + 1
|
|
|
|
length = 8
|
|
out, src = [torch.randn(length, device=self.device) for _ in range(2)]
|
|
index = torch.randint(length, (length,), device=self.device)
|
|
(gm,) = self._compile_and_check(
|
|
foo, (out, index, src), expected_num_triton_kernels=2
|
|
)
|
|
|
|
# Check for the fallback.
|
|
self.assertEqual(self._count_ops(gm, fallback_op), 1)
|
|
|
|
@parametrize("pred", (False, True))
|
|
def test_cond_subgraph(self, pred: bool):
|
|
"""
|
|
Test a model with subgraphs.
|
|
"""
|
|
|
|
def foo(pred, x):
|
|
return torch.cond(pred, torch.cos, torch.sin, [x]) + 1
|
|
|
|
x = torch.randn((2, 3), device=self.device)
|
|
pred_tensor = torch.tensor([pred], device=self.device)
|
|
gm = self._compile_and_check(
|
|
foo, [pred_tensor, x], expected_num_triton_kernels=3
|
|
)[-1]
|
|
|
|
# Check for subgraphs.
|
|
subgm_getattrs = list(gm.graph.find_nodes(op="get_attr"))
|
|
self.assertEqual(len(subgm_getattrs), 2)
|
|
for subgm_getattr in subgm_getattrs:
|
|
target = subgm_getattr.name
|
|
self.assertTrue(isinstance(getattr(gm, target), torch.fx.GraphModule))
|
|
|
|
@parametrize("pred", (False, True))
|
|
def test_cond_no_operands(self, pred: bool):
|
|
"""
|
|
Test torch.cond when the subgraphs take no inputs.
|
|
"""
|
|
|
|
length = 8
|
|
|
|
def true_fn():
|
|
return torch.zeros(length, device=self.device)
|
|
|
|
def false_fn():
|
|
return true_fn() + 5
|
|
|
|
def foo(pred):
|
|
return torch.cond(pred, true_fn, false_fn, ())
|
|
|
|
pred_tensor = torch.tensor([pred], device=self.device)
|
|
self._compile_and_check(foo, [pred_tensor], expected_num_triton_kernels=2)
|
|
|
|
def test_cpp_raises(self):
|
|
"""
|
|
Test the C++ CPU backend. C++ kernels are not yet supported, so for now check
|
|
that we get the expected exception.
|
|
"""
|
|
|
|
def foo(x, y):
|
|
return x + y * 5
|
|
|
|
device = torch.device("cpu")
|
|
args = [torch.randn(5, device=device) for _ in range(2)]
|
|
|
|
cpp_backend = common.DeviceCodegen(CppScheduling, WrapperFxCodegen, None)
|
|
with (
|
|
unittest.mock.patch.dict(
|
|
common.device_codegens, {device.type: cpp_backend}
|
|
),
|
|
self.assertRaisesRegex(BackendCompilerFailed, "Triton"),
|
|
):
|
|
self._compile_and_check(foo, args)
|
|
|
|
@parametrize("enable_tuning", (False, True))
|
|
@parametrize("use_dynamic_shapes", (False, True))
|
|
def test_autotune(self, use_dynamic_shapes: bool, enable_tuning: bool):
|
|
orig_run = torch._inductor.runtime.triton_heuristics.CachingAutotuner.run
|
|
called = False
|
|
|
|
def run(*args, **kwargs):
|
|
nonlocal called
|
|
called = True
|
|
return orig_run(*args, **kwargs)
|
|
|
|
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
|
|
|
with (
|
|
config.patch("triton.autotune_at_compile_time", enable_tuning),
|
|
unittest.mock.patch.object(
|
|
torch._inductor.runtime.triton_heuristics.CachingAutotuner, "run", run
|
|
),
|
|
):
|
|
# Compile and check that the tuner was called.
|
|
self.assertFalse(called)
|
|
(gm,) = self._compile_and_check(
|
|
torch.mul, args, compile_kwargs={"dynamic": use_dynamic_shapes}
|
|
)
|
|
self.assertEqual(called, enable_tuning)
|
|
|
|
# Check for a symbolic output shape.
|
|
(empty_strided,) = gm.graph.find_nodes(
|
|
op="call_function", target=torch.empty_strided
|
|
)
|
|
(shape, stride) = empty_strided.args
|
|
if use_dynamic_shapes:
|
|
self.assertEqual(type(shape[0]), torch.fx.Node)
|
|
|
|
def test_custom_triton(self):
|
|
@triton.jit
|
|
def add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
output = torch.empty_like(x)
|
|
n_elements = output.numel()
|
|
|
|
def grid(meta):
|
|
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
|
|
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
|
|
return output
|
|
|
|
args = [torch.randn(32, device=self.device) for _ in range(2)]
|
|
self._compile_and_check(add, args)
|
|
|
|
def test_output_slice_view(self):
|
|
"""
|
|
Test when the output is a view of the input.
|
|
The sliced strides create a TensorBox in the output IR.
|
|
"""
|
|
|
|
def foo(x):
|
|
return x[0:2:2].T[3:].squeeze(0)
|
|
|
|
args = [torch.rand([4, 4, 4, 4], device=self.device)]
|
|
self._compile_and_check(foo, args, expected_num_triton_kernels=0)
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class AOTFxirTestCase(InductorTestCase):
|
|
device = GPU_TYPE
|
|
|
|
def check(
|
|
self, model, inp, dynamic_shapes=None, strict=False
|
|
) -> torch.fx.GraphModule:
|
|
if self.device == "xpu":
|
|
raise unittest.SkipTest("The feature AOTFxir not currently ready for XPU")
|
|
with torch.no_grad():
|
|
ep = torch.export.export(
|
|
model, inp, dynamic_shapes=dynamic_shapes, strict=strict
|
|
)
|
|
gm = torch._inductor.aot_compile(
|
|
ep.module(), inp, options={"fx_wrapper": True, **test_config}
|
|
)
|
|
self.assertTrue(same(model(*inp), gm(*inp)))
|
|
|
|
for node in gm.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target != triton_kernel_wrapper_mutation
|
|
):
|
|
self.assertTrue(node.meta.get("val", None) is not None)
|
|
|
|
return gm
|
|
|
|
def test_aoti_fx_add(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device))
|
|
self.check(M(), inp)
|
|
|
|
def test_aoti_fx_const(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.device = device
|
|
self.a = torch.nn.Parameter(torch.ones(3, device=self.device))
|
|
self.b = torch.ones(3, device=self.device)
|
|
|
|
def forward(self, x, y):
|
|
return x + y + self.a + self.b + torch.tensor(3, device=self.device)
|
|
|
|
inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device))
|
|
self.check(M(self.device), inp)
|
|
|
|
def test_aoti_fx_linear(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
inp = (torch.ones(3, 3, device=self.device),)
|
|
self.check(M().to(self.device), inp)
|
|
|
|
def test_aoti_fx_dynamic(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device))
|
|
self.check(
|
|
M().to(device=self.device),
|
|
inp,
|
|
dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
|
|
)
|
|
|
|
def test_custom_triton_autotune_dynamic(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
output = torch.zeros_like(x)
|
|
x_elements = output.size()[0]
|
|
y_elements = output.size()[1]
|
|
|
|
def grid(meta):
|
|
return (
|
|
triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
|
|
triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
|
|
)
|
|
|
|
add_kernel_2d_autotuned[grid](x, y, output, x_elements, y_elements)
|
|
|
|
return output
|
|
|
|
num_dims = 2
|
|
dims = [10] * num_dims
|
|
x = torch.randn(*dims, device=self.device)
|
|
y = torch.randn(*dims, device=self.device)
|
|
dim0_x = Dim("dim0_x", min=1, max=10)
|
|
dim0_y = Dim("dim0_y", min=1, max=10)
|
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
|
|
self.check(
|
|
Model().to(device=self.device),
|
|
(x, y),
|
|
dynamic_shapes=dynamic_shapes,
|
|
strict=True,
|
|
)
|
|
|
|
def test_custom_backend(self):
|
|
"""
|
|
Test registering a custom FX backend.
|
|
"""
|
|
called = False
|
|
|
|
class CustomWrapperCodegen(WrapperFxCodegen):
|
|
def compile_graph(self, gm):
|
|
"""
|
|
Simply records whether this override was called.
|
|
"""
|
|
nonlocal called
|
|
called = True
|
|
return super().compile_graph(gm)
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
# Register a custom FX backend.
|
|
custom_backend = common.DeviceCodegen(
|
|
TritonScheduling,
|
|
PythonWrapperCodegen,
|
|
fx_wrapper_codegen=CustomWrapperCodegen,
|
|
)
|
|
with unittest.mock.patch.dict(
|
|
common.device_codegens, {self.device: custom_backend}
|
|
):
|
|
# The backend should not have been called yet.
|
|
self.assertFalse(called)
|
|
|
|
inp = (torch.randn(8, device=self.device),)
|
|
self.check(M().to(self.device), inp)
|
|
|
|
# Now the backend should have been called.
|
|
self.assertTrue(called)
|
|
|
|
@parametrize(
|
|
"expr",
|
|
[
|
|
(2 * Dim("x") + 1),
|
|
(Dim("x", min=3) - 3),
|
|
],
|
|
)
|
|
def test_dynamic_input_expr(self, expr: sympy.Expr):
|
|
"""
|
|
Test dynamic shapes with a nontrivial input expression.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.reshape(x.shape[0] * x.shape[1]) + x.shape[1]
|
|
|
|
dynamic_shapes = {"x": {0: expr}}
|
|
inp = (torch.randn((5, 4), device=self.device),)
|
|
gm = self.check(M().to(self.device), inp, dynamic_shapes=dynamic_shapes)
|
|
|
|
# Check for dynamic size ops.
|
|
self.assertEqual(
|
|
len(
|
|
gm.graph.find_nodes(
|
|
op="call_function", target=torch.ops.aten.sym_size.int
|
|
)
|
|
),
|
|
1,
|
|
)
|
|
|
|
@parametrize("pred", (False, True))
|
|
def test_cond_multi_inputs_and_outputs(self, pred):
|
|
"""
|
|
Test torch.cond and check the output graphs.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, pred, x, y):
|
|
def true_fn(x, y):
|
|
return torch.tanh(x), torch.relu(y)
|
|
|
|
def false_fn(x, y):
|
|
return tuple(t / 2 for t in true_fn(x, y))
|
|
|
|
return torch.cond(pred, true_fn, false_fn, (x, y))
|
|
|
|
pred = torch.tensor([True], device=self.device)
|
|
(x, y) = [torch.randn(8, device=self.device) for _ in range(2)]
|
|
gm = self.check(M(), (pred, x, y))
|
|
|
|
# Check the graph.
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(arg0_1, true_graph_0, false_graph_0, (arg1_1, arg2_1)); arg0_1 = true_graph_0 = false_graph_0 = arg1_1 = arg2_1 = None
|
|
buf1 = cond[0]
|
|
buf2 = cond[1]; cond = None
|
|
return [buf1, buf2]""", # noqa: B950
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
if HAS_GPU or TRITON_HAS_CPU:
|
|
run_tests(needs="filelock")
|