Compare commits

...

3 Commits

Author SHA1 Message Date
2ed602fbde fix 2025-10-30 20:41:43 -07:00
8ad7fb48fd lint 2025-10-30 08:39:49 -07:00
922b23136d enable decompose_mm_pass on xpu 2025-10-30 05:35:37 +00:00
2 changed files with 25 additions and 27 deletions

View File

@ -15,9 +15,8 @@ from torch.testing._internal.common_utils import (
is_navi3_arch,
parametrize,
patch_test_members,
TEST_XPU,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU_AND_TRITON
from torch.testing._internal.triton_utils import requires_gpu
@ -61,11 +60,6 @@ class TestDecomposeAddMM(torch.nn.Module):
@requires_gpu
@unittest.skipIf(
TEST_XPU,
"Intel GPU has not enabled decompose_mem_bound_mm PASS in "
"torch/_inductor/fx_passes/decompose_mem_bound_mm.py",
)
@torch._inductor.config.patch(
post_grad_fusion_options={
"decompose_mm_pass": {},
@ -144,7 +138,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_bmm"],
expected_val,
@ -155,7 +149,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_parameters(module, traced)
self.compare_gradients(module, traced)
expected_val = 3 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 3 if should_decompose and HAS_GPU_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_bmm"],
expected_val,
@ -204,7 +198,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
if has_bias:
self.assertEqual(
counters["inductor"]["decompose_addmm"],
@ -259,7 +253,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
if has_bias:
self.assertEqual(
counters["inductor"]["decompose_addmm"],
@ -304,7 +298,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_mm"],
expected_val,
@ -316,7 +310,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_parameters(module, traced)
self.compare_gradients(module, traced)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
expected_val,
@ -374,7 +368,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_mm"],
expected_val,
@ -386,7 +380,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_parameters(module, traced)
self.compare_gradients(module, traced)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
expected_val,
@ -410,7 +404,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
if has_bias:
self.assertEqual(
counters["inductor"]["decompose_addmm"],
@ -424,7 +418,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_gradients(module, traced)
expected_val = 0
if HAS_CUDA_AND_TRITON:
if HAS_GPU_AND_TRITON:
expected_val = 1 if has_bias else 2
self.assertEqual(
@ -447,12 +441,8 @@ class TestDecomposeMemMM(TestCase):
_, code = run_and_get_code(foo, input1, input2)
if GPU_TYPE == "xpu":
# only 1 kernel generated on the XPU stack
FileCheck().check_count(".run(", 1, exactly=True).run(code[0])
else:
# two kernels generated
FileCheck().check_count(".run(", 2, exactly=True).run(code[0])
# two kernels generated
FileCheck().check_count(".run(", 2, exactly=True).run(code[0])
def test_check_device(self):
m = 5
@ -462,7 +452,7 @@ class TestDecomposeMemMM(TestCase):
input1 = torch.randn(m, k, device=GPU_TYPE)
input2 = torch.randn(k, n, device=GPU_TYPE)
self.assertTrue(check_device(input1, input2))
self.assertTrue(check_device(input1, input2, device=GPU_TYPE))
self.assertFalse(check_device(input1, input2, device="cpu"))
input1 = torch.randn(m, k)

View File

@ -66,7 +66,9 @@ def should_decompose_bmm(mat1, mat2) -> bool:
return False
if len(mat1.shape) != 3 or len(mat2.shape) != 3:
return False
if check_device(mat1, mat2, device="cuda"):
if check_device(mat1, mat2, device="cuda") or check_device(
mat1, mat2, device="xpu"
):
if mat1.shape[0] < min_first_dimension_decomposition:
return False
# 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION
@ -130,7 +132,10 @@ def should_decompose_mm(mat1, mat2) -> bool:
"skip_dynamic_shape_dim_check", False
):
return (
check_device(mat1, mat2, device="cuda")
(
check_device(mat1, mat2, device="cuda")
or check_device(mat1, mat2, device="xpu")
)
and statically_known_true(
mat1.shape[0] >= min_first_dimension_decomposition
)
@ -151,7 +156,10 @@ def should_decompose_mm(mat1, mat2) -> bool:
# case 2: we decompose mm if the input is dynamic shape
else:
return (
check_device(mat1, mat2, device="cuda")
(
check_device(mat1, mat2, device="cuda")
or check_device(mat1, mat2, device="xpu")
)
and (
statically_known_true(
mat1.shape[0] >= min_first_dimension_decomposition