mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] Add CPU_MAX_FIRST_DIMENSION_DECOMPOSITION and CPU_MAX_OTHER_DIMENSION_DECOMPOSITION for decompose_mm_pass (#158183)
Differential Revision: D78209993 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158183 Approved by: https://github.com/houseroad
This commit is contained in:
committed by
PyTorch MergeBot
parent
1b389025ba
commit
7f9fc7e67c
@ -15,11 +15,17 @@ aten = torch.ops.aten
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# TODO: need a better strategy for decomposing mm
|
||||
# The following two constants are for CUDA device only
|
||||
MIN_FIRST_DIMENSION_DECOMPOSITION = 10240
|
||||
MAX_OTHER_DIMENSION_DECOMPOSITION = 32
|
||||
# The following two constants are for CPU device only
|
||||
CPU_MAX_FIRST_DIMENSION_DECOMPOSITION = 1
|
||||
CPU_MAX_OTHER_DIMENSION_DECOMPOSITION = 2048
|
||||
|
||||
min_first_dimension_decomposition = MIN_FIRST_DIMENSION_DECOMPOSITION
|
||||
max_other_dimension_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION
|
||||
cpu_max_first_dimension_decomposition = CPU_MAX_FIRST_DIMENSION_DECOMPOSITION
|
||||
cpu_max_other_dimension_decomposition = CPU_MAX_OTHER_DIMENSION_DECOMPOSITION
|
||||
if "decompose_mm_pass" in config.post_grad_fusion_options:
|
||||
min_first_dimension_decomposition = config.post_grad_fusion_options[
|
||||
"decompose_mm_pass"
|
||||
@ -27,6 +33,16 @@ if "decompose_mm_pass" in config.post_grad_fusion_options:
|
||||
max_other_dimension_decomposition = config.post_grad_fusion_options[
|
||||
"decompose_mm_pass"
|
||||
].get("max_other_dimension_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION)
|
||||
cpu_max_first_dimension_decomposition = config.post_grad_fusion_options[
|
||||
"decompose_mm_pass"
|
||||
].get(
|
||||
"cpu_max_first_dimension_decomposition", CPU_MAX_FIRST_DIMENSION_DECOMPOSITION
|
||||
)
|
||||
cpu_max_other_dimension_decomposition = config.post_grad_fusion_options[
|
||||
"decompose_mm_pass"
|
||||
].get(
|
||||
"cpu_max_other_dimension_decomposition", CPU_MAX_OTHER_DIMENSION_DECOMPOSITION
|
||||
)
|
||||
|
||||
|
||||
def check_device(a: Tensor, b: Tensor, device="cuda") -> bool:
|
||||
@ -57,7 +73,10 @@ def should_decompose_bmm(mat1, mat2) -> bool:
|
||||
return False
|
||||
return True
|
||||
elif check_device(mat1, mat2, device="cpu"):
|
||||
if mat1.shape[0] == 1 and mat2.shape[0] == 1:
|
||||
if (
|
||||
mat1.shape[0] <= cpu_max_first_dimension_decomposition
|
||||
and mat2.shape[0] <= cpu_max_first_dimension_decomposition
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -77,9 +96,15 @@ def should_decompose_mm(mat1, mat2) -> bool:
|
||||
and statically_known_true(mat2.shape[1] < max_other_dimension_decomposition)
|
||||
) or (
|
||||
check_device(mat1, mat2, device="cpu")
|
||||
and statically_known_true(mat1.shape[0] == 1)
|
||||
and statically_known_true(mat2.shape[0] <= 128)
|
||||
and statically_known_true(mat2.shape[1] <= 512)
|
||||
and statically_known_true(
|
||||
mat1.shape[0] <= cpu_max_first_dimension_decomposition
|
||||
)
|
||||
and statically_known_true(
|
||||
mat2.shape[0] <= cpu_max_other_dimension_decomposition
|
||||
)
|
||||
and statically_known_true(
|
||||
mat2.shape[1] <= cpu_max_other_dimension_decomposition
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user