mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 17:24:59 +08:00
Compare commits
12 Commits
ciflow/tru
...
lucaskabel
| Author | SHA1 | Date | |
|---|---|---|---|
| 7bbbd49976 | |||
| ae8a9fa894 | |||
| 6052a01b71 | |||
| 14b153bcf2 | |||
| 641de23c96 | |||
| 89165c0a2b | |||
| dcc2ba4ca4 | |||
| ad5c7c20e0 | |||
| c86540f120 | |||
| c17aa0f113 | |||
| 4ff068c33a | |||
| 0c7a4a6b48 |
@ -271,6 +271,16 @@ case "$tag" in
|
||||
# from pytorch/llvm:9.0.1 is x86 specific
|
||||
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
||||
;;
|
||||
pytorch-linux-jammy-aarch64-py3.10-clang21)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
CLANG_VERSION=21
|
||||
ACL=yes
|
||||
VISION=yes
|
||||
OPENBLAS=yes
|
||||
# snadampal: skipping llvm src build install because the current version
|
||||
# from pytorch/llvm:9.0.1 is x86 specific
|
||||
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
||||
;;
|
||||
pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=11
|
||||
|
||||
@ -1 +1 @@
|
||||
7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd
|
||||
bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7
|
||||
|
||||
@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
|
||||
# work around ubuntu apt-get conflicts
|
||||
sudo apt-get -y -f install
|
||||
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
|
||||
if [[ $CLANG_VERSION == 18 ]]; then
|
||||
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
|
||||
if [[ $CLANG_VERSION -ge 18 ]]; then
|
||||
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
@ -129,7 +129,7 @@ function install_129 {
|
||||
}
|
||||
|
||||
function install_128 {
|
||||
CUDNN_VERSION=9.10.2.21
|
||||
CUDNN_VERSION=9.8.0.87
|
||||
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
|
||||
# install CUDA 12.8.1 in the same container
|
||||
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux
|
||||
|
||||
@ -10,6 +10,7 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
|
||||
|
||||
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
|
||||
OPENBLAS_BUILD_FLAGS="
|
||||
CC=gcc
|
||||
NUM_THREADS=128
|
||||
USE_OPENMP=1
|
||||
NO_SHARED=0
|
||||
|
||||
@ -1 +1 @@
|
||||
3.5.0
|
||||
3.5.1
|
||||
|
||||
@ -272,18 +272,6 @@ def smoke_test_cuda(
|
||||
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
|
||||
print(f"Torch cuDNN version: {torch_cudnn_version}")
|
||||
|
||||
torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion()
|
||||
print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}")
|
||||
torch_cudnn_runtime_version = tuple(
|
||||
[int(x) for x in torch_cudnn_version.split(".")]
|
||||
)
|
||||
if torch_cudnn_runtime_version != torch_cudnn_compile_version:
|
||||
raise RuntimeError(
|
||||
"cuDNN runtime version doesn't match comple version. "
|
||||
f"Loaded: {torch_cudnn_runtime_version} "
|
||||
f"Expected: {torch_cudnn_compile_version}"
|
||||
)
|
||||
|
||||
if sys.platform in ["linux", "linux2"]:
|
||||
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
|
||||
print(f"Torch nccl; version: {torch_nccl_version}")
|
||||
|
||||
2
.github/workflows/docker-builds.yml
vendored
2
.github/workflows/docker-builds.yml
vendored
@ -79,6 +79,8 @@ jobs:
|
||||
include:
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
timeout-minutes: 600
|
||||
|
||||
@ -22,6 +22,9 @@
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#ifdef USE_ROCM
|
||||
#include <ATen/native/hip/ck_group_gemm.h>
|
||||
#endif
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
@ -666,12 +669,19 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
|
||||
use_fast_path = true;
|
||||
}
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
if (use_fast_path) {
|
||||
// fast path, no d2h sync needed
|
||||
#ifndef USE_ROCM
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
|
||||
#endif
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
}
|
||||
|
||||
19
aten/src/ATen/native/hip/ck_group_gemm.h
Normal file
19
aten/src/ATen/native/hip/ck_group_gemm.h
Normal file
@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <optional>
|
||||
|
||||
namespace at {
|
||||
namespace hip {
|
||||
namespace detail {
|
||||
void group_gemm_ck(
|
||||
const at::Tensor& mat_a,
|
||||
const at::Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::Tensor& out);
|
||||
|
||||
} // namespace detail
|
||||
} // namespace hip
|
||||
} // namespace at
|
||||
462
aten/src/ATen/native/hip/ck_group_gemm.hip
Normal file
462
aten/src/ATen/native/hip/ck_group_gemm.hip
Normal file
@ -0,0 +1,462 @@
|
||||
#undef __HIP_NO_HALF_CONVERSIONS__
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/TensorAccessor.h>
|
||||
#include <c10/hip/HIPStream.h>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include <ck/ck.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/utility/tuple.hpp>
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
namespace at {
|
||||
namespace hip {
|
||||
namespace detail {
|
||||
|
||||
namespace CkTypes {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename DataType>
|
||||
using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage<
|
||||
ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor,
|
||||
DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType,
|
||||
CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough,
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2,
|
||||
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
|
||||
3, 8, 8, 1,
|
||||
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
|
||||
3, 8, 8, 1,
|
||||
1, 1,
|
||||
S<1,32,1,8>, 4
|
||||
>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename DataType>
|
||||
void launch_grouped_bgemm_ck_impl_dispatch(
|
||||
const at::Tensor& mat_a,
|
||||
const at::Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
at::Tensor& out)
|
||||
{
|
||||
using DeviceOp = GroupedGemmKernel<ALayout, BLayout, DataType>;
|
||||
using PassThrough = CkTypes::PassThrough;
|
||||
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
std::vector<const void*> p_a_ptrs, p_b_ptrs;
|
||||
std::vector<void*> p_e_ptrs;
|
||||
// Note: d_ptrs will be resized after we populate the other vectors
|
||||
|
||||
const int mat_a_dim = mat_a.dim();
|
||||
const int mat_b_dim = mat_b.dim();
|
||||
|
||||
const char* a_ptr_base = reinterpret_cast<const char*>(mat_a.data_ptr());
|
||||
const char* b_ptr_base = reinterpret_cast<const char*>(mat_b.data_ptr());
|
||||
char* out_ptr_base = reinterpret_cast<char*>(out.data_ptr());
|
||||
const size_t a_element_size = mat_a.element_size();
|
||||
const size_t b_element_size = mat_b.element_size();
|
||||
const size_t out_element_size = out.element_size();
|
||||
|
||||
// for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses.
|
||||
if (mat_a_dim == 2 && mat_b_dim == 2) {
|
||||
// 2D*2D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
const int M = mat_a.size(0); // number of rows in A
|
||||
const int N = mat_b.size(1); // number of columns in B
|
||||
const int K = mat_a.size(1); // columns in A == rows in B
|
||||
// for 2d*2d input, output is 3d.
|
||||
// for each group, A columns (K) are sliced. M and N dimensions are not sliced.
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_k = (i == 0) ? 0 : offs_accessor[i-1];
|
||||
int end_k = offs_accessor[i];
|
||||
int k = end_k - start_k;
|
||||
|
||||
//K dimension are sliced, hence select stride(1) always.
|
||||
//K dimension is always dimension 1, regardless of memory layout (row/column major)
|
||||
const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size;
|
||||
const void* group_b_ptr;
|
||||
int ldb;
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset
|
||||
group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size;
|
||||
// Leading dimension = distance between rows = stride(0)
|
||||
ldb = mat_b.stride(0);
|
||||
} else {
|
||||
// Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset
|
||||
group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size;
|
||||
// Leading dimension = distance between columns = stride(1)
|
||||
ldb = mat_b.stride(1);
|
||||
}
|
||||
|
||||
// Calculate output pointer for group i in 3D tensor [num_groups, M, N]
|
||||
// stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i
|
||||
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
|
||||
int lda, ldc;
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A [M,K]: leading dimension = distance between rows = stride(0)
|
||||
lda = mat_a.stride(0);
|
||||
} else {
|
||||
// Column-major A [M,K]: leading dimension = distance between columns = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
}
|
||||
// Output is always row-major in 3D tensor [num_groups, M, N]
|
||||
// Leading dimension for each group's [M,N] slice = stride(1) = N
|
||||
ldc = out.stride(1);
|
||||
size_t output_group_bytes = M * N * out_element_size;
|
||||
void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes;
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(k),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 2 && mat_b_dim == 3) {
|
||||
// 2D*3D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
|
||||
// 2d*3d input, output is 2d.
|
||||
// A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n]
|
||||
// Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B
|
||||
const int K = mat_a.size(1); // columns in A
|
||||
// For 2D-3D case: The output determines N (result width)
|
||||
const int N = out.size(1); // N is the width of the output tensor
|
||||
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_m = (i == 0) ? 0 : offs_accessor[i - 1];
|
||||
int end_m = offs_accessor[i];
|
||||
int m = end_m - start_m;
|
||||
|
||||
// Skip zero-sized groups but continue processing subsequent groups
|
||||
if (m <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Select A rows for group i: skip start_m rows
|
||||
const void* group_a_ptr;
|
||||
int lda;
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart
|
||||
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
|
||||
lda = mat_a.stride(0); // distance between rows
|
||||
} else {
|
||||
// Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows)
|
||||
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Detect stride pattern for A tensor to determine appropriate lda calculation
|
||||
bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0));
|
||||
|
||||
if (a_is_strided_tensor) {
|
||||
// For strided A tensors: stride(0) gives the actual leading dimension
|
||||
lda = mat_a.stride(0);
|
||||
} else {
|
||||
// For non-strided A tensors: use the M dimension (total rows)
|
||||
lda = mat_a.size(0); // Total M dimension for column-major layout
|
||||
}
|
||||
}
|
||||
|
||||
// Select B batch for group i: B[i, :, :]
|
||||
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
|
||||
int ldb;
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed
|
||||
ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N]
|
||||
} else {
|
||||
// Detect stride pattern to determine appropriate ldb calculation
|
||||
bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2));
|
||||
|
||||
if (is_strided_tensor) {
|
||||
// For strided tensors: stride(2) gives the actual leading dimension
|
||||
ldb = mat_b.stride(2);
|
||||
} else {
|
||||
// For non-strided tensors: use the N dimension
|
||||
ldb = mat_b.size(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N]
|
||||
void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size;
|
||||
int ldc = out.stride(0); // distance between rows in output (should be N for 2D case)
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(m),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 3 && mat_b_dim == 3) {
|
||||
// 3d*3d input, output is 3d - batched matrix multiplication
|
||||
// A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n]
|
||||
// Each batch is processed as a separate GEMM operation
|
||||
const int batch_size = mat_a.size(0);
|
||||
const int M = mat_a.size(1); // rows in each A matrix
|
||||
const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed)
|
||||
|
||||
// Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout
|
||||
int N;
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
N = mat_b.size(2);
|
||||
} else if (mat_b.size(2) == K) {
|
||||
// B is [batch, n, k] - transposed layout
|
||||
N = mat_b.size(1);
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[",
|
||||
batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]");
|
||||
}
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
// Select A batch for group i: A[i, :, :]
|
||||
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Select B batch for group i: B[i, :, :]
|
||||
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
|
||||
|
||||
// Select output batch for group i: Output[i, :, :]
|
||||
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
|
||||
|
||||
int lda, ldb, ldc;
|
||||
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A: leading dimension = distance between rows = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
} else {
|
||||
// Column-major A: leading dimension = distance between columns = stride(2)
|
||||
lda = mat_a.stride(2);
|
||||
}
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B: leading dimension = distance between rows
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
ldb = mat_b.stride(1); // stride between K rows
|
||||
} else {
|
||||
// B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM
|
||||
ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n])
|
||||
}
|
||||
} else {
|
||||
// Column-major B: leading dimension = distance between columns
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
ldb = mat_b.stride(2); // stride between N columns
|
||||
} else {
|
||||
// B is [batch, n, k] - transposed layout
|
||||
ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n])
|
||||
}
|
||||
}
|
||||
|
||||
// Output is typically row-major: leading dimension = distance between rows = stride(1)
|
||||
ldc = out.stride(1);
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 3 && mat_b_dim == 2) {
|
||||
// 3D*2D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
// 3d*2d input, output is 3d.
|
||||
// A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both)
|
||||
// Offset divides N dimension of B, each group gets different slice of B and different batch of A
|
||||
const int batch_size = mat_a.size(0); // n_groups
|
||||
const int M = mat_a.size(1); // rows in each A matrix
|
||||
const int K = mat_a.size(2); // columns in A
|
||||
|
||||
// For row-major A and B case: B should be [K, total_N]
|
||||
const int total_N = mat_b.size(1); // B is [K, total_N] for row-major
|
||||
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_n = (i == 0) ? 0 : offs_accessor[i - 1];
|
||||
int end_n = offs_accessor[i];
|
||||
int n = end_n - start_n;
|
||||
|
||||
// Skip zero-sized groups but continue processing subsequent groups
|
||||
if (n <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Select A batch for group i: A[i, :, :]
|
||||
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Select B slice for group i: B[:, start_n:end_n] (B[K, total_N])
|
||||
const void* group_b_ptr;
|
||||
int ldb;
|
||||
|
||||
// Check if B is row-major or column-major
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B [K, total_N]: slice columns [start_n:end_n]
|
||||
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
|
||||
ldb = mat_b.stride(0); // distance between rows (should be total_N)
|
||||
} else {
|
||||
// Column-major B [K, total_N]: slice columns [start_n:end_n]
|
||||
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
|
||||
ldb = mat_b.stride(1); // distance between columns (should be K)
|
||||
}
|
||||
|
||||
// Select output slice for group i: Output[:, start_n:end_n]
|
||||
void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size;
|
||||
|
||||
int lda, ldc;
|
||||
|
||||
// Row-major A: leading dimension = distance between rows = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
// Output is row-major: leading dimension = distance between rows = stride(0)
|
||||
ldc = out.stride(0);
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(n),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim);
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups");
|
||||
|
||||
// Initialize d_ptrs with the correct size
|
||||
std::vector<std::array<const void*, 0>> d_ptrs(p_a_ptrs.size());
|
||||
|
||||
static DeviceOp gemm_instance;
|
||||
auto argument = gemm_instance.MakeArgument(
|
||||
p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs,
|
||||
gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}
|
||||
);
|
||||
TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument),
|
||||
"CK Group GEMM: argument unsupported (shape/strides/type config)");
|
||||
size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument);
|
||||
size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument);
|
||||
|
||||
void* gemm_arg_buf = nullptr;
|
||||
void* ws_buf = nullptr;
|
||||
|
||||
hipMalloc(&gemm_arg_buf, arg_buf_size);
|
||||
hipMalloc(&ws_buf, ws_size);
|
||||
|
||||
gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf);
|
||||
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
|
||||
|
||||
auto invoker = gemm_instance.MakeInvoker();
|
||||
hipStream_t stream = c10::hip::getCurrentHIPStream();
|
||||
invoker.Run(argument, {stream});
|
||||
hipFree(gemm_arg_buf);
|
||||
hipFree(ws_buf);
|
||||
}
|
||||
|
||||
void group_gemm_ck(
|
||||
const at::Tensor& input_a,
|
||||
const at::Tensor& input_b_colmajor,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& /*bias*/,
|
||||
at::Tensor& out)
|
||||
{
|
||||
// Detect if input_a is row-major based on stride pattern
|
||||
bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1);
|
||||
bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1);
|
||||
// Ensure tensor A is row-major and contiguous if not already
|
||||
at::Tensor mat_a = input_a;
|
||||
if (!a_row_major) {
|
||||
// If A is not row-major, make it contiguous (row-major)
|
||||
mat_a = input_a.contiguous();
|
||||
}
|
||||
// Force tensor B to be column-major using double transpose trick
|
||||
// This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape
|
||||
at::Tensor mat_b = input_b_colmajor;
|
||||
if (!b_col_major) {
|
||||
mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1);
|
||||
}
|
||||
|
||||
// For 3D tensors, check the last dimension stride for row-major detection
|
||||
a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1);
|
||||
bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1);
|
||||
|
||||
if (mat_a.dtype() == at::kBFloat16) {
|
||||
// bf16 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else if (mat_a.dtype() == at::kHalf) {
|
||||
// fp16 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else if (mat_a.dtype() == at::kFloat) {
|
||||
// fp32 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace hip
|
||||
} // namespace at
|
||||
@ -5,8 +5,16 @@ import contextlib
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -426,6 +434,31 @@ class TestDTensorDebugMode(TestCase):
|
||||
][-1]
|
||||
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
|
||||
|
||||
def test_pretty_print_dtensor_make_fx(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
A = torch.randn(8, 32)
|
||||
B = torch.randn(32, 32)
|
||||
dA = distribute_tensor(A, mesh, [Shard(0)]).requires_grad_()
|
||||
dB = distribute_tensor(B, mesh, [Replicate()]).requires_grad_()
|
||||
|
||||
def f(dA, dB):
|
||||
dy = dA @ dB
|
||||
loss = dy.sum()
|
||||
loss.backward()
|
||||
return dA.grad, dB.grad
|
||||
|
||||
# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
|
||||
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
|
||||
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
|
||||
gm = make_fx(f, tracing_mode="fake")(dA, dB)
|
||||
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.recompile()
|
||||
# Colored is nice for actual viewing, not using in this test though
|
||||
gm_str = gm.print_readable(colored=False, print_output=False)
|
||||
self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDTensorDebugMode)
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca
|
||||
torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
|
||||
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node
|
||||
torch.fx.graph.Graph.print_tabular(self)
|
||||
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode
|
||||
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode
|
||||
torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')
|
||||
torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool
|
||||
torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None
|
||||
|
||||
@ -14424,6 +14424,20 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
|
||||
self.common(fn, (torch.randn(6, 4, device=GPU_TYPE).t().contiguous().t(),))
|
||||
|
||||
@skip_if_halide
|
||||
@requires_cuda_and_triton
|
||||
def test_unbacked_float_item(self):
|
||||
def fn(x, max_val):
|
||||
return torch.clamp(x, 0, max_val.item())
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
(
|
||||
torch.randn(10, 20, 30, device=self.device),
|
||||
torch.tensor(5.0, device=self.device),
|
||||
),
|
||||
)
|
||||
|
||||
# end of class CommonTemplate - add new tests here
|
||||
|
||||
|
||||
|
||||
176
test/test_as_strided.py
Normal file
176
test/test_as_strided.py
Normal file
@ -0,0 +1,176 @@
|
||||
# Owner(s): ["oncall: pt2"]
|
||||
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
def get_state(t: torch.Tensor) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
"""Extract (sizes, strides) tuple from a tensor."""
|
||||
return (tuple(t.size()), tuple(t.stride()))
|
||||
|
||||
|
||||
def enumerate_reachable_states(
|
||||
initial_size: int,
|
||||
) -> set[tuple[tuple[int, ...], tuple[int, ...]]]:
|
||||
"""
|
||||
Use BFS with DP to enumerate all reachable (size, stride) states from
|
||||
a 1D contiguous tensor via valid view operations.
|
||||
|
||||
We only explore states with offset=0 (you can retroactively change the offset).
|
||||
We reject states with size=0 or size=1 dimensions as they are degenerate.
|
||||
"""
|
||||
# Create initial 1D contiguous tensor
|
||||
initial_tensor = torch.arange(initial_size)
|
||||
|
||||
initial_state = get_state(initial_tensor)
|
||||
|
||||
# Map from state to tensor for that state
|
||||
state_to_tensor: dict[tuple[tuple[int, ...], tuple[int, ...]], torch.Tensor] = {
|
||||
initial_state: initial_tensor
|
||||
}
|
||||
visited: set[tuple[tuple[int, ...], tuple[int, ...]]] = {initial_state}
|
||||
queue: deque[tuple[tuple[int, ...], tuple[int, ...]]] = deque([initial_state])
|
||||
|
||||
while queue:
|
||||
state = queue.popleft()
|
||||
t = state_to_tensor[state]
|
||||
sizes, strides = state
|
||||
ndim = len(sizes)
|
||||
|
||||
def add_state(new_t: torch.Tensor) -> None:
|
||||
new_state = get_state(new_t)
|
||||
sizes, strides = new_state
|
||||
# Skip if has size-0 or size-1 dimensions
|
||||
if any(s == 0 or s == 1 for s in sizes):
|
||||
return
|
||||
# Only accept states where strides are in descending order
|
||||
if list(strides) != sorted(strides, reverse=True):
|
||||
return
|
||||
if new_state not in visited:
|
||||
visited.add(new_state)
|
||||
queue.append(new_state)
|
||||
state_to_tensor[new_state] = new_t
|
||||
|
||||
# 1. Unflatten: try factoring each dimension
|
||||
for dim in range(ndim):
|
||||
size = sizes[dim]
|
||||
assert size > 1
|
||||
# Try all factorizations x * y = size where both x, y >= 2
|
||||
# We only need to check x up to size // 2 since when x > size // 2,
|
||||
# y = size // x < 2, which we reject
|
||||
for x in range(2, size // 2 + 1):
|
||||
if size % x == 0:
|
||||
y = size // x
|
||||
add_state(t.unflatten(dim, (x, y)))
|
||||
|
||||
# 2. Slice: exhaustively check all possible slicing parameters
|
||||
for dim in range(ndim):
|
||||
size = sizes[dim]
|
||||
for start in range(size):
|
||||
for stop in range(start + 1, size + 1):
|
||||
for step in range(1, size + 1):
|
||||
slices = [slice(None)] * ndim
|
||||
slices[dim] = slice(start, stop, step)
|
||||
add_state(t[tuple(slices)])
|
||||
|
||||
# 3. Flatten: merge adjacent dimensions
|
||||
for dim in range(ndim - 1):
|
||||
add_state(t.flatten(dim, dim + 1))
|
||||
|
||||
return visited
|
||||
|
||||
|
||||
class TestAsStrided(TestCase):
|
||||
def test_size_10_exhaustive(self) -> None:
|
||||
"""Test that size 10 produces exactly the expected 54 states."""
|
||||
expected_states = {
|
||||
((2,), (1,)),
|
||||
((2,), (2,)),
|
||||
((2,), (3,)),
|
||||
((2,), (4,)),
|
||||
((2,), (5,)),
|
||||
((2,), (6,)),
|
||||
((2,), (7,)),
|
||||
((2,), (8,)),
|
||||
((2,), (9,)),
|
||||
((2, 2), (2, 1)),
|
||||
((2, 2), (3, 1)),
|
||||
((2, 2), (3, 2)),
|
||||
((2, 2), (4, 1)),
|
||||
((2, 2), (4, 2)),
|
||||
((2, 2), (4, 3)),
|
||||
((2, 2), (5, 1)),
|
||||
((2, 2), (5, 2)),
|
||||
((2, 2), (5, 3)),
|
||||
((2, 2), (5, 4)),
|
||||
((2, 2), (6, 1)),
|
||||
((2, 2), (6, 2)),
|
||||
((2, 2), (6, 3)),
|
||||
((2, 2), (8, 1)),
|
||||
((2, 2, 2), (4, 2, 1)),
|
||||
((2, 2, 2), (5, 2, 1)),
|
||||
((2, 3), (3, 1)),
|
||||
((2, 3), (4, 1)),
|
||||
((2, 3), (5, 1)),
|
||||
((2, 3), (5, 2)),
|
||||
((2, 3), (6, 1)),
|
||||
((2, 4), (4, 1)),
|
||||
((2, 4), (5, 1)),
|
||||
((2, 5), (5, 1)),
|
||||
((3,), (1,)),
|
||||
((3,), (2,)),
|
||||
((3,), (3,)),
|
||||
((3,), (4,)),
|
||||
((3, 2), (2, 1)),
|
||||
((3, 2), (3, 1)),
|
||||
((3, 2), (3, 2)),
|
||||
((3, 2), (4, 1)),
|
||||
((3, 3), (3, 1)),
|
||||
((4,), (1,)),
|
||||
((4,), (2,)),
|
||||
((4,), (3,)),
|
||||
((4, 2), (2, 1)),
|
||||
((5,), (1,)),
|
||||
((5,), (2,)),
|
||||
((5, 2), (2, 1)),
|
||||
((6,), (1,)),
|
||||
((7,), (1,)),
|
||||
((8,), (1,)),
|
||||
((9,), (1,)),
|
||||
((10,), (1,)),
|
||||
}
|
||||
|
||||
actual_states = enumerate_reachable_states(10)
|
||||
|
||||
self.assertEqual(len(actual_states), 54)
|
||||
self.assertEqual(actual_states, expected_states)
|
||||
|
||||
def test_subset_property(self) -> None:
|
||||
"""
|
||||
Test that for sizes 2..10, each smaller tensor results in a strict
|
||||
subset of possible states compared to the next one.
|
||||
"""
|
||||
prev_states: Optional[set[tuple[tuple[int, ...], tuple[int, ...]]]] = None
|
||||
for size in range(2, 11):
|
||||
current_states = enumerate_reachable_states(size)
|
||||
|
||||
if prev_states is not None:
|
||||
# Check that prev_states is a strict subset of current_states
|
||||
self.assertTrue(
|
||||
prev_states.issubset(current_states),
|
||||
f"States from size {size - 1} are not a subset of size {size}",
|
||||
)
|
||||
# Check that it's a strict subset (not equal)
|
||||
self.assertTrue(
|
||||
len(prev_states) < len(current_states),
|
||||
f"States from size {size - 1} should be strictly fewer than size {size}",
|
||||
)
|
||||
|
||||
prev_states = current_states
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
180
test/test_fx.py
180
test/test_fx.py
@ -75,12 +75,6 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
|
||||
from torch.autograd.profiler_util import _canonicalize_profiler_events
|
||||
|
||||
try:
|
||||
from torchvision import models as torchvision_models
|
||||
|
||||
@ -207,36 +201,6 @@ def side_effect_func(x: torch.Tensor):
|
||||
print(x)
|
||||
|
||||
|
||||
def _enrich_profiler_traces(prof):
|
||||
"""
|
||||
Helper function to extract and augment profiler events with stack traces.
|
||||
|
||||
Args:
|
||||
prof: A torch.profiler.profile object
|
||||
|
||||
Returns:
|
||||
A string representing enriched events
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
|
||||
trace_file = f.name
|
||||
prof.export_chrome_trace(trace_file)
|
||||
|
||||
with open(trace_file) as f:
|
||||
trace_data = json.load(f)
|
||||
|
||||
map_recorded_events_to_aten_ops_with_stack_trace(
|
||||
trace_data
|
||||
)
|
||||
|
||||
events = []
|
||||
for event in trace_data["traceEvents"]:
|
||||
if "args" in event and "stack_trace" in event["args"]:
|
||||
events.append(event)
|
||||
|
||||
actual_traces = _canonicalize_profiler_events(events)
|
||||
return actual_traces
|
||||
|
||||
|
||||
class TestFX(JitTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -4248,150 +4212,6 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
||||
# recorver mutable checking flag
|
||||
torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_stack_trace_augmentation(self):
|
||||
"""
|
||||
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
|
||||
augments profiler events with stack traces from FX metadata registry.
|
||||
"""
|
||||
|
||||
# Simple test model
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.linear2 = torch.nn.Linear(16, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
model = TestModel().cuda()
|
||||
|
||||
# Compile the model
|
||||
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
_ = compiled_model(torch.randn(10, 10, device="cuda"))
|
||||
|
||||
# Profile with the compiled model
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
) as prof:
|
||||
result = compiled_model(torch.randn(10, 10, device="cuda"))
|
||||
|
||||
actual_traces = _enrich_profiler_traces(prof)
|
||||
|
||||
self.assertExpectedInline(actual_traces, """\
|
||||
event=aten::t node=t stack_trace=x = self.linear1(x)
|
||||
event=aten::transpose node=t stack_trace=x = self.linear1(x)
|
||||
event=aten::as_strided node=t stack_trace=x = self.linear1(x)
|
||||
event=aten::addmm node=addmm stack_trace=x = self.linear1(x)
|
||||
event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x)
|
||||
event=aten::relu node=relu stack_trace=x = self.relu(x)
|
||||
event=aten::clamp_min node=relu stack_trace=x = self.relu(x)
|
||||
event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x)
|
||||
event=aten::t node=t_1 stack_trace=x = self.linear2(x)
|
||||
event=aten::transpose node=t_1 stack_trace=x = self.linear2(x)
|
||||
event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x)
|
||||
event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x)
|
||||
event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_multiple_modules(self):
|
||||
"""
|
||||
Test that multiple compiled modules under the same profiler session
|
||||
have their events correctly augmented with stack traces.
|
||||
"""
|
||||
|
||||
class ModelA(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1
|
||||
|
||||
class ModelB(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x - 1
|
||||
|
||||
model_a = ModelA().cuda()
|
||||
model_b = ModelB().cuda()
|
||||
|
||||
# Compile both models
|
||||
compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True)
|
||||
compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True)
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
_ = compiled_a(torch.randn(10, 10, device="cuda"))
|
||||
_ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
|
||||
|
||||
# Profile both models in the same session
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
) as prof:
|
||||
result_a = compiled_a(torch.randn(10, 10, device="cuda"))
|
||||
result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
|
||||
|
||||
actual_traces = _enrich_profiler_traces(prof)
|
||||
self.assertExpectedInline(actual_traces, """\
|
||||
event=aten::add node=add stack_trace=return x + 1
|
||||
event=cudaLaunchKernel node=add stack_trace=return x + 1
|
||||
event=aten::sub node=sub stack_trace=return x - 1
|
||||
event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_nested_graph_modules(self):
|
||||
"""
|
||||
Test that nested graph modules (e.g., graph modules calling subgraphs)
|
||||
have their events correctly augmented with stack traces.
|
||||
"""
|
||||
|
||||
# Model with nested structure
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c = 5
|
||||
|
||||
@torch.compiler.nested_compile_region
|
||||
def forward(self, x, y):
|
||||
m = torch.mul(x, y)
|
||||
s = m.sin()
|
||||
a = s + self.c
|
||||
return a
|
||||
|
||||
model = Mod().cuda()
|
||||
|
||||
# Compile the model (this may create nested graph modules)
|
||||
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
_ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
|
||||
|
||||
# Profile
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
) as prof:
|
||||
result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
|
||||
|
||||
actual_traces = _enrich_profiler_traces(prof)
|
||||
self.assertExpectedInline(actual_traces, """\
|
||||
event=aten::mul node=mul stack_trace=m = torch.mul(x, y)
|
||||
event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y)
|
||||
event=aten::sin node=sin stack_trace=s = m.sin()
|
||||
event=cudaLaunchKernel node=sin stack_trace=s = m.sin()
|
||||
event=aten::add node=add stack_trace=a = s + self.c
|
||||
event=cudaLaunchKernel node=add stack_trace=a = s + self.c"""
|
||||
)
|
||||
|
||||
|
||||
def run_getitem_target():
|
||||
from torch.fx._symbolic_trace import _wrapped_methods_to_patch
|
||||
|
||||
@ -490,8 +490,6 @@ class TestMatmulCuda(InductorTestCase):
|
||||
@parametrize("b_row_major", [False, True])
|
||||
@dtypes(torch.bfloat16, torch.float32, torch.float16)
|
||||
def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype):
|
||||
if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]:
|
||||
self.skipTest("failed using hipblaslt on rocm 6.4.2")
|
||||
device = "cuda"
|
||||
s_int = int(strided)
|
||||
m, n, k, n_groups = 16, 32, 64, 4
|
||||
|
||||
@ -3320,7 +3320,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg]
|
||||
assert isinstance(obj, SetVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "add", [v], {})
|
||||
obj.call_method(self, "add", [v], {}) # type: ignore[arg-type]
|
||||
|
||||
def SET_UPDATE(self, inst: Instruction) -> None:
|
||||
v = self.pop()
|
||||
@ -3329,7 +3329,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg]
|
||||
assert isinstance(obj, SetVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "update", [v], {})
|
||||
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
|
||||
|
||||
def LIST_APPEND(self, inst: Instruction) -> None:
|
||||
v = self.pop()
|
||||
@ -3637,7 +3637,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg].realize()
|
||||
assert isinstance(obj, ConstDictVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "update", [v], {})
|
||||
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
|
||||
|
||||
DICT_UPDATE = DICT_MERGE
|
||||
|
||||
|
||||
@ -1991,7 +1991,7 @@ class BuiltinVariable(VariableTracker):
|
||||
# If the object implements a __getitem__ method, iter(...) will call obj.__getitem__()
|
||||
# with an integer argument starting at 0, until __getitem__ raises IndexError
|
||||
ret = variables.UserFunctionVariable(
|
||||
polyfills.builtins.iter_
|
||||
polyfills.builtins.iter_ # type: ignore[arg-type]
|
||||
).call_function(tx, [obj, *args], {})
|
||||
|
||||
if args:
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Dictionary-related variable tracking classes for PyTorch Dynamo.
|
||||
|
||||
@ -26,7 +24,7 @@ import inspect
|
||||
import operator
|
||||
import types
|
||||
from collections.abc import Hashable as py_Hashable
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch._subclasses.fake_tensor import is_fake
|
||||
|
||||
@ -59,11 +57,13 @@ if TYPE_CHECKING:
|
||||
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
|
||||
|
||||
|
||||
def was_instancecheck_override(obj):
|
||||
def was_instancecheck_override(obj: Any) -> bool:
|
||||
return type(obj).__dict__.get("__instancecheck__", False)
|
||||
|
||||
|
||||
def raise_unhashable(arg, tx=None):
|
||||
def raise_unhashable(
|
||||
arg: VariableTracker, tx: Optional["InstructionTranslator"] = None
|
||||
) -> None:
|
||||
if tx is None:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
@ -75,7 +75,7 @@ def raise_unhashable(arg, tx=None):
|
||||
)
|
||||
|
||||
|
||||
def is_hashable(x):
|
||||
def is_hashable(x: VariableTracker) -> bool:
|
||||
# NB - performing isinstance check on a LazVT realizes the VT, accidentally
|
||||
# inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at
|
||||
# the underlying value without realizing the VT. Consider updating the
|
||||
@ -143,7 +143,7 @@ class ConstDictVariable(VariableTracker):
|
||||
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
|
||||
"""
|
||||
|
||||
def __init__(self, vt) -> None:
|
||||
def __init__(self, vt: VariableTracker) -> None:
|
||||
# We specialize SymNodes
|
||||
vt = specialize_symnode(vt)
|
||||
# TODO Temporarily remove to figure out what keys are we breaking on
|
||||
@ -153,7 +153,7 @@ class ConstDictVariable(VariableTracker):
|
||||
self.vt = vt
|
||||
|
||||
@property
|
||||
def underlying_value(self):
|
||||
def underlying_value(self) -> Any:
|
||||
if (
|
||||
isinstance(self.vt, variables.LazyVariableTracker)
|
||||
and not self.vt.is_realized()
|
||||
@ -178,7 +178,8 @@ class ConstDictVariable(VariableTracker):
|
||||
elif isinstance(self.vt, variables.FrozenDataClassVariable):
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
fields_values = {
|
||||
k: Hashable(v).underlying_value for k, v in self.vt.fields.items()
|
||||
k: Hashable(v).underlying_value
|
||||
for k, v in self.vt.fields.items() # type: ignore[attr-defined]
|
||||
}
|
||||
return variables.FrozenDataClassVariable.HashWrapper(
|
||||
self.vt.python_type(), fields_values
|
||||
@ -187,16 +188,16 @@ class ConstDictVariable(VariableTracker):
|
||||
# The re module in Python 3.13+ has a dictionary (_cache2) with
|
||||
# an object as key (`class _ZeroSentinel(int): ...`):
|
||||
# python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual
|
||||
return self.vt.value
|
||||
return self.vt.value # type: ignore[attr-defined,union-attr]
|
||||
else:
|
||||
x = self.vt.as_python_constant()
|
||||
return x
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.underlying_value)
|
||||
|
||||
@staticmethod
|
||||
def _eq_impl(a, b):
|
||||
def _eq_impl(a: Any, b: Any) -> bool:
|
||||
# TODO: Put this in utils and share it between variables/builtin.py and here
|
||||
type_a, type_b = type(a), type(b)
|
||||
if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)):
|
||||
@ -212,7 +213,7 @@ class ConstDictVariable(VariableTracker):
|
||||
else:
|
||||
return a == b
|
||||
|
||||
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), (
|
||||
type(other)
|
||||
@ -226,8 +227,8 @@ class ConstDictVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
items: dict[VariableTracker, VariableTracker],
|
||||
user_cls=dict,
|
||||
**kwargs,
|
||||
user_cls: type = dict,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# .clone() pass these arguments in kwargs but they're recreated a few
|
||||
# lines below
|
||||
@ -247,18 +248,22 @@ class ConstDictVariable(VariableTracker):
|
||||
for x, v in items.items()
|
||||
)
|
||||
|
||||
def make_hashable(key):
|
||||
def make_hashable(
|
||||
key: Union[VariableTracker, "ConstDictVariable._HashableTracker"],
|
||||
) -> "ConstDictVariable._HashableTracker":
|
||||
return key if isinstance(key, Hashable) else Hashable(key)
|
||||
|
||||
dict_cls = self._get_dict_cls_from_user_cls(user_cls)
|
||||
self.items = dict_cls({make_hashable(x): v for x, v in items.items()})
|
||||
# need to reconstruct everything if the dictionary is an intermediate value
|
||||
# or if a pop/delitem was executed
|
||||
self.should_reconstruct_all = not is_from_local_source(self.source)
|
||||
self.should_reconstruct_all = (
|
||||
not is_from_local_source(self.source) if self.source else True
|
||||
)
|
||||
self.original_items = items.copy()
|
||||
self.user_cls = user_cls
|
||||
|
||||
def _get_dict_cls_from_user_cls(self, user_cls):
|
||||
def _get_dict_cls_from_user_cls(self, user_cls: type) -> type:
|
||||
accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict)
|
||||
|
||||
# avoid executing user code if user_cls is a dict subclass
|
||||
@ -277,10 +282,10 @@ class ConstDictVariable(VariableTracker):
|
||||
dict_cls = dict
|
||||
return dict_cls
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> dict[Any, Any]:
|
||||
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
return (
|
||||
"{"
|
||||
+ ", ".join(
|
||||
@ -289,20 +294,20 @@ class ConstDictVariable(VariableTracker):
|
||||
+ "}"
|
||||
)
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> dict[Any, Any]:
|
||||
return {
|
||||
k.vt.as_python_constant(): v.as_python_constant()
|
||||
for k, v in self.items.items()
|
||||
}
|
||||
|
||||
def keys_as_python_constant(self):
|
||||
def keys_as_python_constant(self) -> dict[Any, VariableTracker]:
|
||||
self.install_dict_keys_match_guard()
|
||||
return {k.vt.as_python_constant(): v for k, v in self.items.items()}
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return self.user_cls
|
||||
|
||||
def __contains__(self, vt) -> bool:
|
||||
def __contains__(self, vt: VariableTracker) -> bool:
|
||||
assert isinstance(vt, VariableTracker)
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
return (
|
||||
@ -322,13 +327,15 @@ class ConstDictVariable(VariableTracker):
|
||||
for key, value in self.items.items()
|
||||
)
|
||||
|
||||
def is_new_item(self, value, other):
|
||||
def is_new_item(
|
||||
self, value: Optional[VariableTracker], other: VariableTracker
|
||||
) -> bool:
|
||||
# compare the id of the realized values if both values are not lazy VTs
|
||||
if value and value.is_realized() and other.is_realized():
|
||||
return id(value.realize()) != id(other.realize())
|
||||
return id(value) != id(other)
|
||||
|
||||
def reconstruct_kvs_into_new_dict(self, codegen):
|
||||
def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None:
|
||||
# Build a dictionary that contains the keys and values.
|
||||
num_args = 0
|
||||
for key, value in self.items.items():
|
||||
@ -340,7 +347,7 @@ class ConstDictVariable(VariableTracker):
|
||||
num_args += 1
|
||||
codegen.append_output(create_instruction("BUILD_MAP", arg=num_args))
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
if self.user_cls is collections.OrderedDict:
|
||||
# emit `OrderedDict(constructed_dict)`
|
||||
codegen.add_push_null(
|
||||
@ -358,19 +365,21 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
def getitem_const_raise_exception_if_absent(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
):
|
||||
) -> VariableTracker:
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
raise_observed_exception(KeyError, tx)
|
||||
return self.items[key]
|
||||
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
msg = f"Dictionary key {arg.value} not found during tracing"
|
||||
msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined]
|
||||
unimplemented_v2(
|
||||
gb_type="key not found in dict",
|
||||
context=f"Key {arg.value}",
|
||||
context=f"Key {arg.value}", # type: ignore[attr-defined]
|
||||
explanation=msg,
|
||||
hints=[
|
||||
"Check if the key exists in the dictionary before accessing it.",
|
||||
@ -379,13 +388,13 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
return self.items[key]
|
||||
|
||||
def maybe_getitem_const(self, arg: VariableTracker):
|
||||
def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]:
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
return None
|
||||
return self.items[key]
|
||||
|
||||
def realize_key_vt(self, arg: VariableTracker):
|
||||
def realize_key_vt(self, arg: VariableTracker) -> None:
|
||||
# Realize the LazyVT on a particular index
|
||||
assert arg in self
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
@ -394,11 +403,13 @@ class ConstDictVariable(VariableTracker):
|
||||
if isinstance(original_key_vt, variables.LazyVariableTracker):
|
||||
original_key_vt.realize()
|
||||
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
if self.source:
|
||||
install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH))
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
# Key guarding - These are the cases to consider
|
||||
# 1) The dict has been mutated. In this case, we would have already
|
||||
# inserted a DICT_KEYS_MATCH guard, so we can skip.
|
||||
@ -439,11 +450,11 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
# NB - Both key and value are LazyVariableTrackers in the beginning. So,
|
||||
# we have to insert guards when a dict method is accessed. For this to
|
||||
# be simple, we are conservative and overguard. We skip guard only for
|
||||
@ -462,7 +473,7 @@ class ConstDictVariable(VariableTracker):
|
||||
tx, *args, **kwargs
|
||||
)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items.update(temp_dict_vt.items)
|
||||
self.items.update(temp_dict_vt.items) # type: ignore[attr-defined]
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "__getitem__":
|
||||
# Key guarding - Nothing to do. LazyVT for value will take care.
|
||||
@ -526,7 +537,7 @@ class ConstDictVariable(VariableTracker):
|
||||
return ConstantVariable.create(len(self.items))
|
||||
elif name == "__setitem__" and self.is_mutable():
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
self.install_dict_keys_match_guard()
|
||||
if kwargs or len(args) != 2:
|
||||
@ -550,7 +561,7 @@ class ConstDictVariable(VariableTracker):
|
||||
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
if args[0] not in self:
|
||||
self.install_dict_contains_guard(tx, args)
|
||||
@ -565,7 +576,7 @@ class ConstDictVariable(VariableTracker):
|
||||
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
if args[0] not in self:
|
||||
# missing item, return the default value. Install no DICT_CONTAINS guard.
|
||||
@ -599,7 +610,7 @@ class ConstDictVariable(VariableTracker):
|
||||
last = v.value
|
||||
else:
|
||||
raise_args_mismatch(tx, name)
|
||||
k, v = self.items.popitem(last=last)
|
||||
k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined]
|
||||
else:
|
||||
k, v = self.items.popitem()
|
||||
|
||||
@ -632,17 +643,17 @@ class ConstDictVariable(VariableTracker):
|
||||
# NB - Guard on all the keys of the other dict to ensure
|
||||
# correctness.
|
||||
args[0].install_dict_keys_match_guard()
|
||||
dict_vt = args[0]
|
||||
dict_vt: ConstDictVariable = args[0]
|
||||
else:
|
||||
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
|
||||
self.items.update(dict_vt.items)
|
||||
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment]
|
||||
self.items.update(dict_vt.items) # type: ignore[attr-defined]
|
||||
if has_kwargs:
|
||||
# Handle kwargs
|
||||
kwargs = {
|
||||
kwargs_hashable = {
|
||||
Hashable(ConstantVariable.create(k)): v
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
self.items.update(kwargs)
|
||||
self.items.update(kwargs_hashable)
|
||||
return ConstantVariable.create(None)
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
@ -656,7 +667,7 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
self.install_dict_contains_guard(tx, args)
|
||||
contains = args[0] in self
|
||||
@ -671,7 +682,7 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
self.install_dict_keys_match_guard()
|
||||
if kwargs or len(args) > 2:
|
||||
@ -707,7 +718,7 @@ class ConstDictVariable(VariableTracker):
|
||||
and "last" in kwargs
|
||||
and isinstance(kwargs["last"], ConstantVariable)
|
||||
):
|
||||
last = kwargs.get("last").value
|
||||
last = kwargs.get("last").value # type: ignore[union-attr]
|
||||
|
||||
key = Hashable(args[0])
|
||||
self.items.move_to_end(key, last=last)
|
||||
@ -723,7 +734,7 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
elif name == "__ne__":
|
||||
return ConstantVariable.create(
|
||||
not self.call_method(tx, "__eq__", args, kwargs).value
|
||||
not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined]
|
||||
)
|
||||
elif name == "__or__":
|
||||
if len(args) != 1:
|
||||
@ -750,14 +761,14 @@ class ConstDictVariable(VariableTracker):
|
||||
if not istype(
|
||||
other, (ConstDictVariable, variables.UserDefinedDictVariable)
|
||||
):
|
||||
msg = (
|
||||
err_msg = (
|
||||
f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
|
||||
f"and '{other.python_type().__name__}'"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
raise_observed_exception(TypeError, tx, args=[err_msg])
|
||||
|
||||
# OrderedDict overloads __ror__
|
||||
ts = {self.user_cls, other.user_cls}
|
||||
ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined]
|
||||
user_cls = (
|
||||
collections.OrderedDict
|
||||
if any(issubclass(t, collections.OrderedDict) for t in ts)
|
||||
@ -774,8 +785,8 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
# NB - Guard on all the keys of the other dict to ensure
|
||||
# correctness.
|
||||
args[0].install_dict_keys_match_guard()
|
||||
new_dict_vt.items.update(args[0].items)
|
||||
args[0].install_dict_keys_match_guard() # type: ignore[attr-defined]
|
||||
new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined]
|
||||
return new_dict_vt
|
||||
elif name == "__ior__":
|
||||
self.call_method(tx, "update", args, kwargs)
|
||||
@ -789,11 +800,13 @@ class ConstDictVariable(VariableTracker):
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
self.install_dict_keys_match_guard()
|
||||
return [x.vt for x in self.items.keys()]
|
||||
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
# dict not allow setting arbitrary attributes. OrderedDict and
|
||||
# defaultdict allow arbitrary setattr, but not deletion of default attrs
|
||||
if any(
|
||||
@ -816,25 +829,25 @@ class ConstDictVariable(VariableTracker):
|
||||
],
|
||||
)
|
||||
|
||||
def clone(self, **kwargs):
|
||||
def clone(self, **kwargs: Any) -> VariableTracker:
|
||||
self.install_dict_keys_match_guard()
|
||||
return super().clone(**kwargs)
|
||||
|
||||
|
||||
class MappingProxyVariable(VariableTracker):
|
||||
# proxies to the original dict_vt
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(dv_dict, ConstDictVariable)
|
||||
self.dv_dict = dv_dict
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return types.MappingProxyType
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
return self.dv_dict.unpack_var_sequence(tx)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# load types.MappingProxyType
|
||||
if self.source:
|
||||
msg = (
|
||||
@ -863,11 +876,11 @@ class MappingProxyVariable(VariableTracker):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if self.source and tx.output.side_effects.has_existing_dict_mutation():
|
||||
msg = (
|
||||
"A dict has been modified while we have an existing mappingproxy object. "
|
||||
@ -892,7 +905,7 @@ class MappingProxyVariable(VariableTracker):
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if self.python_type() is types.MappingProxyType:
|
||||
return ConstantVariable.create(name in types.MappingProxyType.__dict__)
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
@ -900,35 +913,44 @@ class MappingProxyVariable(VariableTracker):
|
||||
|
||||
class NNModuleHooksDictVariable(ConstDictVariable):
|
||||
# Special class to avoid adding any guards on the nn module hook ids.
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DefaultDictVariable(ConstDictVariable):
|
||||
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
items: dict[VariableTracker, VariableTracker],
|
||||
user_cls: type,
|
||||
default_factory: Optional[VariableTracker] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(items, user_cls, **kwargs)
|
||||
assert user_cls is collections.defaultdict
|
||||
if default_factory is None:
|
||||
default_factory = ConstantVariable.create(None)
|
||||
self.default_factory = default_factory
|
||||
|
||||
def is_python_constant(self):
|
||||
def is_python_constant(self) -> bool:
|
||||
# Return false for unsupported defaults. This ensures that a bad handler
|
||||
# path is not taken in BuiltinVariable for getitem.
|
||||
if self.default_factory not in [list, tuple, dict] and not self.items:
|
||||
return False
|
||||
return super().is_python_constant()
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
assert self.default_factory is not None
|
||||
return (
|
||||
f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_supported_arg(arg):
|
||||
def is_supported_arg(arg: VariableTracker) -> bool:
|
||||
if isinstance(arg, variables.BuiltinVariable):
|
||||
return arg.fn in (list, tuple, dict, set)
|
||||
else:
|
||||
@ -942,11 +964,11 @@ class DefaultDictVariable(ConstDictVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__getitem__":
|
||||
if len(args) != 1:
|
||||
raise_args_mismatch(tx, name, "1 args", f"{len(args)} args")
|
||||
@ -962,13 +984,13 @@ class DefaultDictVariable(ConstDictVariable):
|
||||
else:
|
||||
default_var = self.default_factory.call_function(tx, [], {})
|
||||
super().call_method(
|
||||
tx, "__setitem__", (args[0], default_var), kwargs
|
||||
tx, "__setitem__", [args[0], default_var], kwargs
|
||||
)
|
||||
return default_var
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# emit `defaultdict(default_factory, new_dict)`
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
@ -994,40 +1016,48 @@ class SetVariable(ConstDictVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# pyrefly: ignore[bad-assignment]
|
||||
items = dict.fromkeys(items, SetVariable._default_value())
|
||||
# pyrefly: ignore[bad-argument-type]
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
if not self.items:
|
||||
return "set()"
|
||||
else:
|
||||
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
|
||||
return set(self.items.keys())
|
||||
|
||||
@staticmethod
|
||||
def _default_value():
|
||||
def _default_value() -> VariableTracker:
|
||||
# Variable to fill in he keys of the dictionary
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> Any:
|
||||
return {k.vt.as_proxy() for k in self.set_items}
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return set
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return {k.vt.as_python_constant() for k in self.set_items}
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.foreach([x.vt for x in self.set_items])
|
||||
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
|
||||
|
||||
def _fast_set_method(self, tx, fn, args, kwargs):
|
||||
def _fast_set_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
fn: Any,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
try:
|
||||
res = fn(
|
||||
*[x.as_python_constant() for x in [self, *args]],
|
||||
@ -1037,15 +1067,16 @@ class SetVariable(ConstDictVariable):
|
||||
raise_observed_exception(
|
||||
type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
|
||||
)
|
||||
# pyrefly: ignore[unbound-name]
|
||||
return VariableTracker.build(tx, res)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
# We forward the calls to the dictionary model
|
||||
from ..utils import check_constant_args
|
||||
|
||||
@ -1065,10 +1096,10 @@ class SetVariable(ConstDictVariable):
|
||||
return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
|
||||
|
||||
if name == "__init__":
|
||||
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
|
||||
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items.clear()
|
||||
self.items.update(temp_set_vt.items)
|
||||
self.items.update(temp_set_vt.items) # type: ignore[attr-defined]
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "add":
|
||||
if kwargs or len(args) != 1:
|
||||
@ -1079,7 +1110,7 @@ class SetVariable(ConstDictVariable):
|
||||
f"{len(args)} args and {len(kwargs)} kwargs",
|
||||
)
|
||||
name = "__setitem__"
|
||||
args = (args[0], SetVariable._default_value())
|
||||
args = [args[0], SetVariable._default_value()]
|
||||
elif name == "pop":
|
||||
if kwargs or args:
|
||||
raise_args_mismatch(
|
||||
@ -1090,12 +1121,14 @@ class SetVariable(ConstDictVariable):
|
||||
)
|
||||
# Choose an item at random and pop it via the Dict.pop method
|
||||
try:
|
||||
result = self.set_items.pop().vt
|
||||
result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment]
|
||||
except KeyError as e:
|
||||
raise_observed_exception(
|
||||
KeyError, tx, args=list(map(ConstantVariable.create, e.args))
|
||||
)
|
||||
super().call_method(tx, name, (result,), kwargs)
|
||||
# pyrefly: ignore[unbound-name]
|
||||
super().call_method(tx, name, [result], kwargs)
|
||||
# pyrefly: ignore[unbound-name]
|
||||
return result
|
||||
elif name == "isdisjoint":
|
||||
if kwargs or len(args) != 1:
|
||||
@ -1217,6 +1250,7 @@ class SetVariable(ConstDictVariable):
|
||||
f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
assert m is not None
|
||||
return self.call_method(tx, m, args, kwargs)
|
||||
elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"):
|
||||
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
|
||||
@ -1230,29 +1264,34 @@ class SetVariable(ConstDictVariable):
|
||||
"__ixor__": "symmetric_difference_update",
|
||||
"__isub__": "difference_update",
|
||||
}.get(name)
|
||||
assert m is not None
|
||||
self.call_method(tx, m, args, kwargs)
|
||||
return self
|
||||
elif name == "__eq__":
|
||||
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
|
||||
return ConstantVariable.create(False)
|
||||
r = self.call_method(tx, "symmetric_difference", args, kwargs)
|
||||
return ConstantVariable.create(len(r.set_items) == 0)
|
||||
return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined]
|
||||
elif name in cmp_name_to_op_mapping:
|
||||
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
return ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
|
||||
)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
raise RuntimeError("Illegal to getitem on a set")
|
||||
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
super().install_dict_contains_guard(tx, args)
|
||||
|
||||
|
||||
@ -1260,27 +1299,27 @@ class FrozensetVariable(SetVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
if not self.items:
|
||||
return "frozenset()"
|
||||
else:
|
||||
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
|
||||
return self.items.keys()
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return frozenset
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return frozenset({k.vt.as_python_constant() for k in self.set_items})
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.foreach([x.vt for x in self.set_items])
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
@ -1293,11 +1332,11 @@ class FrozensetVariable(SetVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
|
||||
raise RuntimeError(f"Illegal call_method {name} on a frozenset")
|
||||
elif name == "__init__":
|
||||
@ -1316,7 +1355,7 @@ class FrozensetVariable(SetVariable):
|
||||
"symmetric_difference",
|
||||
):
|
||||
r = super().call_method(tx, name, args, kwargs)
|
||||
return FrozensetVariable(r.items)
|
||||
return FrozensetVariable(r.items) # type: ignore[attr-defined]
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
||||
@ -1324,11 +1363,11 @@ class DictKeySetVariable(SetVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
if not self.items:
|
||||
return "dict_keys([])"
|
||||
else:
|
||||
@ -1338,33 +1377,35 @@ class DictKeySetVariable(SetVariable):
|
||||
+ "])"
|
||||
)
|
||||
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> Any:
|
||||
return self.items
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_keys
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return dict.fromkeys(
|
||||
{k.vt.as_python_constant() for k in self.set_items}, None
|
||||
).keys()
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
|
||||
raise RuntimeError(f"Illegal call_method {name} on a dict_keys")
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
@ -1379,42 +1420,47 @@ class DictViewVariable(VariableTracker):
|
||||
|
||||
kv: Optional[str] = None
|
||||
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert self.kv in ("keys", "values", "items")
|
||||
assert isinstance(dv_dict, ConstDictVariable)
|
||||
self.dv_dict = dv_dict
|
||||
|
||||
@property
|
||||
def view_items(self):
|
||||
def view_items(self) -> Any:
|
||||
assert self.kv is not None
|
||||
return getattr(self.dv_dict.items, self.kv)()
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
# Returns an iterable of the unpacked items
|
||||
# Implement in the subclasses
|
||||
raise NotImplementedError
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
return self.view_items_vt
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
assert self.kv is not None
|
||||
codegen(self.dv_dict)
|
||||
codegen.load_method(self.kv)
|
||||
codegen.call_method(0)
|
||||
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
assert self.kv is not None
|
||||
if name in self.python_type().__dict__:
|
||||
return ConstantVariable.create(True)
|
||||
return ConstantVariable.create(False)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__len__":
|
||||
return self.dv_dict.call_method(tx, name, args, kwargs)
|
||||
elif name == "__iter__":
|
||||
@ -1428,24 +1474,24 @@ class DictKeysVariable(DictViewVariable):
|
||||
kv = "keys"
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> set[VariableTracker]:
|
||||
return set(self.view_items)
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
# Returns an iterable of the unpacked items
|
||||
return [x.vt for x in self.view_items]
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_keys
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__contains__":
|
||||
return self.dv_dict.call_method(tx, name, args, kwargs)
|
||||
elif name in (
|
||||
@ -1460,13 +1506,13 @@ class DictKeysVariable(DictViewVariable):
|
||||
):
|
||||
# These methods always returns a set
|
||||
m = getattr(self.set_items, name)
|
||||
r = m(args[0].set_items)
|
||||
r = m(args[0].set_items) # type: ignore[attr-defined]
|
||||
return SetVariable(r)
|
||||
if name in cmp_name_to_op_mapping:
|
||||
if not isinstance(args[0], (SetVariable, DictKeysVariable)):
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
return ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
|
||||
)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
@ -1476,10 +1522,10 @@ class DictValuesVariable(DictViewVariable):
|
||||
kv = "values"
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
return list(self.view_items)
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_values
|
||||
|
||||
|
||||
@ -1487,14 +1533,20 @@ class DictItemsVariable(DictViewVariable):
|
||||
kv = "items"
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
# Returns an iterable of the unpacked items
|
||||
return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items]
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_items
|
||||
|
||||
def call_method(self, tx, name, args, kwargs):
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
# TODO(guilhermeleobas): This should actually check if args[0]
|
||||
# implements the mapping protocol.
|
||||
if name == "__eq__":
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -586,7 +586,7 @@ class FilterVariable(IteratorVariable):
|
||||
else:
|
||||
res = self.fn.call_function(tx, [item], {})
|
||||
pred_res = variables.UserFunctionVariable(
|
||||
polyfills.predicate
|
||||
polyfills.predicate # type: ignore[arg-type]
|
||||
).call_function(tx, [res], {})
|
||||
if pred_res.as_python_constant():
|
||||
return item
|
||||
|
||||
@ -472,7 +472,12 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
||||
)
|
||||
elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined]
|
||||
name_to_arg_map = bind_args_cached(
|
||||
self.value, tx, self.source, args, kwargs
|
||||
# pyrefly: ignore[bad-argument-type]
|
||||
self.value,
|
||||
tx,
|
||||
self.source,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
backends = name_to_arg_map["backends"].as_python_constant()
|
||||
set_priority = name_to_arg_map["set_priority"].as_python_constant()
|
||||
@ -1349,7 +1354,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
packed_input_vt = TupleVariable.build(
|
||||
tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs))
|
||||
)
|
||||
out_vt = variables.UserFunctionVariable(tree_flatten).call_function(
|
||||
out_vt = variables.UserFunctionVariable(tree_flatten).call_function( # type: ignore[arg-type]
|
||||
tx, [packed_input_vt], {}
|
||||
)
|
||||
assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2
|
||||
|
||||
@ -2970,6 +2970,12 @@ class CppPythonBindingsCodeCache(CppCodeCache):
|
||||
throw std::runtime_error("expected int arg");
|
||||
return reinterpret_cast<uintptr_t>(result);
|
||||
}}
|
||||
template <> inline float parse_arg<float>(PyObject* args, size_t n) {{
|
||||
auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n));
|
||||
if(unlikely(result == -1.0 && PyErr_Occurred()))
|
||||
throw std::runtime_error("expected float arg");
|
||||
return static_cast<float>(result);
|
||||
}}
|
||||
|
||||
{extra_parse_arg}
|
||||
|
||||
|
||||
@ -1732,9 +1732,15 @@ class KernelArgs:
|
||||
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
||||
arg_types.append(f"{cpp_dtype}*")
|
||||
for outer, inner in self.sizevars.items():
|
||||
arg_defs.append(f"const {INDEX_TYPE} {inner}")
|
||||
if isinstance(outer, sympy.Symbol) and symbol_is_type(
|
||||
outer, (SymT.UNBACKED_FLOAT)
|
||||
):
|
||||
arg_defs.append(f"const float {inner}")
|
||||
arg_types.append("const float")
|
||||
else:
|
||||
arg_defs.append(f"const {INDEX_TYPE} {inner}")
|
||||
arg_types.append(f"const {INDEX_TYPE}")
|
||||
call_args.append(self.wrap_size_arg(outer))
|
||||
arg_types.append(f"const {INDEX_TYPE}")
|
||||
if V.graph.wrapper_code:
|
||||
V.graph.wrapper_code.ensure_size_computed(outer)
|
||||
assert not self.workspace_args, "Workspace not supported on CPU "
|
||||
@ -2353,6 +2359,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
SymT.UNBACKED_INT,
|
||||
SymT.SIZE,
|
||||
SymT.PRECOMPUTED_SIZE,
|
||||
SymT.UNBACKED_FLOAT,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, Optional
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
|
||||
from .. import config
|
||||
from ..runtime.hints import AttrsDescriptorWrapper
|
||||
@ -71,6 +72,10 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
|
||||
return "constexpr"
|
||||
elif isinstance(arg.expr, (float, sympy.Float)):
|
||||
return "fp32"
|
||||
elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type(
|
||||
arg.expr, (SymT.UNBACKED_FLOAT)
|
||||
):
|
||||
return "fp32"
|
||||
elif isinstance(arg.expr, bool):
|
||||
return "i1"
|
||||
|
||||
|
||||
@ -1224,43 +1224,3 @@ def _build_table(
|
||||
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
|
||||
)
|
||||
return "".join(result)
|
||||
|
||||
|
||||
# Collect all events with stack traces and format them canonically
|
||||
def _canonicalize_profiler_events(events):
|
||||
"""
|
||||
Extract and format all events with stack traces in a canonical way
|
||||
for deterministic testing.
|
||||
"""
|
||||
events_with_traces = []
|
||||
|
||||
for event in events:
|
||||
# Extract relevant fields
|
||||
event_name = event.get("name", "")
|
||||
node_name = event["args"].get("node_name", "")
|
||||
stack_trace = event["args"].get("stack_trace", "")
|
||||
|
||||
# Get the last non-empty line of the stack trace
|
||||
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
|
||||
stack_trace = lines[-1] if lines else ""
|
||||
|
||||
events_with_traces.append(
|
||||
{
|
||||
"event_name": event_name[:20],
|
||||
"node_name": node_name,
|
||||
"stack_trace": stack_trace,
|
||||
"start_time": event.get("ts", 0),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by node_name for deterministic ordering
|
||||
events_with_traces.sort(key=lambda x: x["start_time"])
|
||||
|
||||
# Format as a string
|
||||
lines: list[str] = []
|
||||
for evt in events_with_traces:
|
||||
lines.append(
|
||||
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
@ -443,7 +443,6 @@ class CodeGen:
|
||||
colored: bool = False,
|
||||
# Render each argument on its own line
|
||||
expanded_def: bool = False,
|
||||
record_func: bool = False,
|
||||
) -> PythonCode:
|
||||
free_vars: list[str] = []
|
||||
body: list[str] = []
|
||||
@ -648,6 +647,15 @@ class CodeGen:
|
||||
|
||||
if verbose:
|
||||
# override annotation with more detailed information
|
||||
try:
|
||||
from torch.distributed.tensor._api import DTensor, DTensorSpec
|
||||
|
||||
dtensorspec_format_shard_order_str = (
|
||||
DTensorSpec.format_shard_order_str
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
DTensor = None # type: ignore[assignment,misc]
|
||||
dtensorspec_format_shard_order_str = None
|
||||
from torch.fx.experimental.proxy_tensor import py_sym_types
|
||||
from torch.fx.passes.shape_prop import TensorMetadata
|
||||
|
||||
@ -678,6 +686,16 @@ class CodeGen:
|
||||
core = _tensor_annotation(meta_val)
|
||||
if is_plain:
|
||||
maybe_type_annotation = f': "{core}"'
|
||||
elif type(meta_val) is DTensor:
|
||||
assert dtensorspec_format_shard_order_str is not None
|
||||
dtensor_meta = dtensorspec_format_shard_order_str(
|
||||
meta_val._spec.placements, # type: ignore[attr-defined]
|
||||
meta_val._spec.shard_order, # type: ignore[attr-defined]
|
||||
)
|
||||
cls = meta_val.__class__.__name__
|
||||
maybe_type_annotation = (
|
||||
f': "{cls}({core}, {dim_green(dtensor_meta)})"'
|
||||
)
|
||||
else:
|
||||
cls = meta_val.__class__.__name__
|
||||
maybe_type_annotation = f': "{cls}({core})"'
|
||||
@ -799,10 +817,6 @@ class CodeGen:
|
||||
return
|
||||
raise NotImplementedError(f"node: {node.op} {node.target}")
|
||||
|
||||
if record_func:
|
||||
body.append(
|
||||
"_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n"
|
||||
)
|
||||
for i, node in enumerate(nodes):
|
||||
# NOTE: emit_node does not emit a string with newline. It depends
|
||||
# on delete_unused_values to append one
|
||||
@ -812,22 +826,8 @@ class CodeGen:
|
||||
# node index, which will be deleted later
|
||||
# after going through _body_transformer
|
||||
body.append(f"# COUNTER: {i}\n")
|
||||
do_record = record_func and node.op in (
|
||||
"call_function",
|
||||
"call_method",
|
||||
"call_module",
|
||||
)
|
||||
if do_record:
|
||||
# The double hash ## convention is used by post-processing to find the fx markers
|
||||
body.append(
|
||||
f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n"
|
||||
)
|
||||
emit_node(node)
|
||||
delete_unused_values(node)
|
||||
if do_record:
|
||||
body.append(f"_rf_{node.name}.__exit__(None, None, None)\n")
|
||||
if record_func:
|
||||
body.append("_rf.__exit__(None, None, None)\n")
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
@ -1779,7 +1779,6 @@ class Graph:
|
||||
include_device: bool = False,
|
||||
colored: bool = False,
|
||||
expanded_def: bool = False,
|
||||
record_func: bool = False,
|
||||
) -> PythonCode:
|
||||
"""
|
||||
Turn this ``Graph`` into valid Python code.
|
||||
@ -1847,7 +1846,6 @@ class Graph:
|
||||
include_device=include_device,
|
||||
colored=colored,
|
||||
expanded_def=expanded_def,
|
||||
record_func=record_func,
|
||||
)
|
||||
|
||||
def _python_code(
|
||||
@ -1860,7 +1858,6 @@ class Graph:
|
||||
include_device: bool = False,
|
||||
colored: bool = False,
|
||||
expanded_def: bool = False,
|
||||
record_func: bool = False,
|
||||
) -> PythonCode:
|
||||
return self._codegen._gen_python_code(
|
||||
self.nodes,
|
||||
@ -1871,7 +1868,6 @@ class Graph:
|
||||
include_device=include_device,
|
||||
colored=colored,
|
||||
expanded_def=expanded_def,
|
||||
record_func=record_func,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
@ -861,18 +861,14 @@ class {module_name}(torch.nn.Module):
|
||||
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
python_code = self._graph.python_code(
|
||||
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
|
||||
)
|
||||
python_code = self._graph.python_code(root_module="self")
|
||||
self._code = python_code.src
|
||||
self._lineno_map = python_code._lineno_map
|
||||
self._prologue_start = python_code._prologue_start
|
||||
|
||||
cls = type(self)
|
||||
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
if dynamo_config.enrich_profiler_metadata:
|
||||
# Generate metadata and register for profiler augmentation
|
||||
@ -889,6 +885,7 @@ class {module_name}(torch.nn.Module):
|
||||
# This ensures the same code+metadata always generates the same filename
|
||||
hash_value = _metadata_hash(self._code, node_metadata)
|
||||
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
|
||||
|
||||
filename = f"{file_stem}.py"
|
||||
|
||||
# Only include co_filename to use it directly as the cache key
|
||||
@ -908,13 +905,6 @@ class {module_name}(torch.nn.Module):
|
||||
|
||||
_register_fx_metadata(filename, metadata)
|
||||
|
||||
# Replace the placeholder in generated code with actual filename
|
||||
# The double hash ## convention is used by post-processing to find the fx markers
|
||||
self._code = self._code.replace(
|
||||
"torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')",
|
||||
f"torch._C._profiler._RecordFunctionFast('## {filename} ##')",
|
||||
)
|
||||
|
||||
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
|
||||
|
||||
# Determine whether this class explicitly defines a __call__ implementation
|
||||
|
||||
@ -4,7 +4,7 @@ import operator
|
||||
import re
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch.autograd.profiler import profile
|
||||
from torch.profiler import DeviceType
|
||||
@ -400,170 +400,3 @@ def _init_for_cuda_graphs() -> None:
|
||||
|
||||
with profile():
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimelineEvent:
|
||||
"""Represents an event in the profiler timeline."""
|
||||
|
||||
timestamp: int
|
||||
event_type: Literal["start", "end", "regular"]
|
||||
marker_type: Optional[Literal["filename", "node"]]
|
||||
identifier: Optional[str | int]
|
||||
event: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextStackEntry:
|
||||
"""Represents a context (filename or node) in the stack."""
|
||||
|
||||
context_type: Literal["filename", "node"]
|
||||
identifier: str | int
|
||||
metadata: Optional[dict]
|
||||
tid: Optional[int] = None # Thread ID associated with this context
|
||||
|
||||
|
||||
def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
"""
|
||||
Maps recorded profiler events to their corresponding fx nodes and adds stack traces.
|
||||
|
||||
Builds a timeline of all events (regular ops and FX markers for filenames/nodes),
|
||||
sorts by timestamp, then processes chronologically while maintaining a context stack of active
|
||||
filename/node scopes. Regular events are augmented with stack traces and node names from the
|
||||
innermost active context. Runtime is O(n log n) for n events.
|
||||
|
||||
Args:
|
||||
traced_data: Json of profiler events from Chrome trace
|
||||
|
||||
Returns:
|
||||
Dict mapping recorded event names to their aten operations with added stack traces
|
||||
"""
|
||||
from torch.fx.traceback import _FX_METADATA_REGISTRY
|
||||
|
||||
trace_events = traced_data.get("traceEvents", [])
|
||||
|
||||
# Create event timeline
|
||||
event_timeline: list[TimelineEvent] = []
|
||||
|
||||
def is_fx_marker_event(event):
|
||||
return (
|
||||
event.get("cat") == "cpu_op"
|
||||
and event.get("name", "").startswith("## ")
|
||||
and event.get("name", "").endswith(" ##")
|
||||
)
|
||||
|
||||
def append_fx_marker_event(event_type, identifier, event):
|
||||
start_ts = event["ts"]
|
||||
end_ts = start_ts + event["dur"]
|
||||
event_timeline.append(
|
||||
TimelineEvent(start_ts, "start", event_type, identifier, event)
|
||||
)
|
||||
event_timeline.append(
|
||||
TimelineEvent(end_ts, "end", event_type, identifier, event)
|
||||
)
|
||||
|
||||
for event in trace_events:
|
||||
if "ts" not in event or "dur" not in event:
|
||||
continue
|
||||
|
||||
if is_fx_marker_event(event):
|
||||
content = event["name"][3:-3]
|
||||
|
||||
if content.endswith(".py"):
|
||||
append_fx_marker_event("filename", content, event)
|
||||
else:
|
||||
try:
|
||||
node_index = int(content)
|
||||
except ValueError:
|
||||
pass
|
||||
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
|
||||
|
||||
else:
|
||||
# Regular event that needs augmentation
|
||||
start_ts = event["ts"]
|
||||
event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event))
|
||||
|
||||
# Sort by timestamp
|
||||
event_timeline.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# Process events in chronological order with a stack
|
||||
context_stack: list[ContextStackEntry] = []
|
||||
|
||||
# Invariant: all start event has a corresponding end event
|
||||
for timeline_event in event_timeline:
|
||||
match timeline_event.event_type:
|
||||
case "start":
|
||||
assert timeline_event.identifier is not None
|
||||
|
||||
if timeline_event.marker_type == "filename":
|
||||
assert isinstance(timeline_event.identifier, str)
|
||||
# Push filename context - query metadata registry on-demand
|
||||
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
|
||||
tid = timeline_event.event.get("tid")
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
"filename", timeline_event.identifier, metadata, tid
|
||||
)
|
||||
)
|
||||
elif timeline_event.marker_type == "node":
|
||||
# Find the current filename from stack
|
||||
current_file_metadata = None
|
||||
tid = timeline_event.event.get("tid")
|
||||
for ctx_entry in reversed(context_stack):
|
||||
if (
|
||||
ctx_entry.context_type == "filename"
|
||||
and ctx_entry.tid == tid
|
||||
):
|
||||
current_file_metadata = ctx_entry.metadata
|
||||
break
|
||||
|
||||
if current_file_metadata:
|
||||
node_metadata = current_file_metadata.get("node_metadata", {})
|
||||
if timeline_event.identifier in node_metadata:
|
||||
node_meta: Optional[dict] = node_metadata[
|
||||
timeline_event.identifier
|
||||
]
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
"node", timeline_event.identifier, node_meta, tid
|
||||
)
|
||||
)
|
||||
|
||||
case "end":
|
||||
# Pop from stack - search backwards to find matching context
|
||||
for i in range(len(context_stack) - 1, -1, -1):
|
||||
ctx_entry = context_stack[i]
|
||||
if (
|
||||
timeline_event.marker_type == ctx_entry.context_type
|
||||
and timeline_event.identifier == ctx_entry.identifier
|
||||
):
|
||||
context_stack.pop(i)
|
||||
break
|
||||
|
||||
case "regular":
|
||||
# Apply metadata from current context stack
|
||||
# Find the most specific context (node takes precedence over filename)
|
||||
# Only augment events with the same tid as the file/node event matched
|
||||
current_stack_trace = None
|
||||
current_node_name = None
|
||||
event_tid = timeline_event.event.get("tid")
|
||||
|
||||
for ctx_entry in reversed(context_stack):
|
||||
# Only apply metadata from contexts with matching tid
|
||||
if ctx_entry.tid == event_tid:
|
||||
if ctx_entry.context_type == "node" and ctx_entry.metadata:
|
||||
current_stack_trace = ctx_entry.metadata.get(
|
||||
"stack_trace", "No model stack trace available"
|
||||
)
|
||||
current_node_name = ctx_entry.metadata.get("name", "")
|
||||
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
|
||||
# if nodes are nested, e.g. in nested graph modules
|
||||
break
|
||||
|
||||
# Augment the event
|
||||
if current_stack_trace or current_node_name:
|
||||
args = timeline_event.event.setdefault("args", {})
|
||||
if current_stack_trace:
|
||||
args["stack_trace"] = current_stack_trace
|
||||
if current_node_name:
|
||||
args["node_name"] = current_node_name
|
||||
|
||||
@ -210,7 +210,8 @@ class _KinetoProfile:
|
||||
def start_trace(self) -> None:
|
||||
if self.execution_trace_observer:
|
||||
self.execution_trace_observer.start()
|
||||
assert self.profiler is not None
|
||||
if self.profiler is None:
|
||||
raise AssertionError("Profiler must be initialized before starting trace")
|
||||
self.profiler._start_trace()
|
||||
|
||||
if self.profile_memory:
|
||||
@ -256,7 +257,8 @@ class _KinetoProfile:
|
||||
def stop_trace(self) -> None:
|
||||
if self.execution_trace_observer:
|
||||
self.execution_trace_observer.stop()
|
||||
assert self.profiler is not None
|
||||
if self.profiler is None:
|
||||
raise AssertionError("Profiler must be initialized before stopping trace")
|
||||
self.profiler.__exit__(None, None, None)
|
||||
|
||||
def export_chrome_trace(self, path: str):
|
||||
@ -264,7 +266,10 @@ class _KinetoProfile:
|
||||
Exports the collected trace in Chrome JSON format. If kineto is enabled, only
|
||||
last cycle in schedule is exported.
|
||||
"""
|
||||
assert self.profiler
|
||||
if self.profiler is None:
|
||||
raise AssertionError(
|
||||
"Profiler must be initialized before exporting chrome trace"
|
||||
)
|
||||
if path.endswith(".gz"):
|
||||
fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False)
|
||||
fp.close()
|
||||
@ -284,7 +289,8 @@ class _KinetoProfile:
|
||||
path (str): save stacks file to this location;
|
||||
metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total"
|
||||
"""
|
||||
assert self.profiler
|
||||
if self.profiler is None:
|
||||
raise AssertionError("Profiler must be initialized before exporting stacks")
|
||||
return self.profiler.export_stacks(path, metric)
|
||||
|
||||
def toggle_collection_dynamic(
|
||||
@ -316,7 +322,7 @@ class _KinetoProfile:
|
||||
print(p.key_averages().table(
|
||||
sort_by="self_cuda_time_total", row_limit=-1))
|
||||
"""
|
||||
if not self.profiler:
|
||||
if self.profiler is None:
|
||||
return
|
||||
self.profiler.toggle_collection_dynamic(enable, activities)
|
||||
|
||||
@ -333,7 +339,10 @@ class _KinetoProfile:
|
||||
To use shape/stack functionality make sure to set record_shapes/with_stack
|
||||
when creating profiler context manager.
|
||||
"""
|
||||
assert self.profiler
|
||||
if self.profiler is None:
|
||||
raise AssertionError(
|
||||
"Profiler must be initialized before getting key averages"
|
||||
)
|
||||
return self.profiler.key_averages(
|
||||
group_by_input_shape, group_by_stack_n, group_by_overload_name
|
||||
)
|
||||
@ -343,7 +352,8 @@ class _KinetoProfile:
|
||||
Returns the list of unaggregated profiler events,
|
||||
to be used in the trace callback or after the profiling is finished
|
||||
"""
|
||||
assert self.profiler
|
||||
if self.profiler is None:
|
||||
raise AssertionError("Profiler must be initialized before accessing events")
|
||||
return self.profiler.function_events
|
||||
|
||||
def add_metadata(self, key: str, value: str) -> None:
|
||||
@ -395,7 +405,10 @@ class _KinetoProfile:
|
||||
if missing:
|
||||
raise ValueError(f"{', '.join(missing)} required for memory profiling.")
|
||||
|
||||
assert self.profiler is not None and self.profiler.kineto_results is not None
|
||||
if self.profiler is None or self.profiler.kineto_results is None:
|
||||
raise AssertionError(
|
||||
"Profiler and kineto_results must be initialized for memory profiling"
|
||||
)
|
||||
return MemoryProfile(self.profiler.kineto_results)
|
||||
|
||||
def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
|
||||
@ -485,7 +498,8 @@ def schedule(
|
||||
"""
|
||||
|
||||
def schedule_fn(step: int) -> ProfilerAction:
|
||||
assert step >= 0
|
||||
if step < 0:
|
||||
raise AssertionError(f"Step must be non-negative. Got {step}.")
|
||||
if step < skip_first:
|
||||
return ProfilerAction.NONE
|
||||
else:
|
||||
@ -508,9 +522,11 @@ def schedule(
|
||||
else ProfilerAction.RECORD_AND_SAVE
|
||||
)
|
||||
|
||||
assert (
|
||||
wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0
|
||||
), "Invalid profiler schedule arguments"
|
||||
if wait < 0 or warmup < 0 or active <= 0 or repeat < 0 or skip_first < 0:
|
||||
raise AssertionError(
|
||||
f"Invalid profiler schedule arguments. Got wait={wait} (need >= 0), warmup={warmup} (need >= 0), "
|
||||
f"active={active} (need > 0), repeat={repeat} (need >= 0), skip_first={skip_first} (need >= 0)."
|
||||
)
|
||||
if warmup == 0:
|
||||
warn(
|
||||
"Profiler won't be using warmup, this can skew profiler results",
|
||||
@ -717,7 +733,8 @@ class profile(_KinetoProfile):
|
||||
activities_set.add(ProfilerActivity.CUDA)
|
||||
elif ProfilerActivity.CUDA in activities_set:
|
||||
activities_set.remove(ProfilerActivity.CUDA)
|
||||
assert len(activities_set) > 0, "No valid profiler activities found"
|
||||
if len(activities_set) == 0:
|
||||
raise AssertionError("No valid profiler activities found")
|
||||
|
||||
super().__init__(
|
||||
activities=activities,
|
||||
|
||||
Reference in New Issue
Block a user