mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fixes #154111 Resolves an issue during compilation with dynamic shapes where `torch._inductor.decomposition.mm` evaluates the SymInt expression for the input tensor due to a for loop, and thus the output tensor is not dynamically shaped. This issue is limited to (Mx1)x(1xN) small matrix multiplications, and creates an explicit error with tensor subclasses such as DTensor. The proposed fix replaces the loop with a simple product instead. Benchmark currently running https://hud.pytorch.org/benchmark/compilers Pull Request resolved: https://github.com/pytorch/pytorch/pull/158998 Approved by: https://github.com/jansel, https://github.com/BoyuanFeng
283 lines
9.5 KiB
Python
283 lines
9.5 KiB
Python
# Owner(s): ["module: nn"]
|
|
|
|
import math
|
|
import unittest
|
|
from typing import Union
|
|
|
|
import torch
|
|
from torch._inductor import config
|
|
from torch._inductor.decomposition import mm
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
DimDynamic,
|
|
ShapeEnv,
|
|
StatelessSymbolicContext,
|
|
)
|
|
from torch.testing._internal.common_cuda import SM80OrLater
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_nn import NNTestCase
|
|
from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, run_tests
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
|
|
|
|
|
default_atol = {
|
|
torch.float16: 1e-3,
|
|
torch.bfloat16: float("infinity"),
|
|
torch.float32: 1e-5,
|
|
}
|
|
default_rtol = {
|
|
torch.float16: 1e-3,
|
|
torch.bfloat16: float("infinity"),
|
|
torch.float32: 1.3e-6,
|
|
}
|
|
|
|
|
|
def rand_math_tensor(
|
|
shape: tuple[Union[int, list[int]]],
|
|
device: str,
|
|
dtype: torch.dtype,
|
|
requires_grad: bool = False,
|
|
packed: bool = False,
|
|
) -> torch.Tensor:
|
|
"""Creates rand dense or nested tensor with given shape and type.
|
|
|
|
Args:
|
|
shape (Tuple[int]): Shape of Tensor to construct
|
|
device (str): which device to create tensor on
|
|
dtype (torch.dtype): Tensors' dtype
|
|
requires_grad (bool, optional): Tensors grad status. Defaults to False.
|
|
packed (bool, optional): Whether to create a single QKV packed or not. Defaults to False.
|
|
|
|
Returns:
|
|
torch.Tensor: A new tensor
|
|
"""
|
|
return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad)
|
|
|
|
|
|
def init_tensor(tensor_list, **kwargs) -> torch.Tensor:
|
|
return torch.Tensor(tensor_list).to(**kwargs)
|
|
|
|
|
|
def run_comp_nocomp(function, *inputs, **kwargs):
|
|
c_function = torch.compile(function)
|
|
|
|
f_res = function(*inputs)
|
|
cf_res = c_function(*inputs)
|
|
|
|
if not (math.isinf(kwargs.get("atol", 0.0)) or math.isinf(kwargs.get("rtol", 0.0))):
|
|
torch.testing.assert_close(f_res, cf_res, **kwargs)
|
|
|
|
|
|
# The test functions are used by several tests
|
|
def torch_mm(a, b):
|
|
return torch.mm(a, b)
|
|
|
|
|
|
def torch_addmm(add, b, c):
|
|
return torch.addmm(add, b, c)
|
|
|
|
|
|
def torch_bmm(a, b):
|
|
return torch.bmm(a, b)
|
|
|
|
|
|
def torch_baddbmm(add, b, c, alpha, beta):
|
|
return torch.baddbmm(add, b, c, alpha=alpha, beta=beta)
|
|
|
|
|
|
def create_fake_tensor_with_dynamic_size(x, fake_mode):
|
|
with fake_mode:
|
|
dynamic_sizes = [DimDynamic.DYNAMIC for _ in range(x.dim())]
|
|
dynamic_strides = [DimDynamic.INFER_STRIDE for _ in range(x.dim())]
|
|
return fake_mode.from_tensor(
|
|
x,
|
|
symbolic_context=StatelessSymbolicContext(
|
|
dynamic_sizes=dynamic_sizes,
|
|
dynamic_strides=dynamic_strides,
|
|
),
|
|
)
|
|
|
|
|
|
# The shapes we test on
|
|
ts_list = [
|
|
(1, 32, 32, 1),
|
|
(1, 10, 10, 1),
|
|
(1, 3, 3, 1),
|
|
(32, 1, 1, 32),
|
|
(3, 1, 1, 3),
|
|
(4, 1, 1, 9),
|
|
(9, 1, 1, 4),
|
|
]
|
|
|
|
|
|
class TestDecomp(NNTestCase):
|
|
_do_cuda_memory_leak_check = GPU_TYPE == "cuda"
|
|
_do_cuda_non_default_stream = GPU_TYPE == "cuda"
|
|
|
|
@unittest.skipIf(not HAS_GPU, "GPU tests require triton")
|
|
@parametrize("dtype", [torch.float, torch.bfloat16])
|
|
def test_simple_mm(self, device, dtype):
|
|
fudge = 10
|
|
rtol = default_rtol[dtype] * fudge
|
|
atol = default_atol[dtype] * fudge
|
|
|
|
for t_size in ts_list:
|
|
((a1_0, a1_1, a2_0, a2_1)) = t_size
|
|
|
|
t1 = rand_math_tensor((a1_0, a1_1), dtype=dtype, device=device)
|
|
t2 = rand_math_tensor((a2_0, a2_1), dtype=dtype, device=device)
|
|
tadd = rand_math_tensor((a1_0, a2_1), dtype=dtype, device=device)
|
|
|
|
run_comp_nocomp(torch_mm, t1, t2, rtol=rtol, atol=atol)
|
|
run_comp_nocomp(torch_addmm, tadd, t1, t2, rtol=rtol, atol=atol)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "GPU tests require triton")
|
|
@parametrize(
|
|
"dtype", [torch.float, torch.bfloat16] if SM80OrLater else [torch.float]
|
|
)
|
|
@parametrize("bs", [1, 2, 4, 10])
|
|
def test_batched_mm(self, device, dtype, bs):
|
|
fudge = 3
|
|
rtol = default_rtol[dtype] * fudge
|
|
atol = default_atol[dtype] * fudge
|
|
|
|
for t_size in ts_list:
|
|
((a1_0, a1_1, a2_0, a2_1)) = t_size
|
|
|
|
t1 = rand_math_tensor((bs, a1_0, a1_1), dtype=dtype, device=device)
|
|
t2 = rand_math_tensor((bs, a2_0, a2_1), dtype=dtype, device=device)
|
|
tadd = rand_math_tensor((bs, a1_0, a2_1), dtype=dtype, device=device)
|
|
|
|
run_comp_nocomp(torch_bmm, t1, t2, rtol=rtol, atol=atol)
|
|
|
|
for alpha in (0, 1, -1, 0.5, -0.5):
|
|
for beta in (0, 1, -1, 0.5, -0.5):
|
|
run_comp_nocomp(
|
|
torch_baddbmm, tadd, t1, t2, alpha, beta, rtol=rtol, atol=atol
|
|
)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "GPU tests require triton")
|
|
@config.patch(coordinate_descent_tuning=True)
|
|
def test_bmm_batch2_last_dim_size_is_one(self, device):
|
|
fudge = 3
|
|
rtol = default_rtol[torch.float32] * fudge
|
|
atol = default_atol[torch.float32] * fudge
|
|
|
|
t1 = torch.randn(1, 32, 2, device=device)
|
|
t2 = torch.randn(1, 2, 1, device=device)
|
|
|
|
run_comp_nocomp(torch_bmm, t1, t2, rtol=rtol, atol=atol)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "GPU tests require triton")
|
|
@parametrize("dtype", [torch.float, torch.bfloat16, torch.int])
|
|
def test_some(self, device, dtype):
|
|
# this Pytorch data type is not fully supported on cuda today
|
|
# - unfortunately we can't skipIf because we don't see the actual params in skipIf
|
|
if device.startswith(GPU_TYPE) and dtype == torch.int:
|
|
return
|
|
|
|
run_comp_nocomp(
|
|
torch_mm,
|
|
init_tensor([[1], [2], [3], [4]], dtype=dtype, device=device),
|
|
init_tensor([[1, 2, 3, 4]], dtype=dtype, device=device),
|
|
)
|
|
run_comp_nocomp(
|
|
torch_mm,
|
|
init_tensor([[1, 2, 3, 4]], dtype=dtype, device=device),
|
|
init_tensor([[1], [2], [3], [4]], dtype=dtype, device=device),
|
|
)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "GPU tests require triton")
|
|
@parametrize("dtype", [torch.float, torch.bfloat16, torch.int])
|
|
@parametrize("bs", [1, 2, 4, 10])
|
|
def test_some_batched(self, device, dtype, bs):
|
|
# this Pytorch data type is not fully supported on cuda today
|
|
# - unfortunately we can't skipIf because we don't see the actual params in skipIf
|
|
if device.startswith(GPU_TYPE) and dtype == torch.int:
|
|
return
|
|
|
|
run_comp_nocomp(
|
|
torch_bmm,
|
|
init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device),
|
|
init_tensor([[[1, 2, 3, 4]]] * bs, dtype=dtype, device=device),
|
|
)
|
|
run_comp_nocomp(
|
|
torch_bmm,
|
|
init_tensor([[[1, 2, 3, 4]]] * bs, dtype=dtype, device=device),
|
|
init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device),
|
|
)
|
|
|
|
@parametrize("dtype", [torch.float, torch.bfloat16])
|
|
def test_dynamic_shape_mm(self, device, dtype):
|
|
# Test that the mm decomp does not evaluate expressions for dynamic shapes
|
|
|
|
shape_env = ShapeEnv()
|
|
fake_mode = FakeTensorMode(shape_env=shape_env)
|
|
|
|
# Only test decomp for cpu to match fake tensors from dynamo
|
|
if device != "cpu":
|
|
return
|
|
|
|
for t_size in ts_list:
|
|
((a1_0, a1_1, a2_0, a2_1)) = t_size
|
|
|
|
# Create the fake tensors
|
|
t1 = create_fake_tensor_with_dynamic_size(
|
|
rand_math_tensor((a1_0, a1_1), dtype=dtype, device=device),
|
|
fake_mode,
|
|
)
|
|
t2 = create_fake_tensor_with_dynamic_size(
|
|
rand_math_tensor((a2_0, a2_1), dtype=dtype, device=device),
|
|
fake_mode,
|
|
)
|
|
|
|
# Save the expression types to check if any symints are evaluated
|
|
og_t1_expr_types = [
|
|
type(d.node.expr) if type(d) is torch.SymInt else int for d in t1.size()
|
|
]
|
|
og_t2_expr_types = [
|
|
type(d.node.expr) if type(d) is torch.SymInt else int for d in t2.size()
|
|
]
|
|
|
|
r = mm(t1, t2)
|
|
|
|
# Make sure all symints are not evaluated
|
|
new_t1_expr_types = [
|
|
type(d.node.expr) if type(d) is torch.SymInt else int for d in t1.size()
|
|
]
|
|
new_t2_expr_types = [
|
|
type(d.node.expr) if type(d) is torch.SymInt else int for d in t2.size()
|
|
]
|
|
self.assertTrue(
|
|
all(
|
|
og_t1_expr_types[i] == new_t1_expr_types[i]
|
|
for i in range(len(og_t1_expr_types))
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
all(
|
|
og_t2_expr_types[i] == new_t2_expr_types[i]
|
|
for i in range(len(og_t2_expr_types))
|
|
)
|
|
)
|
|
|
|
if r is not NotImplemented:
|
|
# Check that the output is well formed
|
|
self.assertEqual(t1.size(0), r.size(0))
|
|
self.assertEqual(t2.size(1), r.size(1))
|
|
r_expr_types = [
|
|
type(d.node.expr) if type(d) is torch.SymInt else int
|
|
for d in r.size()
|
|
]
|
|
self.assertTrue(r_expr_types[0] == og_t1_expr_types[0])
|
|
self.assertTrue(r_expr_types[1] == og_t2_expr_types[1])
|
|
|
|
|
|
device_types = ("cpu", GPU_TYPE)
|
|
instantiate_device_type_tests(TestDecomp, globals(), only_for=device_types)
|
|
|
|
if __name__ == "__main__":
|
|
# We don't support torch.compile() on Windows
|
|
if not IS_WINDOWS:
|
|
run_tests()
|