[inductor] Add configuration control for CUTLASS operation selection. (#155770)

Added a new configuration option `cutlass_enabled_ops` that allows users to control which operations use CUTLASS lowerings. By default, CUTLASS is enabled for all operations (maintaining backward compatibility), but users can now selectively enable it only for specific operations to optimize compilation time.

**Fixes #155718**

## Usage Examples

```bash
# Enable CUTLASS for all operations (default behavior)
export TORCHINDUCTOR_CUTLASS_ENABLED_OPS="ALL"

# Enable CUTLASS only for matrix multiplication operations
export TORCHINDUCTOR_CUTLASS_ENABLED_OPS="mm,addmm"

# Enable CUTLASS only for batch operations
export TORCHINDUCTOR_CUTLASS_ENABLED_OPS="bmm,baddbmm"

# Disable CUTLASS for all operations
export TORCHINDUCTOR_CUTLASS_ENABLED_OPS=""
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155770
Approved by: https://github.com/henrylhtsang
This commit is contained in:
penknife6153
2025-06-14 08:19:50 +00:00
committed by PyTorch MergeBot
parent 1982ec2d22
commit 3e38feb05f
4 changed files with 60 additions and 6 deletions

View File

@ -434,6 +434,7 @@ max_autotune_gemm_backends = os.environ.get(
"TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP"
).upper()
# As above, specify candidate backends for conv autotune.
# NB: in some cases for 1x1 convs we emit as matmul,
# which will use the backends of `max_autotune_gemm_backends`
@ -1487,6 +1488,13 @@ class cuda:
os.environ.get("TORCHINDUCTOR_CUTLASS_PRESCREENING", "1") == "1"
)
# Specify which operations should use CUTLASS backend
# Comma-separated list like "mm,addmm,bmm", "all" for all operations, and "" for none.
# Acceptable operations: mm, int_mm, addmm, sparse_semi_structured_mm, bmm, scaled_mm
cutlass_enabled_ops: str = os.environ.get(
"TORCHINDUCTOR_CUTLASS_ENABLED_OPS", "all"
)
class rocm:
# Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"].

View File

@ -13,6 +13,7 @@ from ..select_algorithm import (
TritonTemplate,
)
from ..utils import (
_use_cutlass_for_op,
use_aten_gemm_kernels,
use_ck_gemm_template,
use_cpp_bmm_template,
@ -218,7 +219,12 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
)
_, is_nonzero = _is_static_problem(layout)
batch_stride_largest = is_batch_stride_largest(mat1, mat2, layout)
if batch_stride_largest and is_nonzero and use_cutlass_template(layout, m, n, k):
if (
batch_stride_largest
and is_nonzero
and use_cutlass_template(layout, m, n, k)
and _use_cutlass_for_op("bmm")
):
from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) # type: ignore[arg-type]

View File

@ -38,6 +38,7 @@ from ..select_algorithm import (
TritonTemplate,
)
from ..utils import (
_use_cutlass_for_op,
get_k_splits,
get_tma_workspace_arg,
use_aten_gemm_kernels,
@ -772,7 +773,11 @@ def tuned_mm(mat1, mat2, *, layout=None):
layout=layout,
)
if is_nonzero and use_cutlass_template(layout, m, n, k):
if (
is_nonzero
and use_cutlass_template(layout, m, n, k)
and _use_cutlass_for_op("mm")
):
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
@ -867,7 +872,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
[aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
)
if use_cutlass:
if use_cutlass and _use_cutlass_for_op("int_mm"):
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
)
@ -991,7 +996,11 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
)
if is_nonzero and use_cutlass_template(layout, m, n, k):
if (
is_nonzero
and use_cutlass_template(layout, m, n, k)
and _use_cutlass_for_op("addmm")
):
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
choices,
layout,
@ -1061,7 +1070,11 @@ def tuned_sparse_semi_structured_mm(
else []
)
if m * n != 0 and use_cutlass_template(layout, m, n, k):
if (
m * n != 0
and use_cutlass_template(layout, m, n, k)
and _use_cutlass_for_op("sparse_semi_structured_mm")
):
CUTLASS2xGemmTemplate.add_cutlass_gemm_choices(
choices, layout, [mat1, mat2, mat1_meta], fuseable=True, non_fuseable=True
)
@ -1086,6 +1099,21 @@ def tuned_scaled_mm(
use_fast_accum=False,
layout=None,
):
"""
Performs an optimized matrix multiplication where scaling factors are applied
to the inputs and/or output.
Args:
mat1 (Tensor): First input matrix
mat2 (Tensor): Second input matrix
scale1 (Tensor): Scale factor applied to mat1 (supports broadcasting)
scale2 (Tensor): Scale factor applied to mat2 (supports broadcasting)
bias (Tensor, optional): Optional bias tensor to add to the result
layout: Layout hint for optimization
Returns:
Tensor: The result of the scaled matrix multiplication
"""
m, n, k, layout, mat_a, mat_b = mm_args(
mat_a, mat_b, layout=layout, out_dtype=out_dtype
)
@ -1213,7 +1241,11 @@ def tuned_scaled_mm(
epilogue_fn_hash="scale_mm_epilogue",
)
if is_nonzero and use_cutlass_template(layout, m, n, k):
if (
is_nonzero
and use_cutlass_template(layout, m, n, k)
and _use_cutlass_for_op("scaled_mm")
):
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
choices,
layout,

View File

@ -1578,6 +1578,14 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
return res
def _use_cutlass_for_op(op_name: str) -> bool:
"""Check if CUTLASS should be used for the given operation."""
enabled_ops = config.cuda.cutlass_enabled_ops.upper()
if enabled_ops == "ALL":
return True
return op_name.upper() in [x.strip() for x in enabled_ops.split(",")]
decompose_k_threshold = 32
# To limit compile time