Compare commits

..

16 Commits

Author SHA1 Message Date
fa9d5c2dd7 Update on "conv: refactor for lookup table support"
\# why

enable configuring conv operations through the lookup table

\# what

- move kwargs etc into template_heuristics
- add conv specific kernel inputs
- add lookup table e2e test for conv

\# testing

```
python3 -bb -m pytest test/inductor/test_lookup_table.py -k "conv2d" -v
python3 -bb -m pytest test/inductor/test_max_autotune.py -k "conv" -v
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D86474839](https://our.internmc.facebook.com/intern/diff/D86474839)

[ghstack-poisoned]
2025-11-10 17:28:12 -08:00
f048cb1f3c Update on "conv: refactor for lookup table support"
\# why

enable configuring conv operations through the lookup table

\# what

- move kwargs etc into template_heuristics
- add conv specific kernel inputs
- add lookup table e2e test for conv

\# testing

```
python3 -bb -m pytest test/inductor/test_lookup_table.py -k "conv2d" -v
python3 -bb -m pytest test/inductor/test_max_autotune.py -k "conv" -v
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
2025-11-06 16:29:43 -08:00
c277e07f77 conv: refactor for lookup table support
\# why

enable configuring conv operations through the lookup table

\# what

- move kwargs etc into template_heuristics
- add conv specific kernel inputs
- add lookup table e2e test for conv

\# testing

```
python3 -bb -m pytest test/inductor/test_lookup_table.py -k "conv2d" -v
python3 -bb -m pytest test/inductor/test_max_autotune.py -k "conv" -v
```

[ghstack-poisoned]
2025-11-05 18:57:57 -08:00
8e8cbb85ee Revert "[Inductor] Fix unbacked float symbol handling in kernel codegen (#166890)"
This reverts commit 0c7a4a6b48d49306eae8d0a9ee8d32b1899e5e23.

