Compare commits

..

1 Commits

77 changed files with 896 additions and 1157 deletions

View File

@ -271,16 +271,6 @@ case "$tag" in
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
pytorch-linux-jammy-aarch64-py3.10-clang21)
ANACONDA_PYTHON_VERSION=3.10
CLANG_VERSION=21
ACL=yes
VISION=yes
OPENBLAS=yes
# snadampal: skipping llvm src build install because the current version
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=11

View File

@ -1 +1 @@
bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7
7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd

View File

@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
# work around ubuntu apt-get conflicts
sudo apt-get -y -f install
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
if [[ $CLANG_VERSION -ge 18 ]]; then
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main"
if [[ $CLANG_VERSION == 18 ]]; then
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
fi
fi

View File

@ -129,7 +129,7 @@ function install_129 {
}
function install_128 {
CUDNN_VERSION=9.8.0.87
CUDNN_VERSION=9.10.2.21
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
# install CUDA 12.8.1 in the same container
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux

View File

@ -10,7 +10,6 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
OPENBLAS_BUILD_FLAGS="
CC=gcc
NUM_THREADS=128
USE_OPENMP=1
NO_SHARED=0

View File

@ -1 +1 @@
3.5.1
3.5.0

View File

@ -272,6 +272,18 @@ def smoke_test_cuda(
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
print(f"Torch cuDNN version: {torch_cudnn_version}")
torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion()
print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}")
torch_cudnn_runtime_version = tuple(
[int(x) for x in torch_cudnn_version.split(".")]
)
if torch_cudnn_runtime_version != torch_cudnn_compile_version:
raise RuntimeError(
"cuDNN runtime version doesn't match comple version. "
f"Loaded: {torch_cudnn_runtime_version} "
f"Expected: {torch_cudnn_compile_version}"
)
if sys.platform in ["linux", "linux2"]:
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
print(f"Torch nccl; version: {torch_nccl_version}")

View File

@ -79,8 +79,6 @@ jobs:
include:
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21
runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
runner: linux.arm64.m7g.4xlarge
timeout-minutes: 600

View File

@ -93,7 +93,7 @@ jobs:
- linux-jammy-cuda12_8-py3_10-gcc11-build
- target-determination
with:
timeout-minutes: 360
timeout-minutes: 400
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }}

View File

