Add contiguous subgraph transformation threshold (#162192)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162192
Approved by: https://github.com/coconutruben
This commit is contained in:
Gabriel Ferns
2025-09-06 02:48:00 +00:00
committed by PyTorch MergeBot
parent c3ceca2995
commit 20629b1619
2 changed files with 6 additions and 3 deletions

View File

@ -1786,6 +1786,9 @@ class rocm:
# The threshold at which we trigger a splitK config - K // max(M,N) has to be greater than this # The threshold at which we trigger a splitK config - K // max(M,N) has to be greater than this
split_k_threshold: int = 16 split_k_threshold: int = 16
# The threshold at which we trigger a contiguous subgraph transformation
contiguous_threshold: int = 16
# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) # Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental)
cpu_backend: Literal["cpp", "triton", "halide"] = "cpp" cpu_backend: Literal["cpp", "triton", "halide"] = "cpp"

View File

@ -1842,7 +1842,7 @@ def use_contiguous(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
Check if we should use the contiguous subgraph transform. Check if we should use the contiguous subgraph transform.
This transform makes the second matrix contiguous before the matmul. This transform makes the second matrix contiguous before the matmul.
""" """
decompose_k_threshold = config.triton.decompose_k_threshold contiguous_threshold = config.rocm.contiguous_threshold
# Similar conditions to decompose_k but for contiguous transform # Similar conditions to decompose_k but for contiguous transform
from torch._inductor.virtualized import V from torch._inductor.virtualized import V
@ -1851,8 +1851,8 @@ def use_contiguous(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
bool(torch.version.hip) # Only relevant on AMD bool(torch.version.hip) # Only relevant on AMD
and V.graph.sizevars.statically_known_true( and V.graph.sizevars.statically_known_true(
sympy.And( sympy.And(
sympy.Ge(k, decompose_k_threshold * m), sympy.Ge(k, contiguous_threshold * m),
sympy.Ge(k, decompose_k_threshold * n), sympy.Ge(k, contiguous_threshold * n),
) )
) )
and not V.graph.aot_mode and not V.graph.aot_mode