mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-07 18:04:58 +08:00
Compare commits
20 Commits
ciflow/tru
...
lucaskabel
| Author | SHA1 | Date | |
|---|---|---|---|
| 0f2f1f7e3e | |||
| ae8a9fa894 | |||
| 6052a01b71 | |||
| 14b153bcf2 | |||
| 641de23c96 | |||
| 89165c0a2b | |||
| dcc2ba4ca4 | |||
| ad5c7c20e0 | |||
| c86540f120 | |||
| c17aa0f113 | |||
| 4ff068c33a | |||
| 0c7a4a6b48 | |||
| f93ee16fb6 | |||
| 9c2c3dbc15 | |||
| d4dcd0354c | |||
| aba2fa3259 | |||
| d2d13bf62d | |||
| 7a6ff88196 | |||
| 59563dfe56 | |||
| 5c639466f7 |
@ -271,6 +271,16 @@ case "$tag" in
|
||||
# from pytorch/llvm:9.0.1 is x86 specific
|
||||
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
||||
;;
|
||||
pytorch-linux-jammy-aarch64-py3.10-clang21)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
CLANG_VERSION=21
|
||||
ACL=yes
|
||||
VISION=yes
|
||||
OPENBLAS=yes
|
||||
# snadampal: skipping llvm src build install because the current version
|
||||
# from pytorch/llvm:9.0.1 is x86 specific
|
||||
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
||||
;;
|
||||
pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=11
|
||||
|
||||
@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
|
||||
# work around ubuntu apt-get conflicts
|
||||
sudo apt-get -y -f install
|
||||
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
|
||||
if [[ $CLANG_VERSION == 18 ]]; then
|
||||
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
|
||||
if [[ $CLANG_VERSION -ge 18 ]]; then
|
||||
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
@ -129,7 +129,7 @@ function install_129 {
|
||||
}
|
||||
|
||||
function install_128 {
|
||||
CUDNN_VERSION=9.10.2.21
|
||||
CUDNN_VERSION=9.8.0.87
|
||||
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
|
||||
# install CUDA 12.8.1 in the same container
|
||||
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux
|
||||
|
||||
@ -10,6 +10,7 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
|
||||
|
||||
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
|
||||
OPENBLAS_BUILD_FLAGS="
|
||||
CC=gcc
|
||||
NUM_THREADS=128
|
||||
USE_OPENMP=1
|
||||
NO_SHARED=0
|
||||
|
||||
@ -272,18 +272,6 @@ def smoke_test_cuda(
|
||||
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
|
||||
print(f"Torch cuDNN version: {torch_cudnn_version}")
|
||||
|
||||
torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion()
|
||||
print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}")
|
||||
torch_cudnn_runtime_version = tuple(
|
||||
[int(x) for x in torch_cudnn_version.split(".")]
|
||||
)
|
||||
if torch_cudnn_runtime_version != torch_cudnn_compile_version:
|
||||
raise RuntimeError(
|
||||
"cuDNN runtime version doesn't match comple version. "
|
||||
f"Loaded: {torch_cudnn_runtime_version} "
|
||||
f"Expected: {torch_cudnn_compile_version}"
|
||||
)
|
||||
|
||||
if sys.platform in ["linux", "linux2"]:
|
||||
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
|
||||
print(f"Torch nccl; version: {torch_nccl_version}")
|
||||
|
||||
@ -337,7 +337,7 @@ test_python() {
|
||||
|
||||
test_python_smoke() {
|
||||
# Smoke tests for H100/B200
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
|
||||
2
.github/workflows/docker-builds.yml
vendored
2
.github/workflows/docker-builds.yml
vendored
@ -79,6 +79,8 @@ jobs:
|
||||
include:
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
timeout-minutes: 600
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -127,7 +127,6 @@ torch/test/
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
|
||||
torch/version.py
|
||||
torch/_inductor/kernel/vendored_templates/*
|
||||
minifier_launcher.py
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
|
||||
|
||||
@ -191,7 +191,7 @@ class Vectorized<BFloat16> {
|
||||
auto vals = svreinterpret_u16_bf16(values);
|
||||
vals = sveor_u16_x(ptrue, vals, mask);
|
||||
return svreinterpret_bf16_u16(vals);
|
||||
};
|
||||
}
|
||||
Vectorized<BFloat16> round() const;
|
||||
Vectorized<BFloat16> tan() const;
|
||||
Vectorized<BFloat16> tanh() const;
|
||||
@ -349,47 +349,47 @@ Vectorized<BFloat16> inline Vectorized<BFloat16>::frac() const {
|
||||
return convert_float_bfloat16(v1, v2); \
|
||||
}
|
||||
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(isnan);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(angle);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(acos);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(acosh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(asin);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(atan);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(atanh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(erf);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(erfc);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp2);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(expm1);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0e);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(digamma);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log2);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log10);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log1p);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sin);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sinh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(cos);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(cosh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(ceil);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(floor);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(round);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(tan);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(tanh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(trunc);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(isnan)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(angle)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(acos)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(acosh)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(asin)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(atan)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(atanh)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(erf)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(erfc)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp2)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(expm1)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0e)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(digamma)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log2)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log10)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log1p)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sin)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sinh)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(cos)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(cosh)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(ceil)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(floor)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(round)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(tan)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(tanh)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(trunc)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow)
|
||||
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(
|
||||
const Vectorized<BFloat16>& other) const {
|
||||
|
||||
@ -293,7 +293,7 @@ struct ComputeLocationBase<scalar_t, /*align_corners=*/false> {
|
||||
, empty(size <= 0) {}
|
||||
|
||||
inline Vec unnormalize(const Vec &in) const {
|
||||
return (in + Vec(1)) * Vec(scaling_factor) - Vec(0.5);
|
||||
return (in + Vec(static_cast<scalar_t>(1))) * Vec(scaling_factor) - Vec(static_cast<scalar_t>(0.5));
|
||||
}
|
||||
|
||||
inline Vec clip_coordinates(const Vec &in) const {
|
||||
@ -831,7 +831,7 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bicubic,
|
||||
|
||||
// constant used in cubic convolution
|
||||
// could be -0.5 or -0.75, use the same value in UpSampleBicubic2d.h
|
||||
const Vec A = Vec(-0.75);
|
||||
const Vec A = Vec(static_cast<scalar_t>(-0.75));
|
||||
|
||||
ApplyGridSample(const TensorAccessor<const scalar_t, 4>& input)
|
||||
: inp_H(input.size(2))
|
||||
|
||||
@ -22,6 +22,9 @@
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#ifdef USE_ROCM
|
||||
#include <ATen/native/hip/ck_group_gemm.h>
|
||||
#endif
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
@ -666,12 +669,19 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
|
||||
use_fast_path = true;
|
||||
}
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
if (use_fast_path) {
|
||||
// fast path, no d2h sync needed
|
||||
#ifndef USE_ROCM
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
|
||||
#endif
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
}
|
||||
|
||||
19
aten/src/ATen/native/hip/ck_group_gemm.h
Normal file
19
aten/src/ATen/native/hip/ck_group_gemm.h
Normal file
@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <optional>
|
||||
|
||||
namespace at {
|
||||
namespace hip {
|
||||
namespace detail {
|
||||
void group_gemm_ck(
|
||||
const at::Tensor& mat_a,
|
||||
const at::Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::Tensor& out);
|
||||
|
||||
} // namespace detail
|
||||
} // namespace hip
|
||||
} // namespace at
|
||||
462
aten/src/ATen/native/hip/ck_group_gemm.hip
Normal file
462
aten/src/ATen/native/hip/ck_group_gemm.hip
Normal file
@ -0,0 +1,462 @@
|
||||
#undef __HIP_NO_HALF_CONVERSIONS__
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/TensorAccessor.h>
|
||||
#include <c10/hip/HIPStream.h>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include <ck/ck.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/utility/tuple.hpp>
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
namespace at {
|
||||
namespace hip {
|
||||
namespace detail {
|
||||
|
||||
namespace CkTypes {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename DataType>
|
||||
using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage<
|
||||
ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor,
|
||||
DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType,
|
||||
CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough,
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2,
|
||||
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
|
||||
3, 8, 8, 1,
|
||||
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
|
||||
3, 8, 8, 1,
|
||||
1, 1,
|
||||
S<1,32,1,8>, 4
|
||||
>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename DataType>
|
||||
void launch_grouped_bgemm_ck_impl_dispatch(
|
||||
const at::Tensor& mat_a,
|
||||
const at::Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
at::Tensor& out)
|
||||
{
|
||||
using DeviceOp = GroupedGemmKernel<ALayout, BLayout, DataType>;
|
||||
using PassThrough = CkTypes::PassThrough;
|
||||
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
std::vector<const void*> p_a_ptrs, p_b_ptrs;
|
||||
std::vector<void*> p_e_ptrs;
|
||||
// Note: d_ptrs will be resized after we populate the other vectors
|
||||
|
||||
const int mat_a_dim = mat_a.dim();
|
||||
const int mat_b_dim = mat_b.dim();
|
||||
|
||||
const char* a_ptr_base = reinterpret_cast<const char*>(mat_a.data_ptr());
|
||||
const char* b_ptr_base = reinterpret_cast<const char*>(mat_b.data_ptr());
|
||||
char* out_ptr_base = reinterpret_cast<char*>(out.data_ptr());
|
||||
const size_t a_element_size = mat_a.element_size();
|
||||
const size_t b_element_size = mat_b.element_size();
|
||||
const size_t out_element_size = out.element_size();
|
||||
|
||||
// for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses.
|
||||
if (mat_a_dim == 2 && mat_b_dim == 2) {
|
||||
// 2D*2D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
const int M = mat_a.size(0); // number of rows in A
|
||||
const int N = mat_b.size(1); // number of columns in B
|
||||
const int K = mat_a.size(1); // columns in A == rows in B
|
||||
// for 2d*2d input, output is 3d.
|
||||
// for each group, A columns (K) are sliced. M and N dimensions are not sliced.
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_k = (i == 0) ? 0 : offs_accessor[i-1];
|
||||
int end_k = offs_accessor[i];
|
||||
int k = end_k - start_k;
|
||||
|
||||
//K dimension are sliced, hence select stride(1) always.
|
||||
//K dimension is always dimension 1, regardless of memory layout (row/column major)
|
||||
const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size;
|
||||
const void* group_b_ptr;
|
||||
int ldb;
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset
|
||||
group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size;
|
||||
// Leading dimension = distance between rows = stride(0)
|
||||
ldb = mat_b.stride(0);
|
||||
} else {
|
||||
// Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset
|
||||
group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size;
|
||||
// Leading dimension = distance between columns = stride(1)
|
||||
ldb = mat_b.stride(1);
|
||||
}
|
||||
|
||||
// Calculate output pointer for group i in 3D tensor [num_groups, M, N]
|
||||
// stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i
|
||||
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
|
||||
int lda, ldc;
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A [M,K]: leading dimension = distance between rows = stride(0)
|
||||
lda = mat_a.stride(0);
|
||||
} else {
|
||||
// Column-major A [M,K]: leading dimension = distance between columns = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
}
|
||||
// Output is always row-major in 3D tensor [num_groups, M, N]
|
||||
// Leading dimension for each group's [M,N] slice = stride(1) = N
|
||||
ldc = out.stride(1);
|
||||
size_t output_group_bytes = M * N * out_element_size;
|
||||
void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes;
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(k),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 2 && mat_b_dim == 3) {
|
||||
// 2D*3D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
|
||||
// 2d*3d input, output is 2d.
|
||||
// A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n]
|
||||
// Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B
|
||||
const int K = mat_a.size(1); // columns in A
|
||||
// For 2D-3D case: The output determines N (result width)
|
||||
const int N = out.size(1); // N is the width of the output tensor
|
||||
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_m = (i == 0) ? 0 : offs_accessor[i - 1];
|
||||
int end_m = offs_accessor[i];
|
||||
int m = end_m - start_m;
|
||||
|
||||
// Skip zero-sized groups but continue processing subsequent groups
|
||||
if (m <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Select A rows for group i: skip start_m rows
|
||||
const void* group_a_ptr;
|
||||
int lda;
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart
|
||||
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
|
||||
lda = mat_a.stride(0); // distance between rows
|
||||
} else {
|
||||
// Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows)
|
||||
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Detect stride pattern for A tensor to determine appropriate lda calculation
|
||||
bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0));
|
||||
|
||||
if (a_is_strided_tensor) {
|
||||
// For strided A tensors: stride(0) gives the actual leading dimension
|
||||
lda = mat_a.stride(0);
|
||||
} else {
|
||||
// For non-strided A tensors: use the M dimension (total rows)
|
||||
lda = mat_a.size(0); // Total M dimension for column-major layout
|
||||
}
|
||||
}
|
||||
|
||||
// Select B batch for group i: B[i, :, :]
|
||||
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
|
||||
int ldb;
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed
|
||||
ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N]
|
||||
} else {
|
||||
// Detect stride pattern to determine appropriate ldb calculation
|
||||
bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2));
|
||||
|
||||
if (is_strided_tensor) {
|
||||
// For strided tensors: stride(2) gives the actual leading dimension
|
||||
ldb = mat_b.stride(2);
|
||||
} else {
|
||||
// For non-strided tensors: use the N dimension
|
||||
ldb = mat_b.size(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N]
|
||||
void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size;
|
||||
int ldc = out.stride(0); // distance between rows in output (should be N for 2D case)
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(m),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 3 && mat_b_dim == 3) {
|
||||
// 3d*3d input, output is 3d - batched matrix multiplication
|
||||
// A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n]
|
||||
// Each batch is processed as a separate GEMM operation
|
||||
const int batch_size = mat_a.size(0);
|
||||
const int M = mat_a.size(1); // rows in each A matrix
|
||||
const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed)
|
||||
|
||||
// Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout
|
||||
int N;
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
N = mat_b.size(2);
|
||||
} else if (mat_b.size(2) == K) {
|
||||
// B is [batch, n, k] - transposed layout
|
||||
N = mat_b.size(1);
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[",
|
||||
batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]");
|
||||
}
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
// Select A batch for group i: A[i, :, :]
|
||||
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Select B batch for group i: B[i, :, :]
|
||||
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
|
||||
|
||||
// Select output batch for group i: Output[i, :, :]
|
||||
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
|
||||
|
||||
int lda, ldb, ldc;
|
||||
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A: leading dimension = distance between rows = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
} else {
|
||||
// Column-major A: leading dimension = distance between columns = stride(2)
|
||||
lda = mat_a.stride(2);
|
||||
}
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B: leading dimension = distance between rows
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
ldb = mat_b.stride(1); // stride between K rows
|
||||
} else {
|
||||
// B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM
|
||||
ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n])
|
||||
}
|
||||
} else {
|
||||
// Column-major B: leading dimension = distance between columns
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
ldb = mat_b.stride(2); // stride between N columns
|
||||
} else {
|
||||
// B is [batch, n, k] - transposed layout
|
||||
ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n])
|
||||
}
|
||||
}
|
||||
|
||||
// Output is typically row-major: leading dimension = distance between rows = stride(1)
|
||||
ldc = out.stride(1);
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 3 && mat_b_dim == 2) {
|
||||
// 3D*2D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
// 3d*2d input, output is 3d.
|
||||
// A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both)
|
||||
// Offset divides N dimension of B, each group gets different slice of B and different batch of A
|
||||
const int batch_size = mat_a.size(0); // n_groups
|
||||
const int M = mat_a.size(1); // rows in each A matrix
|
||||
const int K = mat_a.size(2); // columns in A
|
||||
|
||||
// For row-major A and B case: B should be [K, total_N]
|
||||
const int total_N = mat_b.size(1); // B is [K, total_N] for row-major
|
||||
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_n = (i == 0) ? 0 : offs_accessor[i - 1];
|
||||
int end_n = offs_accessor[i];
|
||||
int n = end_n - start_n;
|
||||
|
||||
// Skip zero-sized groups but continue processing subsequent groups
|
||||
if (n <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Select A batch for group i: A[i, :, :]
|
||||
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Select B slice for group i: B[:, start_n:end_n] (B[K, total_N])
|
||||
const void* group_b_ptr;
|
||||
int ldb;
|
||||
|
||||
// Check if B is row-major or column-major
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B [K, total_N]: slice columns [start_n:end_n]
|
||||
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
|
||||
ldb = mat_b.stride(0); // distance between rows (should be total_N)
|
||||
} else {
|
||||
// Column-major B [K, total_N]: slice columns [start_n:end_n]
|
||||
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
|
||||
ldb = mat_b.stride(1); // distance between columns (should be K)
|
||||
}
|
||||
|
||||
// Select output slice for group i: Output[:, start_n:end_n]
|
||||
void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size;
|
||||
|
||||
int lda, ldc;
|
||||
|
||||
// Row-major A: leading dimension = distance between rows = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
// Output is row-major: leading dimension = distance between rows = stride(0)
|
||||
ldc = out.stride(0);
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(n),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim);
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups");
|
||||
|
||||
// Initialize d_ptrs with the correct size
|
||||
std::vector<std::array<const void*, 0>> d_ptrs(p_a_ptrs.size());
|
||||
|
||||
static DeviceOp gemm_instance;
|
||||
auto argument = gemm_instance.MakeArgument(
|
||||
p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs,
|
||||
gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}
|
||||
);
|
||||
TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument),
|
||||
"CK Group GEMM: argument unsupported (shape/strides/type config)");
|
||||
size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument);
|
||||
size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument);
|
||||
|
||||
void* gemm_arg_buf = nullptr;
|
||||
void* ws_buf = nullptr;
|
||||
|
||||
hipMalloc(&gemm_arg_buf, arg_buf_size);
|
||||
hipMalloc(&ws_buf, ws_size);
|
||||
|
||||
gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf);
|
||||
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
|
||||
|
||||
auto invoker = gemm_instance.MakeInvoker();
|
||||
hipStream_t stream = c10::hip::getCurrentHIPStream();
|
||||
invoker.Run(argument, {stream});
|
||||
hipFree(gemm_arg_buf);
|
||||
hipFree(ws_buf);
|
||||
}
|
||||
|
||||
void group_gemm_ck(
|
||||
const at::Tensor& input_a,
|
||||
const at::Tensor& input_b_colmajor,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& /*bias*/,
|
||||
at::Tensor& out)
|
||||
{
|
||||
// Detect if input_a is row-major based on stride pattern
|
||||
bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1);
|
||||
bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1);
|
||||
// Ensure tensor A is row-major and contiguous if not already
|
||||
at::Tensor mat_a = input_a;
|
||||
if (!a_row_major) {
|
||||
// If A is not row-major, make it contiguous (row-major)
|
||||
mat_a = input_a.contiguous();
|
||||
}
|
||||
// Force tensor B to be column-major using double transpose trick
|
||||
// This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape
|
||||
at::Tensor mat_b = input_b_colmajor;
|
||||
if (!b_col_major) {
|
||||
mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1);
|
||||
}
|
||||
|
||||
// For 3D tensors, check the last dimension stride for row-major detection
|
||||
a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1);
|
||||
bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1);
|
||||
|
||||
if (mat_a.dtype() == at::kBFloat16) {
|
||||
// bf16 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else if (mat_a.dtype() == at::kHalf) {
|
||||
// fp16 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else if (mat_a.dtype() == at::kFloat) {
|
||||
// fp32 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace hip
|
||||
} // namespace at
|
||||
@ -18,6 +18,7 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
@ -40,200 +41,99 @@ namespace c10 {
|
||||
///
|
||||
/// This is intended to be trivially copyable, so it should be passed by
|
||||
/// value.
|
||||
///
|
||||
/// NOTE: We have refactored out the headeronly parts of the ArrayRef struct
|
||||
/// into HeaderOnlyArrayRef. As adding `virtual` would change the performance of
|
||||
/// the underlying constexpr calls, we rely on apparent-type dispatch for
|
||||
/// inheritance. This should be fine because their memory format is the same,
|
||||
/// and it is never incorrect for ArrayRef to call HeaderOnlyArrayRef methods.
|
||||
/// However, you should prefer to use ArrayRef when possible, because its use
|
||||
/// of TORCH_CHECK will lead to better user-facing error messages.
|
||||
template <typename T>
|
||||
class ArrayRef final {
|
||||
class ArrayRef final : public HeaderOnlyArrayRef<T> {
|
||||
public:
|
||||
using iterator = const T*;
|
||||
using const_iterator = const T*;
|
||||
using size_type = size_t;
|
||||
using value_type = T;
|
||||
|
||||
using reverse_iterator = std::reverse_iterator<iterator>;
|
||||
|
||||
private:
|
||||
/// The start of the array, in an external buffer.
|
||||
const T* Data;
|
||||
|
||||
/// The number of elements.
|
||||
size_type Length;
|
||||
|
||||
void debugCheckNullptrInvariant() {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
Data != nullptr || Length == 0,
|
||||
"created ArrayRef with nullptr and non-zero length! std::optional relies on this being illegal");
|
||||
}
|
||||
|
||||
public:
|
||||
/// @name Constructors
|
||||
/// @name Constructors, all inherited from HeaderOnlyArrayRef except for
|
||||
/// SmallVector. As inherited constructors won't work with class template
|
||||
/// argument deduction (CTAD) until C++23, we add deduction guides after
|
||||
/// the class definition to enable CTAD.
|
||||
/// @{
|
||||
|
||||
/// Construct an empty ArrayRef.
|
||||
/* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {}
|
||||
|
||||
/// Construct an ArrayRef from a single element.
|
||||
// TODO Make this explicit
|
||||
constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
|
||||
|
||||
/// Construct an ArrayRef from a pointer and length.
|
||||
constexpr ArrayRef(const T* data, size_t length)
|
||||
: Data(data), Length(length) {
|
||||
debugCheckNullptrInvariant();
|
||||
}
|
||||
|
||||
/// Construct an ArrayRef from a range.
|
||||
constexpr ArrayRef(const T* begin, const T* end)
|
||||
: Data(begin), Length(end - begin) {
|
||||
debugCheckNullptrInvariant();
|
||||
}
|
||||
using HeaderOnlyArrayRef<T>::HeaderOnlyArrayRef;
|
||||
|
||||
/// Construct an ArrayRef from a SmallVector. This is templated in order to
|
||||
/// avoid instantiating SmallVectorTemplateCommon<T> whenever we
|
||||
/// copy-construct an ArrayRef.
|
||||
/// NOTE: this is the only constructor that is not inherited from
|
||||
/// HeaderOnlyArrayRef.
|
||||
template <typename U>
|
||||
/* implicit */ ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec)
|
||||
: Data(Vec.data()), Length(Vec.size()) {
|
||||
debugCheckNullptrInvariant();
|
||||
}
|
||||
|
||||
template <
|
||||
typename Container,
|
||||
typename U = decltype(std::declval<Container>().data()),
|
||||
typename = std::enable_if_t<
|
||||
(std::is_same_v<U, T*> || std::is_same_v<U, T const*>)>>
|
||||
/* implicit */ ArrayRef(const Container& container)
|
||||
: Data(container.data()), Length(container.size()) {
|
||||
debugCheckNullptrInvariant();
|
||||
}
|
||||
|
||||
/// Construct an ArrayRef from a std::vector.
|
||||
// The enable_if stuff here makes sure that this isn't used for
|
||||
// std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
|
||||
// bitfield.
|
||||
template <typename A>
|
||||
/* implicit */ ArrayRef(const std::vector<T, A>& Vec)
|
||||
: Data(Vec.data()), Length(Vec.size()) {
|
||||
static_assert(
|
||||
!std::is_same_v<T, bool>,
|
||||
"ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
|
||||
}
|
||||
|
||||
/// Construct an ArrayRef from a std::array
|
||||
template <size_t N>
|
||||
/* implicit */ constexpr ArrayRef(const std::array<T, N>& Arr)
|
||||
: Data(Arr.data()), Length(N) {}
|
||||
|
||||
/// Construct an ArrayRef from a C array.
|
||||
template <size_t N>
|
||||
// NOLINTNEXTLINE(*c-arrays*)
|
||||
/* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
|
||||
|
||||
/// Construct an ArrayRef from a std::initializer_list.
|
||||
/* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec)
|
||||
: Data(
|
||||
std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
|
||||
: std::begin(Vec)),
|
||||
Length(Vec.size()) {}
|
||||
: HeaderOnlyArrayRef<T>(Vec.data(), Vec.size()) {}
|
||||
|
||||
/// @}
|
||||
/// @name Simple Operations
|
||||
/// @name Simple Operations, mostly inherited from HeaderOnlyArrayRef
|
||||
/// @{
|
||||
|
||||
constexpr iterator begin() const {
|
||||
return Data;
|
||||
}
|
||||
constexpr iterator end() const {
|
||||
return Data + Length;
|
||||
}
|
||||
|
||||
// These are actually the same as iterator, since ArrayRef only
|
||||
// gives you const iterators.
|
||||
constexpr const_iterator cbegin() const {
|
||||
return Data;
|
||||
}
|
||||
constexpr const_iterator cend() const {
|
||||
return Data + Length;
|
||||
}
|
||||
|
||||
constexpr reverse_iterator rbegin() const {
|
||||
return reverse_iterator(end());
|
||||
}
|
||||
constexpr reverse_iterator rend() const {
|
||||
return reverse_iterator(begin());
|
||||
}
|
||||
|
||||
/// Check if all elements in the array satisfy the given expression
|
||||
constexpr bool allMatch(const std::function<bool(const T&)>& pred) const {
|
||||
return std::all_of(cbegin(), cend(), pred);
|
||||
}
|
||||
|
||||
/// empty - Check if the array is empty.
|
||||
constexpr bool empty() const {
|
||||
return Length == 0;
|
||||
}
|
||||
|
||||
constexpr const T* data() const {
|
||||
return Data;
|
||||
}
|
||||
|
||||
/// size - Get the array size.
|
||||
constexpr size_t size() const {
|
||||
return Length;
|
||||
}
|
||||
|
||||
/// front - Get the first element.
|
||||
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
|
||||
/// STD_TORCH_CHECK
|
||||
constexpr const T& front() const {
|
||||
TORCH_CHECK(
|
||||
!empty(), "ArrayRef: attempted to access front() of empty list");
|
||||
return Data[0];
|
||||
!this->empty(), "ArrayRef: attempted to access front() of empty list");
|
||||
return this->Data[0];
|
||||
}
|
||||
|
||||
/// back - Get the last element.
|
||||
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
|
||||
/// STD_TORCH_CHECK
|
||||
constexpr const T& back() const {
|
||||
TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list");
|
||||
return Data[Length - 1];
|
||||
}
|
||||
|
||||
/// equals - Check for element-wise equality.
|
||||
constexpr bool equals(ArrayRef RHS) const {
|
||||
return Length == RHS.Length && std::equal(begin(), end(), RHS.begin());
|
||||
TORCH_CHECK(
|
||||
!this->empty(), "ArrayRef: attempted to access back() of empty list");
|
||||
return this->Data[this->Length - 1];
|
||||
}
|
||||
|
||||
/// slice(n, m) - Take M elements of the array starting at element N
|
||||
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
|
||||
/// STD_TORCH_CHECK
|
||||
constexpr ArrayRef<T> slice(size_t N, size_t M) const {
|
||||
TORCH_CHECK(
|
||||
N + M <= size(),
|
||||
N + M <= this->size(),
|
||||
"ArrayRef: invalid slice, N = ",
|
||||
N,
|
||||
"; M = ",
|
||||
M,
|
||||
"; size = ",
|
||||
size());
|
||||
return ArrayRef<T>(data() + N, M);
|
||||
this->size());
|
||||
return ArrayRef<T>(this->data() + N, M);
|
||||
}
|
||||
|
||||
/// slice(n) - Chop off the first N elements of the array.
|
||||
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
|
||||
/// STD_TORCH_CHECK
|
||||
constexpr ArrayRef<T> slice(size_t N) const {
|
||||
TORCH_CHECK(
|
||||
N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size());
|
||||
return slice(N, size() - N);
|
||||
N <= this->size(),
|
||||
"ArrayRef: invalid slice, N = ",
|
||||
N,
|
||||
"; size = ",
|
||||
this->size());
|
||||
return slice(N, this->size() - N); // should this slice be this->slice?
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Operator Overloads
|
||||
/// @{
|
||||
constexpr const T& operator[](size_t Index) const {
|
||||
return Data[Index];
|
||||
}
|
||||
|
||||
/// Vector compatibility
|
||||
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
|
||||
/// STD_TORCH_CHECK
|
||||
constexpr const T& at(size_t Index) const {
|
||||
TORCH_CHECK(
|
||||
Index < Length,
|
||||
Index < this->Length,
|
||||
"ArrayRef: invalid index Index = ",
|
||||
Index,
|
||||
"; Length = ",
|
||||
Length);
|
||||
return Data[Index];
|
||||
this->Length);
|
||||
return this->Data[Index];
|
||||
}
|
||||
|
||||
/// Disallow accidental assignment from a temporary.
|
||||
@ -253,16 +153,48 @@ class ArrayRef final {
|
||||
std::enable_if_t<std::is_same_v<U, T>, ArrayRef<T>>& operator=(
|
||||
std::initializer_list<U>) = delete;
|
||||
|
||||
/// @}
|
||||
/// @name Expensive Operations
|
||||
/// @{
|
||||
std::vector<T> vec() const {
|
||||
return std::vector<T>(Data, Data + Length);
|
||||
}
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
/// Deduction guides for ArrayRef to support CTAD with inherited constructors
|
||||
/// These mirror the constructors inherited from HeaderOnlyArrayRef
|
||||
/// @{
|
||||
|
||||
// Single element constructor
|
||||
template <typename T>
|
||||
ArrayRef(const T&) -> ArrayRef<T>;
|
||||
|
||||
// Pointer and length constructor
|
||||
template <typename T>
|
||||
ArrayRef(const T*, size_t) -> ArrayRef<T>;
|
||||
|
||||
// Range constructor (begin, end)
|
||||
template <typename T>
|
||||
ArrayRef(const T*, const T*) -> ArrayRef<T>;
|
||||
|
||||
// Generic container constructor (anything with .data() and .size())
|
||||
template <typename Container>
|
||||
ArrayRef(const Container&) -> ArrayRef<
|
||||
std::remove_pointer_t<decltype(std::declval<Container>().data())>>;
|
||||
|
||||
// std::vector constructor
|
||||
template <typename T, typename A>
|
||||
ArrayRef(const std::vector<T, A>&) -> ArrayRef<T>;
|
||||
|
||||
// std::array constructor
|
||||
template <typename T, size_t N>
|
||||
ArrayRef(const std::array<T, N>&) -> ArrayRef<T>;
|
||||
|
||||
// C array constructor
|
||||
template <typename T, size_t N>
|
||||
ArrayRef(const T (&)[N]) -> ArrayRef<T>;
|
||||
|
||||
// std::initializer_list constructor
|
||||
template <typename T>
|
||||
ArrayRef(const std::initializer_list<T>&) -> ArrayRef<T>;
|
||||
|
||||
/// @}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& out, ArrayRef<T> list) {
|
||||
int i = 0;
|
||||
|
||||
@ -1307,7 +1307,7 @@ endif()
|
||||
|
||||
if(USE_MKLDNN_ACL)
|
||||
find_package(ACL REQUIRED)
|
||||
target_include_directories(torch_cpu PRIVATE ${ACL_INCLUDE_DIRS})
|
||||
target_include_directories(torch_cpu SYSTEM PRIVATE ${ACL_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
||||
target_include_directories(torch_cpu PRIVATE ${ATen_CPU_INCLUDE})
|
||||
|
||||
34
setup.py
34
setup.py
@ -630,37 +630,6 @@ def mirror_files_into_torchgen() -> None:
|
||||
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
|
||||
|
||||
|
||||
def mirror_inductor_external_kernels() -> None:
|
||||
"""
|
||||
Copy external kernels into Inductor so they are importable.
|
||||
"""
|
||||
paths = [
|
||||
(
|
||||
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
|
||||
CWD
|
||||
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
|
||||
),
|
||||
]
|
||||
for new_path, orig_path in paths:
|
||||
# Create the dirs involved in new_path if they don't exist
|
||||
if not new_path.exists():
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy the files from the orig location to the new location
|
||||
if orig_path.is_file():
|
||||
shutil.copyfile(orig_path, new_path)
|
||||
continue
|
||||
if orig_path.is_dir():
|
||||
if new_path.exists():
|
||||
# copytree fails if the tree exists already, so remove it.
|
||||
shutil.rmtree(new_path)
|
||||
shutil.copytree(orig_path, new_path)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"Check the file paths in `mirror_inductor_external_kernels()`"
|
||||
)
|
||||
|
||||
|
||||
# ATTENTION: THIS IS AI SLOP
|
||||
def extract_variant_from_version(version: str) -> str:
|
||||
"""Extract variant from version string, defaulting to 'cpu'."""
|
||||
@ -1647,8 +1616,6 @@ def main() -> None:
|
||||
if RUN_BUILD_DEPS:
|
||||
build_deps()
|
||||
|
||||
mirror_inductor_external_kernels()
|
||||
|
||||
(
|
||||
ext_modules,
|
||||
cmdclass,
|
||||
@ -1682,7 +1649,6 @@ def main() -> None:
|
||||
"_inductor/codegen/aoti_runtime/*.cpp",
|
||||
"_inductor/script.ld",
|
||||
"_inductor/kernel/flex/templates/*.jinja",
|
||||
"_inductor/kernel/templates/*.jinja",
|
||||
"_export/serde/*.yaml",
|
||||
"_export/serde/*.thrift",
|
||||
"share/cmake/ATen/*.cmake",
|
||||
|
||||
@ -12,6 +12,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp
|
||||
|
||||
52
test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp
Normal file
52
test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
using torch::headeronly::HeaderOnlyArrayRef;
|
||||
|
||||
TEST(TestHeaderOnlyArrayRef, TestEmpty) {
|
||||
HeaderOnlyArrayRef<float> arr;
|
||||
ASSERT_TRUE(arr.empty());
|
||||
}
|
||||
|
||||
TEST(TestHeaderOnlyArrayRef, TestSingleton) {
|
||||
float val = 5.0f;
|
||||
HeaderOnlyArrayRef<float> arr(val);
|
||||
ASSERT_FALSE(arr.empty());
|
||||
EXPECT_EQ(arr.size(), 1);
|
||||
EXPECT_EQ(arr[0], val);
|
||||
}
|
||||
|
||||
TEST(TestHeaderOnlyArrayRef, TestAPIs) {
|
||||
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
|
||||
HeaderOnlyArrayRef<int> arr(vec);
|
||||
ASSERT_FALSE(arr.empty());
|
||||
EXPECT_EQ(arr.size(), 7);
|
||||
for (size_t i = 0; i < arr.size(); i++) {
|
||||
EXPECT_EQ(arr[i], i + 1);
|
||||
EXPECT_EQ(arr.at(i), i + 1);
|
||||
}
|
||||
EXPECT_EQ(arr.front(), 1);
|
||||
EXPECT_EQ(arr.back(), 7);
|
||||
ASSERT_TRUE(arr.slice(3, 4).equals(arr.slice(3)));
|
||||
}
|
||||
|
||||
TEST(TestHeaderOnlyArrayRef, TestFromInitializerList) {
|
||||
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
|
||||
HeaderOnlyArrayRef<int> arr({1, 2, 3, 4, 5, 6, 7});
|
||||
auto res_vec = arr.vec();
|
||||
for (size_t i = 0; i < vec.size(); i++) {
|
||||
EXPECT_EQ(vec[i], res_vec[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestHeaderOnlyArrayRef, TestFromRange) {
|
||||
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
|
||||
HeaderOnlyArrayRef<int> arr(vec.data() + 3, vec.data() + 7);
|
||||
auto res_vec = arr.vec();
|
||||
for (size_t i = 0; i < res_vec.size(); i++) {
|
||||
EXPECT_EQ(vec[i + 3], res_vec[i]);
|
||||
}
|
||||
}
|
||||
@ -311,10 +311,9 @@ void boxed_fill_infinity(
|
||||
}
|
||||
|
||||
Tensor my_pad(Tensor t) {
|
||||
std::vector<int64_t> padding = {1, 2, 2, 1};
|
||||
std::string mode = "constant";
|
||||
double value = 0.0;
|
||||
return pad(t, padding, mode, value);
|
||||
return pad(t, {1, 2, 2, 1}, mode, value);
|
||||
}
|
||||
|
||||
void boxed_my_pad(
|
||||
@ -342,6 +341,9 @@ void boxed_my_narrow(
|
||||
}
|
||||
|
||||
Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
// Still using a std::vector below even though people can just pass in an
|
||||
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
|
||||
// directly.
|
||||
std::vector<int64_t> sizes = {2, 5};
|
||||
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
|
||||
return new_empty(t, sizes, dtype);
|
||||
@ -353,9 +355,8 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui
|
||||
}
|
||||
|
||||
Tensor my_new_zeros_dtype_variant(Tensor t) {
|
||||
std::vector<int64_t> sizes = {2, 5};
|
||||
auto dtype = std::make_optional(at::ScalarType::Float);
|
||||
return new_zeros(t, sizes, dtype);
|
||||
return new_zeros(t, {2, 5}, dtype);
|
||||
}
|
||||
|
||||
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
@ -429,8 +430,7 @@ void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs)
|
||||
}
|
||||
|
||||
Tensor my_amax_vec(Tensor t) {
|
||||
std::vector<int64_t> v = {0,1};
|
||||
return amax(t, v, false);
|
||||
return amax(t, {0,1}, false);
|
||||
}
|
||||
|
||||
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
|
||||
@ -5,8 +5,16 @@ import contextlib
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -426,6 +434,31 @@ class TestDTensorDebugMode(TestCase):
|
||||
][-1]
|
||||
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
|
||||
|
||||
def test_pretty_print_dtensor_make_fx(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
A = torch.randn(8, 32)
|
||||
B = torch.randn(32, 32)
|
||||
dA = distribute_tensor(A, mesh, [Shard(0)]).requires_grad_()
|
||||
dB = distribute_tensor(B, mesh, [Replicate()]).requires_grad_()
|
||||
|
||||
def f(dA, dB):
|
||||
dy = dA @ dB
|
||||
loss = dy.sum()
|
||||
loss.backward()
|
||||
return dA.grad, dB.grad
|
||||
|
||||
# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
|
||||
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
|
||||
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
|
||||
gm = make_fx(f, tracing_mode="fake")(dA, dB)
|
||||
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.recompile()
|
||||
# Colored is nice for actual viewing, not using in this test though
|
||||
gm_str = gm.print_readable(colored=False, print_output=False)
|
||||
self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDTensorDebugMode)
|
||||
|
||||
|
||||
@ -13194,6 +13194,30 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_pytree_tree_map_dict_order(self, pytree):
|
||||
def fn(tree):
|
||||
new_tree = pytree.tree_map(lambda x: x, tree)
|
||||
return list(new_tree.keys()), list(new_tree.values())
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
fn_opt = torch.compile(fullgraph=True)(fn)
|
||||
|
||||
tree1 = {"b": x + 2, "a": x, "c": x - 1}
|
||||
expected1 = fn(tree1)
|
||||
actual1 = fn_opt(tree1)
|
||||
self.assertEqual(actual1, expected1)
|
||||
|
||||
tree2 = collections.OrderedDict([("b", x + 2), ("a", x), ("c", x - 1)])
|
||||
expected2 = fn(tree2)
|
||||
actual2 = fn_opt(tree2)
|
||||
self.assertEqual(actual2, expected2)
|
||||
|
||||
tree3 = collections.defaultdict(int, {"b": x + 2, "a": x, "c": x - 1})
|
||||
expected3 = fn(tree3)
|
||||
actual3 = fn_opt(tree3)
|
||||
self.assertEqual(actual3, expected3)
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_pytree_tree_map_only(self, pytree):
|
||||
if not callable(getattr(pytree, "tree_map_only", None)):
|
||||
|
||||
@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca
|
||||
torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
|
||||
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node
|
||||
torch.fx.graph.Graph.print_tabular(self)
|
||||
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode
|
||||
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode
|
||||
torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')
|
||||
torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool
|
||||
torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None
|
||||
|
||||
@ -1,154 +0,0 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
|
||||
from torch._inductor.utils import ensure_cute_available
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
|
||||
"CuTeDSL library or Blackwell device not available",
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
class TestCuTeDSLGroupedGemm(InductorTestCase):
|
||||
def _get_inputs(
|
||||
self,
|
||||
group_size: int,
|
||||
M_hint: int,
|
||||
K: int,
|
||||
N: int,
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
alignment: int = 16,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# --- Random, tile-aligned M sizes ---
|
||||
M_sizes = (
|
||||
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
|
||||
* alignment
|
||||
)
|
||||
|
||||
M_total = torch.sum(M_sizes).item()
|
||||
|
||||
# --- Construct input tensors ---
|
||||
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
|
||||
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
# --- Build offsets (no leading zero, strictly increasing) ---
|
||||
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
|
||||
|
||||
return (A, B, offsets)
|
||||
|
||||
@parametrize("group_size", (2, 8))
|
||||
@parametrize("M_hint", (256, 1024))
|
||||
@parametrize("K", (64, 128))
|
||||
@parametrize("N", (128, 256))
|
||||
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# Eager execution
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# Test with Cute backend
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
|
||||
@parametrize("layout_B", ("contiguous", "broadcasted"))
|
||||
def test_grouped_gemm_assorted_layouts(
|
||||
self,
|
||||
layout_A: str,
|
||||
layout_B: str,
|
||||
):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
G, K, N = 8, 64, 128
|
||||
M_sizes = [128] * G
|
||||
sum_M = sum(M_sizes)
|
||||
offsets = torch.tensor(
|
||||
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
|
||||
A = A_base
|
||||
|
||||
if layout_A == "offset":
|
||||
# allocate bigger buffer than needed, use nonzero storage offset
|
||||
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
|
||||
offset = 128 # skip first 128 elements
|
||||
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
|
||||
elif layout_A == "padded":
|
||||
# simulate row pitch > K (row_stride = K + pad)
|
||||
row_pitch = K + 8
|
||||
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
|
||||
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
|
||||
elif layout_A == "view":
|
||||
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
|
||||
A = A_storage.view(sum_M, K)
|
||||
assert A._base is not None
|
||||
assert A.shape == (sum_M, K)
|
||||
|
||||
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
if layout_B == "broadcasted":
|
||||
# Broadcast B across groups (zero stride along G)
|
||||
B = B[0].expand(G, K, N)
|
||||
assert B.stride(0) == 0
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# --- eager ---
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# --- compiled (CUTE backend) ---
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -6,6 +6,7 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._C import FileCheck
|
||||
from torch._inductor import config, utils
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
@ -29,7 +30,6 @@ from torch.testing._internal.inductor_utils import (
|
||||
HAS_CPU,
|
||||
HAS_CUDA_AND_TRITON,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import FileCheck
|
||||
from torch.utils._triton import has_triton_tma_device
|
||||
|
||||
|
||||
@ -953,6 +953,240 @@ class TestFP8Lowering(TestCase):
|
||||
self.assertEqual(y_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0.07)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@torch._inductor.config.patch("emulate_precision_casts", True)
|
||||
def test_mx_fusion(self):
|
||||
# Register fake_scaled_mm custom op scoped to this test
|
||||
with torch.library._scoped_library("test_fp8", "FRAGMENT") as lib:
|
||||
# Define the op schema
|
||||
lib.define(
|
||||
"fake_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scale_a, Tensor scale_b, "
|
||||
"Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, "
|
||||
"bool use_fast_accum=False) -> Tensor"
|
||||
)
|
||||
input_values = []
|
||||
|
||||
# Register CUDA implementation
|
||||
@torch.library.impl(lib, "fake_scaled_mm", "CUDA")
|
||||
def fake_scaled_mm_impl(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=None,
|
||||
use_fast_accum=False,
|
||||
):
|
||||
"""Software-emulated scaled_mm for testing without CUDA 12.8"""
|
||||
out_dtype = out_dtype or torch.bfloat16
|
||||
# just using add, because without real dtypes,
|
||||
# was seeing overflow/instability
|
||||
nonlocal input_values
|
||||
input_values.append((mat_a, mat_b, scale_a, scale_b))
|
||||
result = mat_a.to(torch.float32) + mat_b.to(torch.float32)
|
||||
if bias is not None:
|
||||
result = result + bias.to(torch.float32)
|
||||
return result.to(out_dtype)
|
||||
|
||||
# Register fake implementation
|
||||
@torch.library.impl(lib, "fake_scaled_mm", "Meta")
|
||||
def fake_scaled_mm_meta(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=None,
|
||||
use_fast_accum=False,
|
||||
):
|
||||
"""FakeTensor implementation"""
|
||||
out_dtype = out_dtype or torch.bfloat16
|
||||
M, K = mat_a.shape
|
||||
K2, N = mat_b.shape
|
||||
torch._check(
|
||||
K == K2,
|
||||
lambda: f"Incompatible shapes: {mat_a.shape} @ {mat_b.shape}",
|
||||
)
|
||||
return torch.empty((M, N), dtype=out_dtype, device=mat_a.device)
|
||||
|
||||
def forward(
|
||||
arg0_1,
|
||||
arg1_1,
|
||||
):
|
||||
view = torch.ops.aten.reshape.default(arg0_1, [8192, 256, 32])
|
||||
abs_1 = torch.ops.aten.abs.default(view)
|
||||
amax = torch.ops.aten.amax.default(abs_1, [-1])
|
||||
unsqueeze = torch.ops.aten.unsqueeze.default(amax, -1)
|
||||
view_1 = torch.ops.aten.view.dtype(unsqueeze, torch.int32)
|
||||
bitwise_right_shift = torch.ops.aten.bitwise_right_shift.Tensor_Scalar(
|
||||
view_1, 23
|
||||
)
|
||||
bitwise_and = torch.ops.aten.bitwise_and.Scalar(
|
||||
bitwise_right_shift, 255
|
||||
)
|
||||
sub = torch.ops.aten.sub.Tensor(bitwise_and, 127)
|
||||
sub_1 = torch.ops.aten.sub.Tensor(sub, 8)
|
||||
clamp_min = torch.ops.aten.clamp_min.default(sub_1, -127)
|
||||
clamp_max = torch.ops.aten.clamp_max.default(clamp_min, 128)
|
||||
add = torch.ops.aten.add.Tensor(clamp_max, 127)
|
||||
convert_element_type = torch.ops.prims.convert_element_type.default(
|
||||
add, torch.uint8
|
||||
)
|
||||
isnan = torch.ops.aten.isnan.default(unsqueeze)
|
||||
scalar_tensor = torch.ops.aten.scalar_tensor.default(
|
||||
255, dtype=torch.uint8, layout=torch.strided, device="cuda"
|
||||
)
|
||||
where = torch.ops.aten.where.self(
|
||||
isnan, scalar_tensor, convert_element_type
|
||||
)
|
||||
convert_element_type_1 = torch.ops.prims.convert_element_type.default(
|
||||
where, torch.int32
|
||||
)
|
||||
bitwise_left_shift = torch.ops.aten.bitwise_left_shift.Tensor_Scalar(
|
||||
convert_element_type_1, 23
|
||||
)
|
||||
view_2 = torch.ops.aten.view.dtype(bitwise_left_shift, torch.float32)
|
||||
clamp_min_1 = torch.ops.aten.clamp_min.default(
|
||||
view_2, 1.1754943508222875e-38
|
||||
)
|
||||
div = torch.ops.aten.div.Tensor(view, clamp_min_1)
|
||||
clamp_min_2 = torch.ops.aten.clamp_min.default(div, -448.0)
|
||||
clamp_max_1 = torch.ops.aten.clamp_max.default(clamp_min_2, 448.0)
|
||||
convert_element_type_2 = torch.ops.prims.convert_element_type.default(
|
||||
clamp_max_1, torch.float8_e4m3fn
|
||||
)
|
||||
view_3 = torch.ops.aten.reshape.default(
|
||||
convert_element_type_2, [8192, 8192]
|
||||
)
|
||||
convert_element_type_2 = None
|
||||
view_4 = torch.ops.aten.view.dtype(where, torch.float8_e8m0fnu)
|
||||
squeeze = torch.ops.aten.squeeze.dim(view_4, -1)
|
||||
|
||||
view_5 = torch.ops.aten.reshape.default(arg1_1, [8192, 256, 32])
|
||||
abs_2 = torch.ops.aten.abs.default(view_5)
|
||||
amax_1 = torch.ops.aten.amax.default(abs_2, [-1])
|
||||
unsqueeze_1 = torch.ops.aten.unsqueeze.default(amax_1, -1)
|
||||
view_6 = torch.ops.aten.view.dtype(unsqueeze_1, torch.int32)
|
||||
bitwise_right_shift_1 = (
|
||||
torch.ops.aten.bitwise_right_shift.Tensor_Scalar(view_6, 23)
|
||||
)
|
||||
bitwise_and_1 = torch.ops.aten.bitwise_and.Scalar(
|
||||
bitwise_right_shift_1, 255
|
||||
)
|
||||
sub_2 = torch.ops.aten.sub.Tensor(bitwise_and_1, 127)
|
||||
sub_3 = torch.ops.aten.sub.Tensor(sub_2, 8)
|
||||
clamp_min_3 = torch.ops.aten.clamp_min.default(sub_3, -127)
|
||||
clamp_max_2 = torch.ops.aten.clamp_max.default(clamp_min_3, 128)
|
||||
add_1 = torch.ops.aten.add.Tensor(clamp_max_2, 127)
|
||||
convert_element_type_3 = torch.ops.prims.convert_element_type.default(
|
||||
add_1, torch.uint8
|
||||
)
|
||||
isnan_1 = torch.ops.aten.isnan.default(unsqueeze_1)
|
||||
unsqueeze_1 = None
|
||||
scalar_tensor_1 = torch.ops.aten.scalar_tensor.default(
|
||||
255, dtype=torch.uint8, layout=torch.strided, device="cuda"
|
||||
)
|
||||
where_1 = torch.ops.aten.where.self(
|
||||
isnan_1, scalar_tensor_1, convert_element_type_3
|
||||
)
|
||||
convert_element_type_4 = torch.ops.prims.convert_element_type.default(
|
||||
where_1, torch.int32
|
||||
)
|
||||
bitwise_left_shift_1 = torch.ops.aten.bitwise_left_shift.Tensor_Scalar(
|
||||
convert_element_type_4, 23
|
||||
)
|
||||
convert_element_type_4 = None
|
||||
view_7 = torch.ops.aten.view.dtype(bitwise_left_shift_1, torch.float32)
|
||||
bitwise_left_shift_1 = None
|
||||
clamp_min_4 = torch.ops.aten.clamp_min.default(
|
||||
view_7, 1.1754943508222875e-38
|
||||
)
|
||||
div_1 = torch.ops.aten.div.Tensor(view_5, clamp_min_4)
|
||||
clamp_min_5 = torch.ops.aten.clamp_min.default(div_1, -448.0)
|
||||
clamp_max_3 = torch.ops.aten.clamp_max.default(clamp_min_5, 448.0)
|
||||
convert_element_type_5 = torch.ops.prims.convert_element_type.default(
|
||||
clamp_max_3, torch.float8_e4m3fn
|
||||
)
|
||||
view_8 = torch.ops.aten.reshape.default(
|
||||
convert_element_type_5, [8192, 8192]
|
||||
)
|
||||
view_9 = torch.ops.aten.view.dtype(where_1, torch.float8_e8m0fnu)
|
||||
squeeze_1 = torch.ops.aten.squeeze.dim(view_9, -1)
|
||||
|
||||
permute = torch.ops.aten.permute.default(view_8, [1, 0])
|
||||
|
||||
view_13 = torch.ops.aten.reshape.default(squeeze, [64, 128, 64, 4])
|
||||
permute_2 = torch.ops.aten.permute.default(view_13, [0, 2, 1, 3])
|
||||
clone = torch.ops.aten.clone.default(
|
||||
permute_2, memory_format=torch.contiguous_format
|
||||
)
|
||||
view_14 = torch.ops.aten.reshape.default(clone, [4096, 4, 32, 4])
|
||||
permute_3 = torch.ops.aten.permute.default(view_14, [0, 2, 1, 3])
|
||||
clone_1 = torch.ops.aten.clone.default(
|
||||
permute_3, memory_format=torch.contiguous_format
|
||||
)
|
||||
view_15 = torch.ops.aten.reshape.default(clone_1, [4096, 32, 16])
|
||||
|
||||
view_16 = torch.ops.aten.reshape.default(view_15, [2097152])
|
||||
|
||||
view_18 = torch.ops.aten.reshape.default(squeeze_1, [64, 128, 64, 4])
|
||||
permute_5 = torch.ops.aten.permute.default(view_18, [0, 2, 1, 3])
|
||||
clone_2 = torch.ops.aten.clone.default(
|
||||
permute_5, memory_format=torch.contiguous_format
|
||||
)
|
||||
view_19 = torch.ops.aten.reshape.default(clone_2, [4096, 4, 32, 4])
|
||||
permute_6 = torch.ops.aten.permute.default(view_19, [0, 2, 1, 3])
|
||||
clone_3 = torch.ops.aten.clone.default(
|
||||
permute_6, memory_format=torch.contiguous_format
|
||||
)
|
||||
view_20 = torch.ops.aten.reshape.default(clone_3, [4096, 32, 16])
|
||||
|
||||
view_21 = torch.ops.aten.reshape.default(view_20, [2097152])
|
||||
|
||||
_scaled_mm = torch.ops.test_fp8.fake_scaled_mm.default(
|
||||
view_3, permute, view_16, view_21, None, None, torch.float32
|
||||
)
|
||||
return (_scaled_mm,)
|
||||
|
||||
# Run with largest shape
|
||||
M, K, N = 8192, 8192, 8192
|
||||
device = "cuda"
|
||||
|
||||
A = torch.randn(M, K, dtype=torch.float32, device=device)
|
||||
B = torch.randn(K, N, dtype=torch.float32, device=device)
|
||||
f_c = torch.compile(fullgraph=True)(forward)
|
||||
|
||||
_, code = run_and_get_code(f_c, A, B)
|
||||
|
||||
FileCheck().check(".run(").check(".run(").check("fake_scaled_mm").run(
|
||||
code[0]
|
||||
)
|
||||
|
||||
for seed in range(5):
|
||||
input_values.clear()
|
||||
torch.manual_seed(seed)
|
||||
# without dividing, outputs get way too large
|
||||
A = torch.randn(M, K, dtype=torch.float32, device=device)
|
||||
B = torch.randn(K, N, dtype=torch.float32, device=device)
|
||||
|
||||
# Uses fake_scaled_mm custom op (no CUDA 12.8 needed!)
|
||||
torch._dynamo.reset()
|
||||
torch.compile(forward)(A, B)
|
||||
|
||||
torch._dynamo.reset()
|
||||
with config.patch({"loop_index_inversion_in_fusion": False}):
|
||||
torch.compile(forward)(A, B)
|
||||
|
||||
assert len(input_values) == 2
|
||||
for i in range(4):
|
||||
self.assertEqual(
|
||||
input_values[0][i],
|
||||
input_values[1][i],
|
||||
msg=f"idx {i} seed {seed}",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("M", (1, 3, 33, 257, 1024))
|
||||
@parametrize("K", (16, 32, 1024))
|
||||
|
||||
@ -16,6 +16,7 @@ from torch._dynamo.utils import same
|
||||
from torch._inductor import config as inductor_config, ir, metrics
|
||||
from torch._inductor.codegen.triton import TritonScheduling
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.invert_expr_analysis import generate_inverse_formula
|
||||
from torch._inductor.scheduler import SchedulerNode
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.test_operators import realize
|
||||
@ -1188,6 +1189,113 @@ class TestTiling(TestCase):
|
||||
torch.compile(f)(x)
|
||||
|
||||
|
||||
class TestIndexInversion(TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
gm = torch.fx.symbolic_trace(lambda: 0)
|
||||
graph = GraphLowering(gm)
|
||||
graph.scheduler = MockScheduler
|
||||
cls._exit_stack = contextlib.ExitStack()
|
||||
cls._exit_stack.enter_context(V.set_graph_handler(graph))
|
||||
|
||||
def _check_expr(self, expr, reconstruction, val_range):
|
||||
import numpy as np
|
||||
from sympy import lambdify
|
||||
|
||||
assert len(expr.free_symbols) == 1
|
||||
p0 = next(iter(expr.free_symbols))
|
||||
|
||||
def floordiv_replacement(a, b):
|
||||
"""Replace FloorDiv(a, b) with a // b"""
|
||||
return a // b
|
||||
|
||||
def modularindexing_replacement(x, base, divisor):
|
||||
"""Replace ModularIndexing(x, base, divisor) with (x // base) % divisor"""
|
||||
return (x // base) % divisor
|
||||
|
||||
# Replace custom functions with sympy equivalents
|
||||
expr_numpy_ready = expr.replace(FloorDiv, floordiv_replacement).replace(
|
||||
ModularIndexing, modularindexing_replacement
|
||||
)
|
||||
reconstruction_numpy_ready = reconstruction.replace(
|
||||
FloorDiv, floordiv_replacement
|
||||
).replace(ModularIndexing, modularindexing_replacement)
|
||||
|
||||
# Now lambdify with standard numpy
|
||||
forward_func = lambdify(p0, expr_numpy_ready, modules="numpy")
|
||||
inverse_func = lambdify(p0, reconstruction_numpy_ready, modules="numpy")
|
||||
|
||||
test_values = np.arange(0, val_range, dtype=np.int64)
|
||||
forward_values = forward_func(test_values).astype(np.int64)
|
||||
recovered_values = inverse_func(forward_values).astype(np.int64)
|
||||
torch.testing.assert_close(test_values, recovered_values)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
cls._exit_stack.close()
|
||||
|
||||
def test_original_complex_expression(self):
|
||||
"""Test the original motivating complex expression."""
|
||||
p0 = sympy.Symbol("p0")
|
||||
expr = (
|
||||
32768 * FloorDiv(p0, 32768)
|
||||
+ 8192 * FloorDiv(ModularIndexing(p0, 1, 16), 4)
|
||||
+ ModularIndexing(p0, 1, 4)
|
||||
+ 256 * ModularIndexing(p0, 16, 32)
|
||||
+ 4 * ModularIndexing(p0, 512, 64)
|
||||
)
|
||||
|
||||
reconstruction = generate_inverse_formula(expr, p0)
|
||||
self.assertIsNotNone(reconstruction)
|
||||
self._check_expr(expr, reconstruction, 2097152)
|
||||
|
||||
def test_inversion_cases(self):
|
||||
"""Test various expressions for correct inversion behavior."""
|
||||
p = sympy.Symbol("p")
|
||||
|
||||
cases = [
|
||||
# (expression, should_be_invertible, test_range)
|
||||
# Simple 2-term base-10 style: 10 = 1 × 10 ✓
|
||||
(10 * ModularIndexing(p, 10, 10) + ModularIndexing(p, 1, 10), True, 100),
|
||||
# Simple 2-term base-2 style: 2 = 1 × 2 ✓
|
||||
(2 * ModularIndexing(p, 2, 2) + ModularIndexing(p, 1, 2), True, 4),
|
||||
# 3-term decimal: 100 = 10×10, 10 = 1×10 ✓
|
||||
(
|
||||
100 * FloorDiv(p, 100)
|
||||
+ 10 * FloorDiv(ModularIndexing(p, 1, 100), 10)
|
||||
+ ModularIndexing(p, 1, 10),
|
||||
True,
|
||||
1000,
|
||||
),
|
||||
(4 * p, False, 64), # expr and inverse not bijections
|
||||
# when sorted, invertible
|
||||
(ModularIndexing(p, 1, 10) + 10 * ModularIndexing(p, 10, 10), True, None),
|
||||
# Wrong coefficient ratios: 4 ≠ 1×2
|
||||
(4 * ModularIndexing(p, 1, 8) + ModularIndexing(p, 8, 2), False, None),
|
||||
(
|
||||
100 * FloorDiv(p, 100) + 7 * ModularIndexing(p, 1, 100),
|
||||
False,
|
||||
None,
|
||||
), # Wrong ratios
|
||||
(FloorDiv(p, 100) + FloorDiv(p, 10) + p, False, None), # Overlapping ranges
|
||||
(p**2 + 10 * p + 1, False, None), # Quadratic
|
||||
(sympy.sin(p) + sympy.cos(p), False, None), # Trigonometric
|
||||
]
|
||||
|
||||
for expr, should_invert, test_range in cases:
|
||||
reconstruction = generate_inverse_formula(expr, p)
|
||||
|
||||
if should_invert:
|
||||
self.assertIsNotNone(reconstruction, f"Expected invertible: {expr}")
|
||||
# Test correctness on sample values
|
||||
self._check_expr(expr, reconstruction, test_range)
|
||||
else:
|
||||
self.assertIsNone(reconstruction, f"Expected non-invertible: {expr}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if HAS_GPU:
|
||||
run_tests()
|
||||
|
||||
@ -14424,6 +14424,20 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
|
||||
self.common(fn, (torch.randn(6, 4, device=GPU_TYPE).t().contiguous().t(),))
|
||||
|
||||
@skip_if_halide
|
||||
@requires_cuda_and_triton
|
||||
def test_unbacked_float_item(self):
|
||||
def fn(x, max_val):
|
||||
return torch.clamp(x, 0, max_val.item())
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
(
|
||||
torch.randn(10, 20, 30, device=self.device),
|
||||
torch.tensor(5.0, device=self.device),
|
||||
),
|
||||
)
|
||||
|
||||
# end of class CommonTemplate - add new tests here
|
||||
|
||||
|
||||
|
||||
@ -73,7 +73,22 @@ from tools.testing.test_selections import (
|
||||
ShardedTest,
|
||||
THRESHOLD,
|
||||
)
|
||||
from tools.testing.upload_artifacts import zip_and_upload_artifacts
|
||||
|
||||
|
||||
try:
|
||||
from tools.testing.upload_artifacts import (
|
||||
parse_xml_and_upload_json,
|
||||
zip_and_upload_artifacts,
|
||||
)
|
||||
except ImportError:
|
||||
# some imports in those files might fail, e.g., boto3 not installed. These
|
||||
# functions are only needed under specific circumstances (CI) so we can
|
||||
# define dummy functions here.
|
||||
def parse_xml_and_upload_json():
|
||||
pass
|
||||
|
||||
def zip_and_upload_artifacts(failed: bool):
|
||||
pass
|
||||
|
||||
|
||||
# Make sure to remove REPO_ROOT after import is done
|
||||
@ -1887,6 +1902,7 @@ def run_tests(
|
||||
def handle_complete(failure: Optional[TestFailure]):
|
||||
failed = failure is not None
|
||||
if IS_CI and options.upload_artifacts_while_running:
|
||||
parse_xml_and_upload_json()
|
||||
zip_and_upload_artifacts(failed)
|
||||
if not failed:
|
||||
return False
|
||||
|
||||
176
test/test_as_strided.py
Normal file
176
test/test_as_strided.py
Normal file
@ -0,0 +1,176 @@
|
||||
# Owner(s): ["oncall: pt2"]
|
||||
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
def get_state(t: torch.Tensor) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
"""Extract (sizes, strides) tuple from a tensor."""
|
||||
return (tuple(t.size()), tuple(t.stride()))
|
||||
|
||||
|
||||
def enumerate_reachable_states(
|
||||
initial_size: int,
|
||||
) -> set[tuple[tuple[int, ...], tuple[int, ...]]]:
|
||||
"""
|
||||
Use BFS with DP to enumerate all reachable (size, stride) states from
|
||||
a 1D contiguous tensor via valid view operations.
|
||||
|
||||
We only explore states with offset=0 (you can retroactively change the offset).
|
||||
We reject states with size=0 or size=1 dimensions as they are degenerate.
|
||||
"""
|
||||
# Create initial 1D contiguous tensor
|
||||
initial_tensor = torch.arange(initial_size)
|
||||
|
||||
initial_state = get_state(initial_tensor)
|
||||
|
||||
# Map from state to tensor for that state
|
||||
state_to_tensor: dict[tuple[tuple[int, ...], tuple[int, ...]], torch.Tensor] = {
|
||||
initial_state: initial_tensor
|
||||
}
|
||||
visited: set[tuple[tuple[int, ...], tuple[int, ...]]] = {initial_state}
|
||||
queue: deque[tuple[tuple[int, ...], tuple[int, ...]]] = deque([initial_state])
|
||||
|
||||
while queue:
|
||||
state = queue.popleft()
|
||||
t = state_to_tensor[state]
|
||||
sizes, strides = state
|
||||
ndim = len(sizes)
|
||||
|
||||
def add_state(new_t: torch.Tensor) -> None:
|
||||
new_state = get_state(new_t)
|
||||
sizes, strides = new_state
|
||||
# Skip if has size-0 or size-1 dimensions
|
||||
if any(s == 0 or s == 1 for s in sizes):
|
||||
return
|
||||
# Only accept states where strides are in descending order
|
||||
if list(strides) != sorted(strides, reverse=True):
|
||||
return
|
||||
if new_state not in visited:
|
||||
visited.add(new_state)
|
||||
queue.append(new_state)
|
||||
state_to_tensor[new_state] = new_t
|
||||
|
||||
# 1. Unflatten: try factoring each dimension
|
||||
for dim in range(ndim):
|
||||
size = sizes[dim]
|
||||
assert size > 1
|
||||
# Try all factorizations x * y = size where both x, y >= 2
|
||||
# We only need to check x up to size // 2 since when x > size // 2,
|
||||
# y = size // x < 2, which we reject
|
||||
for x in range(2, size // 2 + 1):
|
||||
if size % x == 0:
|
||||
y = size // x
|
||||
add_state(t.unflatten(dim, (x, y)))
|
||||
|
||||
# 2. Slice: exhaustively check all possible slicing parameters
|
||||
for dim in range(ndim):
|
||||
size = sizes[dim]
|
||||
for start in range(size):
|
||||
for stop in range(start + 1, size + 1):
|
||||
for step in range(1, size + 1):
|
||||
slices = [slice(None)] * ndim
|
||||
slices[dim] = slice(start, stop, step)
|
||||
add_state(t[tuple(slices)])
|
||||
|
||||
# 3. Flatten: merge adjacent dimensions
|
||||
for dim in range(ndim - 1):
|
||||
add_state(t.flatten(dim, dim + 1))
|
||||
|
||||
return visited
|
||||
|
||||
|
||||
class TestAsStrided(TestCase):
|
||||
def test_size_10_exhaustive(self) -> None:
|
||||
"""Test that size 10 produces exactly the expected 54 states."""
|
||||
expected_states = {
|
||||
((2,), (1,)),
|
||||
((2,), (2,)),
|
||||
((2,), (3,)),
|
||||
((2,), (4,)),
|
||||
((2,), (5,)),
|
||||
((2,), (6,)),
|
||||
((2,), (7,)),
|
||||
((2,), (8,)),
|
||||
((2,), (9,)),
|
||||
((2, 2), (2, 1)),
|
||||
((2, 2), (3, 1)),
|
||||
((2, 2), (3, 2)),
|
||||
((2, 2), (4, 1)),
|
||||
((2, 2), (4, 2)),
|
||||
((2, 2), (4, 3)),
|
||||
((2, 2), (5, 1)),
|
||||
((2, 2), (5, 2)),
|
||||
((2, 2), (5, 3)),
|
||||
((2, 2), (5, 4)),
|
||||
((2, 2), (6, 1)),
|
||||
((2, 2), (6, 2)),
|
||||
((2, 2), (6, 3)),
|
||||
((2, 2), (8, 1)),
|
||||
((2, 2, 2), (4, 2, 1)),
|
||||
((2, 2, 2), (5, 2, 1)),
|
||||
((2, 3), (3, 1)),
|
||||
((2, 3), (4, 1)),
|
||||
((2, 3), (5, 1)),
|
||||
((2, 3), (5, 2)),
|
||||
((2, 3), (6, 1)),
|
||||
((2, 4), (4, 1)),
|
||||
((2, 4), (5, 1)),
|
||||
((2, 5), (5, 1)),
|
||||
((3,), (1,)),
|
||||
((3,), (2,)),
|
||||
((3,), (3,)),
|
||||
((3,), (4,)),
|
||||
((3, 2), (2, 1)),
|
||||
((3, 2), (3, 1)),
|
||||
((3, 2), (3, 2)),
|
||||
((3, 2), (4, 1)),
|
||||
((3, 3), (3, 1)),
|
||||
((4,), (1,)),
|
||||
((4,), (2,)),
|
||||
((4,), (3,)),
|
||||
((4, 2), (2, 1)),
|
||||
((5,), (1,)),
|
||||
((5,), (2,)),
|
||||
((5, 2), (2, 1)),
|
||||
((6,), (1,)),
|
||||
((7,), (1,)),
|
||||
((8,), (1,)),
|
||||
((9,), (1,)),
|
||||
((10,), (1,)),
|
||||
}
|
||||
|
||||
actual_states = enumerate_reachable_states(10)
|
||||
|
||||
self.assertEqual(len(actual_states), 54)
|
||||
self.assertEqual(actual_states, expected_states)
|
||||
|
||||
def test_subset_property(self) -> None:
|
||||
"""
|
||||
Test that for sizes 2..10, each smaller tensor results in a strict
|
||||
subset of possible states compared to the next one.
|
||||
"""
|
||||
prev_states: Optional[set[tuple[tuple[int, ...], tuple[int, ...]]]] = None
|
||||
for size in range(2, 11):
|
||||
current_states = enumerate_reachable_states(size)
|
||||
|
||||
if prev_states is not None:
|
||||
# Check that prev_states is a strict subset of current_states
|
||||
self.assertTrue(
|
||||
prev_states.issubset(current_states),
|
||||
f"States from size {size - 1} are not a subset of size {size}",
|
||||
)
|
||||
# Check that it's a strict subset (not equal)
|
||||
self.assertTrue(
|
||||
len(prev_states) < len(current_states),
|
||||
f"States from size {size - 1} should be strictly fewer than size {size}",
|
||||
)
|
||||
|
||||
prev_states = current_states
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
180
test/test_fx.py
180
test/test_fx.py
@ -75,12 +75,6 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
|
||||
from torch.autograd.profiler_util import _canonicalize_profiler_events
|
||||
|
||||
try:
|
||||
from torchvision import models as torchvision_models
|
||||
|
||||
@ -207,36 +201,6 @@ def side_effect_func(x: torch.Tensor):
|
||||
print(x)
|
||||
|
||||
|
||||
def _enrich_profiler_traces(prof):
|
||||
"""
|
||||
Helper function to extract and augment profiler events with stack traces.
|
||||
|
||||
Args:
|
||||
prof: A torch.profiler.profile object
|
||||
|
||||
Returns:
|
||||
A string representing enriched events
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
|
||||
trace_file = f.name
|
||||
prof.export_chrome_trace(trace_file)
|
||||
|
||||
with open(trace_file) as f:
|
||||
trace_data = json.load(f)
|
||||
|
||||
map_recorded_events_to_aten_ops_with_stack_trace(
|
||||
trace_data
|
||||
)
|
||||
|
||||
events = []
|
||||
for event in trace_data["traceEvents"]:
|
||||
if "args" in event and "stack_trace" in event["args"]:
|
||||
events.append(event)
|
||||
|
||||
actual_traces = _canonicalize_profiler_events(events)
|
||||
return actual_traces
|
||||
|
||||
|
||||
class TestFX(JitTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -4248,150 +4212,6 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
||||
# recorver mutable checking flag
|
||||
torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_stack_trace_augmentation(self):
|
||||
"""
|
||||
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
|
||||
augments profiler events with stack traces from FX metadata registry.
|
||||
"""
|
||||
|
||||
# Simple test model
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.linear2 = torch.nn.Linear(16, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
model = TestModel().cuda()
|
||||
|
||||
# Compile the model
|
||||
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
_ = compiled_model(torch.randn(10, 10, device="cuda"))
|
||||
|
||||
# Profile with the compiled model
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
) as prof:
|
||||
result = compiled_model(torch.randn(10, 10, device="cuda"))
|
||||
|
||||
actual_traces = _enrich_profiler_traces(prof)
|
||||
|
||||
self.assertExpectedInline(actual_traces, """\
|
||||
event=aten::t node=t stack_trace=x = self.linear1(x)
|
||||
event=aten::transpose node=t stack_trace=x = self.linear1(x)
|
||||
event=aten::as_strided node=t stack_trace=x = self.linear1(x)
|
||||
event=aten::addmm node=addmm stack_trace=x = self.linear1(x)
|
||||
event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x)
|
||||
event=aten::relu node=relu stack_trace=x = self.relu(x)
|
||||
event=aten::clamp_min node=relu stack_trace=x = self.relu(x)
|
||||
event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x)
|
||||
event=aten::t node=t_1 stack_trace=x = self.linear2(x)
|
||||
event=aten::transpose node=t_1 stack_trace=x = self.linear2(x)
|
||||
event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x)
|
||||
event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x)
|
||||
event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_multiple_modules(self):
|
||||
"""
|
||||
Test that multiple compiled modules under the same profiler session
|
||||
have their events correctly augmented with stack traces.
|
||||
"""
|
||||
|
||||
class ModelA(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1
|
||||
|
||||
class ModelB(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x - 1
|
||||
|
||||
model_a = ModelA().cuda()
|
||||
model_b = ModelB().cuda()
|
||||
|
||||
# Compile both models
|
||||
compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True)
|
||||
compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True)
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
_ = compiled_a(torch.randn(10, 10, device="cuda"))
|
||||
_ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
|
||||
|
||||
# Profile both models in the same session
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
) as prof:
|
||||
result_a = compiled_a(torch.randn(10, 10, device="cuda"))
|
||||
result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
|
||||
|
||||
actual_traces = _enrich_profiler_traces(prof)
|
||||
self.assertExpectedInline(actual_traces, """\
|
||||
event=aten::add node=add stack_trace=return x + 1
|
||||
event=cudaLaunchKernel node=add stack_trace=return x + 1
|
||||
event=aten::sub node=sub stack_trace=return x - 1
|
||||
event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_nested_graph_modules(self):
|
||||
"""
|
||||
Test that nested graph modules (e.g., graph modules calling subgraphs)
|
||||
have their events correctly augmented with stack traces.
|
||||
"""
|
||||
|
||||
# Model with nested structure
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c = 5
|
||||
|
||||
@torch.compiler.nested_compile_region
|
||||
def forward(self, x, y):
|
||||
m = torch.mul(x, y)
|
||||
s = m.sin()
|
||||
a = s + self.c
|
||||
return a
|
||||
|
||||
model = Mod().cuda()
|
||||
|
||||
# Compile the model (this may create nested graph modules)
|
||||
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
_ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
|
||||
|
||||
# Profile
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
) as prof:
|
||||
result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
|
||||
|
||||
actual_traces = _enrich_profiler_traces(prof)
|
||||
self.assertExpectedInline(actual_traces, """\
|
||||
event=aten::mul node=mul stack_trace=m = torch.mul(x, y)
|
||||
event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y)
|
||||
event=aten::sin node=sin stack_trace=s = m.sin()
|
||||
event=cudaLaunchKernel node=sin stack_trace=s = m.sin()
|
||||
event=aten::add node=add stack_trace=a = s + self.c
|
||||
event=cudaLaunchKernel node=add stack_trace=a = s + self.c"""
|
||||
)
|
||||
|
||||
|
||||
def run_getitem_target():
|
||||
from torch.fx._symbolic_trace import _wrapped_methods_to_patch
|
||||
|
||||
@ -490,8 +490,6 @@ class TestMatmulCuda(InductorTestCase):
|
||||
@parametrize("b_row_major", [False, True])
|
||||
@dtypes(torch.bfloat16, torch.float32, torch.float16)
|
||||
def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype):
|
||||
if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]:
|
||||
self.skipTest("failed using hipblaslt on rocm 6.4.2")
|
||||
device = "cuda"
|
||||
s_int = int(strided)
|
||||
m, n, k, n_groups = 16, 32, 64, 4
|
||||
|
||||
@ -601,6 +601,24 @@ class TestGenericPytree(TestCase):
|
||||
for case in cases:
|
||||
run_test(case)
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_tree_map_dict_order(self, pytree):
|
||||
d = {"b": 2, "a": 1, "c": 3}
|
||||
od = OrderedDict([("b", 2), ("a", 1), ("c", 3)])
|
||||
dd = defaultdict(int, {"b": 2, "a": 1, "c": 3})
|
||||
for tree in (d, od, dd):
|
||||
result = pytree.tree_map(lambda x: x, tree)
|
||||
self.assertEqual(
|
||||
list(result.keys()),
|
||||
list(tree.keys()),
|
||||
msg=f"Dictionary keys order changed in tree_map: {tree!r} vs. {result!r}",
|
||||
)
|
||||
self.assertEqual(
|
||||
list(result.values()),
|
||||
list(tree.values()),
|
||||
msg=f"Dictionary keys order changed in tree_map: {tree!r} vs. {result!r}",
|
||||
)
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_tree_map_only(self, pytree):
|
||||
self.assertEqual(pytree.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"])
|
||||
|
||||
@ -38,12 +38,14 @@ def parse_xml_report(
|
||||
report: Path,
|
||||
workflow_id: int,
|
||||
workflow_run_attempt: int,
|
||||
job_id: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert a test report xml file into a JSON-serializable list of test cases."""
|
||||
print(f"Parsing {tag}s for test report: {report}")
|
||||
|
||||
job_id = get_job_id(report)
|
||||
print(f"Found job id: {job_id}")
|
||||
if job_id is None:
|
||||
job_id = get_job_id(report)
|
||||
print(f"Found job id: {job_id}")
|
||||
|
||||
test_cases: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
import glob
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import zipfile
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from tools.stats.upload_test_stats import parse_xml_report
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
@ -140,3 +145,66 @@ def trigger_upload_test_stats_intermediate_workflow() -> None:
|
||||
},
|
||||
)
|
||||
print(x.text)
|
||||
|
||||
|
||||
def parse_xml_and_upload_json() -> None:
|
||||
"""
|
||||
Parse xml test reports that do not yet have a corresponding json report
|
||||
uploaded to s3, and upload the json reports to s3. Use filelock to avoid
|
||||
uploading the same file from multiple processes.
|
||||
"""
|
||||
try:
|
||||
job_id: Optional[int] = int(os.environ.get("JOB_ID", 0))
|
||||
if job_id == 0:
|
||||
job_id = None
|
||||
except (ValueError, TypeError):
|
||||
job_id = None
|
||||
|
||||
try:
|
||||
for xml_file in glob.glob(
|
||||
f"{REPO_ROOT}/test/test-reports/**/*.xml", recursive=True
|
||||
):
|
||||
xml_path = Path(xml_file)
|
||||
json_file = xml_path.with_suffix(".json")
|
||||
lock = FileLock(str(json_file) + ".lock")
|
||||
|
||||
try:
|
||||
lock.acquire(timeout=0) # immediately fails if already locked
|
||||
if json_file.exists():
|
||||
continue # already uploaded
|
||||
test_cases = parse_xml_report(
|
||||
"testcase",
|
||||
xml_path,
|
||||
int(os.environ.get("GITHUB_RUN_ID", "0")),
|
||||
int(os.environ.get("GITHUB_RUN_ATTEMPT", "0")),
|
||||
job_id,
|
||||
)
|
||||
line_by_line_jsons = "\n".join([json.dumps(tc) for tc in test_cases])
|
||||
|
||||
gzipped = gzip.compress(line_by_line_jsons.encode("utf-8"))
|
||||
s3_key = (
|
||||
json_file.relative_to(REPO_ROOT / "test/test-reports")
|
||||
.as_posix()
|
||||
.replace("/", "_")
|
||||
)
|
||||
|
||||
get_s3_resource().put_object(
|
||||
Body=gzipped,
|
||||
Bucket="gha-artifacts",
|
||||
Key=f"test_jsons_while_running/{os.environ.get('GITHUB_RUN_ID')}/{job_id}/{s3_key}",
|
||||
ContentType="application/json",
|
||||
ContentEncoding="gzip",
|
||||
)
|
||||
|
||||
# We don't need to save the json file locally, but doing so lets us
|
||||
# track which ones have been uploaded already. We could probably also
|
||||
# check S3
|
||||
with open(json_file, "w") as f:
|
||||
f.write(line_by_line_jsons)
|
||||
except Timeout:
|
||||
continue # another process is working on this file
|
||||
finally:
|
||||
if lock.is_locked:
|
||||
lock.release()
|
||||
except Exception as e:
|
||||
print(f"Failed to parse and upload json test reports: {e}")
|
||||
|
||||
@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch.utils._pytree as python_pytree
|
||||
@ -24,9 +24,15 @@ if TYPE_CHECKING:
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_KT = TypeVar("_KT")
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
import optree
|
||||
import optree._C
|
||||
import optree.utils
|
||||
|
||||
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
|
||||
|
||||
@ -600,14 +606,47 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
|
||||
__all__ += ["tree_map_"]
|
||||
|
||||
_none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr, attr-defined]
|
||||
_none_registration = optree.register_pytree_node.get(type(None))
|
||||
assert _none_registration is not None
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
_none_unflatten,
|
||||
_none_registration.unflatten_func,
|
||||
can_constant_fold_through=True,
|
||||
skip_signature_check=True,
|
||||
)
|
||||
def none_unflatten(_: None, children: Iterable[Any], /) -> None:
|
||||
def none_unflatten(_: None, children: Iterable[_T], /) -> None:
|
||||
if len(list(children)) != 0:
|
||||
raise ValueError("Expected no children.")
|
||||
return None
|
||||
|
||||
with optree.dict_insertion_ordered(False, namespace="torch"):
|
||||
_dict_registration = optree.register_pytree_node.get(dict)
|
||||
assert _dict_registration is not None
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
_dict_registration.flatten_func,
|
||||
can_constant_fold_through=True,
|
||||
skip_signature_check=True,
|
||||
)
|
||||
def dict_flatten(
|
||||
dct: dict[_KT, _VT], /
|
||||
) -> tuple[list[_VT], tuple[list[_KT], list[_KT]], tuple[_KT, ...]]:
|
||||
sorted_keys = optree.utils.total_order_sorted(dct)
|
||||
values = [dct[key] for key in sorted_keys]
|
||||
original_keys = list(dct)
|
||||
return values, (original_keys, sorted_keys), tuple(sorted_keys)
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
_dict_registration.unflatten_func,
|
||||
can_constant_fold_through=True,
|
||||
skip_signature_check=True,
|
||||
)
|
||||
def dict_unflatten(
|
||||
metadata: tuple[list[_KT], list[_KT]],
|
||||
values: Iterable[_VT],
|
||||
/,
|
||||
) -> dict[_KT, _VT]:
|
||||
original_keys, sorted_keys = metadata
|
||||
d = dict.fromkeys(original_keys)
|
||||
d.update(zip(sorted_keys, values))
|
||||
return d # type: ignore[return-value]
|
||||
|
||||
@ -3320,7 +3320,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg]
|
||||
assert isinstance(obj, SetVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "add", [v], {})
|
||||
obj.call_method(self, "add", [v], {}) # type: ignore[arg-type]
|
||||
|
||||
def SET_UPDATE(self, inst: Instruction) -> None:
|
||||
v = self.pop()
|
||||
@ -3329,7 +3329,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg]
|
||||
assert isinstance(obj, SetVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "update", [v], {})
|
||||
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
|
||||
|
||||
def LIST_APPEND(self, inst: Instruction) -> None:
|
||||
v = self.pop()
|
||||
@ -3637,7 +3637,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg].realize()
|
||||
assert isinstance(obj, ConstDictVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "update", [v], {})
|
||||
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
|
||||
|
||||
DICT_UPDATE = DICT_MERGE
|
||||
|
||||
|
||||
@ -1991,7 +1991,7 @@ class BuiltinVariable(VariableTracker):
|
||||
# If the object implements a __getitem__ method, iter(...) will call obj.__getitem__()
|
||||
# with an integer argument starting at 0, until __getitem__ raises IndexError
|
||||
ret = variables.UserFunctionVariable(
|
||||
polyfills.builtins.iter_
|
||||
polyfills.builtins.iter_ # type: ignore[arg-type]
|
||||
).call_function(tx, [obj, *args], {})
|
||||
|
||||
if args:
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Dictionary-related variable tracking classes for PyTorch Dynamo.
|
||||
|
||||
@ -26,7 +24,7 @@ import inspect
|
||||
import operator
|
||||
import types
|
||||
from collections.abc import Hashable as py_Hashable
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch._subclasses.fake_tensor import is_fake
|
||||
|
||||
@ -59,11 +57,13 @@ if TYPE_CHECKING:
|
||||
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
|
||||
|
||||
|
||||
def was_instancecheck_override(obj):
|
||||
def was_instancecheck_override(obj: Any) -> bool:
|
||||
return type(obj).__dict__.get("__instancecheck__", False)
|
||||
|
||||
|
||||
def raise_unhashable(arg, tx=None):
|
||||
def raise_unhashable(
|
||||
arg: VariableTracker, tx: Optional["InstructionTranslator"] = None
|
||||
) -> None:
|
||||
if tx is None:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
@ -75,7 +75,7 @@ def raise_unhashable(arg, tx=None):
|
||||
)
|
||||
|
||||
|
||||
def is_hashable(x):
|
||||
def is_hashable(x: VariableTracker) -> bool:
|
||||
# NB - performing isinstance check on a LazVT realizes the VT, accidentally
|
||||
# inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at
|
||||
# the underlying value without realizing the VT. Consider updating the
|
||||
@ -143,7 +143,7 @@ class ConstDictVariable(VariableTracker):
|
||||
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
|
||||
"""
|
||||
|
||||
def __init__(self, vt) -> None:
|
||||
def __init__(self, vt: VariableTracker) -> None:
|
||||
# We specialize SymNodes
|
||||
vt = specialize_symnode(vt)
|
||||
# TODO Temporarily remove to figure out what keys are we breaking on
|
||||
@ -153,7 +153,7 @@ class ConstDictVariable(VariableTracker):
|
||||
self.vt = vt
|
||||
|
||||
@property
|
||||
def underlying_value(self):
|
||||
def underlying_value(self) -> Any:
|
||||
if (
|
||||
isinstance(self.vt, variables.LazyVariableTracker)
|
||||
and not self.vt.is_realized()
|
||||
@ -178,7 +178,8 @@ class ConstDictVariable(VariableTracker):
|
||||
elif isinstance(self.vt, variables.FrozenDataClassVariable):
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
fields_values = {
|
||||
k: Hashable(v).underlying_value for k, v in self.vt.fields.items()
|
||||
k: Hashable(v).underlying_value
|
||||
for k, v in self.vt.fields.items() # type: ignore[attr-defined]
|
||||
}
|
||||
return variables.FrozenDataClassVariable.HashWrapper(
|
||||
self.vt.python_type(), fields_values
|
||||
@ -187,16 +188,16 @@ class ConstDictVariable(VariableTracker):
|
||||
# The re module in Python 3.13+ has a dictionary (_cache2) with
|
||||
# an object as key (`class _ZeroSentinel(int): ...`):
|
||||
# python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual
|
||||
return self.vt.value
|
||||
return self.vt.value # type: ignore[attr-defined,union-attr]
|
||||
else:
|
||||
x = self.vt.as_python_constant()
|
||||
return x
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.underlying_value)
|
||||
|
||||
@staticmethod
|
||||
def _eq_impl(a, b):
|
||||
def _eq_impl(a: Any, b: Any) -> bool:
|
||||
# TODO: Put this in utils and share it between variables/builtin.py and here
|
||||
type_a, type_b = type(a), type(b)
|
||||
if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)):
|
||||
@ -212,7 +213,7 @@ class ConstDictVariable(VariableTracker):
|
||||
else:
|
||||
return a == b
|
||||
|
||||
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), (
|
||||
type(other)
|
||||
@ -226,8 +227,8 @@ class ConstDictVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
items: dict[VariableTracker, VariableTracker],
|
||||
user_cls=dict,
|
||||
**kwargs,
|
||||
user_cls: type = dict,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# .clone() pass these arguments in kwargs but they're recreated a few
|
||||
# lines below
|
||||
@ -247,18 +248,22 @@ class ConstDictVariable(VariableTracker):
|
||||
for x, v in items.items()
|
||||
)
|
||||
|
||||
def make_hashable(key):
|
||||
def make_hashable(
|
||||
key: Union[VariableTracker, "ConstDictVariable._HashableTracker"],
|
||||
) -> "ConstDictVariable._HashableTracker":
|
||||
return key if isinstance(key, Hashable) else Hashable(key)
|
||||
|
||||
dict_cls = self._get_dict_cls_from_user_cls(user_cls)
|
||||
self.items = dict_cls({make_hashable(x): v for x, v in items.items()})
|
||||
# need to reconstruct everything if the dictionary is an intermediate value
|
||||
# or if a pop/delitem was executed
|
||||
self.should_reconstruct_all = not is_from_local_source(self.source)
|
||||
self.should_reconstruct_all = (
|
||||
not is_from_local_source(self.source) if self.source else True
|
||||
)
|
||||
self.original_items = items.copy()
|
||||
self.user_cls = user_cls
|
||||
|
||||
def _get_dict_cls_from_user_cls(self, user_cls):
|
||||
def _get_dict_cls_from_user_cls(self, user_cls: type) -> type:
|
||||
accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict)
|
||||
|
||||
# avoid executing user code if user_cls is a dict subclass
|
||||
@ -277,10 +282,10 @@ class ConstDictVariable(VariableTracker):
|
||||
dict_cls = dict
|
||||
return dict_cls
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> dict[Any, Any]:
|
||||
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
return (
|
||||
"{"
|
||||
+ ", ".join(
|
||||
@ -289,20 +294,20 @@ class ConstDictVariable(VariableTracker):
|
||||
+ "}"
|
||||
)
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> dict[Any, Any]:
|
||||
return {
|
||||
k.vt.as_python_constant(): v.as_python_constant()
|
||||
for k, v in self.items.items()
|
||||
}
|
||||
|
||||
def keys_as_python_constant(self):
|
||||
def keys_as_python_constant(self) -> dict[Any, VariableTracker]:
|
||||
self.install_dict_keys_match_guard()
|
||||
return {k.vt.as_python_constant(): v for k, v in self.items.items()}
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return self.user_cls
|
||||
|
||||
def __contains__(self, vt) -> bool:
|
||||
def __contains__(self, vt: VariableTracker) -> bool:
|
||||
assert isinstance(vt, VariableTracker)
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
return (
|
||||
@ -322,13 +327,15 @@ class ConstDictVariable(VariableTracker):
|
||||
for key, value in self.items.items()
|
||||
)
|
||||
|
||||
def is_new_item(self, value, other):
|
||||
def is_new_item(
|
||||
self, value: Optional[VariableTracker], other: VariableTracker
|
||||
) -> bool:
|
||||
# compare the id of the realized values if both values are not lazy VTs
|
||||
if value and value.is_realized() and other.is_realized():
|
||||
return id(value.realize()) != id(other.realize())
|
||||
return id(value) != id(other)
|
||||
|
||||
def reconstruct_kvs_into_new_dict(self, codegen):
|
||||
def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None:
|
||||
# Build a dictionary that contains the keys and values.
|
||||
num_args = 0
|
||||
for key, value in self.items.items():
|
||||
@ -340,7 +347,7 @@ class ConstDictVariable(VariableTracker):
|
||||
num_args += 1
|
||||
codegen.append_output(create_instruction("BUILD_MAP", arg=num_args))
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
if self.user_cls is collections.OrderedDict:
|
||||
# emit `OrderedDict(constructed_dict)`
|
||||
codegen.add_push_null(
|
||||
@ -358,19 +365,21 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
def getitem_const_raise_exception_if_absent(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
):
|
||||
) -> VariableTracker:
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
raise_observed_exception(KeyError, tx)
|
||||
return self.items[key]
|
||||
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
msg = f"Dictionary key {arg.value} not found during tracing"
|
||||
msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined]
|
||||
unimplemented_v2(
|
||||
gb_type="key not found in dict",
|
||||
context=f"Key {arg.value}",
|
||||
context=f"Key {arg.value}", # type: ignore[attr-defined]
|
||||
explanation=msg,
|
||||
hints=[
|
||||
"Check if the key exists in the dictionary before accessing it.",
|
||||
@ -379,13 +388,13 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
return self.items[key]
|
||||
|
||||
def maybe_getitem_const(self, arg: VariableTracker):
|
||||
def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]:
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
return None
|
||||
return self.items[key]
|
||||
|
||||
def realize_key_vt(self, arg: VariableTracker):
|
||||
def realize_key_vt(self, arg: VariableTracker) -> None:
|
||||
# Realize the LazyVT on a particular index
|
||||
assert arg in self
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
@ -394,11 +403,13 @@ class ConstDictVariable(VariableTracker):
|
||||
if isinstance(original_key_vt, variables.LazyVariableTracker):
|
||||
original_key_vt.realize()
|
||||
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
if self.source:
|
||||
install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH))
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
# Key guarding - These are the cases to consider
|
||||
# 1) The dict has been mutated. In this case, we would have already
|
||||
# inserted a DICT_KEYS_MATCH guard, so we can skip.
|
||||
@ -439,11 +450,11 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
# NB - Both key and value are LazyVariableTrackers in the beginning. So,
|
||||
# we have to insert guards when a dict method is accessed. For this to
|
||||
# be simple, we are conservative and overguard. We skip guard only for
|
||||
@ -462,7 +473,7 @@ class ConstDictVariable(VariableTracker):
|
||||
tx, *args, **kwargs
|
||||
)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items.update(temp_dict_vt.items)
|
||||
self.items.update(temp_dict_vt.items) # type: ignore[attr-defined]
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "__getitem__":
|
||||
# Key guarding - Nothing to do. LazyVT for value will take care.
|
||||
@ -526,7 +537,7 @@ class ConstDictVariable(VariableTracker):
|
||||
return ConstantVariable.create(len(self.items))
|
||||
elif name == "__setitem__" and self.is_mutable():
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
self.install_dict_keys_match_guard()
|
||||
if kwargs or len(args) != 2:
|
||||
@ -550,7 +561,7 @@ class ConstDictVariable(VariableTracker):
|
||||
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
if args[0] not in self:
|
||||
self.install_dict_contains_guard(tx, args)
|
||||
@ -565,7 +576,7 @@ class ConstDictVariable(VariableTracker):
|
||||
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
if args[0] not in self:
|
||||
# missing item, return the default value. Install no DICT_CONTAINS guard.
|
||||
@ -599,7 +610,7 @@ class ConstDictVariable(VariableTracker):
|
||||
last = v.value
|
||||
else:
|
||||
raise_args_mismatch(tx, name)
|
||||
k, v = self.items.popitem(last=last)
|
||||
k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined]
|
||||
else:
|
||||
k, v = self.items.popitem()
|
||||
|
||||
@ -632,17 +643,17 @@ class ConstDictVariable(VariableTracker):
|
||||
# NB - Guard on all the keys of the other dict to ensure
|
||||
# correctness.
|
||||
args[0].install_dict_keys_match_guard()
|
||||
dict_vt = args[0]
|
||||
dict_vt: ConstDictVariable = args[0]
|
||||
else:
|
||||
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
|
||||
self.items.update(dict_vt.items)
|
||||
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment]
|
||||
self.items.update(dict_vt.items) # type: ignore[attr-defined]
|
||||
if has_kwargs:
|
||||
# Handle kwargs
|
||||
kwargs = {
|
||||
kwargs_hashable = {
|
||||
Hashable(ConstantVariable.create(k)): v
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
self.items.update(kwargs)
|
||||
self.items.update(kwargs_hashable)
|
||||
return ConstantVariable.create(None)
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
@ -656,7 +667,7 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
self.install_dict_contains_guard(tx, args)
|
||||
contains = args[0] in self
|
||||
@ -671,7 +682,7 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
self.install_dict_keys_match_guard()
|
||||
if kwargs or len(args) > 2:
|
||||
@ -707,7 +718,7 @@ class ConstDictVariable(VariableTracker):
|
||||
and "last" in kwargs
|
||||
and isinstance(kwargs["last"], ConstantVariable)
|
||||
):
|
||||
last = kwargs.get("last").value
|
||||
last = kwargs.get("last").value # type: ignore[union-attr]
|
||||
|
||||
key = Hashable(args[0])
|
||||
self.items.move_to_end(key, last=last)
|
||||
@ -723,7 +734,7 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
elif name == "__ne__":
|
||||
return ConstantVariable.create(
|
||||
not self.call_method(tx, "__eq__", args, kwargs).value
|
||||
not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined]
|
||||
)
|
||||
elif name == "__or__":
|
||||
if len(args) != 1:
|
||||
@ -750,14 +761,14 @@ class ConstDictVariable(VariableTracker):
|
||||
if not istype(
|
||||
other, (ConstDictVariable, variables.UserDefinedDictVariable)
|
||||
):
|
||||
msg = (
|
||||
err_msg = (
|
||||
f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
|
||||
f"and '{other.python_type().__name__}'"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
raise_observed_exception(TypeError, tx, args=[err_msg])
|
||||
|
||||
# OrderedDict overloads __ror__
|
||||
ts = {self.user_cls, other.user_cls}
|
||||
ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined]
|
||||
user_cls = (
|
||||
collections.OrderedDict
|
||||
if any(issubclass(t, collections.OrderedDict) for t in ts)
|
||||
@ -774,8 +785,8 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
# NB - Guard on all the keys of the other dict to ensure
|
||||
# correctness.
|
||||
args[0].install_dict_keys_match_guard()
|
||||
new_dict_vt.items.update(args[0].items)
|
||||
args[0].install_dict_keys_match_guard() # type: ignore[attr-defined]
|
||||
new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined]
|
||||
return new_dict_vt
|
||||
elif name == "__ior__":
|
||||
self.call_method(tx, "update", args, kwargs)
|
||||
@ -789,11 +800,13 @@ class ConstDictVariable(VariableTracker):
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
self.install_dict_keys_match_guard()
|
||||
return [x.vt for x in self.items.keys()]
|
||||
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
# dict not allow setting arbitrary attributes. OrderedDict and
|
||||
# defaultdict allow arbitrary setattr, but not deletion of default attrs
|
||||
if any(
|
||||
@ -816,25 +829,25 @@ class ConstDictVariable(VariableTracker):
|
||||
],
|
||||
)
|
||||
|
||||
def clone(self, **kwargs):
|
||||
def clone(self, **kwargs: Any) -> VariableTracker:
|
||||
self.install_dict_keys_match_guard()
|
||||
return super().clone(**kwargs)
|
||||
|
||||
|
||||
class MappingProxyVariable(VariableTracker):
|
||||
# proxies to the original dict_vt
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(dv_dict, ConstDictVariable)
|
||||
self.dv_dict = dv_dict
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return types.MappingProxyType
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
return self.dv_dict.unpack_var_sequence(tx)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# load types.MappingProxyType
|
||||
if self.source:
|
||||
msg = (
|
||||
@ -863,11 +876,11 @@ class MappingProxyVariable(VariableTracker):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if self.source and tx.output.side_effects.has_existing_dict_mutation():
|
||||
msg = (
|
||||
"A dict has been modified while we have an existing mappingproxy object. "
|
||||
@ -892,7 +905,7 @@ class MappingProxyVariable(VariableTracker):
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if self.python_type() is types.MappingProxyType:
|
||||
return ConstantVariable.create(name in types.MappingProxyType.__dict__)
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
@ -900,35 +913,44 @@ class MappingProxyVariable(VariableTracker):
|
||||
|
||||
class NNModuleHooksDictVariable(ConstDictVariable):
|
||||
# Special class to avoid adding any guards on the nn module hook ids.
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DefaultDictVariable(ConstDictVariable):
|
||||
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
items: dict[VariableTracker, VariableTracker],
|
||||
user_cls: type,
|
||||
default_factory: Optional[VariableTracker] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(items, user_cls, **kwargs)
|
||||
assert user_cls is collections.defaultdict
|
||||
if default_factory is None:
|
||||
default_factory = ConstantVariable.create(None)
|
||||
self.default_factory = default_factory
|
||||
|
||||
def is_python_constant(self):
|
||||
def is_python_constant(self) -> bool:
|
||||
# Return false for unsupported defaults. This ensures that a bad handler
|
||||
# path is not taken in BuiltinVariable for getitem.
|
||||
if self.default_factory not in [list, tuple, dict] and not self.items:
|
||||
return False
|
||||
return super().is_python_constant()
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
assert self.default_factory is not None
|
||||
return (
|
||||
f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_supported_arg(arg):
|
||||
def is_supported_arg(arg: VariableTracker) -> bool:
|
||||
if isinstance(arg, variables.BuiltinVariable):
|
||||
return arg.fn in (list, tuple, dict, set)
|
||||
else:
|
||||
@ -942,11 +964,11 @@ class DefaultDictVariable(ConstDictVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__getitem__":
|
||||
if len(args) != 1:
|
||||
raise_args_mismatch(tx, name, "1 args", f"{len(args)} args")
|
||||
@ -962,13 +984,13 @@ class DefaultDictVariable(ConstDictVariable):
|
||||
else:
|
||||
default_var = self.default_factory.call_function(tx, [], {})
|
||||
super().call_method(
|
||||
tx, "__setitem__", (args[0], default_var), kwargs
|
||||
tx, "__setitem__", [args[0], default_var], kwargs
|
||||
)
|
||||
return default_var
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# emit `defaultdict(default_factory, new_dict)`
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
@ -994,40 +1016,48 @@ class SetVariable(ConstDictVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# pyrefly: ignore[bad-assignment]
|
||||
items = dict.fromkeys(items, SetVariable._default_value())
|
||||
# pyrefly: ignore[bad-argument-type]
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
if not self.items:
|
||||
return "set()"
|
||||
else:
|
||||
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
|
||||
return set(self.items.keys())
|
||||
|
||||
@staticmethod
|
||||
def _default_value():
|
||||
def _default_value() -> VariableTracker:
|
||||
# Variable to fill in he keys of the dictionary
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> Any:
|
||||
return {k.vt.as_proxy() for k in self.set_items}
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return set
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return {k.vt.as_python_constant() for k in self.set_items}
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.foreach([x.vt for x in self.set_items])
|
||||
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
|
||||
|
||||
def _fast_set_method(self, tx, fn, args, kwargs):
|
||||
def _fast_set_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
fn: Any,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
try:
|
||||
res = fn(
|
||||
*[x.as_python_constant() for x in [self, *args]],
|
||||
@ -1037,15 +1067,16 @@ class SetVariable(ConstDictVariable):
|
||||
raise_observed_exception(
|
||||
type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
|
||||
)
|
||||
# pyrefly: ignore[unbound-name]
|
||||
return VariableTracker.build(tx, res)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
# We forward the calls to the dictionary model
|
||||
from ..utils import check_constant_args
|
||||
|
||||
@ -1065,10 +1096,10 @@ class SetVariable(ConstDictVariable):
|
||||
return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
|
||||
|
||||
if name == "__init__":
|
||||
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
|
||||
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items.clear()
|
||||
self.items.update(temp_set_vt.items)
|
||||
self.items.update(temp_set_vt.items) # type: ignore[attr-defined]
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "add":
|
||||
if kwargs or len(args) != 1:
|
||||
@ -1079,7 +1110,7 @@ class SetVariable(ConstDictVariable):
|
||||
f"{len(args)} args and {len(kwargs)} kwargs",
|
||||
)
|
||||
name = "__setitem__"
|
||||
args = (args[0], SetVariable._default_value())
|
||||
args = [args[0], SetVariable._default_value()]
|
||||
elif name == "pop":
|
||||
if kwargs or args:
|
||||
raise_args_mismatch(
|
||||
@ -1090,12 +1121,14 @@ class SetVariable(ConstDictVariable):
|
||||
)
|
||||
# Choose an item at random and pop it via the Dict.pop method
|
||||
try:
|
||||
result = self.set_items.pop().vt
|
||||
result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment]
|
||||
except KeyError as e:
|
||||
raise_observed_exception(
|
||||
KeyError, tx, args=list(map(ConstantVariable.create, e.args))
|
||||
)
|
||||
super().call_method(tx, name, (result,), kwargs)
|
||||
# pyrefly: ignore[unbound-name]
|
||||
super().call_method(tx, name, [result], kwargs)
|
||||
# pyrefly: ignore[unbound-name]
|
||||
return result
|
||||
elif name == "isdisjoint":
|
||||
if kwargs or len(args) != 1:
|
||||
@ -1217,6 +1250,7 @@ class SetVariable(ConstDictVariable):
|
||||
f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
assert m is not None
|
||||
return self.call_method(tx, m, args, kwargs)
|
||||
elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"):
|
||||
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
|
||||
@ -1230,29 +1264,34 @@ class SetVariable(ConstDictVariable):
|
||||
"__ixor__": "symmetric_difference_update",
|
||||
"__isub__": "difference_update",
|
||||
}.get(name)
|
||||
assert m is not None
|
||||
self.call_method(tx, m, args, kwargs)
|
||||
return self
|
||||
elif name == "__eq__":
|
||||
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
|
||||
return ConstantVariable.create(False)
|
||||
r = self.call_method(tx, "symmetric_difference", args, kwargs)
|
||||
return ConstantVariable.create(len(r.set_items) == 0)
|
||||
return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined]
|
||||
elif name in cmp_name_to_op_mapping:
|
||||
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
return ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
|
||||
)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
raise RuntimeError("Illegal to getitem on a set")
|
||||
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
super().install_dict_contains_guard(tx, args)
|
||||
|
||||
|
||||
@ -1260,27 +1299,27 @@ class FrozensetVariable(SetVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
if not self.items:
|
||||
return "frozenset()"
|
||||
else:
|
||||
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
|
||||
return self.items.keys()
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return frozenset
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return frozenset({k.vt.as_python_constant() for k in self.set_items})
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.foreach([x.vt for x in self.set_items])
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
@ -1293,11 +1332,11 @@ class FrozensetVariable(SetVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
|
||||
raise RuntimeError(f"Illegal call_method {name} on a frozenset")
|
||||
elif name == "__init__":
|
||||
@ -1316,7 +1355,7 @@ class FrozensetVariable(SetVariable):
|
||||
"symmetric_difference",
|
||||
):
|
||||
r = super().call_method(tx, name, args, kwargs)
|
||||
return FrozensetVariable(r.items)
|
||||
return FrozensetVariable(r.items) # type: ignore[attr-defined]
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
||||
@ -1324,11 +1363,11 @@ class DictKeySetVariable(SetVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
if not self.items:
|
||||
return "dict_keys([])"
|
||||
else:
|
||||
@ -1338,33 +1377,35 @@ class DictKeySetVariable(SetVariable):
|
||||
+ "])"
|
||||
)
|
||||
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> Any:
|
||||
return self.items
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_keys
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return dict.fromkeys(
|
||||
{k.vt.as_python_constant() for k in self.set_items}, None
|
||||
).keys()
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
|
||||
raise RuntimeError(f"Illegal call_method {name} on a dict_keys")
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
@ -1379,42 +1420,47 @@ class DictViewVariable(VariableTracker):
|
||||
|
||||
kv: Optional[str] = None
|
||||
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert self.kv in ("keys", "values", "items")
|
||||
assert isinstance(dv_dict, ConstDictVariable)
|
||||
self.dv_dict = dv_dict
|
||||
|
||||
@property
|
||||
def view_items(self):
|
||||
def view_items(self) -> Any:
|
||||
assert self.kv is not None
|
||||
return getattr(self.dv_dict.items, self.kv)()
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
# Returns an iterable of the unpacked items
|
||||
# Implement in the subclasses
|
||||
raise NotImplementedError
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
return self.view_items_vt
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
assert self.kv is not None
|
||||
codegen(self.dv_dict)
|
||||
codegen.load_method(self.kv)
|
||||
codegen.call_method(0)
|
||||
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
assert self.kv is not None
|
||||
if name in self.python_type().__dict__:
|
||||
return ConstantVariable.create(True)
|
||||
return ConstantVariable.create(False)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__len__":
|
||||
return self.dv_dict.call_method(tx, name, args, kwargs)
|
||||
elif name == "__iter__":
|
||||
@ -1428,24 +1474,24 @@ class DictKeysVariable(DictViewVariable):
|
||||
kv = "keys"
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> set[VariableTracker]:
|
||||
return set(self.view_items)
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
# Returns an iterable of the unpacked items
|
||||
return [x.vt for x in self.view_items]
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_keys
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__contains__":
|
||||
return self.dv_dict.call_method(tx, name, args, kwargs)
|
||||
elif name in (
|
||||
@ -1460,13 +1506,13 @@ class DictKeysVariable(DictViewVariable):
|
||||
):
|
||||
# These methods always returns a set
|
||||
m = getattr(self.set_items, name)
|
||||
r = m(args[0].set_items)
|
||||
r = m(args[0].set_items) # type: ignore[attr-defined]
|
||||
return SetVariable(r)
|
||||
if name in cmp_name_to_op_mapping:
|
||||
if not isinstance(args[0], (SetVariable, DictKeysVariable)):
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
return ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
|
||||
)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
@ -1476,10 +1522,10 @@ class DictValuesVariable(DictViewVariable):
|
||||
kv = "values"
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
return list(self.view_items)
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_values
|
||||
|
||||
|
||||
@ -1487,14 +1533,20 @@ class DictItemsVariable(DictViewVariable):
|
||||
kv = "items"
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
# Returns an iterable of the unpacked items
|
||||
return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items]
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_items
|
||||
|
||||
def call_method(self, tx, name, args, kwargs):
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
# TODO(guilhermeleobas): This should actually check if args[0]
|
||||
# implements the mapping protocol.
|
||||
if name == "__eq__":
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -586,7 +586,7 @@ class FilterVariable(IteratorVariable):
|
||||
else:
|
||||
res = self.fn.call_function(tx, [item], {})
|
||||
pred_res = variables.UserFunctionVariable(
|
||||
polyfills.predicate
|
||||
polyfills.predicate # type: ignore[arg-type]
|
||||
).call_function(tx, [res], {})
|
||||
if pred_res.as_python_constant():
|
||||
return item
|
||||
|
||||
@ -472,7 +472,12 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
||||
)
|
||||
elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined]
|
||||
name_to_arg_map = bind_args_cached(
|
||||
self.value, tx, self.source, args, kwargs
|
||||
# pyrefly: ignore[bad-argument-type]
|
||||
self.value,
|
||||
tx,
|
||||
self.source,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
backends = name_to_arg_map["backends"].as_python_constant()
|
||||
set_priority = name_to_arg_map["set_priority"].as_python_constant()
|
||||
@ -1349,7 +1354,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
packed_input_vt = TupleVariable.build(
|
||||
tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs))
|
||||
)
|
||||
out_vt = variables.UserFunctionVariable(tree_flatten).call_function(
|
||||
out_vt = variables.UserFunctionVariable(tree_flatten).call_function( # type: ignore[arg-type]
|
||||
tx, [packed_input_vt], {}
|
||||
)
|
||||
assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2
|
||||
|
||||
@ -2970,6 +2970,12 @@ class CppPythonBindingsCodeCache(CppCodeCache):
|
||||
throw std::runtime_error("expected int arg");
|
||||
return reinterpret_cast<uintptr_t>(result);
|
||||
}}
|
||||
template <> inline float parse_arg<float>(PyObject* args, size_t n) {{
|
||||
auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n));
|
||||
if(unlikely(result == -1.0 && PyErr_Occurred()))
|
||||
throw std::runtime_error("expected float arg");
|
||||
return static_cast<float>(result);
|
||||
}}
|
||||
|
||||
{extra_parse_arg}
|
||||
|
||||
|
||||
@ -1732,9 +1732,15 @@ class KernelArgs:
|
||||
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
||||
arg_types.append(f"{cpp_dtype}*")
|
||||
for outer, inner in self.sizevars.items():
|
||||
arg_defs.append(f"const {INDEX_TYPE} {inner}")
|
||||
if isinstance(outer, sympy.Symbol) and symbol_is_type(
|
||||
outer, (SymT.UNBACKED_FLOAT)
|
||||
):
|
||||
arg_defs.append(f"const float {inner}")
|
||||
arg_types.append("const float")
|
||||
else:
|
||||
arg_defs.append(f"const {INDEX_TYPE} {inner}")
|
||||
arg_types.append(f"const {INDEX_TYPE}")
|
||||
call_args.append(self.wrap_size_arg(outer))
|
||||
arg_types.append(f"const {INDEX_TYPE}")
|
||||
if V.graph.wrapper_code:
|
||||
V.graph.wrapper_code.ensure_size_computed(outer)
|
||||
assert not self.workspace_args, "Workspace not supported on CPU "
|
||||
@ -2353,6 +2359,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
SymT.UNBACKED_INT,
|
||||
SymT.SIZE,
|
||||
SymT.PRECOMPUTED_SIZE,
|
||||
SymT.UNBACKED_FLOAT,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, Optional
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
|
||||
from .. import config
|
||||
from ..runtime.hints import AttrsDescriptorWrapper
|
||||
@ -71,6 +72,10 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
|
||||
return "constexpr"
|
||||
elif isinstance(arg.expr, (float, sympy.Float)):
|
||||
return "fp32"
|
||||
elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type(
|
||||
arg.expr, (SymT.UNBACKED_FLOAT)
|
||||
):
|
||||
return "fp32"
|
||||
elif isinstance(arg.expr, bool):
|
||||
return "i1"
|
||||
|
||||
|
||||
@ -546,10 +546,6 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
|
||||
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
|
||||
).upper() # type: ignore[assignment]
|
||||
|
||||
cutedsl_enable_autotuning: bool = (
|
||||
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
|
||||
)
|
||||
|
||||
# DEPRECATED. This setting is ignored.
|
||||
autotune_fallback_to_aten = False
|
||||
|
||||
@ -678,6 +674,17 @@ loop_ordering_after_fusion: bool = (
|
||||
== "1"
|
||||
)
|
||||
|
||||
|
||||
# When trying to fuse two nodes, one with:
|
||||
# a[contiguous_writes] = fn(...)
|
||||
# and another node:
|
||||
# b[contiguous_writes] = a[discontiguous_reads]
|
||||
# If b is unary, and we can figure out an inverse formula for
|
||||
# discontiguous writes, invert b as :
|
||||
# b[inverse(discontiguous_writes)] = a[contiguous_reads]
|
||||
# so that the nodes can fuse. for more details: https://gist.github.com/eellison/6f9f4a7ec10a860150b15b719f9285a9
|
||||
loop_index_inversion_in_fusion: bool = True
|
||||
|
||||
# If fusing two nodes only save less then score_fusion_memory_threshold memory,
|
||||
# we should not bother fusing the nodes.
|
||||
#
|
||||
|
||||
208
torch/_inductor/invert_expr_analysis.py
Normal file
208
torch/_inductor/invert_expr_analysis.py
Normal file
@ -0,0 +1,208 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import sympy
|
||||
|
||||
from torch._inductor.utils import _IntLike, argsort_sym
|
||||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
||||
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
def static_eq(a: _IntLike, b: _IntLike) -> bool:
|
||||
return V.graph.sizevars.statically_known_equals(a, b)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Term:
|
||||
coefficient: _IntLike
|
||||
range: Optional[_IntLike] # None for unbounded
|
||||
original_expr: sympy.Expr
|
||||
reconstruction_multiplier: _IntLike # The multiplier needed for reconstruction
|
||||
|
||||
|
||||
def generate_inverse_formula(
|
||||
expr: sympy.Expr, var: sympy.Symbol
|
||||
) -> Optional[sympy.Expr]:
|
||||
"""
|
||||
Analyze an expression to see if it matches a specific invertible pattern that we
|
||||
know how to reverse.
|
||||
|
||||
We're looking for expressions that are sums of terms where each term extracts a
|
||||
distinct bounded range from the input variable, like:
|
||||
|
||||
y = c₀*a₀ + c₁*a₁ + c₂*a₂ + ... + cₙ*aₙ
|
||||
|
||||
where each aᵢ must be one of these specific patterns:
|
||||
- ModularIndexing(var, divisor, modulo)
|
||||
- FloorDiv(ModularIndexing(var, 1, modulo), divisor)
|
||||
- FloorDiv(var, divisor)
|
||||
- var (the variable itself)
|
||||
|
||||
The key pattern we need is:
|
||||
- Coefficients are strictly decreasing: c₀ > c₁ > c₂ > ... > cₙ
|
||||
- Each coefficient matches the product of ranges of later terms (mixed-radix property)
|
||||
- Each term extracts a bounded range, creating non-overlapping "slots"
|
||||
|
||||
If we find this pattern, we can generate the reconstruction transformation that
|
||||
decomposes the variable and rebuilds it using the correct multipliers.
|
||||
|
||||
EXAMPLE:
|
||||
Input: 100*((p//100)) + 10*((p%100)//10) + (p%10)
|
||||
|
||||
Returns the reconstruction expression:
|
||||
remainder₀ = p
|
||||
component₀ = remainder₀ // 100 # hundreds digit
|
||||
remainder₁ = remainder₀ % 100
|
||||
component₁ = remainder₁ // 10 # tens digit
|
||||
remainder₂ = remainder₁ % 10
|
||||
component₂ = remainder₂ # ones digit
|
||||
result = component₀*100 + component₁*10 + component₂*1
|
||||
|
||||
This decomposes p into its components and rebuilds it using the original
|
||||
multipliers, which should equal the input expression.
|
||||
|
||||
Args:
|
||||
expr: Expression to analyze (sum of terms with ModularIndexing, FloorDiv, etc.)
|
||||
var: The variable being decomposed
|
||||
|
||||
Returns:
|
||||
None if not invertible, or the reconstruction expression
|
||||
|
||||
References:
|
||||
Mixed-radix systems: https://en.wikipedia.org/wiki/Mixed_radix
|
||||
"""
|
||||
# Step 1: Parse all terms
|
||||
terms = parse_terms(expr, var)
|
||||
if not terms:
|
||||
return None
|
||||
|
||||
# Step 2: Sort by coefficient (descending)
|
||||
coeffs = [t.coefficient for t in terms]
|
||||
idxs = reversed(argsort_sym(V.graph.sizevars.shape_env, coeffs))
|
||||
terms = [terms[i] for i in idxs]
|
||||
|
||||
# Step 3: Check invertibility conditions
|
||||
if not check_invertibility(terms):
|
||||
return None
|
||||
|
||||
return generate_reconstruction_expr(terms, var)
|
||||
|
||||
|
||||
def parse_terms(expr: sympy.Expr, var: sympy.Symbol) -> Optional[list[Term]]:
|
||||
"""Parse expression into terms."""
|
||||
if not isinstance(expr, sympy.Add):
|
||||
# Single term
|
||||
term = parse_single_term(expr, var)
|
||||
return [term] if term else []
|
||||
|
||||
terms = []
|
||||
for arg in expr.args:
|
||||
term = parse_single_term(arg, var)
|
||||
if term:
|
||||
terms.append(term)
|
||||
else:
|
||||
return None # If any term fails to parse, fail completely
|
||||
|
||||
return terms
|
||||
|
||||
|
||||
def parse_single_term(term: sympy.Expr, var: sympy.Symbol) -> Optional[Term]:
|
||||
"""Parse a single term and extract coefficient, range, and reconstruction multiplier."""
|
||||
# Extract coefficient and expression parts
|
||||
coefficient, expr_parts = term.as_coeff_mul()
|
||||
|
||||
if len(expr_parts) == 0:
|
||||
# Pure constant term
|
||||
return Term(
|
||||
coefficient=coefficient,
|
||||
range=1,
|
||||
original_expr=1,
|
||||
reconstruction_multiplier=0,
|
||||
)
|
||||
elif len(expr_parts) == 1:
|
||||
expr = expr_parts[0]
|
||||
else:
|
||||
# Multiple non-constant factors, too complex
|
||||
return None
|
||||
|
||||
# Now determine the range and reconstruction multiplier
|
||||
range_val, reconstruction_multiplier = analyze_expression_properties(expr, var)
|
||||
if reconstruction_multiplier is None:
|
||||
return None
|
||||
|
||||
return Term(
|
||||
coefficient=coefficient,
|
||||
range=range_val,
|
||||
original_expr=expr,
|
||||
reconstruction_multiplier=reconstruction_multiplier,
|
||||
)
|
||||
|
||||
|
||||
def analyze_expression_properties(
|
||||
expr: sympy.Expr, var: sympy.Symbol
|
||||
) -> tuple[Optional[_IntLike], Optional[_IntLike]]:
|
||||
"""Analyze an expression to determine its range and reconstruction multiplier."""
|
||||
# ModularIndexing(var, divisor, modulo) = (var // divisor) % modulo
|
||||
if isinstance(expr, ModularIndexing):
|
||||
x, div, mod = expr.args
|
||||
if static_eq(x, var):
|
||||
return mod, div # Range is mod, multiplier is div
|
||||
|
||||
# FloorDiv cases
|
||||
if isinstance(expr, FloorDiv):
|
||||
base, divisor = expr.args
|
||||
|
||||
# FloorDiv(ModularIndexing(var, 1, mod), div) = (var % mod) // div
|
||||
if isinstance(base, ModularIndexing):
|
||||
x, inner_div, mod = base.args
|
||||
if static_eq(x, var) and static_eq(inner_div, 1):
|
||||
range_val = FloorDiv(mod, divisor)
|
||||
return range_val, divisor # Range is mod//div, multiplier is div
|
||||
|
||||
# FloorDiv(var, divisor) = var // divisor (unbounded)
|
||||
elif static_eq(base, var):
|
||||
return None, divisor # Unbounded range, multiplier is div
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def check_invertibility(terms: list[Term]) -> bool:
|
||||
"""Check if the terms represent an invertible transformation."""
|
||||
if not terms:
|
||||
return False
|
||||
|
||||
# Coefficients must be strictly decreasing
|
||||
coeffs = [t.coefficient for t in terms]
|
||||
if argsort_sym(V.graph.sizevars.shape_env, coeffs) != list(
|
||||
reversed(range(len(coeffs)))
|
||||
):
|
||||
return False
|
||||
|
||||
# Check mixed-radix property: each coeff[i] = coeff[i+1] * range[i+1]
|
||||
expected_coeff = 1
|
||||
for term in reversed(terms):
|
||||
if not static_eq(term.coefficient, expected_coeff):
|
||||
return False
|
||||
if term.range is not None:
|
||||
expected_coeff *= term.range
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def generate_reconstruction_expr(terms: list[Term], var: sympy.Symbol) -> sympy.Expr:
|
||||
y = var
|
||||
reconstruction = sympy.S.Zero
|
||||
remainder = y
|
||||
|
||||
for i, term in enumerate(terms):
|
||||
if i < len(terms) - 1:
|
||||
component = FloorDiv(remainder, term.coefficient)
|
||||
remainder = ModularIndexing(remainder, 1, term.coefficient)
|
||||
else:
|
||||
# Last term should also divide by its coefficient
|
||||
component = FloorDiv(remainder, term.coefficient)
|
||||
|
||||
reconstruction += component * term.reconstruction_multiplier
|
||||
|
||||
return reconstruction
|
||||
@ -1,8 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@ -14,7 +12,6 @@ from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
|
||||
from .. import config
|
||||
from ..codegen.wrapper import PythonWrapperCodegen
|
||||
from ..ir import _IntLike, Layout, TensorBox
|
||||
from ..utils import load_template
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -257,7 +254,3 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
|
||||
from torch._inductor.runtime.triton_compat import tl
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.utils._triton import has_triton
|
||||
@ -19,25 +18,19 @@ from ..select_algorithm import (
|
||||
TritonTemplate,
|
||||
)
|
||||
from ..utils import (
|
||||
ensure_cute_available,
|
||||
get_gpu_shared_memory,
|
||||
get_num_sms,
|
||||
has_free_symbols,
|
||||
use_aten_gemm_kernels,
|
||||
use_blackwell_cutedsl_grouped_mm,
|
||||
use_triton_template,
|
||||
)
|
||||
from .mm_common import (
|
||||
_is_static_problem,
|
||||
check_supported_striding,
|
||||
load_kernel_template,
|
||||
persistent_grouped_mm_grid,
|
||||
)
|
||||
|
||||
|
||||
if ensure_cute_available():
|
||||
from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
|
||||
@ -520,11 +513,6 @@ triton_scaled_grouped_mm_template = TritonTemplate(
|
||||
source=triton_grouped_mm_source,
|
||||
)
|
||||
|
||||
cutedsl_grouped_mm_template = CuteDSLTemplate(
|
||||
name="grouped_gemm_cutedsl",
|
||||
source=load_kernel_template("cutedsl_mm_grouped"),
|
||||
)
|
||||
|
||||
|
||||
def grouped_mm_args(
|
||||
mat1: TensorBox,
|
||||
@ -726,44 +714,43 @@ def _tuned_grouped_mm_common(
|
||||
# Checking only for the equality of corresponding dims of
|
||||
# multiplicands here, relying on meta function checks for
|
||||
# everything else.
|
||||
if len(m1_size) == 2:
|
||||
if len(m2_size) == 2:
|
||||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
if (
|
||||
is_nonzero
|
||||
and use_triton_template(layout)
|
||||
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
|
||||
):
|
||||
scaled = scale_a is not None
|
||||
if len(m1_size) == 2:
|
||||
if len(m2_size) == 2:
|
||||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
a_is_k_major = mat_a.get_stride()[-1] == 1
|
||||
b_is_k_major = mat_b.get_stride()[-2] == 1
|
||||
@ -801,22 +788,6 @@ def _tuned_grouped_mm_common(
|
||||
**config.kwargs,
|
||||
)
|
||||
|
||||
if use_blackwell_cutedsl_grouped_mm(
|
||||
mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result
|
||||
):
|
||||
for config in get_groupgemm_configs():
|
||||
kwargs = dict(
|
||||
ACC_DTYPE="cutlass.Float32",
|
||||
)
|
||||
|
||||
cutedsl_grouped_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
**asdict(config),
|
||||
)
|
||||
|
||||
input_gen_fns = {
|
||||
4: lambda x: create_offsets(
|
||||
x, m1_size, m2_size, offs.get_size() if offs is not None else None
|
||||
|
||||
@ -1,333 +0,0 @@
|
||||
import functools
|
||||
from torch._inductor.runtime.runtime_utils import ceildiv
|
||||
from cutlass.utils import TensorMapUpdateMode
|
||||
{{gen_defines()}}
|
||||
# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ----
|
||||
from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import (
|
||||
GroupedGemmKernel,
|
||||
)
|
||||
|
||||
|
||||
# Note about caching:
|
||||
# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor
|
||||
# maintains its own local caching system. At this stage, all compile-time
|
||||
# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel
|
||||
# name itself ({{kernel_name}}) are permanently baked into the file, so they
|
||||
# do not need to be included in any cache key.
|
||||
#
|
||||
# The caching mechanism is split into two levels:
|
||||
#
|
||||
# 1. prep_cache
|
||||
# Caches the compiled executor for build_group_ptrs_from_bases(). This
|
||||
# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C,
|
||||
# and can therefore be safely reused across runs with different group
|
||||
# partitioning (`offs`).
|
||||
#
|
||||
# 2. gemm_cache
|
||||
# Caches the compiled Grouped GEMM executor. Its key extends the prep
|
||||
# cache key with hardware- and grid-specific parameters:
|
||||
# (prep_cache_key, max_active_clusters, total_num_clusters).
|
||||
# This is necessary because different `offs` tensors can change the
|
||||
# per-group problem sizes and thus alter `total_num_clusters`, which in
|
||||
# turn changes the grid shape and persistent scheduler configuration.
|
||||
# Kernels compiled for one grid cannot be safely reused for another.
|
||||
#
|
||||
#
|
||||
# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically,
|
||||
# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead,
|
||||
# despite depending only on the GPU type. We cache this function to mitigate
|
||||
# redundant recompiles even when shape/stride/dtype cache misses force kernel
|
||||
# regeneration. A follow-up study will investigate the root cause.
|
||||
|
||||
prep_cache = {}
|
||||
gemm_cache = {}
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_hardware_info():
|
||||
hw = cutlass.utils.HardwareInfo()
|
||||
sm_count = hw.get_max_active_clusters(1)
|
||||
max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N)
|
||||
|
||||
return (sm_count, max_active_clusters)
|
||||
|
||||
|
||||
def get_prep_cache_key(input_a, input_b, output):
|
||||
"""
|
||||
Returns a tuple key for caching the preprocessing kernel executor based on kernel name,
|
||||
shapes, strides, and dtypes of input/output tensors.
|
||||
"""
|
||||
return (
|
||||
tuple(input_a.shape),
|
||||
tuple(input_a.stride()),
|
||||
input_a.dtype,
|
||||
tuple(input_b.shape),
|
||||
tuple(input_b.stride()),
|
||||
input_b.dtype,
|
||||
tuple(output.shape),
|
||||
tuple(output.stride()),
|
||||
output.dtype,
|
||||
)
|
||||
|
||||
|
||||
def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters):
|
||||
"""
|
||||
Returns a tuple key for caching the gemm kernel executor by extending the
|
||||
prep cache key with hardware- and grid-specific parameters.
|
||||
"""
|
||||
return (
|
||||
prep_cache_key,
|
||||
max_active_clusters,
|
||||
total_num_clusters,
|
||||
)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def build_group_ptrs_from_bases_kernel(
|
||||
base_A_u64: cutlass.Int64, # device addr of input_a (bytes)
|
||||
base_B_u64: cutlass.Int64, # device addr of input_b (bytes)
|
||||
base_C_u64: cutlass.Int64, # device addr of Output (bytes)
|
||||
offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative
|
||||
K: cutlass.Constexpr,
|
||||
N: cutlass.Constexpr,
|
||||
sizeof_element: cutlass.Int32, # bytes
|
||||
# -------- STRIDES (in ELEMENTS) --------
|
||||
stride_A_m_elems: cutlass.Constexpr, # A.stride(0)
|
||||
stride_A_k_elems: cutlass.Constexpr, # A.stride(1)
|
||||
stride_B0_elems: cutlass.Constexpr, # B.stride(0)
|
||||
stride_Bk_elems: cutlass.Constexpr, # B.stride(1)
|
||||
stride_Bn_elems: cutlass.Constexpr, # B.stride(2)
|
||||
stride_C_m_elems: cutlass.Constexpr, # C.stride(0)
|
||||
stride_C_n_elems: cutlass.Constexpr, # C.stride(1)
|
||||
# -------- OUTPUTS --------
|
||||
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr)
|
||||
out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1)
|
||||
out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]]
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
g = tidx
|
||||
|
||||
m_beg_i32 = 0
|
||||
if g > 0:
|
||||
m_beg_i32 = offs[g - 1]
|
||||
m_end_i32 = offs[g]
|
||||
m_g_i32 = m_end_i32 - m_beg_i32
|
||||
|
||||
a_byte_off = (
|
||||
cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element)
|
||||
)
|
||||
c_byte_off = (
|
||||
cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element)
|
||||
)
|
||||
b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element)
|
||||
|
||||
# ---- pointers ----
|
||||
out_ptrs[g, 0] = base_A_u64 + a_byte_off
|
||||
out_ptrs[g, 1] = base_B_u64 + b_byte_off
|
||||
out_ptrs[g, 2] = base_C_u64 + c_byte_off
|
||||
|
||||
# ---- (m, n, k, 1) ----
|
||||
out_problem[g, 0] = m_g_i32
|
||||
out_problem[g, 1] = N
|
||||
out_problem[g, 2] = K
|
||||
out_problem[g, 3] = cutlass.Int32(1)
|
||||
|
||||
# ---- strides ----
|
||||
out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems)
|
||||
out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems)
|
||||
out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems)
|
||||
out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems)
|
||||
out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems)
|
||||
out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_build_group_ptrs_from_bases(
|
||||
base_A_u64: cutlass.Int64,
|
||||
base_B_u64: cutlass.Int64,
|
||||
base_C_u64: cutlass.Int64,
|
||||
offs: cute.Tensor,
|
||||
G: cutlass.Constexpr,
|
||||
K: cutlass.Constexpr,
|
||||
N: cutlass.Constexpr,
|
||||
sizeof_element: cutlass.Constexpr,
|
||||
stride_A_m_elems: cutlass.Constexpr,
|
||||
stride_A_k_elems: cutlass.Constexpr,
|
||||
stride_B0_elems: cutlass.Constexpr,
|
||||
stride_Bk_elems: cutlass.Constexpr,
|
||||
stride_Bn_elems: cutlass.Constexpr,
|
||||
stride_C_m_elems: cutlass.Constexpr,
|
||||
stride_C_n_elems: cutlass.Constexpr,
|
||||
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64
|
||||
out_problem: cute.Tensor, # [G,4] cutlass.Int32
|
||||
out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
build_group_ptrs_from_bases_kernel(
|
||||
base_A_u64,
|
||||
base_B_u64,
|
||||
base_C_u64,
|
||||
offs,
|
||||
K,
|
||||
N,
|
||||
sizeof_element,
|
||||
stride_A_m_elems,
|
||||
stride_A_k_elems,
|
||||
stride_B0_elems,
|
||||
stride_Bk_elems,
|
||||
stride_Bn_elems,
|
||||
stride_C_m_elems,
|
||||
stride_C_n_elems,
|
||||
out_ptrs,
|
||||
out_problem,
|
||||
out_strides_abc,
|
||||
).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream)
|
||||
|
||||
|
||||
{{def_kernel("input_a", "input_b", "input_a_offs")}}
|
||||
stream = cuda.CUstream(stream)
|
||||
|
||||
input_b = input_b.transpose(1, 2)
|
||||
|
||||
sumM, K = input_a.shape
|
||||
G, N, Kb = input_b.shape
|
||||
|
||||
dev = input_a.device
|
||||
|
||||
base_A_u64 = int(input_a.data_ptr())
|
||||
base_B_u64 = int(input_b.data_ptr())
|
||||
base_C_u64 = int({{get_output()}}.data_ptr())
|
||||
|
||||
ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64)
|
||||
probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32)
|
||||
strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32)
|
||||
ptrs = from_dlpack(ptrs_t)
|
||||
probs = from_dlpack(probs_t)
|
||||
strides = from_dlpack(strides_t)
|
||||
|
||||
prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}})
|
||||
prep_executor = prep_cache.get(prep_cache_key)
|
||||
|
||||
if prep_executor is None:
|
||||
sizeof_element = int(input_a.element_size())
|
||||
sA_m, sA_k = map(int, input_a.stride())
|
||||
sB_0, sB_n, sB_k = map(int, input_b.stride())
|
||||
sC_m, sC_n = map(int, {{get_output()}}.stride())
|
||||
|
||||
prep_executor = cute.compile(
|
||||
launch_build_group_ptrs_from_bases,
|
||||
base_A_u64=base_A_u64,
|
||||
base_B_u64=base_B_u64,
|
||||
base_C_u64=base_C_u64,
|
||||
offs=from_dlpack(input_a_offs),
|
||||
G=int(G),
|
||||
K=int(K),
|
||||
N=int(N),
|
||||
sizeof_element=sizeof_element,
|
||||
stride_A_m_elems=sA_m,
|
||||
stride_A_k_elems=sA_k,
|
||||
stride_B0_elems=sB_0,
|
||||
stride_Bk_elems=sB_k,
|
||||
stride_Bn_elems=sB_n,
|
||||
stride_C_m_elems=sC_m,
|
||||
stride_C_n_elems=sC_n,
|
||||
out_ptrs=ptrs,
|
||||
out_problem=probs,
|
||||
out_strides_abc=strides,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
prep_cache[prep_cache_key] = prep_executor
|
||||
|
||||
prep_executor(
|
||||
base_A_u64=base_A_u64,
|
||||
base_B_u64=base_B_u64,
|
||||
base_C_u64=base_C_u64,
|
||||
offs=from_dlpack(input_a_offs),
|
||||
out_ptrs=ptrs,
|
||||
out_problem=probs,
|
||||
out_strides_abc=strides,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# --- Tensormap workspace per SM ---
|
||||
num_tensormap_buffers, max_active_clusters = get_hardware_info()
|
||||
tensormap_shape = (
|
||||
num_tensormap_buffers,
|
||||
GroupedGemmKernel.num_tensormaps,
|
||||
GroupedGemmKernel.bytes_per_tensormap // 8,
|
||||
)
|
||||
tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64)
|
||||
tensormap_workspace = from_dlpack(tensormap_workspace_t)
|
||||
|
||||
# --- Total clusters ---
|
||||
def compute_total_num_clusters(
|
||||
problem_sizes_mnkl,
|
||||
cluster_tile_shape_mn,
|
||||
):
|
||||
total_num_clusters = 0
|
||||
for m, n, _, _ in problem_sizes_mnkl:
|
||||
num_clusters_mn = tuple(
|
||||
ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn)
|
||||
)
|
||||
total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
|
||||
return total_num_clusters
|
||||
|
||||
# Compute cluster tile shape
|
||||
def compute_cluster_tile_shape(
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
use_2cta_instrs,
|
||||
):
|
||||
cta_tile_shape_mn = list(mma_tiler_mn)
|
||||
if use_2cta_instrs:
|
||||
cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2
|
||||
return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))
|
||||
|
||||
cluster_tile_shape_mn = compute_cluster_tile_shape(
|
||||
(TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA)
|
||||
)
|
||||
|
||||
total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn))
|
||||
|
||||
gemm_cache_key = get_gemm_cache_key(
|
||||
prep_cache_key, max_active_clusters, total_num_clusters
|
||||
)
|
||||
gemm_executor = gemm_cache.get(gemm_cache_key)
|
||||
|
||||
if gemm_executor is None:
|
||||
grouped_gemm = GroupedGemmKernel(
|
||||
acc_dtype=ACC_DTYPE,
|
||||
use_2cta_instrs=USE_2_CTA,
|
||||
mma_tiler_mn=(TILE_M, TILE_N),
|
||||
cluster_shape_mn=(CLUSTER_M, CLUSTER_N),
|
||||
tensormap_update_mode=TENSORMAP_UPDATE_MODE,
|
||||
)
|
||||
|
||||
gemm_executor = cute.compile(
|
||||
grouped_gemm,
|
||||
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
|
||||
G,
|
||||
probs,
|
||||
strides,
|
||||
ptrs,
|
||||
total_num_clusters,
|
||||
tensormap_workspace,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
)
|
||||
|
||||
gemm_cache[gemm_cache_key] = gemm_executor
|
||||
|
||||
gemm_executor(
|
||||
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
|
||||
probs,
|
||||
strides,
|
||||
ptrs,
|
||||
tensormap_workspace,
|
||||
stream,
|
||||
)
|
||||
@ -95,7 +95,6 @@ class LoopBody:
|
||||
"""
|
||||
|
||||
indexing_exprs: dict[str, sympy.Expr]
|
||||
indexing_exprs_name: dict[sympy.Expr, str]
|
||||
submodules: dict[str, Any]
|
||||
subblocks: dict[str, LoopBodyBlock]
|
||||
indirect_vars: list[sympy.Symbol]
|
||||
@ -104,6 +103,9 @@ class LoopBody:
|
||||
memory_usage: dict[MemoryUsageType, list[MemoryEntry]]
|
||||
op_counts: collections.Counter[str]
|
||||
|
||||
# defined only temporarily
|
||||
indexing_exprs_name: dict[sympy.Expr, str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn,
|
||||
|
||||
@ -3345,7 +3345,10 @@ class Scheduler:
|
||||
)
|
||||
break
|
||||
|
||||
if config.loop_ordering_after_fusion:
|
||||
if (
|
||||
config.loop_ordering_after_fusion
|
||||
or config.loop_index_inversion_in_fusion
|
||||
):
|
||||
nodes = self.fuse_nodes_once(nodes, is_reorder_round=True)
|
||||
return nodes
|
||||
|
||||
@ -4302,6 +4305,148 @@ class Scheduler:
|
||||
|
||||
return str(reasons)
|
||||
|
||||
def shared_data_after_inverting_indexing(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
) -> int:
|
||||
"""
|
||||
Attempts to enable fusion between two nodes by inverting indexing patterns.
|
||||
|
||||
This optimization targets cases where node1 has a contiguous write and
|
||||
node2 has a contiguous write but discontiguous read. By inverting the
|
||||
indexing in node2's read and write operations, we can make them compatible
|
||||
with node1 for potential fusion.
|
||||
|
||||
Args:
|
||||
node1: First scheduler node (source)
|
||||
node2: Second scheduler node (target for inversion)
|
||||
|
||||
Returns:
|
||||
int: Fusion score if successful, 0 if optimization not applicable
|
||||
"""
|
||||
|
||||
if not config.loop_index_inversion_in_fusion:
|
||||
return -1
|
||||
|
||||
if any(n.is_cpu() for n in [node1, node2]):
|
||||
return -1
|
||||
|
||||
# Check for shared buffers between nodes
|
||||
node1_buffer_names = node1.read_writes.buffer_names()
|
||||
node2_buffer_names = node2.read_writes.buffer_names()
|
||||
common_buffer_names = node1_buffer_names & node2_buffer_names
|
||||
|
||||
if not common_buffer_names:
|
||||
return -1
|
||||
|
||||
# only invert if node1 is single unmet dep
|
||||
node2_unmet_dependencies = OrderedSet(
|
||||
dep.name for dep in node2.unmet_dependencies
|
||||
)
|
||||
if node2_unmet_dependencies - node1_buffer_names:
|
||||
return -1
|
||||
|
||||
if len(node2_unmet_dependencies) > 1:
|
||||
return -1
|
||||
|
||||
# Currently only handle single read/write operations
|
||||
if len(node2.read_writes.reads) > 1 or len(node2.read_writes.writes) > 1:
|
||||
return -1
|
||||
|
||||
node2_read = next(iter(node2.read_writes.reads))
|
||||
node2_write = next(iter(node2.read_writes.writes))
|
||||
|
||||
if not isinstance(node2_read, MemoryDep) or not isinstance(
|
||||
node2_write, MemoryDep
|
||||
):
|
||||
return -1
|
||||
|
||||
node1_writes = {dep.name: dep for dep in node1.read_writes.writes}
|
||||
if node2_read.name not in node1_writes:
|
||||
return -1
|
||||
|
||||
node1_write = node1_writes[node2_read.name]
|
||||
|
||||
if not isinstance(node1_write, MemoryDep):
|
||||
return -1
|
||||
|
||||
# We are checking for compatibility with the normalized node1 write
|
||||
# then modifying node2 reads/writes. since the node1 write will be just used
|
||||
# for compatibility, while node2 will be used in actual modification, just
|
||||
# normalize node1 not node2.
|
||||
node1_write = node1_write.normalize()
|
||||
|
||||
if (
|
||||
node1_write.index != node2_write.index
|
||||
and node1_write.size != node2_write.size
|
||||
):
|
||||
return -1
|
||||
|
||||
if node2_read.size != node2_write.size or len(node2_read.var_names) != 1:
|
||||
return -1
|
||||
|
||||
# Verify we have exactly two indexing expressions (one read, one write)
|
||||
if len(node2._body.indexing_exprs) != 2: # type: ignore[attr-defined]
|
||||
return -1
|
||||
|
||||
# No subblocks allowed for this optimization
|
||||
if node2._body.subblocks: # type: ignore[attr-defined]
|
||||
return -1
|
||||
|
||||
assert (
|
||||
"index0" in node2._body.indexing_exprs # type: ignore[attr-defined]
|
||||
and "index1" in node2._body.indexing_exprs # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
# Extract and verify single read expression
|
||||
node2_read_exprs = OrderedSet(expr for expr in node2._body.get_read_exprs()) # type: ignore[attr-defined]
|
||||
if len(node2_read_exprs) != 1:
|
||||
return -1
|
||||
|
||||
read_expr = next(iter(node2_read_exprs))
|
||||
|
||||
# Determine which index is for reading vs writing
|
||||
if read_expr == node2._body.indexing_exprs["index0"]: # type: ignore[attr-defined]
|
||||
read_expr_index = "index0"
|
||||
write_expr_index = "index1"
|
||||
else:
|
||||
assert read_expr == node2._body.indexing_exprs["index1"] # type: ignore[attr-defined]
|
||||
read_expr_index = "index1"
|
||||
write_expr_index = "index0"
|
||||
|
||||
from torch._inductor.invert_expr_analysis import generate_inverse_formula
|
||||
|
||||
index_vars = node2._body.vars[0] # type: ignore[attr-defined]
|
||||
if len(index_vars) != 1:
|
||||
return -1
|
||||
|
||||
simplified_terms = []
|
||||
for term in sympy.Add.make_args(read_expr):
|
||||
simplified_terms.append(
|
||||
V.graph.sizevars.combine_modular_indexing_pairs(term)
|
||||
)
|
||||
simplified_read_expr = sum(simplified_terms)
|
||||
|
||||
inverse_formula = generate_inverse_formula(simplified_read_expr, index_vars[0])
|
||||
|
||||
# formula is not invertible
|
||||
if inverse_formula is None:
|
||||
return -1
|
||||
|
||||
# === Apply Inversion ===
|
||||
|
||||
# Swap the indexing expressions using the inverse formula
|
||||
node2._body.indexing_exprs[read_expr_index] = node2._body.indexing_exprs[ # type: ignore[attr-defined]
|
||||
write_expr_index
|
||||
]
|
||||
node2._body.indexing_exprs[write_expr_index] = inverse_formula # type: ignore[attr-defined]
|
||||
|
||||
# Refresh dependencies and calculate fusion score
|
||||
node2.refresh_dependencies(True, False) # type: ignore[attr-defined]
|
||||
score = self.score_fusion_memory(node1, node2)
|
||||
|
||||
fusion_log.info("Shared memory after inversion: %d", score)
|
||||
return score
|
||||
|
||||
def shared_data_after_reordering_loop(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
) -> int:
|
||||
@ -4686,6 +4831,7 @@ class Scheduler:
|
||||
del device2
|
||||
|
||||
shared_data_score = self.score_fusion_memory(node1, node2)
|
||||
|
||||
if (
|
||||
can_reorder
|
||||
and shared_data_score < config.score_fusion_memory_threshold
|
||||
@ -4702,6 +4848,16 @@ class Scheduler:
|
||||
smaller_node.expand_dimension_for_pointwise_node(expand_dim, expand_size)
|
||||
shared_data_score = self.score_fusion_memory(node1, node2)
|
||||
|
||||
if (
|
||||
config.loop_index_inversion_in_fusion
|
||||
and shared_data_score < config.score_fusion_memory_threshold
|
||||
):
|
||||
new_shared_data_score = self.shared_data_after_inverting_indexing(
|
||||
node1, node2
|
||||
)
|
||||
if new_shared_data_score >= 0:
|
||||
shared_data_score = new_shared_data_score
|
||||
|
||||
if loop_ordering_log.isEnabledFor(logging.DEBUG):
|
||||
loop_ordering_log.debug(
|
||||
"%s and %s has %s shared data",
|
||||
|
||||
@ -1,141 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from itertools import product
|
||||
|
||||
import torch._inductor.config as config
|
||||
|
||||
|
||||
class TensorMapUpdateMode(Enum):
|
||||
"""Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency."""
|
||||
|
||||
SMEM = auto()
|
||||
GMEM = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CuTeGemmConfig:
|
||||
TILE_M: int = 128
|
||||
TILE_N: int = 192
|
||||
CLUSTER_M: int = 2
|
||||
CLUSTER_N: int = 1
|
||||
USE_2_CTA: bool = False
|
||||
TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM
|
||||
|
||||
|
||||
def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
For information regarding valid config sets, see:
|
||||
https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py
|
||||
"""
|
||||
|
||||
# Tile_n is always the same regardless of 2cta
|
||||
tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
# Valid clusters
|
||||
clusters_no_2cta = [
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(1, 4),
|
||||
(1, 8),
|
||||
(1, 16),
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(2, 8),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
(8, 1),
|
||||
(8, 2),
|
||||
(16, 1),
|
||||
]
|
||||
clusters_2cta = [
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(2, 8),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
(8, 1),
|
||||
(8, 2),
|
||||
(16, 1),
|
||||
]
|
||||
|
||||
configs: list[CuTeGemmConfig] = []
|
||||
|
||||
for use_2cta, cluster_set, tile_m_range in [
|
||||
(False, clusters_no_2cta, [64, 128]),
|
||||
(True, clusters_2cta, [128, 256]),
|
||||
]:
|
||||
for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product(
|
||||
[TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM],
|
||||
tile_m_range,
|
||||
tile_n_vals,
|
||||
cluster_set,
|
||||
):
|
||||
configs.append(
|
||||
CuTeGemmConfig(
|
||||
tile_m,
|
||||
tile_n,
|
||||
cluster_m,
|
||||
cluster_n,
|
||||
USE_2_CTA=use_2cta,
|
||||
TENSORMAP_UPDATE_MODE=tensormap_update_mode,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def get_default_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
"""
|
||||
|
||||
config_tuples = [
|
||||
(128, 256, 2, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 160, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(128, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 256, 2, 2, True, TensorMapUpdateMode.GMEM),
|
||||
(128, 256, 1, 2, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 256, 2, 1, True, TensorMapUpdateMode.SMEM),
|
||||
(128, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 8, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 192, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 2, 2, True, TensorMapUpdateMode.SMEM),
|
||||
(128, 96, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(64, 192, 1, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(64, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 192, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(128, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 160, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
]
|
||||
|
||||
return [CuTeGemmConfig(*args) for args in config_tuples]
|
||||
|
||||
|
||||
def get_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
|
||||
Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures
|
||||
or unstable results. By default, autotuning is disabled and we return only
|
||||
a single baseline config.
|
||||
"""
|
||||
if (
|
||||
config.cutedsl_enable_autotuning
|
||||
and config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
||||
):
|
||||
return get_exhaustive_groupgemm_configs()
|
||||
elif config.cutedsl_enable_autotuning:
|
||||
return get_default_groupgemm_configs()
|
||||
else:
|
||||
return [get_default_groupgemm_configs()[0]]
|
||||
@ -1975,84 +1975,6 @@ def use_triton_blackwell_tma_template(
|
||||
return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def ensure_cute_available() -> bool:
|
||||
"""Check if CuTeDSL is importable; cache the result for reuse.
|
||||
|
||||
Call ensure_cute_available.cache_clear() after installing CuTeDSL
|
||||
in the same interpreter to retry the import.
|
||||
"""
|
||||
try:
|
||||
return importlib.util.find_spec("cutlass.cute") is not None
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def use_blackwell_cutedsl_grouped_mm(
|
||||
mat_a: Any,
|
||||
mat_b: Any,
|
||||
layout: Layout,
|
||||
a_is_2d: bool,
|
||||
b_is_2d: bool,
|
||||
offs: Optional[Any],
|
||||
bias: Optional[Any],
|
||||
scale_result: Optional[Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if we can use the blackwell kernel for grouped mm.
|
||||
Required conditions:
|
||||
1. CuTeDSL backend is enabled
|
||||
2. CuTeDSL is available
|
||||
3. We are on a blackwell arch
|
||||
4. The dtype is bf16
|
||||
5. Max autotune or max autotune gemm is enabled
|
||||
6. A, B, and the output are 16B aligned
|
||||
7. We are not using dynamic shapes
|
||||
8. A is 2d
|
||||
9. B is 3d
|
||||
10. Offsets are provided
|
||||
11. Bias and Scale are not provided
|
||||
"""
|
||||
if not ensure_cute_available():
|
||||
return False
|
||||
|
||||
if not _use_autotune_backend("CUTEDSL"):
|
||||
return False
|
||||
|
||||
from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
|
||||
if not is_gpu(layout.device.type):
|
||||
return False
|
||||
|
||||
if not is_datacenter_blackwell_arch():
|
||||
return False
|
||||
|
||||
layout_dtypes = [torch.bfloat16]
|
||||
if not _use_template_for_gpu(layout, layout_dtypes):
|
||||
return False
|
||||
|
||||
if not (config.max_autotune or config.max_autotune_gemm):
|
||||
return False
|
||||
|
||||
# Checks for 16B ptr and stride alignment
|
||||
if not can_use_tma(mat_a, mat_b, output_layout=layout):
|
||||
return False
|
||||
|
||||
if any(is_dynamic(x) for x in [mat_a, mat_b]):
|
||||
return False
|
||||
|
||||
if not a_is_2d or b_is_2d:
|
||||
return False
|
||||
|
||||
if offs is None:
|
||||
return False
|
||||
|
||||
if bias is not None or scale_result is not None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
@ -1224,43 +1224,3 @@ def _build_table(
|
||||
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
|
||||
)
|
||||
return "".join(result)
|
||||
|
||||
|
||||
# Collect all events with stack traces and format them canonically
|
||||
def _canonicalize_profiler_events(events):
|
||||
"""
|
||||
Extract and format all events with stack traces in a canonical way
|
||||
for deterministic testing.
|
||||
"""
|
||||
events_with_traces = []
|
||||
|
||||
for event in events:
|
||||
# Extract relevant fields
|
||||
event_name = event.get("name", "")
|
||||
node_name = event["args"].get("node_name", "")
|
||||
stack_trace = event["args"].get("stack_trace", "")
|
||||
|
||||
# Get the last non-empty line of the stack trace
|
||||
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
|
||||
stack_trace = lines[-1] if lines else ""
|
||||
|
||||
events_with_traces.append(
|
||||
{
|
||||
"event_name": event_name[:20],
|
||||
"node_name": node_name,
|
||||
"stack_trace": stack_trace,
|
||||
"start_time": event.get("ts", 0),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by node_name for deterministic ordering
|
||||
events_with_traces.sort(key=lambda x: x["start_time"])
|
||||
|
||||
# Format as a string
|
||||
lines: list[str] = []
|
||||
for evt in events_with_traces:
|
||||
lines.append(
|
||||
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
@ -5,13 +5,13 @@
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
#include <torch/csrc/stable/version.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
|
||||
|
||||
HIDDEN_NAMESPACE_BEGIN(torch, stable)
|
||||
|
||||
@ -68,7 +68,7 @@ inline torch::stable::Tensor narrow(
|
||||
// only dtype information.
|
||||
inline torch::stable::Tensor new_empty(
|
||||
const torch::stable::Tensor& self,
|
||||
std::vector<int64_t> size,
|
||||
torch::headeronly::IntHeaderOnlyArrayRef size,
|
||||
std::optional<c10::ScalarType> dtype = std::nullopt) {
|
||||
int32_t device_type;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
|
||||
@ -107,7 +107,7 @@ inline torch::stable::Tensor new_empty(
|
||||
// only dtype information.
|
||||
inline torch::stable::Tensor new_zeros(
|
||||
const torch::stable::Tensor& self,
|
||||
std::vector<int64_t> size,
|
||||
torch::headeronly::IntHeaderOnlyArrayRef size,
|
||||
std::optional<c10::ScalarType> dtype = std::nullopt) {
|
||||
int32_t device_type;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
|
||||
@ -144,12 +144,10 @@ inline torch::stable::Tensor new_zeros(
|
||||
|
||||
// We expect this to be the stable version of the pad.default op.
|
||||
// pad.default takes in a SymInt[] as the pad argument however pad is typed as
|
||||
// use std::vector<int64_t> because
|
||||
// (1) IntArrayRef is not yet header-only
|
||||
// (2) SymInt is not yet header-only
|
||||
// torch::headeronly::IntHeaderOnlyArrayRef as SymInt is not yet header-only.
|
||||
inline torch::stable::Tensor pad(
|
||||
const torch::stable::Tensor& self,
|
||||
std::vector<int64_t> pad,
|
||||
torch::headeronly::IntHeaderOnlyArrayRef pad,
|
||||
const std::string& mode = "constant",
|
||||
double value = 0.0) {
|
||||
AtenTensorHandle ret0 = nullptr;
|
||||
@ -181,11 +179,10 @@ inline torch::stable::Tensor amax(
|
||||
// This function is an overload to compute the maximum value along each slice of
|
||||
// `self` reducing over all the dimensions in the vector `dims`. The
|
||||
// amax.default op takes in a SymInt[] as the dims argument, however dims is
|
||||
// typed as use std::vector<int64_t> here because (1) IntArrayRef is not yet
|
||||
// header-only (2) SymInt is not yet header-only
|
||||
// typed as use IntHeaderOnlyArrayRef here because SymInt is not yet header-only
|
||||
inline torch::stable::Tensor amax(
|
||||
const torch::stable::Tensor& self,
|
||||
std::vector<int64_t> dims,
|
||||
torch::headeronly::IntHeaderOnlyArrayRef dims,
|
||||
bool keepdim = false) {
|
||||
AtenTensorHandle ret = nullptr;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax(
|
||||
|
||||
@ -443,7 +443,6 @@ class CodeGen:
|
||||
colored: bool = False,
|
||||
# Render each argument on its own line
|
||||
expanded_def: bool = False,
|
||||
record_func: bool = False,
|
||||
) -> PythonCode:
|
||||
free_vars: list[str] = []
|
||||
body: list[str] = []
|
||||
@ -648,6 +647,15 @@ class CodeGen:
|
||||
|
||||
if verbose:
|
||||
# override annotation with more detailed information
|
||||
try:
|
||||
from torch.distributed.tensor._api import DTensor, DTensorSpec
|
||||
|
||||
dtensorspec_format_shard_order_str = (
|
||||
DTensorSpec.format_shard_order_str
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
DTensor = None # type: ignore[assignment,misc]
|
||||
dtensorspec_format_shard_order_str = None
|
||||
from torch.fx.experimental.proxy_tensor import py_sym_types
|
||||
from torch.fx.passes.shape_prop import TensorMetadata
|
||||
|
||||
@ -678,6 +686,16 @@ class CodeGen:
|
||||
core = _tensor_annotation(meta_val)
|
||||
if is_plain:
|
||||
maybe_type_annotation = f': "{core}"'
|
||||
elif type(meta_val) is DTensor:
|
||||
assert dtensorspec_format_shard_order_str is not None
|
||||
dtensor_meta = dtensorspec_format_shard_order_str(
|
||||
meta_val._spec.placements, # type: ignore[attr-defined]
|
||||
meta_val._spec.shard_order, # type: ignore[attr-defined]
|
||||
)
|
||||
cls = meta_val.__class__.__name__
|
||||
maybe_type_annotation = (
|
||||
f': "{cls}({core}, {dim_green(dtensor_meta)})"'
|
||||
)
|
||||
else:
|
||||
cls = meta_val.__class__.__name__
|
||||
maybe_type_annotation = f': "{cls}({core})"'
|
||||
@ -799,10 +817,6 @@ class CodeGen:
|
||||
return
|
||||
raise NotImplementedError(f"node: {node.op} {node.target}")
|
||||
|
||||
if record_func:
|
||||
body.append(
|
||||
"_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n"
|
||||
)
|
||||
for i, node in enumerate(nodes):
|
||||
# NOTE: emit_node does not emit a string with newline. It depends
|
||||
# on delete_unused_values to append one
|
||||
@ -812,22 +826,8 @@ class CodeGen:
|
||||
# node index, which will be deleted later
|
||||
# after going through _body_transformer
|
||||
body.append(f"# COUNTER: {i}\n")
|
||||
do_record = record_func and node.op in (
|
||||
"call_function",
|
||||
"call_method",
|
||||
"call_module",
|
||||
)
|
||||
if do_record:
|
||||
# The double hash ## convention is used by post-processing to find the fx markers
|
||||
body.append(
|
||||
f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n"
|
||||
)
|
||||
emit_node(node)
|
||||
delete_unused_values(node)
|
||||
if do_record:
|
||||
body.append(f"_rf_{node.name}.__exit__(None, None, None)\n")
|
||||
if record_func:
|
||||
body.append("_rf.__exit__(None, None, None)\n")
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
@ -1779,7 +1779,6 @@ class Graph:
|
||||
include_device: bool = False,
|
||||
colored: bool = False,
|
||||
expanded_def: bool = False,
|
||||
record_func: bool = False,
|
||||
) -> PythonCode:
|
||||
"""
|
||||
Turn this ``Graph`` into valid Python code.
|
||||
@ -1847,7 +1846,6 @@ class Graph:
|
||||
include_device=include_device,
|
||||
colored=colored,
|
||||
expanded_def=expanded_def,
|
||||
record_func=record_func,
|
||||
)
|
||||
|
||||
def _python_code(
|
||||
@ -1860,7 +1858,6 @@ class Graph:
|
||||
include_device: bool = False,
|
||||
colored: bool = False,
|
||||
expanded_def: bool = False,
|
||||
record_func: bool = False,
|
||||
) -> PythonCode:
|
||||
return self._codegen._gen_python_code(
|
||||
self.nodes,
|
||||
@ -1871,7 +1868,6 @@ class Graph:
|
||||
include_device=include_device,
|
||||
colored=colored,
|
||||
expanded_def=expanded_def,
|
||||
record_func=record_func,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
@ -861,18 +861,14 @@ class {module_name}(torch.nn.Module):
|
||||
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
python_code = self._graph.python_code(
|
||||
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
|
||||
)
|
||||
python_code = self._graph.python_code(root_module="self")
|
||||
self._code = python_code.src
|
||||
self._lineno_map = python_code._lineno_map
|
||||
self._prologue_start = python_code._prologue_start
|
||||
|
||||
cls = type(self)
|
||||
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
if dynamo_config.enrich_profiler_metadata:
|
||||
# Generate metadata and register for profiler augmentation
|
||||
@ -889,6 +885,7 @@ class {module_name}(torch.nn.Module):
|
||||
# This ensures the same code+metadata always generates the same filename
|
||||
hash_value = _metadata_hash(self._code, node_metadata)
|
||||
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
|
||||
|
||||
filename = f"{file_stem}.py"
|
||||
|
||||
# Only include co_filename to use it directly as the cache key
|
||||
@ -908,13 +905,6 @@ class {module_name}(torch.nn.Module):
|
||||
|
||||
_register_fx_metadata(filename, metadata)
|
||||
|
||||
# Replace the placeholder in generated code with actual filename
|
||||
# The double hash ## convention is used by post-processing to find the fx markers
|
||||
self._code = self._code.replace(
|
||||
"torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')",
|
||||
f"torch._C._profiler._RecordFunctionFast('## {filename} ##')",
|
||||
)
|
||||
|
||||
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
|
||||
|
||||
# Determine whether this class explicitly defines a __call__ implementation
|
||||
|
||||
@ -42,6 +42,9 @@ fp16_ieee_to_fp32_value
|
||||
# fp32_from_bits called from fp16_ieee_to_fp32_value
|
||||
# fp32_to_bits called from fp16_ieee_from_fp32_value
|
||||
|
||||
# torch/headeronly/util/HeaderOnlyArrayRef.h
|
||||
HeaderOnlyArrayRef
|
||||
|
||||
# c10/util/complex.h, torch/headeronly/util/complex.h
|
||||
complex
|
||||
|
||||
|
||||
247
torch/headeronly/util/HeaderOnlyArrayRef.h
Normal file
247
torch/headeronly/util/HeaderOnlyArrayRef.h
Normal file
@ -0,0 +1,247 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <initializer_list>
|
||||
#include <iterator>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// HeaderOnlyArrayRef - A subset of ArrayRef that is implemented only
|
||||
/// in headers. This will be a base class from which ArrayRef inherits, so that
|
||||
/// we can keep much of the implementation shared.
|
||||
///
|
||||
/// [HeaderOnlyArrayRef vs ArrayRef note]
|
||||
/// As HeaderOnlyArrayRef is a subset of ArrayRef, it has slightly less
|
||||
/// functionality than ArrayRef. We document the minor differences below:
|
||||
/// 1. ArrayRef has an extra convenience constructor for SmallVector.
|
||||
/// 2. ArrayRef uses TORCH_CHECK. HeaderOnlyArrayRef uses header-only
|
||||
/// STD_TORCH_CHECK, which will output a std::runtime_error vs a
|
||||
/// c10::Error. Consequently, you should use ArrayRef when possible
|
||||
/// and HeaderOnlyArrayRef only when necessary to support headeronly code.
|
||||
/// In all other aspects, HeaderOnlyArrayRef is identical to ArrayRef, with the
|
||||
/// positive benefit of being header-only and thus independent of libtorch.so.
|
||||
template <typename T>
|
||||
class HeaderOnlyArrayRef {
|
||||
public:
|
||||
using iterator = const T*;
|
||||
using const_iterator = const T*;
|
||||
using size_type = size_t;
|
||||
using value_type = T;
|
||||
|
||||
using reverse_iterator = std::reverse_iterator<iterator>;
|
||||
|
||||
protected:
|
||||
/// The start of the array, in an external buffer.
|
||||
const T* Data;
|
||||
|
||||
/// The number of elements.
|
||||
size_type Length;
|
||||
|
||||
public:
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
/// Construct an empty HeaderOnlyArrayRef.
|
||||
/* implicit */ constexpr HeaderOnlyArrayRef() : Data(nullptr), Length(0) {}
|
||||
|
||||
/// Construct a HeaderOnlyArrayRef from a single element.
|
||||
// TODO Make this explicit
|
||||
constexpr HeaderOnlyArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
|
||||
|
||||
/// Construct a HeaderOnlyArrayRef from a pointer and length.
|
||||
constexpr HeaderOnlyArrayRef(const T* data, size_t length)
|
||||
: Data(data), Length(length) {}
|
||||
|
||||
/// Construct a HeaderOnlyArrayRef from a range.
|
||||
constexpr HeaderOnlyArrayRef(const T* begin, const T* end)
|
||||
: Data(begin), Length(end - begin) {}
|
||||
|
||||
template <
|
||||
typename Container,
|
||||
typename U = decltype(std::declval<Container>().data()),
|
||||
typename = std::enable_if_t<
|
||||
(std::is_same_v<U, T*> || std::is_same_v<U, T const*>)>>
|
||||
/* implicit */ HeaderOnlyArrayRef(const Container& container)
|
||||
: Data(container.data()), Length(container.size()) {}
|
||||
|
||||
/// Construct a HeaderOnlyArrayRef from a std::vector.
|
||||
// The enable_if stuff here makes sure that this isn't used for
|
||||
// std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
|
||||
// bitfield.
|
||||
template <typename A>
|
||||
/* implicit */ HeaderOnlyArrayRef(const std::vector<T, A>& Vec)
|
||||
: Data(Vec.data()), Length(Vec.size()) {
|
||||
static_assert(
|
||||
!std::is_same_v<T, bool>,
|
||||
"HeaderOnlyArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
|
||||
}
|
||||
|
||||
/// Construct a HeaderOnlyArrayRef from a std::array
|
||||
template <size_t N>
|
||||
/* implicit */ constexpr HeaderOnlyArrayRef(const std::array<T, N>& Arr)
|
||||
: Data(Arr.data()), Length(N) {}
|
||||
|
||||
/// Construct a HeaderOnlyArrayRef from a C array.
|
||||
template <size_t N>
|
||||
// NOLINTNEXTLINE(*c-arrays*)
|
||||
/* implicit */ constexpr HeaderOnlyArrayRef(const T (&Arr)[N])
|
||||
: Data(Arr), Length(N) {}
|
||||
|
||||
/// Construct a HeaderOnlyArrayRef from a std::initializer_list.
|
||||
/* implicit */ constexpr HeaderOnlyArrayRef(
|
||||
const std::initializer_list<T>& Vec)
|
||||
: Data(
|
||||
std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
|
||||
: std::begin(Vec)),
|
||||
Length(Vec.size()) {}
|
||||
|
||||
/// @}
|
||||
/// @name Simple Operations
|
||||
/// @{
|
||||
|
||||
constexpr iterator begin() const {
|
||||
return this->Data;
|
||||
}
|
||||
constexpr iterator end() const {
|
||||
return this->Data + this->Length;
|
||||
}
|
||||
|
||||
// These are actually the same as iterator, since ArrayRef only
|
||||
// gives you const iterators.
|
||||
constexpr const_iterator cbegin() const {
|
||||
return this->Data;
|
||||
}
|
||||
constexpr const_iterator cend() const {
|
||||
return this->Data + this->Length;
|
||||
}
|
||||
|
||||
constexpr reverse_iterator rbegin() const {
|
||||
return reverse_iterator(end());
|
||||
}
|
||||
constexpr reverse_iterator rend() const {
|
||||
return reverse_iterator(begin());
|
||||
}
|
||||
|
||||
/// Check if all elements in the array satisfy the given expression
|
||||
constexpr bool allMatch(const std::function<bool(const T&)>& pred) const {
|
||||
return std::all_of(cbegin(), cend(), pred);
|
||||
}
|
||||
|
||||
/// empty - Check if the array is empty.
|
||||
constexpr bool empty() const {
|
||||
return this->Length == 0;
|
||||
}
|
||||
|
||||
constexpr const T* data() const {
|
||||
return this->Data;
|
||||
}
|
||||
|
||||
/// size - Get the array size.
|
||||
constexpr size_t size() const {
|
||||
return this->Length;
|
||||
}
|
||||
|
||||
/// front - Get the first element.
|
||||
constexpr const T& front() const {
|
||||
STD_TORCH_CHECK(
|
||||
!this->empty(),
|
||||
"HeaderOnlyArrayRef: attempted to access front() of empty list");
|
||||
return this->Data[0];
|
||||
}
|
||||
|
||||
/// back - Get the last element.
|
||||
constexpr const T& back() const {
|
||||
STD_TORCH_CHECK(
|
||||
!this->empty(),
|
||||
"HeaderOnlyArrayRef: attempted to access back() of empty list");
|
||||
return this->Data[this->Length - 1];
|
||||
}
|
||||
|
||||
/// equals - Check for element-wise equality.
|
||||
constexpr bool equals(HeaderOnlyArrayRef RHS) const {
|
||||
return this->Length == RHS.Length &&
|
||||
std::equal(begin(), end(), RHS.begin());
|
||||
}
|
||||
|
||||
/// slice(n, m) - Take M elements of the array starting at element N
|
||||
constexpr HeaderOnlyArrayRef<T> slice(size_t N, size_t M) const {
|
||||
STD_TORCH_CHECK(
|
||||
N + M <= this->size(),
|
||||
"HeaderOnlyArrayRef: invalid slice, N = ",
|
||||
N,
|
||||
"; M = ",
|
||||
M,
|
||||
"; size = ",
|
||||
this->size());
|
||||
return HeaderOnlyArrayRef<T>(this->data() + N, M);
|
||||
}
|
||||
|
||||
/// slice(n) - Chop off the first N elements of the array.
|
||||
constexpr HeaderOnlyArrayRef<T> slice(size_t N) const {
|
||||
STD_TORCH_CHECK(
|
||||
N <= this->size(),
|
||||
"HeaderOnlyArrayRef: invalid slice, N = ",
|
||||
N,
|
||||
"; size = ",
|
||||
this->size());
|
||||
return slice(N, this->size() - N);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Operator Overloads
|
||||
/// @{
|
||||
constexpr const T& operator[](size_t Index) const {
|
||||
return this->Data[Index];
|
||||
}
|
||||
|
||||
/// Vector compatibility
|
||||
constexpr const T& at(size_t Index) const {
|
||||
STD_TORCH_CHECK(
|
||||
Index < this->Length,
|
||||
"HeaderOnlyArrayRef: invalid index Index = ",
|
||||
Index,
|
||||
"; Length = ",
|
||||
this->Length);
|
||||
return this->Data[Index];
|
||||
}
|
||||
|
||||
/// Disallow accidental assignment from a temporary.
|
||||
///
|
||||
/// The declaration here is extra complicated so that "arrayRef = {}"
|
||||
/// continues to select the move assignment operator.
|
||||
template <typename U>
|
||||
std::enable_if_t<std::is_same_v<U, T>, HeaderOnlyArrayRef<T>>& operator=(
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
|
||||
U&& Temporary) = delete;
|
||||
|
||||
/// Disallow accidental assignment from a temporary.
|
||||
///
|
||||
/// The declaration here is extra complicated so that "arrayRef = {}"
|
||||
/// continues to select the move assignment operator.
|
||||
template <typename U>
|
||||
std::enable_if_t<std::is_same_v<U, T>, HeaderOnlyArrayRef<T>>& operator=(
|
||||
std::initializer_list<U>) = delete;
|
||||
|
||||
/// @}
|
||||
/// @name Expensive Operations
|
||||
/// @{
|
||||
std::vector<T> vec() const {
|
||||
return std::vector<T>(this->Data, this->Data + this->Length);
|
||||
}
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace torch::headeronly {
|
||||
using c10::HeaderOnlyArrayRef;
|
||||
using IntHeaderOnlyArrayRef = HeaderOnlyArrayRef<int64_t>;
|
||||
} // namespace torch::headeronly
|
||||
@ -4,7 +4,7 @@ import operator
|
||||
import re
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch.autograd.profiler import profile
|
||||
from torch.profiler import DeviceType
|
||||
@ -400,170 +400,3 @@ def _init_for_cuda_graphs() -> None:
|
||||
|
||||
with profile():
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimelineEvent:
|
||||
"""Represents an event in the profiler timeline."""
|
||||
|
||||
timestamp: int
|
||||
event_type: Literal["start", "end", "regular"]
|
||||
marker_type: Optional[Literal["filename", "node"]]
|
||||
identifier: Optional[str | int]
|
||||
event: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextStackEntry:
|
||||
"""Represents a context (filename or node) in the stack."""
|
||||
|
||||
context_type: Literal["filename", "node"]
|
||||
identifier: str | int
|
||||
metadata: Optional[dict]
|
||||
tid: Optional[int] = None # Thread ID associated with this context
|
||||
|
||||
|
||||
def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
"""
|
||||
Maps recorded profiler events to their corresponding fx nodes and adds stack traces.
|
||||
|
||||
Builds a timeline of all events (regular ops and FX markers for filenames/nodes),
|
||||
sorts by timestamp, then processes chronologically while maintaining a context stack of active
|
||||
filename/node scopes. Regular events are augmented with stack traces and node names from the
|
||||
innermost active context. Runtime is O(n log n) for n events.
|
||||
|
||||
Args:
|
||||
traced_data: Json of profiler events from Chrome trace
|
||||
|
||||
Returns:
|
||||
Dict mapping recorded event names to their aten operations with added stack traces
|
||||
"""
|
||||
from torch.fx.traceback import _FX_METADATA_REGISTRY
|
||||
|
||||
trace_events = traced_data.get("traceEvents", [])
|
||||
|
||||
# Create event timeline
|
||||
event_timeline: list[TimelineEvent] = []
|
||||
|
||||
def is_fx_marker_event(event):
|
||||
return (
|
||||
event.get("cat") == "cpu_op"
|
||||
and event.get("name", "").startswith("## ")
|
||||
and event.get("name", "").endswith(" ##")
|
||||
)
|
||||
|
||||
def append_fx_marker_event(event_type, identifier, event):
|
||||
start_ts = event["ts"]
|
||||
end_ts = start_ts + event["dur"]
|
||||
event_timeline.append(
|
||||
TimelineEvent(start_ts, "start", event_type, identifier, event)
|
||||
)
|
||||
event_timeline.append(
|
||||
TimelineEvent(end_ts, "end", event_type, identifier, event)
|
||||
)
|
||||
|
||||
for event in trace_events:
|
||||
if "ts" not in event or "dur" not in event:
|
||||
continue
|
||||
|
||||
if is_fx_marker_event(event):
|
||||
content = event["name"][3:-3]
|
||||
|
||||
if content.endswith(".py"):
|
||||
append_fx_marker_event("filename", content, event)
|
||||
else:
|
||||
try:
|
||||
node_index = int(content)
|
||||
except ValueError:
|
||||
pass
|
||||
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
|
||||
|
||||
else:
|
||||
# Regular event that needs augmentation
|
||||
start_ts = event["ts"]
|
||||
event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event))
|
||||
|
||||
# Sort by timestamp
|
||||
event_timeline.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# Process events in chronological order with a stack
|
||||
context_stack: list[ContextStackEntry] = []
|
||||
|
||||
# Invariant: all start event has a corresponding end event
|
||||
for timeline_event in event_timeline:
|
||||
match timeline_event.event_type:
|
||||
case "start":
|
||||
assert timeline_event.identifier is not None
|
||||
|
||||
if timeline_event.marker_type == "filename":
|
||||
assert isinstance(timeline_event.identifier, str)
|
||||
# Push filename context - query metadata registry on-demand
|
||||
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
|
||||
tid = timeline_event.event.get("tid")
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
"filename", timeline_event.identifier, metadata, tid
|
||||
)
|
||||
)
|
||||
elif timeline_event.marker_type == "node":
|
||||
# Find the current filename from stack
|
||||
current_file_metadata = None
|
||||
tid = timeline_event.event.get("tid")
|
||||
for ctx_entry in reversed(context_stack):
|
||||
if (
|
||||
ctx_entry.context_type == "filename"
|
||||
and ctx_entry.tid == tid
|
||||
):
|
||||
current_file_metadata = ctx_entry.metadata
|
||||
break
|
||||
|
||||
if current_file_metadata:
|
||||
node_metadata = current_file_metadata.get("node_metadata", {})
|
||||
if timeline_event.identifier in node_metadata:
|
||||
node_meta: Optional[dict] = node_metadata[
|
||||
timeline_event.identifier
|
||||
]
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
"node", timeline_event.identifier, node_meta, tid
|
||||
)
|
||||
)
|
||||
|
||||
case "end":
|
||||
# Pop from stack - search backwards to find matching context
|
||||
for i in range(len(context_stack) - 1, -1, -1):
|
||||
ctx_entry = context_stack[i]
|
||||
if (
|
||||
timeline_event.marker_type == ctx_entry.context_type
|
||||
and timeline_event.identifier == ctx_entry.identifier
|
||||
):
|
||||
context_stack.pop(i)
|
||||
break
|
||||
|
||||
case "regular":
|
||||
# Apply metadata from current context stack
|
||||
# Find the most specific context (node takes precedence over filename)
|
||||
# Only augment events with the same tid as the file/node event matched
|
||||
current_stack_trace = None
|
||||
current_node_name = None
|
||||
event_tid = timeline_event.event.get("tid")
|
||||
|
||||
for ctx_entry in reversed(context_stack):
|
||||
# Only apply metadata from contexts with matching tid
|
||||
if ctx_entry.tid == event_tid:
|
||||
if ctx_entry.context_type == "node" and ctx_entry.metadata:
|
||||
current_stack_trace = ctx_entry.metadata.get(
|
||||
"stack_trace", "No model stack trace available"
|
||||
)
|
||||
current_node_name = ctx_entry.metadata.get("name", "")
|
||||
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
|
||||
# if nodes are nested, e.g. in nested graph modules
|
||||
break
|
||||
|
||||
# Augment the event
|
||||
if current_stack_trace or current_node_name:
|
||||
args = timeline_event.event.setdefault("args", {})
|
||||
if current_stack_trace:
|
||||
args["stack_trace"] = current_stack_trace
|
||||
if current_node_name:
|
||||
args["node_name"] = current_node_name
|
||||
|
||||
@ -210,7 +210,8 @@ class _KinetoProfile:
|
||||
def start_trace(self) -> None:
|
||||
if self.execution_trace_observer:
|
||||
self.execution_trace_observer.start()
|
||||
assert self.profiler is not None
|
||||
if self.profiler is None:
|
||||
raise AssertionError("Profiler must be initialized before starting trace")
|
||||
self.profiler._start_trace()
|
||||
|
||||
if self.profile_memory:
|
||||
@ -256,7 +257,8 @@ class _KinetoProfile:
|
||||
def stop_trace(self) -> None:
|
||||
if self.execution_trace_observer:
|
||||
self.execution_trace_observer.stop()
|
||||
assert self.profiler is not None
|
||||
if self.profiler is None:
|
||||
raise AssertionError("Profiler must be initialized before stopping trace")
|
||||
self.profiler.__exit__(None, None, None)
|
||||
|
||||
def export_chrome_trace(self, path: str):
|
||||
@ -264,7 +266,10 @@ class _KinetoProfile:
|
||||
Exports the collected trace in Chrome JSON format. If kineto is enabled, only
|
||||
last cycle in schedule is exported.
|
||||
"""
|
||||
assert self.profiler
|
||||
if self.profiler is None:
|
||||
raise AssertionError(
|
||||
"Profiler must be initialized before exporting chrome trace"
|
||||
)
|
||||
if path.endswith(".gz"):
|
||||
fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False)
|
||||
fp.close()
|
||||
@ -284,7 +289,8 @@ class _KinetoProfile:
|
||||
path (str): save stacks file to this location;
|
||||
metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total"
|
||||
"""
|
||||
assert self.profiler
|
||||
if self.profiler is None:
|
||||
raise AssertionError("Profiler must be initialized before exporting stacks")
|
||||
return self.profiler.export_stacks(path, metric)
|
||||
|
||||
def toggle_collection_dynamic(
|
||||
@ -316,7 +322,7 @@ class _KinetoProfile:
|
||||
print(p.key_averages().table(
|
||||
sort_by="self_cuda_time_total", row_limit=-1))
|
||||
"""
|
||||
if not self.profiler:
|
||||
if self.profiler is None:
|
||||
return
|
||||
self.profiler.toggle_collection_dynamic(enable, activities)
|
||||
|
||||
@ -333,7 +339,10 @@ class _KinetoProfile:
|
||||
To use shape/stack functionality make sure to set record_shapes/with_stack
|
||||
when creating profiler context manager.
|
||||
"""
|
||||
assert self.profiler
|
||||
if self.profiler is None:
|
||||
raise AssertionError(
|
||||
"Profiler must be initialized before getting key averages"
|
||||
)
|
||||
return self.profiler.key_averages(
|
||||
group_by_input_shape, group_by_stack_n, group_by_overload_name
|
||||
)
|
||||
@ -343,7 +352,8 @@ class _KinetoProfile:
|
||||
Returns the list of unaggregated profiler events,
|
||||
to be used in the trace callback or after the profiling is finished
|
||||
"""
|
||||
assert self.profiler
|
||||
if self.profiler is None:
|
||||
raise AssertionError("Profiler must be initialized before accessing events")
|
||||
return self.profiler.function_events
|
||||
|
||||
def add_metadata(self, key: str, value: str) -> None:
|
||||
@ -395,7 +405,10 @@ class _KinetoProfile:
|
||||
if missing:
|
||||
raise ValueError(f"{', '.join(missing)} required for memory profiling.")
|
||||
|
||||
assert self.profiler is not None and self.profiler.kineto_results is not None
|
||||
if self.profiler is None or self.profiler.kineto_results is None:
|
||||
raise AssertionError(
|
||||
"Profiler and kineto_results must be initialized for memory profiling"
|
||||
)
|
||||
return MemoryProfile(self.profiler.kineto_results)
|
||||
|
||||
def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
|
||||
@ -485,7 +498,8 @@ def schedule(
|
||||
"""
|
||||
|
||||
def schedule_fn(step: int) -> ProfilerAction:
|
||||
assert step >= 0
|
||||
if step < 0:
|
||||
raise AssertionError(f"Step must be non-negative. Got {step}.")
|
||||
if step < skip_first:
|
||||
return ProfilerAction.NONE
|
||||
else:
|
||||
@ -508,9 +522,11 @@ def schedule(
|
||||
else ProfilerAction.RECORD_AND_SAVE
|
||||
)
|
||||
|
||||
assert (
|
||||
wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0
|
||||
), "Invalid profiler schedule arguments"
|
||||
if wait < 0 or warmup < 0 or active <= 0 or repeat < 0 or skip_first < 0:
|
||||
raise AssertionError(
|
||||
f"Invalid profiler schedule arguments. Got wait={wait} (need >= 0), warmup={warmup} (need >= 0), "
|
||||
f"active={active} (need > 0), repeat={repeat} (need >= 0), skip_first={skip_first} (need >= 0)."
|
||||
)
|
||||
if warmup == 0:
|
||||
warn(
|
||||
"Profiler won't be using warmup, this can skew profiler results",
|
||||
@ -717,7 +733,8 @@ class profile(_KinetoProfile):
|
||||
activities_set.add(ProfilerActivity.CUDA)
|
||||
elif ProfilerActivity.CUDA in activities_set:
|
||||
activities_set.remove(ProfilerActivity.CUDA)
|
||||
assert len(activities_set) > 0, "No valid profiler activities found"
|
||||
if len(activities_set) == 0:
|
||||
raise AssertionError("No valid profiler activities found")
|
||||
|
||||
super().__init__(
|
||||
activities=activities,
|
||||
|
||||
Reference in New Issue
Block a user