[Inductor] Add CPU_MAX_FIRST_DIMENSION_DECOMPOSITION and CPU_MAX_OTHER_DIMENSION_DECOMPOSITION for decompose_mm_pass (#158183)

Differential Revision: D78209993

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158183
Approved by: https://github.com/houseroad
This commit is contained in:
Huamin Li
2025-07-15 10:07:25 +00:00
committed by PyTorch MergeBot
parent 1b389025ba
commit 7f9fc7e67c

View File

@ -15,11 +15,17 @@ aten = torch.ops.aten
log = logging.getLogger(__name__)
# TODO: need a better strategy for decomposing mm
# The following two constants are for CUDA device only
MIN_FIRST_DIMENSION_DECOMPOSITION = 10240
MAX_OTHER_DIMENSION_DECOMPOSITION = 32
# The following two constants are for CPU device only
CPU_MAX_FIRST_DIMENSION_DECOMPOSITION = 1
CPU_MAX_OTHER_DIMENSION_DECOMPOSITION = 2048
min_first_dimension_decomposition = MIN_FIRST_DIMENSION_DECOMPOSITION
max_other_dimension_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION
cpu_max_first_dimension_decomposition = CPU_MAX_FIRST_DIMENSION_DECOMPOSITION
cpu_max_other_dimension_decomposition = CPU_MAX_OTHER_DIMENSION_DECOMPOSITION
if "decompose_mm_pass" in config.post_grad_fusion_options:
min_first_dimension_decomposition = config.post_grad_fusion_options[
"decompose_mm_pass"
@ -27,6 +33,16 @@ if "decompose_mm_pass" in config.post_grad_fusion_options:
max_other_dimension_decomposition = config.post_grad_fusion_options[
"decompose_mm_pass"
].get("max_other_dimension_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION)
cpu_max_first_dimension_decomposition = config.post_grad_fusion_options[
"decompose_mm_pass"
].get(
"cpu_max_first_dimension_decomposition", CPU_MAX_FIRST_DIMENSION_DECOMPOSITION
)
cpu_max_other_dimension_decomposition = config.post_grad_fusion_options[
"decompose_mm_pass"
].get(
"cpu_max_other_dimension_decomposition", CPU_MAX_OTHER_DIMENSION_DECOMPOSITION
)
def check_device(a: Tensor, b: Tensor, device="cuda") -> bool:
@ -57,7 +73,10 @@ def should_decompose_bmm(mat1, mat2) -> bool:
return False
return True
elif check_device(mat1, mat2, device="cpu"):
if mat1.shape[0] == 1 and mat2.shape[0] == 1:
if (
mat1.shape[0] <= cpu_max_first_dimension_decomposition
and mat2.shape[0] <= cpu_max_first_dimension_decomposition
):
return True
return False
@ -77,9 +96,15 @@ def should_decompose_mm(mat1, mat2) -> bool:
and statically_known_true(mat2.shape[1] < max_other_dimension_decomposition)
) or (
check_device(mat1, mat2, device="cpu")
and statically_known_true(mat1.shape[0] == 1)
and statically_known_true(mat2.shape[0] <= 128)
and statically_known_true(mat2.shape[1] <= 512)
and statically_known_true(
mat1.shape[0] <= cpu_max_first_dimension_decomposition
)
and statically_known_true(
mat2.shape[0] <= cpu_max_other_dimension_decomposition
)
and statically_known_true(
mat2.shape[1] <= cpu_max_other_dimension_decomposition
)
)