Compare commits

..

1 Commits

Author SHA1 Message Date
0ad356fa6c move log from warning to debug 2025-11-04 10:55:51 -08:00
28 changed files with 1087 additions and 611 deletions

View File

@ -129,7 +129,7 @@ function install_129 {
}
function install_128 {
CUDNN_VERSION=9.10.2.21
CUDNN_VERSION=9.8.0.87
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
# install CUDA 12.8.1 in the same container
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux

View File

@ -272,18 +272,6 @@ def smoke_test_cuda(
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
print(f"Torch cuDNN version: {torch_cudnn_version}")
torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion()
print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}")
torch_cudnn_runtime_version = tuple(
[int(x) for x in torch_cudnn_version.split(".")]
)
if torch_cudnn_runtime_version != torch_cudnn_compile_version:
raise RuntimeError(
"cuDNN runtime version doesn't match comple version. "
f"Loaded: {torch_cudnn_runtime_version} "
f"Expected: {torch_cudnn_compile_version}"
)
if sys.platform in ["linux", "linux2"]:
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
print(f"Torch nccl; version: {torch_nccl_version}")

View File

@ -337,7 +337,7 @@ test_python() {
test_python_smoke() {
# Smoke tests for H100/B200
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty
}

View File

@ -97,8 +97,8 @@ jobs:
shell: bash
run: |
ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx')
if [[ $ngpu -lt 2 ]]; then #We are temporarily reducing this down to 2 from 4 so that we can run tests on nodes with less gpus.
echo "Error: only $ngpu GPU(s) detected, at least 2 GPUs are needed for distributed jobs"
if [[ $ngpu -lt 4 ]]; then
echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs"
exit 1
fi

1
.gitignore vendored
View File

@ -127,6 +127,7 @@ torch/test/
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
torch/version.py
torch/_inductor/kernel/vendored_templates/*
minifier_launcher.py
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*

View File

@ -1,7 +1,7 @@
# Security Policy
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
- [**Using PyTorch Securely**](#using-pytorch-securely)
- [**Using Pytorch Securely**](#using-pytorch-securely)
- [Untrusted models](#untrusted-models)
- [TorchScript models](#torchscript-models)
- [Untrusted inputs](#untrusted-inputs)
@ -10,28 +10,28 @@
- [**CI/CD security principles**](#cicd-security-principles)
## Reporting Security Issues
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
https://www.facebook.com/whitehat
## Using PyTorch Securely
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
## Using Pytorch Securely
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
### Untrusted models
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
### TorchScript models
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
### Untrusted inputs during training and prediction
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
### Data privacy
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
### Using distributed features

View File

@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None:
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
def mirror_inductor_external_kernels() -> None:
"""
Copy external kernels into Inductor so they are importable.
"""
paths = [
(
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
CWD
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
),
]
for new_path, orig_path in paths:
# Create the dirs involved in new_path if they don't exist
if not new_path.exists():
new_path.parent.mkdir(parents=True, exist_ok=True)
# Copy the files from the orig location to the new location
if orig_path.is_file():
shutil.copyfile(orig_path, new_path)
continue
if orig_path.is_dir():
if new_path.exists():
# copytree fails if the tree exists already, so remove it.
shutil.rmtree(new_path)
shutil.copytree(orig_path, new_path)
continue
raise RuntimeError(
"Check the file paths in `mirror_inductor_external_kernels()`"
)
# ATTENTION: THIS IS AI SLOP
def extract_variant_from_version(version: str) -> str:
"""Extract variant from version string, defaulting to 'cpu'."""
@ -1616,6 +1647,8 @@ def main() -> None:
if RUN_BUILD_DEPS:
build_deps()
mirror_inductor_external_kernels()
(
ext_modules,
cmdclass,
@ -1649,6 +1682,7 @@ def main() -> None:
"_inductor/codegen/aoti_runtime/*.cpp",
"_inductor/script.ld",
"_inductor/kernel/flex/templates/*.jinja",
"_inductor/kernel/templates/*.jinja",
"_export/serde/*.yaml",
"_export/serde/*.thrift",
"share/cmake/ATen/*.cmake",

View File

@ -32,7 +32,6 @@ from torch.distributed.tensor._ops._einsum_strategy import (
)
from torch.distributed.tensor._ops.utils import (
register_op_strategy,
register_single_dim_strategy,
replicate_op_strategy,
)
from torch.distributed.tensor.debug import CommDebugMode
@ -656,202 +655,5 @@ TestStrategyHashingWithLocalTensor = create_local_tensor_test_class(
TestStrategyHashing,
)
class TestSingleDimStrategy(DTensorTestBase):
@with_comms
def test_register_single_dim_strategy_replaces_existing_rule(self):
"""
Test that calling register_single_dim_strategy works and replaces an existing registered rule.
"""
from torch.distributed.tensor._ops._matrix_ops import (
_mm_like_strategy,
gen_single_dim_einsum_strategies,
)
mesh = self.build_device_mesh()
# Create test inputs
lhs_tensor = torch.randn(6, 8)
rhs_tensor = torch.randn(8, 12)
lhs_tensor_meta = extract_tensor_meta(lhs_tensor)
rhs_tensor_meta = extract_tensor_meta(rhs_tensor)
# Test a specific input sharding combination
lhs_placement = (Shard(1),)
rhs_placement = (Shard(0),)
lhs_spec = DTensorSpec(mesh, lhs_placement, lhs_tensor_meta)
rhs_spec = DTensorSpec(mesh, rhs_placement, rhs_tensor_meta)
# Create the OpSchema for mm operation
op_schema = OpSchema(
torch.ops.aten.mm.default,
(
OpStrategy([OpSpec(lhs_spec)]),
OpStrategy([OpSpec(rhs_spec)]),
),
{},
)
# Get the strategies from the old mm_like_strategy (what was used before)
old_style_strategy = _mm_like_strategy("mk,kn->mn", mesh, op_schema)
# Get the strategies from the new register_single_dim_strategy approach
# First, we need to get the single dim strategy function
def mm_single_dim_strategy_func(op_schema: OpSchema):
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
# Now expand it to full strategy using the same logic as register_single_dim_strategy
single_dim_strategies = mm_single_dim_strategy_func(op_schema)
all_mesh_dim_strategies = [single_dim_strategies] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = [
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
]
all_strategies.append(
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
)
new_style_strategy = OpStrategy(all_strategies)
# Verify that both strategies produce the same set of shardings
old_strategy_set = {str(strategy) for strategy in old_style_strategy.strategies}
new_strategy_set = {str(strategy) for strategy in new_style_strategy.strategies}
self.assertEqual(
old_strategy_set,
new_strategy_set,
"Old and new strategies should produce the same shardings",
)
# Verify that the registration actually works by checking the propagator
propagator = DTensor._op_dispatcher.sharding_propagator
# Save the original strategy if it exists
original_strategy = None
if torch.ops.aten.mm.default in propagator.op_strategy_funcs:
original_strategy = propagator.op_strategy_funcs[torch.ops.aten.mm.default]
try:
# Register a custom single-dim strategy
@register_single_dim_strategy(torch.ops.aten.mm.default)
def custom_mm_single_dim_strategy(op_schema: OpSchema):
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
# Verify the strategy was registered
self.assertIn(
torch.ops.aten.mm.default,
propagator.op_strategy_funcs,
"Strategy should be registered after calling register_single_dim_strategy",
)
# Verify it replaced any existing rule
registered_func = propagator.op_strategy_funcs[torch.ops.aten.mm.default]
self.assertIsNotNone(
registered_func, "Registered strategy function should not be None"
)
# Test that the registered strategy produces valid output
result_strategy = registered_func(op_schema)
self.assertIsInstance(
result_strategy, OpStrategy, "Result should be an OpStrategy"
)
self.assertGreater(
len(result_strategy.strategies),
0,
"Strategy should contain at least one OpSpec",
)
finally:
# Restore original strategy if it existed
if original_strategy is not None:
propagator.op_strategy_funcs[torch.ops.aten.mm.default] = (
original_strategy
)
else:
if torch.ops.aten.mm.default in propagator.op_strategy_funcs:
del propagator.op_strategy_funcs[torch.ops.aten.mm.default]
# Clear the cache
propagator.propagate_op_sharding.cache.cache_clear()
@with_comms
def test_single_dim_strategy_shardings_match_full_strategy(self):
"""
Verify that the shardings produced by a single-dim strategy match those produced
by the full strategy implementation.
"""
from torch.distributed.tensor._ops._matrix_ops import (
gen_single_dim_einsum_strategies,
)
mesh = self.build_device_mesh()
# Create test inputs
lhs_tensor = torch.randn(6, 8)
rhs_tensor = torch.randn(8, 12)
lhs_tensor_meta = extract_tensor_meta(lhs_tensor)
rhs_tensor_meta = extract_tensor_meta(rhs_tensor)
# Test multiple input sharding combinations
mm_combs = (
(Shard(0), Replicate()),
(Replicate(), Shard(1)),
(Shard(1), Shard(0)),
(Replicate(), Replicate()),
)
for lhs_placement, rhs_placement in mm_combs:
lhs_spec = DTensorSpec(mesh, (lhs_placement,), lhs_tensor_meta)
rhs_spec = DTensorSpec(mesh, (rhs_placement,), rhs_tensor_meta)
op_schema = OpSchema(
torch.ops.aten.mm.default,
(
OpStrategy([OpSpec(lhs_spec)]),
OpStrategy([OpSpec(rhs_spec)]),
),
{},
)
# Get single-dim strategies
single_dim_strategies = gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
# Expand to full strategy (mimicking what register_single_dim_strategy does)
all_mesh_dim_strategies = [single_dim_strategies] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
expanded_strategies = []
for strategy_comb in strategy_combs:
spec_list = [
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
]
expanded_strategies.append(
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
)
# Verify that for the given input shardings, we can find a matching strategy
# with zero redistribute cost
found_zero_cost_strategy = False
for strategy in expanded_strategies:
if strategy.input_specs == (lhs_spec, rhs_spec):
# This strategy should have zero redistribute cost since inputs match
found_zero_cost_strategy = True
# In a real strategy, redistribute costs would be computed
# Here we just verify the structure is correct
self.assertEqual(
len(strategy.input_specs),
2,
"MM should have exactly 2 input specs",
)
self.assertIsNotNone(
strategy.output_specs, "Output spec should not be None"
)
break
self.assertTrue(
found_zero_cost_strategy,
f"Should find a strategy matching input shardings {lhs_placement}, {rhs_placement}",
)
if __name__ == "__main__":
run_tests()

View File

@ -167,14 +167,6 @@ def _pack_fp8_wrap(x):
if not x.dtype.is_floating_point:
return x
if type(x) is not torch.Tensor:
# Check only during compilation
# Test calls hooks to get reference output
ctx = torch._functorch._aot_autograd.graph_compile._get_saved_tensor_hook_context()
assert ctx["_fw_graph"] is not None
assert ctx["_bw_graph"] is not None
assert ctx["_node"] is not None
return (x.dtype, x.to(torch.float8_e5m2))
@ -184,13 +176,6 @@ def _unpack_fp8_wrap(x):
return x
dtype, tensor = x
if type(tensor) is not torch.Tensor:
# Check only during compilation
# Test calls hooks to get reference output
ctx = torch._functorch._aot_autograd.graph_compile._get_saved_tensor_hook_context()
assert ctx["_fw_graph"] is not None
assert ctx["_bw_graph"] is not None
assert ctx["_node"] is not None
return tensor.to(dtype)

View File

@ -0,0 +1,154 @@
# Owner(s): ["module: inductor"]
import unittest
import torch
from torch import Tensor
from torch._inductor import config
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch._inductor.utils import ensure_cute_available
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
@unittest.skipIf(
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
"CuTeDSL library or Blackwell device not available",
)
@instantiate_parametrized_tests
class TestCuTeDSLGroupedGemm(InductorTestCase):
def _get_inputs(
self,
group_size: int,
M_hint: int,
K: int,
N: int,
device: str,
dtype: torch.dtype,
alignment: int = 16,
) -> tuple[Tensor, Tensor, Tensor]:
# --- Random, tile-aligned M sizes ---
M_sizes = (
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
* alignment
)
M_total = torch.sum(M_sizes).item()
# --- Construct input tensors ---
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
# --- Build offsets (no leading zero, strictly increasing) ---
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
return (A, B, offsets)
@parametrize("group_size", (2, 8))
@parametrize("M_hint", (256, 1024))
@parametrize("K", (64, 128))
@parametrize("N", (128, 256))
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
device = "cuda"
dtype = torch.bfloat16
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# Eager execution
c_eager = grouped_gemm_fn(A, B, offsets)
# Test with Cute backend
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
@parametrize("layout_B", ("contiguous", "broadcasted"))
def test_grouped_gemm_assorted_layouts(
self,
layout_A: str,
layout_B: str,
):
device = "cuda"
dtype = torch.bfloat16
G, K, N = 8, 64, 128
M_sizes = [128] * G
sum_M = sum(M_sizes)
offsets = torch.tensor(
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
)
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
A = A_base
if layout_A == "offset":
# allocate bigger buffer than needed, use nonzero storage offset
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
offset = 128 # skip first 128 elements
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
elif layout_A == "padded":
# simulate row pitch > K (row_stride = K + pad)
row_pitch = K + 8
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
elif layout_A == "view":
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
A = A_storage.view(sum_M, K)
assert A._base is not None
assert A.shape == (sum_M, K)
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
if layout_B == "broadcasted":
# Broadcast B across groups (zero stride along G)
B = B[0].expand(G, K, N)
assert B.stride(0) == 0
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# --- eager ---
c_eager = grouped_gemm_fn(A, B, offsets)
# --- compiled (CUTE backend) ---
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
if __name__ == "__main__":
run_tests()

View File

@ -117,22 +117,6 @@ class MixOrderReductionTest(TestBase):
metrics.codegen_mix_order_reduction,
)
@inductor_config.patch(coordinate_descent_tuning=True)
def test_XBLOCK_coordest_tuning(self):
"""
We should skip XBLOCK coordinate descent tuning for
mix order reduction.
"""
if not inductor_config.triton.mix_order_reduction:
self.skipTest("Mix order reduction not enabled")
def f(x):
return x.sum(dim=-1), x.sum(dim=0)
x = torch.randn(32768, 256, dtype=torch.float, device=GPU_TYPE)
self.check_numeric(f, (x,))
self.assertEqual(metrics.codegen_mix_order_reduction, 1)
@inductor_config.patch(unroll_reductions_threshold=1)
def test_3layer_split_reduction(self):
"""

View File

@ -25,9 +25,6 @@ from typing import Any, Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
from collections.abc import Sequence
import threading
from contextlib import contextmanager
import torch
import torch.utils._pytree as pytree
import torch.utils.dlpack
@ -100,43 +97,6 @@ from .utils import (
)
_thread_local = threading.local()
# Saved tensor hooks context
# Compiled saved tensor hooks are convenient way to inline some logic in the graphs
# for saved nodes from forward to backward. (E.g. activations quantization)
# In base implementation user does not have any additional information about saved value
# in the hook, except FakeTensor shape, dtype, device etc.
# _get_saved_tensor_hook_context gives additional graph information about that saved value,
# that can be used to make a decisions which pack/unpack to apply for particular saved value.
# This allows user to reuse saved tensors hooks api to apply selective pack/unpack in
# graph aware way.
# Alternative to this will be making user to write a custom pass that mucks with forward outputs,
# backward input metadata, which requires significantly more effort.
#
# As for now in context we expose forward graph, backward graph and current saved node,
# which contains node.meta with additional information about that fx.Node.
# Warning: This API may change without backward compatibility.
@contextmanager
def _saved_tensor_hook_context(state: dict[str, Any]):
previous_state = getattr(_thread_local, "state", None)
try:
_thread_local.state = state
yield
finally:
# Clean up: restore previous state or remove attribute
if previous_state is not None:
_thread_local.state = previous_state
else:
if hasattr(_thread_local, "state"):
delattr(_thread_local, "state")
def _get_saved_tensor_hook_context() -> dict[str, Any] | None:
return getattr(_thread_local, "state", None)
zip = strict_zip
log = logging.getLogger(__name__)
@ -1137,10 +1097,6 @@ def maybe_inline_graph_saved_tensors_hooks(
if not isinstance(val, torch.Tensor):
continue
def _get_extra_info() -> dict[str, Any]:
return {"_fw_graph": fw_g, "_bw_graph": bw_g, "_node": saved}
with _saved_tensor_hook_context(_get_extra_info()):
pack_out_val = pack_hook_gm(val)
requires_sc_handling = any(
@ -1153,7 +1109,6 @@ def maybe_inline_graph_saved_tensors_hooks(
" in the pack hook, and reconstructing the subclass in the unpack hook"
)
with _saved_tensor_hook_context(_get_extra_info()):
pack_gm = prepare_hook_gm(aot_config, pack_hook_gm, (val,))
pack_g = pack_gm.graph
maybe_log_graph(
@ -1233,7 +1188,6 @@ def maybe_inline_graph_saved_tensors_hooks(
# Install unpack hook graph as a prologue of backward graph
# Saved tensors inputs are replaced with packed tensors and packed sym scalars.
# The saved tensors inputs usages in the graph are replaced with unpack hook graph outputs.
with _saved_tensor_hook_context(_get_extra_info()):
unpack_gm = prepare_hook_gm(aot_config, unpack_hook_gm, (pack_out_val,))
unpack_g = unpack_gm.graph
maybe_log_graph(

View File

@ -498,7 +498,6 @@ def generate_ttir(
# pyrefly: ignore # missing-attribute
codegen_fns = backend.get_codegen_implementation(*codegen_args)
module_map = backend.get_module_map()
# pyrefly: ignore[missing-argument,bad-argument-type]
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
else:
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []

View File

@ -98,7 +98,7 @@ def _default_custom_combo_kernel_horizontal_partition(
]
short_reduction = [n for n in reduction if n not in long_reduction]
if long_reduction:
log.warning(
log.debug(
"ComboKernels: %d long reduction nodes are separated",
len(long_reduction),
)
@ -112,7 +112,7 @@ def _default_custom_combo_kernel_horizontal_partition(
]
if large_pointwise:
# TODO benchmark the performance when large pointwise nodes combining with others
log.warning(
log.debug(
"ComboKernels: %d large pointwise nodes are separated",
len(large_pointwise),
)

View File

@ -546,6 +546,10 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
).upper() # type: ignore[assignment]
cutedsl_enable_autotuning: bool = (
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
)
# DEPRECATED. This setting is ignored.
autotune_fallback_to_aten = False

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs
import logging
from collections.abc import Sequence
from functools import partial
from pathlib import Path
from typing import Any
import torch
@ -12,6 +14,7 @@ from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
from .. import config
from ..codegen.wrapper import PythonWrapperCodegen
from ..ir import _IntLike, Layout, TensorBox
from ..utils import load_template
log = logging.getLogger(__name__)
@ -254,3 +257,7 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
return False
return True
_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)

View File

@ -1,10 +1,11 @@
# mypy: allow-untyped-defs
import logging
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Any, Optional
import torch
from torch._dynamo.utils import counters
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
from torch._inductor.runtime.triton_compat import tl
from torch._inductor.virtualized import V
from torch.utils._triton import has_triton
@ -18,19 +19,25 @@ from ..select_algorithm import (
TritonTemplate,
)
from ..utils import (
ensure_cute_available,
get_gpu_shared_memory,
get_num_sms,
has_free_symbols,
use_aten_gemm_kernels,
use_blackwell_cutedsl_grouped_mm,
use_triton_template,
)
from .mm_common import (
_is_static_problem,
check_supported_striding,
load_kernel_template,
persistent_grouped_mm_grid,
)
if ensure_cute_available():
from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
log = logging.getLogger(__name__)
aten = torch.ops.aten
@ -513,6 +520,11 @@ triton_scaled_grouped_mm_template = TritonTemplate(
source=triton_grouped_mm_source,
)
cutedsl_grouped_mm_template = CuteDSLTemplate(
name="grouped_gemm_cutedsl",
source=load_kernel_template("cutedsl_mm_grouped"),
)
def grouped_mm_args(
mat1: TensorBox,
@ -714,12 +726,6 @@ def _tuned_grouped_mm_common(
# Checking only for the equality of corresponding dims of
# multiplicands here, relying on meta function checks for
# everything else.
if (
is_nonzero
and use_triton_template(layout)
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
):
scaled = scale_a is not None
if len(m1_size) == 2:
if len(m2_size) == 2:
m, k1 = m1_size
@ -752,6 +758,13 @@ def _tuned_grouped_mm_common(
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, False
if (
is_nonzero
and use_triton_template(layout)
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
):
scaled = scale_a is not None
a_is_k_major = mat_a.get_stride()[-1] == 1
b_is_k_major = mat_b.get_stride()[-2] == 1
@ -788,6 +801,22 @@ def _tuned_grouped_mm_common(
**config.kwargs,
)
if use_blackwell_cutedsl_grouped_mm(
mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result
):
for config in get_groupgemm_configs():
kwargs = dict(
ACC_DTYPE="cutlass.Float32",
)
cutedsl_grouped_mm_template.maybe_append_choice(
choices,
input_nodes=input_nodes,
layout=layout,
**kwargs,
**asdict(config),
)
input_gen_fns = {
4: lambda x: create_offsets(
x, m1_size, m2_size, offs.get_size() if offs is not None else None

View File

@ -0,0 +1,333 @@
import functools
from torch._inductor.runtime.runtime_utils import ceildiv
from cutlass.utils import TensorMapUpdateMode
{{gen_defines()}}
# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ----
from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import (
GroupedGemmKernel,
)
# Note about caching:
# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor
# maintains its own local caching system. At this stage, all compile-time
# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel
# name itself ({{kernel_name}}) are permanently baked into the file, so they
# do not need to be included in any cache key.
#
# The caching mechanism is split into two levels:
#
# 1. prep_cache
# Caches the compiled executor for build_group_ptrs_from_bases(). This
# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C,
# and can therefore be safely reused across runs with different group
# partitioning (`offs`).
#
# 2. gemm_cache
# Caches the compiled Grouped GEMM executor. Its key extends the prep
# cache key with hardware- and grid-specific parameters:
# (prep_cache_key, max_active_clusters, total_num_clusters).
# This is necessary because different `offs` tensors can change the
# per-group problem sizes and thus alter `total_num_clusters`, which in
# turn changes the grid shape and persistent scheduler configuration.
# Kernels compiled for one grid cannot be safely reused for another.
#
#
# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically,
# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead,
# despite depending only on the GPU type. We cache this function to mitigate
# redundant recompiles even when shape/stride/dtype cache misses force kernel
# regeneration. A follow-up study will investigate the root cause.
prep_cache = {}
gemm_cache = {}
@functools.lru_cache
def get_hardware_info():
hw = cutlass.utils.HardwareInfo()
sm_count = hw.get_max_active_clusters(1)
max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N)
return (sm_count, max_active_clusters)
def get_prep_cache_key(input_a, input_b, output):
"""
Returns a tuple key for caching the preprocessing kernel executor based on kernel name,
shapes, strides, and dtypes of input/output tensors.
"""
return (
tuple(input_a.shape),
tuple(input_a.stride()),
input_a.dtype,
tuple(input_b.shape),
tuple(input_b.stride()),
input_b.dtype,
tuple(output.shape),
tuple(output.stride()),
output.dtype,
)
def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters):
"""
Returns a tuple key for caching the gemm kernel executor by extending the
prep cache key with hardware- and grid-specific parameters.
"""
return (
prep_cache_key,
max_active_clusters,
total_num_clusters,
)
@cute.kernel
def build_group_ptrs_from_bases_kernel(
base_A_u64: cutlass.Int64, # device addr of input_a (bytes)
base_B_u64: cutlass.Int64, # device addr of input_b (bytes)
base_C_u64: cutlass.Int64, # device addr of Output (bytes)
offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative
K: cutlass.Constexpr,
N: cutlass.Constexpr,
sizeof_element: cutlass.Int32, # bytes
# -------- STRIDES (in ELEMENTS) --------
stride_A_m_elems: cutlass.Constexpr, # A.stride(0)
stride_A_k_elems: cutlass.Constexpr, # A.stride(1)
stride_B0_elems: cutlass.Constexpr, # B.stride(0)
stride_Bk_elems: cutlass.Constexpr, # B.stride(1)
stride_Bn_elems: cutlass.Constexpr, # B.stride(2)
stride_C_m_elems: cutlass.Constexpr, # C.stride(0)
stride_C_n_elems: cutlass.Constexpr, # C.stride(1)
# -------- OUTPUTS --------
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr)
out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1)
out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]]
):
tidx, _, _ = cute.arch.thread_idx()
g = tidx
m_beg_i32 = 0
if g > 0:
m_beg_i32 = offs[g - 1]
m_end_i32 = offs[g]
m_g_i32 = m_end_i32 - m_beg_i32
a_byte_off = (
cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element)
)
c_byte_off = (
cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element)
)
b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element)
# ---- pointers ----
out_ptrs[g, 0] = base_A_u64 + a_byte_off
out_ptrs[g, 1] = base_B_u64 + b_byte_off
out_ptrs[g, 2] = base_C_u64 + c_byte_off
# ---- (m, n, k, 1) ----
out_problem[g, 0] = m_g_i32
out_problem[g, 1] = N
out_problem[g, 2] = K
out_problem[g, 3] = cutlass.Int32(1)
# ---- strides ----
out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems)
out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems)
out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems)
out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems)
out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems)
out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems)
@cute.jit
def launch_build_group_ptrs_from_bases(
base_A_u64: cutlass.Int64,
base_B_u64: cutlass.Int64,
base_C_u64: cutlass.Int64,
offs: cute.Tensor,
G: cutlass.Constexpr,
K: cutlass.Constexpr,
N: cutlass.Constexpr,
sizeof_element: cutlass.Constexpr,
stride_A_m_elems: cutlass.Constexpr,
stride_A_k_elems: cutlass.Constexpr,
stride_B0_elems: cutlass.Constexpr,
stride_Bk_elems: cutlass.Constexpr,
stride_Bn_elems: cutlass.Constexpr,
stride_C_m_elems: cutlass.Constexpr,
stride_C_n_elems: cutlass.Constexpr,
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64
out_problem: cute.Tensor, # [G,4] cutlass.Int32
out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32
stream: cuda.CUstream,
):
build_group_ptrs_from_bases_kernel(
base_A_u64,
base_B_u64,
base_C_u64,
offs,
K,
N,
sizeof_element,
stride_A_m_elems,
stride_A_k_elems,
stride_B0_elems,
stride_Bk_elems,
stride_Bn_elems,
stride_C_m_elems,
stride_C_n_elems,
out_ptrs,
out_problem,
out_strides_abc,
).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream)
{{def_kernel("input_a", "input_b", "input_a_offs")}}
stream = cuda.CUstream(stream)
input_b = input_b.transpose(1, 2)
sumM, K = input_a.shape
G, N, Kb = input_b.shape
dev = input_a.device
base_A_u64 = int(input_a.data_ptr())
base_B_u64 = int(input_b.data_ptr())
base_C_u64 = int({{get_output()}}.data_ptr())
ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64)
probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32)
strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32)
ptrs = from_dlpack(ptrs_t)
probs = from_dlpack(probs_t)
strides = from_dlpack(strides_t)
prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}})
prep_executor = prep_cache.get(prep_cache_key)
if prep_executor is None:
sizeof_element = int(input_a.element_size())
sA_m, sA_k = map(int, input_a.stride())
sB_0, sB_n, sB_k = map(int, input_b.stride())
sC_m, sC_n = map(int, {{get_output()}}.stride())
prep_executor = cute.compile(
launch_build_group_ptrs_from_bases,
base_A_u64=base_A_u64,
base_B_u64=base_B_u64,
base_C_u64=base_C_u64,
offs=from_dlpack(input_a_offs),
G=int(G),
K=int(K),
N=int(N),
sizeof_element=sizeof_element,
stride_A_m_elems=sA_m,
stride_A_k_elems=sA_k,
stride_B0_elems=sB_0,
stride_Bk_elems=sB_k,
stride_Bn_elems=sB_n,
stride_C_m_elems=sC_m,
stride_C_n_elems=sC_n,
out_ptrs=ptrs,
out_problem=probs,
out_strides_abc=strides,
stream=stream,
)
prep_cache[prep_cache_key] = prep_executor
prep_executor(
base_A_u64=base_A_u64,
base_B_u64=base_B_u64,
base_C_u64=base_C_u64,
offs=from_dlpack(input_a_offs),
out_ptrs=ptrs,
out_problem=probs,
out_strides_abc=strides,
stream=stream,
)
# --- Tensormap workspace per SM ---
num_tensormap_buffers, max_active_clusters = get_hardware_info()
tensormap_shape = (
num_tensormap_buffers,
GroupedGemmKernel.num_tensormaps,
GroupedGemmKernel.bytes_per_tensormap // 8,
)
tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64)
tensormap_workspace = from_dlpack(tensormap_workspace_t)
# --- Total clusters ---
def compute_total_num_clusters(
problem_sizes_mnkl,
cluster_tile_shape_mn,
):
total_num_clusters = 0
for m, n, _, _ in problem_sizes_mnkl:
num_clusters_mn = tuple(
ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn)
)
total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
return total_num_clusters
# Compute cluster tile shape
def compute_cluster_tile_shape(
mma_tiler_mn,
cluster_shape_mn,
use_2cta_instrs,
):
cta_tile_shape_mn = list(mma_tiler_mn)
if use_2cta_instrs:
cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2
return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))
cluster_tile_shape_mn = compute_cluster_tile_shape(
(TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA)
)
total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn))
gemm_cache_key = get_gemm_cache_key(
prep_cache_key, max_active_clusters, total_num_clusters
)
gemm_executor = gemm_cache.get(gemm_cache_key)
if gemm_executor is None:
grouped_gemm = GroupedGemmKernel(
acc_dtype=ACC_DTYPE,
use_2cta_instrs=USE_2_CTA,
mma_tiler_mn=(TILE_M, TILE_N),
cluster_shape_mn=(CLUSTER_M, CLUSTER_N),
tensormap_update_mode=TENSORMAP_UPDATE_MODE,
)
gemm_executor = cute.compile(
grouped_gemm,
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
G,
probs,
strides,
ptrs,
total_num_clusters,
tensormap_workspace,
max_active_clusters,
stream,
)
gemm_cache[gemm_cache_key] = gemm_executor
gemm_executor(
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
probs,
strides,
ptrs,
tensormap_workspace,
stream,
)

View File

@ -5,8 +5,6 @@ import logging
from collections.abc import Callable
from typing import TYPE_CHECKING
from torch.utils._ordered_set import OrderedSet
from .hints import TRITON_MAX_BLOCK
from .runtime_utils import red_text, triton_config_to_hashable
@ -56,7 +54,6 @@ class CoordescTuner:
name="unknown",
size_hints=None,
inductor_meta=None,
frozen_fields=None,
):
self.is_mm = is_mm # we will tune num_stages for mm
@ -69,9 +66,6 @@ class CoordescTuner:
self.name = name
self.size_hints = size_hints
self.inductor_meta = inductor_meta or {}
self.frozen_fields: OrderedSet[str] = (
OrderedSet(frozen_fields) if frozen_fields is not None else OrderedSet()
)
def get_config_max(self, prefix: str) -> int:
max_block = TRITON_MAX_BLOCK[prefix.upper()]
@ -123,7 +117,7 @@ class CoordescTuner:
out.append("num_stages")
out.remove("ZBLOCK") # ZBLOCK=1 always in native matmul
return [f for f in out if f not in self.frozen_fields]
return out
def value_too_large(self, name: str, val: int) -> bool:
block_suffix = "BLOCK"

View File

@ -336,7 +336,6 @@ class CachingAutotuner(KernelInterface):
name=self.fn.__name__,
size_hints=size_hints,
inductor_meta=self.inductor_meta,
frozen_fields=self.get_coordesc_frozen_fields(),
)
self.filename = filename
@ -366,13 +365,6 @@ class CachingAutotuner(KernelInterface):
# Mode for launch grid calculation
self.grid_mode: Literal["python", "cpp"] = "python"
def get_coordesc_frozen_fields(self) -> OrderedSet[str]:
out: OrderedSet[str] = OrderedSet()
if self.inductor_meta.get("RSPLIT_SIZE"):
# We fix XBLOCK for mix order reduction
out.add("XBLOCK")
return out
def is_statically_launchable(self):
"""
Checks if every compiled kernel is statically launchable, which

View File

@ -0,0 +1,141 @@
from dataclasses import dataclass
from enum import auto, Enum
from itertools import product
import torch._inductor.config as config
class TensorMapUpdateMode(Enum):
"""Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency."""
SMEM = auto()
GMEM = auto()
@dataclass(frozen=True)
class CuTeGemmConfig:
TILE_M: int = 128
TILE_N: int = 192
CLUSTER_M: int = 2
CLUSTER_N: int = 1
USE_2_CTA: bool = False
TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM
def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
For information regarding valid config sets, see:
https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py
"""
# Tile_n is always the same regardless of 2cta
tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256]
# Valid clusters
clusters_no_2cta = [
(1, 1),
(1, 2),
(1, 4),
(1, 8),
(1, 16),
(2, 1),
(2, 2),
(2, 4),
(2, 8),
(4, 1),
(4, 2),
(4, 4),
(8, 1),
(8, 2),
(16, 1),
]
clusters_2cta = [
(2, 1),
(2, 2),
(2, 4),
(2, 8),
(4, 1),
(4, 2),
(4, 4),
(8, 1),
(8, 2),
(16, 1),
]
configs: list[CuTeGemmConfig] = []
for use_2cta, cluster_set, tile_m_range in [
(False, clusters_no_2cta, [64, 128]),
(True, clusters_2cta, [128, 256]),
]:
for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product(
[TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM],
tile_m_range,
tile_n_vals,
cluster_set,
):
configs.append(
CuTeGemmConfig(
tile_m,
tile_n,
cluster_m,
cluster_n,
USE_2_CTA=use_2cta,
TENSORMAP_UPDATE_MODE=tensormap_update_mode,
)
)
return configs
def get_default_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
"""
config_tuples = [
(128, 256, 2, 1, False, TensorMapUpdateMode.SMEM),
(256, 160, 2, 1, True, TensorMapUpdateMode.GMEM),
(256, 256, 2, 1, True, TensorMapUpdateMode.GMEM),
(64, 32, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
(128, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
(256, 256, 2, 2, True, TensorMapUpdateMode.GMEM),
(128, 256, 1, 2, False, TensorMapUpdateMode.GMEM),
(64, 32, 1, 1, False, TensorMapUpdateMode.SMEM),
(256, 256, 2, 1, True, TensorMapUpdateMode.SMEM),
(128, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
(256, 256, 8, 1, True, TensorMapUpdateMode.GMEM),
(64, 32, 1, 2, False, TensorMapUpdateMode.SMEM),
(256, 192, 2, 1, True, TensorMapUpdateMode.GMEM),
(256, 256, 2, 2, True, TensorMapUpdateMode.SMEM),
(128, 96, 1, 2, False, TensorMapUpdateMode.SMEM),
(64, 192, 1, 1, False, TensorMapUpdateMode.SMEM),
(64, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 192, 1, 1, False, TensorMapUpdateMode.GMEM),
(128, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 160, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
]
return [CuTeGemmConfig(*args) for args in config_tuples]
def get_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures
or unstable results. By default, autotuning is disabled and we return only
a single baseline config.
"""
if (
config.cutedsl_enable_autotuning
and config.max_autotune_gemm_search_space == "EXHAUSTIVE"
):
return get_exhaustive_groupgemm_configs()
elif config.cutedsl_enable_autotuning:
return get_default_groupgemm_configs()
else:
return [get_default_groupgemm_configs()[0]]

View File

@ -1975,6 +1975,77 @@ def use_triton_blackwell_tma_template(
return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
@functools.lru_cache(maxsize=1)
def ensure_cute_available() -> bool:
"""Check if CuTeDSL is importable; cache the result for reuse.
Call ensure_cute_available.cache_clear() after installing CuTeDSL
in the same interpreter to retry the import.
"""
try:
return importlib.util.find_spec("cutlass.cute") is not None
except ImportError:
return False
def use_blackwell_cutedsl_grouped_mm(
mat_a: Any,
mat_b: Any,
layout: Layout,
a_is_2d: bool,
b_is_2d: bool,
offs: Optional[Any],
bias: Optional[Any],
scale_result: Optional[Any],
) -> bool:
"""
Returns True if we can use the blackwell kernel for grouped mm.
Required conditions:
1. CuTeDSL is available
2. We are on a blackwell arch
3. The dtype is bf16
4. Max autotune or max autotune gemm is enabled
6. A, B, and the output are 16B aligned
7. We are not using dynamic shapes
8. A is 2d
9. B is 3d
10. Offsets are provided
11. Bias and Scale are not provided
"""
if not ensure_cute_available():
return False
from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
if not is_gpu(layout.device.type) and is_datacenter_blackwell_arch():
return False
layout_dtypes = [torch.bfloat16]
if not _use_template_for_gpu(layout, layout_dtypes):
return False
if not (config.max_autotune or config.max_autotune_gemm):
return False
# Checks for 16B ptr and stride alignment
if not can_use_tma(mat_a, mat_b, output_layout=layout):
return False
if any(is_dynamic(x) for x in [mat_a, mat_b]):
return False
if not a_is_2d or b_is_2d:
return False
if offs is None:
return False
if bias is not None or scale_result is not None:
return False
return True
def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
from .virtualized import V

View File

@ -1228,7 +1228,7 @@ def _get_pynvml_handler(device: "Device" = None):
"nvidia-ml-py does not seem to be installed or it can't be imported."
# pyrefly: ignore [invalid-inheritance]
) from _PYNVML_ERR
# pyrefly: ignore [import-error,missing-module-attribute]
# pyrefly: ignore [import-error]
from pynvml import NVMLError_DriverNotLoaded
try:

View File

@ -828,7 +828,7 @@ def list_gpu_processes(device: "Device" = None) -> str:
import pynvml # type: ignore[import]
except ModuleNotFoundError:
return "pynvml module not found, please install nvidia-ml-py"
# pyrefly: ignore [import-error,missing-module-attribute]
# pyrefly: ignore [import-error]
from pynvml import NVMLError_DriverNotLoaded
try:

View File

@ -23,7 +23,6 @@ from torch.distributed.tensor._ops.utils import (
map_placements_after_broadcast,
prod,
register_op_strategy,
register_single_dim_strategy,
)
from torch.distributed.tensor._utils import (
compute_local_shape_and_global_offset,
@ -238,130 +237,10 @@ def dot_strategy(op_schema: OpSchema) -> OpStrategy:
return _mm_like_strategy("i,i->", mesh, op_schema)
# @register_op_strategy(aten.mm.default)
# def mm_strategy(op_schema: OpSchema) -> OpStrategy:
# mesh = op_schema.get_mesh_from_args()
# return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
from ._einsum_strategy import EinsumDims
def gen_single_dim_einsum_strategies(
equation: str,
mesh: DeviceMesh,
*,
linearity: bool = False,
) -> list[Placement]:
"""
Generate a strategy list for the ops that follow einsum style notation.
In principle, each mesh dim is independent of other device mesh dim when we
generate strategies. So we generate strategy over each device mesh dim and
do product combination on all mesh dims. We basically follow the below rule
for each device mesh dim:
1. Shard on contracting dim: When both inputs shard on contracting dim over
the same device dim. The result will be Partial over that device dim.
2. Shard on noncontracting dim:
2.1: Shard on batch dim: output, both inputs all should shard on batch
dim.
2.2: Shard on lhs only dim or rhs only dim: both output and lhs or rhs
input should shard on this free dim.
3. Linearity (Partial): If enabled, set Partial on output and inputs over
the same device mesh dim.
"""
# parse einop equation and extract dims
input_dims, output_dim = EinsumDims.parse_equation(equation)
edims = EinsumDims.parse_dims(input_dims, output_dim)
all_mesh_dim_strategies = []
# generate strategies for each mesh dim and do cartesian product for final strategy. E.g., for a 2D mesh, we can have [P(),R,R]
strategies_over_one_mesh_dim = []
# placement list stores placements of [output, input1, input2, ...]
# first we always have replicate all for inputs and output
placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1)
strategies_over_one_mesh_dim.append(placement_list)
# split batch dim
for batch_dim in edims.batch_dims:
output_batch_dim = output_dim.index(batch_dim)
placement_list = [Shard(output_batch_dim)]
for input_dim in input_dims:
input_batch_dim = input_dim.index(batch_dim)
placement_list.append(Shard(input_batch_dim))
strategies_over_one_mesh_dim.append(placement_list)
# split contracting dim
for contracting_dim in edims.contracting_dims:
# Contracting dim can shard on same device axis for both inputs. This
# results in the output being Partial on that device axis. For example:
# bmk_{x},k_{x}n -> bmn{Ux} (becomes partial over device axis x)
placement_list = [Partial()]
for input_dim in input_dims:
input_contracting_dim = input_dim.index(contracting_dim)
placement_list.append(Shard(input_contracting_dim))
strategies_over_one_mesh_dim.append(placement_list)
# split lhs free dim
for lhs_dim in edims.lhs_out_only_dims:
lhs_free_dim_output = output_dim.index(lhs_dim)
lhs_free_dim_input = input_dims[0].index(lhs_dim)
# this means split the lhs input and output
# i.e. S(0), R -> S(0)
lhs_placement_list: list[Placement] = [
Shard(lhs_free_dim_output),
Shard(lhs_free_dim_input),
Replicate(),
]
strategies_over_one_mesh_dim.append(lhs_placement_list)
# split rhs free dim
for rhs_dim in edims.rhs_out_only_dims:
rhs_free_dim_output = output_dim.index(rhs_dim)
rhs_free_dim_input = input_dims[1].index(rhs_dim)
rhs_placement_list: list[Placement] = [
Shard(rhs_free_dim_output),
Replicate(),
Shard(rhs_free_dim_input),
]
strategies_over_one_mesh_dim.append(rhs_placement_list)
# linearity strategy
if linearity:
linearity_placement_list: list[Placement] = [Partial()]
for _ in input_dims:
linearity_placement_list.append(Partial())
strategies_over_one_mesh_dim.append(linearity_placement_list)
# generate strategies for entire mesh
# all_mesh_dim_strategies = [strategies_over_one_mesh_dim] * mesh.ndim
# strategy_combs = itertools.product(*all_mesh_dim_strategies)
# all_strategies = []
# for strategy_comb in strategy_combs:
# spec_list = [DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)]
# strat = OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
# all_strategies.append(strat)
# return OpStrategy(all_strategies)
return strategies_over_one_mesh_dim
@register_single_dim_strategy(aten.mm.default)
def mm_single_dim_strategy(op_schema: OpSchema) -> list[Placement]:
self_strategy, mat2_strategy = op_schema.args_schema
if not isinstance(self_strategy, OpStrategy):
raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}")
if not isinstance(mat2_strategy, OpStrategy):
raise AssertionError(f"Expected OpStrategy, got {type(mat2_strategy)}")
# generate all possible strategies for mm
@register_op_strategy(aten.mm.default)
def mm_strategy(op_schema: OpSchema) -> OpStrategy:
mesh = op_schema.get_mesh_from_args()
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
@register_op_strategy(aten.addmm.default)

View File

@ -41,8 +41,6 @@ from torch.distributed.tensor.placement_types import (
aten = torch.ops.aten
# WHC- i think anywhere this is used, we can replace it with a corresponding single-dim passthrough strategy
# (anyshard, replicate, partial can all pass through- and then expand that to the mesh dims later)
def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType:
# For ops with a single tensor input, we perform a 1:1 mapping such that
# for each strategy that the input supports, we create a corresponding strategy.
@ -99,28 +97,6 @@ register_op_strategy(
)(propagate_single_input_strategy)
"""
WHC- equal_strategy is an example baking an optimization into the sharding rule.
The unoptimized equal strategy (for one mesh dim) should look like this
S, S -> S
R, R -> R
P, P -> P * - this could work, i think, if we supported a Partial of boolean and reduction?
And this should be expanded to the full mesh.
But what this rule actually does is
- compare the two tensor args to equal- look at the strategies for each, which represent the I-O sharding relationship for the
op that produced those tensor args. Pick the one that has the strategy (OpSpec) with the most Shard() placements in it.
Why? becuase converting the other arg from R->S is cheaper than converting S->R
- start with the assumption that the 'equal' op has the same strategy as the op that produced its max-shard input
- then adjust the placements from partial to replicate since we don't support partial in equal
- finally, produce an OpSpec that only populates the 'output_specs' of OpSpec
TODO: why is it ok to populate only the output_specs of an OpSpec? Is it defined to mean that all input specs are the same as the output spec?
"""
@register_op_strategy(
[
aten.equal.default,
@ -164,19 +140,6 @@ def equal_strategy(op_schema: OpSchema) -> StrategyType:
return equal_strategy
"""
WHC
seems like we could replace this with single-mesh strategy
S->S
R->R
P->R
The P->R thing is odd, but makes sense:
* can't support P->P since it would be incorrect to create a new 'partial' tensor from ones, which would no longer be ones if we replicated them
* don't want to omit the support for input Partial becuase we'd force a replication on the input which would be wasteful
"""
@register_op_strategy(
[
aten.empty_like.default,
@ -518,19 +481,6 @@ def replicate_tensor_dim(
)
"""
WHC- example of a complicated 'follow your inputs' strategy that would be useful to try out as a simple rule
seems very simple to write this way
assert input, src same ndim
for i in range(input.ndim):
if i != slice_dim:
Shard(i), Shard(i) -> Shard(i)
"""
@register_op_strategy(aten.slice_scatter.default, schema_info=RuntimeSchemaInfo(2))
def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType:
# 1. number of dimensions in input and src need to match.

View File

@ -4,7 +4,8 @@ import functools
import itertools
import operator
from collections.abc import Callable, Iterable, Sequence
from typing import cast, Optional, Union
from typing import cast, Optional, TypeVar, Union
from typing_extensions import ParamSpec
import torch
from torch._prims_common import DimsSequenceType, DimsType
@ -29,6 +30,10 @@ from torch.distributed.tensor.placement_types import (
)
_T = TypeVar("_T")
_P = ParamSpec("_P")
# convenient wrapper to register sharding propagation rules
def register_prop_rule(
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
@ -49,61 +54,11 @@ def register_prop_rule(
return wrapper
def register_single_dim_strategy(
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
schema_info: Optional[RuntimeSchemaInfo] = None,
) -> Callable[
[Callable[[OpSchema], list[Placement]]], Callable[[OpSchema], StrategyType]
]:
"""
Registers a simplified op strategy that only considers a single mesh dim, taking care to expand it
to cover all the mesh dims present in the runtime inputs.
"""
def expanded_registration_wrapper(
single_dim_strategy: Callable[[OpSchema], list[Placement]],
) -> Callable[[OpSchema], StrategyType]:
def _expanded_strategy(op_schema: OpSchema) -> StrategyType:
"""
Expands the single_mesh_dim impl across all mesh dims, and expands ShardingPlacholder into all
sharding types used by inputs.
"""
inputs_strategy = op_schema.args_strategy
mesh = inputs_strategy[0].mesh
strategies_over_one_mesh_dim = single_dim_strategy(op_schema)
# TODO: handle 'ShardingPlaceholder' expansion (doesn't exist yet)
# TODO: filter out 'invalid' placements
# - ShardVar needs to say whether 'even sharding' is required or not
# copied from einsum strategy..
# TODO: identify differences between this and 'expand_' util
all_mesh_dim_strategies = [strategies_over_one_mesh_dim] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = [
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
]
all_strategies.append(
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
)
return OpStrategy(all_strategies)
# register_op_strategy returns another wrapper that actually does the strategy registration,
# we just add another layer of wrapping that expands the single_dim_strategy into a strategy that's
# compatible with register_op_strategy
register_op_strategy(op, schema_info)(_expanded_strategy)
return _expanded_strategy
return expanded_registration_wrapper
def register_op_strategy(
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
schema_info: Optional[RuntimeSchemaInfo] = None,
) -> Callable[[Callable[[OpSchema], StrategyType]], Callable[[OpSchema], StrategyType]]:
op, schema_info=None
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
# pyre-fixme[2]: Parameter must be annotated.
# For every ATen op that accepts any args in this list,
# the arg itself can impact the strides (and potentially the sharding strategy)
# of the output tensor.
@ -113,9 +68,7 @@ def register_op_strategy(
"memory_format",
]
def wrapper(
impl: Callable[[OpSchema], StrategyType],
) -> Callable[[OpSchema], StrategyType]:
def wrapper(impl):
if isinstance(op, list):
overloads = op
else:
@ -206,10 +159,7 @@ def prod(xs: Iterable[int]) -> int:
def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
"""Check if the spec matches these criteria:
* any Shard placements in spec refer to valid tensor dims
* no empty local tensors (uneven sharding OK, as long as last rank has >0 size)
"""
"""Check if the shape is shardable according to the spec."""
# number of shards in each tensor dimension
shards_map = [1] * len(shape)
for i, placement in enumerate(spec.placements):

View File

@ -17,5 +17,230 @@ def is_stdlib_module(module: str) -> bool:
def _get_stdlib_modules():
assert sys.version_info >= (3, 10)
return sys.stdlib_module_names
if sys.version_info.major == 3: # noqa: UP036
if sys.version_info.minor == 9:
return stdlib3_9
if sys.version_info.minor >= 10: # noqa: YTT204
return sys.stdlib_module_names # type: ignore[attr-defined]
elif sys.version_info.major > 3: # noqa: UP036
return sys.stdlib_module_names # type: ignore[attr-defined]
raise RuntimeError(f"Unsupported Python version: {sys.version_info}")
stdlib3_9 = {
"_thread",
"abc",
"aifc",
"argparse",
"array",
"ast",
"asynchat",
"asyncio",
"asyncore",
"atexit",
"audioop",
"base64",
"bdb",
"binascii",
"binhex",
"bisect",
"builtins",
"bz2",
"cProfile",
"calendar",
"cgi",
"cgitb",
"chunk",
"cmath",
"cmd",
"code",
"codecs",
"codeop",
"collections",
"colorsys",
"compileall",
"concurrent",
"configparser",
"contextlib",
"contextvars",
"copy",
"copyreg",
"crypt",
"csv",
"ctypes",
"curses",
"dataclasses",
"datetime",
"dbm",
"decimal",
"difflib",
"dis",
"distutils",
"doctest",
"email",
"encodings",
"ensurepip",
"enum",
"errno",
"faulthandler",
"fcntl",
"filecmp",
"fileinput",
"fnmatch",
"formatter",
"fractions",
"ftplib",
"functools",
"gc",
"getopt",
"getpass",
"gettext",
"glob",
"graphlib",
"grp",
"gzip",
"hashlib",
"heapq",
"hmac",
"html",
"http",
"imaplib",
"imghdr",
"imp",
"importlib",
"inspect",
"io",
"ipaddress",
"itertools",
"json",
"keyword",
"lib2to3",
"linecache",
"locale",
"logging",
"lzma",
"mailbox",
"mailcap",
"marshal",
"math",
"mimetypes",
"mmap",
"modulefinder",
"msilib",
"msvcrt",
"multiprocessing",
"netrc",
"nis",
"nntplib",
"ntpath",
"numbers",
"operator",
"optparse",
"os",
"ossaudiodev",
"parser",
"pathlib",
"pdb",
"pickle",
"pickletools",
"pipes",
"pkgutil",
"platform",
"plistlib",
"poplib",
"posix",
"posixpath",
"pprint",
"profile",
"pstats",
"pty",
"pwd",
"py_compile",
"pyclbr",
"pydoc",
"queue",
"quopri",
"random",
"re",
"readline",
"reprlib",
"resource",
"rlcompleter",
"runpy",
"sched",
"secrets",
"select",
"selectors",
"shelve",
"shlex",
"shutil",
"signal",
"site",
"smtpd",
"smtplib",
"sndhdr",
"socket",
"socketserver",
"spwd",
"sqlite3",
"sre",
"sre_compile",
"sre_constants",
"sre_parse",
"ssl",
"stat",
"statistics",
"string",
"stringprep",
"struct",
"subprocess",
"sunau",
"symbol",
"symtable",
"sys",
"sysconfig",
"syslog",
"tabnanny",
"tarfile",
"telnetlib",
"tempfile",
"termios",
"test",
"textwrap",
"threading",
"time",
"timeit",
"tkinter",
"token",
"tokenize",
"trace",
"traceback",
"tracemalloc",
"tty",
"turtle",
"turtledemo",
"types",
"typing",
"unicodedata",
"unittest",
"urllib",
"uu",
"uuid",
"venv",
"warnings",
"wave",
"weakref",
"webbrowser",
"winreg",
"winsound",
"wsgiref",
"xdrlib",
"xml",
"xmlrpc",
"zipapp",
"zipfile",
"zipimport",
"zlib",
"zoneinfo",
}