Compare commits

...

40 Commits

Author SHA1 Message Date
10fb096070 Update on "[Inductor XPU GEMM] Step 9/N: Support generating XPU cutlass gemm kernel"
This PR implements the CUTLASS XPU backend kernel generation as proposed in RFC #160175. It reuses most of the CUTLASS CUDA kernel generation code, with only minor adjustments made to handle XPU-specific code generation.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-31 03:17:53 +00:00
d943cfe814 Update base for Update on "[Inductor XPU GEMM] Step 9/N: Support generating XPU cutlass gemm kernel"
This PR implements the CUTLASS XPU backend kernel generation as proposed in RFC #160175. It reuses most of the CUTLASS CUDA kernel generation code, with only minor adjustments made to handle XPU-specific code generation.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-31 03:17:53 +00:00
0839124dae Update on "[Inductor XPU GEMM] Step 9/N: Support generating XPU cutlass gemm kernel"
This PR implements the CUTLASS XPU backend kernel generation as proposed in RFC #160175. It reuses most of the CUTLASS CUDA kernel generation code, with only minor adjustments made to handle XPU-specific code generation.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-08 01:07:24 +00:00
b7e073c4d0 Update base for Update on "[Inductor XPU GEMM] Step 9/N: Support generating XPU cutlass gemm kernel"
This PR implements the CUTLASS XPU backend kernel generation as proposed in RFC #160175. It reuses most of the CUTLASS CUDA kernel generation code, with only minor adjustments made to handle XPU-specific code generation.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-08 01:07:24 +00:00
08ceb31352 Update on "[Inductor XPU GEMM] Step 9/N: Support generating XPU cutlass gemm kernel"
This PR implements the CUTLASS XPU backend kernel generation as proposed in RFC #160175. It reuses most of the CUTLASS CUDA kernel generation code, with only minor adjustments made to handle XPU-specific code generation.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-08 01:06:16 +00:00
8c6cab4dad Update base for Update on "[Inductor XPU GEMM] Step 9/N: Support generating XPU cutlass gemm kernel"
This PR implements the CUTLASS XPU backend kernel generation as proposed in RFC #160175. It reuses most of the CUTLASS CUDA kernel generation code, with only minor adjustments made to handle XPU-specific code generation.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-08 01:06:16 +00:00
0a0edbbfcf Update on "[Inductor XPU GEMM] Step 9/N: Support generating XPU cutlass gemm kernel"
This PR implements the CUTLASS XPU backend kernel generation as proposed in RFC #160175. It reuses most of the CUTLASS CUDA kernel generation code, with only minor adjustments made to handle XPU-specific code generation.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-08 00:25:09 +00:00
690ed68946 Update base for Update on "[Inductor XPU GEMM] Step 9/N: Support generating XPU cutlass gemm kernel"
This PR implements the CUTLASS XPU backend kernel generation as proposed in RFC #160175. It reuses most of the CUTLASS CUDA kernel generation code, with only minor adjustments made to handle XPU-specific code generation.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-08 00:25:09 +00:00
76ab0f1a09 Update on "[Inductor XPU GEMM] Step 9/N: Support generating XPU cutlass gemm kernel"
This PR implements the CUTLASS XPU backend kernel generation as proposed in RFC #160175. It reuses most of the CUTLASS CUDA kernel generation code, with only minor adjustments made to handle XPU-specific code generation.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-04 08:53:38 +00:00
5bd370a702 Update base for Update on "[Inductor XPU GEMM] Step 9/N: Support generating XPU cutlass gemm kernel"
This PR implements the CUTLASS XPU backend kernel generation as proposed in RFC #160175. It reuses most of the CUTLASS CUDA kernel generation code, with only minor adjustments made to handle XPU-specific code generation.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-04 08:53:38 +00:00
b1eb229d30 Update on "[Inductor XPU GEMM] Step 9/N: Support generating XPU cutlass gemm kernel"
kernel.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-03 03:02:54 +00:00
1c4b94c0d2 Update base for Update on "[Inductor XPU GEMM] Step 9/N: Support generating XPU cutlass gemm kernel"
kernel.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-03 03:02:54 +00:00
0f60e4c18d Update on "[Inductor XPU GEMM] Step 9/N: Support generating XPU cutlass gemm"
kernel.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-02 10:58:36 +00:00
23e3d4cfe2 Update base for Update on "[Inductor XPU GEMM] Step 9/N: Support generating XPU cutlass gemm"
kernel.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-02 10:58:36 +00:00
4124804146 [Inductor XPU GEMM] Step 9/N: Support generating XPU cutlass gemm
kernel.

[ghstack-poisoned]
2025-09-02 03:10:02 +00:00
663aa67934 [Inductor XPU GEMM] Step 8/N: Add XPU code compilation and codecache.
[ghstack-poisoned]
2025-09-02 03:09:58 +00:00
f0d3f7db52 Update on "[Inductor XPU GEMM] Step 8/N: Refactor CUDABenchmarkRequest"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-02 03:09:58 +00:00
ec887d1962 Update base for Update on "[Inductor XPU GEMM] Step 8/N: Refactor CUDABenchmarkRequest"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-02 03:09:58 +00:00
2cae11a0d1 Update on "[Inductor XPU GEMM] Step 8/N: Refactor CUDABenchmarkRequest"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-08-16 00:36:15 +00:00
c5b67af4e7 Update base for Update on "[Inductor XPU GEMM] Step 8/N: Refactor CUDABenchmarkRequest"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-08-16 00:36:15 +00:00
53e25e520e [Inductor XPU GEMM] Step 8/N: Refactor CUDABenchmarkRequest
[ghstack-poisoned]
2025-08-15 08:52:19 +00:00
4381da8371 Update on "[Inductor XPU GEMM] Step 7/N: Refactor CUDACodeCache."
[ghstack-poisoned]
2025-08-15 08:52:19 +00:00
938b5b7424 Update base for Update on "[Inductor XPU GEMM] Step 7/N: Refactor CUDACodeCache."
[ghstack-poisoned]
2025-08-15 08:52:19 +00:00
1fc3e1abe8 Update on "[Inductor XPU GEMM] Step 7/N: Refactor CUDACodeCache."
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-08-15 05:54:37 +00:00
f7d1d70526 Update base for Update on "[Inductor XPU GEMM] Step 7/N: Refactor CUDACodeCache."
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-08-15 05:54:37 +00:00
3ae7cacd26 Update on "[Inductor XPU GEMM] Step 7/N: Refactor CUDACodeCache."
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-08-15 02:57:44 +00:00
2060620b08 Update base for Update on "[Inductor XPU GEMM] Step 7/N: Refactor CUDACodeCache."
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-08-15 02:57:44 +00:00
501a28a3e7 [Inductor XPU GEMM] Step 7/N: Refactor CUDACodeCache.
[ghstack-poisoned]
2025-08-15 02:56:21 +00:00
ac2acb2097 Update on "[Inductor XPU GEMM] Step 6/N: Refactor CUDACombinedScheduling and CUDACppScheduling."
cc jeffdaily sunway513 jithunnair-amd pruthvistony ROCmSupport dllehr-amd jataylo hongxiayang naromero77amd voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-08-15 02:56:21 +00:00
8bc69def41 Update base for Update on "[Inductor XPU GEMM] Step 6/N: Refactor CUDACombinedScheduling and CUDACppScheduling."
cc jeffdaily sunway513 jithunnair-amd pruthvistony ROCmSupport dllehr-amd jataylo hongxiayang naromero77amd voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-08-15 02:56:21 +00:00
874fc3199c Update on " [Inductor XPU GEMM] Step 6/N: Refactor CUDACombinedScheduling and CUDACppScheduling."
cc jeffdaily sunway513 jithunnair-amd pruthvistony ROCmSupport dllehr-amd jataylo hongxiayang naromero77amd voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-08-14 23:13:35 +00:00
a0819a3a88 Update base for Update on " [Inductor XPU GEMM] Step 6/N: Refactor CUDACombinedScheduling and CUDACppScheduling."
cc jeffdaily sunway513 jithunnair-amd pruthvistony ROCmSupport dllehr-amd jataylo hongxiayang naromero77amd voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-08-14 23:13:35 +00:00
e36aeb0a2e [Inductor XPU GEMM] Step 6/N: Refactor CUDACombinedScheduling and CUDACppScheduling.
[ghstack-poisoned]
2025-08-14 22:25:35 +00:00
c924349562 [Inductor XPU GEMM] Step 5/N Refactor CUDAKernel/CUDATemplateKernel/CUDATemplateCaller/CUDATemplateBuffer to CUTLASSKernel/CUTLASSTemplateKernel/CUTLASSTemplateCaller/CUTLASSTemplateBuffer.
[ghstack-poisoned]
2025-08-14 22:25:31 +00:00
c4a551d72f [Inductor XPU GEMM] Step 4/N Refactor CUDATempalte to CUTLASSTemplate.
[ghstack-poisoned]
2025-08-14 22:25:28 +00:00
08f5fe8139 [Inductor XPU GEMM] Step 3/N: Move cutlass files from torch/_inductor/codegen/cuda to
torch/_inductor/codegen/cutlass.

Signed-off-by: xinan.lin <xinan.lin@intel.com>

[ghstack-poisoned]
2025-08-14 22:25:24 +00:00
8604af4bfa Update on "[Inductor XPU GEMM] Step 2/N: Generalize cutlass configuration."
This PR is the second step toward implementing RFC #160175.
Currently, all Cutlass-related Torch Inductor configs are located in `torch._inductor.config.cuda`. This PR refactors the device-agnostic Cutlass configurations into a separate module, `torch._inductor.config.cutlass`, so they can be shared and reused by XPU as well.


[ghstack-poisoned]
2025-08-14 22:25:24 +00:00
f4ef77b220 Update base for Update on "[Inductor XPU GEMM] Step 2/N: Generalize cutlass configuration."
This PR is the second step toward implementing RFC #160175.
Currently, all Cutlass-related Torch Inductor configs are located in `torch._inductor.config.cuda`. This PR refactors the device-agnostic Cutlass configurations into a separate module, `torch._inductor.config.cutlass`, so they can be shared and reused by XPU as well.


[ghstack-poisoned]
2025-08-14 22:25:24 +00:00
44789dc58d [Inductor Intel Cutlass] Step 2/N: Generalize cutlass configuration.
[ghstack-poisoned]
2025-08-08 06:03:45 +00:00
2ddfc76a10 [Inductor Intel Cutlass] Step 1/N: Add Intel Cutlass repro into third_party.
[ghstack-poisoned]
2025-08-08 06:03:41 +00:00
47 changed files with 1071 additions and 692 deletions

View File

@ -125,7 +125,7 @@ class CutlassExperimentConfig(ExperimentConfig):
def to_options(self) -> dict[str, Any]:
return {
**super().to_options(),
"cuda.cutlass_instantiation_level": self.cutlass_instantiation_level,
"cutlass.cutlass_instantiation_level": self.cutlass_instantiation_level,
}

View File

