mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 14:54:55 +08:00
Compare commits
8 Commits
lucaskabel
...
trunk/b629
| Author | SHA1 | Date | |
|---|---|---|---|
| b62935d1a5 | |||
| ccc8c117dc | |||
| 86db4de10f | |||
| 12860892f8 | |||
| 694592ac1e | |||
| 285748e838 | |||
| 192034c41b | |||
| 5bfce8f345 |
@ -337,7 +337,7 @@ test_python() {
|
||||
|
||||
test_python_smoke() {
|
||||
# Smoke tests for H100/B200
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
|
||||
9
.github/labeler.yml
vendored
9
.github/labeler.yml
vendored
@ -138,7 +138,8 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- torch/**/*cublas*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
@ -148,7 +149,8 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- torch/**/*cublas*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
@ -158,7 +160,8 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
- third_party/fbgemm
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -127,7 +127,6 @@ torch/test/
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
|
||||
torch/version.py
|
||||
torch/_inductor/kernel/vendored_templates/*
|
||||
minifier_launcher.py
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
|
||||
|
||||
@ -210,8 +210,12 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
|
||||
/test/inductor/test_flex_attention.py @drisspg
|
||||
/test/inductor/test_flex_decoding.py @drisspg
|
||||
|
||||
# Low Precision GEMMs
|
||||
# Low Precision & Grouped GEMMs
|
||||
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/native/cuda/GroupedBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/native/cuda/ScaledBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDAScaledBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDAScaledBlas.h @drisspg @slayton58
|
||||
/test/test_scaled_matmul_cuda.py @drisspg @slayton58
|
||||
|
||||
34
setup.py
34
setup.py
@ -630,37 +630,6 @@ def mirror_files_into_torchgen() -> None:
|
||||
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
|
||||
|
||||
|
||||
def mirror_inductor_external_kernels() -> None:
|
||||
"""
|
||||
Copy external kernels into Inductor so they are importable.
|
||||
"""
|
||||
paths = [
|
||||
(
|
||||
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
|
||||
CWD
|
||||
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
|
||||
),
|
||||
]
|
||||
for new_path, orig_path in paths:
|
||||
# Create the dirs involved in new_path if they don't exist
|
||||
if not new_path.exists():
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy the files from the orig location to the new location
|
||||
if orig_path.is_file():
|
||||
shutil.copyfile(orig_path, new_path)
|
||||
continue
|
||||
if orig_path.is_dir():
|
||||
if new_path.exists():
|
||||
# copytree fails if the tree exists already, so remove it.
|
||||
shutil.rmtree(new_path)
|
||||
shutil.copytree(orig_path, new_path)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"Check the file paths in `mirror_inductor_external_kernels()`"
|
||||
)
|
||||
|
||||
|
||||
# ATTENTION: THIS IS AI SLOP
|
||||
def extract_variant_from_version(version: str) -> str:
|
||||
"""Extract variant from version string, defaulting to 'cpu'."""
|
||||
@ -1647,8 +1616,6 @@ def main() -> None:
|
||||
if RUN_BUILD_DEPS:
|
||||
build_deps()
|
||||
|
||||
mirror_inductor_external_kernels()
|
||||
|
||||
(
|
||||
ext_modules,
|
||||
cmdclass,
|
||||
@ -1682,7 +1649,6 @@ def main() -> None:
|
||||
"_inductor/codegen/aoti_runtime/*.cpp",
|
||||
"_inductor/script.ld",
|
||||
"_inductor/kernel/flex/templates/*.jinja",
|
||||
"_inductor/kernel/templates/*.jinja",
|
||||
"_export/serde/*.yaml",
|
||||
"_export/serde/*.thrift",
|
||||
"share/cmake/ATen/*.cmake",
|
||||
|
||||
@ -1,154 +0,0 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
|
||||
from torch._inductor.utils import ensure_cute_available
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
|
||||
"CuTeDSL library or Blackwell device not available",
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
class TestCuTeDSLGroupedGemm(InductorTestCase):
|
||||
def _get_inputs(
|
||||
self,
|
||||
group_size: int,
|
||||
M_hint: int,
|
||||
K: int,
|
||||
N: int,
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
alignment: int = 16,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# --- Random, tile-aligned M sizes ---
|
||||
M_sizes = (
|
||||
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
|
||||
* alignment
|
||||
)
|
||||
|
||||
M_total = torch.sum(M_sizes).item()
|
||||
|
||||
# --- Construct input tensors ---
|
||||
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
|
||||
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
# --- Build offsets (no leading zero, strictly increasing) ---
|
||||
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
|
||||
|
||||
return (A, B, offsets)
|
||||
|
||||
@parametrize("group_size", (2, 8))
|
||||
@parametrize("M_hint", (256, 1024))
|
||||
@parametrize("K", (64, 128))
|
||||
@parametrize("N", (128, 256))
|
||||
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# Eager execution
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# Test with Cute backend
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
|
||||
@parametrize("layout_B", ("contiguous", "broadcasted"))
|
||||
def test_grouped_gemm_assorted_layouts(
|
||||
self,
|
||||
layout_A: str,
|
||||
layout_B: str,
|
||||
):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
G, K, N = 8, 64, 128
|
||||
M_sizes = [128] * G
|
||||
sum_M = sum(M_sizes)
|
||||
offsets = torch.tensor(
|
||||
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
|
||||
A = A_base
|
||||
|
||||
if layout_A == "offset":
|
||||
# allocate bigger buffer than needed, use nonzero storage offset
|
||||
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
|
||||
offset = 128 # skip first 128 elements
|
||||
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
|
||||
elif layout_A == "padded":
|
||||
# simulate row pitch > K (row_stride = K + pad)
|
||||
row_pitch = K + 8
|
||||
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
|
||||
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
|
||||
elif layout_A == "view":
|
||||
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
|
||||
A = A_storage.view(sum_M, K)
|
||||
assert A._base is not None
|
||||
assert A.shape == (sum_M, K)
|
||||
|
||||
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
if layout_B == "broadcasted":
|
||||
# Broadcast B across groups (zero stride along G)
|
||||
B = B[0].expand(G, K, N)
|
||||
assert B.stride(0) == 0
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# --- eager ---
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# --- compiled (CUTE backend) ---
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -1913,6 +1913,29 @@ class TestMaxAutotune(TestCase):
|
||||
# Check that contiguous transform was used
|
||||
FileCheck().check("contiguous_mm").run(code[0])
|
||||
|
||||
@unittest.skipIf(config.cpp_wrapper, "out_dtype override not supported for AOTI")
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "out_dtype override only available on NVIDIA")
|
||||
def test_bmm_out_dtype(self):
|
||||
def f(a, b):
|
||||
return torch.bmm(a, b, out_dtype=torch.float32)
|
||||
|
||||
a = torch.randn(2, 3, 4, device=GPU_TYPE, dtype=torch.float16)
|
||||
b = torch.randn(2, 4, 5, device=GPU_TYPE, dtype=torch.float16)
|
||||
with config.patch(
|
||||
max_autotune=True,
|
||||
max_autotune_gemm_backends="TRITON",
|
||||
):
|
||||
compiled_f = torch.compile(f)
|
||||
with self.assertRaisesRegex(
|
||||
torch._inductor.exc.InductorError,
|
||||
r"LoweringException: NoValidChoicesError: No choices to select",
|
||||
):
|
||||
out, code = run_and_get_code(compiled_f, a, b)
|
||||
|
||||
compiled_f = torch.compile(f)
|
||||
out, code = run_and_get_code(compiled_f, a, b)
|
||||
FileCheck().check("extern_kernels.bmm_dtype").run(code[0])
|
||||
|
||||
def test_triton_template_generated_code_cache_key(self):
|
||||
generate_and_load_args = len(
|
||||
inspect.signature(
|
||||
|
||||
@ -1217,6 +1217,43 @@ class TestPatternMatcher(TestCase):
|
||||
_, (code) = run_and_get_code(fn2, args[0], args[1], args[2])
|
||||
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
|
||||
|
||||
def test_addmm_alpha_beta_with_pointwise(self):
|
||||
# Test that addmm with alpha/beta != 1 is unfused correctly with pointwise ops
|
||||
# See https://github.com/pytorch/pytorch/issues/167313
|
||||
x = torch.rand(2, device=GPU_TYPE)
|
||||
a = torch.rand(2, 3, device=GPU_TYPE)
|
||||
b = torch.rand(3, 2, device=GPU_TYPE)
|
||||
|
||||
def f(x, a, b):
|
||||
return torch.nn.functional.relu(torch.addmm(x, a, b, alpha=0.8, beta=0.2))
|
||||
|
||||
fc = torch.compile(f)
|
||||
|
||||
expected = f(x, a, b)
|
||||
actual = fc(x, a, b)
|
||||
|
||||
# The compiled version should produce the same result as eager
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
# Verify that addmm is unfused (should not use extern_kernels.addmm)
|
||||
# The pattern should be replaced with beta * x + alpha * (a @ b)
|
||||
_, (code) = run_and_get_code(fc, x, a, b)
|
||||
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
|
||||
|
||||
# Test with alpha=1, beta=1 (default) - should also unfuse
|
||||
def f_default(x, a, b):
|
||||
return torch.nn.functional.relu(torch.addmm(x, a, b))
|
||||
|
||||
fc_default = torch.compile(f_default)
|
||||
expected_default = f_default(x, a, b)
|
||||
actual_default = fc_default(x, a, b)
|
||||
|
||||
torch.testing.assert_close(actual_default, expected_default)
|
||||
|
||||
# Should unfuse and not use extern_kernels.addmm
|
||||
_, (code) = run_and_get_code(fc_default, x, a, b)
|
||||
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
|
||||
|
||||
def test_serialized_patterns_up_to_date(self):
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._inductor.fx_passes import joint_graph
|
||||
|
||||
@ -7486,7 +7486,7 @@ class TestFXMemoryProfiler(TestCase):
|
||||
return fx_frames
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
def test_fx_memory_profiler_augmentation(self):
|
||||
"""Test that memory snapshots are augmented with FX debug information."""
|
||||
|
||||
|
||||
@ -4251,7 +4251,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_stack_trace_augmentation(self):
|
||||
"""
|
||||
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
|
||||
@ -4307,7 +4307,7 @@ event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_multiple_modules(self):
|
||||
"""
|
||||
Test that multiple compiled modules under the same profiler session
|
||||
@ -4351,7 +4351,7 @@ event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_nested_graph_modules(self):
|
||||
"""
|
||||
Test that nested graph modules (e.g., graph modules calling subgraphs)
|
||||
|
||||
@ -740,11 +740,8 @@ enable_aot_compile = False
|
||||
# HACK: this is for testing custom ops profiling only
|
||||
_custom_ops_profile: Optional[Any] = None
|
||||
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
# Deprecated! Please use the config in torch/fx/experimental/_config instead.
|
||||
enrich_profiler_metadata: bool = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
@ -4,6 +4,8 @@
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch.utils._pytree as python_pytree
|
||||
|
||||
from .. import polyfills, trace_rules
|
||||
|
||||
|
||||
@ -19,12 +21,14 @@ POLYFILLED_MODULE_NAMES: tuple[str, ...] = (
|
||||
"itertools",
|
||||
"operator",
|
||||
"os",
|
||||
"pytree",
|
||||
"struct",
|
||||
"sys",
|
||||
"fx",
|
||||
"tensor",
|
||||
)
|
||||
if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
POLYFILLED_MODULE_NAMES += ("pytree",)
|
||||
|
||||
POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple(
|
||||
importlib.import_module(f".{submodule}", package=polyfills.__name__)
|
||||
for submodule in POLYFILLED_MODULE_NAMES
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -550,10 +550,6 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
|
||||
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
|
||||
).upper() # type: ignore[assignment]
|
||||
|
||||
cutedsl_enable_autotuning: bool = (
|
||||
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
|
||||
)
|
||||
|
||||
# DEPRECATED. This setting is ignored.
|
||||
autotune_fallback_to_aten = False
|
||||
|
||||
|
||||
@ -913,6 +913,10 @@ def _get_optimization_cflags(
|
||||
if not config.is_fbcode():
|
||||
if platform.machine() == "ppc64le":
|
||||
cflags.append("mcpu=native")
|
||||
elif platform.machine() == "riscv64":
|
||||
cflags.append("march=rv64gc")
|
||||
elif platform.machine() == "riscv32":
|
||||
cflags.append("march=rv32gc")
|
||||
else:
|
||||
cflags.append("march=native")
|
||||
|
||||
|
||||
@ -1516,17 +1516,29 @@ def should_prefer_unfused_addmm(match):
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
|
||||
CallFunction(
|
||||
aten.addmm,
|
||||
KeywordArg("inp"),
|
||||
Arg(),
|
||||
Arg(),
|
||||
beta=KeywordArg("beta"),
|
||||
alpha=KeywordArg("alpha"),
|
||||
),
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
pass_dict=pass_patterns[2],
|
||||
extra_check=should_prefer_unfused_addmm,
|
||||
)
|
||||
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
|
||||
def repl(inp, x1, x2):
|
||||
return x1 @ x2 + inp
|
||||
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp, alpha, beta):
|
||||
def repl(inp, x1, x2, alpha, beta):
|
||||
mm_result = x1 @ x2
|
||||
if alpha != 1:
|
||||
mm_result = alpha * mm_result
|
||||
if beta != 1:
|
||||
inp = beta * inp
|
||||
return inp + mm_result
|
||||
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
match.replace_by_example(repl, [inp, mat1, mat2])
|
||||
match.replace_by_example(repl, [inp, mat1, mat2, alpha, beta])
|
||||
|
||||
|
||||
def is_valid_addmm_fusion(match):
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@ -14,7 +12,6 @@ from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
|
||||
from .. import config
|
||||
from ..codegen.wrapper import PythonWrapperCodegen
|
||||
from ..ir import _IntLike, Layout, TensorBox
|
||||
from ..utils import load_template
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -257,7 +254,3 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
|
||||
from torch._inductor.runtime.triton_compat import tl
|
||||
from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
@ -24,13 +22,11 @@ from ..utils import (
|
||||
get_num_sms,
|
||||
has_free_symbols,
|
||||
use_aten_gemm_kernels,
|
||||
use_blackwell_cutedsl_grouped_mm,
|
||||
use_triton_template,
|
||||
)
|
||||
from .mm_common import (
|
||||
_is_static_problem,
|
||||
check_supported_striding,
|
||||
load_kernel_template,
|
||||
persistent_grouped_mm_grid,
|
||||
)
|
||||
|
||||
@ -517,11 +513,6 @@ triton_scaled_grouped_mm_template = TritonTemplate(
|
||||
source=triton_grouped_mm_source,
|
||||
)
|
||||
|
||||
cutedsl_grouped_mm_template = CuteDSLTemplate(
|
||||
name="grouped_gemm_cutedsl",
|
||||
source=load_kernel_template("cutedsl_mm_grouped"),
|
||||
)
|
||||
|
||||
|
||||
def grouped_mm_args(
|
||||
mat1: TensorBox,
|
||||
@ -723,44 +714,43 @@ def _tuned_grouped_mm_common(
|
||||
# Checking only for the equality of corresponding dims of
|
||||
# multiplicands here, relying on meta function checks for
|
||||
# everything else.
|
||||
if len(m1_size) == 2:
|
||||
if len(m2_size) == 2:
|
||||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
if (
|
||||
is_nonzero
|
||||
and use_triton_template(layout)
|
||||
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
|
||||
):
|
||||
scaled = scale_a is not None
|
||||
if len(m1_size) == 2:
|
||||
if len(m2_size) == 2:
|
||||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
a_is_k_major = mat_a.get_stride()[-1] == 1
|
||||
b_is_k_major = mat_b.get_stride()[-2] == 1
|
||||
@ -798,22 +788,6 @@ def _tuned_grouped_mm_common(
|
||||
**config.kwargs,
|
||||
)
|
||||
|
||||
if use_blackwell_cutedsl_grouped_mm(
|
||||
mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result
|
||||
):
|
||||
for config in get_groupgemm_configs():
|
||||
kwargs = dict(
|
||||
ACC_DTYPE="cutlass.Float32",
|
||||
)
|
||||
|
||||
cutedsl_grouped_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
**asdict(config),
|
||||
)
|
||||
|
||||
input_gen_fns = {
|
||||
4: lambda x: create_offsets(
|
||||
x, m1_size, m2_size, offs.get_size() if offs is not None else None
|
||||
|
||||
@ -1,333 +0,0 @@
|
||||
import functools
|
||||
from torch._inductor.runtime.runtime_utils import ceildiv
|
||||
from cutlass.utils import TensorMapUpdateMode
|
||||
{{gen_defines()}}
|
||||
# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ----
|
||||
from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import (
|
||||
GroupedGemmKernel,
|
||||
)
|
||||
|
||||
|
||||
# Note about caching:
|
||||
# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor
|
||||
# maintains its own local caching system. At this stage, all compile-time
|
||||
# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel
|
||||
# name itself ({{kernel_name}}) are permanently baked into the file, so they
|
||||
# do not need to be included in any cache key.
|
||||
#
|
||||
# The caching mechanism is split into two levels:
|
||||
#
|
||||
# 1. prep_cache
|
||||
# Caches the compiled executor for build_group_ptrs_from_bases(). This
|
||||
# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C,
|
||||
# and can therefore be safely reused across runs with different group
|
||||
# partitioning (`offs`).
|
||||
#
|
||||
# 2. gemm_cache
|
||||
# Caches the compiled Grouped GEMM executor. Its key extends the prep
|
||||
# cache key with hardware- and grid-specific parameters:
|
||||
# (prep_cache_key, max_active_clusters, total_num_clusters).
|
||||
# This is necessary because different `offs` tensors can change the
|
||||
# per-group problem sizes and thus alter `total_num_clusters`, which in
|
||||
# turn changes the grid shape and persistent scheduler configuration.
|
||||
# Kernels compiled for one grid cannot be safely reused for another.
|
||||
#
|
||||
#
|
||||
# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically,
|
||||
# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead,
|
||||
# despite depending only on the GPU type. We cache this function to mitigate
|
||||
# redundant recompiles even when shape/stride/dtype cache misses force kernel
|
||||
# regeneration. A follow-up study will investigate the root cause.
|
||||
|
||||
prep_cache = {}
|
||||
gemm_cache = {}
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_hardware_info():
|
||||
hw = cutlass.utils.HardwareInfo()
|
||||
sm_count = hw.get_max_active_clusters(1)
|
||||
max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N)
|
||||
|
||||
return (sm_count, max_active_clusters)
|
||||
|
||||
|
||||
def get_prep_cache_key(input_a, input_b, output):
|
||||
"""
|
||||
Returns a tuple key for caching the preprocessing kernel executor based on kernel name,
|
||||
shapes, strides, and dtypes of input/output tensors.
|
||||
"""
|
||||
return (
|
||||
tuple(input_a.shape),
|
||||
tuple(input_a.stride()),
|
||||
input_a.dtype,
|
||||
tuple(input_b.shape),
|
||||
tuple(input_b.stride()),
|
||||
input_b.dtype,
|
||||
tuple(output.shape),
|
||||
tuple(output.stride()),
|
||||
output.dtype,
|
||||
)
|
||||
|
||||
|
||||
def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters):
|
||||
"""
|
||||
Returns a tuple key for caching the gemm kernel executor by extending the
|
||||
prep cache key with hardware- and grid-specific parameters.
|
||||
"""
|
||||
return (
|
||||
prep_cache_key,
|
||||
max_active_clusters,
|
||||
total_num_clusters,
|
||||
)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def build_group_ptrs_from_bases_kernel(
|
||||
base_A_u64: cutlass.Int64, # device addr of input_a (bytes)
|
||||
base_B_u64: cutlass.Int64, # device addr of input_b (bytes)
|
||||
base_C_u64: cutlass.Int64, # device addr of Output (bytes)
|
||||
offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative
|
||||
K: cutlass.Constexpr,
|
||||
N: cutlass.Constexpr,
|
||||
sizeof_element: cutlass.Int32, # bytes
|
||||
# -------- STRIDES (in ELEMENTS) --------
|
||||
stride_A_m_elems: cutlass.Constexpr, # A.stride(0)
|
||||
stride_A_k_elems: cutlass.Constexpr, # A.stride(1)
|
||||
stride_B0_elems: cutlass.Constexpr, # B.stride(0)
|
||||
stride_Bk_elems: cutlass.Constexpr, # B.stride(1)
|
||||
stride_Bn_elems: cutlass.Constexpr, # B.stride(2)
|
||||
stride_C_m_elems: cutlass.Constexpr, # C.stride(0)
|
||||
stride_C_n_elems: cutlass.Constexpr, # C.stride(1)
|
||||
# -------- OUTPUTS --------
|
||||
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr)
|
||||
out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1)
|
||||
out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]]
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
g = tidx
|
||||
|
||||
m_beg_i32 = 0
|
||||
if g > 0:
|
||||
m_beg_i32 = offs[g - 1]
|
||||
m_end_i32 = offs[g]
|
||||
m_g_i32 = m_end_i32 - m_beg_i32
|
||||
|
||||
a_byte_off = (
|
||||
cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element)
|
||||
)
|
||||
c_byte_off = (
|
||||
cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element)
|
||||
)
|
||||
b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element)
|
||||
|
||||
# ---- pointers ----
|
||||
out_ptrs[g, 0] = base_A_u64 + a_byte_off
|
||||
out_ptrs[g, 1] = base_B_u64 + b_byte_off
|
||||
out_ptrs[g, 2] = base_C_u64 + c_byte_off
|
||||
|
||||
# ---- (m, n, k, 1) ----
|
||||
out_problem[g, 0] = m_g_i32
|
||||
out_problem[g, 1] = N
|
||||
out_problem[g, 2] = K
|
||||
out_problem[g, 3] = cutlass.Int32(1)
|
||||
|
||||
# ---- strides ----
|
||||
out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems)
|
||||
out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems)
|
||||
out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems)
|
||||
out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems)
|
||||
out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems)
|
||||
out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_build_group_ptrs_from_bases(
|
||||
base_A_u64: cutlass.Int64,
|
||||
base_B_u64: cutlass.Int64,
|
||||
base_C_u64: cutlass.Int64,
|
||||
offs: cute.Tensor,
|
||||
G: cutlass.Constexpr,
|
||||
K: cutlass.Constexpr,
|
||||
N: cutlass.Constexpr,
|
||||
sizeof_element: cutlass.Constexpr,
|
||||
stride_A_m_elems: cutlass.Constexpr,
|
||||
stride_A_k_elems: cutlass.Constexpr,
|
||||
stride_B0_elems: cutlass.Constexpr,
|
||||
stride_Bk_elems: cutlass.Constexpr,
|
||||
stride_Bn_elems: cutlass.Constexpr,
|
||||
stride_C_m_elems: cutlass.Constexpr,
|
||||
stride_C_n_elems: cutlass.Constexpr,
|
||||
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64
|
||||
out_problem: cute.Tensor, # [G,4] cutlass.Int32
|
||||
out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
build_group_ptrs_from_bases_kernel(
|
||||
base_A_u64,
|
||||
base_B_u64,
|
||||
base_C_u64,
|
||||
offs,
|
||||
K,
|
||||
N,
|
||||
sizeof_element,
|
||||
stride_A_m_elems,
|
||||
stride_A_k_elems,
|
||||
stride_B0_elems,
|
||||
stride_Bk_elems,
|
||||
stride_Bn_elems,
|
||||
stride_C_m_elems,
|
||||
stride_C_n_elems,
|
||||
out_ptrs,
|
||||
out_problem,
|
||||
out_strides_abc,
|
||||
).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream)
|
||||
|
||||
|
||||
{{def_kernel("input_a", "input_b", "input_a_offs")}}
|
||||
stream = cuda.CUstream(stream)
|
||||
|
||||
input_b = input_b.transpose(1, 2)
|
||||
|
||||
sumM, K = input_a.shape
|
||||
G, N, Kb = input_b.shape
|
||||
|
||||
dev = input_a.device
|
||||
|
||||
base_A_u64 = int(input_a.data_ptr())
|
||||
base_B_u64 = int(input_b.data_ptr())
|
||||
base_C_u64 = int({{get_output()}}.data_ptr())
|
||||
|
||||
ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64)
|
||||
probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32)
|
||||
strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32)
|
||||
ptrs = from_dlpack(ptrs_t)
|
||||
probs = from_dlpack(probs_t)
|
||||
strides = from_dlpack(strides_t)
|
||||
|
||||
prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}})
|
||||
prep_executor = prep_cache.get(prep_cache_key)
|
||||
|
||||
if prep_executor is None:
|
||||
sizeof_element = int(input_a.element_size())
|
||||
sA_m, sA_k = map(int, input_a.stride())
|
||||
sB_0, sB_n, sB_k = map(int, input_b.stride())
|
||||
sC_m, sC_n = map(int, {{get_output()}}.stride())
|
||||
|
||||
prep_executor = cute.compile(
|
||||
launch_build_group_ptrs_from_bases,
|
||||
base_A_u64=base_A_u64,
|
||||
base_B_u64=base_B_u64,
|
||||
base_C_u64=base_C_u64,
|
||||
offs=from_dlpack(input_a_offs),
|
||||
G=int(G),
|
||||
K=int(K),
|
||||
N=int(N),
|
||||
sizeof_element=sizeof_element,
|
||||
stride_A_m_elems=sA_m,
|
||||
stride_A_k_elems=sA_k,
|
||||
stride_B0_elems=sB_0,
|
||||
stride_Bk_elems=sB_k,
|
||||
stride_Bn_elems=sB_n,
|
||||
stride_C_m_elems=sC_m,
|
||||
stride_C_n_elems=sC_n,
|
||||
out_ptrs=ptrs,
|
||||
out_problem=probs,
|
||||
out_strides_abc=strides,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
prep_cache[prep_cache_key] = prep_executor
|
||||
|
||||
prep_executor(
|
||||
base_A_u64=base_A_u64,
|
||||
base_B_u64=base_B_u64,
|
||||
base_C_u64=base_C_u64,
|
||||
offs=from_dlpack(input_a_offs),
|
||||
out_ptrs=ptrs,
|
||||
out_problem=probs,
|
||||
out_strides_abc=strides,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# --- Tensormap workspace per SM ---
|
||||
num_tensormap_buffers, max_active_clusters = get_hardware_info()
|
||||
tensormap_shape = (
|
||||
num_tensormap_buffers,
|
||||
GroupedGemmKernel.num_tensormaps,
|
||||
GroupedGemmKernel.bytes_per_tensormap // 8,
|
||||
)
|
||||
tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64)
|
||||
tensormap_workspace = from_dlpack(tensormap_workspace_t)
|
||||
|
||||
# --- Total clusters ---
|
||||
def compute_total_num_clusters(
|
||||
problem_sizes_mnkl,
|
||||
cluster_tile_shape_mn,
|
||||
):
|
||||
total_num_clusters = 0
|
||||
for m, n, _, _ in problem_sizes_mnkl:
|
||||
num_clusters_mn = tuple(
|
||||
ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn)
|
||||
)
|
||||
total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
|
||||
return total_num_clusters
|
||||
|
||||
# Compute cluster tile shape
|
||||
def compute_cluster_tile_shape(
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
use_2cta_instrs,
|
||||
):
|
||||
cta_tile_shape_mn = list(mma_tiler_mn)
|
||||
if use_2cta_instrs:
|
||||
cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2
|
||||
return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))
|
||||
|
||||
cluster_tile_shape_mn = compute_cluster_tile_shape(
|
||||
(TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA)
|
||||
)
|
||||
|
||||
total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn))
|
||||
|
||||
gemm_cache_key = get_gemm_cache_key(
|
||||
prep_cache_key, max_active_clusters, total_num_clusters
|
||||
)
|
||||
gemm_executor = gemm_cache.get(gemm_cache_key)
|
||||
|
||||
if gemm_executor is None:
|
||||
grouped_gemm = GroupedGemmKernel(
|
||||
acc_dtype=ACC_DTYPE,
|
||||
use_2cta_instrs=USE_2_CTA,
|
||||
mma_tiler_mn=(TILE_M, TILE_N),
|
||||
cluster_shape_mn=(CLUSTER_M, CLUSTER_N),
|
||||
tensormap_update_mode=TENSORMAP_UPDATE_MODE,
|
||||
)
|
||||
|
||||
gemm_executor = cute.compile(
|
||||
grouped_gemm,
|
||||
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
|
||||
G,
|
||||
probs,
|
||||
strides,
|
||||
ptrs,
|
||||
total_num_clusters,
|
||||
tensormap_workspace,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
)
|
||||
|
||||
gemm_cache[gemm_cache_key] = gemm_executor
|
||||
|
||||
gemm_executor(
|
||||
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
|
||||
probs,
|
||||
strides,
|
||||
ptrs,
|
||||
tensormap_workspace,
|
||||
stream,
|
||||
)
|
||||
@ -1,141 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from itertools import product
|
||||
|
||||
import torch._inductor.config as config
|
||||
|
||||
|
||||
class TensorMapUpdateMode(Enum):
|
||||
"""Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency."""
|
||||
|
||||
SMEM = auto()
|
||||
GMEM = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CuTeGemmConfig:
|
||||
TILE_M: int = 128
|
||||
TILE_N: int = 192
|
||||
CLUSTER_M: int = 2
|
||||
CLUSTER_N: int = 1
|
||||
USE_2_CTA: bool = False
|
||||
TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM
|
||||
|
||||
|
||||
def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
For information regarding valid config sets, see:
|
||||
https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py
|
||||
"""
|
||||
|
||||
# Tile_n is always the same regardless of 2cta
|
||||
tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
# Valid clusters
|
||||
clusters_no_2cta = [
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(1, 4),
|
||||
(1, 8),
|
||||
(1, 16),
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(2, 8),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
(8, 1),
|
||||
(8, 2),
|
||||
(16, 1),
|
||||
]
|
||||
clusters_2cta = [
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(2, 8),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
(8, 1),
|
||||
(8, 2),
|
||||
(16, 1),
|
||||
]
|
||||
|
||||
configs: list[CuTeGemmConfig] = []
|
||||
|
||||
for use_2cta, cluster_set, tile_m_range in [
|
||||
(False, clusters_no_2cta, [64, 128]),
|
||||
(True, clusters_2cta, [128, 256]),
|
||||
]:
|
||||
for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product(
|
||||
[TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM],
|
||||
tile_m_range,
|
||||
tile_n_vals,
|
||||
cluster_set,
|
||||
):
|
||||
configs.append(
|
||||
CuTeGemmConfig(
|
||||
tile_m,
|
||||
tile_n,
|
||||
cluster_m,
|
||||
cluster_n,
|
||||
USE_2_CTA=use_2cta,
|
||||
TENSORMAP_UPDATE_MODE=tensormap_update_mode,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def get_default_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
"""
|
||||
|
||||
config_tuples = [
|
||||
(128, 256, 2, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 160, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(128, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 256, 2, 2, True, TensorMapUpdateMode.GMEM),
|
||||
(128, 256, 1, 2, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 256, 2, 1, True, TensorMapUpdateMode.SMEM),
|
||||
(128, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 8, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 192, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 2, 2, True, TensorMapUpdateMode.SMEM),
|
||||
(128, 96, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(64, 192, 1, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(64, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 192, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(128, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 160, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
]
|
||||
|
||||
return [CuTeGemmConfig(*args) for args in config_tuples]
|
||||
|
||||
|
||||
def get_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
|
||||
Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures
|
||||
or unstable results. By default, autotuning is disabled and we return only
|
||||
a single baseline config.
|
||||
"""
|
||||
if (
|
||||
config.cutedsl_enable_autotuning
|
||||
and config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
||||
):
|
||||
return get_exhaustive_groupgemm_configs()
|
||||
elif config.cutedsl_enable_autotuning:
|
||||
return get_default_groupgemm_configs()
|
||||
else:
|
||||
return [get_default_groupgemm_configs()[0]]
|
||||
@ -1911,84 +1911,6 @@ def use_triton_blackwell_tma_template(
|
||||
return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def ensure_cute_available() -> bool:
|
||||
"""Check if CuTeDSL is importable; cache the result for reuse.
|
||||
|
||||
Call ensure_cute_available.cache_clear() after installing CuTeDSL
|
||||
in the same interpreter to retry the import.
|
||||
"""
|
||||
try:
|
||||
return importlib.util.find_spec("cutlass.cute") is not None
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def use_blackwell_cutedsl_grouped_mm(
|
||||
mat_a: Any,
|
||||
mat_b: Any,
|
||||
layout: Layout,
|
||||
a_is_2d: bool,
|
||||
b_is_2d: bool,
|
||||
offs: Optional[Any],
|
||||
bias: Optional[Any],
|
||||
scale_result: Optional[Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if we can use the blackwell kernel for grouped mm.
|
||||
Required conditions:
|
||||
1. CuTeDSL backend is enabled
|
||||
2. CuTeDSL is available
|
||||
3. We are on a blackwell arch
|
||||
4. The dtype is bf16
|
||||
5. Max autotune or max autotune gemm is enabled
|
||||
6. A, B, and the output are 16B aligned
|
||||
7. We are not using dynamic shapes
|
||||
8. A is 2d
|
||||
9. B is 3d
|
||||
10. Offsets are provided
|
||||
11. Bias and Scale are not provided
|
||||
"""
|
||||
if not ensure_cute_available():
|
||||
return False
|
||||
|
||||
if not _use_autotune_backend("CUTEDSL"):
|
||||
return False
|
||||
|
||||
from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
|
||||
if not is_gpu(layout.device.type):
|
||||
return False
|
||||
|
||||
if not is_datacenter_blackwell_arch():
|
||||
return False
|
||||
|
||||
layout_dtypes = [torch.bfloat16]
|
||||
if not _use_template_for_gpu(layout, layout_dtypes):
|
||||
return False
|
||||
|
||||
if not (config.max_autotune or config.max_autotune_gemm):
|
||||
return False
|
||||
|
||||
# Checks for 16B ptr and stride alignment
|
||||
if not can_use_tma(mat_a, mat_b, output_layout=layout):
|
||||
return False
|
||||
|
||||
if any(is_dynamic(x) for x in [mat_a, mat_b]):
|
||||
return False
|
||||
|
||||
if not a_is_2d or b_is_2d:
|
||||
return False
|
||||
|
||||
if offs is None:
|
||||
return False
|
||||
|
||||
if bias is not None or scale_result is not None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
@ -1485,6 +1485,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
use_full_backward: Optional[bool] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
# Init parent
|
||||
super().__init__(
|
||||
@ -1517,6 +1518,11 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
# This will be set during init of derived schedules
|
||||
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
|
||||
|
||||
# When using a custom backward function, we may or may not need autograd to be used
|
||||
# for the backward pass. This flag is used to determine whether or torch.is_grad_enabled()
|
||||
# check should be performed before the step function.
|
||||
self._backward_requires_autograd = backward_requires_autograd
|
||||
|
||||
if use_full_backward is not None:
|
||||
logger.warning(
|
||||
"Deprecation warning: 'use_full_backward' is no longer supported. "
|
||||
@ -1609,7 +1615,11 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
losses: a list to store the losses for each microbatch.
|
||||
return_outputs: whether to return the outputs from the last stage.
|
||||
"""
|
||||
if self._has_backward and not torch.is_grad_enabled():
|
||||
if (
|
||||
self._has_backward
|
||||
and self._backward_requires_autograd
|
||||
and not torch.is_grad_enabled()
|
||||
):
|
||||
raise RuntimeError(
|
||||
"step() requires gradients to be enabled for backward computation; "
|
||||
"it should not be used under torch.no_grad() context. "
|
||||
@ -1891,7 +1901,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
Args:
|
||||
computation_type: The computation type for which to register the custom function
|
||||
custom_function: The function to execute when this computation type is encountered.
|
||||
Must have signature: (stage: _PipelineStageBase, mb_index: int, *args, **kwargs) -> None
|
||||
Must have signature: (action: _Action, ctx: _PipelineContext) -> None
|
||||
"""
|
||||
# Ensure that the computation type is valid
|
||||
if computation_type not in (
|
||||
@ -1900,10 +1910,13 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
BACKWARD_INPUT,
|
||||
BACKWARD_WEIGHT,
|
||||
OVERLAP_F_B,
|
||||
UNSHARD,
|
||||
RESHARD,
|
||||
REDUCE_GRAD,
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid computation type {computation_type}. Only FORWARD, FULL_BACKWARD, \
|
||||
BACKWARD_INPUT, BACKWARD_WEIGHT, and OVERLAP_F_B are supported."
|
||||
BACKWARD_INPUT, BACKWARD_WEIGHT, OVERLAP_F_B, UNSHARD, RESHARD and REDUCE_GRAD are supported."
|
||||
)
|
||||
|
||||
# Check if computation_type is already registered
|
||||
@ -2296,6 +2309,7 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
|
||||
loss_fn: Optional[Union[Callable, _Loss]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
stages=stages,
|
||||
@ -2303,6 +2317,7 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
|
||||
loss_fn=loss_fn,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
backward_requires_autograd=backward_requires_autograd,
|
||||
)
|
||||
|
||||
# 1. Create the pipeline_order (all ranks do this calculation)
|
||||
@ -2510,6 +2525,7 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
self.pp_group_size = stages[0].group_size
|
||||
super().__init__(
|
||||
@ -2520,6 +2536,7 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
backward_requires_autograd=backward_requires_autograd,
|
||||
)
|
||||
self.n_local_stages = len(stages)
|
||||
self.rank = stages[0].group_rank
|
||||
@ -2622,6 +2639,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
# TODO: we dont support input/weight backward split with torch.compile
|
||||
_check_torch_compile_compatibility(stages, self.__class__.__name__)
|
||||
@ -2634,6 +2652,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
backward_requires_autograd=backward_requires_autograd,
|
||||
)
|
||||
self.n_local_stages = len(stages)
|
||||
self.rank = stages[0].group_rank
|
||||
@ -2819,6 +2838,7 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
# TODO: we dont support input/weight backward split with torch.compile
|
||||
_check_torch_compile_compatibility(stages, self.__class__.__name__)
|
||||
@ -2831,6 +2851,7 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
backward_requires_autograd=backward_requires_autograd,
|
||||
)
|
||||
self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
|
||||
self.pp_group_size, self._num_stages, style="v"
|
||||
@ -2995,6 +3016,7 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
# TODO: we dont support input/weight backward split with torch.compile
|
||||
_check_torch_compile_compatibility(stages, self.__class__.__name__)
|
||||
@ -3007,6 +3029,7 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
backward_requires_autograd=backward_requires_autograd,
|
||||
)
|
||||
self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
|
||||
self.pp_group_size, self._num_stages, style="v"
|
||||
|
||||
@ -2,6 +2,8 @@ import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from torch.utils._config_module import Config, install_config_module
|
||||
|
||||
|
||||
# [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors.
|
||||
no_data_dependent_graph_break = (
|
||||
@ -100,7 +102,11 @@ backed_size_oblivious = False
|
||||
# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking.
|
||||
skip_dtype_check_in_meta_registrations = False
|
||||
|
||||
from torch.utils._config_module import install_config_module
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
|
||||
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
||||
@ -20,6 +20,7 @@ from torch.nn.modules.module import _addindent
|
||||
from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from .experimental import _config as fx_experimental_config
|
||||
from .graph import (
|
||||
_BoxedCodeGen,
|
||||
_custom_builtins,
|
||||
@ -858,14 +859,15 @@ class {module_name}(torch.nn.Module):
|
||||
called after editing the contained ``graph``, otherwise the generated
|
||||
code of this ``GraphModule`` will be out of date.
|
||||
"""
|
||||
# Do not import anything inside recompile, it might slow down the
|
||||
# function and cause perf regression. Import outside of the method instead.
|
||||
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
python_code = self._graph.python_code(
|
||||
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
|
||||
root_module="self",
|
||||
record_func=fx_experimental_config.enrich_profiler_metadata,
|
||||
)
|
||||
self._code = python_code.src
|
||||
self._lineno_map = python_code._lineno_map
|
||||
@ -874,7 +876,7 @@ class {module_name}(torch.nn.Module):
|
||||
cls = type(self)
|
||||
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
|
||||
|
||||
if dynamo_config.enrich_profiler_metadata:
|
||||
if fx_experimental_config.enrich_profiler_metadata:
|
||||
# Generate metadata and register for profiler augmentation
|
||||
node_metadata: dict[int, dict[str, Any]] = {}
|
||||
for i, node in enumerate(self._graph.nodes):
|
||||
|
||||
Reference in New Issue
Block a user