Compare commits

...

20 Commits

Author SHA1 Message Date
0f2f1f7e3e Fix signals 2025-11-06 12:04:07 -08:00
ae8a9fa894 Type functions 2025-11-05 15:50:25 -08:00
6052a01b71 [BE][Typing][Dynamo] Type torch/_dynamo/variables/dicts.py (#167022)
Provides type coverage to torch/_dynamo/variables/dicts.py

Coverage report:
`mypy torch/_dynamo/variables/dicts.py --linecount-report /tmp/coverage_log`

Compare before to after - we go from 0 lines and 0 funcs covered to 1547 lines and 89 funcs covered

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167022
Approved by: https://github.com/Skylion007
2025-11-05 19:18:35 +00:00
14b153bcf2 include DTensor metadata when pretty-printing fx.Graphs (#166750)
Example below. You need to trace your function with DTensor inputs in order for the graph proxies to run on DTensor (and not the inner local tensor). You also need to run with `tracing_mode="fake"`, or with your own `FakeTensorMode`, to see the nice DTensor printing. If this doesn't feel very ergonomic then maybe we can find some better UX for printing a graph with DTensor in it:

<img width="1446" height="582" alt="image" src="https://github.com/user-attachments/assets/99ea5ce6-1008-4ba5-b58e-542cd34a340b" />

```
import torch
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.distributed.tensor import distribute_tensor, Shard, Replicate
from torch.utils._debug_mode import DebugMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils import _pytree as pytree

world_size = 8
device_type = "cpu"
fake_store = FakeStore()
torch.distributed.init_process_group("fake", store=fake_store, rank=0, world_size=world_size)
device_mesh = torch.distributed.init_device_mesh(device_type, (world_size,))
dim = 128

A = torch.randn(8, dim)
B = torch.randn(dim, dim)
dA = distribute_tensor(A, device_mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, device_mesh, [Replicate()]).requires_grad_()

def f(dA, dB):
    dy = dA @ dB
    loss = dy.sum()
    loss.backward()
    return dA.grad, dB.grad

# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode='fake')(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
gm.print_readable(colored=True)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166750
Approved by: https://github.com/ezyang, https://github.com/wconstab, https://github.com/Skylion007
2025-11-05 18:58:54 +00:00
641de23c96 ci: Add aarch64 docker builds for modern clang (#166416)
Should enable us to build using some arm optimizations that are only
available on the newest versions of clang.

Signed-off-by: Eli Uriegas <eliuriegas@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166416
Approved by: https://github.com/malfet
2025-11-05 18:55:56 +00:00
89165c0a2b Update triton to 3.5.1 release (#166968)
This includes sm103 https://github.com/triton-lang/triton/pull/8485 fix

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166968
Approved by: https://github.com/Lucaskabela, https://github.com/njriasan
2025-11-05 18:26:34 +00:00
dcc2ba4ca4 Add some code for exploring the space of accessible size/stride configs via plain views (#167076)
We are working on a translation from as_strided to view operations, but
only when the as_strided is representable as a plain view.  A useful
testing utility in this situation is the ability to enumerate all valid
views on an original tensor.  So we have a small test here that shows
it is possible.

To avoid an explosion of states, we don't handle permutes and size=1,
which are degenerate cases (you can always do a single permute and
a series of unsqueezes to get to the final desired state.)

Authored with claude code assistance.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167076
Approved by: https://github.com/albanD
ghstack dependencies: #166868, #166867
2025-11-05 18:25:19 +00:00
ad5c7c20e0 Revert "[cuDNN] Smoke-test runtime cuDNN version matches compile time version in CI (#165922)"
This reverts commit 1d3f5e19da068ec1340db041b7105b287a513578.

Reverted https://github.com/pytorch/pytorch/pull/165922 on behalf of https://github.com/atalman due to Introduces Segfault in linux-jammy-cuda12.8-py3.10-gcc11 ([comment](https://github.com/pytorch/pytorch/pull/165922#issuecomment-3492667312))
2025-11-05 18:13:57 +00:00
c86540f120 Revert "Add model code stack trace to torch.profile (#166677)"
This reverts commit c00696144dae1f02e04ce345480b55e46c7d32a8.

Reverted https://github.com/pytorch/pytorch/pull/166677 on behalf of https://github.com/jeffdaily due to broke rocm ([comment](https://github.com/pytorch/pytorch/pull/166677#issuecomment-3492658160))
2025-11-05 18:11:11 +00:00
c17aa0f113 [ROCm] Enable group gemm through CK (#166334)
Fixes #161366
All the 4 types of dimension matrix are supported.
2d-2d, 2d-3d, 3d-3d, 3d-2d. The corresponding test cases in test_matmul_cuda are working
for both forward and backward pass.
The CK path is enabled for gfx942, gfx950.
ToDo: Need to enable support on gfx90a since the ck kernel used in this commit produces gpu error,
might require a different CK kernel config, based on the profiler result on gfx90a.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166334
Approved by: https://github.com/atalman
2025-11-05 18:03:59 +00:00
4ff068c33a [Code Clean] Replace assert with if statement and raise AssertionError (#166935)
Including:
- `torch/profiler/profiler.py`

Fixes part of #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166935
Approved by: https://github.com/fffrog, https://github.com/albanD
2025-11-05 17:59:16 +00:00
0c7a4a6b48 [Inductor] Fix unbacked float symbol handling in kernel codegen (#166890)
When a fn compiled with `torch.compile` calls `.item()` on a float tensor arg (e.g., for thresholds in `torch.clamp`), the generated triton kernel references an unbacked float symbol (e.g., `zuf0`) that was never added to the kernel's parameter list, causing a compilation error.

Fixes: #166888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166890
Approved by: https://github.com/eellison
2025-11-05 17:50:08 +00:00
f93ee16fb6 [CI] Parse xml and upload json while running (#166988)
Then we can point an ClickHouse ingestor at this s3 path and get them into ClickHouse while the job is running.

use filelock to make sure each json is uploaded once so we don't end up with dups in ClickHouse
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166988
Approved by: https://github.com/izaitsevfb
2025-11-05 17:19:24 +00:00
9c2c3dbc15 Revert "Update triton to 3.5.1 release (#166968)"
This reverts commit b4e4ee81d386db922d8f63359f9870eff1f44052.

Reverted https://github.com/pytorch/pytorch/pull/166968 on behalf of https://github.com/malfet due to It might have caused deadlock/test timeouts, see d4dcd0354c/1 ([comment](https://github.com/pytorch/pytorch/pull/166968#issuecomment-3492399396))
2025-11-05 17:12:30 +00:00
d4dcd0354c [pytree][dynamo] add test to ensure tree_map preserves dict order (#166236)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166236
Approved by: https://github.com/mlazos
2025-11-05 17:04:40 +00:00
aba2fa3259 Fix clang-21 warnings (#166859)
Fixes compiler warnings thrown by Clang-21

Fixes #166755

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166859
Approved by: https://github.com/aditew01, https://github.com/fadara01, https://github.com/malfet
2025-11-05 16:55:51 +00:00
d2d13bf62d Invert unary read and write for fusion (#161404)
For [this repro](https://gist.github.com/eellison/75a99616a0fcca0436316bbfd8987fae) enables fusion of `to_blocked` with the prior `to_mx` calculation, so that there is only a single kernel per tensor, resulting in a 10% speedup of the non conversion code (need to update my local devserver to 12.9 to time the matmul as well).

The `to_mx` kernel has a contiguous write:

```Py
op6_op7: FusedSchedulerNode(SchedulerNode,SchedulerNode)
op6_op7.writes = [MemoryDep('buf6', c0, {c0: 2097152}), MemoryDep('buf7', c0, {c0: 67108864})]
op6_op7.unmet_dependencies = []
op6_op7.met_dependencies = [MemoryDep('arg1_1', c0, {c0: 67108864})]
op6_op7.outputs = [
    buf6: ComputedBuffer
    buf6.layout = FixedLayout('cuda:0', torch.float32, size=[8192, 256], stride=[256, 1])
    buf6.users = [
        NodeUser(node=SchedulerNode(name='op7'), can_inplace=False, is_weak=False),
        NodeUser(node=SchedulerNode(name='op9'), can_inplace=False, is_weak=False),
    ]
    buf7: ComputedBuffer
    buf7.layout = FixedLayout('cuda:0', torch.float8_e4m3fn, size=[8192, 256, 32], stride=[8192, 32, 1])
    buf7.users = [NodeUser(node=ExternKernelSchedulerNode(name='op10'), can_inplace=False, is_weak=False)]
]
```

While the `to_blocked` has a single discontiguous read and a single contiguous write.

```Py
op9: SchedulerNode(ComputedBuffer)
op9.writes = [MemoryDep('buf9', c0, {c0: 2097152})]
op9.unmet_dependencies = [   MemoryDep('buf6', 32768*((c0//32768)) + 8192*(((ModularIndexing(c0, 1, 16))//4)) + 256*(ModularIndexing(c0, 16, 32)) + 4*(ModularIndexing(c0, 512, 64)) + (ModularIndexing(ModularIndexing(c0, 1, 16), 1, 4)), {c0: 2097152})]
op9.met_dependencies = []
op9.outputs = [
    buf9: ComputedBuffer
    buf9.layout = FixedLayout('cuda:0', torch.float8_e8m0fnu, size=[2097152], stride=[1])
    buf9.users = [NodeUser(node=ExternKernelSchedulerNode(name='op10'), can_inplace=False, is_weak=False)]
]
```

To enable fusion, we invert the read, giving op9 and contiguous read and discontiguous write. More explanation here: https://gist.github.com/eellison/6f9f4a7ec10a860150b15b719f9285a9

[Tlparse with this optimization](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/eellison/custom/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000).

[Tlparse without this optimization](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/eellison/custom/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161404
Approved by: https://github.com/shunting314
2025-11-05 16:10:52 +00:00
7a6ff88196 Widen ops support to take in IntHOArrayRef vs only std::vec (#165152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165152
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #164991
2025-11-05 16:00:24 +00:00
59563dfe56 Refactor out headeronly ArrayRef (#164991)
Differential Revision: [D85091961](https://our.internmc.facebook.com/intern/diff/D85091961)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164991
Approved by: https://github.com/swolchok
2025-11-05 16:00:24 +00:00
5c639466f7 Revert "[Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#167003)"
This reverts commit 658c5f879c37142b1df51c7eb6c5a5bb06318597.

Reverted https://github.com/pytorch/pytorch/pull/167003 on behalf of https://github.com/atalman due to regressed vllm signal: [GH job link](https://github.com/pytorch/pytorch/actions/runs/19093785744/job/54553796743) [HUD commit link](658c5f879c) ([comment](https://github.com/pytorch/pytorch/pull/167003#issuecomment-3491527704))
2025-11-05 14:30:15 +00:00
60 changed files with 2874 additions and 1954 deletions

View File

@ -271,6 +271,16 @@ case "$tag" in
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
pytorch-linux-jammy-aarch64-py3.10-clang21)
ANACONDA_PYTHON_VERSION=3.10
CLANG_VERSION=21
ACL=yes
VISION=yes
OPENBLAS=yes
# snadampal: skipping llvm src build install because the current version
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=11

View File

@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
# work around ubuntu apt-get conflicts
sudo apt-get -y -f install
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
if [[ $CLANG_VERSION == 18 ]]; then
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
if [[ $CLANG_VERSION -ge 18 ]]; then
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main"
fi
fi

View File

@ -129,7 +129,7 @@ function install_129 {
}
function install_128 {
CUDNN_VERSION=9.10.2.21
CUDNN_VERSION=9.8.0.87
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
# install CUDA 12.8.1 in the same container
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux

View File

@ -10,6 +10,7 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
OPENBLAS_BUILD_FLAGS="
CC=gcc
NUM_THREADS=128
USE_OPENMP=1
NO_SHARED=0

View File

@ -272,18 +272,6 @@ def smoke_test_cuda(
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
print(f"Torch cuDNN version: {torch_cudnn_version}")
torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion()
print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}")
torch_cudnn_runtime_version = tuple(
[int(x) for x in torch_cudnn_version.split(".")]
)
if torch_cudnn_runtime_version != torch_cudnn_compile_version:
raise RuntimeError(
"cuDNN runtime version doesn't match comple version. "
f"Loaded: {torch_cudnn_runtime_version} "
f"Expected: {torch_cudnn_compile_version}"
)
if sys.platform in ["linux", "linux2"]:
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
print(f"Torch nccl; version: {torch_nccl_version}")

View File

@ -337,7 +337,7 @@ test_python() {
test_python_smoke() {
# Smoke tests for H100/B200
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty
}

View File

@ -79,6 +79,8 @@ jobs:
include:
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21
runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
runner: linux.arm64.m7g.4xlarge
timeout-minutes: 600

1
.gitignore vendored
View File

@ -127,7 +127,6 @@ torch/test/
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
torch/version.py
torch/_inductor/kernel/vendored_templates/*
minifier_launcher.py
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*

View File

@ -191,7 +191,7 @@ class Vectorized<BFloat16> {
auto vals = svreinterpret_u16_bf16(values);
vals = sveor_u16_x(ptrue, vals, mask);
return svreinterpret_bf16_u16(vals);
};
}
Vectorized<BFloat16> round() const;
Vectorized<BFloat16> tan() const;
Vectorized<BFloat16> tanh() const;
@ -349,47 +349,47 @@ Vectorized<BFloat16> inline Vectorized<BFloat16>::frac() const {
return convert_float_bfloat16(v1, v2); \
}
DEFINE_BF16_FUNC_VIA_FLOAT(isnan);
DEFINE_BF16_FUNC_VIA_FLOAT(angle);
DEFINE_BF16_FUNC_VIA_FLOAT(acos);
DEFINE_BF16_FUNC_VIA_FLOAT(acosh);
DEFINE_BF16_FUNC_VIA_FLOAT(asin);
DEFINE_BF16_FUNC_VIA_FLOAT(atan);
DEFINE_BF16_FUNC_VIA_FLOAT(atanh);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign);
DEFINE_BF16_FUNC_VIA_FLOAT(erf);
DEFINE_BF16_FUNC_VIA_FLOAT(erfc);
DEFINE_BF16_FUNC_VIA_FLOAT(exp);
DEFINE_BF16_FUNC_VIA_FLOAT(exp2);
DEFINE_BF16_FUNC_VIA_FLOAT(expm1);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot);
DEFINE_BF16_FUNC_VIA_FLOAT(i0);
DEFINE_BF16_FUNC_VIA_FLOAT(i0e);
DEFINE_BF16_FUNC_VIA_FLOAT(digamma);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter);
DEFINE_BF16_FUNC_VIA_FLOAT(log);
DEFINE_BF16_FUNC_VIA_FLOAT(log2);
DEFINE_BF16_FUNC_VIA_FLOAT(log10);
DEFINE_BF16_FUNC_VIA_FLOAT(log1p);
DEFINE_BF16_FUNC_VIA_FLOAT(sin);
DEFINE_BF16_FUNC_VIA_FLOAT(sinh);
DEFINE_BF16_FUNC_VIA_FLOAT(cos);
DEFINE_BF16_FUNC_VIA_FLOAT(cosh);
DEFINE_BF16_FUNC_VIA_FLOAT(ceil);
DEFINE_BF16_FUNC_VIA_FLOAT(floor);
DEFINE_BF16_FUNC_VIA_FLOAT(round);
DEFINE_BF16_FUNC_VIA_FLOAT(tan);
DEFINE_BF16_FUNC_VIA_FLOAT(tanh);
DEFINE_BF16_FUNC_VIA_FLOAT(trunc);
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma);
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt);
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal);
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow);
DEFINE_BF16_FUNC_VIA_FLOAT(isnan)
DEFINE_BF16_FUNC_VIA_FLOAT(angle)
DEFINE_BF16_FUNC_VIA_FLOAT(acos)
DEFINE_BF16_FUNC_VIA_FLOAT(acosh)
DEFINE_BF16_FUNC_VIA_FLOAT(asin)
DEFINE_BF16_FUNC_VIA_FLOAT(atan)
DEFINE_BF16_FUNC_VIA_FLOAT(atanh)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign)
DEFINE_BF16_FUNC_VIA_FLOAT(erf)
DEFINE_BF16_FUNC_VIA_FLOAT(erfc)
DEFINE_BF16_FUNC_VIA_FLOAT(exp)
DEFINE_BF16_FUNC_VIA_FLOAT(exp2)
DEFINE_BF16_FUNC_VIA_FLOAT(expm1)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot)
DEFINE_BF16_FUNC_VIA_FLOAT(i0)
DEFINE_BF16_FUNC_VIA_FLOAT(i0e)
DEFINE_BF16_FUNC_VIA_FLOAT(digamma)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter)
DEFINE_BF16_FUNC_VIA_FLOAT(log)
DEFINE_BF16_FUNC_VIA_FLOAT(log2)
DEFINE_BF16_FUNC_VIA_FLOAT(log10)
DEFINE_BF16_FUNC_VIA_FLOAT(log1p)
DEFINE_BF16_FUNC_VIA_FLOAT(sin)
DEFINE_BF16_FUNC_VIA_FLOAT(sinh)
DEFINE_BF16_FUNC_VIA_FLOAT(cos)
DEFINE_BF16_FUNC_VIA_FLOAT(cosh)
DEFINE_BF16_FUNC_VIA_FLOAT(ceil)
DEFINE_BF16_FUNC_VIA_FLOAT(floor)
DEFINE_BF16_FUNC_VIA_FLOAT(round)
DEFINE_BF16_FUNC_VIA_FLOAT(tan)
DEFINE_BF16_FUNC_VIA_FLOAT(tanh)
DEFINE_BF16_FUNC_VIA_FLOAT(trunc)
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma)
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt)
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal)
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow)
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(
const Vectorized<BFloat16>& other) const {

View File

@ -293,7 +293,7 @@ struct ComputeLocationBase<scalar_t, /*align_corners=*/false> {
, empty(size <= 0) {}
inline Vec unnormalize(const Vec &in) const {
return (in + Vec(1)) * Vec(scaling_factor) - Vec(0.5);
return (in + Vec(static_cast<scalar_t>(1))) * Vec(scaling_factor) - Vec(static_cast<scalar_t>(0.5));
}
inline Vec clip_coordinates(const Vec &in) const {
@ -831,7 +831,7 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bicubic,
// constant used in cubic convolution
// could be -0.5 or -0.75, use the same value in UpSampleBicubic2d.h
const Vec A = Vec(-0.75);
const Vec A = Vec(static_cast<scalar_t>(-0.75));
ApplyGridSample(const TensorAccessor<const scalar_t, 4>& input)
: inp_H(input.size(2))

View File

@ -22,6 +22,9 @@
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#ifdef USE_ROCM
#include <ATen/native/hip/ck_group_gemm.h>
#endif
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
@ -666,12 +669,19 @@ std::optional<c10::ScalarType> out_dtype) {
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
use_fast_path = true;
}
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
if (use_fast_path) {
// fast path, no d2h sync needed
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#endif
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
}

View File

@ -0,0 +1,19 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/ScalarType.h>
#include <optional>
namespace at {
namespace hip {
namespace detail {
void group_gemm_ck(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
at::Tensor& out);
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -0,0 +1,462 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/hip/HIPContext.h>
#include <ATen/Tensor.h>
#include <ATen/TensorAccessor.h>
#include <c10/hip/HIPStream.h>
#include <iostream>
#include <vector>
#include <optional>
#include <type_traits>
#include <ck/ck.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/utility/tuple.hpp>
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
namespace at {
namespace hip {
namespace detail {
namespace CkTypes {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
}
template <typename ALayout, typename BLayout, typename DataType>
using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage<
ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor,
DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType,
CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
1, 1,
S<1,32,1,8>, 4
>;
template <typename ALayout, typename BLayout, typename DataType>
void launch_grouped_bgemm_ck_impl_dispatch(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
at::Tensor& out)
{
using DeviceOp = GroupedGemmKernel<ALayout, BLayout, DataType>;
using PassThrough = CkTypes::PassThrough;
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a_ptrs, p_b_ptrs;
std::vector<void*> p_e_ptrs;
// Note: d_ptrs will be resized after we populate the other vectors
const int mat_a_dim = mat_a.dim();
const int mat_b_dim = mat_b.dim();
const char* a_ptr_base = reinterpret_cast<const char*>(mat_a.data_ptr());
const char* b_ptr_base = reinterpret_cast<const char*>(mat_b.data_ptr());
char* out_ptr_base = reinterpret_cast<char*>(out.data_ptr());
const size_t a_element_size = mat_a.element_size();
const size_t b_element_size = mat_b.element_size();
const size_t out_element_size = out.element_size();
// for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses.
if (mat_a_dim == 2 && mat_b_dim == 2) {
// 2D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
const int M = mat_a.size(0); // number of rows in A
const int N = mat_b.size(1); // number of columns in B
const int K = mat_a.size(1); // columns in A == rows in B
// for 2d*2d input, output is 3d.
// for each group, A columns (K) are sliced. M and N dimensions are not sliced.
for (int i = 0; i < num_groups; ++i) {
int start_k = (i == 0) ? 0 : offs_accessor[i-1];
int end_k = offs_accessor[i];
int k = end_k - start_k;
//K dimension are sliced, hence select stride(1) always.
//K dimension is always dimension 1, regardless of memory layout (row/column major)
const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size;
const void* group_b_ptr;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size;
// Leading dimension = distance between rows = stride(0)
ldb = mat_b.stride(0);
} else {
// Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size;
// Leading dimension = distance between columns = stride(1)
ldb = mat_b.stride(1);
}
// Calculate output pointer for group i in 3D tensor [num_groups, M, N]
// stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [M,K]: leading dimension = distance between rows = stride(0)
lda = mat_a.stride(0);
} else {
// Column-major A [M,K]: leading dimension = distance between columns = stride(1)
lda = mat_a.stride(1);
}
// Output is always row-major in 3D tensor [num_groups, M, N]
// Leading dimension for each group's [M,N] slice = stride(1) = N
ldc = out.stride(1);
size_t output_group_bytes = M * N * out_element_size;
void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes;
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(k),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 2 && mat_b_dim == 3) {
// 2D*3D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 2d*3d input, output is 2d.
// A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n]
// Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B
const int K = mat_a.size(1); // columns in A
// For 2D-3D case: The output determines N (result width)
const int N = out.size(1); // N is the width of the output tensor
for (int i = 0; i < num_groups; ++i) {
int start_m = (i == 0) ? 0 : offs_accessor[i - 1];
int end_m = offs_accessor[i];
int m = end_m - start_m;
// Skip zero-sized groups but continue processing subsequent groups
if (m <= 0) {
continue;
}
// Select A rows for group i: skip start_m rows
const void* group_a_ptr;
int lda;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
lda = mat_a.stride(0); // distance between rows
} else {
// Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows)
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
// Detect stride pattern for A tensor to determine appropriate lda calculation
bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0));
if (a_is_strided_tensor) {
// For strided A tensors: stride(0) gives the actual leading dimension
lda = mat_a.stride(0);
} else {
// For non-strided A tensors: use the M dimension (total rows)
lda = mat_a.size(0); // Total M dimension for column-major layout
}
}
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed
ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N]
} else {
// Detect stride pattern to determine appropriate ldb calculation
bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2));
if (is_strided_tensor) {
// For strided tensors: stride(2) gives the actual leading dimension
ldb = mat_b.stride(2);
} else {
// For non-strided tensors: use the N dimension
ldb = mat_b.size(1);
}
}
// Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N]
void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size;
int ldc = out.stride(0); // distance between rows in output (should be N for 2D case)
gemm_descs.push_back({
static_cast<ck::index_t>(m),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 3) {
// 3d*3d input, output is 3d - batched matrix multiplication
// A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n]
// Each batch is processed as a separate GEMM operation
const int batch_size = mat_a.size(0);
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed)
// Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout
int N;
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
N = mat_b.size(2);
} else if (mat_b.size(2) == K) {
// B is [batch, n, k] - transposed layout
N = mat_b.size(1);
} else {
TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[",
batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]");
}
for (int i = 0; i < batch_size; ++i) {
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
// Select output batch for group i: Output[i, :, :]
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldb, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
} else {
// Column-major A: leading dimension = distance between columns = stride(2)
lda = mat_a.stride(2);
}
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B: leading dimension = distance between rows
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(1); // stride between K rows
} else {
// B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM
ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n])
}
} else {
// Column-major B: leading dimension = distance between columns
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(2); // stride between N columns
} else {
// B is [batch, n, k] - transposed layout
ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n])
}
}
// Output is typically row-major: leading dimension = distance between rows = stride(1)
ldc = out.stride(1);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 2) {
// 3D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 3d*2d input, output is 3d.
// A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both)
// Offset divides N dimension of B, each group gets different slice of B and different batch of A
const int batch_size = mat_a.size(0); // n_groups
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A
// For row-major A and B case: B should be [K, total_N]
const int total_N = mat_b.size(1); // B is [K, total_N] for row-major
for (int i = 0; i < num_groups; ++i) {
int start_n = (i == 0) ? 0 : offs_accessor[i - 1];
int end_n = offs_accessor[i];
int n = end_n - start_n;
// Skip zero-sized groups but continue processing subsequent groups
if (n <= 0) {
continue;
}
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B slice for group i: B[:, start_n:end_n] (B[K, total_N])
const void* group_b_ptr;
int ldb;
// Check if B is row-major or column-major
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(0); // distance between rows (should be total_N)
} else {
// Column-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(1); // distance between columns (should be K)
}
// Select output slice for group i: Output[:, start_n:end_n]
void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size;
int lda, ldc;
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
// Output is row-major: leading dimension = distance between rows = stride(0)
ldc = out.stride(0);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(n),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim);
}
TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups");
// Initialize d_ptrs with the correct size
std::vector<std::array<const void*, 0>> d_ptrs(p_a_ptrs.size());
static DeviceOp gemm_instance;
auto argument = gemm_instance.MakeArgument(
p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs,
gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}
);
TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument),
"CK Group GEMM: argument unsupported (shape/strides/type config)");
size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument);
size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument);
void* gemm_arg_buf = nullptr;
void* ws_buf = nullptr;
hipMalloc(&gemm_arg_buf, arg_buf_size);
hipMalloc(&ws_buf, ws_size);
gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf);
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
auto invoker = gemm_instance.MakeInvoker();
hipStream_t stream = c10::hip::getCurrentHIPStream();
invoker.Run(argument, {stream});
hipFree(gemm_arg_buf);
hipFree(ws_buf);
}
void group_gemm_ck(
const at::Tensor& input_a,
const at::Tensor& input_b_colmajor,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& /*bias*/,
at::Tensor& out)
{
// Detect if input_a is row-major based on stride pattern
bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1);
bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1);
// Ensure tensor A is row-major and contiguous if not already
at::Tensor mat_a = input_a;
if (!a_row_major) {
// If A is not row-major, make it contiguous (row-major)
mat_a = input_a.contiguous();
}
// Force tensor B to be column-major using double transpose trick
// This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape
at::Tensor mat_b = input_b_colmajor;
if (!b_col_major) {
mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1);
}
// For 3D tensors, check the last dimension stride for row-major detection
a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1);
bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1);
if (mat_a.dtype() == at::kBFloat16) {
// bf16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kHalf) {
// fp16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kFloat) {
// fp32 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype");
}
}
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -18,6 +18,7 @@
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
#include <array>
#include <cstddef>
@ -40,200 +41,99 @@ namespace c10 {
///
/// This is intended to be trivially copyable, so it should be passed by
/// value.
///
/// NOTE: We have refactored out the headeronly parts of the ArrayRef struct
/// into HeaderOnlyArrayRef. As adding `virtual` would change the performance of
/// the underlying constexpr calls, we rely on apparent-type dispatch for
/// inheritance. This should be fine because their memory format is the same,
/// and it is never incorrect for ArrayRef to call HeaderOnlyArrayRef methods.
/// However, you should prefer to use ArrayRef when possible, because its use
/// of TORCH_CHECK will lead to better user-facing error messages.
template <typename T>
class ArrayRef final {
class ArrayRef final : public HeaderOnlyArrayRef<T> {
public:
using iterator = const T*;
using const_iterator = const T*;
using size_type = size_t;
using value_type = T;
using reverse_iterator = std::reverse_iterator<iterator>;
private:
/// The start of the array, in an external buffer.
const T* Data;
/// The number of elements.
size_type Length;
void debugCheckNullptrInvariant() {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
Data != nullptr || Length == 0,
"created ArrayRef with nullptr and non-zero length! std::optional relies on this being illegal");
}
public:
/// @name Constructors
/// @name Constructors, all inherited from HeaderOnlyArrayRef except for
/// SmallVector. As inherited constructors won't work with class template
/// argument deduction (CTAD) until C++23, we add deduction guides after
/// the class definition to enable CTAD.
/// @{
/// Construct an empty ArrayRef.
/* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {}
/// Construct an ArrayRef from a single element.
// TODO Make this explicit
constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
/// Construct an ArrayRef from a pointer and length.
constexpr ArrayRef(const T* data, size_t length)
: Data(data), Length(length) {
debugCheckNullptrInvariant();
}
/// Construct an ArrayRef from a range.
constexpr ArrayRef(const T* begin, const T* end)
: Data(begin), Length(end - begin) {
debugCheckNullptrInvariant();
}
using HeaderOnlyArrayRef<T>::HeaderOnlyArrayRef;
/// Construct an ArrayRef from a SmallVector. This is templated in order to
/// avoid instantiating SmallVectorTemplateCommon<T> whenever we
/// copy-construct an ArrayRef.
/// NOTE: this is the only constructor that is not inherited from
/// HeaderOnlyArrayRef.
template <typename U>
/* implicit */ ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec)
: Data(Vec.data()), Length(Vec.size()) {
debugCheckNullptrInvariant();
}
template <
typename Container,
typename U = decltype(std::declval<Container>().data()),
typename = std::enable_if_t<
(std::is_same_v<U, T*> || std::is_same_v<U, T const*>)>>
/* implicit */ ArrayRef(const Container& container)
: Data(container.data()), Length(container.size()) {
debugCheckNullptrInvariant();
}
/// Construct an ArrayRef from a std::vector.
// The enable_if stuff here makes sure that this isn't used for
// std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
// bitfield.
template <typename A>
/* implicit */ ArrayRef(const std::vector<T, A>& Vec)
: Data(Vec.data()), Length(Vec.size()) {
static_assert(
!std::is_same_v<T, bool>,
"ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
}
/// Construct an ArrayRef from a std::array
template <size_t N>
/* implicit */ constexpr ArrayRef(const std::array<T, N>& Arr)
: Data(Arr.data()), Length(N) {}
/// Construct an ArrayRef from a C array.
template <size_t N>
// NOLINTNEXTLINE(*c-arrays*)
/* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
/// Construct an ArrayRef from a std::initializer_list.
/* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec)
: Data(
std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
: std::begin(Vec)),
Length(Vec.size()) {}
: HeaderOnlyArrayRef<T>(Vec.data(), Vec.size()) {}
/// @}
/// @name Simple Operations
/// @name Simple Operations, mostly inherited from HeaderOnlyArrayRef
/// @{
constexpr iterator begin() const {
return Data;
}
constexpr iterator end() const {
return Data + Length;
}
// These are actually the same as iterator, since ArrayRef only
// gives you const iterators.
constexpr const_iterator cbegin() const {
return Data;
}
constexpr const_iterator cend() const {
return Data + Length;
}
constexpr reverse_iterator rbegin() const {
return reverse_iterator(end());
}
constexpr reverse_iterator rend() const {
return reverse_iterator(begin());
}
/// Check if all elements in the array satisfy the given expression
constexpr bool allMatch(const std::function<bool(const T&)>& pred) const {
return std::all_of(cbegin(), cend(), pred);
}
/// empty - Check if the array is empty.
constexpr bool empty() const {
return Length == 0;
}
constexpr const T* data() const {
return Data;
}
/// size - Get the array size.
constexpr size_t size() const {
return Length;
}
/// front - Get the first element.
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
/// STD_TORCH_CHECK
constexpr const T& front() const {
TORCH_CHECK(
!empty(), "ArrayRef: attempted to access front() of empty list");
return Data[0];
!this->empty(), "ArrayRef: attempted to access front() of empty list");
return this->Data[0];
}
/// back - Get the last element.
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
/// STD_TORCH_CHECK
constexpr const T& back() const {
TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list");
return Data[Length - 1];
}
/// equals - Check for element-wise equality.
constexpr bool equals(ArrayRef RHS) const {
return Length == RHS.Length && std::equal(begin(), end(), RHS.begin());
TORCH_CHECK(
!this->empty(), "ArrayRef: attempted to access back() of empty list");
return this->Data[this->Length - 1];
}
/// slice(n, m) - Take M elements of the array starting at element N
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
/// STD_TORCH_CHECK
constexpr ArrayRef<T> slice(size_t N, size_t M) const {
TORCH_CHECK(
N + M <= size(),
N + M <= this->size(),
"ArrayRef: invalid slice, N = ",
N,
"; M = ",
M,
"; size = ",
size());
return ArrayRef<T>(data() + N, M);
this->size());
return ArrayRef<T>(this->data() + N, M);
}
/// slice(n) - Chop off the first N elements of the array.
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
/// STD_TORCH_CHECK
constexpr ArrayRef<T> slice(size_t N) const {
TORCH_CHECK(
N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size());
return slice(N, size() - N);
N <= this->size(),
"ArrayRef: invalid slice, N = ",
N,
"; size = ",
this->size());
return slice(N, this->size() - N); // should this slice be this->slice?
}
/// @}
/// @name Operator Overloads
/// @{
constexpr const T& operator[](size_t Index) const {
return Data[Index];
}
/// Vector compatibility
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
/// STD_TORCH_CHECK
constexpr const T& at(size_t Index) const {
TORCH_CHECK(
Index < Length,
Index < this->Length,
"ArrayRef: invalid index Index = ",
Index,
"; Length = ",
Length);
return Data[Index];
this->Length);
return this->Data[Index];
}
/// Disallow accidental assignment from a temporary.
@ -253,16 +153,48 @@ class ArrayRef final {
std::enable_if_t<std::is_same_v<U, T>, ArrayRef<T>>& operator=(
std::initializer_list<U>) = delete;
/// @}
/// @name Expensive Operations
/// @{
std::vector<T> vec() const {
return std::vector<T>(Data, Data + Length);
}
/// @}
};
/// Deduction guides for ArrayRef to support CTAD with inherited constructors
/// These mirror the constructors inherited from HeaderOnlyArrayRef
/// @{
// Single element constructor
template <typename T>
ArrayRef(const T&) -> ArrayRef<T>;
// Pointer and length constructor
template <typename T>
ArrayRef(const T*, size_t) -> ArrayRef<T>;
// Range constructor (begin, end)
template <typename T>
ArrayRef(const T*, const T*) -> ArrayRef<T>;
// Generic container constructor (anything with .data() and .size())
template <typename Container>
ArrayRef(const Container&) -> ArrayRef<
std::remove_pointer_t<decltype(std::declval<Container>().data())>>;
// std::vector constructor
template <typename T, typename A>
ArrayRef(const std::vector<T, A>&) -> ArrayRef<T>;
// std::array constructor
template <typename T, size_t N>
ArrayRef(const std::array<T, N>&) -> ArrayRef<T>;
// C array constructor
template <typename T, size_t N>
ArrayRef(const T (&)[N]) -> ArrayRef<T>;
// std::initializer_list constructor
template <typename T>
ArrayRef(const std::initializer_list<T>&) -> ArrayRef<T>;
/// @}
template <typename T>
std::ostream& operator<<(std::ostream& out, ArrayRef<T> list) {
int i = 0;

View File

@ -1307,7 +1307,7 @@ endif()
if(USE_MKLDNN_ACL)
find_package(ACL REQUIRED)
target_include_directories(torch_cpu PRIVATE ${ACL_INCLUDE_DIRS})
target_include_directories(torch_cpu SYSTEM PRIVATE ${ACL_INCLUDE_DIRS})
endif()
target_include_directories(torch_cpu PRIVATE ${ATen_CPU_INCLUDE})

View File

@ -630,37 +630,6 @@ def mirror_files_into_torchgen() -> None:
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
def mirror_inductor_external_kernels() -> None:
"""
Copy external kernels into Inductor so they are importable.
"""
paths = [
(
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
CWD
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
),
]
for new_path, orig_path in paths:
# Create the dirs involved in new_path if they don't exist
if not new_path.exists():
new_path.parent.mkdir(parents=True, exist_ok=True)
# Copy the files from the orig location to the new location
if orig_path.is_file():
shutil.copyfile(orig_path, new_path)
continue
if orig_path.is_dir():
if new_path.exists():
# copytree fails if the tree exists already, so remove it.
shutil.rmtree(new_path)
shutil.copytree(orig_path, new_path)
continue
raise RuntimeError(
"Check the file paths in `mirror_inductor_external_kernels()`"
)
# ATTENTION: THIS IS AI SLOP
def extract_variant_from_version(version: str) -> str:
"""Extract variant from version string, defaulting to 'cpu'."""
@ -1647,8 +1616,6 @@ def main() -> None:
if RUN_BUILD_DEPS:
build_deps()
mirror_inductor_external_kernels()
(
ext_modules,
cmdclass,
@ -1682,7 +1649,6 @@ def main() -> None:
"_inductor/codegen/aoti_runtime/*.cpp",
"_inductor/script.ld",
"_inductor/kernel/flex/templates/*.jinja",
"_inductor/kernel/templates/*.jinja",
"_export/serde/*.yaml",
"_export/serde/*.thrift",
"share/cmake/ATen/*.cmake",

View File

@ -12,6 +12,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp

View File

@ -0,0 +1,52 @@
#include <gtest/gtest.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
#include <vector>
using torch::headeronly::HeaderOnlyArrayRef;
TEST(TestHeaderOnlyArrayRef, TestEmpty) {
HeaderOnlyArrayRef<float> arr;
ASSERT_TRUE(arr.empty());
}
TEST(TestHeaderOnlyArrayRef, TestSingleton) {
float val = 5.0f;
HeaderOnlyArrayRef<float> arr(val);
ASSERT_FALSE(arr.empty());
EXPECT_EQ(arr.size(), 1);
EXPECT_EQ(arr[0], val);
}
TEST(TestHeaderOnlyArrayRef, TestAPIs) {
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
HeaderOnlyArrayRef<int> arr(vec);
ASSERT_FALSE(arr.empty());
EXPECT_EQ(arr.size(), 7);
for (size_t i = 0; i < arr.size(); i++) {
EXPECT_EQ(arr[i], i + 1);
EXPECT_EQ(arr.at(i), i + 1);
}
EXPECT_EQ(arr.front(), 1);
EXPECT_EQ(arr.back(), 7);
ASSERT_TRUE(arr.slice(3, 4).equals(arr.slice(3)));
}
TEST(TestHeaderOnlyArrayRef, TestFromInitializerList) {
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
HeaderOnlyArrayRef<int> arr({1, 2, 3, 4, 5, 6, 7});
auto res_vec = arr.vec();
for (size_t i = 0; i < vec.size(); i++) {
EXPECT_EQ(vec[i], res_vec[i]);
}
}
TEST(TestHeaderOnlyArrayRef, TestFromRange) {
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
HeaderOnlyArrayRef<int> arr(vec.data() + 3, vec.data() + 7);
auto res_vec = arr.vec();
for (size_t i = 0; i < res_vec.size(); i++) {
EXPECT_EQ(vec[i + 3], res_vec[i]);
}
}

View File

@ -311,10 +311,9 @@ void boxed_fill_infinity(
}
Tensor my_pad(Tensor t) {
std::vector<int64_t> padding = {1, 2, 2, 1};
std::string mode = "constant";
double value = 0.0;
return pad(t, padding, mode, value);
return pad(t, {1, 2, 2, 1}, mode, value);
}
void boxed_my_pad(
@ -342,6 +341,9 @@ void boxed_my_narrow(
}
Tensor my_new_empty_dtype_variant(Tensor t) {
// Still using a std::vector below even though people can just pass in an
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
// directly.
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
return new_empty(t, sizes, dtype);
@ -353,9 +355,8 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui
}
Tensor my_new_zeros_dtype_variant(Tensor t) {
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(at::ScalarType::Float);
return new_zeros(t, sizes, dtype);
return new_zeros(t, {2, 5}, dtype);
}
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
@ -429,8 +430,7 @@ void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs)
}
Tensor my_amax_vec(Tensor t) {
std::vector<int64_t> v = {0,1};
return amax(t, v, false);
return amax(t, {0,1}, false);
}
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {

View File

@ -5,8 +5,16 @@ import contextlib
import torch
import torch.distributed as dist
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -426,6 +434,31 @@ class TestDTensorDebugMode(TestCase):
][-1]
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
def test_pretty_print_dtensor_make_fx(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
A = torch.randn(8, 32)
B = torch.randn(32, 32)
dA = distribute_tensor(A, mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, mesh, [Replicate()]).requires_grad_()
def f(dA, dB):
dy = dA @ dB
loss = dy.sum()
loss.backward()
return dA.grad, dB.grad
# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode="fake")(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
# Colored is nice for actual viewing, not using in this test though
gm_str = gm.print_readable(colored=False, print_output=False)
self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str)
instantiate_parametrized_tests(TestDTensorDebugMode)

View File

@ -13194,6 +13194,30 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
self.assertEqual(actual, expected)
@parametrize_pytree_module
def test_pytree_tree_map_dict_order(self, pytree):
def fn(tree):
new_tree = pytree.tree_map(lambda x: x, tree)
return list(new_tree.keys()), list(new_tree.values())
x = torch.randn(3, 2)
fn_opt = torch.compile(fullgraph=True)(fn)
tree1 = {"b": x + 2, "a": x, "c": x - 1}
expected1 = fn(tree1)
actual1 = fn_opt(tree1)
self.assertEqual(actual1, expected1)
tree2 = collections.OrderedDict([("b", x + 2), ("a", x), ("c", x - 1)])
expected2 = fn(tree2)
actual2 = fn_opt(tree2)
self.assertEqual(actual2, expected2)
tree3 = collections.defaultdict(int, {"b": x + 2, "a": x, "c": x - 1})
expected3 = fn(tree3)
actual3 = fn_opt(tree3)
self.assertEqual(actual3, expected3)
@parametrize_pytree_module
def test_pytree_tree_map_only(self, pytree):
if not callable(getattr(pytree, "tree_map_only", None)):

View File

@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca
torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node
torch.fx.graph.Graph.print_tabular(self)
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode
torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')
torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool
torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None

View File

@ -1,154 +0,0 @@
# Owner(s): ["module: inductor"]
import unittest
import torch
from torch import Tensor
from torch._inductor import config
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch._inductor.utils import ensure_cute_available
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
@unittest.skipIf(
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
"CuTeDSL library or Blackwell device not available",
)
@instantiate_parametrized_tests
class TestCuTeDSLGroupedGemm(InductorTestCase):
def _get_inputs(
self,
group_size: int,
M_hint: int,
K: int,
N: int,
device: str,
dtype: torch.dtype,
alignment: int = 16,
) -> tuple[Tensor, Tensor, Tensor]:
# --- Random, tile-aligned M sizes ---
M_sizes = (
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
* alignment
)
M_total = torch.sum(M_sizes).item()
# --- Construct input tensors ---
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
# --- Build offsets (no leading zero, strictly increasing) ---
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
return (A, B, offsets)
@parametrize("group_size", (2, 8))
@parametrize("M_hint", (256, 1024))
@parametrize("K", (64, 128))
@parametrize("N", (128, 256))
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
device = "cuda"
dtype = torch.bfloat16
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# Eager execution
c_eager = grouped_gemm_fn(A, B, offsets)
# Test with Cute backend
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
@parametrize("layout_B", ("contiguous", "broadcasted"))
def test_grouped_gemm_assorted_layouts(
self,
layout_A: str,
layout_B: str,
):
device = "cuda"
dtype = torch.bfloat16
G, K, N = 8, 64, 128
M_sizes = [128] * G
sum_M = sum(M_sizes)
offsets = torch.tensor(
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
)
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
A = A_base
if layout_A == "offset":
# allocate bigger buffer than needed, use nonzero storage offset
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
offset = 128 # skip first 128 elements
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
elif layout_A == "padded":
# simulate row pitch > K (row_stride = K + pad)
row_pitch = K + 8
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
elif layout_A == "view":
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
A = A_storage.view(sum_M, K)
assert A._base is not None
assert A.shape == (sum_M, K)
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
if layout_B == "broadcasted":
# Broadcast B across groups (zero stride along G)
B = B[0].expand(G, K, N)
assert B.stride(0) == 0
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# --- eager ---
c_eager = grouped_gemm_fn(A, B, offsets)
# --- compiled (CUTE backend) ---
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
if __name__ == "__main__":
run_tests()

View File

@ -6,6 +6,7 @@ from typing import Union
import torch
from torch import Tensor
from torch._C import FileCheck
from torch._inductor import config, utils
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._inductor.test_case import run_tests, TestCase
@ -29,7 +30,6 @@ from torch.testing._internal.inductor_utils import (
HAS_CPU,
HAS_CUDA_AND_TRITON,
)
from torch.testing._internal.jit_utils import FileCheck
from torch.utils._triton import has_triton_tma_device
@ -953,6 +953,240 @@ class TestFP8Lowering(TestCase):
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0.07)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@torch._inductor.config.patch("emulate_precision_casts", True)
def test_mx_fusion(self):
# Register fake_scaled_mm custom op scoped to this test
with torch.library._scoped_library("test_fp8", "FRAGMENT") as lib:
# Define the op schema
lib.define(
"fake_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scale_a, Tensor scale_b, "
"Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, "
"bool use_fast_accum=False) -> Tensor"
)
input_values = []
# Register CUDA implementation
@torch.library.impl(lib, "fake_scaled_mm", "CUDA")
def fake_scaled_mm_impl(
mat_a,
mat_b,
scale_a,
scale_b,
bias=None,
scale_result=None,
out_dtype=None,
use_fast_accum=False,
):
"""Software-emulated scaled_mm for testing without CUDA 12.8"""
out_dtype = out_dtype or torch.bfloat16
# just using add, because without real dtypes,
# was seeing overflow/instability
nonlocal input_values
input_values.append((mat_a, mat_b, scale_a, scale_b))
result = mat_a.to(torch.float32) + mat_b.to(torch.float32)
if bias is not None:
result = result + bias.to(torch.float32)
return result.to(out_dtype)
# Register fake implementation
@torch.library.impl(lib, "fake_scaled_mm", "Meta")
def fake_scaled_mm_meta(
mat_a,
mat_b,
scale_a,
scale_b,
bias=None,
scale_result=None,
out_dtype=None,
use_fast_accum=False,
):
"""FakeTensor implementation"""
out_dtype = out_dtype or torch.bfloat16
M, K = mat_a.shape
K2, N = mat_b.shape
torch._check(
K == K2,
lambda: f"Incompatible shapes: {mat_a.shape} @ {mat_b.shape}",
)
return torch.empty((M, N), dtype=out_dtype, device=mat_a.device)
def forward(
arg0_1,
arg1_1,
):
view = torch.ops.aten.reshape.default(arg0_1, [8192, 256, 32])
abs_1 = torch.ops.aten.abs.default(view)
amax = torch.ops.aten.amax.default(abs_1, [-1])
unsqueeze = torch.ops.aten.unsqueeze.default(amax, -1)
view_1 = torch.ops.aten.view.dtype(unsqueeze, torch.int32)
bitwise_right_shift = torch.ops.aten.bitwise_right_shift.Tensor_Scalar(
view_1, 23
)
bitwise_and = torch.ops.aten.bitwise_and.Scalar(
bitwise_right_shift, 255
)
sub = torch.ops.aten.sub.Tensor(bitwise_and, 127)
sub_1 = torch.ops.aten.sub.Tensor(sub, 8)
clamp_min = torch.ops.aten.clamp_min.default(sub_1, -127)
clamp_max = torch.ops.aten.clamp_max.default(clamp_min, 128)
add = torch.ops.aten.add.Tensor(clamp_max, 127)
convert_element_type = torch.ops.prims.convert_element_type.default(
add, torch.uint8
)
isnan = torch.ops.aten.isnan.default(unsqueeze)
scalar_tensor = torch.ops.aten.scalar_tensor.default(
255, dtype=torch.uint8, layout=torch.strided, device="cuda"
)
where = torch.ops.aten.where.self(
isnan, scalar_tensor, convert_element_type
)
convert_element_type_1 = torch.ops.prims.convert_element_type.default(
where, torch.int32
)
bitwise_left_shift = torch.ops.aten.bitwise_left_shift.Tensor_Scalar(
convert_element_type_1, 23
)
view_2 = torch.ops.aten.view.dtype(bitwise_left_shift, torch.float32)
clamp_min_1 = torch.ops.aten.clamp_min.default(
view_2, 1.1754943508222875e-38
)
div = torch.ops.aten.div.Tensor(view, clamp_min_1)
clamp_min_2 = torch.ops.aten.clamp_min.default(div, -448.0)
clamp_max_1 = torch.ops.aten.clamp_max.default(clamp_min_2, 448.0)
convert_element_type_2 = torch.ops.prims.convert_element_type.default(
clamp_max_1, torch.float8_e4m3fn
)
view_3 = torch.ops.aten.reshape.default(
convert_element_type_2, [8192, 8192]
)
convert_element_type_2 = None
view_4 = torch.ops.aten.view.dtype(where, torch.float8_e8m0fnu)
squeeze = torch.ops.aten.squeeze.dim(view_4, -1)
view_5 = torch.ops.aten.reshape.default(arg1_1, [8192, 256, 32])
abs_2 = torch.ops.aten.abs.default(view_5)
amax_1 = torch.ops.aten.amax.default(abs_2, [-1])
unsqueeze_1 = torch.ops.aten.unsqueeze.default(amax_1, -1)
view_6 = torch.ops.aten.view.dtype(unsqueeze_1, torch.int32)
bitwise_right_shift_1 = (
torch.ops.aten.bitwise_right_shift.Tensor_Scalar(view_6, 23)
)
bitwise_and_1 = torch.ops.aten.bitwise_and.Scalar(
bitwise_right_shift_1, 255
)
sub_2 = torch.ops.aten.sub.Tensor(bitwise_and_1, 127)
sub_3 = torch.ops.aten.sub.Tensor(sub_2, 8)
clamp_min_3 = torch.ops.aten.clamp_min.default(sub_3, -127)
clamp_max_2 = torch.ops.aten.clamp_max.default(clamp_min_3, 128)
add_1 = torch.ops.aten.add.Tensor(clamp_max_2, 127)
convert_element_type_3 = torch.ops.prims.convert_element_type.default(
add_1, torch.uint8
)
isnan_1 = torch.ops.aten.isnan.default(unsqueeze_1)
unsqueeze_1 = None
scalar_tensor_1 = torch.ops.aten.scalar_tensor.default(
255, dtype=torch.uint8, layout=torch.strided, device="cuda"
)
where_1 = torch.ops.aten.where.self(
isnan_1, scalar_tensor_1, convert_element_type_3
)
convert_element_type_4 = torch.ops.prims.convert_element_type.default(
where_1, torch.int32
)
bitwise_left_shift_1 = torch.ops.aten.bitwise_left_shift.Tensor_Scalar(
convert_element_type_4, 23
)
convert_element_type_4 = None
view_7 = torch.ops.aten.view.dtype(bitwise_left_shift_1, torch.float32)
bitwise_left_shift_1 = None
clamp_min_4 = torch.ops.aten.clamp_min.default(
view_7, 1.1754943508222875e-38
)
div_1 = torch.ops.aten.div.Tensor(view_5, clamp_min_4)
clamp_min_5 = torch.ops.aten.clamp_min.default(div_1, -448.0)
clamp_max_3 = torch.ops.aten.clamp_max.default(clamp_min_5, 448.0)
convert_element_type_5 = torch.ops.prims.convert_element_type.default(
clamp_max_3, torch.float8_e4m3fn
)
view_8 = torch.ops.aten.reshape.default(
convert_element_type_5, [8192, 8192]
)
view_9 = torch.ops.aten.view.dtype(where_1, torch.float8_e8m0fnu)
squeeze_1 = torch.ops.aten.squeeze.dim(view_9, -1)
permute = torch.ops.aten.permute.default(view_8, [1, 0])
view_13 = torch.ops.aten.reshape.default(squeeze, [64, 128, 64, 4])
permute_2 = torch.ops.aten.permute.default(view_13, [0, 2, 1, 3])
clone = torch.ops.aten.clone.default(
permute_2, memory_format=torch.contiguous_format
)
view_14 = torch.ops.aten.reshape.default(clone, [4096, 4, 32, 4])
permute_3 = torch.ops.aten.permute.default(view_14, [0, 2, 1, 3])
clone_1 = torch.ops.aten.clone.default(
permute_3, memory_format=torch.contiguous_format
)
view_15 = torch.ops.aten.reshape.default(clone_1, [4096, 32, 16])
view_16 = torch.ops.aten.reshape.default(view_15, [2097152])
view_18 = torch.ops.aten.reshape.default(squeeze_1, [64, 128, 64, 4])
permute_5 = torch.ops.aten.permute.default(view_18, [0, 2, 1, 3])
clone_2 = torch.ops.aten.clone.default(
permute_5, memory_format=torch.contiguous_format
)
view_19 = torch.ops.aten.reshape.default(clone_2, [4096, 4, 32, 4])
permute_6 = torch.ops.aten.permute.default(view_19, [0, 2, 1, 3])
clone_3 = torch.ops.aten.clone.default(
permute_6, memory_format=torch.contiguous_format
)
view_20 = torch.ops.aten.reshape.default(clone_3, [4096, 32, 16])
view_21 = torch.ops.aten.reshape.default(view_20, [2097152])
_scaled_mm = torch.ops.test_fp8.fake_scaled_mm.default(
view_3, permute, view_16, view_21, None, None, torch.float32
)
return (_scaled_mm,)
# Run with largest shape
M, K, N = 8192, 8192, 8192
device = "cuda"
A = torch.randn(M, K, dtype=torch.float32, device=device)
B = torch.randn(K, N, dtype=torch.float32, device=device)
f_c = torch.compile(fullgraph=True)(forward)
_, code = run_and_get_code(f_c, A, B)
FileCheck().check(".run(").check(".run(").check("fake_scaled_mm").run(
code[0]
)
for seed in range(5):
input_values.clear()
torch.manual_seed(seed)
# without dividing, outputs get way too large
A = torch.randn(M, K, dtype=torch.float32, device=device)
B = torch.randn(K, N, dtype=torch.float32, device=device)
# Uses fake_scaled_mm custom op (no CUDA 12.8 needed!)
torch._dynamo.reset()
torch.compile(forward)(A, B)
torch._dynamo.reset()
with config.patch({"loop_index_inversion_in_fusion": False}):
torch.compile(forward)(A, B)
assert len(input_values) == 2
for i in range(4):
self.assertEqual(
input_values[0][i],
input_values[1][i],
msg=f"idx {i} seed {seed}",
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("M", (1, 3, 33, 257, 1024))
@parametrize("K", (16, 32, 1024))

View File

@ -16,6 +16,7 @@ from torch._dynamo.utils import same
from torch._inductor import config as inductor_config, ir, metrics
from torch._inductor.codegen.triton import TritonScheduling
from torch._inductor.graph import GraphLowering
from torch._inductor.invert_expr_analysis import generate_inverse_formula
from torch._inductor.scheduler import SchedulerNode
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.test_operators import realize
@ -1188,6 +1189,113 @@ class TestTiling(TestCase):
torch.compile(f)(x)
class TestIndexInversion(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
gm = torch.fx.symbolic_trace(lambda: 0)
graph = GraphLowering(gm)
graph.scheduler = MockScheduler
cls._exit_stack = contextlib.ExitStack()
cls._exit_stack.enter_context(V.set_graph_handler(graph))
def _check_expr(self, expr, reconstruction, val_range):
import numpy as np
from sympy import lambdify
assert len(expr.free_symbols) == 1
p0 = next(iter(expr.free_symbols))
def floordiv_replacement(a, b):
"""Replace FloorDiv(a, b) with a // b"""
return a // b
def modularindexing_replacement(x, base, divisor):
"""Replace ModularIndexing(x, base, divisor) with (x // base) % divisor"""
return (x // base) % divisor
# Replace custom functions with sympy equivalents
expr_numpy_ready = expr.replace(FloorDiv, floordiv_replacement).replace(
ModularIndexing, modularindexing_replacement
)
reconstruction_numpy_ready = reconstruction.replace(
FloorDiv, floordiv_replacement
).replace(ModularIndexing, modularindexing_replacement)
# Now lambdify with standard numpy
forward_func = lambdify(p0, expr_numpy_ready, modules="numpy")
inverse_func = lambdify(p0, reconstruction_numpy_ready, modules="numpy")
test_values = np.arange(0, val_range, dtype=np.int64)
forward_values = forward_func(test_values).astype(np.int64)
recovered_values = inverse_func(forward_values).astype(np.int64)
torch.testing.assert_close(test_values, recovered_values)
@classmethod
def tearDownClass(cls):
super().tearDownClass()
cls._exit_stack.close()
def test_original_complex_expression(self):
"""Test the original motivating complex expression."""
p0 = sympy.Symbol("p0")
expr = (
32768 * FloorDiv(p0, 32768)
+ 8192 * FloorDiv(ModularIndexing(p0, 1, 16), 4)
+ ModularIndexing(p0, 1, 4)
+ 256 * ModularIndexing(p0, 16, 32)
+ 4 * ModularIndexing(p0, 512, 64)
)
reconstruction = generate_inverse_formula(expr, p0)
self.assertIsNotNone(reconstruction)
self._check_expr(expr, reconstruction, 2097152)
def test_inversion_cases(self):
"""Test various expressions for correct inversion behavior."""
p = sympy.Symbol("p")
cases = [
# (expression, should_be_invertible, test_range)
# Simple 2-term base-10 style: 10 = 1 × 10 ✓
(10 * ModularIndexing(p, 10, 10) + ModularIndexing(p, 1, 10), True, 100),
# Simple 2-term base-2 style: 2 = 1 × 2 ✓
(2 * ModularIndexing(p, 2, 2) + ModularIndexing(p, 1, 2), True, 4),
# 3-term decimal: 100 = 10×10, 10 = 1×10 ✓
(
100 * FloorDiv(p, 100)
+ 10 * FloorDiv(ModularIndexing(p, 1, 100), 10)
+ ModularIndexing(p, 1, 10),
True,
1000,
),
(4 * p, False, 64), # expr and inverse not bijections
# when sorted, invertible
(ModularIndexing(p, 1, 10) + 10 * ModularIndexing(p, 10, 10), True, None),
# Wrong coefficient ratios: 4 ≠ 1×2
(4 * ModularIndexing(p, 1, 8) + ModularIndexing(p, 8, 2), False, None),
(
100 * FloorDiv(p, 100) + 7 * ModularIndexing(p, 1, 100),
False,
None,
), # Wrong ratios
(FloorDiv(p, 100) + FloorDiv(p, 10) + p, False, None), # Overlapping ranges
(p**2 + 10 * p + 1, False, None), # Quadratic
(sympy.sin(p) + sympy.cos(p), False, None), # Trigonometric
]
for expr, should_invert, test_range in cases:
reconstruction = generate_inverse_formula(expr, p)
if should_invert:
self.assertIsNotNone(reconstruction, f"Expected invertible: {expr}")
# Test correctness on sample values
self._check_expr(expr, reconstruction, test_range)
else:
self.assertIsNone(reconstruction, f"Expected non-invertible: {expr}")
if __name__ == "__main__":
if HAS_GPU:
run_tests()

View File

@ -14424,6 +14424,20 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.common(fn, (torch.randn(6, 4, device=GPU_TYPE).t().contiguous().t(),))
@skip_if_halide
@requires_cuda_and_triton
def test_unbacked_float_item(self):
def fn(x, max_val):
return torch.clamp(x, 0, max_val.item())
self.common(
fn,
(
torch.randn(10, 20, 30, device=self.device),
torch.tensor(5.0, device=self.device),
),
)
# end of class CommonTemplate - add new tests here

View File

@ -73,7 +73,22 @@ from tools.testing.test_selections import (
ShardedTest,
THRESHOLD,
)
from tools.testing.upload_artifacts import zip_and_upload_artifacts
try:
from tools.testing.upload_artifacts import (
parse_xml_and_upload_json,
zip_and_upload_artifacts,
)
except ImportError:
# some imports in those files might fail, e.g., boto3 not installed. These
# functions are only needed under specific circumstances (CI) so we can
# define dummy functions here.
def parse_xml_and_upload_json():
pass
def zip_and_upload_artifacts(failed: bool):
pass
# Make sure to remove REPO_ROOT after import is done
@ -1887,6 +1902,7 @@ def run_tests(
def handle_complete(failure: Optional[TestFailure]):
failed = failure is not None
if IS_CI and options.upload_artifacts_while_running:
parse_xml_and_upload_json()
zip_and_upload_artifacts(failed)
if not failed:
return False

176
test/test_as_strided.py Normal file
View File

@ -0,0 +1,176 @@
# Owner(s): ["oncall: pt2"]
from collections import deque
from typing import Optional
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
def get_state(t: torch.Tensor) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""Extract (sizes, strides) tuple from a tensor."""
return (tuple(t.size()), tuple(t.stride()))
def enumerate_reachable_states(
initial_size: int,
) -> set[tuple[tuple[int, ...], tuple[int, ...]]]:
"""
Use BFS with DP to enumerate all reachable (size, stride) states from
a 1D contiguous tensor via valid view operations.
We only explore states with offset=0 (you can retroactively change the offset).
We reject states with size=0 or size=1 dimensions as they are degenerate.
"""
# Create initial 1D contiguous tensor
initial_tensor = torch.arange(initial_size)
initial_state = get_state(initial_tensor)
# Map from state to tensor for that state
state_to_tensor: dict[tuple[tuple[int, ...], tuple[int, ...]], torch.Tensor] = {
initial_state: initial_tensor
}
visited: set[tuple[tuple[int, ...], tuple[int, ...]]] = {initial_state}
queue: deque[tuple[tuple[int, ...], tuple[int, ...]]] = deque([initial_state])
while queue:
state = queue.popleft()
t = state_to_tensor[state]
sizes, strides = state
ndim = len(sizes)
def add_state(new_t: torch.Tensor) -> None:
new_state = get_state(new_t)
sizes, strides = new_state
# Skip if has size-0 or size-1 dimensions
if any(s == 0 or s == 1 for s in sizes):
return
# Only accept states where strides are in descending order
if list(strides) != sorted(strides, reverse=True):
return
if new_state not in visited:
visited.add(new_state)
queue.append(new_state)
state_to_tensor[new_state] = new_t
# 1. Unflatten: try factoring each dimension
for dim in range(ndim):
size = sizes[dim]
assert size > 1
# Try all factorizations x * y = size where both x, y >= 2
# We only need to check x up to size // 2 since when x > size // 2,
# y = size // x < 2, which we reject
for x in range(2, size // 2 + 1):
if size % x == 0:
y = size // x
add_state(t.unflatten(dim, (x, y)))
# 2. Slice: exhaustively check all possible slicing parameters
for dim in range(ndim):
size = sizes[dim]
for start in range(size):
for stop in range(start + 1, size + 1):
for step in range(1, size + 1):
slices = [slice(None)] * ndim
slices[dim] = slice(start, stop, step)
add_state(t[tuple(slices)])
# 3. Flatten: merge adjacent dimensions
for dim in range(ndim - 1):
add_state(t.flatten(dim, dim + 1))
return visited
class TestAsStrided(TestCase):
def test_size_10_exhaustive(self) -> None:
"""Test that size 10 produces exactly the expected 54 states."""
expected_states = {
((2,), (1,)),
((2,), (2,)),
((2,), (3,)),
((2,), (4,)),
((2,), (5,)),
((2,), (6,)),
((2,), (7,)),
((2,), (8,)),
((2,), (9,)),
((2, 2), (2, 1)),
((2, 2), (3, 1)),
((2, 2), (3, 2)),
((2, 2), (4, 1)),
((2, 2), (4, 2)),
((2, 2), (4, 3)),
((2, 2), (5, 1)),
((2, 2), (5, 2)),
((2, 2), (5, 3)),
((2, 2), (5, 4)),
((2, 2), (6, 1)),
((2, 2), (6, 2)),
((2, 2), (6, 3)),
((2, 2), (8, 1)),
((2, 2, 2), (4, 2, 1)),
((2, 2, 2), (5, 2, 1)),
((2, 3), (3, 1)),
((2, 3), (4, 1)),
((2, 3), (5, 1)),
((2, 3), (5, 2)),
((2, 3), (6, 1)),
((2, 4), (4, 1)),
((2, 4), (5, 1)),
((2, 5), (5, 1)),
((3,), (1,)),
((3,), (2,)),
((3,), (3,)),
((3,), (4,)),
((3, 2), (2, 1)),
((3, 2), (3, 1)),
((3, 2), (3, 2)),
((3, 2), (4, 1)),
((3, 3), (3, 1)),
((4,), (1,)),
((4,), (2,)),
((4,), (3,)),
((4, 2), (2, 1)),
((5,), (1,)),
((5,), (2,)),
((5, 2), (2, 1)),
((6,), (1,)),
((7,), (1,)),
((8,), (1,)),
((9,), (1,)),
((10,), (1,)),
}
actual_states = enumerate_reachable_states(10)
self.assertEqual(len(actual_states), 54)
self.assertEqual(actual_states, expected_states)
def test_subset_property(self) -> None:
"""
Test that for sizes 2..10, each smaller tensor results in a strict
subset of possible states compared to the next one.
"""
prev_states: Optional[set[tuple[tuple[int, ...], tuple[int, ...]]]] = None
for size in range(2, 11):
current_states = enumerate_reachable_states(size)
if prev_states is not None:
# Check that prev_states is a strict subset of current_states
self.assertTrue(
prev_states.issubset(current_states),
f"States from size {size - 1} are not a subset of size {size}",
)
# Check that it's a strict subset (not equal)
self.assertTrue(
len(prev_states) < len(current_states),
f"States from size {size - 1} should be strictly fewer than size {size}",
)
prev_states = current_states
if __name__ == "__main__":
run_tests()

View File

@ -75,12 +75,6 @@ from torch.testing._internal.common_utils import (
)
from torch.testing._internal.jit_utils import JitTestCase
import json
import tempfile
from torch.profiler import profile, ProfilerActivity
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
from torch.autograd.profiler_util import _canonicalize_profiler_events
try:
from torchvision import models as torchvision_models
@ -207,36 +201,6 @@ def side_effect_func(x: torch.Tensor):
print(x)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
trace_file = f.name
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(
trace_data
)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
class TestFX(JitTestCase):
def setUp(self):
super().setUp()
@ -4248,150 +4212,6 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
# recorver mutable checking flag
torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
augments profiler events with stack traces from FX metadata registry.
"""
# Simple test model
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(16, 10)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
model = TestModel().cuda()
# Compile the model
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(torch.randn(10, 10, device="cuda"))
# Profile with the compiled model
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result = compiled_model(torch.randn(10, 10, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::t node=t stack_trace=x = self.linear1(x)
event=aten::transpose node=t stack_trace=x = self.linear1(x)
event=aten::as_strided node=t stack_trace=x = self.linear1(x)
event=aten::addmm node=addmm stack_trace=x = self.linear1(x)
event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x)
event=aten::relu node=relu stack_trace=x = self.relu(x)
event=aten::clamp_min node=relu stack_trace=x = self.relu(x)
event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x)
event=aten::t node=t_1 stack_trace=x = self.linear2(x)
event=aten::transpose node=t_1 stack_trace=x = self.linear2(x)
event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x)
event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x)
event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_multiple_modules(self):
"""
Test that multiple compiled modules under the same profiler session
have their events correctly augmented with stack traces.
"""
class ModelA(torch.nn.Module):
def forward(self, x):
return x + 1
class ModelB(torch.nn.Module):
def forward(self, x):
return x - 1
model_a = ModelA().cuda()
model_b = ModelB().cuda()
# Compile both models
compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True)
compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_a(torch.randn(10, 10, device="cuda"))
_ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
# Profile both models in the same session
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result_a = compiled_a(torch.randn(10, 10, device="cuda"))
result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::add node=add stack_trace=return x + 1
event=cudaLaunchKernel node=add stack_trace=return x + 1
event=aten::sub node=sub stack_trace=return x - 1
event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_nested_graph_modules(self):
"""
Test that nested graph modules (e.g., graph modules calling subgraphs)
have their events correctly augmented with stack traces.
"""
# Model with nested structure
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.c = 5
@torch.compiler.nested_compile_region
def forward(self, x, y):
m = torch.mul(x, y)
s = m.sin()
a = s + self.c
return a
model = Mod().cuda()
# Compile the model (this may create nested graph modules)
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
# Profile
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::mul node=mul stack_trace=m = torch.mul(x, y)
event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y)
event=aten::sin node=sin stack_trace=s = m.sin()
event=cudaLaunchKernel node=sin stack_trace=s = m.sin()
event=aten::add node=add stack_trace=a = s + self.c
event=cudaLaunchKernel node=add stack_trace=a = s + self.c"""
)
def run_getitem_target():
from torch.fx._symbolic_trace import _wrapped_methods_to_patch

View File

@ -490,8 +490,6 @@ class TestMatmulCuda(InductorTestCase):
@parametrize("b_row_major", [False, True])
@dtypes(torch.bfloat16, torch.float32, torch.float16)
def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype):
if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]:
self.skipTest("failed using hipblaslt on rocm 6.4.2")
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 64, 4

View File

@ -601,6 +601,24 @@ class TestGenericPytree(TestCase):
for case in cases:
run_test(case)
@parametrize_pytree_module
def test_tree_map_dict_order(self, pytree):
d = {"b": 2, "a": 1, "c": 3}
od = OrderedDict([("b", 2), ("a", 1), ("c", 3)])
dd = defaultdict(int, {"b": 2, "a": 1, "c": 3})
for tree in (d, od, dd):
result = pytree.tree_map(lambda x: x, tree)
self.assertEqual(
list(result.keys()),
list(tree.keys()),
msg=f"Dictionary keys order changed in tree_map: {tree!r} vs. {result!r}",
)
self.assertEqual(
list(result.values()),
list(tree.values()),
msg=f"Dictionary keys order changed in tree_map: {tree!r} vs. {result!r}",
)
@parametrize_pytree_module
def test_tree_map_only(self, pytree):
self.assertEqual(pytree.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"])

View File

@ -38,12 +38,14 @@ def parse_xml_report(
report: Path,
workflow_id: int,
workflow_run_attempt: int,
job_id: int | None = None,
) -> list[dict[str, Any]]:
"""Convert a test report xml file into a JSON-serializable list of test cases."""
print(f"Parsing {tag}s for test report: {report}")
job_id = get_job_id(report)
print(f"Found job id: {job_id}")
if job_id is None:
job_id = get_job_id(report)
print(f"Found job id: {job_id}")
test_cases: list[dict[str, Any]] = []

View File

@ -1,11 +1,16 @@
import glob
import gzip
import json
import os
import time
import zipfile
from functools import lru_cache
from pathlib import Path
from typing import Any
from typing import Any, Optional
from filelock import FileLock, Timeout
from tools.stats.upload_test_stats import parse_xml_report
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
@ -140,3 +145,66 @@ def trigger_upload_test_stats_intermediate_workflow() -> None:
},
)
print(x.text)
def parse_xml_and_upload_json() -> None:
"""
Parse xml test reports that do not yet have a corresponding json report
uploaded to s3, and upload the json reports to s3. Use filelock to avoid
uploading the same file from multiple processes.
"""
try:
job_id: Optional[int] = int(os.environ.get("JOB_ID", 0))
if job_id == 0:
job_id = None
except (ValueError, TypeError):
job_id = None
try:
for xml_file in glob.glob(
f"{REPO_ROOT}/test/test-reports/**/*.xml", recursive=True
):
xml_path = Path(xml_file)
json_file = xml_path.with_suffix(".json")
lock = FileLock(str(json_file) + ".lock")
try:
lock.acquire(timeout=0) # immediately fails if already locked
if json_file.exists():
continue # already uploaded
test_cases = parse_xml_report(
"testcase",
xml_path,
int(os.environ.get("GITHUB_RUN_ID", "0")),
int(os.environ.get("GITHUB_RUN_ATTEMPT", "0")),
job_id,
)
line_by_line_jsons = "\n".join([json.dumps(tc) for tc in test_cases])
gzipped = gzip.compress(line_by_line_jsons.encode("utf-8"))
s3_key = (
json_file.relative_to(REPO_ROOT / "test/test-reports")
.as_posix()
.replace("/", "_")
)
get_s3_resource().put_object(
Body=gzipped,
Bucket="gha-artifacts",
Key=f"test_jsons_while_running/{os.environ.get('GITHUB_RUN_ID')}/{job_id}/{s3_key}",
ContentType="application/json",
ContentEncoding="gzip",
)
# We don't need to save the json file locally, but doing so lets us
# track which ones have been uploaded already. We could probably also
# check S3
with open(json_file, "w") as f:
f.write(line_by_line_jsons)
except Timeout:
continue # another process is working on this file
finally:
if lock.is_locked:
lock.release()
except Exception as e:
print(f"Failed to parse and upload json test reports: {e}")

View File

@ -6,7 +6,7 @@ from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING
from typing import Any, TYPE_CHECKING, TypeVar
from typing_extensions import TypeIs
import torch.utils._pytree as python_pytree
@ -24,9 +24,15 @@ if TYPE_CHECKING:
__all__: list[str] = []
_T = TypeVar("_T")
_KT = TypeVar("_KT")
_VT = TypeVar("_VT")
if python_pytree._cxx_pytree_dynamo_traceable:
import optree
import optree._C
import optree.utils
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
@ -600,14 +606,47 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_map_"]
_none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr, attr-defined]
_none_registration = optree.register_pytree_node.get(type(None))
assert _none_registration is not None
@substitute_in_graph( # type: ignore[arg-type]
_none_unflatten,
_none_registration.unflatten_func,
can_constant_fold_through=True,
skip_signature_check=True,
)
def none_unflatten(_: None, children: Iterable[Any], /) -> None:
def none_unflatten(_: None, children: Iterable[_T], /) -> None:
if len(list(children)) != 0:
raise ValueError("Expected no children.")
return None
with optree.dict_insertion_ordered(False, namespace="torch"):
_dict_registration = optree.register_pytree_node.get(dict)
assert _dict_registration is not None
@substitute_in_graph( # type: ignore[arg-type]
_dict_registration.flatten_func,
can_constant_fold_through=True,
skip_signature_check=True,
)
def dict_flatten(
dct: dict[_KT, _VT], /
) -> tuple[list[_VT], tuple[list[_KT], list[_KT]], tuple[_KT, ...]]:
sorted_keys = optree.utils.total_order_sorted(dct)
values = [dct[key] for key in sorted_keys]
original_keys = list(dct)
return values, (original_keys, sorted_keys), tuple(sorted_keys)
@substitute_in_graph( # type: ignore[arg-type]
_dict_registration.unflatten_func,
can_constant_fold_through=True,
skip_signature_check=True,
)
def dict_unflatten(
metadata: tuple[list[_KT], list[_KT]],
values: Iterable[_VT],
/,
) -> dict[_KT, _VT]:
original_keys, sorted_keys = metadata
d = dict.fromkeys(original_keys)
d.update(zip(sorted_keys, values))
return d # type: ignore[return-value]

View File

@ -3320,7 +3320,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, SetVariable)
assert obj.is_mutable()
obj.call_method(self, "add", [v], {})
obj.call_method(self, "add", [v], {}) # type: ignore[arg-type]
def SET_UPDATE(self, inst: Instruction) -> None:
v = self.pop()
@ -3329,7 +3329,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, SetVariable)
assert obj.is_mutable()
obj.call_method(self, "update", [v], {})
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
def LIST_APPEND(self, inst: Instruction) -> None:
v = self.pop()
@ -3637,7 +3637,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg].realize()
assert isinstance(obj, ConstDictVariable)
assert obj.is_mutable()
obj.call_method(self, "update", [v], {})
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
DICT_UPDATE = DICT_MERGE

View File

@ -1991,7 +1991,7 @@ class BuiltinVariable(VariableTracker):
# If the object implements a __getitem__ method, iter(...) will call obj.__getitem__()
# with an integer argument starting at 0, until __getitem__ raises IndexError
ret = variables.UserFunctionVariable(
polyfills.builtins.iter_
polyfills.builtins.iter_ # type: ignore[arg-type]
).call_function(tx, [obj, *args], {})
if args:

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
Dictionary-related variable tracking classes for PyTorch Dynamo.
@ -26,7 +24,7 @@ import inspect
import operator
import types
from collections.abc import Hashable as py_Hashable
from typing import Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING, Union
from torch._subclasses.fake_tensor import is_fake
@ -59,11 +57,13 @@ if TYPE_CHECKING:
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
def was_instancecheck_override(obj):
def was_instancecheck_override(obj: Any) -> bool:
return type(obj).__dict__.get("__instancecheck__", False)
def raise_unhashable(arg, tx=None):
def raise_unhashable(
arg: VariableTracker, tx: Optional["InstructionTranslator"] = None
) -> None:
if tx is None:
from torch._dynamo.symbolic_convert import InstructionTranslator
@ -75,7 +75,7 @@ def raise_unhashable(arg, tx=None):
)
def is_hashable(x):
def is_hashable(x: VariableTracker) -> bool:
# NB - performing isinstance check on a LazVT realizes the VT, accidentally
# inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at
# the underlying value without realizing the VT. Consider updating the
@ -143,7 +143,7 @@ class ConstDictVariable(VariableTracker):
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
"""
def __init__(self, vt) -> None:
def __init__(self, vt: VariableTracker) -> None:
# We specialize SymNodes
vt = specialize_symnode(vt)
# TODO Temporarily remove to figure out what keys are we breaking on
@ -153,7 +153,7 @@ class ConstDictVariable(VariableTracker):
self.vt = vt
@property
def underlying_value(self):
def underlying_value(self) -> Any:
if (
isinstance(self.vt, variables.LazyVariableTracker)
and not self.vt.is_realized()
@ -178,7 +178,8 @@ class ConstDictVariable(VariableTracker):
elif isinstance(self.vt, variables.FrozenDataClassVariable):
Hashable = ConstDictVariable._HashableTracker
fields_values = {
k: Hashable(v).underlying_value for k, v in self.vt.fields.items()
k: Hashable(v).underlying_value
for k, v in self.vt.fields.items() # type: ignore[attr-defined]
}
return variables.FrozenDataClassVariable.HashWrapper(
self.vt.python_type(), fields_values
@ -187,16 +188,16 @@ class ConstDictVariable(VariableTracker):
# The re module in Python 3.13+ has a dictionary (_cache2) with
# an object as key (`class _ZeroSentinel(int): ...`):
# python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual
return self.vt.value
return self.vt.value # type: ignore[attr-defined,union-attr]
else:
x = self.vt.as_python_constant()
return x
def __hash__(self):
def __hash__(self) -> int:
return hash(self.underlying_value)
@staticmethod
def _eq_impl(a, b):
def _eq_impl(a: Any, b: Any) -> bool:
# TODO: Put this in utils and share it between variables/builtin.py and here
type_a, type_b = type(a), type(b)
if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)):
@ -212,7 +213,7 @@ class ConstDictVariable(VariableTracker):
else:
return a == b
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
def __eq__(self, other: object) -> bool:
Hashable = ConstDictVariable._HashableTracker
assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), (
type(other)
@ -226,8 +227,8 @@ class ConstDictVariable(VariableTracker):
def __init__(
self,
items: dict[VariableTracker, VariableTracker],
user_cls=dict,
**kwargs,
user_cls: type = dict,
**kwargs: Any,
) -> None:
# .clone() pass these arguments in kwargs but they're recreated a few
# lines below
@ -247,18 +248,22 @@ class ConstDictVariable(VariableTracker):
for x, v in items.items()
)
def make_hashable(key):
def make_hashable(
key: Union[VariableTracker, "ConstDictVariable._HashableTracker"],
) -> "ConstDictVariable._HashableTracker":
return key if isinstance(key, Hashable) else Hashable(key)
dict_cls = self._get_dict_cls_from_user_cls(user_cls)
self.items = dict_cls({make_hashable(x): v for x, v in items.items()})
# need to reconstruct everything if the dictionary is an intermediate value
# or if a pop/delitem was executed
self.should_reconstruct_all = not is_from_local_source(self.source)
self.should_reconstruct_all = (
not is_from_local_source(self.source) if self.source else True
)
self.original_items = items.copy()
self.user_cls = user_cls
def _get_dict_cls_from_user_cls(self, user_cls):
def _get_dict_cls_from_user_cls(self, user_cls: type) -> type:
accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict)
# avoid executing user code if user_cls is a dict subclass
@ -277,10 +282,10 @@ class ConstDictVariable(VariableTracker):
dict_cls = dict
return dict_cls
def as_proxy(self):
def as_proxy(self) -> dict[Any, Any]:
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
def debug_repr(self):
def debug_repr(self) -> str:
return (
"{"
+ ", ".join(
@ -289,20 +294,20 @@ class ConstDictVariable(VariableTracker):
+ "}"
)
def as_python_constant(self):
def as_python_constant(self) -> dict[Any, Any]:
return {
k.vt.as_python_constant(): v.as_python_constant()
for k, v in self.items.items()
}
def keys_as_python_constant(self):
def keys_as_python_constant(self) -> dict[Any, VariableTracker]:
self.install_dict_keys_match_guard()
return {k.vt.as_python_constant(): v for k, v in self.items.items()}
def python_type(self):
def python_type(self) -> type:
return self.user_cls
def __contains__(self, vt) -> bool:
def __contains__(self, vt: VariableTracker) -> bool:
assert isinstance(vt, VariableTracker)
Hashable = ConstDictVariable._HashableTracker
return (
@ -322,13 +327,15 @@ class ConstDictVariable(VariableTracker):
for key, value in self.items.items()
)
def is_new_item(self, value, other):
def is_new_item(
self, value: Optional[VariableTracker], other: VariableTracker
) -> bool:
# compare the id of the realized values if both values are not lazy VTs
if value and value.is_realized() and other.is_realized():
return id(value.realize()) != id(other.realize())
return id(value) != id(other)
def reconstruct_kvs_into_new_dict(self, codegen):
def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None:
# Build a dictionary that contains the keys and values.
num_args = 0
for key, value in self.items.items():
@ -340,7 +347,7 @@ class ConstDictVariable(VariableTracker):
num_args += 1
codegen.append_output(create_instruction("BUILD_MAP", arg=num_args))
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
if self.user_cls is collections.OrderedDict:
# emit `OrderedDict(constructed_dict)`
codegen.add_push_null(
@ -358,19 +365,21 @@ class ConstDictVariable(VariableTracker):
def getitem_const_raise_exception_if_absent(
self, tx: "InstructionTranslator", arg: VariableTracker
):
) -> VariableTracker:
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
raise_observed_exception(KeyError, tx)
return self.items[key]
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
msg = f"Dictionary key {arg.value} not found during tracing"
msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined]
unimplemented_v2(
gb_type="key not found in dict",
context=f"Key {arg.value}",
context=f"Key {arg.value}", # type: ignore[attr-defined]
explanation=msg,
hints=[
"Check if the key exists in the dictionary before accessing it.",
@ -379,13 +388,13 @@ class ConstDictVariable(VariableTracker):
)
return self.items[key]
def maybe_getitem_const(self, arg: VariableTracker):
def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]:
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
return None
return self.items[key]
def realize_key_vt(self, arg: VariableTracker):
def realize_key_vt(self, arg: VariableTracker) -> None:
# Realize the LazyVT on a particular index
assert arg in self
key = ConstDictVariable._HashableTracker(arg)
@ -394,11 +403,13 @@ class ConstDictVariable(VariableTracker):
if isinstance(original_key_vt, variables.LazyVariableTracker):
original_key_vt.realize()
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
if self.source:
install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH))
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
# Key guarding - These are the cases to consider
# 1) The dict has been mutated. In this case, we would have already
# inserted a DICT_KEYS_MATCH guard, so we can skip.
@ -439,11 +450,11 @@ class ConstDictVariable(VariableTracker):
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
# NB - Both key and value are LazyVariableTrackers in the beginning. So,
# we have to insert guards when a dict method is accessed. For this to
# be simple, we are conservative and overguard. We skip guard only for
@ -462,7 +473,7 @@ class ConstDictVariable(VariableTracker):
tx, *args, **kwargs
)
tx.output.side_effects.mutation(self)
self.items.update(temp_dict_vt.items)
self.items.update(temp_dict_vt.items) # type: ignore[attr-defined]
return ConstantVariable.create(None)
elif name == "__getitem__":
# Key guarding - Nothing to do. LazyVT for value will take care.
@ -526,7 +537,7 @@ class ConstDictVariable(VariableTracker):
return ConstantVariable.create(len(self.items))
elif name == "__setitem__" and self.is_mutable():
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
self.install_dict_keys_match_guard()
if kwargs or len(args) != 2:
@ -550,7 +561,7 @@ class ConstDictVariable(VariableTracker):
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
if args[0] not in self:
self.install_dict_contains_guard(tx, args)
@ -565,7 +576,7 @@ class ConstDictVariable(VariableTracker):
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
if args[0] not in self:
# missing item, return the default value. Install no DICT_CONTAINS guard.
@ -599,7 +610,7 @@ class ConstDictVariable(VariableTracker):
last = v.value
else:
raise_args_mismatch(tx, name)
k, v = self.items.popitem(last=last)
k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined]
else:
k, v = self.items.popitem()
@ -632,17 +643,17 @@ class ConstDictVariable(VariableTracker):
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard()
dict_vt = args[0]
dict_vt: ConstDictVariable = args[0]
else:
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
self.items.update(dict_vt.items)
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment]
self.items.update(dict_vt.items) # type: ignore[attr-defined]
if has_kwargs:
# Handle kwargs
kwargs = {
kwargs_hashable = {
Hashable(ConstantVariable.create(k)): v
for k, v in kwargs.items()
}
self.items.update(kwargs)
self.items.update(kwargs_hashable)
return ConstantVariable.create(None)
else:
return super().call_method(tx, name, args, kwargs)
@ -656,7 +667,7 @@ class ConstDictVariable(VariableTracker):
)
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
self.install_dict_contains_guard(tx, args)
contains = args[0] in self
@ -671,7 +682,7 @@ class ConstDictVariable(VariableTracker):
)
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
self.install_dict_keys_match_guard()
if kwargs or len(args) > 2:
@ -707,7 +718,7 @@ class ConstDictVariable(VariableTracker):
and "last" in kwargs
and isinstance(kwargs["last"], ConstantVariable)
):
last = kwargs.get("last").value
last = kwargs.get("last").value # type: ignore[union-attr]
key = Hashable(args[0])
self.items.move_to_end(key, last=last)
@ -723,7 +734,7 @@ class ConstDictVariable(VariableTracker):
)
elif name == "__ne__":
return ConstantVariable.create(
not self.call_method(tx, "__eq__", args, kwargs).value
not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined]
)
elif name == "__or__":
if len(args) != 1:
@ -750,14 +761,14 @@ class ConstDictVariable(VariableTracker):
if not istype(
other, (ConstDictVariable, variables.UserDefinedDictVariable)
):
msg = (
err_msg = (
f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
f"and '{other.python_type().__name__}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
raise_observed_exception(TypeError, tx, args=[err_msg])
# OrderedDict overloads __ror__
ts = {self.user_cls, other.user_cls}
ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined]
user_cls = (
collections.OrderedDict
if any(issubclass(t, collections.OrderedDict) for t in ts)
@ -774,8 +785,8 @@ class ConstDictVariable(VariableTracker):
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard()
new_dict_vt.items.update(args[0].items)
args[0].install_dict_keys_match_guard() # type: ignore[attr-defined]
new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined]
return new_dict_vt
elif name == "__ior__":
self.call_method(tx, "update", args, kwargs)
@ -789,11 +800,13 @@ class ConstDictVariable(VariableTracker):
else:
return super().call_method(tx, name, args, kwargs)
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
self.install_dict_keys_match_guard()
return [x.vt for x in self.items.keys()]
def call_obj_hasattr(self, tx, name):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
# dict not allow setting arbitrary attributes. OrderedDict and
# defaultdict allow arbitrary setattr, but not deletion of default attrs
if any(
@ -816,25 +829,25 @@ class ConstDictVariable(VariableTracker):
],
)
def clone(self, **kwargs):
def clone(self, **kwargs: Any) -> VariableTracker:
self.install_dict_keys_match_guard()
return super().clone(**kwargs)
class MappingProxyVariable(VariableTracker):
# proxies to the original dict_vt
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
super().__init__(**kwargs)
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
def python_type(self):
def python_type(self) -> type:
return types.MappingProxyType
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
return self.dv_dict.unpack_var_sequence(tx)
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
# load types.MappingProxyType
if self.source:
msg = (
@ -863,11 +876,11 @@ class MappingProxyVariable(VariableTracker):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if self.source and tx.output.side_effects.has_existing_dict_mutation():
msg = (
"A dict has been modified while we have an existing mappingproxy object. "
@ -892,7 +905,7 @@ class MappingProxyVariable(VariableTracker):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is types.MappingProxyType:
return ConstantVariable.create(name in types.MappingProxyType.__dict__)
return super().call_obj_hasattr(tx, name)
@ -900,35 +913,44 @@ class MappingProxyVariable(VariableTracker):
class NNModuleHooksDictVariable(ConstDictVariable):
# Special class to avoid adding any guards on the nn module hook ids.
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
pass
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
pass
class DefaultDictVariable(ConstDictVariable):
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
def __init__(
self,
items: dict[VariableTracker, VariableTracker],
user_cls: type,
default_factory: Optional[VariableTracker] = None,
**kwargs: Any,
) -> None:
super().__init__(items, user_cls, **kwargs)
assert user_cls is collections.defaultdict
if default_factory is None:
default_factory = ConstantVariable.create(None)
self.default_factory = default_factory
def is_python_constant(self):
def is_python_constant(self) -> bool:
# Return false for unsupported defaults. This ensures that a bad handler
# path is not taken in BuiltinVariable for getitem.
if self.default_factory not in [list, tuple, dict] and not self.items:
return False
return super().is_python_constant()
def debug_repr(self):
def debug_repr(self) -> str:
assert self.default_factory is not None
return (
f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})"
)
@staticmethod
def is_supported_arg(arg):
def is_supported_arg(arg: VariableTracker) -> bool:
if isinstance(arg, variables.BuiltinVariable):
return arg.fn in (list, tuple, dict, set)
else:
@ -942,11 +964,11 @@ class DefaultDictVariable(ConstDictVariable):
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__getitem__":
if len(args) != 1:
raise_args_mismatch(tx, name, "1 args", f"{len(args)} args")
@ -962,13 +984,13 @@ class DefaultDictVariable(ConstDictVariable):
else:
default_var = self.default_factory.call_function(tx, [], {})
super().call_method(
tx, "__setitem__", (args[0], default_var), kwargs
tx, "__setitem__", [args[0], default_var], kwargs
)
return default_var
else:
return super().call_method(tx, name, args, kwargs)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen") -> None:
# emit `defaultdict(default_factory, new_dict)`
codegen.add_push_null(
lambda: codegen.extend_output(
@ -994,40 +1016,48 @@ class SetVariable(ConstDictVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
**kwargs: Any,
) -> None:
# pyrefly: ignore[bad-assignment]
items = dict.fromkeys(items, SetVariable._default_value())
# pyrefly: ignore[bad-argument-type]
super().__init__(items, **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
if not self.items:
return "set()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self):
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
return set(self.items.keys())
@staticmethod
def _default_value():
def _default_value() -> VariableTracker:
# Variable to fill in he keys of the dictionary
return ConstantVariable.create(None)
def as_proxy(self):
def as_proxy(self) -> Any:
return {k.vt.as_proxy() for k in self.set_items}
def python_type(self):
def python_type(self) -> type:
return set
def as_python_constant(self):
def as_python_constant(self) -> Any:
return {k.vt.as_python_constant() for k in self.set_items}
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.foreach([x.vt for x in self.set_items])
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
def _fast_set_method(self, tx, fn, args, kwargs):
def _fast_set_method(
self,
tx: "InstructionTranslator",
fn: Any,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
try:
res = fn(
*[x.as_python_constant() for x in [self, *args]],
@ -1037,15 +1067,16 @@ class SetVariable(ConstDictVariable):
raise_observed_exception(
type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
)
# pyrefly: ignore[unbound-name]
return VariableTracker.build(tx, res)
def call_method(
self,
tx,
name,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
) -> VariableTracker:
# We forward the calls to the dictionary model
from ..utils import check_constant_args
@ -1065,10 +1096,10 @@ class SetVariable(ConstDictVariable):
return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
if name == "__init__":
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs)
tx.output.side_effects.mutation(self)
self.items.clear()
self.items.update(temp_set_vt.items)
self.items.update(temp_set_vt.items) # type: ignore[attr-defined]
return ConstantVariable.create(None)
elif name == "add":
if kwargs or len(args) != 1:
@ -1079,7 +1110,7 @@ class SetVariable(ConstDictVariable):
f"{len(args)} args and {len(kwargs)} kwargs",
)
name = "__setitem__"
args = (args[0], SetVariable._default_value())
args = [args[0], SetVariable._default_value()]
elif name == "pop":
if kwargs or args:
raise_args_mismatch(
@ -1090,12 +1121,14 @@ class SetVariable(ConstDictVariable):
)
# Choose an item at random and pop it via the Dict.pop method
try:
result = self.set_items.pop().vt
result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment]
except KeyError as e:
raise_observed_exception(
KeyError, tx, args=list(map(ConstantVariable.create, e.args))
)
super().call_method(tx, name, (result,), kwargs)
# pyrefly: ignore[unbound-name]
super().call_method(tx, name, [result], kwargs)
# pyrefly: ignore[unbound-name]
return result
elif name == "isdisjoint":
if kwargs or len(args) != 1:
@ -1217,6 +1250,7 @@ class SetVariable(ConstDictVariable):
f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
assert m is not None
return self.call_method(tx, m, args, kwargs)
elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"):
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
@ -1230,29 +1264,34 @@ class SetVariable(ConstDictVariable):
"__ixor__": "symmetric_difference_update",
"__isub__": "difference_update",
}.get(name)
assert m is not None
self.call_method(tx, m, args, kwargs)
return self
elif name == "__eq__":
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
return ConstantVariable.create(False)
r = self.call_method(tx, "symmetric_difference", args, kwargs)
return ConstantVariable.create(len(r.set_items) == 0)
return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined]
elif name in cmp_name_to_op_mapping:
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
return ConstantVariable.create(NotImplemented)
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
)
return super().call_method(tx, name, args, kwargs)
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
raise RuntimeError("Illegal to getitem on a set")
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
super().install_dict_contains_guard(tx, args)
@ -1260,27 +1299,27 @@ class FrozensetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
**kwargs: Any,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
if not self.items:
return "frozenset()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self):
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
return self.items.keys()
def python_type(self):
def python_type(self) -> type:
return frozenset
def as_python_constant(self):
def as_python_constant(self) -> Any:
return frozenset({k.vt.as_python_constant() for k in self.set_items})
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.foreach([x.vt for x in self.set_items])
codegen.add_push_null(
lambda: codegen.extend_output(
@ -1293,11 +1332,11 @@ class FrozensetVariable(SetVariable):
def call_method(
self,
tx,
name,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
) -> VariableTracker:
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a frozenset")
elif name == "__init__":
@ -1316,7 +1355,7 @@ class FrozensetVariable(SetVariable):
"symmetric_difference",
):
r = super().call_method(tx, name, args, kwargs)
return FrozensetVariable(r.items)
return FrozensetVariable(r.items) # type: ignore[attr-defined]
return super().call_method(tx, name, args, kwargs)
@ -1324,11 +1363,11 @@ class DictKeySetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
**kwargs: Any,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
if not self.items:
return "dict_keys([])"
else:
@ -1338,33 +1377,35 @@ class DictKeySetVariable(SetVariable):
+ "])"
)
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
# Already EQUALS_MATCH guarded
pass
@property
def set_items(self):
def set_items(self) -> Any:
return self.items
def python_type(self):
def python_type(self) -> type:
return dict_keys
def as_python_constant(self):
def as_python_constant(self) -> Any:
return dict.fromkeys(
{k.vt.as_python_constant() for k in self.set_items}, None
).keys()
def call_method(
self,
tx,
name,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
) -> VariableTracker:
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a dict_keys")
return super().call_method(tx, name, args, kwargs)
@ -1379,42 +1420,47 @@ class DictViewVariable(VariableTracker):
kv: Optional[str] = None
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
super().__init__(**kwargs)
assert self.kv in ("keys", "values", "items")
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
@property
def view_items(self):
def view_items(self) -> Any:
assert self.kv is not None
return getattr(self.dv_dict.items, self.kv)()
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
# Returns an iterable of the unpacked items
# Implement in the subclasses
raise NotImplementedError
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
return self.view_items_vt
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
assert self.kv is not None
codegen(self.dv_dict)
codegen.load_method(self.kv)
codegen.call_method(0)
def call_obj_hasattr(self, tx, name):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
assert self.kv is not None
if name in self.python_type().__dict__:
return ConstantVariable.create(True)
return ConstantVariable.create(False)
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__len__":
return self.dv_dict.call_method(tx, name, args, kwargs)
elif name == "__iter__":
@ -1428,24 +1474,24 @@ class DictKeysVariable(DictViewVariable):
kv = "keys"
@property
def set_items(self):
def set_items(self) -> set[VariableTracker]:
return set(self.view_items)
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
# Returns an iterable of the unpacked items
return [x.vt for x in self.view_items]
def python_type(self):
def python_type(self) -> type:
return dict_keys
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__contains__":
return self.dv_dict.call_method(tx, name, args, kwargs)
elif name in (
@ -1460,13 +1506,13 @@ class DictKeysVariable(DictViewVariable):
):
# These methods always returns a set
m = getattr(self.set_items, name)
r = m(args[0].set_items)
r = m(args[0].set_items) # type: ignore[attr-defined]
return SetVariable(r)
if name in cmp_name_to_op_mapping:
if not isinstance(args[0], (SetVariable, DictKeysVariable)):
return ConstantVariable.create(NotImplemented)
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
)
return super().call_method(tx, name, args, kwargs)
@ -1476,10 +1522,10 @@ class DictValuesVariable(DictViewVariable):
kv = "values"
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
return list(self.view_items)
def python_type(self):
def python_type(self) -> type:
return dict_values
@ -1487,14 +1533,20 @@ class DictItemsVariable(DictViewVariable):
kv = "items"
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
# Returns an iterable of the unpacked items
return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items]
def python_type(self):
def python_type(self) -> type:
return dict_items
def call_method(self, tx, name, args, kwargs):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
# TODO(guilhermeleobas): This should actually check if args[0]
# implements the mapping protocol.
if name == "__eq__":

File diff suppressed because it is too large Load Diff

View File

@ -586,7 +586,7 @@ class FilterVariable(IteratorVariable):
else:
res = self.fn.call_function(tx, [item], {})
pred_res = variables.UserFunctionVariable(
polyfills.predicate
polyfills.predicate # type: ignore[arg-type]
).call_function(tx, [res], {})
if pred_res.as_python_constant():
return item

View File

@ -472,7 +472,12 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
)
elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined]
name_to_arg_map = bind_args_cached(
self.value, tx, self.source, args, kwargs
# pyrefly: ignore[bad-argument-type]
self.value,
tx,
self.source,
args,
kwargs,
)
backends = name_to_arg_map["backends"].as_python_constant()
set_priority = name_to_arg_map["set_priority"].as_python_constant()
@ -1349,7 +1354,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
packed_input_vt = TupleVariable.build(
tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs))
)
out_vt = variables.UserFunctionVariable(tree_flatten).call_function(
out_vt = variables.UserFunctionVariable(tree_flatten).call_function( # type: ignore[arg-type]
tx, [packed_input_vt], {}
)
assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2

View File

@ -2970,6 +2970,12 @@ class CppPythonBindingsCodeCache(CppCodeCache):
throw std::runtime_error("expected int arg");
return reinterpret_cast<uintptr_t>(result);
}}
template <> inline float parse_arg<float>(PyObject* args, size_t n) {{
auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n));
if(unlikely(result == -1.0 && PyErr_Occurred()))
throw std::runtime_error("expected float arg");
return static_cast<float>(result);
}}
{extra_parse_arg}

View File

@ -1732,9 +1732,15 @@ class KernelArgs:
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"{cpp_dtype}*")
for outer, inner in self.sizevars.items():
arg_defs.append(f"const {INDEX_TYPE} {inner}")
if isinstance(outer, sympy.Symbol) and symbol_is_type(
outer, (SymT.UNBACKED_FLOAT)
):
arg_defs.append(f"const float {inner}")
arg_types.append("const float")
else:
arg_defs.append(f"const {INDEX_TYPE} {inner}")
arg_types.append(f"const {INDEX_TYPE}")
call_args.append(self.wrap_size_arg(outer))
arg_types.append(f"const {INDEX_TYPE}")
if V.graph.wrapper_code:
V.graph.wrapper_code.ensure_size_computed(outer)
assert not self.workspace_args, "Workspace not supported on CPU "
@ -2353,6 +2359,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
SymT.UNBACKED_INT,
SymT.SIZE,
SymT.PRECOMPUTED_SIZE,
SymT.UNBACKED_FLOAT,
),
)
}

View File

@ -4,6 +4,7 @@ from typing import Any, Optional
import sympy
import torch
from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import config
from ..runtime.hints import AttrsDescriptorWrapper
@ -71,6 +72,10 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
return "constexpr"
elif isinstance(arg.expr, (float, sympy.Float)):
return "fp32"
elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type(
arg.expr, (SymT.UNBACKED_FLOAT)
):
return "fp32"
elif isinstance(arg.expr, bool):
return "i1"

View File

@ -546,10 +546,6 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
).upper() # type: ignore[assignment]
cutedsl_enable_autotuning: bool = (
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
)
# DEPRECATED. This setting is ignored.
autotune_fallback_to_aten = False
@ -678,6 +674,17 @@ loop_ordering_after_fusion: bool = (
== "1"
)
# When trying to fuse two nodes, one with:
# a[contiguous_writes] = fn(...)
# and another node:
# b[contiguous_writes] = a[discontiguous_reads]
# If b is unary, and we can figure out an inverse formula for
# discontiguous writes, invert b as :
# b[inverse(discontiguous_writes)] = a[contiguous_reads]
# so that the nodes can fuse. for more details: https://gist.github.com/eellison/6f9f4a7ec10a860150b15b719f9285a9
loop_index_inversion_in_fusion: bool = True
# If fusing two nodes only save less then score_fusion_memory_threshold memory,
# we should not bother fusing the nodes.
#

View File

@ -0,0 +1,208 @@
from dataclasses import dataclass
from typing import Optional
import sympy
from torch._inductor.utils import _IntLike, argsort_sym
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from .virtualized import V
def static_eq(a: _IntLike, b: _IntLike) -> bool:
return V.graph.sizevars.statically_known_equals(a, b)
@dataclass
class Term:
coefficient: _IntLike
range: Optional[_IntLike] # None for unbounded
original_expr: sympy.Expr
reconstruction_multiplier: _IntLike # The multiplier needed for reconstruction
def generate_inverse_formula(
expr: sympy.Expr, var: sympy.Symbol
) -> Optional[sympy.Expr]:
"""
Analyze an expression to see if it matches a specific invertible pattern that we
know how to reverse.
We're looking for expressions that are sums of terms where each term extracts a
distinct bounded range from the input variable, like:
y = c₀*a₀ + c₁*a₁ + c₂*a₂ + ... + cₙ*aₙ
where each aᵢ must be one of these specific patterns:
- ModularIndexing(var, divisor, modulo)
- FloorDiv(ModularIndexing(var, 1, modulo), divisor)
- FloorDiv(var, divisor)
- var (the variable itself)
The key pattern we need is:
- Coefficients are strictly decreasing: c₀ > c₁ > c₂ > ... > cₙ
- Each coefficient matches the product of ranges of later terms (mixed-radix property)
- Each term extracts a bounded range, creating non-overlapping "slots"
If we find this pattern, we can generate the reconstruction transformation that
decomposes the variable and rebuilds it using the correct multipliers.
EXAMPLE:
Input: 100*((p//100)) + 10*((p%100)//10) + (p%10)
Returns the reconstruction expression:
remainder₀ = p
component₀ = remainder₀ // 100 # hundreds digit
remainder₁ = remainder₀ % 100
component₁ = remainder₁ // 10 # tens digit
remainder₂ = remainder₁ % 10
component₂ = remainder₂ # ones digit
result = component₀*100 + component₁*10 + component₂*1
This decomposes p into its components and rebuilds it using the original
multipliers, which should equal the input expression.
Args:
expr: Expression to analyze (sum of terms with ModularIndexing, FloorDiv, etc.)
var: The variable being decomposed
Returns:
None if not invertible, or the reconstruction expression
References:
Mixed-radix systems: https://en.wikipedia.org/wiki/Mixed_radix
"""
# Step 1: Parse all terms
terms = parse_terms(expr, var)
if not terms:
return None
# Step 2: Sort by coefficient (descending)
coeffs = [t.coefficient for t in terms]
idxs = reversed(argsort_sym(V.graph.sizevars.shape_env, coeffs))
terms = [terms[i] for i in idxs]
# Step 3: Check invertibility conditions
if not check_invertibility(terms):
return None
return generate_reconstruction_expr(terms, var)
def parse_terms(expr: sympy.Expr, var: sympy.Symbol) -> Optional[list[Term]]:
"""Parse expression into terms."""
if not isinstance(expr, sympy.Add):
# Single term
term = parse_single_term(expr, var)
return [term] if term else []
terms = []
for arg in expr.args:
term = parse_single_term(arg, var)
if term:
terms.append(term)
else:
return None # If any term fails to parse, fail completely
return terms
def parse_single_term(term: sympy.Expr, var: sympy.Symbol) -> Optional[Term]:
"""Parse a single term and extract coefficient, range, and reconstruction multiplier."""
# Extract coefficient and expression parts
coefficient, expr_parts = term.as_coeff_mul()
if len(expr_parts) == 0:
# Pure constant term
return Term(
coefficient=coefficient,
range=1,
original_expr=1,
reconstruction_multiplier=0,
)
elif len(expr_parts) == 1:
expr = expr_parts[0]
else:
# Multiple non-constant factors, too complex
return None
# Now determine the range and reconstruction multiplier
range_val, reconstruction_multiplier = analyze_expression_properties(expr, var)
if reconstruction_multiplier is None:
return None
return Term(
coefficient=coefficient,
range=range_val,
original_expr=expr,
reconstruction_multiplier=reconstruction_multiplier,
)
def analyze_expression_properties(
expr: sympy.Expr, var: sympy.Symbol
) -> tuple[Optional[_IntLike], Optional[_IntLike]]:
"""Analyze an expression to determine its range and reconstruction multiplier."""
# ModularIndexing(var, divisor, modulo) = (var // divisor) % modulo
if isinstance(expr, ModularIndexing):
x, div, mod = expr.args
if static_eq(x, var):
return mod, div # Range is mod, multiplier is div
# FloorDiv cases
if isinstance(expr, FloorDiv):
base, divisor = expr.args
# FloorDiv(ModularIndexing(var, 1, mod), div) = (var % mod) // div
if isinstance(base, ModularIndexing):
x, inner_div, mod = base.args
if static_eq(x, var) and static_eq(inner_div, 1):
range_val = FloorDiv(mod, divisor)
return range_val, divisor # Range is mod//div, multiplier is div
# FloorDiv(var, divisor) = var // divisor (unbounded)
elif static_eq(base, var):
return None, divisor # Unbounded range, multiplier is div
return None, None
def check_invertibility(terms: list[Term]) -> bool:
"""Check if the terms represent an invertible transformation."""
if not terms:
return False
# Coefficients must be strictly decreasing
coeffs = [t.coefficient for t in terms]
if argsort_sym(V.graph.sizevars.shape_env, coeffs) != list(
reversed(range(len(coeffs)))
):
return False
# Check mixed-radix property: each coeff[i] = coeff[i+1] * range[i+1]
expected_coeff = 1
for term in reversed(terms):
if not static_eq(term.coefficient, expected_coeff):
return False
if term.range is not None:
expected_coeff *= term.range
return True
def generate_reconstruction_expr(terms: list[Term], var: sympy.Symbol) -> sympy.Expr:
y = var
reconstruction = sympy.S.Zero
remainder = y
for i, term in enumerate(terms):
if i < len(terms) - 1:
component = FloorDiv(remainder, term.coefficient)
remainder = ModularIndexing(remainder, 1, term.coefficient)
else:
# Last term should also divide by its coefficient
component = FloorDiv(remainder, term.coefficient)
reconstruction += component * term.reconstruction_multiplier
return reconstruction

View File

@ -1,8 +1,6 @@
# mypy: allow-untyped-defs
import logging
from collections.abc import Sequence
from functools import partial
from pathlib import Path
from typing import Any
import torch
@ -14,7 +12,6 @@ from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
from .. import config
from ..codegen.wrapper import PythonWrapperCodegen
from ..ir import _IntLike, Layout, TensorBox
from ..utils import load_template
log = logging.getLogger(__name__)
@ -257,7 +254,3 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
return False
return True
_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)

View File

@ -1,11 +1,10 @@
# mypy: allow-untyped-defs
import logging
from dataclasses import asdict, dataclass
from dataclasses import dataclass
from typing import Any, Optional
import torch
from torch._dynamo.utils import counters
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
from torch._inductor.runtime.triton_compat import tl
from torch._inductor.virtualized import V
from torch.utils._triton import has_triton
@ -19,25 +18,19 @@ from ..select_algorithm import (
TritonTemplate,
)
from ..utils import (
ensure_cute_available,
get_gpu_shared_memory,
get_num_sms,
has_free_symbols,
use_aten_gemm_kernels,
use_blackwell_cutedsl_grouped_mm,
use_triton_template,
)
from .mm_common import (
_is_static_problem,
check_supported_striding,
load_kernel_template,
persistent_grouped_mm_grid,
)
if ensure_cute_available():
from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
log = logging.getLogger(__name__)
aten = torch.ops.aten
@ -520,11 +513,6 @@ triton_scaled_grouped_mm_template = TritonTemplate(
source=triton_grouped_mm_source,
)
cutedsl_grouped_mm_template = CuteDSLTemplate(
name="grouped_gemm_cutedsl",
source=load_kernel_template("cutedsl_mm_grouped"),
)
def grouped_mm_args(
mat1: TensorBox,
@ -726,44 +714,43 @@ def _tuned_grouped_mm_common(
# Checking only for the equality of corresponding dims of
# multiplicands here, relying on meta function checks for
# everything else.
if len(m1_size) == 2:
if len(m2_size) == 2:
m, k1 = m1_size
k2, _ = m2_size
# pyrefly: ignore [missing-attribute]
g = offs.get_size()[0]
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, True
else:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, False
else:
if len(m2_size) == 2:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
g2, m, k1 = m1_size
k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, True
else:
g1, m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, False
if (
is_nonzero
and use_triton_template(layout)
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
):
scaled = scale_a is not None
if len(m1_size) == 2:
if len(m2_size) == 2:
m, k1 = m1_size
k2, _ = m2_size
# pyrefly: ignore [missing-attribute]
g = offs.get_size()[0]
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, True
else:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, False
else:
if len(m2_size) == 2:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
g2, m, k1 = m1_size
k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, True
else:
g1, m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, False
a_is_k_major = mat_a.get_stride()[-1] == 1
b_is_k_major = mat_b.get_stride()[-2] == 1
@ -801,22 +788,6 @@ def _tuned_grouped_mm_common(
**config.kwargs,
)
if use_blackwell_cutedsl_grouped_mm(
mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result
):
for config in get_groupgemm_configs():
kwargs = dict(
ACC_DTYPE="cutlass.Float32",
)
cutedsl_grouped_mm_template.maybe_append_choice(
choices,
input_nodes=input_nodes,
layout=layout,
**kwargs,
**asdict(config),
)
input_gen_fns = {
4: lambda x: create_offsets(
x, m1_size, m2_size, offs.get_size() if offs is not None else None

View File

@ -1,333 +0,0 @@
import functools
from torch._inductor.runtime.runtime_utils import ceildiv
from cutlass.utils import TensorMapUpdateMode
{{gen_defines()}}
# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ----
from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import (
GroupedGemmKernel,
)
# Note about caching:
# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor
# maintains its own local caching system. At this stage, all compile-time
# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel
# name itself ({{kernel_name}}) are permanently baked into the file, so they
# do not need to be included in any cache key.
#
# The caching mechanism is split into two levels:
#
# 1. prep_cache
# Caches the compiled executor for build_group_ptrs_from_bases(). This
# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C,
# and can therefore be safely reused across runs with different group
# partitioning (`offs`).
#
# 2. gemm_cache
# Caches the compiled Grouped GEMM executor. Its key extends the prep
# cache key with hardware- and grid-specific parameters:
# (prep_cache_key, max_active_clusters, total_num_clusters).
# This is necessary because different `offs` tensors can change the
# per-group problem sizes and thus alter `total_num_clusters`, which in
# turn changes the grid shape and persistent scheduler configuration.
# Kernels compiled for one grid cannot be safely reused for another.
#
#
# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically,
# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead,
# despite depending only on the GPU type. We cache this function to mitigate
# redundant recompiles even when shape/stride/dtype cache misses force kernel
# regeneration. A follow-up study will investigate the root cause.
prep_cache = {}
gemm_cache = {}
@functools.lru_cache
def get_hardware_info():
hw = cutlass.utils.HardwareInfo()
sm_count = hw.get_max_active_clusters(1)
max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N)
return (sm_count, max_active_clusters)
def get_prep_cache_key(input_a, input_b, output):
"""
Returns a tuple key for caching the preprocessing kernel executor based on kernel name,
shapes, strides, and dtypes of input/output tensors.
"""
return (
tuple(input_a.shape),
tuple(input_a.stride()),
input_a.dtype,
tuple(input_b.shape),
tuple(input_b.stride()),
input_b.dtype,
tuple(output.shape),
tuple(output.stride()),
output.dtype,
)
def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters):
"""
Returns a tuple key for caching the gemm kernel executor by extending the
prep cache key with hardware- and grid-specific parameters.
"""
return (
prep_cache_key,
max_active_clusters,
total_num_clusters,
)
@cute.kernel
def build_group_ptrs_from_bases_kernel(
base_A_u64: cutlass.Int64, # device addr of input_a (bytes)
base_B_u64: cutlass.Int64, # device addr of input_b (bytes)
base_C_u64: cutlass.Int64, # device addr of Output (bytes)
offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative
K: cutlass.Constexpr,
N: cutlass.Constexpr,
sizeof_element: cutlass.Int32, # bytes
# -------- STRIDES (in ELEMENTS) --------
stride_A_m_elems: cutlass.Constexpr, # A.stride(0)
stride_A_k_elems: cutlass.Constexpr, # A.stride(1)
stride_B0_elems: cutlass.Constexpr, # B.stride(0)
stride_Bk_elems: cutlass.Constexpr, # B.stride(1)
stride_Bn_elems: cutlass.Constexpr, # B.stride(2)
stride_C_m_elems: cutlass.Constexpr, # C.stride(0)
stride_C_n_elems: cutlass.Constexpr, # C.stride(1)
# -------- OUTPUTS --------
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr)
out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1)
out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]]
):
tidx, _, _ = cute.arch.thread_idx()
g = tidx
m_beg_i32 = 0
if g > 0:
m_beg_i32 = offs[g - 1]
m_end_i32 = offs[g]
m_g_i32 = m_end_i32 - m_beg_i32
a_byte_off = (
cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element)
)
c_byte_off = (
cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element)
)
b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element)
# ---- pointers ----
out_ptrs[g, 0] = base_A_u64 + a_byte_off
out_ptrs[g, 1] = base_B_u64 + b_byte_off
out_ptrs[g, 2] = base_C_u64 + c_byte_off
# ---- (m, n, k, 1) ----
out_problem[g, 0] = m_g_i32
out_problem[g, 1] = N
out_problem[g, 2] = K
out_problem[g, 3] = cutlass.Int32(1)
# ---- strides ----
out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems)
out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems)
out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems)
out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems)
out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems)
out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems)
@cute.jit
def launch_build_group_ptrs_from_bases(
base_A_u64: cutlass.Int64,
base_B_u64: cutlass.Int64,
base_C_u64: cutlass.Int64,
offs: cute.Tensor,
G: cutlass.Constexpr,
K: cutlass.Constexpr,
N: cutlass.Constexpr,
sizeof_element: cutlass.Constexpr,
stride_A_m_elems: cutlass.Constexpr,
stride_A_k_elems: cutlass.Constexpr,
stride_B0_elems: cutlass.Constexpr,
stride_Bk_elems: cutlass.Constexpr,
stride_Bn_elems: cutlass.Constexpr,
stride_C_m_elems: cutlass.Constexpr,
stride_C_n_elems: cutlass.Constexpr,
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64
out_problem: cute.Tensor, # [G,4] cutlass.Int32
out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32
stream: cuda.CUstream,
):
build_group_ptrs_from_bases_kernel(
base_A_u64,
base_B_u64,
base_C_u64,
offs,
K,
N,
sizeof_element,
stride_A_m_elems,
stride_A_k_elems,
stride_B0_elems,
stride_Bk_elems,
stride_Bn_elems,
stride_C_m_elems,
stride_C_n_elems,
out_ptrs,
out_problem,
out_strides_abc,
).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream)
{{def_kernel("input_a", "input_b", "input_a_offs")}}
stream = cuda.CUstream(stream)
input_b = input_b.transpose(1, 2)
sumM, K = input_a.shape
G, N, Kb = input_b.shape
dev = input_a.device
base_A_u64 = int(input_a.data_ptr())
base_B_u64 = int(input_b.data_ptr())
base_C_u64 = int({{get_output()}}.data_ptr())
ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64)
probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32)
strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32)
ptrs = from_dlpack(ptrs_t)
probs = from_dlpack(probs_t)
strides = from_dlpack(strides_t)
prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}})
prep_executor = prep_cache.get(prep_cache_key)
if prep_executor is None:
sizeof_element = int(input_a.element_size())
sA_m, sA_k = map(int, input_a.stride())
sB_0, sB_n, sB_k = map(int, input_b.stride())
sC_m, sC_n = map(int, {{get_output()}}.stride())
prep_executor = cute.compile(
launch_build_group_ptrs_from_bases,
base_A_u64=base_A_u64,
base_B_u64=base_B_u64,
base_C_u64=base_C_u64,
offs=from_dlpack(input_a_offs),
G=int(G),
K=int(K),
N=int(N),
sizeof_element=sizeof_element,
stride_A_m_elems=sA_m,
stride_A_k_elems=sA_k,
stride_B0_elems=sB_0,
stride_Bk_elems=sB_k,
stride_Bn_elems=sB_n,
stride_C_m_elems=sC_m,
stride_C_n_elems=sC_n,
out_ptrs=ptrs,
out_problem=probs,
out_strides_abc=strides,
stream=stream,
)
prep_cache[prep_cache_key] = prep_executor
prep_executor(
base_A_u64=base_A_u64,
base_B_u64=base_B_u64,
base_C_u64=base_C_u64,
offs=from_dlpack(input_a_offs),
out_ptrs=ptrs,
out_problem=probs,
out_strides_abc=strides,
stream=stream,
)
# --- Tensormap workspace per SM ---
num_tensormap_buffers, max_active_clusters = get_hardware_info()
tensormap_shape = (
num_tensormap_buffers,
GroupedGemmKernel.num_tensormaps,
GroupedGemmKernel.bytes_per_tensormap // 8,
)
tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64)
tensormap_workspace = from_dlpack(tensormap_workspace_t)
# --- Total clusters ---
def compute_total_num_clusters(
problem_sizes_mnkl,
cluster_tile_shape_mn,
):
total_num_clusters = 0
for m, n, _, _ in problem_sizes_mnkl:
num_clusters_mn = tuple(
ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn)
)
total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
return total_num_clusters
# Compute cluster tile shape
def compute_cluster_tile_shape(
mma_tiler_mn,
cluster_shape_mn,
use_2cta_instrs,
):
cta_tile_shape_mn = list(mma_tiler_mn)
if use_2cta_instrs:
cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2
return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))
cluster_tile_shape_mn = compute_cluster_tile_shape(
(TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA)
)
total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn))
gemm_cache_key = get_gemm_cache_key(
prep_cache_key, max_active_clusters, total_num_clusters
)
gemm_executor = gemm_cache.get(gemm_cache_key)
if gemm_executor is None:
grouped_gemm = GroupedGemmKernel(
acc_dtype=ACC_DTYPE,
use_2cta_instrs=USE_2_CTA,
mma_tiler_mn=(TILE_M, TILE_N),
cluster_shape_mn=(CLUSTER_M, CLUSTER_N),
tensormap_update_mode=TENSORMAP_UPDATE_MODE,
)
gemm_executor = cute.compile(
grouped_gemm,
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
G,
probs,
strides,
ptrs,
total_num_clusters,
tensormap_workspace,
max_active_clusters,
stream,
)
gemm_cache[gemm_cache_key] = gemm_executor
gemm_executor(
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
probs,
strides,
ptrs,
tensormap_workspace,
stream,
)

View File

@ -95,7 +95,6 @@ class LoopBody:
"""
indexing_exprs: dict[str, sympy.Expr]
indexing_exprs_name: dict[sympy.Expr, str]
submodules: dict[str, Any]
subblocks: dict[str, LoopBodyBlock]
indirect_vars: list[sympy.Symbol]
@ -104,6 +103,9 @@ class LoopBody:
memory_usage: dict[MemoryUsageType, list[MemoryEntry]]
op_counts: collections.Counter[str]
# defined only temporarily
indexing_exprs_name: dict[sympy.Expr, str]
def __init__(
self,
fn,

View File

@ -3345,7 +3345,10 @@ class Scheduler:
)
break
if config.loop_ordering_after_fusion:
if (
config.loop_ordering_after_fusion
or config.loop_index_inversion_in_fusion
):
nodes = self.fuse_nodes_once(nodes, is_reorder_round=True)
return nodes
@ -4302,6 +4305,148 @@ class Scheduler:
return str(reasons)
def shared_data_after_inverting_indexing(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> int:
"""
Attempts to enable fusion between two nodes by inverting indexing patterns.
This optimization targets cases where node1 has a contiguous write and
node2 has a contiguous write but discontiguous read. By inverting the
indexing in node2's read and write operations, we can make them compatible
with node1 for potential fusion.
Args:
node1: First scheduler node (source)
node2: Second scheduler node (target for inversion)
Returns:
int: Fusion score if successful, 0 if optimization not applicable
"""
if not config.loop_index_inversion_in_fusion:
return -1
if any(n.is_cpu() for n in [node1, node2]):
return -1
# Check for shared buffers between nodes
node1_buffer_names = node1.read_writes.buffer_names()
node2_buffer_names = node2.read_writes.buffer_names()
common_buffer_names = node1_buffer_names & node2_buffer_names
if not common_buffer_names:
return -1
# only invert if node1 is single unmet dep
node2_unmet_dependencies = OrderedSet(
dep.name for dep in node2.unmet_dependencies
)
if node2_unmet_dependencies - node1_buffer_names:
return -1
if len(node2_unmet_dependencies) > 1:
return -1
# Currently only handle single read/write operations
if len(node2.read_writes.reads) > 1 or len(node2.read_writes.writes) > 1:
return -1
node2_read = next(iter(node2.read_writes.reads))
node2_write = next(iter(node2.read_writes.writes))
if not isinstance(node2_read, MemoryDep) or not isinstance(
node2_write, MemoryDep
):
return -1
node1_writes = {dep.name: dep for dep in node1.read_writes.writes}
if node2_read.name not in node1_writes:
return -1
node1_write = node1_writes[node2_read.name]
if not isinstance(node1_write, MemoryDep):
return -1
# We are checking for compatibility with the normalized node1 write
# then modifying node2 reads/writes. since the node1 write will be just used
# for compatibility, while node2 will be used in actual modification, just
# normalize node1 not node2.
node1_write = node1_write.normalize()
if (
node1_write.index != node2_write.index
and node1_write.size != node2_write.size
):
return -1
if node2_read.size != node2_write.size or len(node2_read.var_names) != 1:
return -1
# Verify we have exactly two indexing expressions (one read, one write)
if len(node2._body.indexing_exprs) != 2: # type: ignore[attr-defined]
return -1
# No subblocks allowed for this optimization
if node2._body.subblocks: # type: ignore[attr-defined]
return -1
assert (
"index0" in node2._body.indexing_exprs # type: ignore[attr-defined]
and "index1" in node2._body.indexing_exprs # type: ignore[attr-defined]
)
# Extract and verify single read expression
node2_read_exprs = OrderedSet(expr for expr in node2._body.get_read_exprs()) # type: ignore[attr-defined]
if len(node2_read_exprs) != 1:
return -1
read_expr = next(iter(node2_read_exprs))
# Determine which index is for reading vs writing
if read_expr == node2._body.indexing_exprs["index0"]: # type: ignore[attr-defined]
read_expr_index = "index0"
write_expr_index = "index1"
else:
assert read_expr == node2._body.indexing_exprs["index1"] # type: ignore[attr-defined]
read_expr_index = "index1"
write_expr_index = "index0"
from torch._inductor.invert_expr_analysis import generate_inverse_formula
index_vars = node2._body.vars[0] # type: ignore[attr-defined]
if len(index_vars) != 1:
return -1
simplified_terms = []
for term in sympy.Add.make_args(read_expr):
simplified_terms.append(
V.graph.sizevars.combine_modular_indexing_pairs(term)
)
simplified_read_expr = sum(simplified_terms)
inverse_formula = generate_inverse_formula(simplified_read_expr, index_vars[0])
# formula is not invertible
if inverse_formula is None:
return -1
# === Apply Inversion ===
# Swap the indexing expressions using the inverse formula
node2._body.indexing_exprs[read_expr_index] = node2._body.indexing_exprs[ # type: ignore[attr-defined]
write_expr_index
]
node2._body.indexing_exprs[write_expr_index] = inverse_formula # type: ignore[attr-defined]
# Refresh dependencies and calculate fusion score
node2.refresh_dependencies(True, False) # type: ignore[attr-defined]
score = self.score_fusion_memory(node1, node2)
fusion_log.info("Shared memory after inversion: %d", score)
return score
def shared_data_after_reordering_loop(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> int:
@ -4686,6 +4831,7 @@ class Scheduler:
del device2
shared_data_score = self.score_fusion_memory(node1, node2)
if (
can_reorder
and shared_data_score < config.score_fusion_memory_threshold
@ -4702,6 +4848,16 @@ class Scheduler:
smaller_node.expand_dimension_for_pointwise_node(expand_dim, expand_size)
shared_data_score = self.score_fusion_memory(node1, node2)
if (
config.loop_index_inversion_in_fusion
and shared_data_score < config.score_fusion_memory_threshold
):
new_shared_data_score = self.shared_data_after_inverting_indexing(
node1, node2
)
if new_shared_data_score >= 0:
shared_data_score = new_shared_data_score
if loop_ordering_log.isEnabledFor(logging.DEBUG):
loop_ordering_log.debug(
"%s and %s has %s shared data",

View File

@ -1,141 +0,0 @@
from dataclasses import dataclass
from enum import auto, Enum
from itertools import product
import torch._inductor.config as config
class TensorMapUpdateMode(Enum):
"""Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency."""
SMEM = auto()
GMEM = auto()
@dataclass(frozen=True)
class CuTeGemmConfig:
TILE_M: int = 128
TILE_N: int = 192
CLUSTER_M: int = 2
CLUSTER_N: int = 1
USE_2_CTA: bool = False
TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM
def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
For information regarding valid config sets, see:
https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py
"""
# Tile_n is always the same regardless of 2cta
tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256]
# Valid clusters
clusters_no_2cta = [
(1, 1),
(1, 2),
(1, 4),
(1, 8),
(1, 16),
(2, 1),
(2, 2),
(2, 4),
(2, 8),
(4, 1),
(4, 2),
(4, 4),
(8, 1),
(8, 2),
(16, 1),
]
clusters_2cta = [
(2, 1),
(2, 2),
(2, 4),
(2, 8),
(4, 1),
(4, 2),
(4, 4),
(8, 1),
(8, 2),
(16, 1),
]
configs: list[CuTeGemmConfig] = []
for use_2cta, cluster_set, tile_m_range in [
(False, clusters_no_2cta, [64, 128]),
(True, clusters_2cta, [128, 256]),
]:
for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product(
[TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM],
tile_m_range,
tile_n_vals,
cluster_set,
):
configs.append(
CuTeGemmConfig(
tile_m,
tile_n,
cluster_m,
cluster_n,
USE_2_CTA=use_2cta,
TENSORMAP_UPDATE_MODE=tensormap_update_mode,
)
)
return configs
def get_default_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
"""
config_tuples = [
(128, 256, 2, 1, False, TensorMapUpdateMode.SMEM),
(256, 160, 2, 1, True, TensorMapUpdateMode.GMEM),
(256, 256, 2, 1, True, TensorMapUpdateMode.GMEM),
(64, 32, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
(128, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
(256, 256, 2, 2, True, TensorMapUpdateMode.GMEM),
(128, 256, 1, 2, False, TensorMapUpdateMode.GMEM),
(64, 32, 1, 1, False, TensorMapUpdateMode.SMEM),
(256, 256, 2, 1, True, TensorMapUpdateMode.SMEM),
(128, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
(256, 256, 8, 1, True, TensorMapUpdateMode.GMEM),
(64, 32, 1, 2, False, TensorMapUpdateMode.SMEM),
(256, 192, 2, 1, True, TensorMapUpdateMode.GMEM),
(256, 256, 2, 2, True, TensorMapUpdateMode.SMEM),
(128, 96, 1, 2, False, TensorMapUpdateMode.SMEM),
(64, 192, 1, 1, False, TensorMapUpdateMode.SMEM),
(64, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 192, 1, 1, False, TensorMapUpdateMode.GMEM),
(128, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 160, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
]
return [CuTeGemmConfig(*args) for args in config_tuples]
def get_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures
or unstable results. By default, autotuning is disabled and we return only
a single baseline config.
"""
if (
config.cutedsl_enable_autotuning
and config.max_autotune_gemm_search_space == "EXHAUSTIVE"
):
return get_exhaustive_groupgemm_configs()
elif config.cutedsl_enable_autotuning:
return get_default_groupgemm_configs()
else:
return [get_default_groupgemm_configs()[0]]

View File

@ -1975,84 +1975,6 @@ def use_triton_blackwell_tma_template(
return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
@functools.lru_cache(maxsize=1)
def ensure_cute_available() -> bool:
"""Check if CuTeDSL is importable; cache the result for reuse.
Call ensure_cute_available.cache_clear() after installing CuTeDSL
in the same interpreter to retry the import.
"""
try:
return importlib.util.find_spec("cutlass.cute") is not None
except ImportError:
return False
def use_blackwell_cutedsl_grouped_mm(
mat_a: Any,
mat_b: Any,
layout: Layout,
a_is_2d: bool,
b_is_2d: bool,
offs: Optional[Any],
bias: Optional[Any],
scale_result: Optional[Any],
) -> bool:
"""
Returns True if we can use the blackwell kernel for grouped mm.
Required conditions:
1. CuTeDSL backend is enabled
2. CuTeDSL is available
3. We are on a blackwell arch
4. The dtype is bf16
5. Max autotune or max autotune gemm is enabled
6. A, B, and the output are 16B aligned
7. We are not using dynamic shapes
8. A is 2d
9. B is 3d
10. Offsets are provided
11. Bias and Scale are not provided
"""
if not ensure_cute_available():
return False
if not _use_autotune_backend("CUTEDSL"):
return False
from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
if not is_gpu(layout.device.type):
return False
if not is_datacenter_blackwell_arch():
return False
layout_dtypes = [torch.bfloat16]
if not _use_template_for_gpu(layout, layout_dtypes):
return False
if not (config.max_autotune or config.max_autotune_gemm):
return False
# Checks for 16B ptr and stride alignment
if not can_use_tma(mat_a, mat_b, output_layout=layout):
return False
if any(is_dynamic(x) for x in [mat_a, mat_b]):
return False
if not a_is_2d or b_is_2d:
return False
if offs is None:
return False
if bias is not None or scale_result is not None:
return False
return True
def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
from .virtualized import V

View File

@ -1224,43 +1224,3 @@ def _build_table(
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
)
return "".join(result)
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)

View File

@ -5,13 +5,13 @@
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
#include <torch/csrc/stable/c/shim.h>
#include <torch/csrc/stable/version.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
HIDDEN_NAMESPACE_BEGIN(torch, stable)
@ -68,7 +68,7 @@ inline torch::stable::Tensor narrow(
// only dtype information.
inline torch::stable::Tensor new_empty(
const torch::stable::Tensor& self,
std::vector<int64_t> size,
torch::headeronly::IntHeaderOnlyArrayRef size,
std::optional<c10::ScalarType> dtype = std::nullopt) {
int32_t device_type;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
@ -107,7 +107,7 @@ inline torch::stable::Tensor new_empty(
// only dtype information.
inline torch::stable::Tensor new_zeros(
const torch::stable::Tensor& self,
std::vector<int64_t> size,
torch::headeronly::IntHeaderOnlyArrayRef size,
std::optional<c10::ScalarType> dtype = std::nullopt) {
int32_t device_type;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
@ -144,12 +144,10 @@ inline torch::stable::Tensor new_zeros(
// We expect this to be the stable version of the pad.default op.
// pad.default takes in a SymInt[] as the pad argument however pad is typed as
// use std::vector<int64_t> because
// (1) IntArrayRef is not yet header-only
// (2) SymInt is not yet header-only
// torch::headeronly::IntHeaderOnlyArrayRef as SymInt is not yet header-only.
inline torch::stable::Tensor pad(
const torch::stable::Tensor& self,
std::vector<int64_t> pad,
torch::headeronly::IntHeaderOnlyArrayRef pad,
const std::string& mode = "constant",
double value = 0.0) {
AtenTensorHandle ret0 = nullptr;
@ -181,11 +179,10 @@ inline torch::stable::Tensor amax(
// This function is an overload to compute the maximum value along each slice of
// `self` reducing over all the dimensions in the vector `dims`. The
// amax.default op takes in a SymInt[] as the dims argument, however dims is
// typed as use std::vector<int64_t> here because (1) IntArrayRef is not yet
// header-only (2) SymInt is not yet header-only
// typed as use IntHeaderOnlyArrayRef here because SymInt is not yet header-only
inline torch::stable::Tensor amax(
const torch::stable::Tensor& self,
std::vector<int64_t> dims,
torch::headeronly::IntHeaderOnlyArrayRef dims,
bool keepdim = false) {
AtenTensorHandle ret = nullptr;
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax(

View File

@ -443,7 +443,6 @@ class CodeGen:
colored: bool = False,
# Render each argument on its own line
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
free_vars: list[str] = []
body: list[str] = []
@ -648,6 +647,15 @@ class CodeGen:
if verbose:
# override annotation with more detailed information
try:
from torch.distributed.tensor._api import DTensor, DTensorSpec
dtensorspec_format_shard_order_str = (
DTensorSpec.format_shard_order_str
)
except ModuleNotFoundError:
DTensor = None # type: ignore[assignment,misc]
dtensorspec_format_shard_order_str = None
from torch.fx.experimental.proxy_tensor import py_sym_types
from torch.fx.passes.shape_prop import TensorMetadata
@ -678,6 +686,16 @@ class CodeGen:
core = _tensor_annotation(meta_val)
if is_plain:
maybe_type_annotation = f': "{core}"'
elif type(meta_val) is DTensor:
assert dtensorspec_format_shard_order_str is not None
dtensor_meta = dtensorspec_format_shard_order_str(
meta_val._spec.placements, # type: ignore[attr-defined]
meta_val._spec.shard_order, # type: ignore[attr-defined]
)
cls = meta_val.__class__.__name__
maybe_type_annotation = (
f': "{cls}({core}, {dim_green(dtensor_meta)})"'
)
else:
cls = meta_val.__class__.__name__
maybe_type_annotation = f': "{cls}({core})"'
@ -799,10 +817,6 @@ class CodeGen:
return
raise NotImplementedError(f"node: {node.op} {node.target}")
if record_func:
body.append(
"_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n"
)
for i, node in enumerate(nodes):
# NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one
@ -812,22 +826,8 @@ class CodeGen:
# node index, which will be deleted later
# after going through _body_transformer
body.append(f"# COUNTER: {i}\n")
do_record = record_func and node.op in (
"call_function",
"call_method",
"call_module",
)
if do_record:
# The double hash ## convention is used by post-processing to find the fx markers
body.append(
f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n"
)
emit_node(node)
delete_unused_values(node)
if do_record:
body.append(f"_rf_{node.name}.__exit__(None, None, None)\n")
if record_func:
body.append("_rf.__exit__(None, None, None)\n")
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@ -1779,7 +1779,6 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
"""
Turn this ``Graph`` into valid Python code.
@ -1847,7 +1846,6 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def _python_code(
@ -1860,7 +1858,6 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
return self._codegen._gen_python_code(
self.nodes,
@ -1871,7 +1868,6 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def __str__(self) -> str:

View File

@ -861,18 +861,14 @@ class {module_name}(torch.nn.Module):
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
from torch._dynamo import config as dynamo_config
python_code = self._graph.python_code(
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
)
python_code = self._graph.python_code(root_module="self")
self._code = python_code.src
self._lineno_map = python_code._lineno_map
self._prologue_start = python_code._prologue_start
cls = type(self)
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
from torch._dynamo import config as dynamo_config
if dynamo_config.enrich_profiler_metadata:
# Generate metadata and register for profiler augmentation
@ -889,6 +885,7 @@ class {module_name}(torch.nn.Module):
# This ensures the same code+metadata always generates the same filename
hash_value = _metadata_hash(self._code, node_metadata)
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
filename = f"{file_stem}.py"
# Only include co_filename to use it directly as the cache key
@ -908,13 +905,6 @@ class {module_name}(torch.nn.Module):
_register_fx_metadata(filename, metadata)
# Replace the placeholder in generated code with actual filename
# The double hash ## convention is used by post-processing to find the fx markers
self._code = self._code.replace(
"torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')",
f"torch._C._profiler._RecordFunctionFast('## {filename} ##')",
)
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
# Determine whether this class explicitly defines a __call__ implementation

View File

@ -42,6 +42,9 @@ fp16_ieee_to_fp32_value
# fp32_from_bits called from fp16_ieee_to_fp32_value
# fp32_to_bits called from fp16_ieee_from_fp32_value
# torch/headeronly/util/HeaderOnlyArrayRef.h
HeaderOnlyArrayRef
# c10/util/complex.h, torch/headeronly/util/complex.h
complex

View File

@ -0,0 +1,247 @@
#pragma once
#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/Exception.h>
#include <array>
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <iterator>
#include <type_traits>
#include <vector>
namespace c10 {
/// HeaderOnlyArrayRef - A subset of ArrayRef that is implemented only
/// in headers. This will be a base class from which ArrayRef inherits, so that
/// we can keep much of the implementation shared.
///
/// [HeaderOnlyArrayRef vs ArrayRef note]
/// As HeaderOnlyArrayRef is a subset of ArrayRef, it has slightly less
/// functionality than ArrayRef. We document the minor differences below:
/// 1. ArrayRef has an extra convenience constructor for SmallVector.
/// 2. ArrayRef uses TORCH_CHECK. HeaderOnlyArrayRef uses header-only
/// STD_TORCH_CHECK, which will output a std::runtime_error vs a
/// c10::Error. Consequently, you should use ArrayRef when possible
/// and HeaderOnlyArrayRef only when necessary to support headeronly code.
/// In all other aspects, HeaderOnlyArrayRef is identical to ArrayRef, with the
/// positive benefit of being header-only and thus independent of libtorch.so.
template <typename T>
class HeaderOnlyArrayRef {
public:
using iterator = const T*;
using const_iterator = const T*;
using size_type = size_t;
using value_type = T;
using reverse_iterator = std::reverse_iterator<iterator>;
protected:
/// The start of the array, in an external buffer.
const T* Data;
/// The number of elements.
size_type Length;
public:
/// @name Constructors
/// @{
/// Construct an empty HeaderOnlyArrayRef.
/* implicit */ constexpr HeaderOnlyArrayRef() : Data(nullptr), Length(0) {}
/// Construct a HeaderOnlyArrayRef from a single element.
// TODO Make this explicit
constexpr HeaderOnlyArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
/// Construct a HeaderOnlyArrayRef from a pointer and length.
constexpr HeaderOnlyArrayRef(const T* data, size_t length)
: Data(data), Length(length) {}
/// Construct a HeaderOnlyArrayRef from a range.
constexpr HeaderOnlyArrayRef(const T* begin, const T* end)
: Data(begin), Length(end - begin) {}
template <
typename Container,
typename U = decltype(std::declval<Container>().data()),
typename = std::enable_if_t<
(std::is_same_v<U, T*> || std::is_same_v<U, T const*>)>>
/* implicit */ HeaderOnlyArrayRef(const Container& container)
: Data(container.data()), Length(container.size()) {}
/// Construct a HeaderOnlyArrayRef from a std::vector.
// The enable_if stuff here makes sure that this isn't used for
// std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
// bitfield.
template <typename A>
/* implicit */ HeaderOnlyArrayRef(const std::vector<T, A>& Vec)
: Data(Vec.data()), Length(Vec.size()) {
static_assert(
!std::is_same_v<T, bool>,
"HeaderOnlyArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
}
/// Construct a HeaderOnlyArrayRef from a std::array
template <size_t N>
/* implicit */ constexpr HeaderOnlyArrayRef(const std::array<T, N>& Arr)
: Data(Arr.data()), Length(N) {}
/// Construct a HeaderOnlyArrayRef from a C array.
template <size_t N>
// NOLINTNEXTLINE(*c-arrays*)
/* implicit */ constexpr HeaderOnlyArrayRef(const T (&Arr)[N])
: Data(Arr), Length(N) {}
/// Construct a HeaderOnlyArrayRef from a std::initializer_list.
/* implicit */ constexpr HeaderOnlyArrayRef(
const std::initializer_list<T>& Vec)
: Data(
std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
: std::begin(Vec)),
Length(Vec.size()) {}
/// @}
/// @name Simple Operations
/// @{
constexpr iterator begin() const {
return this->Data;
}
constexpr iterator end() const {
return this->Data + this->Length;
}
// These are actually the same as iterator, since ArrayRef only
// gives you const iterators.
constexpr const_iterator cbegin() const {
return this->Data;
}
constexpr const_iterator cend() const {
return this->Data + this->Length;
}
constexpr reverse_iterator rbegin() const {
return reverse_iterator(end());
}
constexpr reverse_iterator rend() const {
return reverse_iterator(begin());
}
/// Check if all elements in the array satisfy the given expression
constexpr bool allMatch(const std::function<bool(const T&)>& pred) const {
return std::all_of(cbegin(), cend(), pred);
}
/// empty - Check if the array is empty.
constexpr bool empty() const {
return this->Length == 0;
}
constexpr const T* data() const {
return this->Data;
}
/// size - Get the array size.
constexpr size_t size() const {
return this->Length;
}
/// front - Get the first element.
constexpr const T& front() const {
STD_TORCH_CHECK(
!this->empty(),
"HeaderOnlyArrayRef: attempted to access front() of empty list");
return this->Data[0];
}
/// back - Get the last element.
constexpr const T& back() const {
STD_TORCH_CHECK(
!this->empty(),
"HeaderOnlyArrayRef: attempted to access back() of empty list");
return this->Data[this->Length - 1];
}
/// equals - Check for element-wise equality.
constexpr bool equals(HeaderOnlyArrayRef RHS) const {
return this->Length == RHS.Length &&
std::equal(begin(), end(), RHS.begin());
}
/// slice(n, m) - Take M elements of the array starting at element N
constexpr HeaderOnlyArrayRef<T> slice(size_t N, size_t M) const {
STD_TORCH_CHECK(
N + M <= this->size(),
"HeaderOnlyArrayRef: invalid slice, N = ",
N,
"; M = ",
M,
"; size = ",
this->size());
return HeaderOnlyArrayRef<T>(this->data() + N, M);
}
/// slice(n) - Chop off the first N elements of the array.
constexpr HeaderOnlyArrayRef<T> slice(size_t N) const {
STD_TORCH_CHECK(
N <= this->size(),
"HeaderOnlyArrayRef: invalid slice, N = ",
N,
"; size = ",
this->size());
return slice(N, this->size() - N);
}
/// @}
/// @name Operator Overloads
/// @{
constexpr const T& operator[](size_t Index) const {
return this->Data[Index];
}
/// Vector compatibility
constexpr const T& at(size_t Index) const {
STD_TORCH_CHECK(
Index < this->Length,
"HeaderOnlyArrayRef: invalid index Index = ",
Index,
"; Length = ",
this->Length);
return this->Data[Index];
}
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
std::enable_if_t<std::is_same_v<U, T>, HeaderOnlyArrayRef<T>>& operator=(
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
U&& Temporary) = delete;
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
std::enable_if_t<std::is_same_v<U, T>, HeaderOnlyArrayRef<T>>& operator=(
std::initializer_list<U>) = delete;
/// @}
/// @name Expensive Operations
/// @{
std::vector<T> vec() const {
return std::vector<T>(this->Data, this->Data + this->Length);
}
/// @}
};
} // namespace c10
namespace torch::headeronly {
using c10::HeaderOnlyArrayRef;
using IntHeaderOnlyArrayRef = HeaderOnlyArrayRef<int64_t>;
} // namespace torch::headeronly

View File

@ -4,7 +4,7 @@ import operator
import re
from collections import deque
from dataclasses import dataclass
from typing import Any, Literal, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING
from torch.autograd.profiler import profile
from torch.profiler import DeviceType
@ -400,170 +400,3 @@ def _init_for_cuda_graphs() -> None:
with profile():
pass
@dataclass
class TimelineEvent:
"""Represents an event in the profiler timeline."""
timestamp: int
event_type: Literal["start", "end", "regular"]
marker_type: Optional[Literal["filename", "node"]]
identifier: Optional[str | int]
event: dict[str, Any]
@dataclass
class ContextStackEntry:
"""Represents a context (filename or node) in the stack."""
context_type: Literal["filename", "node"]
identifier: str | int
metadata: Optional[dict]
tid: Optional[int] = None # Thread ID associated with this context
def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
"""
Maps recorded profiler events to their corresponding fx nodes and adds stack traces.
Builds a timeline of all events (regular ops and FX markers for filenames/nodes),
sorts by timestamp, then processes chronologically while maintaining a context stack of active
filename/node scopes. Regular events are augmented with stack traces and node names from the
innermost active context. Runtime is O(n log n) for n events.
Args:
traced_data: Json of profiler events from Chrome trace
Returns:
Dict mapping recorded event names to their aten operations with added stack traces
"""
from torch.fx.traceback import _FX_METADATA_REGISTRY
trace_events = traced_data.get("traceEvents", [])
# Create event timeline
event_timeline: list[TimelineEvent] = []
def is_fx_marker_event(event):
return (
event.get("cat") == "cpu_op"
and event.get("name", "").startswith("## ")
and event.get("name", "").endswith(" ##")
)
def append_fx_marker_event(event_type, identifier, event):
start_ts = event["ts"]
end_ts = start_ts + event["dur"]
event_timeline.append(
TimelineEvent(start_ts, "start", event_type, identifier, event)
)
event_timeline.append(
TimelineEvent(end_ts, "end", event_type, identifier, event)
)
for event in trace_events:
if "ts" not in event or "dur" not in event:
continue
if is_fx_marker_event(event):
content = event["name"][3:-3]
if content.endswith(".py"):
append_fx_marker_event("filename", content, event)
else:
try:
node_index = int(content)
except ValueError:
pass
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
else:
# Regular event that needs augmentation
start_ts = event["ts"]
event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event))
# Sort by timestamp
event_timeline.sort(key=lambda x: x.timestamp)
# Process events in chronological order with a stack
context_stack: list[ContextStackEntry] = []
# Invariant: all start event has a corresponding end event
for timeline_event in event_timeline:
match timeline_event.event_type:
case "start":
assert timeline_event.identifier is not None
if timeline_event.marker_type == "filename":
assert isinstance(timeline_event.identifier, str)
# Push filename context - query metadata registry on-demand
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
tid = timeline_event.event.get("tid")
context_stack.append(
ContextStackEntry(
"filename", timeline_event.identifier, metadata, tid
)
)
elif timeline_event.marker_type == "node":
# Find the current filename from stack
current_file_metadata = None
tid = timeline_event.event.get("tid")
for ctx_entry in reversed(context_stack):
if (
ctx_entry.context_type == "filename"
and ctx_entry.tid == tid
):
current_file_metadata = ctx_entry.metadata
break
if current_file_metadata:
node_metadata = current_file_metadata.get("node_metadata", {})
if timeline_event.identifier in node_metadata:
node_meta: Optional[dict] = node_metadata[
timeline_event.identifier
]
context_stack.append(
ContextStackEntry(
"node", timeline_event.identifier, node_meta, tid
)
)
case "end":
# Pop from stack - search backwards to find matching context
for i in range(len(context_stack) - 1, -1, -1):
ctx_entry = context_stack[i]
if (
timeline_event.marker_type == ctx_entry.context_type
and timeline_event.identifier == ctx_entry.identifier
):
context_stack.pop(i)
break
case "regular":
# Apply metadata from current context stack
# Find the most specific context (node takes precedence over filename)
# Only augment events with the same tid as the file/node event matched
current_stack_trace = None
current_node_name = None
event_tid = timeline_event.event.get("tid")
for ctx_entry in reversed(context_stack):
# Only apply metadata from contexts with matching tid
if ctx_entry.tid == event_tid:
if ctx_entry.context_type == "node" and ctx_entry.metadata:
current_stack_trace = ctx_entry.metadata.get(
"stack_trace", "No model stack trace available"
)
current_node_name = ctx_entry.metadata.get("name", "")
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
# if nodes are nested, e.g. in nested graph modules
break
# Augment the event
if current_stack_trace or current_node_name:
args = timeline_event.event.setdefault("args", {})
if current_stack_trace:
args["stack_trace"] = current_stack_trace
if current_node_name:
args["node_name"] = current_node_name

View File

@ -210,7 +210,8 @@ class _KinetoProfile:
def start_trace(self) -> None:
if self.execution_trace_observer:
self.execution_trace_observer.start()
assert self.profiler is not None
if self.profiler is None:
raise AssertionError("Profiler must be initialized before starting trace")
self.profiler._start_trace()
if self.profile_memory:
@ -256,7 +257,8 @@ class _KinetoProfile:
def stop_trace(self) -> None:
if self.execution_trace_observer:
self.execution_trace_observer.stop()
assert self.profiler is not None
if self.profiler is None:
raise AssertionError("Profiler must be initialized before stopping trace")
self.profiler.__exit__(None, None, None)
def export_chrome_trace(self, path: str):
@ -264,7 +266,10 @@ class _KinetoProfile:
Exports the collected trace in Chrome JSON format. If kineto is enabled, only
last cycle in schedule is exported.
"""
assert self.profiler
if self.profiler is None:
raise AssertionError(
"Profiler must be initialized before exporting chrome trace"
)
if path.endswith(".gz"):
fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False)
fp.close()
@ -284,7 +289,8 @@ class _KinetoProfile:
path (str): save stacks file to this location;
metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total"
"""
assert self.profiler
if self.profiler is None:
raise AssertionError("Profiler must be initialized before exporting stacks")
return self.profiler.export_stacks(path, metric)
def toggle_collection_dynamic(
@ -316,7 +322,7 @@ class _KinetoProfile:
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
"""
if not self.profiler:
if self.profiler is None:
return
self.profiler.toggle_collection_dynamic(enable, activities)
@ -333,7 +339,10 @@ class _KinetoProfile:
To use shape/stack functionality make sure to set record_shapes/with_stack
when creating profiler context manager.
"""
assert self.profiler
if self.profiler is None:
raise AssertionError(
"Profiler must be initialized before getting key averages"
)
return self.profiler.key_averages(
group_by_input_shape, group_by_stack_n, group_by_overload_name
)
@ -343,7 +352,8 @@ class _KinetoProfile:
Returns the list of unaggregated profiler events,
to be used in the trace callback or after the profiling is finished
"""
assert self.profiler
if self.profiler is None:
raise AssertionError("Profiler must be initialized before accessing events")
return self.profiler.function_events
def add_metadata(self, key: str, value: str) -> None:
@ -395,7 +405,10 @@ class _KinetoProfile:
if missing:
raise ValueError(f"{', '.join(missing)} required for memory profiling.")
assert self.profiler is not None and self.profiler.kineto_results is not None
if self.profiler is None or self.profiler.kineto_results is None:
raise AssertionError(
"Profiler and kineto_results must be initialized for memory profiling"
)
return MemoryProfile(self.profiler.kineto_results)
def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
@ -485,7 +498,8 @@ def schedule(
"""
def schedule_fn(step: int) -> ProfilerAction:
assert step >= 0
if step < 0:
raise AssertionError(f"Step must be non-negative. Got {step}.")
if step < skip_first:
return ProfilerAction.NONE
else:
@ -508,9 +522,11 @@ def schedule(
else ProfilerAction.RECORD_AND_SAVE
)
assert (
wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0
), "Invalid profiler schedule arguments"
if wait < 0 or warmup < 0 or active <= 0 or repeat < 0 or skip_first < 0:
raise AssertionError(
f"Invalid profiler schedule arguments. Got wait={wait} (need >= 0), warmup={warmup} (need >= 0), "
f"active={active} (need > 0), repeat={repeat} (need >= 0), skip_first={skip_first} (need >= 0)."
)
if warmup == 0:
warn(
"Profiler won't be using warmup, this can skew profiler results",
@ -717,7 +733,8 @@ class profile(_KinetoProfile):
activities_set.add(ProfilerActivity.CUDA)
elif ProfilerActivity.CUDA in activities_set:
activities_set.remove(ProfilerActivity.CUDA)
assert len(activities_set) > 0, "No valid profiler activities found"
if len(activities_set) == 0:
raise AssertionError("No valid profiler activities found")
super().__init__(
activities=activities,