Files
pytorch/test/inductor/test_mmdecomp.py
Arsh Zahed 24d07b3a67 [inductor] Fix mm decomposition evaluating symints (#158998)
Fixes #154111

Resolves an issue during compilation with dynamic shapes where `torch._inductor.decomposition.mm` evaluates the SymInt expression for the input tensor due to a for loop, and thus the output tensor is not dynamically shaped. This issue is limited to (Mx1)x(1xN) small matrix multiplications, and creates an explicit error with tensor subclasses such as DTensor.

The proposed fix replaces the loop with a simple product instead. Benchmark currently running https://hud.pytorch.org/benchmark/compilers

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158998
Approved by: https://github.com/jansel, https://github.com/BoyuanFeng
2025-07-30 16:34:15 +00:00

283 lines
9.5 KiB
Python

# Owner(s): ["module: nn"]
import math
import unittest
from typing import Union
import torch
from torch._inductor import config
from torch._inductor.decomposition import mm
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import (
DimDynamic,
ShapeEnv,
StatelessSymbolicContext,
)
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, run_tests
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
default_atol = {
torch.float16: 1e-3,
torch.bfloat16: float("infinity"),
torch.float32: 1e-5,
}
default_rtol = {
torch.float16: 1e-3,
torch.bfloat16: float("infinity"),
torch.float32: 1.3e-6,
}
def rand_math_tensor(
shape: tuple[Union[int, list[int]]],
device: str,
dtype: torch.dtype,
requires_grad: bool = False,
packed: bool = False,
) -> torch.Tensor:
"""Creates rand dense or nested tensor with given shape and type.
Args:
shape (Tuple[int]): Shape of Tensor to construct
device (str): which device to create tensor on
dtype (torch.dtype): Tensors' dtype
requires_grad (bool, optional): Tensors grad status. Defaults to False.
packed (bool, optional): Whether to create a single QKV packed or not. Defaults to False.
Returns:
torch.Tensor: A new tensor
"""
return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad)
def init_tensor(tensor_list, **kwargs) -> torch.Tensor:
return torch.Tensor(tensor_list).to(**kwargs)
def run_comp_nocomp(function, *inputs, **kwargs):
c_function = torch.compile(function)
f_res = function(*inputs)
cf_res = c_function(*inputs)
if not (math.isinf(kwargs.get("atol", 0.0)) or math.isinf(kwargs.get("rtol", 0.0))):
torch.testing.assert_close(f_res, cf_res, **kwargs)
# The test functions are used by several tests
def torch_mm(a, b):
return torch.mm(a, b)
def torch_addmm(add, b, c):
return torch.addmm(add, b, c)
def torch_bmm(a, b):
return torch.bmm(a, b)
def torch_baddbmm(add, b, c, alpha, beta):
return torch.baddbmm(add, b, c, alpha=alpha, beta=beta)
def create_fake_tensor_with_dynamic_size(x, fake_mode):
with fake_mode:
dynamic_sizes = [DimDynamic.DYNAMIC for _ in range(x.dim())]
dynamic_strides = [DimDynamic.INFER_STRIDE for _ in range(x.dim())]
return fake_mode.from_tensor(
x,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=dynamic_sizes,
dynamic_strides=dynamic_strides,
),
)
# The shapes we test on
ts_list = [
(1, 32, 32, 1),
(1, 10, 10, 1),
(1, 3, 3, 1),
(32, 1, 1, 32),
(3, 1, 1, 3),
(4, 1, 1, 9),
(9, 1, 1, 4),
]
class TestDecomp(NNTestCase):
_do_cuda_memory_leak_check = GPU_TYPE == "cuda"
_do_cuda_non_default_stream = GPU_TYPE == "cuda"
@unittest.skipIf(not HAS_GPU, "GPU tests require triton")
@parametrize("dtype", [torch.float, torch.bfloat16])
def test_simple_mm(self, device, dtype):
fudge = 10
rtol = default_rtol[dtype] * fudge
atol = default_atol[dtype] * fudge
for t_size in ts_list:
((a1_0, a1_1, a2_0, a2_1)) = t_size
t1 = rand_math_tensor((a1_0, a1_1), dtype=dtype, device=device)
t2 = rand_math_tensor((a2_0, a2_1), dtype=dtype, device=device)
tadd = rand_math_tensor((a1_0, a2_1), dtype=dtype, device=device)
run_comp_nocomp(torch_mm, t1, t2, rtol=rtol, atol=atol)
run_comp_nocomp(torch_addmm, tadd, t1, t2, rtol=rtol, atol=atol)
@unittest.skipIf(not HAS_GPU, "GPU tests require triton")
@parametrize(
"dtype", [torch.float, torch.bfloat16] if SM80OrLater else [torch.float]
)
@parametrize("bs", [1, 2, 4, 10])
def test_batched_mm(self, device, dtype, bs):
fudge = 3
rtol = default_rtol[dtype] * fudge
atol = default_atol[dtype] * fudge
for t_size in ts_list:
((a1_0, a1_1, a2_0, a2_1)) = t_size
t1 = rand_math_tensor((bs, a1_0, a1_1), dtype=dtype, device=device)
t2 = rand_math_tensor((bs, a2_0, a2_1), dtype=dtype, device=device)
tadd = rand_math_tensor((bs, a1_0, a2_1), dtype=dtype, device=device)
run_comp_nocomp(torch_bmm, t1, t2, rtol=rtol, atol=atol)
for alpha in (0, 1, -1, 0.5, -0.5):
for beta in (0, 1, -1, 0.5, -0.5):
run_comp_nocomp(
torch_baddbmm, tadd, t1, t2, alpha, beta, rtol=rtol, atol=atol
)
@unittest.skipIf(not HAS_GPU, "GPU tests require triton")
@config.patch(coordinate_descent_tuning=True)
def test_bmm_batch2_last_dim_size_is_one(self, device):
fudge = 3
rtol = default_rtol[torch.float32] * fudge
atol = default_atol[torch.float32] * fudge
t1 = torch.randn(1, 32, 2, device=device)
t2 = torch.randn(1, 2, 1, device=device)
run_comp_nocomp(torch_bmm, t1, t2, rtol=rtol, atol=atol)
@unittest.skipIf(not HAS_GPU, "GPU tests require triton")
@parametrize("dtype", [torch.float, torch.bfloat16, torch.int])
def test_some(self, device, dtype):
# this Pytorch data type is not fully supported on cuda today
# - unfortunately we can't skipIf because we don't see the actual params in skipIf
if device.startswith(GPU_TYPE) and dtype == torch.int:
return
run_comp_nocomp(
torch_mm,
init_tensor([[1], [2], [3], [4]], dtype=dtype, device=device),
init_tensor([[1, 2, 3, 4]], dtype=dtype, device=device),
)
run_comp_nocomp(
torch_mm,
init_tensor([[1, 2, 3, 4]], dtype=dtype, device=device),
init_tensor([[1], [2], [3], [4]], dtype=dtype, device=device),
)
@unittest.skipIf(not HAS_GPU, "GPU tests require triton")
@parametrize("dtype", [torch.float, torch.bfloat16, torch.int])
@parametrize("bs", [1, 2, 4, 10])
def test_some_batched(self, device, dtype, bs):
# this Pytorch data type is not fully supported on cuda today
# - unfortunately we can't skipIf because we don't see the actual params in skipIf
if device.startswith(GPU_TYPE) and dtype == torch.int:
return
run_comp_nocomp(
torch_bmm,
init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device),
init_tensor([[[1, 2, 3, 4]]] * bs, dtype=dtype, device=device),
)
run_comp_nocomp(
torch_bmm,
init_tensor([[[1, 2, 3, 4]]] * bs, dtype=dtype, device=device),
init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device),
)
@parametrize("dtype", [torch.float, torch.bfloat16])
def test_dynamic_shape_mm(self, device, dtype):
# Test that the mm decomp does not evaluate expressions for dynamic shapes
shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)
# Only test decomp for cpu to match fake tensors from dynamo
if device != "cpu":
return
for t_size in ts_list:
((a1_0, a1_1, a2_0, a2_1)) = t_size
# Create the fake tensors
t1 = create_fake_tensor_with_dynamic_size(
rand_math_tensor((a1_0, a1_1), dtype=dtype, device=device),
fake_mode,
)
t2 = create_fake_tensor_with_dynamic_size(
rand_math_tensor((a2_0, a2_1), dtype=dtype, device=device),
fake_mode,
)
# Save the expression types to check if any symints are evaluated
og_t1_expr_types = [
type(d.node.expr) if type(d) is torch.SymInt else int for d in t1.size()
]
og_t2_expr_types = [
type(d.node.expr) if type(d) is torch.SymInt else int for d in t2.size()
]
r = mm(t1, t2)
# Make sure all symints are not evaluated
new_t1_expr_types = [
type(d.node.expr) if type(d) is torch.SymInt else int for d in t1.size()
]
new_t2_expr_types = [
type(d.node.expr) if type(d) is torch.SymInt else int for d in t2.size()
]
self.assertTrue(
all(
og_t1_expr_types[i] == new_t1_expr_types[i]
for i in range(len(og_t1_expr_types))
)
)
self.assertTrue(
all(
og_t2_expr_types[i] == new_t2_expr_types[i]
for i in range(len(og_t2_expr_types))
)
)
if r is not NotImplemented:
# Check that the output is well formed
self.assertEqual(t1.size(0), r.size(0))
self.assertEqual(t2.size(1), r.size(1))
r_expr_types = [
type(d.node.expr) if type(d) is torch.SymInt else int
for d in r.size()
]
self.assertTrue(r_expr_types[0] == og_t1_expr_types[0])
self.assertTrue(r_expr_types[1] == og_t2_expr_types[1])
device_types = ("cpu", GPU_TYPE)
instantiate_device_type_tests(TestDecomp, globals(), only_for=device_types)
if __name__ == "__main__":
# We don't support torch.compile() on Windows
if not IS_WINDOWS:
run_tests()