[ghstack-poisoned]
This commit is contained in:
Yu, Guangye
2025-07-15 13:22:32 +00:00
122 changed files with 2240 additions and 669 deletions

View File

@ -5,7 +5,7 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh"
if [[ ${BUILD_ENVIRONMENT} == *onnx* ]]; then
pip install click mock tabulate networkx==2.0
pip -q install --user "file:///var/lib/jenkins/workspace/third_party/onnx#egg=onnx"
pip -q install "file:///var/lib/jenkins/workspace/third_party/onnx#egg=onnx"
fi
# Skip tests in environments where they are not built/applicable
@ -147,8 +147,8 @@ export DNNL_MAX_CPU_ISA=AVX2
if [[ "${SHARD_NUMBER:-1}" == "1" ]]; then
# TODO(sdym@meta.com) remove this when the linked issue resolved.
# py is temporary until https://github.com/Teemu/pytest-sugar/issues/241 is fixed
pip install --user py==1.11.0
pip install --user pytest-sugar
pip install py==1.11.0
pip install pytest-sugar
# NB: Warnings are disabled because they make it harder to see what
# the actual erroring test is
"$PYTHON" \

View File

@ -91,6 +91,17 @@ tag=$(echo $image | awk -F':' '{print $2}')
# configuration, so we hardcode everything here rather than do it
# from scratch
case "$tag" in
pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11)
CUDA_VERSION=12.4
CUDNN_VERSION=9
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=11
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
;;
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11)
CUDA_VERSION=12.8.1
CUDNN_VERSION=9

View File

@ -78,6 +78,19 @@ function install_nvshmem {
echo "nvSHMEM ${nvshmem_version} for CUDA ${cuda_major_version} (${arch_path}) installed."
}
function install_124 {
CUDNN_VERSION=9.1.0.70
echo "Installing CUDA 12.4.1 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.6.2"
install_cuda 12.4.1 cuda_12.4.1_550.54.15_linux
install_cudnn 12 $CUDNN_VERSION
CUDA_VERSION=12.4 bash install_nccl.sh
CUDA_VERSION=12.4 bash install_cusparselt.sh
ldconfig
}
function install_126 {
CUDNN_VERSION=9.10.2.21
@ -113,6 +126,40 @@ function install_129 {
ldconfig
}
function prune_124 {
echo "Pruning CUDA 12.4"
#####################################################################################
# CUDA 12.4 prune static libs
#####################################################################################
export NVPRUNE="/usr/local/cuda-12.4/bin/nvprune"
export CUDA_LIB_DIR="/usr/local/cuda-12.4/lib64"
export GENCODE="-gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90"
export GENCODE_CUDNN="-gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90"
if [[ -n "$OVERRIDE_GENCODE" ]]; then
export GENCODE=$OVERRIDE_GENCODE
fi
if [[ -n "$OVERRIDE_GENCODE_CUDNN" ]]; then
export GENCODE_CUDNN=$OVERRIDE_GENCODE_CUDNN
fi
# all CUDA libs except CuDNN and CuBLAS
ls $CUDA_LIB_DIR/ | grep "\.a" | grep -v "culibos" | grep -v "cudart" | grep -v "cudnn" | grep -v "cublas" | grep -v "metis" \
| xargs -I {} bash -c \
"echo {} && $NVPRUNE $GENCODE $CUDA_LIB_DIR/{} -o $CUDA_LIB_DIR/{}"
# prune CuDNN and CuBLAS
$NVPRUNE $GENCODE_CUDNN $CUDA_LIB_DIR/libcublas_static.a -o $CUDA_LIB_DIR/libcublas_static.a
$NVPRUNE $GENCODE_CUDNN $CUDA_LIB_DIR/libcublasLt_static.a -o $CUDA_LIB_DIR/libcublasLt_static.a
#####################################################################################
# CUDA 12.4 prune visual tools
#####################################################################################
export CUDA_BASE="/usr/local/cuda-12.4/"
rm -rf $CUDA_BASE/libnvvp $CUDA_BASE/nsightee_plugins $CUDA_BASE/nsight-compute-2024.1.0 $CUDA_BASE/nsight-systems-2023.4.4/
}
function prune_126 {
echo "Pruning CUDA 12.6"
#####################################################################################
@ -169,6 +216,8 @@ function install_128 {
while test $# -gt 0
do
case "$1" in
12.4) install_124; prune_124
;;
12.6|12.6.*) install_126; prune_126
;;
12.8|12.8.*) install_128;

View File

@ -8,6 +8,8 @@ if [[ -n "${CUDNN_VERSION}" ]]; then
CUDNN_NAME="cudnn-linux-x86_64-9.10.2.21_cuda12-archive"
elif [[ ${CUDA_VERSION:0:4} == "12.6" ]]; then
CUDNN_NAME="cudnn-linux-x86_64-9.10.2.21_cuda12-archive"
elif [[ ${CUDA_VERSION:0:4} == "12.4" ]]; then
CUDNN_NAME="cudnn-linux-x86_64-9.10.2.21_cuda12-archive"
elif [[ ${CUDA_VERSION:0:2} == "11" ]]; then
CUDNN_NAME="cudnn-linux-x86_64-9.1.0.70_cuda11-archive"
else

View File

@ -13,6 +13,14 @@ if [[ ${CUDA_VERSION:0:4} =~ ^12\.[5-9]$ ]]; then
fi
CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.7.1.0-archive"
curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz
elif [[ ${CUDA_VERSION:0:4} == "12.4" ]]; then
arch_path='sbsa'
export TARGETARCH=${TARGETARCH:-$(uname -m)}
if [ ${TARGETARCH} = 'amd64' ] || [ "${TARGETARCH}" = 'x86_64' ]; then
arch_path='x86_64'
fi
CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.6.2.3-archive"
curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz
else
echo "Not sure which libcusparselt version to install for this ${CUDA_VERSION}"
fi

View File

@ -19,7 +19,7 @@ git config --global --add safe.directory /var/lib/jenkins/workspace
if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
# TODO: This can be removed later once vision is also part of the Docker image
pip install -q --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)"
pip install -q --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)"
# JIT C++ extensions require ninja, so put it into PATH.
export PATH="/var/lib/jenkins/.local/bin:$PATH"
# NB: ONNX test is fast (~15m) so it's ok to retry it few more times to avoid any flaky issue, we

View File