@ -18,7 +18,7 @@ aspects of contributing to PyTorch.
- [Python Unit Testing](#python-unit-testing)
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
- [Local linting](#local-linting)
- [Running `pyrefly`](#running-pyrefly)
- [Running `mypy`](#running-mypy)
- [C++ Unit Testing](#c-unit-testing)
- [Run Specific CI Jobs](#run-specific-ci-jobs)
- [Merging your Change](#merging-your-change)
@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory.
**Prerequisites**:
The following packages should be installed with `pip`:
- `expecttest` and `hypothesis` - required to run tests
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
- `mypy` - recommended for linting
- `pytest` - recommended to run tests more selectively
Running
```
@ -350,32 +350,15 @@ make lint
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
#### Running `pyrefly`
#### Running `mypy`
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
**Getting Started with Pyrefly:**
To run type checking on the PyTorch codebase:
```bash
pyrefly check
```
For more detailed error information with summaries:
```bash
pyrefly check --summarize-errors
```
**Learn More:**
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
`mypy` is an optional static type checker for Python. We have multiple `mypy`
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
See [Guide for adding type annotations to
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
for more information on how to set up `mypy` and tackle type annotation
tasks.
### C++ Unit Testing

View File

@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
});
}
template <typename func_t, typename vec_func_t, typename ident_t = double>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast<ident_t>(0)) {
template <typename func_t, typename vec_func_t>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
using traits = binary_function_traits<func_t>;
static_assert(
all_same<

View File

@ -339,13 +339,33 @@ void or_kernel_impl(TensorIterator& iter) {
}
}
template<typename scalar_t>
struct MinValuesOps: public at::native::MinOps<scalar_t> {
using arg_t = typename MinOps<scalar_t>::arg_t;
static scalar_t project(arg_t arg) {
return arg.first;
}
};
void min_values_kernel_impl(TensorIterator& iter) {
// This case is special because of Vectorized<int64_t> does not
// handle upper_bound<int64_t>().
// See: https://github.com/pytorch/pytorch/issues/43254
if (iter.dtype() == kLong || iter.dtype() == kUInt64) {
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce(
iter,
MinValuesOps<scalar_t>{},
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
}), kLong, kUInt64);
return;
}
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
upper_bound<scalar_t>());
static_cast<double>(upper_bound<scalar_t>()));
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}

View File

@ -22,9 +22,6 @@
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#ifdef USE_ROCM
#include <ATen/native/hip/ck_group_gemm.h>
#endif
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
@ -669,19 +666,12 @@ std::optional<c10::ScalarType> out_dtype) {
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
use_fast_path = true;
}
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
if (use_fast_path) {
// fast path, no d2h sync needed
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#endif
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
}

View File

@ -1,19 +0,0 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/ScalarType.h>
#include <optional>
namespace at {
namespace hip {
namespace detail {
void group_gemm_ck(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
at::Tensor& out);
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -1,462 +0,0 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/hip/HIPContext.h>
#include <ATen/Tensor.h>
#include <ATen/TensorAccessor.h>
#include <c10/hip/HIPStream.h>
#include <iostream>
#include <vector>
#include <optional>
#include <type_traits>
#include <ck/ck.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/utility/tuple.hpp>
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
namespace at {
namespace hip {
namespace detail {
namespace CkTypes {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
}
template <typename ALayout, typename BLayout, typename DataType>
using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage<
ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor,
DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType,
CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
1, 1,
S<1,32,1,8>, 4
>;
template <typename ALayout, typename BLayout, typename DataType>
void launch_grouped_bgemm_ck_impl_dispatch(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
at::Tensor& out)
{
using DeviceOp = GroupedGemmKernel<ALayout, BLayout, DataType>;
using PassThrough = CkTypes::PassThrough;
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a_ptrs, p_b_ptrs;
std::vector<void*> p_e_ptrs;
// Note: d_ptrs will be resized after we populate the other vectors
const int mat_a_dim = mat_a.dim();
const int mat_b_dim = mat_b.dim();
const char* a_ptr_base = reinterpret_cast<const char*>(mat_a.data_ptr());
const char* b_ptr_base = reinterpret_cast<const char*>(mat_b.data_ptr());
char* out_ptr_base = reinterpret_cast<char*>(out.data_ptr());
const size_t a_element_size = mat_a.element_size();
const size_t b_element_size = mat_b.element_size();
const size_t out_element_size = out.element_size();
// for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses.
if (mat_a_dim == 2 && mat_b_dim == 2) {
// 2D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
const int M = mat_a.size(0); // number of rows in A
const int N = mat_b.size(1); // number of columns in B
const int K = mat_a.size(1); // columns in A == rows in B
// for 2d*2d input, output is 3d.
// for each group, A columns (K) are sliced. M and N dimensions are not sliced.
for (int i = 0; i < num_groups; ++i) {
int start_k = (i == 0) ? 0 : offs_accessor[i-1];
int end_k = offs_accessor[i];
int k = end_k - start_k;
//K dimension are sliced, hence select stride(1) always.
//K dimension is always dimension 1, regardless of memory layout (row/column major)
const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size;
const void* group_b_ptr;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size;
// Leading dimension = distance between rows = stride(0)
ldb = mat_b.stride(0);
} else {
// Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size;
// Leading dimension = distance between columns = stride(1)
ldb = mat_b.stride(1);
}
// Calculate output pointer for group i in 3D tensor [num_groups, M, N]
// stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [M,K]: leading dimension = distance between rows = stride(0)
lda = mat_a.stride(0);
} else {
// Column-major A [M,K]: leading dimension = distance between columns = stride(1)
lda = mat_a.stride(1);
}
// Output is always row-major in 3D tensor [num_groups, M, N]
// Leading dimension for each group's [M,N] slice = stride(1) = N
ldc = out.stride(1);
size_t output_group_bytes = M * N * out_element_size;
void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes;
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(k),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 2 && mat_b_dim == 3) {
// 2D*3D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 2d*3d input, output is 2d.
// A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n]
// Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B
const int K = mat_a.size(1); // columns in A
// For 2D-3D case: The output determines N (result width)
const int N = out.size(1); // N is the width of the output tensor
for (int i = 0; i < num_groups; ++i) {
int start_m = (i == 0) ? 0 : offs_accessor[i - 1];
int end_m = offs_accessor[i];
int m = end_m - start_m;
// Skip zero-sized groups but continue processing subsequent groups
if (m <= 0) {
continue;
}
// Select A rows for group i: skip start_m rows
const void* group_a_ptr;
int lda;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
lda = mat_a.stride(0); // distance between rows
} else {
// Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows)
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
// Detect stride pattern for A tensor to determine appropriate lda calculation
bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0));
if (a_is_strided_tensor) {
// For strided A tensors: stride(0) gives the actual leading dimension
lda = mat_a.stride(0);
} else {
// For non-strided A tensors: use the M dimension (total rows)
lda = mat_a.size(0); // Total M dimension for column-major layout
}
}
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed
ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N]
} else {
// Detect stride pattern to determine appropriate ldb calculation
bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2));
if (is_strided_tensor) {
// For strided tensors: stride(2) gives the actual leading dimension
ldb = mat_b.stride(2);
} else {
// For non-strided tensors: use the N dimension
ldb = mat_b.size(1);
}
}
// Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N]
void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size;
int ldc = out.stride(0); // distance between rows in output (should be N for 2D case)
gemm_descs.push_back({
static_cast<ck::index_t>(m),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 3) {
// 3d*3d input, output is 3d - batched matrix multiplication
// A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n]
// Each batch is processed as a separate GEMM operation
const int batch_size = mat_a.size(0);
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed)
// Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout
int N;
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
N = mat_b.size(2);
} else if (mat_b.size(2) == K) {
// B is [batch, n, k] - transposed layout
N = mat_b.size(1);
} else {
TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[",
batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]");
}
for (int i = 0; i < batch_size; ++i) {
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
// Select output batch for group i: Output[i, :, :]
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldb, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
} else {
// Column-major A: leading dimension = distance between columns = stride(2)
lda = mat_a.stride(2);
}
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B: leading dimension = distance between rows
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(1); // stride between K rows
} else {
// B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM
ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n])
}
} else {
// Column-major B: leading dimension = distance between columns
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(2); // stride between N columns
} else {
// B is [batch, n, k] - transposed layout
ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n])
}
}
// Output is typically row-major: leading dimension = distance between rows = stride(1)
ldc = out.stride(1);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 2) {
// 3D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 3d*2d input, output is 3d.
// A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both)
// Offset divides N dimension of B, each group gets different slice of B and different batch of A
const int batch_size = mat_a.size(0); // n_groups
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A
// For row-major A and B case: B should be [K, total_N]
const int total_N = mat_b.size(1); // B is [K, total_N] for row-major
for (int i = 0; i < num_groups; ++i) {
int start_n = (i == 0) ? 0 : offs_accessor[i - 1];
int end_n = offs_accessor[i];
int n = end_n - start_n;
// Skip zero-sized groups but continue processing subsequent groups
if (n <= 0) {
continue;
}
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B slice for group i: B[:, start_n:end_n] (B[K, total_N])
const void* group_b_ptr;
int ldb;
// Check if B is row-major or column-major
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(0); // distance between rows (should be total_N)
} else {
// Column-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(1); // distance between columns (should be K)
}
// Select output slice for group i: Output[:, start_n:end_n]
void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size;
int lda, ldc;
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
// Output is row-major: leading dimension = distance between rows = stride(0)
ldc = out.stride(0);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(n),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim);
}
TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups");
// Initialize d_ptrs with the correct size
std::vector<std::array<const void*, 0>> d_ptrs(p_a_ptrs.size());
static DeviceOp gemm_instance;
auto argument = gemm_instance.MakeArgument(
p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs,
gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}
);
TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument),
"CK Group GEMM: argument unsupported (shape/strides/type config)");
size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument);
size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument);
void* gemm_arg_buf = nullptr;
void* ws_buf = nullptr;
hipMalloc(&gemm_arg_buf, arg_buf_size);
hipMalloc(&ws_buf, ws_size);
gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf);
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
auto invoker = gemm_instance.MakeInvoker();
hipStream_t stream = c10::hip::getCurrentHIPStream();
invoker.Run(argument, {stream});
hipFree(gemm_arg_buf);
hipFree(ws_buf);
}
void group_gemm_ck(
const at::Tensor& input_a,
const at::Tensor& input_b_colmajor,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& /*bias*/,
at::Tensor& out)
{
// Detect if input_a is row-major based on stride pattern
bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1);
bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1);
// Ensure tensor A is row-major and contiguous if not already
at::Tensor mat_a = input_a;
if (!a_row_major) {
// If A is not row-major, make it contiguous (row-major)
mat_a = input_a.contiguous();
}
// Force tensor B to be column-major using double transpose trick
// This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape
at::Tensor mat_b = input_b_colmajor;
if (!b_col_major) {
mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1);
}
// For 3D tensors, check the last dimension stride for row-major detection
a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1);
bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1);
if (mat_a.dtype() == at::kBFloat16) {
// bf16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kHalf) {
// fp16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kFloat) {
// fp32 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype");
}
}
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -47,10 +47,20 @@ Tensor sgd_out_of_place(
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");
// testing Tensor strides + stride
STD_TORCH_CHECK(param.strides()[0] == param.stride(0));
int64_t *param_sizes;
int64_t *param_strides;
aoti_torch_get_sizes(param.get(), &param_sizes);
aoti_torch_get_strides(param.get(), &param_strides);
auto out = new_empty(param, param.sizes());
int32_t param_dtype;
aoti_torch_get_dtype(param.get(), &param_dtype);
int32_t param_device_type;
aoti_torch_get_device_type(param.get(), &param_device_type);
AtenTensorHandle out_ath;
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
auto out = Tensor(out_ath);
sgd_math(
reinterpret_cast<float*>(param.data_ptr()),
@ -334,8 +344,6 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
// Still using a std::vector below even though people can just pass in an
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
// directly.
// This is to test that passing in a std::vector works for BC. (It gets
// implicitly converted to HeaderOnlyArrayRef too!)
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
return new_empty(t, sizes, dtype);

View File

@ -5,16 +5,8 @@ import contextlib
import torch
import torch.distributed as dist
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -50,22 +42,24 @@ class TestDTensorDebugMode(TestCase):
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
with DebugMode(record_torchfunction=True) as debug_mode:
with DebugMode(
record_torchfunction=True, record_ids=True, record_output=True
) as debug_mode:
torch.mm(x_dtensor, y_dtensor).sum()
self.assertExpectedInline(
debug_mode.debug_string(),
"""\
torch.mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0))
aten::mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0))
torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$6: f32[8, 32]| S(0)
aten::mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0))
redistribute_input(1, S(0) -> R)
redistribute_input(t: f32[1, 32], trace: S(0)->R)
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
_c10d_functional::wait_tensor(t: f32[8, 32])
aten::mm(t: f32[1, 8], t: f32[8, 32])
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 32]| S(0))
aten::sum(dt: f32[8, 32]| S(0))
aten::sum(t: f32[1, 32])""",
redistribute_input(t$2: f32[1, 32], trace: S(0)->R)
_c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32]
_c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32]
aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32]
<method 'sum' of 'torch._C.TensorBase' objects>(dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P
aten::sum(dt$6: f32[8, 32]| S(0))
aten::sum(t$5: f32[1, 32]) -> t$7: f32[]""",
)
self.assertTrue(isinstance(debug_mode.operators[0], _OpCall))
@ -432,31 +426,6 @@ class TestDTensorDebugMode(TestCase):
][-1]
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
def test_pretty_print_dtensor_make_fx(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
A = torch.randn(8, 32)
B = torch.randn(32, 32)
dA = distribute_tensor(A, mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, mesh, [Replicate()]).requires_grad_()
def f(dA, dB):
dy = dA @ dB
loss = dy.sum()
loss.backward()
return dA.grad, dB.grad
# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode="fake")(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
# Colored is nice for actual viewing, not using in this test though
gm_str = gm.print_readable(colored=False, print_output=False)
self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str)
instantiate_parametrized_tests(TestDTensorDebugMode)

View File

@ -3,7 +3,8 @@
import itertools
import random
import unittest
from typing import Any, Callable, ClassVar, Optional
from collections.abc import Callable
from typing import Any, ClassVar, Optional
import torch
import torch.distributed as dist

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: dynamo"]
import unittest
from collections.abc import Callable, Sequence
from typing import Any, Union
from collections.abc import Sequence
from typing import Any, Callable, Union
import torch
import torch._dynamo

View File

@ -1,5 +1,5 @@
# Owner(s): ["module: dynamo"]
from typing import NamedTuple, Optional, TYPE_CHECKING
from typing import Callable, NamedTuple, Optional
import torch
import torch._dynamo
@ -7,10 +7,6 @@ from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter, same
if TYPE_CHECKING:
from collections.abc import Callable
"""
This is an example of a pure-python version of autograd implemented by
@zdevito. It represents a rather challenging test case for TorchDynamo

View File

@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca
torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node
torch.fx.graph.Graph.print_tabular(self)
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode
torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')
torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool
torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None

View File

@ -4,8 +4,9 @@
import functools
import unittest
from collections.abc import Callable
from contextlib import contextmanager, ExitStack
from typing import Any, Callable, Optional
from typing import Any, Optional
import torch
import torch._dynamo

View File

@ -13,7 +13,7 @@ from random import Random
from shutil import rmtree
from threading import Lock
from time import sleep, time
from typing import Any, Generator, Sequence, TYPE_CHECKING, Union
from typing import Any, TYPE_CHECKING, Union
from typing_extensions import TypeVar
from unittest.mock import patch
@ -37,6 +37,7 @@ from torch.testing._internal.common_utils import (
if TYPE_CHECKING:
from collections.abc import Generator, Sequence
from pathlib import Path

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: inductor"]
from typing import Any, Callable
from collections.abc import Callable
from typing import Any
import torch
from torch._inductor.fx_passes.pre_grad import (

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: inductor"]
from typing import Callable
from collections.abc import Callable
import torch
from torch._dynamo.testing import rand_strided

View File

@ -204,7 +204,8 @@ import itertools
import operator
import unittest
import io
from typing import Callable, Optional
from typing import Optional
from collections.abc import Callable
class BinaryOp(torch.nn.Module):
def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar):

View File

@ -1,176 +0,0 @@
# Owner(s): ["oncall: pt2"]
from collections import deque
from typing import Optional
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
def get_state(t: torch.Tensor) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""Extract (sizes, strides) tuple from a tensor."""
return (tuple(t.size()), tuple(t.stride()))
def enumerate_reachable_states(
initial_size: int,
) -> set[tuple[tuple[int, ...], tuple[int, ...]]]:
"""
Use BFS with DP to enumerate all reachable (size, stride) states from
a 1D contiguous tensor via valid view operations.
We only explore states with offset=0 (you can retroactively change the offset).
We reject states with size=0 or size=1 dimensions as they are degenerate.
"""
# Create initial 1D contiguous tensor
initial_tensor = torch.arange(initial_size)
initial_state = get_state(initial_tensor)
# Map from state to tensor for that state
state_to_tensor: dict[tuple[tuple[int, ...], tuple[int, ...]], torch.Tensor] = {
initial_state: initial_tensor
}
visited: set[tuple[tuple[int, ...], tuple[int, ...]]] = {initial_state}
queue: deque[tuple[tuple[int, ...], tuple[int, ...]]] = deque([initial_state])
while queue:
state = queue.popleft()
t = state_to_tensor[state]
sizes, strides = state
ndim = len(sizes)
def add_state(new_t: torch.Tensor) -> None:
new_state = get_state(new_t)
sizes, strides = new_state
# Skip if has size-0 or size-1 dimensions
if any(s == 0 or s == 1 for s in sizes):
return
# Only accept states where strides are in descending order
if list(strides) != sorted(strides, reverse=True):
return
if new_state not in visited:
visited.add(new_state)
queue.append(new_state)
state_to_tensor[new_state] = new_t
# 1. Unflatten: try factoring each dimension
for dim in range(ndim):
size = sizes[dim]
assert size > 1
# Try all factorizations x * y = size where both x, y >= 2
# We only need to check x up to size // 2 since when x > size // 2,
# y = size // x < 2, which we reject
for x in range(2, size // 2 + 1):
if size % x == 0:
y = size // x
add_state(t.unflatten(dim, (x, y)))
# 2. Slice: exhaustively check all possible slicing parameters
for dim in range(ndim):
size = sizes[dim]
for start in range(size):
for stop in range(start + 1, size + 1):
for step in range(1, size + 1):
slices = [slice(None)] * ndim
slices[dim] = slice(start, stop, step)
add_state(t[tuple(slices)])
# 3. Flatten: merge adjacent dimensions
for dim in range(ndim - 1):
add_state(t.flatten(dim, dim + 1))
return visited
class TestAsStrided(TestCase):
def test_size_10_exhaustive(self) -> None:
"""Test that size 10 produces exactly the expected 54 states."""
expected_states = {
((2,), (1,)),
((2,), (2,)),
((2,), (3,)),
((2,), (4,)),
((2,), (5,)),
((2,), (6,)),
((2,), (7,)),
((2,), (8,)),
((2,), (9,)),
((2, 2), (2, 1)),
((2, 2), (3, 1)),
((2, 2), (3, 2)),
((2, 2), (4, 1)),
((2, 2), (4, 2)),
((2, 2), (4, 3)),
((2, 2), (5, 1)),
((2, 2), (5, 2)),
((2, 2), (5, 3)),
((2, 2), (5, 4)),
((2, 2), (6, 1)),
((2, 2), (6, 2)),
((2, 2), (6, 3)),
((2, 2), (8, 1)),
((2, 2, 2), (4, 2, 1)),
((2, 2, 2), (5, 2, 1)),
((2, 3), (3, 1)),
((2, 3), (4, 1)),
((2, 3), (5, 1)),
((2, 3), (5, 2)),
((2, 3), (6, 1)),
((2, 4), (4, 1)),
((2, 4), (5, 1)),
((2, 5), (5, 1)),
((3,), (1,)),
((3,), (2,)),
((3,), (3,)),
((3,), (4,)),
((3, 2), (2, 1)),
((3, 2), (3, 1)),
((3, 2), (3, 2)),
((3, 2), (4, 1)),
((3, 3), (3, 1)),
((4,), (1,)),
((4,), (2,)),
((4,), (3,)),
((4, 2), (2, 1)),
((5,), (1,)),
((5,), (2,)),
((5, 2), (2, 1)),
((6,), (1,)),
((7,), (1,)),
((8,), (1,)),
((9,), (1,)),
((10,), (1,)),
}
actual_states = enumerate_reachable_states(10)
self.assertEqual(len(actual_states), 54)
self.assertEqual(actual_states, expected_states)
def test_subset_property(self) -> None:
"""
Test that for sizes 2..10, each smaller tensor results in a strict
subset of possible states compared to the next one.
"""
prev_states: Optional[set[tuple[tuple[int, ...], tuple[int, ...]]]] = None
for size in range(2, 11):
current_states = enumerate_reachable_states(size)
if prev_states is not None:
# Check that prev_states is a strict subset of current_states
self.assertTrue(
prev_states.issubset(current_states),
f"States from size {size - 1} are not a subset of size {size}",
)
# Check that it's a strict subset (not equal)
self.assertTrue(
len(prev_states) < len(current_states),
f"States from size {size - 1} should be strictly fewer than size {size}",
)
prev_states = current_states
if __name__ == "__main__":
run_tests()

View File

@ -75,6 +75,12 @@ from torch.testing._internal.common_utils import (
)
from torch.testing._internal.jit_utils import JitTestCase
import json
import tempfile
from torch.profiler import profile, ProfilerActivity
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
from torch.autograd.profiler_util import _canonicalize_profiler_events
try:
from torchvision import models as torchvision_models
@ -201,6 +207,36 @@ def side_effect_func(x: torch.Tensor):
print(x)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
trace_file = f.name
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(
trace_data
)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
class TestFX(JitTestCase):
def setUp(self):
super().setUp()
@ -4212,6 +4248,150 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
# recorver mutable checking flag
torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
augments profiler events with stack traces from FX metadata registry.
"""
# Simple test model
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(16, 10)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
model = TestModel().cuda()
# Compile the model
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(torch.randn(10, 10, device="cuda"))
# Profile with the compiled model
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result = compiled_model(torch.randn(10, 10, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::t node=t stack_trace=x = self.linear1(x)
event=aten::transpose node=t stack_trace=x = self.linear1(x)
event=aten::as_strided node=t stack_trace=x = self.linear1(x)
event=aten::addmm node=addmm stack_trace=x = self.linear1(x)
event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x)
event=aten::relu node=relu stack_trace=x = self.relu(x)
event=aten::clamp_min node=relu stack_trace=x = self.relu(x)
event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x)
event=aten::t node=t_1 stack_trace=x = self.linear2(x)
event=aten::transpose node=t_1 stack_trace=x = self.linear2(x)
event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x)
event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x)
event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_multiple_modules(self):
"""
Test that multiple compiled modules under the same profiler session
have their events correctly augmented with stack traces.
"""
class ModelA(torch.nn.Module):
def forward(self, x):
return x + 1
class ModelB(torch.nn.Module):
def forward(self, x):
return x - 1
model_a = ModelA().cuda()
model_b = ModelB().cuda()
# Compile both models
compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True)
compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_a(torch.randn(10, 10, device="cuda"))
_ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
# Profile both models in the same session
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result_a = compiled_a(torch.randn(10, 10, device="cuda"))
result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::add node=add stack_trace=return x + 1
event=cudaLaunchKernel node=add stack_trace=return x + 1
event=aten::sub node=sub stack_trace=return x - 1
event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_nested_graph_modules(self):
"""
Test that nested graph modules (e.g., graph modules calling subgraphs)
have their events correctly augmented with stack traces.
"""
# Model with nested structure
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.c = 5
@torch.compiler.nested_compile_region
def forward(self, x, y):
m = torch.mul(x, y)
s = m.sin()
a = s + self.c
return a
model = Mod().cuda()
# Compile the model (this may create nested graph modules)
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
# Profile
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::mul node=mul stack_trace=m = torch.mul(x, y)
event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y)
event=aten::sin node=sin stack_trace=s = m.sin()
event=cudaLaunchKernel node=sin stack_trace=s = m.sin()
event=aten::add node=add stack_trace=a = s + self.c
event=cudaLaunchKernel node=add stack_trace=a = s + self.c"""
)
def run_getitem_target():
from torch.fx._symbolic_trace import _wrapped_methods_to_patch

View File

@ -5,7 +5,7 @@ import time
import unittest
from itertools import product
from functools import partial
from typing import Callable
from collections.abc import Callable
import torch
@ -490,6 +490,8 @@ class TestMatmulCuda(InductorTestCase):
@parametrize("b_row_major", [False, True])
@dtypes(torch.bfloat16, torch.float32, torch.float16)
def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype):
if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]:
self.skipTest("failed using hipblaslt on rocm 6.4.2")
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 64, 4

View File

@ -1,5 +1,5 @@
from typing import TypeAlias, Union
from typing_extensions import assert_type
from typing import Union
from typing_extensions import assert_type, TypeAlias
from torch import randn, Tensor

View File

@ -313,4 +313,12 @@ if __name__ == "__main__":
remove_nan_inf(test_cases),
)
# Part of an experiment to see if we can handle all the data as is
upload_workflow_stats_to_s3(
args.workflow_run_id,
args.workflow_run_attempt,
"all_test_runs",
remove_nan_inf(test_cases),
)
upload_additional_info(args.workflow_run_id, args.workflow_run_attempt, test_cases)

View File

@ -1,9 +1,8 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="type-arg"
from collections.abc import Callable
from datetime import timedelta
from enum import Enum
from typing import Any, Optional, overload, Union
from typing import Any, Callable, Optional, overload, Union
import torch
from torch import Tensor

View File

@ -39,10 +39,11 @@ import types
import unittest
import warnings
import weakref
from collections.abc import Sized
from dataclasses import dataclass
from enum import Enum
from os.path import dirname, join
from typing import Any, NamedTuple, Optional, Sized, TYPE_CHECKING, Union
from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union
from unittest.mock import patch
import sympy

View File

@ -1,5 +1,6 @@
import weakref
from typing import Any, Callable
from collections.abc import Callable
from typing import Any
from torch._dynamo.source import Source

View File

@ -3320,7 +3320,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, SetVariable)
assert obj.is_mutable()
obj.call_method(self, "add", [v], {}) # type: ignore[arg-type]
obj.call_method(self, "add", [v], {})
def SET_UPDATE(self, inst: Instruction) -> None:
v = self.pop()
@ -3329,7 +3329,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, SetVariable)
assert obj.is_mutable()
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
obj.call_method(self, "update", [v], {})
def LIST_APPEND(self, inst: Instruction) -> None:
v = self.pop()
@ -3637,7 +3637,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg].realize()
assert isinstance(obj, ConstDictVariable)
assert obj.is_mutable()
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
obj.call_method(self, "update", [v], {})
DICT_UPDATE = DICT_MERGE

View File

@ -21,9 +21,9 @@ restoring state changes.
import inspect
import sys
import warnings
from collections.abc import Callable, Sequence, Sized
from collections.abc import Callable, Sequence
from contextlib import ExitStack
from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union
from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union
import torch._C
from torch._guards import Guard

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
"""
Dictionary-related variable tracking classes for PyTorch Dynamo.
@ -24,7 +26,7 @@ import inspect
import operator
import types
from collections.abc import Hashable as py_Hashable
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING
from torch._subclasses.fake_tensor import is_fake
@ -57,13 +59,11 @@ if TYPE_CHECKING:
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
def was_instancecheck_override(obj: Any) -> bool:
def was_instancecheck_override(obj):
return type(obj).__dict__.get("__instancecheck__", False)
def raise_unhashable(
arg: VariableTracker, tx: Optional["InstructionTranslator"] = None
) -> None:
def raise_unhashable(arg, tx=None):
if tx is None:
from torch._dynamo.symbolic_convert import InstructionTranslator
@ -75,7 +75,7 @@ def raise_unhashable(
)
def is_hashable(x: VariableTracker) -> bool:
def is_hashable(x):
# NB - performing isinstance check on a LazVT realizes the VT, accidentally
# inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at
# the underlying value without realizing the VT. Consider updating the
@ -143,7 +143,7 @@ class ConstDictVariable(VariableTracker):
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
"""
def __init__(self, vt: VariableTracker) -> None:
def __init__(self, vt) -> None:
# We specialize SymNodes
vt = specialize_symnode(vt)
# TODO Temporarily remove to figure out what keys are we breaking on
@ -153,7 +153,7 @@ class ConstDictVariable(VariableTracker):
self.vt = vt
@property
def underlying_value(self) -> Any:
def underlying_value(self):
if (
isinstance(self.vt, variables.LazyVariableTracker)
and not self.vt.is_realized()
@ -178,8 +178,7 @@ class ConstDictVariable(VariableTracker):
elif isinstance(self.vt, variables.FrozenDataClassVariable):
Hashable = ConstDictVariable._HashableTracker
fields_values = {
k: Hashable(v).underlying_value
for k, v in self.vt.fields.items() # type: ignore[attr-defined]
k: Hashable(v).underlying_value for k, v in self.vt.fields.items()
}
return variables.FrozenDataClassVariable.HashWrapper(
self.vt.python_type(), fields_values
@ -188,16 +187,16 @@ class ConstDictVariable(VariableTracker):
# The re module in Python 3.13+ has a dictionary (_cache2) with
# an object as key (`class _ZeroSentinel(int): ...`):
# python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual
return self.vt.value # type: ignore[attr-defined,union-attr]
return self.vt.value
else:
x = self.vt.as_python_constant()
return x
def __hash__(self) -> int:
def __hash__(self):
return hash(self.underlying_value)
@staticmethod
def _eq_impl(a: Any, b: Any) -> bool:
def _eq_impl(a, b):
# TODO: Put this in utils and share it between variables/builtin.py and here
type_a, type_b = type(a), type(b)
if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)):
@ -213,7 +212,7 @@ class ConstDictVariable(VariableTracker):
else:
return a == b
def __eq__(self, other: object) -> bool:
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
Hashable = ConstDictVariable._HashableTracker
assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), (
type(other)
@ -227,8 +226,8 @@ class ConstDictVariable(VariableTracker):
def __init__(
self,
items: dict[VariableTracker, VariableTracker],
user_cls: type = dict,
**kwargs: Any,
user_cls=dict,
**kwargs,
) -> None:
# .clone() pass these arguments in kwargs but they're recreated a few
# lines below
@ -248,22 +247,18 @@ class ConstDictVariable(VariableTracker):
for x, v in items.items()
)
def make_hashable(
key: Union[VariableTracker, "ConstDictVariable._HashableTracker"],
) -> "ConstDictVariable._HashableTracker":
def make_hashable(key):
return key if isinstance(key, Hashable) else Hashable(key)
dict_cls = self._get_dict_cls_from_user_cls(user_cls)
self.items = dict_cls({make_hashable(x): v for x, v in items.items()})
# need to reconstruct everything if the dictionary is an intermediate value
# or if a pop/delitem was executed
self.should_reconstruct_all = (
not is_from_local_source(self.source) if self.source else True
)
self.should_reconstruct_all = not is_from_local_source(self.source)
self.original_items = items.copy()
self.user_cls = user_cls
def _get_dict_cls_from_user_cls(self, user_cls: type) -> type:
def _get_dict_cls_from_user_cls(self, user_cls):
accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict)
# avoid executing user code if user_cls is a dict subclass
@ -282,10 +277,10 @@ class ConstDictVariable(VariableTracker):
dict_cls = dict
return dict_cls
def as_proxy(self) -> dict[Any, Any]:
def as_proxy(self):
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
def debug_repr(self) -> str:
def debug_repr(self):
return (
"{"
+ ", ".join(
@ -294,20 +289,20 @@ class ConstDictVariable(VariableTracker):
+ "}"
)
def as_python_constant(self) -> dict[Any, Any]:
def as_python_constant(self):
return {
k.vt.as_python_constant(): v.as_python_constant()
for k, v in self.items.items()
}
def keys_as_python_constant(self) -> dict[Any, VariableTracker]:
def keys_as_python_constant(self):
self.install_dict_keys_match_guard()
return {k.vt.as_python_constant(): v for k, v in self.items.items()}
def python_type(self) -> type:
def python_type(self):
return self.user_cls
def __contains__(self, vt: VariableTracker) -> bool:
def __contains__(self, vt) -> bool:
assert isinstance(vt, VariableTracker)
Hashable = ConstDictVariable._HashableTracker
return (
@ -327,15 +322,13 @@ class ConstDictVariable(VariableTracker):
for key, value in self.items.items()
)
def is_new_item(
self, value: Optional[VariableTracker], other: VariableTracker
) -> bool:
def is_new_item(self, value, other):
# compare the id of the realized values if both values are not lazy VTs
if value and value.is_realized() and other.is_realized():
return id(value.realize()) != id(other.realize())
return id(value) != id(other)
def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None:
def reconstruct_kvs_into_new_dict(self, codegen):
# Build a dictionary that contains the keys and values.
num_args = 0
for key, value in self.items.items():
@ -347,7 +340,7 @@ class ConstDictVariable(VariableTracker):
num_args += 1
codegen.append_output(create_instruction("BUILD_MAP", arg=num_args))
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen: "PyCodegen"):
if self.user_cls is collections.OrderedDict:
# emit `OrderedDict(constructed_dict)`
codegen.add_push_null(
@ -365,21 +358,19 @@ class ConstDictVariable(VariableTracker):
def getitem_const_raise_exception_if_absent(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
raise_observed_exception(KeyError, tx)
return self.items[key]
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined]
msg = f"Dictionary key {arg.value} not found during tracing"
unimplemented_v2(
gb_type="key not found in dict",
context=f"Key {arg.value}", # type: ignore[attr-defined]
context=f"Key {arg.value}",
explanation=msg,
hints=[
"Check if the key exists in the dictionary before accessing it.",
@ -388,13 +379,13 @@ class ConstDictVariable(VariableTracker):
)
return self.items[key]
def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]:
def maybe_getitem_const(self, arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
return None
return self.items[key]
def realize_key_vt(self, arg: VariableTracker) -> None:
def realize_key_vt(self, arg: VariableTracker):
# Realize the LazyVT on a particular index
assert arg in self
key = ConstDictVariable._HashableTracker(arg)
@ -403,13 +394,11 @@ class ConstDictVariable(VariableTracker):
if isinstance(original_key_vt, variables.LazyVariableTracker):
original_key_vt.realize()
def install_dict_keys_match_guard(self) -> None:
def install_dict_keys_match_guard(self):
if self.source:
install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH))
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
def install_dict_contains_guard(self, tx, args):
# Key guarding - These are the cases to consider
# 1) The dict has been mutated. In this case, we would have already
# inserted a DICT_KEYS_MATCH guard, so we can skip.
@ -450,11 +439,11 @@ class ConstDictVariable(VariableTracker):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# NB - Both key and value are LazyVariableTrackers in the beginning. So,
# we have to insert guards when a dict method is accessed. For this to
# be simple, we are conservative and overguard. We skip guard only for
@ -473,7 +462,7 @@ class ConstDictVariable(VariableTracker):
tx, *args, **kwargs
)
tx.output.side_effects.mutation(self)
self.items.update(temp_dict_vt.items) # type: ignore[attr-defined]
self.items.update(temp_dict_vt.items)
return ConstantVariable.create(None)
elif name == "__getitem__":
# Key guarding - Nothing to do. LazyVT for value will take care.
@ -537,7 +526,7 @@ class ConstDictVariable(VariableTracker):
return ConstantVariable.create(len(self.items))
elif name == "__setitem__" and self.is_mutable():
if not arg_hashable:
raise_unhashable(args[0], tx)
raise_unhashable(args[0])
self.install_dict_keys_match_guard()
if kwargs or len(args) != 2:
@ -561,7 +550,7 @@ class ConstDictVariable(VariableTracker):
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
if not arg_hashable:
raise_unhashable(args[0], tx)
raise_unhashable(args[0])
if args[0] not in self:
self.install_dict_contains_guard(tx, args)
@ -576,7 +565,7 @@ class ConstDictVariable(VariableTracker):
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
if not arg_hashable:
raise_unhashable(args[0], tx)
raise_unhashable(args[0])
if args[0] not in self:
# missing item, return the default value. Install no DICT_CONTAINS guard.
@ -610,7 +599,7 @@ class ConstDictVariable(VariableTracker):
last = v.value
else:
raise_args_mismatch(tx, name)
k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined]
k, v = self.items.popitem(last=last)
else:
k, v = self.items.popitem()
@ -643,17 +632,17 @@ class ConstDictVariable(VariableTracker):
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard()
dict_vt: ConstDictVariable = args[0]
dict_vt = args[0]
else:
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment]
self.items.update(dict_vt.items) # type: ignore[attr-defined]
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
self.items.update(dict_vt.items)
if has_kwargs:
# Handle kwargs
kwargs_hashable = {
kwargs = {
Hashable(ConstantVariable.create(k)): v
for k, v in kwargs.items()
}
self.items.update(kwargs_hashable)
self.items.update(kwargs)
return ConstantVariable.create(None)
else:
return super().call_method(tx, name, args, kwargs)
@ -667,7 +656,7 @@ class ConstDictVariable(VariableTracker):
)
if not arg_hashable:
raise_unhashable(args[0], tx)
raise_unhashable(args[0])
self.install_dict_contains_guard(tx, args)
contains = args[0] in self
@ -682,7 +671,7 @@ class ConstDictVariable(VariableTracker):
)
if not arg_hashable:
raise_unhashable(args[0], tx)
raise_unhashable(args[0])
self.install_dict_keys_match_guard()
if kwargs or len(args) > 2:
@ -718,7 +707,7 @@ class ConstDictVariable(VariableTracker):
and "last" in kwargs
and isinstance(kwargs["last"], ConstantVariable)
):
last = kwargs.get("last").value # type: ignore[union-attr]
last = kwargs.get("last").value
key = Hashable(args[0])
self.items.move_to_end(key, last=last)
@ -734,7 +723,7 @@ class ConstDictVariable(VariableTracker):
)
elif name == "__ne__":
return ConstantVariable.create(
not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined]
not self.call_method(tx, "__eq__", args, kwargs).value
)
elif name == "__or__":
if len(args) != 1:
@ -761,14 +750,14 @@ class ConstDictVariable(VariableTracker):
if not istype(
other, (ConstDictVariable, variables.UserDefinedDictVariable)
):
err_msg = (
msg = (
f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
f"and '{other.python_type().__name__}'"
)
raise_observed_exception(TypeError, tx, args=[err_msg])
raise_observed_exception(TypeError, tx, args=[msg])
# OrderedDict overloads __ror__
ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined]
ts = {self.user_cls, other.user_cls}
user_cls = (
collections.OrderedDict
if any(issubclass(t, collections.OrderedDict) for t in ts)
@ -785,8 +774,8 @@ class ConstDictVariable(VariableTracker):
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard() # type: ignore[attr-defined]
new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined]
args[0].install_dict_keys_match_guard()
new_dict_vt.items.update(args[0].items)
return new_dict_vt
elif name == "__ior__":
self.call_method(tx, "update", args, kwargs)
@ -800,13 +789,11 @@ class ConstDictVariable(VariableTracker):
else:
return super().call_method(tx, name, args, kwargs)
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
def unpack_var_sequence(self, tx):
self.install_dict_keys_match_guard()
return [x.vt for x in self.items.keys()]
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
def call_obj_hasattr(self, tx, name):
# dict not allow setting arbitrary attributes. OrderedDict and
# defaultdict allow arbitrary setattr, but not deletion of default attrs
if any(
@ -829,25 +816,25 @@ class ConstDictVariable(VariableTracker):
],
)
def clone(self, **kwargs: Any) -> VariableTracker:
def clone(self, **kwargs):
self.install_dict_keys_match_guard()
return super().clone(**kwargs)
class MappingProxyVariable(VariableTracker):
# proxies to the original dict_vt
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
super().__init__(**kwargs)
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
def python_type(self) -> type:
def python_type(self):
return types.MappingProxyType
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
def unpack_var_sequence(self, tx):
return self.dv_dict.unpack_var_sequence(tx)
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen: "PyCodegen"):
# load types.MappingProxyType
if self.source:
msg = (
@ -876,11 +863,11 @@ class MappingProxyVariable(VariableTracker):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if self.source and tx.output.side_effects.has_existing_dict_mutation():
msg = (
"A dict has been modified while we have an existing mappingproxy object. "
@ -905,7 +892,7 @@ class MappingProxyVariable(VariableTracker):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
) -> "VariableTracker":
if self.python_type() is types.MappingProxyType:
return ConstantVariable.create(name in types.MappingProxyType.__dict__)
return super().call_obj_hasattr(tx, name)
@ -913,44 +900,35 @@ class MappingProxyVariable(VariableTracker):
class NNModuleHooksDictVariable(ConstDictVariable):
# Special class to avoid adding any guards on the nn module hook ids.
def install_dict_keys_match_guard(self) -> None:
def install_dict_keys_match_guard(self):
pass
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
def install_dict_contains_guard(self, tx, args):
pass
class DefaultDictVariable(ConstDictVariable):
def __init__(
self,
items: dict[VariableTracker, VariableTracker],
user_cls: type,
default_factory: Optional[VariableTracker] = None,
**kwargs: Any,
) -> None:
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
super().__init__(items, user_cls, **kwargs)
assert user_cls is collections.defaultdict
if default_factory is None:
default_factory = ConstantVariable.create(None)
self.default_factory = default_factory
def is_python_constant(self) -> bool:
def is_python_constant(self):
# Return false for unsupported defaults. This ensures that a bad handler
# path is not taken in BuiltinVariable for getitem.
if self.default_factory not in [list, tuple, dict] and not self.items:
return False
return super().is_python_constant()
def debug_repr(self) -> str:
assert self.default_factory is not None
def debug_repr(self):
return (
f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})"
)
@staticmethod
def is_supported_arg(arg: VariableTracker) -> bool:
def is_supported_arg(arg):
if isinstance(arg, variables.BuiltinVariable):
return arg.fn in (list, tuple, dict, set)
else:
@ -964,11 +942,11 @@ class DefaultDictVariable(ConstDictVariable):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__":
if len(args) != 1:
raise_args_mismatch(tx, name, "1 args", f"{len(args)} args")
@ -984,13 +962,13 @@ class DefaultDictVariable(ConstDictVariable):
else:
default_var = self.default_factory.call_function(tx, [], {})
super().call_method(
tx, "__setitem__", [args[0], default_var], kwargs
tx, "__setitem__", (args[0], default_var), kwargs
)
return default_var
else:
return super().call_method(tx, name, args, kwargs)
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen):
# emit `defaultdict(default_factory, new_dict)`
codegen.add_push_null(
lambda: codegen.extend_output(
@ -1016,48 +994,40 @@ class SetVariable(ConstDictVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs: Any,
**kwargs,
) -> None:
# pyrefly: ignore[bad-assignment]
items = dict.fromkeys(items, SetVariable._default_value())
# pyrefly: ignore[bad-argument-type]
super().__init__(items, **kwargs)
def debug_repr(self) -> str:
def debug_repr(self):
if not self.items:
return "set()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
def set_items(self):
return set(self.items.keys())
@staticmethod
def _default_value() -> VariableTracker:
def _default_value():
# Variable to fill in he keys of the dictionary
return ConstantVariable.create(None)
def as_proxy(self) -> Any:
def as_proxy(self):
return {k.vt.as_proxy() for k in self.set_items}
def python_type(self) -> type:
def python_type(self):
return set
def as_python_constant(self) -> Any:
def as_python_constant(self):
return {k.vt.as_python_constant() for k in self.set_items}
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen: "PyCodegen"):
codegen.foreach([x.vt for x in self.set_items])
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
def _fast_set_method(
self,
tx: "InstructionTranslator",
fn: Any,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
def _fast_set_method(self, tx, fn, args, kwargs):
try:
res = fn(
*[x.as_python_constant() for x in [self, *args]],
@ -1067,16 +1037,15 @@ class SetVariable(ConstDictVariable):
raise_observed_exception(
type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
)
# pyrefly: ignore[unbound-name]
return VariableTracker.build(tx, res)
def call_method(
self,
tx: "InstructionTranslator",
name: str,
tx,
name,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
) -> "VariableTracker":
# We forward the calls to the dictionary model
from ..utils import check_constant_args
@ -1096,10 +1065,10 @@ class SetVariable(ConstDictVariable):
return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
if name == "__init__":
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs)
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
tx.output.side_effects.mutation(self)
self.items.clear()
self.items.update(temp_set_vt.items) # type: ignore[attr-defined]
self.items.update(temp_set_vt.items)
return ConstantVariable.create(None)
elif name == "add":
if kwargs or len(args) != 1:
@ -1110,7 +1079,7 @@ class SetVariable(ConstDictVariable):
f"{len(args)} args and {len(kwargs)} kwargs",
)
name = "__setitem__"
args = [args[0], SetVariable._default_value()]
args = (args[0], SetVariable._default_value())
elif name == "pop":
if kwargs or args:
raise_args_mismatch(
@ -1121,14 +1090,12 @@ class SetVariable(ConstDictVariable):
)
# Choose an item at random and pop it via the Dict.pop method
try:
result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment]
result = self.set_items.pop().vt
except KeyError as e:
raise_observed_exception(
KeyError, tx, args=list(map(ConstantVariable.create, e.args))
)
# pyrefly: ignore[unbound-name]
super().call_method(tx, name, [result], kwargs)
# pyrefly: ignore[unbound-name]
super().call_method(tx, name, (result,), kwargs)
return result
elif name == "isdisjoint":
if kwargs or len(args) != 1:
@ -1250,7 +1217,6 @@ class SetVariable(ConstDictVariable):
f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
assert m is not None
return self.call_method(tx, m, args, kwargs)
elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"):
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
@ -1264,34 +1230,29 @@ class SetVariable(ConstDictVariable):
"__ixor__": "symmetric_difference_update",
"__isub__": "difference_update",
}.get(name)
assert m is not None
self.call_method(tx, m, args, kwargs)
return self
elif name == "__eq__":
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
return ConstantVariable.create(False)
r = self.call_method(tx, "symmetric_difference", args, kwargs)
return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined]
return ConstantVariable.create(len(r.set_items) == 0)
elif name in cmp_name_to_op_mapping:
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
return ConstantVariable.create(NotImplemented)
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
)
return super().call_method(tx, name, args, kwargs)
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
raise RuntimeError("Illegal to getitem on a set")
def install_dict_keys_match_guard(self) -> None:
def install_dict_keys_match_guard(self):
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
def install_dict_contains_guard(self, tx, args):
super().install_dict_contains_guard(tx, args)
@ -1299,27 +1260,27 @@ class FrozensetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs: Any,
**kwargs,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self) -> str:
def debug_repr(self):
if not self.items:
return "frozenset()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
def set_items(self):
return self.items.keys()
def python_type(self) -> type:
def python_type(self):
return frozenset
def as_python_constant(self) -> Any:
def as_python_constant(self):
return frozenset({k.vt.as_python_constant() for k in self.set_items})
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen: "PyCodegen"):
codegen.foreach([x.vt for x in self.set_items])
codegen.add_push_null(
lambda: codegen.extend_output(
@ -1332,11 +1293,11 @@ class FrozensetVariable(SetVariable):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
tx,
name,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
) -> "VariableTracker":
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a frozenset")
elif name == "__init__":
@ -1355,7 +1316,7 @@ class FrozensetVariable(SetVariable):
"symmetric_difference",
):
r = super().call_method(tx, name, args, kwargs)
return FrozensetVariable(r.items) # type: ignore[attr-defined]
return FrozensetVariable(r.items)
return super().call_method(tx, name, args, kwargs)
@ -1363,11 +1324,11 @@ class DictKeySetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs: Any,
**kwargs,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self) -> str:
def debug_repr(self):
if not self.items:
return "dict_keys([])"
else:
@ -1377,35 +1338,33 @@ class DictKeySetVariable(SetVariable):
+ "])"
)
def install_dict_keys_match_guard(self) -> None:
def install_dict_keys_match_guard(self):
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
def install_dict_contains_guard(self, tx, args):
# Already EQUALS_MATCH guarded
pass
@property
def set_items(self) -> Any:
def set_items(self):
return self.items
def python_type(self) -> type:
def python_type(self):
return dict_keys
def as_python_constant(self) -> Any:
def as_python_constant(self):
return dict.fromkeys(
{k.vt.as_python_constant() for k in self.set_items}, None
).keys()
def call_method(
self,
tx: "InstructionTranslator",
name: str,
tx,
name,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
) -> "VariableTracker":
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a dict_keys")
return super().call_method(tx, name, args, kwargs)
@ -1420,47 +1379,42 @@ class DictViewVariable(VariableTracker):
kv: Optional[str] = None
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
super().__init__(**kwargs)
assert self.kv in ("keys", "values", "items")
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
@property
def view_items(self) -> Any:
assert self.kv is not None
def view_items(self):
return getattr(self.dv_dict.items, self.kv)()
@property
def view_items_vt(self) -> list[VariableTracker]:
def view_items_vt(self):
# Returns an iterable of the unpacked items
# Implement in the subclasses
raise NotImplementedError
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
def unpack_var_sequence(self, tx):
return self.view_items_vt
def reconstruct(self, codegen: "PyCodegen") -> None:
assert self.kv is not None
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.dv_dict)
codegen.load_method(self.kv)
codegen.call_method(0)
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
assert self.kv is not None
def call_obj_hasattr(self, tx, name):
if name in self.python_type().__dict__:
return ConstantVariable.create(True)
return ConstantVariable.create(False)
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if name == "__len__":
return self.dv_dict.call_method(tx, name, args, kwargs)
elif name == "__iter__":
@ -1474,24 +1428,24 @@ class DictKeysVariable(DictViewVariable):
kv = "keys"
@property
def set_items(self) -> set[VariableTracker]:
def set_items(self):
return set(self.view_items)
@property
def view_items_vt(self) -> list[VariableTracker]:
def view_items_vt(self):
# Returns an iterable of the unpacked items
return [x.vt for x in self.view_items]
def python_type(self) -> type:
def python_type(self):
return dict_keys
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if name == "__contains__":
return self.dv_dict.call_method(tx, name, args, kwargs)
elif name in (
@ -1506,13 +1460,13 @@ class DictKeysVariable(DictViewVariable):
):
# These methods always returns a set
m = getattr(self.set_items, name)
r = m(args[0].set_items) # type: ignore[attr-defined]
r = m(args[0].set_items)
return SetVariable(r)
if name in cmp_name_to_op_mapping:
if not isinstance(args[0], (SetVariable, DictKeysVariable)):
return ConstantVariable.create(NotImplemented)
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
)
return super().call_method(tx, name, args, kwargs)
@ -1522,10 +1476,10 @@ class DictValuesVariable(DictViewVariable):
kv = "values"
@property
def view_items_vt(self) -> list[VariableTracker]:
def view_items_vt(self):
return list(self.view_items)
def python_type(self) -> type:
def python_type(self):
return dict_values
@ -1533,20 +1487,14 @@ class DictItemsVariable(DictViewVariable):
kv = "items"
@property
def view_items_vt(self) -> list[VariableTracker]:
def view_items_vt(self):
# Returns an iterable of the unpacked items
return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items]
def python_type(self) -> type:
def python_type(self):
return dict_items
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
def call_method(self, tx, name, args, kwargs):
# TODO(guilhermeleobas): This should actually check if args[0]
# implements the mapping protocol.
if name == "__eq__":

View File

@ -20,7 +20,8 @@ checks and proper tracking of distributed state and operations across processes.
import functools
import inspect
from typing import Any, Sequence, TYPE_CHECKING
from collections.abc import Sequence
from typing import Any, TYPE_CHECKING
import torch
from torch.fx.experimental._backward_state import BackwardState

View File

@ -14,8 +14,8 @@ handling of iterator operations during code transformation and optimization.
"""
import itertools
from collections.abc import Callable
from typing import Any, Sequence, TYPE_CHECKING, Union
from collections.abc import Callable, Sequence
from typing import Any, TYPE_CHECKING, Union
from .. import graph_break_hints, polyfills, variables
from ..bytecode_transformation import (

View File

@ -22,7 +22,8 @@ optimizer-specific optimizations and safety guarantees.
import logging
import weakref
from typing import Any, Iterable, Optional, TYPE_CHECKING
from collections.abc import Iterable
from typing import Any, Optional, TYPE_CHECKING
import torch
from torch._dynamo.variables.tensor import TensorVariable

View File

@ -19,8 +19,8 @@ by limiting operations to known-safe patterns and failing fast for unsafe usage.
"""
import functools
from collections.abc import Callable
from typing import Any, Iterable, TYPE_CHECKING, TypeVar
from collections.abc import Callable, Iterable
from typing import Any, TYPE_CHECKING, TypeVar
from typing_extensions import ParamSpec
import torch

View File

@ -1,5 +1,6 @@
from collections.abc import Sequence
from inspect import getattr_static
from typing import Any, Sequence, TYPE_CHECKING, TypeGuard
from typing import Any, TYPE_CHECKING, TypeGuard
from torch._guards import Source
from torch.backends.cuda import SDPAParams

View File

@ -1,5 +1,6 @@
import collections
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any, Optional
import torch
from torch._dynamo.variables.dicts import ConstDictVariable

View File

@ -29,9 +29,9 @@ import contextlib
import functools
import inspect
import operator
from collections.abc import Sequence
from collections.abc import Generator, Iterable, Sequence
from types import TracebackType
from typing import Any, Generator, Iterable, Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
import torch._C
import torch.utils._pytree as pytree

View File

@ -22,9 +22,10 @@ from __future__ import annotations
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import copy
from dataclasses import dataclass
from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar
from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar
import torch
from torch._dynamo.precompile_context import BackendCacheArtifact

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import hashlib
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
import sympy # noqa: TC002
@ -17,8 +17,6 @@ from .simd import SIMDKernel, SIMDScheduling
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ..ir import IRNode
from ..scheduler import BaseSchedulerNode

View File

@ -627,7 +627,7 @@ class ComboKernel(Kernel):
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
filename=__file__,
num_warps={self.num_warps},
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)

View File

@ -1,6 +1,7 @@
from collections.abc import Callable
from threading import Lock, Thread
from time import monotonic, sleep
from typing import Callable, Optional, Union
from typing import Optional, Union
class Timer:

View File

@ -2,7 +2,8 @@ import collections
import logging
import operator
from collections import defaultdict
from typing import Any, Callable, Literal, TypeAlias
from collections.abc import Callable
from typing import Any, Literal, TypeAlias
import torch
import torch.distributed as dist

View File

@ -4,10 +4,10 @@ import inspect
import logging
import math
import operator
from collections.abc import Generator
from collections.abc import Callable, Generator
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, cast
from typing import Any, cast
import torch
import torch.fx as fx

View File

@ -1,5 +1,5 @@
import logging
from typing import Callable
from collections.abc import Callable
import torch
from torch._inductor.fx_passes.bucketing import (

View File

@ -1,8 +1,8 @@
import itertools
import logging
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
from typing import Callable
import torch
import torch.fx as fx

View File

@ -2,7 +2,7 @@
import functools
import operator
from functools import reduce
from typing import Any, Callable
from typing import Any, TYPE_CHECKING
import torch
from torch._dynamo.utils import counters
@ -35,6 +35,10 @@ from .quantization import (
)
if TYPE_CHECKING:
from collections.abc import Callable
if torch._C._has_mkldnn:
aten = torch.ops.aten
mkldnn = torch.ops.mkldnn

View File

@ -4,9 +4,9 @@ import itertools
import logging
import sys
from collections import Counter, defaultdict
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from typing import Any, Callable
from typing import Any
import torch
import torch.fx as fx

View File

@ -2,8 +2,8 @@ import functools
import itertools
import operator
import typing
from collections.abc import Sequence
from typing import Any, Callable
from collections.abc import Callable, Sequence
from typing import Any
import torch
import torch._inductor.runtime.runtime_utils

View File

@ -5,7 +5,8 @@ import itertools
import logging
import operator
from collections import Counter, defaultdict
from typing import Any, Callable, TypeVar
from collections.abc import Callable
from typing import Any, TypeVar
from typing_extensions import ParamSpec
import torch

View File

@ -3,10 +3,10 @@ import itertools
import logging
import operator
from collections import defaultdict
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any, Callable, cast
from typing import Any, cast
import torch
import torch.fx.node

View File

@ -4,9 +4,8 @@ import logging
import operator
import os
from collections import defaultdict
from collections.abc import Sequence
from typing import Any, Callable
from typing_extensions import TypeAlias
from collections.abc import Callable, Sequence
from typing import Any, TypeAlias
import torch
from torch._dynamo.utils import counters

View File

@ -2,7 +2,8 @@
import functools
import logging
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any, Optional, Union
import torch
from torch._inductor.codegen.subgraph import SubgraphTemplate

View File

@ -3,8 +3,9 @@
import functools
import importlib
from collections.abc import Callable, Sequence
from contextlib import contextmanager
from typing import Any, Callable, Optional, Sequence
from typing import Any, Optional
import sympy
from sympy import Expr, Integer

View File

@ -5,8 +5,8 @@ from collections.abc import Callable
from functools import cached_property, wraps
from itertools import chain
from statistics import median
from typing import Any, Optional, Union
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
from typing import Any, Concatenate, Optional, Union
from typing_extensions import ParamSpec, Self, TypeVar
import torch
import torch.utils._pytree as pytree

View File

@ -1,6 +1,6 @@
import os
from collections.abc import Callable
from functools import cache, partial
from typing import Callable
import torch
from torch._environment import is_fbcode

View File

@ -12,8 +12,8 @@ from os import PathLike
from pathlib import Path
from threading import Lock
from time import time
from typing import Any, Callable, TYPE_CHECKING
from typing_extensions import override, TypeAlias
from typing import Any, TYPE_CHECKING, TypeAlias
from typing_extensions import override
from filelock import FileLock
@ -21,6 +21,8 @@ from . import config, context, exceptions, implementations as impls, locks
if TYPE_CHECKING:
from collections.abc import Callable
from .utils import P, R

View File

@ -12,8 +12,8 @@ The module offers both context manager and manual acquisition patterns:
from __future__ import annotations
from contextlib import _GeneratorContextManager, contextmanager, ExitStack
from typing import Generator, TYPE_CHECKING
from typing_extensions import Protocol, TypeAlias
from typing import TYPE_CHECKING, TypeAlias
from typing_extensions import Protocol
from filelock import FileLock, Timeout
@ -21,6 +21,7 @@ from . import exceptions, implementations as impls
if TYPE_CHECKING:
from collections.abc import Generator
from threading import Lock

View File

@ -3586,24 +3586,13 @@ def user_autotune(
)
def foreach(triton_meta, filename=None, inductor_meta=None):
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
configs = []
# Naive autotuning path for num_warps
if not (
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
return cached_autotune(
None,
configs,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,

View File

@ -52,7 +52,26 @@ __all__ = [
"MemRecordsAcc",
]
from contextlib import ContextDecorator
try:
# Available in Python >= 3.2
from contextlib import ContextDecorator as _ContextDecorator
except ImportError:
import functools
class _ContextDecorator: # type: ignore[no-redef]
def __enter__(self):
raise NotImplementedError
def __exit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError
def __call__(self, func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
with self:
return func(*args, **kwargs)
return wrapped
# global python state - whether profiler is currently enabled
@ -725,7 +744,8 @@ class profile:
return all_function_events
class record_function(ContextDecorator):
# pyrefly: ignore [invalid-inheritance]
class record_function(_ContextDecorator):
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
Label will only appear if CPU activity tracing is enabled.

View File

@ -1224,3 +1224,43 @@ def _build_table(
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
)
return "".join(result)
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)

View File

@ -4,7 +4,6 @@
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
#include <torch/headeronly/util/shim_utils.h>
#include <climits>
#include <memory>
@ -14,7 +13,6 @@
HIDDEN_NAMESPACE_BEGIN(torch, stable)
using accelerator::DeviceIndex;
using torch::headeronly::IntHeaderOnlyArrayRef;
using torch::headeronly::ScalarType;
// The torch::stable::Tensor class is a highlevel C++ wrapper around
@ -95,32 +93,6 @@ class Tensor {
return numel;
}
// note: this API is, for all intents and purposes, the same as the one in
// TensorBase.h: it returns a borrowed reference of the dimension sizes of
// a Tensor.
//
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
// which has slightly less functionality than a regular IntArrayRef. See
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
IntHeaderOnlyArrayRef sizes() const {
int64_t* sizes;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(ath_.get(), &sizes));
return IntHeaderOnlyArrayRef(sizes, dim());
}
// note: this API is, for all intents and purposes, the same as the one in
// TensorBase.h: it returns a borrowed reference of the strides of a
// Tensor.
//
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
// which has slightly less functionality than a regular IntArrayRef. See
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
IntHeaderOnlyArrayRef strides() const {
int64_t* strides;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(ath_.get(), &strides));
return IntHeaderOnlyArrayRef(strides, dim());
}
// note: this is a subset of the original TensorBase API. It takes no
// arguments whereas the original API takes in a kwarg of memory format.
// Here, we assume the default contiguous memory format.

View File

@ -1,8 +1,9 @@
import functools
import math
import operator
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from datetime import timedelta
from typing import Callable
import torch
from torch._C import ScriptObject

View File

@ -10,9 +10,10 @@
import logging
import os
import time
from collections.abc import Callable
from concurrent.futures.thread import ThreadPoolExecutor
from threading import Event
from typing import Callable, Optional, TextIO, TYPE_CHECKING, Union
from typing import Optional, TextIO, TYPE_CHECKING, Union
if TYPE_CHECKING:

View File

@ -443,6 +443,7 @@ class CodeGen:
colored: bool = False,
# Render each argument on its own line
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
free_vars: list[str] = []
body: list[str] = []
@ -647,15 +648,6 @@ class CodeGen:
if verbose:
# override annotation with more detailed information
try:
from torch.distributed.tensor._api import DTensor, DTensorSpec
dtensorspec_format_shard_order_str = (
DTensorSpec.format_shard_order_str
)
except ModuleNotFoundError:
DTensor = None # type: ignore[assignment,misc]
dtensorspec_format_shard_order_str = None
from torch.fx.experimental.proxy_tensor import py_sym_types
from torch.fx.passes.shape_prop import TensorMetadata
@ -686,16 +678,6 @@ class CodeGen:
core = _tensor_annotation(meta_val)
if is_plain:
maybe_type_annotation = f': "{core}"'
elif type(meta_val) is DTensor:
assert dtensorspec_format_shard_order_str is not None
dtensor_meta = dtensorspec_format_shard_order_str(
meta_val._spec.placements, # type: ignore[attr-defined]
meta_val._spec.shard_order, # type: ignore[attr-defined]
)
cls = meta_val.__class__.__name__
maybe_type_annotation = (
f': "{cls}({core}, {dim_green(dtensor_meta)})"'
)
else:
cls = meta_val.__class__.__name__
maybe_type_annotation = f': "{cls}({core})"'
@ -817,6 +799,10 @@ class CodeGen:
return
raise NotImplementedError(f"node: {node.op} {node.target}")
if record_func:
body.append(
"_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n"
)
for i, node in enumerate(nodes):
# NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one
@ -826,8 +812,22 @@ class CodeGen:
# node index, which will be deleted later
# after going through _body_transformer
body.append(f"# COUNTER: {i}\n")
do_record = record_func and node.op in (
"call_function",
"call_method",
"call_module",
)
if do_record:
# The double hash ## convention is used by post-processing to find the fx markers
body.append(
f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n"
)
emit_node(node)
delete_unused_values(node)
if do_record:
body.append(f"_rf_{node.name}.__exit__(None, None, None)\n")
if record_func:
body.append("_rf.__exit__(None, None, None)\n")
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@ -1779,6 +1779,7 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
"""
Turn this ``Graph`` into valid Python code.
@ -1846,6 +1847,7 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def _python_code(
@ -1858,6 +1860,7 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
return self._codegen._gen_python_code(
self.nodes,
@ -1868,6 +1871,7 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def __str__(self) -> str:

View File

@ -861,14 +861,18 @@ class {module_name}(torch.nn.Module):
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module="self")
from torch._dynamo import config as dynamo_config
python_code = self._graph.python_code(
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
)
self._code = python_code.src
self._lineno_map = python_code._lineno_map
self._prologue_start = python_code._prologue_start
cls = type(self)
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
from torch._dynamo import config as dynamo_config
if dynamo_config.enrich_profiler_metadata:
# Generate metadata and register for profiler augmentation
@ -885,7 +889,6 @@ class {module_name}(torch.nn.Module):
# This ensures the same code+metadata always generates the same filename
hash_value = _metadata_hash(self._code, node_metadata)
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
filename = f"{file_stem}.py"
# Only include co_filename to use it directly as the cache key
@ -905,6 +908,13 @@ class {module_name}(torch.nn.Module):
_register_fx_metadata(filename, metadata)
# Replace the placeholder in generated code with actual filename
# The double hash ## convention is used by post-processing to find the fx markers
self._code = self._code.replace(
"torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')",
f"torch._C._profiler._RecordFunctionFast('## {filename} ##')",
)
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
# Determine whether this class explicitly defines a __call__ implementation

View File

@ -4,7 +4,7 @@ import operator
import re
from collections import deque
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Any, Literal, Optional, TYPE_CHECKING
from torch.autograd.profiler import profile
from torch.profiler import DeviceType
@ -400,3 +400,170 @@ def _init_for_cuda_graphs() -> None:
with profile():
pass
@dataclass
class TimelineEvent:
"""Represents an event in the profiler timeline."""
timestamp: int
event_type: Literal["start", "end", "regular"]
marker_type: Optional[Literal["filename", "node"]]
identifier: Optional[str | int]
event: dict[str, Any]
@dataclass
class ContextStackEntry:
"""Represents a context (filename or node) in the stack."""
context_type: Literal["filename", "node"]
identifier: str | int
metadata: Optional[dict]
tid: Optional[int] = None # Thread ID associated with this context
def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
"""
Maps recorded profiler events to their corresponding fx nodes and adds stack traces.
Builds a timeline of all events (regular ops and FX markers for filenames/nodes),
sorts by timestamp, then processes chronologically while maintaining a context stack of active
filename/node scopes. Regular events are augmented with stack traces and node names from the
innermost active context. Runtime is O(n log n) for n events.
Args:
traced_data: Json of profiler events from Chrome trace
Returns:
Dict mapping recorded event names to their aten operations with added stack traces
"""
from torch.fx.traceback import _FX_METADATA_REGISTRY
trace_events = traced_data.get("traceEvents", [])
# Create event timeline
event_timeline: list[TimelineEvent] = []
def is_fx_marker_event(event):
return (
event.get("cat") == "cpu_op"
and event.get("name", "").startswith("## ")
and event.get("name", "").endswith(" ##")
)
def append_fx_marker_event(event_type, identifier, event):
start_ts = event["ts"]
end_ts = start_ts + event["dur"]
event_timeline.append(
TimelineEvent(start_ts, "start", event_type, identifier, event)
)
event_timeline.append(
TimelineEvent(end_ts, "end", event_type, identifier, event)
)
for event in trace_events:
if "ts" not in event or "dur" not in event:
continue
if is_fx_marker_event(event):
content = event["name"][3:-3]
if content.endswith(".py"):
append_fx_marker_event("filename", content, event)
else:
try:
node_index = int(content)
except ValueError:
pass
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
else:
# Regular event that needs augmentation
start_ts = event["ts"]
event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event))
# Sort by timestamp
event_timeline.sort(key=lambda x: x.timestamp)
# Process events in chronological order with a stack
context_stack: list[ContextStackEntry] = []
# Invariant: all start event has a corresponding end event
for timeline_event in event_timeline:
match timeline_event.event_type:
case "start":
assert timeline_event.identifier is not None
if timeline_event.marker_type == "filename":
assert isinstance(timeline_event.identifier, str)
# Push filename context - query metadata registry on-demand
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
tid = timeline_event.event.get("tid")
context_stack.append(
ContextStackEntry(
"filename", timeline_event.identifier, metadata, tid
)
)
elif timeline_event.marker_type == "node":
# Find the current filename from stack
current_file_metadata = None
tid = timeline_event.event.get("tid")
for ctx_entry in reversed(context_stack):
if (
ctx_entry.context_type == "filename"
and ctx_entry.tid == tid
):
current_file_metadata = ctx_entry.metadata
break
if current_file_metadata:
node_metadata = current_file_metadata.get("node_metadata", {})
if timeline_event.identifier in node_metadata:
node_meta: Optional[dict] = node_metadata[
timeline_event.identifier
]
context_stack.append(
ContextStackEntry(
"node", timeline_event.identifier, node_meta, tid
)
)
case "end":
# Pop from stack - search backwards to find matching context
for i in range(len(context_stack) - 1, -1, -1):
ctx_entry = context_stack[i]
if (
timeline_event.marker_type == ctx_entry.context_type
and timeline_event.identifier == ctx_entry.identifier
):
context_stack.pop(i)
break
case "regular":
# Apply metadata from current context stack
# Find the most specific context (node takes precedence over filename)
# Only augment events with the same tid as the file/node event matched
current_stack_trace = None
current_node_name = None
event_tid = timeline_event.event.get("tid")
for ctx_entry in reversed(context_stack):
# Only apply metadata from contexts with matching tid
if ctx_entry.tid == event_tid:
if ctx_entry.context_type == "node" and ctx_entry.metadata:
current_stack_trace = ctx_entry.metadata.get(
"stack_trace", "No model stack trace available"
)
current_node_name = ctx_entry.metadata.get("name", "")
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
# if nodes are nested, e.g. in nested graph modules
break
# Augment the event
if current_stack_trace or current_node_name:
args = timeline_event.event.setdefault("args", {})
if current_stack_trace:
args["stack_trace"] = current_stack_trace
if current_node_name:
args["node_name"] = current_node_name

View File

@ -210,8 +210,7 @@ class _KinetoProfile:
def start_trace(self) -> None:
if self.execution_trace_observer:
self.execution_trace_observer.start()
if self.profiler is None:
raise AssertionError("Profiler must be initialized before starting trace")
assert self.profiler is not None
self.profiler._start_trace()
if self.profile_memory:
@ -257,8 +256,7 @@ class _KinetoProfile:
def stop_trace(self) -> None:
if self.execution_trace_observer:
self.execution_trace_observer.stop()
if self.profiler is None:
raise AssertionError("Profiler must be initialized before stopping trace")
assert self.profiler is not None
self.profiler.__exit__(None, None, None)
def export_chrome_trace(self, path: str):
@ -266,10 +264,7 @@ class _KinetoProfile:
Exports the collected trace in Chrome JSON format. If kineto is enabled, only
last cycle in schedule is exported.
"""
if self.profiler is None:
raise AssertionError(
"Profiler must be initialized before exporting chrome trace"
)
assert self.profiler
if path.endswith(".gz"):
fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False)
fp.close()
@ -289,8 +284,7 @@ class _KinetoProfile:
path (str): save stacks file to this location;
metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total"
"""
if self.profiler is None:
raise AssertionError("Profiler must be initialized before exporting stacks")
assert self.profiler
return self.profiler.export_stacks(path, metric)
def toggle_collection_dynamic(
@ -322,7 +316,7 @@ class _KinetoProfile:
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
"""
if self.profiler is None:
if not self.profiler:
return
self.profiler.toggle_collection_dynamic(enable, activities)
@ -339,10 +333,7 @@ class _KinetoProfile:
To use shape/stack functionality make sure to set record_shapes/with_stack
when creating profiler context manager.
"""
if self.profiler is None:
raise AssertionError(
"Profiler must be initialized before getting key averages"
)
assert self.profiler
return self.profiler.key_averages(
group_by_input_shape, group_by_stack_n, group_by_overload_name
)
@ -352,8 +343,7 @@ class _KinetoProfile:
Returns the list of unaggregated profiler events,
to be used in the trace callback or after the profiling is finished
"""
if self.profiler is None:
raise AssertionError("Profiler must be initialized before accessing events")
assert self.profiler
return self.profiler.function_events
def add_metadata(self, key: str, value: str) -> None:
@ -405,10 +395,7 @@ class _KinetoProfile:
if missing:
raise ValueError(f"{', '.join(missing)} required for memory profiling.")
if self.profiler is None or self.profiler.kineto_results is None:
raise AssertionError(
"Profiler and kineto_results must be initialized for memory profiling"
)
assert self.profiler is not None and self.profiler.kineto_results is not None
return MemoryProfile(self.profiler.kineto_results)
def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
@ -498,8 +485,7 @@ def schedule(
"""
def schedule_fn(step: int) -> ProfilerAction:
if step < 0:
raise AssertionError(f"Step must be non-negative. Got {step}.")
assert step >= 0
if step < skip_first:
return ProfilerAction.NONE
else:
@ -522,11 +508,9 @@ def schedule(
else ProfilerAction.RECORD_AND_SAVE
)
if wait < 0 or warmup < 0 or active <= 0 or repeat < 0 or skip_first < 0:
raise AssertionError(
f"Invalid profiler schedule arguments. Got wait={wait} (need >= 0), warmup={warmup} (need >= 0), "
f"active={active} (need > 0), repeat={repeat} (need >= 0), skip_first={skip_first} (need >= 0)."
)
assert (
wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0
), "Invalid profiler schedule arguments"
if warmup == 0:
warn(
"Profiler won't be using warmup, this can skew profiler results",
@ -733,8 +717,7 @@ class profile(_KinetoProfile):
activities_set.add(ProfilerActivity.CUDA)
elif ProfilerActivity.CUDA in activities_set:
activities_set.remove(ProfilerActivity.CUDA)
if len(activities_set) == 0:
raise AssertionError("No valid profiler activities found")
assert len(activities_set) > 0, "No valid profiler activities found"
super().__init__(
activities=activities,

View File

@ -15,8 +15,8 @@ collection support for PyTorch APIs.
import functools
import types
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Optional, overload, TypeVar, Union
from typing_extensions import deprecated, Self, TypeAlias, TypeIs
from typing import Any, Optional, overload, TypeAlias, TypeVar, Union
from typing_extensions import deprecated, Self, TypeIs
import torch.utils._pytree as python_pytree
from torch.torch_version import TorchVersion as _TorchVersion

View File

@ -2,7 +2,9 @@
import contextlib
import functools
import traceback
from typing import Any, Callable, Optional, TYPE_CHECKING
import weakref
from collections.abc import Callable
from typing import Any, Optional, TYPE_CHECKING
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
@ -14,6 +16,7 @@ from torch.utils._python_dispatch import (
)
from torch.utils._pytree import tree_all, tree_map
from torch.utils._traceback import CapturedTraceback
from torch.utils.weak import WeakIdRef
if TYPE_CHECKING:
@ -56,29 +59,48 @@ def _stringify_dtensor_spec(spec) -> str:
return DTensorSpec.format_shard_order_str(spec.placements, spec.shard_order)
def _tensor_debug_string(tensor, attributes) -> str:
class TensorIdTracker:
def __init__(self):
self.tensor_memo: dict[WeakIdRef, int] = {}
self.next_tensor_id = 0
def _id(self, tensor) -> int:
with torch._C._DisablePythonDispatcher():
o = WeakIdRef(tensor)
def del_memo():
self.tensor_memo.pop(o, None)
weakref.finalize(tensor, del_memo)
if o not in self.tensor_memo:
self.tensor_memo[o] = self.next_tensor_id
self.next_tensor_id += 1
return self.tensor_memo[o]
def _tensor_debug_string(tensor, attributes, tensor_memo=None) -> str:
"""Convert tensor to debug string representation."""
if isinstance(tensor, torch.Tensor):
tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}"
id_str = f"${tensor_memo._id(tensor)}" if tensor_memo is not None else ""
if isinstance(tensor, torch.distributed.tensor.DTensor):
# omitted device mesh
return f"dt: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}"
return f"dt{id_str}: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}"
elif isinstance(tensor, FakeTensor):
return f"ft: {tensor_debug_str}"
return f"ft{id_str}: {tensor_debug_str}"
else:
return f"t: {tensor_debug_str}"
return f"t{id_str}: {tensor_debug_str}"
else:
raise RuntimeError(f"Unsupported tensor type: {type(tensor)}")
def _arg_to_str(arg, attributes) -> str:
def _arg_to_str(arg, attributes, tensor_memo=None) -> str:
from torch.distributed.tensor._dtensor_spec import DTensorSpec
def to_str(x):
if isinstance(x, torch.Tensor):
return _tensor_debug_string(x, attributes)
return _tensor_debug_string(x, attributes, tensor_memo)
elif isinstance(x, DTensorSpec):
return _stringify_dtensor_spec(x)
return x
@ -144,8 +166,11 @@ class _DebugCall:
# results from dispatch hooks
self.record = record
self.log = log
self.output_str: Optional[str] = None
def stringify_args(self, attributes: list[str]) -> None:
def stringify_args(
self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None
) -> None:
"""
To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs.
"""
@ -153,6 +178,18 @@ class _DebugCall:
"Subclasses must implement stringify_args(), even if no-op"
)
def stringify_output(
self,
output: Any,
attributes: list[str],
tensor_memo: Optional[TensorIdTracker] = None,
) -> None:
"""Store stringified version of call output in self.output_str"""
if tree_all(lambda x: x is None, output):
return
output_str = tree_map(lambda x: _arg_to_str(x, attributes, tensor_memo), output)
self.output_str = f" -> {str(output_str)}"
def render(self, attributes: list[str]) -> str:
raise NotImplementedError("Subclasses must implement string render()")
@ -179,11 +216,16 @@ class _OpCall(_DebugCall):
self.args_str: Optional[str] = None
self.kwargs_str: Optional[str] = None
def stringify_args(self, attributes: list[str]) -> None:
self.args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args)
def stringify_args(
self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None
) -> None:
self.args_str = ", ".join(
_arg_to_str(arg, attributes, tensor_memo) for arg in self.args
)
if self.kwargs:
self.kwargs_str = ", " + ", ".join(
f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items()
f"{k}={_arg_to_str(v, attributes, tensor_memo)}"
for k, v in self.kwargs.items()
)
else:
self.kwargs_str = ""
@ -215,6 +257,8 @@ class _OpCall(_DebugCall):
base_str = f"{op_name}({args_str}{kwargs_str})"
if self.output_str:
base_str += self.output_str
if self.log:
base_str += f" # {self.log}"
return base_str
@ -247,8 +291,10 @@ class _RedistributeCall(_DebugCall):
self.arg_str: Optional[str] = None
def stringify_args(self, attributes: list[str]) -> None:
self.arg_str = f"{_arg_to_str(self.arg, attributes)}"
def stringify_args(
self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None
) -> None:
self.arg_str = f"{_arg_to_str(self.arg, attributes, tensor_memo)}"
del self.arg
def render(self, attributes: list[str]) -> str:
@ -263,7 +309,11 @@ class _RedistributeCall(_DebugCall):
src_placement_str = _arg_to_str(self.src_placement, attributes)
dst_placement_str = _arg_to_str(self.dst_placement, attributes)
placement_str = f"{src_placement_str} -> {dst_placement_str}"
return f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})"
base_str = f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})"
if self.output_str:
base_str += self.output_str
return base_str
def __iter__(self):
# for BC; tuple(self) returns (op, placement info, kwargs, call_depth)
@ -288,7 +338,9 @@ class _NNModuleCall(_DebugCall):
super().__init__(call_depth, stack=stack)
self.module_name = module_name
def stringify_args(self, attributes: list[str]) -> None:
def stringify_args(
self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None
) -> None:
pass # nothing to stringify
def render(self, attributes: list[str]) -> str:
@ -341,6 +393,8 @@ class DebugMode(TorchDispatchMode):
record_nn_module=False,
store_original_args=False,
record_stack_trace=False,
record_output=False,
record_ids=False,
):
super().__init__()
import torch.distributed.tensor # noqa: F401
@ -378,8 +432,24 @@ class DebugMode(TorchDispatchMode):
# e.g. via DebugMode(record_stack_trace=True), or torch.autograd.set_detect_anomaly().
self.record_stack_trace = record_stack_trace
# Records call outputs in logs (e.g. for __torch_dispatch__, __torch_function__, redistribute_input)
self.record_output: bool = record_output
# Annotates string dumps with graph-style tensor ids, e.g. op($1, $2) -> $3.
self.record_ids: bool = record_ids
self.reset()
def reset(self):
self.operators = []
self.call_depth = 0
self._tensor_memo = TensorIdTracker()
self._output_info: dict[int, object] = {}
def _track_op_output(self, op_index, result):
"""Assign IDs to output tensors and store in output_info"""
# self._track_tensor_ids(result)
self._output_info[op_index] = result
# Without this override, running torch.compile under DebugMode
# will force torch.compile to always use the “eager” backend
@ -390,20 +460,35 @@ class DebugMode(TorchDispatchMode):
def _record_call(self, call):
if not self.store_original_args:
call.stringify_args(self.record_tensor_attributes)
call.stringify_args(
self.record_tensor_attributes,
self._tensor_memo if self.record_ids else None,
)
self.operators.append(call)
def _record_call_output(self, call, output):
if not self.record_output:
return
call.stringify_output(
output,
self.record_tensor_attributes,
self._tensor_memo if self.record_ids else None,
)
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
self._record_call(
_OpCall(func, args, kwargs, self.call_depth, stack=self.record_stack_trace)
call = _OpCall(
func, args, kwargs, self.call_depth, stack=self.record_stack_trace
)
self._record_call(call)
try:
self.call_depth += 1
return func(*args, **kwargs)
result = func(*args, **kwargs)
self._record_call_output(call, result)
return result
finally:
self.call_depth -= 1
@ -445,13 +530,13 @@ class DebugMode(TorchDispatchMode):
result = func(*args, **kwargs)
if call:
self._record_call_output(call, result)
_run_dispatch_hooks(call, func, types, args, kwargs, result)
return result
def __enter__(self):
self.operators = []
self.call_depth = 0
self.reset()
if self.record_torchfunction:
torch._C._push_on_torch_function_stack(self)

View File

@ -36,10 +36,11 @@ from typing import (
Optional,
overload,
Protocol,
TypeAlias,
TypeVar,
Union,
)
from typing_extensions import deprecated, NamedTuple, Self, TypeAlias
from typing_extensions import deprecated, NamedTuple, Self
from torch.torch_version import TorchVersion as _TorchVersion