@ -24,7 +24,6 @@ from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._inductor import config, metrics
from torch._inductor.codecache import (
BypassFxGraphCache,
cuda_compile_command,
CUDACodeCache,
FxGraphCachePickler,
FxGraphHashDetails,
@ -32,6 +31,7 @@ from torch._inductor.codecache import (
TensorMetadata,
TensorMetadataAndValues,
)
from torch._inductor.codegen.cuda.compile_utils import cuda_compile_command
from torch._inductor.cpp_builder import normalize_path_separator
from torch._inductor.custom_graph_pass import (
CustomGraphModulePass,

View File

@ -14,7 +14,9 @@ from pathlib import Path
from typing import Optional
from torch._dynamo.exc import BackendCompilerFailed
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
from torch._inductor.codegen.cutlass.serialization import (
get_cutlass_operation_serializer,
)
from torch._inductor.utils import clear_caches
from torch.export import Dim
from torch.testing._internal.logging_utils import log_settings
@ -32,11 +34,8 @@ import torch.version
from torch._dynamo import config as dynamo_config
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
from torch._inductor.codegen.cuda.cutlass_utils import (
_gen_ops_cached,
get_max_alignment,
)
from torch._inductor.codegen.cutlass.kernel import CUTLASSTemplateCaller
from torch._inductor.codegen.cutlass.utils import _gen_ops_cached, get_max_alignment
from torch._inductor.exc import InductorError
from torch._inductor.ir import FixedLayout
from torch._inductor.select_algorithm import NoValidChoicesError
@ -133,10 +132,10 @@ use_evt_config = config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 1,
"cutlass.cutlass_max_profiling_configs": 1,
"benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet
"cuda.cutlass_tma_only": True,
"cuda.cutlass_epilogue_fusion_enabled": True,
"cutlass.cutlass_tma_only": True,
"cutlass.cutlass_epilogue_fusion_enabled": True,
}
)
@ -144,9 +143,9 @@ fp8_config = config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 1,
"cutlass.cutlass_max_profiling_configs": 1,
"benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet
"cuda.cutlass_tma_only": True,
"cutlass.cutlass_tma_only": True,
}
)
@ -198,7 +197,7 @@ class TestCutlassBackend(TestCase):
ref_result = model(a, b, extra_args)
self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"],
torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
num_fusions,
)
torch.testing.assert_close(result, ref_result)
@ -206,7 +205,7 @@ class TestCutlassBackend(TestCase):
def test_check_paths(self):
cutlass_mock_imports_path = os.path.join(
os.path.dirname(torch.__file__),
"_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports",
"_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports",
)
cutlass_mock_cuda_path = os.path.join(cutlass_mock_imports_path, "cuda")
cutlass_mock_pydot_path = os.path.join(cutlass_mock_imports_path, "pydot")
@ -234,8 +233,8 @@ class TestCutlassBackend(TestCase):
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"compile_threads": 4,
"cuda.cutlass_backend_min_gemm_size": 100000,
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_backend_min_gemm_size": 100000,
"cutlass.cutlass_max_profiling_configs": 2,
}
):
with mock.patch(
@ -251,7 +250,7 @@ class TestCutlassBackend(TestCase):
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_import_cutlass(self):
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
from torch._inductor.codegen.cutlass.utils import try_import_cutlass
self.assertTrue(try_import_cutlass())
@ -259,7 +258,7 @@ class TestCutlassBackend(TestCase):
import cutlass_library # noqa: F401
def test_cutlass_key(self):
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
from torch._inductor.codegen.cutlass.utils import try_import_cutlass
self.assertTrue(try_import_cutlass())
from torch._inductor.codecache import cutlass_key
@ -287,7 +286,7 @@ class TestCutlassBackend(TestCase):
"autotune_in_subproc": True,
"max_autotune_gemm_backends": "CUTLASS",
"compile_threads": 4,
"cuda.cutlass_max_profiling_configs": 4,
"cutlass.cutlass_max_profiling_configs": 4,
}
):
Y_compiled = torch.compile(torch.mm)(a, b)
@ -324,7 +323,7 @@ class TestCutlassBackend(TestCase):
"autotune_in_subproc": True,
"max_autotune_gemm_backends": "CUTLASS",
"compile_threads": 4,
"cuda.cutlass_max_profiling_configs": 4,
"cutlass.cutlass_max_profiling_configs": 4,
}
):
for x_shape in x_shapes:
@ -354,7 +353,7 @@ class TestCutlassBackend(TestCase):
"autotune_in_subproc": True,
"max_autotune_gemm_backends": "CUTLASS",
"compile_threads": 4,
"cuda.cutlass_max_profiling_configs": 4,
"cutlass.cutlass_max_profiling_configs": 4,
}
):
Y_compiled = torch.compile(torch.bmm)(a, b)
@ -386,7 +385,7 @@ class TestCutlassBackend(TestCase):
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 1,
"cutlass.cutlass_max_profiling_configs": 1,
}
):
from torch._inductor.utils import run_and_get_code
@ -428,8 +427,8 @@ class TestCutlassBackend(TestCase):
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 1,
"cuda.cutlass_max_profiling_swizzle_options": [
"cutlass.cutlass_max_profiling_configs": 1,
"cutlass.cutlass_max_profiling_swizzle_options": [
1,
2,
4,
@ -505,7 +504,7 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_max_profiling_configs": 2,
}
),
dynamo_config.patch({"error_on_recompile": dynamic}),
@ -595,9 +594,9 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_max_profiling_configs": 2,
"benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet
"cuda.cutlass_tma_only": True,
"cutlass.cutlass_tma_only": True,
}
),
dynamo_config.patch({"error_on_recompile": dynamic}),
@ -677,7 +676,7 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_max_profiling_configs": 2,
}
),
dynamo_config.patch({"error_on_recompile": dynamic}),
@ -746,7 +745,7 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_max_profiling_configs": 2,
}
):
expected = [model(*input) for input in inputs]
@ -775,8 +774,8 @@ class TestCutlassBackend(TestCase):
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 2,
"cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels
"cutlass.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels
}
):
for M, K, N in (
@ -819,7 +818,7 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels
"cutlass.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels
}
):
with self.assertRaisesRegex(InductorError, r".*NoValidChoicesError.*"):
@ -849,8 +848,8 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 1,
"cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels
"cutlass.cutlass_max_profiling_configs": 1,
"cutlass.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels
}
):
_ = compiled_model(a, b)
@ -884,15 +883,15 @@ class TestCutlassBackend(TestCase):
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 4,
"cuda.version": "12.2", # required to enable the Kernels we need
"cutlass.cutlass_max_profiling_configs": 4,
"cuda.cuda_version": "12.2", # required to enable the Kernels we need
}
):
counters["inductor"]["cuda_epilogue_fusion_counter"] = 0
counters["inductor"]["cutlass_epilogue_fusion_counter"] = 0
assert mm is not None
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
Y = mm(a, b)
actual_count = counters["inductor"]["cuda_epilogue_fusion_counter"]
actual_count = counters["inductor"]["cutlass_epilogue_fusion_counter"]
assert actual_count == expected_fuse_count, (
f"Expected fuse count of {expected_fuse_count} but got {actual_count}"
)
@ -983,7 +982,7 @@ class TestCutlassBackend(TestCase):
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_max_profiling_configs": 2,
}
):
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
@ -1002,7 +1001,7 @@ class TestCutlassBackend(TestCase):
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_max_profiling_configs": 2,
}
):
model = MyModel()
@ -1040,7 +1039,7 @@ class TestCutlassBackend(TestCase):
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_max_profiling_configs": 2,
}
):
model = MyModel()
@ -1073,8 +1072,8 @@ class TestCutlassBackend(TestCase):
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_op_allowlist_regex": "128x256x64.*stream_k_warpspecialized_cooperative_epi_nosmem",
"cuda.cutlass_max_profiling_configs": 1,
"cutlass.cutlass_op_allowlist_regex": "128x256x64.*stream_k_warpspecialized_cooperative_epi_nosmem",
"cutlass.cutlass_max_profiling_configs": 1,
}
):
model = MyModel()
@ -1117,7 +1116,7 @@ class TestCutlassBackend(TestCase):
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_max_profiling_configs": 2,
"autotune_local_cache": True,
}
):
@ -1157,9 +1156,9 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 2,
"cuda.cutlass_op_allowlist_regex": "",
"cuda.cutlass_op_denylist_regex": "pingpong",
"cutlass.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_op_allowlist_regex": "",
"cutlass.cutlass_op_denylist_regex": "pingpong",
}
):
with mock.patch(
@ -1175,7 +1174,7 @@ class TestCutlassBackend(TestCase):
assert op_name == "addmm"
cuda_template_count = 0
for choice in choices:
if isinstance(choice, CUDATemplateCaller):
if isinstance(choice, CUTLASSTemplateCaller):
choice_info = choice.info_dict()
op_conf_name = choice_info.get("op_conf_name", "")
assert isinstance(op_conf_name, str)
@ -1183,7 +1182,7 @@ class TestCutlassBackend(TestCase):
"All pingpong Kernels should have been filtered"
)
cuda_template_count += 1
assert cuda_template_count > 0, "No CUDATemplateCaller choices"
assert cuda_template_count > 0, "No CUTLASSTemplateCaller choices"
@unittest.skipIf(not SM90OrLater, "need sm_90")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@ -1202,9 +1201,9 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 2,
"cuda.cutlass_op_allowlist_regex": "pingpong",
"cuda.cutlass_op_denylist_regex": None,
"cutlass.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_op_allowlist_regex": "pingpong",
"cutlass.cutlass_op_denylist_regex": None,
}
):
with mock.patch(
@ -1220,7 +1219,7 @@ class TestCutlassBackend(TestCase):
assert op_name == "addmm"
cuda_template_count = 0
for choice in choices:
if isinstance(choice, CUDATemplateCaller):
if isinstance(choice, CUTLASSTemplateCaller):
choice_info = choice.info_dict()
op_conf_name = choice_info.get("op_conf_name", "")
assert isinstance(op_conf_name, str)
@ -1228,7 +1227,7 @@ class TestCutlassBackend(TestCase):
"Only pingpong Kernels should have been allowed"
)
cuda_template_count += 1
assert cuda_template_count > 0, "No CUDATemplateCaller choices"
assert cuda_template_count > 0, "No CUTLASSTemplateCaller choices"
@unittest.skipIf(not SM90OrLater, "need sm_90")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@ -1273,7 +1272,7 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_max_profiling_configs": 2,
}
):
with mock.patch(
@ -1296,7 +1295,7 @@ class TestCutlassBackend(TestCase):
_, choices, _, _ = args
cuda_template_count = 0
for choice in choices:
if isinstance(choice, CUDATemplateCaller):
if isinstance(choice, CUTLASSTemplateCaller):
choice_info = choice.info_dict()
op_conf_name = choice_info.get("op_conf_name", "")
assert isinstance(op_conf_name, str)
@ -1309,7 +1308,9 @@ class TestCutlassBackend(TestCase):
"fastaccum Kernels should have been filtered"
)
cuda_template_count += 1
assert cuda_template_count > 0, "No CUDATemplateCaller choices"
assert cuda_template_count > 0, (
"No CUTLASSTemplateCaller choices"
)
run_test(True)
run_test(False)
@ -1350,7 +1351,7 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_max_profiling_configs": 2,
}
),
mock.patch(
@ -1375,7 +1376,7 @@ class TestCutlassBackend(TestCase):
assert op_name == "mm"
cuda_template_count = 0
for choice in choices:
if isinstance(choice, CUDATemplateCaller):
if isinstance(choice, CUTLASSTemplateCaller):
choice_info = choice.info_dict()
op_conf_name = choice_info.get("op_conf_name", "")
assert isinstance(op_conf_name, str)
@ -1384,7 +1385,7 @@ class TestCutlassBackend(TestCase):
self.assertGreater(
cuda_template_count,
0,
"No CUDATemplateCaller choices found for matmul with shape "
"No CUTLASSTemplateCaller choices found for matmul with shape "
f"M={M}, N={N}, K={K}",
)
@ -1461,13 +1462,13 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 2,
"cuda.generate_test_runner": True, # put standalone runner in the generated code
"cutlass.cutlass_max_profiling_configs": 2,
"cutlass.generate_test_runner": True, # put standalone runner in the generated code
}
):
from tempfile import NamedTemporaryFile
from torch._inductor.codegen.cuda.cutlass_utils import (
from torch._inductor.codegen.cuda.compile_utils import (
cuda_standalone_runner_compile_command,
CUDACompileSourceCapturingContext,
)
@ -1544,7 +1545,7 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": "ATEN,TRITON,CUTLASS",
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_max_profiling_configs": 2,
# needed for log searching
"fx_graph_cache": False,
"fx_graph_remote_cache": False,
@ -1553,7 +1554,7 @@ class TestCutlassBackend(TestCase):
with (
log_settings("+inductor"),
self.assertLogs(
logger="torch._inductor.codegen.cuda", level=logging.DEBUG
logger="torch._inductor.codegen.cutlass", level=logging.DEBUG
) as test_log,
):
Y_compiled = torch.compile(mm, dynamic=False)(a, b)
@ -1591,7 +1592,7 @@ class TestCutlassBackend(TestCase):
expected = model(A, B)
# Track render calls
from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate
from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate
original_render = CUTLASSGemmTemplate.render
render_call_count = 0
@ -1608,8 +1609,8 @@ class TestCutlassBackend(TestCase):
"max_autotune_gemm_backends": "CUTLASS",
"fx_graph_cache": False,
"fx_graph_remote_cache": False,
"cuda.enable_caching_codegen": True,
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.enable_caching_codegen": True,
"cutlass.cutlass_max_profiling_configs": 2,
}
):
compiled_model = torch.compile(model, fullgraph=True)
@ -1645,7 +1646,7 @@ class TestCutlassBackend(TestCase):
d = torch.randn(64, 128).cuda().half().t()
# Track render calls
from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate
from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate
original_render = CUTLASSGemmTemplate.render
render_call_count = 0
@ -1660,10 +1661,10 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_max_profiling_configs": 2,
"fx_graph_cache": False,
"fx_graph_remote_cache": False,
"cuda.enable_caching_codegen": True,
"cutlass.enable_caching_codegen": True,
}
):
# Get expected results
@ -1706,7 +1707,7 @@ class TestCutlassBackend(TestCase):
b = torch.randn(32, 64).cuda().half().t()
# Track render calls
from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate
from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate
original_render = CUTLASSGemmTemplate.render
render_call_count = 0
@ -1721,10 +1722,10 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_max_profiling_configs": 2,
"fx_graph_cache": False,
"fx_graph_remote_cache": False,
"cuda.enable_caching_codegen": True,
"cutlass.enable_caching_codegen": True,
}
):
# Get expected results
@ -1752,7 +1753,7 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_max_profiling_configs": 2,
}
):
compiled = torch.compile(torch.mm)
@ -1771,7 +1772,7 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 2,
"cutlass.cutlass_max_profiling_configs": 2,
}
):
compiled = torch.compile(torch.mm)
@ -1795,7 +1796,7 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 1,
"cutlass.cutlass_max_profiling_configs": 1,
}
):
_ = torch.compile(model)(B)
@ -1817,13 +1818,14 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 1,
"cutlass.cutlass_max_profiling_configs": 1,
}
):
_ = torch.compile(model)(B)
self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
1,
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@ -1845,7 +1847,7 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 1,
"cutlass.cutlass_max_profiling_configs": 1,
}
):
_ = torch.compile(model)(B)
@ -1871,7 +1873,7 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 1,
"cutlass.cutlass_max_profiling_configs": 1,
}
):
if use_aoti:
@ -1917,7 +1919,8 @@ class TestCutlassBackend(TestCase):
ref_result = model(a, b, extra_args)
self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
1,
)
torch.testing.assert_close(result, ref_result)
@ -1968,18 +1971,19 @@ class TestCutlassBackend(TestCase):
# baseline is cutlass kernel + triton
# matches expected casting behavior
with config.patch({"cuda.cutlass_epilogue_fusion_enabled": False}):
with config.patch({"cutlass.cutlass_epilogue_fusion_enabled": False}):
ref_result = torch.compile(model)(a, b, extra_args)
self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 0
torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
0,
)
torch._dynamo.reset()
result = torch.compile(model)(a, b, extra_args)
self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"],
torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
1,
)
@ -2037,7 +2041,7 @@ class TestCutlassBackend(TestCase):
self.assertEqual(
torch._dynamo.utils.counters["inductor"][
"cuda_epilogue_fusion_counter"
"cutlass_epilogue_fusion_counter"
],
2 * (i + 1),
)
@ -2064,7 +2068,8 @@ class TestCutlassBackend(TestCase):
ref_result = model(a, b, extra_args)
self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
1,
)
torch.testing.assert_close(result, ref_result)
@ -2368,7 +2373,7 @@ class TestCutlassBackend(TestCase):
"max_autotune_gemm_backends": "CUTLASS",
# needed for log searching
"force_disable_caches": True,
"cuda.cutlass_max_profiling_swizzle_options": [2],
"cutlass.cutlass_max_profiling_swizzle_options": [2],
}
):
with mock.patch(

View File

@ -5,7 +5,7 @@ import sympy
import torch
from torch._dynamo.test_case import TestCase
from torch._inductor.codegen.cuda.cutlass_utils import (
from torch._inductor.codegen.cutlass.utils import (
torch_dtype_to_cutlass_type,
try_import_cutlass,
)
@ -28,7 +28,7 @@ if try_import_cutlass():
DataType = cutlass_lib.DataType
from cutlass_cppgen.backend.evt.ir.tensor import Tensor as CutlassTensor
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
from torch._inductor.codegen.cutlass.lib_extensions.evt_extensions import (
_render_argument_type,
_trace,
trace,
@ -107,7 +107,7 @@ class TestCutlassEVT(TestCase):
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_py_codegen_accumulator_return(self):
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen
from torch._inductor.virtualized import V
size = (100, 300, 200)
@ -164,7 +164,7 @@ return tmp_0, tmp_2, D""",
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_py_codegen_disjoint_read_indexing(self):
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen
from torch._inductor.virtualized import V
size = (100, 300, 200)
@ -213,7 +213,7 @@ index strides [200, 60000, 1], and layout stride [60000, 200, 1]""",
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_py_codegen_broadcasting(self):
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen
from torch._inductor.virtualized import V
size = (100, 300, 200)
@ -273,7 +273,7 @@ return tmp_0, tmp_2, D""",
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_py_codegen(self):
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen
from torch._inductor.virtualized import V
size = (100, 300, 200)
@ -329,7 +329,7 @@ return tmp_1, D""",
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_example_tensor_creation(self):
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
from torch._inductor.codegen.cutlass.lib_extensions.evt_extensions import (
create_example_tensors,
)
from torch._inductor.virtualized import V

View File

@ -292,7 +292,7 @@ class TestPublicBindings(TestCase):
# do not get imported by public code.
# DO NOT add public modules here.
private_allowlist = {
"torch._inductor.codegen.cuda.cuda_kernel",
"torch._inductor.codegen.cutlass.kernel",
# TODO(#133647): Remove the onnx._internal entries after
# onnx and onnxscript are installed in CI.
"torch.onnx._internal.exporter",
@ -357,8 +357,8 @@ class TestPublicBindings(TestCase):
"torch.testing._internal.distributed.rpc.rpc_test",
"torch.testing._internal.distributed.rpc.tensorpipe_rpc_agent_test_fixture",
"torch.testing._internal.distributed.rpc_utils",
"torch._inductor.codegen.cuda.cuda_template",
"torch._inductor.codegen.cuda.gemm_template",
"torch._inductor.codegen.cutlass.template",
"torch._inductor.codegen.cutlass.gemm_template",
"torch._inductor.codegen.cpp_template",
"torch._inductor.codegen.cpp_gemm_template",
"torch._inductor.codegen.cpp_micro_gemm",

View File

@ -264,4 +264,4 @@
"torch/_inductor/utils.py": {
"class IndentedBuffer": 145
}
}
}

View File

@ -37,6 +37,7 @@ from torch._inductor.codecache import (
ROCmCodeCache,
StaticAutotunerFuture,
torch_key,
XPUCodeCache,
)
from torch._inductor.compile_worker.subproc_pool import (
AnyPool,
@ -557,6 +558,19 @@ class AsyncCompile:
return self.submit(task)
def xpu(self, source_code, dst_file_ext, aot_compile=False):
kernel_code_log.info("XPU Kernel:\n%s", source_code)
def task():
if aot_compile:
# We rely on JITInductor to compile the CUDA code,
# so that we can load it into AOTInductor.
output_path, *_ = XPUCodeCache.compile(source_code, "o")
XPUCodeCache.aot_kernels_o.append(output_path)
return XPUCodeCache.load(source_code, dst_file_ext)[0]
return self.submit(task)
def rocm(
self,
source_code,

View File

@ -30,6 +30,7 @@ from torch._inductor.codecache import (
DLLWrapper,
get_hash,
PyCodeCache,
XPUCodeCache,
)
from torch._inductor.utils import (
get_gpu_type,
@ -682,7 +683,7 @@ class TritonCPUBenchmarkRequest(CPUDeviceBenchmarkMixin, TritonBenchmarkRequest)
pass
class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
class CUTLASSBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
"""
A class to handle CUDA (CUTLASS) benchmark requests. This class is for
managing the lifecycle of a CUDA kernel benchmark, including compiling
@ -699,6 +700,7 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
extra_args: Iterable[Any],
source_code: str,
device_type: str = "cuda",
) -> None:
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
self.source_code = source_code
@ -708,7 +710,12 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
self._workspace_size_updated = False
self.hash_key: str = ""
self.source_file: str = ""
self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so")
self.device_type = device_type
self.codecache_cls = XPUCodeCache if device_type == "xpu" else CUDACodeCache
self.device_interface = get_interface_for_device(device_type)
self.hash_key, self.source_file = self.codecache_cls.write(
self.source_code, "so"
)
def precompile(self):
"""
@ -716,14 +723,14 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
This may happen in a separate thread pool.
"""
autotuning_log.debug("Precompiling %s", self)
CUDACodeCache.compile(self.source_code, "so")
self.codecache_cls.compile(self.source_code, "so")
autotuning_log.debug("Done precompiling %s", self)
def make_run_fn(
self, *input_tensors: torch.Tensor, out: torch.Tensor
) -> Callable[[], None]:
"""
Create a function to run the CUDA kernel with the given input and output tensors.
Create a function to run the CUDA/XPU kernel with the given input and output tensors.
"""
self.ensure_dll_loaded()
@ -738,7 +745,9 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
args,
self.extra_args,
)
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
stream_ptr = c_void_p(
self.device_interface.get_raw_stream(self.device_interface.current_device())
)
run_method = getattr(self.DLL, self.kernel_name)
workspace_ptr = c_void_p(0)
if self.workspace_size > 0:
@ -781,7 +790,9 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
dict.fromkeys(meta.name for meta in self.input_tensor_meta)
)
args = [c_void_p(None) for _ in range(unique_input_count + 1)]
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
stream_ptr = c_void_p(
self.device_interface.get_raw_stream(self.device_interface.current_device())
)
run_method = getattr(self.DLL, self.kernel_name)
# Retrieve workspace_size and initialize workspace.
@ -795,7 +806,7 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
None, # null workspace ptr
stream_ptr,
)
torch.cuda.synchronize() # shake out any CUDA errors
self.device_interface.synchronize() # shake out any device errors
self.workspace_size = c_workspace_size.value
autotuning_log.debug(
"update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950
@ -811,7 +822,7 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
def ensure_dll_loaded(self):
if self.DLL is None:
self.DLL, self.hash_key, self.source_file = CUDACodeCache.load(
self.DLL, self.hash_key, self.source_file = self.codecache_cls.load(
self.source_code, "so"
)

View File

@ -34,7 +34,17 @@ from pathlib import Path
from tempfile import _TemporaryFileWrapper
from time import time, time_ns
from types import ModuleType
from typing import Any, Callable, cast, Generic, NoReturn, TYPE_CHECKING, TypeVar, Union
from typing import (
Any,
Callable,
cast,
Generic,
NoReturn,
Optional,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import override, Self
import torch
@ -54,7 +64,7 @@ from torch._inductor.codegen.common import (
custom_backend_passes,
init_backend_registration,
)
from torch._inductor.codegen.cuda import cuda_env
from torch._inductor.codegen.cuda import compile_utils as cuda_compile_utils
from torch._inductor.codegen.rocm.compile_command import (
rocm_compile_command,
rocm_compiler,
@ -64,7 +74,6 @@ from torch._inductor.cpp_builder import (
_LINKER_SCRIPT,
_set_gpu_runtime_env,
_TORCH_PATH,
_transform_cuda_paths,
convert_cubin_to_obj,
CppBuilder,
CppOptions,
@ -119,10 +128,6 @@ from .triton_bundler import TritonBundler
from .virtualized import V
if config.is_fbcode():
from triton.fb.build import build_paths
T = TypeVar("T")
if TYPE_CHECKING:
@ -148,17 +153,6 @@ autotuning_log = torch._logging.getArtifactLogger(__name__, "autotuning")
log = logging.getLogger(__name__)
def use_re_build() -> bool:
"""
Use for CUTLASS compilation only right now.
"""
if config.is_fbcode() and not cuda_env.nvcc_exist(_cuda_compiler()):
from triton.fb.re_build_helper import should_build_locally
return not should_build_locally()
return False
def get_cpp_wrapper_cubin_path_name() -> str:
return "cubin_path" if torch.version.hip is None else "hsaco_path"
@ -2352,7 +2346,8 @@ end
f.write(json.dumps(qual_name_to_id))
generated_files.append(constants_config_json)
gpu_codecache: ROCmCodeCache | CUDACodeCache = (
gpu_codecache: ROCmCodeCache | CUDACodeCache | XPUCodeCache = (
XPUCodeCache() if device_type == "xpu" else
ROCmCodeCache() if torch.version.hip else CUDACodeCache()
)
gpu_kernels_o = gpu_codecache.aot_kernels_o.copy()
@ -2383,10 +2378,10 @@ end
config.aot_inductor.emit_multi_arch_kernel
and device_type == "cuda"
):
current_arch = _nvcc_arch_as_compile_option()
current_arch = cuda_compile_utils._nvcc_arch_as_compile_option()
cmd = (
# pyrefly: ignore [unbound-name]
f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} "
f"{cuda_compile_utils._cuda_compiler()} -fatbin {asm_file} -o {cubin_file} "
# Triton only allows generating PTX version as same as the current arch
f"-gencode arch=compute_{current_arch},code=compute_{current_arch} "
# Include SASS for the current specific arch
@ -3686,55 +3681,6 @@ def _load_triton_kernel_from_source(
return getattr(PyCodeCache.load(source_code), kernel_name)
def _cuda_compiler() -> str | None:
if cuda_env.nvcc_exist(config.cuda.cuda_cxx):
return config.cuda.cuda_cxx
if config.is_fbcode():
return os.path.join(build_paths.sdk_home, "bin", "nvcc")
if cuda_env.nvcc_exist(os.getenv("CUDACXX")):
return os.getenv("CUDACXX", "")
if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")):
return os.path.realpath(os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc"))
return "nvcc"
def _cutlass_path() -> str:
if config.is_fbcode():
from libfb.py import parutil
return parutil.get_dir_path("cutlass-4-headers")
else:
return config.cuda.cutlass_dir
def _cutlass_paths() -> list[str]:
return [
"include",
"tools/library/include",
"tools/library/src",
"tools/util/include",
]
def _clone_cutlass_paths(build_root: str) -> list[str]:
paths = _cutlass_paths()
cutlass_root = _cutlass_path()
for path in _cutlass_paths():
old_path = os.path.join(cutlass_root, path)
new_path = os.path.join(build_root, path)
shutil.copytree(old_path, new_path, dirs_exist_ok=True)
return paths
def _cutlass_include_paths() -> list[str]:
cutlass_path = _cutlass_path()
return [
# Use realpath to get canonical absolute paths, in order not to mess up cache keys
os.path.realpath(os.path.join(cutlass_path, path))
for path in _cutlass_paths()
]
@torch_key_cache
def cutlass_key() -> bytes:
"""
@ -3750,151 +3696,10 @@ def cutlass_key() -> bytes:
return resource_file.read().encode()
combined_hash = hashlib.sha256()
build_code_hash([config.cuda.cutlass_dir], "", combined_hash)
build_code_hash([config.cutlass.cutlass_dir], "", combined_hash)
return combined_hash.digest()
def _cuda_lib_options() -> list[str]:
"""
Util function for CUTLASS backend to find the correct CUDA libraries.
"""
_set_gpu_runtime_env() # cpp_extension consults the env
from torch.utils import cpp_extension
lpaths = cpp_extension.library_paths(device_type="cuda")
if use_re_build():
lpaths += [
build_paths.sdk_lib,
os.path.join(build_paths.sdk_lib, "stubs"),
]
extra_ldflags: list[str] = []
if is_linux():
_transform_cuda_paths(lpaths)
for path in lpaths:
if "torch/lib" in path:
# don't want to depend on pytorch
continue
extra_ldflags.append(f"-L{path}")
# -rpath ensures the DLL can find its dependencies when loaded, even
# if the library path is non-standard.
# But do not add the stubs folder to rpath as the driver is expected to be found at runtime
if os.path.basename(path) != "stubs":
extra_ldflags.extend(["-Xlinker", f"-rpath={path}"])
extra_ldflags.append("-lcuda")
extra_ldflags.append("-lcudart")
else:
raise NotImplementedError(
"Unsupported env, failed to find cuda libs! Currently only Linux is supported."
)
return extra_ldflags
def _nvcc_host_compiler_options() -> list[str]:
return [
"-fPIC",
"-fno-strict-aliasing",
"-fvisibility=hidden",
"-Wconversion",
]
def _nvcc_arch_as_compile_option() -> str:
arch = cuda_env.get_cuda_arch()
if arch == "90":
# Required by cutlass compilation.
return "90a"
if arch == "100":
return "100a"
return arch
def _nvcc_compiler_options() -> list[str]:
arch = _nvcc_arch_as_compile_option()
code = [f"sm_{arch}", f"compute_{arch}"]
if config.cuda.enable_cuda_lto:
code += [f"lto_{arch}"]
options = [
"-t=0",
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
"-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1",
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
"-w",
f"-gencode=arch=compute_{arch},code=[{','.join(code)}]",
config.cuda.compile_opt_level,
"-std=c++17",
"--expt-relaxed-constexpr",
"-DNDEBUG",
]
if config.is_fbcode():
options.extend(["-ccbin", os.path.dirname(build_paths.gcc)])
if config.cuda.enable_debug_info:
options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"])
if config.cuda.enable_ptxas_info:
options.extend(
[
"--keep", # Keep the intermediate files for debugging (including ptx, sass, cubin etc.)
"--ptxas-options=--warn-on-local-memory-usage", # warn us if local memory is used in CUDA Kernels
"--ptxas-options=--warn-on-spills", # warn us if register spilling happens in CUDA Kernels
"--resource-usage", # Report on CUDA resource usage (shared mem, registers etc.)
"--source-in-ptx",
]
) # Annotate the ptx file with source information
if config.cuda.use_fast_math:
options.extend(
[
"--use_fast_math",
"-DCUTLASS_USE_TANH_FOR_SIGMOID=1",
]
)
return options
def cuda_compile_command(
src_files: list[str],
dst_file: str,
dst_file_ext: str,
extra_args: list[str] | None = None,
) -> str:
if extra_args is None:
extra_args = []
if use_re_build():
build_path = os.path.dirname(dst_file)
include_paths = _clone_cutlass_paths(build_path)
src_files = [os.path.basename(src_file) for src_file in src_files]
dst_file = os.path.basename(dst_file)
else:
include_paths = _cutlass_include_paths()
cuda_lib_options = _cuda_lib_options()
nvcc_host_compiler_options = _nvcc_host_compiler_options()
nvcc_compiler_options = _nvcc_compiler_options()
options = (
nvcc_compiler_options
+ extra_args
+ [
f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}"
for opt in nvcc_host_compiler_options
]
+ ["-I" + path for path in include_paths]
+ cuda_lib_options
)
src_file = " ".join(src_files)
res = ""
if dst_file_ext == "o":
res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}"
elif dst_file_ext == "so":
options.append("-shared")
res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
elif dst_file_ext == "exe":
res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
else:
raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!")
if log.isEnabledFor(logging.DEBUG):
log.debug("CUDA command: %s", res)
else:
autotuning_log.debug("CUDA command: %s", res)
return res
class DLLWrapper:
"""A wrapper for a dynamic library."""
@ -3978,10 +3783,9 @@ def binary_error_path(output_path: str) -> str:
return output_path + ".error"
@clear_on_fresh_cache
class CUDACodeCache:
class CUTLASSCodeCache:
"""
A cache for managing the compilation and loading of CUDA source code specifically for CUTLASS.
A cache for managing the compilation and loading source code specifically for CUTLASS.
This class handles writing source code to files, compiling them into shared objects, and caching
the results to avoid redundant compilations. It also manages error handling and logging for the
compilation process.
@ -3995,12 +3799,15 @@ class CUDACodeCache:
cache: dict[str, CacheEntry] = {}
aot_kernels_o: list[str] = []
_SOURCE_CODE_SUFFIX = "cu"
_SOURCE_CODE_SUFFIX: str = ""
_BACKEND: str = ""
@staticmethod
def cache_clear() -> None:
CUDACodeCache.cache.clear()
CUDACodeCache.aot_kernels_o.clear()
CUTLASSCodeCache.cache.clear()
CUTLASSCodeCache.aot_kernels_o.clear()
CUTLASSCodeCache.write.cache_clear()
@staticmethod
@lru_cache(maxsize=4)
@ -4035,6 +3842,24 @@ class CUDACodeCache:
)
return None
@classmethod
def _use_re_build(cls) -> bool:
raise NotImplementedError
@classmethod
def _compile_command(
cls,
src_files: list[str],
dst_file: str,
dst_file_ext: str,
extra_args: Optional[list[str]] = None,
) -> str:
raise NotImplementedError
@classmethod
def _source_code_extra(cls) -> str:
raise NotImplementedError
@classmethod
@lru_cache(None)
def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]:
@ -4043,25 +3868,14 @@ class CUDACodeCache:
Returns the hash key of source code, and the path to the file.
"""
if config.cuda.cutlass_hash_with_compile_cmd:
cuda_command = repr(
cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext)
if config.cutlass.cutlass_hash_with_compile_cmd:
compile_command = repr(
cls._compile_command(["dummy_input"], "dummy_output", dst_file_ext)
)
extra = cuda_command
extra = compile_command
else:
extra = repr(
[
# nvcc and cuda hash
_cuda_compiler(),
# cutlass flags and gcc hash
_nvcc_compiler_options(),
# flags
_nvcc_host_compiler_options(),
# cutlass key
cutlass_key(),
# hack to deal with AOTI .o compilation
]
)
extra = cls._source_code_extra()
key, input_path = write(source_code, cls._SOURCE_CODE_SUFFIX, extra=extra)
return key, input_path
@ -4094,7 +3908,7 @@ class CUDACodeCache:
output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext
error_path = binary_error_path(output_path)
binary_remote_cache = cls.get_kernel_binary_remote_cache(
caching_enabled=config.cuda.use_binary_remote_cache
caching_enabled=config.cutlass.use_binary_remote_cache
and not config.force_disable_caches,
caching_available=config.is_fbcode(),
)
@ -4109,30 +3923,30 @@ class CUDACodeCache:
cmd_parts, error_output = json.loads(error_json)
if (
binary_remote_cache is not None
and config.cuda.upload_to_binary_remote_cache
and config.cutlass.upload_to_binary_remote_cache
):
# This ensures that a local error is uploaded to the remote cache,
# as we make no assumptions about the remote cache having the same
# information as the local cache
binary_remote_cache.put(
error_path, config.cuda.binary_remote_cache_force_write
error_path, config.cutlass.binary_remote_cache_force_write
)
cls.cache[key_with_ext] = CUDACodeCache.CacheEntry(
cls.cache[key_with_ext] = CUTLASSCodeCache.CacheEntry(
input_path, output_path, error_json
)
raise exc.CUDACompileError(cmd_parts, error_output)
if not os.path.exists(output_path):
cmd = cuda_compile_command(
cmd = cls._compile_command(
src_files, output_path, dst_file_ext, extra_args
)
with open(input_path, "a") as f:
f.write("\n")
f.write(f"// CUDA {operation_name} cmd\n// {cmd}\n")
f.write(f"// {cls._BACKEND} {operation_name} cmd\n// {cmd}\n")
start_time = time()
log.debug("CUDA %s: %s", operation_name, cmd)
log.debug("%s %s: %s", cls._BACKEND, operation_name, cmd)
cmd_parts = cmd.split(" ")
try:
if use_re_build():
if cls._use_re_build():
from triton.fb.re_build_helper import run_build_command
run_build_command(
@ -4145,7 +3959,7 @@ class CUDACodeCache:
cmd_parts, stderr=subprocess.STDOUT, env=os.environ
)
except subprocess.CalledProcessError as error:
cls._record_cuda_compile_error(
cls._record_compile_error(
error.output.decode("utf-8"),
key_with_ext,
cmd_parts,
@ -4156,7 +3970,7 @@ class CUDACodeCache:
raise exc.CUDACompileError(cmd_parts, error.output) from error
except Exception as error:
if "COMPILE FAILED WITH" in str(error):
cls._record_cuda_compile_error(
cls._record_compile_error(
str(error),
key_with_ext,
cmd_parts,
@ -4167,29 +3981,30 @@ class CUDACodeCache:
raise exc.CUDACompileError(cmd_parts, str(error)) from error
raise error
end_time = time()
log_duration_msg = f"CUDA {operation_name} took {end_time - start_time} seconds. Command: {cmd}"
log_duration_msg = f"{cls._BACKEND} {operation_name} took {end_time - start_time} seconds. Command: {cmd}"
log.info(log_duration_msg)
else:
log.debug(
"CUDA %s skipped: %s since output already exists",
"%s %s skipped: %s since output already exists",
cls._BACKEND,
operation_name,
output_path,
)
# Upload to remote cache if enabled
if (
binary_remote_cache is not None
and config.cuda.upload_to_binary_remote_cache
and config.cutlass.upload_to_binary_remote_cache
):
# will log on errors, but not fail out
binary_remote_cache.put(
output_path, config.cuda.binary_remote_cache_force_write
output_path, config.cutlass.binary_remote_cache_force_write
)
cls.cache[key_with_ext] = CUDACodeCache.CacheEntry(
cls.cache[key_with_ext] = CUTLASSCodeCache.CacheEntry(
input_path, output_path, None
)
cache_entry: CUDACodeCache.CacheEntry = cls.cache[key_with_ext]
cache_entry: CUTLASSCodeCache.CacheEntry = cls.cache[key_with_ext]
if cache_entry.error_json is not None:
# Restore cached Exception and raise it as if we had compiled
cmd_parts, error_output = json.loads(cache_entry.error_json)
@ -4214,7 +4029,7 @@ class CUDACodeCache:
return (DLLWrapper(dst_file_path), hash_key, source_code_path)
@classmethod
def _record_cuda_compile_error(
def _record_compile_error(
cls,
error_str: str,
key_with_ext: str,
@ -4226,7 +4041,7 @@ class CUDACodeCache:
binary_remote_cache: Any = None,
) -> None:
error_json = json.dumps([cmd_parts, error_str])
cls.cache[key_with_ext] = CUDACodeCache.CacheEntry(
cls.cache[key_with_ext] = CUTLASSCodeCache.CacheEntry(
input_path, output_path, error_json
)
error_path = binary_error_path(output_path)
@ -4236,13 +4051,94 @@ class CUDACodeCache:
# Upload to remote cache directly from memory if enabled
if (
binary_remote_cache is not None
and config.cuda.upload_to_binary_remote_cache
and config.cutlass.upload_to_binary_remote_cache
):
binary_remote_cache.put(
error_path, config.cuda.binary_remote_cache_force_write
error_path, config.cutlass.binary_remote_cache_force_write
)
@clear_on_fresh_cache
class CUDACodeCache(CUTLASSCodeCache):
_SOURCE_CODE_SUFFIX = "cu"
_BACKEND = "CUDA"
@classmethod
def _use_re_build(cls) -> bool:
return cuda_compile_utils.use_re_build()
@classmethod
def _compile_command(
cls,
src_files: list[str],
dst_file: str,
dst_file_ext: str,
extra_args: Optional[list[str]] = None,
) -> str:
return cuda_compile_utils.cuda_compile_command(
src_files, dst_file, dst_file_ext, extra_args=extra_args
)
@classmethod
def _source_code_extra(cls) -> str:
extra = repr(
[
# nvcc and cuda hash
cuda_compile_utils._cuda_compiler(),
# cutlass flags and gcc hash
cuda_compile_utils._nvcc_compiler_options(),
# flags
cuda_compile_utils._nvcc_host_compiler_options(),
# cutlass key
cutlass_key(),
# hack to deal with AOTI .o compilation
]
)
return extra
from torch._inductor.codegen.xpu import compile_utils as xpu_compile_utils
@clear_on_fresh_cache
class XPUCodeCache(CUTLASSCodeCache):
_SOURCE_CODE_SUFFIX = "cpp"
_BACKEND = "XPU"
@classmethod
def _use_re_build(cls) -> bool:
return False
@classmethod
def _compile_command(
cls,
src_files: list[str],
dst_file: str,
dst_file_ext: str,
extra_args: Optional[list[str]] = None,
) -> str:
return xpu_compile_utils.xpu_compile_command(
src_files, dst_file, dst_file_ext, extra_args=extra_args
)
@classmethod
def _source_code_extra(cls) -> str:
extra = repr(
[
# nvcc and cuda hash
xpu_compile_utils._sycl_compiler(),
# cutlass flags and gcc hash
xpu_compile_utils._sycl_compiler_options(),
# flags
xpu_compile_utils._sycl_host_compiler_options(),
# cutlass key
cutlass_key(),
# hack to deal with AOTI .o compilation
]
)
return extra
@clear_on_fresh_cache
class ROCmCodeCache:
@dataclasses.dataclass

View File

@ -10,8 +10,8 @@ from ..scheduler import (
Scheduler,
SchedulerNode,
)
from .cuda.cuda_cpp_scheduling import CUDACPPScheduling
from .cutedsl.cutedsl_scheduling import CuteDSLScheduling
from .cutlass.scheduling import CUTLASSScheduling
from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling
from .triton import TritonScheduling
@ -30,7 +30,7 @@ if TYPE_CHECKING:
_IntLike: TypeAlias = Union[int, Expr]
class CUDACombinedScheduling(BaseScheduling):
class CombinedScheduling(BaseScheduling):
"""
Scheduler for CUDA Kernels, which delegates calls as appropriate
to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices
@ -43,7 +43,7 @@ class CUDACombinedScheduling(BaseScheduling):
def __init__(self, scheduler: Optional[Scheduler]) -> None:
super().__init__(scheduler)
self._triton_scheduling = TritonScheduling(scheduler)
self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler)
self._cutlass_scheduling = CUTLASSScheduling(scheduler)
self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler)
self._cutedsl_scheduling = CuteDSLScheduling(scheduler)
@ -51,8 +51,8 @@ class CUDACombinedScheduling(BaseScheduling):
return self._triton_scheduling.get_backend_features(device)
def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling:
if self._cuda_cpp_scheduling.is_cuda_cpp_template(node):
return self._cuda_cpp_scheduling
if self._cutlass_scheduling.is_cutlass_template(node):
return self._cutlass_scheduling
if self._rocm_cpp_scheduling.is_rocm_cpp_template(node):
return self._rocm_cpp_scheduling
if self._cutedsl_scheduling.is_cutedsl_template(node):
@ -62,11 +62,11 @@ class CUDACombinedScheduling(BaseScheduling):
def can_fuse_vertical(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool:
if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2):
if self._cutlass_scheduling.can_fuse_vertical(node1, node2):
return True
elif self._cuda_cpp_scheduling.is_cuda_cpp_template(
elif self._cutlass_scheduling.is_cutlass_template(
node1
) or self._cuda_cpp_scheduling.is_cuda_cpp_template(node2):
) or self._cutlass_scheduling.is_cutlass_template(node2):
return False
# CuteDSL doesn't support vertical fusion currently
elif self._cutedsl_scheduling.is_cutedsl_template(
@ -79,8 +79,8 @@ class CUDACombinedScheduling(BaseScheduling):
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool:
for node in (node1, node2):
if self._cuda_cpp_scheduling.is_cuda_cpp_template(node):
return self._cuda_cpp_scheduling.can_fuse_horizontal(
if self._cutlass_scheduling.is_cutlass_template(node):
return self._cutlass_scheduling.can_fuse_horizontal(
node1, node2
) # always False at the moment
if self._cutedsl_scheduling.is_cutedsl_template(node):
@ -100,9 +100,9 @@ class CUDACombinedScheduling(BaseScheduling):
epilogue_nodes: Sequence[BaseSchedulerNode],
prologue_nodes: Sequence[BaseSchedulerNode],
) -> Optional[str]:
if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
if self._cutlass_scheduling.is_cutlass_template(template_node):
assert not prologue_nodes
return self._cuda_cpp_scheduling.codegen_template(
return self._cutlass_scheduling.codegen_template(
template_node, epilogue_nodes, prologue_nodes
)
elif self._rocm_cpp_scheduling.is_rocm_cpp_template(template_node):

View File

@ -371,6 +371,12 @@ class DeviceOpOverrides:
# optionally return (scratch definition, arg name)
raise NotImplementedError
def get_device_arch(self) -> str:
raise NotImplementedError
def get_toolkit_version(self) -> str:
raise NotImplementedError
device_op_overrides_dict: dict[str, DeviceOpOverrides] = {}
custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {}
@ -495,12 +501,12 @@ def init_backend_registration() -> None:
Register the backend for different devices, including the scheduling
for kernel code generation and the host side wrapper code generation.
"""
from .combined_scheduling import CombinedScheduling
from .cpp import CppScheduling
from .cpp_wrapper_cpu import CppWrapperCpu
from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef
from .cpp_wrapper_gpu import CppWrapperGpu
from .cpp_wrapper_mps import CppWrapperMps
from .cuda_combined_scheduling import CUDACombinedScheduling
from .halide import HalideScheduling
from .mps import MetalScheduling
from .python_wrapper_mtia import PythonWrapperMtia
@ -525,9 +531,9 @@ def init_backend_registration() -> None:
)
if get_scheduling_for_device("cuda") is None:
# CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
# CombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
cuda_backends = {
"triton": CUDACombinedScheduling,
"triton": CombinedScheduling,
"halide": HalideScheduling,
}
register_backend_for_device(
@ -2366,7 +2372,7 @@ class KernelTemplate:
"""
Base class for defining kernel templates.
Children classes: TritonTemplate, CUDATemplate
Children classes: TritonTemplate, CUTLASSTemplate
"""
@staticmethod
@ -2618,12 +2624,12 @@ class CSEProxy(DefaultHandler):
"""
from ..bounds import ValueRangeAnalysis
from ..select_algorithm import TritonTemplateKernel
from .cuda.cuda_kernel import CUDATemplateKernel
from .cutlass.kernel import CUTLASSTemplateKernel
if isinstance(V.kernel, TritonTemplateKernel):
return ValueRanges.unknown()
if isinstance(V.kernel, CUDATemplateKernel):
if isinstance(V.kernel, CUTLASSTemplateKernel):
return ValueRanges.unknown()
fx_node = V.interpreter.current_node

View File

@ -0,0 +1,264 @@
# mypy: allow-untyped-defs
import logging
import os
import shutil
from pathlib import Path
from typing import Optional
import torch
from torch._inductor import config
from torch._inductor.codegen.cuda import cuda_env
from torch._inductor.cpp_builder import _set_gpu_runtime_env, _transform_cuda_paths
from torch._inductor.utils import is_linux
if config.is_fbcode():
from triton.fb.build import build_paths
log = logging.getLogger(__name__)
autotuning_log = torch._logging.getArtifactLogger(__name__, "autotuning")
def use_re_build() -> bool:
"""
Use for CUTLASS compilation only right now.
"""
if config.is_fbcode() and not cuda_env.nvcc_exist(_cuda_compiler()):
from triton.fb.re_build_helper import should_build_locally
return not should_build_locally()
return False
def _cutlass_path() -> str:
if config.is_fbcode():
from libfb.py import parutil
return parutil.get_dir_path("cutlass-4-headers")
else:
return config.cutlass.cutlass_dir
def _cutlass_paths() -> list[str]:
return [
"include",
"tools/library/include",
"tools/library/src",
"tools/util/include",
]
def _clone_cutlass_paths(build_root: str) -> list[str]:
paths = _cutlass_paths()
cutlass_root = _cutlass_path()
for path in _cutlass_paths():
old_path = os.path.join(cutlass_root, path)
new_path = os.path.join(build_root, path)
shutil.copytree(old_path, new_path, dirs_exist_ok=True)
return paths
def _cutlass_include_paths() -> list[str]:
cutlass_path = _cutlass_path()
return [
# Use realpath to get canonical absolute paths, in order not to mess up cache keys
os.path.realpath(os.path.join(cutlass_path, path))
for path in _cutlass_paths()
]
def _cuda_compiler() -> Optional[str]:
if cuda_env.nvcc_exist(config.cutlass.cuda_cxx):
return config.cutlass.cuda_cxx
if config.is_fbcode():
return os.path.join(build_paths.sdk_home, "bin", "nvcc")
if cuda_env.nvcc_exist(os.getenv("CUDACXX")):
return os.getenv("CUDACXX", "")
if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")):
return os.path.realpath(os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc"))
return "nvcc"
def _cuda_lib_options() -> list[str]:
"""
Util function for CUTLASS backend to find the correct CUDA libraries.
"""
_set_gpu_runtime_env() # cpp_extension consults the env
from torch.utils import cpp_extension
lpaths = cpp_extension.library_paths(device_type="cuda")
if use_re_build():
lpaths += [
build_paths.sdk_lib,
os.path.join(build_paths.sdk_lib, "stubs"),
]
extra_ldflags: list[str] = []
if is_linux():
_transform_cuda_paths(lpaths)
for path in lpaths:
if "torch/lib" in path:
# don't want to depend on pytorch
continue
extra_ldflags.append(f"-L{path}")
# -rpath ensures the DLL can find its dependencies when loaded, even
# if the library path is non-standard.
# But do not add the stubs folder to rpath as the driver is expected to be found at runtime
if os.path.basename(path) != "stubs":
extra_ldflags.extend(["-Xlinker", f"-rpath={path}"])
extra_ldflags.append("-lcuda")
extra_ldflags.append("-lcudart")
else:
raise NotImplementedError(
"Unsupported env, failed to find cuda libs! Currently only Linux is supported."
)
return extra_ldflags
def _nvcc_host_compiler_options() -> list[str]:
return [
"-fPIC",
"-fno-strict-aliasing",
"-fvisibility=hidden",
"-Wconversion",
]
def _nvcc_arch_as_compile_option() -> str:
arch = cuda_env.get_cuda_arch()
if arch == "90":
# Required by cutlass compilation.
return "90a"
if arch == "100":
return "100a"
return arch
def _nvcc_compiler_options() -> list[str]:
arch = _nvcc_arch_as_compile_option()
code = [f"sm_{arch}", f"compute_{arch}"]
if config.cutlass.enable_cuda_lto:
code += [f"lto_{arch}"]
options = [
"-t=0",
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
"-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1",
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
"-w",
f"-gencode=arch=compute_{arch},code=[{','.join(code)}]",
config.cutlass.compile_opt_level,
"-std=c++17",
"--expt-relaxed-constexpr",
"-DNDEBUG",
]
if config.is_fbcode():
options.extend(["-ccbin", os.path.dirname(build_paths.gcc)])
if config.cutlass.enable_debug_info:
options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"])
if config.cutlass.enable_ptxas_info:
options.extend(
[
"--keep", # Keep the intermediate files for debugging (including ptx, sass, cubin etc.)
"--ptxas-options=--warn-on-local-memory-usage", # warn us if local memory is used in CUDA Kernels
"--ptxas-options=--warn-on-spills", # warn us if register spilling happens in CUDA Kernels
"--resource-usage", # Report on CUDA resource usage (shared mem, registers etc.)
"--source-in-ptx",
]
) # Annotate the ptx file with source information
if config.cutlass.use_fast_math:
options.extend(
[
"--use_fast_math",
"-DCUTLASS_USE_TANH_FOR_SIGMOID=1",
]
)
return options
def cuda_compile_command(
src_files: list[str],
dst_file: str,
dst_file_ext: str,
extra_args: list[str] | None = None,
) -> str:
if extra_args is None:
extra_args = []
if use_re_build():
build_path = os.path.dirname(dst_file)
include_paths = _clone_cutlass_paths(build_path)
src_files = [os.path.basename(src_file) for src_file in src_files]
dst_file = os.path.basename(dst_file)
else:
include_paths = _cutlass_include_paths()
cuda_lib_options = _cuda_lib_options()
nvcc_host_compiler_options = _nvcc_host_compiler_options()
nvcc_compiler_options = _nvcc_compiler_options()
options = (
nvcc_compiler_options
+ extra_args
+ [
f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}"
for opt in nvcc_host_compiler_options
]
+ ["-I" + path for path in include_paths]
+ cuda_lib_options
)
src_file = " ".join(src_files)
res = ""
if dst_file_ext == "o":
res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}"
elif dst_file_ext == "so":
options.append("-shared")
res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
elif dst_file_ext == "exe":
res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
else:
raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!")
if log.isEnabledFor(logging.DEBUG):
log.debug("CUDA command: %s", res)
else:
autotuning_log.debug("CUDA command: %s", res)
return res
class CUDACompileSourceCapturingContext:
# Helper class for Benchmarking and Testing CUTLASS Kernels in isolation.
# Can be used to capture the sourcecode passed to CUDACodeCache.compile
def __init__(self):
self.sources = []
self._compile_patch = None
def __enter__(self, *args, **kwargs):
import unittest.mock as mock
import torch._inductor.codecache
_compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile
def my_compile(
source_code, dst_file_ext, extra_args: Optional[list[str]] = None
):
self.sources.append(source_code)
return _compile_method_orig(source_code, dst_file_ext)
# pyrefly: ignore [bad-assignment]
self._compile_patch = mock.patch(
"torch._inductor.codecache.CUDACodeCache.compile", my_compile
)
self._compile_patch.__enter__(*args, **kwargs) # type: ignore[union-attr]
return self
def __exit__(self, *args, **kwargs):
self._compile_patch.__exit__(*args, **kwargs) # type: ignore[union-attr]
def cuda_standalone_runner_compile_command(srcpath: Path, exepath: Path):
# returns command string to compile a (captured) CUDA GEMM Kernel source to a standalone executable that's ready to run
# Passes the correct preprocessor define to nvcc to ensure the standalone runner is enabled.
from torch._inductor.codecache import cuda_compile_command
extra_args = ["-DGENERATE_STANDALONE_RUNNER=1", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]
compile_command = cuda_compile_command(
[str(srcpath)], str(exepath), "exe", extra_args=extra_args
)
return compile_command

View File

@ -9,6 +9,7 @@ from ..common import (
register_device_op_overrides,
TritonScratchWorkspace,
)
from .cuda_env import get_cuda_arch, get_cuda_version
class CUDADeviceOpOverrides(DeviceOpOverrides):
@ -360,5 +361,11 @@ class CUDADeviceOpOverrides(DeviceOpOverrides):
else:
return [f"CUdeviceptr {var_name} = 0;"], var_name
def get_device_arch(self) -> str:
return get_cuda_arch()
def get_toolkit_version(self) -> str:
return get_cuda_version()
register_device_op_overrides("cuda", CUDADeviceOpOverrides())

View File

@ -10,9 +10,11 @@ from typing import Any, Optional
import torch._inductor.config as config
from torch._inductor.codecache import cutlass_key
from torch._inductor.codegen.cuda import cutlass_utils, serialization
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
from torch._inductor.codegen.common import get_device_op_overrides
from torch._inductor.codegen.cutlass import serialization, utils
from torch._inductor.codegen.cutlass.serialization import (
get_cutlass_operation_serializer,
)
from torch._inductor.runtime.cache_dir_utils import cache_dir
from torch._inductor.utils import clear_on_fresh_cache
@ -39,7 +41,7 @@ def get_config_request_key(
return hashlib.sha256(f.read()).hexdigest()
serialization_hash = get_file_hash(serialization)
cutlass_utils_hash = get_file_hash(cutlass_utils)
cutlass_utils_hash = get_file_hash(utils)
hash_target = "-".join(
[
@ -63,7 +65,7 @@ def _generate_config_filename(request_key: str) -> str:
@clear_on_fresh_cache
@functools.cache
def maybe_fetch_ops() -> Optional[list[Any]]:
def maybe_fetch_ops(device_type: str) -> Optional[list[Any]]:
"""
Fetch ops from databases.
"""
@ -71,11 +73,14 @@ def maybe_fetch_ops() -> Optional[list[Any]]:
return None
# setup
arch: str = get_cuda_arch()
# get_cuda_version might return "12.4.0" or "12.4"
# but we want to use "12.4"
version: str = ".".join(get_cuda_version().split(".")[:2])
instantiation_level: str = config.cuda.cutlass_instantiation_level
device_op_overrides = get_device_op_overrides(device_type)
arch: str = device_op_overrides.get_device_arch()
version: str = device_op_overrides.get_toolkit_version()
if device_type == "cuda":
# get_cuda_version might return "12.4.0" or "12.4"
# but we want to use "12.4"
version = ".".join(version.split(".")[:2])
instantiation_level: str = config.cutlass.cutlass_instantiation_level
# filename and filepath
request_key: str = get_config_request_key(arch, version, instantiation_level)

View File

@ -11,7 +11,7 @@ from typing import Any, Optional, Union
import torch
import torch.utils._pytree as pytree
from torch._inductor.autotune_process import TensorMeta
from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops
from torch._inductor.codegen.cutlass.cache import maybe_fetch_ops
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
from torch._inductor.runtime.runtime_utils import dynamo_timed
from torch._inductor.scheduler import BaseSchedulerNode
@ -19,11 +19,11 @@ from torch._inductor.select_algorithm import create_inputs_key
from torch._inductor.utils import clear_on_fresh_cache
from ... import ir
from ...config import cuda as inductor_cuda_config
from ...config import cutlass as inductor_cutlass_config
from ...ir import (
Buffer,
ChoiceCaller,
CUDATemplateBuffer,
CUTLASSTemplateBuffer,
FixedLayout,
IRNode,
Layout,
@ -32,11 +32,12 @@ from ...ir import (
from ...utils import is_dynamic, Placeholder
from ...virtualized import V
from ..common import IndentedBuffer
from . import cutlass_utils
from .cuda_kernel import CUDATemplateKernel
from .cuda_template import CUTLASSTemplate
from .cutlass_python_evt import CutlassEVTCodegen, scaled_mm_evt
from .cutlass_utils import (
from ..cuda import cuda_env
from . import utils as cutlass_utils
from .kernel import CUTLASSTemplateKernel
from .python_evt import CutlassEVTCodegen, scaled_mm_evt
from .template import CUTLASSTemplate
from .utils import (
ACCUMULATOR_DTYPES,
dtype_match,
torch_dtype_to_cutlass_type,
@ -578,7 +579,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
for name, op in ops:
for (
swizzle
) in inductor_cuda_config.cutlass_max_profiling_swizzle_options:
) in inductor_cutlass_config.cutlass_max_profiling_swizzle_options:
description = f"{name} swizzle={swizzle}"
self.maybe_append_choice(
choices,
@ -621,7 +622,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/device/gemm_sparse.h"
//#include "cutlass/gemm/device/gemm_sparse.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
@ -635,7 +636,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
#include "cutlass/util/tensor_view_io.h"
"""
)
if inductor_cuda_config.generate_test_runner and not is_dynamic(
if inductor_cutlass_config.generate_test_runner and not is_dynamic(
*self.input_nodes, self.output_node
):
res.splice(GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES)
@ -712,12 +713,14 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
bool: True if the alignment was successfully updated, False otherwise.
"""
alignment = cutlass_utils.get_max_alignment(torch_layout)
cuda_arch = cutlass_utils.get_cuda_arch()
if cuda_arch and int(cuda_arch) >= 90 and alignment < op_element.alignment:
return False
else:
op_element.alignment = alignment
return True
if torch.cuda.is_available():
cuda_arch = cuda_env.get_cuda_arch()
cuda_arch = cutlass_utils._normalize_cutlass_arch(cuda_arch)
if cuda_arch and int(cuda_arch) >= 90 and alignment < op_element.alignment:
return False
op_element.alignment = alignment
return True
@staticmethod
def should_swap_XW(
@ -953,7 +956,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
)
return None
if inductor_cuda_config.cutlass_tma_only and not self._has_tma_epilogue(op):
if inductor_cutlass_config.cutlass_tma_only and not self._has_tma_epilogue(op):
return None
# Set epilogue.
@ -975,14 +978,16 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
return None
# Apply regex filters at the end when configuration name doesn't change anymore
if inductor_cuda_config.cutlass_op_allowlist_regex:
if inductor_cutlass_config.cutlass_op_allowlist_regex:
if not re.search(
inductor_cuda_config.cutlass_op_allowlist_regex, op.configuration_name()
inductor_cutlass_config.cutlass_op_allowlist_regex,
op.configuration_name(),
):
return None
if inductor_cuda_config.cutlass_op_denylist_regex is not None:
if inductor_cutlass_config.cutlass_op_denylist_regex is not None:
if re.search(
inductor_cuda_config.cutlass_op_denylist_regex, op.configuration_name()
inductor_cutlass_config.cutlass_op_denylist_regex,
op.configuration_name(),
):
return None
@ -1007,10 +1012,10 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
return self.filtered_ops_cache[self.cache_key]
with dynamo_timed("CUTLASSGemmTemplate.maybe_fetch_ops"):
maybe_ops = maybe_fetch_ops()
maybe_ops = maybe_fetch_ops(self.device_type)
if maybe_ops is None:
log.debug("Cannot fetch ops from cache, generating ops from scratch")
full_ops = cutlass_utils.gen_ops()
full_ops = cutlass_utils.gen_ops(self.device_type)
ops = pytree.tree_flatten(full_ops)[0]
else:
log.debug("Using cached ops from cache")
@ -1035,7 +1040,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
time.time() - start_time,
)
sorted_res = sorted(res.items())
ret_res = sorted_res[: inductor_cuda_config.cutlass_max_profiling_configs]
ret_res = sorted_res[: inductor_cutlass_config.cutlass_max_profiling_configs]
if len(self.filtered_ops_cache) < 50:
self.filtered_ops_cache[self.cache_key] = ret_res
else:
@ -1060,26 +1065,26 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
def render( # type: ignore[override]
self,
kernel: CUDATemplateKernel,
kernel: CUTLASSTemplateKernel,
op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821
template_buffer_node: Optional[CUDATemplateBuffer] = None,
template_buffer_node: Optional[CUTLASSTemplateBuffer] = None,
epilogue_nodes: Optional[list[BaseSchedulerNode]] = None,
**kwargs,
) -> str:
"""
The primary entry point for the code rendering process used in this template.
Renders the Cutlass based CUDA C++ code for the GEMM Kernel that this template is designed to implement,
Renders the Cutlass based CUDA/XPU C++ code for the GEMM Kernel that this template is designed to implement,
including potentially fused epilogues.
Args:
kernel (CUDATemplateKernel): The kernel to be rendered.
kernel (CUTLASSTemplateKernel): The kernel to be rendered.
op (cutlass_gemm_op.GemmOperation, optional): A GEMM operation that is required to be compatible with the
input and output definitions as well as a possible epilogue. Defaults to None.
**kwargs: Additional keyword arguments. Currently unused.
Returns:
str: Cutlass based CUDA C++ code fragment as a string, to be used by the current
CUDATemplateKernel or autotuning code.
str: Cutlass based CUDA/XPU C++ code fragment as a string, to be used by the current
CUTLASSTemplateKernel or autotuning code.
Note:
All inputs and their corresponding buffer addresses and names take precedence over previously
@ -1277,7 +1282,9 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
}
options.update(dict(zip(extra_names, extra_inputs)))
res = self._template_from_string(self._get_template()).render(**options)
if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias):
if inductor_cutlass_config.generate_test_runner and not is_dynamic(
X, W, Y, Bias
):
test_runner_code = self._template_from_string(
GEMM_STANDALONE_RUNNER_TEMPLATE
).render(**options)
@ -1295,7 +1302,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
names_str: str = "",
) -> str:
"""
Helper method to render the Cutlass CUDA C++ code required for calling the GEMM operation in the standalone
Helper method to render the Cutlass CUDA/XPU C++ code required for calling the GEMM operation in the standalone
test runner that might also be generated along with the rest of the code, if the corresponding config is
enabled.
@ -1483,7 +1490,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
output_dtype: torch.dtype,
accumulator_dtype: torch.dtype,
) -> tuple[str, str, str, EVTArgRenames]:
from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace
from .lib_extensions.evt_extensions import create_example_tensors, trace
acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype)
output_dtype = torch_dtype_to_cutlass_type(output_dtype)
@ -1554,7 +1561,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
op: GemmOperation,
evt_name: Optional[str] = None,
) -> tuple[str, str]:
"""Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance.
"""Defines and renders the Cutlass / CUDA/XPU C++ code for a given GEMM operation instance.
This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply
forms a core part of a number of scientific applications, so this efficient and adaptable implementation is
@ -1570,9 +1577,11 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
assert cutlass_utils.try_import_cutlass()
import cutlass_library.library as cutlass_lib
from .cutlass_lib_extensions import gemm_operation_extensions as gemm_extensions
from .lib_extensions import gemm_operation_extensions as gemm_extensions
emitter = gemm_extensions.EmitGemmUniversal3xInstanceWithEVT(evt_name=evt_name) # type: ignore[call-arg]
emitter = gemm_extensions.EmitGemmUniversal3xInstanceWithEVT(
evt_name=evt_name, device_type=self.device_type
) # type: ignore[call-arg]
if not hasattr(op, "epilogue_functor") or not isinstance(
op.epilogue_functor, enum.Enum
@ -1629,11 +1638,11 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
Y: IRNode,
alpha: float,
beta: float,
kernel: CUDATemplateKernel,
kernel: CUTLASSTemplateKernel,
epilogue_args,
) -> str:
"""
Render the Cutlass CUDA C++ code required for passing arguments to the GEMM operation.
Render the Cutlass CUDA/XPU C++ code required for passing arguments to the GEMM operation.
Args:
argument_template (str): Template for the GEMM operation arguments.
@ -1646,11 +1655,11 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
Y (IRNode): The output tensor.
alpha (float): Scaling factor for the product of the inputs.
beta (float): Scaling factor for the output tensor.
kernel (CUDATemplateKernel): CUDA Template kernel for the operation.
kernel (CUTLASSTemplateKernel): CUDA/XPU Template kernel for the operation.
epilogue_args (any): Additional arguments for the epilogue state.
Returns:
str: A block of CUDA C++ code as a string, ready to be used as arguments for the GEMM operation.
str: A block of CUDA/XPU C++ code as a string, ready to be used as arguments for the GEMM operation.
Note: If `should_swap_xw` is True, a transpose operation will be applied to the X, W, Bias, and Y
tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped
@ -1710,6 +1719,8 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
"""CUTLASS 2x GEMM Template, which is used to generate CUTLASS GEMM kernels"""
def __init__(
self,
input_nodes: list[Buffer],
@ -1918,7 +1929,7 @@ class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
Y: IRNode,
alpha: float,
beta: float,
kernel: CUDATemplateKernel,
kernel: CUTLASSTemplateKernel,
epilogue_args,
) -> str:
"""
@ -1937,7 +1948,7 @@ class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
Y (IRNode): The output tensor.
alpha (float): Scaling factor for the product of the inputs.
beta (float): Scaling factor for the output tensor.
kernel (CUDATemplateKernel): CUDA Template kernel for the operation.
kernel (CUTLASSTemplateKernel): CUDA Template kernel for the operation.
epilogue_args (any): Additional arguments for the epilogue state.
Returns:

View File

@ -10,22 +10,23 @@ from sympy import Expr, symbols
import torch._inductor.config as config
from torch import dtype as torch_dtype
from torch._inductor.codegen.common import get_device_op_overrides
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
from torch._inductor.scheduler import BaseSchedulerNode
from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder
from torch.utils._sympy.value_ranges import ValueRanges
from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE
from .utils import DTYPE_TO_CUTLASS_TYPE
if TYPE_CHECKING:
from .cuda_template import ArgInfo
from .template import ArgInfo
from ...autotune_process import CUDABenchmarkRequest
from ...autotune_process import CUTLASSBenchmarkRequest
from ...ir import (
Buffer,
ChoiceCaller,
CUDATemplateBuffer,
CUTLASSTemplateBuffer,
IRNode,
Layout,
PrimitiveInfoType,
@ -46,7 +47,7 @@ from ..cpp_utils import CppPrinter, DTYPE_TO_CPP
if TYPE_CHECKING:
from torch._inductor.codegen.cuda.cuda_template import CUDATemplate
from torch._inductor.codegen.cutlass.template import CUTLASSTemplate
log = logging.getLogger(__name__)
@ -72,9 +73,9 @@ class LayoutArg:
return self.node == node and self.attr == attr and self.dim == dim
class CUDAKernel(Kernel):
class CUTLASSKernel(Kernel):
"""
Baseclass for CUDA / Cutlass based Kernels
Baseclass for Cutlass based Kernels
"""
overrides = OpOverrides # type: ignore[assignment]
@ -191,21 +192,20 @@ class CUDAKernel(Kernel):
return _normalize_idx(-1, len(strides))
class CUDATemplateKernel(CUDAKernel):
class CUTLASSTemplateKernel(CUTLASSKernel):
"""
Template kernels defined by CUDA / Cutlass in C++.
Template kernels defined by Cutlass in C++.
"""
_EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream"
def __init__(
self,
kernel_name: str,
runtime_arg_info: list["ArgInfo"],
runtime_arg_values: list[Any],
device_type: str = "cuda", # type: ignore[assignment]
) -> None:
"""
Initializes a new instance of the CUDATemplateKernel class.
Initializes a new instance of the CUTLASSTemplateKernel class.
Args:
kernel_name (str): The name of the kernel.
@ -214,6 +214,9 @@ class CUDATemplateKernel(CUDAKernel):
self.kernel_name = kernel_name
self.runtime_arg_info = runtime_arg_info
self.runtime_arg_values = runtime_arg_values
self.device_type = device_type
self.device_codegen = get_device_op_overrides(self.device_type)
self._EXTRA_CPP_ARGS = f"size_t* workspace_size, uint8_t* workspace, {self.device_codegen.cpp_stream_type()} stream"
def check_not_null(self, node: IRNode) -> str:
"""
@ -328,14 +331,14 @@ class CUDATemplateKernel(CUDAKernel):
def call_kernel(
self,
name: str,
node: "CUDATemplateBuffer", # type: ignore[name-defined]
node: "CUTLASSTemplateBuffer", # type: ignore[name-defined]
) -> None:
"""
Generates code to call the kernel through V.graph.wrapper_code.
used from within torch._inductor.wrapper.PythonWrapperCodegen
name: Name of kernel function.
node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes
node: The CUTLASSTemplateBuffer node which contains information about the kernel, it's fused epilogue nodes
as well as all required inputs and outputs.
"""
wrapper = V.graph.wrapper_code
@ -423,7 +426,7 @@ class CUDATemplateKernel(CUDAKernel):
# Helper method, called into from CUTLASSGemmTemplate
if node is None:
return default_dtype
from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate
from torch._inductor.codegen.cutlass.template import CUTLASSTemplate
return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]
@ -562,16 +565,16 @@ class CUDATemplateKernel(CUDAKernel):
self.store_buffer_names.add(name)
class CUDATemplateCaller(ChoiceCaller):
class CUTLASSTemplateCaller(ChoiceCaller):
"""
CUDATemplateCaller
CUTLASSTemplateCaller
This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller.
This class represents a caller for CUTLASS template kernels. It is a subclass of ChoiceCaller.
Attributes:
name (str): The name of the caller.
category (str): The category of the caller.
bmreq (CUDABenchmarkRequest): The benchmark request for the caller.
template_buffer (CUDATemplateBuffer): The template buffer for the caller.
bmreq (CUTLASSBenchmarkRequest): The benchmark request for the caller.
template_buffer (CUTLASSTemplateBuffer): The template buffer for the caller.
"""
def __init__(
@ -581,12 +584,12 @@ class CUDATemplateCaller(ChoiceCaller):
input_nodes: list[Buffer],
layout: Layout,
make_kernel_render: Callable[
[CUDATemplateBuffer, Optional[list[BaseSchedulerNode]]],
tuple[CUDATemplateKernel, functools.partial[str]],
[CUTLASSTemplateBuffer, Optional[list[BaseSchedulerNode]]],
tuple[CUTLASSTemplateKernel, functools.partial[str]],
],
bmreq: CUDABenchmarkRequest,
bmreq: CUTLASSBenchmarkRequest,
supports_epilogue_fusion: bool,
template: "CUDATemplate", # type: ignore[name-defined]
template: "CUTLASSTemplate", # type: ignore[name-defined]
info_kwargs: Optional[
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
], # type: ignore[type-arg]
@ -612,10 +615,10 @@ class CUDATemplateCaller(ChoiceCaller):
return self.bmreq.benchmark(*args, out=out)
def __str__(self) -> str:
return f"CUDATemplateCaller(source_file={self.bmreq.source_file})"
return f"CUTLASSTemplateCaller(source_file={self.bmreq.source_file})"
def call_name(self) -> str:
return f"cuda_template_kernels.{self.name}"
return f"cutlass_template_kernels.{self.name}"
def kernel_hash_key(self) -> str:
"""
@ -675,7 +678,7 @@ class CUDATemplateCaller(ChoiceCaller):
def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]:
self.bmreq.update_workspace_size()
return TensorBox.create(
CUDATemplateBuffer(
CUTLASSTemplateBuffer(
layout=self.layout,
inputs=self.input_nodes,
make_kernel_render=self.make_kernel_render,

View File

@ -10,7 +10,7 @@ from torch._inductor.ir import (
)
from torch.utils._ordered_set import OrderedSet
from ..cutlass_utils import torch_dtype_to_cutlass_type, try_import_cutlass
from ..utils import torch_dtype_to_cutlass_type, try_import_cutlass
EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace
@ -237,7 +237,7 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
size_hint_fn: Callable[[Union[Expr, int]], int],
arg_renames: EVTArgRenames,
) -> str:
from ..cuda_template import CUTLASSTemplate
from ..template import CUTLASSTemplate
# Today, arguments are either a pointer to the
# node's memory, a stride tuple, the datatype

View File

@ -1,5 +1,5 @@
# mypy: ignore-errors
from ..cutlass_utils import try_import_cutlass
from ..utils import try_import_cutlass
# copied / modified from original at
@ -16,7 +16,7 @@ if try_import_cutlass():
class EmitGemmUniversal3xInstanceWithEVT:
"""Responsible for emitting a CUTLASS 3.x template definition"""
def __init__(self, operation_suffix="", evt_name=None):
def __init__(self, operation_suffix="", evt_name=None, device_type="cuda"):
self.operation_suffix = operation_suffix
self.includes = [
"cutlass/cutlass.h",
@ -32,6 +32,13 @@ if try_import_cutlass():
${element_c},
${element_epilogue}
>"""
if device_type == "xpu":
self.builtin_epilogue_functor_template = """${epilogue_functor}<
${element_accumulator},
${element_epilogue},
${element_c},
${element_epilogue}
>"""
self.evt_name = evt_name
self.gemm_template = """
@ -175,6 +182,8 @@ ${compile_guard_end}
f"cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(\
sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage))>"
)
if operation.arch == 11:
stage_count_string = "cutlass::gemm::collective::StageCountAuto"
epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto"
@ -350,6 +359,11 @@ cute::Layout<cute::Shape<int,int,int>, {operation_name_str}_StrideNarrow>{{}}));
if self.evt_name:
epilogue_functor = self.evt_name
arch = (
"cutlass::arch::IntelXe"
if operation.arch == 11
else f"cutlass::arch::Sm{operation.arch}"
)
values = {
"operation_name": operation_name_str,
"operation_suffix": self.operation_suffix,
@ -369,7 +383,7 @@ cute::Layout<cute::Shape<int,int,int>, {operation_name_str}_StrideNarrow>{{}}));
"element_accumulator": DataTypeTag[operation.accumulator_type()],
"opcode_class_main": OpcodeClassTag[opcode_class_main],
"opcode_class_epi": OpcodeClassTag[opcode_class_epi],
"arch": f"cutlass::arch::Sm{operation.arch}",
"arch": arch,
"tile_shape_m": str(tile_shape_m),
"tile_shape_n": str(tile_shape_n),
"tile_shape_k": str(tile_shape_k),

View File

@ -183,11 +183,11 @@ class CutlassEVTCodegen(CutlassEVTOpsMixIn):
@staticmethod
def ir_to_evt_python_code(
cuda_template_node_name: str,
cutlass_template_node_name: str,
epilogue_nodes: list[BaseSchedulerNode],
removed_buffers: OrderedSet[str],
) -> tuple[list[str], list[str], dict[str, Any], str]:
codegen = CutlassEVTCodegen(cuda_template_node_name, removed_buffers)
codegen = CutlassEVTCodegen(cutlass_template_node_name, removed_buffers)
handler = _AssignmentFormatter(codegen)
with virtualized.V.set_ops_handler(handler):

View File

@ -4,7 +4,7 @@ import logging
from collections.abc import Sequence
from typing import cast
from torch._inductor.codegen.cuda.cutlass_python_evt import (
from torch._inductor.codegen.cutlass.python_evt import (
CutlassEVTCodegen,
MockCutlassHandler,
)
@ -14,7 +14,7 @@ from torch.utils._ordered_set import OrderedSet
from ...._dynamo.utils import counters
from ... import config
from ...codecache import code_hash, get_path
from ...ir import Buffer, ComputedBuffer, CUDATemplateBuffer, Pointwise
from ...ir import Buffer, ComputedBuffer, CUTLASSTemplateBuffer, Pointwise
from ...scheduler import (
BaseSchedulerNode,
BaseScheduling,
@ -36,13 +36,13 @@ class WhyNoFuseNames(WhyNoFuse):
self.name2 = name2
class CUDACPPScheduling(BaseScheduling):
class CUTLASSScheduling(BaseScheduling):
"""
Partial Scheduling implementation for CUDA C++ Kernels.
Partial Scheduling implementation for cutlass C++ Kernels.
This class is intended to be used in combination with TritonScheduling,
and delegated to by CUDACombinedScheduling.
and delegated to by CombinedScheduling.
It handles fusion decisions and CUDA C++ specific template code generation.
It handles fusion decisions and cutlass C++ specific template code generation.
"""
@classmethod
@ -53,25 +53,25 @@ class CUDACPPScheduling(BaseScheduling):
return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
@staticmethod
def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool:
def is_cutlass_template(node: BaseSchedulerNode) -> bool:
return isinstance(node, SchedulerNode) and isinstance(
node.node, CUDATemplateBuffer
node.node, CUTLASSTemplateBuffer
)
def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool:
return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template(node)
def is_cutlass_fused_template(self, node: BaseSchedulerNode) -> bool:
return isinstance(node, FusedSchedulerNode) and self.is_cutlass_template(node)
def can_fuse_vertical(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool:
if self.is_cuda_cpp_template(node1) and isinstance(node2, BaseSchedulerNode):
if self.is_cutlass_template(node1) and isinstance(node2, BaseSchedulerNode):
assert node1.node, "node1.node should not be None"
return self._can_fuse_epilogue_impl(
cast(CUDATemplateBuffer, node1.node),
cast(CUTLASSTemplateBuffer, node1.node),
[],
node2, # type: ignore[arg-type]
)
elif self.is_cuda_cpp_fused_template(node1) and isinstance(
elif self.is_cutlass_fused_template(node1) and isinstance(
node2, BaseSchedulerNode
):
assert node1.node, "node1.node should not be None"
@ -130,16 +130,16 @@ class CUDACPPScheduling(BaseScheduling):
prologue_nodes: Sequence[BaseSchedulerNode],
):
"""
Codegen a CUDA template, possibly with fused epilogues
Codegen a cutlass template, possibly with fused epilogues
"""
counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes)
assert self.is_cuda_cpp_template(template_node), (
"Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
counters["inductor"]["cutlass_epilogue_fusion_counter"] += len(epilogue_nodes)
assert self.is_cutlass_template(template_node), (
"Template node passed to CUTLASSScheduling.codegen_template must be a SchedulerNode that wraps a CUTLASSTemplateBuffer"
)
template_node = cast(SchedulerNode, template_node)
_, (_numel, rnumel) = template_node.group
assert rnumel == 1
ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node)
ctb: CUTLASSTemplateBuffer = cast(CUTLASSTemplateBuffer, template_node.node)
epilogue_ir_nodes: list[Buffer] = [n.node for n in epilogue_nodes] # type: ignore[misc]
assert all(isinstance(n, ComputedBuffer) for n in epilogue_ir_nodes), (
"Epilogue nodes must all be instances of ir.ComputedBuffer"
@ -197,7 +197,7 @@ class CUDACPPScheduling(BaseScheduling):
def _can_fuse_epilogue_impl(
self,
cuda_template_buffer: CUDATemplateBuffer,
cutlass_template_buffer: CUTLASSTemplateBuffer,
existing_epilogue_nodes: list[BaseSchedulerNode],
node_to_fuse: BaseSchedulerNode,
) -> bool:
@ -206,18 +206,20 @@ class CUDACPPScheduling(BaseScheduling):
support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes.
Args:
cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer
cutlass_template_buffer : A CUTLASSTemplateBuffer object representing the CUTLASS template and it's result buffer
existing_epilogue_nodes : List[SchedulerNode]: The list of already fused epilogue nodes.
node_to_fuse: The SchedulerNode node to be checked if it can be fused with the epilogue.
Returns:
- bool: True if the given node can be fused with the epilogue, False otherwise.
"""
why = WhyNoFuseNames(cuda_template_buffer.get_name(), node_to_fuse.get_name())
why = WhyNoFuseNames(
cutlass_template_buffer.get_name(), node_to_fuse.get_name()
)
scheduler_nodes_to_fuse = node_to_fuse.get_nodes()
assert isinstance(cuda_template_buffer, CUDATemplateBuffer)
assert isinstance(cutlass_template_buffer, CUTLASSTemplateBuffer)
# Checks on constituent nodes
for s_node in scheduler_nodes_to_fuse:
@ -235,18 +237,18 @@ class CUDACPPScheduling(BaseScheduling):
name = node.get_computed_buffer_name() # type: ignore[attr-defined]
# dtype can differ, and strides can differ as long as they are broadcastable
if node.get_size() != cuda_template_buffer.get_size():
if node.get_size() != cutlass_template_buffer.get_size():
why(
f"{name}'s size: {node.get_size()} differs from {cuda_template_buffer.get_name()}'s \
size: {cuda_template_buffer.get_size()}"
f"{name}'s size: {node.get_size()} differs from {cutlass_template_buffer.get_name()}'s \
size: {cutlass_template_buffer.get_size()}"
)
return False
assert len(
existing_epilogue_nodes
) or cuda_template_buffer.get_name() in OrderedSet(
) or cutlass_template_buffer.get_name() in OrderedSet(
[rd.name for rd in node_to_fuse.read_writes.reads]
), "First epilogue node must read from cuda template buffer"
), "First epilogue node must read from cutlass template buffer"
if node_to_fuse.has_aliasing_or_mutation():
why(f"{node_to_fuse.get_name()} has aliasing or mutation")
@ -257,22 +259,20 @@ size: {cuda_template_buffer.get_size()}"
)
return False
elif (
not config.cuda.cutlass_epilogue_fusion_enabled
not config.cutlass.cutlass_epilogue_fusion_enabled
or not config.epilogue_fusion
):
why("cutlass epilogue fusion is not enabled")
return False
elif not cuda_template_buffer.supports_epilogue_fusion:
elif not cutlass_template_buffer.supports_epilogue_fusion:
why("epilogue fusion is only supported for TMA-enabled gemm ops")
return False
try:
from torch._inductor.codegen.cuda.cutlass_python_evt import (
CutlassEVTCodegen,
)
from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen
CutlassEVTCodegen.ir_to_evt_python_code(
cuda_template_buffer.get_name(),
cutlass_template_buffer.get_name(),
existing_epilogue_nodes + list(node_to_fuse.get_nodes()),
OrderedSet(),
)
@ -282,13 +282,13 @@ size: {cuda_template_buffer.get_size()}"
if not_implemented_op.startswith("_op_"):
not_implemented_op = not_implemented_op[4:]
why(
f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}, \
f"Cannot fuse epilogue node {node_to_fuse} into {cutlass_template_buffer.name}, \
likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950
)
return False
else: # Likely due to unsupported dtype.
why(
f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}. \
f"Cannot fuse epilogue node {node_to_fuse} into {cutlass_template_buffer.name}. \
Reason: {not_implemented_op}" # noqa: G004, B950
)
return False

View File

@ -4,7 +4,7 @@ import json
from enum import Enum
from typing import Any, Optional
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
from torch._inductor.codegen.cutlass.utils import try_import_cutlass
class CUTLASSOperationSerializer:

View File

@ -4,7 +4,6 @@ import hashlib
import itertools
from dataclasses import dataclass
from typing import Any, Optional, TYPE_CHECKING, Union
from typing_extensions import override
from unittest.mock import patch
import sympy
@ -14,13 +13,13 @@ from torch._inductor import config
from torch._inductor.utils import clear_on_fresh_cache, Placeholder
from torch._logging import getArtifactLogger
from ...autotune_process import CUDABenchmarkRequest, TensorMeta
from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout
from ...autotune_process import CUTLASSBenchmarkRequest, TensorMeta
from ...ir import Buffer, CUTLASSTemplateBuffer, IRNode, Layout
from ...utils import IndentedBuffer, unique
from ...virtualized import V
from ..common import KernelTemplate
from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE
from .kernel import CUTLASSTemplateCaller, CUTLASSTemplateKernel
from .utils import DTYPE_TO_CUTLASS_TYPE
if TYPE_CHECKING:
@ -40,7 +39,12 @@ class ArgInfo:
@clear_on_fresh_cache
class CUDATemplate(KernelTemplate):
class CUTLASSTemplate(KernelTemplate):
"""
CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the
CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels.
"""
index_counter = itertools.count()
# dict of cache key to (code, size_args)
code_cache: dict[str, tuple[str, tuple[int, ...], tuple[int, ...]]] = {}
@ -54,11 +58,11 @@ class CUDATemplate(KernelTemplate):
input_reorder: Optional[list[int]] = None,
) -> None:
"""
Baseclass for CUDA C++ Templates, derived from KernelTemplate.
Baseclass for CUTLASS C++ Templates, derived from KernelTemplate.
Not to be instantiated directly.
Args:
name (str): The name of the CUDATemplate object.
name (str): The name of the CUTLASSTemplate object.
input_nodes (List[IRNode]): A list of input IRNodes.
layout (Layout): The layout of the output buffer / tensor.
input_reorder (Optional[List[int]]): An optional list that specifies
@ -69,6 +73,7 @@ class CUDATemplate(KernelTemplate):
self.output_node: Buffer = Buffer(name="buf_out", layout=layout)
self.input_reorder = input_reorder
self.layout = layout
self.device_type = layout.device.type if input_nodes else "cuda"
@classmethod
@functools.lru_cache(None)
@ -110,7 +115,7 @@ class CUDATemplate(KernelTemplate):
args are different.
"""
key: Optional[str] = None
if config.cuda.enable_caching_codegen:
if config.cutlass.enable_caching_codegen:
key = self.make_key(name=name, input_key=input_key, layout_repr=layout_repr)
if key is not None and key in self.code_cache:
@ -123,10 +128,11 @@ class CUDATemplate(KernelTemplate):
return code, extra_args
kernel_name = str(Placeholder.KERNEL_NAME)
kernel = CUDATemplateKernel(
kernel = CUTLASSTemplateKernel(
kernel_name=kernel_name,
runtime_arg_info=self.get_runtime_arg_info(),
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
device_type=self.device_type,
)
with patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)):
code = self.render(kernel=kernel, **kwargs)
@ -174,10 +180,10 @@ class CUDATemplate(KernelTemplate):
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
**kwargs,
) -> CUDATemplateCaller:
) -> CUTLASSTemplateCaller:
"""
Generates the CUDA template caller object for the given GEMM template and operation.
This CUDATemplateCaller may be used to call and benchmark the generated CUDA kernel
This CUTLASSTemplateCaller may be used to call and benchmark the generated CUDA kernel
in a standalone manner to enable Autotuning.
Args:
@ -185,7 +191,7 @@ class CUDATemplate(KernelTemplate):
kwargs: Additional keyword arguments.
Returns:
A CUDATemplateCaller object representing the generated CUDA template caller.
A CUTLASSTemplateCaller object representing the generated CUDA template caller.
"""
code, extra_args = self.generate_code_and_args(
name=name,
@ -200,12 +206,13 @@ class CUDATemplate(KernelTemplate):
code = code.replace(self.name, kernel_name)
# create the BenchmarkRequest
bmreq = CUDABenchmarkRequest(
bmreq = CUTLASSBenchmarkRequest(
kernel_name=kernel_name,
input_tensor_meta=input_tensor_meta,
output_tensor_meta=output_tensor_meta,
extra_args=extra_args,
source_code=code,
device_type=self.device_type,
)
# kwargs has "op" argument in case of CUTLASSGemmTemplate
@ -217,16 +224,17 @@ class CUDATemplate(KernelTemplate):
supports_epilogue_fusion = self.supports_epilogue_fusion(op)
def make_kernel_render(
template_node: CUDATemplateBuffer,
template_node: CUTLASSTemplateBuffer,
epilogue_nodes: Optional[list[BaseSchedulerNode]] = None,
) -> tuple[CUDATemplateKernel, functools.partial[str]]:
) -> tuple[CUTLASSTemplateKernel, functools.partial[str]]:
assert supports_epilogue_fusion or not epilogue_nodes, (
"epilogue fusion is not supported for this kernel"
)
kernel = CUDATemplateKernel(
kernel = CUTLASSTemplateKernel(
kernel_name=str(Placeholder.KERNEL_NAME),
runtime_arg_info=self.get_runtime_arg_info(),
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
device_type=self.device_type,
)
render = functools.partial(
self.render,
@ -237,7 +245,7 @@ class CUDATemplate(KernelTemplate):
)
return kernel, render
return CUDATemplateCaller(
return CUTLASSTemplateCaller(
kernel_name,
"cutlass_gemm",
self.input_nodes,
@ -261,6 +269,18 @@ class CUDATemplate(KernelTemplate):
#include <vector>
"""
)
res.splice(
"""
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/device_memory.h"
"""
)
return res
def globals(self) -> IndentedBuffer:
@ -281,42 +301,6 @@ class CUDATemplate(KernelTemplate):
#endif
"""
)
return res
def render(self, **kwargs) -> str:
raise NotImplementedError
def get_runtime_arg_info(self) -> list[ArgInfo]:
return []
def get_runtime_arg_values(self, **kwargs) -> list[Any]:
return []
class CUTLASSTemplate(CUDATemplate):
"""
CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the
CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels.
"""
def header(self) -> IndentedBuffer:
res = super().header()
res.splice(
"""
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/device_memory.h"
"""
)
return res
def globals(self) -> IndentedBuffer:
res = super().globals()
res.splice(
"""
using namespace cute;
@ -382,11 +366,12 @@ class CUTLASSTemplate(CUDATemplate):
f"({self._DTYPE_TO_CUTLASS_SPARSE_META.get(node.get_dtype())}*)({ptr})"
)
@override
def render(self, **kwargs) -> str:
raise NotImplementedError
def get_runtime_arg_info(self) -> list[ArgInfo]:
return [ArgInfo("swizzle", "const uint8_t")]
@override
def get_runtime_arg_values(self, **kwargs) -> list[Any]:
"""
Helper method to retrieve runtime args from generate kwargs

View File

@ -7,7 +7,6 @@ import shutil
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
from typing_extensions import TypeIs
@ -22,8 +21,7 @@ from ... import config
from ...ir import Layout
from ...runtime.runtime_utils import cache_dir
from ...virtualized import V
from ..cpp_utils import DTYPE_TO_CPP
from .cuda_env import get_cuda_arch, get_cuda_version
from ..common import get_device_op_overrides
log = logging.getLogger(__name__)
@ -98,14 +96,14 @@ def try_import_cutlass() -> bool:
# contains both cutlass and cutlass_library
# we need cutlass for eVT
cutlass_python_path = path_join(config.cuda.cutlass_dir, "python")
cutlass_python_path = path_join(config.cutlass.cutlass_dir, "python")
torch_root = os.path.abspath(os.path.dirname(torch.__file__))
mock_src_path = os.path.join(
torch_root,
"_inductor",
"codegen",
"cuda",
"cutlass_lib_extensions",
"cutlass",
"lib_extensions",
"cutlass_mock_imports",
)
@ -177,7 +175,10 @@ def try_import_cutlass() -> bool:
@functools.lru_cache(8)
def _normalize_cuda_arch(arch: str) -> str:
def _normalize_cutlass_arch(arch: str) -> str:
if torch.xpu.is_available():
return arch
if int(arch) >= 100:
log.warning(
"Detected CUDA architecture >= 100: %s. We will generate operations with "
@ -229,7 +230,7 @@ class CUTLASSArgs:
raise RuntimeError(
f"{self.architectures=} or {self.cuda_version=} is None!"
)
self.architectures = _normalize_cuda_arch(self.architectures)
self.architectures = _normalize_cutlass_arch(self.architectures)
@clear_on_fresh_cache
@ -251,8 +252,8 @@ def _gen_ops_cached(arch, version) -> dict[Any, Any]:
version,
)
return {}
arch = _normalize_cuda_arch(arch)
instantiation_level: str = config.cuda.cutlass_instantiation_level
arch = _normalize_cutlass_arch(arch)
instantiation_level: str = config.cutlass.cutlass_instantiation_level
args = CUTLASSArgs(
architectures=arch,
cuda_version=version,
@ -266,6 +267,13 @@ def _gen_ops_cached(arch, version) -> dict[Any, Any]:
if hasattr(cutlass_generator, "GenerateSM100"):
cutlass_generator.GenerateSM100(manifest, args.cuda_version)
cutlass_generator.GenerateSM90(manifest, args.cuda_version)
if arch == "11":
if hasattr(cutlass_generator, "GeneratePVC"):
cutlass_generator.GeneratePVC(manifest, args.cuda_version)
else:
raise NotImplementedError(
"Arch PVC is not supported by current cutlass lib."
)
else:
try:
func = getattr(cutlass_generator, "GenerateSM" + arch)
@ -283,22 +291,34 @@ def _gen_ops_cached(arch, version) -> dict[Any, Any]:
return manifest.operations
def gen_ops() -> dict[Any, Any]:
def gen_ops(device_type: str) -> dict[Any, Any]:
"""
Generates all supported CUTLASS operations.
"""
with dynamo_timed("cutlass_utils.gen_ops"):
arch = get_cuda_arch()
version = get_cuda_version()
device_op_overrides = get_device_op_overrides(device_type)
arch = device_op_overrides.get_device_arch()
version = device_op_overrides.get_toolkit_version()
return _gen_ops_cached(arch, version)
DTYPE_TO_CUTLASS_TYPE = {
**DTYPE_TO_CPP,
torch.float16: "__half",
torch.bfloat16: "__nv_bfloat16",
torch.float8_e4m3fn: "__nv_fp8_e4m3",
}
from ..cpp_utils import DTYPE_TO_CPP
if torch.xpu.is_available():
DTYPE_TO_CUTLASS_TYPE = {
**DTYPE_TO_CPP,
torch.float16: "uint16_t",
torch.bfloat16: "uint16_t",
torch.float8_e4m3fn: "uint8_t",
}
else:
DTYPE_TO_CUTLASS_TYPE = {
**DTYPE_TO_CPP,
torch.float16: "__half",
torch.bfloat16: "__nv_bfloat16",
torch.float8_e4m3fn: "__nv_fp8_e4m3",
}
@functools.lru_cache(32)
@ -447,47 +467,3 @@ def get_max_alignment(inductor_layout: Layout) -> int:
):
return alignment
return 1
class CUDACompileSourceCapturingContext:
# Helper class for Benchmarking and Testing CUTLASS Kernels in isolation.
# Can be used to capture the sourcecode passed to CUDACodeCache.compile
def __init__(self):
self.sources = []
self._compile_patch = None
def __enter__(self, *args, **kwargs):
import unittest.mock as mock
import torch._inductor.codecache
_compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile
def my_compile(
source_code, dst_file_ext, extra_args: Optional[list[str]] = None
):
self.sources.append(source_code)
return _compile_method_orig(source_code, dst_file_ext)
# pyrefly: ignore [bad-assignment]
self._compile_patch = mock.patch(
"torch._inductor.codecache.CUDACodeCache.compile", my_compile
)
self._compile_patch.__enter__(*args, **kwargs) # type: ignore[union-attr]
return self
def __exit__(self, *args, **kwargs):
self._compile_patch.__exit__(*args, **kwargs) # type: ignore[union-attr]
def cuda_standalone_runner_compile_command(srcpath: Path, exepath: Path):
# returns command string to compile a (captured) CUDA GEMM Kernel source to a standalone executable that's ready to run
# Passes the correct preprocessor define to nvcc to ensure the standalone runner is enabled.
from torch._inductor.codecache import cuda_compile_command
extra_args = ["-DGENERATE_STANDALONE_RUNNER=1", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]
compile_command = cuda_compile_command(
[str(srcpath)], str(exepath), "exe", extra_args=extra_args
)
return compile_command

View File

@ -19,7 +19,7 @@ class ROCmCPPScheduling(BaseScheduling):
"""
Partial Scheduling implementation for ROCm C++ Kernels.
This class is intended to be used in combination with TritonScheduling,
and delegated to by CUDACombinedScheduling.
and delegated to by CombinedScheduling.
It handles fusion decisions and ROCm C++ specific template code generation.
"""

View File

@ -1764,7 +1764,7 @@ class SIMDScheduling(BaseScheduling):
partial_code.finalize_hook("<DEF_KERNEL>")
partial_code.finalize_hook("<ARGDEFS>", strict=False)
# TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion.
# TODO: Maybe unify CUTLASSTemplateKernel to also use PartialRender for flexible epilogue fusion.
for input_name in kernel.named_input_nodes.keys():
subgraph_name = f"<LOAD_INPUT_{input_name}>"

View File

@ -5886,14 +5886,12 @@ def debug_triton_code(node: BaseSchedulerNode) -> list[str]:
if multi_template and multi_template.make_kernel_render is None:
lines.append(f"{node.get_name()} Unfinalized multi template buffer")
else:
from torch._inductor.codegen.cuda_combined_scheduling import (
CUDACombinedScheduling,
)
from torch._inductor.codegen.combined_scheduling import CombinedScheduling
device = node.get_device()
assert device is not None
backend = node.scheduler.get_backend(device)
assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)), (
assert isinstance(backend, (SIMDScheduling, CombinedScheduling)), (
f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}"
)

View File

@ -0,0 +1,114 @@
# mypy: allow-untyped-defs
import logging
from typing import Optional
from torch._inductor import config
from torch._inductor.utils import is_linux
from ..cuda.compile_utils import _cutlass_include_paths
from .xpu_env import get_xpu_arch
log = logging.getLogger(__name__)
def _sycl_compiler() -> Optional[str]:
return "icpx"
def _sycl_lib_options() -> list[str]:
"""
Util function for CUTLASS backend to find the correct XPU libraries.
"""
# _set_gpu_runtime_env() # cpp_extension consults the env
from torch.utils import cpp_extension
lpaths = cpp_extension.library_paths(device_type="xpu")
extra_ldflags: list[str] = []
if is_linux():
for path in lpaths:
if "torch/lib" in path:
# don't want to depend on pytorch
continue
# -rpath ensures the DLL can find its dependencies when loaded, even
# if the library path is non-standard.
extra_ldflags.extend([f"-L{path}", "-Xlinker", f"-rpath={path}"])
else:
raise NotImplementedError(
"Unsupported env, failed to find xpu libs! Currently only Linux is supported."
)
return extra_ldflags
def _sycl_host_compiler_options() -> list[str]:
return [
"-fPIC",
]
def _sycl_arch_as_compile_option() -> str:
arc_option_map = {"pvc": "intel_gpu_pvc", "bmg": "intel_gpu_bmg"}
arch = get_xpu_arch()
return arc_option_map.get(arch, "intel_gpu_pvc")
def _sycl_compiler_options() -> list[str]:
options = [
"-DCUTLASS_ENABLE_SYCL",
"-DCUTLASS_SYCL_PROFILING_ENABLED",
"-DSYCLCOMPAT_PROFILING_ENABLED",
"-DSYCL_INTEL_TARGET",
"-gline-tables-only",
"-DCUTLASS_VERSIONS_GENERATED",
"-O3",
"-DNDEBUG",
"-std=c++17",
"-fPIE",
"-fPIC",
"-fsycl",
f"-fsycl-targets={_sycl_arch_as_compile_option()}",
"-Xspirv-translator",
"-spirv-ext=+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate",
"-fno-sycl-instrument-device-code",
"-DMKL_ILP64",
"-MD",
"-MT",
]
if config.cutlass.enable_debug_info:
options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"])
return options
def xpu_compile_command(
src_files: list[str],
dst_file: str,
dst_file_ext: str,
extra_args: Optional[list[str]] = None,
) -> str:
if extra_args is None:
extra_args = []
include_paths = _cutlass_include_paths()
sycl_lib_options = _sycl_lib_options()
sycl_host_compiler_options = _sycl_host_compiler_options()
sycl_compiler_options = _sycl_compiler_options()
options = (
["-I" + path for path in include_paths]
+ ["-isystem /include"]
+ sycl_compiler_options
+ extra_args
+ sycl_host_compiler_options
+ sycl_lib_options
)
src_file = " ".join(src_files)
res = ""
if dst_file_ext == "o":
res = f"{_sycl_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}"
elif dst_file_ext == "so":
options.append("-shared")
res = f"{_sycl_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
elif dst_file_ext == "exe":
res = f"{_sycl_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
else:
raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!")
log.debug("XPU command: %s", res)
return res

View File

@ -7,6 +7,7 @@ from ..common import (
register_device_op_overrides,
TritonScratchWorkspace,
)
from .xpu_env import get_xpu_arch, get_xpu_version
class XPUDeviceOpOverrides(DeviceOpOverrides):
@ -63,5 +64,11 @@ class XPUDeviceOpOverrides(DeviceOpOverrides):
) -> Optional[tuple[list[str], str]]:
return [f"void *global_scratch_{idx} = 0;"], f"global_scratch_{idx}"
def get_device_arch(self) -> str:
return get_xpu_arch()
def get_toolkit_version(self) -> str:
return get_xpu_version()
register_device_op_overrides("xpu", XPUDeviceOpOverrides())

View File

@ -0,0 +1,34 @@
import functools
import logging
from typing import Optional
import torch
from torch._inductor.utils import clear_on_fresh_cache
log = logging.getLogger(__name__)
@clear_on_fresh_cache
@functools.lru_cache(1)
def get_xpu_arch() -> Optional[str]:
arch_name2code = {"pvc": "11"}
try:
assert len(torch.xpu.get_arch_list()) == 1
arch_name = torch.xpu.get_arch_list()[0]
return arch_name2code[arch_name]
except Exception as e:
log.error("Error getting xpu arch: %s", e)
return None
@clear_on_fresh_cache
@functools.lru_cache(1)
def get_xpu_version() -> Optional[str]:
# string of version, like 20250101
try:
xpu_version = torch.version.xpu
return xpu_version
except Exception as e:
log.error("Error getting xpu version: %s", e)
return None

View File

@ -1713,28 +1713,12 @@ class aot_inductor_mode:
compile_standalone: bool = False
class cuda:
"""Settings for cuda backend, today this consists of cutlass"""
# CUDA arch to use for CUDA template kernel compilation.
# e.g. "70", "75", "80", "90", etc.
# When arch is None, Inductor uses torch.cuda.get_device_capability(0).
arch: Optional[str] = None
# CUDA version to use for CUDA template kernel compilation.
# e.g. "11.4", "12.1", etc.
# When version is None, Inductor uses torch.version.cuda.
version: Optional[str] = None
class cutlass:
"""Settings for cutlass backend, today this consists of cutlass"""
# Optimization level for the host compiler.
compile_opt_level: Literal["-O0", "-O1", "-O2", "-O3", "-OS"] = "-O1"
# Whether to enable device LTO (link-time-optimization).
enable_cuda_lto = False
# Whether to keep intermediate files dring compilation.
enable_ptxas_info = False
# Whether to enable debug info, e.g. line number, cutlass debug info.
enable_debug_info = False
@ -1746,7 +1730,10 @@ class cuda:
cutlass_dir = os.path.realpath(
os.environ.get(
"TORCHINDUCTOR_CUTLASS_DIR",
os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/"),
os.path.join(
os.path.dirname(torch.__file__),
"../third_party/cutlass/",
),
)
)
@ -1766,14 +1753,6 @@ class cuda:
# Whether to only use TMA-compatible kernels in CUTLASS
cutlass_tma_only = False
# Path to CUDA NVCC.
# NVCC search order:
# 1) cuda_cxx set in this config
# 2) CUDACXX environment variable
# 3) CUDA_HOME environment variable
# 4) default system search PATH.
cuda_cxx: Optional[str] = None
# Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops.
cutlass_backend_min_gemm_size: int = 1
@ -1843,6 +1822,53 @@ class cuda:
enable_caching_codegen: bool = True
class cuda(cutlass):
"""Settings for cutlass backend, today this consists of cutlass"""
# CUDA arch to use for CUDA template kernel compilation.
# e.g. "70", "75", "80", "90", etc.
# When arch is None, Inductor uses torch.cuda.get_device_capability(0).
arch: Optional[str] = None
# CUDA version to use for CUDA template kernel compilation.
# e.g. "11.4", "12.1", etc.
# When version is None, Inductor uses torch.version.cuda.
version: Optional[str] = None
# Path to CUDA NVCC.
# NVCC search order:
# 1) cuda_cxx set in this config
# 2) CUDACXX environment variable
# 3) CUDA_HOME environment variable
# 4) default system search PATH.
cuda_cxx: Optional[str] = None
# Whether to enable device LTO (link-time-optimization).
enable_cuda_lto = False
# Whether to keep intermediate files dring compilation.
enable_ptxas_info = False
class xpu(cutlass):
# Xe arch to use for SYCL template kernel compilation.
# eg. 12, 20, which corresponding to Xe12(PVC) and Xe20 (BMG)
arch: Optional[str] = None
# oneAPI version to use for SYCL template kernel compilation.
# e.g. "20250201".
version: Optional[str] = None
cutlass_dir = os.path.realpath(
os.environ.get(
"TORCHINDUCTOR_CUTLASS_DIR",
os.path.join(
os.path.dirname(torch.__file__),
"../third_party/sycl-tla/",
),
)
)
class rocm:
# Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"].
# If empty, the `native` arch is used

View File

@ -2196,7 +2196,9 @@ class CppBuilder:
)
if device_type == "cuda" and torch.version.hip is None:
from torch._inductor.codecache import _nvcc_arch_as_compile_option
from torch._inductor.codegen.cuda.compile_utils import (
_nvcc_arch_as_compile_option,
)
current_arch = _nvcc_arch_as_compile_option()
contents += textwrap.dedent(

View File

@ -489,7 +489,7 @@ MODULE_DEFAULTS: dict[str, ConfigType] = {
"aot_inductor.presets": DEFAULT, # Typing
"cuda.arch": DEFAULT, # Out of Scope
"cuda.version": DEFAULT, # Out of Scope
"cuda.cutlass_dir": DEFAULT, # Out of Scope
"cutlass.cutlass_dir": DEFAULT, # Out of Scope
"cuda.cuda_cxx": DEFAULT, # Out of Scope
"rocm.arch": DEFAULT, # Out of Scope
"rocm.ck_supported_arch": DEFAULT, # Out of Scope

View File

@ -124,13 +124,13 @@ if TYPE_CHECKING:
from torch.fx.experimental.symbolic_shapes import SympyBoolean
from torch.fx.node import Argument
from .codegen.cuda.cuda_template import CUDATemplate
from .codegen.cutlass.template import CUTLASSTemplate
from .codegen.wrapper import PythonWrapperCodegen
from .graph import GraphLowering
from .utils import IndentedBuffer
else:
CUDATemplate: TypeAlias = object
CUTLASSTemplate: TypeAlias = object
try:
@ -5017,7 +5017,7 @@ class ChoiceCaller:
During autotuning, self.benchmark() is first called to get benchmark result,
and if this choice is selected, self.output_node() is called to get the output_node.
Children classes: TritonTemplateCaller, CUDATemplateCaller.
Children classes: TritonTemplateCaller, CUTLASSTemplateCaller.
"""
def __init__(
@ -5180,14 +5180,14 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
self.make_kernel_render = self._make_kernel_renders[None]
class CUDATemplateBuffer(TemplateBuffer):
class CUTLASSTemplateBuffer(TemplateBuffer):
def __init__(
self,
layout: Layout,
inputs: Sequence[IRNode],
make_kernel_render: Callable[_P, _T],
workspace_size: int,
template: CUDATemplate,
template: CUTLASSTemplate,
supports_epilogue_fusion: bool,
) -> None:
super().__init__(layout, inputs, make_kernel_render)
@ -5210,7 +5210,7 @@ class CppTemplateBuffer(TemplateBuffer):
layout: Layout,
inputs: Sequence[IRNode],
make_kernel_render: Callable[_P, _T],
template: CUDATemplate,
template: CUTLASSTemplate,
choice: Any,
) -> None:
super().__init__(layout, inputs, make_kernel_render)

View File

@ -261,7 +261,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
and use_cutlass_template(layout, m, n, k)
and _use_cutlass_for_op(name)
):
from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
from ..codegen.cutlass.gemm_template import CUTLASS3xGemmTemplate
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
choices, layout, kernel_inputs.nodes()

View File

@ -20,7 +20,7 @@ from torch.nn.functional import ScalingType # type: ignore[attr-defined]
from torch.torch_version import TorchVersion
from .. import config as inductor_config
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
from ..codegen.cutlass.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
from ..codegen.subgraph import SubgraphChoiceCaller, SubgraphTemplate

View File

@ -5510,10 +5510,10 @@ class Scheduler:
node = typing.cast(ForeachKernelSchedulerNode, node)
# pyrefly: ignore [unbound-name]
backend_ = self.get_backend(device)
from .codegen.cuda_combined_scheduling import CUDACombinedScheduling
from .codegen.combined_scheduling import CombinedScheduling
from .codegen.simd import SIMDScheduling
if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)):
if isinstance(backend_, (SIMDScheduling, CombinedScheduling)):
backend = backend_
else:
raise AssertionError(f"{type(self)=}")

View File

@ -2665,7 +2665,7 @@ class AlgorithmSelectorCache(PersistentCache):
return_multi_template=False,
best_config_future=None,
):
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
from .codegen.cutlass.kernel import CUTLASSTemplateCaller
# Run preprocessing functions on choices
for preprocessing_fn in self.preprocessing_fns:
@ -2701,8 +2701,8 @@ class AlgorithmSelectorCache(PersistentCache):
log.debug("Max autotune selects from %s choices.", str(len(choices)))
if len(choices) == 1:
if not isinstance(choices[0], CUDATemplateCaller):
# CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size.
if not isinstance(choices[0], CUTLASSTemplateCaller):
# CUTLASSTemplateCaller still needs to go through autotuning process to retrieve workspace size.
return choices[0].output_node()
if config.deterministic:
@ -3086,12 +3086,12 @@ class AlgorithmSelectorCache(PersistentCache):
"select_algorithm_num_precompilation_exceptions"
] += 1
exceptions.append((futures[future], e))
from torch._inductor.codegen.cuda.cuda_kernel import (
CUDATemplateCaller,
from torch._inductor.codegen.cutlass.kernel import (
CUTLASSTemplateCaller,
)
if isinstance(e, CUDACompileError) and isinstance(
futures[future], CUDATemplateCaller
futures[future], CUTLASSTemplateCaller
):
log.debug(
"Exception %s for benchmark choice %s",
@ -3272,19 +3272,20 @@ class AlgorithmSelectorCache(PersistentCache):
for choice in choices:
try:
timing = cls.benchmark_choice(choice, autotune_args)
except CUDACompileError:
from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
except CUDACompileError as e:
from torch._inductor.codegen.cutlass.kernel import CUTLASSTemplateCaller
if not isinstance(choice, CUDATemplateCaller):
if not isinstance(choice, CUTLASSTemplateCaller):
log.exception(
"CUDA compilation error during autotuning: \n%s. \nIgnoring this choice."
"CUDA compilation error during autotuning: \n%s. \nIgnoring this choice.",
e,
)
timing = float("inf")
except NotImplementedError:
log.warning("Not yet implemented", exc_info=True)
timing = float("inf")
except RuntimeError as e:
from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
from torch._inductor.codegen.cutlass.kernel import CUTLASSTemplateCaller
msg = str(e)
if "invalid argument" in msg:
@ -3295,7 +3296,7 @@ class AlgorithmSelectorCache(PersistentCache):
msg += "\n\nAn unrecoverable unspecified launch failure was caught during autotuning."
msg += "\nPlease try re-running with TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1.\n\n"
if isinstance(choice, CUDATemplateCaller):
if isinstance(choice, CUTLASSTemplateCaller):
log.debug(
"Runtime error during autotuning: \n%s. \nIgnoring this choice.",
msg,
@ -3429,18 +3430,18 @@ class AlgorithmSelectorCache(PersistentCache):
return prescreen_winners
# prescreen cutlass
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
from .codegen.cutlass.kernel import CUTLASSTemplateCaller
candidates = []
if (
config.cuda.cutlass_prescreening
and len(config.cuda.cutlass_max_profiling_swizzle_options) > 1
config.cutlass.cutlass_prescreening
and len(config.cutlass.cutlass_max_profiling_swizzle_options) > 1
):
candidates.extend(
[
c
for c in choices
if isinstance(c, CUDATemplateCaller)
if isinstance(c, CUTLASSTemplateCaller)
# hardcoded to only look at swizzle=2
if c.info_dict().get("swizzle") == "2"
]
@ -3463,7 +3464,7 @@ class AlgorithmSelectorCache(PersistentCache):
"""
Prune the choices after prescreening.
"""
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
from .codegen.cutlass.kernel import CUTLASSTemplateCaller
prescreen_key = f"{name}:{inputs_key}"
@ -3477,7 +3478,7 @@ class AlgorithmSelectorCache(PersistentCache):
pruned_choices = [
choice
for choice in choices
if not isinstance(choice, CUDATemplateCaller)
if not isinstance(choice, CUTLASSTemplateCaller)
or choice.kernel_hash_key() in winner_kernel_hashes
]
return pruned_choices
@ -3523,7 +3524,7 @@ class AlgorithmSelectorCache(PersistentCache):
candidates_to_prune.add(candidate.kernel_hash_key())
else:
winner_hashes.add(candidate.hash_key())
if isinstance(candidate, CUDATemplateCaller):
if isinstance(candidate, CUTLASSTemplateCaller):
candidate.bmreq.ensure_dll_loaded()
pruned_choices = [

View File

@ -1872,9 +1872,9 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
from .virtualized import V
gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1)
if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size:
if gemm_size <= 0 or gemm_size < config.cutlass.cutlass_backend_min_gemm_size:
return False
from .codegen.cuda.cutlass_utils import try_import_cutlass
from .codegen.cutlass.utils import try_import_cutlass
# Do not use cutlass template on ROCm
if torch.version.hip:
@ -1893,9 +1893,9 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
if not try_import_cutlass():
log.warning(
"Failed to import CUTLASS lib. Please check whether "
"_inductor.config.cuda.cutlass_dir %s is set correctly. "
"_inductor.config.cutlass.cutlass_dir %s is set correctly. "
"Skipping CUTLASS backend for now.",
config.cuda.cutlass_dir,
config.cutlass.cutlass_dir,
)
return False
return res
@ -1903,7 +1903,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
def _use_cutlass_for_op(op_name: str) -> bool:
"""Check if CUTLASS should be used for the given operation."""
enabled_ops = config.cuda.cutlass_enabled_ops.upper()
enabled_ops = config.cutlass.cutlass_enabled_ops.upper()
if enabled_ops == "ALL":
return True
return op_name.upper() in [x.strip() for x in enabled_ops.split(",")]