Compare commits

...

8 Commits

3 changed files with 179 additions and 82 deletions

View File

@ -20,7 +20,11 @@ from torch._inductor import config
from torch._inductor.codegen.cpp import CppScheduling
from torch._inductor.codegen.triton import TritonScheduling
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
from torch._inductor.codegen.wrapper_fxir import FxConverter, WrapperFxCodegen
from torch._inductor.codegen.wrapper_fxir import (
FxConverter,
replace_floor_div,
WrapperFxCodegen,
)
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.export import Dim
from torch.testing._internal.common_utils import (
@ -34,6 +38,7 @@ from torch.testing._internal.inductor_utils import (
requires_gpu,
TRITON_HAS_CPU,
)
from torch.utils._sympy.functions import FloorDiv
if HAS_GPU:
@ -483,10 +488,11 @@ class FxirTestCase(InductorTestCase):
)
self.assertIn("ks0", triton_node.kwargs["kwargs"])
def test_dynamic_launch_grid_calc_python(self):
def test_dynamic_launch_grid_calc(self):
"""
Test the dyanmic launch grid calculation for Triton kernel wrapper using python mode
Test the dyanmic launch grid calculation.
"""
func = torch.add
args = [torch.randn(shape, device=self.device) for shape in [(7, 12), (7, 1)]]
(gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True})
@ -505,41 +511,6 @@ class FxirTestCase(InductorTestCase):
self.assertEqual(grid[1], 1)
self.assertEqual(grid[2], 1)
def test_dynamic_launch_grid_calc_python_slow(self):
"""
Test the dyanmic launch grid calculation for Triton kernel wrapper using python_slow mode
"""
from torch._inductor.runtime.triton_heuristics import GridExpr
# Mock GridExpr.from_meta to use "python_slow" mode explicitly
original_from_meta = GridExpr.from_meta
def mocked_from_meta(inductor_meta, cfg, mode="python"):
return original_from_meta(inductor_meta, cfg, mode="python_slow")
with unittest.mock.patch.object(GridExpr, "from_meta", mocked_from_meta):
func = torch.add
args = [
torch.randn(shape, device=self.device) for shape in [(7, 12), (7, 1)]
]
(gm,) = self._compile_and_check(
func, args, compile_kwargs={"dynamic": True}
)
# Check for the precomputed size arg.
(triton_node,) = gm.graph.find_nodes(
op="call_function", target=triton_kernel_wrapper_mutation
)
self.assertIn("grid", triton_node.kwargs)
self.assertIn("xnumel", triton_node.kwargs["kwargs"])
self.assertIn("XBLOCK", triton_node.kwargs["kwargs"])
grid = triton_node.kwargs["grid"][0]
xnumel = triton_node.kwargs["kwargs"]["xnumel"].meta["val"]
xblock = triton_node.kwargs["kwargs"]["XBLOCK"]
self.assertEqual(grid[0].meta["val"], ((xnumel + xblock - 1) // xblock))
self.assertEqual(grid[1], 1)
self.assertEqual(grid[2], 1)
@config.patch({"trace.enabled": True})
@unittest.mock.patch("torch._inductor.debug.DebugFormatter.output_code")
def test_debug(self, mock_output_code):
@ -990,6 +961,29 @@ def forward(self, arg0_1, arg1_1, arg2_1):
return [buf1, buf2]""", # noqa: B950
)
def test_dims_dynamic_outer_static_padded_inner(self):
"""
Test padding on inner dimensions, with dynamic outer dimensions.
"""
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
def get_input_padded_inner(shape):
full_shape = shape[:-1] + (shape[-1] * 2,)
full = torch.randn(full_shape, dtype=torch.float32, device=self.device)
view = torch.as_strided(full, shape, full.stride())
return view
shape = (4, 4, 4)
args = tuple(get_input_padded_inner(shape) for _ in range(2))
self.check(
M(),
args,
dynamic_shapes=({0: Dim.DYNAMIC, 1: Dim.DYNAMIC, 2: Dim.STATIC},) * 2,
)
@parametrize("length", (4, 8))
def test_cond_dynamic_shape_pred_scalar_closure(self, length: int):
"""
@ -1033,6 +1027,132 @@ def forward(self, arg0_1, arg1_1, arg2_1):
self.check(M(), (x,), dynamic_shapes=({0: Dim.DYNAMIC},))
class TestReplaceFloorDiv(InductorTestCase):
"""
Tests for floor -> FloorDiv conversion.
"""
def _check(self, expr: sympy.Expr) -> sympy.Expr:
# Check that we started with floor's.
num_floors = expr.count(sympy.floor)
self.assertGreater(num_floors, 0)
replaced = replace_floor_div(expr)
# Check that all floor's were replaced.
# We shoud have no more new FloorDiv's than floor's in the original expression,
# although we can have less due to simplification.
self.assertEqual(replaced.count(sympy.floor), 0)
self.assertLessEqual(
replaced.count(FloorDiv) - expr.count(FloorDiv), num_floors
)
def expand_floor_div(
numerator: sympy.Expr, denominator: sympy.Expr
) -> sympy.Expr:
return sympy.floor(numerator / denominator)
# Expand FloorDiv back into floor and check for equality.
self.assertEqual(
*[
sympy.simplify(e.replace(FloorDiv, expand_floor_div))
for e in (replaced, expr)
]
)
return replaced
def test_rewrite_floor_div_mul_pow(self):
x, y = sympy.symbols("x y")
expr = sympy.floor(x / y)
self.assertEqual(expr.count(FloorDiv), 0)
self.assertEqual(expr.count(sympy.core.mul.Mul), 1)
self.assertEqual(expr.count(sympy.Pow), 1)
rewritten = self._check(expr)
self.assertTrue(isinstance(rewritten, FloorDiv))
self.assertEqual(rewritten.args, (x, y))
def test_rewrite_floor_div_mul_rational(self):
x = sympy.Symbol("x")
expr = sympy.floor(x / 5)
self.assertEqual(expr.count(FloorDiv), 0)
self.assertEqual(expr.count(sympy.core.mul.Mul), 1)
self.assertEqual(expr.count(sympy.Rational), 1)
rewritten = self._check(expr)
self.assertTrue(isinstance(rewritten, FloorDiv))
self.assertEqual(rewritten.args, (x, 5))
def test_no_rewrite_div(self):
x, y = sympy.symbols("x y")
expr = x / y
self.assertEqual(expr.count(FloorDiv), 0)
rewritten = replace_floor_div(expr)
self.assertEqual(rewritten, expr)
def test_rewrite_floor_div_nested(self):
x, y = sympy.symbols("x y")
expr = sympy.floor((sympy.floor(x / 5) + 1) / y)
self.assertEqual(expr.count(FloorDiv), 0)
rewritten = self._check(expr)
self.assertEqual(rewritten.count(FloorDiv), 2)
def test_rewrite_floor_div_rational_const(self):
expr = sympy.floor(sympy.S.One / 5, evaluate=False)
self.assertEqual(expr.count(FloorDiv), 0)
self.assertEqual(expr.count(sympy.Mul), 0)
self.assertEqual(expr.count(sympy.Rational), 1)
# Expression evaluates to a compile time constant
rewritten = self._check(expr)
self.assertEqual(rewritten, sympy.S.Zero)
def test_no_distribute_mul_floordiv(self):
"""
Test that multiplication doesn't distribute with floor division.
"""
x = sympy.Symbol("x")
expr = 2 * sympy.floor(x / 2)
rewritten = self._check(expr)
self.assertEqual(rewritten.count(sympy.Mul), 1)
self.assertEqual(rewritten.count(FloorDiv), 1)
def test_rational_multi_pows(self):
"""
Test an expression with a rational and multiple pows.
"""
x, y, z = sympy.symbols("x y z")
expr = sympy.floor((x / 5) * (y**2) * (z**3))
mul = expr.args[0]
self.assertTrue(isinstance(mul, sympy.Mul))
self.assertTrue(isinstance(mul.args[0], sympy.Rational))
self.assertEqual(expr.count(sympy.Pow), 2)
rewritten = self._check(expr)
self.assertEqual(rewritten.count(FloorDiv), 1)
def test_variable_exp(self):
"""
Test pow when the exponent is a variable.
"""
x = sympy.Symbol("x", positive=True)
expr = sympy.floor(2**-x)
replaced = self._check(expr)
# Check that x went to the denominator.
self.assertEqual(replaced.args, (1, 2**x))
def test_launch_grid_dynamic_padding(self):
"""
Test a complex launch grid expression arising from padding with dynamic shapes.
"""
x, y = sympy.symbols("x y")
expr = sympy.floor(-FloorDiv(x * y, 2) / FloorDiv(-x * y, 131070))
self._check(expr)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -23,11 +23,7 @@ from torch._higher_order_ops.triton_kernel_wrap import (
from torch._inductor.codecache import LambdaFuture, PyCodeCache
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from torch._inductor.select_algorithm import extern_kernels # noqa: F401
from torch._inductor.utils import (
convert_shape_to_symint,
convert_to_symint,
sympy_product,
)
from torch._inductor.utils import convert_shape_to_symint, convert_to_symint
from torch._inductor.virtualized import V
from torch._library.triton import wrap_triton
from torch.fx import GraphModule
@ -120,30 +116,20 @@ def replace_floor_div(expr: sympy.Expr) -> sympy.Expr:
def replace(expr: sympy.Expr) -> sympy.Expr:
expr = sympy.together(expr)
# Find division operations in the sympy.floor expression
# Div is either represented as Mul with:
# Rational denominator or Pow with negative exponent
if not isinstance(expr, sympy.core.mul.Mul):
return sympy.floor(expr)
if isinstance(expr.args[0], sympy.Rational):
frac = expr.args[0]
numerator = sympy_product(expr.args[1:]) * frac.numerator
denominator = frac.denominator
return FloorDiv(numerator, denominator)
elif isinstance(expr.args[0], sympy.Pow):
base = expr.args[0].base
exp = expr.args[0].exp
numerator = sympy_product(expr.args[1:])
if exp < 0:
denominator = base ** (-exp)
# Division is represented as a Mul with a Rational factor or a Pow with negative
# exponent. We convert floor(Mul(...)) to FloorDiv(numerator, denominator) by
# partitioning factors into the numerator and denominator.
(numerator, denominator) = (sympy.S.One,) * 2
for arg in sympy.Mul.make_args(expr):
if isinstance(arg, sympy.Rational):
numerator *= arg.numerator
denominator *= arg.denominator
elif isinstance(arg, sympy.Pow) and arg.exp.is_negative:
denominator *= arg.base**-arg.exp
else:
numerator = numerator * (base**exp)
denominator = 1
return FloorDiv(numerator, denominator)
else:
return sympy.floor(expr)
numerator *= arg
return FloorDiv(numerator, denominator)
return expr.replace(sympy.floor, replace)
@ -930,10 +916,6 @@ class FxConverter:
call_args = self._lookup_args(line.call_args)
kernel = self.kernels[line.kernel_name]
tuner = kernel.tuner
# Use python_slow mode instead of python mode to avoid
# the round to neginf behaviour, which is not the convention
# in other languages.
tuner.grid_mode = "python_slow"
# Optionally autotune the kernels.
# The FX backend currently only supports compile-time tuning.
@ -1007,8 +989,7 @@ class FxConverter:
call_kwargs = dict(zip(signature, call_args))
call_kwargs.update(kernel_config.kwargs)
# Replace all sympy.floor with FloorDiv
# _generate_sym_node does not support sympy.floor
# Replace sympy.floor with FloorDiv, to make the expression traceable.
grid = [replace_floor_div(x) if isinstance(x, sympy.Expr) else x for x in grid]
wrapper_grid = [tuple(self._generate_sym_nodes(grid))]
call_kwargs = {

View File

@ -375,7 +375,7 @@ class CachingAutotuner(KernelInterface):
self.is_backward = False
# Mode for launch grid calculation
self.grid_mode: Literal["python", "python_slow", "cpp"] = "python"
self.grid_mode: Literal["python", "cpp"] = "python"
def is_statically_launchable(self):
"""
@ -3192,14 +3192,14 @@ class GridExpr:
"""Generate code for grid size expressions in launcher"""
inductor_meta: dict[str, Any]
mode: Literal["python", "cpp", "python_slow"] = "python"
mode: Literal["python", "cpp"] = "python"
prefix: list[str] = dataclasses.field(default_factory=list)
x_grid: Union[str, int] = 1
y_grid: Union[str, int] = 1
z_grid: Union[str, int] = 1
def __post_init__(self) -> None:
assert self.mode in ("python", "cpp", "python_slow")
assert self.mode in ("python", "cpp")
def generate(self, meta: dict[str, int]) -> None:
raise NotImplementedError
@ -3215,10 +3215,6 @@ class GridExpr:
# negative integer division is floored
if self.mode == "python":
return f"-(({numel}) // -({block}))"
# This is more generic than above, and works in languages where
# positive integer division is floored/truncated
elif self.mode == "python_slow":
return f"(({numel} + {block} - 1) // ({block}))"
# For cpp code gen
return f"(({numel} + ({block} - 1)) / ({block}))"
@ -3227,7 +3223,7 @@ class GridExpr:
items = self._constant_fold(max, seq)
if len(items) <= 1:
return items[0]
if self.mode in ("python", "python_slow"):
if self.mode == "python":
return f"max({', '.join(map(str, items))})"
return functools.reduce(lambda x, y: f"std::max({x}, {y})", items)
@ -3250,7 +3246,7 @@ class GridExpr:
def assign_tmp(self, name: str, expr: Union[str, int]) -> str:
# Grid functions are one per kernel, so name collisions are fine
if self.mode in ("python", "python_slow"):
if self.mode == "python":
return f"{name} = {expr}"
if self.mode == "cpp":
return f"uint32_t {name} = {expr};"
@ -3260,7 +3256,7 @@ class GridExpr:
def from_meta(
inductor_meta: dict[str, Any],
cfg: Union[Config, dict[str, int]],
mode: Literal["python", "cpp", "python_slow"] = "python",
mode: Literal["python", "cpp"] = "python",
) -> GridExpr:
grid_cls = globals()[inductor_meta["grid_type"]]
assert issubclass(grid_cls, GridExpr)