mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			36 Commits
		
	
	
		
			mlazos/use
			...
			ciflow/ind
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 0eff316cd3 | |||
| 315beb2cdd | |||
| 6b7e6d7fd6 | |||
| 4054dde426 | |||
| 8878554e80 | |||
| 346fd135e2 | |||
| 9a819cc30d | |||
| 421d25d9a4 | |||
| 1f8bbd493a | |||
| f52e6427af | |||
| 3d86ea1534 | |||
| 0f6dbb69de | |||
| f0d3f7db52 | |||
| ec887d1962 | |||
| 2cae11a0d1 | |||
| c5b67af4e7 | |||
| 53e25e520e | |||
| 4381da8371 | |||
| 938b5b7424 | |||
| 1fc3e1abe8 | |||
| f7d1d70526 | |||
| 3ae7cacd26 | |||
| 2060620b08 | |||
| 501a28a3e7 | |||
| ac2acb2097 | |||
| 8bc69def41 | |||
| 874fc3199c | |||
| a0819a3a88 | |||
| e36aeb0a2e | |||
| c924349562 | |||
| c4a551d72f | |||
| 08f5fe8139 | |||
| 8604af4bfa | |||
| f4ef77b220 | |||
| 44789dc58d | |||
| 2ddfc76a10 | 
@ -125,7 +125,7 @@ class CutlassExperimentConfig(ExperimentConfig):
 | 
			
		||||
    def to_options(self) -> dict[str, Any]:
 | 
			
		||||
        return {
 | 
			
		||||
            **super().to_options(),
 | 
			
		||||
            "cuda.cutlass_instantiation_level": self.cutlass_instantiation_level,
 | 
			
		||||
            "cutlass.cutlass_instantiation_level": self.cutlass_instantiation_level,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,6 @@ from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
 | 
			
		||||
from torch._inductor import config, metrics
 | 
			
		||||
from torch._inductor.codecache import (
 | 
			
		||||
    BypassFxGraphCache,
 | 
			
		||||
    cuda_compile_command,
 | 
			
		||||
    CUDACodeCache,
 | 
			
		||||
    FxGraphCachePickler,
 | 
			
		||||
    FxGraphHashDetails,
 | 
			
		||||
@ -32,6 +31,7 @@ from torch._inductor.codecache import (
 | 
			
		||||
    TensorMetadata,
 | 
			
		||||
    TensorMetadataAndValues,
 | 
			
		||||
)
 | 
			
		||||
from torch._inductor.codegen.cuda.compile_utils import cuda_compile_command
 | 
			
		||||
from torch._inductor.cpp_builder import normalize_path_separator
 | 
			
		||||
from torch._inductor.custom_graph_pass import (
 | 
			
		||||
    CustomGraphModulePass,
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,9 @@ from pathlib import Path
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from torch._dynamo.exc import BackendCompilerFailed
 | 
			
		||||
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
 | 
			
		||||
from torch._inductor.codegen.cutlass.serialization import (
 | 
			
		||||
    get_cutlass_operation_serializer,
 | 
			
		||||
)
 | 
			
		||||
from torch._inductor.utils import clear_caches
 | 
			
		||||
from torch.export import Dim
 | 
			
		||||
from torch.testing._internal.logging_utils import log_settings
 | 
			
		||||
@ -32,11 +34,8 @@ import torch.version
 | 
			
		||||
from torch._dynamo import config as dynamo_config
 | 
			
		||||
from torch._dynamo.utils import counters
 | 
			
		||||
from torch._inductor import config
 | 
			
		||||
from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
 | 
			
		||||
from torch._inductor.codegen.cuda.cutlass_utils import (
 | 
			
		||||
    _gen_ops_cached,
 | 
			
		||||
    get_max_alignment,
 | 
			
		||||
)
 | 
			
		||||
from torch._inductor.codegen.cutlass.kernel import CUTLASSTemplateCaller
 | 
			
		||||
from torch._inductor.codegen.cutlass.utils import _gen_ops_cached, get_max_alignment
 | 
			
		||||
from torch._inductor.exc import InductorError
 | 
			
		||||
from torch._inductor.ir import FixedLayout
 | 
			
		||||
from torch._inductor.select_algorithm import NoValidChoicesError
 | 
			
		||||
@ -133,10 +132,10 @@ use_evt_config = config.patch(
 | 
			
		||||
    {
 | 
			
		||||
        "max_autotune": True,
 | 
			
		||||
        "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
        "cuda.cutlass_max_profiling_configs": 1,
 | 
			
		||||
        "cutlass.cutlass_max_profiling_configs": 1,
 | 
			
		||||
        "benchmark_epilogue_fusion": False,  # EVT doesn't support benchmark fusion yet
 | 
			
		||||
        "cuda.cutlass_tma_only": True,
 | 
			
		||||
        "cuda.cutlass_epilogue_fusion_enabled": True,
 | 
			
		||||
        "cutlass.cutlass_tma_only": True,
 | 
			
		||||
        "cutlass.cutlass_epilogue_fusion_enabled": True,
 | 
			
		||||
    }
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -144,9 +143,9 @@ fp8_config = config.patch(
 | 
			
		||||
    {
 | 
			
		||||
        "max_autotune": True,
 | 
			
		||||
        "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
        "cuda.cutlass_max_profiling_configs": 1,
 | 
			
		||||
        "cutlass.cutlass_max_profiling_configs": 1,
 | 
			
		||||
        "benchmark_epilogue_fusion": False,  # EVT doesn't support benchmark fusion yet
 | 
			
		||||
        "cuda.cutlass_tma_only": True,
 | 
			
		||||
        "cutlass.cutlass_tma_only": True,
 | 
			
		||||
    }
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -198,7 +197,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
        ref_result = model(a, b, extra_args)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"],
 | 
			
		||||
            torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
 | 
			
		||||
            num_fusions,
 | 
			
		||||
        )
 | 
			
		||||
        torch.testing.assert_close(result, ref_result)
 | 
			
		||||
@ -206,7 +205,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
    def test_check_paths(self):
 | 
			
		||||
        cutlass_mock_imports_path = os.path.join(
 | 
			
		||||
            os.path.dirname(torch.__file__),
 | 
			
		||||
            "_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports",
 | 
			
		||||
            "_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports",
 | 
			
		||||
        )
 | 
			
		||||
        cutlass_mock_cuda_path = os.path.join(cutlass_mock_imports_path, "cuda")
 | 
			
		||||
        cutlass_mock_pydot_path = os.path.join(cutlass_mock_imports_path, "pydot")
 | 
			
		||||
@ -234,8 +233,8 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                "compile_threads": 4,
 | 
			
		||||
                "cuda.cutlass_backend_min_gemm_size": 100000,
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                "cutlass.cutlass_backend_min_gemm_size": 100000,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            with mock.patch(
 | 
			
		||||
@ -251,7 +250,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
 | 
			
		||||
    @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
 | 
			
		||||
    def test_import_cutlass(self):
 | 
			
		||||
        from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
 | 
			
		||||
        from torch._inductor.codegen.cutlass.utils import try_import_cutlass
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(try_import_cutlass())
 | 
			
		||||
 | 
			
		||||
@ -259,7 +258,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
        import cutlass_library  # noqa: F401
 | 
			
		||||
 | 
			
		||||
    def test_cutlass_key(self):
 | 
			
		||||
        from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
 | 
			
		||||
        from torch._inductor.codegen.cutlass.utils import try_import_cutlass
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(try_import_cutlass())
 | 
			
		||||
        from torch._inductor.codecache import cutlass_key
 | 
			
		||||
@ -287,7 +286,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                "autotune_in_subproc": True,
 | 
			
		||||
                "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                "compile_threads": 4,
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 4,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 4,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            Y_compiled = torch.compile(torch.mm)(a, b)
 | 
			
		||||
@ -324,7 +323,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                "autotune_in_subproc": True,
 | 
			
		||||
                "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                "compile_threads": 4,
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 4,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 4,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            for x_shape in x_shapes:
 | 
			
		||||
@ -354,7 +353,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                "autotune_in_subproc": True,
 | 
			
		||||
                "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                "compile_threads": 4,
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 4,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 4,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            Y_compiled = torch.compile(torch.bmm)(a, b)
 | 
			
		||||
@ -386,7 +385,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "autotune_in_subproc": True,
 | 
			
		||||
                "max_autotune_gemm_backends": max_autotune_gemm_backends,
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 1,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 1,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            from torch._inductor.utils import run_and_get_code
 | 
			
		||||
@ -428,8 +427,8 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "autotune_in_subproc": True,
 | 
			
		||||
                "max_autotune_gemm_backends": max_autotune_gemm_backends,
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 1,
 | 
			
		||||
                "cuda.cutlass_max_profiling_swizzle_options": [
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 1,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_swizzle_options": [
 | 
			
		||||
                    1,
 | 
			
		||||
                    2,
 | 
			
		||||
                    4,
 | 
			
		||||
@ -505,7 +504,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                {
 | 
			
		||||
                    "max_autotune": True,
 | 
			
		||||
                    "max_autotune_gemm_backends": max_autotune_gemm_backends,
 | 
			
		||||
                    "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                    "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                }
 | 
			
		||||
            ),
 | 
			
		||||
            dynamo_config.patch({"error_on_recompile": dynamic}),
 | 
			
		||||
@ -595,9 +594,9 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                {
 | 
			
		||||
                    "max_autotune": True,
 | 
			
		||||
                    "max_autotune_gemm_backends": max_autotune_gemm_backends,
 | 
			
		||||
                    "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                    "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                    "benchmark_epilogue_fusion": False,  # EVT doesn't support benchmark fusion yet
 | 
			
		||||
                    "cuda.cutlass_tma_only": True,
 | 
			
		||||
                    "cutlass.cutlass_tma_only": True,
 | 
			
		||||
                }
 | 
			
		||||
            ),
 | 
			
		||||
            dynamo_config.patch({"error_on_recompile": dynamic}),
 | 
			
		||||
@ -677,7 +676,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                    {
 | 
			
		||||
                        "max_autotune": True,
 | 
			
		||||
                        "max_autotune_gemm_backends": max_autotune_gemm_backends,
 | 
			
		||||
                        "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                        "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                    }
 | 
			
		||||
                ),
 | 
			
		||||
                dynamo_config.patch({"error_on_recompile": dynamic}),
 | 
			
		||||
@ -746,7 +745,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
            {
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "max_autotune_gemm_backends": max_autotune_gemm_backends,
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            expected = [model(*input) for input in inputs]
 | 
			
		||||
@ -775,8 +774,8 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "autotune_in_subproc": True,
 | 
			
		||||
                "max_autotune_gemm_backends": max_autotune_gemm_backends,
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                "cuda.cutlass_op_allowlist_regex": "stream_k",  # only stream-k GEMM Kernels
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                "cutlass.cutlass_op_allowlist_regex": "stream_k",  # only stream-k GEMM Kernels
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            for M, K, N in (
 | 
			
		||||
@ -819,7 +818,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
            {
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                "cuda.cutlass_op_allowlist_regex": "stream_k",  # only stream-k GEMM Kernels
 | 
			
		||||
                "cutlass.cutlass_op_allowlist_regex": "stream_k",  # only stream-k GEMM Kernels
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            with self.assertRaisesRegex(InductorError, r".*NoValidChoicesError.*"):
 | 
			
		||||
@ -849,8 +848,8 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                {
 | 
			
		||||
                    "max_autotune": True,
 | 
			
		||||
                    "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                    "cuda.cutlass_max_profiling_configs": 1,
 | 
			
		||||
                    "cuda.cutlass_op_allowlist_regex": "stream_k",  # only stream-k GEMM Kernels
 | 
			
		||||
                    "cutlass.cutlass_max_profiling_configs": 1,
 | 
			
		||||
                    "cutlass.cutlass_op_allowlist_regex": "stream_k",  # only stream-k GEMM Kernels
 | 
			
		||||
                }
 | 
			
		||||
            ):
 | 
			
		||||
                _ = compiled_model(a, b)
 | 
			
		||||
@ -884,15 +883,15 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "autotune_in_subproc": True,
 | 
			
		||||
                "max_autotune_gemm_backends": max_autotune_gemm_backends,
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 4,
 | 
			
		||||
                "cuda.version": "12.2",  # required to enable the Kernels we need
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 4,
 | 
			
		||||
                "cuda.cuda_version": "12.2",  # required to enable the Kernels we need
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            counters["inductor"]["cuda_epilogue_fusion_counter"] = 0
 | 
			
		||||
            counters["inductor"]["cutlass_epilogue_fusion_counter"] = 0
 | 
			
		||||
            assert mm is not None
 | 
			
		||||
            Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
 | 
			
		||||
            Y = mm(a, b)
 | 
			
		||||
            actual_count = counters["inductor"]["cuda_epilogue_fusion_counter"]
 | 
			
		||||
            actual_count = counters["inductor"]["cutlass_epilogue_fusion_counter"]
 | 
			
		||||
            assert actual_count == expected_fuse_count, (
 | 
			
		||||
                f"Expected fuse count of {expected_fuse_count} but got {actual_count}"
 | 
			
		||||
            )
 | 
			
		||||
@ -983,7 +982,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "autotune_in_subproc": True,
 | 
			
		||||
                "max_autotune_gemm_backends": max_autotune_gemm_backends,
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
 | 
			
		||||
@ -1002,7 +1001,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "autotune_in_subproc": False,
 | 
			
		||||
                "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            model = MyModel()
 | 
			
		||||
@ -1040,7 +1039,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "autotune_in_subproc": False,
 | 
			
		||||
                "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            model = MyModel()
 | 
			
		||||
@ -1073,8 +1072,8 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "autotune_in_subproc": False,
 | 
			
		||||
                "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                "cuda.cutlass_op_allowlist_regex": "128x256x64.*stream_k_warpspecialized_cooperative_epi_nosmem",
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 1,
 | 
			
		||||
                "cutlass.cutlass_op_allowlist_regex": "128x256x64.*stream_k_warpspecialized_cooperative_epi_nosmem",
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 1,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            model = MyModel()
 | 
			
		||||
@ -1117,7 +1116,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "autotune_in_subproc": True,
 | 
			
		||||
                "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                "autotune_local_cache": True,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
@ -1157,9 +1156,9 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                {
 | 
			
		||||
                    "max_autotune": True,
 | 
			
		||||
                    "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                    "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                    "cuda.cutlass_op_allowlist_regex": "",
 | 
			
		||||
                    "cuda.cutlass_op_denylist_regex": "pingpong",
 | 
			
		||||
                    "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                    "cutlass.cutlass_op_allowlist_regex": "",
 | 
			
		||||
                    "cutlass.cutlass_op_denylist_regex": "pingpong",
 | 
			
		||||
                }
 | 
			
		||||
            ):
 | 
			
		||||
                with mock.patch(
 | 
			
		||||
@ -1175,7 +1174,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                    assert op_name == "addmm"
 | 
			
		||||
                    cuda_template_count = 0
 | 
			
		||||
                    for choice in choices:
 | 
			
		||||
                        if isinstance(choice, CUDATemplateCaller):
 | 
			
		||||
                        if isinstance(choice, CUTLASSTemplateCaller):
 | 
			
		||||
                            choice_info = choice.info_dict()
 | 
			
		||||
                            op_conf_name = choice_info.get("op_conf_name", "")
 | 
			
		||||
                            assert isinstance(op_conf_name, str)
 | 
			
		||||
@ -1183,7 +1182,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                                "All pingpong Kernels should have been filtered"
 | 
			
		||||
                            )
 | 
			
		||||
                            cuda_template_count += 1
 | 
			
		||||
                    assert cuda_template_count > 0, "No CUDATemplateCaller choices"
 | 
			
		||||
                    assert cuda_template_count > 0, "No CUTLASSTemplateCaller choices"
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(not SM90OrLater, "need sm_90")
 | 
			
		||||
    @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
 | 
			
		||||
@ -1202,9 +1201,9 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                {
 | 
			
		||||
                    "max_autotune": True,
 | 
			
		||||
                    "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                    "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                    "cuda.cutlass_op_allowlist_regex": "pingpong",
 | 
			
		||||
                    "cuda.cutlass_op_denylist_regex": None,
 | 
			
		||||
                    "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                    "cutlass.cutlass_op_allowlist_regex": "pingpong",
 | 
			
		||||
                    "cutlass.cutlass_op_denylist_regex": None,
 | 
			
		||||
                }
 | 
			
		||||
            ):
 | 
			
		||||
                with mock.patch(
 | 
			
		||||
@ -1220,7 +1219,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                    assert op_name == "addmm"
 | 
			
		||||
                    cuda_template_count = 0
 | 
			
		||||
                    for choice in choices:
 | 
			
		||||
                        if isinstance(choice, CUDATemplateCaller):
 | 
			
		||||
                        if isinstance(choice, CUTLASSTemplateCaller):
 | 
			
		||||
                            choice_info = choice.info_dict()
 | 
			
		||||
                            op_conf_name = choice_info.get("op_conf_name", "")
 | 
			
		||||
                            assert isinstance(op_conf_name, str)
 | 
			
		||||
@ -1228,7 +1227,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                                "Only pingpong Kernels should have been allowed"
 | 
			
		||||
                            )
 | 
			
		||||
                            cuda_template_count += 1
 | 
			
		||||
                    assert cuda_template_count > 0, "No CUDATemplateCaller choices"
 | 
			
		||||
                    assert cuda_template_count > 0, "No CUTLASSTemplateCaller choices"
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(not SM90OrLater, "need sm_90")
 | 
			
		||||
    @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
 | 
			
		||||
@ -1273,7 +1272,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                    {
 | 
			
		||||
                        "max_autotune": True,
 | 
			
		||||
                        "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                        "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                        "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                    }
 | 
			
		||||
                ):
 | 
			
		||||
                    with mock.patch(
 | 
			
		||||
@ -1296,7 +1295,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                        _, choices, _, _ = args
 | 
			
		||||
                        cuda_template_count = 0
 | 
			
		||||
                        for choice in choices:
 | 
			
		||||
                            if isinstance(choice, CUDATemplateCaller):
 | 
			
		||||
                            if isinstance(choice, CUTLASSTemplateCaller):
 | 
			
		||||
                                choice_info = choice.info_dict()
 | 
			
		||||
                                op_conf_name = choice_info.get("op_conf_name", "")
 | 
			
		||||
                                assert isinstance(op_conf_name, str)
 | 
			
		||||
@ -1309,7 +1308,9 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                                        "fastaccum Kernels should have been filtered"
 | 
			
		||||
                                    )
 | 
			
		||||
                                cuda_template_count += 1
 | 
			
		||||
                        assert cuda_template_count > 0, "No CUDATemplateCaller choices"
 | 
			
		||||
                        assert cuda_template_count > 0, (
 | 
			
		||||
                            "No CUTLASSTemplateCaller choices"
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
        run_test(True)
 | 
			
		||||
        run_test(False)
 | 
			
		||||
@ -1350,7 +1351,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                {
 | 
			
		||||
                    "max_autotune": True,
 | 
			
		||||
                    "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                    "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                    "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                }
 | 
			
		||||
            ),
 | 
			
		||||
            mock.patch(
 | 
			
		||||
@ -1375,7 +1376,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                assert op_name == "mm"
 | 
			
		||||
                cuda_template_count = 0
 | 
			
		||||
                for choice in choices:
 | 
			
		||||
                    if isinstance(choice, CUDATemplateCaller):
 | 
			
		||||
                    if isinstance(choice, CUTLASSTemplateCaller):
 | 
			
		||||
                        choice_info = choice.info_dict()
 | 
			
		||||
                        op_conf_name = choice_info.get("op_conf_name", "")
 | 
			
		||||
                        assert isinstance(op_conf_name, str)
 | 
			
		||||
@ -1384,7 +1385,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                self.assertGreater(
 | 
			
		||||
                    cuda_template_count,
 | 
			
		||||
                    0,
 | 
			
		||||
                    "No CUDATemplateCaller choices found for matmul with shape "
 | 
			
		||||
                    "No CUTLASSTemplateCaller choices found for matmul with shape "
 | 
			
		||||
                    f"M={M}, N={N}, K={K}",
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
@ -1461,13 +1462,13 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
            {
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "max_autotune_gemm_backends": max_autotune_gemm_backends,
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                "cuda.generate_test_runner": True,  # put standalone runner in the generated code
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                "cutlass.generate_test_runner": True,  # put standalone runner in the generated code
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            from tempfile import NamedTemporaryFile
 | 
			
		||||
 | 
			
		||||
            from torch._inductor.codegen.cuda.cutlass_utils import (
 | 
			
		||||
            from torch._inductor.codegen.cuda.compile_utils import (
 | 
			
		||||
                cuda_standalone_runner_compile_command,
 | 
			
		||||
                CUDACompileSourceCapturingContext,
 | 
			
		||||
            )
 | 
			
		||||
@ -1544,7 +1545,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
            {
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "max_autotune_gemm_backends": "ATEN,TRITON,CUTLASS",
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                # needed for log searching
 | 
			
		||||
                "fx_graph_cache": False,
 | 
			
		||||
                "fx_graph_remote_cache": False,
 | 
			
		||||
@ -1553,7 +1554,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
            with (
 | 
			
		||||
                log_settings("+inductor"),
 | 
			
		||||
                self.assertLogs(
 | 
			
		||||
                    logger="torch._inductor.codegen.cuda", level=logging.DEBUG
 | 
			
		||||
                    logger="torch._inductor.codegen.cutlass", level=logging.DEBUG
 | 
			
		||||
                ) as test_log,
 | 
			
		||||
            ):
 | 
			
		||||
                Y_compiled = torch.compile(mm, dynamic=False)(a, b)
 | 
			
		||||
@ -1591,7 +1592,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
        expected = model(A, B)
 | 
			
		||||
 | 
			
		||||
        # Track render calls
 | 
			
		||||
        from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate
 | 
			
		||||
        from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate
 | 
			
		||||
 | 
			
		||||
        original_render = CUTLASSGemmTemplate.render
 | 
			
		||||
        render_call_count = 0
 | 
			
		||||
@ -1608,8 +1609,8 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                    "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                    "fx_graph_cache": False,
 | 
			
		||||
                    "fx_graph_remote_cache": False,
 | 
			
		||||
                    "cuda.enable_caching_codegen": True,
 | 
			
		||||
                    "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                    "cutlass.enable_caching_codegen": True,
 | 
			
		||||
                    "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                }
 | 
			
		||||
            ):
 | 
			
		||||
                compiled_model = torch.compile(model, fullgraph=True)
 | 
			
		||||
@ -1645,7 +1646,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
        d = torch.randn(64, 128).cuda().half().t()
 | 
			
		||||
 | 
			
		||||
        # Track render calls
 | 
			
		||||
        from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate
 | 
			
		||||
        from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate
 | 
			
		||||
 | 
			
		||||
        original_render = CUTLASSGemmTemplate.render
 | 
			
		||||
        render_call_count = 0
 | 
			
		||||
@ -1660,10 +1661,10 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                {
 | 
			
		||||
                    "max_autotune": True,
 | 
			
		||||
                    "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                    "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                    "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                    "fx_graph_cache": False,
 | 
			
		||||
                    "fx_graph_remote_cache": False,
 | 
			
		||||
                    "cuda.enable_caching_codegen": True,
 | 
			
		||||
                    "cutlass.enable_caching_codegen": True,
 | 
			
		||||
                }
 | 
			
		||||
            ):
 | 
			
		||||
                # Get expected results
 | 
			
		||||
@ -1706,7 +1707,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
        b = torch.randn(32, 64).cuda().half().t()
 | 
			
		||||
 | 
			
		||||
        # Track render calls
 | 
			
		||||
        from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate
 | 
			
		||||
        from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate
 | 
			
		||||
 | 
			
		||||
        original_render = CUTLASSGemmTemplate.render
 | 
			
		||||
        render_call_count = 0
 | 
			
		||||
@ -1721,10 +1722,10 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                {
 | 
			
		||||
                    "max_autotune": True,
 | 
			
		||||
                    "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                    "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                    "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                    "fx_graph_cache": False,
 | 
			
		||||
                    "fx_graph_remote_cache": False,
 | 
			
		||||
                    "cuda.enable_caching_codegen": True,
 | 
			
		||||
                    "cutlass.enable_caching_codegen": True,
 | 
			
		||||
                }
 | 
			
		||||
            ):
 | 
			
		||||
                # Get expected results
 | 
			
		||||
@ -1752,7 +1753,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
            {
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "max_autotune_gemm_backends": max_autotune_gemm_backends,
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            compiled = torch.compile(torch.mm)
 | 
			
		||||
@ -1771,7 +1772,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
            {
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "max_autotune_gemm_backends": max_autotune_gemm_backends,
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 2,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 2,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            compiled = torch.compile(torch.mm)
 | 
			
		||||
@ -1795,7 +1796,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
            {
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 1,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 1,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            _ = torch.compile(model)(B)
 | 
			
		||||
@ -1817,13 +1818,14 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
            {
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 1,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 1,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            _ = torch.compile(model)(B)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
 | 
			
		||||
            torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
 | 
			
		||||
            1,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(not SM90OrLater, "need sm_90")
 | 
			
		||||
@ -1845,7 +1847,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
            {
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 1,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 1,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            _ = torch.compile(model)(B)
 | 
			
		||||
@ -1871,7 +1873,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
            {
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                "cuda.cutlass_max_profiling_configs": 1,
 | 
			
		||||
                "cutlass.cutlass_max_profiling_configs": 1,
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            if use_aoti:
 | 
			
		||||
@ -1917,7 +1919,8 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
        ref_result = model(a, b, extra_args)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
 | 
			
		||||
            torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
 | 
			
		||||
            1,
 | 
			
		||||
        )
 | 
			
		||||
        torch.testing.assert_close(result, ref_result)
 | 
			
		||||
 | 
			
		||||
@ -1968,18 +1971,19 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
 | 
			
		||||
        # baseline is cutlass kernel + triton
 | 
			
		||||
        # matches expected casting behavior
 | 
			
		||||
        with config.patch({"cuda.cutlass_epilogue_fusion_enabled": False}):
 | 
			
		||||
        with config.patch({"cutlass.cutlass_epilogue_fusion_enabled": False}):
 | 
			
		||||
            ref_result = torch.compile(model)(a, b, extra_args)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 0
 | 
			
		||||
            torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
 | 
			
		||||
            0,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        torch._dynamo.reset()
 | 
			
		||||
        result = torch.compile(model)(a, b, extra_args)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"],
 | 
			
		||||
            torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
 | 
			
		||||
            1,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -2037,7 +2041,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(
 | 
			
		||||
                torch._dynamo.utils.counters["inductor"][
 | 
			
		||||
                    "cuda_epilogue_fusion_counter"
 | 
			
		||||
                    "cutlass_epilogue_fusion_counter"
 | 
			
		||||
                ],
 | 
			
		||||
                2 * (i + 1),
 | 
			
		||||
            )
 | 
			
		||||
@ -2064,7 +2068,8 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
        ref_result = model(a, b, extra_args)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
 | 
			
		||||
            torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
 | 
			
		||||
            1,
 | 
			
		||||
        )
 | 
			
		||||
        torch.testing.assert_close(result, ref_result)
 | 
			
		||||
 | 
			
		||||
@ -2368,7 +2373,7 @@ class TestCutlassBackend(TestCase):
 | 
			
		||||
                        "max_autotune_gemm_backends": "CUTLASS",
 | 
			
		||||
                        # needed for log searching
 | 
			
		||||
                        "force_disable_caches": True,
 | 
			
		||||
                        "cuda.cutlass_max_profiling_swizzle_options": [2],
 | 
			
		||||
                        "cutlass.cutlass_max_profiling_swizzle_options": [2],
 | 
			
		||||
                    }
 | 
			
		||||
                ):
 | 
			
		||||
                    with mock.patch(
 | 
			
		||||
 | 
			
		||||
@ -5,7 +5,7 @@ import sympy
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch._dynamo.test_case import TestCase
 | 
			
		||||
from torch._inductor.codegen.cuda.cutlass_utils import (
 | 
			
		||||
from torch._inductor.codegen.cutlass.utils import (
 | 
			
		||||
    torch_dtype_to_cutlass_type,
 | 
			
		||||
    try_import_cutlass,
 | 
			
		||||
)
 | 
			
		||||
@ -28,7 +28,7 @@ if try_import_cutlass():
 | 
			
		||||
    DataType = cutlass_lib.DataType
 | 
			
		||||
    from cutlass_cppgen.backend.evt.ir.tensor import Tensor as CutlassTensor
 | 
			
		||||
 | 
			
		||||
    from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
 | 
			
		||||
    from torch._inductor.codegen.cutlass.lib_extensions.evt_extensions import (
 | 
			
		||||
        _render_argument_type,
 | 
			
		||||
        _trace,
 | 
			
		||||
        trace,
 | 
			
		||||
@ -107,7 +107,7 @@ class TestCutlassEVT(TestCase):
 | 
			
		||||
    @unittest.skipIf(not SM90OrLater, "need sm_90")
 | 
			
		||||
    @unittest.skipIf(not try_import_cutlass(), "requires cutlass")
 | 
			
		||||
    def test_py_codegen_accumulator_return(self):
 | 
			
		||||
        from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
 | 
			
		||||
        from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen
 | 
			
		||||
        from torch._inductor.virtualized import V
 | 
			
		||||
 | 
			
		||||
        size = (100, 300, 200)
 | 
			
		||||
@ -164,7 +164,7 @@ return tmp_0, tmp_2, D""",
 | 
			
		||||
    @unittest.skipIf(not SM90OrLater, "need sm_90")
 | 
			
		||||
    @unittest.skipIf(not try_import_cutlass(), "requires cutlass")
 | 
			
		||||
    def test_py_codegen_disjoint_read_indexing(self):
 | 
			
		||||
        from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
 | 
			
		||||
        from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen
 | 
			
		||||
        from torch._inductor.virtualized import V
 | 
			
		||||
 | 
			
		||||
        size = (100, 300, 200)
 | 
			
		||||
@ -213,7 +213,7 @@ index strides [200, 60000, 1], and layout stride [60000, 200, 1]""",
 | 
			
		||||
    @unittest.skipIf(not SM90OrLater, "need sm_90")
 | 
			
		||||
    @unittest.skipIf(not try_import_cutlass(), "requires cutlass")
 | 
			
		||||
    def test_py_codegen_broadcasting(self):
 | 
			
		||||
        from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
 | 
			
		||||
        from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen
 | 
			
		||||
        from torch._inductor.virtualized import V
 | 
			
		||||
 | 
			
		||||
        size = (100, 300, 200)
 | 
			
		||||
@ -273,7 +273,7 @@ return tmp_0, tmp_2, D""",
 | 
			
		||||
    @unittest.skipIf(not SM90OrLater, "need sm_90")
 | 
			
		||||
    @unittest.skipIf(not try_import_cutlass(), "requires cutlass")
 | 
			
		||||
    def test_py_codegen(self):
 | 
			
		||||
        from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
 | 
			
		||||
        from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen
 | 
			
		||||
        from torch._inductor.virtualized import V
 | 
			
		||||
 | 
			
		||||
        size = (100, 300, 200)
 | 
			
		||||
@ -329,7 +329,7 @@ return tmp_1, D""",
 | 
			
		||||
    @unittest.skipIf(not SM90OrLater, "need sm_90")
 | 
			
		||||
    @unittest.skipIf(not try_import_cutlass(), "requires cutlass")
 | 
			
		||||
    def test_example_tensor_creation(self):
 | 
			
		||||
        from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
 | 
			
		||||
        from torch._inductor.codegen.cutlass.lib_extensions.evt_extensions import (
 | 
			
		||||
            create_example_tensors,
 | 
			
		||||
        )
 | 
			
		||||
        from torch._inductor.virtualized import V
 | 
			
		||||
 | 
			
		||||
@ -292,7 +292,7 @@ class TestPublicBindings(TestCase):
 | 
			
		||||
        # do not get imported by public code.
 | 
			
		||||
        # DO NOT add public modules here.
 | 
			
		||||
        private_allowlist = {
 | 
			
		||||
            "torch._inductor.codegen.cuda.cuda_kernel",
 | 
			
		||||
            "torch._inductor.codegen.cutlass.kernel",
 | 
			
		||||
            # TODO(#133647): Remove the onnx._internal entries after
 | 
			
		||||
            # onnx and onnxscript are installed in CI.
 | 
			
		||||
            "torch.onnx._internal.exporter",
 | 
			
		||||
@ -357,8 +357,8 @@ class TestPublicBindings(TestCase):
 | 
			
		||||
            "torch.testing._internal.distributed.rpc.rpc_test",
 | 
			
		||||
            "torch.testing._internal.distributed.rpc.tensorpipe_rpc_agent_test_fixture",
 | 
			
		||||
            "torch.testing._internal.distributed.rpc_utils",
 | 
			
		||||
            "torch._inductor.codegen.cuda.cuda_template",
 | 
			
		||||
            "torch._inductor.codegen.cuda.gemm_template",
 | 
			
		||||
            "torch._inductor.codegen.cutlass.template",
 | 
			
		||||
            "torch._inductor.codegen.cutlass.gemm_template",
 | 
			
		||||
            "torch._inductor.codegen.cpp_template",
 | 
			
		||||
            "torch._inductor.codegen.cpp_gemm_template",
 | 
			
		||||
            "torch._inductor.codegen.cpp_micro_gemm",
 | 
			
		||||
 | 
			
		||||
@ -264,4 +264,4 @@
 | 
			
		||||
  "torch/_inductor/utils.py": {
 | 
			
		||||
    "class IndentedBuffer": 145
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -682,7 +682,7 @@ class TritonCPUBenchmarkRequest(CPUDeviceBenchmarkMixin, TritonBenchmarkRequest)
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
 | 
			
		||||
class CUTLASSBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
 | 
			
		||||
    """
 | 
			
		||||
    A class to handle CUDA (CUTLASS) benchmark requests. This class is for
 | 
			
		||||
    managing the lifecycle of a CUDA kernel benchmark, including compiling
 | 
			
		||||
@ -699,6 +699,7 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
 | 
			
		||||
        output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
 | 
			
		||||
        extra_args: Iterable[Any],
 | 
			
		||||
        source_code: str,
 | 
			
		||||
        device_type: str = "cuda",
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
 | 
			
		||||
        self.source_code = source_code
 | 
			
		||||
@ -708,7 +709,12 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
 | 
			
		||||
        self._workspace_size_updated = False
 | 
			
		||||
        self.hash_key: str = ""
 | 
			
		||||
        self.source_file: str = ""
 | 
			
		||||
        self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so")
 | 
			
		||||
        self.device_type = device_type
 | 
			
		||||
        self.codecache_cls = CUDACodeCache
 | 
			
		||||
        self.device_interface = get_interface_for_device(device_type)
 | 
			
		||||
        self.hash_key, self.source_file = self.codecache_cls.write(
 | 
			
		||||
            self.source_code, "so"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def precompile(self):
 | 
			
		||||
        """
 | 
			
		||||
@ -723,7 +729,7 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
 | 
			
		||||
        self, *input_tensors: torch.Tensor, out: torch.Tensor
 | 
			
		||||
    ) -> Callable[[], None]:
 | 
			
		||||
        """
 | 
			
		||||
        Create a function to run the CUDA kernel with the given input and output tensors.
 | 
			
		||||
        Create a function to run the CUDA/XPU kernel with the given input and output tensors.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        self.ensure_dll_loaded()
 | 
			
		||||
@ -738,7 +744,8 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
 | 
			
		||||
            args,
 | 
			
		||||
            self.extra_args,
 | 
			
		||||
        )
 | 
			
		||||
        stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
 | 
			
		||||
        current_stream = self.device_interface.current_stream()
 | 
			
		||||
        stream_ptr = c_void_p(current_stream.cuda_stream)  # type: ignore[attr-defined]
 | 
			
		||||
        run_method = getattr(self.DLL, self.kernel_name)
 | 
			
		||||
        workspace_ptr = c_void_p(0)
 | 
			
		||||
        if self.workspace_size > 0:
 | 
			
		||||
@ -781,7 +788,8 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
 | 
			
		||||
            dict.fromkeys(meta.name for meta in self.input_tensor_meta)
 | 
			
		||||
        )
 | 
			
		||||
        args = [c_void_p(None) for _ in range(unique_input_count + 1)]
 | 
			
		||||
        stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
 | 
			
		||||
        current_stream = self.device_interface.current_stream()
 | 
			
		||||
        stream_ptr = c_void_p(current_stream.cuda_stream)  # type: ignore[attr-defined]
 | 
			
		||||
 | 
			
		||||
        run_method = getattr(self.DLL, self.kernel_name)
 | 
			
		||||
        # Retrieve workspace_size and initialize workspace.
 | 
			
		||||
@ -795,7 +803,7 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
 | 
			
		||||
            None,  # null workspace ptr
 | 
			
		||||
            stream_ptr,
 | 
			
		||||
        )
 | 
			
		||||
        torch.cuda.synchronize()  # shake out any CUDA errors
 | 
			
		||||
        self.device_interface.synchronize()  # shake out any device errors
 | 
			
		||||
        self.workspace_size = c_workspace_size.value
 | 
			
		||||
        autotuning_log.debug(
 | 
			
		||||
            "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",  # noqa: B950
 | 
			
		||||
@ -811,7 +819,7 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
 | 
			
		||||
 | 
			
		||||
    def ensure_dll_loaded(self):
 | 
			
		||||
        if self.DLL is None:
 | 
			
		||||
            self.DLL, self.hash_key, self.source_file = CUDACodeCache.load(
 | 
			
		||||
            self.DLL, self.hash_key, self.source_file = self.codecache_cls.load(
 | 
			
		||||
                self.source_code, "so"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -34,7 +34,17 @@ from pathlib import Path
 | 
			
		||||
from tempfile import _TemporaryFileWrapper
 | 
			
		||||
from time import time, time_ns
 | 
			
		||||
from types import ModuleType
 | 
			
		||||
from typing import Any, Callable, cast, Generic, NoReturn, TYPE_CHECKING, TypeVar, Union
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    Callable,
 | 
			
		||||
    cast,
 | 
			
		||||
    Generic,
 | 
			
		||||
    NoReturn,
 | 
			
		||||
    Optional,
 | 
			
		||||
    TYPE_CHECKING,
 | 
			
		||||
    TypeVar,
 | 
			
		||||
    Union,
 | 
			
		||||
)
 | 
			
		||||
from typing_extensions import override, Self
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -54,7 +64,7 @@ from torch._inductor.codegen.common import (
 | 
			
		||||
    custom_backend_passes,
 | 
			
		||||
    init_backend_registration,
 | 
			
		||||
)
 | 
			
		||||
from torch._inductor.codegen.cuda import cuda_env
 | 
			
		||||
from torch._inductor.codegen.cuda import compile_utils as cuda_compile_utils
 | 
			
		||||
from torch._inductor.codegen.rocm.compile_command import (
 | 
			
		||||
    rocm_compile_command,
 | 
			
		||||
    rocm_compiler,
 | 
			
		||||
@ -64,7 +74,6 @@ from torch._inductor.cpp_builder import (
 | 
			
		||||
    _LINKER_SCRIPT,
 | 
			
		||||
    _set_gpu_runtime_env,
 | 
			
		||||
    _TORCH_PATH,
 | 
			
		||||
    _transform_cuda_paths,
 | 
			
		||||
    convert_cubin_to_obj,
 | 
			
		||||
    CppBuilder,
 | 
			
		||||
    CppOptions,
 | 
			
		||||
@ -119,10 +128,6 @@ from .triton_bundler import TritonBundler
 | 
			
		||||
from .virtualized import V
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if config.is_fbcode():
 | 
			
		||||
    from triton.fb.build import build_paths
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
T = TypeVar("T")
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -148,17 +153,6 @@ autotuning_log = torch._logging.getArtifactLogger(__name__, "autotuning")
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_re_build() -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Use for CUTLASS compilation only right now.
 | 
			
		||||
    """
 | 
			
		||||
    if config.is_fbcode() and not cuda_env.nvcc_exist(_cuda_compiler()):
 | 
			
		||||
        from triton.fb.re_build_helper import should_build_locally
 | 
			
		||||
 | 
			
		||||
        return not should_build_locally()
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_cpp_wrapper_cubin_path_name() -> str:
 | 
			
		||||
    return "cubin_path" if torch.version.hip is None else "hsaco_path"
 | 
			
		||||
 | 
			
		||||
@ -2383,10 +2377,10 @@ end
 | 
			
		||||
                        config.aot_inductor.emit_multi_arch_kernel
 | 
			
		||||
                        and device_type == "cuda"
 | 
			
		||||
                    ):
 | 
			
		||||
                        current_arch = _nvcc_arch_as_compile_option()
 | 
			
		||||
                        current_arch = cuda_compile_utils._nvcc_arch_as_compile_option()
 | 
			
		||||
                        cmd = (
 | 
			
		||||
                            # pyrefly: ignore [unbound-name]
 | 
			
		||||
                            f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} "
 | 
			
		||||
                            f"{cuda_compile_utils._cuda_compiler()} -fatbin {asm_file} -o {cubin_file} "
 | 
			
		||||
                            # Triton only allows generating PTX version as same as the current arch
 | 
			
		||||
                            f"-gencode arch=compute_{current_arch},code=compute_{current_arch} "
 | 
			
		||||
                            # Include SASS for the current specific arch
 | 
			
		||||
@ -3686,55 +3680,6 @@ def _load_triton_kernel_from_source(
 | 
			
		||||
    return getattr(PyCodeCache.load(source_code), kernel_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _cuda_compiler() -> str | None:
 | 
			
		||||
    if cuda_env.nvcc_exist(config.cuda.cuda_cxx):
 | 
			
		||||
        return config.cuda.cuda_cxx
 | 
			
		||||
    if config.is_fbcode():
 | 
			
		||||
        return os.path.join(build_paths.sdk_home, "bin", "nvcc")
 | 
			
		||||
    if cuda_env.nvcc_exist(os.getenv("CUDACXX")):
 | 
			
		||||
        return os.getenv("CUDACXX", "")
 | 
			
		||||
    if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")):
 | 
			
		||||
        return os.path.realpath(os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc"))
 | 
			
		||||
    return "nvcc"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _cutlass_path() -> str:
 | 
			
		||||
    if config.is_fbcode():
 | 
			
		||||
        from libfb.py import parutil
 | 
			
		||||
 | 
			
		||||
        return parutil.get_dir_path("cutlass-4-headers")
 | 
			
		||||
    else:
 | 
			
		||||
        return config.cuda.cutlass_dir
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _cutlass_paths() -> list[str]:
 | 
			
		||||
    return [
 | 
			
		||||
        "include",
 | 
			
		||||
        "tools/library/include",
 | 
			
		||||
        "tools/library/src",
 | 
			
		||||
        "tools/util/include",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _clone_cutlass_paths(build_root: str) -> list[str]:
 | 
			
		||||
    paths = _cutlass_paths()
 | 
			
		||||
    cutlass_root = _cutlass_path()
 | 
			
		||||
    for path in _cutlass_paths():
 | 
			
		||||
        old_path = os.path.join(cutlass_root, path)
 | 
			
		||||
        new_path = os.path.join(build_root, path)
 | 
			
		||||
        shutil.copytree(old_path, new_path, dirs_exist_ok=True)
 | 
			
		||||
    return paths
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _cutlass_include_paths() -> list[str]:
 | 
			
		||||
    cutlass_path = _cutlass_path()
 | 
			
		||||
    return [
 | 
			
		||||
        # Use realpath to get canonical absolute paths, in order not to mess up cache keys
 | 
			
		||||
        os.path.realpath(os.path.join(cutlass_path, path))
 | 
			
		||||
        for path in _cutlass_paths()
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch_key_cache
 | 
			
		||||
def cutlass_key() -> bytes:
 | 
			
		||||
    """
 | 
			
		||||
@ -3750,151 +3695,10 @@ def cutlass_key() -> bytes:
 | 
			
		||||
                return resource_file.read().encode()
 | 
			
		||||
 | 
			
		||||
    combined_hash = hashlib.sha256()
 | 
			
		||||
    build_code_hash([config.cuda.cutlass_dir], "", combined_hash)
 | 
			
		||||
    build_code_hash([config.cutlass.cutlass_dir], "", combined_hash)
 | 
			
		||||
    return combined_hash.digest()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _cuda_lib_options() -> list[str]:
 | 
			
		||||
    """
 | 
			
		||||
    Util function for CUTLASS backend to find the correct CUDA libraries.
 | 
			
		||||
    """
 | 
			
		||||
    _set_gpu_runtime_env()  # cpp_extension consults the env
 | 
			
		||||
    from torch.utils import cpp_extension
 | 
			
		||||
 | 
			
		||||
    lpaths = cpp_extension.library_paths(device_type="cuda")
 | 
			
		||||
    if use_re_build():
 | 
			
		||||
        lpaths += [
 | 
			
		||||
            build_paths.sdk_lib,
 | 
			
		||||
            os.path.join(build_paths.sdk_lib, "stubs"),
 | 
			
		||||
        ]
 | 
			
		||||
    extra_ldflags: list[str] = []
 | 
			
		||||
    if is_linux():
 | 
			
		||||
        _transform_cuda_paths(lpaths)
 | 
			
		||||
        for path in lpaths:
 | 
			
		||||
            if "torch/lib" in path:
 | 
			
		||||
                # don't want to depend on pytorch
 | 
			
		||||
                continue
 | 
			
		||||
            extra_ldflags.append(f"-L{path}")
 | 
			
		||||
            # -rpath ensures the DLL can find its dependencies when loaded, even
 | 
			
		||||
            # if the library path is non-standard.
 | 
			
		||||
            # But do not add the stubs folder to rpath as the driver is expected to be found at runtime
 | 
			
		||||
            if os.path.basename(path) != "stubs":
 | 
			
		||||
                extra_ldflags.extend(["-Xlinker", f"-rpath={path}"])
 | 
			
		||||
        extra_ldflags.append("-lcuda")
 | 
			
		||||
        extra_ldflags.append("-lcudart")
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError(
 | 
			
		||||
            "Unsupported env, failed to find cuda libs! Currently only Linux is supported."
 | 
			
		||||
        )
 | 
			
		||||
    return extra_ldflags
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _nvcc_host_compiler_options() -> list[str]:
 | 
			
		||||
    return [
 | 
			
		||||
        "-fPIC",
 | 
			
		||||
        "-fno-strict-aliasing",
 | 
			
		||||
        "-fvisibility=hidden",
 | 
			
		||||
        "-Wconversion",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _nvcc_arch_as_compile_option() -> str:
 | 
			
		||||
    arch = cuda_env.get_cuda_arch()
 | 
			
		||||
    if arch == "90":
 | 
			
		||||
        # Required by cutlass compilation.
 | 
			
		||||
        return "90a"
 | 
			
		||||
    if arch == "100":
 | 
			
		||||
        return "100a"
 | 
			
		||||
    return arch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _nvcc_compiler_options() -> list[str]:
 | 
			
		||||
    arch = _nvcc_arch_as_compile_option()
 | 
			
		||||
    code = [f"sm_{arch}", f"compute_{arch}"]
 | 
			
		||||
    if config.cuda.enable_cuda_lto:
 | 
			
		||||
        code += [f"lto_{arch}"]
 | 
			
		||||
    options = [
 | 
			
		||||
        "-t=0",
 | 
			
		||||
        "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
 | 
			
		||||
        "-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1",
 | 
			
		||||
        "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
 | 
			
		||||
        "-w",
 | 
			
		||||
        f"-gencode=arch=compute_{arch},code=[{','.join(code)}]",
 | 
			
		||||
        config.cuda.compile_opt_level,
 | 
			
		||||
        "-std=c++17",
 | 
			
		||||
        "--expt-relaxed-constexpr",
 | 
			
		||||
        "-DNDEBUG",
 | 
			
		||||
    ]
 | 
			
		||||
    if config.is_fbcode():
 | 
			
		||||
        options.extend(["-ccbin", os.path.dirname(build_paths.gcc)])
 | 
			
		||||
    if config.cuda.enable_debug_info:
 | 
			
		||||
        options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"])
 | 
			
		||||
    if config.cuda.enable_ptxas_info:
 | 
			
		||||
        options.extend(
 | 
			
		||||
            [
 | 
			
		||||
                "--keep",  # Keep the intermediate files for debugging (including ptx, sass, cubin etc.)
 | 
			
		||||
                "--ptxas-options=--warn-on-local-memory-usage",  # warn us if local memory is used in CUDA Kernels
 | 
			
		||||
                "--ptxas-options=--warn-on-spills",  # warn us if register spilling happens in CUDA Kernels
 | 
			
		||||
                "--resource-usage",  # Report on CUDA resource usage (shared mem, registers etc.)
 | 
			
		||||
                "--source-in-ptx",
 | 
			
		||||
            ]
 | 
			
		||||
        )  # Annotate the ptx file with source information
 | 
			
		||||
    if config.cuda.use_fast_math:
 | 
			
		||||
        options.extend(
 | 
			
		||||
            [
 | 
			
		||||
                "--use_fast_math",
 | 
			
		||||
                "-DCUTLASS_USE_TANH_FOR_SIGMOID=1",
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
    return options
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cuda_compile_command(
 | 
			
		||||
    src_files: list[str],
 | 
			
		||||
    dst_file: str,
 | 
			
		||||
    dst_file_ext: str,
 | 
			
		||||
    extra_args: list[str] | None = None,
 | 
			
		||||
) -> str:
 | 
			
		||||
    if extra_args is None:
 | 
			
		||||
        extra_args = []
 | 
			
		||||
    if use_re_build():
 | 
			
		||||
        build_path = os.path.dirname(dst_file)
 | 
			
		||||
        include_paths = _clone_cutlass_paths(build_path)
 | 
			
		||||
        src_files = [os.path.basename(src_file) for src_file in src_files]
 | 
			
		||||
        dst_file = os.path.basename(dst_file)
 | 
			
		||||
    else:
 | 
			
		||||
        include_paths = _cutlass_include_paths()
 | 
			
		||||
    cuda_lib_options = _cuda_lib_options()
 | 
			
		||||
    nvcc_host_compiler_options = _nvcc_host_compiler_options()
 | 
			
		||||
    nvcc_compiler_options = _nvcc_compiler_options()
 | 
			
		||||
    options = (
 | 
			
		||||
        nvcc_compiler_options
 | 
			
		||||
        + extra_args
 | 
			
		||||
        + [
 | 
			
		||||
            f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}"
 | 
			
		||||
            for opt in nvcc_host_compiler_options
 | 
			
		||||
        ]
 | 
			
		||||
        + ["-I" + path for path in include_paths]
 | 
			
		||||
        + cuda_lib_options
 | 
			
		||||
    )
 | 
			
		||||
    src_file = " ".join(src_files)
 | 
			
		||||
    res = ""
 | 
			
		||||
    if dst_file_ext == "o":
 | 
			
		||||
        res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}"
 | 
			
		||||
    elif dst_file_ext == "so":
 | 
			
		||||
        options.append("-shared")
 | 
			
		||||
        res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
 | 
			
		||||
    elif dst_file_ext == "exe":
 | 
			
		||||
        res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!")
 | 
			
		||||
    if log.isEnabledFor(logging.DEBUG):
 | 
			
		||||
        log.debug("CUDA command: %s", res)
 | 
			
		||||
    else:
 | 
			
		||||
        autotuning_log.debug("CUDA command: %s", res)
 | 
			
		||||
    return res
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DLLWrapper:
 | 
			
		||||
    """A wrapper for a dynamic library."""
 | 
			
		||||
 | 
			
		||||
@ -3978,10 +3782,9 @@ def binary_error_path(output_path: str) -> str:
 | 
			
		||||
    return output_path + ".error"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@clear_on_fresh_cache
 | 
			
		||||
class CUDACodeCache:
 | 
			
		||||
class CUTLASSCodeCache:
 | 
			
		||||
    """
 | 
			
		||||
    A cache for managing the compilation and loading of CUDA source code specifically for CUTLASS.
 | 
			
		||||
    A cache for managing the compilation and loading source code specifically for CUTLASS.
 | 
			
		||||
    This class handles writing source code to files, compiling them into shared objects, and caching
 | 
			
		||||
    the results to avoid redundant compilations. It also manages error handling and logging for the
 | 
			
		||||
    compilation process.
 | 
			
		||||
@ -3995,12 +3798,14 @@ class CUDACodeCache:
 | 
			
		||||
 | 
			
		||||
    cache: dict[str, CacheEntry] = {}
 | 
			
		||||
    aot_kernels_o: list[str] = []
 | 
			
		||||
    _SOURCE_CODE_SUFFIX = "cu"
 | 
			
		||||
 | 
			
		||||
    _SOURCE_CODE_SUFFIX: str = ""
 | 
			
		||||
    _BACKEND: str = ""
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def cache_clear() -> None:
 | 
			
		||||
        CUDACodeCache.cache.clear()
 | 
			
		||||
        CUDACodeCache.aot_kernels_o.clear()
 | 
			
		||||
        CUTLASSCodeCache.cache.clear()
 | 
			
		||||
        CUTLASSCodeCache.aot_kernels_o.clear()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @lru_cache(maxsize=4)
 | 
			
		||||
@ -4035,6 +3840,24 @@ class CUDACodeCache:
 | 
			
		||||
            )
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _use_re_build(cls) -> bool:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _compile_command(
 | 
			
		||||
        cls,
 | 
			
		||||
        src_files: list[str],
 | 
			
		||||
        dst_file: str,
 | 
			
		||||
        dst_file_ext: str,
 | 
			
		||||
        extra_args: Optional[list[str]] = None,
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _source_code_extra(cls) -> str:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @lru_cache(None)
 | 
			
		||||
    def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]:
 | 
			
		||||
@ -4043,25 +3866,14 @@ class CUDACodeCache:
 | 
			
		||||
        Returns the hash key of source code, and the path to the file.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        if config.cuda.cutlass_hash_with_compile_cmd:
 | 
			
		||||
            cuda_command = repr(
 | 
			
		||||
                cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext)
 | 
			
		||||
        if config.cutlass.cutlass_hash_with_compile_cmd:
 | 
			
		||||
            compile_command = repr(
 | 
			
		||||
                cls._compile_command(["dummy_input"], "dummy_output", dst_file_ext)
 | 
			
		||||
            )
 | 
			
		||||
            extra = cuda_command
 | 
			
		||||
            extra = compile_command
 | 
			
		||||
        else:
 | 
			
		||||
            extra = repr(
 | 
			
		||||
                [
 | 
			
		||||
                    # nvcc and cuda hash
 | 
			
		||||
                    _cuda_compiler(),
 | 
			
		||||
                    # cutlass flags and gcc hash
 | 
			
		||||
                    _nvcc_compiler_options(),
 | 
			
		||||
                    # flags
 | 
			
		||||
                    _nvcc_host_compiler_options(),
 | 
			
		||||
                    # cutlass key
 | 
			
		||||
                    cutlass_key(),
 | 
			
		||||
                    # hack to deal with AOTI .o compilation
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            extra = cls._source_code_extra()
 | 
			
		||||
 | 
			
		||||
        key, input_path = write(source_code, cls._SOURCE_CODE_SUFFIX, extra=extra)
 | 
			
		||||
        return key, input_path
 | 
			
		||||
 | 
			
		||||
@ -4094,7 +3906,7 @@ class CUDACodeCache:
 | 
			
		||||
                output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext
 | 
			
		||||
                error_path = binary_error_path(output_path)
 | 
			
		||||
                binary_remote_cache = cls.get_kernel_binary_remote_cache(
 | 
			
		||||
                    caching_enabled=config.cuda.use_binary_remote_cache
 | 
			
		||||
                    caching_enabled=config.cutlass.use_binary_remote_cache
 | 
			
		||||
                    and not config.force_disable_caches,
 | 
			
		||||
                    caching_available=config.is_fbcode(),
 | 
			
		||||
                )
 | 
			
		||||
@ -4109,30 +3921,30 @@ class CUDACodeCache:
 | 
			
		||||
                    cmd_parts, error_output = json.loads(error_json)
 | 
			
		||||
                    if (
 | 
			
		||||
                        binary_remote_cache is not None
 | 
			
		||||
                        and config.cuda.upload_to_binary_remote_cache
 | 
			
		||||
                        and config.cutlass.upload_to_binary_remote_cache
 | 
			
		||||
                    ):
 | 
			
		||||
                        # This ensures that a local error is uploaded to the remote cache,
 | 
			
		||||
                        # as we make no assumptions about the remote cache having the same
 | 
			
		||||
                        # information as the local cache
 | 
			
		||||
                        binary_remote_cache.put(
 | 
			
		||||
                            error_path, config.cuda.binary_remote_cache_force_write
 | 
			
		||||
                            error_path, config.cutlass.binary_remote_cache_force_write
 | 
			
		||||
                        )
 | 
			
		||||
                    cls.cache[key_with_ext] = CUDACodeCache.CacheEntry(
 | 
			
		||||
                    cls.cache[key_with_ext] = CUTLASSCodeCache.CacheEntry(
 | 
			
		||||
                        input_path, output_path, error_json
 | 
			
		||||
                    )
 | 
			
		||||
                    raise exc.CUDACompileError(cmd_parts, error_output)
 | 
			
		||||
                if not os.path.exists(output_path):
 | 
			
		||||
                    cmd = cuda_compile_command(
 | 
			
		||||
                    cmd = cls._compile_command(
 | 
			
		||||
                        src_files, output_path, dst_file_ext, extra_args
 | 
			
		||||
                    )
 | 
			
		||||
                    with open(input_path, "a") as f:
 | 
			
		||||
                        f.write("\n")
 | 
			
		||||
                        f.write(f"// CUDA {operation_name} cmd\n// {cmd}\n")
 | 
			
		||||
                        f.write(f"// {cls._BACKEND} {operation_name} cmd\n// {cmd}\n")
 | 
			
		||||
                    start_time = time()
 | 
			
		||||
                    log.debug("CUDA %s: %s", operation_name, cmd)
 | 
			
		||||
                    log.debug("%s %s: %s", cls._BACKEND, operation_name, cmd)
 | 
			
		||||
                    cmd_parts = cmd.split(" ")
 | 
			
		||||
                    try:
 | 
			
		||||
                        if use_re_build():
 | 
			
		||||
                        if cls._use_re_build():
 | 
			
		||||
                            from triton.fb.re_build_helper import run_build_command
 | 
			
		||||
 | 
			
		||||
                            run_build_command(
 | 
			
		||||
@ -4145,7 +3957,7 @@ class CUDACodeCache:
 | 
			
		||||
                                cmd_parts, stderr=subprocess.STDOUT, env=os.environ
 | 
			
		||||
                            )
 | 
			
		||||
                    except subprocess.CalledProcessError as error:
 | 
			
		||||
                        cls._record_cuda_compile_error(
 | 
			
		||||
                        cls._record_compile_error(
 | 
			
		||||
                            error.output.decode("utf-8"),
 | 
			
		||||
                            key_with_ext,
 | 
			
		||||
                            cmd_parts,
 | 
			
		||||
@ -4156,7 +3968,7 @@ class CUDACodeCache:
 | 
			
		||||
                        raise exc.CUDACompileError(cmd_parts, error.output) from error
 | 
			
		||||
                    except Exception as error:
 | 
			
		||||
                        if "COMPILE FAILED WITH" in str(error):
 | 
			
		||||
                            cls._record_cuda_compile_error(
 | 
			
		||||
                            cls._record_compile_error(
 | 
			
		||||
                                str(error),
 | 
			
		||||
                                key_with_ext,
 | 
			
		||||
                                cmd_parts,
 | 
			
		||||
@ -4167,29 +3979,30 @@ class CUDACodeCache:
 | 
			
		||||
                            raise exc.CUDACompileError(cmd_parts, str(error)) from error
 | 
			
		||||
                        raise error
 | 
			
		||||
                    end_time = time()
 | 
			
		||||
                    log_duration_msg = f"CUDA {operation_name} took {end_time - start_time} seconds. Command: {cmd}"
 | 
			
		||||
                    log_duration_msg = f"{cls._BACKEND} {operation_name} took {end_time - start_time} seconds. Command: {cmd}"
 | 
			
		||||
                    log.info(log_duration_msg)
 | 
			
		||||
 | 
			
		||||
                else:
 | 
			
		||||
                    log.debug(
 | 
			
		||||
                        "CUDA %s skipped: %s since output already exists",
 | 
			
		||||
                        "%s %s skipped: %s since output already exists",
 | 
			
		||||
                        cls._BACKEND,
 | 
			
		||||
                        operation_name,
 | 
			
		||||
                        output_path,
 | 
			
		||||
                    )
 | 
			
		||||
                # Upload to remote cache if enabled
 | 
			
		||||
                if (
 | 
			
		||||
                    binary_remote_cache is not None
 | 
			
		||||
                    and config.cuda.upload_to_binary_remote_cache
 | 
			
		||||
                    and config.cutlass.upload_to_binary_remote_cache
 | 
			
		||||
                ):
 | 
			
		||||
                    # will log on errors, but not fail out
 | 
			
		||||
                    binary_remote_cache.put(
 | 
			
		||||
                        output_path, config.cuda.binary_remote_cache_force_write
 | 
			
		||||
                        output_path, config.cutlass.binary_remote_cache_force_write
 | 
			
		||||
                    )
 | 
			
		||||
                cls.cache[key_with_ext] = CUDACodeCache.CacheEntry(
 | 
			
		||||
                cls.cache[key_with_ext] = CUTLASSCodeCache.CacheEntry(
 | 
			
		||||
                    input_path, output_path, None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        cache_entry: CUDACodeCache.CacheEntry = cls.cache[key_with_ext]
 | 
			
		||||
        cache_entry: CUTLASSCodeCache.CacheEntry = cls.cache[key_with_ext]
 | 
			
		||||
        if cache_entry.error_json is not None:
 | 
			
		||||
            # Restore cached Exception and raise it as if we had compiled
 | 
			
		||||
            cmd_parts, error_output = json.loads(cache_entry.error_json)
 | 
			
		||||
@ -4214,7 +4027,7 @@ class CUDACodeCache:
 | 
			
		||||
        return (DLLWrapper(dst_file_path), hash_key, source_code_path)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _record_cuda_compile_error(
 | 
			
		||||
    def _record_compile_error(
 | 
			
		||||
        cls,
 | 
			
		||||
        error_str: str,
 | 
			
		||||
        key_with_ext: str,
 | 
			
		||||
@ -4226,7 +4039,7 @@ class CUDACodeCache:
 | 
			
		||||
        binary_remote_cache: Any = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        error_json = json.dumps([cmd_parts, error_str])
 | 
			
		||||
        cls.cache[key_with_ext] = CUDACodeCache.CacheEntry(
 | 
			
		||||
        cls.cache[key_with_ext] = CUTLASSCodeCache.CacheEntry(
 | 
			
		||||
            input_path, output_path, error_json
 | 
			
		||||
        )
 | 
			
		||||
        error_path = binary_error_path(output_path)
 | 
			
		||||
@ -4236,13 +4049,52 @@ class CUDACodeCache:
 | 
			
		||||
        # Upload to remote cache directly from memory if enabled
 | 
			
		||||
        if (
 | 
			
		||||
            binary_remote_cache is not None
 | 
			
		||||
            and config.cuda.upload_to_binary_remote_cache
 | 
			
		||||
            and config.cutlass.upload_to_binary_remote_cache
 | 
			
		||||
        ):
 | 
			
		||||
            binary_remote_cache.put(
 | 
			
		||||
                error_path, config.cuda.binary_remote_cache_force_write
 | 
			
		||||
                error_path, config.cutlass.binary_remote_cache_force_write
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@clear_on_fresh_cache
 | 
			
		||||
class CUDACodeCache(CUTLASSCodeCache):
 | 
			
		||||
    _SOURCE_CODE_SUFFIX = "cu"
 | 
			
		||||
    _BACKEND = "CUDA"
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _use_re_build(cls) -> bool:
 | 
			
		||||
        return cuda_compile_utils.use_re_build()
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _compile_command(
 | 
			
		||||
        cls,
 | 
			
		||||
        src_files: list[str],
 | 
			
		||||
        dst_file: str,
 | 
			
		||||
        dst_file_ext: str,
 | 
			
		||||
        extra_args: Optional[list[str]] = None,
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        return cuda_compile_utils.cuda_compile_command(
 | 
			
		||||
            src_files, dst_file, dst_file_ext, extra_args=extra_args
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _source_code_extra(cls) -> str:
 | 
			
		||||
        extra = repr(
 | 
			
		||||
            [
 | 
			
		||||
                # nvcc and cuda hash
 | 
			
		||||
                cuda_compile_utils._cuda_compiler(),
 | 
			
		||||
                # cutlass flags and gcc hash
 | 
			
		||||
                cuda_compile_utils._nvcc_compiler_options(),
 | 
			
		||||
                # flags
 | 
			
		||||
                cuda_compile_utils._nvcc_host_compiler_options(),
 | 
			
		||||
                # cutlass key
 | 
			
		||||
                cutlass_key(),
 | 
			
		||||
                # hack to deal with AOTI .o compilation
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        return extra
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@clear_on_fresh_cache
 | 
			
		||||
class ROCmCodeCache:
 | 
			
		||||
    @dataclasses.dataclass
 | 
			
		||||
 | 
			
		||||
@ -10,8 +10,8 @@ from ..scheduler import (
 | 
			
		||||
    Scheduler,
 | 
			
		||||
    SchedulerNode,
 | 
			
		||||
)
 | 
			
		||||
from .cuda.cuda_cpp_scheduling import CUDACPPScheduling
 | 
			
		||||
from .cutedsl.cutedsl_scheduling import CuteDSLScheduling
 | 
			
		||||
from .cutlass.scheduling import CUTLASSScheduling
 | 
			
		||||
from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling
 | 
			
		||||
from .triton import TritonScheduling
 | 
			
		||||
 | 
			
		||||
@ -30,7 +30,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    _IntLike: TypeAlias = Union[int, Expr]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CUDACombinedScheduling(BaseScheduling):
 | 
			
		||||
class CombinedScheduling(BaseScheduling):
 | 
			
		||||
    """
 | 
			
		||||
    Scheduler for CUDA Kernels, which delegates calls as appropriate
 | 
			
		||||
    to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices
 | 
			
		||||
@ -43,7 +43,7 @@ class CUDACombinedScheduling(BaseScheduling):
 | 
			
		||||
    def __init__(self, scheduler: Optional[Scheduler]) -> None:
 | 
			
		||||
        super().__init__(scheduler)
 | 
			
		||||
        self._triton_scheduling = TritonScheduling(scheduler)
 | 
			
		||||
        self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler)
 | 
			
		||||
        self._cutlass_scheduling = CUTLASSScheduling(scheduler)
 | 
			
		||||
        self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler)
 | 
			
		||||
        self._cutedsl_scheduling = CuteDSLScheduling(scheduler)
 | 
			
		||||
 | 
			
		||||
@ -51,8 +51,8 @@ class CUDACombinedScheduling(BaseScheduling):
 | 
			
		||||
        return self._triton_scheduling.get_backend_features(device)
 | 
			
		||||
 | 
			
		||||
    def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling:
 | 
			
		||||
        if self._cuda_cpp_scheduling.is_cuda_cpp_template(node):
 | 
			
		||||
            return self._cuda_cpp_scheduling
 | 
			
		||||
        if self._cutlass_scheduling.is_cutlass_template(node):
 | 
			
		||||
            return self._cutlass_scheduling
 | 
			
		||||
        if self._rocm_cpp_scheduling.is_rocm_cpp_template(node):
 | 
			
		||||
            return self._rocm_cpp_scheduling
 | 
			
		||||
        if self._cutedsl_scheduling.is_cutedsl_template(node):
 | 
			
		||||
@ -62,11 +62,11 @@ class CUDACombinedScheduling(BaseScheduling):
 | 
			
		||||
    def can_fuse_vertical(
 | 
			
		||||
        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
 | 
			
		||||
    ) -> bool:
 | 
			
		||||
        if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2):
 | 
			
		||||
        if self._cutlass_scheduling.can_fuse_vertical(node1, node2):
 | 
			
		||||
            return True
 | 
			
		||||
        elif self._cuda_cpp_scheduling.is_cuda_cpp_template(
 | 
			
		||||
        elif self._cutlass_scheduling.is_cutlass_template(
 | 
			
		||||
            node1
 | 
			
		||||
        ) or self._cuda_cpp_scheduling.is_cuda_cpp_template(node2):
 | 
			
		||||
        ) or self._cutlass_scheduling.is_cutlass_template(node2):
 | 
			
		||||
            return False
 | 
			
		||||
        # CuteDSL doesn't support vertical fusion currently
 | 
			
		||||
        elif self._cutedsl_scheduling.is_cutedsl_template(
 | 
			
		||||
@ -79,8 +79,8 @@ class CUDACombinedScheduling(BaseScheduling):
 | 
			
		||||
        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
 | 
			
		||||
    ) -> bool:
 | 
			
		||||
        for node in (node1, node2):
 | 
			
		||||
            if self._cuda_cpp_scheduling.is_cuda_cpp_template(node):
 | 
			
		||||
                return self._cuda_cpp_scheduling.can_fuse_horizontal(
 | 
			
		||||
            if self._cutlass_scheduling.is_cutlass_template(node):
 | 
			
		||||
                return self._cutlass_scheduling.can_fuse_horizontal(
 | 
			
		||||
                    node1, node2
 | 
			
		||||
                )  # always False at the moment
 | 
			
		||||
            if self._cutedsl_scheduling.is_cutedsl_template(node):
 | 
			
		||||
@ -100,9 +100,9 @@ class CUDACombinedScheduling(BaseScheduling):
 | 
			
		||||
        epilogue_nodes: Sequence[BaseSchedulerNode],
 | 
			
		||||
        prologue_nodes: Sequence[BaseSchedulerNode],
 | 
			
		||||
    ) -> Optional[str]:
 | 
			
		||||
        if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
 | 
			
		||||
        if self._cutlass_scheduling.is_cutlass_template(template_node):
 | 
			
		||||
            assert not prologue_nodes
 | 
			
		||||
            return self._cuda_cpp_scheduling.codegen_template(
 | 
			
		||||
            return self._cutlass_scheduling.codegen_template(
 | 
			
		||||
                template_node, epilogue_nodes, prologue_nodes
 | 
			
		||||
            )
 | 
			
		||||
        elif self._rocm_cpp_scheduling.is_rocm_cpp_template(template_node):
 | 
			
		||||
@ -495,12 +495,12 @@ def init_backend_registration() -> None:
 | 
			
		||||
    Register the backend for different devices, including the scheduling
 | 
			
		||||
    for kernel code generation and the host side wrapper code generation.
 | 
			
		||||
    """
 | 
			
		||||
    from .combined_scheduling import CombinedScheduling
 | 
			
		||||
    from .cpp import CppScheduling
 | 
			
		||||
    from .cpp_wrapper_cpu import CppWrapperCpu
 | 
			
		||||
    from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef
 | 
			
		||||
    from .cpp_wrapper_gpu import CppWrapperGpu
 | 
			
		||||
    from .cpp_wrapper_mps import CppWrapperMps
 | 
			
		||||
    from .cuda_combined_scheduling import CUDACombinedScheduling
 | 
			
		||||
    from .halide import HalideScheduling
 | 
			
		||||
    from .mps import MetalScheduling
 | 
			
		||||
    from .python_wrapper_mtia import PythonWrapperMtia
 | 
			
		||||
@ -525,9 +525,9 @@ def init_backend_registration() -> None:
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if get_scheduling_for_device("cuda") is None:
 | 
			
		||||
        # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
 | 
			
		||||
        # CombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
 | 
			
		||||
        cuda_backends = {
 | 
			
		||||
            "triton": CUDACombinedScheduling,
 | 
			
		||||
            "triton": CombinedScheduling,
 | 
			
		||||
            "halide": HalideScheduling,
 | 
			
		||||
        }
 | 
			
		||||
        register_backend_for_device(
 | 
			
		||||
@ -2366,7 +2366,7 @@ class KernelTemplate:
 | 
			
		||||
    """
 | 
			
		||||
    Base class for defining kernel templates.
 | 
			
		||||
 | 
			
		||||
    Children classes: TritonTemplate, CUDATemplate
 | 
			
		||||
    Children classes: TritonTemplate, CUTLASSTemplate
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -2618,12 +2618,12 @@ class CSEProxy(DefaultHandler):
 | 
			
		||||
        """
 | 
			
		||||
        from ..bounds import ValueRangeAnalysis
 | 
			
		||||
        from ..select_algorithm import TritonTemplateKernel
 | 
			
		||||
        from .cuda.cuda_kernel import CUDATemplateKernel
 | 
			
		||||
        from .cutlass.kernel import CUTLASSTemplateKernel
 | 
			
		||||
 | 
			
		||||
        if isinstance(V.kernel, TritonTemplateKernel):
 | 
			
		||||
            return ValueRanges.unknown()
 | 
			
		||||
 | 
			
		||||
        if isinstance(V.kernel, CUDATemplateKernel):
 | 
			
		||||
        if isinstance(V.kernel, CUTLASSTemplateKernel):
 | 
			
		||||
            return ValueRanges.unknown()
 | 
			
		||||
 | 
			
		||||
        fx_node = V.interpreter.current_node
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										264
									
								
								torch/_inductor/codegen/cuda/compile_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										264
									
								
								torch/_inductor/codegen/cuda/compile_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,264 @@
 | 
			
		||||
# mypy: allow-untyped-defs
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import shutil
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch._inductor import config
 | 
			
		||||
from torch._inductor.codegen.cuda import cuda_env
 | 
			
		||||
from torch._inductor.cpp_builder import _set_gpu_runtime_env, _transform_cuda_paths
 | 
			
		||||
from torch._inductor.utils import is_linux
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if config.is_fbcode():
 | 
			
		||||
    from triton.fb.build import build_paths
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
autotuning_log = torch._logging.getArtifactLogger(__name__, "autotuning")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_re_build() -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Use for CUTLASS compilation only right now.
 | 
			
		||||
    """
 | 
			
		||||
    if config.is_fbcode() and not cuda_env.nvcc_exist(_cuda_compiler()):
 | 
			
		||||
        from triton.fb.re_build_helper import should_build_locally
 | 
			
		||||
 | 
			
		||||
        return not should_build_locally()
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _cutlass_path() -> str:
 | 
			
		||||
    if config.is_fbcode():
 | 
			
		||||
        from libfb.py import parutil
 | 
			
		||||
 | 
			
		||||
        return parutil.get_dir_path("cutlass-4-headers")
 | 
			
		||||
    else:
 | 
			
		||||
        return config.cutlass.cutlass_dir
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _cutlass_paths() -> list[str]:
 | 
			
		||||
    return [
 | 
			
		||||
        "include",
 | 
			
		||||
        "tools/library/include",
 | 
			
		||||
        "tools/library/src",
 | 
			
		||||
        "tools/util/include",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _clone_cutlass_paths(build_root: str) -> list[str]:
 | 
			
		||||
    paths = _cutlass_paths()
 | 
			
		||||
    cutlass_root = _cutlass_path()
 | 
			
		||||
    for path in _cutlass_paths():
 | 
			
		||||
        old_path = os.path.join(cutlass_root, path)
 | 
			
		||||
        new_path = os.path.join(build_root, path)
 | 
			
		||||
        shutil.copytree(old_path, new_path, dirs_exist_ok=True)
 | 
			
		||||
    return paths
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _cutlass_include_paths() -> list[str]:
 | 
			
		||||
    cutlass_path = _cutlass_path()
 | 
			
		||||
    return [
 | 
			
		||||
        # Use realpath to get canonical absolute paths, in order not to mess up cache keys
 | 
			
		||||
        os.path.realpath(os.path.join(cutlass_path, path))
 | 
			
		||||
        for path in _cutlass_paths()
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _cuda_compiler() -> Optional[str]:
 | 
			
		||||
    if cuda_env.nvcc_exist(config.cutlass.cuda_cxx):
 | 
			
		||||
        return config.cutlass.cuda_cxx
 | 
			
		||||
    if config.is_fbcode():
 | 
			
		||||
        return os.path.join(build_paths.sdk_home, "bin", "nvcc")
 | 
			
		||||
    if cuda_env.nvcc_exist(os.getenv("CUDACXX")):
 | 
			
		||||
        return os.getenv("CUDACXX", "")
 | 
			
		||||
    if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")):
 | 
			
		||||
        return os.path.realpath(os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc"))
 | 
			
		||||
    return "nvcc"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _cuda_lib_options() -> list[str]:
 | 
			
		||||
    """
 | 
			
		||||
    Util function for CUTLASS backend to find the correct CUDA libraries.
 | 
			
		||||
    """
 | 
			
		||||
    _set_gpu_runtime_env()  # cpp_extension consults the env
 | 
			
		||||
    from torch.utils import cpp_extension
 | 
			
		||||
 | 
			
		||||
    lpaths = cpp_extension.library_paths(device_type="cuda")
 | 
			
		||||
    if use_re_build():
 | 
			
		||||
        lpaths += [
 | 
			
		||||
            build_paths.sdk_lib,
 | 
			
		||||
            os.path.join(build_paths.sdk_lib, "stubs"),
 | 
			
		||||
        ]
 | 
			
		||||
    extra_ldflags: list[str] = []
 | 
			
		||||
    if is_linux():
 | 
			
		||||
        _transform_cuda_paths(lpaths)
 | 
			
		||||
        for path in lpaths:
 | 
			
		||||
            if "torch/lib" in path:
 | 
			
		||||
                # don't want to depend on pytorch
 | 
			
		||||
                continue
 | 
			
		||||
            extra_ldflags.append(f"-L{path}")
 | 
			
		||||
            # -rpath ensures the DLL can find its dependencies when loaded, even
 | 
			
		||||
            # if the library path is non-standard.
 | 
			
		||||
            # But do not add the stubs folder to rpath as the driver is expected to be found at runtime
 | 
			
		||||
            if os.path.basename(path) != "stubs":
 | 
			
		||||
                extra_ldflags.extend(["-Xlinker", f"-rpath={path}"])
 | 
			
		||||
        extra_ldflags.append("-lcuda")
 | 
			
		||||
        extra_ldflags.append("-lcudart")
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError(
 | 
			
		||||
            "Unsupported env, failed to find cuda libs! Currently only Linux is supported."
 | 
			
		||||
        )
 | 
			
		||||
    return extra_ldflags
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _nvcc_host_compiler_options() -> list[str]:
 | 
			
		||||
    return [
 | 
			
		||||
        "-fPIC",
 | 
			
		||||
        "-fno-strict-aliasing",
 | 
			
		||||
        "-fvisibility=hidden",
 | 
			
		||||
        "-Wconversion",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _nvcc_arch_as_compile_option() -> str:
 | 
			
		||||
    arch = cuda_env.get_cuda_arch()
 | 
			
		||||
    if arch == "90":
 | 
			
		||||
        # Required by cutlass compilation.
 | 
			
		||||
        return "90a"
 | 
			
		||||
    if arch == "100":
 | 
			
		||||
        return "100a"
 | 
			
		||||
    return arch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _nvcc_compiler_options() -> list[str]:
 | 
			
		||||
    arch = _nvcc_arch_as_compile_option()
 | 
			
		||||
    code = [f"sm_{arch}", f"compute_{arch}"]
 | 
			
		||||
    if config.cutlass.enable_cuda_lto:
 | 
			
		||||
        code += [f"lto_{arch}"]
 | 
			
		||||
    options = [
 | 
			
		||||
        "-t=0",
 | 
			
		||||
        "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
 | 
			
		||||
        "-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1",
 | 
			
		||||
        "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
 | 
			
		||||
        "-w",
 | 
			
		||||
        f"-gencode=arch=compute_{arch},code=[{','.join(code)}]",
 | 
			
		||||
        config.cutlass.compile_opt_level,
 | 
			
		||||
        "-std=c++17",
 | 
			
		||||
        "--expt-relaxed-constexpr",
 | 
			
		||||
        "-DNDEBUG",
 | 
			
		||||
    ]
 | 
			
		||||
    if config.is_fbcode():
 | 
			
		||||
        options.extend(["-ccbin", os.path.dirname(build_paths.gcc)])
 | 
			
		||||
    if config.cutlass.enable_debug_info:
 | 
			
		||||
        options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"])
 | 
			
		||||
    if config.cutlass.enable_ptxas_info:
 | 
			
		||||
        options.extend(
 | 
			
		||||
            [
 | 
			
		||||
                "--keep",  # Keep the intermediate files for debugging (including ptx, sass, cubin etc.)
 | 
			
		||||
                "--ptxas-options=--warn-on-local-memory-usage",  # warn us if local memory is used in CUDA Kernels
 | 
			
		||||
                "--ptxas-options=--warn-on-spills",  # warn us if register spilling happens in CUDA Kernels
 | 
			
		||||
                "--resource-usage",  # Report on CUDA resource usage (shared mem, registers etc.)
 | 
			
		||||
                "--source-in-ptx",
 | 
			
		||||
            ]
 | 
			
		||||
        )  # Annotate the ptx file with source information
 | 
			
		||||
    if config.cutlass.use_fast_math:
 | 
			
		||||
        options.extend(
 | 
			
		||||
            [
 | 
			
		||||
                "--use_fast_math",
 | 
			
		||||
                "-DCUTLASS_USE_TANH_FOR_SIGMOID=1",
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
    return options
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cuda_compile_command(
 | 
			
		||||
    src_files: list[str],
 | 
			
		||||
    dst_file: str,
 | 
			
		||||
    dst_file_ext: str,
 | 
			
		||||
    extra_args: list[str] | None = None,
 | 
			
		||||
) -> str:
 | 
			
		||||
    if extra_args is None:
 | 
			
		||||
        extra_args = []
 | 
			
		||||
    if use_re_build():
 | 
			
		||||
        build_path = os.path.dirname(dst_file)
 | 
			
		||||
        include_paths = _clone_cutlass_paths(build_path)
 | 
			
		||||
        src_files = [os.path.basename(src_file) for src_file in src_files]
 | 
			
		||||
        dst_file = os.path.basename(dst_file)
 | 
			
		||||
    else:
 | 
			
		||||
        include_paths = _cutlass_include_paths()
 | 
			
		||||
    cuda_lib_options = _cuda_lib_options()
 | 
			
		||||
    nvcc_host_compiler_options = _nvcc_host_compiler_options()
 | 
			
		||||
    nvcc_compiler_options = _nvcc_compiler_options()
 | 
			
		||||
    options = (
 | 
			
		||||
        nvcc_compiler_options
 | 
			
		||||
        + extra_args
 | 
			
		||||
        + [
 | 
			
		||||
            f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}"
 | 
			
		||||
            for opt in nvcc_host_compiler_options
 | 
			
		||||
        ]
 | 
			
		||||
        + ["-I" + path for path in include_paths]
 | 
			
		||||
        + cuda_lib_options
 | 
			
		||||
    )
 | 
			
		||||
    src_file = " ".join(src_files)
 | 
			
		||||
    res = ""
 | 
			
		||||
    if dst_file_ext == "o":
 | 
			
		||||
        res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}"
 | 
			
		||||
    elif dst_file_ext == "so":
 | 
			
		||||
        options.append("-shared")
 | 
			
		||||
        res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
 | 
			
		||||
    elif dst_file_ext == "exe":
 | 
			
		||||
        res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!")
 | 
			
		||||
    if log.isEnabledFor(logging.DEBUG):
 | 
			
		||||
        log.debug("CUDA command: %s", res)
 | 
			
		||||
    else:
 | 
			
		||||
        autotuning_log.debug("CUDA command: %s", res)
 | 
			
		||||
    return res
 | 
			
		||||
 | 
			
		||||
class CUDACompileSourceCapturingContext:
 | 
			
		||||
    # Helper class for Benchmarking and Testing CUTLASS Kernels in isolation.
 | 
			
		||||
    # Can be used to capture the sourcecode passed to CUDACodeCache.compile
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.sources = []
 | 
			
		||||
        self._compile_patch = None
 | 
			
		||||
 | 
			
		||||
    def __enter__(self, *args, **kwargs):
 | 
			
		||||
        import unittest.mock as mock
 | 
			
		||||
 | 
			
		||||
        import torch._inductor.codecache
 | 
			
		||||
 | 
			
		||||
        _compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile
 | 
			
		||||
 | 
			
		||||
        def my_compile(
 | 
			
		||||
            source_code, dst_file_ext, extra_args: Optional[list[str]] = None
 | 
			
		||||
        ):
 | 
			
		||||
            self.sources.append(source_code)
 | 
			
		||||
            return _compile_method_orig(source_code, dst_file_ext)
 | 
			
		||||
 | 
			
		||||
        # pyrefly: ignore [bad-assignment]
 | 
			
		||||
        self._compile_patch = mock.patch(
 | 
			
		||||
            "torch._inductor.codecache.CUDACodeCache.compile", my_compile
 | 
			
		||||
        )
 | 
			
		||||
        self._compile_patch.__enter__(*args, **kwargs)  # type: ignore[union-attr]
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def __exit__(self, *args, **kwargs):
 | 
			
		||||
        self._compile_patch.__exit__(*args, **kwargs)  # type: ignore[union-attr]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cuda_standalone_runner_compile_command(srcpath: Path, exepath: Path):
 | 
			
		||||
    # returns command string to compile a (captured) CUDA GEMM Kernel source to a standalone executable that's ready to run
 | 
			
		||||
    # Passes the correct preprocessor define to nvcc to ensure the standalone runner is enabled.
 | 
			
		||||
    from torch._inductor.codecache import cuda_compile_command
 | 
			
		||||
 | 
			
		||||
    extra_args = ["-DGENERATE_STANDALONE_RUNNER=1", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]
 | 
			
		||||
    compile_command = cuda_compile_command(
 | 
			
		||||
        [str(srcpath)], str(exepath), "exe", extra_args=extra_args
 | 
			
		||||
    )
 | 
			
		||||
    return compile_command
 | 
			
		||||
@ -10,9 +10,11 @@ from typing import Any, Optional
 | 
			
		||||
 | 
			
		||||
import torch._inductor.config as config
 | 
			
		||||
from torch._inductor.codecache import cutlass_key
 | 
			
		||||
from torch._inductor.codegen.cuda import cutlass_utils, serialization
 | 
			
		||||
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version
 | 
			
		||||
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
 | 
			
		||||
from torch._inductor.codegen.cutlass import serialization, utils
 | 
			
		||||
from torch._inductor.codegen.cutlass.serialization import (
 | 
			
		||||
    get_cutlass_operation_serializer,
 | 
			
		||||
)
 | 
			
		||||
from torch._inductor.runtime.cache_dir_utils import cache_dir
 | 
			
		||||
from torch._inductor.utils import clear_on_fresh_cache
 | 
			
		||||
 | 
			
		||||
@ -39,7 +41,7 @@ def get_config_request_key(
 | 
			
		||||
            return hashlib.sha256(f.read()).hexdigest()
 | 
			
		||||
 | 
			
		||||
    serialization_hash = get_file_hash(serialization)
 | 
			
		||||
    cutlass_utils_hash = get_file_hash(cutlass_utils)
 | 
			
		||||
    cutlass_utils_hash = get_file_hash(utils)
 | 
			
		||||
 | 
			
		||||
    hash_target = "-".join(
 | 
			
		||||
        [
 | 
			
		||||
@ -75,7 +77,7 @@ def maybe_fetch_ops() -> Optional[list[Any]]:
 | 
			
		||||
    # get_cuda_version might return "12.4.0" or "12.4"
 | 
			
		||||
    # but we want to use "12.4"
 | 
			
		||||
    version: str = ".".join(get_cuda_version().split(".")[:2])
 | 
			
		||||
    instantiation_level: str = config.cuda.cutlass_instantiation_level
 | 
			
		||||
    instantiation_level: str = config.cutlass.cutlass_instantiation_level
 | 
			
		||||
 | 
			
		||||
    # filename and filepath
 | 
			
		||||
    request_key: str = get_config_request_key(arch, version, instantiation_level)
 | 
			
		||||
@ -11,7 +11,7 @@ from typing import Any, Optional, Union
 | 
			
		||||
import torch
 | 
			
		||||
import torch.utils._pytree as pytree
 | 
			
		||||
from torch._inductor.autotune_process import TensorMeta
 | 
			
		||||
from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops
 | 
			
		||||
from torch._inductor.codegen.cutlass.cache import maybe_fetch_ops
 | 
			
		||||
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
 | 
			
		||||
from torch._inductor.runtime.runtime_utils import dynamo_timed
 | 
			
		||||
from torch._inductor.scheduler import BaseSchedulerNode
 | 
			
		||||
@ -19,11 +19,11 @@ from torch._inductor.select_algorithm import create_inputs_key
 | 
			
		||||
from torch._inductor.utils import clear_on_fresh_cache
 | 
			
		||||
 | 
			
		||||
from ... import ir
 | 
			
		||||
from ...config import cuda as inductor_cuda_config
 | 
			
		||||
from ...config import cutlass as inductor_cutlass_config
 | 
			
		||||
from ...ir import (
 | 
			
		||||
    Buffer,
 | 
			
		||||
    ChoiceCaller,
 | 
			
		||||
    CUDATemplateBuffer,
 | 
			
		||||
    CUTLASSTemplateBuffer,
 | 
			
		||||
    FixedLayout,
 | 
			
		||||
    IRNode,
 | 
			
		||||
    Layout,
 | 
			
		||||
@ -32,11 +32,11 @@ from ...ir import (
 | 
			
		||||
from ...utils import is_dynamic, Placeholder
 | 
			
		||||
from ...virtualized import V
 | 
			
		||||
from ..common import IndentedBuffer
 | 
			
		||||
from . import cutlass_utils
 | 
			
		||||
from .cuda_kernel import CUDATemplateKernel
 | 
			
		||||
from .cuda_template import CUTLASSTemplate
 | 
			
		||||
from .cutlass_python_evt import CutlassEVTCodegen, scaled_mm_evt
 | 
			
		||||
from .cutlass_utils import (
 | 
			
		||||
from . import utils as cutlass_utils
 | 
			
		||||
from .kernel import CUTLASSTemplateKernel
 | 
			
		||||
from .python_evt import CutlassEVTCodegen, scaled_mm_evt
 | 
			
		||||
from .template import CUTLASSTemplate
 | 
			
		||||
from .utils import (
 | 
			
		||||
    ACCUMULATOR_DTYPES,
 | 
			
		||||
    dtype_match,
 | 
			
		||||
    torch_dtype_to_cutlass_type,
 | 
			
		||||
@ -578,7 +578,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
 | 
			
		||||
            for name, op in ops:
 | 
			
		||||
                for (
 | 
			
		||||
                    swizzle
 | 
			
		||||
                ) in inductor_cuda_config.cutlass_max_profiling_swizzle_options:
 | 
			
		||||
                ) in inductor_cutlass_config.cutlass_max_profiling_swizzle_options:
 | 
			
		||||
                    description = f"{name} swizzle={swizzle}"
 | 
			
		||||
                    self.maybe_append_choice(
 | 
			
		||||
                        choices,
 | 
			
		||||
@ -635,7 +635,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
 | 
			
		||||
                #include "cutlass/util/tensor_view_io.h"
 | 
			
		||||
            """
 | 
			
		||||
        )
 | 
			
		||||
        if inductor_cuda_config.generate_test_runner and not is_dynamic(
 | 
			
		||||
        if inductor_cutlass_config.generate_test_runner and not is_dynamic(
 | 
			
		||||
            *self.input_nodes, self.output_node
 | 
			
		||||
        ):
 | 
			
		||||
            res.splice(GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES)
 | 
			
		||||
@ -953,7 +953,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
 | 
			
		||||
            )
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        if inductor_cuda_config.cutlass_tma_only and not self._has_tma_epilogue(op):
 | 
			
		||||
        if inductor_cutlass_config.cutlass_tma_only and not self._has_tma_epilogue(op):
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        # Set epilogue.
 | 
			
		||||
@ -975,14 +975,16 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        # Apply regex filters at the end when configuration name doesn't change anymore
 | 
			
		||||
        if inductor_cuda_config.cutlass_op_allowlist_regex:
 | 
			
		||||
        if inductor_cutlass_config.cutlass_op_allowlist_regex:
 | 
			
		||||
            if not re.search(
 | 
			
		||||
                inductor_cuda_config.cutlass_op_allowlist_regex, op.configuration_name()
 | 
			
		||||
                inductor_cutlass_config.cutlass_op_allowlist_regex,
 | 
			
		||||
                op.configuration_name(),
 | 
			
		||||
            ):
 | 
			
		||||
                return None
 | 
			
		||||
        if inductor_cuda_config.cutlass_op_denylist_regex is not None:
 | 
			
		||||
        if inductor_cutlass_config.cutlass_op_denylist_regex is not None:
 | 
			
		||||
            if re.search(
 | 
			
		||||
                inductor_cuda_config.cutlass_op_denylist_regex, op.configuration_name()
 | 
			
		||||
                inductor_cutlass_config.cutlass_op_denylist_regex,
 | 
			
		||||
                op.configuration_name(),
 | 
			
		||||
            ):
 | 
			
		||||
                return None
 | 
			
		||||
 | 
			
		||||
@ -1035,7 +1037,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
 | 
			
		||||
            time.time() - start_time,
 | 
			
		||||
        )
 | 
			
		||||
        sorted_res = sorted(res.items())
 | 
			
		||||
        ret_res = sorted_res[: inductor_cuda_config.cutlass_max_profiling_configs]
 | 
			
		||||
        ret_res = sorted_res[: inductor_cutlass_config.cutlass_max_profiling_configs]
 | 
			
		||||
        if len(self.filtered_ops_cache) < 50:
 | 
			
		||||
            self.filtered_ops_cache[self.cache_key] = ret_res
 | 
			
		||||
        else:
 | 
			
		||||
@ -1060,9 +1062,9 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
 | 
			
		||||
 | 
			
		||||
    def render(  # type: ignore[override]
 | 
			
		||||
        self,
 | 
			
		||||
        kernel: CUDATemplateKernel,
 | 
			
		||||
        kernel: CUTLASSTemplateKernel,
 | 
			
		||||
        op: "cutlass_gemm_op.GemmOperation" = None,  # type: ignore[name-defined]  # noqa: F821
 | 
			
		||||
        template_buffer_node: Optional[CUDATemplateBuffer] = None,
 | 
			
		||||
        template_buffer_node: Optional[CUTLASSTemplateBuffer] = None,
 | 
			
		||||
        epilogue_nodes: Optional[list[BaseSchedulerNode]] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> str:
 | 
			
		||||
@ -1072,14 +1074,14 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
 | 
			
		||||
        including potentially fused epilogues.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            kernel (CUDATemplateKernel): The kernel to be rendered.
 | 
			
		||||
            kernel (CUTLASSTemplateKernel): The kernel to be rendered.
 | 
			
		||||
            op (cutlass_gemm_op.GemmOperation, optional): A GEMM operation that is required to be compatible with the
 | 
			
		||||
                input and output definitions as well as a possible epilogue. Defaults to None.
 | 
			
		||||
            **kwargs: Additional keyword arguments. Currently unused.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            str: Cutlass based CUDA C++ code fragment as a string, to be used by the current
 | 
			
		||||
            CUDATemplateKernel or autotuning code.
 | 
			
		||||
            CUTLASSTemplateKernel or autotuning code.
 | 
			
		||||
 | 
			
		||||
        Note:
 | 
			
		||||
            All inputs and their corresponding buffer addresses and names take precedence over previously
 | 
			
		||||
@ -1277,7 +1279,9 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
 | 
			
		||||
        }
 | 
			
		||||
        options.update(dict(zip(extra_names, extra_inputs)))
 | 
			
		||||
        res = self._template_from_string(self._get_template()).render(**options)
 | 
			
		||||
        if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias):
 | 
			
		||||
        if inductor_cutlass_config.generate_test_runner and not is_dynamic(
 | 
			
		||||
            X, W, Y, Bias
 | 
			
		||||
        ):
 | 
			
		||||
            test_runner_code = self._template_from_string(
 | 
			
		||||
                GEMM_STANDALONE_RUNNER_TEMPLATE
 | 
			
		||||
            ).render(**options)
 | 
			
		||||
@ -1483,7 +1487,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
 | 
			
		||||
        output_dtype: torch.dtype,
 | 
			
		||||
        accumulator_dtype: torch.dtype,
 | 
			
		||||
    ) -> tuple[str, str, str, EVTArgRenames]:
 | 
			
		||||
        from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace
 | 
			
		||||
        from .lib_extensions.evt_extensions import create_example_tensors, trace
 | 
			
		||||
 | 
			
		||||
        acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype)
 | 
			
		||||
        output_dtype = torch_dtype_to_cutlass_type(output_dtype)
 | 
			
		||||
@ -1570,7 +1574,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
 | 
			
		||||
        assert cutlass_utils.try_import_cutlass()
 | 
			
		||||
        import cutlass_library.library as cutlass_lib
 | 
			
		||||
 | 
			
		||||
        from .cutlass_lib_extensions import gemm_operation_extensions as gemm_extensions
 | 
			
		||||
        from .lib_extensions import gemm_operation_extensions as gemm_extensions
 | 
			
		||||
 | 
			
		||||
        emitter = gemm_extensions.EmitGemmUniversal3xInstanceWithEVT(evt_name=evt_name)  # type: ignore[call-arg]
 | 
			
		||||
 | 
			
		||||
@ -1629,7 +1633,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
 | 
			
		||||
        Y: IRNode,
 | 
			
		||||
        alpha: float,
 | 
			
		||||
        beta: float,
 | 
			
		||||
        kernel: CUDATemplateKernel,
 | 
			
		||||
        kernel: CUTLASSTemplateKernel,
 | 
			
		||||
        epilogue_args,
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        """
 | 
			
		||||
@ -1646,7 +1650,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
 | 
			
		||||
            Y (IRNode): The output tensor.
 | 
			
		||||
            alpha (float): Scaling factor for the product of the inputs.
 | 
			
		||||
            beta (float): Scaling factor for the output tensor.
 | 
			
		||||
            kernel (CUDATemplateKernel): CUDA Template kernel for the operation.
 | 
			
		||||
            kernel (CUTLASSTemplateKernel): CUDA Template kernel for the operation.
 | 
			
		||||
            epilogue_args (any): Additional arguments for the epilogue state.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
@ -1710,6 +1714,8 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
 | 
			
		||||
    """CUTLASS 2x GEMM Template, which is used to generate CUTLASS GEMM kernels"""
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        input_nodes: list[Buffer],
 | 
			
		||||
@ -1918,7 +1924,7 @@ class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
 | 
			
		||||
        Y: IRNode,
 | 
			
		||||
        alpha: float,
 | 
			
		||||
        beta: float,
 | 
			
		||||
        kernel: CUDATemplateKernel,
 | 
			
		||||
        kernel: CUTLASSTemplateKernel,
 | 
			
		||||
        epilogue_args,
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        """
 | 
			
		||||
@ -1937,7 +1943,7 @@ class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
 | 
			
		||||
            Y (IRNode): The output tensor.
 | 
			
		||||
            alpha (float): Scaling factor for the product of the inputs.
 | 
			
		||||
            beta (float): Scaling factor for the output tensor.
 | 
			
		||||
            kernel (CUDATemplateKernel): CUDA Template kernel for the operation.
 | 
			
		||||
            kernel (CUTLASSTemplateKernel): CUDA Template kernel for the operation.
 | 
			
		||||
            epilogue_args (any): Additional arguments for the epilogue state.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
@ -15,17 +15,17 @@ from torch._inductor.scheduler import BaseSchedulerNode
 | 
			
		||||
from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder
 | 
			
		||||
from torch.utils._sympy.value_ranges import ValueRanges
 | 
			
		||||
 | 
			
		||||
from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE
 | 
			
		||||
from .utils import DTYPE_TO_CUTLASS_TYPE
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from .cuda_template import ArgInfo
 | 
			
		||||
    from .template import ArgInfo
 | 
			
		||||
 | 
			
		||||
from ...autotune_process import CUDABenchmarkRequest
 | 
			
		||||
from ...autotune_process import CUTLASSBenchmarkRequest
 | 
			
		||||
from ...ir import (
 | 
			
		||||
    Buffer,
 | 
			
		||||
    ChoiceCaller,
 | 
			
		||||
    CUDATemplateBuffer,
 | 
			
		||||
    CUTLASSTemplateBuffer,
 | 
			
		||||
    IRNode,
 | 
			
		||||
    Layout,
 | 
			
		||||
    PrimitiveInfoType,
 | 
			
		||||
@ -46,7 +46,7 @@ from ..cpp_utils import CppPrinter, DTYPE_TO_CPP
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from torch._inductor.codegen.cuda.cuda_template import CUDATemplate
 | 
			
		||||
    from torch._inductor.codegen.cutlass.template import CUTLASSTemplate
 | 
			
		||||
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
@ -72,9 +72,9 @@ class LayoutArg:
 | 
			
		||||
        return self.node == node and self.attr == attr and self.dim == dim
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CUDAKernel(Kernel):
 | 
			
		||||
class CUTLASSKernel(Kernel):
 | 
			
		||||
    """
 | 
			
		||||
    Baseclass for CUDA / Cutlass based Kernels
 | 
			
		||||
    Baseclass for Cutlass based Kernels
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    overrides = OpOverrides  # type: ignore[assignment]
 | 
			
		||||
@ -191,9 +191,9 @@ class CUDAKernel(Kernel):
 | 
			
		||||
        return _normalize_idx(-1, len(strides))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CUDATemplateKernel(CUDAKernel):
 | 
			
		||||
class CUTLASSTemplateKernel(CUTLASSKernel):
 | 
			
		||||
    """
 | 
			
		||||
    Template kernels defined by CUDA / Cutlass in C++.
 | 
			
		||||
    Template kernels defined by Cutlass in C++.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream"
 | 
			
		||||
@ -205,7 +205,7 @@ class CUDATemplateKernel(CUDAKernel):
 | 
			
		||||
        runtime_arg_values: list[Any],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Initializes a new instance of the CUDATemplateKernel class.
 | 
			
		||||
        Initializes a new instance of the CUTLASSTemplateKernel class.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            kernel_name (str): The name of the kernel.
 | 
			
		||||
@ -328,14 +328,14 @@ class CUDATemplateKernel(CUDAKernel):
 | 
			
		||||
    def call_kernel(
 | 
			
		||||
        self,
 | 
			
		||||
        name: str,
 | 
			
		||||
        node: "CUDATemplateBuffer",  # type: ignore[name-defined]
 | 
			
		||||
        node: "CUTLASSTemplateBuffer",  # type: ignore[name-defined]
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Generates code to call the kernel through V.graph.wrapper_code.
 | 
			
		||||
        used from within torch._inductor.wrapper.PythonWrapperCodegen
 | 
			
		||||
 | 
			
		||||
        name: Name of kernel function.
 | 
			
		||||
        node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes
 | 
			
		||||
        node: The CUTLASSTemplateBuffer node which contains information about the kernel, it's fused epilogue nodes
 | 
			
		||||
        as well as all required inputs and outputs.
 | 
			
		||||
        """
 | 
			
		||||
        wrapper = V.graph.wrapper_code
 | 
			
		||||
@ -423,7 +423,7 @@ class CUDATemplateKernel(CUDAKernel):
 | 
			
		||||
        # Helper method, called into from CUTLASSGemmTemplate
 | 
			
		||||
        if node is None:
 | 
			
		||||
            return default_dtype
 | 
			
		||||
        from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate
 | 
			
		||||
        from torch._inductor.codegen.cutlass.template import CUTLASSTemplate
 | 
			
		||||
 | 
			
		||||
        return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]
 | 
			
		||||
 | 
			
		||||
@ -562,16 +562,16 @@ class CUDATemplateKernel(CUDAKernel):
 | 
			
		||||
        self.store_buffer_names.add(name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CUDATemplateCaller(ChoiceCaller):
 | 
			
		||||
class CUTLASSTemplateCaller(ChoiceCaller):
 | 
			
		||||
    """
 | 
			
		||||
    CUDATemplateCaller
 | 
			
		||||
    CUTLASSTemplateCaller
 | 
			
		||||
 | 
			
		||||
    This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller.
 | 
			
		||||
    This class represents a caller for CUTLASS template kernels. It is a subclass of ChoiceCaller.
 | 
			
		||||
    Attributes:
 | 
			
		||||
        name (str): The name of the caller.
 | 
			
		||||
        category (str): The category of the caller.
 | 
			
		||||
        bmreq (CUDABenchmarkRequest): The benchmark request for the caller.
 | 
			
		||||
        template_buffer (CUDATemplateBuffer): The template buffer for the caller.
 | 
			
		||||
        bmreq (CUTLASSBenchmarkRequest): The benchmark request for the caller.
 | 
			
		||||
        template_buffer (CUTLASSTemplateBuffer): The template buffer for the caller.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
@ -581,12 +581,12 @@ class CUDATemplateCaller(ChoiceCaller):
 | 
			
		||||
        input_nodes: list[Buffer],
 | 
			
		||||
        layout: Layout,
 | 
			
		||||
        make_kernel_render: Callable[
 | 
			
		||||
            [CUDATemplateBuffer, Optional[list[BaseSchedulerNode]]],
 | 
			
		||||
            tuple[CUDATemplateKernel, functools.partial[str]],
 | 
			
		||||
            [CUTLASSTemplateBuffer, Optional[list[BaseSchedulerNode]]],
 | 
			
		||||
            tuple[CUTLASSTemplateKernel, functools.partial[str]],
 | 
			
		||||
        ],
 | 
			
		||||
        bmreq: CUDABenchmarkRequest,
 | 
			
		||||
        bmreq: CUTLASSBenchmarkRequest,
 | 
			
		||||
        supports_epilogue_fusion: bool,
 | 
			
		||||
        template: "CUDATemplate",  # type: ignore[name-defined]
 | 
			
		||||
        template: "CUTLASSTemplate",  # type: ignore[name-defined]
 | 
			
		||||
        info_kwargs: Optional[
 | 
			
		||||
            dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
 | 
			
		||||
        ],  # type: ignore[type-arg]
 | 
			
		||||
@ -612,10 +612,10 @@ class CUDATemplateCaller(ChoiceCaller):
 | 
			
		||||
        return self.bmreq.benchmark(*args, out=out)
 | 
			
		||||
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        return f"CUDATemplateCaller(source_file={self.bmreq.source_file})"
 | 
			
		||||
        return f"CUTLASSTemplateCaller(source_file={self.bmreq.source_file})"
 | 
			
		||||
 | 
			
		||||
    def call_name(self) -> str:
 | 
			
		||||
        return f"cuda_template_kernels.{self.name}"
 | 
			
		||||
        return f"cutlass_template_kernels.{self.name}"
 | 
			
		||||
 | 
			
		||||
    def kernel_hash_key(self) -> str:
 | 
			
		||||
        """
 | 
			
		||||
@ -675,7 +675,7 @@ class CUDATemplateCaller(ChoiceCaller):
 | 
			
		||||
    def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]:
 | 
			
		||||
        self.bmreq.update_workspace_size()
 | 
			
		||||
        return TensorBox.create(
 | 
			
		||||
            CUDATemplateBuffer(
 | 
			
		||||
            CUTLASSTemplateBuffer(
 | 
			
		||||
                layout=self.layout,
 | 
			
		||||
                inputs=self.input_nodes,
 | 
			
		||||
                make_kernel_render=self.make_kernel_render,
 | 
			
		||||
@ -10,7 +10,7 @@ from torch._inductor.ir import (
 | 
			
		||||
)
 | 
			
		||||
from torch.utils._ordered_set import OrderedSet
 | 
			
		||||
 | 
			
		||||
from ..cutlass_utils import torch_dtype_to_cutlass_type, try_import_cutlass
 | 
			
		||||
from ..utils import torch_dtype_to_cutlass_type, try_import_cutlass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
EpilogueFunctor = Any  # EpilogueFunctor local class defined in _trace
 | 
			
		||||
@ -237,7 +237,7 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
 | 
			
		||||
        size_hint_fn: Callable[[Union[Expr, int]], int],
 | 
			
		||||
        arg_renames: EVTArgRenames,
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        from ..cuda_template import CUTLASSTemplate
 | 
			
		||||
        from ..template import CUTLASSTemplate
 | 
			
		||||
 | 
			
		||||
        # Today, arguments are either a pointer to the
 | 
			
		||||
        # node's memory, a stride tuple, the datatype
 | 
			
		||||
@ -1,5 +1,5 @@
 | 
			
		||||
# mypy: ignore-errors
 | 
			
		||||
from ..cutlass_utils import try_import_cutlass
 | 
			
		||||
from ..utils import try_import_cutlass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# copied / modified from original at
 | 
			
		||||
@ -183,11 +183,11 @@ class CutlassEVTCodegen(CutlassEVTOpsMixIn):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def ir_to_evt_python_code(
 | 
			
		||||
        cuda_template_node_name: str,
 | 
			
		||||
        cutlass_template_node_name: str,
 | 
			
		||||
        epilogue_nodes: list[BaseSchedulerNode],
 | 
			
		||||
        removed_buffers: OrderedSet[str],
 | 
			
		||||
    ) -> tuple[list[str], list[str], dict[str, Any], str]:
 | 
			
		||||
        codegen = CutlassEVTCodegen(cuda_template_node_name, removed_buffers)
 | 
			
		||||
        codegen = CutlassEVTCodegen(cutlass_template_node_name, removed_buffers)
 | 
			
		||||
        handler = _AssignmentFormatter(codegen)
 | 
			
		||||
 | 
			
		||||
        with virtualized.V.set_ops_handler(handler):
 | 
			
		||||
@ -4,7 +4,7 @@ import logging
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from typing import cast
 | 
			
		||||
 | 
			
		||||
from torch._inductor.codegen.cuda.cutlass_python_evt import (
 | 
			
		||||
from torch._inductor.codegen.cutlass.python_evt import (
 | 
			
		||||
    CutlassEVTCodegen,
 | 
			
		||||
    MockCutlassHandler,
 | 
			
		||||
)
 | 
			
		||||
@ -14,7 +14,7 @@ from torch.utils._ordered_set import OrderedSet
 | 
			
		||||
from ...._dynamo.utils import counters
 | 
			
		||||
from ... import config
 | 
			
		||||
from ...codecache import code_hash, get_path
 | 
			
		||||
from ...ir import Buffer, ComputedBuffer, CUDATemplateBuffer, Pointwise
 | 
			
		||||
from ...ir import Buffer, ComputedBuffer, CUTLASSTemplateBuffer, Pointwise
 | 
			
		||||
from ...scheduler import (
 | 
			
		||||
    BaseSchedulerNode,
 | 
			
		||||
    BaseScheduling,
 | 
			
		||||
@ -36,13 +36,13 @@ class WhyNoFuseNames(WhyNoFuse):
 | 
			
		||||
        self.name2 = name2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CUDACPPScheduling(BaseScheduling):
 | 
			
		||||
class CUTLASSScheduling(BaseScheduling):
 | 
			
		||||
    """
 | 
			
		||||
    Partial Scheduling implementation for CUDA C++ Kernels.
 | 
			
		||||
    Partial Scheduling implementation for cutlass C++ Kernels.
 | 
			
		||||
    This class is intended to be used in combination with TritonScheduling,
 | 
			
		||||
    and delegated to by CUDACombinedScheduling.
 | 
			
		||||
    and delegated to by CombinedScheduling.
 | 
			
		||||
 | 
			
		||||
    It handles fusion decisions and CUDA C++ specific template code generation.
 | 
			
		||||
    It handles fusion decisions and cutlass C++ specific template code generation.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
@ -53,25 +53,25 @@ class CUDACPPScheduling(BaseScheduling):
 | 
			
		||||
        return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool:
 | 
			
		||||
    def is_cutlass_template(node: BaseSchedulerNode) -> bool:
 | 
			
		||||
        return isinstance(node, SchedulerNode) and isinstance(
 | 
			
		||||
            node.node, CUDATemplateBuffer
 | 
			
		||||
            node.node, CUTLASSTemplateBuffer
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool:
 | 
			
		||||
        return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template(node)
 | 
			
		||||
    def is_cutlass_fused_template(self, node: BaseSchedulerNode) -> bool:
 | 
			
		||||
        return isinstance(node, FusedSchedulerNode) and self.is_cutlass_template(node)
 | 
			
		||||
 | 
			
		||||
    def can_fuse_vertical(
 | 
			
		||||
        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
 | 
			
		||||
    ) -> bool:
 | 
			
		||||
        if self.is_cuda_cpp_template(node1) and isinstance(node2, BaseSchedulerNode):
 | 
			
		||||
        if self.is_cutlass_template(node1) and isinstance(node2, BaseSchedulerNode):
 | 
			
		||||
            assert node1.node, "node1.node should not be None"
 | 
			
		||||
            return self._can_fuse_epilogue_impl(
 | 
			
		||||
                cast(CUDATemplateBuffer, node1.node),
 | 
			
		||||
                cast(CUTLASSTemplateBuffer, node1.node),
 | 
			
		||||
                [],
 | 
			
		||||
                node2,  # type: ignore[arg-type]
 | 
			
		||||
            )
 | 
			
		||||
        elif self.is_cuda_cpp_fused_template(node1) and isinstance(
 | 
			
		||||
        elif self.is_cutlass_fused_template(node1) and isinstance(
 | 
			
		||||
            node2, BaseSchedulerNode
 | 
			
		||||
        ):
 | 
			
		||||
            assert node1.node, "node1.node should not be None"
 | 
			
		||||
@ -130,16 +130,16 @@ class CUDACPPScheduling(BaseScheduling):
 | 
			
		||||
        prologue_nodes: Sequence[BaseSchedulerNode],
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Codegen a CUDA template, possibly with fused epilogues
 | 
			
		||||
        Codegen a cutlass template, possibly with fused epilogues
 | 
			
		||||
        """
 | 
			
		||||
        counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes)
 | 
			
		||||
        assert self.is_cuda_cpp_template(template_node), (
 | 
			
		||||
            "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
 | 
			
		||||
        counters["inductor"]["cutlass_epilogue_fusion_counter"] += len(epilogue_nodes)
 | 
			
		||||
        assert self.is_cutlass_template(template_node), (
 | 
			
		||||
            "Template node passed to CUTLASSScheduling.codegen_template must be a SchedulerNode that wraps a CUTLASSTemplateBuffer"
 | 
			
		||||
        )
 | 
			
		||||
        template_node = cast(SchedulerNode, template_node)
 | 
			
		||||
        _, (_numel, rnumel) = template_node.group
 | 
			
		||||
        assert rnumel == 1
 | 
			
		||||
        ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node)
 | 
			
		||||
        ctb: CUTLASSTemplateBuffer = cast(CUTLASSTemplateBuffer, template_node.node)
 | 
			
		||||
        epilogue_ir_nodes: list[Buffer] = [n.node for n in epilogue_nodes]  # type: ignore[misc]
 | 
			
		||||
        assert all(isinstance(n, ComputedBuffer) for n in epilogue_ir_nodes), (
 | 
			
		||||
            "Epilogue nodes must all be instances of ir.ComputedBuffer"
 | 
			
		||||
@ -197,7 +197,7 @@ class CUDACPPScheduling(BaseScheduling):
 | 
			
		||||
 | 
			
		||||
    def _can_fuse_epilogue_impl(
 | 
			
		||||
        self,
 | 
			
		||||
        cuda_template_buffer: CUDATemplateBuffer,
 | 
			
		||||
        cutlass_template_buffer: CUTLASSTemplateBuffer,
 | 
			
		||||
        existing_epilogue_nodes: list[BaseSchedulerNode],
 | 
			
		||||
        node_to_fuse: BaseSchedulerNode,
 | 
			
		||||
    ) -> bool:
 | 
			
		||||
@ -206,18 +206,20 @@ class CUDACPPScheduling(BaseScheduling):
 | 
			
		||||
        support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer
 | 
			
		||||
            cutlass_template_buffer : A CUTLASSTemplateBuffer object representing the CUTLASS template and it's result buffer
 | 
			
		||||
            existing_epilogue_nodes : List[SchedulerNode]: The list of already fused epilogue nodes.
 | 
			
		||||
            node_to_fuse: The SchedulerNode node to be checked if it can be fused with the epilogue.
 | 
			
		||||
        Returns:
 | 
			
		||||
        - bool: True if the given node can be fused with the epilogue, False otherwise.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        why = WhyNoFuseNames(cuda_template_buffer.get_name(), node_to_fuse.get_name())
 | 
			
		||||
        why = WhyNoFuseNames(
 | 
			
		||||
            cutlass_template_buffer.get_name(), node_to_fuse.get_name()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        scheduler_nodes_to_fuse = node_to_fuse.get_nodes()
 | 
			
		||||
 | 
			
		||||
        assert isinstance(cuda_template_buffer, CUDATemplateBuffer)
 | 
			
		||||
        assert isinstance(cutlass_template_buffer, CUTLASSTemplateBuffer)
 | 
			
		||||
 | 
			
		||||
        # Checks on constituent nodes
 | 
			
		||||
        for s_node in scheduler_nodes_to_fuse:
 | 
			
		||||
@ -235,18 +237,18 @@ class CUDACPPScheduling(BaseScheduling):
 | 
			
		||||
 | 
			
		||||
            name = node.get_computed_buffer_name()  # type: ignore[attr-defined]
 | 
			
		||||
            # dtype can differ, and strides can differ as long as they are broadcastable
 | 
			
		||||
            if node.get_size() != cuda_template_buffer.get_size():
 | 
			
		||||
            if node.get_size() != cutlass_template_buffer.get_size():
 | 
			
		||||
                why(
 | 
			
		||||
                    f"{name}'s size: {node.get_size()} differs from {cuda_template_buffer.get_name()}'s \
 | 
			
		||||
size: {cuda_template_buffer.get_size()}"
 | 
			
		||||
                    f"{name}'s size: {node.get_size()} differs from {cutlass_template_buffer.get_name()}'s \
 | 
			
		||||
size: {cutlass_template_buffer.get_size()}"
 | 
			
		||||
                )
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
        assert len(
 | 
			
		||||
            existing_epilogue_nodes
 | 
			
		||||
        ) or cuda_template_buffer.get_name() in OrderedSet(
 | 
			
		||||
        ) or cutlass_template_buffer.get_name() in OrderedSet(
 | 
			
		||||
            [rd.name for rd in node_to_fuse.read_writes.reads]
 | 
			
		||||
        ), "First epilogue node must read from cuda template buffer"
 | 
			
		||||
        ), "First epilogue node must read from cutlass template buffer"
 | 
			
		||||
 | 
			
		||||
        if node_to_fuse.has_aliasing_or_mutation():
 | 
			
		||||
            why(f"{node_to_fuse.get_name()} has aliasing or mutation")
 | 
			
		||||
@ -257,22 +259,20 @@ size: {cuda_template_buffer.get_size()}"
 | 
			
		||||
            )
 | 
			
		||||
            return False
 | 
			
		||||
        elif (
 | 
			
		||||
            not config.cuda.cutlass_epilogue_fusion_enabled
 | 
			
		||||
            not config.cutlass.cutlass_epilogue_fusion_enabled
 | 
			
		||||
            or not config.epilogue_fusion
 | 
			
		||||
        ):
 | 
			
		||||
            why("cutlass epilogue fusion is not enabled")
 | 
			
		||||
            return False
 | 
			
		||||
        elif not cuda_template_buffer.supports_epilogue_fusion:
 | 
			
		||||
        elif not cutlass_template_buffer.supports_epilogue_fusion:
 | 
			
		||||
            why("epilogue fusion is only supported for TMA-enabled gemm ops")
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            from torch._inductor.codegen.cuda.cutlass_python_evt import (
 | 
			
		||||
                CutlassEVTCodegen,
 | 
			
		||||
            )
 | 
			
		||||
            from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen
 | 
			
		||||
 | 
			
		||||
            CutlassEVTCodegen.ir_to_evt_python_code(
 | 
			
		||||
                cuda_template_buffer.get_name(),
 | 
			
		||||
                cutlass_template_buffer.get_name(),
 | 
			
		||||
                existing_epilogue_nodes + list(node_to_fuse.get_nodes()),
 | 
			
		||||
                OrderedSet(),
 | 
			
		||||
            )
 | 
			
		||||
@ -282,13 +282,13 @@ size: {cuda_template_buffer.get_size()}"
 | 
			
		||||
            if not_implemented_op.startswith("_op_"):
 | 
			
		||||
                not_implemented_op = not_implemented_op[4:]
 | 
			
		||||
                why(
 | 
			
		||||
                    f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}, \
 | 
			
		||||
                    f"Cannot fuse epilogue node {node_to_fuse} into {cutlass_template_buffer.name}, \
 | 
			
		||||
likely due to unsupported operation: {not_implemented_op}"  # noqa: G004, B950
 | 
			
		||||
                )
 | 
			
		||||
                return False
 | 
			
		||||
            else:  # Likely due to unsupported dtype.
 | 
			
		||||
                why(
 | 
			
		||||
                    f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}. \
 | 
			
		||||
                    f"Cannot fuse epilogue node {node_to_fuse} into {cutlass_template_buffer.name}. \
 | 
			
		||||
Reason: {not_implemented_op}"  # noqa: G004, B950
 | 
			
		||||
                )
 | 
			
		||||
                return False
 | 
			
		||||
@ -4,7 +4,7 @@ import json
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
 | 
			
		||||
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
 | 
			
		||||
from torch._inductor.codegen.cutlass.utils import try_import_cutlass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CUTLASSOperationSerializer:
 | 
			
		||||
@ -4,7 +4,6 @@ import hashlib
 | 
			
		||||
import itertools
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, Optional, TYPE_CHECKING, Union
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
import sympy
 | 
			
		||||
@ -14,13 +13,13 @@ from torch._inductor import config
 | 
			
		||||
from torch._inductor.utils import clear_on_fresh_cache, Placeholder
 | 
			
		||||
from torch._logging import getArtifactLogger
 | 
			
		||||
 | 
			
		||||
from ...autotune_process import CUDABenchmarkRequest, TensorMeta
 | 
			
		||||
from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout
 | 
			
		||||
from ...autotune_process import CUTLASSBenchmarkRequest, TensorMeta
 | 
			
		||||
from ...ir import Buffer, CUTLASSTemplateBuffer, IRNode, Layout
 | 
			
		||||
from ...utils import IndentedBuffer, unique
 | 
			
		||||
from ...virtualized import V
 | 
			
		||||
from ..common import KernelTemplate
 | 
			
		||||
from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
 | 
			
		||||
from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE
 | 
			
		||||
from .kernel import CUTLASSTemplateCaller, CUTLASSTemplateKernel
 | 
			
		||||
from .utils import DTYPE_TO_CUTLASS_TYPE
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -40,7 +39,12 @@ class ArgInfo:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@clear_on_fresh_cache
 | 
			
		||||
class CUDATemplate(KernelTemplate):
 | 
			
		||||
class CUTLASSTemplate(KernelTemplate):
 | 
			
		||||
    """
 | 
			
		||||
    CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the
 | 
			
		||||
    CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    index_counter = itertools.count()
 | 
			
		||||
    # dict of cache key to (code, size_args)
 | 
			
		||||
    code_cache: dict[str, tuple[str, tuple[int, ...], tuple[int, ...]]] = {}
 | 
			
		||||
@ -52,13 +56,14 @@ class CUDATemplate(KernelTemplate):
 | 
			
		||||
        input_nodes: list[Buffer],
 | 
			
		||||
        layout: Layout,
 | 
			
		||||
        input_reorder: Optional[list[int]] = None,
 | 
			
		||||
        device_type: str = "cuda",
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Baseclass for CUDA C++ Templates, derived from KernelTemplate.
 | 
			
		||||
        Baseclass for CUTLASS C++ Templates, derived from KernelTemplate.
 | 
			
		||||
        Not to be instantiated directly.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            name (str): The name of the CUDATemplate object.
 | 
			
		||||
            name (str): The name of the CUTLASSTemplate object.
 | 
			
		||||
            input_nodes (List[IRNode]): A list of input IRNodes.
 | 
			
		||||
            layout (Layout): The layout of the output buffer / tensor.
 | 
			
		||||
            input_reorder (Optional[List[int]]): An optional list that specifies
 | 
			
		||||
@ -69,6 +74,7 @@ class CUDATemplate(KernelTemplate):
 | 
			
		||||
        self.output_node: Buffer = Buffer(name="buf_out", layout=layout)
 | 
			
		||||
        self.input_reorder = input_reorder
 | 
			
		||||
        self.layout = layout
 | 
			
		||||
        self.device_type = device_type
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @functools.lru_cache(None)
 | 
			
		||||
@ -110,7 +116,7 @@ class CUDATemplate(KernelTemplate):
 | 
			
		||||
        args are different.
 | 
			
		||||
        """
 | 
			
		||||
        key: Optional[str] = None
 | 
			
		||||
        if config.cuda.enable_caching_codegen:
 | 
			
		||||
        if config.cutlass.enable_caching_codegen:
 | 
			
		||||
            key = self.make_key(name=name, input_key=input_key, layout_repr=layout_repr)
 | 
			
		||||
 | 
			
		||||
        if key is not None and key in self.code_cache:
 | 
			
		||||
@ -123,7 +129,7 @@ class CUDATemplate(KernelTemplate):
 | 
			
		||||
            return code, extra_args
 | 
			
		||||
 | 
			
		||||
        kernel_name = str(Placeholder.KERNEL_NAME)
 | 
			
		||||
        kernel = CUDATemplateKernel(
 | 
			
		||||
        kernel = CUTLASSTemplateKernel(
 | 
			
		||||
            kernel_name=kernel_name,
 | 
			
		||||
            runtime_arg_info=self.get_runtime_arg_info(),
 | 
			
		||||
            runtime_arg_values=self.get_runtime_arg_values(**kwargs),
 | 
			
		||||
@ -174,10 +180,10 @@ class CUDATemplate(KernelTemplate):
 | 
			
		||||
        input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
 | 
			
		||||
        output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> CUDATemplateCaller:
 | 
			
		||||
    ) -> CUTLASSTemplateCaller:
 | 
			
		||||
        """
 | 
			
		||||
        Generates the CUDA template caller object for the given GEMM template and operation.
 | 
			
		||||
        This CUDATemplateCaller may be used to call and benchmark the generated CUDA kernel
 | 
			
		||||
        This CUTLASSTemplateCaller may be used to call and benchmark the generated CUDA kernel
 | 
			
		||||
        in a standalone manner to enable Autotuning.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
@ -185,7 +191,7 @@ class CUDATemplate(KernelTemplate):
 | 
			
		||||
            kwargs: Additional keyword arguments.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A CUDATemplateCaller object representing the generated CUDA template caller.
 | 
			
		||||
            A CUTLASSTemplateCaller object representing the generated CUDA template caller.
 | 
			
		||||
        """
 | 
			
		||||
        code, extra_args = self.generate_code_and_args(
 | 
			
		||||
            name=name,
 | 
			
		||||
@ -200,12 +206,13 @@ class CUDATemplate(KernelTemplate):
 | 
			
		||||
        code = code.replace(self.name, kernel_name)
 | 
			
		||||
 | 
			
		||||
        # create the BenchmarkRequest
 | 
			
		||||
        bmreq = CUDABenchmarkRequest(
 | 
			
		||||
        bmreq = CUTLASSBenchmarkRequest(
 | 
			
		||||
            kernel_name=kernel_name,
 | 
			
		||||
            input_tensor_meta=input_tensor_meta,
 | 
			
		||||
            output_tensor_meta=output_tensor_meta,
 | 
			
		||||
            extra_args=extra_args,
 | 
			
		||||
            source_code=code,
 | 
			
		||||
            device_type=self.device_type,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # kwargs has "op" argument in case of CUTLASSGemmTemplate
 | 
			
		||||
@ -217,13 +224,13 @@ class CUDATemplate(KernelTemplate):
 | 
			
		||||
            supports_epilogue_fusion = self.supports_epilogue_fusion(op)
 | 
			
		||||
 | 
			
		||||
        def make_kernel_render(
 | 
			
		||||
            template_node: CUDATemplateBuffer,
 | 
			
		||||
            template_node: CUTLASSTemplateBuffer,
 | 
			
		||||
            epilogue_nodes: Optional[list[BaseSchedulerNode]] = None,
 | 
			
		||||
        ) -> tuple[CUDATemplateKernel, functools.partial[str]]:
 | 
			
		||||
        ) -> tuple[CUTLASSTemplateKernel, functools.partial[str]]:
 | 
			
		||||
            assert supports_epilogue_fusion or not epilogue_nodes, (
 | 
			
		||||
                "epilogue fusion is not supported for this kernel"
 | 
			
		||||
            )
 | 
			
		||||
            kernel = CUDATemplateKernel(
 | 
			
		||||
            kernel = CUTLASSTemplateKernel(
 | 
			
		||||
                kernel_name=str(Placeholder.KERNEL_NAME),
 | 
			
		||||
                runtime_arg_info=self.get_runtime_arg_info(),
 | 
			
		||||
                runtime_arg_values=self.get_runtime_arg_values(**kwargs),
 | 
			
		||||
@ -237,7 +244,7 @@ class CUDATemplate(KernelTemplate):
 | 
			
		||||
            )
 | 
			
		||||
            return kernel, render
 | 
			
		||||
 | 
			
		||||
        return CUDATemplateCaller(
 | 
			
		||||
        return CUTLASSTemplateCaller(
 | 
			
		||||
            kernel_name,
 | 
			
		||||
            "cutlass_gemm",
 | 
			
		||||
            self.input_nodes,
 | 
			
		||||
@ -261,6 +268,18 @@ class CUDATemplate(KernelTemplate):
 | 
			
		||||
                #include <vector>
 | 
			
		||||
            """
 | 
			
		||||
        )
 | 
			
		||||
        res.splice(
 | 
			
		||||
            """
 | 
			
		||||
                #include "cute/tensor.hpp"
 | 
			
		||||
                #include "cutlass/cutlass.h"
 | 
			
		||||
                #include "cutlass/numeric_types.h"
 | 
			
		||||
                #include "cutlass/tensor_ref.h"
 | 
			
		||||
                #include "cutlass/util/host_tensor.h"
 | 
			
		||||
                #include "cutlass/util/reference/host/tensor_fill.h"
 | 
			
		||||
                #include "cutlass/util/reference/device/tensor_fill.h"
 | 
			
		||||
                #include "cutlass/util/device_memory.h"
 | 
			
		||||
            """
 | 
			
		||||
        )
 | 
			
		||||
        return res
 | 
			
		||||
 | 
			
		||||
    def globals(self) -> IndentedBuffer:
 | 
			
		||||
@ -281,42 +300,6 @@ class CUDATemplate(KernelTemplate):
 | 
			
		||||
                #endif
 | 
			
		||||
            """
 | 
			
		||||
        )
 | 
			
		||||
        return res
 | 
			
		||||
 | 
			
		||||
    def render(self, **kwargs) -> str:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def get_runtime_arg_info(self) -> list[ArgInfo]:
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    def get_runtime_arg_values(self, **kwargs) -> list[Any]:
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CUTLASSTemplate(CUDATemplate):
 | 
			
		||||
    """
 | 
			
		||||
    CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the
 | 
			
		||||
    CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def header(self) -> IndentedBuffer:
 | 
			
		||||
        res = super().header()
 | 
			
		||||
        res.splice(
 | 
			
		||||
            """
 | 
			
		||||
                #include "cute/tensor.hpp"
 | 
			
		||||
                #include "cutlass/cutlass.h"
 | 
			
		||||
                #include "cutlass/numeric_types.h"
 | 
			
		||||
                #include "cutlass/tensor_ref.h"
 | 
			
		||||
                #include "cutlass/util/host_tensor.h"
 | 
			
		||||
                #include "cutlass/util/reference/host/tensor_fill.h"
 | 
			
		||||
                #include "cutlass/util/reference/device/tensor_fill.h"
 | 
			
		||||
                #include "cutlass/util/device_memory.h"
 | 
			
		||||
            """
 | 
			
		||||
        )
 | 
			
		||||
        return res
 | 
			
		||||
 | 
			
		||||
    def globals(self) -> IndentedBuffer:
 | 
			
		||||
        res = super().globals()
 | 
			
		||||
        res.splice(
 | 
			
		||||
            """
 | 
			
		||||
                using namespace cute;
 | 
			
		||||
@ -382,11 +365,12 @@ class CUTLASSTemplate(CUDATemplate):
 | 
			
		||||
                f"({self._DTYPE_TO_CUTLASS_SPARSE_META.get(node.get_dtype())}*)({ptr})"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def render(self, **kwargs) -> str:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def get_runtime_arg_info(self) -> list[ArgInfo]:
 | 
			
		||||
        return [ArgInfo("swizzle", "const uint8_t")]
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_runtime_arg_values(self, **kwargs) -> list[Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Helper method to retrieve runtime args from generate kwargs
 | 
			
		||||
@ -7,7 +7,6 @@ import shutil
 | 
			
		||||
import sys
 | 
			
		||||
import time
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
from typing_extensions import TypeIs
 | 
			
		||||
 | 
			
		||||
@ -22,8 +21,7 @@ from ... import config
 | 
			
		||||
from ...ir import Layout
 | 
			
		||||
from ...runtime.runtime_utils import cache_dir
 | 
			
		||||
from ...virtualized import V
 | 
			
		||||
from ..cpp_utils import DTYPE_TO_CPP
 | 
			
		||||
from .cuda_env import get_cuda_arch, get_cuda_version
 | 
			
		||||
from ..cuda.cuda_env import get_cuda_arch, get_cuda_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
@ -98,14 +96,14 @@ def try_import_cutlass() -> bool:
 | 
			
		||||
 | 
			
		||||
    # contains both cutlass and cutlass_library
 | 
			
		||||
    # we need cutlass for eVT
 | 
			
		||||
    cutlass_python_path = path_join(config.cuda.cutlass_dir, "python")
 | 
			
		||||
    cutlass_python_path = path_join(config.cutlass.cutlass_dir, "python")
 | 
			
		||||
    torch_root = os.path.abspath(os.path.dirname(torch.__file__))
 | 
			
		||||
    mock_src_path = os.path.join(
 | 
			
		||||
        torch_root,
 | 
			
		||||
        "_inductor",
 | 
			
		||||
        "codegen",
 | 
			
		||||
        "cuda",
 | 
			
		||||
        "cutlass_lib_extensions",
 | 
			
		||||
        "cutlass",
 | 
			
		||||
        "lib_extensions",
 | 
			
		||||
        "cutlass_mock_imports",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -252,7 +250,7 @@ def _gen_ops_cached(arch, version) -> dict[Any, Any]:
 | 
			
		||||
        )
 | 
			
		||||
        return {}
 | 
			
		||||
    arch = _normalize_cuda_arch(arch)
 | 
			
		||||
    instantiation_level: str = config.cuda.cutlass_instantiation_level
 | 
			
		||||
    instantiation_level: str = config.cutlass.cutlass_instantiation_level
 | 
			
		||||
    args = CUTLASSArgs(
 | 
			
		||||
        architectures=arch,
 | 
			
		||||
        cuda_version=version,
 | 
			
		||||
@ -293,6 +291,9 @@ def gen_ops() -> dict[Any, Any]:
 | 
			
		||||
        return _gen_ops_cached(arch, version)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from ..cpp_utils import DTYPE_TO_CPP
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
DTYPE_TO_CUTLASS_TYPE = {
 | 
			
		||||
    **DTYPE_TO_CPP,
 | 
			
		||||
    torch.float16: "__half",
 | 
			
		||||
@ -447,47 +448,3 @@ def get_max_alignment(inductor_layout: Layout) -> int:
 | 
			
		||||
        ):
 | 
			
		||||
            return alignment
 | 
			
		||||
    return 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CUDACompileSourceCapturingContext:
 | 
			
		||||
    # Helper class for Benchmarking and Testing CUTLASS Kernels in isolation.
 | 
			
		||||
    # Can be used to capture the sourcecode passed to CUDACodeCache.compile
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.sources = []
 | 
			
		||||
        self._compile_patch = None
 | 
			
		||||
 | 
			
		||||
    def __enter__(self, *args, **kwargs):
 | 
			
		||||
        import unittest.mock as mock
 | 
			
		||||
 | 
			
		||||
        import torch._inductor.codecache
 | 
			
		||||
 | 
			
		||||
        _compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile
 | 
			
		||||
 | 
			
		||||
        def my_compile(
 | 
			
		||||
            source_code, dst_file_ext, extra_args: Optional[list[str]] = None
 | 
			
		||||
        ):
 | 
			
		||||
            self.sources.append(source_code)
 | 
			
		||||
            return _compile_method_orig(source_code, dst_file_ext)
 | 
			
		||||
 | 
			
		||||
        # pyrefly: ignore [bad-assignment]
 | 
			
		||||
        self._compile_patch = mock.patch(
 | 
			
		||||
            "torch._inductor.codecache.CUDACodeCache.compile", my_compile
 | 
			
		||||
        )
 | 
			
		||||
        self._compile_patch.__enter__(*args, **kwargs)  # type: ignore[union-attr]
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def __exit__(self, *args, **kwargs):
 | 
			
		||||
        self._compile_patch.__exit__(*args, **kwargs)  # type: ignore[union-attr]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cuda_standalone_runner_compile_command(srcpath: Path, exepath: Path):
 | 
			
		||||
    # returns command string to compile a (captured) CUDA GEMM Kernel source to a standalone executable that's ready to run
 | 
			
		||||
    # Passes the correct preprocessor define to nvcc to ensure the standalone runner is enabled.
 | 
			
		||||
    from torch._inductor.codecache import cuda_compile_command
 | 
			
		||||
 | 
			
		||||
    extra_args = ["-DGENERATE_STANDALONE_RUNNER=1", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]
 | 
			
		||||
    compile_command = cuda_compile_command(
 | 
			
		||||
        [str(srcpath)], str(exepath), "exe", extra_args=extra_args
 | 
			
		||||
    )
 | 
			
		||||
    return compile_command
 | 
			
		||||
@ -19,7 +19,7 @@ class ROCmCPPScheduling(BaseScheduling):
 | 
			
		||||
    """
 | 
			
		||||
    Partial Scheduling implementation for ROCm C++ Kernels.
 | 
			
		||||
    This class is intended to be used in combination with TritonScheduling,
 | 
			
		||||
    and delegated to by CUDACombinedScheduling.
 | 
			
		||||
    and delegated to by CombinedScheduling.
 | 
			
		||||
 | 
			
		||||
    It handles fusion decisions and ROCm C++ specific template code generation.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -1764,7 +1764,7 @@ class SIMDScheduling(BaseScheduling):
 | 
			
		||||
                    partial_code.finalize_hook("<DEF_KERNEL>")
 | 
			
		||||
                partial_code.finalize_hook("<ARGDEFS>", strict=False)
 | 
			
		||||
 | 
			
		||||
            # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion.
 | 
			
		||||
            # TODO: Maybe unify CUTLASSTemplateKernel to also use PartialRender for flexible epilogue fusion.
 | 
			
		||||
 | 
			
		||||
            for input_name in kernel.named_input_nodes.keys():
 | 
			
		||||
                subgraph_name = f"<LOAD_INPUT_{input_name}>"
 | 
			
		||||
 | 
			
		||||
@ -5886,14 +5886,12 @@ def debug_triton_code(node: BaseSchedulerNode) -> list[str]:
 | 
			
		||||
    if multi_template and multi_template.make_kernel_render is None:
 | 
			
		||||
        lines.append(f"{node.get_name()} Unfinalized multi template buffer")
 | 
			
		||||
    else:
 | 
			
		||||
        from torch._inductor.codegen.cuda_combined_scheduling import (
 | 
			
		||||
            CUDACombinedScheduling,
 | 
			
		||||
        )
 | 
			
		||||
        from torch._inductor.codegen.combined_scheduling import CombinedScheduling
 | 
			
		||||
 | 
			
		||||
        device = node.get_device()
 | 
			
		||||
        assert device is not None
 | 
			
		||||
        backend = node.scheduler.get_backend(device)
 | 
			
		||||
        assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)), (
 | 
			
		||||
        assert isinstance(backend, (SIMDScheduling, CombinedScheduling)), (
 | 
			
		||||
            f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1713,28 +1713,12 @@ class aot_inductor_mode:
 | 
			
		||||
    compile_standalone: bool = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class cuda:
 | 
			
		||||
    """Settings for cuda backend, today this consists of cutlass"""
 | 
			
		||||
 | 
			
		||||
    # CUDA arch to use for CUDA template kernel compilation.
 | 
			
		||||
    # e.g. "70", "75", "80", "90", etc.
 | 
			
		||||
    # When arch is None, Inductor uses torch.cuda.get_device_capability(0).
 | 
			
		||||
    arch: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    # CUDA version to use for CUDA template kernel compilation.
 | 
			
		||||
    # e.g. "11.4", "12.1", etc.
 | 
			
		||||
    # When version is None, Inductor uses torch.version.cuda.
 | 
			
		||||
    version: Optional[str] = None
 | 
			
		||||
class cutlass:
 | 
			
		||||
    """Settings for cutlass backend, today this consists of cutlass"""
 | 
			
		||||
 | 
			
		||||
    # Optimization level for the host compiler.
 | 
			
		||||
    compile_opt_level: Literal["-O0", "-O1", "-O2", "-O3", "-OS"] = "-O1"
 | 
			
		||||
 | 
			
		||||
    # Whether to enable device LTO (link-time-optimization).
 | 
			
		||||
    enable_cuda_lto = False
 | 
			
		||||
 | 
			
		||||
    # Whether to keep intermediate files dring compilation.
 | 
			
		||||
    enable_ptxas_info = False
 | 
			
		||||
 | 
			
		||||
    # Whether to enable debug info, e.g. line number, cutlass debug info.
 | 
			
		||||
    enable_debug_info = False
 | 
			
		||||
 | 
			
		||||
@ -1746,7 +1730,10 @@ class cuda:
 | 
			
		||||
    cutlass_dir = os.path.realpath(
 | 
			
		||||
        os.environ.get(
 | 
			
		||||
            "TORCHINDUCTOR_CUTLASS_DIR",
 | 
			
		||||
            os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/"),
 | 
			
		||||
            os.path.join(
 | 
			
		||||
                os.path.dirname(torch.__file__),
 | 
			
		||||
                "../third_party/cutlass/",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -1766,14 +1753,6 @@ class cuda:
 | 
			
		||||
    # Whether to only use TMA-compatible kernels in CUTLASS
 | 
			
		||||
    cutlass_tma_only = False
 | 
			
		||||
 | 
			
		||||
    # Path to CUDA NVCC.
 | 
			
		||||
    # NVCC search order:
 | 
			
		||||
    # 1) cuda_cxx set in this config
 | 
			
		||||
    # 2) CUDACXX environment variable
 | 
			
		||||
    # 3) CUDA_HOME environment variable
 | 
			
		||||
    # 4) default system search PATH.
 | 
			
		||||
    cuda_cxx: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    # Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops.
 | 
			
		||||
    cutlass_backend_min_gemm_size: int = 1
 | 
			
		||||
 | 
			
		||||
@ -1843,6 +1822,53 @@ class cuda:
 | 
			
		||||
    enable_caching_codegen: bool = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class cuda(cutlass):
 | 
			
		||||
    """Settings for cutlass backend, today this consists of cutlass"""
 | 
			
		||||
 | 
			
		||||
    # CUDA arch to use for CUDA template kernel compilation.
 | 
			
		||||
    # e.g. "70", "75", "80", "90", etc.
 | 
			
		||||
    # When arch is None, Inductor uses torch.cuda.get_device_capability(0).
 | 
			
		||||
    arch: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    # CUDA version to use for CUDA template kernel compilation.
 | 
			
		||||
    # e.g. "11.4", "12.1", etc.
 | 
			
		||||
    # When version is None, Inductor uses torch.version.cuda.
 | 
			
		||||
    version: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    # Path to CUDA NVCC.
 | 
			
		||||
    # NVCC search order:
 | 
			
		||||
    # 1) cuda_cxx set in this config
 | 
			
		||||
    # 2) CUDACXX environment variable
 | 
			
		||||
    # 3) CUDA_HOME environment variable
 | 
			
		||||
    # 4) default system search PATH.
 | 
			
		||||
    cuda_cxx: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    # Whether to enable device LTO (link-time-optimization).
 | 
			
		||||
    enable_cuda_lto = False
 | 
			
		||||
 | 
			
		||||
    # Whether to keep intermediate files dring compilation.
 | 
			
		||||
    enable_ptxas_info = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class xpu(cutlass):
 | 
			
		||||
    # Xe arch to use for SYCL template kernel compilation.
 | 
			
		||||
    # eg. 12, 20, which corresponding to Xe12(PVC) and Xe20 (BMG)
 | 
			
		||||
    arch: Optional[str] = None
 | 
			
		||||
    # oneAPI version to use for SYCL template kernel compilation.
 | 
			
		||||
    # e.g. "20250201".
 | 
			
		||||
    version: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    cutlass_dir = os.path.realpath(
 | 
			
		||||
        os.environ.get(
 | 
			
		||||
            "TORCHINDUCTOR_CUTLASS_DIR",
 | 
			
		||||
            os.path.join(
 | 
			
		||||
                os.path.dirname(torch.__file__),
 | 
			
		||||
                "../third_party/sycl-tla/",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class rocm:
 | 
			
		||||
    # Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"].
 | 
			
		||||
    # If empty, the `native` arch is used
 | 
			
		||||
 | 
			
		||||
@ -2196,7 +2196,9 @@ class CppBuilder:
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if device_type == "cuda" and torch.version.hip is None:
 | 
			
		||||
            from torch._inductor.codecache import _nvcc_arch_as_compile_option
 | 
			
		||||
            from torch._inductor.codegen.cuda.compile_utils import (
 | 
			
		||||
                _nvcc_arch_as_compile_option,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            current_arch = _nvcc_arch_as_compile_option()
 | 
			
		||||
            contents += textwrap.dedent(
 | 
			
		||||
 | 
			
		||||
@ -489,7 +489,7 @@ MODULE_DEFAULTS: dict[str, ConfigType] = {
 | 
			
		||||
        "aot_inductor.presets": DEFAULT,  # Typing
 | 
			
		||||
        "cuda.arch": DEFAULT,  # Out of Scope
 | 
			
		||||
        "cuda.version": DEFAULT,  # Out of Scope
 | 
			
		||||
        "cuda.cutlass_dir": DEFAULT,  # Out of Scope
 | 
			
		||||
        "cutlass.cutlass_dir": DEFAULT,  # Out of Scope
 | 
			
		||||
        "cuda.cuda_cxx": DEFAULT,  # Out of Scope
 | 
			
		||||
        "rocm.arch": DEFAULT,  # Out of Scope
 | 
			
		||||
        "rocm.ck_supported_arch": DEFAULT,  # Out of Scope
 | 
			
		||||
 | 
			
		||||
@ -124,13 +124,13 @@ if TYPE_CHECKING:
 | 
			
		||||
    from torch.fx.experimental.symbolic_shapes import SympyBoolean
 | 
			
		||||
    from torch.fx.node import Argument
 | 
			
		||||
 | 
			
		||||
    from .codegen.cuda.cuda_template import CUDATemplate
 | 
			
		||||
    from .codegen.cutlass.template import CUTLASSTemplate
 | 
			
		||||
    from .codegen.wrapper import PythonWrapperCodegen
 | 
			
		||||
    from .graph import GraphLowering
 | 
			
		||||
    from .utils import IndentedBuffer
 | 
			
		||||
 | 
			
		||||
else:
 | 
			
		||||
    CUDATemplate: TypeAlias = object
 | 
			
		||||
    CUTLASSTemplate: TypeAlias = object
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
@ -5017,7 +5017,7 @@ class ChoiceCaller:
 | 
			
		||||
    During autotuning, self.benchmark() is first called to get benchmark result,
 | 
			
		||||
    and if this choice is selected, self.output_node() is called to get the output_node.
 | 
			
		||||
 | 
			
		||||
    Children classes: TritonTemplateCaller, CUDATemplateCaller.
 | 
			
		||||
    Children classes: TritonTemplateCaller, CUTLASSTemplateCaller.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
@ -5180,14 +5180,14 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
 | 
			
		||||
        self.make_kernel_render = self._make_kernel_renders[None]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CUDATemplateBuffer(TemplateBuffer):
 | 
			
		||||
class CUTLASSTemplateBuffer(TemplateBuffer):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        layout: Layout,
 | 
			
		||||
        inputs: Sequence[IRNode],
 | 
			
		||||
        make_kernel_render: Callable[_P, _T],
 | 
			
		||||
        workspace_size: int,
 | 
			
		||||
        template: CUDATemplate,
 | 
			
		||||
        template: CUTLASSTemplate,
 | 
			
		||||
        supports_epilogue_fusion: bool,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__(layout, inputs, make_kernel_render)
 | 
			
		||||
@ -5210,7 +5210,7 @@ class CppTemplateBuffer(TemplateBuffer):
 | 
			
		||||
        layout: Layout,
 | 
			
		||||
        inputs: Sequence[IRNode],
 | 
			
		||||
        make_kernel_render: Callable[_P, _T],
 | 
			
		||||
        template: CUDATemplate,
 | 
			
		||||
        template: CUTLASSTemplate,
 | 
			
		||||
        choice: Any,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__(layout, inputs, make_kernel_render)
 | 
			
		||||
 | 
			
		||||
@ -261,7 +261,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
 | 
			
		||||
        and use_cutlass_template(layout, m, n, k)
 | 
			
		||||
        and _use_cutlass_for_op(name)
 | 
			
		||||
    ):
 | 
			
		||||
        from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
 | 
			
		||||
        from ..codegen.cutlass.gemm_template import CUTLASS3xGemmTemplate
 | 
			
		||||
 | 
			
		||||
        CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
 | 
			
		||||
            choices, layout, kernel_inputs.nodes()
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,7 @@ from torch.nn.functional import ScalingType  # type: ignore[attr-defined]
 | 
			
		||||
from torch.torch_version import TorchVersion
 | 
			
		||||
 | 
			
		||||
from .. import config as inductor_config
 | 
			
		||||
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
 | 
			
		||||
from ..codegen.cutlass.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
 | 
			
		||||
from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate
 | 
			
		||||
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
 | 
			
		||||
from ..codegen.subgraph import SubgraphChoiceCaller, SubgraphTemplate
 | 
			
		||||
 | 
			
		||||
@ -5510,10 +5510,10 @@ class Scheduler:
 | 
			
		||||
                node = typing.cast(ForeachKernelSchedulerNode, node)
 | 
			
		||||
                # pyrefly: ignore [unbound-name]
 | 
			
		||||
                backend_ = self.get_backend(device)
 | 
			
		||||
                from .codegen.cuda_combined_scheduling import CUDACombinedScheduling
 | 
			
		||||
                from .codegen.combined_scheduling import CombinedScheduling
 | 
			
		||||
                from .codegen.simd import SIMDScheduling
 | 
			
		||||
 | 
			
		||||
                if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)):
 | 
			
		||||
                if isinstance(backend_, (SIMDScheduling, CombinedScheduling)):
 | 
			
		||||
                    backend = backend_
 | 
			
		||||
                else:
 | 
			
		||||
                    raise AssertionError(f"{type(self)=}")
 | 
			
		||||
 | 
			
		||||
@ -2665,7 +2665,7 @@ class AlgorithmSelectorCache(PersistentCache):
 | 
			
		||||
        return_multi_template=False,
 | 
			
		||||
        best_config_future=None,
 | 
			
		||||
    ):
 | 
			
		||||
        from .codegen.cuda.cuda_kernel import CUDATemplateCaller
 | 
			
		||||
        from .codegen.cutlass.kernel import CUTLASSTemplateCaller
 | 
			
		||||
 | 
			
		||||
        # Run preprocessing functions on choices
 | 
			
		||||
        for preprocessing_fn in self.preprocessing_fns:
 | 
			
		||||
@ -2701,8 +2701,8 @@ class AlgorithmSelectorCache(PersistentCache):
 | 
			
		||||
        log.debug("Max autotune selects from %s choices.", str(len(choices)))
 | 
			
		||||
 | 
			
		||||
        if len(choices) == 1:
 | 
			
		||||
            if not isinstance(choices[0], CUDATemplateCaller):
 | 
			
		||||
                # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size.
 | 
			
		||||
            if not isinstance(choices[0], CUTLASSTemplateCaller):
 | 
			
		||||
                # CUTLASSTemplateCaller still needs to go through autotuning process to retrieve workspace size.
 | 
			
		||||
                return choices[0].output_node()
 | 
			
		||||
 | 
			
		||||
        if config.deterministic:
 | 
			
		||||
@ -3086,12 +3086,12 @@ class AlgorithmSelectorCache(PersistentCache):
 | 
			
		||||
                        "select_algorithm_num_precompilation_exceptions"
 | 
			
		||||
                    ] += 1
 | 
			
		||||
                    exceptions.append((futures[future], e))
 | 
			
		||||
                    from torch._inductor.codegen.cuda.cuda_kernel import (
 | 
			
		||||
                        CUDATemplateCaller,
 | 
			
		||||
                    from torch._inductor.codegen.cutlass.kernel import (
 | 
			
		||||
                        CUTLASSTemplateCaller,
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                    if isinstance(e, CUDACompileError) and isinstance(
 | 
			
		||||
                        futures[future], CUDATemplateCaller
 | 
			
		||||
                        futures[future], CUTLASSTemplateCaller
 | 
			
		||||
                    ):
 | 
			
		||||
                        log.debug(
 | 
			
		||||
                            "Exception %s for benchmark choice %s",
 | 
			
		||||
@ -3272,19 +3272,20 @@ class AlgorithmSelectorCache(PersistentCache):
 | 
			
		||||
        for choice in choices:
 | 
			
		||||
            try:
 | 
			
		||||
                timing = cls.benchmark_choice(choice, autotune_args)
 | 
			
		||||
            except CUDACompileError:
 | 
			
		||||
                from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
 | 
			
		||||
            except CUDACompileError as e:
 | 
			
		||||
                from torch._inductor.codegen.cutlass.kernel import CUTLASSTemplateCaller
 | 
			
		||||
 | 
			
		||||
                if not isinstance(choice, CUDATemplateCaller):
 | 
			
		||||
                if not isinstance(choice, CUTLASSTemplateCaller):
 | 
			
		||||
                    log.exception(
 | 
			
		||||
                        "CUDA compilation error during autotuning: \n%s. \nIgnoring this choice."
 | 
			
		||||
                        "CUDA compilation error during autotuning: \n%s. \nIgnoring this choice.",
 | 
			
		||||
                        e,
 | 
			
		||||
                    )
 | 
			
		||||
                timing = float("inf")
 | 
			
		||||
            except NotImplementedError:
 | 
			
		||||
                log.warning("Not yet implemented", exc_info=True)
 | 
			
		||||
                timing = float("inf")
 | 
			
		||||
            except RuntimeError as e:
 | 
			
		||||
                from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
 | 
			
		||||
                from torch._inductor.codegen.cutlass.kernel import CUTLASSTemplateCaller
 | 
			
		||||
 | 
			
		||||
                msg = str(e)
 | 
			
		||||
                if "invalid argument" in msg:
 | 
			
		||||
@ -3295,7 +3296,7 @@ class AlgorithmSelectorCache(PersistentCache):
 | 
			
		||||
                    msg += "\n\nAn unrecoverable unspecified launch failure was caught during autotuning."
 | 
			
		||||
                    msg += "\nPlease try re-running with TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1.\n\n"
 | 
			
		||||
 | 
			
		||||
                if isinstance(choice, CUDATemplateCaller):
 | 
			
		||||
                if isinstance(choice, CUTLASSTemplateCaller):
 | 
			
		||||
                    log.debug(
 | 
			
		||||
                        "Runtime error during autotuning: \n%s. \nIgnoring this choice.",
 | 
			
		||||
                        msg,
 | 
			
		||||
@ -3429,18 +3430,18 @@ class AlgorithmSelectorCache(PersistentCache):
 | 
			
		||||
            return prescreen_winners
 | 
			
		||||
 | 
			
		||||
        # prescreen cutlass
 | 
			
		||||
        from .codegen.cuda.cuda_kernel import CUDATemplateCaller
 | 
			
		||||
        from .codegen.cutlass.kernel import CUTLASSTemplateCaller
 | 
			
		||||
 | 
			
		||||
        candidates = []
 | 
			
		||||
        if (
 | 
			
		||||
            config.cuda.cutlass_prescreening
 | 
			
		||||
            and len(config.cuda.cutlass_max_profiling_swizzle_options) > 1
 | 
			
		||||
            config.cutlass.cutlass_prescreening
 | 
			
		||||
            and len(config.cutlass.cutlass_max_profiling_swizzle_options) > 1
 | 
			
		||||
        ):
 | 
			
		||||
            candidates.extend(
 | 
			
		||||
                [
 | 
			
		||||
                    c
 | 
			
		||||
                    for c in choices
 | 
			
		||||
                    if isinstance(c, CUDATemplateCaller)
 | 
			
		||||
                    if isinstance(c, CUTLASSTemplateCaller)
 | 
			
		||||
                    # hardcoded to only look at swizzle=2
 | 
			
		||||
                    if c.info_dict().get("swizzle") == "2"
 | 
			
		||||
                ]
 | 
			
		||||
@ -3463,7 +3464,7 @@ class AlgorithmSelectorCache(PersistentCache):
 | 
			
		||||
        """
 | 
			
		||||
        Prune the choices after prescreening.
 | 
			
		||||
        """
 | 
			
		||||
        from .codegen.cuda.cuda_kernel import CUDATemplateCaller
 | 
			
		||||
        from .codegen.cutlass.kernel import CUTLASSTemplateCaller
 | 
			
		||||
 | 
			
		||||
        prescreen_key = f"{name}:{inputs_key}"
 | 
			
		||||
 | 
			
		||||
@ -3477,7 +3478,7 @@ class AlgorithmSelectorCache(PersistentCache):
 | 
			
		||||
            pruned_choices = [
 | 
			
		||||
                choice
 | 
			
		||||
                for choice in choices
 | 
			
		||||
                if not isinstance(choice, CUDATemplateCaller)
 | 
			
		||||
                if not isinstance(choice, CUTLASSTemplateCaller)
 | 
			
		||||
                or choice.kernel_hash_key() in winner_kernel_hashes
 | 
			
		||||
            ]
 | 
			
		||||
            return pruned_choices
 | 
			
		||||
@ -3523,7 +3524,7 @@ class AlgorithmSelectorCache(PersistentCache):
 | 
			
		||||
                candidates_to_prune.add(candidate.kernel_hash_key())
 | 
			
		||||
            else:
 | 
			
		||||
                winner_hashes.add(candidate.hash_key())
 | 
			
		||||
                if isinstance(candidate, CUDATemplateCaller):
 | 
			
		||||
                if isinstance(candidate, CUTLASSTemplateCaller):
 | 
			
		||||
                    candidate.bmreq.ensure_dll_loaded()
 | 
			
		||||
 | 
			
		||||
        pruned_choices = [
 | 
			
		||||
 | 
			
		||||
@ -1872,9 +1872,9 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
 | 
			
		||||
    from .virtualized import V
 | 
			
		||||
 | 
			
		||||
    gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1)
 | 
			
		||||
    if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size:
 | 
			
		||||
    if gemm_size <= 0 or gemm_size < config.cutlass.cutlass_backend_min_gemm_size:
 | 
			
		||||
        return False
 | 
			
		||||
    from .codegen.cuda.cutlass_utils import try_import_cutlass
 | 
			
		||||
    from .codegen.cutlass.utils import try_import_cutlass
 | 
			
		||||
 | 
			
		||||
    # Do not use cutlass template on ROCm
 | 
			
		||||
    if torch.version.hip:
 | 
			
		||||
@ -1893,9 +1893,9 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
 | 
			
		||||
        if not try_import_cutlass():
 | 
			
		||||
            log.warning(
 | 
			
		||||
                "Failed to import CUTLASS lib. Please check whether "
 | 
			
		||||
                "_inductor.config.cuda.cutlass_dir %s is set correctly. "
 | 
			
		||||
                "_inductor.config.cutlass.cutlass_dir %s is set correctly. "
 | 
			
		||||
                "Skipping CUTLASS backend for now.",
 | 
			
		||||
                config.cuda.cutlass_dir,
 | 
			
		||||
                config.cutlass.cutlass_dir,
 | 
			
		||||
            )
 | 
			
		||||
            return False
 | 
			
		||||
    return res
 | 
			
		||||
@ -1903,7 +1903,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
 | 
			
		||||
 | 
			
		||||
def _use_cutlass_for_op(op_name: str) -> bool:
 | 
			
		||||
    """Check if CUTLASS should be used for the given operation."""
 | 
			
		||||
    enabled_ops = config.cuda.cutlass_enabled_ops.upper()
 | 
			
		||||
    enabled_ops = config.cutlass.cutlass_enabled_ops.upper()
 | 
			
		||||
    if enabled_ops == "ALL":
 | 
			
		||||
        return True
 | 
			
		||||
    return op_name.upper() in [x.strip() for x in enabled_ops.split(",")]
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user