mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 22:25:03 +08:00
Update
[ghstack-poisoned]
This commit is contained in:
@ -337,7 +337,7 @@ test_python() {
|
||||
|
||||
test_python_smoke() {
|
||||
# Smoke tests for H100/B200
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -127,6 +127,7 @@ torch/test/
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
|
||||
torch/version.py
|
||||
torch/_inductor/kernel/vendored_templates/*
|
||||
minifier_launcher.py
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
|
||||
|
||||
33
setup.py
33
setup.py
@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None:
|
||||
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
|
||||
|
||||
|
||||
def mirror_inductor_external_kernels() -> None:
|
||||
"""
|
||||
Copy external kernels into Inductor so they are importable.
|
||||
"""
|
||||
paths = [
|
||||
(
|
||||
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
|
||||
CWD
|
||||
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
|
||||
),
|
||||
]
|
||||
for new_path, orig_path in paths:
|
||||
# Create the dirs involved in new_path if they don't exist
|
||||
if not new_path.exists():
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy the files from the orig location to the new location
|
||||
if orig_path.is_file():
|
||||
shutil.copyfile(orig_path, new_path)
|
||||
continue
|
||||
if orig_path.is_dir():
|
||||
if new_path.exists():
|
||||
# copytree fails if the tree exists already, so remove it.
|
||||
shutil.rmtree(new_path)
|
||||
shutil.copytree(orig_path, new_path)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"Check the file paths in `mirror_inductor_external_kernels()`"
|
||||
)
|
||||
|
||||
|
||||
# ATTENTION: THIS IS AI SLOP
|
||||
def extract_variant_from_version(version: str) -> str:
|
||||
"""Extract variant from version string, defaulting to 'cpu'."""
|
||||
@ -1615,6 +1646,7 @@ def main() -> None:
|
||||
mirror_files_into_torchgen()
|
||||
if RUN_BUILD_DEPS:
|
||||
build_deps()
|
||||
mirror_inductor_external_kernels()
|
||||
|
||||
(
|
||||
ext_modules,
|
||||
@ -1649,6 +1681,7 @@ def main() -> None:
|
||||
"_inductor/codegen/aoti_runtime/*.cpp",
|
||||
"_inductor/script.ld",
|
||||
"_inductor/kernel/flex/templates/*.jinja",
|
||||
"_inductor/kernel/templates/*.jinja",
|
||||
"_export/serde/*.yaml",
|
||||
"_export/serde/*.thrift",
|
||||
"share/cmake/ATen/*.cmake",
|
||||
|
||||
154
test/inductor/test_cutedsl_grouped_mm.py
Normal file
154
test/inductor/test_cutedsl_grouped_mm.py
Normal file
@ -0,0 +1,154 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
|
||||
from torch._inductor.utils import ensure_cute_available
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
|
||||
"CuTeDSL library or Blackwell device not available",
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
class TestCuTeDSLGroupedGemm(InductorTestCase):
|
||||
def _get_inputs(
|
||||
self,
|
||||
group_size: int,
|
||||
M_hint: int,
|
||||
K: int,
|
||||
N: int,
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
alignment: int = 16,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# --- Random, tile-aligned M sizes ---
|
||||
M_sizes = (
|
||||
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
|
||||
* alignment
|
||||
)
|
||||
|
||||
M_total = torch.sum(M_sizes).item()
|
||||
|
||||
# --- Construct input tensors ---
|
||||
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
|
||||
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
# --- Build offsets (no leading zero, strictly increasing) ---
|
||||
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
|
||||
|
||||
return (A, B, offsets)
|
||||
|
||||
@parametrize("group_size", (2, 8))
|
||||
@parametrize("M_hint", (256, 1024))
|
||||
@parametrize("K", (64, 128))
|
||||
@parametrize("N", (128, 256))
|
||||
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# Eager execution
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# Test with Cute backend
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
|
||||
@parametrize("layout_B", ("contiguous", "broadcasted"))
|
||||
def test_grouped_gemm_assorted_layouts(
|
||||
self,
|
||||
layout_A: str,
|
||||
layout_B: str,
|
||||
):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
G, K, N = 8, 64, 128
|
||||
M_sizes = [128] * G
|
||||
sum_M = sum(M_sizes)
|
||||
offsets = torch.tensor(
|
||||
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
|
||||
A = A_base
|
||||
|
||||
if layout_A == "offset":
|
||||
# allocate bigger buffer than needed, use nonzero storage offset
|
||||
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
|
||||
offset = 128 # skip first 128 elements
|
||||
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
|
||||
elif layout_A == "padded":
|
||||
# simulate row pitch > K (row_stride = K + pad)
|
||||
row_pitch = K + 8
|
||||
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
|
||||
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
|
||||
elif layout_A == "view":
|
||||
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
|
||||
A = A_storage.view(sum_M, K)
|
||||
assert A._base is not None
|
||||
assert A.shape == (sum_M, K)
|
||||
|
||||
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
if layout_B == "broadcasted":
|
||||
# Broadcast B across groups (zero stride along G)
|
||||
B = B[0].expand(G, K, N)
|
||||
assert B.stride(0) == 0
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# --- eager ---
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# --- compiled (CUTE backend) ---
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -9,7 +9,6 @@ import sys
|
||||
import time
|
||||
import unittest
|
||||
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import auto
|
||||
from typing import Any, NamedTuple, Optional
|
||||
@ -1531,23 +1530,21 @@ class TestCxxPytree(TestCase):
|
||||
if IS_FBCODE:
|
||||
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
|
||||
|
||||
def assertEqualSpecs(
|
||||
self,
|
||||
spec1,
|
||||
spec2,
|
||||
msg: str | Callable[[str], str] | None = None,
|
||||
):
|
||||
if TEST_WITH_TORCHDYNAMO:
|
||||
# The Dynamo polyfill returns a pure Python class for PyTreeSpec.
|
||||
# So we compare the type names and reprs instead because the types
|
||||
# themselves won't be equal.
|
||||
self.assertEqual(type(spec1).__name__, type(spec2).__name__, msg=msg)
|
||||
self.assertEqual(repr(spec1), repr(spec2), msg=msg)
|
||||
else:
|
||||
self.assertEqual(spec1, spec2, msg=msg)
|
||||
def assertEqual(self, x, y, *args, **kwargs):
|
||||
x_typename, y_typename = type(x).__name__, type(y).__name__
|
||||
if not ("treespec" in x_typename.lower() or "treespec" in y_typename.lower()):
|
||||
super().assertEqual(x, y, *args, **kwargs)
|
||||
|
||||
# The Dynamo polyfill returns a polyfilled Python class for C++ PyTreeSpec instead of the
|
||||
# C++ class. So we compare the type names and reprs instead because the types themselves
|
||||
# won't be equal.
|
||||
super().assertEqual(x_typename, y_typename, *args, **kwargs)
|
||||
super().assertEqual(repr(x), repr(y), *args, **kwargs)
|
||||
if not TEST_WITH_TORCHDYNAMO or type(x) is type(y):
|
||||
super().assertEqual(x, y, *args, **kwargs)
|
||||
|
||||
def test_treespec_equality(self):
|
||||
self.assertEqualSpecs(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
|
||||
self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
|
||||
|
||||
def test_treespec_repr(self):
|
||||
# Check that it looks sane
|
||||
@ -1577,11 +1574,18 @@ class TestCxxPytree(TestCase):
|
||||
],
|
||||
)
|
||||
def test_pytree_serialize(self, spec):
|
||||
self.assertEqual(
|
||||
spec,
|
||||
cxx_pytree.tree_structure(
|
||||
cxx_pytree.tree_unflatten([0] * spec.num_leaves, spec)
|
||||
),
|
||||
)
|
||||
|
||||
serialized_spec = cxx_pytree.treespec_dumps(spec)
|
||||
self.assertIsInstance(serialized_spec, str)
|
||||
|
||||
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
|
||||
self.assertEqualSpecs(roundtrip_spec, spec)
|
||||
self.assertEqual(roundtrip_spec, spec)
|
||||
|
||||
def test_pytree_serialize_namedtuple(self):
|
||||
python_pytree._register_namedtuple(
|
||||
@ -1607,7 +1611,7 @@ class TestCxxPytree(TestCase):
|
||||
spec = cxx_pytree.tree_structure(GlobalDummyType(0, 1))
|
||||
serialized_spec = cxx_pytree.treespec_dumps(spec)
|
||||
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
|
||||
self.assertEqualSpecs(roundtrip_spec, spec)
|
||||
self.assertEqual(roundtrip_spec, spec)
|
||||
|
||||
class LocalDummyType:
|
||||
def __init__(self, x, y):
|
||||
@ -1623,7 +1627,7 @@ class TestCxxPytree(TestCase):
|
||||
spec = cxx_pytree.tree_structure(LocalDummyType(0, 1))
|
||||
serialized_spec = cxx_pytree.treespec_dumps(spec)
|
||||
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
|
||||
self.assertEqualSpecs(roundtrip_spec, spec)
|
||||
self.assertEqual(roundtrip_spec, spec)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestGenericPytree)
|
||||
|
||||
@ -550,6 +550,10 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
|
||||
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
|
||||
).upper() # type: ignore[assignment]
|
||||
|
||||
cutedsl_enable_autotuning: bool = (
|
||||
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
|
||||
)
|
||||
|
||||
# DEPRECATED. This setting is ignored.
|
||||
autotune_fallback_to_aten = False
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@ -12,6 +14,7 @@ from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
|
||||
from .. import config
|
||||
from ..codegen.wrapper import PythonWrapperCodegen
|
||||
from ..ir import _IntLike, Layout, TensorBox
|
||||
from ..utils import load_template
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -254,3 +257,7 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
|
||||
from torch._inductor.runtime.triton_compat import tl
|
||||
from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
@ -22,11 +24,13 @@ from ..utils import (
|
||||
get_num_sms,
|
||||
has_free_symbols,
|
||||
use_aten_gemm_kernels,
|
||||
use_blackwell_cutedsl_grouped_mm,
|
||||
use_triton_template,
|
||||
)
|
||||
from .mm_common import (
|
||||
_is_static_problem,
|
||||
check_supported_striding,
|
||||
load_kernel_template,
|
||||
persistent_grouped_mm_grid,
|
||||
)
|
||||
|
||||
@ -513,6 +517,11 @@ triton_scaled_grouped_mm_template = TritonTemplate(
|
||||
source=triton_grouped_mm_source,
|
||||
)
|
||||
|
||||
cutedsl_grouped_mm_template = CuteDSLTemplate(
|
||||
name="grouped_gemm_cutedsl",
|
||||
source=load_kernel_template("cutedsl_mm_grouped"),
|
||||
)
|
||||
|
||||
|
||||
def grouped_mm_args(
|
||||
mat1: TensorBox,
|
||||
@ -714,43 +723,44 @@ def _tuned_grouped_mm_common(
|
||||
# Checking only for the equality of corresponding dims of
|
||||
# multiplicands here, relying on meta function checks for
|
||||
# everything else.
|
||||
if len(m1_size) == 2:
|
||||
if len(m2_size) == 2:
|
||||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
if (
|
||||
is_nonzero
|
||||
and use_triton_template(layout)
|
||||
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
|
||||
):
|
||||
scaled = scale_a is not None
|
||||
if len(m1_size) == 2:
|
||||
if len(m2_size) == 2:
|
||||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
a_is_k_major = mat_a.get_stride()[-1] == 1
|
||||
b_is_k_major = mat_b.get_stride()[-2] == 1
|
||||
@ -788,6 +798,22 @@ def _tuned_grouped_mm_common(
|
||||
**config.kwargs,
|
||||
)
|
||||
|
||||
if use_blackwell_cutedsl_grouped_mm(
|
||||
mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result
|
||||
):
|
||||
for config in get_groupgemm_configs():
|
||||
kwargs = dict(
|
||||
ACC_DTYPE="cutlass.Float32",
|
||||
)
|
||||
|
||||
cutedsl_grouped_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
**asdict(config),
|
||||
)
|
||||
|
||||
input_gen_fns = {
|
||||
4: lambda x: create_offsets(
|
||||
x, m1_size, m2_size, offs.get_size() if offs is not None else None
|
||||
|
||||
333
torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja
Normal file
333
torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja
Normal file
@ -0,0 +1,333 @@
|
||||
import functools
|
||||
from torch._inductor.runtime.runtime_utils import ceildiv
|
||||
from cutlass.utils import TensorMapUpdateMode
|
||||
{{gen_defines()}}
|
||||
# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ----
|
||||
from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import (
|
||||
GroupedGemmKernel,
|
||||
)
|
||||
|
||||
|
||||
# Note about caching:
|
||||
# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor
|
||||
# maintains its own local caching system. At this stage, all compile-time
|
||||
# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel
|
||||
# name itself ({{kernel_name}}) are permanently baked into the file, so they
|
||||
# do not need to be included in any cache key.
|
||||
#
|
||||
# The caching mechanism is split into two levels:
|
||||
#
|
||||
# 1. prep_cache
|
||||
# Caches the compiled executor for build_group_ptrs_from_bases(). This
|
||||
# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C,
|
||||
# and can therefore be safely reused across runs with different group
|
||||
# partitioning (`offs`).
|
||||
#
|
||||
# 2. gemm_cache
|
||||
# Caches the compiled Grouped GEMM executor. Its key extends the prep
|
||||
# cache key with hardware- and grid-specific parameters:
|
||||
# (prep_cache_key, max_active_clusters, total_num_clusters).
|
||||
# This is necessary because different `offs` tensors can change the
|
||||
# per-group problem sizes and thus alter `total_num_clusters`, which in
|
||||
# turn changes the grid shape and persistent scheduler configuration.
|
||||
# Kernels compiled for one grid cannot be safely reused for another.
|
||||
#
|
||||
#
|
||||
# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically,
|
||||
# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead,
|
||||
# despite depending only on the GPU type. We cache this function to mitigate
|
||||
# redundant recompiles even when shape/stride/dtype cache misses force kernel
|
||||
# regeneration. A follow-up study will investigate the root cause.
|
||||
|
||||
prep_cache = {}
|
||||
gemm_cache = {}
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_hardware_info():
|
||||
hw = cutlass.utils.HardwareInfo()
|
||||
sm_count = hw.get_max_active_clusters(1)
|
||||
max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N)
|
||||
|
||||
return (sm_count, max_active_clusters)
|
||||
|
||||
|
||||
def get_prep_cache_key(input_a, input_b, output):
|
||||
"""
|
||||
Returns a tuple key for caching the preprocessing kernel executor based on kernel name,
|
||||
shapes, strides, and dtypes of input/output tensors.
|
||||
"""
|
||||
return (
|
||||
tuple(input_a.shape),
|
||||
tuple(input_a.stride()),
|
||||
input_a.dtype,
|
||||
tuple(input_b.shape),
|
||||
tuple(input_b.stride()),
|
||||
input_b.dtype,
|
||||
tuple(output.shape),
|
||||
tuple(output.stride()),
|
||||
output.dtype,
|
||||
)
|
||||
|
||||
|
||||
def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters):
|
||||
"""
|
||||
Returns a tuple key for caching the gemm kernel executor by extending the
|
||||
prep cache key with hardware- and grid-specific parameters.
|
||||
"""
|
||||
return (
|
||||
prep_cache_key,
|
||||
max_active_clusters,
|
||||
total_num_clusters,
|
||||
)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def build_group_ptrs_from_bases_kernel(
|
||||
base_A_u64: cutlass.Int64, # device addr of input_a (bytes)
|
||||
base_B_u64: cutlass.Int64, # device addr of input_b (bytes)
|
||||
base_C_u64: cutlass.Int64, # device addr of Output (bytes)
|
||||
offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative
|
||||
K: cutlass.Constexpr,
|
||||
N: cutlass.Constexpr,
|
||||
sizeof_element: cutlass.Int32, # bytes
|
||||
# -------- STRIDES (in ELEMENTS) --------
|
||||
stride_A_m_elems: cutlass.Constexpr, # A.stride(0)
|
||||
stride_A_k_elems: cutlass.Constexpr, # A.stride(1)
|
||||
stride_B0_elems: cutlass.Constexpr, # B.stride(0)
|
||||
stride_Bk_elems: cutlass.Constexpr, # B.stride(1)
|
||||
stride_Bn_elems: cutlass.Constexpr, # B.stride(2)
|
||||
stride_C_m_elems: cutlass.Constexpr, # C.stride(0)
|
||||
stride_C_n_elems: cutlass.Constexpr, # C.stride(1)
|
||||
# -------- OUTPUTS --------
|
||||
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr)
|
||||
out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1)
|
||||
out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]]
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
g = tidx
|
||||
|
||||
m_beg_i32 = 0
|
||||
if g > 0:
|
||||
m_beg_i32 = offs[g - 1]
|
||||
m_end_i32 = offs[g]
|
||||
m_g_i32 = m_end_i32 - m_beg_i32
|
||||
|
||||
a_byte_off = (
|
||||
cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element)
|
||||
)
|
||||
c_byte_off = (
|
||||
cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element)
|
||||
)
|
||||
b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element)
|
||||
|
||||
# ---- pointers ----
|
||||
out_ptrs[g, 0] = base_A_u64 + a_byte_off
|
||||
out_ptrs[g, 1] = base_B_u64 + b_byte_off
|
||||
out_ptrs[g, 2] = base_C_u64 + c_byte_off
|
||||
|
||||
# ---- (m, n, k, 1) ----
|
||||
out_problem[g, 0] = m_g_i32
|
||||
out_problem[g, 1] = N
|
||||
out_problem[g, 2] = K
|
||||
out_problem[g, 3] = cutlass.Int32(1)
|
||||
|
||||
# ---- strides ----
|
||||
out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems)
|
||||
out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems)
|
||||
out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems)
|
||||
out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems)
|
||||
out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems)
|
||||
out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_build_group_ptrs_from_bases(
|
||||
base_A_u64: cutlass.Int64,
|
||||
base_B_u64: cutlass.Int64,
|
||||
base_C_u64: cutlass.Int64,
|
||||
offs: cute.Tensor,
|
||||
G: cutlass.Constexpr,
|
||||
K: cutlass.Constexpr,
|
||||
N: cutlass.Constexpr,
|
||||
sizeof_element: cutlass.Constexpr,
|
||||
stride_A_m_elems: cutlass.Constexpr,
|
||||
stride_A_k_elems: cutlass.Constexpr,
|
||||
stride_B0_elems: cutlass.Constexpr,
|
||||
stride_Bk_elems: cutlass.Constexpr,
|
||||
stride_Bn_elems: cutlass.Constexpr,
|
||||
stride_C_m_elems: cutlass.Constexpr,
|
||||
stride_C_n_elems: cutlass.Constexpr,
|
||||
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64
|
||||
out_problem: cute.Tensor, # [G,4] cutlass.Int32
|
||||
out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
build_group_ptrs_from_bases_kernel(
|
||||
base_A_u64,
|
||||
base_B_u64,
|
||||
base_C_u64,
|
||||
offs,
|
||||
K,
|
||||
N,
|
||||
sizeof_element,
|
||||
stride_A_m_elems,
|
||||
stride_A_k_elems,
|
||||
stride_B0_elems,
|
||||
stride_Bk_elems,
|
||||
stride_Bn_elems,
|
||||
stride_C_m_elems,
|
||||
stride_C_n_elems,
|
||||
out_ptrs,
|
||||
out_problem,
|
||||
out_strides_abc,
|
||||
).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream)
|
||||
|
||||
|
||||
{{def_kernel("input_a", "input_b", "input_a_offs")}}
|
||||
stream = cuda.CUstream(stream)
|
||||
|
||||
input_b = input_b.transpose(1, 2)
|
||||
|
||||
sumM, K = input_a.shape
|
||||
G, N, Kb = input_b.shape
|
||||
|
||||
dev = input_a.device
|
||||
|
||||
base_A_u64 = int(input_a.data_ptr())
|
||||
base_B_u64 = int(input_b.data_ptr())
|
||||
base_C_u64 = int({{get_output()}}.data_ptr())
|
||||
|
||||
ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64)
|
||||
probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32)
|
||||
strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32)
|
||||
ptrs = from_dlpack(ptrs_t)
|
||||
probs = from_dlpack(probs_t)
|
||||
strides = from_dlpack(strides_t)
|
||||
|
||||
prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}})
|
||||
prep_executor = prep_cache.get(prep_cache_key)
|
||||
|
||||
if prep_executor is None:
|
||||
sizeof_element = int(input_a.element_size())
|
||||
sA_m, sA_k = map(int, input_a.stride())
|
||||
sB_0, sB_n, sB_k = map(int, input_b.stride())
|
||||
sC_m, sC_n = map(int, {{get_output()}}.stride())
|
||||
|
||||
prep_executor = cute.compile(
|
||||
launch_build_group_ptrs_from_bases,
|
||||
base_A_u64=base_A_u64,
|
||||
base_B_u64=base_B_u64,
|
||||
base_C_u64=base_C_u64,
|
||||
offs=from_dlpack(input_a_offs),
|
||||
G=int(G),
|
||||
K=int(K),
|
||||
N=int(N),
|
||||
sizeof_element=sizeof_element,
|
||||
stride_A_m_elems=sA_m,
|
||||
stride_A_k_elems=sA_k,
|
||||
stride_B0_elems=sB_0,
|
||||
stride_Bk_elems=sB_k,
|
||||
stride_Bn_elems=sB_n,
|
||||
stride_C_m_elems=sC_m,
|
||||
stride_C_n_elems=sC_n,
|
||||
out_ptrs=ptrs,
|
||||
out_problem=probs,
|
||||
out_strides_abc=strides,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
prep_cache[prep_cache_key] = prep_executor
|
||||
|
||||
prep_executor(
|
||||
base_A_u64=base_A_u64,
|
||||
base_B_u64=base_B_u64,
|
||||
base_C_u64=base_C_u64,
|
||||
offs=from_dlpack(input_a_offs),
|
||||
out_ptrs=ptrs,
|
||||
out_problem=probs,
|
||||
out_strides_abc=strides,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# --- Tensormap workspace per SM ---
|
||||
num_tensormap_buffers, max_active_clusters = get_hardware_info()
|
||||
tensormap_shape = (
|
||||
num_tensormap_buffers,
|
||||
GroupedGemmKernel.num_tensormaps,
|
||||
GroupedGemmKernel.bytes_per_tensormap // 8,
|
||||
)
|
||||
tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64)
|
||||
tensormap_workspace = from_dlpack(tensormap_workspace_t)
|
||||
|
||||
# --- Total clusters ---
|
||||
def compute_total_num_clusters(
|
||||
problem_sizes_mnkl,
|
||||
cluster_tile_shape_mn,
|
||||
):
|
||||
total_num_clusters = 0
|
||||
for m, n, _, _ in problem_sizes_mnkl:
|
||||
num_clusters_mn = tuple(
|
||||
ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn)
|
||||
)
|
||||
total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
|
||||
return total_num_clusters
|
||||
|
||||
# Compute cluster tile shape
|
||||
def compute_cluster_tile_shape(
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
use_2cta_instrs,
|
||||
):
|
||||
cta_tile_shape_mn = list(mma_tiler_mn)
|
||||
if use_2cta_instrs:
|
||||
cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2
|
||||
return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))
|
||||
|
||||
cluster_tile_shape_mn = compute_cluster_tile_shape(
|
||||
(TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA)
|
||||
)
|
||||
|
||||
total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn))
|
||||
|
||||
gemm_cache_key = get_gemm_cache_key(
|
||||
prep_cache_key, max_active_clusters, total_num_clusters
|
||||
)
|
||||
gemm_executor = gemm_cache.get(gemm_cache_key)
|
||||
|
||||
if gemm_executor is None:
|
||||
grouped_gemm = GroupedGemmKernel(
|
||||
acc_dtype=ACC_DTYPE,
|
||||
use_2cta_instrs=USE_2_CTA,
|
||||
mma_tiler_mn=(TILE_M, TILE_N),
|
||||
cluster_shape_mn=(CLUSTER_M, CLUSTER_N),
|
||||
tensormap_update_mode=TENSORMAP_UPDATE_MODE,
|
||||
)
|
||||
|
||||
gemm_executor = cute.compile(
|
||||
grouped_gemm,
|
||||
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
|
||||
G,
|
||||
probs,
|
||||
strides,
|
||||
ptrs,
|
||||
total_num_clusters,
|
||||
tensormap_workspace,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
)
|
||||
|
||||
gemm_cache[gemm_cache_key] = gemm_executor
|
||||
|
||||
gemm_executor(
|
||||
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
|
||||
probs,
|
||||
strides,
|
||||
ptrs,
|
||||
tensormap_workspace,
|
||||
stream,
|
||||
)
|
||||
141
torch/_inductor/template_heuristics/cutedsl.py
Normal file
141
torch/_inductor/template_heuristics/cutedsl.py
Normal file
@ -0,0 +1,141 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from itertools import product
|
||||
|
||||
import torch._inductor.config as config
|
||||
|
||||
|
||||
class TensorMapUpdateMode(Enum):
|
||||
"""Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency."""
|
||||
|
||||
SMEM = auto()
|
||||
GMEM = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CuTeGemmConfig:
|
||||
TILE_M: int = 128
|
||||
TILE_N: int = 192
|
||||
CLUSTER_M: int = 2
|
||||
CLUSTER_N: int = 1
|
||||
USE_2_CTA: bool = False
|
||||
TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM
|
||||
|
||||
|
||||
def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
For information regarding valid config sets, see:
|
||||
https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py
|
||||
"""
|
||||
|
||||
# Tile_n is always the same regardless of 2cta
|
||||
tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
# Valid clusters
|
||||
clusters_no_2cta = [
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(1, 4),
|
||||
(1, 8),
|
||||
(1, 16),
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(2, 8),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
(8, 1),
|
||||
(8, 2),
|
||||
(16, 1),
|
||||
]
|
||||
clusters_2cta = [
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(2, 8),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
(8, 1),
|
||||
(8, 2),
|
||||
(16, 1),
|
||||
]
|
||||
|
||||
configs: list[CuTeGemmConfig] = []
|
||||
|
||||
for use_2cta, cluster_set, tile_m_range in [
|
||||
(False, clusters_no_2cta, [64, 128]),
|
||||
(True, clusters_2cta, [128, 256]),
|
||||
]:
|
||||
for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product(
|
||||
[TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM],
|
||||
tile_m_range,
|
||||
tile_n_vals,
|
||||
cluster_set,
|
||||
):
|
||||
configs.append(
|
||||
CuTeGemmConfig(
|
||||
tile_m,
|
||||
tile_n,
|
||||
cluster_m,
|
||||
cluster_n,
|
||||
USE_2_CTA=use_2cta,
|
||||
TENSORMAP_UPDATE_MODE=tensormap_update_mode,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def get_default_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
"""
|
||||
|
||||
config_tuples = [
|
||||
(128, 256, 2, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 160, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(128, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 256, 2, 2, True, TensorMapUpdateMode.GMEM),
|
||||
(128, 256, 1, 2, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 256, 2, 1, True, TensorMapUpdateMode.SMEM),
|
||||
(128, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 8, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 192, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 2, 2, True, TensorMapUpdateMode.SMEM),
|
||||
(128, 96, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(64, 192, 1, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(64, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 192, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(128, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 160, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
]
|
||||
|
||||
return [CuTeGemmConfig(*args) for args in config_tuples]
|
||||
|
||||
|
||||
def get_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
|
||||
Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures
|
||||
or unstable results. By default, autotuning is disabled and we return only
|
||||
a single baseline config.
|
||||
"""
|
||||
if (
|
||||
config.cutedsl_enable_autotuning
|
||||
and config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
||||
):
|
||||
return get_exhaustive_groupgemm_configs()
|
||||
elif config.cutedsl_enable_autotuning:
|
||||
return get_default_groupgemm_configs()
|
||||
else:
|
||||
return [get_default_groupgemm_configs()[0]]
|
||||
@ -1911,6 +1911,84 @@ def use_triton_blackwell_tma_template(
|
||||
return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def ensure_cute_available() -> bool:
|
||||
"""Check if CuTeDSL is importable; cache the result for reuse.
|
||||
|
||||
Call ensure_cute_available.cache_clear() after installing CuTeDSL
|
||||
in the same interpreter to retry the import.
|
||||
"""
|
||||
try:
|
||||
return importlib.util.find_spec("cutlass.cute") is not None
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def use_blackwell_cutedsl_grouped_mm(
|
||||
mat_a: Any,
|
||||
mat_b: Any,
|
||||
layout: Layout,
|
||||
a_is_2d: bool,
|
||||
b_is_2d: bool,
|
||||
offs: Optional[Any],
|
||||
bias: Optional[Any],
|
||||
scale_result: Optional[Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if we can use the blackwell kernel for grouped mm.
|
||||
Required conditions:
|
||||
1. CuTeDSL backend is enabled
|
||||
2. CuTeDSL is available
|
||||
3. We are on a blackwell arch
|
||||
4. The dtype is bf16
|
||||
5. Max autotune or max autotune gemm is enabled
|
||||
6. A, B, and the output are 16B aligned
|
||||
7. We are not using dynamic shapes
|
||||
8. A is 2d
|
||||
9. B is 3d
|
||||
10. Offsets are provided
|
||||
11. Bias and Scale are not provided
|
||||
"""
|
||||
if not ensure_cute_available():
|
||||
return False
|
||||
|
||||
if not _use_autotune_backend("CUTEDSL"):
|
||||
return False
|
||||
|
||||
from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
|
||||
if not is_gpu(layout.device.type):
|
||||
return False
|
||||
|
||||
if not is_datacenter_blackwell_arch():
|
||||
return False
|
||||
|
||||
layout_dtypes = [torch.bfloat16]
|
||||
if not _use_template_for_gpu(layout, layout_dtypes):
|
||||
return False
|
||||
|
||||
if not (config.max_autotune or config.max_autotune_gemm):
|
||||
return False
|
||||
|
||||
# Checks for 16B ptr and stride alignment
|
||||
if not can_use_tma(mat_a, mat_b, output_layout=layout):
|
||||
return False
|
||||
|
||||
if any(is_dynamic(x) for x in [mat_a, mat_b]):
|
||||
return False
|
||||
|
||||
if not a_is_2d or b_is_2d:
|
||||
return False
|
||||
|
||||
if offs is None:
|
||||
return False
|
||||
|
||||
if bias is not None or scale_result is not None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
@ -122,29 +122,47 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT
|
||||
ID get_tensor_storage_ID(const c10::Storage& t_storage) {
|
||||
const std::lock_guard<std::recursive_mutex> lock(gMutex);
|
||||
|
||||
const void* raw_data_ptr = t_storage.data();
|
||||
auto iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr);
|
||||
if (iter == data_ptr_to_weak_storage_ptr.end()) {
|
||||
const void* raw_data_ptr = nullptr;
|
||||
bool should_track_liveness = false;
|
||||
// FakeTensor/FunctionalTensor may clear the Storage handle entirely or use
|
||||
// a nullptr data pointer. Treat both cases as a shared cache key but avoid
|
||||
// touching the weak-ref table so they can reuse the same ID without
|
||||
// tripping the liveness check.
|
||||
if (t_storage.unsafeGetStorageImpl()) {
|
||||
raw_data_ptr = t_storage.data();
|
||||
should_track_liveness = raw_data_ptr != nullptr;
|
||||
}
|
||||
|
||||
auto id_iter = data_ptr_to_storage_id.find(raw_data_ptr);
|
||||
if (!should_track_liveness) {
|
||||
if (id_iter != data_ptr_to_storage_id.end()) {
|
||||
return id_iter->second;
|
||||
}
|
||||
ID id = storage_id_++;
|
||||
data_ptr_to_storage_id.emplace(raw_data_ptr, id);
|
||||
return id;
|
||||
}
|
||||
|
||||
auto weak_iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr);
|
||||
if (weak_iter == data_ptr_to_weak_storage_ptr.end()) {
|
||||
ID id = storage_id_++;
|
||||
data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id);
|
||||
data_ptr_to_weak_storage_ptr.emplace(
|
||||
raw_data_ptr, t_storage.getWeakStorageImpl());
|
||||
return id;
|
||||
} else {
|
||||
// check if the storage is still alive
|
||||
if (iter->second.expired()) {
|
||||
ID id = storage_id_++;
|
||||
// std::unorder_map does not change if the key is already in the map.
|
||||
// So we need to remove the key and insert the key with the new value.
|
||||
data_ptr_to_storage_id.erase(raw_data_ptr);
|
||||
data_ptr_to_storage_id[raw_data_ptr] = id;
|
||||
data_ptr_to_weak_storage_ptr.insert_or_assign(
|
||||
raw_data_ptr, t_storage.getWeakStorageImpl());
|
||||
return id;
|
||||
} else {
|
||||
return data_ptr_to_storage_id[raw_data_ptr];
|
||||
}
|
||||
}
|
||||
|
||||
if (weak_iter->second.expired()) {
|
||||
ID id = storage_id_++;
|
||||
data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id);
|
||||
data_ptr_to_weak_storage_ptr.insert_or_assign(
|
||||
raw_data_ptr, t_storage.getWeakStorageImpl());
|
||||
return id;
|
||||
}
|
||||
|
||||
id_iter = data_ptr_to_storage_id.find(raw_data_ptr);
|
||||
TORCH_INTERNAL_ASSERT(id_iter != data_ptr_to_storage_id.end());
|
||||
return id_iter->second;
|
||||
}
|
||||
|
||||
// Observer run state.
|
||||
|
||||
Reference in New Issue
Block a user