mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Add configuration control for CUTLASS operation selection. (#155770)
Added a new configuration option `cutlass_enabled_ops` that allows users to control which operations use CUTLASS lowerings. By default, CUTLASS is enabled for all operations (maintaining backward compatibility), but users can now selectively enable it only for specific operations to optimize compilation time. **Fixes #155718** ## Usage Examples ```bash # Enable CUTLASS for all operations (default behavior) export TORCHINDUCTOR_CUTLASS_ENABLED_OPS="ALL" # Enable CUTLASS only for matrix multiplication operations export TORCHINDUCTOR_CUTLASS_ENABLED_OPS="mm,addmm" # Enable CUTLASS only for batch operations export TORCHINDUCTOR_CUTLASS_ENABLED_OPS="bmm,baddbmm" # Disable CUTLASS for all operations export TORCHINDUCTOR_CUTLASS_ENABLED_OPS="" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/155770 Approved by: https://github.com/henrylhtsang
This commit is contained in:
committed by
PyTorch MergeBot
parent
1982ec2d22
commit
3e38feb05f
@ -434,6 +434,7 @@ max_autotune_gemm_backends = os.environ.get(
|
||||
"TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP"
|
||||
).upper()
|
||||
|
||||
|
||||
# As above, specify candidate backends for conv autotune.
|
||||
# NB: in some cases for 1x1 convs we emit as matmul,
|
||||
# which will use the backends of `max_autotune_gemm_backends`
|
||||
@ -1487,6 +1488,13 @@ class cuda:
|
||||
os.environ.get("TORCHINDUCTOR_CUTLASS_PRESCREENING", "1") == "1"
|
||||
)
|
||||
|
||||
# Specify which operations should use CUTLASS backend
|
||||
# Comma-separated list like "mm,addmm,bmm", "all" for all operations, and "" for none.
|
||||
# Acceptable operations: mm, int_mm, addmm, sparse_semi_structured_mm, bmm, scaled_mm
|
||||
cutlass_enabled_ops: str = os.environ.get(
|
||||
"TORCHINDUCTOR_CUTLASS_ENABLED_OPS", "all"
|
||||
)
|
||||
|
||||
|
||||
class rocm:
|
||||
# Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"].
|
||||
|
@ -13,6 +13,7 @@ from ..select_algorithm import (
|
||||
TritonTemplate,
|
||||
)
|
||||
from ..utils import (
|
||||
_use_cutlass_for_op,
|
||||
use_aten_gemm_kernels,
|
||||
use_ck_gemm_template,
|
||||
use_cpp_bmm_template,
|
||||
@ -218,7 +219,12 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
||||
)
|
||||
_, is_nonzero = _is_static_problem(layout)
|
||||
batch_stride_largest = is_batch_stride_largest(mat1, mat2, layout)
|
||||
if batch_stride_largest and is_nonzero and use_cutlass_template(layout, m, n, k):
|
||||
if (
|
||||
batch_stride_largest
|
||||
and is_nonzero
|
||||
and use_cutlass_template(layout, m, n, k)
|
||||
and _use_cutlass_for_op("bmm")
|
||||
):
|
||||
from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
|
||||
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) # type: ignore[arg-type]
|
||||
|
@ -38,6 +38,7 @@ from ..select_algorithm import (
|
||||
TritonTemplate,
|
||||
)
|
||||
from ..utils import (
|
||||
_use_cutlass_for_op,
|
||||
get_k_splits,
|
||||
get_tma_workspace_arg,
|
||||
use_aten_gemm_kernels,
|
||||
@ -772,7 +773,11 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
layout=layout,
|
||||
)
|
||||
|
||||
if is_nonzero and use_cutlass_template(layout, m, n, k):
|
||||
if (
|
||||
is_nonzero
|
||||
and use_cutlass_template(layout, m, n, k)
|
||||
and _use_cutlass_for_op("mm")
|
||||
):
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
|
||||
|
||||
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
|
||||
@ -867,7 +872,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
||||
[aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
|
||||
)
|
||||
|
||||
if use_cutlass:
|
||||
if use_cutlass and _use_cutlass_for_op("int_mm"):
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
||||
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
|
||||
)
|
||||
@ -991,7 +996,11 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
|
||||
)
|
||||
|
||||
if is_nonzero and use_cutlass_template(layout, m, n, k):
|
||||
if (
|
||||
is_nonzero
|
||||
and use_cutlass_template(layout, m, n, k)
|
||||
and _use_cutlass_for_op("addmm")
|
||||
):
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
||||
choices,
|
||||
layout,
|
||||
@ -1061,7 +1070,11 @@ def tuned_sparse_semi_structured_mm(
|
||||
else []
|
||||
)
|
||||
|
||||
if m * n != 0 and use_cutlass_template(layout, m, n, k):
|
||||
if (
|
||||
m * n != 0
|
||||
and use_cutlass_template(layout, m, n, k)
|
||||
and _use_cutlass_for_op("sparse_semi_structured_mm")
|
||||
):
|
||||
CUTLASS2xGemmTemplate.add_cutlass_gemm_choices(
|
||||
choices, layout, [mat1, mat2, mat1_meta], fuseable=True, non_fuseable=True
|
||||
)
|
||||
@ -1086,6 +1099,21 @@ def tuned_scaled_mm(
|
||||
use_fast_accum=False,
|
||||
layout=None,
|
||||
):
|
||||
"""
|
||||
Performs an optimized matrix multiplication where scaling factors are applied
|
||||
to the inputs and/or output.
|
||||
|
||||
Args:
|
||||
mat1 (Tensor): First input matrix
|
||||
mat2 (Tensor): Second input matrix
|
||||
scale1 (Tensor): Scale factor applied to mat1 (supports broadcasting)
|
||||
scale2 (Tensor): Scale factor applied to mat2 (supports broadcasting)
|
||||
bias (Tensor, optional): Optional bias tensor to add to the result
|
||||
layout: Layout hint for optimization
|
||||
|
||||
Returns:
|
||||
Tensor: The result of the scaled matrix multiplication
|
||||
"""
|
||||
m, n, k, layout, mat_a, mat_b = mm_args(
|
||||
mat_a, mat_b, layout=layout, out_dtype=out_dtype
|
||||
)
|
||||
@ -1213,7 +1241,11 @@ def tuned_scaled_mm(
|
||||
epilogue_fn_hash="scale_mm_epilogue",
|
||||
)
|
||||
|
||||
if is_nonzero and use_cutlass_template(layout, m, n, k):
|
||||
if (
|
||||
is_nonzero
|
||||
and use_cutlass_template(layout, m, n, k)
|
||||
and _use_cutlass_for_op("scaled_mm")
|
||||
):
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
||||
choices,
|
||||
layout,
|
||||
|
@ -1578,6 +1578,14 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
|
||||
return res
|
||||
|
||||
|
||||
def _use_cutlass_for_op(op_name: str) -> bool:
|
||||
"""Check if CUTLASS should be used for the given operation."""
|
||||
enabled_ops = config.cuda.cutlass_enabled_ops.upper()
|
||||
if enabled_ops == "ALL":
|
||||
return True
|
||||
return op_name.upper() in [x.strip() for x in enabled_ops.split(",")]
|
||||
|
||||
|
||||
decompose_k_threshold = 32
|
||||
|
||||
# To limit compile time
|
||||
|
Reference in New Issue
Block a user