Reverted https://github.com/pytorch/pytorch/pull/166890 on behalf of https://github.com/malfet due to Looks like it broke torchfuzz tests, see fbd70fb84e/1 and same test on slow ([comment](https://github.com/pytorch/pytorch/pull/166890#issuecomment-3493011038))
2025-11-05 19:42:39 +00:00
fbd70fb84e Update typing docs to reference pyrefly (#166883)
Replacing mypy codumentation in the CONTRIBUTING.MD file with pyrefly references. I have made initial changes to https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch documentation, and will replace the script at the bottom with one tailored to the pyrefly tool as a follow-up.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166883
Approved by: https://github.com/malfet
2025-11-05 19:35:38 +00:00
6c5db82584 [Inductor] Naive foreach autotune support (#162053)
Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code.

Before:
triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |

After:
triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |

num_warps=8 default due to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton_combo_kernel.py#L374

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162053
Approved by: https://github.com/mlazos, https://github.com/naromero77amd, https://github.com/jeffdaily

Co-authored-by: Nichols A. Romero <nick.romero@amd.com>
2025-11-05 19:27:23 +00:00
6052a01b71 [BE][Typing][Dynamo] Type torch/_dynamo/variables/dicts.py (#167022)
Provides type coverage to torch/_dynamo/variables/dicts.py

Coverage report:
`mypy torch/_dynamo/variables/dicts.py --linecount-report /tmp/coverage_log`

Compare before to after - we go from 0 lines and 0 funcs covered to 1547 lines and 89 funcs covered

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167022
Approved by: https://github.com/Skylion007
2025-11-05 19:18:35 +00:00
14b153bcf2 include DTensor metadata when pretty-printing fx.Graphs (#166750)
Example below. You need to trace your function with DTensor inputs in order for the graph proxies to run on DTensor (and not the inner local tensor). You also need to run with `tracing_mode="fake"`, or with your own `FakeTensorMode`, to see the nice DTensor printing. If this doesn't feel very ergonomic then maybe we can find some better UX for printing a graph with DTensor in it:

<img width="1446" height="582" alt="image" src="https://github.com/user-attachments/assets/99ea5ce6-1008-4ba5-b58e-542cd34a340b" />

```
import torch
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.distributed.tensor import distribute_tensor, Shard, Replicate
from torch.utils._debug_mode import DebugMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils import _pytree as pytree

world_size = 8
device_type = "cpu"
fake_store = FakeStore()
torch.distributed.init_process_group("fake", store=fake_store, rank=0, world_size=world_size)
device_mesh = torch.distributed.init_device_mesh(device_type, (world_size,))
dim = 128

A = torch.randn(8, dim)
B = torch.randn(dim, dim)
dA = distribute_tensor(A, device_mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, device_mesh, [Replicate()]).requires_grad_()

def f(dA, dB):
    dy = dA @ dB
    loss = dy.sum()
    loss.backward()
    return dA.grad, dB.grad

# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode='fake')(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
gm.print_readable(colored=True)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166750
Approved by: https://github.com/ezyang, https://github.com/wconstab, https://github.com/Skylion007
2025-11-05 18:58:54 +00:00
641de23c96 ci: Add aarch64 docker builds for modern clang (#166416)
Should enable us to build using some arm optimizations that are only
available on the newest versions of clang.

Signed-off-by: Eli Uriegas <eliuriegas@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166416
Approved by: https://github.com/malfet
2025-11-05 18:55:56 +00:00
89165c0a2b Update triton to 3.5.1 release (#166968)
This includes sm103 https://github.com/triton-lang/triton/pull/8485 fix

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166968
Approved by: https://github.com/Lucaskabela, https://github.com/njriasan
2025-11-05 18:26:34 +00:00
dcc2ba4ca4 Add some code for exploring the space of accessible size/stride configs via plain views (#167076)
We are working on a translation from as_strided to view operations, but
only when the as_strided is representable as a plain view.  A useful
testing utility in this situation is the ability to enumerate all valid
views on an original tensor.  So we have a small test here that shows
it is possible.

To avoid an explosion of states, we don't handle permutes and size=1,
which are degenerate cases (you can always do a single permute and
a series of unsqueezes to get to the final desired state.)

Authored with claude code assistance.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167076
Approved by: https://github.com/albanD
ghstack dependencies: #166868, #166867
2025-11-05 18:25:19 +00:00
ad5c7c20e0 Revert "[cuDNN] Smoke-test runtime cuDNN version matches compile time version in CI (#165922)"
This reverts commit 1d3f5e19da068ec1340db041b7105b287a513578.

Reverted https://github.com/pytorch/pytorch/pull/165922 on behalf of https://github.com/atalman due to Introduces Segfault in linux-jammy-cuda12.8-py3.10-gcc11 ([comment](https://github.com/pytorch/pytorch/pull/165922#issuecomment-3492667312))
2025-11-05 18:13:57 +00:00
c86540f120 Revert "Add model code stack trace to torch.profile (#166677)"
This reverts commit c00696144dae1f02e04ce345480b55e46c7d32a8.

Reverted https://github.com/pytorch/pytorch/pull/166677 on behalf of https://github.com/jeffdaily due to broke rocm ([comment](https://github.com/pytorch/pytorch/pull/166677#issuecomment-3492658160))
2025-11-05 18:11:11 +00:00
c17aa0f113 [ROCm] Enable group gemm through CK (#166334)
Fixes #161366
All the 4 types of dimension matrix are supported.
2d-2d, 2d-3d, 3d-3d, 3d-2d. The corresponding test cases in test_matmul_cuda are working
for both forward and backward pass.
The CK path is enabled for gfx942, gfx950.
ToDo: Need to enable support on gfx90a since the ck kernel used in this commit produces gpu error,
might require a different CK kernel config, based on the profiler result on gfx90a.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166334
Approved by: https://github.com/atalman
2025-11-05 18:03:59 +00:00
4ff068c33a [Code Clean] Replace assert with if statement and raise AssertionError (#166935)
Including:
- `torch/profiler/profiler.py`

Fixes part of #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166935
Approved by: https://github.com/fffrog, https://github.com/albanD
2025-11-05 17:59:16 +00:00
0c7a4a6b48 [Inductor] Fix unbacked float symbol handling in kernel codegen (#166890)
When a fn compiled with `torch.compile` calls `.item()` on a float tensor arg (e.g., for thresholds in `torch.clamp`), the generated triton kernel references an unbacked float symbol (e.g., `zuf0`) that was never added to the kernel's parameter list, causing a compilation error.

Fixes: #166888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166890
Approved by: https://github.com/eellison
2025-11-05 17:50:08 +00:00
32 changed files with 1630 additions and 713 deletions

View File

@ -271,6 +271,16 @@ case "$tag" in
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
pytorch-linux-jammy-aarch64-py3.10-clang21)
ANACONDA_PYTHON_VERSION=3.10
CLANG_VERSION=21
ACL=yes
VISION=yes
OPENBLAS=yes
# snadampal: skipping llvm src build install because the current version
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=11

View File

@ -1 +1 @@
7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd
bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7

View File

@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
# work around ubuntu apt-get conflicts
sudo apt-get -y -f install
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
if [[ $CLANG_VERSION == 18 ]]; then
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
if [[ $CLANG_VERSION -ge 18 ]]; then
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main"
fi
fi

View File

@ -129,7 +129,7 @@ function install_129 {
}
function install_128 {
CUDNN_VERSION=9.10.2.21
CUDNN_VERSION=9.8.0.87
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
# install CUDA 12.8.1 in the same container
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux

View File

@ -10,6 +10,7 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
OPENBLAS_BUILD_FLAGS="
CC=gcc
NUM_THREADS=128
USE_OPENMP=1
NO_SHARED=0

View File

@ -1 +1 @@
3.5.0
3.5.1

View File

@ -272,18 +272,6 @@ def smoke_test_cuda(
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
print(f"Torch cuDNN version: {torch_cudnn_version}")
torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion()
print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}")
torch_cudnn_runtime_version = tuple(
[int(x) for x in torch_cudnn_version.split(".")]
)
if torch_cudnn_runtime_version != torch_cudnn_compile_version:
raise RuntimeError(
"cuDNN runtime version doesn't match comple version. "
f"Loaded: {torch_cudnn_runtime_version} "
f"Expected: {torch_cudnn_compile_version}"
)
if sys.platform in ["linux", "linux2"]:
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
print(f"Torch nccl; version: {torch_nccl_version}")

View File

@ -79,6 +79,8 @@ jobs:
include:
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21
runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
runner: linux.arm64.m7g.4xlarge
timeout-minutes: 600

View File

@ -93,7 +93,7 @@ jobs:
- linux-jammy-cuda12_8-py3_10-gcc11-build
- target-determination
with:
timeout-minutes: 400
timeout-minutes: 360
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }}

View File

@ -18,7 +18,7 @@ aspects of contributing to PyTorch.
- [Python Unit Testing](#python-unit-testing)
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
- [Local linting](#local-linting)
- [Running `mypy`](#running-mypy)
- [Running `pyrefly`](#running-pyrefly)
- [C++ Unit Testing](#c-unit-testing)
- [Run Specific CI Jobs](#run-specific-ci-jobs)
- [Merging your Change](#merging-your-change)
@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory.
**Prerequisites**:
The following packages should be installed with `pip`:
- `expecttest` and `hypothesis` - required to run tests
- `mypy` - recommended for linting
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
- `pytest` - recommended to run tests more selectively
Running
```
@ -350,15 +350,32 @@ make lint
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
#### Running `mypy`
#### Running `pyrefly`
`mypy` is an optional static type checker for Python. We have multiple `mypy`
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
**Getting Started with Pyrefly:**
To run type checking on the PyTorch codebase:
```bash
pyrefly check
```
For more detailed error information with summaries:
```bash
pyrefly check --summarize-errors
```
**Learn More:**
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
See [Guide for adding type annotations to
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
for more information on how to set up `mypy` and tackle type annotation
tasks.
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
### C++ Unit Testing

View File

@ -22,6 +22,9 @@
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#ifdef USE_ROCM
#include <ATen/native/hip/ck_group_gemm.h>
#endif
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
@ -666,12 +669,19 @@ std::optional<c10::ScalarType> out_dtype) {
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
use_fast_path = true;
}
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
if (use_fast_path) {
// fast path, no d2h sync needed
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#endif
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
}

View File

@ -0,0 +1,19 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/ScalarType.h>
#include <optional>
namespace at {
namespace hip {
namespace detail {
void group_gemm_ck(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
at::Tensor& out);
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -0,0 +1,462 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/hip/HIPContext.h>
#include <ATen/Tensor.h>
#include <ATen/TensorAccessor.h>
#include <c10/hip/HIPStream.h>
#include <iostream>
#include <vector>
#include <optional>
#include <type_traits>
#include <ck/ck.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/utility/tuple.hpp>
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
namespace at {
namespace hip {
namespace detail {
namespace CkTypes {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
}
template <typename ALayout, typename BLayout, typename DataType>
using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage<
ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor,
DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType,
CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
1, 1,
S<1,32,1,8>, 4
>;
template <typename ALayout, typename BLayout, typename DataType>
void launch_grouped_bgemm_ck_impl_dispatch(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
at::Tensor& out)
{
using DeviceOp = GroupedGemmKernel<ALayout, BLayout, DataType>;
using PassThrough = CkTypes::PassThrough;
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a_ptrs, p_b_ptrs;
std::vector<void*> p_e_ptrs;
// Note: d_ptrs will be resized after we populate the other vectors
const int mat_a_dim = mat_a.dim();
const int mat_b_dim = mat_b.dim();
const char* a_ptr_base = reinterpret_cast<const char*>(mat_a.data_ptr());
const char* b_ptr_base = reinterpret_cast<const char*>(mat_b.data_ptr());
char* out_ptr_base = reinterpret_cast<char*>(out.data_ptr());
const size_t a_element_size = mat_a.element_size();
const size_t b_element_size = mat_b.element_size();
const size_t out_element_size = out.element_size();
// for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses.
if (mat_a_dim == 2 && mat_b_dim == 2) {
// 2D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
const int M = mat_a.size(0); // number of rows in A
const int N = mat_b.size(1); // number of columns in B
const int K = mat_a.size(1); // columns in A == rows in B
// for 2d*2d input, output is 3d.
// for each group, A columns (K) are sliced. M and N dimensions are not sliced.
for (int i = 0; i < num_groups; ++i) {
int start_k = (i == 0) ? 0 : offs_accessor[i-1];
int end_k = offs_accessor[i];
int k = end_k - start_k;
//K dimension are sliced, hence select stride(1) always.
//K dimension is always dimension 1, regardless of memory layout (row/column major)
const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size;
const void* group_b_ptr;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size;
// Leading dimension = distance between rows = stride(0)
ldb = mat_b.stride(0);
} else {
// Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size;
// Leading dimension = distance between columns = stride(1)
ldb = mat_b.stride(1);
}
// Calculate output pointer for group i in 3D tensor [num_groups, M, N]
// stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [M,K]: leading dimension = distance between rows = stride(0)
lda = mat_a.stride(0);
} else {
// Column-major A [M,K]: leading dimension = distance between columns = stride(1)
lda = mat_a.stride(1);
}
// Output is always row-major in 3D tensor [num_groups, M, N]
// Leading dimension for each group's [M,N] slice = stride(1) = N
ldc = out.stride(1);
size_t output_group_bytes = M * N * out_element_size;
void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes;
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(k),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 2 && mat_b_dim == 3) {
// 2D*3D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 2d*3d input, output is 2d.
// A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n]
// Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B
const int K = mat_a.size(1); // columns in A
// For 2D-3D case: The output determines N (result width)
const int N = out.size(1); // N is the width of the output tensor
for (int i = 0; i < num_groups; ++i) {
int start_m = (i == 0) ? 0 : offs_accessor[i - 1];
int end_m = offs_accessor[i];
int m = end_m - start_m;
// Skip zero-sized groups but continue processing subsequent groups
if (m <= 0) {
continue;
}
// Select A rows for group i: skip start_m rows
const void* group_a_ptr;
int lda;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
lda = mat_a.stride(0); // distance between rows
} else {
// Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows)
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
// Detect stride pattern for A tensor to determine appropriate lda calculation
bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0));
if (a_is_strided_tensor) {
// For strided A tensors: stride(0) gives the actual leading dimension
lda = mat_a.stride(0);
} else {
// For non-strided A tensors: use the M dimension (total rows)
lda = mat_a.size(0); // Total M dimension for column-major layout
}
}
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed
ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N]
} else {
// Detect stride pattern to determine appropriate ldb calculation
bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2));
if (is_strided_tensor) {
// For strided tensors: stride(2) gives the actual leading dimension
ldb = mat_b.stride(2);
} else {
// For non-strided tensors: use the N dimension
ldb = mat_b.size(1);
}
}
// Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N]
void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size;
int ldc = out.stride(0); // distance between rows in output (should be N for 2D case)
gemm_descs.push_back({
static_cast<ck::index_t>(m),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 3) {
// 3d*3d input, output is 3d - batched matrix multiplication
// A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n]
// Each batch is processed as a separate GEMM operation
const int batch_size = mat_a.size(0);
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed)
// Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout
int N;
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
N = mat_b.size(2);
} else if (mat_b.size(2) == K) {
// B is [batch, n, k] - transposed layout
N = mat_b.size(1);
} else {
TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[",
batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]");
}
for (int i = 0; i < batch_size; ++i) {
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
// Select output batch for group i: Output[i, :, :]
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldb, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
} else {
// Column-major A: leading dimension = distance between columns = stride(2)
lda = mat_a.stride(2);
}
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B: leading dimension = distance between rows
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(1); // stride between K rows
} else {
// B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM
ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n])
}
} else {
// Column-major B: leading dimension = distance between columns
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(2); // stride between N columns
} else {
// B is [batch, n, k] - transposed layout
ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n])
}
}
// Output is typically row-major: leading dimension = distance between rows = stride(1)
ldc = out.stride(1);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 2) {
// 3D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 3d*2d input, output is 3d.
// A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both)
// Offset divides N dimension of B, each group gets different slice of B and different batch of A
const int batch_size = mat_a.size(0); // n_groups
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A
// For row-major A and B case: B should be [K, total_N]
const int total_N = mat_b.size(1); // B is [K, total_N] for row-major
for (int i = 0; i < num_groups; ++i) {
int start_n = (i == 0) ? 0 : offs_accessor[i - 1];
int end_n = offs_accessor[i];
int n = end_n - start_n;
// Skip zero-sized groups but continue processing subsequent groups
if (n <= 0) {
continue;
}
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B slice for group i: B[:, start_n:end_n] (B[K, total_N])
const void* group_b_ptr;
int ldb;
// Check if B is row-major or column-major
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(0); // distance between rows (should be total_N)
} else {
// Column-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(1); // distance between columns (should be K)
}
// Select output slice for group i: Output[:, start_n:end_n]
void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size;
int lda, ldc;
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
// Output is row-major: leading dimension = distance between rows = stride(0)
ldc = out.stride(0);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(n),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim);
}
TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups");
// Initialize d_ptrs with the correct size
std::vector<std::array<const void*, 0>> d_ptrs(p_a_ptrs.size());
static DeviceOp gemm_instance;
auto argument = gemm_instance.MakeArgument(
p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs,
gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}
);
TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument),
"CK Group GEMM: argument unsupported (shape/strides/type config)");
size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument);
size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument);
void* gemm_arg_buf = nullptr;
void* ws_buf = nullptr;
hipMalloc(&gemm_arg_buf, arg_buf_size);
hipMalloc(&ws_buf, ws_size);
gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf);
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
auto invoker = gemm_instance.MakeInvoker();
hipStream_t stream = c10::hip::getCurrentHIPStream();
invoker.Run(argument, {stream});
hipFree(gemm_arg_buf);
hipFree(ws_buf);
}
void group_gemm_ck(
const at::Tensor& input_a,
const at::Tensor& input_b_colmajor,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& /*bias*/,
at::Tensor& out)
{
// Detect if input_a is row-major based on stride pattern
bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1);
bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1);
// Ensure tensor A is row-major and contiguous if not already
at::Tensor mat_a = input_a;
if (!a_row_major) {
// If A is not row-major, make it contiguous (row-major)
mat_a = input_a.contiguous();
}
// Force tensor B to be column-major using double transpose trick
// This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape
at::Tensor mat_b = input_b_colmajor;
if (!b_col_major) {
mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1);
}
// For 3D tensors, check the last dimension stride for row-major detection
a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1);
bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1);
if (mat_a.dtype() == at::kBFloat16) {
// bf16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kHalf) {
// fp16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kFloat) {
// fp32 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype");
}
}
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -5,8 +5,16 @@ import contextlib
import torch
import torch.distributed as dist
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -426,6 +434,31 @@ class TestDTensorDebugMode(TestCase):
][-1]
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
def test_pretty_print_dtensor_make_fx(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
A = torch.randn(8, 32)
B = torch.randn(32, 32)
dA = distribute_tensor(A, mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, mesh, [Replicate()]).requires_grad_()
def f(dA, dB):
dy = dA @ dB
loss = dy.sum()
loss.backward()
return dA.grad, dB.grad
# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode="fake")(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
# Colored is nice for actual viewing, not using in this test though
gm_str = gm.print_readable(colored=False, print_output=False)
self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str)
instantiate_parametrized_tests(TestDTensorDebugMode)

View File

@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca
torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node
torch.fx.graph.Graph.print_tabular(self)
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode
torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')
torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool
torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None

View File

