mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
# Problem Inductor's FX backend receives sympy expressions for Triton launch grids, and passes these to a tracer to generate equivalent FX IR. However, the tracer does not support all possible sympy expressions. In particular, it can't handle ops like `floor` and `Pow` which would be found in an expression like `floor(x / y)`. Instead, it expects `FloorDiv(x, y)`, which has the advantage that all intermediate values are integers, unlike `x / y`. Inductor's Python backend uses a trick where `ceil(x / y)` is computed in Python as `-(x // -y)`, which is faster when evaluating Python launch grids at runtime. However, this trick generates more complex sympy expressions, so the FX backend introduced a `"python_slow"` mode using a more familiar form of ceil division. However, this mode is slower to evaluate, which increased production CPU usage. (Internal reviewers see T237853632.) # Solution To get the best of both worlds, this PR removes `"python_slow"` mode, and generalizes the `replace_floor_div` function to handle the more complex expressions resulting from the `"python"` grid mode. The new algorithm is conceptually similar to the existing one, except instead of analyzing only the first argument to a `sympy.Mul` op, it checks all factors, so it can handle expressions containing both `Rational` and `Pow` ops, among other cases. It also uses `Mul.make_args` to handle the case when the argument to `floor` is not a `Mul`. Finally, it uses `expr.is_positive` to check the sign of symbolic exponents. This new algorithm is guaranteed to convert all `floor` ops to an equivalent expression using `FloorDiv`. (To see this, consider that `floor(x) == FloorDiv(x, 1)`.) Note it may not remove all `Pow` ops, with a counterexample being `floor(x / (2 + z ** y))`, but it covers everything we've seen in practice for symbolic launch grids. In particular, it covers the typical case where `Pow` is a factor of the argument to `floor`, and the exponent is `-1`. Is this situation, we move the `Pow` to the denominator of `FloorDiv` and the exponent becomes `1`, eliminating the `Pow` op. # Test plan This PR adds an end-to-end test for static padding with dynamic outer dimensions, which creates a difficult sympy expression that the existing algorithm would not be able to handle. This PR also adds some unit tests for the `replace_floor_div` function. It can be difficult to construct end-to-end tests that expose all the trickiest expressions, as those tests have to pass through a number of other systems handling dynamic shapes. Therefore, it's easier to expose the edge cases with these new unit tests. The tests check that we can replace all `floor` ops in the input expression with `FloorDiv`, then they expand `FloorDiv` back to `floor` and check equality with the original expression. Note this PR also requires some MTIA changes to pass internal tests. Those will be stacked onto the imported diff. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163828 Approved by: https://github.com/nandesuka, https://github.com/angelayi, https://github.com/jansel
1196 lines
39 KiB
Python
1196 lines
39 KiB
Python
# Owner(s): ["module: inductor"]
|
|
"""
|
|
Test the FX IR backend.
|
|
"""
|
|
|
|
import itertools
|
|
import operator
|
|
import unittest
|
|
from collections.abc import Callable
|
|
from typing import Optional
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
import torch._inductor.codegen.common as common
|
|
import torch.utils._pytree as pytree
|
|
from torch._dynamo.exc import BackendCompilerFailed
|
|
from torch._dynamo.utils import same
|
|
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
|
|
from torch._inductor import config
|
|
from torch._inductor.codegen.cpp import CppScheduling
|
|
from torch._inductor.codegen.triton import TritonScheduling
|
|
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
|
|
from torch._inductor.codegen.wrapper_fxir import (
|
|
FxConverter,
|
|
replace_floor_div,
|
|
WrapperFxCodegen,
|
|
)
|
|
from torch._inductor.test_case import TestCase as InductorTestCase
|
|
from torch.export import Dim
|
|
from torch.testing._internal.common_utils import (
|
|
DeterministicGuard,
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
)
|
|
from torch.testing._internal.inductor_utils import (
|
|
GPU_TYPE,
|
|
HAS_GPU,
|
|
requires_gpu,
|
|
TRITON_HAS_CPU,
|
|
)
|
|
from torch.utils._sympy.functions import FloorDiv
|
|
|
|
|
|
try:
|
|
from .test_control_flow import CondModels
|
|
except ImportError:
|
|
from test_control_flow import (
|
|
CondModels, # @manual=fbcode//caffe2/test/inductor:control_flow-library
|
|
)
|
|
|
|
if HAS_GPU:
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from torch.testing._internal.triton_utils import add_kernel_2d_autotuned
|
|
|
|
test_config = {
|
|
"compile_threads": 1,
|
|
"alignment_asserts": False,
|
|
"size_asserts": False,
|
|
"scalar_asserts": False,
|
|
"nan_asserts": False,
|
|
}
|
|
|
|
|
|
@requires_gpu()
|
|
@config.patch(test_config)
|
|
@instantiate_parametrized_tests
|
|
class FxirTestCase(InductorTestCase):
|
|
device = GPU_TYPE
|
|
|
|
def _count_ops(self, gm: torch.fx.GraphModule, target: Callable) -> int:
|
|
return len(gm.graph.find_nodes(op="call_function", target=target))
|
|
|
|
def _run_and_capture_graphs(self, opt, args) -> torch.fx.GraphModule:
|
|
gms = []
|
|
|
|
orig_generate = FxConverter.generate
|
|
|
|
def generate(self) -> torch.fx.GraphModule:
|
|
nonlocal gms
|
|
gm = orig_generate(self)
|
|
gms.append(gm)
|
|
return gm
|
|
|
|
with unittest.mock.patch.object(
|
|
torch._inductor.codegen.wrapper_fxir.FxConverter, "generate", generate
|
|
):
|
|
opt(*args)
|
|
|
|
return gms
|
|
|
|
def _compile_and_check(
|
|
self,
|
|
func,
|
|
args,
|
|
expected_num_triton_kernels: int = 1,
|
|
metadata_only: bool = False,
|
|
compile_kwargs: Optional[dict] = None,
|
|
):
|
|
if compile_kwargs is None:
|
|
compile_kwargs = {}
|
|
|
|
opt = torch.compile(func, **compile_kwargs)
|
|
|
|
# Get the FX graph from the backend.
|
|
gms = self._run_and_capture_graphs(opt, args)
|
|
|
|
# Check the code for triton kernels.
|
|
num_kernels = sum(
|
|
self._count_ops(gm, triton_kernel_wrapper_mutation) for gm in gms
|
|
)
|
|
self.assertEqual(num_kernels, expected_num_triton_kernels)
|
|
|
|
# Check accuracy.
|
|
result = opt(*args)
|
|
ref = func(*args)
|
|
if metadata_only:
|
|
# When we only want to check metadata, fill in zeros for tensor data.
|
|
ref, result = tuple(
|
|
pytree.tree_map(torch.zeros_like, x) for x in (ref, result)
|
|
)
|
|
|
|
self.assertTrue(same(ref, result))
|
|
|
|
return gms
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
|
|
# Register the FX backend, storing the default for later.
|
|
common.init_backend_registration()
|
|
cls._default_backend = common.device_codegens[cls.device]
|
|
common.register_backend_for_device(
|
|
cls.device, TritonScheduling, WrapperFxCodegen
|
|
)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
super().tearDownClass()
|
|
|
|
# Restore the default backend.
|
|
common.device_codegens[cls.device] = cls._default_backend
|
|
|
|
def test_basic(self):
|
|
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
|
self._compile_and_check(torch.add, args)
|
|
|
|
def test_multiple_kernels(self):
|
|
def foo(x, y):
|
|
return x.sum() + y.sum()
|
|
|
|
args = [torch.randn(length, device=self.device) for length in [517, 1029]]
|
|
self._compile_and_check(foo, args, expected_num_triton_kernels=2)
|
|
|
|
def test_free(self):
|
|
"""
|
|
Test a program that frees a buffer which is no longer in use.
|
|
"""
|
|
|
|
def foo(x, y, z):
|
|
w = x.sum() + y
|
|
return z.sum() + w.sum()
|
|
|
|
args = [torch.randn(length, device=self.device) for length in [517, 1029, 123]]
|
|
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=3)
|
|
|
|
# Check the generated code for frees.
|
|
num_frees = gm.code.count("= None")
|
|
self.assertGreater(num_frees, 0)
|
|
|
|
def test_extern(self):
|
|
"""
|
|
Test a program that calls an extern kernel.
|
|
"""
|
|
|
|
def foo(x, y):
|
|
return x @ y + y.sum()
|
|
|
|
args = [
|
|
torch.randn(size, device=self.device) for size in [(129, 129), (129, 1)]
|
|
]
|
|
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
|
|
|
# Check for the extern kernel
|
|
num_extern = self._count_ops(gm, torch.ops.aten.addmm.out)
|
|
self.assertEqual(num_extern, 1)
|
|
|
|
def test_fallback(self):
|
|
"""
|
|
Test a program that calls aten fallbacks.
|
|
"""
|
|
|
|
def foo(x):
|
|
batch1 = torch.randn(2, 3, 5, device=self.device)
|
|
batch2 = torch.randn(2, 5, 4, device=self.device)
|
|
return torch.addbmm(x, batch1, batch2)
|
|
|
|
args = (torch.randn(3, 4, device=self.device),)
|
|
|
|
# Since the program has a random output, just check metadata.
|
|
# Don't check for an exact value.
|
|
(gm,) = self._compile_and_check(
|
|
foo, args, expected_num_triton_kernels=2, metadata_only=True
|
|
)
|
|
|
|
# Check for the fallback kernel.
|
|
num_fallback = self._count_ops(
|
|
gm, torch.ops.aten.randint.low_out
|
|
) + self._count_ops(gm, torch.ops.aten.addbmm.default)
|
|
self.assertEqual(num_fallback, 2)
|
|
|
|
def test_cat_inputs(self):
|
|
"""
|
|
Test concatenation of graph inputs.
|
|
"""
|
|
|
|
def foo(x, y):
|
|
return torch.cat((x, y)) + 1
|
|
|
|
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
|
self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
|
|
|
def test_cat_views(self):
|
|
"""
|
|
Test concatenation with multiple kernels writing to the same buffer.
|
|
"""
|
|
|
|
def foo(x, y):
|
|
a = x - 2
|
|
b = y.sum(0, keepdim=True)
|
|
c = torch.cat((a, b)).clone()
|
|
return a, b, c
|
|
|
|
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
|
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=2)
|
|
|
|
def get_offset(node: torch.fx.Node) -> int:
|
|
(input_, shape, stride, offset) = node.args
|
|
assert isinstance(offset, int)
|
|
return offset
|
|
|
|
# Check for 2 views, one of which is offset.
|
|
as_strided_nodes = list(
|
|
gm.graph.find_nodes(op="call_function", target=torch.as_strided)
|
|
)
|
|
self.assertEqual(len(as_strided_nodes), 2)
|
|
num_offset_views = sum(get_offset(node) > 0 for node in as_strided_nodes)
|
|
self.assertEqual(num_offset_views, 1)
|
|
|
|
def test_cat_to_alloc(self):
|
|
"""
|
|
Test concatenation that's optimized out to an allocation.
|
|
"""
|
|
length = 8
|
|
|
|
def foo(x):
|
|
y, z = tuple(
|
|
torch.arange(length // 2, device=self.device) for _ in range(2)
|
|
)
|
|
return x + torch.cat((y, z))
|
|
|
|
args = [torch.randn(length, device=self.device)]
|
|
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
|
|
|
# Expect a single allocation, even though eager mode would use 2.
|
|
num_allocs = self._count_ops(gm, torch.empty_strided)
|
|
self.assertEqual(num_allocs, 1)
|
|
|
|
def test_cat_reinterpret_view(self):
|
|
"""
|
|
Test torch.cat using ReinterpretView.
|
|
"""
|
|
length = 8
|
|
|
|
def foo(x):
|
|
y, z = tuple(torch.randn(length // 2, device=self.device) for _ in range(2))
|
|
return x + torch.cat((y, z))
|
|
|
|
args = [torch.randn(length, device=self.device)]
|
|
|
|
# Since this test generates random numbers, check metadata only.
|
|
(gm,) = self._compile_and_check(
|
|
foo, args, expected_num_triton_kernels=3, metadata_only=True
|
|
)
|
|
|
|
# Check for as_strided. We map ReinterpretView to this.
|
|
num_as_strided = self._count_ops(gm, torch.as_strided)
|
|
self.assertEqual(num_as_strided, 2)
|
|
|
|
def test_reshape_output(self):
|
|
"""
|
|
Test reshaping the output, which maps to a ReinterpretView.
|
|
"""
|
|
|
|
def foo(x, y):
|
|
return torch.reshape(x + y, (8,))
|
|
|
|
args = [torch.randn((2, 4), device=self.device) for _ in range(2)]
|
|
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
|
|
|
# Check for as_strided. We map ReinterpretView to this.
|
|
num_as_strided = self._count_ops(gm, torch.as_strided)
|
|
self.assertEqual(num_as_strided, 1)
|
|
|
|
def test_extern_multi_output(self):
|
|
"""
|
|
Test an extern kernel with multiple outputs.
|
|
Also test a graph with multiple outputs.
|
|
"""
|
|
|
|
def foo(x):
|
|
top, idx = torch.topk(x, 2)
|
|
return top + 1, idx * 2
|
|
|
|
args = [torch.randn(8, device=self.device)]
|
|
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=2)
|
|
|
|
# Check for multiple kernel outputs via getitems.
|
|
num_getitems = self._count_ops(gm, operator.getitem)
|
|
self.assertEqual(num_getitems, 2)
|
|
|
|
# Check for multiple graph outputs.
|
|
output_node = gm.graph.find_nodes(op="output")[0]
|
|
self.assertEqual(len(output_node.args[0]), 2)
|
|
|
|
def test_duplicate_input(self):
|
|
"""
|
|
Test duplicated inputs. This will collapse into a single input in the GM.
|
|
"""
|
|
|
|
args = [torch.randn(4, device=self.device)] * 2
|
|
(gm,) = self._compile_and_check(torch.add, args, expected_num_triton_kernels=1)
|
|
|
|
num_placeholders = len(gm.graph.find_nodes(op="placeholder"))
|
|
self.assertEqual(num_placeholders, 1)
|
|
|
|
def test_backward(self):
|
|
"""
|
|
Test a program with a backward pass.
|
|
"""
|
|
|
|
x = torch.ones(5, device=self.device) # input tensor
|
|
y = torch.zeros(3, device=self.device) # expected output
|
|
w = torch.randn(5, 3, requires_grad=True, device=self.device)
|
|
b = torch.randn(3, requires_grad=True, device=self.device)
|
|
|
|
def foo(x, y):
|
|
z = torch.matmul(x, w) + b
|
|
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
|
|
loss.backward()
|
|
return w.grad, b.grad
|
|
|
|
# Expect separate forward and backward graphs.
|
|
(forward_gm, backward_gm) = self._compile_and_check(
|
|
foo, (x, y), expected_num_triton_kernels=3
|
|
)
|
|
|
|
def test_custom_compiler(self):
|
|
"""
|
|
Test a derived backend with a custom compiler.
|
|
"""
|
|
offset = 1
|
|
|
|
class CustomWrapperCodegen(WrapperFxCodegen):
|
|
def compile_graph(self, gm):
|
|
def compiled_fn(*args):
|
|
# Adds an offset to the program's outputs.
|
|
outputs = gm(*args)
|
|
return pytree.tree_map(lambda x: x + 1, outputs)
|
|
|
|
return compiled_fn
|
|
|
|
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
|
custom_backend = common.DeviceCodegen(
|
|
TritonScheduling, CustomWrapperCodegen, None
|
|
)
|
|
with unittest.mock.patch.dict(
|
|
common.device_codegens, {self.device: custom_backend}
|
|
):
|
|
func = torch.add
|
|
opt = torch.compile(func)
|
|
result = opt(*args)
|
|
|
|
# Check the output is offset from eager mode.
|
|
ref = func(*args)
|
|
self.assertFalse(same(result, ref))
|
|
self.assertNotEqual(offset, 0)
|
|
self.assertTrue(same(result - offset, ref))
|
|
|
|
def test_dynamic_shapes_and_strides(self):
|
|
"""
|
|
Test a graph with dynamic shapes and strides.
|
|
"""
|
|
|
|
static_dims = (8, 8)
|
|
|
|
def get_input():
|
|
full_size = (16, 8)
|
|
full = torch.randn(full_size, device=self.device)
|
|
view = torch.as_strided(full, static_dims, full.stride())
|
|
return view
|
|
|
|
func = torch.add
|
|
args = [get_input() for _ in range(2)]
|
|
(gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True})
|
|
|
|
# Check for a symbolic output shape.
|
|
(empty_strided,) = gm.graph.find_nodes(
|
|
op="call_function", target=torch.empty_strided
|
|
)
|
|
example_tensor = empty_strided.meta["val"]
|
|
symbolic_dims = example_tensor.shape
|
|
self.assertEqual(len(symbolic_dims), len(static_dims))
|
|
|
|
# Check for symbolic output strides.
|
|
(stride, one) = example_tensor.stride()
|
|
self.assertEqual(one, sympy.S.One)
|
|
|
|
# Find the size symbols, and check for a corresponding placeholders defining them.
|
|
for symbol in itertools.chain(symbolic_dims, [stride]):
|
|
self.assertTrue(isinstance(symbol, torch.SymInt))
|
|
(placeholder,) = [
|
|
node
|
|
for node in gm.graph.find_nodes(op="placeholder")
|
|
if node.name == str(symbol)
|
|
]
|
|
self.assertEqual(placeholder.meta["val"], symbol)
|
|
|
|
@parametrize(
|
|
"shape",
|
|
[
|
|
(20,),
|
|
(50, 30),
|
|
(50, 30, 40),
|
|
],
|
|
)
|
|
@torch._inductor.config.patch(
|
|
{
|
|
"pad_dynamic_shapes": True,
|
|
"comprehensive_padding": True,
|
|
"padding_alignment_bytes": 32,
|
|
"pad_outputs": True,
|
|
}
|
|
)
|
|
def test_dynamic_shapes_with_padding(self, shape):
|
|
"""
|
|
Test a graph with dynamic shapes with padding.
|
|
"""
|
|
|
|
def get_input(shape):
|
|
pad_size = list(shape)
|
|
pad_size[-1] = ((shape[-1] + 7) // 8) * 8
|
|
pad = torch.randn(pad_size, dtype=torch.float32, device=self.device)
|
|
view = torch.as_strided(pad, shape, pad.stride())
|
|
return view
|
|
|
|
args = [get_input(shape) for _ in range(2)]
|
|
(gm,) = self._compile_and_check(
|
|
torch.add, args, compile_kwargs={"dynamic": True}
|
|
)
|
|
|
|
# Check for a symbolic output shape.
|
|
(empty_strided,) = gm.graph.find_nodes(
|
|
op="call_function", target=torch.empty_strided
|
|
)
|
|
example_tensor = empty_strided.meta["val"]
|
|
symbolic_dims = example_tensor.shape
|
|
symbolic_strides = example_tensor.stride()
|
|
|
|
align_elems = 32 // args[0].dtype.itemsize
|
|
expected_strides = [1 for _ in range(len(shape))]
|
|
for i in range(len(shape) - 1, 0, -1):
|
|
expected_strides[i - 1] = align_elems * (
|
|
((expected_strides[i] * symbolic_dims[i]) + align_elems - 1)
|
|
// align_elems
|
|
)
|
|
for i, j in zip(symbolic_strides, expected_strides):
|
|
self.assertEqual(i, j)
|
|
|
|
def test_dynamic_shapes_precomputed_size(self):
|
|
"""
|
|
Test dynamic shapes where a kernel's size arg is precomputed.
|
|
"""
|
|
func = torch.add
|
|
args = [
|
|
torch.randn(shape, device=self.device) for shape in [(7, 12, 9), (7, 1, 1)]
|
|
]
|
|
(gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True})
|
|
|
|
# Check for the precomputed size arg.
|
|
(triton_node,) = gm.graph.find_nodes(
|
|
op="call_function", target=triton_kernel_wrapper_mutation
|
|
)
|
|
self.assertIn("ks0", triton_node.kwargs["kwargs"])
|
|
|
|
def test_dynamic_launch_grid_calc(self):
|
|
"""
|
|
Test the dyanmic launch grid calculation.
|
|
"""
|
|
|
|
func = torch.add
|
|
args = [torch.randn(shape, device=self.device) for shape in [(7, 12), (7, 1)]]
|
|
(gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True})
|
|
|
|
# Check for the precomputed size arg.
|
|
(triton_node,) = gm.graph.find_nodes(
|
|
op="call_function", target=triton_kernel_wrapper_mutation
|
|
)
|
|
self.assertIn("grid", triton_node.kwargs)
|
|
self.assertIn("xnumel", triton_node.kwargs["kwargs"])
|
|
self.assertIn("XBLOCK", triton_node.kwargs["kwargs"])
|
|
grid = triton_node.kwargs["grid"][0]
|
|
xnumel = triton_node.kwargs["kwargs"]["xnumel"].meta["val"]
|
|
xblock = triton_node.kwargs["kwargs"]["XBLOCK"]
|
|
self.assertEqual(grid[0].meta["val"], -(-xnumel // xblock))
|
|
self.assertEqual(grid[1], 1)
|
|
self.assertEqual(grid[2], 1)
|
|
|
|
@config.patch({"trace.enabled": True})
|
|
@unittest.mock.patch("torch._inductor.debug.DebugFormatter.output_code")
|
|
def test_debug(self, mock_output_code):
|
|
# Compile in debug mode.
|
|
args = [torch.randn(11, device=self.device) for _ in range(2)]
|
|
self._compile_and_check(torch.sub, args)
|
|
|
|
# Check the output code for a Triton kernel call.
|
|
mock_output_code.assert_called_once()
|
|
(output_filename,) = mock_output_code.call_args.args
|
|
with open(output_filename) as f:
|
|
output_code = f.read()
|
|
self.assertIn("triton_kernel_wrapper_mutation", output_code)
|
|
|
|
@parametrize(
|
|
"const",
|
|
(1, 1.5),
|
|
)
|
|
def test_export_const_placeholder(self, const):
|
|
"""
|
|
Test that we can compile a graph coming from torch.export with a constant input.
|
|
"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x - y
|
|
|
|
args = (torch.randn(8, device=self.device), const)
|
|
mod = TestModule()
|
|
export_gm = torch.export.export(mod, args).module()
|
|
|
|
def compile_module(*inps):
|
|
torch._inductor.compile(export_gm, inps)
|
|
|
|
(inductor_gm,) = self._run_and_capture_graphs(compile_module, args)
|
|
result = inductor_gm(*args)
|
|
ref = mod(*args)
|
|
|
|
self.assertTrue(same(ref, result))
|
|
|
|
def test_scatter_fallback_scalar_src(self):
|
|
"""
|
|
Test a special case where ScatterFallback takes a scalar 'src' argument.
|
|
"""
|
|
|
|
def foo(input_):
|
|
dim = 0
|
|
src = 1.5
|
|
return torch.ops.aten.scatter(input_, dim, index, src)
|
|
|
|
length = 8
|
|
index = torch.randint(length, (length,), device=self.device)
|
|
input_ = torch.randn(length, device=self.device)
|
|
with DeterministicGuard(True):
|
|
(gm,) = self._compile_and_check(
|
|
foo,
|
|
(input_,),
|
|
)
|
|
|
|
# Check for the fallback op.
|
|
num_fallback = self._count_ops(gm, torch.ops.aten.scatter_.value)
|
|
self.assertEqual(num_fallback, 1)
|
|
|
|
def test_index_put_fallback(self):
|
|
"""
|
|
Test the deterministic fallback for index_put.
|
|
"""
|
|
length = 8
|
|
out, values = [torch.randn(length, device=self.device) for _ in range(2)]
|
|
indices = (torch.randint(length, (length,), device=self.device),)
|
|
accumulate = True
|
|
with DeterministicGuard(True):
|
|
(gm,) = self._compile_and_check(
|
|
torch.index_put,
|
|
(out, indices, values, accumulate),
|
|
expected_num_triton_kernels=1,
|
|
)
|
|
|
|
# Check for the fallback op.
|
|
self.assertEqual(self._count_ops(gm, torch.ops.aten.index_put_.default), 1)
|
|
|
|
def test_scatter_reduce_fallback(self):
|
|
"""
|
|
Test the customized wrapper codegen for ScatterFallback ops.
|
|
"""
|
|
fallback_op = torch.ops.aten.scatter_reduce_.two
|
|
|
|
def foo(out, index, src):
|
|
dim = 0
|
|
out = fallback_op(out, dim, index, src, reduce="amax", include_self=False)
|
|
return out + 1
|
|
|
|
length = 8
|
|
out, src = [torch.randn(length, device=self.device) for _ in range(2)]
|
|
index = torch.randint(length, (length,), device=self.device)
|
|
(gm,) = self._compile_and_check(
|
|
foo, (out, index, src), expected_num_triton_kernels=2
|
|
)
|
|
|
|
# Check for the fallback.
|
|
self.assertEqual(self._count_ops(gm, fallback_op), 1)
|
|
|
|
@parametrize("pred", (False, True))
|
|
def test_cond_subgraph(self, pred: bool):
|
|
"""
|
|
Test a model with subgraphs.
|
|
"""
|
|
|
|
def foo(pred, x):
|
|
return torch.cond(pred, torch.cos, torch.sin, [x]) + 1
|
|
|
|
x = torch.randn((2, 3), device=self.device)
|
|
pred_tensor = torch.tensor([pred], device=self.device)
|
|
gm = self._compile_and_check(
|
|
foo, [pred_tensor, x], expected_num_triton_kernels=3
|
|
)[-1]
|
|
|
|
# Check for subgraphs.
|
|
subgm_getattrs = list(gm.graph.find_nodes(op="get_attr"))
|
|
self.assertEqual(len(subgm_getattrs), 2)
|
|
for subgm_getattr in subgm_getattrs:
|
|
target = subgm_getattr.name
|
|
self.assertTrue(isinstance(getattr(gm, target), torch.fx.GraphModule))
|
|
|
|
@parametrize("pred", (False, True))
|
|
def test_cond_no_operands(self, pred: bool):
|
|
"""
|
|
Test torch.cond when the subgraphs take no inputs.
|
|
"""
|
|
|
|
length = 8
|
|
|
|
def true_fn():
|
|
return torch.zeros(length, device=self.device)
|
|
|
|
def false_fn():
|
|
return true_fn() + 5
|
|
|
|
def foo(pred):
|
|
return torch.cond(pred, true_fn, false_fn, ())
|
|
|
|
pred_tensor = torch.tensor([pred], device=self.device)
|
|
self._compile_and_check(foo, [pred_tensor], expected_num_triton_kernels=2)
|
|
|
|
def test_cpp_raises(self):
|
|
"""
|
|
Test the C++ CPU backend. C++ kernels are not yet supported, so for now check
|
|
that we get the expected exception.
|
|
"""
|
|
|
|
def foo(x, y):
|
|
return x + y * 5
|
|
|
|
device = torch.device("cpu")
|
|
args = [torch.randn(5, device=device) for _ in range(2)]
|
|
|
|
cpp_backend = common.DeviceCodegen(CppScheduling, WrapperFxCodegen, None)
|
|
with (
|
|
unittest.mock.patch.dict(
|
|
common.device_codegens, {device.type: cpp_backend}
|
|
),
|
|
self.assertRaisesRegex(BackendCompilerFailed, "Triton"),
|
|
):
|
|
self._compile_and_check(foo, args)
|
|
|
|
@parametrize("enable_tuning", (False, True))
|
|
@parametrize("use_dynamic_shapes", (False, True))
|
|
def test_autotune(self, use_dynamic_shapes: bool, enable_tuning: bool):
|
|
orig_run = torch._inductor.runtime.triton_heuristics.CachingAutotuner.run
|
|
called = False
|
|
|
|
def run(*args, **kwargs):
|
|
nonlocal called
|
|
called = True
|
|
return orig_run(*args, **kwargs)
|
|
|
|
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
|
|
|
with (
|
|
config.patch("triton.autotune_at_compile_time", enable_tuning),
|
|
unittest.mock.patch.object(
|
|
torch._inductor.runtime.triton_heuristics.CachingAutotuner, "run", run
|
|
),
|
|
):
|
|
# Compile and check that the tuner was called.
|
|
self.assertFalse(called)
|
|
(gm,) = self._compile_and_check(
|
|
torch.mul, args, compile_kwargs={"dynamic": use_dynamic_shapes}
|
|
)
|
|
self.assertEqual(called, enable_tuning)
|
|
|
|
# Check for a symbolic output shape.
|
|
(empty_strided,) = gm.graph.find_nodes(
|
|
op="call_function", target=torch.empty_strided
|
|
)
|
|
(shape, stride) = empty_strided.args
|
|
if use_dynamic_shapes:
|
|
self.assertEqual(type(shape[0]), torch.fx.Node)
|
|
|
|
def test_custom_triton(self):
|
|
@triton.jit
|
|
def add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
output = torch.empty_like(x)
|
|
n_elements = output.numel()
|
|
|
|
def grid(meta):
|
|
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
|
|
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
|
|
return output
|
|
|
|
args = [torch.randn(32, device=self.device) for _ in range(2)]
|
|
self._compile_and_check(add, args)
|
|
|
|
def test_output_slice_view(self):
|
|
"""
|
|
Test when the output is a view of the input.
|
|
The sliced strides create a TensorBox in the output IR.
|
|
"""
|
|
|
|
def foo(x):
|
|
return x[0:2:2].T[3:].squeeze(0)
|
|
|
|
args = [torch.rand([4, 4, 4, 4], device=self.device)]
|
|
self._compile_and_check(foo, args, expected_num_triton_kernels=0)
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class AOTFxirTestCase(InductorTestCase):
|
|
device = GPU_TYPE
|
|
|
|
def check(
|
|
self, model, inp, dynamic_shapes=None, strict=False
|
|
) -> torch.fx.GraphModule:
|
|
if self.device == "xpu":
|
|
raise unittest.SkipTest("The feature AOTFxir not currently ready for XPU")
|
|
with torch.no_grad():
|
|
ep = torch.export.export(
|
|
model, inp, dynamic_shapes=dynamic_shapes, strict=strict
|
|
)
|
|
gm = torch._inductor.aot_compile(
|
|
ep.module(), inp, options={"fx_wrapper": True, **test_config}
|
|
)
|
|
self.assertTrue(same(model(*inp), gm(*inp)))
|
|
|
|
for node in gm.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target != triton_kernel_wrapper_mutation
|
|
):
|
|
self.assertTrue(node.meta.get("val", None) is not None)
|
|
|
|
return gm
|
|
|
|
def test_aoti_fx_add(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device))
|
|
self.check(M(), inp)
|
|
|
|
def test_aoti_fx_const(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.device = device
|
|
self.a = torch.nn.Parameter(torch.ones(3, device=self.device))
|
|
self.b = torch.ones(3, device=self.device)
|
|
|
|
def forward(self, x, y):
|
|
return x + y + self.a + self.b + torch.tensor(3, device=self.device)
|
|
|
|
inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device))
|
|
self.check(M(self.device), inp)
|
|
|
|
def test_aoti_fx_linear(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
inp = (torch.ones(3, 3, device=self.device),)
|
|
self.check(M().to(self.device), inp)
|
|
|
|
def test_aoti_fx_dynamic(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device))
|
|
self.check(
|
|
M().to(device=self.device),
|
|
inp,
|
|
dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
|
|
)
|
|
|
|
def test_custom_triton_autotune_dynamic(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
output = torch.zeros_like(x)
|
|
x_elements = output.size()[0]
|
|
y_elements = output.size()[1]
|
|
|
|
def grid(meta):
|
|
return (
|
|
triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
|
|
triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
|
|
)
|
|
|
|
add_kernel_2d_autotuned[grid](x, y, output, x_elements, y_elements)
|
|
|
|
return output
|
|
|
|
num_dims = 2
|
|
dims = [10] * num_dims
|
|
x = torch.randn(*dims, device=self.device)
|
|
y = torch.randn(*dims, device=self.device)
|
|
dim0_x = Dim("dim0_x", min=1, max=10)
|
|
dim0_y = Dim("dim0_y", min=1, max=10)
|
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
|
|
self.check(
|
|
Model().to(device=self.device),
|
|
(x, y),
|
|
dynamic_shapes=dynamic_shapes,
|
|
strict=True,
|
|
)
|
|
|
|
def test_custom_backend(self):
|
|
"""
|
|
Test registering a custom FX backend.
|
|
"""
|
|
called = False
|
|
|
|
class CustomWrapperCodegen(WrapperFxCodegen):
|
|
def compile_graph(self, gm):
|
|
"""
|
|
Simply records whether this override was called.
|
|
"""
|
|
nonlocal called
|
|
called = True
|
|
return super().compile_graph(gm)
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
# Register a custom FX backend.
|
|
custom_backend = common.DeviceCodegen(
|
|
TritonScheduling,
|
|
PythonWrapperCodegen,
|
|
fx_wrapper_codegen=CustomWrapperCodegen,
|
|
)
|
|
with unittest.mock.patch.dict(
|
|
common.device_codegens, {self.device: custom_backend}
|
|
):
|
|
# The backend should not have been called yet.
|
|
self.assertFalse(called)
|
|
|
|
inp = (torch.randn(8, device=self.device),)
|
|
self.check(M().to(self.device), inp)
|
|
|
|
# Now the backend should have been called.
|
|
self.assertTrue(called)
|
|
|
|
@parametrize(
|
|
"expr",
|
|
[
|
|
(2 * Dim("x") + 1),
|
|
(Dim("x", min=3) - 3),
|
|
],
|
|
)
|
|
def test_dynamic_input_expr(self, expr: sympy.Expr):
|
|
"""
|
|
Test dynamic shapes with a nontrivial input expression.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.reshape(x.shape[0] * x.shape[1]) + x.shape[1]
|
|
|
|
dynamic_shapes = {"x": {0: expr}}
|
|
inp = (torch.randn((5, 4), device=self.device),)
|
|
gm = self.check(M().to(self.device), inp, dynamic_shapes=dynamic_shapes)
|
|
|
|
# Check for dynamic size ops.
|
|
self.assertEqual(
|
|
len(
|
|
gm.graph.find_nodes(
|
|
op="call_function", target=torch.ops.aten.sym_size.int
|
|
)
|
|
),
|
|
1,
|
|
)
|
|
|
|
@parametrize("pred", (False, True))
|
|
def test_cond_multi_inputs_and_outputs(self, pred):
|
|
"""
|
|
Test torch.cond and check the output graphs.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, pred, x, y):
|
|
def true_fn(x, y):
|
|
return torch.tanh(x), torch.relu(y)
|
|
|
|
def false_fn(x, y):
|
|
return tuple(t / 2 for t in true_fn(x, y))
|
|
|
|
return torch.cond(pred, true_fn, false_fn, (x, y))
|
|
|
|
pred = torch.tensor([True], device=self.device)
|
|
(x, y) = [torch.randn(8, device=self.device) for _ in range(2)]
|
|
gm = self.check(M(), (pred, x, y))
|
|
|
|
# Check the graph.
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(arg0_1, true_graph_0, false_graph_0, (arg1_1, arg2_1)); arg0_1 = true_graph_0 = false_graph_0 = arg1_1 = arg2_1 = None
|
|
buf1 = cond[0]
|
|
buf2 = cond[1]; cond = None
|
|
return [buf1, buf2]""", # noqa: B950
|
|
)
|
|
|
|
def test_dims_dynamic_outer_static_padded_inner(self):
|
|
"""
|
|
Test padding on inner dimensions, with dynamic outer dimensions.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
def get_input_padded_inner(shape):
|
|
full_shape = shape[:-1] + (shape[-1] * 2,)
|
|
full = torch.randn(full_shape, dtype=torch.float32, device=self.device)
|
|
view = torch.as_strided(full, shape, full.stride())
|
|
return view
|
|
|
|
shape = (4, 4, 4)
|
|
args = tuple(get_input_padded_inner(shape) for _ in range(2))
|
|
self.check(
|
|
M(),
|
|
args,
|
|
dynamic_shapes=({0: Dim.DYNAMIC, 1: Dim.DYNAMIC, 2: Dim.STATIC},) * 2,
|
|
)
|
|
|
|
@parametrize("length", (4, 8))
|
|
def test_cond_dynamic_shape_pred_scalar_closure(self, length: int):
|
|
"""
|
|
Test cond using a predicate computed from dynamic shapes.
|
|
Also test a dynamic scalar computed outside the branches.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
z = x.reshape(-1)
|
|
a = y.shape[0]
|
|
|
|
def true_fn(x):
|
|
return x + a
|
|
|
|
def false_fn(x):
|
|
return true_fn(x) / 2
|
|
|
|
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (z,))
|
|
|
|
(x, y) = [
|
|
torch.randn(shape, device=self.device)
|
|
for shape in [(length // 2,) * 2, (length,)]
|
|
]
|
|
dynamic_shapes = {
|
|
"x": {0: Dim.DYNAMIC},
|
|
"y": {0: Dim.DYNAMIC},
|
|
}
|
|
self.check(M(), (x, y), dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_dynamic_scalar_output(self):
|
|
"""
|
|
Test an output scalar from dynamic shapes.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.shape[0] * 3
|
|
|
|
x = torch.randn(7, device=self.device)
|
|
self.check(M(), (x,), dynamic_shapes=({0: Dim.DYNAMIC},))
|
|
|
|
@parametrize("pred", (False, True))
|
|
def test_mismatched_branch_dynamic(self, pred: bool):
|
|
"""
|
|
Test cond branches with mismatched dynamic shapes.
|
|
"""
|
|
|
|
# Apply an offset to guarantee the truith of the predicate.
|
|
pred_offset = 1 if pred else -1
|
|
|
|
inputs = [
|
|
torch.tensor([pred], device=self.device),
|
|
] + [torch.randn(10, 20, device=self.device) + pred_offset for _ in range(3)]
|
|
dim0_a = Dim("s0", min=4, max=1024)
|
|
dim0_b = Dim("s1", min=4, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"x": {0: dim0_a, 1: None},
|
|
"y": {0: dim0_b, 1: None},
|
|
"z": {0: dim0_a, 1: None},
|
|
}
|
|
|
|
self.check(
|
|
CondModels.MismatchedOutputSize(),
|
|
tuple(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
|
|
class TestReplaceFloorDiv(InductorTestCase):
|
|
"""
|
|
Tests for floor -> FloorDiv conversion.
|
|
"""
|
|
|
|
def _check(self, expr: sympy.Expr) -> sympy.Expr:
|
|
# Check that we started with floor's.
|
|
num_floors = expr.count(sympy.floor)
|
|
self.assertGreater(num_floors, 0)
|
|
|
|
replaced = replace_floor_div(expr)
|
|
|
|
# Check that all floor's were replaced.
|
|
# We shoud have no more new FloorDiv's than floor's in the original expression,
|
|
# although we can have less due to simplification.
|
|
self.assertEqual(replaced.count(sympy.floor), 0)
|
|
self.assertLessEqual(
|
|
replaced.count(FloorDiv) - expr.count(FloorDiv), num_floors
|
|
)
|
|
|
|
def expand_floor_div(
|
|
numerator: sympy.Expr, denominator: sympy.Expr
|
|
) -> sympy.Expr:
|
|
return sympy.floor(numerator / denominator)
|
|
|
|
# Expand FloorDiv back into floor and check for equality.
|
|
self.assertEqual(
|
|
*[
|
|
sympy.simplify(e.replace(FloorDiv, expand_floor_div))
|
|
for e in (replaced, expr)
|
|
]
|
|
)
|
|
|
|
return replaced
|
|
|
|
def test_rewrite_floor_div_mul_pow(self):
|
|
x, y = sympy.symbols("x y")
|
|
expr = sympy.floor(x / y)
|
|
self.assertEqual(expr.count(FloorDiv), 0)
|
|
self.assertEqual(expr.count(sympy.core.mul.Mul), 1)
|
|
self.assertEqual(expr.count(sympy.Pow), 1)
|
|
|
|
rewritten = self._check(expr)
|
|
self.assertTrue(isinstance(rewritten, FloorDiv))
|
|
self.assertEqual(rewritten.args, (x, y))
|
|
|
|
def test_rewrite_floor_div_mul_rational(self):
|
|
x = sympy.Symbol("x")
|
|
expr = sympy.floor(x / 5)
|
|
self.assertEqual(expr.count(FloorDiv), 0)
|
|
self.assertEqual(expr.count(sympy.core.mul.Mul), 1)
|
|
self.assertEqual(expr.count(sympy.Rational), 1)
|
|
|
|
rewritten = self._check(expr)
|
|
self.assertTrue(isinstance(rewritten, FloorDiv))
|
|
self.assertEqual(rewritten.args, (x, 5))
|
|
|
|
def test_no_rewrite_div(self):
|
|
x, y = sympy.symbols("x y")
|
|
expr = x / y
|
|
self.assertEqual(expr.count(FloorDiv), 0)
|
|
|
|
rewritten = replace_floor_div(expr)
|
|
self.assertEqual(rewritten, expr)
|
|
|
|
def test_rewrite_floor_div_nested(self):
|
|
x, y = sympy.symbols("x y")
|
|
expr = sympy.floor((sympy.floor(x / 5) + 1) / y)
|
|
self.assertEqual(expr.count(FloorDiv), 0)
|
|
|
|
rewritten = self._check(expr)
|
|
self.assertEqual(rewritten.count(FloorDiv), 2)
|
|
|
|
def test_rewrite_floor_div_rational_const(self):
|
|
expr = sympy.floor(sympy.S.One / 5, evaluate=False)
|
|
self.assertEqual(expr.count(FloorDiv), 0)
|
|
self.assertEqual(expr.count(sympy.Mul), 0)
|
|
self.assertEqual(expr.count(sympy.Rational), 1)
|
|
|
|
# Expression evaluates to a compile time constant
|
|
rewritten = self._check(expr)
|
|
self.assertEqual(rewritten, sympy.S.Zero)
|
|
|
|
def test_no_distribute_mul_floordiv(self):
|
|
"""
|
|
Test that multiplication doesn't distribute with floor division.
|
|
"""
|
|
x = sympy.Symbol("x")
|
|
expr = 2 * sympy.floor(x / 2)
|
|
rewritten = self._check(expr)
|
|
self.assertEqual(rewritten.count(sympy.Mul), 1)
|
|
self.assertEqual(rewritten.count(FloorDiv), 1)
|
|
|
|
def test_rational_multi_pows(self):
|
|
"""
|
|
Test an expression with a rational and multiple pows.
|
|
"""
|
|
x, y, z = sympy.symbols("x y z")
|
|
expr = sympy.floor((x / 5) * (y**2) * (z**3))
|
|
mul = expr.args[0]
|
|
self.assertTrue(isinstance(mul, sympy.Mul))
|
|
self.assertTrue(isinstance(mul.args[0], sympy.Rational))
|
|
self.assertEqual(expr.count(sympy.Pow), 2)
|
|
rewritten = self._check(expr)
|
|
self.assertEqual(rewritten.count(FloorDiv), 1)
|
|
|
|
def test_variable_exp(self):
|
|
"""
|
|
Test pow when the exponent is a variable.
|
|
"""
|
|
x = sympy.Symbol("x", positive=True)
|
|
expr = sympy.floor(2**-x)
|
|
replaced = self._check(expr)
|
|
|
|
# Check that x went to the denominator.
|
|
self.assertEqual(replaced.args, (1, 2**x))
|
|
|
|
def test_launch_grid_dynamic_padding(self):
|
|
"""
|
|
Test a complex launch grid expression arising from padding with dynamic shapes.
|
|
"""
|
|
x, y = sympy.symbols("x y")
|
|
expr = sympy.floor(-FloorDiv(x * y, 2) / FloorDiv(-x * y, 131070))
|
|
self._check(expr)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
if HAS_GPU or TRITON_HAS_CPU:
|
|
run_tests(needs="filelock")
|