mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add contiguous subgraph transformation threshold (#162192)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162192 Approved by: https://github.com/coconutruben
This commit is contained in:
committed by
PyTorch MergeBot
parent
c3ceca2995
commit
20629b1619
@ -1786,6 +1786,9 @@ class rocm:
|
|||||||
# The threshold at which we trigger a splitK config - K // max(M,N) has to be greater than this
|
# The threshold at which we trigger a splitK config - K // max(M,N) has to be greater than this
|
||||||
split_k_threshold: int = 16
|
split_k_threshold: int = 16
|
||||||
|
|
||||||
|
# The threshold at which we trigger a contiguous subgraph transformation
|
||||||
|
contiguous_threshold: int = 16
|
||||||
|
|
||||||
|
|
||||||
# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental)
|
# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental)
|
||||||
cpu_backend: Literal["cpp", "triton", "halide"] = "cpp"
|
cpu_backend: Literal["cpp", "triton", "halide"] = "cpp"
|
||||||
|
@ -1842,7 +1842,7 @@ def use_contiguous(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
|
|||||||
Check if we should use the contiguous subgraph transform.
|
Check if we should use the contiguous subgraph transform.
|
||||||
This transform makes the second matrix contiguous before the matmul.
|
This transform makes the second matrix contiguous before the matmul.
|
||||||
"""
|
"""
|
||||||
decompose_k_threshold = config.triton.decompose_k_threshold
|
contiguous_threshold = config.rocm.contiguous_threshold
|
||||||
|
|
||||||
# Similar conditions to decompose_k but for contiguous transform
|
# Similar conditions to decompose_k but for contiguous transform
|
||||||
from torch._inductor.virtualized import V
|
from torch._inductor.virtualized import V
|
||||||
@ -1851,8 +1851,8 @@ def use_contiguous(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
|
|||||||
bool(torch.version.hip) # Only relevant on AMD
|
bool(torch.version.hip) # Only relevant on AMD
|
||||||
and V.graph.sizevars.statically_known_true(
|
and V.graph.sizevars.statically_known_true(
|
||||||
sympy.And(
|
sympy.And(
|
||||||
sympy.Ge(k, decompose_k_threshold * m),
|
sympy.Ge(k, contiguous_threshold * m),
|
||||||
sympy.Ge(k, decompose_k_threshold * n),
|
sympy.Ge(k, contiguous_threshold * n),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
and not V.graph.aot_mode
|
and not V.graph.aot_mode
|
||||||
|
Reference in New Issue
Block a user