@ -2,14 +2,18 @@
import re
import unittest
from functools import partial
from typing import Any, Optional, Union
from typing import Any, Optional
from unittest.mock import patch
import torch
import torch.nn as nn
from torch._inductor import config as inductor_config
from torch._inductor.choices import InductorChoices
from torch._inductor.kernel_inputs import MMKernelInputs
from torch._inductor.kernel_inputs import (
ConvKernelInputs,
MMKernelInputs,
SerializableValue,
)
from torch._inductor.lookup_table.choices import LookupTableChoices
from torch._inductor.select_algorithm import (
add_preprocessing_fn,
@ -54,7 +58,7 @@ class MockMMKernelInputs(MMKernelInputs):
def __init__(
self,
tensors: list[torch.Tensor],
scalars: Optional[dict[str, Union[float, int]]] = None,
scalars: Optional[dict[str, SerializableValue]] = None,
mat1_idx: int = -2,
mat2_idx: int = -1,
):
@ -80,6 +84,37 @@ class MockMMKernelInputs(MMKernelInputs):
return self.tensors[0].device.type
class MockConvKernelInputs(ConvKernelInputs):
"""Mock ConvKernelInputs that subclasses the real class and uses real tensors"""
def __init__(
self,
tensors: list[torch.Tensor],
scalars: Optional[dict[str, SerializableValue]] = None,
x_idx: int = 0,
weight_idx: int = 1,
bias_idx: Optional[int] = None,
):
"""Initialize with real tensors, creating mock nodes for the base class"""
mock_nodes = [MockTensorNode(t) for t in tensors]
super().__init__(
mock_nodes, scalars, x_idx=x_idx, weight_idx=weight_idx, bias_idx=bias_idx
)
self.tensors = tensors # Keep reference to original tensors
def shapes_hinted(self) -> tuple[tuple[int, ...], ...]:
"""Delegate to symbolic since real tensors already have int shapes"""
return self.shapes_symbolic()
def strides_hinted(self) -> tuple[tuple[int, ...], ...]:
"""Delegate to symbolic since real tensors already have int strides"""
return self.strides_symbolic() # pyre-ignore
@property
def device_type(self) -> Optional[str]:
return self.tensors[0].device.type
class BaseLookupTableTest(TestCase):
"""Base class for lookup table tests with common setup and utilities"""
@ -103,7 +138,7 @@ class BaseLookupTableTest(TestCase):
shapes: Optional[list[tuple[int, ...]]] = None,
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.float32,
scalars: Optional[dict[str, Union[float, int]]] = None,
scalars: Optional[dict[str, SerializableValue]] = None,
) -> MockMMKernelInputs:
"""Create MockMMKernelInputs with real tensors"""
if shapes is None:
@ -1055,6 +1090,119 @@ class TestLookupTableE2E(BaseE2ELookupTableTest):
with patch.object(inductor_config.lookup_table, "check_src_hash", True):
self.run_model("mm", tensors)
@fresh_cache()
def test_conv2d_lookup_table_entry_e2e(self):
"""Test end-to-end conv2d with lookup table entry - verifies config is picked up and produces valid results"""
import torch._inductor.kernel.conv
# Create input tensors with specific shapes for conv2d
# Input: [batch=2, in_channels=3, height=32, width=32]
# Weight: [out_channels=64, in_channels=3, kernel_h=3, kernel_w=3]
# Make them channels-last to match what conv lowering uses
x = torch.randn(2, 3, 32, 32, device=self.device, dtype=torch.float16).to(
memory_format=torch.channels_last
)
weight = torch.randn(64, 3, 3, 3, device=self.device, dtype=torch.float16).to(
memory_format=torch.channels_last
)
# Define conv parameters - use these SAME values everywhere
stride = (1, 1)
padding = (1, 1)
dilation = (1, 1)
groups = 1
# Create MockConvKernelInputs using the SAME tensors and SAME scalar values
mock_scalars = {
"stride": stride,
"padding": padding,
"dilation": dilation,
"transposed": False,
"output_padding": (0, 0),
"groups": groups,
}
mock_kernel_inputs = MockConvKernelInputs([x, weight], mock_scalars)
# Create lookup key for "convolution" operation
choices_handler = LookupTableChoices()
lookup_key = choices_handler.make_lookup_key(mock_kernel_inputs, "convolution")
# Get the exact template UID from conv2d_template
template_uid = torch._inductor.kernel.conv.conv2d_template.uid
# Create a precisely configured conv2d config
# IMPORTANT: Only include per-config tunable parameters!
# Static parameters (KERNEL_H, STRIDE_H, GROUPS, UNROLL, ALLOW_TF32) are
# automatically generated by get_extra_kwargs() and should NOT be in the lookup table
conv2d_config = {
"template_id": template_uid,
# Per-config tunable parameters only (what you'd tune via autotuning)
"BLOCK_M": 64,
"BLOCK_N": 64,
"BLOCK_K": 32,
"num_stages": 2,
"num_warps": 4,
}
# Setup lookup table
inductor_config.lookup_table.table = {lookup_key: [conv2d_config]}
def validate_conv_choice(choices):
assert len(choices) == 1, (
f"Expected 1 choice from lookup table, got {len(choices)}"
)
assert isinstance(choices[0], TritonTemplateCaller), (
f"Expected TritonTemplateCaller, got {type(choices[0])}"
)
assert "convolution2d" in choices[0].name, (
f"Expected 'convolution2d' in name, got {choices[0].name}"
)
return choices
add_preprocessing_fn(validate_conv_choice)
# Create and compile the model using the SAME weight tensor
class SimpleConv2d(nn.Module):
def __init__(self, weight):
super().__init__()
self.register_buffer("weight", weight)
def forward(self, x):
return torch.conv2d(
x,
self.weight,
bias=None,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
model = SimpleConv2d(weight).to(self.device)
with inductor_config.patch({"max_autotune": True, "max_autotune_gemm": True}):
compiled_model = torch.compile(model)
result = compiled_model(x) # Use the SAME x tensor
# Output shape: [batch=2, out_channels=64, out_h=32, out_w=32]
# (same spatial dims due to padding=1, stride=1, kernel=3)
expected_shape = (2, 64, 32, 32)
self.assertEqual(
result.shape,
expected_shape,
f"Expected shape {expected_shape}, got {result.shape}",
)
self.assertFalse(
torch.isnan(result).any().item(),
"Output contains NaN values",
)
self.assertFalse(
torch.isinf(result).any().item(),
"Output contains Inf values",
)
if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu

176
test/test_as_strided.py Normal file
View File

@ -0,0 +1,176 @@
# Owner(s): ["oncall: pt2"]
from collections import deque
from typing import Optional
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
def get_state(t: torch.Tensor) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""Extract (sizes, strides) tuple from a tensor."""
return (tuple(t.size()), tuple(t.stride()))
def enumerate_reachable_states(
initial_size: int,
) -> set[tuple[tuple[int, ...], tuple[int, ...]]]:
"""
Use BFS with DP to enumerate all reachable (size, stride) states from
a 1D contiguous tensor via valid view operations.
We only explore states with offset=0 (you can retroactively change the offset).
We reject states with size=0 or size=1 dimensions as they are degenerate.
"""
# Create initial 1D contiguous tensor
initial_tensor = torch.arange(initial_size)
initial_state = get_state(initial_tensor)
# Map from state to tensor for that state
state_to_tensor: dict[tuple[tuple[int, ...], tuple[int, ...]], torch.Tensor] = {
initial_state: initial_tensor
}
visited: set[tuple[tuple[int, ...], tuple[int, ...]]] = {initial_state}
queue: deque[tuple[tuple[int, ...], tuple[int, ...]]] = deque([initial_state])
while queue:
state = queue.popleft()
t = state_to_tensor[state]
sizes, strides = state
ndim = len(sizes)
def add_state(new_t: torch.Tensor) -> None:
new_state = get_state(new_t)
sizes, strides = new_state
# Skip if has size-0 or size-1 dimensions
if any(s == 0 or s == 1 for s in sizes):
return
# Only accept states where strides are in descending order
if list(strides) != sorted(strides, reverse=True):
return
if new_state not in visited:
visited.add(new_state)
queue.append(new_state)
state_to_tensor[new_state] = new_t
# 1. Unflatten: try factoring each dimension
for dim in range(ndim):
size = sizes[dim]
assert size > 1
# Try all factorizations x * y = size where both x, y >= 2
# We only need to check x up to size // 2 since when x > size // 2,
# y = size // x < 2, which we reject
for x in range(2, size // 2 + 1):
if size % x == 0:
y = size // x
add_state(t.unflatten(dim, (x, y)))
# 2. Slice: exhaustively check all possible slicing parameters
for dim in range(ndim):
size = sizes[dim]
for start in range(size):
for stop in range(start + 1, size + 1):
for step in range(1, size + 1):
slices = [slice(None)] * ndim
slices[dim] = slice(start, stop, step)
add_state(t[tuple(slices)])
# 3. Flatten: merge adjacent dimensions
for dim in range(ndim - 1):
add_state(t.flatten(dim, dim + 1))
return visited
class TestAsStrided(TestCase):
def test_size_10_exhaustive(self) -> None:
"""Test that size 10 produces exactly the expected 54 states."""
expected_states = {
((2,), (1,)),
((2,), (2,)),
((2,), (3,)),
((2,), (4,)),
((2,), (5,)),
((2,), (6,)),
((2,), (7,)),
((2,), (8,)),
((2,), (9,)),
((2, 2), (2, 1)),
((2, 2), (3, 1)),
((2, 2), (3, 2)),
((2, 2), (4, 1)),
((2, 2), (4, 2)),
((2, 2), (4, 3)),
((2, 2), (5, 1)),
((2, 2), (5, 2)),
((2, 2), (5, 3)),
((2, 2), (5, 4)),
((2, 2), (6, 1)),
((2, 2), (6, 2)),
((2, 2), (6, 3)),
((2, 2), (8, 1)),
((2, 2, 2), (4, 2, 1)),
((2, 2, 2), (5, 2, 1)),
((2, 3), (3, 1)),
((2, 3), (4, 1)),
((2, 3), (5, 1)),
((2, 3), (5, 2)),
((2, 3), (6, 1)),
((2, 4), (4, 1)),
((2, 4), (5, 1)),
((2, 5), (5, 1)),
((3,), (1,)),
((3,), (2,)),
((3,), (3,)),
((3,), (4,)),
((3, 2), (2, 1)),
((3, 2), (3, 1)),
((3, 2), (3, 2)),
((3, 2), (4, 1)),
((3, 3), (3, 1)),
((4,), (1,)),
((4,), (2,)),
((4,), (3,)),
((4, 2), (2, 1)),
((5,), (1,)),
((5,), (2,)),
((5, 2), (2, 1)),
((6,), (1,)),
((7,), (1,)),
((8,), (1,)),
((9,), (1,)),
((10,), (1,)),
}
actual_states = enumerate_reachable_states(10)
self.assertEqual(len(actual_states), 54)
self.assertEqual(actual_states, expected_states)
def test_subset_property(self) -> None:
"""
Test that for sizes 2..10, each smaller tensor results in a strict
subset of possible states compared to the next one.
"""
prev_states: Optional[set[tuple[tuple[int, ...], tuple[int, ...]]]] = None
for size in range(2, 11):
current_states = enumerate_reachable_states(size)
if prev_states is not None:
# Check that prev_states is a strict subset of current_states
self.assertTrue(
prev_states.issubset(current_states),
f"States from size {size - 1} are not a subset of size {size}",
)
# Check that it's a strict subset (not equal)
self.assertTrue(
len(prev_states) < len(current_states),
f"States from size {size - 1} should be strictly fewer than size {size}",
)
prev_states = current_states
if __name__ == "__main__":
run_tests()

View File

@ -75,12 +75,6 @@ from torch.testing._internal.common_utils import (
)
from torch.testing._internal.jit_utils import JitTestCase
import json
import tempfile
from torch.profiler import profile, ProfilerActivity
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
from torch.autograd.profiler_util import _canonicalize_profiler_events
try:
from torchvision import models as torchvision_models
@ -207,36 +201,6 @@ def side_effect_func(x: torch.Tensor):
print(x)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
trace_file = f.name
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(
trace_data
)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
class TestFX(JitTestCase):
def setUp(self):
super().setUp()
@ -4248,150 +4212,6 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
# recorver mutable checking flag
torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
augments profiler events with stack traces from FX metadata registry.
"""
# Simple test model
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(16, 10)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
model = TestModel().cuda()
# Compile the model
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(torch.randn(10, 10, device="cuda"))
# Profile with the compiled model
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result = compiled_model(torch.randn(10, 10, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::t node=t stack_trace=x = self.linear1(x)
event=aten::transpose node=t stack_trace=x = self.linear1(x)
event=aten::as_strided node=t stack_trace=x = self.linear1(x)
event=aten::addmm node=addmm stack_trace=x = self.linear1(x)
event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x)
event=aten::relu node=relu stack_trace=x = self.relu(x)
event=aten::clamp_min node=relu stack_trace=x = self.relu(x)
event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x)
event=aten::t node=t_1 stack_trace=x = self.linear2(x)
event=aten::transpose node=t_1 stack_trace=x = self.linear2(x)
event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x)
event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x)
event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_multiple_modules(self):
"""
Test that multiple compiled modules under the same profiler session
have their events correctly augmented with stack traces.
"""
class ModelA(torch.nn.Module):
def forward(self, x):
return x + 1
class ModelB(torch.nn.Module):
def forward(self, x):
return x - 1
model_a = ModelA().cuda()
model_b = ModelB().cuda()
# Compile both models
compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True)
compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_a(torch.randn(10, 10, device="cuda"))
_ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
# Profile both models in the same session
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result_a = compiled_a(torch.randn(10, 10, device="cuda"))
result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::add node=add stack_trace=return x + 1
event=cudaLaunchKernel node=add stack_trace=return x + 1
event=aten::sub node=sub stack_trace=return x - 1
event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_nested_graph_modules(self):
"""
Test that nested graph modules (e.g., graph modules calling subgraphs)
have their events correctly augmented with stack traces.
"""
# Model with nested structure
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.c = 5
@torch.compiler.nested_compile_region
def forward(self, x, y):
m = torch.mul(x, y)
s = m.sin()
a = s + self.c
return a
model = Mod().cuda()
# Compile the model (this may create nested graph modules)
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
# Profile
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::mul node=mul stack_trace=m = torch.mul(x, y)
event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y)
event=aten::sin node=sin stack_trace=s = m.sin()
event=cudaLaunchKernel node=sin stack_trace=s = m.sin()
event=aten::add node=add stack_trace=a = s + self.c
event=cudaLaunchKernel node=add stack_trace=a = s + self.c"""
)
def run_getitem_target():
from torch.fx._symbolic_trace import _wrapped_methods_to_patch

View File

@ -490,8 +490,6 @@ class TestMatmulCuda(InductorTestCase):
@parametrize("b_row_major", [False, True])
@dtypes(torch.bfloat16, torch.float32, torch.float16)
def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype):
if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]:
self.skipTest("failed using hipblaslt on rocm 6.4.2")
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 64, 4

View File

@ -3320,7 +3320,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, SetVariable)
assert obj.is_mutable()
obj.call_method(self, "add", [v], {})
obj.call_method(self, "add", [v], {}) # type: ignore[arg-type]
def SET_UPDATE(self, inst: Instruction) -> None:
v = self.pop()
@ -3329,7 +3329,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, SetVariable)
assert obj.is_mutable()
obj.call_method(self, "update", [v], {})
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
def LIST_APPEND(self, inst: Instruction) -> None:
v = self.pop()
@ -3637,7 +3637,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg].realize()
assert isinstance(obj, ConstDictVariable)
assert obj.is_mutable()
obj.call_method(self, "update", [v], {})
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
DICT_UPDATE = DICT_MERGE

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
Dictionary-related variable tracking classes for PyTorch Dynamo.
@ -26,7 +24,7 @@ import inspect
import operator
import types
from collections.abc import Hashable as py_Hashable
from typing import Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING, Union
from torch._subclasses.fake_tensor import is_fake
@ -59,11 +57,13 @@ if TYPE_CHECKING:
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
def was_instancecheck_override(obj):
def was_instancecheck_override(obj: Any) -> bool:
return type(obj).__dict__.get("__instancecheck__", False)
def raise_unhashable(arg, tx=None):
def raise_unhashable(
arg: VariableTracker, tx: Optional["InstructionTranslator"] = None
) -> None:
if tx is None:
from torch._dynamo.symbolic_convert import InstructionTranslator
@ -75,7 +75,7 @@ def raise_unhashable(arg, tx=None):
)
def is_hashable(x):
def is_hashable(x: VariableTracker) -> bool:
# NB - performing isinstance check on a LazVT realizes the VT, accidentally
# inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at
# the underlying value without realizing the VT. Consider updating the
@ -143,7 +143,7 @@ class ConstDictVariable(VariableTracker):
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
"""
def __init__(self, vt) -> None:
def __init__(self, vt: VariableTracker) -> None:
# We specialize SymNodes
vt = specialize_symnode(vt)
# TODO Temporarily remove to figure out what keys are we breaking on
@ -153,7 +153,7 @@ class ConstDictVariable(VariableTracker):
self.vt = vt
@property
def underlying_value(self):
def underlying_value(self) -> Any:
if (
isinstance(self.vt, variables.LazyVariableTracker)
and not self.vt.is_realized()
@ -178,7 +178,8 @@ class ConstDictVariable(VariableTracker):
elif isinstance(self.vt, variables.FrozenDataClassVariable):
Hashable = ConstDictVariable._HashableTracker
fields_values = {
k: Hashable(v).underlying_value for k, v in self.vt.fields.items()
k: Hashable(v).underlying_value
for k, v in self.vt.fields.items() # type: ignore[attr-defined]
}
return variables.FrozenDataClassVariable.HashWrapper(
self.vt.python_type(), fields_values
@ -187,16 +188,16 @@ class ConstDictVariable(VariableTracker):
# The re module in Python 3.13+ has a dictionary (_cache2) with
# an object as key (`class _ZeroSentinel(int): ...`):
# python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual
return self.vt.value
return self.vt.value # type: ignore[attr-defined,union-attr]
else:
x = self.vt.as_python_constant()
return x
def __hash__(self):
def __hash__(self) -> int:
return hash(self.underlying_value)
@staticmethod
def _eq_impl(a, b):
def _eq_impl(a: Any, b: Any) -> bool:
# TODO: Put this in utils and share it between variables/builtin.py and here
type_a, type_b = type(a), type(b)
if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)):
@ -212,7 +213,7 @@ class ConstDictVariable(VariableTracker):
else:
return a == b
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
def __eq__(self, other: object) -> bool:
Hashable = ConstDictVariable._HashableTracker
assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), (
type(other)
@ -226,8 +227,8 @@ class ConstDictVariable(VariableTracker):
def __init__(
self,
items: dict[VariableTracker, VariableTracker],
user_cls=dict,
**kwargs,
user_cls: type = dict,
**kwargs: Any,
) -> None:
# .clone() pass these arguments in kwargs but they're recreated a few
# lines below
@ -247,18 +248,22 @@ class ConstDictVariable(VariableTracker):
for x, v in items.items()
)
def make_hashable(key):
def make_hashable(
key: Union[VariableTracker, "ConstDictVariable._HashableTracker"],
) -> "ConstDictVariable._HashableTracker":
return key if isinstance(key, Hashable) else Hashable(key)
dict_cls = self._get_dict_cls_from_user_cls(user_cls)
self.items = dict_cls({make_hashable(x): v for x, v in items.items()})
# need to reconstruct everything if the dictionary is an intermediate value
# or if a pop/delitem was executed
self.should_reconstruct_all = not is_from_local_source(self.source)
self.should_reconstruct_all = (
not is_from_local_source(self.source) if self.source else True
)
self.original_items = items.copy()
self.user_cls = user_cls
def _get_dict_cls_from_user_cls(self, user_cls):
def _get_dict_cls_from_user_cls(self, user_cls: type) -> type:
accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict)
# avoid executing user code if user_cls is a dict subclass
@ -277,10 +282,10 @@ class ConstDictVariable(VariableTracker):
dict_cls = dict
return dict_cls
def as_proxy(self):
def as_proxy(self) -> dict[Any, Any]:
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
def debug_repr(self):
def debug_repr(self) -> str:
return (
"{"
+ ", ".join(
@ -289,20 +294,20 @@ class ConstDictVariable(VariableTracker):
+ "}"
)
def as_python_constant(self):
def as_python_constant(self) -> dict[Any, Any]:
return {
k.vt.as_python_constant(): v.as_python_constant()
for k, v in self.items.items()
}
def keys_as_python_constant(self):
def keys_as_python_constant(self) -> dict[Any, VariableTracker]:
self.install_dict_keys_match_guard()
return {k.vt.as_python_constant(): v for k, v in self.items.items()}
def python_type(self):
def python_type(self) -> type:
return self.user_cls
def __contains__(self, vt) -> bool:
def __contains__(self, vt: VariableTracker) -> bool:
assert isinstance(vt, VariableTracker)
Hashable = ConstDictVariable._HashableTracker
return (
@ -322,13 +327,15 @@ class ConstDictVariable(VariableTracker):
for key, value in self.items.items()
)
def is_new_item(self, value, other):
def is_new_item(
self, value: Optional[VariableTracker], other: VariableTracker
) -> bool:
# compare the id of the realized values if both values are not lazy VTs
if value and value.is_realized() and other.is_realized():
return id(value.realize()) != id(other.realize())
return id(value) != id(other)
def reconstruct_kvs_into_new_dict(self, codegen):
def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None:
# Build a dictionary that contains the keys and values.
num_args = 0
for key, value in self.items.items():
@ -340,7 +347,7 @@ class ConstDictVariable(VariableTracker):
num_args += 1
codegen.append_output(create_instruction("BUILD_MAP", arg=num_args))
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
if self.user_cls is collections.OrderedDict:
# emit `OrderedDict(constructed_dict)`
codegen.add_push_null(
@ -358,19 +365,21 @@ class ConstDictVariable(VariableTracker):
def getitem_const_raise_exception_if_absent(
self, tx: "InstructionTranslator", arg: VariableTracker
):
) -> VariableTracker:
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
raise_observed_exception(KeyError, tx)
return self.items[key]
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
msg = f"Dictionary key {arg.value} not found during tracing"
msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined]
unimplemented_v2(
gb_type="key not found in dict",
context=f"Key {arg.value}",
context=f"Key {arg.value}", # type: ignore[attr-defined]
explanation=msg,
hints=[
"Check if the key exists in the dictionary before accessing it.",
@ -379,13 +388,13 @@ class ConstDictVariable(VariableTracker):
)
return self.items[key]
def maybe_getitem_const(self, arg: VariableTracker):
def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]:
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
return None
return self.items[key]
def realize_key_vt(self, arg: VariableTracker):
def realize_key_vt(self, arg: VariableTracker) -> None:
# Realize the LazyVT on a particular index
assert arg in self
key = ConstDictVariable._HashableTracker(arg)
@ -394,11 +403,13 @@ class ConstDictVariable(VariableTracker):
if isinstance(original_key_vt, variables.LazyVariableTracker):
original_key_vt.realize()
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
if self.source:
install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH))
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
# Key guarding - These are the cases to consider
# 1) The dict has been mutated. In this case, we would have already
# inserted a DICT_KEYS_MATCH guard, so we can skip.
@ -439,11 +450,11 @@ class ConstDictVariable(VariableTracker):
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
# NB - Both key and value are LazyVariableTrackers in the beginning. So,
# we have to insert guards when a dict method is accessed. For this to
# be simple, we are conservative and overguard. We skip guard only for
@ -462,7 +473,7 @@ class ConstDictVariable(VariableTracker):
tx, *args, **kwargs
)
tx.output.side_effects.mutation(self)
self.items.update(temp_dict_vt.items)
self.items.update(temp_dict_vt.items) # type: ignore[attr-defined]
return ConstantVariable.create(None)
elif name == "__getitem__":
# Key guarding - Nothing to do. LazyVT for value will take care.
@ -526,7 +537,7 @@ class ConstDictVariable(VariableTracker):
return ConstantVariable.create(len(self.items))
elif name == "__setitem__" and self.is_mutable():
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
self.install_dict_keys_match_guard()
if kwargs or len(args) != 2:
@ -550,7 +561,7 @@ class ConstDictVariable(VariableTracker):
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
if args[0] not in self:
self.install_dict_contains_guard(tx, args)
@ -565,7 +576,7 @@ class ConstDictVariable(VariableTracker):
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
if args[0] not in self:
# missing item, return the default value. Install no DICT_CONTAINS guard.
@ -599,7 +610,7 @@ class ConstDictVariable(VariableTracker):
last = v.value
else:
raise_args_mismatch(tx, name)
k, v = self.items.popitem(last=last)
k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined]
else:
k, v = self.items.popitem()
@ -632,17 +643,17 @@ class ConstDictVariable(VariableTracker):
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard()
dict_vt = args[0]
dict_vt: ConstDictVariable = args[0]
else:
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
self.items.update(dict_vt.items)
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment]
self.items.update(dict_vt.items) # type: ignore[attr-defined]
if has_kwargs:
# Handle kwargs
kwargs = {
kwargs_hashable = {
Hashable(ConstantVariable.create(k)): v
for k, v in kwargs.items()
}
self.items.update(kwargs)
self.items.update(kwargs_hashable)
return ConstantVariable.create(None)
else:
return super().call_method(tx, name, args, kwargs)
@ -656,7 +667,7 @@ class ConstDictVariable(VariableTracker):
)
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
self.install_dict_contains_guard(tx, args)
contains = args[0] in self
@ -671,7 +682,7 @@ class ConstDictVariable(VariableTracker):
)
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
self.install_dict_keys_match_guard()
if kwargs or len(args) > 2:
@ -707,7 +718,7 @@ class ConstDictVariable(VariableTracker):
and "last" in kwargs
and isinstance(kwargs["last"], ConstantVariable)
):
last = kwargs.get("last").value
last = kwargs.get("last").value # type: ignore[union-attr]
key = Hashable(args[0])
self.items.move_to_end(key, last=last)
@ -723,7 +734,7 @@ class ConstDictVariable(VariableTracker):
)
elif name == "__ne__":
return ConstantVariable.create(
not self.call_method(tx, "__eq__", args, kwargs).value
not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined]
)
elif name == "__or__":
if len(args) != 1:
@ -750,14 +761,14 @@ class ConstDictVariable(VariableTracker):
if not istype(
other, (ConstDictVariable, variables.UserDefinedDictVariable)
):
msg = (
err_msg = (
f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
f"and '{other.python_type().__name__}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
raise_observed_exception(TypeError, tx, args=[err_msg])
# OrderedDict overloads __ror__
ts = {self.user_cls, other.user_cls}
ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined]
user_cls = (
collections.OrderedDict
if any(issubclass(t, collections.OrderedDict) for t in ts)
@ -774,8 +785,8 @@ class ConstDictVariable(VariableTracker):
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard()
new_dict_vt.items.update(args[0].items)
args[0].install_dict_keys_match_guard() # type: ignore[attr-defined]
new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined]
return new_dict_vt
elif name == "__ior__":
self.call_method(tx, "update", args, kwargs)
@ -789,11 +800,13 @@ class ConstDictVariable(VariableTracker):
else:
return super().call_method(tx, name, args, kwargs)
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
self.install_dict_keys_match_guard()
return [x.vt for x in self.items.keys()]
def call_obj_hasattr(self, tx, name):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
# dict not allow setting arbitrary attributes. OrderedDict and
# defaultdict allow arbitrary setattr, but not deletion of default attrs
if any(
@ -816,25 +829,25 @@ class ConstDictVariable(VariableTracker):
],
)
def clone(self, **kwargs):
def clone(self, **kwargs: Any) -> VariableTracker:
self.install_dict_keys_match_guard()
return super().clone(**kwargs)
class MappingProxyVariable(VariableTracker):
# proxies to the original dict_vt
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
super().__init__(**kwargs)
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
def python_type(self):
def python_type(self) -> type:
return types.MappingProxyType
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
return self.dv_dict.unpack_var_sequence(tx)
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
# load types.MappingProxyType
if self.source:
msg = (
@ -863,11 +876,11 @@ class MappingProxyVariable(VariableTracker):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if self.source and tx.output.side_effects.has_existing_dict_mutation():
msg = (
"A dict has been modified while we have an existing mappingproxy object. "
@ -892,7 +905,7 @@ class MappingProxyVariable(VariableTracker):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is types.MappingProxyType:
return ConstantVariable.create(name in types.MappingProxyType.__dict__)
return super().call_obj_hasattr(tx, name)
@ -900,35 +913,44 @@ class MappingProxyVariable(VariableTracker):
class NNModuleHooksDictVariable(ConstDictVariable):
# Special class to avoid adding any guards on the nn module hook ids.
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
pass
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
pass
class DefaultDictVariable(ConstDictVariable):
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
def __init__(
self,
items: dict[VariableTracker, VariableTracker],
user_cls: type,
default_factory: Optional[VariableTracker] = None,
**kwargs: Any,
) -> None:
super().__init__(items, user_cls, **kwargs)
assert user_cls is collections.defaultdict
if default_factory is None:
default_factory = ConstantVariable.create(None)
self.default_factory = default_factory
def is_python_constant(self):
def is_python_constant(self) -> bool:
# Return false for unsupported defaults. This ensures that a bad handler
# path is not taken in BuiltinVariable for getitem.
if self.default_factory not in [list, tuple, dict] and not self.items:
return False
return super().is_python_constant()
def debug_repr(self):
def debug_repr(self) -> str:
assert self.default_factory is not None
return (
f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})"
)
@staticmethod
def is_supported_arg(arg):
def is_supported_arg(arg: VariableTracker) -> bool:
if isinstance(arg, variables.BuiltinVariable):
return arg.fn in (list, tuple, dict, set)
else:
@ -942,11 +964,11 @@ class DefaultDictVariable(ConstDictVariable):
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__getitem__":
if len(args) != 1:
raise_args_mismatch(tx, name, "1 args", f"{len(args)} args")
@ -962,13 +984,13 @@ class DefaultDictVariable(ConstDictVariable):
else:
default_var = self.default_factory.call_function(tx, [], {})
super().call_method(
tx, "__setitem__", (args[0], default_var), kwargs
tx, "__setitem__", [args[0], default_var], kwargs
)
return default_var
else:
return super().call_method(tx, name, args, kwargs)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen") -> None:
# emit `defaultdict(default_factory, new_dict)`
codegen.add_push_null(
lambda: codegen.extend_output(
@ -994,40 +1016,48 @@ class SetVariable(ConstDictVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
**kwargs: Any,
) -> None:
# pyrefly: ignore[bad-assignment]
items = dict.fromkeys(items, SetVariable._default_value())
# pyrefly: ignore[bad-argument-type]
super().__init__(items, **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
if not self.items:
return "set()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self):
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
return set(self.items.keys())
@staticmethod
def _default_value():
def _default_value() -> VariableTracker:
# Variable to fill in he keys of the dictionary
return ConstantVariable.create(None)
def as_proxy(self):
def as_proxy(self) -> Any:
return {k.vt.as_proxy() for k in self.set_items}
def python_type(self):
def python_type(self) -> type:
return set
def as_python_constant(self):
def as_python_constant(self) -> Any:
return {k.vt.as_python_constant() for k in self.set_items}
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.foreach([x.vt for x in self.set_items])
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
def _fast_set_method(self, tx, fn, args, kwargs):
def _fast_set_method(
self,
tx: "InstructionTranslator",
fn: Any,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
try:
res = fn(
*[x.as_python_constant() for x in [self, *args]],
@ -1037,15 +1067,16 @@ class SetVariable(ConstDictVariable):
raise_observed_exception(
type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
)
# pyrefly: ignore[unbound-name]
return VariableTracker.build(tx, res)
def call_method(
self,
tx,
name,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
) -> VariableTracker:
# We forward the calls to the dictionary model
from ..utils import check_constant_args
@ -1065,10 +1096,10 @@ class SetVariable(ConstDictVariable):
return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
if name == "__init__":
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs)
tx.output.side_effects.mutation(self)
self.items.clear()
self.items.update(temp_set_vt.items)
self.items.update(temp_set_vt.items) # type: ignore[attr-defined]
return ConstantVariable.create(None)
elif name == "add":
if kwargs or len(args) != 1:
@ -1079,7 +1110,7 @@ class SetVariable(ConstDictVariable):
f"{len(args)} args and {len(kwargs)} kwargs",
)
name = "__setitem__"
args = (args[0], SetVariable._default_value())
args = [args[0], SetVariable._default_value()]
elif name == "pop":
if kwargs or args:
raise_args_mismatch(
@ -1090,12 +1121,14 @@ class SetVariable(ConstDictVariable):
)
# Choose an item at random and pop it via the Dict.pop method
try:
result = self.set_items.pop().vt
result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment]
except KeyError as e:
raise_observed_exception(
KeyError, tx, args=list(map(ConstantVariable.create, e.args))
)
super().call_method(tx, name, (result,), kwargs)
# pyrefly: ignore[unbound-name]
super().call_method(tx, name, [result], kwargs)
# pyrefly: ignore[unbound-name]
return result
elif name == "isdisjoint":
if kwargs or len(args) != 1:
@ -1217,6 +1250,7 @@ class SetVariable(ConstDictVariable):
f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
assert m is not None
return self.call_method(tx, m, args, kwargs)
elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"):
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
@ -1230,29 +1264,34 @@ class SetVariable(ConstDictVariable):
"__ixor__": "symmetric_difference_update",
"__isub__": "difference_update",
}.get(name)
assert m is not None
self.call_method(tx, m, args, kwargs)
return self
elif name == "__eq__":
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
return ConstantVariable.create(False)
r = self.call_method(tx, "symmetric_difference", args, kwargs)
return ConstantVariable.create(len(r.set_items) == 0)
return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined]
elif name in cmp_name_to_op_mapping:
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
return ConstantVariable.create(NotImplemented)
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
)
return super().call_method(tx, name, args, kwargs)
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
raise RuntimeError("Illegal to getitem on a set")
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
super().install_dict_contains_guard(tx, args)
@ -1260,27 +1299,27 @@ class FrozensetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
**kwargs: Any,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
if not self.items:
return "frozenset()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self):
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
return self.items.keys()
def python_type(self):
def python_type(self) -> type:
return frozenset
def as_python_constant(self):
def as_python_constant(self) -> Any:
return frozenset({k.vt.as_python_constant() for k in self.set_items})
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.foreach([x.vt for x in self.set_items])
codegen.add_push_null(
lambda: codegen.extend_output(
@ -1293,11 +1332,11 @@ class FrozensetVariable(SetVariable):
def call_method(
self,
tx,
name,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
) -> VariableTracker:
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a frozenset")
elif name == "__init__":
@ -1316,7 +1355,7 @@ class FrozensetVariable(SetVariable):
"symmetric_difference",
):
r = super().call_method(tx, name, args, kwargs)
return FrozensetVariable(r.items)
return FrozensetVariable(r.items) # type: ignore[attr-defined]
return super().call_method(tx, name, args, kwargs)
@ -1324,11 +1363,11 @@ class DictKeySetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
**kwargs: Any,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
if not self.items:
return "dict_keys([])"
else:
@ -1338,33 +1377,35 @@ class DictKeySetVariable(SetVariable):
+ "])"
)
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
# Already EQUALS_MATCH guarded
pass
@property
def set_items(self):
def set_items(self) -> Any:
return self.items
def python_type(self):
def python_type(self) -> type:
return dict_keys
def as_python_constant(self):
def as_python_constant(self) -> Any:
return dict.fromkeys(
{k.vt.as_python_constant() for k in self.set_items}, None
).keys()
def call_method(
self,
tx,
name,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
) -> VariableTracker:
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a dict_keys")
return super().call_method(tx, name, args, kwargs)
@ -1379,42 +1420,47 @@ class DictViewVariable(VariableTracker):
kv: Optional[str] = None
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
super().__init__(**kwargs)
assert self.kv in ("keys", "values", "items")
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
@property
def view_items(self):
def view_items(self) -> Any:
assert self.kv is not None
return getattr(self.dv_dict.items, self.kv)()
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
# Returns an iterable of the unpacked items
# Implement in the subclasses
raise NotImplementedError
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
return self.view_items_vt
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
assert self.kv is not None
codegen(self.dv_dict)
codegen.load_method(self.kv)
codegen.call_method(0)
def call_obj_hasattr(self, tx, name):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
assert self.kv is not None
if name in self.python_type().__dict__:
return ConstantVariable.create(True)
return ConstantVariable.create(False)
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__len__":
return self.dv_dict.call_method(tx, name, args, kwargs)
elif name == "__iter__":
@ -1428,24 +1474,24 @@ class DictKeysVariable(DictViewVariable):
kv = "keys"
@property
def set_items(self):
def set_items(self) -> set[VariableTracker]:
return set(self.view_items)
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
# Returns an iterable of the unpacked items
return [x.vt for x in self.view_items]
def python_type(self):
def python_type(self) -> type:
return dict_keys
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__contains__":
return self.dv_dict.call_method(tx, name, args, kwargs)
elif name in (
@ -1460,13 +1506,13 @@ class DictKeysVariable(DictViewVariable):
):
# These methods always returns a set
m = getattr(self.set_items, name)
r = m(args[0].set_items)
r = m(args[0].set_items) # type: ignore[attr-defined]
return SetVariable(r)
if name in cmp_name_to_op_mapping:
if not isinstance(args[0], (SetVariable, DictKeysVariable)):
return ConstantVariable.create(NotImplemented)
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
)
return super().call_method(tx, name, args, kwargs)
@ -1476,10 +1522,10 @@ class DictValuesVariable(DictViewVariable):
kv = "values"
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
return list(self.view_items)
def python_type(self):
def python_type(self) -> type:
return dict_values
@ -1487,14 +1533,20 @@ class DictItemsVariable(DictViewVariable):
kv = "items"
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
# Returns an iterable of the unpacked items
return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items]
def python_type(self):
def python_type(self) -> type:
return dict_items
def call_method(self, tx, name, args, kwargs):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
# TODO(guilhermeleobas): This should actually check if args[0]
# implements the mapping protocol.
if name == "__eq__":

