Compare commits

..

2 Commits

Author SHA1 Message Date
fd7d074666 lint 2025-11-05 14:01:29 -08:00
1b3088d04f tc 2025-11-05 14:01:08 -08:00
59 changed files with 773 additions and 1942 deletions

View File

@ -271,16 +271,6 @@ case "$tag" in
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
pytorch-linux-jammy-aarch64-py3.10-clang21)
ANACONDA_PYTHON_VERSION=3.10
CLANG_VERSION=21
ACL=yes
VISION=yes
OPENBLAS=yes
# snadampal: skipping llvm src build install because the current version
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=11

View File

@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
# work around ubuntu apt-get conflicts
sudo apt-get -y -f install
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
if [[ $CLANG_VERSION -ge 18 ]]; then
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main"
if [[ $CLANG_VERSION == 18 ]]; then
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
fi
fi

View File

@ -129,7 +129,7 @@ function install_129 {
}
function install_128 {
CUDNN_VERSION=9.8.0.87
CUDNN_VERSION=9.10.2.21
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
# install CUDA 12.8.1 in the same container
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux

View File

@ -10,7 +10,6 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
OPENBLAS_BUILD_FLAGS="
CC=gcc
NUM_THREADS=128
USE_OPENMP=1
NO_SHARED=0

View File

@ -272,6 +272,18 @@ def smoke_test_cuda(
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
print(f"Torch cuDNN version: {torch_cudnn_version}")
torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion()
print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}")
torch_cudnn_runtime_version = tuple(
[int(x) for x in torch_cudnn_version.split(".")]
)
if torch_cudnn_runtime_version != torch_cudnn_compile_version:
raise RuntimeError(
"cuDNN runtime version doesn't match comple version. "
f"Loaded: {torch_cudnn_runtime_version} "
f"Expected: {torch_cudnn_compile_version}"
)
if sys.platform in ["linux", "linux2"]:
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
print(f"Torch nccl; version: {torch_nccl_version}")

View File

@ -79,8 +79,6 @@ jobs:
include:
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21
runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
runner: linux.arm64.m7g.4xlarge
timeout-minutes: 600

View File

@ -18,7 +18,7 @@ aspects of contributing to PyTorch.
- [Python Unit Testing](#python-unit-testing)
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
- [Local linting](#local-linting)
- [Running `pyrefly`](#running-pyrefly)
- [Running `mypy`](#running-mypy)
- [C++ Unit Testing](#c-unit-testing)
- [Run Specific CI Jobs](#run-specific-ci-jobs)
- [Merging your Change](#merging-your-change)
@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory.
**Prerequisites**:
The following packages should be installed with `pip`:
- `expecttest` and `hypothesis` - required to run tests
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
- `mypy` - recommended for linting
- `pytest` - recommended to run tests more selectively
Running
```
@ -350,32 +350,15 @@ make lint
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
#### Running `pyrefly`
#### Running `mypy`
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
**Getting Started with Pyrefly:**
To run type checking on the PyTorch codebase:
```bash
pyrefly check
```
For more detailed error information with summaries:
```bash
pyrefly check --summarize-errors
```
**Learn More:**
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
`mypy` is an optional static type checker for Python. We have multiple `mypy`
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
See [Guide for adding type annotations to
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
for more information on how to set up `mypy` and tackle type annotation
tasks.
### C++ Unit Testing

View File

@ -226,8 +226,8 @@ template <
typename B = HostBlock<S>>
struct CachingHostAllocatorImpl {
virtual ~CachingHostAllocatorImpl() {
if (active_) {
active_ = false;
active_ = false;
if (pinned_use_background_threads()) {
getBackgroundThreadPool()->waitWorkComplete();
}
}
@ -260,7 +260,6 @@ struct CachingHostAllocatorImpl {
if (pinned_use_background_threads()) {
// Launch the background thread and process events in a loop.
static bool background_thread_flag [[maybe_unused]] = [this] {
active_ = true;
getBackgroundThreadPool()->run([&]() {
while (active_) {
process_events();
@ -684,9 +683,9 @@ struct CachingHostAllocatorImpl {
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
std::deque<std::pair<E, B*>> events_; // event queue paired with block
// Indicates whether the event-processing thread pool is active.
// Indicates whether the object is active.
// Set to false in the destructor to signal background threads to stop.
std::atomic<bool> active_{false};
std::atomic<bool> active_{true};
protected:
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
};

View File

@ -24,13 +24,7 @@ namespace detail {
// radix_sort_pairs doesn't interact with value_t other than to copy
// the data, so we can save template instantiations by reinterpreting
// it as an opaque type.
// We use native integer types for 1/2/4/8-byte values to reduce
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
template <int N> struct alignas(N) OpaqueType { char data[N]; };
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
template<typename key_t, int value_size>
void radix_sort_pairs_impl(

View File

@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
});
}
template <typename func_t, typename vec_func_t, typename ident_t = double>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast<ident_t>(0)) {
template <typename func_t, typename vec_func_t>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
using traits = binary_function_traits<func_t>;
static_assert(
all_same<

View File

@ -339,13 +339,33 @@ void or_kernel_impl(TensorIterator& iter) {
}
}
template<typename scalar_t>
struct MinValuesOps: public at::native::MinOps<scalar_t> {
using arg_t = typename MinOps<scalar_t>::arg_t;
static scalar_t project(arg_t arg) {
return arg.first;
}
};
void min_values_kernel_impl(TensorIterator& iter) {
// This case is special because of Vectorized<int64_t> does not
// handle upper_bound<int64_t>().
// See: https://github.com/pytorch/pytorch/issues/43254
if (iter.dtype() == kLong || iter.dtype() == kUInt64) {
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce(
iter,
MinValuesOps<scalar_t>{},
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
}), kLong, kUInt64);
return;
}
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
upper_bound<scalar_t>());
static_cast<double>(upper_bound<scalar_t>()));
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}

View File

@ -22,9 +22,6 @@
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#ifdef USE_ROCM
#include <ATen/native/hip/ck_group_gemm.h>
#endif
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
@ -669,19 +666,12 @@ std::optional<c10::ScalarType> out_dtype) {
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
use_fast_path = true;
}
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
if (use_fast_path) {
// fast path, no d2h sync needed
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#endif
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
}

View File

@ -1,19 +0,0 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/ScalarType.h>
#include <optional>
namespace at {
namespace hip {
namespace detail {
void group_gemm_ck(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
at::Tensor& out);
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -1,462 +0,0 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/hip/HIPContext.h>
#include <ATen/Tensor.h>
#include <ATen/TensorAccessor.h>
#include <c10/hip/HIPStream.h>
#include <iostream>
#include <vector>
#include <optional>
#include <type_traits>
#include <ck/ck.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/utility/tuple.hpp>
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
namespace at {
namespace hip {
namespace detail {
namespace CkTypes {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
}
template <typename ALayout, typename BLayout, typename DataType>
using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage<
ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor,
DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType,
CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
1, 1,
S<1,32,1,8>, 4
>;
template <typename ALayout, typename BLayout, typename DataType>
void launch_grouped_bgemm_ck_impl_dispatch(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
at::Tensor& out)
{
using DeviceOp = GroupedGemmKernel<ALayout, BLayout, DataType>;
using PassThrough = CkTypes::PassThrough;
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a_ptrs, p_b_ptrs;
std::vector<void*> p_e_ptrs;
// Note: d_ptrs will be resized after we populate the other vectors
const int mat_a_dim = mat_a.dim();
const int mat_b_dim = mat_b.dim();
const char* a_ptr_base = reinterpret_cast<const char*>(mat_a.data_ptr());
const char* b_ptr_base = reinterpret_cast<const char*>(mat_b.data_ptr());
char* out_ptr_base = reinterpret_cast<char*>(out.data_ptr());
const size_t a_element_size = mat_a.element_size();
const size_t b_element_size = mat_b.element_size();
const size_t out_element_size = out.element_size();
// for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses.
if (mat_a_dim == 2 && mat_b_dim == 2) {
// 2D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
const int M = mat_a.size(0); // number of rows in A
const int N = mat_b.size(1); // number of columns in B
const int K = mat_a.size(1); // columns in A == rows in B
// for 2d*2d input, output is 3d.
// for each group, A columns (K) are sliced. M and N dimensions are not sliced.
for (int i = 0; i < num_groups; ++i) {
int start_k = (i == 0) ? 0 : offs_accessor[i-1];
int end_k = offs_accessor[i];
int k = end_k - start_k;
//K dimension are sliced, hence select stride(1) always.
//K dimension is always dimension 1, regardless of memory layout (row/column major)
const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size;
const void* group_b_ptr;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size;
// Leading dimension = distance between rows = stride(0)
ldb = mat_b.stride(0);
} else {
// Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size;
// Leading dimension = distance between columns = stride(1)
ldb = mat_b.stride(1);
}
// Calculate output pointer for group i in 3D tensor [num_groups, M, N]
// stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [M,K]: leading dimension = distance between rows = stride(0)
lda = mat_a.stride(0);
} else {
// Column-major A [M,K]: leading dimension = distance between columns = stride(1)
lda = mat_a.stride(1);
}
// Output is always row-major in 3D tensor [num_groups, M, N]
// Leading dimension for each group's [M,N] slice = stride(1) = N
ldc = out.stride(1);
size_t output_group_bytes = M * N * out_element_size;
void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes;
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(k),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 2 && mat_b_dim == 3) {
// 2D*3D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 2d*3d input, output is 2d.
// A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n]
// Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B
const int K = mat_a.size(1); // columns in A
// For 2D-3D case: The output determines N (result width)
const int N = out.size(1); // N is the width of the output tensor
for (int i = 0; i < num_groups; ++i) {
int start_m = (i == 0) ? 0 : offs_accessor[i - 1];
int end_m = offs_accessor[i];
int m = end_m - start_m;
// Skip zero-sized groups but continue processing subsequent groups
if (m <= 0) {
continue;
}
// Select A rows for group i: skip start_m rows
const void* group_a_ptr;
int lda;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
lda = mat_a.stride(0); // distance between rows
} else {
// Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows)
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
// Detect stride pattern for A tensor to determine appropriate lda calculation
bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0));
if (a_is_strided_tensor) {
// For strided A tensors: stride(0) gives the actual leading dimension
lda = mat_a.stride(0);
} else {
// For non-strided A tensors: use the M dimension (total rows)
lda = mat_a.size(0); // Total M dimension for column-major layout
}
}
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed
ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N]
} else {
// Detect stride pattern to determine appropriate ldb calculation
bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2));
if (is_strided_tensor) {
// For strided tensors: stride(2) gives the actual leading dimension
ldb = mat_b.stride(2);
} else {
// For non-strided tensors: use the N dimension
ldb = mat_b.size(1);
}
}
// Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N]
void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size;
int ldc = out.stride(0); // distance between rows in output (should be N for 2D case)
gemm_descs.push_back({
static_cast<ck::index_t>(m),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 3) {
// 3d*3d input, output is 3d - batched matrix multiplication
// A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n]
// Each batch is processed as a separate GEMM operation
const int batch_size = mat_a.size(0);
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed)
// Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout
int N;
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
N = mat_b.size(2);
} else if (mat_b.size(2) == K) {
// B is [batch, n, k] - transposed layout
N = mat_b.size(1);
} else {
TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[",
batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]");
}
for (int i = 0; i < batch_size; ++i) {
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
// Select output batch for group i: Output[i, :, :]
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldb, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
} else {
// Column-major A: leading dimension = distance between columns = stride(2)
lda = mat_a.stride(2);
}
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B: leading dimension = distance between rows
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(1); // stride between K rows
} else {
// B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM
ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n])
}
} else {
// Column-major B: leading dimension = distance between columns
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(2); // stride between N columns
} else {
// B is [batch, n, k] - transposed layout
ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n])
}
}
// Output is typically row-major: leading dimension = distance between rows = stride(1)
ldc = out.stride(1);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 2) {
// 3D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 3d*2d input, output is 3d.
// A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both)
// Offset divides N dimension of B, each group gets different slice of B and different batch of A
const int batch_size = mat_a.size(0); // n_groups
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A
// For row-major A and B case: B should be [K, total_N]
const int total_N = mat_b.size(1); // B is [K, total_N] for row-major
for (int i = 0; i < num_groups; ++i) {
int start_n = (i == 0) ? 0 : offs_accessor[i - 1];
int end_n = offs_accessor[i];
int n = end_n - start_n;
// Skip zero-sized groups but continue processing subsequent groups
if (n <= 0) {
continue;
}
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B slice for group i: B[:, start_n:end_n] (B[K, total_N])
const void* group_b_ptr;
int ldb;
// Check if B is row-major or column-major
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(0); // distance between rows (should be total_N)
} else {
// Column-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(1); // distance between columns (should be K)
}
// Select output slice for group i: Output[:, start_n:end_n]
void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size;
int lda, ldc;
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
// Output is row-major: leading dimension = distance between rows = stride(0)
ldc = out.stride(0);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(n),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim);
}
TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups");
// Initialize d_ptrs with the correct size
std::vector<std::array<const void*, 0>> d_ptrs(p_a_ptrs.size());
static DeviceOp gemm_instance;
auto argument = gemm_instance.MakeArgument(
p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs,
gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}
);
TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument),
"CK Group GEMM: argument unsupported (shape/strides/type config)");
size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument);
size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument);
void* gemm_arg_buf = nullptr;
void* ws_buf = nullptr;
hipMalloc(&gemm_arg_buf, arg_buf_size);
hipMalloc(&ws_buf, ws_size);
gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf);
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
auto invoker = gemm_instance.MakeInvoker();
hipStream_t stream = c10::hip::getCurrentHIPStream();
invoker.Run(argument, {stream});
hipFree(gemm_arg_buf);
hipFree(ws_buf);
}
void group_gemm_ck(
const at::Tensor& input_a,
const at::Tensor& input_b_colmajor,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& /*bias*/,
at::Tensor& out)
{
// Detect if input_a is row-major based on stride pattern
bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1);
bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1);
// Ensure tensor A is row-major and contiguous if not already
at::Tensor mat_a = input_a;
if (!a_row_major) {
// If A is not row-major, make it contiguous (row-major)
mat_a = input_a.contiguous();
}
// Force tensor B to be column-major using double transpose trick
// This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape
at::Tensor mat_b = input_b_colmajor;
if (!b_col_major) {
mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1);
}
// For 3D tensors, check the last dimension stride for row-major detection
a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1);
bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1);
if (mat_a.dtype() == at::kBFloat16) {
// bf16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kHalf) {
// fp16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kFloat) {
// fp32 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype");
}
}
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -47,10 +47,20 @@ Tensor sgd_out_of_place(
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");
// testing Tensor strides + stride
STD_TORCH_CHECK(param.strides()[0] == param.stride(0));
int64_t *param_sizes;
int64_t *param_strides;
aoti_torch_get_sizes(param.get(), &param_sizes);
aoti_torch_get_strides(param.get(), &param_strides);
auto out = new_empty(param, param.sizes());
int32_t param_dtype;
aoti_torch_get_dtype(param.get(), &param_dtype);
int32_t param_device_type;
aoti_torch_get_device_type(param.get(), &param_device_type);
AtenTensorHandle out_ath;
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
auto out = Tensor(out_ath);
sgd_math(
reinterpret_cast<float*>(param.data_ptr()),
@ -334,8 +344,6 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
// Still using a std::vector below even though people can just pass in an
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
// directly.
// This is to test that passing in a std::vector works for BC. (It gets
// implicitly converted to HeaderOnlyArrayRef too!)
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
return new_empty(t, sizes, dtype);

View File

@ -5,16 +5,8 @@ import contextlib
import torch
import torch.distributed as dist
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -434,31 +426,6 @@ class TestDTensorDebugMode(TestCase):
][-1]
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
def test_pretty_print_dtensor_make_fx(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
A = torch.randn(8, 32)
B = torch.randn(32, 32)
dA = distribute_tensor(A, mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, mesh, [Replicate()]).requires_grad_()
def f(dA, dB):
dy = dA @ dB
loss = dy.sum()
loss.backward()
return dA.grad, dB.grad
# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode="fake")(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
# Colored is nice for actual viewing, not using in this test though
gm_str = gm.print_readable(colored=False, print_output=False)
self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str)
instantiate_parametrized_tests(TestDTensorDebugMode)

View File

@ -5789,229 +5789,6 @@ class NCCLTraceTest(NCCLTraceTestBase):
else:
self.assertTrue("duration_ms" not in t["entries"][0])
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_circular_buffer_full(self, timing_enabled):
"""
Test that when the circular buffer in entries_ is full and we call reset,
then fill the buffer with new entries, dump_entries returns only the new
entries and not the old ones.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill the buffer completely with 10 entries
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify buffer is full with 10 entries
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 10)
# Now reset the flight recorder
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Add new entries after reset - fill the buffer completely again
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify we get exactly 10 new entries, not 20
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 10)
# Verify all entries have the expected properties (from after reset)
# After reset, record IDs should start from 0 again
for i, entry in enumerate(t["entries"]):
self.assertIn("profiling_name", entry)
self.assertEqual(entry["profiling_name"], "nccl:all_reduce")
self.assertIn("record_id", entry)
# Record IDs should be sequential starting from 0 after reset
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_partial_overwrite(self, timing_enabled):
"""
Test that when the circular buffer is full, we reset, and then add fewer
entries than the buffer size, we only get the new entries.
This tests that old entries at the end of the circular buffer are properly
filtered out based on reset_epoch.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill the buffer completely
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Reset the flight recorder
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Add only 3 new entries (much less than buffer size)
for _ in range(3):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify we only get the 3 new entries, not 10
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 3)
# Verify record IDs start from 0 after reset
for i, entry in enumerate(t["entries"]):
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_wraparound(self, timing_enabled):
"""
Test that when we reset in the middle of the circular buffer and then
wrap around, dump_entries correctly returns only entries from the current
epoch in the correct order.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill half the buffer
for _ in range(5):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Reset at this point (reset happens at index 5)
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Now add 8 entries, which will wrap around
# (5->9 fills rest of buffer, then 0->2 wraps around)
for _ in range(8):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Should get exactly 8 entries, properly ordered
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 8)
# Entries should be in chronological order
# The dump_entries() method returns entries from next_ to end, then 0 to next_
# After filtering old entries, we should have 8 entries in order
# Verify record IDs start from 0 after reset (id_ is reset in reset_all())
for i, entry in enumerate(t["entries"]):
self.assertIn("profiling_name", entry)
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_multiple_resets(self, timing_enabled):
"""
Test multiple consecutive resets to ensure each reset properly increments
the epoch and filters out entries from previous epochs.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# First batch: 2 entries
for _ in range(2):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# First reset
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Second batch: 3 entries
for _ in range(3):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Second reset
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Third batch: 4 entries
for _ in range(4):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Should only see the last 4 entries
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 4)
# Verify record IDs start from 0 after the last reset
for i, entry in enumerate(t["entries"]):
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
def check_if_test_is_skipped(fn):
def wrapper(self, *args, **kwargs):

View File

@ -8,11 +8,21 @@ from torch._dynamo.graph_deduplication import apply_graph_deduplication
from torch._dynamo.graph_utils import _detect_cycles
from torch._dynamo.output_graph import FakeRootModule
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
extract_graph_and_tracker,
normalize_gm,
)
from torch.compiler import allow_in_graph
from torch.utils._ordered_set import OrderedSet
def extract_graph(fn, *args, **kwargs):
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs
def graph_str(gm):
return normalize_gm(gm.print_readable(print_output=False))
@ -30,7 +40,7 @@ class GraphDededuplicationTests(TestCase):
super().tearDown()
def run_and_return_graphs(self, fn, *args, **kwargs):
return extract_graph(fn, *args, **kwargs)[0:3]
return extract_graph(fn, *args, **kwargs)
def run_and_get_simple_graph(self):
def fn(x, y):

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: dynamo"]
import unittest
from collections.abc import Callable, Sequence
from typing import Any, Union
from collections.abc import Sequence
from typing import Any, Callable, Union
import torch
import torch._dynamo

View File

@ -13194,30 +13194,6 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
self.assertEqual(actual, expected)
@parametrize_pytree_module
def test_pytree_tree_map_dict_order(self, pytree):
def fn(tree):
new_tree = pytree.tree_map(lambda x: x, tree)
return list(new_tree.keys()), list(new_tree.values())
x = torch.randn(3, 2)
fn_opt = torch.compile(fullgraph=True)(fn)
tree1 = {"b": x + 2, "a": x, "c": x - 1}
expected1 = fn(tree1)
actual1 = fn_opt(tree1)
self.assertEqual(actual1, expected1)
tree2 = collections.OrderedDict([("b", x + 2), ("a", x), ("c", x - 1)])
expected2 = fn(tree2)
actual2 = fn_opt(tree2)
self.assertEqual(actual2, expected2)
tree3 = collections.defaultdict(int, {"b": x + 2, "a": x, "c": x - 1})
expected3 = fn(tree3)
actual3 = fn_opt(tree3)
self.assertEqual(actual3, expected3)
@parametrize_pytree_module
def test_pytree_tree_map_only(self, pytree):
if not callable(getattr(pytree, "tree_map_only", None)):

View File

@ -1,5 +1,5 @@
# Owner(s): ["module: dynamo"]
from typing import NamedTuple, Optional, TYPE_CHECKING
from typing import Callable, NamedTuple, Optional
import torch
import torch._dynamo
@ -7,10 +7,6 @@ from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter, same
if TYPE_CHECKING:
from collections.abc import Callable
"""
This is an example of a pure-python version of autograd implemented by
@zdevito. It represents a rather challenging test case for TorchDynamo

View File

@ -1,13 +1,11 @@
# Owner(s): ["module: dynamo"]
import functools
import re
import unittest
import weakref
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import extract_graph, remove_trailing_space
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import requires_cuda
@ -17,14 +15,6 @@ requires_multigpu = functools.partial(
)
def remove_file_comment(gm_str: str) -> str:
return remove_trailing_space(re.sub(r"File.*\n", "\n", gm_str))
def print_graph(graph: torch.fx.GraphModule) -> str:
return remove_file_comment(graph.print_readable())
class TestStreams(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
@ -46,7 +36,9 @@ class TestStreams(torch._dynamo.test_case.TestCase):
@requires_cuda
def test_stream_enter_exit(self):
def fn(x, y, s1, s2):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
@ -55,36 +47,13 @@ class TestStreams(torch._dynamo.test_case.TestCase):
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(), torch.Stream())
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
fn_opt = torch.compile(fn, fullgraph=True)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
""",
)
@requires_cuda
@unittest.skip("Needs graph break support with annotation context")
def test_stream_context_graph_break(self):
def fn(x, y):
s2 = torch.Stream()
@ -101,16 +70,9 @@ class <lambda>(torch.nn.Module):
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
fn_opt = torch.compile(fn)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
self.assertEqual(len(fw_graphs), 2)
self.assertExpectedInline(print_graph(fw_graphs[0]), """""")
self.assertExpectedInline(print_graph(fw_graphs[1]), """""")
@requires_cuda
def test_stream_input(self):
@ -193,248 +155,22 @@ class <lambda>(torch.nn.Module):
self.assertEqual(s_act, s_exp)
def test_nested_stream_enter_exit(self):
def fn(x, y, s0, s1, s2):
with s1:
with s2:
z1 = torch.add(x, y)
with s0:
z0 = torch.add(x, y)
with s2:
y = 2 + z1
pass
return z0, y
inp = (
torch.ones(2, 2) + 1,
torch.ones(2, 2),
torch.Stream(),
torch.Stream(),
torch.Stream(),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
)
@unittest.skip("Needs graph break support with annotation context")
def test_stream_enter_exit_graph_break(self):
pass
@unittest.skip("Needs graph break support with annotation context")
def test_nested_stream_enter_exit_graph_break(self):
pass
def test_local_stream_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
z = torch.add(x, y)
y = z + 2 + z1
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 1}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': 0}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
""",
)
pass
def test_local_stream_nested_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
s0 = torch.Stream()
with s1:
with s2:
z1 = torch.add(x, y)
with s0:
z0 = torch.add(x, y)
with s2:
y = 2 + z1
return z0, y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': 0}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
)
pass
def test_stream_with_mutation(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
s0 = torch.Stream()
with s1:
with s2:
x.add_(y)
with s0:
z1 = torch.add(y, y)
z0 = torch.add(z1, y)
with s2:
y = 2 + z1
return z0, y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1)
# Annotation: {'stream': 2}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, arg1_1); arg1_1 = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
#
copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
return (add_2, add_3)
""",
)
def test_stream_backward(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
with s0:
y0 = 2 * x + y
with s2:
z = 2 * x + y
return y0, z
inp = (
torch.ones(2, 2, requires_grad=True) + 1,
torch.ones(2, 2, requires_grad=True),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
bw_graphs,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
# Annotation: {'stream': 1}
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
return (add, add_1)
""",
)
actual[1].sum().backward()
self.assertExpectedInline(
print_graph(bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
# Annotation: {'stream': 0}
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
#
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
#
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",
)
pass
@requires_cuda
def test_run_opcheck(self):

View File

@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca
torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node
torch.fx.graph.Graph.print_tabular(self)
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode
torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')
torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool
torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None

View File

@ -73,22 +73,7 @@ from tools.testing.test_selections import (
ShardedTest,
THRESHOLD,
)
try:
from tools.testing.upload_artifacts import (
parse_xml_and_upload_json,
zip_and_upload_artifacts,
)
except ImportError:
# some imports in those files might fail, e.g., boto3 not installed. These
# functions are only needed under specific circumstances (CI) so we can
# define dummy functions here.
def parse_xml_and_upload_json():
pass
def zip_and_upload_artifacts(failed: bool):
pass
from tools.testing.upload_artifacts import zip_and_upload_artifacts
# Make sure to remove REPO_ROOT after import is done
@ -797,6 +782,11 @@ def run_test_retries(
if not continue_through_error:
print_to_file("Stopping at first consistent failure")
break
if current_failure == f"'{test_file}'":
# Something failed consistently in the entire suite and we don't
# know how to skip only that one individual test, so stop here
print_to_file("Entire suite failed consistently, stopping here")
break
sc_command = f"--scs={stepcurrent_key}"
print_to_file(
"Test failed consistently, "
@ -1902,7 +1892,6 @@ def run_tests(
def handle_complete(failure: Optional[TestFailure]):
failed = failure is not None
if IS_CI and options.upload_artifacts_while_running:
parse_xml_and_upload_json()
zip_and_upload_artifacts(failed)
if not failed:
return False

View File

@ -1,176 +0,0 @@
# Owner(s): ["oncall: pt2"]
from collections import deque
from typing import Optional
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
def get_state(t: torch.Tensor) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""Extract (sizes, strides) tuple from a tensor."""
return (tuple(t.size()), tuple(t.stride()))
def enumerate_reachable_states(
initial_size: int,
) -> set[tuple[tuple[int, ...], tuple[int, ...]]]:
"""
Use BFS with DP to enumerate all reachable (size, stride) states from
a 1D contiguous tensor via valid view operations.
We only explore states with offset=0 (you can retroactively change the offset).
We reject states with size=0 or size=1 dimensions as they are degenerate.
"""
# Create initial 1D contiguous tensor
initial_tensor = torch.arange(initial_size)
initial_state = get_state(initial_tensor)
# Map from state to tensor for that state
state_to_tensor: dict[tuple[tuple[int, ...], tuple[int, ...]], torch.Tensor] = {
initial_state: initial_tensor
}
visited: set[tuple[tuple[int, ...], tuple[int, ...]]] = {initial_state}
queue: deque[tuple[tuple[int, ...], tuple[int, ...]]] = deque([initial_state])
while queue:
state = queue.popleft()
t = state_to_tensor[state]
sizes, strides = state
ndim = len(sizes)
def add_state(new_t: torch.Tensor) -> None:
new_state = get_state(new_t)
sizes, strides = new_state
# Skip if has size-0 or size-1 dimensions
if any(s == 0 or s == 1 for s in sizes):
return
# Only accept states where strides are in descending order
if list(strides) != sorted(strides, reverse=True):
return
if new_state not in visited:
visited.add(new_state)
queue.append(new_state)
state_to_tensor[new_state] = new_t
# 1. Unflatten: try factoring each dimension
for dim in range(ndim):
size = sizes[dim]
assert size > 1
# Try all factorizations x * y = size where both x, y >= 2
# We only need to check x up to size // 2 since when x > size // 2,
# y = size // x < 2, which we reject
for x in range(2, size // 2 + 1):
if size % x == 0:
y = size // x
add_state(t.unflatten(dim, (x, y)))
# 2. Slice: exhaustively check all possible slicing parameters
for dim in range(ndim):
size = sizes[dim]
for start in range(size):
for stop in range(start + 1, size + 1):
for step in range(1, size + 1):
slices = [slice(None)] * ndim
slices[dim] = slice(start, stop, step)
add_state(t[tuple(slices)])
# 3. Flatten: merge adjacent dimensions
for dim in range(ndim - 1):
add_state(t.flatten(dim, dim + 1))
return visited
class TestAsStrided(TestCase):
def test_size_10_exhaustive(self) -> None:
"""Test that size 10 produces exactly the expected 54 states."""
expected_states = {
((2,), (1,)),
((2,), (2,)),
((2,), (3,)),
((2,), (4,)),
((2,), (5,)),
((2,), (6,)),
((2,), (7,)),
((2,), (8,)),
((2,), (9,)),
((2, 2), (2, 1)),
((2, 2), (3, 1)),
((2, 2), (3, 2)),
((2, 2), (4, 1)),
((2, 2), (4, 2)),
((2, 2), (4, 3)),
((2, 2), (5, 1)),
((2, 2), (5, 2)),
((2, 2), (5, 3)),
((2, 2), (5, 4)),
((2, 2), (6, 1)),
((2, 2), (6, 2)),
((2, 2), (6, 3)),
((2, 2), (8, 1)),
((2, 2, 2), (4, 2, 1)),
((2, 2, 2), (5, 2, 1)),
((2, 3), (3, 1)),
((2, 3), (4, 1)),
((2, 3), (5, 1)),
((2, 3), (5, 2)),
((2, 3), (6, 1)),
((2, 4), (4, 1)),
((2, 4), (5, 1)),
((2, 5), (5, 1)),
((3,), (1,)),
((3,), (2,)),
((3,), (3,)),
((3,), (4,)),
((3, 2), (2, 1)),
((3, 2), (3, 1)),
((3, 2), (3, 2)),
((3, 2), (4, 1)),
((3, 3), (3, 1)),
((4,), (1,)),
((4,), (2,)),
((4,), (3,)),
((4, 2), (2, 1)),
((5,), (1,)),
((5,), (2,)),
((5, 2), (2, 1)),
((6,), (1,)),
((7,), (1,)),
((8,), (1,)),
((9,), (1,)),
((10,), (1,)),
}
actual_states = enumerate_reachable_states(10)
self.assertEqual(len(actual_states), 54)
self.assertEqual(actual_states, expected_states)
def test_subset_property(self) -> None:
"""
Test that for sizes 2..10, each smaller tensor results in a strict
subset of possible states compared to the next one.
"""
prev_states: Optional[set[tuple[tuple[int, ...], tuple[int, ...]]]] = None
for size in range(2, 11):
current_states = enumerate_reachable_states(size)
if prev_states is not None:
# Check that prev_states is a strict subset of current_states
self.assertTrue(
prev_states.issubset(current_states),
f"States from size {size - 1} are not a subset of size {size}",
)
# Check that it's a strict subset (not equal)
self.assertTrue(
len(prev_states) < len(current_states),
f"States from size {size - 1} should be strictly fewer than size {size}",
)
prev_states = current_states
if __name__ == "__main__":
run_tests()

View File

@ -75,6 +75,12 @@ from torch.testing._internal.common_utils import (
)
from torch.testing._internal.jit_utils import JitTestCase
import json
import tempfile
from torch.profiler import profile, ProfilerActivity
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
from torch.autograd.profiler_util import _canonicalize_profiler_events
try:
from torchvision import models as torchvision_models
@ -201,6 +207,36 @@ def side_effect_func(x: torch.Tensor):
print(x)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
trace_file = f.name
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(
trace_data
)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
class TestFX(JitTestCase):
def setUp(self):
super().setUp()
@ -4212,6 +4248,150 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
# recorver mutable checking flag
torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
augments profiler events with stack traces from FX metadata registry.
"""
# Simple test model
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(16, 10)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
model = TestModel().cuda()
# Compile the model
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(torch.randn(10, 10, device="cuda"))
# Profile with the compiled model
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result = compiled_model(torch.randn(10, 10, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::t node=t stack_trace=x = self.linear1(x)
event=aten::transpose node=t stack_trace=x = self.linear1(x)
event=aten::as_strided node=t stack_trace=x = self.linear1(x)
event=aten::addmm node=addmm stack_trace=x = self.linear1(x)
event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x)
event=aten::relu node=relu stack_trace=x = self.relu(x)
event=aten::clamp_min node=relu stack_trace=x = self.relu(x)
event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x)
event=aten::t node=t_1 stack_trace=x = self.linear2(x)
event=aten::transpose node=t_1 stack_trace=x = self.linear2(x)
event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x)
event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x)
event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_multiple_modules(self):
"""
Test that multiple compiled modules under the same profiler session
have their events correctly augmented with stack traces.
"""
class ModelA(torch.nn.Module):
def forward(self, x):
return x + 1
class ModelB(torch.nn.Module):
def forward(self, x):
return x - 1
model_a = ModelA().cuda()
model_b = ModelB().cuda()
# Compile both models
compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True)
compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_a(torch.randn(10, 10, device="cuda"))
_ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
# Profile both models in the same session
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result_a = compiled_a(torch.randn(10, 10, device="cuda"))
result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::add node=add stack_trace=return x + 1
event=cudaLaunchKernel node=add stack_trace=return x + 1
event=aten::sub node=sub stack_trace=return x - 1
event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_nested_graph_modules(self):
"""
Test that nested graph modules (e.g., graph modules calling subgraphs)
have their events correctly augmented with stack traces.
"""
# Model with nested structure
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.c = 5
@torch.compiler.nested_compile_region
def forward(self, x, y):
m = torch.mul(x, y)
s = m.sin()
a = s + self.c
return a
model = Mod().cuda()
# Compile the model (this may create nested graph modules)
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
# Profile
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::mul node=mul stack_trace=m = torch.mul(x, y)
event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y)
event=aten::sin node=sin stack_trace=s = m.sin()
event=cudaLaunchKernel node=sin stack_trace=s = m.sin()
event=aten::add node=add stack_trace=a = s + self.c
event=cudaLaunchKernel node=add stack_trace=a = s + self.c"""
)
def run_getitem_target():
from torch.fx._symbolic_trace import _wrapped_methods_to_patch

View File

@ -490,6 +490,8 @@ class TestMatmulCuda(InductorTestCase):
@parametrize("b_row_major", [False, True])
@dtypes(torch.bfloat16, torch.float32, torch.float16)
def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype):
if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]:
self.skipTest("failed using hipblaslt on rocm 6.4.2")
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 64, 4

View File

@ -601,24 +601,6 @@ class TestGenericPytree(TestCase):
for case in cases:
run_test(case)
@parametrize_pytree_module
def test_tree_map_dict_order(self, pytree):
d = {"b": 2, "a": 1, "c": 3}
od = OrderedDict([("b", 2), ("a", 1), ("c", 3)])
dd = defaultdict(int, {"b": 2, "a": 1, "c": 3})
for tree in (d, od, dd):
result = pytree.tree_map(lambda x: x, tree)
self.assertEqual(
list(result.keys()),
list(tree.keys()),
msg=f"Dictionary keys order changed in tree_map: {tree!r} vs. {result!r}",
)
self.assertEqual(
list(result.values()),
list(tree.values()),
msg=f"Dictionary keys order changed in tree_map: {tree!r} vs. {result!r}",
)
@parametrize_pytree_module
def test_tree_map_only(self, pytree):
self.assertEqual(pytree.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"])

View File

@ -1864,8 +1864,6 @@ class TestFP8Matmul(TestCase):
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
if torch.version.hip and recipe == "nvfp4":
raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping")
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")

View File

@ -1914,7 +1914,6 @@ class TestSDPAFailureModes(NNTestCase):
q, k, v, None, 0.0, is_causal=True))
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
batch_size = 2**16
query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
@ -1936,7 +1935,6 @@ class TestSDPAFailureModes(NNTestCase):
self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4)
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
@ -1950,7 +1948,6 @@ class TestSDPAFailureModes(NNTestCase):
@largeTensorTest("15GB", "cuda")
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_large_seq_len_uniform_attention(self):
device = torch.device("cuda")
dtype = torch.bfloat16

View File

@ -1,5 +1,5 @@
from typing import TypeAlias, Union
from typing_extensions import assert_type
from typing import Union
from typing_extensions import assert_type, TypeAlias
from torch import randn, Tensor

View File

@ -38,14 +38,12 @@ def parse_xml_report(
report: Path,
workflow_id: int,
workflow_run_attempt: int,
job_id: int | None = None,
) -> list[dict[str, Any]]:
"""Convert a test report xml file into a JSON-serializable list of test cases."""
print(f"Parsing {tag}s for test report: {report}")
if job_id is None:
job_id = get_job_id(report)
print(f"Found job id: {job_id}")
job_id = get_job_id(report)
print(f"Found job id: {job_id}")
test_cases: list[dict[str, Any]] = []

View File

@ -1,16 +1,11 @@
import glob
import gzip
import json
import os
import time
import zipfile
from functools import lru_cache
from pathlib import Path
from typing import Any, Optional
from filelock import FileLock, Timeout
from tools.stats.upload_test_stats import parse_xml_report
from typing import Any
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
@ -145,66 +140,3 @@ def trigger_upload_test_stats_intermediate_workflow() -> None:
},
)
print(x.text)
def parse_xml_and_upload_json() -> None:
"""
Parse xml test reports that do not yet have a corresponding json report
uploaded to s3, and upload the json reports to s3. Use filelock to avoid
uploading the same file from multiple processes.
"""
try:
job_id: Optional[int] = int(os.environ.get("JOB_ID", 0))
if job_id == 0:
job_id = None
except (ValueError, TypeError):
job_id = None
try:
for xml_file in glob.glob(
f"{REPO_ROOT}/test/test-reports/**/*.xml", recursive=True
):
xml_path = Path(xml_file)
json_file = xml_path.with_suffix(".json")
lock = FileLock(str(json_file) + ".lock")
try:
lock.acquire(timeout=0) # immediately fails if already locked
if json_file.exists():
continue # already uploaded
test_cases = parse_xml_report(
"testcase",
xml_path,
int(os.environ.get("GITHUB_RUN_ID", "0")),
int(os.environ.get("GITHUB_RUN_ATTEMPT", "0")),
job_id,
)
line_by_line_jsons = "\n".join([json.dumps(tc) for tc in test_cases])
gzipped = gzip.compress(line_by_line_jsons.encode("utf-8"))
s3_key = (
json_file.relative_to(REPO_ROOT / "test/test-reports")
.as_posix()
.replace("/", "_")
)
get_s3_resource().put_object(
Body=gzipped,
Bucket="gha-artifacts",
Key=f"test_jsons_while_running/{os.environ.get('GITHUB_RUN_ID')}/{job_id}/{s3_key}",
ContentType="application/json",
ContentEncoding="gzip",
)
# We don't need to save the json file locally, but doing so lets us
# track which ones have been uploaded already. We could probably also
# check S3
with open(json_file, "w") as f:
f.write(line_by_line_jsons)
except Timeout:
continue # another process is working on this file
finally:
if lock.is_locked:
lock.release()
except Exception as e:
print(f"Failed to parse and upload json test reports: {e}")

View File

@ -1,9 +1,8 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="type-arg"
from collections.abc import Callable
from datetime import timedelta
from enum import Enum
from typing import Any, Optional, overload, Union
from typing import Any, Callable, Optional, overload, Union
import torch
from torch import Tensor

View File

@ -6,7 +6,7 @@ from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING, TypeVar
from typing import Any, TYPE_CHECKING
from typing_extensions import TypeIs
import torch.utils._pytree as python_pytree
@ -24,15 +24,9 @@ if TYPE_CHECKING:
__all__: list[str] = []
_T = TypeVar("_T")
_KT = TypeVar("_KT")
_VT = TypeVar("_VT")
if python_pytree._cxx_pytree_dynamo_traceable:
import optree
import optree._C
import optree.utils
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
@ -606,47 +600,14 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_map_"]
_none_registration = optree.register_pytree_node.get(type(None))
assert _none_registration is not None
_none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr, attr-defined]
@substitute_in_graph( # type: ignore[arg-type]
_none_registration.unflatten_func,
_none_unflatten,
can_constant_fold_through=True,
skip_signature_check=True,
)
def none_unflatten(_: None, children: Iterable[_T], /) -> None:
def none_unflatten(_: None, children: Iterable[Any], /) -> None:
if len(list(children)) != 0:
raise ValueError("Expected no children.")
return None
with optree.dict_insertion_ordered(False, namespace="torch"):
_dict_registration = optree.register_pytree_node.get(dict)
assert _dict_registration is not None
@substitute_in_graph( # type: ignore[arg-type]
_dict_registration.flatten_func,
can_constant_fold_through=True,
skip_signature_check=True,
)
def dict_flatten(
dct: dict[_KT, _VT], /
) -> tuple[list[_VT], tuple[list[_KT], list[_KT]], tuple[_KT, ...]]:
sorted_keys = optree.utils.total_order_sorted(dct)
values = [dct[key] for key in sorted_keys]
original_keys = list(dct)
return values, (original_keys, sorted_keys), tuple(sorted_keys)
@substitute_in_graph( # type: ignore[arg-type]
_dict_registration.unflatten_func,
can_constant_fold_through=True,
skip_signature_check=True,
)
def dict_unflatten(
metadata: tuple[list[_KT], list[_KT]],
values: Iterable[_VT],
/,
) -> dict[_KT, _VT]:
original_keys, sorted_keys = metadata
d = dict.fromkeys(original_keys)
d.update(zip(sorted_keys, values))
return d # type: ignore[return-value]

View File

@ -3320,7 +3320,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, SetVariable)
assert obj.is_mutable()
obj.call_method(self, "add", [v], {}) # type: ignore[arg-type]
obj.call_method(self, "add", [v], {})
def SET_UPDATE(self, inst: Instruction) -> None:
v = self.pop()
@ -3329,7 +3329,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, SetVariable)
assert obj.is_mutable()
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
obj.call_method(self, "update", [v], {})
def LIST_APPEND(self, inst: Instruction) -> None:
v = self.pop()
@ -3637,7 +3637,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg].realize()
assert isinstance(obj, ConstDictVariable)
assert obj.is_mutable()
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
obj.call_method(self, "update", [v], {})
DICT_UPDATE = DICT_MERGE

View File

@ -87,12 +87,6 @@ def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-d
return gm.graph, region_tracker # type: ignore[union-attr]
def extract_graph(fn, *args, **kwargs): # type: ignore[no-untyped-def]
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs, backend.bw_graphs
def collect_results(
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
) -> list[Any]:

View File

@ -21,9 +21,9 @@ restoring state changes.
import inspect
import sys
import warnings
from collections.abc import Callable, Sequence, Sized
from collections.abc import Callable, Sequence
from contextlib import ExitStack
from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union
from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union
import torch._C
from torch._guards import Guard

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
"""
Dictionary-related variable tracking classes for PyTorch Dynamo.
@ -24,7 +26,7 @@ import inspect
import operator
import types
from collections.abc import Hashable as py_Hashable
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING
from torch._subclasses.fake_tensor import is_fake
@ -57,13 +59,11 @@ if TYPE_CHECKING:
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
def was_instancecheck_override(obj: Any) -> bool:
def was_instancecheck_override(obj):
return type(obj).__dict__.get("__instancecheck__", False)
def raise_unhashable(
arg: VariableTracker, tx: Optional["InstructionTranslator"] = None
) -> None:
def raise_unhashable(arg, tx=None):
if tx is None:
from torch._dynamo.symbolic_convert import InstructionTranslator
@ -75,7 +75,7 @@ def raise_unhashable(
)
def is_hashable(x: VariableTracker) -> bool:
def is_hashable(x):
# NB - performing isinstance check on a LazVT realizes the VT, accidentally
# inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at
# the underlying value without realizing the VT. Consider updating the
@ -143,7 +143,7 @@ class ConstDictVariable(VariableTracker):
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
"""
def __init__(self, vt: VariableTracker) -> None:
def __init__(self, vt) -> None:
# We specialize SymNodes
vt = specialize_symnode(vt)
# TODO Temporarily remove to figure out what keys are we breaking on
@ -153,7 +153,7 @@ class ConstDictVariable(VariableTracker):
self.vt = vt
@property
def underlying_value(self) -> Any:
def underlying_value(self):
if (
isinstance(self.vt, variables.LazyVariableTracker)
and not self.vt.is_realized()
@ -178,8 +178,7 @@ class ConstDictVariable(VariableTracker):
elif isinstance(self.vt, variables.FrozenDataClassVariable):
Hashable = ConstDictVariable._HashableTracker
fields_values = {
k: Hashable(v).underlying_value
for k, v in self.vt.fields.items() # type: ignore[attr-defined]
k: Hashable(v).underlying_value for k, v in self.vt.fields.items()
}
return variables.FrozenDataClassVariable.HashWrapper(
self.vt.python_type(), fields_values
@ -188,16 +187,16 @@ class ConstDictVariable(VariableTracker):
# The re module in Python 3.13+ has a dictionary (_cache2) with
# an object as key (`class _ZeroSentinel(int): ...`):
# python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual
return self.vt.value # type: ignore[attr-defined,union-attr]
return self.vt.value
else:
x = self.vt.as_python_constant()
return x
def __hash__(self) -> int:
def __hash__(self):
return hash(self.underlying_value)
@staticmethod
def _eq_impl(a: Any, b: Any) -> bool:
def _eq_impl(a, b):
# TODO: Put this in utils and share it between variables/builtin.py and here
type_a, type_b = type(a), type(b)
if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)):
@ -213,7 +212,7 @@ class ConstDictVariable(VariableTracker):
else:
return a == b
def __eq__(self, other: object) -> bool:
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
Hashable = ConstDictVariable._HashableTracker
assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), (
type(other)
@ -227,8 +226,8 @@ class ConstDictVariable(VariableTracker):
def __init__(
self,
items: dict[VariableTracker, VariableTracker],
user_cls: type = dict,
**kwargs: Any,
user_cls=dict,
**kwargs,
) -> None:
# .clone() pass these arguments in kwargs but they're recreated a few
# lines below
@ -248,22 +247,18 @@ class ConstDictVariable(VariableTracker):
for x, v in items.items()
)
def make_hashable(
key: Union[VariableTracker, "ConstDictVariable._HashableTracker"],
) -> "ConstDictVariable._HashableTracker":
def make_hashable(key):
return key if isinstance(key, Hashable) else Hashable(key)
dict_cls = self._get_dict_cls_from_user_cls(user_cls)
self.items = dict_cls({make_hashable(x): v for x, v in items.items()})
# need to reconstruct everything if the dictionary is an intermediate value
# or if a pop/delitem was executed
self.should_reconstruct_all = (
not is_from_local_source(self.source) if self.source else True
)
self.should_reconstruct_all = not is_from_local_source(self.source)
self.original_items = items.copy()
self.user_cls = user_cls
def _get_dict_cls_from_user_cls(self, user_cls: type) -> type:
def _get_dict_cls_from_user_cls(self, user_cls):
accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict)
# avoid executing user code if user_cls is a dict subclass
@ -282,10 +277,10 @@ class ConstDictVariable(VariableTracker):
dict_cls = dict
return dict_cls
def as_proxy(self) -> dict[Any, Any]:
def as_proxy(self):
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
def debug_repr(self) -> str:
def debug_repr(self):
return (
"{"
+ ", ".join(
@ -294,20 +289,20 @@ class ConstDictVariable(VariableTracker):
+ "}"
)
def as_python_constant(self) -> dict[Any, Any]:
def as_python_constant(self):
return {
k.vt.as_python_constant(): v.as_python_constant()
for k, v in self.items.items()
}
def keys_as_python_constant(self) -> dict[Any, VariableTracker]:
def keys_as_python_constant(self):
self.install_dict_keys_match_guard()
return {k.vt.as_python_constant(): v for k, v in self.items.items()}
def python_type(self) -> type:
def python_type(self):
return self.user_cls
def __contains__(self, vt: VariableTracker) -> bool:
def __contains__(self, vt) -> bool:
assert isinstance(vt, VariableTracker)
Hashable = ConstDictVariable._HashableTracker
return (
@ -327,15 +322,13 @@ class ConstDictVariable(VariableTracker):
for key, value in self.items.items()
)
def is_new_item(
self, value: Optional[VariableTracker], other: VariableTracker
) -> bool:
def is_new_item(self, value, other):
# compare the id of the realized values if both values are not lazy VTs
if value and value.is_realized() and other.is_realized():
return id(value.realize()) != id(other.realize())
return id(value) != id(other)
def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None:
def reconstruct_kvs_into_new_dict(self, codegen):
# Build a dictionary that contains the keys and values.
num_args = 0
for key, value in self.items.items():
@ -347,7 +340,7 @@ class ConstDictVariable(VariableTracker):
num_args += 1
codegen.append_output(create_instruction("BUILD_MAP", arg=num_args))
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen: "PyCodegen"):
if self.user_cls is collections.OrderedDict:
# emit `OrderedDict(constructed_dict)`
codegen.add_push_null(
@ -365,21 +358,19 @@ class ConstDictVariable(VariableTracker):
def getitem_const_raise_exception_if_absent(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
raise_observed_exception(KeyError, tx)
return self.items[key]
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined]
msg = f"Dictionary key {arg.value} not found during tracing"
unimplemented_v2(
gb_type="key not found in dict",
context=f"Key {arg.value}", # type: ignore[attr-defined]
context=f"Key {arg.value}",
explanation=msg,
hints=[
"Check if the key exists in the dictionary before accessing it.",
@ -388,13 +379,13 @@ class ConstDictVariable(VariableTracker):
)
return self.items[key]
def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]:
def maybe_getitem_const(self, arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
return None
return self.items[key]
def realize_key_vt(self, arg: VariableTracker) -> None:
def realize_key_vt(self, arg: VariableTracker):
# Realize the LazyVT on a particular index
assert arg in self
key = ConstDictVariable._HashableTracker(arg)
@ -403,13 +394,11 @@ class ConstDictVariable(VariableTracker):
if isinstance(original_key_vt, variables.LazyVariableTracker):
original_key_vt.realize()
def install_dict_keys_match_guard(self) -> None:
def install_dict_keys_match_guard(self):
if self.source:
install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH))
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
def install_dict_contains_guard(self, tx, args):
# Key guarding - These are the cases to consider
# 1) The dict has been mutated. In this case, we would have already
# inserted a DICT_KEYS_MATCH guard, so we can skip.
@ -450,11 +439,11 @@ class ConstDictVariable(VariableTracker):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# NB - Both key and value are LazyVariableTrackers in the beginning. So,
# we have to insert guards when a dict method is accessed. For this to
# be simple, we are conservative and overguard. We skip guard only for
@ -473,7 +462,7 @@ class ConstDictVariable(VariableTracker):
tx, *args, **kwargs
)
tx.output.side_effects.mutation(self)
self.items.update(temp_dict_vt.items) # type: ignore[attr-defined]
self.items.update(temp_dict_vt.items)
return ConstantVariable.create(None)
elif name == "__getitem__":
# Key guarding - Nothing to do. LazyVT for value will take care.
@ -537,7 +526,7 @@ class ConstDictVariable(VariableTracker):
return ConstantVariable.create(len(self.items))
elif name == "__setitem__" and self.is_mutable():
if not arg_hashable:
raise_unhashable(args[0], tx)
raise_unhashable(args[0])
self.install_dict_keys_match_guard()
if kwargs or len(args) != 2:
@ -561,7 +550,7 @@ class ConstDictVariable(VariableTracker):
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
if not arg_hashable:
raise_unhashable(args[0], tx)
raise_unhashable(args[0])
if args[0] not in self:
self.install_dict_contains_guard(tx, args)
@ -576,7 +565,7 @@ class ConstDictVariable(VariableTracker):
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
if not arg_hashable:
raise_unhashable(args[0], tx)
raise_unhashable(args[0])
if args[0] not in self:
# missing item, return the default value. Install no DICT_CONTAINS guard.
@ -610,7 +599,7 @@ class ConstDictVariable(VariableTracker):
last = v.value
else:
raise_args_mismatch(tx, name)
k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined]
k, v = self.items.popitem(last=last)
else:
k, v = self.items.popitem()
@ -643,17 +632,17 @@ class ConstDictVariable(VariableTracker):
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard()
dict_vt: ConstDictVariable = args[0]
dict_vt = args[0]
else:
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment]
self.items.update(dict_vt.items) # type: ignore[attr-defined]
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
self.items.update(dict_vt.items)
if has_kwargs:
# Handle kwargs
kwargs_hashable = {
kwargs = {
Hashable(ConstantVariable.create(k)): v
for k, v in kwargs.items()
}
self.items.update(kwargs_hashable)
self.items.update(kwargs)
return ConstantVariable.create(None)
else:
return super().call_method(tx, name, args, kwargs)
@ -667,7 +656,7 @@ class ConstDictVariable(VariableTracker):
)
if not arg_hashable:
raise_unhashable(args[0], tx)
raise_unhashable(args[0])
self.install_dict_contains_guard(tx, args)
contains = args[0] in self
@ -682,7 +671,7 @@ class ConstDictVariable(VariableTracker):
)
if not arg_hashable:
raise_unhashable(args[0], tx)
raise_unhashable(args[0])
self.install_dict_keys_match_guard()
if kwargs or len(args) > 2:
@ -718,7 +707,7 @@ class ConstDictVariable(VariableTracker):
and "last" in kwargs
and isinstance(kwargs["last"], ConstantVariable)
):
last = kwargs.get("last").value # type: ignore[union-attr]
last = kwargs.get("last").value
key = Hashable(args[0])
self.items.move_to_end(key, last=last)
@ -734,7 +723,7 @@ class ConstDictVariable(VariableTracker):
)
elif name == "__ne__":
return ConstantVariable.create(
not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined]
not self.call_method(tx, "__eq__", args, kwargs).value
)
elif name == "__or__":
if len(args) != 1:
@ -761,14 +750,14 @@ class ConstDictVariable(VariableTracker):
if not istype(
other, (ConstDictVariable, variables.UserDefinedDictVariable)
):
err_msg = (
msg = (
f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
f"and '{other.python_type().__name__}'"
)
raise_observed_exception(TypeError, tx, args=[err_msg])
raise_observed_exception(TypeError, tx, args=[msg])
# OrderedDict overloads __ror__
ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined]
ts = {self.user_cls, other.user_cls}
user_cls = (
collections.OrderedDict
if any(issubclass(t, collections.OrderedDict) for t in ts)
@ -785,8 +774,8 @@ class ConstDictVariable(VariableTracker):
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard() # type: ignore[attr-defined]
new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined]
args[0].install_dict_keys_match_guard()
new_dict_vt.items.update(args[0].items)
return new_dict_vt
elif name == "__ior__":
self.call_method(tx, "update", args, kwargs)
@ -800,13 +789,11 @@ class ConstDictVariable(VariableTracker):
else:
return super().call_method(tx, name, args, kwargs)
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
def unpack_var_sequence(self, tx):
self.install_dict_keys_match_guard()
return [x.vt for x in self.items.keys()]
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
def call_obj_hasattr(self, tx, name):
# dict not allow setting arbitrary attributes. OrderedDict and
# defaultdict allow arbitrary setattr, but not deletion of default attrs
if any(
@ -829,25 +816,25 @@ class ConstDictVariable(VariableTracker):
],
)
def clone(self, **kwargs: Any) -> VariableTracker:
def clone(self, **kwargs):
self.install_dict_keys_match_guard()
return super().clone(**kwargs)
class MappingProxyVariable(VariableTracker):
# proxies to the original dict_vt
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
super().__init__(**kwargs)
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
def python_type(self) -> type:
def python_type(self):
return types.MappingProxyType
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
def unpack_var_sequence(self, tx):
return self.dv_dict.unpack_var_sequence(tx)
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen: "PyCodegen"):
# load types.MappingProxyType
if self.source:
msg = (
@ -876,11 +863,11 @@ class MappingProxyVariable(VariableTracker):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if self.source and tx.output.side_effects.has_existing_dict_mutation():
msg = (
"A dict has been modified while we have an existing mappingproxy object. "
@ -905,7 +892,7 @@ class MappingProxyVariable(VariableTracker):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
) -> "VariableTracker":
if self.python_type() is types.MappingProxyType:
return ConstantVariable.create(name in types.MappingProxyType.__dict__)
return super().call_obj_hasattr(tx, name)
@ -913,44 +900,35 @@ class MappingProxyVariable(VariableTracker):
class NNModuleHooksDictVariable(ConstDictVariable):
# Special class to avoid adding any guards on the nn module hook ids.
def install_dict_keys_match_guard(self) -> None:
def install_dict_keys_match_guard(self):
pass
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
def install_dict_contains_guard(self, tx, args):
pass
class DefaultDictVariable(ConstDictVariable):
def __init__(
self,
items: dict[VariableTracker, VariableTracker],
user_cls: type,
default_factory: Optional[VariableTracker] = None,
**kwargs: Any,
) -> None:
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
super().__init__(items, user_cls, **kwargs)
assert user_cls is collections.defaultdict
if default_factory is None:
default_factory = ConstantVariable.create(None)
self.default_factory = default_factory
def is_python_constant(self) -> bool:
def is_python_constant(self):
# Return false for unsupported defaults. This ensures that a bad handler
# path is not taken in BuiltinVariable for getitem.
if self.default_factory not in [list, tuple, dict] and not self.items:
return False
return super().is_python_constant()
def debug_repr(self) -> str:
assert self.default_factory is not None
def debug_repr(self):
return (
f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})"
)
@staticmethod
def is_supported_arg(arg: VariableTracker) -> bool:
def is_supported_arg(arg):
if isinstance(arg, variables.BuiltinVariable):
return arg.fn in (list, tuple, dict, set)
else:
@ -964,11 +942,11 @@ class DefaultDictVariable(ConstDictVariable):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__":
if len(args) != 1:
raise_args_mismatch(tx, name, "1 args", f"{len(args)} args")
@ -984,13 +962,13 @@ class DefaultDictVariable(ConstDictVariable):
else:
default_var = self.default_factory.call_function(tx, [], {})
super().call_method(
tx, "__setitem__", [args[0], default_var], kwargs
tx, "__setitem__", (args[0], default_var), kwargs
)
return default_var
else:
return super().call_method(tx, name, args, kwargs)
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen):
# emit `defaultdict(default_factory, new_dict)`
codegen.add_push_null(
lambda: codegen.extend_output(
@ -1016,48 +994,40 @@ class SetVariable(ConstDictVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs: Any,
**kwargs,
) -> None:
# pyrefly: ignore[bad-assignment]
items = dict.fromkeys(items, SetVariable._default_value())
# pyrefly: ignore[bad-argument-type]
super().__init__(items, **kwargs)
def debug_repr(self) -> str:
def debug_repr(self):
if not self.items:
return "set()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
def set_items(self):
return set(self.items.keys())
@staticmethod
def _default_value() -> VariableTracker:
def _default_value():
# Variable to fill in he keys of the dictionary
return ConstantVariable.create(None)
def as_proxy(self) -> Any:
def as_proxy(self):
return {k.vt.as_proxy() for k in self.set_items}
def python_type(self) -> type:
def python_type(self):
return set
def as_python_constant(self) -> Any:
def as_python_constant(self):
return {k.vt.as_python_constant() for k in self.set_items}
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen: "PyCodegen"):
codegen.foreach([x.vt for x in self.set_items])
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
def _fast_set_method(
self,
tx: "InstructionTranslator",
fn: Any,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
def _fast_set_method(self, tx, fn, args, kwargs):
try:
res = fn(
*[x.as_python_constant() for x in [self, *args]],
@ -1067,16 +1037,15 @@ class SetVariable(ConstDictVariable):
raise_observed_exception(
type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
)
# pyrefly: ignore[unbound-name]
return VariableTracker.build(tx, res)
def call_method(
self,
tx: "InstructionTranslator",
name: str,
tx,
name,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
) -> "VariableTracker":
# We forward the calls to the dictionary model
from ..utils import check_constant_args
@ -1096,10 +1065,10 @@ class SetVariable(ConstDictVariable):
return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
if name == "__init__":
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs)
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
tx.output.side_effects.mutation(self)
self.items.clear()
self.items.update(temp_set_vt.items) # type: ignore[attr-defined]
self.items.update(temp_set_vt.items)
return ConstantVariable.create(None)
elif name == "add":
if kwargs or len(args) != 1:
@ -1110,7 +1079,7 @@ class SetVariable(ConstDictVariable):
f"{len(args)} args and {len(kwargs)} kwargs",
)
name = "__setitem__"
args = [args[0], SetVariable._default_value()]
args = (args[0], SetVariable._default_value())
elif name == "pop":
if kwargs or args:
raise_args_mismatch(
@ -1121,14 +1090,12 @@ class SetVariable(ConstDictVariable):
)
# Choose an item at random and pop it via the Dict.pop method
try:
result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment]
result = self.set_items.pop().vt
except KeyError as e:
raise_observed_exception(
KeyError, tx, args=list(map(ConstantVariable.create, e.args))
)
# pyrefly: ignore[unbound-name]
super().call_method(tx, name, [result], kwargs)
# pyrefly: ignore[unbound-name]
super().call_method(tx, name, (result,), kwargs)
return result
elif name == "isdisjoint":
if kwargs or len(args) != 1:
@ -1250,7 +1217,6 @@ class SetVariable(ConstDictVariable):
f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
assert m is not None
return self.call_method(tx, m, args, kwargs)
elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"):
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
@ -1264,34 +1230,29 @@ class SetVariable(ConstDictVariable):
"__ixor__": "symmetric_difference_update",
"__isub__": "difference_update",
}.get(name)
assert m is not None
self.call_method(tx, m, args, kwargs)
return self
elif name == "__eq__":
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
return ConstantVariable.create(False)
r = self.call_method(tx, "symmetric_difference", args, kwargs)
return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined]
return ConstantVariable.create(len(r.set_items) == 0)
elif name in cmp_name_to_op_mapping:
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
return ConstantVariable.create(NotImplemented)
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
)
return super().call_method(tx, name, args, kwargs)
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
raise RuntimeError("Illegal to getitem on a set")
def install_dict_keys_match_guard(self) -> None:
def install_dict_keys_match_guard(self):
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
def install_dict_contains_guard(self, tx, args):
super().install_dict_contains_guard(tx, args)
@ -1299,27 +1260,27 @@ class FrozensetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs: Any,
**kwargs,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self) -> str:
def debug_repr(self):
if not self.items:
return "frozenset()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
def set_items(self):
return self.items.keys()
def python_type(self) -> type:
def python_type(self):
return frozenset
def as_python_constant(self) -> Any:
def as_python_constant(self):
return frozenset({k.vt.as_python_constant() for k in self.set_items})
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen: "PyCodegen"):
codegen.foreach([x.vt for x in self.set_items])
codegen.add_push_null(
lambda: codegen.extend_output(
@ -1332,11 +1293,11 @@ class FrozensetVariable(SetVariable):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
tx,
name,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
) -> "VariableTracker":
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a frozenset")
elif name == "__init__":
@ -1355,7 +1316,7 @@ class FrozensetVariable(SetVariable):
"symmetric_difference",
):
r = super().call_method(tx, name, args, kwargs)
return FrozensetVariable(r.items) # type: ignore[attr-defined]
return FrozensetVariable(r.items)
return super().call_method(tx, name, args, kwargs)
@ -1363,11 +1324,11 @@ class DictKeySetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs: Any,
**kwargs,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self) -> str:
def debug_repr(self):
if not self.items:
return "dict_keys([])"
else:
@ -1377,35 +1338,33 @@ class DictKeySetVariable(SetVariable):
+ "])"
)
def install_dict_keys_match_guard(self) -> None:
def install_dict_keys_match_guard(self):
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
def install_dict_contains_guard(self, tx, args):
# Already EQUALS_MATCH guarded
pass
@property
def set_items(self) -> Any:
def set_items(self):
return self.items
def python_type(self) -> type:
def python_type(self):
return dict_keys
def as_python_constant(self) -> Any:
def as_python_constant(self):
return dict.fromkeys(
{k.vt.as_python_constant() for k in self.set_items}, None
).keys()
def call_method(
self,
tx: "InstructionTranslator",
name: str,
tx,
name,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
) -> "VariableTracker":
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a dict_keys")
return super().call_method(tx, name, args, kwargs)
@ -1420,47 +1379,42 @@ class DictViewVariable(VariableTracker):
kv: Optional[str] = None
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
super().__init__(**kwargs)
assert self.kv in ("keys", "values", "items")
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
@property
def view_items(self) -> Any:
assert self.kv is not None
def view_items(self):
return getattr(self.dv_dict.items, self.kv)()
@property
def view_items_vt(self) -> list[VariableTracker]:
def view_items_vt(self):
# Returns an iterable of the unpacked items
# Implement in the subclasses
raise NotImplementedError
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
def unpack_var_sequence(self, tx):
return self.view_items_vt
def reconstruct(self, codegen: "PyCodegen") -> None:
assert self.kv is not None
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.dv_dict)
codegen.load_method(self.kv)
codegen.call_method(0)
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
assert self.kv is not None
def call_obj_hasattr(self, tx, name):
if name in self.python_type().__dict__:
return ConstantVariable.create(True)
return ConstantVariable.create(False)
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if name == "__len__":
return self.dv_dict.call_method(tx, name, args, kwargs)
elif name == "__iter__":
@ -1474,24 +1428,24 @@ class DictKeysVariable(DictViewVariable):
kv = "keys"
@property
def set_items(self) -> set[VariableTracker]:
def set_items(self):
return set(self.view_items)
@property
def view_items_vt(self) -> list[VariableTracker]:
def view_items_vt(self):
# Returns an iterable of the unpacked items
return [x.vt for x in self.view_items]
def python_type(self) -> type:
def python_type(self):
return dict_keys
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if name == "__contains__":
return self.dv_dict.call_method(tx, name, args, kwargs)
elif name in (
@ -1506,13 +1460,13 @@ class DictKeysVariable(DictViewVariable):
):
# These methods always returns a set
m = getattr(self.set_items, name)
r = m(args[0].set_items) # type: ignore[attr-defined]
r = m(args[0].set_items)
return SetVariable(r)
if name in cmp_name_to_op_mapping:
if not isinstance(args[0], (SetVariable, DictKeysVariable)):
return ConstantVariable.create(NotImplemented)
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
)
return super().call_method(tx, name, args, kwargs)
@ -1522,10 +1476,10 @@ class DictValuesVariable(DictViewVariable):
kv = "values"
@property
def view_items_vt(self) -> list[VariableTracker]:
def view_items_vt(self):
return list(self.view_items)
def python_type(self) -> type:
def python_type(self):
return dict_values
@ -1533,20 +1487,14 @@ class DictItemsVariable(DictViewVariable):
kv = "items"
@property
def view_items_vt(self) -> list[VariableTracker]:
def view_items_vt(self):
# Returns an iterable of the unpacked items
return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items]
def python_type(self) -> type:
def python_type(self):
return dict_items
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
def call_method(self, tx, name, args, kwargs):
# TODO(guilhermeleobas): This should actually check if args[0]
# implements the mapping protocol.
if name == "__eq__":

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import hashlib
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
import sympy # noqa: TC002
@ -17,8 +17,6 @@ from .simd import SIMDKernel, SIMDScheduling
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ..ir import IRNode
from ..scheduler import BaseSchedulerNode

View File

@ -627,7 +627,7 @@ class ComboKernel(Kernel):
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
filename=__file__,
num_warps={self.num_warps},
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)

View File

@ -360,7 +360,7 @@ def estimate_nccl_collective_runtime_from_fx_node(
fx_node: torch.fx.Node,
override_size: Optional[int] = None,
# TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix.
use_nccl_estimator: bool = True,
use_nccl_estimator: bool = False,
) -> float:
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).

View File

@ -1,6 +1,6 @@
import os
from collections.abc import Callable
from functools import cache, partial
from typing import Callable
import torch
from torch._environment import is_fbcode

View File

@ -3586,24 +3586,13 @@ def user_autotune(
)
def foreach(triton_meta, filename=None, inductor_meta=None):
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
configs = []
# Naive autotuning path for num_warps
if not (
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
return cached_autotune(
None,
configs,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,

View File

@ -52,7 +52,26 @@ __all__ = [
"MemRecordsAcc",
]
from contextlib import ContextDecorator
try:
# Available in Python >= 3.2
from contextlib import ContextDecorator as _ContextDecorator
except ImportError:
import functools
class _ContextDecorator: # type: ignore[no-redef]
def __enter__(self):
raise NotImplementedError
def __exit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError
def __call__(self, func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
with self:
return func(*args, **kwargs)
return wrapped
# global python state - whether profiler is currently enabled
@ -725,7 +744,8 @@ class profile:
return all_function_events
class record_function(ContextDecorator):
# pyrefly: ignore [invalid-inheritance]
class record_function(_ContextDecorator):
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
Label will only appear if CPU activity tracing is enabled.

View File

@ -1224,3 +1224,43 @@ def _build_table(
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
)
return "".join(result)
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)

View File

@ -108,14 +108,12 @@ struct FlightRecorder {
capture_cpp_stack_ = getCvarBool(
{"TORCH_FR_CPP_STACK", "TORCH_NCCL_TRACE_CPP_STACK"}, false);
enabled_ = max_entries_ > 0;
reset_epoch_start_idx_[0] = 0;
}
struct Entry {
size_t id_; // incremented id in the trace buffer
// used to figure out where in the circular entries
// buffer this entry will be located to
// update state information
size_t reset_epoch_; // epoch when this entry was created
size_t pg_id_;
std::tuple<std::string, std::string> pg_name_; // <group_name, group_desc>
@ -185,34 +183,11 @@ struct FlightRecorder {
size_t max_entries_ = 0;
size_t next_ = 0;
size_t id_ = 0;
size_t reset_epoch_ = 0;
std::unordered_map<size_t, size_t>
reset_epoch_start_idx_; // maps reset_epoch to the idx where it starts
std::map<size_t, std::shared_ptr<ProcessGroupStatus>> all_pg_status_;
std::map<std::tuple<std::string, std::string>, std::vector<uint64_t>>
pg_name_to_ranks_;
std::string comm_lib_version_;
struct TraceIdentifier {
std::optional<size_t> id;
std::optional<size_t> reset_epoch;
};
TraceIdentifier recordWithResetEnabled(
size_t pg_id,
const std::tuple<std::string, std::string>& pg_name,
size_t collective_seq_id,
size_t p2p_seq_id,
size_t op_id,
std::string profiling_name,
const std::vector<at::Tensor>& inputs,
const std::vector<at::Tensor>& outputs,
EventType* start,
EventType* end,
std::chrono::milliseconds timeout_ms,
std::shared_ptr<ProcessGroupStatus> pg_status,
bool isP2P);
std::optional<size_t> record(
size_t pg_id,
const std::tuple<std::string, std::string>& pg_name,
@ -238,16 +213,8 @@ struct FlightRecorder {
std::vector<Entry> dump_entries();
// Returns the index in entries_ for the given id and reset_epoch.
// Caller must hold mutex_lock before calling this method.
size_t getIdxFromId(size_t id, size_t reset_epoch) const;
// Returns the entry with the given id and reset_epoch, if it exists.
// Otherwise, returns std::nullopt.
TORCH_API std::optional<Entry> getEntry(
std::optional<size_t> id,
std::optional<size_t> reset_epoch);
// Returns the entry with the given id, if it exists. Otherwise, returns
// std::nullopt.
TORCH_API std::optional<Entry> getEntry(std::optional<size_t> id);
/*
@ -260,11 +227,6 @@ struct FlightRecorder {
never hang. (timing must also be enabled for compute_duration - see
TORCH_NCCL_ENABLE_TIMING).
*/
TORCH_API void retire_id(
std::optional<size_t> id,
std::optional<size_t> reset_epoch,
bool compute_duration = true);
TORCH_API void retire_id(
std::optional<size_t> id,
bool compute_duration = true);

View File

@ -53,41 +53,8 @@ std::optional<size_t> FlightRecorder<EventType>::record(
std::chrono::milliseconds timeout_ms,
std::shared_ptr<ProcessGroupStatus> pg_status,
bool isP2P) {
auto result = recordWithResetEnabled(
pg_id,
pg_name,
collective_seq_id,
p2p_seq_id,
op_id,
std::move(profiling_name),
inputs,
outputs,
start,
end,
timeout_ms,
std::move(pg_status),
isP2P);
return result.id;
}
template <typename EventType>
typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
recordWithResetEnabled(
size_t pg_id,
const std::tuple<std::string, std::string>& pg_name,
size_t collective_seq_id,
size_t p2p_seq_id,
size_t op_id,
std::string profiling_name,
const std::vector<at::Tensor>& inputs,
const std::vector<at::Tensor>& outputs,
EventType* start,
EventType* end,
std::chrono::milliseconds timeout_ms,
std::shared_ptr<ProcessGroupStatus> pg_status,
bool isP2P) {
if (!enabled_) {
return TraceIdentifier{std::nullopt, std::nullopt};
return std::nullopt;
}
if (all_pg_status_.find(pg_id) == all_pg_status_.end()) {
// Current pg_status is not in FR.
@ -97,13 +64,8 @@ typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
std::lock_guard<std::mutex> guard(mutex_);
TORCH_CHECK(
reset_epoch_start_idx_.find(reset_epoch_) !=
reset_epoch_start_idx_.end());
auto te = Entry{
id_,
reset_epoch_,
pg_id,
pg_name,
collective_seq_id,
@ -142,20 +104,15 @@ typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end());
}
const auto next = next_++;
if (entries_.size() < max_entries_) {
entries_.emplace_back(std::move(te));
} else {
entries_[next] = std::move(te);
entries_[next_++] = std::move(te);
if (next_ == max_entries_) {
next_ = 0;
}
}
if (next_ == max_entries_) {
next_ = 0;
}
const auto id = id_++;
return TraceIdentifier{id, reset_epoch_};
return id_++;
}
template <typename EventType>
@ -206,20 +163,15 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
std::vector<Entry> result;
{
std::lock_guard<std::mutex> guard(mutex_);
// Filter entries during insertion - only keep entries from current epoch
auto filter = [this](const Entry& e) {
return e.reset_epoch_ == reset_epoch_;
};
std::copy_if(
result.reserve(entries_.size());
result.insert(
result.end(),
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
entries_.end(),
std::back_inserter(result),
filter);
std::copy_if(
entries_.end());
result.insert(
result.end(),
entries_.begin(),
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
std::back_inserter(result),
filter);
entries_.begin() + static_cast<std::ptrdiff_t>(next_));
}
// query any remaining events
for (auto& r : result) {
@ -230,47 +182,28 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
}
template <typename EventType>
// Returns the index in entries_ for the given id and reset_epoch.
// Caller must hold mutex_lock before calling this method.
size_t FlightRecorder<EventType>::getIdxFromId(size_t id, size_t reset_epoch)
const {
// Look up the starting idx for the given reset epoch
auto it = reset_epoch_start_idx_.find(reset_epoch);
TORCH_CHECK(it != reset_epoch_start_idx_.end());
// Calculate idx based on where the epoch started
return (it->second + id) % max_entries_;
}
template <typename EventType>
// Returns the entry with the given id and reset_epoch, if it exists. Otherwise,
// returns std::nullopt.
// Returns the entry with the given id, if it exists. Otherwise, returns
// std::nullopt.
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
EventType>::
getEntry(std::optional<size_t> id, std::optional<size_t> reset_epoch) {
if (!enabled_ || !id || !reset_epoch) {
EventType>::getEntry(std::optional<size_t> id) {
if (!enabled_ || !id) {
return std::nullopt;
}
std::unique_lock<std::mutex> guard(mutex_);
Entry entry = entries_.at(getIdxFromId(*id, *reset_epoch));
if (entry.id_ == *id && entry.reset_epoch_ == *reset_epoch) {
Entry entry = entries_.at(*id % max_entries_);
if (entry.id_ == *id) {
return entry;
} else {
return std::nullopt;
}
return std::nullopt;
}
template <typename EventType>
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
EventType>::getEntry(std::optional<size_t> id) {
return getEntry(id, 0);
}
template <typename EventType>
void FlightRecorder<EventType>::retire_id(
std::optional<size_t> id,
std::optional<size_t> reset_epoch,
bool compute_duration) {
if (!enabled_ || !id || !reset_epoch) {
if (!enabled_ || !id) {
return;
}
@ -281,8 +214,8 @@ void FlightRecorder<EventType>::retire_id(
std::unique_lock<std::mutex> guard(mutex_);
Entry* entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
if (entry->id_ == *id && entry->reset_epoch_ == *reset_epoch) {
Entry* entry = &entries_.at(*id % max_entries_);
if (entry->id_ == *id) {
update_state(*entry);
if (compute_duration) {
@ -304,8 +237,8 @@ void FlightRecorder<EventType>::retire_id(
guard.lock();
// Refresh the entry pointer, see if the entry has been overwritten
entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
if (!(entry->id_ == *id && entry->reset_epoch_ == *reset_epoch)) {
entry = &entries_.at(*id % max_entries_);
if (entry->id_ != *id) {
LOG(INFO) << "retire_id abandoned for id " << *id
<< ", event was overwritten while waiting to compute duration.";
return;
@ -316,23 +249,12 @@ void FlightRecorder<EventType>::retire_id(
}
}
template <typename EventType>
void FlightRecorder<EventType>::retire_id(
std::optional<size_t> id,
bool compute_duration) {
retire_id(id, 0, compute_duration);
}
template <typename EventType>
void FlightRecorder<EventType>::reset_all() {
std::lock_guard<std::mutex> guard(mutex_);
if (!entries_.empty()) {
// Soft delete: increment epoch to mark all existing entries as old
// Store where the new epoch starts in the circular buffer
reset_epoch_++;
reset_epoch_start_idx_[reset_epoch_] = next_;
id_ = 0;
}
next_ = 0;
id_ = 0;
entries_.clear();
}
template <typename EventType>

View File

@ -708,8 +708,7 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
// TODO: We need to have numel of tensors for gloo as well.
pgStatus_->lastCompletedNumelIn = 0;
pgStatus_->lastCompletedNumelOut = 0;
FlightRecorder<c10::Event>::get()->retire_id(
work->trace_id_, work->trace_reset_epoch_, false);
FlightRecorder<c10::Event>::get()->retire_id(work->trace_id_, false);
lock.lock();
workInProgress_[workerIndex].reset();
}
@ -781,7 +780,7 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
pgStatus_->lastEnqueuedNumelOut = 0;
// using c10d::FlightRecorder;
// TODO: We need to have a way to use c10::Event inside gloo as well.
auto traceId = FlightRecorder<c10::Event>::get()->recordWithResetEnabled(
work->trace_id_ = FlightRecorder<c10::Event>::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
collectiveCounter_,
@ -796,8 +795,6 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
work->getTimeout(),
pgStatus_,
false);
work->trace_id_ = traceId.id;
work->trace_reset_epoch_ = traceId.reset_epoch;
workQueue_.push_back(std::move(work));
lock.unlock();

View File

@ -99,7 +99,6 @@ class TORCH_API ProcessGroupGloo : public Backend {
// unique id used to tell the trace buffer that this
// work has completed
std::optional<uint64_t> trace_id_;
std::optional<uint64_t> trace_reset_epoch_;
std::shared_ptr<gloo::Context> context_;
const std::chrono::milliseconds timeout_;

View File

@ -575,7 +575,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
futureWorkResult_(w.futureWorkResult_),
timingEnabled_(w.timingEnabled_),
trace_id_(w.trace_id_),
trace_reset_epoch_(w.trace_reset_epoch_),
distDebugLevel_(w.distDebugLevel_) {
exception_ = w.exception_;
}
@ -705,9 +704,9 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
// Print the traceback of the collective at call time
std::string ProcessGroupNCCL::WorkNCCL::getTraceback() const {
// First step we get the corresponding record entry from FR, based on work's
// trace_id_ and trace_reset_epoch_
// trace_id_
std::optional<FlightRecorderCUDA::Entry> entry =
FlightRecorderCUDA::get()->getEntry(trace_id_, trace_reset_epoch_);
FlightRecorderCUDA::get()->getEntry(trace_id_);
if (entry.has_value()) {
auto entryVal = entry.value();
// Get stack trace from FR entry, in string format
@ -2395,8 +2394,7 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_);
pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_;
pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_;
FlightRecorderCUDA::get()->retire_id(
work.trace_id_, work.trace_reset_epoch_, true);
FlightRecorderCUDA::get()->retire_id(work.trace_id_, true);
if (pg_->onCompletionHook_) {
// Move Work object to completedWorkList_ to be consumed by the hook
// thread
@ -3362,7 +3360,7 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
// these objects to the Work because it has implications for keeping those
// tensors alive longer and adds overhead when copying Work objects
// between threads
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
r->trace_id_ = FlightRecorderCUDA::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
seqCollective_,
@ -3376,8 +3374,6 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
options_->timeout,
pgStatus_,
isP2P);
r->trace_id_ = traceId.id;
r->trace_reset_epoch_ = traceId.reset_epoch;
}
return r;
}
@ -3597,7 +3593,6 @@ float ProcessGroupNCCL::endTimeEstimate() {
#ifdef NCCL_SIM_INFO_INITIALIZER
ncclSimInfo_t simInfo = NCCL_SIM_INFO_INITIALIZER;
C10D_NCCL_CHECK(ncclGroupSimulateEnd(&simInfo), std::nullopt);
--ncclActiveGroupCounter_;
return simInfo.estimatedTime;
#else
TORCH_CHECK(
@ -3681,7 +3676,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// later in endCoalescing we record a 'coalesced' Work which has
// timing/state updates via watchdog thread, but lacks op metadata such as
// input/output sizes and profilingTitle per-op in the group.
FlightRecorderCUDA::get()->recordWithResetEnabled(
FlightRecorderCUDA::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
seqCollective_,
@ -4173,7 +4168,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
// TODO(whc) because we don't pass output {tensor} to initWork, we tell
// initWork to not record, and then we manually call record passing all the
// information it wants.
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
work->trace_id_ = FlightRecorderCUDA::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
seqCollective_,
@ -4187,8 +4182,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
options_->timeout,
pgStatus_,
/*isP2P=*/true);
work->trace_id_ = traceId.id;
work->trace_reset_epoch_ = traceId.reset_epoch;
}
// Only check for NaN for send ops, for recv ops `tensor` can be a random

View File

@ -505,7 +505,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// unique id used to tell the trace buffer that this
// work has completed
std::optional<uint64_t> trace_id_;
std::optional<uint64_t> trace_reset_epoch_;
DebugLevel distDebugLevel_;
friend class ProcessGroupNCCL;
};

View File

@ -4,7 +4,6 @@
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
#include <torch/headeronly/util/shim_utils.h>
#include <climits>
#include <memory>
@ -14,7 +13,6 @@
HIDDEN_NAMESPACE_BEGIN(torch, stable)
using accelerator::DeviceIndex;
using torch::headeronly::IntHeaderOnlyArrayRef;
using torch::headeronly::ScalarType;
// The torch::stable::Tensor class is a highlevel C++ wrapper around
@ -95,32 +93,6 @@ class Tensor {
return numel;
}
// note: this API is, for all intents and purposes, the same as the one in
// TensorBase.h: it returns a borrowed reference of the dimension sizes of
// a Tensor.
//
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
// which has slightly less functionality than a regular IntArrayRef. See
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
IntHeaderOnlyArrayRef sizes() const {
int64_t* sizes;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(ath_.get(), &sizes));
return IntHeaderOnlyArrayRef(sizes, dim());
}
// note: this API is, for all intents and purposes, the same as the one in
// TensorBase.h: it returns a borrowed reference of the strides of a
// Tensor.
//
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
// which has slightly less functionality than a regular IntArrayRef. See
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
IntHeaderOnlyArrayRef strides() const {
int64_t* strides;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(ath_.get(), &strides));
return IntHeaderOnlyArrayRef(strides, dim());
}
// note: this is a subset of the original TensorBase API. It takes no
// arguments whereas the original API takes in a kwarg of memory format.
// Here, we assume the default contiguous memory format.

View File

@ -1,8 +1,9 @@
import functools
import math
import operator
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from datetime import timedelta
from typing import Callable
import torch
from torch._C import ScriptObject

View File

@ -10,7 +10,6 @@ from ._context_parallel._attention import (
_enable_context_parallel_dispatcher,
_is_causal_behavior,
_RotateMethod,
_templated_ring_attention,
context_parallel,
context_parallel_unshard,
set_rotate_method,
@ -23,7 +22,6 @@ from ._context_parallel._load_balancer import (
)
# TODO(fegin): add deprecation message once the final interfaces are concluded.
__all__ = [
"_CausalBehavior",
"_context_parallel_shard",
@ -33,7 +31,6 @@ __all__ = [
"_enable_context_parallel_dispatcher",
"_is_causal_behavior",
"_RotateMethod",
"_templated_ring_attention",
"context_parallel",
"context_parallel_unshard",
"set_rotate_method",

View File

@ -443,6 +443,7 @@ class CodeGen:
colored: bool = False,
# Render each argument on its own line
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
free_vars: list[str] = []
body: list[str] = []
@ -647,15 +648,6 @@ class CodeGen:
if verbose:
# override annotation with more detailed information
try:
from torch.distributed.tensor._api import DTensor, DTensorSpec
dtensorspec_format_shard_order_str = (
DTensorSpec.format_shard_order_str
)
except ModuleNotFoundError:
DTensor = None # type: ignore[assignment,misc]
dtensorspec_format_shard_order_str = None
from torch.fx.experimental.proxy_tensor import py_sym_types
from torch.fx.passes.shape_prop import TensorMetadata
@ -686,16 +678,6 @@ class CodeGen:
core = _tensor_annotation(meta_val)
if is_plain:
maybe_type_annotation = f': "{core}"'
elif type(meta_val) is DTensor:
assert dtensorspec_format_shard_order_str is not None
dtensor_meta = dtensorspec_format_shard_order_str(
meta_val._spec.placements, # type: ignore[attr-defined]
meta_val._spec.shard_order, # type: ignore[attr-defined]
)
cls = meta_val.__class__.__name__
maybe_type_annotation = (
f': "{cls}({core}, {dim_green(dtensor_meta)})"'
)
else:
cls = meta_val.__class__.__name__
maybe_type_annotation = f': "{cls}({core})"'
@ -817,6 +799,10 @@ class CodeGen:
return
raise NotImplementedError(f"node: {node.op} {node.target}")
if record_func:
body.append(
"_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n"
)
for i, node in enumerate(nodes):
# NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one
@ -826,8 +812,22 @@ class CodeGen:
# node index, which will be deleted later
# after going through _body_transformer
body.append(f"# COUNTER: {i}\n")
do_record = record_func and node.op in (
"call_function",
"call_method",
"call_module",
)
if do_record:
# The double hash ## convention is used by post-processing to find the fx markers
body.append(
f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n"
)
emit_node(node)
delete_unused_values(node)
if do_record:
body.append(f"_rf_{node.name}.__exit__(None, None, None)\n")
if record_func:
body.append("_rf.__exit__(None, None, None)\n")
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@ -1779,6 +1779,7 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
"""
Turn this ``Graph`` into valid Python code.
@ -1846,6 +1847,7 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def _python_code(
@ -1858,6 +1860,7 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
return self._codegen._gen_python_code(
self.nodes,
@ -1868,6 +1871,7 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def __str__(self) -> str:

View File

@ -861,14 +861,18 @@ class {module_name}(torch.nn.Module):
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module="self")
from torch._dynamo import config as dynamo_config
python_code = self._graph.python_code(
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
)
self._code = python_code.src
self._lineno_map = python_code._lineno_map
self._prologue_start = python_code._prologue_start
cls = type(self)
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
from torch._dynamo import config as dynamo_config
if dynamo_config.enrich_profiler_metadata:
# Generate metadata and register for profiler augmentation
@ -885,7 +889,6 @@ class {module_name}(torch.nn.Module):
# This ensures the same code+metadata always generates the same filename
hash_value = _metadata_hash(self._code, node_metadata)
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
filename = f"{file_stem}.py"
# Only include co_filename to use it directly as the cache key
@ -905,6 +908,13 @@ class {module_name}(torch.nn.Module):
_register_fx_metadata(filename, metadata)
# Replace the placeholder in generated code with actual filename
# The double hash ## convention is used by post-processing to find the fx markers
self._code = self._code.replace(
"torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')",
f"torch._C._profiler._RecordFunctionFast('## {filename} ##')",
)
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
# Determine whether this class explicitly defines a __call__ implementation

View File

@ -4,7 +4,7 @@ import operator
import re
from collections import deque
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Any, Literal, Optional, TYPE_CHECKING
from torch.autograd.profiler import profile
from torch.profiler import DeviceType
@ -400,3 +400,170 @@ def _init_for_cuda_graphs() -> None:
with profile():
pass
@dataclass
class TimelineEvent:
"""Represents an event in the profiler timeline."""
timestamp: int
event_type: Literal["start", "end", "regular"]
marker_type: Optional[Literal["filename", "node"]]
identifier: Optional[str | int]
event: dict[str, Any]
@dataclass
class ContextStackEntry:
"""Represents a context (filename or node) in the stack."""
context_type: Literal["filename", "node"]
identifier: str | int
metadata: Optional[dict]
tid: Optional[int] = None # Thread ID associated with this context
def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
"""
Maps recorded profiler events to their corresponding fx nodes and adds stack traces.
Builds a timeline of all events (regular ops and FX markers for filenames/nodes),
sorts by timestamp, then processes chronologically while maintaining a context stack of active
filename/node scopes. Regular events are augmented with stack traces and node names from the
innermost active context. Runtime is O(n log n) for n events.
Args:
traced_data: Json of profiler events from Chrome trace
Returns:
Dict mapping recorded event names to their aten operations with added stack traces
"""
from torch.fx.traceback import _FX_METADATA_REGISTRY
trace_events = traced_data.get("traceEvents", [])
# Create event timeline
event_timeline: list[TimelineEvent] = []
def is_fx_marker_event(event):
return (
event.get("cat") == "cpu_op"
and event.get("name", "").startswith("## ")
and event.get("name", "").endswith(" ##")
)
def append_fx_marker_event(event_type, identifier, event):
start_ts = event["ts"]
end_ts = start_ts + event["dur"]
event_timeline.append(
TimelineEvent(start_ts, "start", event_type, identifier, event)
)
event_timeline.append(
TimelineEvent(end_ts, "end", event_type, identifier, event)
)
for event in trace_events:
if "ts" not in event or "dur" not in event:
continue
if is_fx_marker_event(event):
content = event["name"][3:-3]
if content.endswith(".py"):
append_fx_marker_event("filename", content, event)
else:
try:
node_index = int(content)
except ValueError:
pass
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
else:
# Regular event that needs augmentation
start_ts = event["ts"]
event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event))
# Sort by timestamp
event_timeline.sort(key=lambda x: x.timestamp)
# Process events in chronological order with a stack
context_stack: list[ContextStackEntry] = []
# Invariant: all start event has a corresponding end event
for timeline_event in event_timeline:
match timeline_event.event_type:
case "start":
assert timeline_event.identifier is not None
if timeline_event.marker_type == "filename":
assert isinstance(timeline_event.identifier, str)
# Push filename context - query metadata registry on-demand
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
tid = timeline_event.event.get("tid")
context_stack.append(
ContextStackEntry(
"filename", timeline_event.identifier, metadata, tid
)
)
elif timeline_event.marker_type == "node":
# Find the current filename from stack
current_file_metadata = None
tid = timeline_event.event.get("tid")
for ctx_entry in reversed(context_stack):
if (
ctx_entry.context_type == "filename"
and ctx_entry.tid == tid
):
current_file_metadata = ctx_entry.metadata
break
if current_file_metadata:
node_metadata = current_file_metadata.get("node_metadata", {})
if timeline_event.identifier in node_metadata:
node_meta: Optional[dict] = node_metadata[
timeline_event.identifier
]
context_stack.append(
ContextStackEntry(
"node", timeline_event.identifier, node_meta, tid
)
)
case "end":
# Pop from stack - search backwards to find matching context
for i in range(len(context_stack) - 1, -1, -1):
ctx_entry = context_stack[i]
if (
timeline_event.marker_type == ctx_entry.context_type
and timeline_event.identifier == ctx_entry.identifier
):
context_stack.pop(i)
break
case "regular":
# Apply metadata from current context stack
# Find the most specific context (node takes precedence over filename)
# Only augment events with the same tid as the file/node event matched
current_stack_trace = None
current_node_name = None
event_tid = timeline_event.event.get("tid")
for ctx_entry in reversed(context_stack):
# Only apply metadata from contexts with matching tid
if ctx_entry.tid == event_tid:
if ctx_entry.context_type == "node" and ctx_entry.metadata:
current_stack_trace = ctx_entry.metadata.get(
"stack_trace", "No model stack trace available"
)
current_node_name = ctx_entry.metadata.get("name", "")
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
# if nodes are nested, e.g. in nested graph modules
break
# Augment the event
if current_stack_trace or current_node_name:
args = timeline_event.event.setdefault("args", {})
if current_stack_trace:
args["stack_trace"] = current_stack_trace
if current_node_name:
args["node_name"] = current_node_name

View File

@ -210,8 +210,7 @@ class _KinetoProfile:
def start_trace(self) -> None:
if self.execution_trace_observer:
self.execution_trace_observer.start()
if self.profiler is None:
raise AssertionError("Profiler must be initialized before starting trace")
assert self.profiler is not None
self.profiler._start_trace()
if self.profile_memory:
@ -257,8 +256,7 @@ class _KinetoProfile:
def stop_trace(self) -> None:
if self.execution_trace_observer:
self.execution_trace_observer.stop()
if self.profiler is None:
raise AssertionError("Profiler must be initialized before stopping trace")
assert self.profiler is not None
self.profiler.__exit__(None, None, None)
def export_chrome_trace(self, path: str):
@ -266,10 +264,7 @@ class _KinetoProfile:
Exports the collected trace in Chrome JSON format. If kineto is enabled, only
last cycle in schedule is exported.
"""
if self.profiler is None:
raise AssertionError(
"Profiler must be initialized before exporting chrome trace"
)
assert self.profiler
if path.endswith(".gz"):
fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False)
fp.close()
@ -289,8 +284,7 @@ class _KinetoProfile:
path (str): save stacks file to this location;
metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total"
"""
if self.profiler is None:
raise AssertionError("Profiler must be initialized before exporting stacks")
assert self.profiler
return self.profiler.export_stacks(path, metric)
def toggle_collection_dynamic(
@ -322,7 +316,7 @@ class _KinetoProfile:
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
"""
if self.profiler is None:
if not self.profiler:
return
self.profiler.toggle_collection_dynamic(enable, activities)
@ -339,10 +333,7 @@ class _KinetoProfile:
To use shape/stack functionality make sure to set record_shapes/with_stack
when creating profiler context manager.
"""
if self.profiler is None:
raise AssertionError(
"Profiler must be initialized before getting key averages"
)
assert self.profiler
return self.profiler.key_averages(
group_by_input_shape, group_by_stack_n, group_by_overload_name
)
@ -352,8 +343,7 @@ class _KinetoProfile:
Returns the list of unaggregated profiler events,
to be used in the trace callback or after the profiling is finished
"""
if self.profiler is None:
raise AssertionError("Profiler must be initialized before accessing events")
assert self.profiler
return self.profiler.function_events
def add_metadata(self, key: str, value: str) -> None:
@ -405,10 +395,7 @@ class _KinetoProfile:
if missing:
raise ValueError(f"{', '.join(missing)} required for memory profiling.")
if self.profiler is None or self.profiler.kineto_results is None:
raise AssertionError(
"Profiler and kineto_results must be initialized for memory profiling"
)
assert self.profiler is not None and self.profiler.kineto_results is not None
return MemoryProfile(self.profiler.kineto_results)
def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
@ -498,8 +485,7 @@ def schedule(
"""
def schedule_fn(step: int) -> ProfilerAction:
if step < 0:
raise AssertionError(f"Step must be non-negative. Got {step}.")
assert step >= 0
if step < skip_first:
return ProfilerAction.NONE
else:
@ -522,11 +508,9 @@ def schedule(
else ProfilerAction.RECORD_AND_SAVE
)
if wait < 0 or warmup < 0 or active <= 0 or repeat < 0 or skip_first < 0:
raise AssertionError(
f"Invalid profiler schedule arguments. Got wait={wait} (need >= 0), warmup={warmup} (need >= 0), "
f"active={active} (need > 0), repeat={repeat} (need >= 0), skip_first={skip_first} (need >= 0)."
)
assert (
wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0
), "Invalid profiler schedule arguments"
if warmup == 0:
warn(
"Profiler won't be using warmup, this can skew profiler results",
@ -733,8 +717,7 @@ class profile(_KinetoProfile):
activities_set.add(ProfilerActivity.CUDA)
elif ProfilerActivity.CUDA in activities_set:
activities_set.remove(ProfilerActivity.CUDA)
if len(activities_set) == 0:
raise AssertionError("No valid profiler activities found")
assert len(activities_set) > 0, "No valid profiler activities found"
super().__init__(
activities=activities,