mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 04:54:55 +08:00
Compare commits
40 Commits
csl/xml_st
...
ciflow/ind
| Author | SHA1 | Date | |
|---|---|---|---|
| 10fb096070 | |||
| d943cfe814 | |||
| 0839124dae | |||
| b7e073c4d0 | |||
| 08ceb31352 | |||
| 8c6cab4dad | |||
| 0a0edbbfcf | |||
| 690ed68946 | |||
| 76ab0f1a09 | |||
| 5bd370a702 | |||
| b1eb229d30 | |||
| 1c4b94c0d2 | |||
| 0f60e4c18d | |||
| 23e3d4cfe2 | |||
| 4124804146 | |||
| 663aa67934 | |||
| f0d3f7db52 | |||
| ec887d1962 | |||
| 2cae11a0d1 | |||
| c5b67af4e7 | |||
| 53e25e520e | |||
| 4381da8371 | |||
| 938b5b7424 | |||
| 1fc3e1abe8 | |||
| f7d1d70526 | |||
| 3ae7cacd26 | |||
| 2060620b08 | |||
| 501a28a3e7 | |||
| ac2acb2097 | |||
| 8bc69def41 | |||
| 874fc3199c | |||
| a0819a3a88 | |||
| e36aeb0a2e | |||
| c924349562 | |||
| c4a551d72f | |||
| 08f5fe8139 | |||
| 8604af4bfa | |||
| f4ef77b220 | |||
| 44789dc58d | |||
| 2ddfc76a10 |
@ -125,7 +125,7 @@ class CutlassExperimentConfig(ExperimentConfig):
|
||||
def to_options(self) -> dict[str, Any]:
|
||||
return {
|
||||
**super().to_options(),
|
||||
"cuda.cutlass_instantiation_level": self.cutlass_instantiation_level,
|
||||
"cutlass.cutlass_instantiation_level": self.cutlass_instantiation_level,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -24,7 +24,6 @@ from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
from torch._inductor import config, metrics
|
||||
from torch._inductor.codecache import (
|
||||
BypassFxGraphCache,
|
||||
cuda_compile_command,
|
||||
CUDACodeCache,
|
||||
FxGraphCachePickler,
|
||||
FxGraphHashDetails,
|
||||
@ -32,6 +31,7 @@ from torch._inductor.codecache import (
|
||||
TensorMetadata,
|
||||
TensorMetadataAndValues,
|
||||
)
|
||||
from torch._inductor.codegen.cuda.compile_utils import cuda_compile_command
|
||||
from torch._inductor.cpp_builder import normalize_path_separator
|
||||
from torch._inductor.custom_graph_pass import (
|
||||
CustomGraphModulePass,
|
||||
|
||||
@ -14,7 +14,9 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from torch._dynamo.exc import BackendCompilerFailed
|
||||
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
|
||||
from torch._inductor.codegen.cutlass.serialization import (
|
||||
get_cutlass_operation_serializer,
|
||||
)
|
||||
from torch._inductor.utils import clear_caches
|
||||
from torch.export import Dim
|
||||
from torch.testing._internal.logging_utils import log_settings
|
||||
@ -32,11 +34,8 @@ import torch.version
|
||||
from torch._dynamo import config as dynamo_config
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
|
||||
from torch._inductor.codegen.cuda.cutlass_utils import (
|
||||
_gen_ops_cached,
|
||||
get_max_alignment,
|
||||
)
|
||||
from torch._inductor.codegen.cutlass.kernel import CUTLASSTemplateCaller
|
||||
from torch._inductor.codegen.cutlass.utils import _gen_ops_cached, get_max_alignment
|
||||
from torch._inductor.exc import InductorError
|
||||
from torch._inductor.ir import FixedLayout
|
||||
from torch._inductor.select_algorithm import NoValidChoicesError
|
||||
@ -133,10 +132,10 @@ use_evt_config = config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 1,
|
||||
"cutlass.cutlass_max_profiling_configs": 1,
|
||||
"benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet
|
||||
"cuda.cutlass_tma_only": True,
|
||||
"cuda.cutlass_epilogue_fusion_enabled": True,
|
||||
"cutlass.cutlass_tma_only": True,
|
||||
"cutlass.cutlass_epilogue_fusion_enabled": True,
|
||||
}
|
||||
)
|
||||
|
||||
@ -144,9 +143,9 @@ fp8_config = config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 1,
|
||||
"cutlass.cutlass_max_profiling_configs": 1,
|
||||
"benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet
|
||||
"cuda.cutlass_tma_only": True,
|
||||
"cutlass.cutlass_tma_only": True,
|
||||
}
|
||||
)
|
||||
|
||||
@ -198,7 +197,7 @@ class TestCutlassBackend(TestCase):
|
||||
ref_result = model(a, b, extra_args)
|
||||
|
||||
self.assertEqual(
|
||||
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"],
|
||||
torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
|
||||
num_fusions,
|
||||
)
|
||||
torch.testing.assert_close(result, ref_result)
|
||||
@ -206,7 +205,7 @@ class TestCutlassBackend(TestCase):
|
||||
def test_check_paths(self):
|
||||
cutlass_mock_imports_path = os.path.join(
|
||||
os.path.dirname(torch.__file__),
|
||||
"_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports",
|
||||
"_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports",
|
||||
)
|
||||
cutlass_mock_cuda_path = os.path.join(cutlass_mock_imports_path, "cuda")
|
||||
cutlass_mock_pydot_path = os.path.join(cutlass_mock_imports_path, "pydot")
|
||||
@ -234,8 +233,8 @@ class TestCutlassBackend(TestCase):
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"compile_threads": 4,
|
||||
"cuda.cutlass_backend_min_gemm_size": 100000,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_backend_min_gemm_size": 100000,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
with mock.patch(
|
||||
@ -251,7 +250,7 @@ class TestCutlassBackend(TestCase):
|
||||
|
||||
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_import_cutlass(self):
|
||||
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
|
||||
from torch._inductor.codegen.cutlass.utils import try_import_cutlass
|
||||
|
||||
self.assertTrue(try_import_cutlass())
|
||||
|
||||
@ -259,7 +258,7 @@ class TestCutlassBackend(TestCase):
|
||||
import cutlass_library # noqa: F401
|
||||
|
||||
def test_cutlass_key(self):
|
||||
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
|
||||
from torch._inductor.codegen.cutlass.utils import try_import_cutlass
|
||||
|
||||
self.assertTrue(try_import_cutlass())
|
||||
from torch._inductor.codecache import cutlass_key
|
||||
@ -287,7 +286,7 @@ class TestCutlassBackend(TestCase):
|
||||
"autotune_in_subproc": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"compile_threads": 4,
|
||||
"cuda.cutlass_max_profiling_configs": 4,
|
||||
"cutlass.cutlass_max_profiling_configs": 4,
|
||||
}
|
||||
):
|
||||
Y_compiled = torch.compile(torch.mm)(a, b)
|
||||
@ -324,7 +323,7 @@ class TestCutlassBackend(TestCase):
|
||||
"autotune_in_subproc": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"compile_threads": 4,
|
||||
"cuda.cutlass_max_profiling_configs": 4,
|
||||
"cutlass.cutlass_max_profiling_configs": 4,
|
||||
}
|
||||
):
|
||||
for x_shape in x_shapes:
|
||||
@ -354,7 +353,7 @@ class TestCutlassBackend(TestCase):
|
||||
"autotune_in_subproc": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"compile_threads": 4,
|
||||
"cuda.cutlass_max_profiling_configs": 4,
|
||||
"cutlass.cutlass_max_profiling_configs": 4,
|
||||
}
|
||||
):
|
||||
Y_compiled = torch.compile(torch.bmm)(a, b)
|
||||
@ -386,7 +385,7 @@ class TestCutlassBackend(TestCase):
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_max_profiling_configs": 1,
|
||||
"cutlass.cutlass_max_profiling_configs": 1,
|
||||
}
|
||||
):
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
@ -428,8 +427,8 @@ class TestCutlassBackend(TestCase):
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_max_profiling_configs": 1,
|
||||
"cuda.cutlass_max_profiling_swizzle_options": [
|
||||
"cutlass.cutlass_max_profiling_configs": 1,
|
||||
"cutlass.cutlass_max_profiling_swizzle_options": [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
@ -505,7 +504,7 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
),
|
||||
dynamo_config.patch({"error_on_recompile": dynamic}),
|
||||
@ -595,9 +594,9 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
"benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet
|
||||
"cuda.cutlass_tma_only": True,
|
||||
"cutlass.cutlass_tma_only": True,
|
||||
}
|
||||
),
|
||||
dynamo_config.patch({"error_on_recompile": dynamic}),
|
||||
@ -677,7 +676,7 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
),
|
||||
dynamo_config.patch({"error_on_recompile": dynamic}),
|
||||
@ -746,7 +745,7 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
expected = [model(*input) for input in inputs]
|
||||
@ -775,8 +774,8 @@ class TestCutlassBackend(TestCase):
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels
|
||||
}
|
||||
):
|
||||
for M, K, N in (
|
||||
@ -819,7 +818,7 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels
|
||||
"cutlass.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels
|
||||
}
|
||||
):
|
||||
with self.assertRaisesRegex(InductorError, r".*NoValidChoicesError.*"):
|
||||
@ -849,8 +848,8 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 1,
|
||||
"cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels
|
||||
"cutlass.cutlass_max_profiling_configs": 1,
|
||||
"cutlass.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels
|
||||
}
|
||||
):
|
||||
_ = compiled_model(a, b)
|
||||
@ -884,15 +883,15 @@ class TestCutlassBackend(TestCase):
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_max_profiling_configs": 4,
|
||||
"cuda.version": "12.2", # required to enable the Kernels we need
|
||||
"cutlass.cutlass_max_profiling_configs": 4,
|
||||
"cuda.cuda_version": "12.2", # required to enable the Kernels we need
|
||||
}
|
||||
):
|
||||
counters["inductor"]["cuda_epilogue_fusion_counter"] = 0
|
||||
counters["inductor"]["cutlass_epilogue_fusion_counter"] = 0
|
||||
assert mm is not None
|
||||
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
|
||||
Y = mm(a, b)
|
||||
actual_count = counters["inductor"]["cuda_epilogue_fusion_counter"]
|
||||
actual_count = counters["inductor"]["cutlass_epilogue_fusion_counter"]
|
||||
assert actual_count == expected_fuse_count, (
|
||||
f"Expected fuse count of {expected_fuse_count} but got {actual_count}"
|
||||
)
|
||||
@ -983,7 +982,7 @@ class TestCutlassBackend(TestCase):
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
|
||||
@ -1002,7 +1001,7 @@ class TestCutlassBackend(TestCase):
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": False,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
model = MyModel()
|
||||
@ -1040,7 +1039,7 @@ class TestCutlassBackend(TestCase):
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": False,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
model = MyModel()
|
||||
@ -1073,8 +1072,8 @@ class TestCutlassBackend(TestCase):
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": False,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_op_allowlist_regex": "128x256x64.*stream_k_warpspecialized_cooperative_epi_nosmem",
|
||||
"cuda.cutlass_max_profiling_configs": 1,
|
||||
"cutlass.cutlass_op_allowlist_regex": "128x256x64.*stream_k_warpspecialized_cooperative_epi_nosmem",
|
||||
"cutlass.cutlass_max_profiling_configs": 1,
|
||||
}
|
||||
):
|
||||
model = MyModel()
|
||||
@ -1117,7 +1116,7 @@ class TestCutlassBackend(TestCase):
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
"autotune_local_cache": True,
|
||||
}
|
||||
):
|
||||
@ -1157,9 +1156,9 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cuda.cutlass_op_allowlist_regex": "",
|
||||
"cuda.cutlass_op_denylist_regex": "pingpong",
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_op_allowlist_regex": "",
|
||||
"cutlass.cutlass_op_denylist_regex": "pingpong",
|
||||
}
|
||||
):
|
||||
with mock.patch(
|
||||
@ -1175,7 +1174,7 @@ class TestCutlassBackend(TestCase):
|
||||
assert op_name == "addmm"
|
||||
cuda_template_count = 0
|
||||
for choice in choices:
|
||||
if isinstance(choice, CUDATemplateCaller):
|
||||
if isinstance(choice, CUTLASSTemplateCaller):
|
||||
choice_info = choice.info_dict()
|
||||
op_conf_name = choice_info.get("op_conf_name", "")
|
||||
assert isinstance(op_conf_name, str)
|
||||
@ -1183,7 +1182,7 @@ class TestCutlassBackend(TestCase):
|
||||
"All pingpong Kernels should have been filtered"
|
||||
)
|
||||
cuda_template_count += 1
|
||||
assert cuda_template_count > 0, "No CUDATemplateCaller choices"
|
||||
assert cuda_template_count > 0, "No CUTLASSTemplateCaller choices"
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
@ -1202,9 +1201,9 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cuda.cutlass_op_allowlist_regex": "pingpong",
|
||||
"cuda.cutlass_op_denylist_regex": None,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_op_allowlist_regex": "pingpong",
|
||||
"cutlass.cutlass_op_denylist_regex": None,
|
||||
}
|
||||
):
|
||||
with mock.patch(
|
||||
@ -1220,7 +1219,7 @@ class TestCutlassBackend(TestCase):
|
||||
assert op_name == "addmm"
|
||||
cuda_template_count = 0
|
||||
for choice in choices:
|
||||
if isinstance(choice, CUDATemplateCaller):
|
||||
if isinstance(choice, CUTLASSTemplateCaller):
|
||||
choice_info = choice.info_dict()
|
||||
op_conf_name = choice_info.get("op_conf_name", "")
|
||||
assert isinstance(op_conf_name, str)
|
||||
@ -1228,7 +1227,7 @@ class TestCutlassBackend(TestCase):
|
||||
"Only pingpong Kernels should have been allowed"
|
||||
)
|
||||
cuda_template_count += 1
|
||||
assert cuda_template_count > 0, "No CUDATemplateCaller choices"
|
||||
assert cuda_template_count > 0, "No CUTLASSTemplateCaller choices"
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
@ -1273,7 +1272,7 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
with mock.patch(
|
||||
@ -1296,7 +1295,7 @@ class TestCutlassBackend(TestCase):
|
||||
_, choices, _, _ = args
|
||||
cuda_template_count = 0
|
||||
for choice in choices:
|
||||
if isinstance(choice, CUDATemplateCaller):
|
||||
if isinstance(choice, CUTLASSTemplateCaller):
|
||||
choice_info = choice.info_dict()
|
||||
op_conf_name = choice_info.get("op_conf_name", "")
|
||||
assert isinstance(op_conf_name, str)
|
||||
@ -1309,7 +1308,9 @@ class TestCutlassBackend(TestCase):
|
||||
"fastaccum Kernels should have been filtered"
|
||||
)
|
||||
cuda_template_count += 1
|
||||
assert cuda_template_count > 0, "No CUDATemplateCaller choices"
|
||||
assert cuda_template_count > 0, (
|
||||
"No CUTLASSTemplateCaller choices"
|
||||
)
|
||||
|
||||
run_test(True)
|
||||
run_test(False)
|
||||
@ -1350,7 +1351,7 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
),
|
||||
mock.patch(
|
||||
@ -1375,7 +1376,7 @@ class TestCutlassBackend(TestCase):
|
||||
assert op_name == "mm"
|
||||
cuda_template_count = 0
|
||||
for choice in choices:
|
||||
if isinstance(choice, CUDATemplateCaller):
|
||||
if isinstance(choice, CUTLASSTemplateCaller):
|
||||
choice_info = choice.info_dict()
|
||||
op_conf_name = choice_info.get("op_conf_name", "")
|
||||
assert isinstance(op_conf_name, str)
|
||||
@ -1384,7 +1385,7 @@ class TestCutlassBackend(TestCase):
|
||||
self.assertGreater(
|
||||
cuda_template_count,
|
||||
0,
|
||||
"No CUDATemplateCaller choices found for matmul with shape "
|
||||
"No CUTLASSTemplateCaller choices found for matmul with shape "
|
||||
f"M={M}, N={N}, K={K}",
|
||||
)
|
||||
|
||||
@ -1461,13 +1462,13 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cuda.generate_test_runner": True, # put standalone runner in the generated code
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.generate_test_runner": True, # put standalone runner in the generated code
|
||||
}
|
||||
):
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from torch._inductor.codegen.cuda.cutlass_utils import (
|
||||
from torch._inductor.codegen.cuda.compile_utils import (
|
||||
cuda_standalone_runner_compile_command,
|
||||
CUDACompileSourceCapturingContext,
|
||||
)
|
||||
@ -1544,7 +1545,7 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "ATEN,TRITON,CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
# needed for log searching
|
||||
"fx_graph_cache": False,
|
||||
"fx_graph_remote_cache": False,
|
||||
@ -1553,7 +1554,7 @@ class TestCutlassBackend(TestCase):
|
||||
with (
|
||||
log_settings("+inductor"),
|
||||
self.assertLogs(
|
||||
logger="torch._inductor.codegen.cuda", level=logging.DEBUG
|
||||
logger="torch._inductor.codegen.cutlass", level=logging.DEBUG
|
||||
) as test_log,
|
||||
):
|
||||
Y_compiled = torch.compile(mm, dynamic=False)(a, b)
|
||||
@ -1591,7 +1592,7 @@ class TestCutlassBackend(TestCase):
|
||||
expected = model(A, B)
|
||||
|
||||
# Track render calls
|
||||
from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate
|
||||
from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate
|
||||
|
||||
original_render = CUTLASSGemmTemplate.render
|
||||
render_call_count = 0
|
||||
@ -1608,8 +1609,8 @@ class TestCutlassBackend(TestCase):
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"fx_graph_cache": False,
|
||||
"fx_graph_remote_cache": False,
|
||||
"cuda.enable_caching_codegen": True,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.enable_caching_codegen": True,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
compiled_model = torch.compile(model, fullgraph=True)
|
||||
@ -1645,7 +1646,7 @@ class TestCutlassBackend(TestCase):
|
||||
d = torch.randn(64, 128).cuda().half().t()
|
||||
|
||||
# Track render calls
|
||||
from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate
|
||||
from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate
|
||||
|
||||
original_render = CUTLASSGemmTemplate.render
|
||||
render_call_count = 0
|
||||
@ -1660,10 +1661,10 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
"fx_graph_cache": False,
|
||||
"fx_graph_remote_cache": False,
|
||||
"cuda.enable_caching_codegen": True,
|
||||
"cutlass.enable_caching_codegen": True,
|
||||
}
|
||||
):
|
||||
# Get expected results
|
||||
@ -1706,7 +1707,7 @@ class TestCutlassBackend(TestCase):
|
||||
b = torch.randn(32, 64).cuda().half().t()
|
||||
|
||||
# Track render calls
|
||||
from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate
|
||||
from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate
|
||||
|
||||
original_render = CUTLASSGemmTemplate.render
|
||||
render_call_count = 0
|
||||
@ -1721,10 +1722,10 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
"fx_graph_cache": False,
|
||||
"fx_graph_remote_cache": False,
|
||||
"cuda.enable_caching_codegen": True,
|
||||
"cutlass.enable_caching_codegen": True,
|
||||
}
|
||||
):
|
||||
# Get expected results
|
||||
@ -1752,7 +1753,7 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(torch.mm)
|
||||
@ -1771,7 +1772,7 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"cutlass.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(torch.mm)
|
||||
@ -1795,7 +1796,7 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 1,
|
||||
"cutlass.cutlass_max_profiling_configs": 1,
|
||||
}
|
||||
):
|
||||
_ = torch.compile(model)(B)
|
||||
@ -1817,13 +1818,14 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 1,
|
||||
"cutlass.cutlass_max_profiling_configs": 1,
|
||||
}
|
||||
):
|
||||
_ = torch.compile(model)(B)
|
||||
|
||||
self.assertEqual(
|
||||
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
|
||||
torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
|
||||
1,
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@ -1845,7 +1847,7 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 1,
|
||||
"cutlass.cutlass_max_profiling_configs": 1,
|
||||
}
|
||||
):
|
||||
_ = torch.compile(model)(B)
|
||||
@ -1871,7 +1873,7 @@ class TestCutlassBackend(TestCase):
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 1,
|
||||
"cutlass.cutlass_max_profiling_configs": 1,
|
||||
}
|
||||
):
|
||||
if use_aoti:
|
||||
@ -1917,7 +1919,8 @@ class TestCutlassBackend(TestCase):
|
||||
ref_result = model(a, b, extra_args)
|
||||
|
||||
self.assertEqual(
|
||||
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
|
||||
torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
|
||||
1,
|
||||
)
|
||||
torch.testing.assert_close(result, ref_result)
|
||||
|
||||
@ -1968,18 +1971,19 @@ class TestCutlassBackend(TestCase):
|
||||
|
||||
# baseline is cutlass kernel + triton
|
||||
# matches expected casting behavior
|
||||
with config.patch({"cuda.cutlass_epilogue_fusion_enabled": False}):
|
||||
with config.patch({"cutlass.cutlass_epilogue_fusion_enabled": False}):
|
||||
ref_result = torch.compile(model)(a, b, extra_args)
|
||||
|
||||
self.assertEqual(
|
||||
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 0
|
||||
torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
|
||||
0,
|
||||
)
|
||||
|
||||
torch._dynamo.reset()
|
||||
result = torch.compile(model)(a, b, extra_args)
|
||||
|
||||
self.assertEqual(
|
||||
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"],
|
||||
torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
|
||||
1,
|
||||
)
|
||||
|
||||
@ -2037,7 +2041,7 @@ class TestCutlassBackend(TestCase):
|
||||
|
||||
self.assertEqual(
|
||||
torch._dynamo.utils.counters["inductor"][
|
||||
"cuda_epilogue_fusion_counter"
|
||||
"cutlass_epilogue_fusion_counter"
|
||||
],
|
||||
2 * (i + 1),
|
||||
)
|
||||
@ -2064,7 +2068,8 @@ class TestCutlassBackend(TestCase):
|
||||
ref_result = model(a, b, extra_args)
|
||||
|
||||
self.assertEqual(
|
||||
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
|
||||
torch._dynamo.utils.counters["inductor"]["cutlass_epilogue_fusion_counter"],
|
||||
1,
|
||||
)
|
||||
torch.testing.assert_close(result, ref_result)
|
||||
|
||||
@ -2368,7 +2373,7 @@ class TestCutlassBackend(TestCase):
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
# needed for log searching
|
||||
"force_disable_caches": True,
|
||||
"cuda.cutlass_max_profiling_swizzle_options": [2],
|
||||
"cutlass.cutlass_max_profiling_swizzle_options": [2],
|
||||
}
|
||||
):
|
||||
with mock.patch(
|
||||
|
||||
@ -5,7 +5,7 @@ import sympy
|
||||
|
||||
import torch
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._inductor.codegen.cuda.cutlass_utils import (
|
||||
from torch._inductor.codegen.cutlass.utils import (
|
||||
torch_dtype_to_cutlass_type,
|
||||
try_import_cutlass,
|
||||
)
|
||||
@ -28,7 +28,7 @@ if try_import_cutlass():
|
||||
DataType = cutlass_lib.DataType
|
||||
from cutlass_cppgen.backend.evt.ir.tensor import Tensor as CutlassTensor
|
||||
|
||||
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
|
||||
from torch._inductor.codegen.cutlass.lib_extensions.evt_extensions import (
|
||||
_render_argument_type,
|
||||
_trace,
|
||||
trace,
|
||||
@ -107,7 +107,7 @@ class TestCutlassEVT(TestCase):
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
||||
def test_py_codegen_accumulator_return(self):
|
||||
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
|
||||
from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
size = (100, 300, 200)
|
||||
@ -164,7 +164,7 @@ return tmp_0, tmp_2, D""",
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
||||
def test_py_codegen_disjoint_read_indexing(self):
|
||||
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
|
||||
from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
size = (100, 300, 200)
|
||||
@ -213,7 +213,7 @@ index strides [200, 60000, 1], and layout stride [60000, 200, 1]""",
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
||||
def test_py_codegen_broadcasting(self):
|
||||
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
|
||||
from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
size = (100, 300, 200)
|
||||
@ -273,7 +273,7 @@ return tmp_0, tmp_2, D""",
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
||||
def test_py_codegen(self):
|
||||
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
|
||||
from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
size = (100, 300, 200)
|
||||
@ -329,7 +329,7 @@ return tmp_1, D""",
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
||||
def test_example_tensor_creation(self):
|
||||
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
|
||||
from torch._inductor.codegen.cutlass.lib_extensions.evt_extensions import (
|
||||
create_example_tensors,
|
||||
)
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
@ -292,7 +292,7 @@ class TestPublicBindings(TestCase):
|
||||
# do not get imported by public code.
|
||||
# DO NOT add public modules here.
|
||||
private_allowlist = {
|
||||
"torch._inductor.codegen.cuda.cuda_kernel",
|
||||
"torch._inductor.codegen.cutlass.kernel",
|
||||
# TODO(#133647): Remove the onnx._internal entries after
|
||||
# onnx and onnxscript are installed in CI.
|
||||
"torch.onnx._internal.exporter",
|
||||
@ -357,8 +357,8 @@ class TestPublicBindings(TestCase):
|
||||
"torch.testing._internal.distributed.rpc.rpc_test",
|
||||
"torch.testing._internal.distributed.rpc.tensorpipe_rpc_agent_test_fixture",
|
||||
"torch.testing._internal.distributed.rpc_utils",
|
||||
"torch._inductor.codegen.cuda.cuda_template",
|
||||
"torch._inductor.codegen.cuda.gemm_template",
|
||||
"torch._inductor.codegen.cutlass.template",
|
||||
"torch._inductor.codegen.cutlass.gemm_template",
|
||||
"torch._inductor.codegen.cpp_template",
|
||||
"torch._inductor.codegen.cpp_gemm_template",
|
||||
"torch._inductor.codegen.cpp_micro_gemm",
|
||||
|
||||
@ -264,4 +264,4 @@
|
||||
"torch/_inductor/utils.py": {
|
||||
"class IndentedBuffer": 145
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,6 +37,7 @@ from torch._inductor.codecache import (
|
||||
ROCmCodeCache,
|
||||
StaticAutotunerFuture,
|
||||
torch_key,
|
||||
XPUCodeCache,
|
||||
)
|
||||
from torch._inductor.compile_worker.subproc_pool import (
|
||||
AnyPool,
|
||||
@ -557,6 +558,19 @@ class AsyncCompile:
|
||||
|
||||
return self.submit(task)
|
||||
|
||||
def xpu(self, source_code, dst_file_ext, aot_compile=False):
|
||||
kernel_code_log.info("XPU Kernel:\n%s", source_code)
|
||||
|
||||
def task():
|
||||
if aot_compile:
|
||||
# We rely on JITInductor to compile the CUDA code,
|
||||
# so that we can load it into AOTInductor.
|
||||
output_path, *_ = XPUCodeCache.compile(source_code, "o")
|
||||
XPUCodeCache.aot_kernels_o.append(output_path)
|
||||
return XPUCodeCache.load(source_code, dst_file_ext)[0]
|
||||
|
||||
return self.submit(task)
|
||||
|
||||
def rocm(
|
||||
self,
|
||||
source_code,
|
||||
|
||||
@ -30,6 +30,7 @@ from torch._inductor.codecache import (
|
||||
DLLWrapper,
|
||||
get_hash,
|
||||
PyCodeCache,
|
||||
XPUCodeCache,
|
||||
)
|
||||
from torch._inductor.utils import (
|
||||
get_gpu_type,
|
||||
@ -682,7 +683,7 @@ class TritonCPUBenchmarkRequest(CPUDeviceBenchmarkMixin, TritonBenchmarkRequest)
|
||||
pass
|
||||
|
||||
|
||||
class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
|
||||
class CUTLASSBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
|
||||
"""
|
||||
A class to handle CUDA (CUTLASS) benchmark requests. This class is for
|
||||
managing the lifecycle of a CUDA kernel benchmark, including compiling
|
||||
@ -699,6 +700,7 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
|
||||
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||
extra_args: Iterable[Any],
|
||||
source_code: str,
|
||||
device_type: str = "cuda",
|
||||
) -> None:
|
||||
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
|
||||
self.source_code = source_code
|
||||
@ -708,7 +710,12 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
|
||||
self._workspace_size_updated = False
|
||||
self.hash_key: str = ""
|
||||
self.source_file: str = ""
|
||||
self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so")
|
||||
self.device_type = device_type
|
||||
self.codecache_cls = XPUCodeCache if device_type == "xpu" else CUDACodeCache
|
||||
self.device_interface = get_interface_for_device(device_type)
|
||||
self.hash_key, self.source_file = self.codecache_cls.write(
|
||||
self.source_code, "so"
|
||||
)
|
||||
|
||||
def precompile(self):
|
||||
"""
|
||||
@ -716,14 +723,14 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
|
||||
This may happen in a separate thread pool.
|
||||
"""
|
||||
autotuning_log.debug("Precompiling %s", self)
|
||||
CUDACodeCache.compile(self.source_code, "so")
|
||||
self.codecache_cls.compile(self.source_code, "so")
|
||||
autotuning_log.debug("Done precompiling %s", self)
|
||||
|
||||
def make_run_fn(
|
||||
self, *input_tensors: torch.Tensor, out: torch.Tensor
|
||||
) -> Callable[[], None]:
|
||||
"""
|
||||
Create a function to run the CUDA kernel with the given input and output tensors.
|
||||
Create a function to run the CUDA/XPU kernel with the given input and output tensors.
|
||||
"""
|
||||
|
||||
self.ensure_dll_loaded()
|
||||
@ -738,7 +745,9 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
|
||||
args,
|
||||
self.extra_args,
|
||||
)
|
||||
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
|
||||
stream_ptr = c_void_p(
|
||||
self.device_interface.get_raw_stream(self.device_interface.current_device())
|
||||
)
|
||||
run_method = getattr(self.DLL, self.kernel_name)
|
||||
workspace_ptr = c_void_p(0)
|
||||
if self.workspace_size > 0:
|
||||
@ -781,7 +790,9 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
|
||||
dict.fromkeys(meta.name for meta in self.input_tensor_meta)
|
||||
)
|
||||
args = [c_void_p(None) for _ in range(unique_input_count + 1)]
|
||||
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
|
||||
stream_ptr = c_void_p(
|
||||
self.device_interface.get_raw_stream(self.device_interface.current_device())
|
||||
)
|
||||
|
||||
run_method = getattr(self.DLL, self.kernel_name)
|
||||
# Retrieve workspace_size and initialize workspace.
|
||||
@ -795,7 +806,7 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
|
||||
None, # null workspace ptr
|
||||
stream_ptr,
|
||||
)
|
||||
torch.cuda.synchronize() # shake out any CUDA errors
|
||||
self.device_interface.synchronize() # shake out any device errors
|
||||
self.workspace_size = c_workspace_size.value
|
||||
autotuning_log.debug(
|
||||
"update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950
|
||||
@ -811,7 +822,7 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
|
||||
|
||||
def ensure_dll_loaded(self):
|
||||
if self.DLL is None:
|
||||
self.DLL, self.hash_key, self.source_file = CUDACodeCache.load(
|
||||
self.DLL, self.hash_key, self.source_file = self.codecache_cls.load(
|
||||
self.source_code, "so"
|
||||
)
|
||||
|
||||
|
||||
@ -34,7 +34,17 @@ from pathlib import Path
|
||||
from tempfile import _TemporaryFileWrapper
|
||||
from time import time, time_ns
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, cast, Generic, NoReturn, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Generic,
|
||||
NoReturn,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import override, Self
|
||||
|
||||
import torch
|
||||
@ -54,7 +64,7 @@ from torch._inductor.codegen.common import (
|
||||
custom_backend_passes,
|
||||
init_backend_registration,
|
||||
)
|
||||
from torch._inductor.codegen.cuda import cuda_env
|
||||
from torch._inductor.codegen.cuda import compile_utils as cuda_compile_utils
|
||||
from torch._inductor.codegen.rocm.compile_command import (
|
||||
rocm_compile_command,
|
||||
rocm_compiler,
|
||||
@ -64,7 +74,6 @@ from torch._inductor.cpp_builder import (
|
||||
_LINKER_SCRIPT,
|
||||
_set_gpu_runtime_env,
|
||||
_TORCH_PATH,
|
||||
_transform_cuda_paths,
|
||||
convert_cubin_to_obj,
|
||||
CppBuilder,
|
||||
CppOptions,
|
||||
@ -119,10 +128,6 @@ from .triton_bundler import TritonBundler
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
if config.is_fbcode():
|
||||
from triton.fb.build import build_paths
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -148,17 +153,6 @@ autotuning_log = torch._logging.getArtifactLogger(__name__, "autotuning")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def use_re_build() -> bool:
|
||||
"""
|
||||
Use for CUTLASS compilation only right now.
|
||||
"""
|
||||
if config.is_fbcode() and not cuda_env.nvcc_exist(_cuda_compiler()):
|
||||
from triton.fb.re_build_helper import should_build_locally
|
||||
|
||||
return not should_build_locally()
|
||||
return False
|
||||
|
||||
|
||||
def get_cpp_wrapper_cubin_path_name() -> str:
|
||||
return "cubin_path" if torch.version.hip is None else "hsaco_path"
|
||||
|
||||
@ -2352,7 +2346,8 @@ end
|
||||
f.write(json.dumps(qual_name_to_id))
|
||||
generated_files.append(constants_config_json)
|
||||
|
||||
gpu_codecache: ROCmCodeCache | CUDACodeCache = (
|
||||
gpu_codecache: ROCmCodeCache | CUDACodeCache | XPUCodeCache = (
|
||||
XPUCodeCache() if device_type == "xpu" else
|
||||
ROCmCodeCache() if torch.version.hip else CUDACodeCache()
|
||||
)
|
||||
gpu_kernels_o = gpu_codecache.aot_kernels_o.copy()
|
||||
@ -2383,10 +2378,10 @@ end
|
||||
config.aot_inductor.emit_multi_arch_kernel
|
||||
and device_type == "cuda"
|
||||
):
|
||||
current_arch = _nvcc_arch_as_compile_option()
|
||||
current_arch = cuda_compile_utils._nvcc_arch_as_compile_option()
|
||||
cmd = (
|
||||
# pyrefly: ignore [unbound-name]
|
||||
f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} "
|
||||
f"{cuda_compile_utils._cuda_compiler()} -fatbin {asm_file} -o {cubin_file} "
|
||||
# Triton only allows generating PTX version as same as the current arch
|
||||
f"-gencode arch=compute_{current_arch},code=compute_{current_arch} "
|
||||
# Include SASS for the current specific arch
|
||||
@ -3686,55 +3681,6 @@ def _load_triton_kernel_from_source(
|
||||
return getattr(PyCodeCache.load(source_code), kernel_name)
|
||||
|
||||
|
||||
def _cuda_compiler() -> str | None:
|
||||
if cuda_env.nvcc_exist(config.cuda.cuda_cxx):
|
||||
return config.cuda.cuda_cxx
|
||||
if config.is_fbcode():
|
||||
return os.path.join(build_paths.sdk_home, "bin", "nvcc")
|
||||
if cuda_env.nvcc_exist(os.getenv("CUDACXX")):
|
||||
return os.getenv("CUDACXX", "")
|
||||
if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")):
|
||||
return os.path.realpath(os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc"))
|
||||
return "nvcc"
|
||||
|
||||
|
||||
def _cutlass_path() -> str:
|
||||
if config.is_fbcode():
|
||||
from libfb.py import parutil
|
||||
|
||||
return parutil.get_dir_path("cutlass-4-headers")
|
||||
else:
|
||||
return config.cuda.cutlass_dir
|
||||
|
||||
|
||||
def _cutlass_paths() -> list[str]:
|
||||
return [
|
||||
"include",
|
||||
"tools/library/include",
|
||||
"tools/library/src",
|
||||
"tools/util/include",
|
||||
]
|
||||
|
||||
|
||||
def _clone_cutlass_paths(build_root: str) -> list[str]:
|
||||
paths = _cutlass_paths()
|
||||
cutlass_root = _cutlass_path()
|
||||
for path in _cutlass_paths():
|
||||
old_path = os.path.join(cutlass_root, path)
|
||||
new_path = os.path.join(build_root, path)
|
||||
shutil.copytree(old_path, new_path, dirs_exist_ok=True)
|
||||
return paths
|
||||
|
||||
|
||||
def _cutlass_include_paths() -> list[str]:
|
||||
cutlass_path = _cutlass_path()
|
||||
return [
|
||||
# Use realpath to get canonical absolute paths, in order not to mess up cache keys
|
||||
os.path.realpath(os.path.join(cutlass_path, path))
|
||||
for path in _cutlass_paths()
|
||||
]
|
||||
|
||||
|
||||
@torch_key_cache
|
||||
def cutlass_key() -> bytes:
|
||||
"""
|
||||
@ -3750,151 +3696,10 @@ def cutlass_key() -> bytes:
|
||||
return resource_file.read().encode()
|
||||
|
||||
combined_hash = hashlib.sha256()
|
||||
build_code_hash([config.cuda.cutlass_dir], "", combined_hash)
|
||||
build_code_hash([config.cutlass.cutlass_dir], "", combined_hash)
|
||||
return combined_hash.digest()
|
||||
|
||||
|
||||
def _cuda_lib_options() -> list[str]:
|
||||
"""
|
||||
Util function for CUTLASS backend to find the correct CUDA libraries.
|
||||
"""
|
||||
_set_gpu_runtime_env() # cpp_extension consults the env
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
lpaths = cpp_extension.library_paths(device_type="cuda")
|
||||
if use_re_build():
|
||||
lpaths += [
|
||||
build_paths.sdk_lib,
|
||||
os.path.join(build_paths.sdk_lib, "stubs"),
|
||||
]
|
||||
extra_ldflags: list[str] = []
|
||||
if is_linux():
|
||||
_transform_cuda_paths(lpaths)
|
||||
for path in lpaths:
|
||||
if "torch/lib" in path:
|
||||
# don't want to depend on pytorch
|
||||
continue
|
||||
extra_ldflags.append(f"-L{path}")
|
||||
# -rpath ensures the DLL can find its dependencies when loaded, even
|
||||
# if the library path is non-standard.
|
||||
# But do not add the stubs folder to rpath as the driver is expected to be found at runtime
|
||||
if os.path.basename(path) != "stubs":
|
||||
extra_ldflags.extend(["-Xlinker", f"-rpath={path}"])
|
||||
extra_ldflags.append("-lcuda")
|
||||
extra_ldflags.append("-lcudart")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unsupported env, failed to find cuda libs! Currently only Linux is supported."
|
||||
)
|
||||
return extra_ldflags
|
||||
|
||||
|
||||
def _nvcc_host_compiler_options() -> list[str]:
|
||||
return [
|
||||
"-fPIC",
|
||||
"-fno-strict-aliasing",
|
||||
"-fvisibility=hidden",
|
||||
"-Wconversion",
|
||||
]
|
||||
|
||||
|
||||
def _nvcc_arch_as_compile_option() -> str:
|
||||
arch = cuda_env.get_cuda_arch()
|
||||
if arch == "90":
|
||||
# Required by cutlass compilation.
|
||||
return "90a"
|
||||
if arch == "100":
|
||||
return "100a"
|
||||
return arch
|
||||
|
||||
|
||||
def _nvcc_compiler_options() -> list[str]:
|
||||
arch = _nvcc_arch_as_compile_option()
|
||||
code = [f"sm_{arch}", f"compute_{arch}"]
|
||||
if config.cuda.enable_cuda_lto:
|
||||
code += [f"lto_{arch}"]
|
||||
options = [
|
||||
"-t=0",
|
||||
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
|
||||
"-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1",
|
||||
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
||||
"-w",
|
||||
f"-gencode=arch=compute_{arch},code=[{','.join(code)}]",
|
||||
config.cuda.compile_opt_level,
|
||||
"-std=c++17",
|
||||
"--expt-relaxed-constexpr",
|
||||
"-DNDEBUG",
|
||||
]
|
||||
if config.is_fbcode():
|
||||
options.extend(["-ccbin", os.path.dirname(build_paths.gcc)])
|
||||
if config.cuda.enable_debug_info:
|
||||
options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"])
|
||||
if config.cuda.enable_ptxas_info:
|
||||
options.extend(
|
||||
[
|
||||
"--keep", # Keep the intermediate files for debugging (including ptx, sass, cubin etc.)
|
||||
"--ptxas-options=--warn-on-local-memory-usage", # warn us if local memory is used in CUDA Kernels
|
||||
"--ptxas-options=--warn-on-spills", # warn us if register spilling happens in CUDA Kernels
|
||||
"--resource-usage", # Report on CUDA resource usage (shared mem, registers etc.)
|
||||
"--source-in-ptx",
|
||||
]
|
||||
) # Annotate the ptx file with source information
|
||||
if config.cuda.use_fast_math:
|
||||
options.extend(
|
||||
[
|
||||
"--use_fast_math",
|
||||
"-DCUTLASS_USE_TANH_FOR_SIGMOID=1",
|
||||
]
|
||||
)
|
||||
return options
|
||||
|
||||
|
||||
def cuda_compile_command(
|
||||
src_files: list[str],
|
||||
dst_file: str,
|
||||
dst_file_ext: str,
|
||||
extra_args: list[str] | None = None,
|
||||
) -> str:
|
||||
if extra_args is None:
|
||||
extra_args = []
|
||||
if use_re_build():
|
||||
build_path = os.path.dirname(dst_file)
|
||||
include_paths = _clone_cutlass_paths(build_path)
|
||||
src_files = [os.path.basename(src_file) for src_file in src_files]
|
||||
dst_file = os.path.basename(dst_file)
|
||||
else:
|
||||
include_paths = _cutlass_include_paths()
|
||||
cuda_lib_options = _cuda_lib_options()
|
||||
nvcc_host_compiler_options = _nvcc_host_compiler_options()
|
||||
nvcc_compiler_options = _nvcc_compiler_options()
|
||||
options = (
|
||||
nvcc_compiler_options
|
||||
+ extra_args
|
||||
+ [
|
||||
f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}"
|
||||
for opt in nvcc_host_compiler_options
|
||||
]
|
||||
+ ["-I" + path for path in include_paths]
|
||||
+ cuda_lib_options
|
||||
)
|
||||
src_file = " ".join(src_files)
|
||||
res = ""
|
||||
if dst_file_ext == "o":
|
||||
res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}"
|
||||
elif dst_file_ext == "so":
|
||||
options.append("-shared")
|
||||
res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
|
||||
elif dst_file_ext == "exe":
|
||||
res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!")
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
log.debug("CUDA command: %s", res)
|
||||
else:
|
||||
autotuning_log.debug("CUDA command: %s", res)
|
||||
return res
|
||||
|
||||
|
||||
class DLLWrapper:
|
||||
"""A wrapper for a dynamic library."""
|
||||
|
||||
@ -3978,10 +3783,9 @@ def binary_error_path(output_path: str) -> str:
|
||||
return output_path + ".error"
|
||||
|
||||
|
||||
@clear_on_fresh_cache
|
||||
class CUDACodeCache:
|
||||
class CUTLASSCodeCache:
|
||||
"""
|
||||
A cache for managing the compilation and loading of CUDA source code specifically for CUTLASS.
|
||||
A cache for managing the compilation and loading source code specifically for CUTLASS.
|
||||
This class handles writing source code to files, compiling them into shared objects, and caching
|
||||
the results to avoid redundant compilations. It also manages error handling and logging for the
|
||||
compilation process.
|
||||
@ -3995,12 +3799,15 @@ class CUDACodeCache:
|
||||
|
||||
cache: dict[str, CacheEntry] = {}
|
||||
aot_kernels_o: list[str] = []
|
||||
_SOURCE_CODE_SUFFIX = "cu"
|
||||
|
||||
_SOURCE_CODE_SUFFIX: str = ""
|
||||
_BACKEND: str = ""
|
||||
|
||||
@staticmethod
|
||||
def cache_clear() -> None:
|
||||
CUDACodeCache.cache.clear()
|
||||
CUDACodeCache.aot_kernels_o.clear()
|
||||
CUTLASSCodeCache.cache.clear()
|
||||
CUTLASSCodeCache.aot_kernels_o.clear()
|
||||
CUTLASSCodeCache.write.cache_clear()
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=4)
|
||||
@ -4035,6 +3842,24 @@ class CUDACodeCache:
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _use_re_build(cls) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _compile_command(
|
||||
cls,
|
||||
src_files: list[str],
|
||||
dst_file: str,
|
||||
dst_file_ext: str,
|
||||
extra_args: Optional[list[str]] = None,
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _source_code_extra(cls) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@lru_cache(None)
|
||||
def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]:
|
||||
@ -4043,25 +3868,14 @@ class CUDACodeCache:
|
||||
Returns the hash key of source code, and the path to the file.
|
||||
"""
|
||||
|
||||
if config.cuda.cutlass_hash_with_compile_cmd:
|
||||
cuda_command = repr(
|
||||
cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext)
|
||||
if config.cutlass.cutlass_hash_with_compile_cmd:
|
||||
compile_command = repr(
|
||||
cls._compile_command(["dummy_input"], "dummy_output", dst_file_ext)
|
||||
)
|
||||
extra = cuda_command
|
||||
extra = compile_command
|
||||
else:
|
||||
extra = repr(
|
||||
[
|
||||
# nvcc and cuda hash
|
||||
_cuda_compiler(),
|
||||
# cutlass flags and gcc hash
|
||||
_nvcc_compiler_options(),
|
||||
# flags
|
||||
_nvcc_host_compiler_options(),
|
||||
# cutlass key
|
||||
cutlass_key(),
|
||||
# hack to deal with AOTI .o compilation
|
||||
]
|
||||
)
|
||||
extra = cls._source_code_extra()
|
||||
|
||||
key, input_path = write(source_code, cls._SOURCE_CODE_SUFFIX, extra=extra)
|
||||
return key, input_path
|
||||
|
||||
@ -4094,7 +3908,7 @@ class CUDACodeCache:
|
||||
output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext
|
||||
error_path = binary_error_path(output_path)
|
||||
binary_remote_cache = cls.get_kernel_binary_remote_cache(
|
||||
caching_enabled=config.cuda.use_binary_remote_cache
|
||||
caching_enabled=config.cutlass.use_binary_remote_cache
|
||||
and not config.force_disable_caches,
|
||||
caching_available=config.is_fbcode(),
|
||||
)
|
||||
@ -4109,30 +3923,30 @@ class CUDACodeCache:
|
||||
cmd_parts, error_output = json.loads(error_json)
|
||||
if (
|
||||
binary_remote_cache is not None
|
||||
and config.cuda.upload_to_binary_remote_cache
|
||||
and config.cutlass.upload_to_binary_remote_cache
|
||||
):
|
||||
# This ensures that a local error is uploaded to the remote cache,
|
||||
# as we make no assumptions about the remote cache having the same
|
||||
# information as the local cache
|
||||
binary_remote_cache.put(
|
||||
error_path, config.cuda.binary_remote_cache_force_write
|
||||
error_path, config.cutlass.binary_remote_cache_force_write
|
||||
)
|
||||
cls.cache[key_with_ext] = CUDACodeCache.CacheEntry(
|
||||
cls.cache[key_with_ext] = CUTLASSCodeCache.CacheEntry(
|
||||
input_path, output_path, error_json
|
||||
)
|
||||
raise exc.CUDACompileError(cmd_parts, error_output)
|
||||
if not os.path.exists(output_path):
|
||||
cmd = cuda_compile_command(
|
||||
cmd = cls._compile_command(
|
||||
src_files, output_path, dst_file_ext, extra_args
|
||||
)
|
||||
with open(input_path, "a") as f:
|
||||
f.write("\n")
|
||||
f.write(f"// CUDA {operation_name} cmd\n// {cmd}\n")
|
||||
f.write(f"// {cls._BACKEND} {operation_name} cmd\n// {cmd}\n")
|
||||
start_time = time()
|
||||
log.debug("CUDA %s: %s", operation_name, cmd)
|
||||
log.debug("%s %s: %s", cls._BACKEND, operation_name, cmd)
|
||||
cmd_parts = cmd.split(" ")
|
||||
try:
|
||||
if use_re_build():
|
||||
if cls._use_re_build():
|
||||
from triton.fb.re_build_helper import run_build_command
|
||||
|
||||
run_build_command(
|
||||
@ -4145,7 +3959,7 @@ class CUDACodeCache:
|
||||
cmd_parts, stderr=subprocess.STDOUT, env=os.environ
|
||||
)
|
||||
except subprocess.CalledProcessError as error:
|
||||
cls._record_cuda_compile_error(
|
||||
cls._record_compile_error(
|
||||
error.output.decode("utf-8"),
|
||||
key_with_ext,
|
||||
cmd_parts,
|
||||
@ -4156,7 +3970,7 @@ class CUDACodeCache:
|
||||
raise exc.CUDACompileError(cmd_parts, error.output) from error
|
||||
except Exception as error:
|
||||
if "COMPILE FAILED WITH" in str(error):
|
||||
cls._record_cuda_compile_error(
|
||||
cls._record_compile_error(
|
||||
str(error),
|
||||
key_with_ext,
|
||||
cmd_parts,
|
||||
@ -4167,29 +3981,30 @@ class CUDACodeCache:
|
||||
raise exc.CUDACompileError(cmd_parts, str(error)) from error
|
||||
raise error
|
||||
end_time = time()
|
||||
log_duration_msg = f"CUDA {operation_name} took {end_time - start_time} seconds. Command: {cmd}"
|
||||
log_duration_msg = f"{cls._BACKEND} {operation_name} took {end_time - start_time} seconds. Command: {cmd}"
|
||||
log.info(log_duration_msg)
|
||||
|
||||
else:
|
||||
log.debug(
|
||||
"CUDA %s skipped: %s since output already exists",
|
||||
"%s %s skipped: %s since output already exists",
|
||||
cls._BACKEND,
|
||||
operation_name,
|
||||
output_path,
|
||||
)
|
||||
# Upload to remote cache if enabled
|
||||
if (
|
||||
binary_remote_cache is not None
|
||||
and config.cuda.upload_to_binary_remote_cache
|
||||
and config.cutlass.upload_to_binary_remote_cache
|
||||
):
|
||||
# will log on errors, but not fail out
|
||||
binary_remote_cache.put(
|
||||
output_path, config.cuda.binary_remote_cache_force_write
|
||||
output_path, config.cutlass.binary_remote_cache_force_write
|
||||
)
|
||||
cls.cache[key_with_ext] = CUDACodeCache.CacheEntry(
|
||||
cls.cache[key_with_ext] = CUTLASSCodeCache.CacheEntry(
|
||||
input_path, output_path, None
|
||||
)
|
||||
|
||||
cache_entry: CUDACodeCache.CacheEntry = cls.cache[key_with_ext]
|
||||
cache_entry: CUTLASSCodeCache.CacheEntry = cls.cache[key_with_ext]
|
||||
if cache_entry.error_json is not None:
|
||||
# Restore cached Exception and raise it as if we had compiled
|
||||
cmd_parts, error_output = json.loads(cache_entry.error_json)
|
||||
@ -4214,7 +4029,7 @@ class CUDACodeCache:
|
||||
return (DLLWrapper(dst_file_path), hash_key, source_code_path)
|
||||
|
||||
@classmethod
|
||||
def _record_cuda_compile_error(
|
||||
def _record_compile_error(
|
||||
cls,
|
||||
error_str: str,
|
||||
key_with_ext: str,
|
||||
@ -4226,7 +4041,7 @@ class CUDACodeCache:
|
||||
binary_remote_cache: Any = None,
|
||||
) -> None:
|
||||
error_json = json.dumps([cmd_parts, error_str])
|
||||
cls.cache[key_with_ext] = CUDACodeCache.CacheEntry(
|
||||
cls.cache[key_with_ext] = CUTLASSCodeCache.CacheEntry(
|
||||
input_path, output_path, error_json
|
||||
)
|
||||
error_path = binary_error_path(output_path)
|
||||
@ -4236,13 +4051,94 @@ class CUDACodeCache:
|
||||
# Upload to remote cache directly from memory if enabled
|
||||
if (
|
||||
binary_remote_cache is not None
|
||||
and config.cuda.upload_to_binary_remote_cache
|
||||
and config.cutlass.upload_to_binary_remote_cache
|
||||
):
|
||||
binary_remote_cache.put(
|
||||
error_path, config.cuda.binary_remote_cache_force_write
|
||||
error_path, config.cutlass.binary_remote_cache_force_write
|
||||
)
|
||||
|
||||
|
||||
@clear_on_fresh_cache
|
||||
class CUDACodeCache(CUTLASSCodeCache):
|
||||
_SOURCE_CODE_SUFFIX = "cu"
|
||||
_BACKEND = "CUDA"
|
||||
|
||||
@classmethod
|
||||
def _use_re_build(cls) -> bool:
|
||||
return cuda_compile_utils.use_re_build()
|
||||
|
||||
@classmethod
|
||||
def _compile_command(
|
||||
cls,
|
||||
src_files: list[str],
|
||||
dst_file: str,
|
||||
dst_file_ext: str,
|
||||
extra_args: Optional[list[str]] = None,
|
||||
) -> str:
|
||||
return cuda_compile_utils.cuda_compile_command(
|
||||
src_files, dst_file, dst_file_ext, extra_args=extra_args
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _source_code_extra(cls) -> str:
|
||||
extra = repr(
|
||||
[
|
||||
# nvcc and cuda hash
|
||||
cuda_compile_utils._cuda_compiler(),
|
||||
# cutlass flags and gcc hash
|
||||
cuda_compile_utils._nvcc_compiler_options(),
|
||||
# flags
|
||||
cuda_compile_utils._nvcc_host_compiler_options(),
|
||||
# cutlass key
|
||||
cutlass_key(),
|
||||
# hack to deal with AOTI .o compilation
|
||||
]
|
||||
)
|
||||
return extra
|
||||
|
||||
|
||||
from torch._inductor.codegen.xpu import compile_utils as xpu_compile_utils
|
||||
|
||||
|
||||
@clear_on_fresh_cache
|
||||
class XPUCodeCache(CUTLASSCodeCache):
|
||||
_SOURCE_CODE_SUFFIX = "cpp"
|
||||
_BACKEND = "XPU"
|
||||
|
||||
@classmethod
|
||||
def _use_re_build(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _compile_command(
|
||||
cls,
|
||||
src_files: list[str],
|
||||
dst_file: str,
|
||||
dst_file_ext: str,
|
||||
extra_args: Optional[list[str]] = None,
|
||||
) -> str:
|
||||
return xpu_compile_utils.xpu_compile_command(
|
||||
src_files, dst_file, dst_file_ext, extra_args=extra_args
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _source_code_extra(cls) -> str:
|
||||
extra = repr(
|
||||
[
|
||||
# nvcc and cuda hash
|
||||
xpu_compile_utils._sycl_compiler(),
|
||||
# cutlass flags and gcc hash
|
||||
xpu_compile_utils._sycl_compiler_options(),
|
||||
# flags
|
||||
xpu_compile_utils._sycl_host_compiler_options(),
|
||||
# cutlass key
|
||||
cutlass_key(),
|
||||
# hack to deal with AOTI .o compilation
|
||||
]
|
||||
)
|
||||
return extra
|
||||
|
||||
|
||||
@clear_on_fresh_cache
|
||||
class ROCmCodeCache:
|
||||
@dataclasses.dataclass
|
||||
|
||||
@ -10,8 +10,8 @@ from ..scheduler import (
|
||||
Scheduler,
|
||||
SchedulerNode,
|
||||
)
|
||||
from .cuda.cuda_cpp_scheduling import CUDACPPScheduling
|
||||
from .cutedsl.cutedsl_scheduling import CuteDSLScheduling
|
||||
from .cutlass.scheduling import CUTLASSScheduling
|
||||
from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling
|
||||
from .triton import TritonScheduling
|
||||
|
||||
@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
_IntLike: TypeAlias = Union[int, Expr]
|
||||
|
||||
|
||||
class CUDACombinedScheduling(BaseScheduling):
|
||||
class CombinedScheduling(BaseScheduling):
|
||||
"""
|
||||
Scheduler for CUDA Kernels, which delegates calls as appropriate
|
||||
to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices
|
||||
@ -43,7 +43,7 @@ class CUDACombinedScheduling(BaseScheduling):
|
||||
def __init__(self, scheduler: Optional[Scheduler]) -> None:
|
||||
super().__init__(scheduler)
|
||||
self._triton_scheduling = TritonScheduling(scheduler)
|
||||
self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler)
|
||||
self._cutlass_scheduling = CUTLASSScheduling(scheduler)
|
||||
self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler)
|
||||
self._cutedsl_scheduling = CuteDSLScheduling(scheduler)
|
||||
|
||||
@ -51,8 +51,8 @@ class CUDACombinedScheduling(BaseScheduling):
|
||||
return self._triton_scheduling.get_backend_features(device)
|
||||
|
||||
def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling:
|
||||
if self._cuda_cpp_scheduling.is_cuda_cpp_template(node):
|
||||
return self._cuda_cpp_scheduling
|
||||
if self._cutlass_scheduling.is_cutlass_template(node):
|
||||
return self._cutlass_scheduling
|
||||
if self._rocm_cpp_scheduling.is_rocm_cpp_template(node):
|
||||
return self._rocm_cpp_scheduling
|
||||
if self._cutedsl_scheduling.is_cutedsl_template(node):
|
||||
@ -62,11 +62,11 @@ class CUDACombinedScheduling(BaseScheduling):
|
||||
def can_fuse_vertical(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
) -> bool:
|
||||
if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2):
|
||||
if self._cutlass_scheduling.can_fuse_vertical(node1, node2):
|
||||
return True
|
||||
elif self._cuda_cpp_scheduling.is_cuda_cpp_template(
|
||||
elif self._cutlass_scheduling.is_cutlass_template(
|
||||
node1
|
||||
) or self._cuda_cpp_scheduling.is_cuda_cpp_template(node2):
|
||||
) or self._cutlass_scheduling.is_cutlass_template(node2):
|
||||
return False
|
||||
# CuteDSL doesn't support vertical fusion currently
|
||||
elif self._cutedsl_scheduling.is_cutedsl_template(
|
||||
@ -79,8 +79,8 @@ class CUDACombinedScheduling(BaseScheduling):
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
) -> bool:
|
||||
for node in (node1, node2):
|
||||
if self._cuda_cpp_scheduling.is_cuda_cpp_template(node):
|
||||
return self._cuda_cpp_scheduling.can_fuse_horizontal(
|
||||
if self._cutlass_scheduling.is_cutlass_template(node):
|
||||
return self._cutlass_scheduling.can_fuse_horizontal(
|
||||
node1, node2
|
||||
) # always False at the moment
|
||||
if self._cutedsl_scheduling.is_cutedsl_template(node):
|
||||
@ -100,9 +100,9 @@ class CUDACombinedScheduling(BaseScheduling):
|
||||
epilogue_nodes: Sequence[BaseSchedulerNode],
|
||||
prologue_nodes: Sequence[BaseSchedulerNode],
|
||||
) -> Optional[str]:
|
||||
if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
|
||||
if self._cutlass_scheduling.is_cutlass_template(template_node):
|
||||
assert not prologue_nodes
|
||||
return self._cuda_cpp_scheduling.codegen_template(
|
||||
return self._cutlass_scheduling.codegen_template(
|
||||
template_node, epilogue_nodes, prologue_nodes
|
||||
)
|
||||
elif self._rocm_cpp_scheduling.is_rocm_cpp_template(template_node):
|
||||
@ -371,6 +371,12 @@ class DeviceOpOverrides:
|
||||
# optionally return (scratch definition, arg name)
|
||||
raise NotImplementedError
|
||||
|
||||
def get_device_arch(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_toolkit_version(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
device_op_overrides_dict: dict[str, DeviceOpOverrides] = {}
|
||||
custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {}
|
||||
@ -495,12 +501,12 @@ def init_backend_registration() -> None:
|
||||
Register the backend for different devices, including the scheduling
|
||||
for kernel code generation and the host side wrapper code generation.
|
||||
"""
|
||||
from .combined_scheduling import CombinedScheduling
|
||||
from .cpp import CppScheduling
|
||||
from .cpp_wrapper_cpu import CppWrapperCpu
|
||||
from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef
|
||||
from .cpp_wrapper_gpu import CppWrapperGpu
|
||||
from .cpp_wrapper_mps import CppWrapperMps
|
||||
from .cuda_combined_scheduling import CUDACombinedScheduling
|
||||
from .halide import HalideScheduling
|
||||
from .mps import MetalScheduling
|
||||
from .python_wrapper_mtia import PythonWrapperMtia
|
||||
@ -525,9 +531,9 @@ def init_backend_registration() -> None:
|
||||
)
|
||||
|
||||
if get_scheduling_for_device("cuda") is None:
|
||||
# CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
|
||||
# CombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
|
||||
cuda_backends = {
|
||||
"triton": CUDACombinedScheduling,
|
||||
"triton": CombinedScheduling,
|
||||
"halide": HalideScheduling,
|
||||
}
|
||||
register_backend_for_device(
|
||||
@ -2366,7 +2372,7 @@ class KernelTemplate:
|
||||
"""
|
||||
Base class for defining kernel templates.
|
||||
|
||||
Children classes: TritonTemplate, CUDATemplate
|
||||
Children classes: TritonTemplate, CUTLASSTemplate
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@ -2618,12 +2624,12 @@ class CSEProxy(DefaultHandler):
|
||||
"""
|
||||
from ..bounds import ValueRangeAnalysis
|
||||
from ..select_algorithm import TritonTemplateKernel
|
||||
from .cuda.cuda_kernel import CUDATemplateKernel
|
||||
from .cutlass.kernel import CUTLASSTemplateKernel
|
||||
|
||||
if isinstance(V.kernel, TritonTemplateKernel):
|
||||
return ValueRanges.unknown()
|
||||
|
||||
if isinstance(V.kernel, CUDATemplateKernel):
|
||||
if isinstance(V.kernel, CUTLASSTemplateKernel):
|
||||
return ValueRanges.unknown()
|
||||
|
||||
fx_node = V.interpreter.current_node
|
||||
|
||||
264
torch/_inductor/codegen/cuda/compile_utils.py
Normal file
264
torch/_inductor/codegen/cuda/compile_utils.py
Normal file
@ -0,0 +1,264 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.cuda import cuda_env
|
||||
from torch._inductor.cpp_builder import _set_gpu_runtime_env, _transform_cuda_paths
|
||||
from torch._inductor.utils import is_linux
|
||||
|
||||
|
||||
if config.is_fbcode():
|
||||
from triton.fb.build import build_paths
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
autotuning_log = torch._logging.getArtifactLogger(__name__, "autotuning")
|
||||
|
||||
|
||||
def use_re_build() -> bool:
|
||||
"""
|
||||
Use for CUTLASS compilation only right now.
|
||||
"""
|
||||
if config.is_fbcode() and not cuda_env.nvcc_exist(_cuda_compiler()):
|
||||
from triton.fb.re_build_helper import should_build_locally
|
||||
|
||||
return not should_build_locally()
|
||||
return False
|
||||
|
||||
|
||||
def _cutlass_path() -> str:
|
||||
if config.is_fbcode():
|
||||
from libfb.py import parutil
|
||||
|
||||
return parutil.get_dir_path("cutlass-4-headers")
|
||||
else:
|
||||
return config.cutlass.cutlass_dir
|
||||
|
||||
|
||||
def _cutlass_paths() -> list[str]:
|
||||
return [
|
||||
"include",
|
||||
"tools/library/include",
|
||||
"tools/library/src",
|
||||
"tools/util/include",
|
||||
]
|
||||
|
||||
|
||||
def _clone_cutlass_paths(build_root: str) -> list[str]:
|
||||
paths = _cutlass_paths()
|
||||
cutlass_root = _cutlass_path()
|
||||
for path in _cutlass_paths():
|
||||
old_path = os.path.join(cutlass_root, path)
|
||||
new_path = os.path.join(build_root, path)
|
||||
shutil.copytree(old_path, new_path, dirs_exist_ok=True)
|
||||
return paths
|
||||
|
||||
|
||||
def _cutlass_include_paths() -> list[str]:
|
||||
cutlass_path = _cutlass_path()
|
||||
return [
|
||||
# Use realpath to get canonical absolute paths, in order not to mess up cache keys
|
||||
os.path.realpath(os.path.join(cutlass_path, path))
|
||||
for path in _cutlass_paths()
|
||||
]
|
||||
|
||||
|
||||
def _cuda_compiler() -> Optional[str]:
|
||||
if cuda_env.nvcc_exist(config.cutlass.cuda_cxx):
|
||||
return config.cutlass.cuda_cxx
|
||||
if config.is_fbcode():
|
||||
return os.path.join(build_paths.sdk_home, "bin", "nvcc")
|
||||
if cuda_env.nvcc_exist(os.getenv("CUDACXX")):
|
||||
return os.getenv("CUDACXX", "")
|
||||
if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")):
|
||||
return os.path.realpath(os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc"))
|
||||
return "nvcc"
|
||||
|
||||
|
||||
def _cuda_lib_options() -> list[str]:
|
||||
"""
|
||||
Util function for CUTLASS backend to find the correct CUDA libraries.
|
||||
"""
|
||||
_set_gpu_runtime_env() # cpp_extension consults the env
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
lpaths = cpp_extension.library_paths(device_type="cuda")
|
||||
if use_re_build():
|
||||
lpaths += [
|
||||
build_paths.sdk_lib,
|
||||
os.path.join(build_paths.sdk_lib, "stubs"),
|
||||
]
|
||||
extra_ldflags: list[str] = []
|
||||
if is_linux():
|
||||
_transform_cuda_paths(lpaths)
|
||||
for path in lpaths:
|
||||
if "torch/lib" in path:
|
||||
# don't want to depend on pytorch
|
||||
continue
|
||||
extra_ldflags.append(f"-L{path}")
|
||||
# -rpath ensures the DLL can find its dependencies when loaded, even
|
||||
# if the library path is non-standard.
|
||||
# But do not add the stubs folder to rpath as the driver is expected to be found at runtime
|
||||
if os.path.basename(path) != "stubs":
|
||||
extra_ldflags.extend(["-Xlinker", f"-rpath={path}"])
|
||||
extra_ldflags.append("-lcuda")
|
||||
extra_ldflags.append("-lcudart")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unsupported env, failed to find cuda libs! Currently only Linux is supported."
|
||||
)
|
||||
return extra_ldflags
|
||||
|
||||
|
||||
def _nvcc_host_compiler_options() -> list[str]:
|
||||
return [
|
||||
"-fPIC",
|
||||
"-fno-strict-aliasing",
|
||||
"-fvisibility=hidden",
|
||||
"-Wconversion",
|
||||
]
|
||||
|
||||
|
||||
def _nvcc_arch_as_compile_option() -> str:
|
||||
arch = cuda_env.get_cuda_arch()
|
||||
if arch == "90":
|
||||
# Required by cutlass compilation.
|
||||
return "90a"
|
||||
if arch == "100":
|
||||
return "100a"
|
||||
return arch
|
||||
|
||||
|
||||
def _nvcc_compiler_options() -> list[str]:
|
||||
arch = _nvcc_arch_as_compile_option()
|
||||
code = [f"sm_{arch}", f"compute_{arch}"]
|
||||
if config.cutlass.enable_cuda_lto:
|
||||
code += [f"lto_{arch}"]
|
||||
options = [
|
||||
"-t=0",
|
||||
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
|
||||
"-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1",
|
||||
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
||||
"-w",
|
||||
f"-gencode=arch=compute_{arch},code=[{','.join(code)}]",
|
||||
config.cutlass.compile_opt_level,
|
||||
"-std=c++17",
|
||||
"--expt-relaxed-constexpr",
|
||||
"-DNDEBUG",
|
||||
]
|
||||
if config.is_fbcode():
|
||||
options.extend(["-ccbin", os.path.dirname(build_paths.gcc)])
|
||||
if config.cutlass.enable_debug_info:
|
||||
options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"])
|
||||
if config.cutlass.enable_ptxas_info:
|
||||
options.extend(
|
||||
[
|
||||
"--keep", # Keep the intermediate files for debugging (including ptx, sass, cubin etc.)
|
||||
"--ptxas-options=--warn-on-local-memory-usage", # warn us if local memory is used in CUDA Kernels
|
||||
"--ptxas-options=--warn-on-spills", # warn us if register spilling happens in CUDA Kernels
|
||||
"--resource-usage", # Report on CUDA resource usage (shared mem, registers etc.)
|
||||
"--source-in-ptx",
|
||||
]
|
||||
) # Annotate the ptx file with source information
|
||||
if config.cutlass.use_fast_math:
|
||||
options.extend(
|
||||
[
|
||||
"--use_fast_math",
|
||||
"-DCUTLASS_USE_TANH_FOR_SIGMOID=1",
|
||||
]
|
||||
)
|
||||
return options
|
||||
|
||||
|
||||
def cuda_compile_command(
|
||||
src_files: list[str],
|
||||
dst_file: str,
|
||||
dst_file_ext: str,
|
||||
extra_args: list[str] | None = None,
|
||||
) -> str:
|
||||
if extra_args is None:
|
||||
extra_args = []
|
||||
if use_re_build():
|
||||
build_path = os.path.dirname(dst_file)
|
||||
include_paths = _clone_cutlass_paths(build_path)
|
||||
src_files = [os.path.basename(src_file) for src_file in src_files]
|
||||
dst_file = os.path.basename(dst_file)
|
||||
else:
|
||||
include_paths = _cutlass_include_paths()
|
||||
cuda_lib_options = _cuda_lib_options()
|
||||
nvcc_host_compiler_options = _nvcc_host_compiler_options()
|
||||
nvcc_compiler_options = _nvcc_compiler_options()
|
||||
options = (
|
||||
nvcc_compiler_options
|
||||
+ extra_args
|
||||
+ [
|
||||
f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}"
|
||||
for opt in nvcc_host_compiler_options
|
||||
]
|
||||
+ ["-I" + path for path in include_paths]
|
||||
+ cuda_lib_options
|
||||
)
|
||||
src_file = " ".join(src_files)
|
||||
res = ""
|
||||
if dst_file_ext == "o":
|
||||
res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}"
|
||||
elif dst_file_ext == "so":
|
||||
options.append("-shared")
|
||||
res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
|
||||
elif dst_file_ext == "exe":
|
||||
res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!")
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
log.debug("CUDA command: %s", res)
|
||||
else:
|
||||
autotuning_log.debug("CUDA command: %s", res)
|
||||
return res
|
||||
|
||||
class CUDACompileSourceCapturingContext:
|
||||
# Helper class for Benchmarking and Testing CUTLASS Kernels in isolation.
|
||||
# Can be used to capture the sourcecode passed to CUDACodeCache.compile
|
||||
|
||||
def __init__(self):
|
||||
self.sources = []
|
||||
self._compile_patch = None
|
||||
|
||||
def __enter__(self, *args, **kwargs):
|
||||
import unittest.mock as mock
|
||||
|
||||
import torch._inductor.codecache
|
||||
|
||||
_compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile
|
||||
|
||||
def my_compile(
|
||||
source_code, dst_file_ext, extra_args: Optional[list[str]] = None
|
||||
):
|
||||
self.sources.append(source_code)
|
||||
return _compile_method_orig(source_code, dst_file_ext)
|
||||
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self._compile_patch = mock.patch(
|
||||
"torch._inductor.codecache.CUDACodeCache.compile", my_compile
|
||||
)
|
||||
self._compile_patch.__enter__(*args, **kwargs) # type: ignore[union-attr]
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
self._compile_patch.__exit__(*args, **kwargs) # type: ignore[union-attr]
|
||||
|
||||
|
||||
def cuda_standalone_runner_compile_command(srcpath: Path, exepath: Path):
|
||||
# returns command string to compile a (captured) CUDA GEMM Kernel source to a standalone executable that's ready to run
|
||||
# Passes the correct preprocessor define to nvcc to ensure the standalone runner is enabled.
|
||||
from torch._inductor.codecache import cuda_compile_command
|
||||
|
||||
extra_args = ["-DGENERATE_STANDALONE_RUNNER=1", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]
|
||||
compile_command = cuda_compile_command(
|
||||
[str(srcpath)], str(exepath), "exe", extra_args=extra_args
|
||||
)
|
||||
return compile_command
|
||||
@ -9,6 +9,7 @@ from ..common import (
|
||||
register_device_op_overrides,
|
||||
TritonScratchWorkspace,
|
||||
)
|
||||
from .cuda_env import get_cuda_arch, get_cuda_version
|
||||
|
||||
|
||||
class CUDADeviceOpOverrides(DeviceOpOverrides):
|
||||
@ -360,5 +361,11 @@ class CUDADeviceOpOverrides(DeviceOpOverrides):
|
||||
else:
|
||||
return [f"CUdeviceptr {var_name} = 0;"], var_name
|
||||
|
||||
def get_device_arch(self) -> str:
|
||||
return get_cuda_arch()
|
||||
|
||||
def get_toolkit_version(self) -> str:
|
||||
return get_cuda_version()
|
||||
|
||||
|
||||
register_device_op_overrides("cuda", CUDADeviceOpOverrides())
|
||||
|
||||
@ -10,9 +10,11 @@ from typing import Any, Optional
|
||||
|
||||
import torch._inductor.config as config
|
||||
from torch._inductor.codecache import cutlass_key
|
||||
from torch._inductor.codegen.cuda import cutlass_utils, serialization
|
||||
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version
|
||||
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
|
||||
from torch._inductor.codegen.common import get_device_op_overrides
|
||||
from torch._inductor.codegen.cutlass import serialization, utils
|
||||
from torch._inductor.codegen.cutlass.serialization import (
|
||||
get_cutlass_operation_serializer,
|
||||
)
|
||||
from torch._inductor.runtime.cache_dir_utils import cache_dir
|
||||
from torch._inductor.utils import clear_on_fresh_cache
|
||||
|
||||
@ -39,7 +41,7 @@ def get_config_request_key(
|
||||
return hashlib.sha256(f.read()).hexdigest()
|
||||
|
||||
serialization_hash = get_file_hash(serialization)
|
||||
cutlass_utils_hash = get_file_hash(cutlass_utils)
|
||||
cutlass_utils_hash = get_file_hash(utils)
|
||||
|
||||
hash_target = "-".join(
|
||||
[
|
||||
@ -63,7 +65,7 @@ def _generate_config_filename(request_key: str) -> str:
|
||||
|
||||
@clear_on_fresh_cache
|
||||
@functools.cache
|
||||
def maybe_fetch_ops() -> Optional[list[Any]]:
|
||||
def maybe_fetch_ops(device_type: str) -> Optional[list[Any]]:
|
||||
"""
|
||||
Fetch ops from databases.
|
||||
"""
|
||||
@ -71,11 +73,14 @@ def maybe_fetch_ops() -> Optional[list[Any]]:
|
||||
return None
|
||||
|
||||
# setup
|
||||
arch: str = get_cuda_arch()
|
||||
# get_cuda_version might return "12.4.0" or "12.4"
|
||||
# but we want to use "12.4"
|
||||
version: str = ".".join(get_cuda_version().split(".")[:2])
|
||||
instantiation_level: str = config.cuda.cutlass_instantiation_level
|
||||
device_op_overrides = get_device_op_overrides(device_type)
|
||||
arch: str = device_op_overrides.get_device_arch()
|
||||
version: str = device_op_overrides.get_toolkit_version()
|
||||
if device_type == "cuda":
|
||||
# get_cuda_version might return "12.4.0" or "12.4"
|
||||
# but we want to use "12.4"
|
||||
version = ".".join(version.split(".")[:2])
|
||||
instantiation_level: str = config.cutlass.cutlass_instantiation_level
|
||||
|
||||
# filename and filepath
|
||||
request_key: str = get_config_request_key(arch, version, instantiation_level)
|
||||
@ -11,7 +11,7 @@ from typing import Any, Optional, Union
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._inductor.autotune_process import TensorMeta
|
||||
from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops
|
||||
from torch._inductor.codegen.cutlass.cache import maybe_fetch_ops
|
||||
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
|
||||
from torch._inductor.runtime.runtime_utils import dynamo_timed
|
||||
from torch._inductor.scheduler import BaseSchedulerNode
|
||||
@ -19,11 +19,11 @@ from torch._inductor.select_algorithm import create_inputs_key
|
||||
from torch._inductor.utils import clear_on_fresh_cache
|
||||
|
||||
from ... import ir
|
||||
from ...config import cuda as inductor_cuda_config
|
||||
from ...config import cutlass as inductor_cutlass_config
|
||||
from ...ir import (
|
||||
Buffer,
|
||||
ChoiceCaller,
|
||||
CUDATemplateBuffer,
|
||||
CUTLASSTemplateBuffer,
|
||||
FixedLayout,
|
||||
IRNode,
|
||||
Layout,
|
||||
@ -32,11 +32,12 @@ from ...ir import (
|
||||
from ...utils import is_dynamic, Placeholder
|
||||
from ...virtualized import V
|
||||
from ..common import IndentedBuffer
|
||||
from . import cutlass_utils
|
||||
from .cuda_kernel import CUDATemplateKernel
|
||||
from .cuda_template import CUTLASSTemplate
|
||||
from .cutlass_python_evt import CutlassEVTCodegen, scaled_mm_evt
|
||||
from .cutlass_utils import (
|
||||
from ..cuda import cuda_env
|
||||
from . import utils as cutlass_utils
|
||||
from .kernel import CUTLASSTemplateKernel
|
||||
from .python_evt import CutlassEVTCodegen, scaled_mm_evt
|
||||
from .template import CUTLASSTemplate
|
||||
from .utils import (
|
||||
ACCUMULATOR_DTYPES,
|
||||
dtype_match,
|
||||
torch_dtype_to_cutlass_type,
|
||||
@ -578,7 +579,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
for name, op in ops:
|
||||
for (
|
||||
swizzle
|
||||
) in inductor_cuda_config.cutlass_max_profiling_swizzle_options:
|
||||
) in inductor_cutlass_config.cutlass_max_profiling_swizzle_options:
|
||||
description = f"{name} swizzle={swizzle}"
|
||||
self.maybe_append_choice(
|
||||
choices,
|
||||
@ -621,7 +622,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/device/gemm_sparse.h"
|
||||
//#include "cutlass/gemm/device/gemm_sparse.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
@ -635,7 +636,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
"""
|
||||
)
|
||||
if inductor_cuda_config.generate_test_runner and not is_dynamic(
|
||||
if inductor_cutlass_config.generate_test_runner and not is_dynamic(
|
||||
*self.input_nodes, self.output_node
|
||||
):
|
||||
res.splice(GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES)
|
||||
@ -712,12 +713,14 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
bool: True if the alignment was successfully updated, False otherwise.
|
||||
"""
|
||||
alignment = cutlass_utils.get_max_alignment(torch_layout)
|
||||
cuda_arch = cutlass_utils.get_cuda_arch()
|
||||
if cuda_arch and int(cuda_arch) >= 90 and alignment < op_element.alignment:
|
||||
return False
|
||||
else:
|
||||
op_element.alignment = alignment
|
||||
return True
|
||||
if torch.cuda.is_available():
|
||||
cuda_arch = cuda_env.get_cuda_arch()
|
||||
cuda_arch = cutlass_utils._normalize_cutlass_arch(cuda_arch)
|
||||
if cuda_arch and int(cuda_arch) >= 90 and alignment < op_element.alignment:
|
||||
return False
|
||||
|
||||
op_element.alignment = alignment
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def should_swap_XW(
|
||||
@ -953,7 +956,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
)
|
||||
return None
|
||||
|
||||
if inductor_cuda_config.cutlass_tma_only and not self._has_tma_epilogue(op):
|
||||
if inductor_cutlass_config.cutlass_tma_only and not self._has_tma_epilogue(op):
|
||||
return None
|
||||
|
||||
# Set epilogue.
|
||||
@ -975,14 +978,16 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
return None
|
||||
|
||||
# Apply regex filters at the end when configuration name doesn't change anymore
|
||||
if inductor_cuda_config.cutlass_op_allowlist_regex:
|
||||
if inductor_cutlass_config.cutlass_op_allowlist_regex:
|
||||
if not re.search(
|
||||
inductor_cuda_config.cutlass_op_allowlist_regex, op.configuration_name()
|
||||
inductor_cutlass_config.cutlass_op_allowlist_regex,
|
||||
op.configuration_name(),
|
||||
):
|
||||
return None
|
||||
if inductor_cuda_config.cutlass_op_denylist_regex is not None:
|
||||
if inductor_cutlass_config.cutlass_op_denylist_regex is not None:
|
||||
if re.search(
|
||||
inductor_cuda_config.cutlass_op_denylist_regex, op.configuration_name()
|
||||
inductor_cutlass_config.cutlass_op_denylist_regex,
|
||||
op.configuration_name(),
|
||||
):
|
||||
return None
|
||||
|
||||
@ -1007,10 +1012,10 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
return self.filtered_ops_cache[self.cache_key]
|
||||
|
||||
with dynamo_timed("CUTLASSGemmTemplate.maybe_fetch_ops"):
|
||||
maybe_ops = maybe_fetch_ops()
|
||||
maybe_ops = maybe_fetch_ops(self.device_type)
|
||||
if maybe_ops is None:
|
||||
log.debug("Cannot fetch ops from cache, generating ops from scratch")
|
||||
full_ops = cutlass_utils.gen_ops()
|
||||
full_ops = cutlass_utils.gen_ops(self.device_type)
|
||||
ops = pytree.tree_flatten(full_ops)[0]
|
||||
else:
|
||||
log.debug("Using cached ops from cache")
|
||||
@ -1035,7 +1040,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
time.time() - start_time,
|
||||
)
|
||||
sorted_res = sorted(res.items())
|
||||
ret_res = sorted_res[: inductor_cuda_config.cutlass_max_profiling_configs]
|
||||
ret_res = sorted_res[: inductor_cutlass_config.cutlass_max_profiling_configs]
|
||||
if len(self.filtered_ops_cache) < 50:
|
||||
self.filtered_ops_cache[self.cache_key] = ret_res
|
||||
else:
|
||||
@ -1060,26 +1065,26 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
|
||||
def render( # type: ignore[override]
|
||||
self,
|
||||
kernel: CUDATemplateKernel,
|
||||
kernel: CUTLASSTemplateKernel,
|
||||
op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821
|
||||
template_buffer_node: Optional[CUDATemplateBuffer] = None,
|
||||
template_buffer_node: Optional[CUTLASSTemplateBuffer] = None,
|
||||
epilogue_nodes: Optional[list[BaseSchedulerNode]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
The primary entry point for the code rendering process used in this template.
|
||||
Renders the Cutlass based CUDA C++ code for the GEMM Kernel that this template is designed to implement,
|
||||
Renders the Cutlass based CUDA/XPU C++ code for the GEMM Kernel that this template is designed to implement,
|
||||
including potentially fused epilogues.
|
||||
|
||||
Args:
|
||||
kernel (CUDATemplateKernel): The kernel to be rendered.
|
||||
kernel (CUTLASSTemplateKernel): The kernel to be rendered.
|
||||
op (cutlass_gemm_op.GemmOperation, optional): A GEMM operation that is required to be compatible with the
|
||||
input and output definitions as well as a possible epilogue. Defaults to None.
|
||||
**kwargs: Additional keyword arguments. Currently unused.
|
||||
|
||||
Returns:
|
||||
str: Cutlass based CUDA C++ code fragment as a string, to be used by the current
|
||||
CUDATemplateKernel or autotuning code.
|
||||
str: Cutlass based CUDA/XPU C++ code fragment as a string, to be used by the current
|
||||
CUTLASSTemplateKernel or autotuning code.
|
||||
|
||||
Note:
|
||||
All inputs and their corresponding buffer addresses and names take precedence over previously
|
||||
@ -1277,7 +1282,9 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
}
|
||||
options.update(dict(zip(extra_names, extra_inputs)))
|
||||
res = self._template_from_string(self._get_template()).render(**options)
|
||||
if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias):
|
||||
if inductor_cutlass_config.generate_test_runner and not is_dynamic(
|
||||
X, W, Y, Bias
|
||||
):
|
||||
test_runner_code = self._template_from_string(
|
||||
GEMM_STANDALONE_RUNNER_TEMPLATE
|
||||
).render(**options)
|
||||
@ -1295,7 +1302,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
names_str: str = "",
|
||||
) -> str:
|
||||
"""
|
||||
Helper method to render the Cutlass CUDA C++ code required for calling the GEMM operation in the standalone
|
||||
Helper method to render the Cutlass CUDA/XPU C++ code required for calling the GEMM operation in the standalone
|
||||
test runner that might also be generated along with the rest of the code, if the corresponding config is
|
||||
enabled.
|
||||
|
||||
@ -1483,7 +1490,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
||||
output_dtype: torch.dtype,
|
||||
accumulator_dtype: torch.dtype,
|
||||
) -> tuple[str, str, str, EVTArgRenames]:
|
||||
from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace
|
||||
from .lib_extensions.evt_extensions import create_example_tensors, trace
|
||||
|
||||
acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype)
|
||||
output_dtype = torch_dtype_to_cutlass_type(output_dtype)
|
||||
@ -1554,7 +1561,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
||||
op: GemmOperation,
|
||||
evt_name: Optional[str] = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance.
|
||||
"""Defines and renders the Cutlass / CUDA/XPU C++ code for a given GEMM operation instance.
|
||||
|
||||
This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply
|
||||
forms a core part of a number of scientific applications, so this efficient and adaptable implementation is
|
||||
@ -1570,9 +1577,11 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
||||
from .cutlass_lib_extensions import gemm_operation_extensions as gemm_extensions
|
||||
from .lib_extensions import gemm_operation_extensions as gemm_extensions
|
||||
|
||||
emitter = gemm_extensions.EmitGemmUniversal3xInstanceWithEVT(evt_name=evt_name) # type: ignore[call-arg]
|
||||
emitter = gemm_extensions.EmitGemmUniversal3xInstanceWithEVT(
|
||||
evt_name=evt_name, device_type=self.device_type
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
if not hasattr(op, "epilogue_functor") or not isinstance(
|
||||
op.epilogue_functor, enum.Enum
|
||||
@ -1629,11 +1638,11 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
||||
Y: IRNode,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
kernel: CUDATemplateKernel,
|
||||
kernel: CUTLASSTemplateKernel,
|
||||
epilogue_args,
|
||||
) -> str:
|
||||
"""
|
||||
Render the Cutlass CUDA C++ code required for passing arguments to the GEMM operation.
|
||||
Render the Cutlass CUDA/XPU C++ code required for passing arguments to the GEMM operation.
|
||||
|
||||
Args:
|
||||
argument_template (str): Template for the GEMM operation arguments.
|
||||
@ -1646,11 +1655,11 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
||||
Y (IRNode): The output tensor.
|
||||
alpha (float): Scaling factor for the product of the inputs.
|
||||
beta (float): Scaling factor for the output tensor.
|
||||
kernel (CUDATemplateKernel): CUDA Template kernel for the operation.
|
||||
kernel (CUTLASSTemplateKernel): CUDA/XPU Template kernel for the operation.
|
||||
epilogue_args (any): Additional arguments for the epilogue state.
|
||||
|
||||
Returns:
|
||||
str: A block of CUDA C++ code as a string, ready to be used as arguments for the GEMM operation.
|
||||
str: A block of CUDA/XPU C++ code as a string, ready to be used as arguments for the GEMM operation.
|
||||
|
||||
Note: If `should_swap_xw` is True, a transpose operation will be applied to the X, W, Bias, and Y
|
||||
tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped
|
||||
@ -1710,6 +1719,8 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
||||
|
||||
|
||||
class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
|
||||
"""CUTLASS 2x GEMM Template, which is used to generate CUTLASS GEMM kernels"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_nodes: list[Buffer],
|
||||
@ -1918,7 +1929,7 @@ class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
|
||||
Y: IRNode,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
kernel: CUDATemplateKernel,
|
||||
kernel: CUTLASSTemplateKernel,
|
||||
epilogue_args,
|
||||
) -> str:
|
||||
"""
|
||||
@ -1937,7 +1948,7 @@ class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
|
||||
Y (IRNode): The output tensor.
|
||||
alpha (float): Scaling factor for the product of the inputs.
|
||||
beta (float): Scaling factor for the output tensor.
|
||||
kernel (CUDATemplateKernel): CUDA Template kernel for the operation.
|
||||
kernel (CUTLASSTemplateKernel): CUDA Template kernel for the operation.
|
||||
epilogue_args (any): Additional arguments for the epilogue state.
|
||||
|
||||
Returns:
|
||||
@ -10,22 +10,23 @@ from sympy import Expr, symbols
|
||||
|
||||
import torch._inductor.config as config
|
||||
from torch import dtype as torch_dtype
|
||||
from torch._inductor.codegen.common import get_device_op_overrides
|
||||
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
|
||||
from torch._inductor.scheduler import BaseSchedulerNode
|
||||
from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE
|
||||
from .utils import DTYPE_TO_CUTLASS_TYPE
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .cuda_template import ArgInfo
|
||||
from .template import ArgInfo
|
||||
|
||||
from ...autotune_process import CUDABenchmarkRequest
|
||||
from ...autotune_process import CUTLASSBenchmarkRequest
|
||||
from ...ir import (
|
||||
Buffer,
|
||||
ChoiceCaller,
|
||||
CUDATemplateBuffer,
|
||||
CUTLASSTemplateBuffer,
|
||||
IRNode,
|
||||
Layout,
|
||||
PrimitiveInfoType,
|
||||
@ -46,7 +47,7 @@ from ..cpp_utils import CppPrinter, DTYPE_TO_CPP
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.codegen.cuda.cuda_template import CUDATemplate
|
||||
from torch._inductor.codegen.cutlass.template import CUTLASSTemplate
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -72,9 +73,9 @@ class LayoutArg:
|
||||
return self.node == node and self.attr == attr and self.dim == dim
|
||||
|
||||
|
||||
class CUDAKernel(Kernel):
|
||||
class CUTLASSKernel(Kernel):
|
||||
"""
|
||||
Baseclass for CUDA / Cutlass based Kernels
|
||||
Baseclass for Cutlass based Kernels
|
||||
"""
|
||||
|
||||
overrides = OpOverrides # type: ignore[assignment]
|
||||
@ -191,21 +192,20 @@ class CUDAKernel(Kernel):
|
||||
return _normalize_idx(-1, len(strides))
|
||||
|
||||
|
||||
class CUDATemplateKernel(CUDAKernel):
|
||||
class CUTLASSTemplateKernel(CUTLASSKernel):
|
||||
"""
|
||||
Template kernels defined by CUDA / Cutlass in C++.
|
||||
Template kernels defined by Cutlass in C++.
|
||||
"""
|
||||
|
||||
_EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name: str,
|
||||
runtime_arg_info: list["ArgInfo"],
|
||||
runtime_arg_values: list[Any],
|
||||
device_type: str = "cuda", # type: ignore[assignment]
|
||||
) -> None:
|
||||
"""
|
||||
Initializes a new instance of the CUDATemplateKernel class.
|
||||
Initializes a new instance of the CUTLASSTemplateKernel class.
|
||||
|
||||
Args:
|
||||
kernel_name (str): The name of the kernel.
|
||||
@ -214,6 +214,9 @@ class CUDATemplateKernel(CUDAKernel):
|
||||
self.kernel_name = kernel_name
|
||||
self.runtime_arg_info = runtime_arg_info
|
||||
self.runtime_arg_values = runtime_arg_values
|
||||
self.device_type = device_type
|
||||
self.device_codegen = get_device_op_overrides(self.device_type)
|
||||
self._EXTRA_CPP_ARGS = f"size_t* workspace_size, uint8_t* workspace, {self.device_codegen.cpp_stream_type()} stream"
|
||||
|
||||
def check_not_null(self, node: IRNode) -> str:
|
||||
"""
|
||||
@ -328,14 +331,14 @@ class CUDATemplateKernel(CUDAKernel):
|
||||
def call_kernel(
|
||||
self,
|
||||
name: str,
|
||||
node: "CUDATemplateBuffer", # type: ignore[name-defined]
|
||||
node: "CUTLASSTemplateBuffer", # type: ignore[name-defined]
|
||||
) -> None:
|
||||
"""
|
||||
Generates code to call the kernel through V.graph.wrapper_code.
|
||||
used from within torch._inductor.wrapper.PythonWrapperCodegen
|
||||
|
||||
name: Name of kernel function.
|
||||
node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes
|
||||
node: The CUTLASSTemplateBuffer node which contains information about the kernel, it's fused epilogue nodes
|
||||
as well as all required inputs and outputs.
|
||||
"""
|
||||
wrapper = V.graph.wrapper_code
|
||||
@ -423,7 +426,7 @@ class CUDATemplateKernel(CUDAKernel):
|
||||
# Helper method, called into from CUTLASSGemmTemplate
|
||||
if node is None:
|
||||
return default_dtype
|
||||
from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate
|
||||
from torch._inductor.codegen.cutlass.template import CUTLASSTemplate
|
||||
|
||||
return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]
|
||||
|
||||
@ -562,16 +565,16 @@ class CUDATemplateKernel(CUDAKernel):
|
||||
self.store_buffer_names.add(name)
|
||||
|
||||
|
||||
class CUDATemplateCaller(ChoiceCaller):
|
||||
class CUTLASSTemplateCaller(ChoiceCaller):
|
||||
"""
|
||||
CUDATemplateCaller
|
||||
CUTLASSTemplateCaller
|
||||
|
||||
This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller.
|
||||
This class represents a caller for CUTLASS template kernels. It is a subclass of ChoiceCaller.
|
||||
Attributes:
|
||||
name (str): The name of the caller.
|
||||
category (str): The category of the caller.
|
||||
bmreq (CUDABenchmarkRequest): The benchmark request for the caller.
|
||||
template_buffer (CUDATemplateBuffer): The template buffer for the caller.
|
||||
bmreq (CUTLASSBenchmarkRequest): The benchmark request for the caller.
|
||||
template_buffer (CUTLASSTemplateBuffer): The template buffer for the caller.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -581,12 +584,12 @@ class CUDATemplateCaller(ChoiceCaller):
|
||||
input_nodes: list[Buffer],
|
||||
layout: Layout,
|
||||
make_kernel_render: Callable[
|
||||
[CUDATemplateBuffer, Optional[list[BaseSchedulerNode]]],
|
||||
tuple[CUDATemplateKernel, functools.partial[str]],
|
||||
[CUTLASSTemplateBuffer, Optional[list[BaseSchedulerNode]]],
|
||||
tuple[CUTLASSTemplateKernel, functools.partial[str]],
|
||||
],
|
||||
bmreq: CUDABenchmarkRequest,
|
||||
bmreq: CUTLASSBenchmarkRequest,
|
||||
supports_epilogue_fusion: bool,
|
||||
template: "CUDATemplate", # type: ignore[name-defined]
|
||||
template: "CUTLASSTemplate", # type: ignore[name-defined]
|
||||
info_kwargs: Optional[
|
||||
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
|
||||
], # type: ignore[type-arg]
|
||||
@ -612,10 +615,10 @@ class CUDATemplateCaller(ChoiceCaller):
|
||||
return self.bmreq.benchmark(*args, out=out)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"CUDATemplateCaller(source_file={self.bmreq.source_file})"
|
||||
return f"CUTLASSTemplateCaller(source_file={self.bmreq.source_file})"
|
||||
|
||||
def call_name(self) -> str:
|
||||
return f"cuda_template_kernels.{self.name}"
|
||||
return f"cutlass_template_kernels.{self.name}"
|
||||
|
||||
def kernel_hash_key(self) -> str:
|
||||
"""
|
||||
@ -675,7 +678,7 @@ class CUDATemplateCaller(ChoiceCaller):
|
||||
def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]:
|
||||
self.bmreq.update_workspace_size()
|
||||
return TensorBox.create(
|
||||
CUDATemplateBuffer(
|
||||
CUTLASSTemplateBuffer(
|
||||
layout=self.layout,
|
||||
inputs=self.input_nodes,
|
||||
make_kernel_render=self.make_kernel_render,
|
||||
@ -10,7 +10,7 @@ from torch._inductor.ir import (
|
||||
)
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from ..cutlass_utils import torch_dtype_to_cutlass_type, try_import_cutlass
|
||||
from ..utils import torch_dtype_to_cutlass_type, try_import_cutlass
|
||||
|
||||
|
||||
EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace
|
||||
@ -237,7 +237,7 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
|
||||
size_hint_fn: Callable[[Union[Expr, int]], int],
|
||||
arg_renames: EVTArgRenames,
|
||||
) -> str:
|
||||
from ..cuda_template import CUTLASSTemplate
|
||||
from ..template import CUTLASSTemplate
|
||||
|
||||
# Today, arguments are either a pointer to the
|
||||
# node's memory, a stride tuple, the datatype
|
||||
@ -1,5 +1,5 @@
|
||||
# mypy: ignore-errors
|
||||
from ..cutlass_utils import try_import_cutlass
|
||||
from ..utils import try_import_cutlass
|
||||
|
||||
|
||||
# copied / modified from original at
|
||||
@ -16,7 +16,7 @@ if try_import_cutlass():
|
||||
class EmitGemmUniversal3xInstanceWithEVT:
|
||||
"""Responsible for emitting a CUTLASS 3.x template definition"""
|
||||
|
||||
def __init__(self, operation_suffix="", evt_name=None):
|
||||
def __init__(self, operation_suffix="", evt_name=None, device_type="cuda"):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = [
|
||||
"cutlass/cutlass.h",
|
||||
@ -32,6 +32,13 @@ if try_import_cutlass():
|
||||
${element_c},
|
||||
${element_epilogue}
|
||||
>"""
|
||||
if device_type == "xpu":
|
||||
self.builtin_epilogue_functor_template = """${epilogue_functor}<
|
||||
${element_accumulator},
|
||||
${element_epilogue},
|
||||
${element_c},
|
||||
${element_epilogue}
|
||||
>"""
|
||||
|
||||
self.evt_name = evt_name
|
||||
self.gemm_template = """
|
||||
@ -175,6 +182,8 @@ ${compile_guard_end}
|
||||
f"cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(\
|
||||
sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage))>"
|
||||
)
|
||||
if operation.arch == 11:
|
||||
stage_count_string = "cutlass::gemm::collective::StageCountAuto"
|
||||
|
||||
epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto"
|
||||
|
||||
@ -350,6 +359,11 @@ cute::Layout<cute::Shape<int,int,int>, {operation_name_str}_StrideNarrow>{{}}));
|
||||
if self.evt_name:
|
||||
epilogue_functor = self.evt_name
|
||||
|
||||
arch = (
|
||||
"cutlass::arch::IntelXe"
|
||||
if operation.arch == 11
|
||||
else f"cutlass::arch::Sm{operation.arch}"
|
||||
)
|
||||
values = {
|
||||
"operation_name": operation_name_str,
|
||||
"operation_suffix": self.operation_suffix,
|
||||
@ -369,7 +383,7 @@ cute::Layout<cute::Shape<int,int,int>, {operation_name_str}_StrideNarrow>{{}}));
|
||||
"element_accumulator": DataTypeTag[operation.accumulator_type()],
|
||||
"opcode_class_main": OpcodeClassTag[opcode_class_main],
|
||||
"opcode_class_epi": OpcodeClassTag[opcode_class_epi],
|
||||
"arch": f"cutlass::arch::Sm{operation.arch}",
|
||||
"arch": arch,
|
||||
"tile_shape_m": str(tile_shape_m),
|
||||
"tile_shape_n": str(tile_shape_n),
|
||||
"tile_shape_k": str(tile_shape_k),
|
||||
@ -183,11 +183,11 @@ class CutlassEVTCodegen(CutlassEVTOpsMixIn):
|
||||
|
||||
@staticmethod
|
||||
def ir_to_evt_python_code(
|
||||
cuda_template_node_name: str,
|
||||
cutlass_template_node_name: str,
|
||||
epilogue_nodes: list[BaseSchedulerNode],
|
||||
removed_buffers: OrderedSet[str],
|
||||
) -> tuple[list[str], list[str], dict[str, Any], str]:
|
||||
codegen = CutlassEVTCodegen(cuda_template_node_name, removed_buffers)
|
||||
codegen = CutlassEVTCodegen(cutlass_template_node_name, removed_buffers)
|
||||
handler = _AssignmentFormatter(codegen)
|
||||
|
||||
with virtualized.V.set_ops_handler(handler):
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from torch._inductor.codegen.cuda.cutlass_python_evt import (
|
||||
from torch._inductor.codegen.cutlass.python_evt import (
|
||||
CutlassEVTCodegen,
|
||||
MockCutlassHandler,
|
||||
)
|
||||
@ -14,7 +14,7 @@ from torch.utils._ordered_set import OrderedSet
|
||||
from ...._dynamo.utils import counters
|
||||
from ... import config
|
||||
from ...codecache import code_hash, get_path
|
||||
from ...ir import Buffer, ComputedBuffer, CUDATemplateBuffer, Pointwise
|
||||
from ...ir import Buffer, ComputedBuffer, CUTLASSTemplateBuffer, Pointwise
|
||||
from ...scheduler import (
|
||||
BaseSchedulerNode,
|
||||
BaseScheduling,
|
||||
@ -36,13 +36,13 @@ class WhyNoFuseNames(WhyNoFuse):
|
||||
self.name2 = name2
|
||||
|
||||
|
||||
class CUDACPPScheduling(BaseScheduling):
|
||||
class CUTLASSScheduling(BaseScheduling):
|
||||
"""
|
||||
Partial Scheduling implementation for CUDA C++ Kernels.
|
||||
Partial Scheduling implementation for cutlass C++ Kernels.
|
||||
This class is intended to be used in combination with TritonScheduling,
|
||||
and delegated to by CUDACombinedScheduling.
|
||||
and delegated to by CombinedScheduling.
|
||||
|
||||
It handles fusion decisions and CUDA C++ specific template code generation.
|
||||
It handles fusion decisions and cutlass C++ specific template code generation.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@ -53,25 +53,25 @@ class CUDACPPScheduling(BaseScheduling):
|
||||
return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
|
||||
|
||||
@staticmethod
|
||||
def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool:
|
||||
def is_cutlass_template(node: BaseSchedulerNode) -> bool:
|
||||
return isinstance(node, SchedulerNode) and isinstance(
|
||||
node.node, CUDATemplateBuffer
|
||||
node.node, CUTLASSTemplateBuffer
|
||||
)
|
||||
|
||||
def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool:
|
||||
return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template(node)
|
||||
def is_cutlass_fused_template(self, node: BaseSchedulerNode) -> bool:
|
||||
return isinstance(node, FusedSchedulerNode) and self.is_cutlass_template(node)
|
||||
|
||||
def can_fuse_vertical(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
) -> bool:
|
||||
if self.is_cuda_cpp_template(node1) and isinstance(node2, BaseSchedulerNode):
|
||||
if self.is_cutlass_template(node1) and isinstance(node2, BaseSchedulerNode):
|
||||
assert node1.node, "node1.node should not be None"
|
||||
return self._can_fuse_epilogue_impl(
|
||||
cast(CUDATemplateBuffer, node1.node),
|
||||
cast(CUTLASSTemplateBuffer, node1.node),
|
||||
[],
|
||||
node2, # type: ignore[arg-type]
|
||||
)
|
||||
elif self.is_cuda_cpp_fused_template(node1) and isinstance(
|
||||
elif self.is_cutlass_fused_template(node1) and isinstance(
|
||||
node2, BaseSchedulerNode
|
||||
):
|
||||
assert node1.node, "node1.node should not be None"
|
||||
@ -130,16 +130,16 @@ class CUDACPPScheduling(BaseScheduling):
|
||||
prologue_nodes: Sequence[BaseSchedulerNode],
|
||||
):
|
||||
"""
|
||||
Codegen a CUDA template, possibly with fused epilogues
|
||||
Codegen a cutlass template, possibly with fused epilogues
|
||||
"""
|
||||
counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes)
|
||||
assert self.is_cuda_cpp_template(template_node), (
|
||||
"Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
|
||||
counters["inductor"]["cutlass_epilogue_fusion_counter"] += len(epilogue_nodes)
|
||||
assert self.is_cutlass_template(template_node), (
|
||||
"Template node passed to CUTLASSScheduling.codegen_template must be a SchedulerNode that wraps a CUTLASSTemplateBuffer"
|
||||
)
|
||||
template_node = cast(SchedulerNode, template_node)
|
||||
_, (_numel, rnumel) = template_node.group
|
||||
assert rnumel == 1
|
||||
ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node)
|
||||
ctb: CUTLASSTemplateBuffer = cast(CUTLASSTemplateBuffer, template_node.node)
|
||||
epilogue_ir_nodes: list[Buffer] = [n.node for n in epilogue_nodes] # type: ignore[misc]
|
||||
assert all(isinstance(n, ComputedBuffer) for n in epilogue_ir_nodes), (
|
||||
"Epilogue nodes must all be instances of ir.ComputedBuffer"
|
||||
@ -197,7 +197,7 @@ class CUDACPPScheduling(BaseScheduling):
|
||||
|
||||
def _can_fuse_epilogue_impl(
|
||||
self,
|
||||
cuda_template_buffer: CUDATemplateBuffer,
|
||||
cutlass_template_buffer: CUTLASSTemplateBuffer,
|
||||
existing_epilogue_nodes: list[BaseSchedulerNode],
|
||||
node_to_fuse: BaseSchedulerNode,
|
||||
) -> bool:
|
||||
@ -206,18 +206,20 @@ class CUDACPPScheduling(BaseScheduling):
|
||||
support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes.
|
||||
|
||||
Args:
|
||||
cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer
|
||||
cutlass_template_buffer : A CUTLASSTemplateBuffer object representing the CUTLASS template and it's result buffer
|
||||
existing_epilogue_nodes : List[SchedulerNode]: The list of already fused epilogue nodes.
|
||||
node_to_fuse: The SchedulerNode node to be checked if it can be fused with the epilogue.
|
||||
Returns:
|
||||
- bool: True if the given node can be fused with the epilogue, False otherwise.
|
||||
|
||||
"""
|
||||
why = WhyNoFuseNames(cuda_template_buffer.get_name(), node_to_fuse.get_name())
|
||||
why = WhyNoFuseNames(
|
||||
cutlass_template_buffer.get_name(), node_to_fuse.get_name()
|
||||
)
|
||||
|
||||
scheduler_nodes_to_fuse = node_to_fuse.get_nodes()
|
||||
|
||||
assert isinstance(cuda_template_buffer, CUDATemplateBuffer)
|
||||
assert isinstance(cutlass_template_buffer, CUTLASSTemplateBuffer)
|
||||
|
||||
# Checks on constituent nodes
|
||||
for s_node in scheduler_nodes_to_fuse:
|
||||
@ -235,18 +237,18 @@ class CUDACPPScheduling(BaseScheduling):
|
||||
|
||||
name = node.get_computed_buffer_name() # type: ignore[attr-defined]
|
||||
# dtype can differ, and strides can differ as long as they are broadcastable
|
||||
if node.get_size() != cuda_template_buffer.get_size():
|
||||
if node.get_size() != cutlass_template_buffer.get_size():
|
||||
why(
|
||||
f"{name}'s size: {node.get_size()} differs from {cuda_template_buffer.get_name()}'s \
|
||||
size: {cuda_template_buffer.get_size()}"
|
||||
f"{name}'s size: {node.get_size()} differs from {cutlass_template_buffer.get_name()}'s \
|
||||
size: {cutlass_template_buffer.get_size()}"
|
||||
)
|
||||
return False
|
||||
|
||||
assert len(
|
||||
existing_epilogue_nodes
|
||||
) or cuda_template_buffer.get_name() in OrderedSet(
|
||||
) or cutlass_template_buffer.get_name() in OrderedSet(
|
||||
[rd.name for rd in node_to_fuse.read_writes.reads]
|
||||
), "First epilogue node must read from cuda template buffer"
|
||||
), "First epilogue node must read from cutlass template buffer"
|
||||
|
||||
if node_to_fuse.has_aliasing_or_mutation():
|
||||
why(f"{node_to_fuse.get_name()} has aliasing or mutation")
|
||||
@ -257,22 +259,20 @@ size: {cuda_template_buffer.get_size()}"
|
||||
)
|
||||
return False
|
||||
elif (
|
||||
not config.cuda.cutlass_epilogue_fusion_enabled
|
||||
not config.cutlass.cutlass_epilogue_fusion_enabled
|
||||
or not config.epilogue_fusion
|
||||
):
|
||||
why("cutlass epilogue fusion is not enabled")
|
||||
return False
|
||||
elif not cuda_template_buffer.supports_epilogue_fusion:
|
||||
elif not cutlass_template_buffer.supports_epilogue_fusion:
|
||||
why("epilogue fusion is only supported for TMA-enabled gemm ops")
|
||||
return False
|
||||
|
||||
try:
|
||||
from torch._inductor.codegen.cuda.cutlass_python_evt import (
|
||||
CutlassEVTCodegen,
|
||||
)
|
||||
from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen
|
||||
|
||||
CutlassEVTCodegen.ir_to_evt_python_code(
|
||||
cuda_template_buffer.get_name(),
|
||||
cutlass_template_buffer.get_name(),
|
||||
existing_epilogue_nodes + list(node_to_fuse.get_nodes()),
|
||||
OrderedSet(),
|
||||
)
|
||||
@ -282,13 +282,13 @@ size: {cuda_template_buffer.get_size()}"
|
||||
if not_implemented_op.startswith("_op_"):
|
||||
not_implemented_op = not_implemented_op[4:]
|
||||
why(
|
||||
f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}, \
|
||||
f"Cannot fuse epilogue node {node_to_fuse} into {cutlass_template_buffer.name}, \
|
||||
likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950
|
||||
)
|
||||
return False
|
||||
else: # Likely due to unsupported dtype.
|
||||
why(
|
||||
f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}. \
|
||||
f"Cannot fuse epilogue node {node_to_fuse} into {cutlass_template_buffer.name}. \
|
||||
Reason: {not_implemented_op}" # noqa: G004, B950
|
||||
)
|
||||
return False
|
||||
@ -4,7 +4,7 @@ import json
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
|
||||
from torch._inductor.codegen.cutlass.utils import try_import_cutlass
|
||||
|
||||
|
||||
class CUTLASSOperationSerializer:
|
||||
@ -4,7 +4,6 @@ import hashlib
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import override
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
@ -14,13 +13,13 @@ from torch._inductor import config
|
||||
from torch._inductor.utils import clear_on_fresh_cache, Placeholder
|
||||
from torch._logging import getArtifactLogger
|
||||
|
||||
from ...autotune_process import CUDABenchmarkRequest, TensorMeta
|
||||
from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout
|
||||
from ...autotune_process import CUTLASSBenchmarkRequest, TensorMeta
|
||||
from ...ir import Buffer, CUTLASSTemplateBuffer, IRNode, Layout
|
||||
from ...utils import IndentedBuffer, unique
|
||||
from ...virtualized import V
|
||||
from ..common import KernelTemplate
|
||||
from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
|
||||
from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE
|
||||
from .kernel import CUTLASSTemplateCaller, CUTLASSTemplateKernel
|
||||
from .utils import DTYPE_TO_CUTLASS_TYPE
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -40,7 +39,12 @@ class ArgInfo:
|
||||
|
||||
|
||||
@clear_on_fresh_cache
|
||||
class CUDATemplate(KernelTemplate):
|
||||
class CUTLASSTemplate(KernelTemplate):
|
||||
"""
|
||||
CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the
|
||||
CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels.
|
||||
"""
|
||||
|
||||
index_counter = itertools.count()
|
||||
# dict of cache key to (code, size_args)
|
||||
code_cache: dict[str, tuple[str, tuple[int, ...], tuple[int, ...]]] = {}
|
||||
@ -54,11 +58,11 @@ class CUDATemplate(KernelTemplate):
|
||||
input_reorder: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Baseclass for CUDA C++ Templates, derived from KernelTemplate.
|
||||
Baseclass for CUTLASS C++ Templates, derived from KernelTemplate.
|
||||
Not to be instantiated directly.
|
||||
|
||||
Args:
|
||||
name (str): The name of the CUDATemplate object.
|
||||
name (str): The name of the CUTLASSTemplate object.
|
||||
input_nodes (List[IRNode]): A list of input IRNodes.
|
||||
layout (Layout): The layout of the output buffer / tensor.
|
||||
input_reorder (Optional[List[int]]): An optional list that specifies
|
||||
@ -69,6 +73,7 @@ class CUDATemplate(KernelTemplate):
|
||||
self.output_node: Buffer = Buffer(name="buf_out", layout=layout)
|
||||
self.input_reorder = input_reorder
|
||||
self.layout = layout
|
||||
self.device_type = layout.device.type if input_nodes else "cuda"
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
@ -110,7 +115,7 @@ class CUDATemplate(KernelTemplate):
|
||||
args are different.
|
||||
"""
|
||||
key: Optional[str] = None
|
||||
if config.cuda.enable_caching_codegen:
|
||||
if config.cutlass.enable_caching_codegen:
|
||||
key = self.make_key(name=name, input_key=input_key, layout_repr=layout_repr)
|
||||
|
||||
if key is not None and key in self.code_cache:
|
||||
@ -123,10 +128,11 @@ class CUDATemplate(KernelTemplate):
|
||||
return code, extra_args
|
||||
|
||||
kernel_name = str(Placeholder.KERNEL_NAME)
|
||||
kernel = CUDATemplateKernel(
|
||||
kernel = CUTLASSTemplateKernel(
|
||||
kernel_name=kernel_name,
|
||||
runtime_arg_info=self.get_runtime_arg_info(),
|
||||
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
|
||||
device_type=self.device_type,
|
||||
)
|
||||
with patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)):
|
||||
code = self.render(kernel=kernel, **kwargs)
|
||||
@ -174,10 +180,10 @@ class CUDATemplate(KernelTemplate):
|
||||
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||
**kwargs,
|
||||
) -> CUDATemplateCaller:
|
||||
) -> CUTLASSTemplateCaller:
|
||||
"""
|
||||
Generates the CUDA template caller object for the given GEMM template and operation.
|
||||
This CUDATemplateCaller may be used to call and benchmark the generated CUDA kernel
|
||||
This CUTLASSTemplateCaller may be used to call and benchmark the generated CUDA kernel
|
||||
in a standalone manner to enable Autotuning.
|
||||
|
||||
Args:
|
||||
@ -185,7 +191,7 @@ class CUDATemplate(KernelTemplate):
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
A CUDATemplateCaller object representing the generated CUDA template caller.
|
||||
A CUTLASSTemplateCaller object representing the generated CUDA template caller.
|
||||
"""
|
||||
code, extra_args = self.generate_code_and_args(
|
||||
name=name,
|
||||
@ -200,12 +206,13 @@ class CUDATemplate(KernelTemplate):
|
||||
code = code.replace(self.name, kernel_name)
|
||||
|
||||
# create the BenchmarkRequest
|
||||
bmreq = CUDABenchmarkRequest(
|
||||
bmreq = CUTLASSBenchmarkRequest(
|
||||
kernel_name=kernel_name,
|
||||
input_tensor_meta=input_tensor_meta,
|
||||
output_tensor_meta=output_tensor_meta,
|
||||
extra_args=extra_args,
|
||||
source_code=code,
|
||||
device_type=self.device_type,
|
||||
)
|
||||
|
||||
# kwargs has "op" argument in case of CUTLASSGemmTemplate
|
||||
@ -217,16 +224,17 @@ class CUDATemplate(KernelTemplate):
|
||||
supports_epilogue_fusion = self.supports_epilogue_fusion(op)
|
||||
|
||||
def make_kernel_render(
|
||||
template_node: CUDATemplateBuffer,
|
||||
template_node: CUTLASSTemplateBuffer,
|
||||
epilogue_nodes: Optional[list[BaseSchedulerNode]] = None,
|
||||
) -> tuple[CUDATemplateKernel, functools.partial[str]]:
|
||||
) -> tuple[CUTLASSTemplateKernel, functools.partial[str]]:
|
||||
assert supports_epilogue_fusion or not epilogue_nodes, (
|
||||
"epilogue fusion is not supported for this kernel"
|
||||
)
|
||||
kernel = CUDATemplateKernel(
|
||||
kernel = CUTLASSTemplateKernel(
|
||||
kernel_name=str(Placeholder.KERNEL_NAME),
|
||||
runtime_arg_info=self.get_runtime_arg_info(),
|
||||
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
|
||||
device_type=self.device_type,
|
||||
)
|
||||
render = functools.partial(
|
||||
self.render,
|
||||
@ -237,7 +245,7 @@ class CUDATemplate(KernelTemplate):
|
||||
)
|
||||
return kernel, render
|
||||
|
||||
return CUDATemplateCaller(
|
||||
return CUTLASSTemplateCaller(
|
||||
kernel_name,
|
||||
"cutlass_gemm",
|
||||
self.input_nodes,
|
||||
@ -261,6 +269,18 @@ class CUDATemplate(KernelTemplate):
|
||||
#include <vector>
|
||||
"""
|
||||
)
|
||||
res.splice(
|
||||
"""
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
"""
|
||||
)
|
||||
return res
|
||||
|
||||
def globals(self) -> IndentedBuffer:
|
||||
@ -281,42 +301,6 @@ class CUDATemplate(KernelTemplate):
|
||||
#endif
|
||||
"""
|
||||
)
|
||||
return res
|
||||
|
||||
def render(self, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_runtime_arg_info(self) -> list[ArgInfo]:
|
||||
return []
|
||||
|
||||
def get_runtime_arg_values(self, **kwargs) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
class CUTLASSTemplate(CUDATemplate):
|
||||
"""
|
||||
CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the
|
||||
CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels.
|
||||
"""
|
||||
|
||||
def header(self) -> IndentedBuffer:
|
||||
res = super().header()
|
||||
res.splice(
|
||||
"""
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
"""
|
||||
)
|
||||
return res
|
||||
|
||||
def globals(self) -> IndentedBuffer:
|
||||
res = super().globals()
|
||||
res.splice(
|
||||
"""
|
||||
using namespace cute;
|
||||
@ -382,11 +366,12 @@ class CUTLASSTemplate(CUDATemplate):
|
||||
f"({self._DTYPE_TO_CUTLASS_SPARSE_META.get(node.get_dtype())}*)({ptr})"
|
||||
)
|
||||
|
||||
@override
|
||||
def render(self, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_runtime_arg_info(self) -> list[ArgInfo]:
|
||||
return [ArgInfo("swizzle", "const uint8_t")]
|
||||
|
||||
@override
|
||||
def get_runtime_arg_values(self, **kwargs) -> list[Any]:
|
||||
"""
|
||||
Helper method to retrieve runtime args from generate kwargs
|
||||
@ -7,7 +7,6 @@ import shutil
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
@ -22,8 +21,7 @@ from ... import config
|
||||
from ...ir import Layout
|
||||
from ...runtime.runtime_utils import cache_dir
|
||||
from ...virtualized import V
|
||||
from ..cpp_utils import DTYPE_TO_CPP
|
||||
from .cuda_env import get_cuda_arch, get_cuda_version
|
||||
from ..common import get_device_op_overrides
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -98,14 +96,14 @@ def try_import_cutlass() -> bool:
|
||||
|
||||
# contains both cutlass and cutlass_library
|
||||
# we need cutlass for eVT
|
||||
cutlass_python_path = path_join(config.cuda.cutlass_dir, "python")
|
||||
cutlass_python_path = path_join(config.cutlass.cutlass_dir, "python")
|
||||
torch_root = os.path.abspath(os.path.dirname(torch.__file__))
|
||||
mock_src_path = os.path.join(
|
||||
torch_root,
|
||||
"_inductor",
|
||||
"codegen",
|
||||
"cuda",
|
||||
"cutlass_lib_extensions",
|
||||
"cutlass",
|
||||
"lib_extensions",
|
||||
"cutlass_mock_imports",
|
||||
)
|
||||
|
||||
@ -177,7 +175,10 @@ def try_import_cutlass() -> bool:
|
||||
|
||||
|
||||
@functools.lru_cache(8)
|
||||
def _normalize_cuda_arch(arch: str) -> str:
|
||||
def _normalize_cutlass_arch(arch: str) -> str:
|
||||
if torch.xpu.is_available():
|
||||
return arch
|
||||
|
||||
if int(arch) >= 100:
|
||||
log.warning(
|
||||
"Detected CUDA architecture >= 100: %s. We will generate operations with "
|
||||
@ -229,7 +230,7 @@ class CUTLASSArgs:
|
||||
raise RuntimeError(
|
||||
f"{self.architectures=} or {self.cuda_version=} is None!"
|
||||
)
|
||||
self.architectures = _normalize_cuda_arch(self.architectures)
|
||||
self.architectures = _normalize_cutlass_arch(self.architectures)
|
||||
|
||||
|
||||
@clear_on_fresh_cache
|
||||
@ -251,8 +252,8 @@ def _gen_ops_cached(arch, version) -> dict[Any, Any]:
|
||||
version,
|
||||
)
|
||||
return {}
|
||||
arch = _normalize_cuda_arch(arch)
|
||||
instantiation_level: str = config.cuda.cutlass_instantiation_level
|
||||
arch = _normalize_cutlass_arch(arch)
|
||||
instantiation_level: str = config.cutlass.cutlass_instantiation_level
|
||||
args = CUTLASSArgs(
|
||||
architectures=arch,
|
||||
cuda_version=version,
|
||||
@ -266,6 +267,13 @@ def _gen_ops_cached(arch, version) -> dict[Any, Any]:
|
||||
if hasattr(cutlass_generator, "GenerateSM100"):
|
||||
cutlass_generator.GenerateSM100(manifest, args.cuda_version)
|
||||
cutlass_generator.GenerateSM90(manifest, args.cuda_version)
|
||||
if arch == "11":
|
||||
if hasattr(cutlass_generator, "GeneratePVC"):
|
||||
cutlass_generator.GeneratePVC(manifest, args.cuda_version)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Arch PVC is not supported by current cutlass lib."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
func = getattr(cutlass_generator, "GenerateSM" + arch)
|
||||
@ -283,22 +291,34 @@ def _gen_ops_cached(arch, version) -> dict[Any, Any]:
|
||||
return manifest.operations
|
||||
|
||||
|
||||
def gen_ops() -> dict[Any, Any]:
|
||||
def gen_ops(device_type: str) -> dict[Any, Any]:
|
||||
"""
|
||||
Generates all supported CUTLASS operations.
|
||||
"""
|
||||
with dynamo_timed("cutlass_utils.gen_ops"):
|
||||
arch = get_cuda_arch()
|
||||
version = get_cuda_version()
|
||||
device_op_overrides = get_device_op_overrides(device_type)
|
||||
arch = device_op_overrides.get_device_arch()
|
||||
version = device_op_overrides.get_toolkit_version()
|
||||
return _gen_ops_cached(arch, version)
|
||||
|
||||
|
||||
DTYPE_TO_CUTLASS_TYPE = {
|
||||
**DTYPE_TO_CPP,
|
||||
torch.float16: "__half",
|
||||
torch.bfloat16: "__nv_bfloat16",
|
||||
torch.float8_e4m3fn: "__nv_fp8_e4m3",
|
||||
}
|
||||
from ..cpp_utils import DTYPE_TO_CPP
|
||||
|
||||
|
||||
if torch.xpu.is_available():
|
||||
DTYPE_TO_CUTLASS_TYPE = {
|
||||
**DTYPE_TO_CPP,
|
||||
torch.float16: "uint16_t",
|
||||
torch.bfloat16: "uint16_t",
|
||||
torch.float8_e4m3fn: "uint8_t",
|
||||
}
|
||||
else:
|
||||
DTYPE_TO_CUTLASS_TYPE = {
|
||||
**DTYPE_TO_CPP,
|
||||
torch.float16: "__half",
|
||||
torch.bfloat16: "__nv_bfloat16",
|
||||
torch.float8_e4m3fn: "__nv_fp8_e4m3",
|
||||
}
|
||||
|
||||
|
||||
@functools.lru_cache(32)
|
||||
@ -447,47 +467,3 @@ def get_max_alignment(inductor_layout: Layout) -> int:
|
||||
):
|
||||
return alignment
|
||||
return 1
|
||||
|
||||
|
||||
class CUDACompileSourceCapturingContext:
|
||||
# Helper class for Benchmarking and Testing CUTLASS Kernels in isolation.
|
||||
# Can be used to capture the sourcecode passed to CUDACodeCache.compile
|
||||
|
||||
def __init__(self):
|
||||
self.sources = []
|
||||
self._compile_patch = None
|
||||
|
||||
def __enter__(self, *args, **kwargs):
|
||||
import unittest.mock as mock
|
||||
|
||||
import torch._inductor.codecache
|
||||
|
||||
_compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile
|
||||
|
||||
def my_compile(
|
||||
source_code, dst_file_ext, extra_args: Optional[list[str]] = None
|
||||
):
|
||||
self.sources.append(source_code)
|
||||
return _compile_method_orig(source_code, dst_file_ext)
|
||||
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self._compile_patch = mock.patch(
|
||||
"torch._inductor.codecache.CUDACodeCache.compile", my_compile
|
||||
)
|
||||
self._compile_patch.__enter__(*args, **kwargs) # type: ignore[union-attr]
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
self._compile_patch.__exit__(*args, **kwargs) # type: ignore[union-attr]
|
||||
|
||||
|
||||
def cuda_standalone_runner_compile_command(srcpath: Path, exepath: Path):
|
||||
# returns command string to compile a (captured) CUDA GEMM Kernel source to a standalone executable that's ready to run
|
||||
# Passes the correct preprocessor define to nvcc to ensure the standalone runner is enabled.
|
||||
from torch._inductor.codecache import cuda_compile_command
|
||||
|
||||
extra_args = ["-DGENERATE_STANDALONE_RUNNER=1", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]
|
||||
compile_command = cuda_compile_command(
|
||||
[str(srcpath)], str(exepath), "exe", extra_args=extra_args
|
||||
)
|
||||
return compile_command
|
||||
@ -19,7 +19,7 @@ class ROCmCPPScheduling(BaseScheduling):
|
||||
"""
|
||||
Partial Scheduling implementation for ROCm C++ Kernels.
|
||||
This class is intended to be used in combination with TritonScheduling,
|
||||
and delegated to by CUDACombinedScheduling.
|
||||
and delegated to by CombinedScheduling.
|
||||
|
||||
It handles fusion decisions and ROCm C++ specific template code generation.
|
||||
"""
|
||||
|
||||
@ -1764,7 +1764,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
partial_code.finalize_hook("<DEF_KERNEL>")
|
||||
partial_code.finalize_hook("<ARGDEFS>", strict=False)
|
||||
|
||||
# TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion.
|
||||
# TODO: Maybe unify CUTLASSTemplateKernel to also use PartialRender for flexible epilogue fusion.
|
||||
|
||||
for input_name in kernel.named_input_nodes.keys():
|
||||
subgraph_name = f"<LOAD_INPUT_{input_name}>"
|
||||
|
||||
@ -5886,14 +5886,12 @@ def debug_triton_code(node: BaseSchedulerNode) -> list[str]:
|
||||
if multi_template and multi_template.make_kernel_render is None:
|
||||
lines.append(f"{node.get_name()} Unfinalized multi template buffer")
|
||||
else:
|
||||
from torch._inductor.codegen.cuda_combined_scheduling import (
|
||||
CUDACombinedScheduling,
|
||||
)
|
||||
from torch._inductor.codegen.combined_scheduling import CombinedScheduling
|
||||
|
||||
device = node.get_device()
|
||||
assert device is not None
|
||||
backend = node.scheduler.get_backend(device)
|
||||
assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)), (
|
||||
assert isinstance(backend, (SIMDScheduling, CombinedScheduling)), (
|
||||
f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}"
|
||||
)
|
||||
|
||||
|
||||
114
torch/_inductor/codegen/xpu/compile_utils.py
Normal file
114
torch/_inductor/codegen/xpu/compile_utils.py
Normal file
@ -0,0 +1,114 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from torch._inductor import config
|
||||
from torch._inductor.utils import is_linux
|
||||
|
||||
from ..cuda.compile_utils import _cutlass_include_paths
|
||||
from .xpu_env import get_xpu_arch
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _sycl_compiler() -> Optional[str]:
|
||||
return "icpx"
|
||||
|
||||
|
||||
def _sycl_lib_options() -> list[str]:
|
||||
"""
|
||||
Util function for CUTLASS backend to find the correct XPU libraries.
|
||||
"""
|
||||
# _set_gpu_runtime_env() # cpp_extension consults the env
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
lpaths = cpp_extension.library_paths(device_type="xpu")
|
||||
extra_ldflags: list[str] = []
|
||||
if is_linux():
|
||||
for path in lpaths:
|
||||
if "torch/lib" in path:
|
||||
# don't want to depend on pytorch
|
||||
continue
|
||||
# -rpath ensures the DLL can find its dependencies when loaded, even
|
||||
# if the library path is non-standard.
|
||||
extra_ldflags.extend([f"-L{path}", "-Xlinker", f"-rpath={path}"])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unsupported env, failed to find xpu libs! Currently only Linux is supported."
|
||||
)
|
||||
return extra_ldflags
|
||||
|
||||
|
||||
def _sycl_host_compiler_options() -> list[str]:
|
||||
return [
|
||||
"-fPIC",
|
||||
]
|
||||
|
||||
|
||||
def _sycl_arch_as_compile_option() -> str:
|
||||
arc_option_map = {"pvc": "intel_gpu_pvc", "bmg": "intel_gpu_bmg"}
|
||||
arch = get_xpu_arch()
|
||||
return arc_option_map.get(arch, "intel_gpu_pvc")
|
||||
|
||||
|
||||
def _sycl_compiler_options() -> list[str]:
|
||||
options = [
|
||||
"-DCUTLASS_ENABLE_SYCL",
|
||||
"-DCUTLASS_SYCL_PROFILING_ENABLED",
|
||||
"-DSYCLCOMPAT_PROFILING_ENABLED",
|
||||
"-DSYCL_INTEL_TARGET",
|
||||
"-gline-tables-only",
|
||||
"-DCUTLASS_VERSIONS_GENERATED",
|
||||
"-O3",
|
||||
"-DNDEBUG",
|
||||
"-std=c++17",
|
||||
"-fPIE",
|
||||
"-fPIC",
|
||||
"-fsycl",
|
||||
f"-fsycl-targets={_sycl_arch_as_compile_option()}",
|
||||
"-Xspirv-translator",
|
||||
"-spirv-ext=+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate",
|
||||
"-fno-sycl-instrument-device-code",
|
||||
"-DMKL_ILP64",
|
||||
"-MD",
|
||||
"-MT",
|
||||
]
|
||||
if config.cutlass.enable_debug_info:
|
||||
options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"])
|
||||
return options
|
||||
|
||||
|
||||
def xpu_compile_command(
|
||||
src_files: list[str],
|
||||
dst_file: str,
|
||||
dst_file_ext: str,
|
||||
extra_args: Optional[list[str]] = None,
|
||||
) -> str:
|
||||
if extra_args is None:
|
||||
extra_args = []
|
||||
include_paths = _cutlass_include_paths()
|
||||
sycl_lib_options = _sycl_lib_options()
|
||||
sycl_host_compiler_options = _sycl_host_compiler_options()
|
||||
sycl_compiler_options = _sycl_compiler_options()
|
||||
options = (
|
||||
["-I" + path for path in include_paths]
|
||||
+ ["-isystem /include"]
|
||||
+ sycl_compiler_options
|
||||
+ extra_args
|
||||
+ sycl_host_compiler_options
|
||||
+ sycl_lib_options
|
||||
)
|
||||
src_file = " ".join(src_files)
|
||||
res = ""
|
||||
if dst_file_ext == "o":
|
||||
res = f"{_sycl_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}"
|
||||
elif dst_file_ext == "so":
|
||||
options.append("-shared")
|
||||
res = f"{_sycl_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
|
||||
elif dst_file_ext == "exe":
|
||||
res = f"{_sycl_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!")
|
||||
log.debug("XPU command: %s", res)
|
||||
return res
|
||||
@ -7,6 +7,7 @@ from ..common import (
|
||||
register_device_op_overrides,
|
||||
TritonScratchWorkspace,
|
||||
)
|
||||
from .xpu_env import get_xpu_arch, get_xpu_version
|
||||
|
||||
|
||||
class XPUDeviceOpOverrides(DeviceOpOverrides):
|
||||
@ -63,5 +64,11 @@ class XPUDeviceOpOverrides(DeviceOpOverrides):
|
||||
) -> Optional[tuple[list[str], str]]:
|
||||
return [f"void *global_scratch_{idx} = 0;"], f"global_scratch_{idx}"
|
||||
|
||||
def get_device_arch(self) -> str:
|
||||
return get_xpu_arch()
|
||||
|
||||
def get_toolkit_version(self) -> str:
|
||||
return get_xpu_version()
|
||||
|
||||
|
||||
register_device_op_overrides("xpu", XPUDeviceOpOverrides())
|
||||
|
||||
34
torch/_inductor/codegen/xpu/xpu_env.py
Normal file
34
torch/_inductor/codegen/xpu/xpu_env.py
Normal file
@ -0,0 +1,34 @@
|
||||
import functools
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._inductor.utils import clear_on_fresh_cache
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@clear_on_fresh_cache
|
||||
@functools.lru_cache(1)
|
||||
def get_xpu_arch() -> Optional[str]:
|
||||
arch_name2code = {"pvc": "11"}
|
||||
try:
|
||||
assert len(torch.xpu.get_arch_list()) == 1
|
||||
arch_name = torch.xpu.get_arch_list()[0]
|
||||
return arch_name2code[arch_name]
|
||||
except Exception as e:
|
||||
log.error("Error getting xpu arch: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
@clear_on_fresh_cache
|
||||
@functools.lru_cache(1)
|
||||
def get_xpu_version() -> Optional[str]:
|
||||
# string of version, like 20250101
|
||||
try:
|
||||
xpu_version = torch.version.xpu
|
||||
return xpu_version
|
||||
except Exception as e:
|
||||
log.error("Error getting xpu version: %s", e)
|
||||
return None
|
||||
@ -1713,28 +1713,12 @@ class aot_inductor_mode:
|
||||
compile_standalone: bool = False
|
||||
|
||||
|
||||
class cuda:
|
||||
"""Settings for cuda backend, today this consists of cutlass"""
|
||||
|
||||
# CUDA arch to use for CUDA template kernel compilation.
|
||||
# e.g. "70", "75", "80", "90", etc.
|
||||
# When arch is None, Inductor uses torch.cuda.get_device_capability(0).
|
||||
arch: Optional[str] = None
|
||||
|
||||
# CUDA version to use for CUDA template kernel compilation.
|
||||
# e.g. "11.4", "12.1", etc.
|
||||
# When version is None, Inductor uses torch.version.cuda.
|
||||
version: Optional[str] = None
|
||||
class cutlass:
|
||||
"""Settings for cutlass backend, today this consists of cutlass"""
|
||||
|
||||
# Optimization level for the host compiler.
|
||||
compile_opt_level: Literal["-O0", "-O1", "-O2", "-O3", "-OS"] = "-O1"
|
||||
|
||||
# Whether to enable device LTO (link-time-optimization).
|
||||
enable_cuda_lto = False
|
||||
|
||||
# Whether to keep intermediate files dring compilation.
|
||||
enable_ptxas_info = False
|
||||
|
||||
# Whether to enable debug info, e.g. line number, cutlass debug info.
|
||||
enable_debug_info = False
|
||||
|
||||
@ -1746,7 +1730,10 @@ class cuda:
|
||||
cutlass_dir = os.path.realpath(
|
||||
os.environ.get(
|
||||
"TORCHINDUCTOR_CUTLASS_DIR",
|
||||
os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/"),
|
||||
os.path.join(
|
||||
os.path.dirname(torch.__file__),
|
||||
"../third_party/cutlass/",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@ -1766,14 +1753,6 @@ class cuda:
|
||||
# Whether to only use TMA-compatible kernels in CUTLASS
|
||||
cutlass_tma_only = False
|
||||
|
||||
# Path to CUDA NVCC.
|
||||
# NVCC search order:
|
||||
# 1) cuda_cxx set in this config
|
||||
# 2) CUDACXX environment variable
|
||||
# 3) CUDA_HOME environment variable
|
||||
# 4) default system search PATH.
|
||||
cuda_cxx: Optional[str] = None
|
||||
|
||||
# Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops.
|
||||
cutlass_backend_min_gemm_size: int = 1
|
||||
|
||||
@ -1843,6 +1822,53 @@ class cuda:
|
||||
enable_caching_codegen: bool = True
|
||||
|
||||
|
||||
class cuda(cutlass):
|
||||
"""Settings for cutlass backend, today this consists of cutlass"""
|
||||
|
||||
# CUDA arch to use for CUDA template kernel compilation.
|
||||
# e.g. "70", "75", "80", "90", etc.
|
||||
# When arch is None, Inductor uses torch.cuda.get_device_capability(0).
|
||||
arch: Optional[str] = None
|
||||
|
||||
# CUDA version to use for CUDA template kernel compilation.
|
||||
# e.g. "11.4", "12.1", etc.
|
||||
# When version is None, Inductor uses torch.version.cuda.
|
||||
version: Optional[str] = None
|
||||
|
||||
# Path to CUDA NVCC.
|
||||
# NVCC search order:
|
||||
# 1) cuda_cxx set in this config
|
||||
# 2) CUDACXX environment variable
|
||||
# 3) CUDA_HOME environment variable
|
||||
# 4) default system search PATH.
|
||||
cuda_cxx: Optional[str] = None
|
||||
|
||||
# Whether to enable device LTO (link-time-optimization).
|
||||
enable_cuda_lto = False
|
||||
|
||||
# Whether to keep intermediate files dring compilation.
|
||||
enable_ptxas_info = False
|
||||
|
||||
|
||||
class xpu(cutlass):
|
||||
# Xe arch to use for SYCL template kernel compilation.
|
||||
# eg. 12, 20, which corresponding to Xe12(PVC) and Xe20 (BMG)
|
||||
arch: Optional[str] = None
|
||||
# oneAPI version to use for SYCL template kernel compilation.
|
||||
# e.g. "20250201".
|
||||
version: Optional[str] = None
|
||||
|
||||
cutlass_dir = os.path.realpath(
|
||||
os.environ.get(
|
||||
"TORCHINDUCTOR_CUTLASS_DIR",
|
||||
os.path.join(
|
||||
os.path.dirname(torch.__file__),
|
||||
"../third_party/sycl-tla/",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class rocm:
|
||||
# Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"].
|
||||
# If empty, the `native` arch is used
|
||||
|
||||
@ -2196,7 +2196,9 @@ class CppBuilder:
|
||||
)
|
||||
|
||||
if device_type == "cuda" and torch.version.hip is None:
|
||||
from torch._inductor.codecache import _nvcc_arch_as_compile_option
|
||||
from torch._inductor.codegen.cuda.compile_utils import (
|
||||
_nvcc_arch_as_compile_option,
|
||||
)
|
||||
|
||||
current_arch = _nvcc_arch_as_compile_option()
|
||||
contents += textwrap.dedent(
|
||||
|
||||
@ -489,7 +489,7 @@ MODULE_DEFAULTS: dict[str, ConfigType] = {
|
||||
"aot_inductor.presets": DEFAULT, # Typing
|
||||
"cuda.arch": DEFAULT, # Out of Scope
|
||||
"cuda.version": DEFAULT, # Out of Scope
|
||||
"cuda.cutlass_dir": DEFAULT, # Out of Scope
|
||||
"cutlass.cutlass_dir": DEFAULT, # Out of Scope
|
||||
"cuda.cuda_cxx": DEFAULT, # Out of Scope
|
||||
"rocm.arch": DEFAULT, # Out of Scope
|
||||
"rocm.ck_supported_arch": DEFAULT, # Out of Scope
|
||||
|
||||
@ -124,13 +124,13 @@ if TYPE_CHECKING:
|
||||
from torch.fx.experimental.symbolic_shapes import SympyBoolean
|
||||
from torch.fx.node import Argument
|
||||
|
||||
from .codegen.cuda.cuda_template import CUDATemplate
|
||||
from .codegen.cutlass.template import CUTLASSTemplate
|
||||
from .codegen.wrapper import PythonWrapperCodegen
|
||||
from .graph import GraphLowering
|
||||
from .utils import IndentedBuffer
|
||||
|
||||
else:
|
||||
CUDATemplate: TypeAlias = object
|
||||
CUTLASSTemplate: TypeAlias = object
|
||||
|
||||
|
||||
try:
|
||||
@ -5017,7 +5017,7 @@ class ChoiceCaller:
|
||||
During autotuning, self.benchmark() is first called to get benchmark result,
|
||||
and if this choice is selected, self.output_node() is called to get the output_node.
|
||||
|
||||
Children classes: TritonTemplateCaller, CUDATemplateCaller.
|
||||
Children classes: TritonTemplateCaller, CUTLASSTemplateCaller.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -5180,14 +5180,14 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
|
||||
self.make_kernel_render = self._make_kernel_renders[None]
|
||||
|
||||
|
||||
class CUDATemplateBuffer(TemplateBuffer):
|
||||
class CUTLASSTemplateBuffer(TemplateBuffer):
|
||||
def __init__(
|
||||
self,
|
||||
layout: Layout,
|
||||
inputs: Sequence[IRNode],
|
||||
make_kernel_render: Callable[_P, _T],
|
||||
workspace_size: int,
|
||||
template: CUDATemplate,
|
||||
template: CUTLASSTemplate,
|
||||
supports_epilogue_fusion: bool,
|
||||
) -> None:
|
||||
super().__init__(layout, inputs, make_kernel_render)
|
||||
@ -5210,7 +5210,7 @@ class CppTemplateBuffer(TemplateBuffer):
|
||||
layout: Layout,
|
||||
inputs: Sequence[IRNode],
|
||||
make_kernel_render: Callable[_P, _T],
|
||||
template: CUDATemplate,
|
||||
template: CUTLASSTemplate,
|
||||
choice: Any,
|
||||
) -> None:
|
||||
super().__init__(layout, inputs, make_kernel_render)
|
||||
|
||||
@ -261,7 +261,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
||||
and use_cutlass_template(layout, m, n, k)
|
||||
and _use_cutlass_for_op(name)
|
||||
):
|
||||
from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
|
||||
from ..codegen.cutlass.gemm_template import CUTLASS3xGemmTemplate
|
||||
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
||||
choices, layout, kernel_inputs.nodes()
|
||||
|
||||
@ -20,7 +20,7 @@ from torch.nn.functional import ScalingType # type: ignore[attr-defined]
|
||||
from torch.torch_version import TorchVersion
|
||||
|
||||
from .. import config as inductor_config
|
||||
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
|
||||
from ..codegen.cutlass.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
|
||||
from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate
|
||||
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
|
||||
from ..codegen.subgraph import SubgraphChoiceCaller, SubgraphTemplate
|
||||
|
||||
@ -5510,10 +5510,10 @@ class Scheduler:
|
||||
node = typing.cast(ForeachKernelSchedulerNode, node)
|
||||
# pyrefly: ignore [unbound-name]
|
||||
backend_ = self.get_backend(device)
|
||||
from .codegen.cuda_combined_scheduling import CUDACombinedScheduling
|
||||
from .codegen.combined_scheduling import CombinedScheduling
|
||||
from .codegen.simd import SIMDScheduling
|
||||
|
||||
if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)):
|
||||
if isinstance(backend_, (SIMDScheduling, CombinedScheduling)):
|
||||
backend = backend_
|
||||
else:
|
||||
raise AssertionError(f"{type(self)=}")
|
||||
|
||||
@ -2665,7 +2665,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
return_multi_template=False,
|
||||
best_config_future=None,
|
||||
):
|
||||
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
|
||||
from .codegen.cutlass.kernel import CUTLASSTemplateCaller
|
||||
|
||||
# Run preprocessing functions on choices
|
||||
for preprocessing_fn in self.preprocessing_fns:
|
||||
@ -2701,8 +2701,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
log.debug("Max autotune selects from %s choices.", str(len(choices)))
|
||||
|
||||
if len(choices) == 1:
|
||||
if not isinstance(choices[0], CUDATemplateCaller):
|
||||
# CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size.
|
||||
if not isinstance(choices[0], CUTLASSTemplateCaller):
|
||||
# CUTLASSTemplateCaller still needs to go through autotuning process to retrieve workspace size.
|
||||
return choices[0].output_node()
|
||||
|
||||
if config.deterministic:
|
||||
@ -3086,12 +3086,12 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
"select_algorithm_num_precompilation_exceptions"
|
||||
] += 1
|
||||
exceptions.append((futures[future], e))
|
||||
from torch._inductor.codegen.cuda.cuda_kernel import (
|
||||
CUDATemplateCaller,
|
||||
from torch._inductor.codegen.cutlass.kernel import (
|
||||
CUTLASSTemplateCaller,
|
||||
)
|
||||
|
||||
if isinstance(e, CUDACompileError) and isinstance(
|
||||
futures[future], CUDATemplateCaller
|
||||
futures[future], CUTLASSTemplateCaller
|
||||
):
|
||||
log.debug(
|
||||
"Exception %s for benchmark choice %s",
|
||||
@ -3272,19 +3272,20 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
for choice in choices:
|
||||
try:
|
||||
timing = cls.benchmark_choice(choice, autotune_args)
|
||||
except CUDACompileError:
|
||||
from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
|
||||
except CUDACompileError as e:
|
||||
from torch._inductor.codegen.cutlass.kernel import CUTLASSTemplateCaller
|
||||
|
||||
if not isinstance(choice, CUDATemplateCaller):
|
||||
if not isinstance(choice, CUTLASSTemplateCaller):
|
||||
log.exception(
|
||||
"CUDA compilation error during autotuning: \n%s. \nIgnoring this choice."
|
||||
"CUDA compilation error during autotuning: \n%s. \nIgnoring this choice.",
|
||||
e,
|
||||
)
|
||||
timing = float("inf")
|
||||
except NotImplementedError:
|
||||
log.warning("Not yet implemented", exc_info=True)
|
||||
timing = float("inf")
|
||||
except RuntimeError as e:
|
||||
from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
|
||||
from torch._inductor.codegen.cutlass.kernel import CUTLASSTemplateCaller
|
||||
|
||||
msg = str(e)
|
||||
if "invalid argument" in msg:
|
||||
@ -3295,7 +3296,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
msg += "\n\nAn unrecoverable unspecified launch failure was caught during autotuning."
|
||||
msg += "\nPlease try re-running with TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1.\n\n"
|
||||
|
||||
if isinstance(choice, CUDATemplateCaller):
|
||||
if isinstance(choice, CUTLASSTemplateCaller):
|
||||
log.debug(
|
||||
"Runtime error during autotuning: \n%s. \nIgnoring this choice.",
|
||||
msg,
|
||||
@ -3429,18 +3430,18 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
return prescreen_winners
|
||||
|
||||
# prescreen cutlass
|
||||
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
|
||||
from .codegen.cutlass.kernel import CUTLASSTemplateCaller
|
||||
|
||||
candidates = []
|
||||
if (
|
||||
config.cuda.cutlass_prescreening
|
||||
and len(config.cuda.cutlass_max_profiling_swizzle_options) > 1
|
||||
config.cutlass.cutlass_prescreening
|
||||
and len(config.cutlass.cutlass_max_profiling_swizzle_options) > 1
|
||||
):
|
||||
candidates.extend(
|
||||
[
|
||||
c
|
||||
for c in choices
|
||||
if isinstance(c, CUDATemplateCaller)
|
||||
if isinstance(c, CUTLASSTemplateCaller)
|
||||
# hardcoded to only look at swizzle=2
|
||||
if c.info_dict().get("swizzle") == "2"
|
||||
]
|
||||
@ -3463,7 +3464,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
"""
|
||||
Prune the choices after prescreening.
|
||||
"""
|
||||
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
|
||||
from .codegen.cutlass.kernel import CUTLASSTemplateCaller
|
||||
|
||||
prescreen_key = f"{name}:{inputs_key}"
|
||||
|
||||
@ -3477,7 +3478,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
pruned_choices = [
|
||||
choice
|
||||
for choice in choices
|
||||
if not isinstance(choice, CUDATemplateCaller)
|
||||
if not isinstance(choice, CUTLASSTemplateCaller)
|
||||
or choice.kernel_hash_key() in winner_kernel_hashes
|
||||
]
|
||||
return pruned_choices
|
||||
@ -3523,7 +3524,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
candidates_to_prune.add(candidate.kernel_hash_key())
|
||||
else:
|
||||
winner_hashes.add(candidate.hash_key())
|
||||
if isinstance(candidate, CUDATemplateCaller):
|
||||
if isinstance(candidate, CUTLASSTemplateCaller):
|
||||
candidate.bmreq.ensure_dll_loaded()
|
||||
|
||||
pruned_choices = [
|
||||
|
||||
@ -1872,9 +1872,9 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
|
||||
from .virtualized import V
|
||||
|
||||
gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1)
|
||||
if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size:
|
||||
if gemm_size <= 0 or gemm_size < config.cutlass.cutlass_backend_min_gemm_size:
|
||||
return False
|
||||
from .codegen.cuda.cutlass_utils import try_import_cutlass
|
||||
from .codegen.cutlass.utils import try_import_cutlass
|
||||
|
||||
# Do not use cutlass template on ROCm
|
||||
if torch.version.hip:
|
||||
@ -1893,9 +1893,9 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
|
||||
if not try_import_cutlass():
|
||||
log.warning(
|
||||
"Failed to import CUTLASS lib. Please check whether "
|
||||
"_inductor.config.cuda.cutlass_dir %s is set correctly. "
|
||||
"_inductor.config.cutlass.cutlass_dir %s is set correctly. "
|
||||
"Skipping CUTLASS backend for now.",
|
||||
config.cuda.cutlass_dir,
|
||||
config.cutlass.cutlass_dir,
|
||||
)
|
||||
return False
|
||||
return res
|
||||
@ -1903,7 +1903,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
|
||||
|
||||
def _use_cutlass_for_op(op_name: str) -> bool:
|
||||
"""Check if CUTLASS should be used for the given operation."""
|
||||
enabled_ops = config.cuda.cutlass_enabled_ops.upper()
|
||||
enabled_ops = config.cutlass.cutlass_enabled_ops.upper()
|
||||
if enabled_ops == "ALL":
|
||||
return True
|
||||
return op_name.upper() in [x.strip() for x in enabled_ops.split(",")]
|
||||
|
||||
Reference in New Issue
Block a user