View File

@ -627,7 +627,7 @@ class ComboKernel(Kernel):
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
num_warps={self.num_warps},
filename=__file__,
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)

View File

@ -8,6 +8,7 @@ import torch
from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
from .. import config, ir
from ..kernel_inputs import ConvKernelInputs
from ..lowering import (
add_layout_constraint,
constrain_to_fx_strides,
@ -16,7 +17,9 @@ from ..lowering import (
)
from ..select_algorithm import (
autotune_select_algorithm,
ChoiceCaller,
ExternKernelChoice,
KernelTemplate,
SymbolicGridFn,
TritonTemplate,
)
@ -76,7 +79,7 @@ LOOP_BODY_2D = """
& (idx_x_h < IN_H)[:, None]
& (idx_x_w >= 0)[:, None]
& (idx_x_w < IN_W)[:, None]
& (idx_x_c < GROUP_IN_C)[None, :]
& (idx_x_c < GROUP_IN_C)[None, :
)
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
@ -542,34 +545,40 @@ def convolution(
x = ir.ExternKernel.require_stride_order(x, req_stride_order) # type: ignore[assignment]
weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) # type: ignore[assignment]
ordered_kwargs_for_cpp_kernel = [
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
if bias is None:
args = [x, weight]
kwargs["bias"] = None # type: ignore[typeddict-unknown-key]
ordered_kwargs_for_cpp_kernel.insert(0, "bias")
else:
args = [x, weight, bias]
# Create ConvKernelInputs for unified template configuration
# Only include bias in input_nodes when it's not None
# - For Triton templates: bias is always None here (peeled off earlier), so input_nodes = [x, weight]
# - For ATEN: input_nodes = [x, weight] when bias is None, [x, weight, bias] when bias is present
if bias is not None:
bias.realize()
bias.freeze_layout()
V.graph.sizevars.guard_int_seq(bias.get_size())
input_nodes = [x, weight, bias]
bias_idx = 2
else:
input_nodes = [x, weight]
bias_idx = None
kernel_inputs = ConvKernelInputs(
input_nodes,
scalars={
"stride": stride,
"padding": padding,
"dilation": dilation,
"transposed": transposed,
"output_padding": output_padding,
"groups": groups,
},
x_idx=0,
weight_idx=1,
bias_idx=bias_idx,
)
# Build list of templates to try
templates: list[ExternKernelChoice | KernelTemplate] = []
choices = []
if torch._inductor.utils._use_conv_autotune_backend("ATEN"):
choices = [
aten_convolution.bind(
args,
layout,
ordered_kwargs_for_cpp_kernel,
**kwargs,
)
]
templates.append(aten_convolution)
if (
torch._inductor.utils._use_conv_autotune_backend("TRITON")
@ -587,60 +596,23 @@ def convolution(
and is_zeros(padding)
and groups == 1
):
choices.append(aten_conv1x1_via_mm.bind(args, layout))
templates.append(aten_conv1x1_via_mm)
conv_configs = V.choices.get_conv_configs(device_type)
# Add appropriate template based on ndim
if ndim == 2:
templates.append(conv2d_template)
elif ndim == 3:
templates.append(conv3d_template)
dtype_size = x.get_dtype().itemsize
for cfg in conv_configs(
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
out_chan,
in_chan,
dtype_size=dtype_size,
):
if ndim == 2:
conv2d_template.maybe_append_choice(
choices,
input_nodes=(x, weight),
layout=layout,
KERNEL_H=kernel_shape[0],
KERNEL_W=kernel_shape[1],
STRIDE_H=stride[0],
STRIDE_W=stride[1],
PADDING_H=padding[0],
PADDING_W=padding[1],
GROUPS=groups,
# TODO(jansel): try unroll for bigger kernels once fixed:
# https://github.com/triton-lang/triton/issues/1254
UNROLL=is_ones(kernel_shape),
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
num_stages=cfg.num_stages,
num_warps=cfg.num_warps,
**cfg.kwargs,
)
elif ndim == 3:
conv3d_template.maybe_append_choice(
choices,
input_nodes=(x, weight),
layout=layout,
KERNEL_D=kernel_shape[0],
KERNEL_H=kernel_shape[1],
KERNEL_W=kernel_shape[2],
STRIDE_D=stride[0],
STRIDE_H=stride[1],
STRIDE_W=stride[2],
PADDING_D=padding[0],
PADDING_H=padding[1],
PADDING_W=padding[2],
GROUPS=groups,
# TODO(jansel): try unroll for bigger kernels once fixed:
# https://github.com/triton-lang/triton/issues/1254
UNROLL=is_ones(kernel_shape),
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
num_stages=cfg.num_stages,
num_warps=cfg.num_warps,
**cfg.kwargs,
)
# Initialize choices list and extend with template configs
choices: list[ChoiceCaller] = []
choices.extend(
V.choices.get_template_configs(
kernel_inputs,
templates,
"convolution",
)
)
if use_ck_conv_template(layout):
CKGroupedConvFwdTemplate.add_ck_conv_choices(
choices,
@ -652,7 +624,9 @@ def convolution(
groups=groups,
n_spatial_dimensions=ndim,
)
return autotune_select_algorithm("convolution", choices, args, layout)
return autotune_select_algorithm(
"convolution", choices, kernel_inputs.nodes(), layout
)
@register_lowering(aten._convolution)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
@ -12,10 +13,12 @@ from .ir import FixedLayout, FlexibleLayout, Layout
if TYPE_CHECKING:
from collections.abc import Sequence
import sympy
# Type aliases for serializable scalar values
Serializable = Union[int, float, bool]
SerializableValue = Union[Serializable, Sequence[Serializable]]
class KernelInputs(ABC):
"""
@ -27,7 +30,7 @@ class KernelInputs(ABC):
def __init__(
self,
input_nodes: list[Any],
scalars: Optional[dict[str, Union[float, int]]] = None,
scalars: Optional[dict[str, SerializableValue]] = None,
out_dtype: Optional[torch.dtype] = None,
):
"""
@ -183,7 +186,7 @@ class KernelInputs(ABC):
The output dtype
"""
def get_scalar(self, name: str) -> Union[float, int]:
def get_scalar(self, name: str) -> SerializableValue:
"""
Get the scalar value for a given name.
@ -191,7 +194,7 @@ class KernelInputs(ABC):
name: Name of the scalar to get
Returns:
The scalar value
The scalar value (can be int, float, bool, or tuple of these types)
"""
assert name in self._scalars, f"Scalar {name} not found, but required"
return self._scalars[name]
@ -216,7 +219,7 @@ class MMKernelInputs(KernelInputs):
def __init__(
self,
input_nodes: list[Any],
scalars: Optional[dict[str, Union[float, int]]] = None,
scalars: Optional[dict[str, SerializableValue]] = None,
out_dtype: Optional[torch.dtype] = None,
mat1_idx: int = -2,
mat2_idx: int = -1,
@ -336,3 +339,113 @@ class MMKernelInputs(KernelInputs):
assert k == k_check, f"K dimensions don't match: {k} vs {k_check}"
return (m, n, k)
class ConvKernelInputs(KernelInputs):
"""
Specialized KernelInputs for convolution operations.
Stores input tensor, weight tensor, and optional bias, along with conv parameters.
"""
def __init__(
self,
input_nodes: list[Any],
scalars: Optional[dict[str, SerializableValue]] = None,
out_dtype: Optional[torch.dtype] = None,
x_idx: int = 0,
weight_idx: int = 1,
bias_idx: Optional[int] = None,
):
"""
Initialize with convolution input nodes.
Args:
input_nodes: List containing [x, weight] or [x, weight, bias]
scalars: Dict with conv params (stride, padding, dilation, groups, transposed, output_padding)
out_dtype: Optional output dtype
x_idx: Index of input tensor (default: 0)
weight_idx: Index of weight tensor (default: 1)
bias_idx: Index of bias tensor if present (default: None)
"""
super().__init__(input_nodes, scalars, out_dtype)
assert len(input_nodes) >= 2, "Expected at least 2 input nodes (x, weight)"
self._x_idx = x_idx
self._weight_idx = weight_idx
self._bias_idx = bias_idx
# Validate that required scalars are present
required_scalars = [
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
for key in required_scalars:
assert key in self._scalars, f"Conv requires scalar '{key}'"
def out_dtype(self) -> torch.dtype:
"""
Get the output dtype, whether passed in or inferred from the nodes
Returns:
The output dtype
"""
if self._out_dtype is not None:
return self._out_dtype
return self._input_nodes[self._x_idx].get_dtype()
def output_layout(self, flexible: bool = True) -> Layout:
"""
Handle output layout generation for convolution.
Args:
flexible: If True, return FlexibleLayout, otherwise FixedLayout
Returns:
Layout for the convolution output
"""
from torch._inductor.kernel.conv import conv_layout
x = self._input_nodes[self._x_idx]
weight = self._input_nodes[self._weight_idx]
bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None
# Use existing conv_layout function
# We know the types here because conv requires these specific scalar types
layout = conv_layout(
x,
weight,
bias,
self._scalars["stride"], # type: ignore[arg-type]
self._scalars["padding"], # type: ignore[arg-type]
self._scalars["dilation"], # type: ignore[arg-type]
self._scalars["transposed"], # type: ignore[arg-type]
self._scalars["output_padding"], # type: ignore[arg-type]
self._scalars["groups"], # type: ignore[arg-type]
)
# TODO: Handle flexible vs fixed based on config if needed
return layout
def get_x_weight_bias(self) -> tuple[Any, Any, Optional[Any]]:
"""
Get x, weight, and optional bias nodes.
Returns:
Tuple of (x, weight, bias) where bias may be None
"""
bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None
return self._input_nodes[self._x_idx], self._input_nodes[self._weight_idx], bias
def spatial_dims(self) -> tuple[Any, ...]:
"""
Get spatial dimensions from input tensor (H, W for 2D, D, H, W for 3D).
Returns:
Tuple of spatial dimension sizes
"""
x_shape = self._input_nodes[self._x_idx].get_size()
return x_shape[2:] # Skip batch and channel dims

View File

@ -3586,13 +3586,24 @@ def user_autotune(
)
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
def foreach(triton_meta, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
configs = []
# Naive autotuning path for num_warps
if not (
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
return cached_autotune(
None,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,

View File

@ -1,6 +1,6 @@
# NOTE: add new template heuristics here, so they get imported and registered
# TODO: write a simple glob if there are many heuristics to auto import them in the right order
from . import aten, base, contiguous_mm, decompose_k, registry, triton
from . import aten, base, contiguous_mm, conv, decompose_k, registry, triton
# expose the entry function
from .registry import get_template_heuristic

View File

@ -0,0 +1,287 @@
from __future__ import annotations
from typing import Any, cast, TYPE_CHECKING
import torch
from ..kernel.conv import aten_convolution, conv2d_template, conv3d_template
from ..kernel_inputs import ConvKernelInputs
from ..utils import is_ones, sympy_product
from ..virtualized import V
from .base import TemplateConfigHeuristics
from .registry import register_template_heuristic
from .triton import (
CPUConfigHeuristic,
CUDAConfigHeuristic,
MTIAConfigHeuristic,
ROCmConfigHeuristic,
XPUConfigHeuristic,
)
if TYPE_CHECKING:
from collections.abc import Generator
from ..kernel_inputs import KernelInputs
class ConvTemplateConfigMixin(TemplateConfigHeuristics):
"""
Mixin for conv templates that converts config lists to template kwargs.
Similar to MMTemplateConfigMixin but for convolutions.
This handles generating both the static template kwargs (KERNEL_H, STRIDE_H, etc.)
and the per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps).
"""
# Type hint for methods from BaseConfigHeuristic
get_conv_configs: Any
def get_extra_kwargs(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> dict[str, Any]:
"""
Return template kwargs that don't change per-config.
These are derived from kernel_inputs and must include all template parameters.
Args:
kernel_inputs: ConvKernelInputs containing input tensors and conv params
op_name: Operation name (e.g., "convolution")
Returns:
Dict of static template kwargs (KERNEL_H, STRIDE_H, GROUPS, etc.)
"""
assert isinstance(kernel_inputs, ConvKernelInputs), (
f"ConvTemplateConfigMixin requires ConvKernelInputs, got {type(kernel_inputs)}"
)
x, weight, bias = kernel_inputs.get_x_weight_bias()
# Extract kernel shape from weight: [out_chan, in_chan, *kernel_shape]
weight_size = V.graph.sizevars.guard_int_seq(weight.get_size())
kernel_shape = weight_size[2:] # Skip out_chan, in_chan
ndim = len(kernel_shape)
# Extract scalars
stride = cast(tuple[int, ...], kernel_inputs.get_scalar("stride"))
padding = cast(tuple[int, ...], kernel_inputs.get_scalar("padding"))
groups = cast(int, kernel_inputs.get_scalar("groups"))
# Check if we should unroll (only for 1x1 kernels)
unroll = is_ones(kernel_shape)
# Build kwargs dict based on ndim
kwargs: dict[str, Any] = {
"GROUPS": groups,
"UNROLL": unroll,
"ALLOW_TF32": torch.backends.cudnn.allow_tf32,
}
if ndim == 2:
kwargs.update(
{
"KERNEL_H": kernel_shape[0],
"KERNEL_W": kernel_shape[1],
"STRIDE_H": stride[0],
"STRIDE_W": stride[1],
"PADDING_H": padding[0],
"PADDING_W": padding[1],
}
)
elif ndim == 3:
kwargs.update(
{
"KERNEL_D": kernel_shape[0],
"KERNEL_H": kernel_shape[1],
"KERNEL_W": kernel_shape[2],
"STRIDE_D": stride[0],
"STRIDE_H": stride[1],
"STRIDE_W": stride[2],
"PADDING_D": padding[0],
"PADDING_H": padding[1],
"PADDING_W": padding[2],
}
)
return kwargs
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Yield per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps).
Args:
kernel_inputs: ConvKernelInputs containing input tensors
op_name: Operation name
Yields:
Dict of per-config kwargs for each configuration to try
"""
assert isinstance(kernel_inputs, ConvKernelInputs), (
"ConvTemplateConfigMixin requires ConvKernelInputs"
)
x, weight, bias = kernel_inputs.get_x_weight_bias()
# Calculate dimensions for heuristics
weight_size = weight.get_size()
out_chan = weight_size[0]
in_chan = weight_size[1]
# Batch * spatial dimensions product
x_size = x.get_size()
batch_spatial_product = sympy_product([x_size[0], *x_size[2:]])
# Get conv config generator from self (which is a BaseConfigHeuristic subclass)
conv_configs_generator = self.get_conv_configs()
dtype_size = x.get_dtype().itemsize
# Generate configs (reusing mm preprocess_mm_configs machinery)
for c in conv_configs_generator(
batch_spatial_product,
out_chan,
in_chan,
dtype_size=dtype_size,
op_name="conv",
):
# Yield per-config kwargs
yield {
"BLOCK_M": c.kwargs.get("BLOCK_M"),
"BLOCK_N": c.kwargs.get("BLOCK_N"),
"BLOCK_K": c.kwargs.get("BLOCK_K"),
"num_stages": c.num_stages,
"num_warps": c.num_warps,
}
# ATEN convolution heuristic (no per-config tuning)
@register_template_heuristic(aten_convolution.uid, None)
class ATenConvConfigHeuristic(TemplateConfigHeuristics):
"""
Pseudo heuristic for ATen convolution.
ATen doesn't have configs to tune - it's a single choice.
"""
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
# ATen doesn't have per-config kwargs to tune
yield dict()
def get_extra_kwargs(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> dict[str, Any]:
"""
ATen gets stride, padding, etc. as ordered kwargs for the C++ kernel.
"""
assert isinstance(kernel_inputs, ConvKernelInputs)
# Extract scalar values from kernel_inputs
stride = cast(tuple[int, ...], kernel_inputs.get_scalar("stride"))
padding = cast(tuple[int, ...], kernel_inputs.get_scalar("padding"))
dilation = cast(tuple[int, ...], kernel_inputs.get_scalar("dilation"))
transposed = cast(bool, kernel_inputs.get_scalar("transposed"))
output_padding = cast(
tuple[int, ...], kernel_inputs.get_scalar("output_padding")
)
groups = cast(int, kernel_inputs.get_scalar("groups"))
# Check if bias is None to match old behavior
# When bias is None: input_nodes = [x, weight], add 'bias' to kwargs and ordered list
# When bias is present: input_nodes = [x, weight, bias], don't add 'bias' to kwargs
x, weight, bias = kernel_inputs.get_x_weight_bias()
kwargs: dict[str, Any] = {
"stride": stride,
"padding": padding,
"dilation": dilation,
"transposed": transposed,
"output_padding": output_padding,
"groups": groups,
}
if bias is None:
# When bias is None, torch.convolution expects it as a kwarg
kwargs["bias"] = None
kwargs["ordered_kwargs_for_cpp_kernel"] = [
"bias",
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
else:
# When bias is present, it's passed as a positional arg (3rd in input_nodes)
kwargs["ordered_kwargs_for_cpp_kernel"] = [
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
return kwargs
# CUDA Conv2D/Conv3D heuristics
@register_template_heuristic(
conv2d_template.uid,
"cuda",
register=torch.version.hip is None,
)
@register_template_heuristic(
conv3d_template.uid,
"cuda",
register=torch.version.hip is None,
)
class CUDAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CUDAConfigHeuristic):
"""Conv template heuristic for CUDA."""
# ROCm Conv2D/Conv3D heuristics
@register_template_heuristic(
conv2d_template.uid,
"cuda",
register=torch.version.hip is not None,
)
@register_template_heuristic(
conv3d_template.uid,
"cuda",
register=torch.version.hip is not None,
)
class ROCmConvTemplateConfigHeuristic(ConvTemplateConfigMixin, ROCmConfigHeuristic):
"""Conv template heuristic for ROCm."""
# CPU Conv2D/Conv3D heuristics
@register_template_heuristic(conv2d_template.uid, "cpu")
@register_template_heuristic(conv3d_template.uid, "cpu")
class CPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CPUConfigHeuristic):
"""Conv template heuristic for CPU."""
# XPU Conv2D/Conv3D heuristics
@register_template_heuristic(conv2d_template.uid, "xpu")
@register_template_heuristic(conv3d_template.uid, "xpu")
class XPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, XPUConfigHeuristic):
"""Conv template heuristic for XPU."""
# MTIA Conv2D/Conv3D heuristics
@register_template_heuristic(conv2d_template.uid, "mtia")
@register_template_heuristic(conv3d_template.uid, "mtia")
class MTIAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, MTIAConfigHeuristic):
"""Conv template heuristic for MTIA."""

View File

@ -1224,43 +1224,3 @@ def _build_table(
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
)
return "".join(result)
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)

View File

@ -443,7 +443,6 @@ class CodeGen:
colored: bool = False,
# Render each argument on its own line
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
free_vars: list[str] = []
body: list[str] = []
@ -648,6 +647,15 @@ class CodeGen:
if verbose:
# override annotation with more detailed information
try:
from torch.distributed.tensor._api import DTensor, DTensorSpec
dtensorspec_format_shard_order_str = (
DTensorSpec.format_shard_order_str
)
except ModuleNotFoundError:
DTensor = None # type: ignore[assignment,misc]
dtensorspec_format_shard_order_str = None
from torch.fx.experimental.proxy_tensor import py_sym_types
from torch.fx.passes.shape_prop import TensorMetadata
@ -678,6 +686,16 @@ class CodeGen:
core = _tensor_annotation(meta_val)
if is_plain:
maybe_type_annotation = f': "{core}"'
elif type(meta_val) is DTensor:
assert dtensorspec_format_shard_order_str is not None
dtensor_meta = dtensorspec_format_shard_order_str(
meta_val._spec.placements, # type: ignore[attr-defined]
meta_val._spec.shard_order, # type: ignore[attr-defined]
)
cls = meta_val.__class__.__name__
maybe_type_annotation = (
f': "{cls}({core}, {dim_green(dtensor_meta)})"'
)
else:
cls = meta_val.__class__.__name__
maybe_type_annotation = f': "{cls}({core})"'
@ -799,10 +817,6 @@ class CodeGen:
return
raise NotImplementedError(f"node: {node.op} {node.target}")
if record_func:
body.append(
"_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n"
)
for i, node in enumerate(nodes):
# NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one
@ -812,22 +826,8 @@ class CodeGen:
# node index, which will be deleted later
# after going through _body_transformer
body.append(f"# COUNTER: {i}\n")
do_record = record_func and node.op in (
"call_function",
"call_method",
"call_module",
)
if do_record:
# The double hash ## convention is used by post-processing to find the fx markers
body.append(
f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n"
)
emit_node(node)
delete_unused_values(node)
if do_record:
body.append(f"_rf_{node.name}.__exit__(None, None, None)\n")
if record_func:
body.append("_rf.__exit__(None, None, None)\n")
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@ -1779,7 +1779,6 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
"""
Turn this ``Graph`` into valid Python code.
@ -1847,7 +1846,6 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def _python_code(
@ -1860,7 +1858,6 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
return self._codegen._gen_python_code(
self.nodes,
@ -1871,7 +1868,6 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def __str__(self) -> str:

View File

@ -861,18 +861,14 @@ class {module_name}(torch.nn.Module):
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
from torch._dynamo import config as dynamo_config
python_code = self._graph.python_code(
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
)
python_code = self._graph.python_code(root_module="self")
self._code = python_code.src
self._lineno_map = python_code._lineno_map
self._prologue_start = python_code._prologue_start
cls = type(self)
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
from torch._dynamo import config as dynamo_config
if dynamo_config.enrich_profiler_metadata:
# Generate metadata and register for profiler augmentation
@ -889,6 +885,7 @@ class {module_name}(torch.nn.Module):
# This ensures the same code+metadata always generates the same filename
hash_value = _metadata_hash(self._code, node_metadata)
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
filename = f"{file_stem}.py"
# Only include co_filename to use it directly as the cache key
@ -908,13 +905,6 @@ class {module_name}(torch.nn.Module):
_register_fx_metadata(filename, metadata)
# Replace the placeholder in generated code with actual filename
# The double hash ## convention is used by post-processing to find the fx markers
self._code = self._code.replace(
"torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')",
f"torch._C._profiler._RecordFunctionFast('## {filename} ##')",
)
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
# Determine whether this class explicitly defines a __call__ implementation

View File

@ -4,7 +4,7 @@ import operator
import re
from collections import deque
from dataclasses import dataclass
from typing import Any, Literal, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING
from torch.autograd.profiler import profile
from torch.profiler import DeviceType
@ -400,170 +400,3 @@ def _init_for_cuda_graphs() -> None:
with profile():
pass
@dataclass
class TimelineEvent:
"""Represents an event in the profiler timeline."""
timestamp: int
event_type: Literal["start", "end", "regular"]
marker_type: Optional[Literal["filename", "node"]]
identifier: Optional[str | int]
event: dict[str, Any]
@dataclass
class ContextStackEntry:
"""Represents a context (filename or node) in the stack."""
context_type: Literal["filename", "node"]
identifier: str | int
metadata: Optional[dict]
tid: Optional[int] = None # Thread ID associated with this context
def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
"""
Maps recorded profiler events to their corresponding fx nodes and adds stack traces.
Builds a timeline of all events (regular ops and FX markers for filenames/nodes),
sorts by timestamp, then processes chronologically while maintaining a context stack of active
filename/node scopes. Regular events are augmented with stack traces and node names from the
innermost active context. Runtime is O(n log n) for n events.
Args:
traced_data: Json of profiler events from Chrome trace
Returns:
Dict mapping recorded event names to their aten operations with added stack traces
"""
from torch.fx.traceback import _FX_METADATA_REGISTRY
trace_events = traced_data.get("traceEvents", [])
# Create event timeline
event_timeline: list[TimelineEvent] = []
def is_fx_marker_event(event):
return (
event.get("cat") == "cpu_op"
and event.get("name", "").startswith("## ")
and event.get("name", "").endswith(" ##")
)
def append_fx_marker_event(event_type, identifier, event):
start_ts = event["ts"]
end_ts = start_ts + event["dur"]
event_timeline.append(
TimelineEvent(start_ts, "start", event_type, identifier, event)
)
event_timeline.append(
TimelineEvent(end_ts, "end", event_type, identifier, event)
)
for event in trace_events:
if "ts" not in event or "dur" not in event:
continue
if is_fx_marker_event(event):
content = event["name"][3:-3]
if content.endswith(".py"):
append_fx_marker_event("filename", content, event)
else:
try:
node_index = int(content)
except ValueError:
pass
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
else:
# Regular event that needs augmentation
start_ts = event["ts"]
event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event))
# Sort by timestamp
event_timeline.sort(key=lambda x: x.timestamp)
# Process events in chronological order with a stack
context_stack: list[ContextStackEntry] = []
# Invariant: all start event has a corresponding end event
for timeline_event in event_timeline:
match timeline_event.event_type:
case "start":
assert timeline_event.identifier is not None
if timeline_event.marker_type == "filename":
assert isinstance(timeline_event.identifier, str)
# Push filename context - query metadata registry on-demand
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
tid = timeline_event.event.get("tid")
context_stack.append(
ContextStackEntry(
"filename", timeline_event.identifier, metadata, tid
)
)
elif timeline_event.marker_type == "node":
# Find the current filename from stack
current_file_metadata = None
tid = timeline_event.event.get("tid")
for ctx_entry in reversed(context_stack):
if (
ctx_entry.context_type == "filename"
and ctx_entry.tid == tid
):
current_file_metadata = ctx_entry.metadata
break
if current_file_metadata:
node_metadata = current_file_metadata.get("node_metadata", {})
if timeline_event.identifier in node_metadata:
node_meta: Optional[dict] = node_metadata[
timeline_event.identifier
]
context_stack.append(
ContextStackEntry(
"node", timeline_event.identifier, node_meta, tid
)
)
case "end":
# Pop from stack - search backwards to find matching context
for i in range(len(context_stack) - 1, -1, -1):
ctx_entry = context_stack[i]
if (
timeline_event.marker_type == ctx_entry.context_type
and timeline_event.identifier == ctx_entry.identifier
):
context_stack.pop(i)
break
case "regular":
# Apply metadata from current context stack
# Find the most specific context (node takes precedence over filename)
# Only augment events with the same tid as the file/node event matched
current_stack_trace = None
current_node_name = None
event_tid = timeline_event.event.get("tid")
for ctx_entry in reversed(context_stack):
# Only apply metadata from contexts with matching tid
if ctx_entry.tid == event_tid:
if ctx_entry.context_type == "node" and ctx_entry.metadata:
current_stack_trace = ctx_entry.metadata.get(
"stack_trace", "No model stack trace available"
)
current_node_name = ctx_entry.metadata.get("name", "")
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
# if nodes are nested, e.g. in nested graph modules
break
# Augment the event
if current_stack_trace or current_node_name:
args = timeline_event.event.setdefault("args", {})
if current_stack_trace:
args["stack_trace"] = current_stack_trace
if current_node_name:
args["node_name"] = current_node_name

View File

@ -210,7 +210,8 @@ class _KinetoProfile:
def start_trace(self) -> None:
if self.execution_trace_observer:
self.execution_trace_observer.start()
assert self.profiler is not None
if self.profiler is None:
raise AssertionError("Profiler must be initialized before starting trace")
self.profiler._start_trace()
if self.profile_memory:
@ -256,7 +257,8 @@ class _KinetoProfile:
def stop_trace(self) -> None:
if self.execution_trace_observer:
self.execution_trace_observer.stop()
assert self.profiler is not None
if self.profiler is None:
raise AssertionError("Profiler must be initialized before stopping trace")
self.profiler.__exit__(None, None, None)
def export_chrome_trace(self, path: str):
@ -264,7 +266,10 @@ class _KinetoProfile:
Exports the collected trace in Chrome JSON format. If kineto is enabled, only
last cycle in schedule is exported.
"""
assert self.profiler
if self.profiler is None:
raise AssertionError(
"Profiler must be initialized before exporting chrome trace"
)
if path.endswith(".gz"):
fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False)
fp.close()
@ -284,7 +289,8 @@ class _KinetoProfile:
path (str): save stacks file to this location;
metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total"
"""
assert self.profiler
if self.profiler is None:
raise AssertionError("Profiler must be initialized before exporting stacks")
return self.profiler.export_stacks(path, metric)
def toggle_collection_dynamic(
@ -316,7 +322,7 @@ class _KinetoProfile:
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
"""
if not self.profiler:
if self.profiler is None:
return
self.profiler.toggle_collection_dynamic(enable, activities)
@ -333,7 +339,10 @@ class _KinetoProfile:
To use shape/stack functionality make sure to set record_shapes/with_stack
when creating profiler context manager.
"""
assert self.profiler
if self.profiler is None:
raise AssertionError(
"Profiler must be initialized before getting key averages"
)
return self.profiler.key_averages(
group_by_input_shape, group_by_stack_n, group_by_overload_name
)
@ -343,7 +352,8 @@ class _KinetoProfile:
Returns the list of unaggregated profiler events,
to be used in the trace callback or after the profiling is finished
"""
assert self.profiler
if self.profiler is None:
raise AssertionError("Profiler must be initialized before accessing events")
return self.profiler.function_events
def add_metadata(self, key: str, value: str) -> None:
@ -395,7 +405,10 @@ class _KinetoProfile:
if missing:
raise ValueError(f"{', '.join(missing)} required for memory profiling.")
assert self.profiler is not None and self.profiler.kineto_results is not None
if self.profiler is None or self.profiler.kineto_results is None:
raise AssertionError(
"Profiler and kineto_results must be initialized for memory profiling"
)
return MemoryProfile(self.profiler.kineto_results)
def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
@ -485,7 +498,8 @@ def schedule(
"""
def schedule_fn(step: int) -> ProfilerAction:
assert step >= 0
if step < 0:
raise AssertionError(f"Step must be non-negative. Got {step}.")
if step < skip_first:
return ProfilerAction.NONE
else:
@ -508,9 +522,11 @@ def schedule(
else ProfilerAction.RECORD_AND_SAVE
)
assert (
wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0
), "Invalid profiler schedule arguments"
if wait < 0 or warmup < 0 or active <= 0 or repeat < 0 or skip_first < 0:
raise AssertionError(
f"Invalid profiler schedule arguments. Got wait={wait} (need >= 0), warmup={warmup} (need >= 0), "
f"active={active} (need > 0), repeat={repeat} (need >= 0), skip_first={skip_first} (need >= 0)."
)
if warmup == 0:
warn(
"Profiler won't be using warmup, this can skew profiler results",
@ -717,7 +733,8 @@ class profile(_KinetoProfile):
activities_set.add(ProfilerActivity.CUDA)
elif ProfilerActivity.CUDA in activities_set:
activities_set.remove(ProfilerActivity.CUDA)
assert len(activities_set) > 0, "No valid profiler activities found"
if len(activities_set) == 0:
raise AssertionError("No valid profiler activities found")
super().__init__(
activities=activities,