@ -127,9 +127,9 @@ function install_torchaudio() {
if [[ "$1" == "cuda" ]]; then
# TODO: This is better to be passed as a parameter from _linux-test workflow
# so that it can be consistent with what is set in build
TORCH_CUDA_ARCH_LIST="8.0;8.6" pip_install --no-use-pep517 --user "git+https://github.com/pytorch/audio.git@${commit}"
TORCH_CUDA_ARCH_LIST="8.0;8.6" pip_install --no-use-pep517 "git+https://github.com/pytorch/audio.git@${commit}"
else
pip_install --no-use-pep517 --user "git+https://github.com/pytorch/audio.git@${commit}"
pip_install --no-use-pep517 "git+https://github.com/pytorch/audio.git@${commit}"
fi
}
@ -139,8 +139,8 @@ function install_torchtext() {
local text_commit
data_commit=$(get_pinned_commit data)
text_commit=$(get_pinned_commit text)
pip_install --no-use-pep517 --user "git+https://github.com/pytorch/data.git@${data_commit}"
pip_install --no-use-pep517 --user "git+https://github.com/pytorch/text.git@${text_commit}"
pip_install --no-use-pep517 "git+https://github.com/pytorch/data.git@${data_commit}"
pip_install --no-use-pep517 "git+https://github.com/pytorch/text.git@${text_commit}"
}
function install_torchvision() {
@ -153,7 +153,7 @@ function install_torchvision() {
echo 'char* dlerror(void) { return "";}'|gcc -fpic -shared -o "${HOME}/dlerror.so" -x c -
LD_PRELOAD=${orig_preload}:${HOME}/dlerror.so
fi
pip_install --no-use-pep517 --user "git+https://github.com/pytorch/vision.git@${commit}"
pip_install --no-use-pep517 "git+https://github.com/pytorch/vision.git@${commit}"
if [ -n "${LD_PRELOAD}" ]; then
LD_PRELOAD=${orig_preload}
fi
@ -173,7 +173,7 @@ function install_torchrec_and_fbgemm() {
if [[ "$BUILD_ENVIRONMENT" == *rocm* ]] ; then
# install torchrec first because it installs fbgemm nightly on top of rocm fbgemm
pip_install --no-use-pep517 --user "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}"
pip_install --no-use-pep517 "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}"
pip_uninstall fbgemm-gpu-nightly
pip_install tabulate # needed for newer fbgemm
@ -190,8 +190,8 @@ function install_torchrec_and_fbgemm() {
rm -rf fbgemm
else
# See https://github.com/pytorch/pytorch/issues/106971
CUDA_PATH=/usr/local/cuda-12.1 pip_install --no-use-pep517 --user "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#egg=fbgemm-gpu&subdirectory=fbgemm_gpu"
pip_install --no-use-pep517 --user "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}"
CUDA_PATH=/usr/local/cuda-12.1 pip_install --no-use-pep517 "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#egg=fbgemm-gpu&subdirectory=fbgemm_gpu"
pip_install --no-use-pep517 "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}"
fi
}
@ -234,7 +234,7 @@ function checkout_install_torchbench() {
function install_torchao() {
local commit
commit=$(get_pinned_commit torchao)
pip_install --no-use-pep517 --user "git+https://github.com/pytorch/ao.git@${commit}"
pip_install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${commit}"
}
function print_sccache_stats() {

View File

@ -201,7 +201,7 @@ fi
if [[ "$BUILD_ENVIRONMENT" != *-bazel-* ]] ; then
# JIT C++ extensions require ninja.
pip_install --user "ninja==1.10.2"
pip_install "ninja==1.10.2"
# ninja is installed in $HOME/.local/bin, e.g., /var/lib/jenkins/.local/bin for CI user jenkins
# but this script should be runnable by any user, including root
export PATH="$HOME/.local/bin:$PATH"
@ -496,7 +496,7 @@ DYNAMO_BENCHMARK_FLAGS=()
pr_time_benchmarks() {
pip_install --user "fbscribelogger"
pip_install "fbscribelogger"
TEST_REPORTS_DIR=$(pwd)/test/test-reports
mkdir -p "$TEST_REPORTS_DIR"
@ -1471,8 +1471,8 @@ test_bazel() {
test_benchmarks() {
if [[ "$BUILD_ENVIRONMENT" == *cuda* && $TEST_CONFIG != *nogpu* ]]; then
pip_install --user "pytest-benchmark==3.2.3"
pip_install --user "requests"
pip_install "pytest-benchmark==3.2.3"
pip_install "requests"
BENCHMARK_DATA="benchmarks/.data"
mkdir -p ${BENCHMARK_DATA}
pytest benchmarks/fastrnns/test_bench.py --benchmark-sort=Name --benchmark-json=${BENCHMARK_DATA}/fastrnns_default.json --fuser=default --executor=default

View File

@ -37,10 +37,10 @@ IF "%CUDA_PATH_V129%"=="" (
)
IF "%BUILD_VISION%" == "" (
set TORCH_CUDA_ARCH_LIST=7.5;8.0;8.6;9.0;10.0;12.0
set TORCH_CUDA_ARCH_LIST=7.0;7.5;8.0;8.6;9.0;10.0;12.0
set TORCH_NVCC_FLAGS=-Xfatbin -compress-all
) ELSE (
set NVCC_FLAGS=-D__CUDA_NO_HALF_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=compute_90 -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_120,code=compute_120
set NVCC_FLAGS=-D__CUDA_NO_HALF_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=compute_90 -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_120,code=compute_120
)
set "CUDA_PATH=%CUDA_PATH_V129%"

View File

@ -57,6 +57,7 @@ jobs:
pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc9-inductor-benchmarks,
pytorch-linux-jammy-cuda12.8-cudnn9-py3.13-gcc9-inductor-benchmarks,
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9,
pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11,
pytorch-linux-jammy-py3.9-clang12,
pytorch-linux-jammy-py3.11-clang12,
pytorch-linux-jammy-py3.12-clang12,

View File

@ -51,6 +51,37 @@ jobs:
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-cuda12_4-py3_10-gcc11-sm89-build:
name: linux-jammy-cuda12.4-py3.10-gcc11-sm89
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
]}
secrets: inherit
linux-jammy-cuda12_4-py3_10-gcc11-sm89-test:
name: linux-jammy-cuda12.4-py3.10-gcc11-sm89
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-jammy-cuda12_4-py3_10-gcc11-sm89-build
- target-determination
with:
build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89
docker-image: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-build:
name: linux-jammy-cuda12.8-py3.10-gcc11
uses: ./.github/workflows/_linux-build.yml

View File

@ -315,14 +315,14 @@ jobs:
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-py3-clang12-mobile-build:
name: linux-jammy-py3-clang12-mobile-build
linux-jammy-py3-clang18-mobile-build:
name: linux-jammy-py3-clang18-mobile-build
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py3-clang12-mobile-build
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang15-asan
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan
build-generates-artifacts: false
test-matrix: |
{ include: [

View File

@ -162,7 +162,7 @@ struct CUDACachingHostAllocatorImpl
}
bool pinned_use_background_threads() override {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
pinned_use_background_threads();
}

View File

@ -17,7 +17,13 @@ __global__ static void compute_cuda_kernel(
index_t* result_ptr,
int64_t size,
int64_t result_size) {
CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1]);
if (C10_UNLIKELY((result_size != cumsum_ptr[size - 1]))) {
printf("%s:%d:%s: block: [%d,%d,%d], thread: [%d,%d,%d] "
"Invalid input! In `repeat_interleave`, the `output_size` argument (%ld) must be the same as the sum of the elements in the `repeats` tensor (%ld).\n",
__FILE__, __LINE__, __func__,blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z, result_size, cumsum_ptr[size - 1 ]);
CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1])
}
int64_t idx = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
int64_t stride = (blockDim.x * gridDim.x) / C10_WARP_SIZE;
int warp_id = idx / C10_WARP_SIZE;

View File

@ -114,8 +114,22 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
}
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
// Overwrites expected NANs in sm with zeros.
auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:sm secondaryTensor:vTensor name:nil];
MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
truePredicateTensor:zeroTensor
falsePredicateTensor:sm
name:nil];
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
graph->qTensor = qTensor;
graph->kTensor = kTensor;
graph->vTensor = vTensor;

View File

@ -1,121 +1,389 @@
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDAException.h>
#include <c10/util/irange.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/llvmMathExtras.h>
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#endif
#include <cuda_runtime_api.h>
namespace c10::cuda::CUDACachingAllocator {
size_t CUDAAllocatorConfig::parseAllocatorConfig(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
constexpr size_t kRoundUpPowerOfTwoIntervals = 16;
CUDAAllocatorConfig::CUDAAllocatorConfig()
: m_max_split_size(std::numeric_limits<size_t>::max()),
m_max_non_split_rounding_size(kLargeBuffer),
m_garbage_collection_threshold(0),
m_pinned_num_register_threads(1),
m_expandable_segments(false),
#if CUDA_VERSION >= 12030
m_expandable_segments_handle_type(
Expandable_Segments_Handle_Type::UNSPECIFIED),
#else
m_expandable_segments_handle_type(
Expandable_Segments_Handle_Type::POSIX_FD),
#endif
m_release_lock_on_cudamalloc(false),
m_pinned_use_cuda_host_register(false),
m_pinned_use_background_threads(false) {
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
}
size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) {
size_t log_size = (63 - llvm::countLeadingZeros(size));
// Our intervals start at 1MB and end at 64GB
const size_t interval_start =
63 - llvm::countLeadingZeros(static_cast<size_t>(1048576));
const size_t interval_end =
63 - llvm::countLeadingZeros(static_cast<size_t>(68719476736));
TORCH_CHECK(
(interval_end - interval_start == kRoundUpPowerOfTwoIntervals),
"kRoundUpPowerOfTwoIntervals mismatch");
int index = static_cast<int>(log_size) - static_cast<int>(interval_start);
index = std::max(0, index);
index = std::min(index, static_cast<int>(kRoundUpPowerOfTwoIntervals) - 1);
return instance().m_roundup_power2_divisions[index];
}
void CUDAAllocatorConfig::lexArgs(
const std::string& env,
std::vector<std::string>& config) {
std::vector<char> buf;
for (char ch : env) {
if (ch == ',' || ch == ':' || ch == '[' || ch == ']') {
if (!buf.empty()) {
config.emplace_back(buf.begin(), buf.end());
buf.clear();
}
config.emplace_back(1, ch);
} else if (ch != ' ') {
buf.emplace_back(ch);
}
}
if (!buf.empty()) {
config.emplace_back(buf.begin(), buf.end());
}
}
void CUDAAllocatorConfig::consumeToken(
const std::vector<std::string>& config,
size_t i,
const char c) {
TORCH_CHECK(
i < config.size() && config[i] == std::string(1, c),
"Error parsing CachingAllocator settings, expected ",
c,
"");
}
size_t CUDAAllocatorConfig::parseMaxSplitSize(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
constexpr int mb = 1024 * 1024;
if (++i < config.size()) {
size_t val1 = stoi(config[i]);
TORCH_CHECK(
val1 > kLargeBuffer / mb,
"CachingAllocator option max_split_size_mb too small, must be > ",
kLargeBuffer / mb,
"");
val1 = std::max(val1, kLargeBuffer / mb);
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
m_max_split_size = val1 * 1024 * 1024;
} else {
TORCH_CHECK(false, "Error, expecting max_split_size_mb value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseMaxNonSplitRoundingSize(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
constexpr int mb = 1024 * 1024;
if (++i < config.size()) {
size_t val1 = stoi(config[i]);
TORCH_CHECK(
val1 > kLargeBuffer / mb,
"CachingAllocator option max_non_split_rounding_mb too small, must be > ",
kLargeBuffer / mb,
"");
val1 = std::max(val1, kLargeBuffer / mb);
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
m_max_non_split_rounding_size = val1 * 1024 * 1024;
} else {
TORCH_CHECK(false, "Error, expecting max_non_split_rounding_mb value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
if (++i < config.size()) {
double val1 = stod(config[i]);
TORCH_CHECK(
val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", "");
TORCH_CHECK(
val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", "");
m_garbage_collection_threshold = val1;
} else {
TORCH_CHECK(
false, "Error, expecting garbage_collection_threshold value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
bool first_value = true;
if (++i < config.size()) {
if (std::string_view(config[i]) == "[") {
size_t last_index = 0;
// NOLINTNEXTLINE(bugprone-inc-dec-in-conditions)
while (++i < config.size() && std::string_view(config[i]) != "]") {
const std::string& val1 = config[i];
size_t val2 = 0;
consumeToken(config, ++i, ':');
if (++i < config.size()) {
val2 = stoi(config[i]);
} else {
TORCH_CHECK(
false, "Error parsing roundup_power2_divisions value", "");
}
TORCH_CHECK(
val2 == 0 || llvm::isPowerOf2_64(val2),
"For roundups, the divisions has to be power of 2 or 0 to disable roundup ",
"");
if (std::string_view(val1) == ">") {
std::fill(
std::next(
m_roundup_power2_divisions.begin(),
static_cast<std::vector<unsigned long>::difference_type>(
last_index)),
m_roundup_power2_divisions.end(),
val2);
} else {
size_t val1_long = stoul(val1);
TORCH_CHECK(
llvm::isPowerOf2_64(val1_long),
"For roundups, the intervals have to be power of 2 ",
"");
size_t index = 63 - llvm::countLeadingZeros(val1_long);
index = std::max((size_t)0, index);
index = std::min(index, m_roundup_power2_divisions.size() - 1);
if (first_value) {
std::fill(
m_roundup_power2_divisions.begin(),
std::next(
m_roundup_power2_divisions.begin(),
static_cast<std::vector<unsigned long>::difference_type>(
index)),
val2);
first_value = false;
}
if (index < m_roundup_power2_divisions.size()) {
m_roundup_power2_divisions[index] = val2;
}
last_index = index;
}
if (std::string_view(config[i + 1]) != "]") {
consumeToken(config, ++i, ',');
}
}
} else { // Keep this for backwards compatibility
size_t val1 = stoi(config[i]);
TORCH_CHECK(
llvm::isPowerOf2_64(val1),
"For roundups, the divisions has to be power of 2 ",
"");
std::fill(
m_roundup_power2_divisions.begin(),
m_roundup_power2_divisions.end(),
val1);
}
} else {
TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseAllocatorConfig(
const std::vector<std::string>& config,
size_t i,
bool& used_cudaMallocAsync) {
// For ease of maintenance and understanding, the CUDA and ROCm
// implementations of this function are separated. This avoids having many
// #ifdef's throughout.
#ifdef USE_ROCM
// Ease burden on ROCm users by allowing either cuda or hip tokens.
// cuda token is broken up to prevent hipify matching it.
#define PYTORCH_TOKEN1 \
"cud" \
"aMallocAsync"
#define PYTORCH_TOKEN2 "hipMallocAsync"
tokenizer.checkToken(++i, ":");
i++; // Move to the value after the colon
TORCH_CHECK(
((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1) ||
(tokenizer[i] == PYTORCH_TOKEN2)),
"Unknown allocator backend, "
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
if (m_is_allocator_loaded) {
bool aync_allocator_at_runtime = (tokenizer[i] != "native");
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
aync_allocator_at_runtime == m_use_async_allocator,
"Allocator async backend parsed at runtime != allocator async backend parsed at load time, ",
aync_allocator_at_runtime,
((config[i] == "native") || (config[i] == PYTORCH_TOKEN1) ||
(config[i] == PYTORCH_TOKEN2)),
"Unknown allocator backend, "
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
used_cudaMallocAsync =
(config[i] == PYTORCH_TOKEN1 || config[i] == PYTORCH_TOKEN2);
TORCH_INTERNAL_ASSERT(
config[i] == get()->name() ||
(config[i] == PYTORCH_TOKEN1 && get()->name() == PYTORCH_TOKEN2),
"Allocator backend parsed at runtime != "
"allocator backend parsed at load time, ",
config[i],
" != ",
m_use_async_allocator);
get()->name());
} else {
TORCH_CHECK(false, "Error parsing backend value", "");
}
m_use_async_allocator =
(tokenizer[i] == PYTORCH_TOKEN1 || tokenizer[i] == PYTORCH_TOKEN2);
// CUDA allocator is always loaded at the start of the program
m_is_allocator_loaded = true;
#if defined(CUDA_VERSION)
if (m_use_async_allocator) {
#if CUDA_VERSION >= 11040
int version = 0;
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
TORCH_CHECK(
version >= 11040,
"backend:cudaMallocAsync requires CUDA runtime "
"11.4 or newer, but cudaDriverGetVersion returned ",
version);
#else
TORCH_CHECK(
false,
"backend:cudaMallocAsync requires PyTorch to be built with "
"CUDA 11.4 or newer, but CUDA_VERSION is ",
CUDA_VERSION);
#endif
}
#endif
return i;
#undef PYTORCH_TOKEN1
#undef PYTORCH_TOKEN2
#else // USE_ROCM
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
((config[i] == "native") || (config[i] == "cudaMallocAsync")),
"Unknown allocator backend, "
"options are native and cudaMallocAsync");
used_cudaMallocAsync = (config[i] == "cudaMallocAsync");
if (used_cudaMallocAsync) {
#if CUDA_VERSION >= 11040
int version = 0;
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
TORCH_CHECK(
version >= 11040,
"backend:cudaMallocAsync requires CUDA runtime "
"11.4 or newer, but cudaDriverGetVersion returned ",
version);
#else
TORCH_CHECK(
false,
"backend:cudaMallocAsync requires PyTorch to be built with "
"CUDA 11.4 or newer, but CUDA_VERSION is ",
CUDA_VERSION);
#endif
}
TORCH_INTERNAL_ASSERT(
config[i] == get()->name(),
"Allocator backend parsed at runtime != "
"allocator backend parsed at load time");
} else {
TORCH_CHECK(false, "Error parsing backend value", "");
}
return i;
#endif // USE_ROCM
}
void CUDAAllocatorConfig::parseArgs(const std::string& env) {
void CUDAAllocatorConfig::parseArgs(const std::optional<std::string>& env) {
// If empty, set the default values
m_max_split_size = std::numeric_limits<size_t>::max();
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
m_garbage_collection_threshold = 0;
bool used_cudaMallocAsync = false;
bool used_native_specific_option = false;
c10::CachingAllocator::ConfigTokenizer tokenizer(env);
for (size_t i = 0; i < tokenizer.size(); i++) {
const auto& key = tokenizer[i];
if (key == "backend") {
i = parseAllocatorConfig(tokenizer, i);
if (!env.has_value()) {
return;
}
{
std::lock_guard<std::mutex> lock(m_last_allocator_settings_mutex);
m_last_allocator_settings = env.value();
}
std::vector<std::string> config;
lexArgs(env.value(), config);
for (size_t i = 0; i < config.size(); i++) {
std::string_view config_item_view(config[i]);
if (config_item_view == "max_split_size_mb") {
i = parseMaxSplitSize(config, i);
used_native_specific_option = true;
} else if (config_item_view == "max_non_split_rounding_mb") {
i = parseMaxNonSplitRoundingSize(config, i);
used_native_specific_option = true;
} else if (config_item_view == "garbage_collection_threshold") {
i = parseGarbageCollectionThreshold(config, i);
used_native_specific_option = true;
} else if (config_item_view == "roundup_power2_divisions") {
i = parseRoundUpPower2Divisions(config, i);
used_native_specific_option = true;
} else if (config_item_view == "backend") {
i = parseAllocatorConfig(config, i, used_cudaMallocAsync);
} else if (config_item_view == "expandable_segments") {
used_native_specific_option = true;
consumeToken(config, ++i, ':');
++i;
TORCH_CHECK(
i < config.size() &&
(std::string_view(config[i]) == "True" ||
std::string_view(config[i]) == "False"),
"Expected a single True/False argument for expandable_segments");
config_item_view = config[i];
m_expandable_segments = (config_item_view == "True");
} else if (
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
// use, accept both. We must break up the string to prevent hipify here.
key == "release_lock_on_hipmalloc" ||
key ==
config_item_view == "release_lock_on_hipmalloc" ||
config_item_view ==
"release_lock_on_c"
"udamalloc") {
used_native_specific_option = true;
tokenizer.checkToken(++i, ":");
m_release_lock_on_cudamalloc = tokenizer.toBool(++i);
consumeToken(config, ++i, ':');
++i;
TORCH_CHECK(
i < config.size() &&
(std::string_view(config[i]) == "True" ||
std::string_view(config[i]) == "False"),
"Expected a single True/False argument for release_lock_on_cudamalloc");
config_item_view = config[i];
m_release_lock_on_cudamalloc = (config_item_view == "True");
} else if (
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
// use, accept both. We must break up the string to prevent hipify here.
key == "pinned_use_hip_host_register" ||
key ==
config_item_view == "pinned_use_hip_host_register" ||
config_item_view ==
"pinned_use_c"
"uda_host_register") {
i = parsePinnedUseCudaHostRegister(tokenizer, i);
i = parsePinnedUseCudaHostRegister(config, i);
used_native_specific_option = true;
} else if (key == "pinned_num_register_threads") {
i = parsePinnedNumRegisterThreads(tokenizer, i);
} else if (config_item_view == "pinned_num_register_threads") {
i = parsePinnedNumRegisterThreads(config, i);
used_native_specific_option = true;
} else if (config_item_view == "pinned_use_background_threads") {
i = parsePinnedUseBackgroundThreads(config, i);
used_native_specific_option = true;
} else {
const auto& keys =
c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys();
TORCH_CHECK(
keys.find(key) != keys.end(),
"Unrecognized key '",
key,
"' in Accelerator allocator config.");
i = tokenizer.skipKey(i);
false, "Unrecognized CachingAllocator option: ", config_item_view);
}
if (i + 1 < tokenizer.size()) {
tokenizer.checkToken(++i, ",");
if (i + 1 < config.size()) {
consumeToken(config, ++i, ',');
}
}
if (m_use_async_allocator && used_native_specific_option) {
if (used_cudaMallocAsync && used_native_specific_option) {
TORCH_WARN(
"backend:cudaMallocAsync ignores max_split_size_mb,"
"roundup_power2_divisions, and garbage_collect_threshold.");
@ -123,33 +391,64 @@ void CUDAAllocatorConfig::parseArgs(const std::string& env) {
}
size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i) {
tokenizer.checkToken(++i, ":");
m_pinned_use_cuda_host_register = tokenizer.toBool(++i);
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
(config[i] == "True" || config[i] == "False"),
"Expected a single True/False argument for pinned_use_cuda_host_register");
m_pinned_use_cuda_host_register = (config[i] == "True");
} else {
TORCH_CHECK(
false, "Error, expecting pinned_use_cuda_host_register value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i) {
tokenizer.checkToken(++i, ":");
size_t val2 = tokenizer.toSizeT(++i);
TORCH_CHECK(
llvm::isPowerOf2_64(val2),
"Number of register threads has to be power of 2 ",
"");
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
TORCH_CHECK(
val2 <= maxThreads,
"Number of register threads should be less than or equal to " +
std::to_string(maxThreads),
"");
m_pinned_num_register_threads = val2;
consumeToken(config, ++i, ':');
if (++i < config.size()) {
size_t val2 = stoi(config[i]);
TORCH_CHECK(
llvm::isPowerOf2_64(val2),
"Number of register threads has to be power of 2 ",
"");
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
TORCH_CHECK(
val2 <= maxThreads,
"Number of register threads should be less than or equal to " +
std::to_string(maxThreads),
"");
m_pinned_num_register_threads = val2;
} else {
TORCH_CHECK(
false, "Error, expecting pinned_num_register_threads value", "");
}
return i;
}
REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(CUDAAllocatorConfig)
size_t CUDAAllocatorConfig::parsePinnedUseBackgroundThreads(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
(config[i] == "True" || config[i] == "False"),
"Expected a single True/False argument for pinned_use_background_threads");
m_pinned_use_background_threads = (config[i] == "True");
} else {
TORCH_CHECK(
false, "Error, expecting pinned_use_background_threads value", "");
}
return i;
}
// General caching allocator utilities
void setAllocatorSettings(const std::string& env) {
CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str());
}
} // namespace c10::cuda::CUDACachingAllocator

View File

@ -1,11 +1,16 @@
#pragma once
#include <c10/core/AllocatorConfig.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/util/Deprecated.h>
#include <c10/util/Exception.h>
#include <c10/util/env.h>
#include <atomic>
#include <cstddef>
#include <cstdlib>
#include <mutex>
#include <string>
#include <vector>
namespace c10::cuda::CUDACachingAllocator {
enum class Expandable_Segments_Handle_Type : int {
@ -17,28 +22,21 @@ enum class Expandable_Segments_Handle_Type : int {
// Environment config parser
class C10_CUDA_API CUDAAllocatorConfig {
public:
C10_DEPRECATED_MESSAGE(
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_split_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size() instead.")
static size_t max_split_size() {
return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size();
return instance().m_max_split_size;
}
C10_DEPRECATED_MESSAGE(
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::garbage_collection_threshold() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::garbage_collection_threshold() instead.")
static double garbage_collection_threshold() {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
garbage_collection_threshold();
return instance().m_garbage_collection_threshold;
}
static bool expandable_segments() {
bool enabled = c10::CachingAllocator::AcceleratorAllocatorConfig::
use_expandable_segments();
#ifndef PYTORCH_C10_DRIVER_API_SUPPORTED
if (enabled) {
if (instance().m_expandable_segments) {
TORCH_WARN_ONCE("expandable_segments not supported on this platform")
}
return false;
#else
return enabled;
return instance().m_expandable_segments;
#endif
}
@ -64,11 +62,8 @@ class C10_CUDA_API CUDAAllocatorConfig {
return instance().m_pinned_num_register_threads;
}
C10_DEPRECATED_MESSAGE(
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::pinned_use_background_threads() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::pinned_use_background_threads() instead.")
static bool pinned_use_background_threads() {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
pinned_use_background_threads();
return instance().m_pinned_use_background_threads;
}
static size_t pinned_max_register_threads() {
@ -78,105 +73,92 @@ class C10_CUDA_API CUDAAllocatorConfig {
return 128;
}
C10_DEPRECATED_MESSAGE(
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.")
static size_t roundup_power2_divisions(size_t size) {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
roundup_power2_divisions(size);
}
// This is used to round-up allocation size to nearest power of 2 divisions.
// More description below in function roundup_power2_next_division
// As an example, if we want 4 divisions between 2's power, this can be done
// using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4
static size_t roundup_power2_divisions(size_t size);
C10_DEPRECATED_MESSAGE(
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.")
static std::vector<size_t> roundup_power2_divisions() {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
roundup_power2_divisions();
return instance().m_roundup_power2_divisions;
}
C10_DEPRECATED_MESSAGE(
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_non_split_rounding_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_non_split_rounding_size() instead.")
static size_t max_non_split_rounding_size() {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
max_non_split_rounding_size();
return instance().m_max_non_split_rounding_size;
}
C10_DEPRECATED_MESSAGE(
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::last_allocator_settings() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::last_allocator_settings() instead.")
static std::string last_allocator_settings() {
return c10::CachingAllocator::getAllocatorSettings();
}
static bool use_async_allocator() {
return instance().m_use_async_allocator;
}
static const std::unordered_set<std::string>& getKeys() {
return instance().keys_;
std::lock_guard<std::mutex> lock(
instance().m_last_allocator_settings_mutex);
return instance().m_last_allocator_settings;
}
static CUDAAllocatorConfig& instance() {
static CUDAAllocatorConfig* s_instance = ([]() {
auto inst = new CUDAAllocatorConfig();
auto env = c10::utils::get_env("PYTORCH_ALLOC_CONF");
if (!env.has_value()) {
// For backward compatibility, check for the old environment variable
// PYTORCH_CUDA_ALLOC_CONF.
env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
}
auto env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
#ifdef USE_ROCM
// convenience for ROCm users, allow alternative HIP token
if (!env.has_value()) {
env = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
}
#endif
if (env.has_value()) {
inst->parseArgs(env.value());
}
inst->parseArgs(env);
return inst;
})();
return *s_instance;
}
void parseArgs(const std::string& env);
void parseArgs(const std::optional<std::string>& env);
private:
CUDAAllocatorConfig() = default;
CUDAAllocatorConfig();
size_t parseAllocatorConfig(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
static void lexArgs(const std::string& env, std::vector<std::string>& config);
static void consumeToken(
const std::vector<std::string>& config,
size_t i,
const char c);
size_t parseMaxSplitSize(const std::vector<std::string>& config, size_t i);
size_t parseMaxNonSplitRoundingSize(
const std::vector<std::string>& config,
size_t i);
size_t parseGarbageCollectionThreshold(
const std::vector<std::string>& config,
size_t i);
size_t parseRoundUpPower2Divisions(
const std::vector<std::string>& config,
size_t i);
size_t parseAllocatorConfig(
const std::vector<std::string>& config,
size_t i,
bool& used_cudaMallocAsync);
size_t parsePinnedUseCudaHostRegister(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i);
size_t parsePinnedNumRegisterThreads(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i);
size_t parsePinnedUseBackgroundThreads(
const std::vector<std::string>& config,
size_t i);
std::atomic<size_t> m_pinned_num_register_threads{1};
std::atomic<Expandable_Segments_Handle_Type> m_expandable_segments_handle_type
#if CUDA_VERSION >= 12030
{Expandable_Segments_Handle_Type::UNSPECIFIED};
#else
{Expandable_Segments_Handle_Type::POSIX_FD};
#endif
std::atomic<bool> m_release_lock_on_cudamalloc{false};
std::atomic<bool> m_pinned_use_cuda_host_register{false};
std::atomic<bool> m_use_async_allocator{false};
std::atomic<bool> m_is_allocator_loaded{false};
std::unordered_set<std::string> keys_{
"backend",
// keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues
// NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors)
"release_lock_on_cud"
"amalloc",
"pinned_use_cud"
"a_host_register",
// NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors)
"release_lock_on_hipmalloc",
"pinned_use_hip_host_register",
"pinned_num_register_threads"};
std::atomic<size_t> m_max_split_size;
std::atomic<size_t> m_max_non_split_rounding_size;
std::vector<size_t> m_roundup_power2_divisions;
std::atomic<double> m_garbage_collection_threshold;
std::atomic<size_t> m_pinned_num_register_threads;
std::atomic<bool> m_expandable_segments;
std::atomic<Expandable_Segments_Handle_Type>
m_expandable_segments_handle_type;
std::atomic<bool> m_release_lock_on_cudamalloc;
std::atomic<bool> m_pinned_use_cuda_host_register;
std::atomic<bool> m_pinned_use_background_threads;
std::string m_last_allocator_settings;
std::mutex m_last_allocator_settings_mutex;
};
// Keep this for backwards compatibility
using c10::CachingAllocator::setAllocatorSettings;
// General caching allocator utilities
C10_CUDA_API void setAllocatorSettings(const std::string& env);
} // namespace c10::cuda::CUDACachingAllocator

View File

@ -1,6 +1,7 @@
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAGuard.h>
@ -63,6 +64,10 @@ namespace cuda::CUDACachingAllocator {
using namespace c10::CachingAllocator;
using namespace c10::CachingDeviceAllocator;
// Included here as this is externally used in CUDAAllocatorConfig
const size_t kLargeBuffer =
20971520; // "large" allocations may be packed in 20 MiB blocks
namespace Native {
//
@ -1226,7 +1231,7 @@ class DeviceCachingAllocator {
DeviceCachingAllocator()
: large_blocks(/*small=*/false), small_blocks(/*small=*/true) {
stats.max_split_size =
static_cast<int64_t>(AcceleratorAllocatorConfig::max_split_size());
static_cast<int64_t>(CUDAAllocatorConfig::max_split_size());
context_recorder_.store(nullptr);
}
@ -1351,8 +1356,7 @@ class DeviceCachingAllocator {
// Do garbage collection if the flag is set.
if (C10_UNLIKELY(
set_fraction &&
AcceleratorAllocatorConfig::garbage_collection_threshold() >
0.0)) {
CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
garbage_collect_cached_blocks(context);
}
// Attempt allocate
@ -1604,7 +1608,7 @@ class DeviceCachingAllocator {
stats.active_bytes[stat_type].increase(block->size);
stats.requested_bytes[stat_type].increase(block->requested_size);
});
if (block->size >= AcceleratorAllocatorConfig::max_split_size())
if (block->size >= CUDAAllocatorConfig::max_split_size())
stats.oversize_allocations.increase(1);
auto allocated_bytes_gauge =
@ -1655,7 +1659,7 @@ class DeviceCachingAllocator {
block->pool->owner_MempoolId(),
context ? context : block->context_when_allocated);
if (block->size >= AcceleratorAllocatorConfig::max_split_size())
if (block->size >= CUDAAllocatorConfig::max_split_size())
stats.oversize_allocations.decrease(1);
if (!block->stream_uses.empty()) {
@ -2205,8 +2209,7 @@ class DeviceCachingAllocator {
if (size < kMinBlockSize) {
return kMinBlockSize;
} else {
auto divisions =
AcceleratorAllocatorConfig::roundup_power2_divisions(size);
auto divisions = CUDAAllocatorConfig::roundup_power2_divisions(size);
if (divisions > 1 && size > (kMinBlockSize * divisions)) {
return roundup_power2_next_division(size, divisions);
} else {
@ -2696,7 +2699,7 @@ class DeviceCachingAllocator {
if (block->pool->is_small || CUDAAllocatorConfig::expandable_segments()) {
return remaining >= kMinBlockSize;
} else {
return (size < AcceleratorAllocatorConfig::max_split_size()) &&
return (size < CUDAAllocatorConfig::max_split_size()) &&
(remaining > kSmallSize);
}
}
@ -2716,7 +2719,7 @@ class DeviceCachingAllocator {
if (C10_UNLIKELY(
set_fraction &&
AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) {
CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
// Track block reuse interval only when garbage collection is enabled.
++pool.get_free_blocks_call_count;
}
@ -2758,13 +2761,13 @@ class DeviceCachingAllocator {
}
// Do not return an oversized block for a large request
if ((p.size() < AcceleratorAllocatorConfig::max_split_size()) &&
((*it)->size >= AcceleratorAllocatorConfig::max_split_size()))
if ((p.size() < CUDAAllocatorConfig::max_split_size()) &&
((*it)->size >= CUDAAllocatorConfig::max_split_size()))
return false;
// Allow oversized block size to be rounded up but within a limit
if ((p.size() >= AcceleratorAllocatorConfig::max_split_size()) &&
if ((p.size() >= CUDAAllocatorConfig::max_split_size()) &&
((*it)->size >=
p.size() + AcceleratorAllocatorConfig::max_non_split_rounding_size()))
p.size() + CUDAAllocatorConfig::max_non_split_rounding_size()))
return false;
p.block = *it;
pool.blocks.erase(it);
@ -2787,7 +2790,7 @@ class DeviceCachingAllocator {
// therefore should be of less overheads.
size_t gc_threshold = static_cast<size_t>(
AcceleratorAllocatorConfig::garbage_collection_threshold() *
CUDAAllocatorConfig::garbage_collection_threshold() *
static_cast<double>(allowed_memory_maximum));
// No need to trigger GC yet
if (total_allocated_memory <= gc_threshold) {
@ -2935,7 +2938,7 @@ class DeviceCachingAllocator {
stats.segment[stat_type].increase(1);
stats.reserved_bytes[stat_type].increase(size);
});
if (size >= AcceleratorAllocatorConfig::max_split_size())
if (size >= CUDAAllocatorConfig::max_split_size())
stats.oversize_segments.increase(1);
auto reserved_bytes_gauge =
STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
@ -2964,7 +2967,7 @@ class DeviceCachingAllocator {
bool release_available_cached_blocks(
const AllocParams& p,
const std::shared_ptr<GatheredContext>& context) {
if (AcceleratorAllocatorConfig::max_split_size() ==
if (CUDAAllocatorConfig::max_split_size() ==
std::numeric_limits<size_t>::max())
return false;
BlockPool& pool = *p.pool;
@ -2972,8 +2975,8 @@ class DeviceCachingAllocator {
// because of std::unique_ptr, block cannot be trivially copied
// Use constructor for search key.
Block key(p.search_key.device, p.search_key.stream, p.search_key.size);
key.size = (key.size < AcceleratorAllocatorConfig::max_split_size())
? AcceleratorAllocatorConfig::max_split_size()
key.size = (key.size < CUDAAllocatorConfig::max_split_size())
? CUDAAllocatorConfig::max_split_size()
: key.size;
auto it = pool.blocks.lower_bound(&key);
if (it == pool.blocks.end() || (*it)->stream != p.stream() ||
@ -2986,7 +2989,7 @@ class DeviceCachingAllocator {
--it; // Back up one item. Now on the largest block for the correct
// stream
while ((totalReleased < key.size) &&
((*it)->size >= AcceleratorAllocatorConfig::max_split_size()) &&
((*it)->size >= CUDAAllocatorConfig::max_split_size()) &&
((*it)->stream == p.stream())) {
auto cur = it;
bool is_first = cur == pool.blocks.begin();
@ -3111,7 +3114,7 @@ class DeviceCachingAllocator {
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
.current);
if (block->size >= AcceleratorAllocatorConfig::max_split_size())
if (block->size >= CUDAAllocatorConfig::max_split_size())
stats.oversize_segments.decrease(1);
pool->blocks.erase(block);
delete block;
@ -3738,8 +3741,8 @@ class NativeCachingAllocator : public CUDAAllocator {
auto& md = result.config_metadata;
md.garbage_collection_threshold =
AcceleratorAllocatorConfig::garbage_collection_threshold();
md.max_split_size = AcceleratorAllocatorConfig::max_split_size();
CUDAAllocatorConfig::garbage_collection_threshold();
md.max_split_size = CUDAAllocatorConfig::max_split_size();
md.pinned_num_register_threads =
CUDAAllocatorConfig::pinned_num_register_threads();
md.expandable_segments = CUDAAllocatorConfig::expandable_segments();
@ -3747,10 +3750,9 @@ class NativeCachingAllocator : public CUDAAllocator {
CUDAAllocatorConfig::release_lock_on_cudamalloc();
md.pinned_use_host_register =
CUDAAllocatorConfig::pinned_use_cuda_host_register();
md.last_allocator_settings =
AcceleratorAllocatorConfig::last_allocator_settings();
md.last_allocator_settings = CUDAAllocatorConfig::last_allocator_settings();
md.roundup_power2_divisions =
AcceleratorAllocatorConfig::roundup_power2_divisions();
CUDAAllocatorConfig::roundup_power2_divisions();
return result;
}
@ -4128,10 +4130,49 @@ CUDAAllocator* allocator();
} // namespace CudaMallocAsync
struct BackendStaticInitializer {
// Parses env for backend at load time, duplicating some logic from
// CUDAAllocatorConfig. CUDAAllocatorConfig double-checks it later (at
// runtime). Defers verbose exceptions and error checks, including Cuda
// version checks, to CUDAAllocatorConfig's runtime doublecheck. If this
// works, maybe we should move all of CUDAAllocatorConfig here?
CUDAAllocator* parseEnvForBackend() {
// If the environment variable is set, we use the CudaMallocAsync allocator.
if (CUDAAllocatorConfig::use_async_allocator()) {
return CudaMallocAsync::allocator();
auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
#ifdef USE_ROCM
// convenience for ROCm users to allow either CUDA or HIP env var
if (!val.has_value()) {
val = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
}
#endif
if (val.has_value()) {
const std::string& config = val.value();
std::regex exp("[\\s,]+");
std::sregex_token_iterator it(config.begin(), config.end(), exp, -1);
std::sregex_token_iterator end;
std::vector<std::string> options(it, end);
for (auto option : options) {
std::regex exp2("[:]+");
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
std::sregex_token_iterator end2;
std::vector<std::string> kv(it2, end2);
if (kv.size() >= 2) {
if (kv[0] == "backend") {
#ifdef USE_ROCM
// convenience for ROCm users to allow either CUDA or HIP env var
if (kv[1] ==
"cud"
"aMallocAsync" ||
kv[1] == "hipMallocAsync")
#else
if (kv[1] == "cudaMallocAsync")
#endif
return CudaMallocAsync::allocator();
if (kv[1] == "native")
return &Native::allocator;
}
}
}
}
return &Native::allocator;
}

View File

@ -1,7 +1,6 @@
#pragma once
#include <c10/core/CachingDeviceAllocator.h>
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/cuda/CUDAStream.h>
@ -50,9 +49,10 @@ namespace c10::cuda::CUDACachingAllocator {
// Preserved only for BC reasons
// NOLINTNEXTLINE(misc-unused-using-decls)
using c10::CachingAllocator::kLargeBuffer;
using c10::CachingDeviceAllocator::DeviceStats;
extern const size_t kLargeBuffer;
typedef std::shared_ptr<GatheredContext> (*CreateContextFn)();
// Struct containing info of an allocation block (i.e. a fractional part of a

View File

@ -48,18 +48,14 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
// Note: event destruction done on creating device to avoid creating a
// CUDA context on other devices.
~CUDAEvent() {
try {
if (is_created_) {
CUDAGuard guard(device_index_);
const c10::impl::PyInterpreter* interp =
c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(
at::kCUDA, reinterpret_cast<uintptr_t>(event_));
}
C10_CUDA_CHECK(cudaEventDestroy(event_));
if (is_created_) {
CUDAGuard guard(device_index_);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(
at::kCUDA, reinterpret_cast<uintptr_t>(event_));
}
} catch (...) { /* No throw */
C10_CUDA_CHECK_WARN(cudaEventDestroy(event_));
}
}
@ -67,11 +63,11 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
CUDAEvent& operator=(const CUDAEvent&) = delete;
CUDAEvent(CUDAEvent&& other) noexcept {
moveHelper(std::move(other));
moveHelper(other);
}
CUDAEvent& operator=(CUDAEvent&& other) noexcept {
if (this != &other) {
moveHelper(std::move(other));
moveHelper(other);
}
return *this;
}
@ -266,7 +262,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
is_created_ = true;
}
void moveHelper(CUDAEvent&& other) {
void moveHelper(CUDAEvent& other) {
std::swap(flags_, other.flags_);
std::swap(is_created_, other.is_created_);
std::swap(was_recorded_, other.was_recorded_);

View File

@ -1,4 +1,3 @@
#include <c10/core/AllocatorConfig.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <c10/xpu/XPUCachingAllocator.h>
@ -21,6 +20,8 @@ constexpr size_t kMinBlockSize = 512;
constexpr size_t kSmallSize = 1048576;
// "small" allocations are packed in 2 MiB blocks
constexpr size_t kSmallBuffer = 2097152;
// "large" allocations may be packed in 20 MiB blocks
constexpr size_t kLargeBuffer = 20971520;
// allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kMinLargeAlloc = 10485760;
// round up large allocations to 2 MiB

View File

@ -1001,7 +1001,7 @@ elseif(USE_CUDA)
# 3. Let CMake find it in the default system paths, e.g. /usr/local.
find_library(NVSHMEM_HOST_LIB
# In pip install case, the lib suffix is `.so.3` instead of `.so`
NAMES nvshmem_host nvshmem_host.so.3
NAMES nvshmem_host libnvshmem_host.so.3 NAMES_PER_DIR
HINTS $ENV{NVSHMEM_HOME} ${NVSHMEM_PY_DIR}
PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64
DOC "The location of NVSHMEM host library.")

View File

@ -28,7 +28,7 @@ class NCCLTestBase {
NCCLTestBase(NCCLTestBase&& other) noexcept = default;
std::shared_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() {
::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() {
return pg_;
}
@ -39,7 +39,7 @@ class NCCLTestBase {
void initialize(
int rank,
size_t size,
std::optional<::std::shared_ptr<::c10d::ProcessGroupNCCL>> split_from =
std::optional<::c10::intrusive_ptr<::c10d::ProcessGroupNCCL>> split_from =
std::nullopt) {
store_ = c10::make_intrusive<::c10d::FileStore>(path_, size);
@ -52,13 +52,13 @@ class NCCLTestBase {
opts->split_color = ++color_;
}
#endif
pg_ = std::make_unique<::c10d::ProcessGroupNCCL>(
pg_ = c10::make_intrusive<::c10d::ProcessGroupNCCL>(
store_, rank, size, std::move(opts));
}
protected:
std::string path_;
std::shared_ptr<::c10d::ProcessGroupNCCL> pg_;
::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> pg_;
std::chrono::milliseconds pgTimeout_;
::c10::intrusive_ptr<::c10d::Store> store_;
int color_{1};

View File

@ -24,6 +24,15 @@ set(NATIVERT_TEST_SRCS
${TORCH_ROOT}/torch/nativert/executor/memory/LayoutPlanner.cpp
${TORCH_ROOT}/torch/nativert/executor/memory/LayoutManager.cpp
${TORCH_ROOT}/torch/nativert/executor/memory/AliasAnalyzer.cpp
${TORCH_ROOT}/torch/nativert/executor/Executor.cpp
${TORCH_ROOT}/torch/nativert/kernels/KernelFactory.cpp
${TORCH_ROOT}/torch/nativert/executor/ConstantFolder.cpp
${TORCH_ROOT}/torch/nativert/executor/GraphExecutorBase.cpp
${TORCH_ROOT}/torch/nativert/executor/SerialGraphExecutor.cpp
${TORCH_ROOT}/torch/nativert/executor/ParallelGraphExecutor.cpp
${TORCH_ROOT}/torch/nativert/kernels/AutoFunctionalizeKernel.cpp
${TORCH_ROOT}/torch/nativert/kernels/CallTorchBindKernel.cpp
${TORCH_ROOT}/torch/nativert/kernels/HigherOrderKernel.cpp
)
add_executable(test_nativert

View File

@ -0,0 +1,182 @@
#include <gtest/gtest.h>
#include <fmt/format.h>
#include <torch/nativert/executor/memory/AliasAnalyzer.h>
#include <torch/nativert/graph/Graph.h>
#include <torch/nativert/executor/Executor.h>
#include <torch/nativert/kernels/KernelFactory.h>
using namespace ::testing;
using namespace torch::nativert;
using AliasTestCase = std::tuple<
std::string /* value */,
AllocationLifetime,
bool /* is_alias */,
bool /* is_storage_associated_with_output */,
c10::FastSet<std::string> /* source(s) */>;
class AliasAnalyzerTests : public testing::Test {
void SetUp() override {}
void TearDown() override {
test_cases.clear();
model.clear();
}
public:
void setTestCases(std::vector<AliasTestCase> cases) {
test_cases = std::move(cases);
}
void setModel(std::string m) {
model = std::move(m);
}
void run() {
EXPECT_FALSE(test_cases.empty());
EXPECT_FALSE(model.empty());
ExecutorConfig cfg;
cfg.enableStaticCPUKernels = true;
auto graph = stringToGraph(model);
auto kernels = KernelFactory().initializeNodeKernels(
*graph, nullptr, cfg, {}, nullptr);
auto kernelSchemas = Executor::getKernelSchemas(kernels.nodeKernels);
AliasAnalyzer analyzer(*graph, kernelSchemas);
for (
auto& [value, lifetime, is_alias, is_storage_associated_with_output, srcs] :
test_cases) {
LOG(INFO) << fmt::format(
"running test: value={}, lifetime=({}, {}), is_alias={}, is_storage_associated_with_output={}, src={}",
value,
lifetime.start,
lifetime.end,
is_alias,
is_storage_associated_with_output,
srcs.empty() ? "{}"
: std::accumulate(
srcs.begin(),
srcs.end(),
std::string{},
[](std::string cur, const std::string& src) {
cur.append(",");
cur.append(src);
return cur;
}));
auto* v = graph->getValue(value);
EXPECT_EQ(analyzer.lifetime(v), lifetime);
EXPECT_EQ(analyzer.is_alias(v), is_alias);
EXPECT_EQ(
analyzer.is_storage_associated_with_output(v),
is_storage_associated_with_output);
const auto* resolved_srcs = analyzer.get_sources_of_alias(v);
if (resolved_srcs /* ensure set equality between *resolved_srcs and srcs */) {
EXPECT_FALSE(srcs.empty());
EXPECT_EQ(resolved_srcs->size(), srcs.size());
for (const auto& resolved_src : *resolved_srcs) {
EXPECT_TRUE(srcs.erase(std::string(resolved_src->name())) == 1);
}
EXPECT_TRUE(srcs.empty());
} else {
EXPECT_TRUE(srcs.empty());
}
}
}
private:
std::string model;
std::vector<AliasTestCase> test_cases;
};
TEST_F(AliasAnalyzerTests, TestNoAlias) {
setModel(R"(
graph(%y0, %y1):
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
%res = torch.ops.aten.clone.default(self=%out_t, memory_format=None)
return (%res))");
setTestCases({
{"out_t", AllocationLifetime(1, 2), false, false, {}},
{"res", AllocationLifetime(2, 3), false, true, {}},
});
run();
}
TEST_F(AliasAnalyzerTests, TestSimpleAlias) {
setModel(R"(
graph(%y0, %y1):
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
%res = torch.ops.aten.slice.Tensor(self=%out_t, dim=1, start=0, end=0, step=1)
return (%res))");
setTestCases({
{"out_t", AllocationLifetime(1, 3), false, true, {}},
{"res", AllocationLifetime(2, 3), true, false, {"out_t"}},
});
run();
}
TEST_F(AliasAnalyzerTests, TestDeepAlias) {
setModel(R"(
graph(%y0, %y1):
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
%a1 = torch.ops.aten.slice.Tensor(self=%out_t, dim=1, start=0, end=0, step=1)
%res = torch.ops.aten.slice.Tensor(self=%a1, dim=1, start=0, end=0, step=1)
return (%res))");
setTestCases({
{"out_t", AllocationLifetime(1, 4), false, true, {}},
{"a1", AllocationLifetime(2, 4), true, false, {"out_t"}},
{"res", AllocationLifetime(3, 4), true, false, {"out_t"}},
});
run();
}
TEST_F(AliasAnalyzerTests, TestPackedListUnpack) {
setModel(R"(
graph(%a, %b, %c, %d):
%input_list[] = prim.ListPack(l0=%a, l1=%b, l2=%c, l3=%d)
%x0, %x1, %x2, %x3 = prim.ListUnpack(input=%input_list)
return (%x1, %x3))");
setTestCases({
{"a", AllocationLifetime(0, 2), false, false, {}},
{"x0", AllocationLifetime(2, 2), true, false, {"a"}},
{"b", AllocationLifetime(0, 3), false, true, {}},
{"x1", AllocationLifetime(2, 3), true, false, {"b"}},
{"c", AllocationLifetime(0, 2), false, false, {}},
{"x2", AllocationLifetime(2, 2), true, false, {"c"}},
{"d", AllocationLifetime(0, 3), false, true, {}},
{"x3", AllocationLifetime(2, 3), true, false, {"d"}},
});
run();
}
TEST_F(AliasAnalyzerTests, TestAmbiguousSourceOfAlias) {
setModel(R"(
graph(%y0, %y1):
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
%out_t2 = torch.ops.aten.matmul.default(self=%y0, other=%y1)
%a1 = prim.VarStack(l0=%out_t, l1=%out_t2)
%res = torch.ops.aten.slice.Tensor(self=%a1, dim=1, start=0, end=0, step=1)
return (%res))");
setTestCases({
{"out_t", AllocationLifetime(1, 5), false, true, {}},
{"out_t2", AllocationLifetime(2, 5), false, true, {}},
{"a1", AllocationLifetime(3, 5), true, false, {"out_t", "out_t2"}},
{"res", AllocationLifetime(4, 5), true, false, {"out_t", "out_t2"}},
});
run();
}

View File

@ -554,21 +554,6 @@ class TestNew2dParallelTraining(DTensorTestBase):
p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local()
self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")
@with_comms
@skip_if_lt_x_gpu(4)
def test_raise_invalid_tp_composition(self):
with self.assertRaisesRegex(
RuntimeError, r"Found TP device_mesh on the \d dimension of its parent mesh"
):
mesh_2d = init_device_mesh(
self.device_type, (2, self.world_size // 2), mesh_dim_names=("tp", "dp")
)
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
parallelize_module(SimpleModel().cuda(), mesh_2d["tp"], parallelize_plan)
@with_comms
@skip_if_lt_x_gpu(4)
def test_2d_fsdp_state_enable_extension(self):

View File

@ -201,6 +201,17 @@ class Dist2MultiProcessTestCase(MultiProcessTestCase):
out_range = out[i * 10 : (i + 1) * 10]
self.assertEqual(out_range, torch.full_like(out_range, i + 1))
def test_group_split(self) -> None:
group = self.new_group()
subgroup = group.split_group([0], timeout=timedelta(seconds=30))
if self.rank == 0:
assert subgroup is not None
self.assertEqual(subgroup.size(), 1)
backend = subgroup._get_backend(self.device)
self.assertEqual(backend.options._timeout, timedelta(seconds=30))
else:
self.assertEqual(subgroup, None)
class ProcessGroupGlooTest(Dist2MultiProcessTestCase):
device = torch.device("cpu")

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/list_tests.py b/test/dynamo/cpython/3_13/list_tests.py
index dbc5ef4f9f2..239b75f74cc 100644
index dbc5ef4f9f2..70e24036f74 100644
--- a/test/dynamo/cpython/3_13/list_tests.py
+++ b/test/dynamo/cpython/3_13/list_tests.py
@@ -1,3 +1,53 @@
@@ -1,3 +1,56 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/list_tests.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -56,7 +59,7 @@ index dbc5ef4f9f2..239b75f74cc 100644
"""
Tests common to list and UserList.UserList
"""
@@ -5,7 +55,7 @@ Tests common to list and UserList.UserList
@@ -5,7 +58,7 @@ Tests common to list and UserList.UserList
import sys
from functools import cmp_to_key
@ -65,7 +68,7 @@ index dbc5ef4f9f2..239b75f74cc 100644
from test.support import ALWAYS_EQ, NEVER_EQ, get_c_recursion_limit
@@ -119,10 +169,6 @@ class CommonTest(seq_tests.CommonTest):
@@ -119,10 +172,6 @@ class CommonTest(seq_tests.CommonTest):
a[-1] = 9
self.assertEqual(a, self.type2test([5,6,7,8,9]))

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/list_tests.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/mapping_tests.py b/test/dynamo/cpython/3_13/mapping_tests.py
index ed89a81a6ea..eed59a68e94 100644
index ed89a81a6ea..10fc6e7e467 100644
--- a/test/dynamo/cpython/3_13/mapping_tests.py
+++ b/test/dynamo/cpython/3_13/mapping_tests.py
@@ -1,10 +1,61 @@
@@ -1,10 +1,64 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/mapping_tests.py
+
+import sys
+import torch
+import torch._dynamo.test_case

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/mapping_tests.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/seq_tests.py b/test/dynamo/cpython/3_13/seq_tests.py
index 719c9434a16..4325892276d 100644
index 719c9434a16..2c502cda4f6 100644
--- a/test/dynamo/cpython/3_13/seq_tests.py
+++ b/test/dynamo/cpython/3_13/seq_tests.py
@@ -1,3 +1,54 @@
@@ -1,3 +1,57 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/seq_tests.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -57,7 +60,7 @@ index 719c9434a16..4325892276d 100644
"""
Tests common to tuple, list and UserList.UserList
"""
@@ -95,7 +146,7 @@ class LyingList(list):
@@ -95,7 +149,7 @@ class LyingList(list):
def __iter__(self):
yield 1

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/seq_tests.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_baseexception.py b/test/dynamo/cpython/3_13/test_baseexception.py
index e599b02c17d..3dc102e3b8a 100644
index e599b02c17d..750d7a84fb4 100644
--- a/test/dynamo/cpython/3_13/test_baseexception.py
+++ b/test/dynamo/cpython/3_13/test_baseexception.py
@@ -1,10 +1,61 @@
@@ -1,10 +1,64 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_baseexception.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -65,7 +68,7 @@ index e599b02c17d..3dc102e3b8a 100644
"""Tests for anything relating to exception objects themselves (e.g.,
inheritance hierarchy)"""
@@ -78,9 +129,6 @@ class ExceptionClassTests(unittest.TestCase):
@@ -78,9 +132,6 @@ class ExceptionClassTests(unittest.TestCase):
last_depth = depth
finally:
inheritance_tree.close()
@ -75,7 +78,7 @@ index e599b02c17d..3dc102e3b8a 100644
self.assertEqual(len(exc_set), 0, "%s not accounted for" % exc_set)
interface_tests = ("length", "args", "str", "repr")
@@ -142,7 +190,7 @@ class ExceptionClassTests(unittest.TestCase):
@@ -142,7 +193,7 @@ class ExceptionClassTests(unittest.TestCase):
gc.collect()
@ -84,7 +87,7 @@ index e599b02c17d..3dc102e3b8a 100644
"""Test usage of exceptions"""
@@ -208,5 +256,5 @@ class UsageTests(unittest.TestCase):
@@ -208,5 +259,5 @@ class UsageTests(unittest.TestCase):
self.catch_fails("spam")

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_baseexception.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_cmath.py b/test/dynamo/cpython/3_13/test_cmath.py
index a96a5780b31..883e87a0733 100644
index a96a5780b31..37fb665d97d 100644
--- a/test/dynamo/cpython/3_13/test_cmath.py
+++ b/test/dynamo/cpython/3_13/test_cmath.py
@@ -1,5 +1,55 @@
@@ -1,5 +1,58 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_cmath.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -59,7 +62,7 @@ index a96a5780b31..883e87a0733 100644
from test.test_math import parse_testfile, test_file
import test.test_math as test_math
import unittest
@@ -50,7 +100,7 @@ complex_nans = [complex(x, y) for x, y in [
@@ -50,7 +103,7 @@ complex_nans = [complex(x, y) for x, y in [
(INF, NAN)
]]
@ -68,7 +71,7 @@ index a96a5780b31..883e87a0733 100644
# list of all functions in cmath
test_functions = [getattr(cmath, fname) for fname in [
'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh',
@@ -66,6 +116,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase):
@@ -66,6 +119,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase):
def tearDown(self):
self.test_values.close()
@ -108,7 +111,7 @@ index a96a5780b31..883e87a0733 100644
def rAssertAlmostEqual(self, a, b, rel_err = 2e-15, abs_err = 5e-323,
msg=None):
"""Fail if the two floating-point numbers are not almost equal.
@@ -590,4 +673,4 @@ class IsCloseTests(test_math.IsCloseTests):
@@ -590,4 +676,4 @@ class IsCloseTests(test_math.IsCloseTests):
if __name__ == "__main__":

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_cmath.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py
index 6ff1a8ab29d..ab5bd3dab62 100644
index 6ff1a8ab29d..cda348d2f37 100644
--- a/test/dynamo/cpython/3_13/test_complex.py
+++ b/test/dynamo/cpython/3_13/test_complex.py
@@ -1,16 +1,143 @@
@@ -1,16 +1,146 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_complex.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -151,7 +154,7 @@ index 6ff1a8ab29d..ab5bd3dab62 100644
INF = float("inf")
NAN = float("nan")
DBL_MAX = sys.float_info.max
@@ -45,7 +172,40 @@ class WithComplex:
@@ -45,7 +175,40 @@ class WithComplex:
def __complex__(self):
return self.value
@ -193,7 +196,7 @@ index 6ff1a8ab29d..ab5bd3dab62 100644
def assertAlmostEqual(self, a, b):
if isinstance(a, complex):
@@ -74,6 +234,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
@@ -74,6 +237,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
# check that relative difference < eps
self.assertTrue(abs((x-y)/y) < eps)
@ -223,7 +226,7 @@ index 6ff1a8ab29d..ab5bd3dab62 100644
def assertClose(self, x, y, eps=1e-9):
"""Return true iff complexes x and y "are close"."""
self.assertCloseAbs(x.real, y.real, eps)
@@ -855,4 +1038,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
@@ -855,4 +1041,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
if __name__ == "__main__":

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_complex.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_contextlib.py b/test/dynamo/cpython/3_13/test_contextlib.py
index cf651959803..6a17bc719eb 100644
index cf651959803..51fd083b112 100644
--- a/test/dynamo/cpython/3_13/test_contextlib.py
+++ b/test/dynamo/cpython/3_13/test_contextlib.py
@@ -1,3 +1,54 @@
@@ -1,3 +1,57 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_contextlib.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -57,7 +60,7 @@ index cf651959803..6a17bc719eb 100644
"""Unit tests for contextlib.py, and other context managers."""
import io
@@ -14,7 +65,7 @@ from test.support.testcase import ExceptionIsLikeMixin
@@ -14,7 +68,7 @@ from test.support.testcase import ExceptionIsLikeMixin
import weakref
@ -66,7 +69,7 @@ index cf651959803..6a17bc719eb 100644
def test_enter(self):
class DefaultEnter(AbstractContextManager):
@@ -67,7 +118,7 @@ class TestAbstractContextManager(unittest.TestCase):
@@ -67,7 +121,7 @@ class TestAbstractContextManager(unittest.TestCase):
self.assertFalse(issubclass(NoExit, AbstractContextManager))
@ -75,7 +78,7 @@ index cf651959803..6a17bc719eb 100644
def test_contextmanager_plain(self):
state = []
@@ -396,7 +447,7 @@ def woohoo():
@@ -396,7 +450,7 @@ def woohoo():
self.assertEqual(depth, 0)
@ -84,7 +87,7 @@ index cf651959803..6a17bc719eb 100644
@support.requires_docstrings
def test_instance_docs(self):
@@ -430,7 +481,7 @@ class ClosingTestCase(unittest.TestCase):
@@ -430,7 +484,7 @@ class ClosingTestCase(unittest.TestCase):
self.assertEqual(state, [1])
@ -93,7 +96,7 @@ index cf651959803..6a17bc719eb 100644
def test_nullcontext(self):
class C:
pass
@@ -439,7 +490,7 @@ class NullcontextTestCase(unittest.TestCase):
@@ -439,7 +493,7 @@ class NullcontextTestCase(unittest.TestCase):
self.assertIs(c_in, c)
@ -102,7 +105,7 @@ index cf651959803..6a17bc719eb 100644
def testWithOpen(self):
tfn = tempfile.mktemp()
@@ -457,7 +508,7 @@ class FileContextTestCase(unittest.TestCase):
@@ -457,7 +511,7 @@ class FileContextTestCase(unittest.TestCase):
finally:
os_helper.unlink(tfn)
@ -111,7 +114,7 @@ index cf651959803..6a17bc719eb 100644
def boilerPlate(self, lock, locked):
self.assertFalse(locked())
@@ -520,7 +571,7 @@ class mycontext(ContextDecorator):
@@ -520,7 +574,7 @@ class mycontext(ContextDecorator):
return self.catch
@ -120,7 +123,7 @@ index cf651959803..6a17bc719eb 100644
@support.requires_docstrings
def test_instance_docs(self):
@@ -680,7 +731,7 @@ class TestContextDecorator(unittest.TestCase):
@@ -680,7 +734,7 @@ class TestContextDecorator(unittest.TestCase):
self.assertEqual(state, [1, 'something else', 999])
@ -129,7 +132,7 @@ index cf651959803..6a17bc719eb 100644
exit_stack = None
@support.requires_docstrings
@@ -1141,7 +1192,7 @@ class TestBaseExitStack:
@@ -1141,7 +1195,7 @@ class TestBaseExitStack:
self.assertIs(exc.__cause__, exc.__context__)
@ -138,7 +141,7 @@ index cf651959803..6a17bc719eb 100644
exit_stack = ExitStack
callback_error_internal_frames = [
('__exit__', 'raise exc'),
@@ -1149,7 +1200,7 @@ class TestExitStack(TestBaseExitStack, unittest.TestCase):
@@ -1149,7 +1203,7 @@ class TestExitStack(TestBaseExitStack, unittest.TestCase):
]
@ -147,7 +150,7 @@ index cf651959803..6a17bc719eb 100644
redirect_stream = None
orig_stream = None
@@ -1206,19 +1257,19 @@ class TestRedirectStream:
@@ -1206,19 +1260,19 @@ class TestRedirectStream:
self.assertEqual(s, "Hello World!\n")
@ -170,7 +173,7 @@ index cf651959803..6a17bc719eb 100644
@support.requires_docstrings
def test_instance_docs(self):
@@ -1315,7 +1366,7 @@ class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase):
@@ -1315,7 +1369,7 @@ class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase):
)
@ -179,7 +182,7 @@ index cf651959803..6a17bc719eb 100644
def make_relative_path(self, *parts):
return os.path.join(
os.path.dirname(os.path.realpath(__file__)),
@@ -1331,6 +1382,7 @@ class TestChdir(unittest.TestCase):
@@ -1331,6 +1385,7 @@ class TestChdir(unittest.TestCase):
self.assertEqual(os.getcwd(), target)
self.assertEqual(os.getcwd(), old_cwd)
@ -187,7 +190,7 @@ index cf651959803..6a17bc719eb 100644
def test_reentrant(self):
old_cwd = os.getcwd()
target1 = self.make_relative_path('data')
@@ -1363,4 +1415,4 @@ class TestChdir(unittest.TestCase):
@@ -1363,4 +1418,4 @@ class TestChdir(unittest.TestCase):
if __name__ == "__main__":

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_contextlib.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_dict.py b/test/dynamo/cpython/3_13/test_dict.py
index 4729132c5a5..14f829c1715 100644
index 4c095464cbb..fcda6484ea6 100644
--- a/test/dynamo/cpython/3_13/test_dict.py
+++ b/test/dynamo/cpython/3_13/test_dict.py
@@ -1,3 +1,57 @@
@@ -1,3 +1,60 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_dict.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -60,7 +63,7 @@ index 4729132c5a5..14f829c1715 100644
import collections
import collections.abc
import gc
@@ -11,7 +65,7 @@ from test import support
@@ -11,7 +68,7 @@ from test import support
from test.support import import_helper, get_c_recursion_limit
@ -69,15 +72,48 @@ index 4729132c5a5..14f829c1715 100644
def test_invalid_keyword_arguments(self):
class Custom(dict):
@@ -265,6 +319,7 @@ class DictTest(unittest.TestCase):
@@ -265,39 +322,7 @@ class DictTest(unittest.TestCase):
self.assertRaises(ValueError, {}.update, [(1, 2, 3)])
- def test_update_shared_keys(self):
- class MyClass: pass
-
- # Subclass str to enable us to create an object during the
- # dict.update() call.
- class MyStr(str):
- def __hash__(self):
- return super().__hash__()
-
- def __eq__(self, other):
- # Create an object that shares the same PyDictKeysObject as
- # obj.__dict__.
- obj2 = MyClass()
- obj2.a = "a"
- obj2.b = "b"
- obj2.c = "c"
- return super().__eq__(other)
-
- obj = MyClass()
- obj.a = "a"
- obj.b = "b"
-
- x = {}
- x[MyStr("a")] = MyStr("a")
-
- # gh-132617: this previously raised "dict mutated during update" error
- x.update(obj.__dict__)
-
- self.assertEqual(x, {
- MyStr("a"): "a",
- "b": "b",
- })
-
+ @unittest.skip("test hangs")
def test_fromkeys(self):
self.assertEqual(dict.fromkeys('abc'), {'a':None, 'b':None, 'c':None})
d = {}
@@ -477,7 +532,7 @@ class DictTest(unittest.TestCase):
@@ -510,7 +535,7 @@ class DictTest(unittest.TestCase):
for copymode in -1, +1:
# -1: b has same structure as a
# +1: b is a.copy()
@ -86,7 +122,7 @@ index 4729132c5a5..14f829c1715 100644
size = 2**log2size
a = {}
b = {}
@@ -1006,18 +1061,6 @@ class DictTest(unittest.TestCase):
@@ -1039,18 +1064,6 @@ class DictTest(unittest.TestCase):
pass
self._tracked(MyDict())
@ -105,7 +141,7 @@ index 4729132c5a5..14f829c1715 100644
def make_shared_key_dict(self, n):
class C:
pass
@@ -1622,7 +1665,7 @@ class DictTest(unittest.TestCase):
@@ -1655,7 +1668,7 @@ class DictTest(unittest.TestCase):
self.assertGreaterEqual(eq_count, 1)
@ -114,7 +150,7 @@ index 4729132c5a5..14f829c1715 100644
# Test _PyDict_GetItem_KnownHash()
@support.cpython_only
@@ -1666,4 +1709,4 @@ class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
@@ -1699,4 +1712,4 @@ class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
if __name__ == "__main__":

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_dict.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_exception_variations.py b/test/dynamo/cpython/3_13/test_exception_variations.py
index a83a41d2975..be432089e3a 100644
index a83a41d2975..c2d6eb3a41a 100644
--- a/test/dynamo/cpython/3_13/test_exception_variations.py
+++ b/test/dynamo/cpython/3_13/test_exception_variations.py
@@ -1,7 +1,59 @@
@@ -1,7 +1,62 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_exception_variations.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -53,17 +56,17 @@ index a83a41d2975..be432089e3a 100644
+# Add the custom finder to sys.meta_path
+sys.meta_path.insert(0, RedirectImportFinder())
+
+
+# ======= END DYNAMO PATCH =======
-class ExceptTestCases(unittest.TestCase):
+# ======= END DYNAMO PATCH =======
+
+import unittest
+
+class ExceptTestCases(__TestCase):
def test_try_except_else_finally(self):
hit_except = False
hit_else = False
@@ -294,282 +346,5 @@ class ExceptTestCases(unittest.TestCase):
@@ -294,282 +349,5 @@ class ExceptTestCases(unittest.TestCase):
self.assertTrue(hit_except)

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_exception_variations.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -0,0 +1,152 @@
diff --git a/test/dynamo/cpython/3_13/test_exceptions.py b/test/dynamo/cpython/3_13/test_exceptions.py
index c91f6662948..0ded70db3c7 100644
--- a/test/dynamo/cpython/3_13/test_exceptions.py
+++ b/test/dynamo/cpython/3_13/test_exceptions.py
@@ -1,3 +1,59 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_exceptions.py
+
+import sys
+import torch
+import torch._dynamo.test_case
+import unittest
+from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import (
+ run_tests,
+ xfailIfTorchDynamo,
+)
+
+__TestCase = CPythonTestCase
+
+# redirect import statements
+import sys
+import importlib.abc
+
+redirect_imports = (
+ "test.mapping_tests",
+ "test.typinganndata",
+ "test.test_grammar",
+ "test.test_math",
+ "test.test_iter",
+ "test.typinganndata.ann_module",
+)
+
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
+ def find_spec(self, fullname, path, target=None):
+ # Check if the import is the problematic one
+ if fullname in redirect_imports:
+ try:
+ # Attempt to import the standalone module
+ name = fullname.removeprefix("test.")
+ r = importlib.import_module(name)
+ # Redirect the module in sys.modules
+ sys.modules[fullname] = r
+ # Return a module spec from the found module
+ return importlib.util.find_spec(name)
+ except ImportError:
+ return None
+ return None
+
+# Add the custom finder to sys.meta_path
+sys.meta_path.insert(0, RedirectImportFinder())
+
+
+# ======= END DYNAMO PATCH =======
+
# Python test set -- part 5, built-in exceptions
import copy
@@ -45,7 +101,7 @@ class BrokenStrException(Exception):
# XXX This is not really enough, each *operation* should be tested!
-class ExceptionTests(unittest.TestCase):
+class ExceptionTests(__TestCase):
def raise_catch(self, exc, excname):
with self.subTest(exc=exc, excname=excname):
@@ -1844,7 +1900,7 @@ class ExceptionTests(unittest.TestCase):
self.assertIn(b'MemoryError', err)
-class NameErrorTests(unittest.TestCase):
+class NameErrorTests(__TestCase):
def test_name_error_has_name(self):
try:
bluch
@@ -1894,7 +1950,7 @@ class NameErrorTests(unittest.TestCase):
# Note: name suggestion tests live in `test_traceback`.
-class AttributeErrorTests(unittest.TestCase):
+class AttributeErrorTests(__TestCase):
def test_attributes(self):
# Setting 'attr' should not be a problem.
exc = AttributeError('Ouch!')
@@ -1937,7 +1993,7 @@ class AttributeErrorTests(unittest.TestCase):
# Note: name suggestion tests live in `test_traceback`.
-class ImportErrorTests(unittest.TestCase):
+class ImportErrorTests(__TestCase):
def test_attributes(self):
# Setting 'name' and 'path' should not be a problem.
@@ -2024,7 +2080,7 @@ def run_script(source):
_rc, _out, err = script_helper.assert_python_failure('-Wd', '-X', 'utf8', TESTFN)
return err.decode('utf-8').splitlines()
-class AssertionErrorTests(unittest.TestCase):
+class AssertionErrorTests(__TestCase):
def tearDown(self):
unlink(TESTFN)
@@ -2159,7 +2215,7 @@ class AssertionErrorTests(unittest.TestCase):
@support.force_not_colorized_test_class
-class SyntaxErrorTests(unittest.TestCase):
+class SyntaxErrorTests(__TestCase):
maxDiff = None
@force_not_colorized
@@ -2290,6 +2346,7 @@ class SyntaxErrorTests(unittest.TestCase):
err = run_script(b"\x89")
self.assertIn("SyntaxError: Non-UTF-8 code starting with '\\x89' in file", err[-1])
+
def test_string_source(self):
def try_compile(source):
with self.assertRaises(SyntaxError) as cm:
@@ -2405,7 +2462,7 @@ class SyntaxErrorTests(unittest.TestCase):
self.assertRaises(TypeError, SyntaxError, "bad bad", args)
-class TestInvalidExceptionMatcher(unittest.TestCase):
+class TestInvalidExceptionMatcher(__TestCase):
def test_except_star_invalid_exception_type(self):
with self.assertRaises(TypeError):
try:
@@ -2420,7 +2477,7 @@ class TestInvalidExceptionMatcher(unittest.TestCase):
pass
-class PEP626Tests(unittest.TestCase):
+class PEP626Tests(__TestCase):
def lineno_after_raise(self, f, *expected):
try:
@@ -2529,5 +2586,5 @@ class PEP626Tests(unittest.TestCase):
1/0
self.lineno_after_raise(after_with, 1, 1)
-if __name__ == '__main__':
- unittest.main()
+if __name__ == "__main__":
+ run_tests()

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_exceptions.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_float.py b/test/dynamo/cpython/3_13/test_float.py
index 97f951f1299..ce2c46777e0 100644
index 87af79eb446..9313a1a63d7 100644
--- a/test/dynamo/cpython/3_13/test_float.py
+++ b/test/dynamo/cpython/3_13/test_float.py
@@ -1,3 +1,54 @@
@@ -1,3 +1,57 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_float.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -57,7 +60,7 @@ index 97f951f1299..ce2c46777e0 100644
import fractions
import operator
import os
@@ -8,11 +59,84 @@ import time
@@ -8,11 +62,84 @@ import time
import unittest
from test import support
@ -147,7 +150,7 @@ index 97f951f1299..ce2c46777e0 100644
from math import isinf, isnan, copysign, ldexp
import math
@@ -35,7 +159,7 @@ class FloatSubclass(float):
@@ -35,7 +162,7 @@ class FloatSubclass(float):
class OtherFloatSubclass(float):
pass
@ -156,7 +159,7 @@ index 97f951f1299..ce2c46777e0 100644
def test_float(self):
self.assertEqual(float(3.14), 3.14)
@@ -620,7 +744,7 @@ class GeneralFloatCases(unittest.TestCase):
@@ -620,7 +747,7 @@ class GeneralFloatCases(unittest.TestCase):
@unittest.skipUnless(hasattr(float, "__getformat__"), "requires __getformat__")
@ -165,7 +168,7 @@ index 97f951f1299..ce2c46777e0 100644
def test_getformat(self):
self.assertIn(float.__getformat__('double'),
['unknown', 'IEEE, big-endian', 'IEEE, little-endian'])
@@ -645,7 +769,7 @@ LE_FLOAT_NAN = bytes(reversed(BE_FLOAT_NAN))
@@ -645,7 +772,7 @@ LE_FLOAT_NAN = bytes(reversed(BE_FLOAT_NAN))
# is accident (today).
# let's also try to guarantee that -0.0 and 0.0 don't get confused.
@ -174,7 +177,7 @@ index 97f951f1299..ce2c46777e0 100644
@support.requires_IEEE_754
def test_double_specials_do_unpack(self):
@@ -670,7 +794,7 @@ class IEEEFormatTestCase(unittest.TestCase):
@@ -670,7 +797,7 @@ class IEEEFormatTestCase(unittest.TestCase):
self.assertEqual(struct.pack("<f", 3.40282356e38), struct.pack("<f", FLT_MAX))
self.assertEqual(struct.pack("<f", -3.40282356e38), struct.pack("<f", -FLT_MAX))
@ -183,7 +186,16 @@ index 97f951f1299..ce2c46777e0 100644
def test_format(self):
# these should be rewritten to use both format(x, spec) and
@@ -767,7 +891,7 @@ class FormatTestCase(unittest.TestCase):
@@ -724,8 +851,6 @@ class FormatTestCase(unittest.TestCase):
self.assertEqual(format(INF, 'F'), 'INF')
@support.requires_IEEE_754
- @unittest.skipUnless(sys.float_repr_style == 'short',
- "applies only when using short float repr style")
def test_format_testfile(self):
with open(format_testfile, encoding="utf-8") as testfile:
for line in testfile:
@@ -769,7 +894,7 @@ class FormatTestCase(unittest.TestCase):
self.assertEqual(format(-123.34, '00.10e'), '-1.2334000000e+02')
self.assertEqual(format(-123.34, '00.10g'), '-123.34')
@ -192,7 +204,7 @@ index 97f951f1299..ce2c46777e0 100644
def test_repr(self):
with open(os.path.join(os.path.split(__file__)[0],
'mathdata',
@@ -832,7 +956,29 @@ class ReprTestCase(unittest.TestCase):
@@ -834,7 +959,29 @@ class ReprTestCase(unittest.TestCase):
self.assertEqual(repr(float(negs)), str(float(negs)))
@support.requires_IEEE_754
@ -223,7 +235,7 @@ index 97f951f1299..ce2c46777e0 100644
def test_inf_nan(self):
self.assertRaises(OverflowError, round, INF)
@@ -955,7 +1101,7 @@ class RoundTestCase(unittest.TestCase, FloatsAreIdenticalMixin):
@@ -957,7 +1104,7 @@ class RoundTestCase(unittest.TestCase, FloatsAreIdenticalMixin):
# Beginning with Python 2.6 float has cross platform compatible
# ways to create and represent inf and nan
@ -232,7 +244,7 @@ index 97f951f1299..ce2c46777e0 100644
def test_inf_from_str(self):
self.assertTrue(isinf(float("inf")))
self.assertTrue(isinf(float("+inf")))
@@ -1056,12 +1202,35 @@ class InfNanTest(unittest.TestCase):
@@ -1058,12 +1205,35 @@ class InfNanTest(unittest.TestCase):
fromHex = float.fromhex
toHex = float.hex
@ -269,7 +281,7 @@ index 97f951f1299..ce2c46777e0 100644
def identical(self, x, y):
self.assertFloatsAreIdentical(x, y)
@@ -1500,5 +1669,5 @@ class HexFloatTestCase(FloatsAreIdenticalMixin, unittest.TestCase):
@@ -1502,5 +1672,5 @@ class HexFloatTestCase(FloatsAreIdenticalMixin, unittest.TestCase):
self.assertEqual(getattr(f, 'foo', 'none'), 'bar')

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_float.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,8 +1,8 @@
diff --git a/test/dynamo/cpython/3_13/test_generator_stop.py b/test/dynamo/cpython/3_13/test_generator_stop.py
index bc235ceb00e..cb2a85255cb 100644
index bc235ceb00e..e3ff8d346a7 100644
--- a/test/dynamo/cpython/3_13/test_generator_stop.py
+++ b/test/dynamo/cpython/3_13/test_generator_stop.py
@@ -1,9 +1,60 @@
@@ -1,9 +1,63 @@
from __future__ import generator_stop
+# ======= BEGIN Dynamo patch =======
@ -11,6 +11,9 @@ index bc235ceb00e..cb2a85255cb 100644
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_generator_stop.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -64,7 +67,7 @@ index bc235ceb00e..cb2a85255cb 100644
def test_stopiteration_wrapping(self):
def f():
raise StopIteration
@@ -30,5 +81,5 @@ class TestPEP479(unittest.TestCase):
@@ -30,5 +84,5 @@ class TestPEP479(unittest.TestCase):
'were not properly set')

View File

@ -6,6 +6,9 @@ from __future__ import generator_stop
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_generator_stop.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_generators.py b/test/dynamo/cpython/3_13/test_generators.py
index e48d79d34f4..40a02d644a9 100644
index e48d79d34f4..a48da0914b9 100644
--- a/test/dynamo/cpython/3_13/test_generators.py
+++ b/test/dynamo/cpython/3_13/test_generators.py
@@ -1,3 +1,53 @@
@@ -1,3 +1,56 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_generators.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -56,7 +59,7 @@ index e48d79d34f4..40a02d644a9 100644
import copy
import gc
import pickle
@@ -22,7 +72,7 @@ except ImportError:
@@ -22,7 +75,7 @@ except ImportError:
@unittest.skipUnless(_testcapi is not None and
hasattr(_testcapi, "raise_SIGINT_then_send_None"),
"needs _testcapi.raise_SIGINT_then_send_None")
@ -65,7 +68,7 @@ index e48d79d34f4..40a02d644a9 100644
def generator1(self):
return (yield from self.generator2())
@@ -46,7 +96,7 @@ class SignalAndYieldFromTest(unittest.TestCase):
@@ -46,7 +99,7 @@ class SignalAndYieldFromTest(unittest.TestCase):
self.assertEqual(exc.value, "PASSED")
@ -74,7 +77,7 @@ index e48d79d34f4..40a02d644a9 100644
def test_frame_resurrect(self):
# A generator frame can be resurrected by a generator's finalization.
@@ -113,7 +163,7 @@ class FinalizationTest(unittest.TestCase):
@@ -113,7 +166,7 @@ class FinalizationTest(unittest.TestCase):
self.assertEqual(cm.exception.value, 2)
@ -83,7 +86,7 @@ index e48d79d34f4..40a02d644a9 100644
def test_name(self):
def func():
@@ -246,8 +296,31 @@ class GeneratorTest(unittest.TestCase):
@@ -246,8 +299,31 @@ class GeneratorTest(unittest.TestCase):
#This should not raise
loop()
@ -116,7 +119,7 @@ index e48d79d34f4..40a02d644a9 100644
iterables = [
range(0),
range(20),
@@ -319,7 +392,7 @@ class ModifyUnderlyingIterableTest(unittest.TestCase):
@@ -319,7 +395,7 @@ class ModifyUnderlyingIterableTest(unittest.TestCase):
self.process_tests(get_generator_genfunc)
@ -125,7 +128,7 @@ index e48d79d34f4..40a02d644a9 100644
# Tests for the issue #23353: check that the currently handled exception
# is correctly saved/restored in PyEval_EvalFrameEx().
@@ -528,7 +601,7 @@ class ExceptionTest(unittest.TestCase):
@@ -528,7 +604,7 @@ class ExceptionTest(unittest.TestCase):
self.assertEqual(cm.exception.value.value, 2)
@ -134,7 +137,7 @@ index e48d79d34f4..40a02d644a9 100644
def test_close_no_return_value(self):
def f():
@@ -630,90 +703,7 @@ class GeneratorCloseTest(unittest.TestCase):
@@ -630,90 +706,7 @@ class GeneratorCloseTest(unittest.TestCase):
self.assertIsNone(f_wr())
@ -226,7 +229,7 @@ index e48d79d34f4..40a02d644a9 100644
def test_exception_context_with_yield(self):
def f():
@@ -812,7 +802,7 @@ class GeneratorThrowTest(unittest.TestCase):
@@ -812,7 +805,7 @@ class GeneratorThrowTest(unittest.TestCase):
gen.throw(ValueError)
@ -235,7 +238,7 @@ index e48d79d34f4..40a02d644a9 100644
def check_stack_names(self, frame, expected):
names = []
@@ -861,7 +851,7 @@ class GeneratorStackTraceTest(unittest.TestCase):
@@ -861,7 +854,7 @@ class GeneratorStackTraceTest(unittest.TestCase):
self.check_yield_from_example(call_throw)
@ -244,7 +247,7 @@ index e48d79d34f4..40a02d644a9 100644
def test_generator_gi_yieldfrom(self):
def a():
self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_RUNNING)
@@ -2752,21 +2742,27 @@ test_generators just happened to be the test that drew these out.
@@ -2752,21 +2745,27 @@ test_generators just happened to be the test that drew these out.
"""

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_generators.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_int.py b/test/dynamo/cpython/3_13/test_int.py
index 48825f46911..4ab200372ea 100644
index 48825f46911..ce115cd784c 100644
--- a/test/dynamo/cpython/3_13/test_int.py
+++ b/test/dynamo/cpython/3_13/test_int.py
@@ -1,13 +1,137 @@
@@ -1,13 +1,140 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_int.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -144,7 +147,7 @@ index 48825f46911..4ab200372ea 100644
try:
import _pylong
@@ -38,7 +162,7 @@ L = [
@@ -38,7 +165,7 @@ L = [
class IntSubclass(int):
pass
@ -153,7 +156,7 @@ index 48825f46911..4ab200372ea 100644
def test_basic(self):
self.assertEqual(int(314), 314)
@@ -566,6 +690,7 @@ class IntTestCases(unittest.TestCase):
@@ -566,6 +693,7 @@ class IntTestCases(unittest.TestCase):
self.assertEqual(n, 1)
self.assertIs(type(n), IntSubclass)
@ -161,7 +164,7 @@ index 48825f46911..4ab200372ea 100644
def test_error_message(self):
def check(s, base=None):
with self.assertRaises(ValueError,
@@ -607,7 +732,7 @@ class IntTestCases(unittest.TestCase):
@@ -607,7 +735,7 @@ class IntTestCases(unittest.TestCase):
self.assertEqual(int('1_2_3_4_5_6_7', 32), 1144132807)
@ -170,7 +173,7 @@ index 48825f46911..4ab200372ea 100644
int_class = int # Override this in subclasses to reuse the suite.
@@ -818,7 +943,7 @@ class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests):
@@ -818,7 +946,7 @@ class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests):
int_class = IntSubclass
@ -179,7 +182,7 @@ index 48825f46911..4ab200372ea 100644
# Tests of the functions in _pylong.py. Those get used when the
# number of digits in the input values are large enough.
@@ -922,4 +1047,4 @@ class PyLongModuleTests(unittest.TestCase):
@@ -922,4 +1050,4 @@ class PyLongModuleTests(unittest.TestCase):
bits <<= 1
if __name__ == "__main__":

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_int.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_int_literal.py b/test/dynamo/cpython/3_13/test_int_literal.py
index bf725710d55..831d03666fb 100644
index bf725710d55..311b8713a36 100644
--- a/test/dynamo/cpython/3_13/test_int_literal.py
+++ b/test/dynamo/cpython/3_13/test_int_literal.py
@@ -1,3 +1,54 @@
@@ -1,3 +1,57 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_int_literal.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -57,7 +60,7 @@ index bf725710d55..831d03666fb 100644
"""Test correct treatment of hex/oct constants.
This is complex because of changes due to PEP 237.
@@ -5,7 +56,7 @@ This is complex because of changes due to PEP 237.
@@ -5,7 +59,7 @@ This is complex because of changes due to PEP 237.
import unittest
@ -66,7 +69,7 @@ index bf725710d55..831d03666fb 100644
def test_hex_baseline(self):
# A few upper/lowercase tests
@@ -140,4 +191,4 @@ class TestHexOctBin(unittest.TestCase):
@@ -140,4 +194,4 @@ class TestHexOctBin(unittest.TestCase):
self.assertEqual(-0b1111111111111111111111111111111111111111111111111111111111111111, -18446744073709551615)
if __name__ == "__main__":

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_int_literal.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_iter.py b/test/dynamo/cpython/3_13/test_iter.py
index 1b9f3cf7624..d0c68f4314c 100644
index 1b9f3cf7624..d2fc26ddc72 100644
--- a/test/dynamo/cpython/3_13/test_iter.py
+++ b/test/dynamo/cpython/3_13/test_iter.py
@@ -1,3 +1,57 @@
@@ -1,3 +1,60 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_iter.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -60,7 +63,7 @@ index 1b9f3cf7624..d0c68f4314c 100644
# Test iterators.
import sys
@@ -104,7 +158,7 @@ class EmptyIterClass:
@@ -104,7 +161,7 @@ class EmptyIterClass:
# Main test suite
@ -69,7 +72,7 @@ index 1b9f3cf7624..d0c68f4314c 100644
# Helper to check that an iterator returns a given sequence
def check_iterator(self, it, seq, pickle=True):
@@ -635,6 +689,7 @@ class TestCase(unittest.TestCase):
@@ -635,6 +692,7 @@ class TestCase(unittest.TestCase):
pass
# Test zip()'s use of iterators.
@ -77,7 +80,7 @@ index 1b9f3cf7624..d0c68f4314c 100644
def test_builtin_zip(self):
self.assertEqual(list(zip()), [])
self.assertEqual(list(zip(*[])), [])
@@ -1187,4 +1242,4 @@ class TestCase(unittest.TestCase):
@@ -1187,4 +1245,4 @@ class TestCase(unittest.TestCase):
if __name__ == "__main__":

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_iter.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_list.py b/test/dynamo/cpython/3_13/test_list.py
index 23ef902aa0b..6e4c6d99d16 100644
index 23ef902aa0b..48e94062a45 100644
--- a/test/dynamo/cpython/3_13/test_list.py
+++ b/test/dynamo/cpython/3_13/test_list.py
@@ -1,6 +1,57 @@
@@ -1,6 +1,60 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_list.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -61,14 +64,14 @@ index 23ef902aa0b..6e4c6d99d16 100644
from test.support import cpython_only
from test.support.script_helper import assert_python_ok
import pickle
@@ -35,8 +86,6 @@ class ListTest(list_tests.CommonTest):
@@ -35,8 +89,6 @@ class ListTest(list_tests.CommonTest):
# Note: This test is expected to SEGV under Cygwin 1.3.12 or
# earlier due to a newlib bug. See the following mailing list
# thread for the details:
self.assertRaises(MemoryError, list, range(sys.maxsize // 2))
# This code used to segfault in Py2.4a3
@@ -324,6 +373,7 @@ class ListTest(list_tests.CommonTest):
@@ -324,6 +376,7 @@ class ListTest(list_tests.CommonTest):
a.append(4)
self.assertEqual(list(it), [])
@ -76,7 +79,7 @@ index 23ef902aa0b..6e4c6d99d16 100644
def test_deopt_from_append_list(self):
# gh-132011: it used to crash, because
# of `CALL_LIST_APPEND` specialization failure.
@@ -345,4 +395,4 @@ class ListTest(list_tests.CommonTest):
@@ -345,4 +398,4 @@ class ListTest(list_tests.CommonTest):
self.assertEqual(rc, 0)
if __name__ == "__main__":

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_list.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_math.py b/test/dynamo/cpython/3_13/test_math.py
index 5ee3055c871..51773d5f478 100644
index 5ee3055c871..6889f53b98f 100644
--- a/test/dynamo/cpython/3_13/test_math.py
+++ b/test/dynamo/cpython/3_13/test_math.py
@@ -1,3 +1,58 @@
@@ -1,3 +1,61 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_math.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -61,7 +64,7 @@ index 5ee3055c871..51773d5f478 100644
# Python test set -- math module
# XXXX Should not do tests around zero only
@@ -242,7 +297,7 @@ class BadDescr:
@@ -242,7 +300,7 @@ class BadDescr:
def __get__(self, obj, objtype=None):
raise ValueError
@ -70,7 +73,7 @@ index 5ee3055c871..51773d5f478 100644
def ftest(self, name, got, expected, ulp_tol=5, abs_tol=0.0):
"""Compare arguments expected and got, as floats, if either
@@ -533,6 +588,7 @@ class MathTests(unittest.TestCase):
@@ -533,6 +591,7 @@ class MathTests(unittest.TestCase):
self.ftest('fabs(0)', math.fabs(0), 0)
self.ftest('fabs(1)', math.fabs(1), 1)
@ -78,7 +81,7 @@ index 5ee3055c871..51773d5f478 100644
def testFactorial(self):
self.assertEqual(math.factorial(0), 1)
total = 1
@@ -1072,6 +1128,7 @@ class MathTests(unittest.TestCase):
@@ -1072,6 +1131,7 @@ class MathTests(unittest.TestCase):
with self.assertRaises(ValueError):
math.dist([1, 2], [3, 4, 5])
@ -86,7 +89,7 @@ index 5ee3055c871..51773d5f478 100644
def testIsqrt(self):
# Test a variety of inputs, large and small.
test_values = (
@@ -1202,12 +1259,6 @@ class MathTests(unittest.TestCase):
@@ -1202,12 +1262,6 @@ class MathTests(unittest.TestCase):
self.assertEqual(math.ldexp(NINF, n), NINF)
self.assertTrue(math.isnan(math.ldexp(NAN, n)))
@ -99,7 +102,7 @@ index 5ee3055c871..51773d5f478 100644
def testLog(self):
self.assertRaises(TypeError, math.log)
self.assertRaises(TypeError, math.log, 1, 2, 3)
@@ -1233,6 +1284,7 @@ class MathTests(unittest.TestCase):
@@ -1233,6 +1287,7 @@ class MathTests(unittest.TestCase):
self.assertRaises(ValueError, math.log1p, -1)
self.assertEqual(math.log1p(INF), INF)
@ -107,7 +110,7 @@ index 5ee3055c871..51773d5f478 100644
@requires_IEEE_754
def testLog2(self):
self.assertRaises(TypeError, math.log2)
@@ -1251,6 +1303,7 @@ class MathTests(unittest.TestCase):
@@ -1251,6 +1306,7 @@ class MathTests(unittest.TestCase):
self.assertRaises(ValueError, math.log2, NINF)
self.assertTrue(math.isnan(math.log2(NAN)))
@ -115,7 +118,7 @@ index 5ee3055c871..51773d5f478 100644
@requires_IEEE_754
# log2() is not accurate enough on Mac OS X Tiger (10.4)
@support.requires_mac_ver(10, 5)
@@ -1332,7 +1385,7 @@ class MathTests(unittest.TestCase):
@@ -1332,7 +1388,7 @@ class MathTests(unittest.TestCase):
with self.assertRaises(RuntimeError):
sumprod(raise_after(5), range(10))
@ -124,7 +127,7 @@ index 5ee3055c871..51773d5f478 100644
self.assertEqual(sumprod(BasicIterClass(1), [1]), 0)
self.assertEqual(sumprod([1], BasicIterClass(1)), 0)
@@ -2252,6 +2305,7 @@ class MathTests(unittest.TestCase):
@@ -2252,6 +2308,7 @@ class MathTests(unittest.TestCase):
self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])),
decimal.Decimal)
@ -132,7 +135,7 @@ index 5ee3055c871..51773d5f478 100644
def testPerm(self):
perm = math.perm
factorial = math.factorial
@@ -2316,6 +2370,7 @@ class MathTests(unittest.TestCase):
@@ -2316,6 +2373,7 @@ class MathTests(unittest.TestCase):
self.assertIs(type(perm(IntSubclass(5), IntSubclass(k))), int)
self.assertIs(type(perm(MyIndexable(5), MyIndexable(k))), int)
@ -140,7 +143,7 @@ index 5ee3055c871..51773d5f478 100644
def testComb(self):
comb = math.comb
factorial = math.factorial
@@ -2446,6 +2501,7 @@ class MathTests(unittest.TestCase):
@@ -2446,6 +2504,7 @@ class MathTests(unittest.TestCase):
math.nextafter(1.0, INF, steps=-1)
@ -148,7 +151,7 @@ index 5ee3055c871..51773d5f478 100644
@requires_IEEE_754
def test_ulp(self):
self.assertEqual(math.ulp(1.0), sys.float_info.epsilon)
@@ -2508,7 +2564,7 @@ class MathTests(unittest.TestCase):
@@ -2508,7 +2567,7 @@ class MathTests(unittest.TestCase):
self.assertEqual(math.copysign(1.0, x), math.copysign(1.0, y))
@ -157,7 +160,7 @@ index 5ee3055c871..51773d5f478 100644
isclose = math.isclose # subclasses should override this
def assertIsClose(self, a, b, *args, **kwargs):
@@ -2631,7 +2687,7 @@ class IsCloseTests(unittest.TestCase):
@@ -2631,7 +2690,7 @@ class IsCloseTests(unittest.TestCase):
self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
@ -166,7 +169,7 @@ index 5ee3055c871..51773d5f478 100644
""" Tests for math.fma. """
def test_fma_nan_results(self):
@@ -2719,8 +2775,7 @@ class FMATests(unittest.TestCase):
@@ -2719,8 +2778,7 @@ class FMATests(unittest.TestCase):
# properly: it doesn't use the right sign when the result is zero.
@unittest.skipIf(
sys.platform.startswith(("freebsd", "wasi", "netbsd", "emscripten"))
@ -176,7 +179,7 @@ index 5ee3055c871..51773d5f478 100644
f"this platform doesn't implement IEE 754-2008 properly")
def test_fma_zero_result(self):
nonnegative_finites = [0.0, 1e-300, 2.3, 1e300]
@@ -2879,10 +2934,5 @@ class FMATests(unittest.TestCase):
@@ -2879,10 +2937,5 @@ class FMATests(unittest.TestCase):
)

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_math.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_ordered_dict.py b/test/dynamo/cpython/3_13/test_ordered_dict.py
index a9b6a84996e..b77eff70414 100644
index a9b6a84996e..d9fce736a10 100644
--- a/test/dynamo/cpython/3_13/test_ordered_dict.py
+++ b/test/dynamo/cpython/3_13/test_ordered_dict.py
@@ -1,3 +1,57 @@
@@ -1,3 +1,60 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_ordered_dict.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -60,7 +63,7 @@ index a9b6a84996e..b77eff70414 100644
import builtins
import contextlib
import copy
@@ -760,7 +814,7 @@ class _TriggerSideEffectOnEqual:
@@ -760,7 +817,7 @@ class _TriggerSideEffectOnEqual:
def side_effect(self):
raise NotImplementedError
@ -69,7 +72,7 @@ index a9b6a84996e..b77eff70414 100644
module = py_coll
OrderedDict = py_coll.OrderedDict
@@ -781,7 +835,7 @@ class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase):
@@ -781,7 +838,7 @@ class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase):
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
@ -78,7 +81,7 @@ index a9b6a84996e..b77eff70414 100644
"""Builtin dict preserves insertion order.
Reuse some of tests in OrderedDict selectively.
@@ -800,6 +854,7 @@ for method in (
@@ -800,6 +857,7 @@ for method in (
del method
@ -86,7 +89,7 @@ index a9b6a84996e..b77eff70414 100644
class CPythonOrderedDictSideEffects:
def check_runtime_error_issue119004(self, dict1, dict2):
@@ -878,7 +933,7 @@ class CPythonOrderedDictSideEffects:
@@ -878,7 +936,7 @@ class CPythonOrderedDictSideEffects:
@unittest.skipUnless(c_coll, 'requires the C version of the collections module')
class CPythonOrderedDictTests(OrderedDictTests,
CPythonOrderedDictSideEffects,
@ -95,7 +98,7 @@ index a9b6a84996e..b77eff70414 100644
module = c_coll
OrderedDict = c_coll.OrderedDict
@@ -986,7 +1041,7 @@ class CPythonOrderedDictSubclassTests(CPythonOrderedDictTests):
@@ -986,7 +1044,7 @@ class CPythonOrderedDictSubclassTests(CPythonOrderedDictTests):
pass
@ -104,7 +107,7 @@ index a9b6a84996e..b77eff70414 100644
module = py_coll
class OrderedDict(py_coll.OrderedDict):
@@ -995,7 +1050,7 @@ class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
@@ -995,7 +1053,7 @@ class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
@unittest.skipUnless(c_coll, 'requires the C version of the collections module')
@ -113,7 +116,7 @@ index a9b6a84996e..b77eff70414 100644
module = c_coll
class OrderedDict(c_coll.OrderedDict):
@@ -1008,6 +1063,7 @@ class PurePythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
@@ -1008,6 +1066,7 @@ class PurePythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
@classmethod
def setUpClass(cls):
cls.type2test = py_coll.OrderedDict
@ -121,7 +124,7 @@ index a9b6a84996e..b77eff70414 100644
def test_popitem(self):
d = self._empty_mapping()
@@ -1020,6 +1076,7 @@ class CPythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
@@ -1020,6 +1079,7 @@ class CPythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
@classmethod
def setUpClass(cls):
cls.type2test = c_coll.OrderedDict
@ -129,7 +132,7 @@ index a9b6a84996e..b77eff70414 100644
def test_popitem(self):
d = self._empty_mapping()
@@ -1033,6 +1090,7 @@ class PurePythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
@@ -1033,6 +1093,7 @@ class PurePythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
class MyOrderedDict(py_coll.OrderedDict):
pass
cls.type2test = MyOrderedDict
@ -137,7 +140,7 @@ index a9b6a84996e..b77eff70414 100644
def test_popitem(self):
d = self._empty_mapping()
@@ -1047,6 +1105,7 @@ class CPythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
@@ -1047,6 +1108,7 @@ class CPythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
class MyOrderedDict(c_coll.OrderedDict):
pass
cls.type2test = MyOrderedDict
@ -145,7 +148,7 @@ index a9b6a84996e..b77eff70414 100644
def test_popitem(self):
d = self._empty_mapping()
@@ -1120,21 +1179,22 @@ class SimpleLRUCacheTests:
@@ -1120,21 +1182,22 @@ class SimpleLRUCacheTests:
self.assertEqual(list(c), [1, 3, 2])

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_ordered_dict.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_raise.py b/test/dynamo/cpython/3_13/test_raise.py
index 6d26a61bee4..8a52b9bfc82 100644
index 6d26a61bee4..042d1ae3d7c 100644
--- a/test/dynamo/cpython/3_13/test_raise.py
+++ b/test/dynamo/cpython/3_13/test_raise.py
@@ -1,3 +1,55 @@
@@ -1,3 +1,58 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_raise.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -58,7 +61,7 @@ index 6d26a61bee4..8a52b9bfc82 100644
# Copyright 2007 Google, Inc. All Rights Reserved.
# Licensed to PSF under a Contributor Agreement.
@@ -23,7 +75,7 @@ class Context:
@@ -23,7 +78,7 @@ class Context:
return True
@ -67,7 +70,7 @@ index 6d26a61bee4..8a52b9bfc82 100644
def test_invalid_reraise(self):
try:
raise
@@ -148,7 +200,7 @@ class TestRaise(unittest.TestCase):
@@ -148,7 +203,7 @@ class TestRaise(unittest.TestCase):
@ -76,7 +79,7 @@ index 6d26a61bee4..8a52b9bfc82 100644
def testCauseSyntax(self):
try:
@@ -221,7 +273,7 @@ class TestCause(unittest.TestCase):
@@ -221,7 +276,7 @@ class TestCause(unittest.TestCase):
self.fail("No exception raised")
@ -85,7 +88,7 @@ index 6d26a61bee4..8a52b9bfc82 100644
def test_sets_traceback(self):
try:
@@ -242,7 +294,7 @@ class TestTraceback(unittest.TestCase):
@@ -242,7 +297,7 @@ class TestTraceback(unittest.TestCase):
self.fail("No exception raised")
@ -94,7 +97,7 @@ index 6d26a61bee4..8a52b9bfc82 100644
def raiser(self):
raise ValueError
@@ -308,7 +360,7 @@ class TestTracebackType(unittest.TestCase):
@@ -308,7 +363,7 @@ class TestTracebackType(unittest.TestCase):
types.TracebackType(other_tb, frame, 1, "nuh-uh")
@ -103,7 +106,7 @@ index 6d26a61bee4..8a52b9bfc82 100644
def test_instance_context_instance_raise(self):
context = IndexError()
try:
@@ -498,7 +550,7 @@ class TestContext(unittest.TestCase):
@@ -498,7 +553,7 @@ class TestContext(unittest.TestCase):
self.assertEqual(ZeroDivisionError, cm.unraisable.exc_type)
@ -112,7 +115,7 @@ index 6d26a61bee4..8a52b9bfc82 100644
def test_tuples(self):
try:
raise (IndexError, KeyError) # This should be a tuple!
@@ -517,4 +569,4 @@ class TestRemovedFunctionality(unittest.TestCase):
@@ -517,4 +572,4 @@ class TestRemovedFunctionality(unittest.TestCase):
if __name__ == "__main__":

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_raise.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_set.py b/test/dynamo/cpython/3_13/test_set.py
index d9102eb98a5..0b8e99a04c4 100644
index d9102eb98a5..3543d60751e 100644
--- a/test/dynamo/cpython/3_13/test_set.py
+++ b/test/dynamo/cpython/3_13/test_set.py
@@ -1,3 +1,53 @@
@@ -1,3 +1,56 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_set.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -56,7 +59,7 @@ index d9102eb98a5..0b8e99a04c4 100644
import unittest
from test import support
from test.support import warnings_helper
@@ -38,7 +88,7 @@ class HashCountingInt(int):
@@ -38,7 +91,7 @@ class HashCountingInt(int):
self.hash_count += 1
return int.__hash__(self)
@ -65,7 +68,7 @@ index d9102eb98a5..0b8e99a04c4 100644
# Tests common to both set and frozenset
def setUp(self):
@@ -47,6 +97,7 @@ class TestJointOps:
@@ -47,6 +100,7 @@ class TestJointOps:
self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
self.s = self.thetype(word)
self.d = dict.fromkeys(word)
@ -73,7 +76,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_new_or_init(self):
self.assertRaises(TypeError, self.thetype, [], 2)
@@ -355,7 +406,7 @@ class TestJointOps:
@@ -355,7 +409,7 @@ class TestJointOps:
def test_free_after_iterating(self):
support.check_free_after_iterating(self, iter, self.thetype)
@ -82,7 +85,7 @@ index d9102eb98a5..0b8e99a04c4 100644
thetype = set
basetype = set
@@ -675,7 +726,7 @@ class TestSetSubclass(TestSet):
@@ -675,7 +729,7 @@ class TestSetSubclass(TestSet):
subclass_with_new([1, 2], newarg=3)
@ -91,7 +94,7 @@ index d9102eb98a5..0b8e99a04c4 100644
thetype = frozenset
basetype = frozenset
@@ -811,10 +862,17 @@ class TestFrozenSetSubclass(TestFrozenSet):
@@ -811,10 +865,17 @@ class TestFrozenSetSubclass(TestFrozenSet):
class SetSubclassWithSlots(set):
__slots__ = ('x', 'y', '__dict__')
@ -112,7 +115,7 @@ index d9102eb98a5..0b8e99a04c4 100644
class FrozenSetSubclassWithSlots(frozenset):
__slots__ = ('x', 'y', '__dict__')
@@ -828,7 +886,7 @@ empty_set = set()
@@ -828,7 +889,7 @@ empty_set = set()
#==============================================================================
@ -121,7 +124,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_repr(self):
if self.repr is not None:
@@ -934,7 +992,7 @@ class TestBasicOps:
@@ -934,7 +995,7 @@ class TestBasicOps:
#------------------------------------------------------------------------------
@ -130,7 +133,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def setUp(self):
self.case = "empty set"
self.values = []
@@ -942,10 +1000,11 @@ class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase):
@@ -942,10 +1003,11 @@ class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase):
self.dup = set(self.values)
self.length = 0
self.repr = "set()"
@ -143,7 +146,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def setUp(self):
self.case = "unit set (number)"
self.values = [3]
@@ -953,6 +1012,7 @@ class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
@@ -953,6 +1015,7 @@ class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
self.dup = set(self.values)
self.length = 1
self.repr = "{3}"
@ -151,7 +154,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_in(self):
self.assertIn(3, self.set)
@@ -962,7 +1022,7 @@ class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
@@ -962,7 +1025,7 @@ class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
#------------------------------------------------------------------------------
@ -160,7 +163,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def setUp(self):
self.case = "unit set (tuple)"
self.values = [(0, "zero")]
@@ -970,6 +1030,7 @@ class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
@@ -970,6 +1033,7 @@ class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
self.dup = set(self.values)
self.length = 1
self.repr = "{(0, 'zero')}"
@ -168,7 +171,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_in(self):
self.assertIn((0, "zero"), self.set)
@@ -979,7 +1040,7 @@ class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
@@ -979,7 +1043,7 @@ class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
#------------------------------------------------------------------------------
@ -177,7 +180,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def setUp(self):
self.case = "triple set"
self.values = [0, "zero", operator.add]
@@ -987,36 +1048,39 @@ class TestBasicOpsTriple(TestBasicOps, unittest.TestCase):
@@ -987,36 +1051,39 @@ class TestBasicOpsTriple(TestBasicOps, unittest.TestCase):
self.dup = set(self.values)
self.length = 3
self.repr = None
@ -220,7 +223,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def setUp(self):
self.enterContext(warnings_helper.check_warnings())
warnings.simplefilter('ignore', BytesWarning)
@@ -1025,6 +1089,7 @@ class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
@@ -1025,6 +1092,7 @@ class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
self.set = set(self.values)
self.dup = set(self.values)
self.length = 4
@ -228,7 +231,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_repr(self):
self.check_repr_against_values()
@@ -1038,7 +1103,7 @@ def baditer():
@@ -1038,7 +1106,7 @@ def baditer():
def gooditer():
yield True
@ -237,7 +240,7 @@ index d9102eb98a5..0b8e99a04c4 100644
"""SF 628246: Set constructor should not trap iterator TypeErrors"""
def test_instanceWithException(self):
@@ -1065,7 +1130,7 @@ class TestExceptionPropagation(unittest.TestCase):
@@ -1065,7 +1133,7 @@ class TestExceptionPropagation(unittest.TestCase):
#==============================================================================
@ -246,7 +249,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_constructor(self):
inner = frozenset([1])
outer = set([inner])
@@ -1078,9 +1143,10 @@ class TestSetOfSets(unittest.TestCase):
@@ -1078,9 +1146,10 @@ class TestSetOfSets(unittest.TestCase):
#==============================================================================
@ -258,7 +261,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_eq(self): # SF bug 643115
self.assertEqual(self.set, set({2:1,4:3,6:5}))
@@ -1151,9 +1217,10 @@ class TestBinaryOps(unittest.TestCase):
@@ -1151,9 +1220,10 @@ class TestBinaryOps(unittest.TestCase):
#==============================================================================
@ -270,7 +273,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_union_subset(self):
self.set |= set([2])
@@ -1237,10 +1304,11 @@ class TestUpdateOps(unittest.TestCase):
@@ -1237,10 +1307,11 @@ class TestUpdateOps(unittest.TestCase):
#==============================================================================
@ -283,7 +286,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_add_present(self):
self.set.add("c")
@@ -1311,7 +1379,7 @@ class TestMutate(unittest.TestCase):
@@ -1311,7 +1382,7 @@ class TestMutate(unittest.TestCase):
#==============================================================================
@ -292,7 +295,7 @@ index d9102eb98a5..0b8e99a04c4 100644
case2method = {"<=": "issubset",
">=": "issuperset",
@@ -1334,22 +1402,22 @@ class TestSubsets:
@@ -1334,22 +1405,22 @@ class TestSubsets:
result = eval("x" + case + "y", locals())
self.assertEqual(result, expected)
# Test the "friendly" method-name spelling, if one exists.
@ -321,7 +324,7 @@ index d9102eb98a5..0b8e99a04c4 100644
left = set()
right = set()
name = "both empty"
@@ -1357,7 +1425,7 @@ class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase):
@@ -1357,7 +1428,7 @@ class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase):
#------------------------------------------------------------------------------
@ -330,7 +333,7 @@ index d9102eb98a5..0b8e99a04c4 100644
left = set([1, 2])
right = set([1, 2])
name = "equal pair"
@@ -1365,7 +1433,7 @@ class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase):
@@ -1365,7 +1436,7 @@ class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase):
#------------------------------------------------------------------------------
@ -339,7 +342,7 @@ index d9102eb98a5..0b8e99a04c4 100644
left = set()
right = set([1, 2])
name = "one empty, one non-empty"
@@ -1373,7 +1441,7 @@ class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase):
@@ -1373,7 +1444,7 @@ class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase):
#------------------------------------------------------------------------------
@ -348,7 +351,7 @@ index d9102eb98a5..0b8e99a04c4 100644
left = set([1])
right = set([1, 2])
name = "one a non-empty proper subset of other"
@@ -1381,7 +1449,7 @@ class TestSubsetPartial(TestSubsets, unittest.TestCase):
@@ -1381,7 +1452,7 @@ class TestSubsetPartial(TestSubsets, unittest.TestCase):
#------------------------------------------------------------------------------
@ -357,7 +360,7 @@ index d9102eb98a5..0b8e99a04c4 100644
left = set([1])
right = set([2])
name = "neither empty, neither contains"
@@ -1389,7 +1457,7 @@ class TestSubsetNonOverlap(TestSubsets, unittest.TestCase):
@@ -1389,7 +1460,7 @@ class TestSubsetNonOverlap(TestSubsets, unittest.TestCase):
#==============================================================================
@ -366,7 +369,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_eq_ne(self):
# Unlike the others, this is testing that == and != *are* allowed.
@@ -1505,47 +1573,52 @@ class TestOnlySetsInBinaryOps:
@@ -1505,47 +1576,52 @@ class TestOnlySetsInBinaryOps:
#------------------------------------------------------------------------------
@ -425,7 +428,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def setUp(self):
def gen():
for i in range(0, 10, 2):
@@ -1553,10 +1626,11 @@ class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase):
@@ -1553,10 +1629,11 @@ class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase):
self.set = set((1, 2, 3))
self.other = gen()
self.otherIsIterable = True
@ -438,7 +441,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_copy(self):
dup = self.set.copy()
@@ -1577,40 +1651,46 @@ class TestCopying:
@@ -1577,40 +1654,46 @@ class TestCopying:
#------------------------------------------------------------------------------
@ -491,7 +494,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_binopsVsSubsets(self):
a, b = self.a, self.b
@@ -1727,7 +1807,7 @@ def L(seqn):
@@ -1727,7 +1810,7 @@ def L(seqn):
'Test multiple tiers of iterators'
return chain(map(lambda x:x, R(Ig(G(seqn)))))
@ -500,7 +503,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_constructor(self):
for cons in (set, frozenset):
@@ -1785,7 +1865,7 @@ class bad_dict_clear:
@@ -1785,7 +1868,7 @@ class bad_dict_clear:
def __hash__(self):
return 0
@ -509,7 +512,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_8420_set_merge(self):
# This used to segfault
global be_bad, set2, dict2
@@ -1826,7 +1906,7 @@ class TestWeirdBugs(unittest.TestCase):
@@ -1826,7 +1909,7 @@ class TestWeirdBugs(unittest.TestCase):
s.update(other)
@ -518,7 +521,7 @@ index d9102eb98a5..0b8e99a04c4 100644
"""Regression test for bpo-46615"""
constructor1 = None
@@ -1862,7 +1942,7 @@ class TestOperationsMutating:
@@ -1862,7 +1945,7 @@ class TestOperationsMutating:
self.assertIn("changed size during iteration", str(e))
@ -527,7 +530,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_eq_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a == b)
@@ -1933,24 +2013,24 @@ class TestBinaryOpsMutating(TestOperationsMutating):
@@ -1933,24 +2016,24 @@ class TestBinaryOpsMutating(TestOperationsMutating):
self.check_set_op_does_not_crash(f3)
@ -557,7 +560,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_issubset_with_mutation(self):
self.check_set_op_does_not_crash(set.issubset)
@@ -1986,27 +2066,27 @@ class TestMethodsMutating(TestOperationsMutating):
@@ -1986,27 +2069,27 @@ class TestMethodsMutating(TestOperationsMutating):
self.check_set_op_does_not_crash(set.update)
@ -591,7 +594,7 @@ index d9102eb98a5..0b8e99a04c4 100644
constructor1 = set
constructor2 = list
@@ -2068,7 +2148,7 @@ def faces(G):
@@ -2068,7 +2151,7 @@ def faces(G):
return f
@ -600,7 +603,7 @@ index d9102eb98a5..0b8e99a04c4 100644
def test_cube(self):
@@ -2118,4 +2198,4 @@ class TestGraphs(unittest.TestCase):
@@ -2118,4 +2201,4 @@ class TestGraphs(unittest.TestCase):
#==============================================================================
if __name__ == "__main__":

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_set.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_sort.py b/test/dynamo/cpython/3_13/test_sort.py
index 2a7cfb7affa..d661ae544b9 100644
index 2a7cfb7affa..58b9b796362 100644
--- a/test/dynamo/cpython/3_13/test_sort.py
+++ b/test/dynamo/cpython/3_13/test_sort.py
@@ -1,3 +1,54 @@
@@ -1,3 +1,57 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_sort.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -57,7 +60,7 @@ index 2a7cfb7affa..d661ae544b9 100644
from test import support
import random
import unittest
@@ -39,7 +90,7 @@ def check(tag, expected, raw, compare=None):
@@ -39,7 +93,7 @@ def check(tag, expected, raw, compare=None):
nerrors += 1
return
@ -66,7 +69,7 @@ index 2a7cfb7affa..d661ae544b9 100644
def testStressfully(self):
# Try a variety of sizes at and around powers of 2, and at powers of 10.
sizes = [0]
@@ -151,7 +202,7 @@ class TestBase(unittest.TestCase):
@@ -151,7 +205,7 @@ class TestBase(unittest.TestCase):
self.assertEqual(forced, native)
#==============================================================================
@ -75,7 +78,7 @@ index 2a7cfb7affa..d661ae544b9 100644
def test_bug453523(self):
# bug 453523 -- list.sort() crasher.
@@ -188,7 +239,7 @@ class TestBugs(unittest.TestCase):
@@ -188,7 +242,7 @@ class TestBugs(unittest.TestCase):
#==============================================================================
@ -84,7 +87,7 @@ index 2a7cfb7affa..d661ae544b9 100644
def test_decorated(self):
data = 'The quick Brown fox Jumped over The lazy Dog'.split()
@@ -309,7 +360,7 @@ def check_against_PyObject_RichCompareBool(self, L):
@@ -309,7 +363,7 @@ def check_against_PyObject_RichCompareBool(self, L):
self.assertIs(opt, ref)
#note: not assertEqual! We want to ensure *identical* behavior.
@ -93,7 +96,7 @@ index 2a7cfb7affa..d661ae544b9 100644
def test_safe_object_compare(self):
heterogeneous_lists = [[0, 'foo'],
[0.0, 'foo'],
@@ -408,4 +459,4 @@ class TestOptimizedCompares(unittest.TestCase):
@@ -408,4 +462,4 @@ class TestOptimizedCompares(unittest.TestCase):
#==============================================================================
if __name__ == "__main__":

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_sort.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_sys.py b/test/dynamo/cpython/3_13/test_sys.py
index 72d51361e0b..0b4c6882e62 100644
index 6b37094ed5f..c5e96a6a3dd 100644
--- a/test/dynamo/cpython/3_13/test_sys.py
+++ b/test/dynamo/cpython/3_13/test_sys.py
@@ -1,3 +1,55 @@
@@ -1,3 +1,58 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_sys.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -58,7 +61,7 @@ index 72d51361e0b..0b4c6882e62 100644
import builtins
import codecs
import _datetime
@@ -35,7 +87,7 @@ def requires_subinterpreters(meth):
@@ -35,7 +90,7 @@ def requires_subinterpreters(meth):
DICT_KEY_STRUCT_FORMAT = 'n2BI2n'
@ -67,7 +70,7 @@ index 72d51361e0b..0b4c6882e62 100644
def test_original_displayhook(self):
dh = sys.__displayhook__
@@ -81,19 +133,8 @@ class DisplayHookTest(unittest.TestCase):
@@ -81,19 +136,8 @@ class DisplayHookTest(unittest.TestCase):
code = compile("42", "<string>", "single")
self.assertRaises(ValueError, eval, code)
@ -77,18 +80,18 @@ index 72d51361e0b..0b4c6882e62 100644
- sys.stdout = io.StringIO()
- support.gc_collect()
- return 'foo'
-
- with support.swap_attr(sys, 'stdout', None):
- sys.stdout = io.StringIO() # the only reference
- sys.displayhook(X()) # should not crash
-
-
-class ActiveExceptionTests(unittest.TestCase):
+class ActiveExceptionTests(__TestCase):
def test_exc_info_no_exception(self):
self.assertEqual(sys.exc_info(), (None, None, None))
@@ -157,7 +198,7 @@ class ActiveExceptionTests(unittest.TestCase):
@@ -157,7 +201,7 @@ class ActiveExceptionTests(unittest.TestCase):
self.assertIs(exc, e)
@ -97,7 +100,7 @@ index 72d51361e0b..0b4c6882e62 100644
@force_not_colorized
def test_original_excepthook(self):
@@ -200,7 +241,7 @@ class ExceptHookTest(unittest.TestCase):
@@ -200,7 +244,7 @@ class ExceptHookTest(unittest.TestCase):
# Python/pythonrun.c::PyErr_PrintEx() is tricky.
@ -106,7 +109,7 @@ index 72d51361e0b..0b4c6882e62 100644
def tearDown(self):
test.support.reap_children()
@@ -500,6 +541,7 @@ class SysModuleTest(unittest.TestCase):
@@ -500,6 +544,7 @@ class SysModuleTest(unittest.TestCase):
is sys._getframe().f_code
)
@ -114,16 +117,21 @@ index 72d51361e0b..0b4c6882e62 100644
def test_getframemodulename(self):
# Default depth gets ourselves
self.assertEqual(__name__, sys._getframemodulename())
@@ -808,7 +850,7 @@ class SysModuleTest(unittest.TestCase):
self.assertRaises(TypeError, sys.intern, S("abc"))
if has_is_interned:
self.assertIs(sys._is_interned(S("abc")), False)
-
+
@support.cpython_only
@requires_subinterpreters
def test_subinterp_intern_dynamically_allocated(self):
@@ -1359,7 +1401,7 @@ class SysModuleTest(unittest.TestCase):
@@ -894,7 +939,12 @@ class SysModuleTest(unittest.TestCase):
def assert_raise_on_new_sys_type(self, sys_attr):
# Users are intentionally prevented from creating new instances of
# sys.flags, sys.version_info, and sys.getwindowsversion.
- support.check_disallow_instantiation(self, type(sys_attr), sys_attr)
+ arg = sys_attr
+ attr_type = type(sys_attr)
+ with self.assertRaises(TypeError):
+ attr_type(arg)
+ with self.assertRaises(TypeError):
+ attr_type.__new__(attr_type, arg)
def test_sys_flags_no_instantiation(self):
self.assert_raise_on_new_sys_type(sys.flags)
@@ -1354,7 +1404,7 @@ class SysModuleTest(unittest.TestCase):
@test.support.cpython_only
@ -132,7 +140,7 @@ index 72d51361e0b..0b4c6882e62 100644
def test_original_unraisablehook(self):
_testcapi = import_helper.import_module('_testcapi')
from _testcapi import err_writeunraisable, err_formatunraisable
@@ -1516,7 +1558,7 @@ class UnraisableHookTest(unittest.TestCase):
@@ -1511,7 +1561,7 @@ class UnraisableHookTest(unittest.TestCase):
@test.support.cpython_only
@ -141,7 +149,7 @@ index 72d51361e0b..0b4c6882e62 100644
def setUp(self):
self.P = struct.calcsize('P')
@@ -1524,6 +1566,7 @@ class SizeofTest(unittest.TestCase):
@@ -1519,6 +1569,7 @@ class SizeofTest(unittest.TestCase):
_testinternalcapi = import_helper.import_module("_testinternalcapi")
self.gc_headsize = _testinternalcapi.SIZEOF_PYGC_HEAD
self.managed_pre_header_size = _testinternalcapi.SIZEOF_MANAGED_PRE_HEADER
@ -149,7 +157,7 @@ index 72d51361e0b..0b4c6882e62 100644
check_sizeof = test.support.check_sizeof
@@ -1960,4 +2003,4 @@ class SizeofTest(unittest.TestCase):
@@ -1955,4 +2006,4 @@ class SizeofTest(unittest.TestCase):
self.assertEqual(err, b"")
if __name__ == "__main__":

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_sys.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,8 +1,8 @@
diff --git a/test/dynamo/cpython/3_13/test_tuple.py b/test/dynamo/cpython/3_13/test_tuple.py
index 9ce80c5e8ea..e52c0cbc140 100644
index 9ce80c5e8ea..c6eab3ff1e9 100644
--- a/test/dynamo/cpython/3_13/test_tuple.py
+++ b/test/dynamo/cpython/3_13/test_tuple.py
@@ -1,4 +1,55 @@
@@ -1,4 +1,58 @@
-from test import support, seq_tests
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
@ -10,6 +10,9 @@ index 9ce80c5e8ea..e52c0cbc140 100644
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_tuple.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -59,7 +62,7 @@ index 9ce80c5e8ea..e52c0cbc140 100644
import unittest
import gc
@@ -510,4 +561,4 @@ class TupleTest(seq_tests.CommonTest):
@@ -510,4 +564,4 @@ class TupleTest(seq_tests.CommonTest):
# pileup 262,143 mean 8.0 coll 262,143 z +92683.6
if __name__ == "__main__":

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_tuple.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_userdict.py b/test/dynamo/cpython/3_13/test_userdict.py
index 61e79f553e8..c953390355e 100644
index 61e79f553e8..75b789633ed 100644
--- a/test/dynamo/cpython/3_13/test_userdict.py
+++ b/test/dynamo/cpython/3_13/test_userdict.py
@@ -1,3 +1,54 @@
@@ -1,3 +1,57 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userdict.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -57,7 +60,7 @@ index 61e79f553e8..c953390355e 100644
# Check every path through every method of UserDict
from test import mapping_tests, support
@@ -215,10 +266,10 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol):
@@ -215,10 +269,10 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol):
# Decorate existing test with recursion limit, because
# the test is for C structure, but `UserDict` is a Python structure.

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userdict.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -1,14 +1,17 @@
diff --git a/test/dynamo/cpython/3_13/test_userlist.py b/test/dynamo/cpython/3_13/test_userlist.py
index 312702c8e39..a4532922f5d 100644
index 312702c8e39..5ede0c3b7f1 100644
--- a/test/dynamo/cpython/3_13/test_userlist.py
+++ b/test/dynamo/cpython/3_13/test_userlist.py
@@ -1,7 +1,58 @@
@@ -1,7 +1,61 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userlist.py
+
+import sys
+import torch
+import torch._dynamo.test_case
@ -62,7 +65,7 @@ index 312702c8e39..a4532922f5d 100644
import unittest
from test import support
@@ -69,9 +120,9 @@ class UserListTest(list_tests.CommonTest):
@@ -69,9 +123,9 @@ class UserListTest(list_tests.CommonTest):
# Decorate existing test with recursion limit, because
# the test is for C structure, but `UserList` is a Python structure.

View File

@ -4,6 +4,9 @@
# ruff: noqa
# flake8: noqa
# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userlist.py
import sys
import torch
import torch._dynamo.test_case

View File

@ -35,7 +35,11 @@ from torch.export import Dim, export, export_for_training
from torch.export.pt2_archive._package import load_pt2
from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM80OrLater
from torch.testing._internal.common_cuda import (
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8,
SM80OrLater,
)
from torch.testing._internal.common_device_type import (
_has_sufficient_memory,
skipCUDAIf,
@ -188,6 +192,9 @@ class AOTInductorTestsTemplate:
# Skip embed_kernel_binary == True for now as it shows random
# failure on CI
@common_utils.parametrize("embed_kernel_binary", [False])
@unittest.skipIf(
_get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+"
)
def test_simple_multi_arch(self, embed_kernel_binary):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU_TYPE")

View File

@ -21,6 +21,7 @@ from torch._inductor.test_case import TestCase
from torch._inductor.utils import fresh_cache
from torch.export import Dim
from torch.export.pt2_archive._package import load_pt2, load_weights_to_pt2_contents
from torch.testing._internal.common_cuda import _get_torch_cuda_version
from torch.testing._internal.common_utils import (
IS_FBCODE,
skipIfRocm,
@ -249,6 +250,9 @@ class TestAOTInductorPackage(TestCase):
self.check_model(Model(), example_inputs)
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
@unittest.skipIf(
_get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+"
)
@skipIfXpu # build system may be different
def test_compile_after_package(self):
self.check_package_cpp_only()
@ -294,6 +298,9 @@ class TestAOTInductorPackage(TestCase):
actual = optimized(*example_inputs)
self.assertTrue(torch.allclose(actual, expected))
@unittest.skipIf(
_get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+"
)
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
@skipIfRocm # doesn't support multi-arch binary
@skipIfXpu # doesn't support multi-arch binary
@ -338,6 +345,9 @@ class TestAOTInductorPackage(TestCase):
actual = optimized(*example_inputs)
self.assertTrue(torch.allclose(actual, expected))
@unittest.skipIf(
_get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+"
)
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
@skipIfXpu # build system may be different
def test_compile_after_package_static(self):
@ -396,6 +406,9 @@ class TestAOTInductorPackage(TestCase):
with self.assertRaisesRegex(Exception, "Invalid AOTI model name"):
self.cmake_compile(model, example_inputs, options, "")
@unittest.skipIf(
_get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+"
)
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
@skipIfRocm # doesn't support multi-arch binary
@skipIfXpu # doesn't support multi-arch binary

View File

@ -1854,6 +1854,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch({"autograd_cache_normalize_inputs": True})
def test_split_module(self):
class Mod(torch.nn.Module):
def forward(self, x, a0, a1, b0, b1, c0, c1):
@ -1900,6 +1901,14 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
y = ca0(a0, x, a1)
y = ca1(b0, y, b1)
y = ca2(c0, y, c1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
# TODO: split_module causes ca1 and ca2 to have different type annotations
# for the parameter x, so we can only AOTAutogradCache cache hit once instead of twice
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2)
expected = Mod()(*example_inputs)
self.assertEqual(y, expected)

View File

@ -30,7 +30,9 @@ from torch.testing._internal.common_device_type import (
)
from torch.testing._internal.common_methods_invocations import op_db, skipOps
from torch.testing._internal.common_utils import (
IS_CI,
IS_MACOS,
IS_WINDOWS,
IS_X86,
skipCUDAMemoryLeakCheckIf,
skipIfCrossRef,
@ -67,6 +69,15 @@ except (unittest.SkipTest, ImportError) as e:
sys.exit(0)
raise
if IS_WINDOWS and IS_CI:
# TODO(xuhancn) : improve the compiler build performance on windows.
sys.stderr.write(
"This UT is too slow on windows, and will cause out of time in CI. So skip it now.\n"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("skip slow test")
bf16 = torch.bfloat16 # not tested
f64 = torch.float64
f32 = torch.float32

View File

@ -9257,6 +9257,18 @@ class TestSDPA(TestCaseMPS):
def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self):
self._test_sdpa_mask(torch.float16, 7, 17, 23, 121)
# Regression test from: https://github.com/pytorch/pytorch/issues/156707
@parametrize("dtype", [torch.float16, torch.float32])
def test_sdpa_full_mask(self, dtype):
q = torch.randn(1, 1, 2, 4, dtype=dtype)
k = torch.randn(1, 1, 2, 4, dtype=dtype)
v = torch.randn(1, 1, 2, 4, dtype=dtype)
mask = torch.tensor([[[[False, False], [True, True]]]], dtype=torch.bool)
out_cpu = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
out_mps = F.scaled_dot_product_attention(q.to('mps'), k.to('mps'), v.to('mps'), attn_mask=mask.to('mps'))
self._compare_tensors(out_mps.cpu(), out_cpu)
@parametrize("dtype", [torch.float16, torch.float32])
def test_sdpa_3d_input(self, dtype):
head_num, seq_len, embed_dim = 16, 16, 80

View File

@ -350,6 +350,13 @@ class ProcessGroup:
) -> None: ...
def rank(self) -> int: ...
def size(self) -> int: ...
def split_group(
self,
new_ranks: list[int],
timeout: Optional[timedelta] = None,
pg_options: Optional[Backend.Options] = None,
group_desc: Optional[str] = None,
) -> Optional[ProcessGroup]: ...
def abort(self) -> None: ...
def set_timeout(self, timeout: timedelta) -> None: ...
def shutdown(self) -> None: ...

View File

@ -384,6 +384,57 @@ class AOTAutogradCachePickler(FxGraphCachePickler):
return (_ident, (metadata,))
@contextlib.contextmanager
def normalize_placeholder_names(gm: torch.fx.GraphModule):
"""
Context manager that normalizes the placeholder names in the graph module.
This is used while generating a cache key for AOTAutogradCache, so that two graphs
that are isomorphic when normalizing names can hit the same cache entry.
This is safe because nothing underneath AOTAutograd uses the node names on the
original dynamo graph: AOTAutograd re-traces with its own nodes, and guards are
in terms of original sources rather than placeholder names.
"""
# Standalone inductor: we're bypassing AOTAutogradCache anyway, so return the graph
# as-is
if not config.autograd_cache_normalize_inputs or not hasattr(gm, "graph"):
yield
return
# Track all the old state of placeholders
old_placeholder_names = []
old_used_names = copy(gm.graph._graph_namespace._used_names)
i = 0
for n in gm.graph.find_nodes(op="placeholder", sort=True):
if n.type != torch.SymInt:
# _rename renames the node in the body of the function,
# but it doesn't change the raw name from node.target
# So we also set the raw_name of node.target to a new placeholder name
new_placeholder_name = f"p_{i}"
old_placeholder_names.append((n.name, n.target))
n.target = new_placeholder_name
n._rename(new_placeholder_name)
i += 1
gm.recompile()
try:
yield
finally:
# Used_names contains all our old placeholder names,
# so we clear it temporarily when we put them back
gm.graph._graph_namespace._used_names = set()
# Restore the placeholder names
i = 0
for n in gm.graph.find_nodes(op="placeholder", sort=True):
if n.type != torch.SymInt:
(name, target) = old_placeholder_names[i]
n.target = target
n._rename(name)
i += 1
assert i == len(old_placeholder_names)
# Now restore the old namespace's used names
gm.graph._graph_namespace._used_names = old_used_names
gm.recompile()
def autograd_cache_key(
gm: torch.fx.GraphModule,
example_inputs,
@ -407,7 +458,6 @@ def autograd_cache_key(
if triton.__version__ < "3.2.0":
raise BypassAOTAutogradCache("AOTAutogradCache requires triton 3.2.0")
details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config)
pickler = AOTAutogradCachePickler(gm)
# The prefix distinguishes among the other kinds of objects we cache
@ -924,21 +974,22 @@ def sanitize_gm_for_cache(gm: torch.fx.GraphModule):
and then put them back before returning. This way, we generate a cache key based off of a canonical graph
without these fields, and also guarantee they aren't used to affect the cache's output.
"""
IGNORED_FIELDS = (
"meta", # metadata used by export
"compile_subgraph_reason", # Used by dynamo only for logging, no change in inductor/autograd behavior
"_param_name_to_source", # Encapsulated by aot_config.aot_autograd_arg_pos_to_source
"_backend_id",
)
# Mapping from each field to a default value
IGNORED_FIELDS: dict[str, Any] = {
"meta": {}, # metadata used by export
"compile_subgraph_reason": None, # Used by dynamo only for logging, no change in inductor/autograd behavior
"_param_name_to_source": None, # Encapsulated by aot_config.aot_autograd_arg_pos_to_source
"_backend_id": None,
}
saved_fields = {}
for field in IGNORED_FIELDS:
for field, default_value in IGNORED_FIELDS.items():
saved_fields[field] = getattr(gm, field, None)
# Clear the field
setattr(gm, field, None)
setattr(gm, field, default_value)
try:
yield
with normalize_placeholder_names(gm):
yield
finally:
# Put the fields back after dispatch_and_compile is complete
for field, value in saved_fields.items():
setattr(gm, field, value)

View File

@ -61,6 +61,10 @@ autograd_cache_allow_custom_autograd_functions: bool = Config(
# need to add env vars or make it configurable
bundled_autograd_cache: bool = False
# Whether or not to normalize placeholder names in graphs
# from dynaom in AOTAutogradCache
autograd_cache_normalize_inputs = not is_fbcode()
def remote_autograd_cache_default() -> Optional[bool]:
if os.environ.get("TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE") == "1":

View File

@ -2,7 +2,6 @@
import itertools
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
from unittest.mock import patch
import sympy
from sympy.parsing.sympy_parser import parse_expr
@ -19,7 +18,7 @@ from ..select_algorithm import PartialRender
from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix
from ..virtualized import V
from .common import REMOVED
from .cpp import CppKernel, CppKernelProxy, KernelGroup, ParallelDepth
from .cpp import CppKernel, CppKernelProxy, KernelGroup
from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext
@ -289,15 +288,7 @@ class CppTemplateKernel(CppKernel):
var_sizes_list.append(var_sizes)
cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
def max_parallel_depth():
return ParallelDepth(parallel_depth=0, start_depth=0)
# This loop is not parallelized since it is not the outermost loop.
with patch.object(
cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth
):
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
return kernel_group.loops_code.getvalue()
def store_grouped_gemm_pointwise_nodes(
@ -351,15 +342,7 @@ class CppTemplateKernel(CppKernel):
var_sizes_list.append(var_sizes)
cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
def max_parallel_depth():
return ParallelDepth(parallel_depth=0, start_depth=0)
# This loop is not parallelized since it is not the outermost loop.
with patch.object(
cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth
):
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
return kernel_group.loops_code.getvalue()
def store_output(

View File

@ -45,7 +45,10 @@ def move_cutlass_compiled_cache() -> None:
else:
import cutlass as python_cutlass # type: ignore[import-not-found] # noqa: F401
if not os.path.exists(python_cutlass.CACHE_FILE):
# Check if the CACHE_FILE attribute exists in python_cutlass and if the file exists
if not hasattr(python_cutlass, "CACHE_FILE") or not os.path.exists(
python_cutlass.CACHE_FILE
):
return
try:

View File

@ -1019,7 +1019,7 @@ class cpp:
dynamic_threads = os.environ.get("TORCHINDUCTOR_CPP_DYNAMIC_THREADS", "0") == "1"
simdlen: Optional[int] = None
min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "512"))
min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096"))
cxx: tuple[Literal[None], str] = (
None, # download gcc12 from conda-forge if conda is installed

View File

@ -20,8 +20,8 @@
#include <ATen/cuda/detail/CUDAHooks.h>
#include <ATen/cuda/jiterator.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <c10/core/AllocatorConfig.h>
#include <c10/core/StorageImpl.h>
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAFunctions.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
@ -426,7 +426,8 @@ PyObject* THCPModule_cudaCachingAllocator_set_allocator_settings(
PyObject* _unused,
PyObject* env) {
HANDLE_TH_ERRORS
c10::CachingAllocator::setAllocatorSettings(THPUtils_unpackString(env));
c10::cuda::CUDACachingAllocator::setAllocatorSettings(
THPUtils_unpackString(env));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}

View File

@ -46,6 +46,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
// backend name
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::string backend;
std::string group_name;
};
explicit Backend(int rank, int size);
@ -105,6 +106,14 @@ class TORCH_API Backend : public torch::CustomClassHolder {
TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented.");
}
// Subclasses must override this method to return the backend name
virtual c10::intrusive_ptr<Options> getBackendOptions() {
TORCH_CHECK(
false,
c10::str(
"Backend ", getBackendName(), " does not implement endCoalescing"));
}
virtual c10::intrusive_ptr<Work> broadcast(
std::vector<at::Tensor>& /* tensors */,
const BroadcastOptions& /* opts */ = BroadcastOptions()) {
@ -379,6 +388,16 @@ class TORCH_API Backend : public torch::CustomClassHolder {
" is missing implementation of enableCollectivesTiming.");
}
virtual c10::intrusive_ptr<Backend> splitBackend(
const std::vector<int>& ranks,
const c10::intrusive_ptr<Options> opts) {
TORCH_CHECK(
false,
"Backend ",
getBackendName(),
" is missing implementation of splitBackend.");
}
bool hasHooks() const {
return onCompletionHook_ != nullptr;
}

View File

@ -573,6 +573,27 @@ size_t hashTensors(const std::vector<at::Tensor>& tensors) {
return hash;
}
// NCCL uses Non-negative int to represent in-group according to API
// requirement. We take a list of ranks and generate a hash value based on the
// list and ensure its range of 32-bit int.
int genNcclSplitColor(const std::vector<int>& ranks) {
// Combine the hash values using a simple reducer (std::hash + fold)
std::size_t combined_hash = std::accumulate(
ranks.begin(),
ranks.end(),
std::size_t(0),
[](std::size_t acc, int rank) {
return acc ^
(std::hash<int>{}(rank) + 0x9e3779b9 + (acc << 6) + (acc >> 2));
});
// max positive value of int32_t
constexpr int32_t max_c_int = std::numeric_limits<int32_t>::max();
int color = static_cast<int>(
std::abs(static_cast<int64_t>(combined_hash)) % max_c_int);
return color;
}
// Default value: 30 minutes
int nccl_nonblocking_timeout() {
static int timeout = -2; // -2 means not initialized

View File

@ -231,6 +231,7 @@ static std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
};
TORCH_API size_t hashTensors(const std::vector<at::Tensor>& tensors);
TORCH_API int genNcclSplitColor(const std::vector<int>& ranks);
TORCH_API std::string getNcclVersion();
TORCH_API std::tuple<int, int, int> getNcclVersionTuple();
TORCH_API int getNcclVersionNumber();

View File

@ -4,6 +4,7 @@
#include <c10/util/Logging.h>
#include <fmt/format.h>
#include <fmt/ranges.h>
#include <string_view>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
@ -158,6 +159,63 @@ void ProcessGroup::release_resources() {
backendTypeToBackend_.clear();
}
c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
const std::vector<int>& ranks,
const std::optional<std::chrono::milliseconds> timeout,
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
const std::optional<std::string>& desc) {
TORCH_CHECK(
ranks.size() > 0,
"Split ranks cannot be empty. Please provide a non-empty list of ranks to split the group.");
TORCH_CHECK(
ranks.size() < static_cast<size_t>(size_),
"the split group's size should be less than the world_size set by init_process_group");
std::set<int> ranks_set(ranks.begin(), ranks.end());
TORCH_CHECK(
ranks_set.size() == ranks.size(),
"Split ranks should not have duplicates. Please provide a list of unique ranks to split the group.");
std::vector<int> sorted_ranks = ranks;
std::sort(sorted_ranks.begin(), sorted_ranks.end());
c10::intrusive_ptr<ProcessGroup> newGroup;
// TODO: Figure out a better way for split group name.
std::string groupName =
c10::str(getGroupName(), ":split:", fmt::format("{}", sorted_ranks));
for (const auto& pair : deviceTypeToBackendType_) {
c10::DeviceType deviceType = pair.first;
BackendType backendType = pair.second;
auto parentBackend = getBackend(deviceType);
auto backendOpts =
opts.has_value() ? opts.value() : parentBackend->getBackendOptions();
backendOpts->group_name = groupName;
backendOpts->timeout =
timeout.has_value() ? timeout.value() : backendOpts->timeout;
auto splitBackend = parentBackend->splitBackend(sorted_ranks, backendOpts);
if (splitBackend == nullptr) {
continue;
}
// TODO: Figure out a better way for split group desc.
// TODO: We can add a new field in Backend::Options to specify the group
// desc
std::string groupDesc = desc.has_value()
? desc.value()
: c10::str(getGroupDesc(), ":split:", incrementSplitCount());
splitBackend->setGroupDesc(groupDesc);
if (!newGroup) {
newGroup = c10::make_intrusive<ProcessGroup>(
store_->clone(), splitBackend->getRank(), splitBackend->getSize());
newGroup->setDefaultBackend(backendType_);
newGroup->setGroupName(groupName);
newGroup->setGroupDesc(groupDesc);
}
newGroup->setBackend(deviceType, backendType, splitBackend);
}
return newGroup;
}
} // namespace c10d
namespace {

View File

@ -170,6 +170,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
}
}
int64_t incrementSplitCount() {
return splitCounter_++;
}
virtual void startCoalescing(c10::DeviceType deviceType) {
// only nccl has implemented startCoalescing so only execute for nccl
// backends
@ -955,6 +959,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
bound_device_id_ = device;
}
// This creates a new subgroup using the specified ranks.
// The current rank must be included in the list of new_ranks.
virtual c10::intrusive_ptr<ProcessGroup> splitGroup(
const std::vector<int>& ranks,
const std::optional<std::chrono::milliseconds> timeout,
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
const std::optional<std::string>& groupDesc);
protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.
@ -968,6 +980,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
BackendType backendType_;
std::string pg_desc_;
int64_t splitCounter_;
// Debug level setting. It is parsed once when ProcessGroup is constructed and
// remains the same across use of this process group.

View File

@ -697,6 +697,35 @@ const std::vector<uint64_t>& ProcessGroupGloo::groupRanks() const {
return options_->global_ranks_in_group;
}
c10::intrusive_ptr<Backend> ProcessGroupGloo::splitBackend(
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options> opts) {
auto it = std::find(ranks.begin(), ranks.end(), rank_);
int groupRank;
if (it == ranks.end()) {
return nullptr;
} else {
groupRank = std::distance(ranks.begin(), it);
}
auto glooOpts = c10::dynamic_intrusive_pointer_cast<Options>(opts);
TORCH_CHECK(glooOpts != nullptr, "opts not a ProcessGroupGloo::Options.");
// TODO: we need to get rid of globalRanksInGroup eventually.
std::vector<uint64_t> globalRanksInGroup;
for (auto rank : ranks) {
globalRanksInGroup.emplace_back(groupRanks()[rank]);
}
glooOpts->global_ranks_in_group = std::move(globalRanksInGroup);
auto store = std::dynamic_pointer_cast<GlooStore>(store_);
TORCH_CHECK(
store != nullptr,
"store inside ProcessGroupGloo not a ProcessGroupGloo::GlooStore.");
auto pg = c10::make_intrusive<ProcessGroupGloo>(
store->_getStore()->clone(), groupRank, ranks.size(), glooOpts);
return c10::static_intrusive_pointer_cast<Backend>(pg);
}
void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
std::unique_lock<std::mutex> lock(workMutex_);
pgStatus_->lastEnqueuedSeq = static_cast<int64_t>(work->seq_);

View File

@ -188,6 +188,10 @@ class TORCH_API ProcessGroupGloo : public Backend {
}
#endif
const c10::intrusive_ptr<::c10d::Store>& _getStore() const {
return store_;
}
protected:
c10::intrusive_ptr<::c10d::Store> store_;
};
@ -252,7 +256,6 @@ class TORCH_API ProcessGroupGloo : public Backend {
}
std::vector<uint64_t> global_ranks_in_group;
std::string group_name;
std::vector<std::shared_ptr<::gloo::transport::Device>> devices;
int threads;
};
@ -301,6 +304,14 @@ class TORCH_API ProcessGroupGloo : public Backend {
}
}
c10::intrusive_ptr<Backend::Options> getBackendOptions() override {
return c10::static_intrusive_pointer_cast<Backend::Options>(options_);
}
c10::intrusive_ptr<Backend> splitBackend(
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options> opts) override;
const std::vector<uint64_t>& groupRanks() const;
c10::intrusive_ptr<Work> broadcast(

View File

@ -1311,6 +1311,45 @@ void ProcessGroupNCCL::enableCollectivesTiming() {
enableTiming_.store(true);
}
c10::intrusive_ptr<Backend> ProcessGroupNCCL::splitBackend(
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options> opts) {
auto deviceIdx = guessDeviceId();
TORCH_CHECK(
deviceIdx >= 0,
"ProcessGroupNCCL::splitBackend: rank ",
rank_,
" has no device is bound to this rank.");
auto device = at::Device(at::DeviceType::CUDA, deviceIdx);
auto it = std::find(ranks.begin(), ranks.end(), rank_);
int groupRank;
if (it == ranks.end()) {
// This rank is not in the new group, so no_color split should be called
performNocolorSplit(device);
return nullptr;
} else {
groupRank = std::distance(ranks.begin(), it);
}
auto ncclOpts = c10::dynamic_intrusive_pointer_cast<Options>(opts);
TORCH_CHECK(ncclOpts != nullptr, "opts not a ProcessGroupNCCL::Options.");
// TODO: we need to get rid of globalRanksInGroup eventually.
std::vector<uint64_t> globalRanksInGroup;
for (auto rank : ranks) {
globalRanksInGroup.emplace_back(groupRanks()[rank]);
}
ncclOpts->split_from =
c10::intrusive_ptr<ProcessGroupNCCL>::unsafe_reclaim_from_nonowning(this);
ncclOpts->global_ranks_in_group = std::move(globalRanksInGroup);
auto color = genNcclSplitColor(ranks);
ncclOpts->split_color = color;
auto pg = c10::make_intrusive<ProcessGroupNCCL>(
store_->clone(), groupRank, ranks.size(), ncclOpts);
pg->eagerConnectSingleDevice(device);
return c10::static_intrusive_pointer_cast<Backend>(pg);
}
bool ProcessGroupNCCL::waitForFutureOrTimeout(
std::future<bool>& fut,
const std::chrono::milliseconds& timeOutMilSec,

View File

@ -541,7 +541,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Optional "parent" backend and color to create communicators from
// via `ncclCommSplit`
std::shared_ptr<ProcessGroupNCCL> split_from;
c10::intrusive_ptr<ProcessGroupNCCL> split_from;
// Color to use for `ncclCommSplit`, values:
// * Non-negative value: in group;
// * NCCL_SPLIT_NOCOLOR (-1): not in group;
@ -562,7 +562,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
int split_color{-2};
#endif
std::vector<uint64_t> global_ranks_in_group;
std::string group_name;
};
// Helper class related to TORCH_NCCL_DESYNC_DEBUG
@ -804,6 +803,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
return options_;
}
c10::intrusive_ptr<Backend::Options> getBackendOptions() override {
return c10::static_intrusive_pointer_cast<Backend::Options>(options_);
}
const std::string getBackendName() const override {
return std::string(NCCL_BACKEND_NAME);
}
@ -972,6 +975,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
void enableCollectivesTiming() override;
c10::intrusive_ptr<Backend> splitBackend(
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options> opts) override;
// Helper function for iteratively aborting communicators in the provided map
void abortCommsFromMap(
std::unordered_map<std::string, std::shared_ptr<NCCLComm>>& ncclCommsMap,

Some files were not shown because too many files have changed in this diff Show More