mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 14:54:55 +08:00
Compare commits
1 Commits
documentat
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| 854b40f81f |
@ -36,7 +36,11 @@ case ${DOCKER_TAG_PREFIX} in
|
||||
;;
|
||||
rocm*)
|
||||
BASE_TARGET=rocm
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||
if [[ "$ROCM_VERSION" == *"7.0"* ]]; then
|
||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||
fi
|
||||
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
|
||||
;;
|
||||
*)
|
||||
|
||||
@ -260,12 +260,6 @@ case "$tag" in
|
||||
HALIDE=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas)
|
||||
CUDA_VERSION=12.8.1
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
GCC_VERSION=11
|
||||
PALLAS=yes
|
||||
;;
|
||||
pytorch-linux-jammy-py3.12-triton-cpu)
|
||||
CUDA_VERSION=12.6
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
@ -387,7 +381,6 @@ docker build \
|
||||
--build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \
|
||||
--build-arg "EXECUTORCH=${EXECUTORCH}" \
|
||||
--build-arg "HALIDE=${HALIDE}" \
|
||||
--build-arg "PALLAS=${PALLAS}" \
|
||||
--build-arg "XPU_VERSION=${XPU_VERSION}" \
|
||||
--build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \
|
||||
--build-arg "ACL=${ACL:-}" \
|
||||
|
||||
@ -1 +0,0 @@
|
||||
0.8.0
|
||||
@ -1,40 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
|
||||
|
||||
# Get the pinned JAX version (same for all CUDA versions)
|
||||
JAX_VERSION=$(get_pinned_commit /ci_commit_pins/jax)
|
||||
|
||||
function install_jax_12() {
|
||||
echo "Installing JAX ${JAX_VERSION} with CUDA 12 support"
|
||||
pip_install "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Verify installation
|
||||
python -c "import jax" # check for errors
|
||||
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 12"
|
||||
}
|
||||
|
||||
function install_jax_13() {
|
||||
echo "Installing JAX ${JAX_VERSION} with CUDA 13 support"
|
||||
pip_install "jax[cuda13]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Verify installation
|
||||
python -c "import jax" # check for errors
|
||||
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 13"
|
||||
}
|
||||
|
||||
# idiomatic parameter and option handling in sh
|
||||
while test $# -gt 0
|
||||
do
|
||||
case "$1" in
|
||||
12.4|12.6|12.6.*|12.8|12.8.*|12.9|12.9.*) install_jax_12;
|
||||
;;
|
||||
13.0|13.0.*) install_jax_13;
|
||||
;;
|
||||
*) echo "bad argument $1"; exit 1
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
@ -49,7 +49,11 @@ case ${DOCKER_TAG_PREFIX} in
|
||||
fi
|
||||
BASE_TARGET=rocm
|
||||
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
|
||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||
fi
|
||||
DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}"
|
||||
;;
|
||||
*)
|
||||
|
||||
@ -87,7 +87,11 @@ case ${image} in
|
||||
MANY_LINUX_VERSION="2_28"
|
||||
DEVTOOLSET_VERSION="11"
|
||||
GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
|
||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||
fi
|
||||
DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}"
|
||||
;;
|
||||
manylinux2_28-builder:xpu)
|
||||
|
||||
@ -143,15 +143,6 @@ COPY ci_commit_pins/halide.txt halide.txt
|
||||
RUN if [ -n "${HALIDE}" ]; then bash ./install_halide.sh; fi
|
||||
RUN rm install_halide.sh common_utils.sh halide.txt
|
||||
|
||||
ARG PALLAS
|
||||
ARG CUDA_VERSION
|
||||
# Install JAX with CUDA support (for Pallas)
|
||||
COPY ./common/install_jax.sh install_jax.sh
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ./ci_commit_pins/jax.txt /ci_commit_pins/jax.txt
|
||||
RUN if [ -n "${PALLAS}" ]; then bash ./install_jax.sh ${CUDA_VERSION}; fi
|
||||
RUN rm -f install_jax.sh common_utils.sh /ci_commit_pins/jax.txt
|
||||
|
||||
ARG ONNX
|
||||
# Install ONNX dependencies
|
||||
COPY ./common/install_onnx.sh ./common/common_utils.sh ./
|
||||
|
||||
@ -8,11 +8,9 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
try:
|
||||
from collections.abc import Callable # Python 3.11+
|
||||
from typing import Any, Required, TypedDict
|
||||
from typing import Any, Callable, Required, TypedDict # Python 3.11+
|
||||
except ImportError:
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, Callable, TypedDict
|
||||
|
||||
from typing_extensions import Required # Fallback for Python <3.11
|
||||
|
||||
|
||||
@ -168,16 +168,14 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/compiler/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/umf/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/pti/latest/env/vars.sh
|
||||
# Enable XCCL build
|
||||
export USE_XCCL=1
|
||||
export USE_MPI=0
|
||||
# XPU kineto feature dependencies are not fully ready, disable kineto build as temp WA
|
||||
export USE_KINETO=0
|
||||
export TORCH_XPU_ARCH_LIST=pvc
|
||||
fi
|
||||
|
||||
|
||||
@ -208,8 +208,6 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
||||
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/pti/latest/env/vars.sh
|
||||
# Check XPU status before testing
|
||||
timeout 30 xpu-smi discovery || true
|
||||
fi
|
||||
@ -826,11 +824,6 @@ test_inductor_halide() {
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_inductor_pallas() {
|
||||
python test/run_test.py --include inductor/test_pallas.py --verbose
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_inductor_triton_cpu() {
|
||||
python test/run_test.py --include inductor/test_triton_cpu_backend.py inductor/test_torchinductor_strided_blocks.py --verbose
|
||||
assert_git_not_dirty
|
||||
@ -1731,8 +1724,6 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
|
||||
test_inductor_distributed
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
|
||||
test_inductor_halide
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-pallas* ]]; then
|
||||
test_inductor_pallas
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
|
||||
test_inductor_triton_cpu
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
|
||||
|
||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
ccb801b88af136454798b945175c4c87e636ac33
|
||||
ca2212438fdd8ce29b66999ed70ed54b0f9372d1
|
||||
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
|
||||
c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9
|
||||
|
||||
9
.github/labeler.yml
vendored
9
.github/labeler.yml
vendored
@ -138,8 +138,7 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- torch/**/*cublas*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
@ -149,8 +148,7 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- torch/**/*cublas*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
@ -160,8 +158,7 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
- third_party/fbgemm
|
||||
|
||||
1
.github/nitpicks.yml
vendored
1
.github/nitpicks.yml
vendored
@ -10,4 +10,3 @@
|
||||
pathFilter:
|
||||
- 'torch/csrc/inductor/aoti_torch/c/*'
|
||||
- 'torch/csrc/inductor/aoti_torch/generated/*'
|
||||
- 'torch/csrc/stable/c/*'
|
||||
|
||||
3
.github/scripts/delete_old_branches.py
vendored
3
.github/scripts/delete_old_branches.py
vendored
@ -1,11 +1,10 @@
|
||||
# Delete old branches
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from github_utils import gh_fetch_json_dict, gh_graphql
|
||||
from gitutils import GitRepo
|
||||
|
||||
3
.github/scripts/filter_test_configs.py
vendored
3
.github/scripts/filter_test_configs.py
vendored
@ -8,11 +8,10 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from functools import cache
|
||||
from logging import info
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import yaml
|
||||
|
||||
3
.github/scripts/get_workflow_job_id.py
vendored
3
.github/scripts/get_workflow_job_id.py
vendored
@ -11,8 +11,7 @@ import sys
|
||||
import time
|
||||
import urllib
|
||||
import urllib.parse
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
|
||||
|
||||
3
.github/scripts/github_utils.py
vendored
3
.github/scripts/github_utils.py
vendored
@ -3,9 +3,8 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast, Optional, Union
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
from urllib.error import HTTPError
|
||||
from urllib.parse import quote
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
4
.github/scripts/gitutils.py
vendored
4
.github/scripts/gitutils.py
vendored
@ -4,10 +4,10 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterator
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Any, cast, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
4
.github/scripts/trymerge.py
vendored
4
.github/scripts/trymerge.py
vendored
@ -17,12 +17,12 @@ import re
|
||||
import time
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterable
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from typing import Any, cast, NamedTuple, Optional
|
||||
from typing import Any, Callable, cast, NamedTuple, Optional
|
||||
from warnings import warn
|
||||
|
||||
import yaml
|
||||
|
||||
1
.github/workflows/docker-builds.yml
vendored
1
.github/workflows/docker-builds.yml
vendored
@ -67,7 +67,6 @@ jobs:
|
||||
pytorch-linux-jammy-py3.10-gcc11,
|
||||
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3.12-halide,
|
||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas,
|
||||
pytorch-linux-jammy-xpu-n-1-py3,
|
||||
pytorch-linux-noble-xpu-n-py3,
|
||||
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,
|
||||
|
||||
26
.github/workflows/inductor-unittest.yml
vendored
26
.github/workflows/inductor-unittest.yml
vendored
@ -81,32 +81,6 @@ jobs:
|
||||
test-matrix: ${{ needs.inductor-halide-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
inductor-pallas-build:
|
||||
name: inductor-pallas-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.12-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-py3.12-pallas
|
||||
cuda-arch-list: '8.9'
|
||||
runner: linux.8xlarge.memory
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
inductor-pallas-test:
|
||||
name: inductor-pallas-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: inductor-pallas-build
|
||||
with:
|
||||
build-environment: linux-jammy-py3.12-gcc11
|
||||
docker-image: ${{ needs.inductor-pallas-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.inductor-pallas-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
inductor-triton-cpu-build:
|
||||
name: inductor-triton-cpu-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
||||
@ -1402,7 +1402,7 @@ init_command = [
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'usort==1.0.8.post1',
|
||||
'isort==6.0.1',
|
||||
'ruff==0.14.4', # sync with RUFF
|
||||
'ruff==0.13.1', # sync with RUFF
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
@ -1537,7 +1537,7 @@ init_command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'ruff==0.14.4', # sync with PYFMT
|
||||
'ruff==0.13.1', # sync with PYFMT
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
|
||||
@ -210,12 +210,8 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
|
||||
/test/inductor/test_flex_attention.py @drisspg
|
||||
/test/inductor/test_flex_decoding.py @drisspg
|
||||
|
||||
# Low Precision & Grouped GEMMs
|
||||
# Low Precision GEMMs
|
||||
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/native/cuda/GroupedBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/native/cuda/ScaledBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDAScaledBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDAScaledBlas.h @drisspg @slayton58
|
||||
/test/test_scaled_matmul_cuda.py @drisspg @slayton58
|
||||
|
||||
@ -94,11 +94,6 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
|
||||
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
||||
}
|
||||
|
||||
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
|
||||
c10::DeviceIndex device_index) {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
|
||||
}
|
||||
} // namespace at::accelerator
|
||||
|
||||
namespace at {
|
||||
|
||||
@ -226,8 +226,8 @@ template <
|
||||
typename B = HostBlock<S>>
|
||||
struct CachingHostAllocatorImpl {
|
||||
virtual ~CachingHostAllocatorImpl() {
|
||||
if (active_) {
|
||||
active_ = false;
|
||||
active_ = false;
|
||||
if (pinned_use_background_threads()) {
|
||||
getBackgroundThreadPool()->waitWorkComplete();
|
||||
}
|
||||
}
|
||||
@ -260,7 +260,6 @@ struct CachingHostAllocatorImpl {
|
||||
if (pinned_use_background_threads()) {
|
||||
// Launch the background thread and process events in a loop.
|
||||
static bool background_thread_flag [[maybe_unused]] = [this] {
|
||||
active_ = true;
|
||||
getBackgroundThreadPool()->run([&]() {
|
||||
while (active_) {
|
||||
process_events();
|
||||
@ -684,9 +683,9 @@ struct CachingHostAllocatorImpl {
|
||||
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
|
||||
std::deque<std::pair<E, B*>> events_; // event queue paired with block
|
||||
|
||||
// Indicates whether the event-processing thread pool is active.
|
||||
// Indicates whether the object is active.
|
||||
// Set to false in the destructor to signal background threads to stop.
|
||||
std::atomic<bool> active_{false};
|
||||
std::atomic<bool> active_{true};
|
||||
protected:
|
||||
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
|
||||
};
|
||||
|
||||
@ -157,8 +157,6 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
|
||||
DispatchKey::Negative,
|
||||
DispatchKey::Conjugate,
|
||||
DispatchKey::XLA,
|
||||
DispatchKey::XPU,
|
||||
DispatchKey::HPU,
|
||||
DispatchKey::CUDA,
|
||||
DispatchKey::CPU,
|
||||
DispatchKey::PrivateUse1,
|
||||
|
||||
@ -141,9 +141,6 @@ static Tensor& addmv_out_mps_impl(const Tensor& self,
|
||||
};
|
||||
|
||||
MPSStream* stream = at::mps::getCurrentMPSStream();
|
||||
if (result.numel() == 0) {
|
||||
return result;
|
||||
}
|
||||
Tensor matMulVec = at::mm(mat, vec.unsqueeze(1)).squeeze(1);
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
@ -2803,7 +2803,7 @@
|
||||
- func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
dispatch:
|
||||
CPU, CUDA, MPS, MTIA: floor_divide_out
|
||||
CPU, CUDA, MPS: floor_divide_out
|
||||
SparseCPU, SparseCUDA, SparseMPS: floor_divide_out_sparse_zerodim
|
||||
|
||||
- func: floor_divide.Scalar(Tensor self, Scalar other) -> Tensor
|
||||
@ -4292,7 +4292,6 @@
|
||||
dispatch:
|
||||
SparseCPU: sparse_sparse_matmul_cpu
|
||||
SparseCUDA: sparse_sparse_matmul_cuda
|
||||
SparseMPS: sparse_sparse_matmul_mps
|
||||
autogen: _sparse_sparse_matmul.out
|
||||
|
||||
- func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
|
||||
@ -4384,7 +4383,7 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: mv
|
||||
SparseCPU, SparseCUDA, SparseMPS: mv_sparse
|
||||
SparseCPU, SparseCUDA: mv_sparse
|
||||
|
||||
- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
@ -9833,7 +9832,7 @@
|
||||
structured_delegate: erfinv.out
|
||||
variants: method, function
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse
|
||||
SparseCPU, SparseCUDA: erfinv_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr
|
||||
tags: pointwise
|
||||
|
||||
@ -9842,7 +9841,7 @@
|
||||
structured_delegate: erfinv.out
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_
|
||||
SparseCPU, SparseCUDA: erfinv_sparse_
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_
|
||||
tags: pointwise
|
||||
|
||||
@ -9852,7 +9851,7 @@
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: erfinv_out
|
||||
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_out
|
||||
SparseCPU, SparseCUDA: erfinv_sparse_out
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_out
|
||||
tags: pointwise
|
||||
|
||||
|
||||
@ -10,10 +10,6 @@
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_coalesce_native.h>
|
||||
#include <ATen/ops/repeat_interleave_native.h>
|
||||
#include <ATen/ops/cumsum.h>
|
||||
#include <ATen/ops/_sparse_sparse_matmul_native.h>
|
||||
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
|
||||
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
|
||||
#include <ATen/ops/cat.h>
|
||||
#include <ATen/ops/add_native.h>
|
||||
@ -892,114 +888,5 @@ static void sparse_mask_intersection_out_mps_kernel(
|
||||
/*coalesce_mask=*/false);
|
||||
}
|
||||
|
||||
Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
|
||||
TORCH_CHECK(mat1_.is_sparse() && mat2_.is_sparse(),
|
||||
"sparse_sparse_matmul_mps: both inputs must be sparse COO tensors");
|
||||
TORCH_CHECK(mat1_.is_mps() && mat2_.is_mps(),
|
||||
"sparse_sparse_matmul_mps: both inputs must be on MPS device");
|
||||
TORCH_CHECK(mat1_.dim() == 2 && mat2_.dim() == 2,
|
||||
"sparse_sparse_matmul_mps: both inputs must be 2D matrices");
|
||||
TORCH_CHECK(mat1_.dense_dim() == 0 && mat2_.dense_dim() == 0,
|
||||
"sparse_sparse_matmul_mps: only scalar values supported (dense_dim == 0)");
|
||||
TORCH_CHECK(mat1_.size(1) == mat2_.size(0),
|
||||
"mat1 and mat2 shapes cannot be multiplied (", mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
|
||||
TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
|
||||
"sparse_sparse_matmul_mps: mat1 dtype ", mat1_.scalar_type(),
|
||||
" does not match mat2 dtype ", mat2_.scalar_type());
|
||||
|
||||
const auto device = mat1_.device();
|
||||
|
||||
auto A = mat1_.coalesce();
|
||||
auto B = mat2_.coalesce();
|
||||
|
||||
const auto I = A.size(0);
|
||||
const auto K = A.size(1);
|
||||
const auto N = B.size(1);
|
||||
|
||||
const auto nnzA = A._nnz();
|
||||
const auto nnzB = B._nnz();
|
||||
|
||||
// Early empty result, return an empty, coalesced tensor
|
||||
if (I == 0 || N == 0 || K == 0 || nnzA == 0 || nnzB == 0) {
|
||||
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
|
||||
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
|
||||
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
|
||||
out._coalesced_(true);
|
||||
return out;
|
||||
}
|
||||
|
||||
const auto computeDtype = at::result_type(mat1_, mat2_);
|
||||
|
||||
auto A_idx = A._indices().contiguous();
|
||||
auto A_val = A._values().to(computeDtype).contiguous();
|
||||
auto A_i = A_idx.select(0, 0).contiguous();
|
||||
auto A_k = A_idx.select(0, 1).contiguous();
|
||||
|
||||
auto B_idx = B._indices().contiguous();
|
||||
auto B_val = B._values().to(computeDtype).contiguous();
|
||||
auto B_k = B_idx.select(0, 0).contiguous();
|
||||
auto B_j = B_idx.select(0, 1).contiguous();
|
||||
|
||||
// csr-style row pointers for B by k (the shared dimension)
|
||||
Tensor row_ptr_B;
|
||||
{
|
||||
auto batch_ptr = at::tensor({0LL, nnzB}, at::device(device).dtype(at::kLong));
|
||||
row_ptr_B = at::empty({K + 1}, at::device(device).dtype(at::kLong));
|
||||
build_row_ptr_per_batch_mps(B_k, batch_ptr, /*B=*/1, /*I=*/K, row_ptr_B);
|
||||
}
|
||||
|
||||
auto row_ptr_B_lo = row_ptr_B.narrow(0, 0, K);
|
||||
auto row_ptr_B_hi = row_ptr_B.narrow(0, 1, K);
|
||||
auto deg_B = row_ptr_B_hi.sub(row_ptr_B_lo);
|
||||
|
||||
auto counts = deg_B.index_select(0, A_k);
|
||||
|
||||
const int64_t P = counts.sum().item<int64_t>();
|
||||
if (P == 0) {
|
||||
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
|
||||
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
|
||||
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
|
||||
out._coalesced_(true);
|
||||
return out;
|
||||
}
|
||||
|
||||
auto group_ids = repeat_interleave_mps(counts);
|
||||
|
||||
// exclusive cumsum of counts
|
||||
auto offsets = cumsum(counts, /*dim=*/0).sub(counts);
|
||||
auto offsets_gather = offsets.index_select(0, group_ids);
|
||||
auto within = at::arange(P, at::device(device).dtype(at::kLong)).sub(offsets_gather);
|
||||
|
||||
// Map each output element to its source B row and position
|
||||
auto k_per_out = A_k.index_select(0, group_ids);
|
||||
auto start_in_B = row_ptr_B.index_select(0, k_per_out);
|
||||
auto seg_index = start_in_B.add(within);
|
||||
|
||||
// Assemble candidate coo pairs and values
|
||||
auto i_out = A_i.index_select(0, group_ids).contiguous();
|
||||
auto j_out = B_j.index_select(0, seg_index).contiguous();
|
||||
auto vA_out = A_val.index_select(0, group_ids).contiguous();
|
||||
auto vB_out = B_val.index_select(0, seg_index).contiguous();
|
||||
auto v_out = vA_out.mul(vB_out);
|
||||
|
||||
// build (2, P) indices
|
||||
auto out_indices = at::empty({2, P}, at::device(device).dtype(at::kLong)).contiguous();
|
||||
out_indices.select(0, 0).copy_(i_out);
|
||||
out_indices.select(0, 1).copy_(j_out);
|
||||
|
||||
auto result = _sparse_coo_tensor_unsafe(
|
||||
out_indices, v_out, {I, N}, mat1_.options().dtype(computeDtype));
|
||||
|
||||
result = result.coalesce();
|
||||
|
||||
if (result.scalar_type() != mat1_.scalar_type()) {
|
||||
auto cast_vals = result._values().to(mat1_.scalar_type());
|
||||
auto out = _sparse_coo_tensor_unsafe(result._indices(), cast_vals, {I, N}, mat1_.options());
|
||||
out._coalesced_(true);
|
||||
return out;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
|
||||
} // namespace at::native
|
||||
@ -52,18 +52,19 @@ def test_sparse_coo_and_csr(m, n, k, nnz, test_count):
|
||||
start.record()
|
||||
coo.matmul(mat)
|
||||
stop.record()
|
||||
|
||||
times.append(start.elapsed_time(stop))
|
||||
|
||||
coo_mean_time = sum(times) / len(times)
|
||||
coo_mean_time = sum(times) / len(times)
|
||||
|
||||
times = []
|
||||
for _ in range(test_count):
|
||||
start.record()
|
||||
csr.matmul(mat)
|
||||
stop.record()
|
||||
times.append(start.elapsed_time(stop))
|
||||
times = []
|
||||
for _ in range(test_count):
|
||||
start.record()
|
||||
csr.matmul(mat)
|
||||
stop.record()
|
||||
times.append(start.elapsed_time(stop))
|
||||
|
||||
csr_mean_time = sum(times) / len(times)
|
||||
csr_mean_time = sum(times) / len(times)
|
||||
|
||||
return coo_mean_time, csr_mean_time
|
||||
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <optional>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
@ -17,8 +15,7 @@ struct C10_API AutogradState {
|
||||
bool inference_mode,
|
||||
bool fw_grad_mode,
|
||||
bool multithreading_enabled)
|
||||
: graph_exec_group_(std::nullopt),
|
||||
grad_mode_(grad_mode),
|
||||
: grad_mode_(grad_mode),
|
||||
inference_mode_(inference_mode),
|
||||
fw_grad_mode_(fw_grad_mode),
|
||||
multithreading_enabled_(multithreading_enabled),
|
||||
@ -44,10 +41,6 @@ struct C10_API AutogradState {
|
||||
view_replay_enabled_ = view_replay_enabled;
|
||||
}
|
||||
|
||||
void set_graph_exec_group(std::optional<SafePyObject> group) {
|
||||
graph_exec_group_ = std::move(group);
|
||||
}
|
||||
|
||||
bool get_grad_mode() const {
|
||||
return grad_mode_;
|
||||
}
|
||||
@ -68,12 +61,7 @@ struct C10_API AutogradState {
|
||||
return view_replay_enabled_;
|
||||
}
|
||||
|
||||
const std::optional<SafePyObject>& get_graph_exec_group() const {
|
||||
return graph_exec_group_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::optional<SafePyObject> graph_exec_group_;
|
||||
bool grad_mode_ : 1;
|
||||
bool inference_mode_ : 1;
|
||||
bool fw_grad_mode_ : 1;
|
||||
|
||||
@ -96,10 +96,6 @@ struct C10_API DeviceAllocator : public c10::Allocator {
|
||||
|
||||
// Resets peak memory usage statistics for the specified device
|
||||
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
|
||||
|
||||
// Return the free memory size and total memory size in bytes for the
|
||||
// specified device.
|
||||
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) = 0;
|
||||
};
|
||||
|
||||
// This function is used to get the DeviceAllocator for a specific device type
|
||||
|
||||
@ -345,13 +345,6 @@ class CUDAAllocator : public DeviceAllocator {
|
||||
c10::DeviceIndex device,
|
||||
std::shared_ptr<AllocatorState> pps) = 0;
|
||||
virtual std::string name() = 0;
|
||||
std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) override {
|
||||
c10::DeviceGuard device_guard({at::kCUDA, device});
|
||||
size_t free = 0;
|
||||
size_t total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&free, &total));
|
||||
return {free, total};
|
||||
}
|
||||
};
|
||||
|
||||
// Allocator object, statically initialized
|
||||
|
||||
@ -66,15 +66,6 @@ def define_targets(rules):
|
||||
],
|
||||
)
|
||||
|
||||
rules.cc_test(
|
||||
name = "util/nofatal_test",
|
||||
srcs = ["util/nofatal_test.cpp"],
|
||||
deps = [
|
||||
"//c10/util:base",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
rules.cc_test(
|
||||
name = "util/ssize_test",
|
||||
srcs = ["util/ssize_test.cpp"],
|
||||
|
||||
@ -1,53 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
inline void expectThrowsEq(T&& fn, const char* expected_msg) {
|
||||
try {
|
||||
std::forward<T>(fn)();
|
||||
} catch (const c10::Error& e) {
|
||||
EXPECT_TRUE(
|
||||
std::string(e.what_without_backtrace()).find(expected_msg) !=
|
||||
std::string::npos);
|
||||
return;
|
||||
}
|
||||
ADD_FAILURE() << "Expected to throw exception with message \"" << expected_msg
|
||||
<< "\" but didn't throw";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(NofatalTest, TorchCheckComparisons) {
|
||||
// quick make sure that no-op works as expected
|
||||
TORCH_CHECK_EQ(1, 1) << "i am a silly message " << 1;
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_EQ(1, 2) << "i am a silly message " << 1; },
|
||||
"Check failed: 1 == 2 (1 vs. 2). i am a silly message 1");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_NE(2, 2); }, "Check failed: 2 != 2 (2 vs. 2).");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_LT(2, 2); }, "Check failed: 2 < 2 (2 vs. 2).");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_LE(3, 2); }, "Check failed: 3 <= 2 (3 vs. 2).");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_GT(2, 2); }, "Check failed: 2 > 2 (2 vs. 2).");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_GE(2, 3); }, "Check failed: 2 >= 3 (2 vs. 3).");
|
||||
expectThrowsEq(
|
||||
[]() {
|
||||
void* p = nullptr;
|
||||
TORCH_CHECK_NOTNULL(p);
|
||||
},
|
||||
"Check failed: 'p' must be non NULL.");
|
||||
|
||||
#if GTEST_HAS_DEATH_TEST
|
||||
#ifndef NDEBUG
|
||||
// if dbg build, DCHECK should result in deth
|
||||
EXPECT_DEATH(TORCH_DCHECK_EQ(1, 2), "Check failed");
|
||||
#else
|
||||
TORCH_DCHECK_EQ(1, 2); // no-op
|
||||
#endif
|
||||
#endif // GTEST_HAS_DEATH_TEST
|
||||
}
|
||||
@ -702,98 +702,6 @@ namespace c10::detail {
|
||||
#define TORCH_CHECK_ARG(cond, argN, ...) \
|
||||
TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__)
|
||||
|
||||
#ifndef FATAL_IF
|
||||
#ifdef C10_USE_GLOG
|
||||
#define FATAL_IF(condition) \
|
||||
condition ? (void)0 \
|
||||
: ::c10::LoggerVoidify() & \
|
||||
::c10::MessageLogger(__FILE__, __LINE__, ::google::GLOG_FATAL) \
|
||||
.stream()
|
||||
#else
|
||||
#define FATAL_IF(condition) \
|
||||
condition ? (void)0 \
|
||||
: ::c10::LoggerVoidify() & \
|
||||
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL).stream()
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef NON_FATAL_IF
|
||||
#ifdef C10_USE_GLOG
|
||||
#define NON_FATAL_IF(condition) \
|
||||
condition ? (void)0 \
|
||||
: ::c10::LoggerVoidify() & \
|
||||
::c10::MessageLogger( \
|
||||
__FILE__, __LINE__, ::google::GLOG_FATAL, false) \
|
||||
.stream()
|
||||
#else
|
||||
#define NON_FATAL_IF(condition) \
|
||||
condition ? (void)0 \
|
||||
: ::c10::LoggerVoidify() & \
|
||||
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL, false) \
|
||||
.stream()
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Binary comparison check macros
|
||||
#define TORCH_CHECK_OP(val1, val2, op) \
|
||||
NON_FATAL_IF(((val1)op(val2))) \
|
||||
<< "Check failed: " #val1 " " #op " " #val2 " (" << (val1) << " vs. " \
|
||||
<< (val2) << "). "
|
||||
|
||||
#define TORCH_DCHECK_OP(val1, val2, op) \
|
||||
FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \
|
||||
<< (val1) << " vs. " << (val2) << "). "
|
||||
|
||||
#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
|
||||
#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
|
||||
#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
|
||||
#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
|
||||
#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
|
||||
#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
|
||||
|
||||
// Debug versions of TORCH_CHECK_OP macros
|
||||
#ifndef NDEBUG
|
||||
#define TORCH_DCHECK_EQ(val1, val2) TORCH_DCHECK_OP(val1, val2, ==)
|
||||
#define TORCH_DCHECK_NE(val1, val2) TORCH_DCHECK_OP(val1, val2, !=)
|
||||
#define TORCH_DCHECK_LE(val1, val2) TORCH_DCHECK_OP(val1, val2, <=)
|
||||
#define TORCH_DCHECK_LT(val1, val2) TORCH_DCHECK_OP(val1, val2, <)
|
||||
#define TORCH_DCHECK_GE(val1, val2) TORCH_DCHECK_OP(val1, val2, >=)
|
||||
#define TORCH_DCHECK_GT(val1, val2) TORCH_DCHECK_OP(val1, val2, >)
|
||||
#else // !NDEBUG
|
||||
// Optimized versions - generate no code
|
||||
#define TORCH_DCHECK_EQ(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, ==)
|
||||
#define TORCH_DCHECK_NE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, !=)
|
||||
#define TORCH_DCHECK_LE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, <=)
|
||||
#define TORCH_DCHECK_LT(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, <)
|
||||
#define TORCH_DCHECK_GE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, >=)
|
||||
#define TORCH_DCHECK_GT(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, >)
|
||||
#endif // NDEBUG
|
||||
|
||||
// Null pointer check macro
|
||||
#define TORCH_CHECK_NOTNULL(val) \
|
||||
::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), false)
|
||||
|
||||
#ifndef NDEBUG
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), true)
|
||||
#else // !NDEBUG
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
while (false) \
|
||||
TORCH_CHECK_NOTNULL(val)
|
||||
#endif // NDEBUG
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Deprecated macros
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
@ -291,32 +291,6 @@ namespace c10 {
|
||||
using fLB::FLAGS_logtostderr;
|
||||
using fLI::FLAGS_minloglevel;
|
||||
using fLI::FLAGS_v;
|
||||
|
||||
MessageLogger::MessageLogger(
|
||||
const char* file,
|
||||
int line,
|
||||
int severity,
|
||||
bool exit_on_fatal)
|
||||
: stream_(), severity_(severity), exit_on_fatal_(exit_on_fatal) {}
|
||||
|
||||
MessageLogger::~MessageLogger() noexcept(false) {
|
||||
if (severity_ == ::google::GLOG_FATAL) {
|
||||
DealWithFatal();
|
||||
}
|
||||
}
|
||||
|
||||
std::stringstream& MessageLogger::stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
void MessageLogger::DealWithFatal() {
|
||||
if (exit_on_fatal_) {
|
||||
LOG(FATAL) << stream_.str();
|
||||
} else {
|
||||
throw c10::Error(stream_.str(), nullptr, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_DEFINE_int(
|
||||
@ -438,16 +412,17 @@ void ShowLogInfoToStderr() {
|
||||
FLAGS_caffe2_log_level = GLOG_INFO;
|
||||
}
|
||||
|
||||
MessageLogger::MessageLogger(
|
||||
const char* file,
|
||||
int line,
|
||||
int severity,
|
||||
bool exit_on_fatal)
|
||||
: severity_(severity), exit_on_fatal_(exit_on_fatal) {
|
||||
MessageLogger::MessageLogger(const char* file, int line, int severity)
|
||||
: severity_(severity) {
|
||||
if (severity_ < FLAGS_caffe2_log_level) {
|
||||
// Nothing needs to be logged.
|
||||
return;
|
||||
}
|
||||
#ifdef ANDROID
|
||||
tag_ = "native";
|
||||
#else // !ANDROID
|
||||
tag_ = "";
|
||||
#endif // ANDROID
|
||||
|
||||
time_t rawtime = 0;
|
||||
time(&rawtime);
|
||||
@ -483,7 +458,7 @@ MessageLogger::MessageLogger(
|
||||
}
|
||||
|
||||
// Output the contents of the stream to the proper channel on destruction.
|
||||
MessageLogger::~MessageLogger() noexcept(false) {
|
||||
MessageLogger::~MessageLogger() {
|
||||
if (severity_ < FLAGS_caffe2_log_level) {
|
||||
// Nothing needs to be logged.
|
||||
return;
|
||||
@ -523,18 +498,6 @@ MessageLogger::~MessageLogger() noexcept(false) {
|
||||
}
|
||||
}
|
||||
|
||||
std::stringstream& MessageLogger::stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
void MessageLogger::DealWithFatal() {
|
||||
if (exit_on_fatal_) {
|
||||
abort();
|
||||
} else {
|
||||
throw c10::Error(stream_.str(), nullptr, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#endif // !C10_USE_GLOG
|
||||
|
||||
@ -1,74 +0,0 @@
|
||||
#ifndef C10_UTIL_LOGGING_COMMON_H_
|
||||
#define C10_UTIL_LOGGING_COMMON_H_
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// MessageLogger that throws exceptions instead of aborting (glog version)
|
||||
// or logs and may abort (non-glog version).
|
||||
class C10_API MessageLogger {
|
||||
public:
|
||||
MessageLogger(
|
||||
const char* file,
|
||||
int line,
|
||||
int severity,
|
||||
bool exit_on_fatal = true);
|
||||
~MessageLogger() noexcept(false);
|
||||
|
||||
// Return the stream associated with the logger object.
|
||||
std::stringstream& stream();
|
||||
|
||||
private:
|
||||
// When there is a fatal log, and fatal == true, we abort
|
||||
// otherwise, we throw.
|
||||
void DealWithFatal();
|
||||
|
||||
#if defined(ANDROID) && !defined(C10_USE_GLOG)
|
||||
const char* tag_{"native"};
|
||||
#endif
|
||||
std::stringstream stream_;
|
||||
int severity_;
|
||||
bool exit_on_fatal_;
|
||||
};
|
||||
|
||||
// This class is used to explicitly ignore values in the conditional
|
||||
// logging macros. This avoids compiler warnings like "value computed
|
||||
// is not used" and "statement has no effect".
|
||||
class C10_API LoggerVoidify {
|
||||
public:
|
||||
LoggerVoidify() = default;
|
||||
// This has to be an operator with a precedence lower than << but
|
||||
// higher than ?:
|
||||
void operator&(const std::ostream& s [[maybe_unused]]) {}
|
||||
};
|
||||
|
||||
// Forward declarations for CheckNotNull functions
|
||||
template <typename T>
|
||||
T& CheckNotNullCommon(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal = true);
|
||||
|
||||
template <typename T>
|
||||
T* CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T* t,
|
||||
bool fatal = true);
|
||||
|
||||
template <typename T>
|
||||
T& CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal = true);
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#endif // C10_UTIL_LOGGING_COMMON_H_
|
||||
@ -47,53 +47,57 @@ INSTANTIATE_FOR_CONTAINER(set)
|
||||
|
||||
#endif
|
||||
|
||||
#include <c10/util/logging_common.h>
|
||||
#include <glog/logging.h>
|
||||
|
||||
namespace c10 {
|
||||
// Additional macros on top of glog
|
||||
#define TORCH_CHECK_EQ(val1, val2) CHECK_EQ(val1, val2)
|
||||
#define TORCH_CHECK_NE(val1, val2) CHECK_NE(val1, val2)
|
||||
#define TORCH_CHECK_LE(val1, val2) CHECK_LE(val1, val2)
|
||||
#define TORCH_CHECK_LT(val1, val2) CHECK_LT(val1, val2)
|
||||
#define TORCH_CHECK_GE(val1, val2) CHECK_GE(val1, val2)
|
||||
#define TORCH_CHECK_GT(val1, val2) CHECK_GT(val1, val2)
|
||||
|
||||
[[noreturn]] void ThrowEnforceNotMet(
|
||||
const char* file,
|
||||
const int line,
|
||||
const char* condition,
|
||||
const std::string& msg,
|
||||
const void* caller);
|
||||
#ifndef NDEBUG
|
||||
#define TORCH_DCHECK_EQ(val1, val2) DCHECK_EQ(val1, val2)
|
||||
#define TORCH_DCHECK_NE(val1, val2) DCHECK_NE(val1, val2)
|
||||
#define TORCH_DCHECK_LE(val1, val2) DCHECK_LE(val1, val2)
|
||||
#define TORCH_DCHECK_LT(val1, val2) DCHECK_LT(val1, val2)
|
||||
#define TORCH_DCHECK_GE(val1, val2) DCHECK_GE(val1, val2)
|
||||
#define TORCH_DCHECK_GT(val1, val2) DCHECK_GT(val1, val2)
|
||||
#else // !NDEBUG
|
||||
// These versions generate no code in optimized mode.
|
||||
#define TORCH_DCHECK_EQ(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_EQ(val1, val2)
|
||||
#define TORCH_DCHECK_NE(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_NE(val1, val2)
|
||||
#define TORCH_DCHECK_LE(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_LE(val1, val2)
|
||||
#define TORCH_DCHECK_LT(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_LT(val1, val2)
|
||||
#define TORCH_DCHECK_GE(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_GE(val1, val2)
|
||||
#define TORCH_DCHECK_GT(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_GT(val1, val2)
|
||||
#endif // NDEBUG
|
||||
|
||||
template <typename T>
|
||||
T& CheckNotNullCommon(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal) {
|
||||
if (t == nullptr) {
|
||||
MessageLogger(file, line, ::google::GLOG_FATAL, fatal).stream()
|
||||
<< "Check failed: '" << names << "' must be non NULL. ";
|
||||
}
|
||||
return t;
|
||||
}
|
||||
// Check that a pointer is not null.
|
||||
#define TORCH_CHECK_NOTNULL(val) CHECK_NOTNULL(val)
|
||||
|
||||
template <typename T>
|
||||
T* CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T* t,
|
||||
bool fatal) {
|
||||
return CheckNotNullCommon(file, line, names, t, fatal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T& CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal) {
|
||||
return CheckNotNullCommon(file, line, names, t, fatal);
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
#ifndef NDEBUG
|
||||
// Debug only version of TORCH_CHECK_NOTNULL
|
||||
#define TORCH_DCHECK_NOTNULL(val) DCHECK_NOTNULL(val)
|
||||
#else // !NDEBUG
|
||||
// Optimized version - generates no code.
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
while (false) \
|
||||
DCHECK_NOTNULL(val)
|
||||
#endif // NDEBUG
|
||||
|
||||
// Log with source location information override (to be used in generic
|
||||
// warning/error handlers implemented as functions, not macros)
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
#include <vector>
|
||||
|
||||
#include <c10/util/Flags.h>
|
||||
#include <c10/util/logging_common.h>
|
||||
|
||||
const char CAFFE2_SEVERITY_PREFIX[] = "FEWIV";
|
||||
|
||||
@ -25,40 +24,61 @@ const int GLOG_ERROR = 2;
|
||||
const int GLOG_WARNING = 1;
|
||||
const int GLOG_INFO = 0;
|
||||
|
||||
class C10_API MessageLogger {
|
||||
public:
|
||||
MessageLogger(const char* file, int line, int severity);
|
||||
~MessageLogger();
|
||||
// Return the stream associated with the logger object.
|
||||
std::stringstream& stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
private:
|
||||
// When there is a fatal log, we simply abort.
|
||||
void DealWithFatal() {
|
||||
abort();
|
||||
}
|
||||
|
||||
const char* tag_;
|
||||
std::stringstream stream_;
|
||||
int severity_;
|
||||
};
|
||||
|
||||
// This class is used to explicitly ignore values in the conditional
|
||||
// logging macros. This avoids compiler warnings like "value computed
|
||||
// is not used" and "statement has no effect".
|
||||
class C10_API LoggerVoidify {
|
||||
public:
|
||||
LoggerVoidify() = default;
|
||||
// This has to be an operator with a precedence lower than << but
|
||||
// higher than ?:
|
||||
void operator&(const std::ostream& s [[maybe_unused]]) {}
|
||||
};
|
||||
|
||||
// Log a message and terminate.
|
||||
template <class T>
|
||||
void LogMessageFatal(const char* file, int line, const T& message) {
|
||||
MessageLogger(file, line, GLOG_FATAL).stream() << message;
|
||||
}
|
||||
|
||||
// Helpers for TORCH_CHECK_NOTNULL(). Two are necessary to support both raw
|
||||
// pointers and smart pointers.
|
||||
template <typename T>
|
||||
T& CheckNotNullCommon(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal) {
|
||||
T& CheckNotNullCommon(const char* file, int line, const char* names, T& t) {
|
||||
if (t == nullptr) {
|
||||
MessageLogger(file, line, GLOG_FATAL, fatal).stream()
|
||||
<< "Check failed: '" << names << "' must be non NULL. ";
|
||||
LogMessageFatal(file, line, std::string(names));
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T* t,
|
||||
bool fatal) {
|
||||
return CheckNotNullCommon(file, line, names, t, fatal);
|
||||
T* CheckNotNull(const char* file, int line, const char* names, T* t) {
|
||||
return CheckNotNullCommon(file, line, names, t);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T& CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal) {
|
||||
return CheckNotNullCommon(file, line, names, t, fatal);
|
||||
T& CheckNotNull(const char* file, int line, const char* names, T& t) {
|
||||
return CheckNotNullCommon(file, line, names, t);
|
||||
}
|
||||
} // namespace c10
|
||||
|
||||
@ -116,6 +136,65 @@ static_assert(
|
||||
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream()
|
||||
#endif // NDEBUG
|
||||
|
||||
#define TORCH_CHECK_OP(val1, val2, op) \
|
||||
FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \
|
||||
<< (val1) << " vs. " << (val2) << ") "
|
||||
|
||||
// TORCH_CHECK_OP macro definitions
|
||||
#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
|
||||
#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
|
||||
#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
|
||||
#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
|
||||
#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
|
||||
#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Debug only versions of TORCH_CHECK_OP macros.
|
||||
#define TORCH_DCHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
|
||||
#define TORCH_DCHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
|
||||
#define TORCH_DCHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
|
||||
#define TORCH_DCHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
|
||||
#define TORCH_DCHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
|
||||
#define TORCH_DCHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
|
||||
#else // !NDEBUG
|
||||
// These versions generate no code in optimized mode.
|
||||
#define TORCH_DCHECK_EQ(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, ==)
|
||||
#define TORCH_DCHECK_NE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, !=)
|
||||
#define TORCH_DCHECK_LE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, <=)
|
||||
#define TORCH_DCHECK_LT(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, <)
|
||||
#define TORCH_DCHECK_GE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, >=)
|
||||
#define TORCH_DCHECK_GT(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, >)
|
||||
#endif // NDEBUG
|
||||
|
||||
// Check that a pointer is not null.
|
||||
#define TORCH_CHECK_NOTNULL(val) \
|
||||
::c10::CheckNotNull( \
|
||||
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Debug only version of TORCH_CHECK_NOTNULL
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
::c10::CheckNotNull( \
|
||||
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
|
||||
#else // !NDEBUG
|
||||
// Optimized version - generates no code.
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
while (false) \
|
||||
TORCH_CHECK_NOTNULL(val)
|
||||
#endif // NDEBUG
|
||||
|
||||
// ---------------------- Support for std objects --------------------------
|
||||
// These are adapted from glog to support a limited set of logging capability
|
||||
// for STL objects.
|
||||
|
||||
@ -926,14 +926,15 @@ class DeviceCachingAllocator {
|
||||
(release_cached_blocks() && alloc_block(params, true));
|
||||
}
|
||||
if (!block_found) {
|
||||
const auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
const auto device_total =
|
||||
raw_device.get_info<sycl::info::device::global_mem_size>();
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
// Estimate the available device memory when the SYCL runtime does not
|
||||
// support the corresponding aspect (ext_intel_free_memory).
|
||||
size_t device_free = device_total -
|
||||
size_t device_free = device_prop.global_mem_size -
|
||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
|
||||
.current;
|
||||
auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
|
||||
// affected devices.
|
||||
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
|
||||
@ -1051,37 +1052,21 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo() {
|
||||
const auto& device = c10::xpu::get_raw_device(device_index);
|
||||
const size_t total = device.get_info<sycl::info::device::global_mem_size>();
|
||||
TORCH_CHECK(
|
||||
device.has(sycl::aspect::ext_intel_free_memory),
|
||||
"The device (",
|
||||
device.get_info<sycl::info::device::name>(),
|
||||
") doesn't support querying the available free memory. ",
|
||||
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
|
||||
"to help us prioritize its implementation.");
|
||||
const size_t free =
|
||||
device.get_info<sycl::ext::intel::info::device::free_memory>();
|
||||
return {free, total};
|
||||
}
|
||||
|
||||
double getMemoryFraction() {
|
||||
if (!set_fraction) {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
return static_cast<double>(allowed_memory_maximum) /
|
||||
static_cast<double>(device_total);
|
||||
static_cast<double>(device_prop.global_mem_size);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction) {
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
|
||||
set_fraction = true;
|
||||
}
|
||||
@ -1255,11 +1240,6 @@ class XPUAllocator : public DeviceAllocator {
|
||||
c10::xpu::get_raw_device(dev_to_access));
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryInfo();
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryFraction();
|
||||
|
||||
@ -40,7 +40,6 @@
|
||||
:nosignatures:
|
||||
|
||||
empty_cache
|
||||
get_memory_info
|
||||
max_memory_allocated
|
||||
max_memory_reserved
|
||||
memory_allocated
|
||||
|
||||
@ -382,6 +382,20 @@ coverage_ignore_functions = [
|
||||
# torch.ao.quantization.backend_config.tensorrt
|
||||
"get_tensorrt_backend_config",
|
||||
"get_tensorrt_backend_config_dict",
|
||||
# torch.ao.quantization.backend_config.utils
|
||||
"entry_to_pretty_str",
|
||||
"get_fused_module_classes",
|
||||
"get_fuser_method_mapping",
|
||||
"get_fusion_pattern_to_extra_inputs_getter",
|
||||
"get_fusion_pattern_to_root_node_getter",
|
||||
"get_module_to_qat_module",
|
||||
"get_pattern_to_dtype_configs",
|
||||
"get_pattern_to_input_type_to_index",
|
||||
"get_qat_module_classes",
|
||||
"get_root_module_to_quantized_reference_module",
|
||||
"pattern_to_human_readable",
|
||||
"remove_boolean_dispatch_from_name",
|
||||
# torch.ao.quantization.backend_config.x86
|
||||
"get_x86_backend_config",
|
||||
# torch.ao.quantization.fuse_modules
|
||||
"fuse_known_modules",
|
||||
@ -412,6 +426,25 @@ coverage_ignore_functions = [
|
||||
"insert_observers_for_model",
|
||||
"prepare",
|
||||
"propagate_dtypes_for_known_nodes",
|
||||
# torch.ao.quantization.fx.utils
|
||||
"all_node_args_except_first",
|
||||
"all_node_args_have_no_tensors",
|
||||
"assert_and_get_unique_device",
|
||||
"collect_producer_nodes",
|
||||
"create_getattr_from_value",
|
||||
"create_node_from_old_node_preserve_meta",
|
||||
"get_custom_module_class_keys",
|
||||
"get_linear_prepack_op_for_dtype",
|
||||
"get_new_attr_name_with_prefix",
|
||||
"get_non_observable_arg_indexes_and_types",
|
||||
"get_qconv_prepack_op",
|
||||
"get_skipped_module_name_and_classes",
|
||||
"graph_module_from_producer_nodes",
|
||||
"maybe_get_next_module",
|
||||
"node_arg_is_bias",
|
||||
"node_arg_is_weight",
|
||||
"return_arg_list",
|
||||
# torch.ao.quantization.pt2e.graph_utils
|
||||
"bfs_trace_with_node_process",
|
||||
"find_sequential_partitions",
|
||||
"get_equivalent_types",
|
||||
@ -827,10 +860,80 @@ coverage_ignore_functions = [
|
||||
"get_latency_of_one_partition",
|
||||
"get_latency_of_partitioned_graph",
|
||||
"get_partition_to_latency_mapping",
|
||||
# torch.fx.experimental.proxy_tensor
|
||||
"decompose",
|
||||
"disable_autocast_cache",
|
||||
"disable_proxy_modes_tracing",
|
||||
"dispatch_trace",
|
||||
"extract_val",
|
||||
"fake_signature",
|
||||
"fetch_sym_proxy",
|
||||
"fetch_object_proxy",
|
||||
"get_innermost_proxy_mode",
|
||||
"get_isolated_graphmodule",
|
||||
"get_proxy_slot",
|
||||
"get_torch_dispatch_modes",
|
||||
"has_proxy_slot",
|
||||
"is_sym_node",
|
||||
"maybe_handle_decomp",
|
||||
"proxy_call",
|
||||
"set_meta",
|
||||
"set_original_aten_op",
|
||||
"set_proxy_slot",
|
||||
"snapshot_fake",
|
||||
"thunkify",
|
||||
"track_tensor",
|
||||
"track_tensor_tree",
|
||||
"wrap_key",
|
||||
"wrapper_and_args_for_make_fx",
|
||||
# torch.fx.experimental.recording
|
||||
"record_shapeenv_event",
|
||||
"replay_shape_env_events",
|
||||
"shape_env_check_state_equal",
|
||||
# torch.fx.experimental.sym_node
|
||||
"ceil_impl",
|
||||
"floor_ceil_helper",
|
||||
"floor_impl",
|
||||
"method_to_operator",
|
||||
"sympy_is_channels_last_contiguous_2d",
|
||||
"sympy_is_channels_last_contiguous_3d",
|
||||
"sympy_is_channels_last_strides_2d",
|
||||
"sympy_is_channels_last_strides_3d",
|
||||
"sympy_is_channels_last_strides_generic",
|
||||
"sympy_is_contiguous",
|
||||
"sympy_is_contiguous_generic",
|
||||
"to_node",
|
||||
"wrap_node",
|
||||
"sym_sqrt",
|
||||
# torch.fx.experimental.symbolic_shapes
|
||||
"bind_symbols",
|
||||
"cast_symbool_to_symint_guardless",
|
||||
"create_contiguous",
|
||||
"error",
|
||||
"eval_guards",
|
||||
"eval_is_non_overlapping_and_dense",
|
||||
"expect_true",
|
||||
"find_symbol_binding_fx_nodes",
|
||||
"free_symbols",
|
||||
"free_unbacked_symbols",
|
||||
"fx_placeholder_targets",
|
||||
"fx_placeholder_vals",
|
||||
"guard_bool",
|
||||
"guard_float",
|
||||
"guard_int",
|
||||
"guard_scalar",
|
||||
"has_hint",
|
||||
"has_symbolic_sizes_strides",
|
||||
"is_channels_last_contiguous_2d",
|
||||
"is_channels_last_contiguous_3d",
|
||||
"is_channels_last_strides_2d",
|
||||
"is_channels_last_strides_3d",
|
||||
"is_contiguous",
|
||||
"is_non_overlapping_and_dense_indicator",
|
||||
"is_nested_int",
|
||||
"is_symbol_binding_fx_node",
|
||||
"is_symbolic",
|
||||
# torch.fx.experimental.unification.core
|
||||
"reify",
|
||||
# torch.fx.experimental.unification.match
|
||||
"edge",
|
||||
@ -868,6 +971,24 @@ coverage_ignore_functions = [
|
||||
"reverse_dict",
|
||||
# torch.fx.experimental.unification.multipledispatch.variadic
|
||||
"isvariadic",
|
||||
# torch.fx.experimental.unification.unification_tools
|
||||
"assoc",
|
||||
"assoc_in",
|
||||
"dissoc",
|
||||
"first",
|
||||
"get_in",
|
||||
"getter",
|
||||
"groupby",
|
||||
"itemfilter",
|
||||
"itemmap",
|
||||
"keyfilter",
|
||||
"keymap",
|
||||
"merge",
|
||||
"merge_with",
|
||||
"update_in",
|
||||
"valfilter",
|
||||
"valmap",
|
||||
# torch.fx.experimental.unification.utils
|
||||
"freeze",
|
||||
"hashable",
|
||||
"raises",
|
||||
@ -1308,8 +1429,319 @@ coverage_ignore_functions = [
|
||||
# torch.onnx.symbolic_opset7
|
||||
"max",
|
||||
"min",
|
||||
# torch.onnx.symbolic_opset8
|
||||
"addmm",
|
||||
"bmm",
|
||||
"empty",
|
||||
"empty_like",
|
||||
"flatten",
|
||||
"full",
|
||||
"full_like",
|
||||
"gt",
|
||||
"lt",
|
||||
"matmul",
|
||||
"mm",
|
||||
"ones",
|
||||
"ones_like",
|
||||
"prelu",
|
||||
"repeat",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
# torch.onnx.symbolic_opset9
|
||||
"abs",
|
||||
"acos",
|
||||
"adaptive_avg_pool1d",
|
||||
"adaptive_avg_pool2d",
|
||||
"adaptive_avg_pool3d",
|
||||
"adaptive_max_pool1d",
|
||||
"adaptive_max_pool2d",
|
||||
"adaptive_max_pool3d",
|
||||
"add",
|
||||
"addcmul",
|
||||
"addmm",
|
||||
"alias",
|
||||
"amax",
|
||||
"amin",
|
||||
"aminmax",
|
||||
"arange",
|
||||
"argmax",
|
||||
"argmin",
|
||||
"as_strided",
|
||||
"as_tensor",
|
||||
"asin",
|
||||
"atan",
|
||||
"atan2",
|
||||
"avg_pool1d",
|
||||
"avg_pool2d",
|
||||
"avg_pool3d",
|
||||
"baddbmm",
|
||||
"batch_norm",
|
||||
"bernoulli",
|
||||
"bitwise_not",
|
||||
"bitwise_or",
|
||||
"bmm",
|
||||
"broadcast_tensors",
|
||||
"broadcast_to",
|
||||
"bucketize",
|
||||
"cat",
|
||||
"cdist",
|
||||
"ceil",
|
||||
"clamp",
|
||||
"clamp_max",
|
||||
"clamp_min",
|
||||
"clone",
|
||||
"constant_pad_nd",
|
||||
"contiguous",
|
||||
"conv1d",
|
||||
"conv2d",
|
||||
"conv3d",
|
||||
"conv_tbc",
|
||||
"conv_transpose1d",
|
||||
"conv_transpose2d",
|
||||
"conv_transpose3d",
|
||||
"convert_element_type",
|
||||
"convolution",
|
||||
"cos",
|
||||
"cosine_similarity",
|
||||
"cross",
|
||||
"cumsum",
|
||||
"detach",
|
||||
"dim",
|
||||
"div",
|
||||
"dot",
|
||||
"dropout",
|
||||
"elu",
|
||||
"embedding",
|
||||
"embedding_bag",
|
||||
"empty",
|
||||
"empty_like",
|
||||
"eq",
|
||||
"erf",
|
||||
"exp",
|
||||
"expand",
|
||||
"expand_as",
|
||||
"eye",
|
||||
"fill",
|
||||
"flatten",
|
||||
"floor",
|
||||
"floor_divide",
|
||||
"floordiv",
|
||||
"frobenius_norm",
|
||||
"full",
|
||||
"full_like",
|
||||
"gather",
|
||||
"ge",
|
||||
"gelu",
|
||||
"get_pool_ceil_padding",
|
||||
"glu",
|
||||
"group_norm",
|
||||
"gru",
|
||||
"gt",
|
||||
"hann_window",
|
||||
"hardshrink",
|
||||
"hardsigmoid",
|
||||
"hardswish",
|
||||
"hardtanh",
|
||||
"index",
|
||||
"index_add",
|
||||
"index_copy",
|
||||
"index_fill",
|
||||
"index_put",
|
||||
"index_select",
|
||||
"instance_norm",
|
||||
"is_floating_point",
|
||||
"is_pinned",
|
||||
"isnan",
|
||||
"item",
|
||||
"kl_div",
|
||||
"layer_norm",
|
||||
"le",
|
||||
"leaky_relu",
|
||||
"lerp",
|
||||
"lift",
|
||||
"linalg_cross",
|
||||
"linalg_matrix_norm",
|
||||
"linalg_norm",
|
||||
"linalg_vector_norm",
|
||||
"linear",
|
||||
"linspace",
|
||||
"log",
|
||||
"log10",
|
||||
"log1p",
|
||||
"log2",
|
||||
"log_sigmoid",
|
||||
"log_softmax",
|
||||
"logical_and",
|
||||
"logical_not",
|
||||
"logical_or",
|
||||
"logical_xor",
|
||||
"logit",
|
||||
"logsumexp",
|
||||
"lstm",
|
||||
"lstm_cell",
|
||||
"lt",
|
||||
"masked_fill",
|
||||
"masked_fill_",
|
||||
"matmul",
|
||||
"max",
|
||||
"max_pool1d",
|
||||
"max_pool1d_with_indices",
|
||||
"max_pool2d",
|
||||
"max_pool2d_with_indices",
|
||||
"max_pool3d",
|
||||
"max_pool3d_with_indices",
|
||||
"maximum",
|
||||
"meshgrid",
|
||||
"min",
|
||||
"minimum",
|
||||
"mish",
|
||||
"mm",
|
||||
"movedim",
|
||||
"mse_loss",
|
||||
"mul",
|
||||
"multinomial",
|
||||
"mv",
|
||||
"narrow",
|
||||
"native_layer_norm",
|
||||
"ne",
|
||||
"neg",
|
||||
"new_empty",
|
||||
"new_full",
|
||||
"new_ones",
|
||||
"new_zeros",
|
||||
"nonzero",
|
||||
"nonzero_numpy",
|
||||
"noop_complex_operators",
|
||||
"norm",
|
||||
"numel",
|
||||
"numpy_T",
|
||||
"one_hot",
|
||||
"ones",
|
||||
"ones_like",
|
||||
"onnx_placeholder",
|
||||
"overload_by_arg_count",
|
||||
"pad",
|
||||
"pairwise_distance",
|
||||
"permute",
|
||||
"pixel_shuffle",
|
||||
"pixel_unshuffle",
|
||||
"pow",
|
||||
"prelu",
|
||||
"prim_constant",
|
||||
"prim_constant_chunk",
|
||||
"prim_constant_split",
|
||||
"prim_data",
|
||||
"prim_device",
|
||||
"prim_dtype",
|
||||
"prim_if",
|
||||
"prim_layout",
|
||||
"prim_list_construct",
|
||||
"prim_list_unpack",
|
||||
"prim_loop",
|
||||
"prim_max",
|
||||
"prim_min",
|
||||
"prim_shape",
|
||||
"prim_tolist",
|
||||
"prim_tuple_construct",
|
||||
"prim_type",
|
||||
"prim_unchecked_cast",
|
||||
"prim_uninitialized",
|
||||
"rand",
|
||||
"rand_like",
|
||||
"randint",
|
||||
"randint_like",
|
||||
"randn",
|
||||
"randn_like",
|
||||
"reciprocal",
|
||||
"reflection_pad",
|
||||
"relu",
|
||||
"relu6",
|
||||
"remainder",
|
||||
"repeat",
|
||||
"repeat_interleave",
|
||||
"replication_pad",
|
||||
"reshape",
|
||||
"reshape_as",
|
||||
"rnn_relu",
|
||||
"rnn_tanh",
|
||||
"roll",
|
||||
"rrelu",
|
||||
"rsqrt",
|
||||
"rsub",
|
||||
"scalar_tensor",
|
||||
"scatter",
|
||||
"scatter_add",
|
||||
"select",
|
||||
"selu",
|
||||
"sigmoid",
|
||||
"sign",
|
||||
"silu",
|
||||
"sin",
|
||||
"size",
|
||||
"slice",
|
||||
"softmax",
|
||||
"softplus",
|
||||
"softshrink",
|
||||
"sort",
|
||||
"split",
|
||||
"split_with_sizes",
|
||||
"sqrt",
|
||||
"square",
|
||||
"squeeze",
|
||||
"stack",
|
||||
"std",
|
||||
"std_mean",
|
||||
"sub",
|
||||
"t",
|
||||
"take",
|
||||
"tan",
|
||||
"tanh",
|
||||
"tanhshrink",
|
||||
"tensor",
|
||||
"threshold",
|
||||
"to",
|
||||
"topk",
|
||||
"transpose",
|
||||
"true_divide",
|
||||
"type_as",
|
||||
"unbind",
|
||||
"unfold",
|
||||
"unsafe_chunk",
|
||||
"unsafe_split",
|
||||
"unsafe_split_with_sizes",
|
||||
"unsqueeze",
|
||||
"unsupported_complex_operators",
|
||||
"unused",
|
||||
"upsample_bilinear2d",
|
||||
"upsample_linear1d",
|
||||
"upsample_nearest1d",
|
||||
"upsample_nearest2d",
|
||||
"upsample_nearest3d",
|
||||
"upsample_trilinear3d",
|
||||
"var",
|
||||
"var_mean",
|
||||
"view",
|
||||
"view_as",
|
||||
"where",
|
||||
"wrap_logical_op_with_cast_to",
|
||||
"wrap_logical_op_with_negation",
|
||||
"zero",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
# torch.onnx.utils
|
||||
"disable_apex_o2_state_dict_hook",
|
||||
"export",
|
||||
"export_to_pretty_string",
|
||||
"exporter_context",
|
||||
"is_in_onnx_export",
|
||||
"model_signature",
|
||||
"register_custom_op_symbolic",
|
||||
"select_model_mode_for_export",
|
||||
"setup_onnx_logging",
|
||||
"unconvertible_ops",
|
||||
"unpack_quantized_tensor",
|
||||
"warn_on_static_input_change",
|
||||
# torch.onnx.verification
|
||||
"check_export_model_diff",
|
||||
"verify",
|
||||
"verify_aten_graph",
|
||||
@ -1400,6 +1832,32 @@ coverage_ignore_functions = [
|
||||
"noop_context_fn",
|
||||
"set_checkpoint_early_stop",
|
||||
"set_device_states",
|
||||
# torch.utils.collect_env
|
||||
"check_release_file",
|
||||
"get_cachingallocator_config",
|
||||
"get_clang_version",
|
||||
"get_cmake_version",
|
||||
"get_conda_packages",
|
||||
"get_cpu_info",
|
||||
"get_cuda_module_loading_config",
|
||||
"get_cudnn_version",
|
||||
"get_env_info",
|
||||
"get_gcc_version",
|
||||
"get_gpu_info",
|
||||
"get_libc_version",
|
||||
"get_lsb_version",
|
||||
"get_mac_version",
|
||||
"get_nvidia_driver_version",
|
||||
"get_nvidia_smi",
|
||||
"get_os",
|
||||
"get_pip_packages",
|
||||
"get_platform",
|
||||
"get_pretty_env_info",
|
||||
"get_python_platform",
|
||||
"get_running_cuda_version",
|
||||
"get_windows_version",
|
||||
"is_xnnpack_available",
|
||||
"pretty_str",
|
||||
# torch.utils.cpp_backtrace
|
||||
"get_cpp_backtrace",
|
||||
# torch.utils.cpp_extension
|
||||
@ -1463,6 +1921,52 @@ coverage_ignore_functions = [
|
||||
"apply_shuffle_seed",
|
||||
"apply_shuffle_settings",
|
||||
"get_all_graph_pipes",
|
||||
# torch.utils.flop_counter
|
||||
"addmm_flop",
|
||||
"baddbmm_flop",
|
||||
"bmm_flop",
|
||||
"conv_backward_flop",
|
||||
"conv_flop",
|
||||
"conv_flop_count",
|
||||
"convert_num_with_suffix",
|
||||
"get_shape",
|
||||
"get_suffix_str",
|
||||
"mm_flop",
|
||||
"normalize_tuple",
|
||||
"register_flop_formula",
|
||||
"sdpa_backward_flop",
|
||||
"sdpa_backward_flop_count",
|
||||
"sdpa_flop",
|
||||
"sdpa_flop_count",
|
||||
"shape_wrapper",
|
||||
"transpose_shape",
|
||||
# torch.utils.hipify.hipify_python
|
||||
"add_dim3",
|
||||
"compute_stats",
|
||||
"extract_arguments",
|
||||
"file_add_header",
|
||||
"file_specific_replacement",
|
||||
"find_bracket_group",
|
||||
"find_closure_group",
|
||||
"find_parentheses_group",
|
||||
"fix_static_global_kernels",
|
||||
"get_hip_file_path",
|
||||
"hip_header_magic",
|
||||
"hipify",
|
||||
"is_caffe2_gpu_file",
|
||||
"is_cusparse_file",
|
||||
"is_out_of_place",
|
||||
"is_pytorch_file",
|
||||
"is_special_file",
|
||||
"match_extensions",
|
||||
"matched_files_iter",
|
||||
"openf",
|
||||
"preprocess_file_and_save_result",
|
||||
"preprocessor",
|
||||
"processKernelLaunches",
|
||||
"replace_extern_shared",
|
||||
"replace_math_functions",
|
||||
"str2bool",
|
||||
# torch.utils.hooks
|
||||
"unserializable_hook",
|
||||
"warn_if_has_hooks",
|
||||
|
||||
@ -12,37 +12,6 @@ These APIs are experimental and subject to change without notice.
|
||||
.. autoclass:: torch.fx.experimental.sym_node.DynamicInt
|
||||
```
|
||||
|
||||
## torch.fx.experimental.sym_node
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.fx.experimental.sym_node
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.fx.experimental.sym_node
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
is_channels_last_contiguous_2d
|
||||
is_channels_last_contiguous_3d
|
||||
is_channels_last_strides_2d
|
||||
is_channels_last_strides_3d
|
||||
is_contiguous
|
||||
is_non_overlapping_and_dense_indicator
|
||||
method_to_operator
|
||||
sympy_is_channels_last_contiguous_2d
|
||||
sympy_is_channels_last_contiguous_3d
|
||||
sympy_is_channels_last_strides_2d
|
||||
sympy_is_channels_last_strides_3d
|
||||
sympy_is_channels_last_strides_generic
|
||||
sympy_is_contiguous
|
||||
sympy_is_contiguous_generic
|
||||
```
|
||||
|
||||
## torch.fx.experimental.symbolic_shapes
|
||||
|
||||
```{eval-rst}
|
||||
@ -100,25 +69,6 @@ These APIs are experimental and subject to change without notice.
|
||||
rebind_unbacked
|
||||
resolve_unbacked_bindings
|
||||
is_accessor_node
|
||||
cast_symbool_to_symint_guardless
|
||||
create_contiguous
|
||||
error
|
||||
eval_guards
|
||||
eval_is_non_overlapping_and_dense
|
||||
find_symbol_binding_fx_nodes
|
||||
free_symbols
|
||||
free_unbacked_symbols
|
||||
fx_placeholder_targets
|
||||
fx_placeholder_vals
|
||||
guard_bool
|
||||
guard_float
|
||||
guard_int
|
||||
guard_scalar
|
||||
has_hint
|
||||
has_symbolic_sizes_strides
|
||||
is_nested_int
|
||||
is_symbol_binding_fx_node
|
||||
is_symbolic
|
||||
```
|
||||
|
||||
## torch.fx.experimental.proxy_tensor
|
||||
@ -141,46 +91,4 @@ These APIs are experimental and subject to change without notice.
|
||||
get_proxy_mode
|
||||
maybe_enable_thunkify
|
||||
maybe_disable_thunkify
|
||||
decompose
|
||||
disable_autocast_cache
|
||||
disable_proxy_modes_tracing
|
||||
extract_val
|
||||
fake_signature
|
||||
fetch_object_proxy
|
||||
fetch_sym_proxy
|
||||
has_proxy_slot
|
||||
is_sym_node
|
||||
maybe_handle_decomp
|
||||
proxy_call
|
||||
set_meta
|
||||
set_original_aten_op
|
||||
set_proxy_slot
|
||||
snapshot_fake
|
||||
```
|
||||
|
||||
## torch.fx.experimental.unification.unification_tools
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.fx.experimental.unification.unification_tools
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.fx.experimental.unification.unification_tools
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
assoc
|
||||
assoc_in
|
||||
dissoc
|
||||
first
|
||||
keyfilter
|
||||
keymap
|
||||
merge
|
||||
merge_with
|
||||
update_in
|
||||
valfilter
|
||||
valmap
|
||||
|
||||
@ -1134,6 +1134,7 @@ The set of leaf modules can be customized by overriding
|
||||
.. py:module:: torch.fx.experimental.refinement_types
|
||||
.. py:module:: torch.fx.experimental.rewriter
|
||||
.. py:module:: torch.fx.experimental.schema_type_annotation
|
||||
.. py:module:: torch.fx.experimental.sym_node
|
||||
.. py:module:: torch.fx.experimental.unification.core
|
||||
.. py:module:: torch.fx.experimental.unification.dispatch
|
||||
.. py:module:: torch.fx.experimental.unification.match
|
||||
@ -1143,6 +1144,7 @@ The set of leaf modules can be customized by overriding
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.utils
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.variadic
|
||||
.. py:module:: torch.fx.experimental.unification.unification_tools
|
||||
.. py:module:: torch.fx.experimental.unification.utils
|
||||
.. py:module:: torch.fx.experimental.unification.variable
|
||||
.. py:module:: torch.fx.experimental.unify_refinements
|
||||
|
||||
@ -134,23 +134,6 @@ Quantization to work with this as well.
|
||||
ObservationType
|
||||
```
|
||||
|
||||
## torch.ao.quantization.backend_config.utils
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.ao.quantization.backend_config.utils
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
entry_to_pretty_str
|
||||
pattern_to_human_readable
|
||||
remove_boolean_dispatch_from_name
|
||||
|
||||
```
|
||||
|
||||
## torch.ao.quantization.fx.custom_config
|
||||
|
||||
This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization
|
||||
@ -171,30 +154,6 @@ This module contains a few CustomConfig classes that's used in both eager mode a
|
||||
StandaloneModuleConfigEntry
|
||||
```
|
||||
|
||||
## torch.ao.quantization.fx.utils
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.ao.quantization.fx.utils
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
all_node_args_except_first
|
||||
all_node_args_have_no_tensors
|
||||
collect_producer_nodes
|
||||
create_getattr_from_value
|
||||
create_node_from_old_node_preserve_meta
|
||||
graph_module_from_producer_nodes
|
||||
maybe_get_next_module
|
||||
node_arg_is_bias
|
||||
node_arg_is_weight
|
||||
return_arg_list
|
||||
```
|
||||
|
||||
## torch.ao.quantization.quantizer
|
||||
|
||||
```{eval-rst}
|
||||
|
||||
@ -19,91 +19,6 @@
|
||||
swap_tensors
|
||||
```
|
||||
|
||||
# torch.utils.collect_env
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.utils.collect_env
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.utils.collect_env
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
check_release_file
|
||||
is_xnnpack_available
|
||||
pretty_str
|
||||
```
|
||||
|
||||
# torch.utils.flop_counter
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.utils.flop_counter
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.utils.flop_counter
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
baddbmm_flop
|
||||
bmm_flop
|
||||
conv_backward_flop
|
||||
conv_flop
|
||||
conv_flop_count
|
||||
register_flop_formula
|
||||
sdpa_backward_flop
|
||||
sdpa_backward_flop_count
|
||||
sdpa_flop
|
||||
sdpa_flop_count
|
||||
shape_wrapper
|
||||
```
|
||||
|
||||
# torch.utils.hipify.hipify_python
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.utils.hipify.hipify_python
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.utils.hipify.hipify_python
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
compute_stats
|
||||
extract_arguments
|
||||
file_add_header
|
||||
file_specific_replacement
|
||||
find_bracket_group
|
||||
find_closure_group
|
||||
find_parentheses_group
|
||||
fix_static_global_kernels
|
||||
hip_header_magic
|
||||
hipify
|
||||
is_caffe2_gpu_file
|
||||
is_cusparse_file
|
||||
is_out_of_place
|
||||
is_pytorch_file
|
||||
is_special_file
|
||||
openf
|
||||
preprocess_file_and_save_result
|
||||
preprocessor
|
||||
processKernelLaunches
|
||||
replace_extern_shared
|
||||
replace_math_functions
|
||||
str2bool
|
||||
```
|
||||
|
||||
|
||||
<!-- This module needs to be documented. Adding here in the meantime
|
||||
for tracking purposes -->
|
||||
```{eval-rst}
|
||||
@ -128,6 +43,7 @@ for tracking purposes -->
|
||||
.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface
|
||||
.. py:module:: torch.utils.bundled_inputs
|
||||
.. py:module:: torch.utils.checkpoint
|
||||
.. py:module:: torch.utils.collect_env
|
||||
.. py:module:: torch.utils.cpp_backtrace
|
||||
.. py:module:: torch.utils.cpp_extension
|
||||
.. py:module:: torch.utils.data.backward_compatibility
|
||||
@ -164,8 +80,10 @@ for tracking purposes -->
|
||||
.. py:module:: torch.utils.data.sampler
|
||||
.. py:module:: torch.utils.dlpack
|
||||
.. py:module:: torch.utils.file_baton
|
||||
.. py:module:: torch.utils.flop_counter
|
||||
.. py:module:: torch.utils.hipify.constants
|
||||
.. py:module:: torch.utils.hipify.cuda_to_hip_mappings
|
||||
.. py:module:: torch.utils.hipify.hipify_python
|
||||
.. py:module:: torch.utils.hipify.version
|
||||
.. py:module:: torch.utils.hooks
|
||||
.. py:module:: torch.utils.jit.log_extract
|
||||
|
||||
@ -172,9 +172,9 @@ ignore = [
|
||||
"SIM102", "SIM103", "SIM112", # flake8-simplify code styles
|
||||
"SIM105", # these ignores are from flake8-simplify. please fix or ignore with commented reason
|
||||
"SIM108", # SIM108 ignored because we prefer if-else-block instead of ternary expression
|
||||
"SIM110", # Checks for for loops that can be replaced with a builtin function, like any or all.
|
||||
"SIM110",
|
||||
"SIM114", # Combine `if` branches using logical `or` operator
|
||||
"SIM115", # Checks for cases where files are opened without using a context manager.
|
||||
"SIM115",
|
||||
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements
|
||||
"SIM117",
|
||||
"SIM118",
|
||||
@ -184,6 +184,7 @@ ignore = [
|
||||
"TC006",
|
||||
# TODO: Remove Python-3.10 specific suppressions
|
||||
"B905",
|
||||
"UP035",
|
||||
]
|
||||
select = [
|
||||
"B",
|
||||
|
||||
3
setup.py
3
setup.py
@ -1646,7 +1646,8 @@ def main() -> None:
|
||||
mirror_files_into_torchgen()
|
||||
if RUN_BUILD_DEPS:
|
||||
build_deps()
|
||||
mirror_inductor_external_kernels()
|
||||
|
||||
mirror_inductor_external_kernels()
|
||||
|
||||
(
|
||||
ext_modules,
|
||||
|
||||
@ -208,7 +208,7 @@ class _BaseDataSparsiferTestCase(TestCase):
|
||||
assert len(sparsifier1.data_groups) == len(sparsifier2.data_groups)
|
||||
|
||||
state1 = state_dict1["state"]
|
||||
for name in state1:
|
||||
for name in state1.keys():
|
||||
# compare mask
|
||||
assert name in sparsifier2.state
|
||||
assert "mask" in sparsifier2.state[name]
|
||||
|
||||
@ -119,7 +119,7 @@ class TestBaseSparsifier(TestCase):
|
||||
for idx in range(len(sparsifier0.groups)):
|
||||
mg0 = sparsifier0.groups[idx]
|
||||
mg1 = sparsifier1.groups[idx]
|
||||
for key in mg0:
|
||||
for key in mg0.keys():
|
||||
assert key in mg1
|
||||
if key == "module":
|
||||
# We cannot compare modules as they are different
|
||||
|
||||
@ -67,13 +67,13 @@ Tensor sgd_out_of_place(
|
||||
|
||||
void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = sgd_out_of_place(
|
||||
torch::stable::detail::to<Tensor>(stack[0]),
|
||||
torch::stable::detail::to<Tensor>(stack[1]),
|
||||
float(torch::stable::detail::to<double>(stack[2])),
|
||||
torch::stable::detail::to<double>(stack[3]),
|
||||
torch::stable::detail::to<bool>(stack[4]));
|
||||
to<Tensor>(stack[0]),
|
||||
to<Tensor>(stack[1]),
|
||||
float(to<double>(stack[2])),
|
||||
to<double>(stack[3]),
|
||||
to<bool>(stack[4]));
|
||||
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
|
||||
@ -89,8 +89,8 @@ Tensor identity(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = identity(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
Tensor res = identity(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -108,14 +108,14 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
Tensor my_abs(Tensor t) {
|
||||
const auto num_args = 1;
|
||||
StableIValue stack[num_args];
|
||||
stack[0] = torch::stable::detail::from(t);
|
||||
stack[0] = from(t);
|
||||
aoti_torch_call_dispatcher("aten::abs", "", stack);
|
||||
return torch::stable::detail::to<Tensor>(stack[0]);
|
||||
return to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_abs(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(tensor_res);
|
||||
Tensor tensor_res = my_abs(to<Tensor>(stack[0]));
|
||||
stack[0] = from(tensor_res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -132,21 +132,21 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
|
||||
|
||||
auto mf = aoti_torch_memory_format_contiguous_format();
|
||||
|
||||
stack[0] = torch::stable::detail::from(t);
|
||||
stack[1] = torch::stable::detail::from(std::optional(t.scalar_type())); // dtype
|
||||
stack[2] = torch::stable::detail::from(std::nullopt); // layout
|
||||
stack[3] = torch::stable::detail::from(std::optional(device)); // device
|
||||
stack[4] = torch::stable::detail::from(std::optional(false)); // pin_memory
|
||||
stack[5] = torch::stable::detail::from(std::optional(mf)); // memory_format
|
||||
stack[0] = from(t);
|
||||
stack[1] = from(std::optional(t.scalar_type())); // dtype
|
||||
stack[2] = from(std::nullopt); // layout
|
||||
stack[3] = from(std::optional(device)); // device
|
||||
stack[4] = from(std::optional(false)); // pin_memory
|
||||
stack[5] = from(std::optional(mf)); // memory_format
|
||||
|
||||
aoti_torch_call_dispatcher("aten::ones_like", "", stack);
|
||||
|
||||
return torch::stable::detail::to<Tensor>(stack[0]);
|
||||
return to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = my_ones_like(torch::stable::detail::to<Tensor>(stack[0]), stack[1]);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
Tensor res = my_ones_like(to<Tensor>(stack[0]), stack[1]);
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -159,28 +159,28 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
|
||||
std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) {
|
||||
StableIValue stack_exp[1];
|
||||
stack_exp[0] = torch::stable::detail::from(t1);
|
||||
stack_exp[0] = from(t1);
|
||||
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
|
||||
|
||||
StableIValue stack_neg[1];
|
||||
stack_neg[0] = torch::stable::detail::from(t2);
|
||||
stack_neg[0] = from(t2);
|
||||
aoti_torch_call_dispatcher("aten::neg", "", stack_neg);
|
||||
|
||||
StableIValue stack_is_leaf[1];
|
||||
stack_is_leaf[0] = torch::stable::detail::from(t3);
|
||||
stack_is_leaf[0] = from(t3);
|
||||
aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf);
|
||||
|
||||
return std::make_tuple(
|
||||
torch::stable::detail::to<Tensor>(stack_exp[0]),
|
||||
torch::stable::detail::to<Tensor>(stack_neg[0]),
|
||||
torch::stable::detail::to<bool>(stack_is_leaf[0]));
|
||||
to<Tensor>(stack_exp[0]),
|
||||
to<Tensor>(stack_neg[0]),
|
||||
to<bool>(stack_is_leaf[0]));
|
||||
}
|
||||
|
||||
void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto tuple = exp_neg_is_leaf(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<Tensor>(stack[2]));
|
||||
stack[0] = torch::stable::detail::from(std::get<0>(tuple));
|
||||
stack[1] = torch::stable::detail::from(std::get<1>(tuple));
|
||||
stack[2] = torch::stable::detail::from(std::get<2>(tuple));
|
||||
auto tuple = exp_neg_is_leaf(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<Tensor>(stack[2]));
|
||||
stack[0] = from(std::get<0>(tuple));
|
||||
stack[1] = from(std::get<1>(tuple));
|
||||
stack[2] = from(std::get<2>(tuple));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -193,15 +193,15 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
|
||||
Tensor neg_exp(Tensor t) {
|
||||
StableIValue stack[1];
|
||||
stack[0] = torch::stable::detail::from(t);
|
||||
stack[0] = from(t);
|
||||
aoti_torch_call_dispatcher("aten::exp", "", stack);
|
||||
aoti_torch_call_dispatcher("aten::neg", "", stack);
|
||||
return torch::stable::detail::to<Tensor>(stack[0]);
|
||||
return to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
Tensor res = neg_exp(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -214,10 +214,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
|
||||
Tensor divide_neg_exp(Tensor t) {
|
||||
StableIValue stack_neg[1];
|
||||
stack_neg[0] = torch::stable::detail::from(t);
|
||||
stack_neg[0] = from(t);
|
||||
|
||||
StableIValue stack_exp[1];
|
||||
stack_exp[0] = torch::stable::detail::from(t);
|
||||
stack_exp[0] = from(t);
|
||||
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
|
||||
aoti_torch_call_dispatcher("aten::neg", "", stack_neg);
|
||||
|
||||
@ -225,12 +225,12 @@ Tensor divide_neg_exp(Tensor t) {
|
||||
stack_div[0] = stack_neg[0];
|
||||
stack_div[1] = stack_exp[0];
|
||||
aoti_torch_call_dispatcher("aten::divide", "Tensor", stack_div);
|
||||
return torch::stable::detail::to<Tensor>(stack_div[0]);
|
||||
return to<Tensor>(stack_div[0]);
|
||||
}
|
||||
|
||||
void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = divide_neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
Tensor res = divide_neg_exp(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -246,8 +246,8 @@ bool is_contiguous(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
bool res = is_contiguous(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
bool res = is_contiguous(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -263,9 +263,9 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
|
||||
}
|
||||
|
||||
void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_transpose(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<int64_t>(stack[1]), torch::stable::detail::to<int64_t>(stack[2]));
|
||||
auto res = my_transpose(to<Tensor>(stack[0]), to<int64_t>(stack[1]), to<int64_t>(stack[2]));
|
||||
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_empty_like(Tensor t) {
|
||||
@ -273,8 +273,8 @@ Tensor my_empty_like(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_empty_like(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
auto res = my_empty_like(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
bool my_is_cpu(Tensor t) {
|
||||
@ -283,8 +283,8 @@ bool my_is_cpu(Tensor t) {
|
||||
|
||||
|
||||
void boxed_my_is_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_is_cpu(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
auto res = my_is_cpu(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor fill_infinity(Tensor t) {
|
||||
@ -296,8 +296,8 @@ void boxed_fill_infinity(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = fill_infinity(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
auto res = fill_infinity(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_pad(Tensor t) {
|
||||
@ -310,8 +310,8 @@ void boxed_my_pad(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = my_pad(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
auto res = my_pad(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) {
|
||||
@ -323,11 +323,11 @@ void boxed_my_narrow(
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = my_narrow(
|
||||
torch::stable::detail::to<Tensor>(stack[0]),
|
||||
torch::stable::detail::to<int64_t>(stack[1]),
|
||||
torch::stable::detail::to<int64_t>(stack[2]),
|
||||
torch::stable::detail::to<int64_t>(stack[3]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
to<Tensor>(stack[0]),
|
||||
to<int64_t>(stack[1]),
|
||||
to<int64_t>(stack[2]),
|
||||
to<int64_t>(stack[3]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
@ -342,8 +342,8 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_new_empty_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
auto res = my_new_empty_dtype_variant(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_new_zeros_dtype_variant(Tensor t) {
|
||||
@ -352,8 +352,8 @@ Tensor my_new_zeros_dtype_variant(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_new_zeros_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
auto res = my_new_zeros_dtype_variant(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
|
||||
@ -361,8 +361,8 @@ Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
|
||||
}
|
||||
|
||||
void boxed_my_copy_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_copy_(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<bool>(stack[2]));
|
||||
stack[0] = torch::stable::detail::from(tensor_res);
|
||||
Tensor tensor_res = my_copy_(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<bool>(stack[2]));
|
||||
stack[0] = from(tensor_res);
|
||||
}
|
||||
|
||||
Tensor my_clone(Tensor t) {
|
||||
@ -370,8 +370,8 @@ Tensor my_clone(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_clone(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_clone(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(tensor_res);
|
||||
Tensor tensor_res = my_clone(to<Tensor>(stack[0]));
|
||||
stack[0] = from(tensor_res);
|
||||
}
|
||||
|
||||
|
||||
@ -408,8 +408,8 @@ Tensor my_zero_(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_zero_(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
auto res = my_zero_(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_amax(Tensor t) {
|
||||
@ -417,8 +417,8 @@ Tensor my_amax(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_amax(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
auto res = my_amax(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_amax_vec(Tensor t) {
|
||||
@ -426,8 +426,8 @@ Tensor my_amax_vec(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_amax_vec(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
auto res = my_amax_vec(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -464,8 +464,8 @@ void boxed_test_default_constructor(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
bool res = test_default_constructor(torch::stable::detail::to<bool>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
bool res = test_default_constructor(to<bool>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -478,56 +478,6 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_amax_vec", &boxed_my_amax_vec);
|
||||
}
|
||||
|
||||
std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
|
||||
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
|
||||
aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data());
|
||||
return torch::stable::detail::to<std::vector<Tensor>>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_my__foreach_mul(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
// Why is the following NOT torch::stable::detail::to<HeaderOnlyArrayRef<Tensor>>(stack[0])? Because calling `to`
|
||||
// on a StableIValue means that the result is owning its underlying data now! HeaderOnlyArrayRef
|
||||
// is not owning, so it cannot safely steward the result of the torch::stable::detail::to<>.
|
||||
auto res = my__foreach_mul(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
|
||||
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
|
||||
aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data());
|
||||
}
|
||||
|
||||
void boxed_my__foreach_mul_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
my__foreach_mul_(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
|
||||
}
|
||||
|
||||
std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
|
||||
// This function tests that my__foreach_mul can take in std::initializer_lists
|
||||
// in addition to std::vectors.
|
||||
Tensor t1_1 = my_clone(t1);
|
||||
Tensor t1_2 = my_clone(t1);
|
||||
Tensor t2_1 = my_clone(t2);
|
||||
Tensor t2_2 = my_clone(t2);
|
||||
return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2});
|
||||
}
|
||||
|
||||
void boxed_make_tensor_clones_and_call_foreach(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = make_tensor_clones_and_call_foreach(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]");
|
||||
m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()");
|
||||
m.def("make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my__foreach_mul", &boxed_my__foreach_mul);
|
||||
m.impl("my__foreach_mul_", &boxed_my__foreach_mul_);
|
||||
m.impl("make_tensor_clones_and_call_foreach", &boxed_make_tensor_clones_and_call_foreach);
|
||||
}
|
||||
|
||||
// Test functions for torch::stable::accelerator APIs
|
||||
|
||||
#ifdef LAE_USE_CUDA
|
||||
@ -550,8 +500,8 @@ void boxed_test_device_guard(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int res = test_device_guard(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
int res = test_device_guard(static_cast<int64_t>(to<int64_t>(stack[0])));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
int64_t test_device_guard_set_index() {
|
||||
@ -570,7 +520,7 @@ void boxed_test_device_guard_set_index(
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_device_guard_set_index();
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
int64_t test_stream(int32_t device_index) {
|
||||
@ -586,8 +536,8 @@ void boxed_test_stream(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_stream(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
int64_t res = test_stream(static_cast<int64_t>(to<int64_t>(stack[0])));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
int64_t test_get_current_device_index() {
|
||||
@ -599,7 +549,7 @@ void boxed_test_get_current_device_index(
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_get_current_device_index();
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -615,5 +565,4 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("test_stream", &boxed_test_stream);
|
||||
m.impl("test_get_current_device_index", &boxed_test_get_current_device_index);
|
||||
}
|
||||
|
||||
#endif // LAE_USE_CUDA
|
||||
|
||||
@ -333,45 +333,3 @@ def my_new_zeros_dtype_variant(t) -> Tensor:
|
||||
Returns: New zeros tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_new_zeros_dtype_variant.default(t)
|
||||
|
||||
|
||||
def my__foreach_mul_(tensors, others) -> ():
|
||||
"""
|
||||
Updates tensors to be the result of pointwise multiplying with others.
|
||||
|
||||
Args:
|
||||
tensors: list of tensors
|
||||
others: list of tensors (with the same corresponding shapes as tensors)
|
||||
|
||||
Returns: nothing, tensors is updated in place.
|
||||
"""
|
||||
torch.ops.libtorch_agnostic.my__foreach_mul_.default(tensors, others)
|
||||
|
||||
|
||||
def my__foreach_mul(tensors, others) -> list[Tensor]:
|
||||
"""
|
||||
Returns a list of tensors that are the results of pointwise multiplying
|
||||
tensors and others.
|
||||
|
||||
Args:
|
||||
tensors: list of tensors
|
||||
others: list of tensors (with the same corresponding shapes as tensors)
|
||||
|
||||
Returns: list of multiplied tensors
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my__foreach_mul.default(tensors, others)
|
||||
|
||||
|
||||
def make_tensor_clones_and_call_foreach(t1, t2) -> list[Tensor]:
|
||||
"""
|
||||
Returns a list of 2 tensors corresponding to the square of the inputs.
|
||||
|
||||
Args:
|
||||
t1: Tensor
|
||||
t2: Tensor
|
||||
|
||||
Returns: list of [t1^2, t2^2]
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.make_tensor_clones_and_call_foreach.default(
|
||||
t1, t2
|
||||
)
|
||||
|
||||
@ -367,57 +367,6 @@ if not IS_WINDOWS:
|
||||
self.assertNotEqual(result.data_ptr(), expected.data_ptr())
|
||||
self.assertEqual(result.stride(), expected.stride())
|
||||
|
||||
def test_my__foreach_mul_(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
N = 5
|
||||
tensors = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
tensors_c = [t.clone() for t in tensors]
|
||||
others = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
|
||||
libtorch_agnostic.ops.my__foreach_mul_(tensors, others)
|
||||
expected_values = torch._foreach_mul(tensors_c, others)
|
||||
|
||||
for tensor_t, expected_t in zip(tensors, expected_values):
|
||||
self.assertEqual(tensor_t, expected_t)
|
||||
|
||||
def test_my__foreach_mul(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
N = 5
|
||||
tensors = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
others = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
|
||||
result = libtorch_agnostic.ops.my__foreach_mul(tensors, others)
|
||||
expected = torch._foreach_mul(tensors, others)
|
||||
|
||||
for result_t, expected_t in zip(result, expected):
|
||||
self.assertEqual(result_t, expected_t)
|
||||
|
||||
def _make_cuda_tensors(prior_mem):
|
||||
cuda_res = libtorch_agnostic.ops.my__foreach_mul(tensors, others)
|
||||
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
|
||||
|
||||
expected = torch._foreach_mul(tensors, others)
|
||||
for result_t, expected_t in zip(cuda_res, expected):
|
||||
self.assertEqual(result_t, expected_t)
|
||||
|
||||
if tensors[0].is_cuda:
|
||||
init_mem = torch.cuda.memory_allocated(device)
|
||||
for _ in range(3):
|
||||
_make_cuda_tensors(init_mem)
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
def test_make_tensor_clones_and_call_foreach(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t1 = torch.rand(2, 5, device=device)
|
||||
t2 = torch.rand(3, 4, device=device)
|
||||
result = libtorch_agnostic.ops.make_tensor_clones_and_call_foreach(t1, t2)
|
||||
self.assertEqual(result[0], t1 * t1)
|
||||
self.assertEqual(result[1], t2 * t2)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from backend import get_custom_backend_library_path, Model, to_custom_backend
|
||||
@ -40,11 +41,14 @@ class TestCustomBackend(TestCase):
|
||||
self.test_execute()
|
||||
|
||||
# Save and load.
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
try:
|
||||
f.close()
|
||||
torch.jit.save(self.model, f.name)
|
||||
loaded = torch.jit.load(f.name)
|
||||
self.model = loaded
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
self.model = loaded
|
||||
|
||||
# Test execution again.
|
||||
self.test_execute()
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -143,13 +144,16 @@ def forward(self, arg0_1):
|
||||
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
|
||||
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
|
||||
# close the file after creation and try to remove it manually.
|
||||
with tempfile.NamedTemporaryFile() as file:
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
try:
|
||||
file.close()
|
||||
model.save(file.name)
|
||||
loaded = torch.jit.load(file.name)
|
||||
finally:
|
||||
os.unlink(file.name)
|
||||
|
||||
output = loaded.forward(torch.ones(5))
|
||||
self.assertTrue(output.allclose(torch.ones(5) + 1))
|
||||
output = loaded.forward(torch.ones(5))
|
||||
self.assertTrue(output.allclose(torch.ones(5) + 1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["module: fsdp"]
|
||||
import functools
|
||||
import os
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch._dynamo.test_case import run_tests
|
||||
@ -37,9 +37,9 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
|
||||
logger = logging.getLogger("torch.distributed._composable.fsdp")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
device = '{device_type.type}'
|
||||
device = {device_type.type}
|
||||
torch.manual_seed(0)
|
||||
model = nn.Sequential(*[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)])
|
||||
for layer in model:
|
||||
|
||||
@ -76,7 +76,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
store=dist.FileStore(self.file_name, self.world_size),
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_replicate_transformer(self):
|
||||
"""
|
||||
This tests that replicate works on a transformer model with fully_shard and replicate layers
|
||||
@ -126,7 +126,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
for parameter in layer.parameters():
|
||||
self.assertEqual(parameter.placements, (Shard(dim=0),))
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_replicate_transformer_managed_modules(self):
|
||||
"""
|
||||
This tests that replicate managed modules works properly. In this test we use a Transformer Module with 3 layers,
|
||||
@ -178,7 +178,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
replicate_model = replicate(replicate_model)
|
||||
self.assertEqual(len(_get_managed_modules((replicate_model,))), 21)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_replicate_tp_device_mesh(self):
|
||||
"""
|
||||
This tests that a user can pass in a device mesh to replicate a module
|
||||
@ -206,7 +206,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
self.assertEqual(parameter.device_mesh.shape, (2,))
|
||||
self.assertEqual(parameter.placements, (Replicate(),))
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_train_replicate_fsdp(self):
|
||||
"""
|
||||
Tests that replicate_model has the same behavior as original model when training
|
||||
@ -253,7 +253,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
self.assertEqual(replicate_loss, loss)
|
||||
check_sharded_parity(self, model, replicate_model)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_train_parity_2d_mlp(self):
|
||||
"""
|
||||
Verifies when a device mesh is passed in, the model has the same behavior as the original model when training
|
||||
|
||||
@ -80,7 +80,7 @@ class TestSACILP(TestCase):
|
||||
# postprocessing due to the fact that for ModTracker, the post backward hook
|
||||
# is not being called for modules whose inputs don't require gradients
|
||||
# TODO: fix this in ModTracker and ensure it does not lead to any perf regression
|
||||
if _ModState.POST_BW not in mod_stats.snapshots:
|
||||
if _ModState.POST_BW not in mod_stats.snapshots.keys():
|
||||
mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append(
|
||||
copy.deepcopy(last_snapshot)
|
||||
)
|
||||
|
||||
@ -16,7 +16,7 @@ from torch.distributed.argparse_util import check_env, env
|
||||
class ArgParseUtilTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# remove any lingering environment variables
|
||||
for e in os.environ.keys(): # noqa: SIM118
|
||||
for e in os.environ.keys():
|
||||
if e.startswith("PET_"):
|
||||
del os.environ[e]
|
||||
|
||||
|
||||
@ -207,7 +207,7 @@ class TestDefaultStager(TestCase):
|
||||
for i, result in enumerate(staged_results):
|
||||
self.assertIsInstance(result, dict)
|
||||
# Verify the result contains the expected keys
|
||||
for key in state_dicts[i]:
|
||||
for key in state_dicts[i].keys():
|
||||
self.assertIn(key, result)
|
||||
|
||||
stager.close()
|
||||
|
||||
@ -299,7 +299,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_dtensor_checkpoint_with_uneven_shards(self) -> None:
|
||||
"""
|
||||
Saving a dtensor with uneven shards.
|
||||
@ -436,7 +436,6 @@ class TestCheckpointableReshard(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_uneven_reshard_with_checkpointable_api(self) -> None:
|
||||
"""
|
||||
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.
|
||||
@ -499,7 +498,6 @@ class TestCheckpointableReshard(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_uneven_reshard_with_dtensor_shards_wrapper_api(self) -> None:
|
||||
"""
|
||||
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.
|
||||
|
||||
@ -60,7 +60,7 @@ class TestSingleRankSaveLoad(TestCase):
|
||||
self.assertEqual(
|
||||
sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys())
|
||||
)
|
||||
for key in state_dict_to_save:
|
||||
for key in state_dict_to_save.keys():
|
||||
self.assertTrue(
|
||||
torch.equal(state_dict_to_save[key], state_dict_loaded[key])
|
||||
)
|
||||
@ -89,7 +89,7 @@ class TestSingleRankSaveLoad(TestCase):
|
||||
self.assertEqual(
|
||||
sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys())
|
||||
)
|
||||
for key in state_dict_to_save:
|
||||
for key in state_dict_to_save.keys():
|
||||
self.assertTrue(
|
||||
torch.equal(state_dict_to_save[key], state_dict_to_load[key])
|
||||
)
|
||||
@ -116,7 +116,7 @@ class TestSingleRankSaveLoad(TestCase):
|
||||
self.assertEqual(
|
||||
sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys())
|
||||
)
|
||||
for key in state_dict_to_save:
|
||||
for key in state_dict_to_save.keys():
|
||||
self.assertTrue(
|
||||
torch.equal(state_dict_to_save[key], state_dict_loaded[key])
|
||||
)
|
||||
@ -156,7 +156,7 @@ class TestSingleRankSaveLoad(TestCase):
|
||||
self.assertEqual(
|
||||
sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys())
|
||||
)
|
||||
for key in state_dict_to_save:
|
||||
for key in state_dict_to_save.keys():
|
||||
self.assertTrue(
|
||||
torch.equal(state_dict_to_save[key], state_dict_to_load[key])
|
||||
)
|
||||
|
||||
@ -18,7 +18,6 @@ from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
|
||||
from torch.distributed.checkpoint.api import CheckpointException
|
||||
from torch.distributed.checkpoint.default_planner import (
|
||||
_create_default_local_metadata,
|
||||
_validate_global_plan,
|
||||
create_default_global_save_plan,
|
||||
create_default_local_load_plan,
|
||||
create_default_local_save_plan,
|
||||
@ -29,7 +28,6 @@ from torch.distributed.checkpoint.filesystem import CURRENT_DCP_VERSION
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
BytesStorageMetadata,
|
||||
ChunkStorageMetadata,
|
||||
Metadata,
|
||||
MetadataIndex,
|
||||
TensorProperties,
|
||||
TensorStorageMetadata,
|
||||
@ -562,32 +560,6 @@ class TestPlannerHelpers(TestCase):
|
||||
self.assertTrue(_compare_save_plans(plan2, plan2))
|
||||
|
||||
|
||||
class TestValidateGlobalPlan(TestCase):
|
||||
def _make_metadata(self, chunks, size):
|
||||
storage = TensorStorageMetadata(
|
||||
properties=TensorProperties(dtype=torch.float32),
|
||||
size=torch.Size(size),
|
||||
chunks=chunks,
|
||||
)
|
||||
return Metadata(state_dict_metadata={"param": storage})
|
||||
|
||||
def test_non_overlapping_chunks(self):
|
||||
chunks = [
|
||||
ChunkStorageMetadata(offsets=torch.Size([i]), sizes=torch.Size([1]))
|
||||
for i in range(4)
|
||||
]
|
||||
metadata = self._make_metadata(chunks, [4])
|
||||
self.assertTrue(_validate_global_plan([SavePlan([])], metadata))
|
||||
|
||||
def test_detect_overlapping_chunks(self):
|
||||
chunks = [
|
||||
ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([2])),
|
||||
ChunkStorageMetadata(offsets=torch.Size([1]), sizes=torch.Size([2])),
|
||||
]
|
||||
metadata = self._make_metadata(chunks, [4])
|
||||
self.assertFalse(_validate_global_plan([SavePlan([])], metadata))
|
||||
|
||||
|
||||
class TestLoadPlanner(TestCase):
|
||||
@with_temp_dir
|
||||
def test_strict(self):
|
||||
|
||||
@ -769,7 +769,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
model_state_dict3 = copy.deepcopy(model_state_dict3)
|
||||
self.assertEqual(len(model_state_dict2), 2)
|
||||
self.assertEqual(len(model_state_dict3), 2)
|
||||
for key in model_state_dict3:
|
||||
for key in model_state_dict3.keys():
|
||||
full_fqn = f"l.{key}"
|
||||
value1 = model_state_dict1[full_fqn]
|
||||
value2 = model_state_dict2[full_fqn]
|
||||
@ -886,7 +886,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
self.assertEqual(cpu_model_value, meta_model_value)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_setting_meta_device_model_broadcasting_and_memory(self) -> None:
|
||||
# This test verifies that we can set model state dict by a meta device model
|
||||
# With the correlated changes in state_dict, meta device model should be accepted
|
||||
|
||||
@ -587,7 +587,9 @@ class TestFSDPStateDict(FSDPTest):
|
||||
model, cpu_offload.offload_params, fp16
|
||||
)
|
||||
|
||||
ignore_keys = [k for k in fsdp_state_dict if NON_ROOT_FSDP_PREFIX in k]
|
||||
ignore_keys = [
|
||||
k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k
|
||||
]
|
||||
|
||||
self._validate_state_dict_contents(
|
||||
model,
|
||||
@ -908,7 +910,7 @@ class TestFSDPStateDict(FSDPTest):
|
||||
with sd_mgr:
|
||||
fsdp_state_dict = model.state_dict()
|
||||
|
||||
ignore_keys = [k for k in fsdp_state_dict if NON_ROOT_FSDP_PREFIX in k]
|
||||
ignore_keys = [k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k]
|
||||
self._validate_state_dict_contents(
|
||||
model,
|
||||
fsdp_state_dict,
|
||||
@ -957,7 +959,9 @@ class TestFSDPStateDict(FSDPTest):
|
||||
# Full name of linear_skip param tensors in SkipModel, as would be
|
||||
# stored in checkpoint.
|
||||
linear_skip_tensor_names = [
|
||||
k for k in dict(module.named_parameters()) if LINEAR_SKIP in k
|
||||
k
|
||||
for k in dict(module.named_parameters()).keys()
|
||||
if LINEAR_SKIP in k
|
||||
]
|
||||
# skip SkipModule
|
||||
linear_skip = getattr(module, LINEAR_SKIP)
|
||||
|
||||
@ -137,7 +137,7 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
|
||||
# remove any lingering environment variables.
|
||||
for env in os.environ.keys(): # noqa: SIM118
|
||||
for env in os.environ.keys():
|
||||
if env.startswith("PET_"):
|
||||
del os.environ[env]
|
||||
|
||||
|
||||
@ -69,7 +69,7 @@ class ElasticLaunchTest(TestCase):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
|
||||
# remove any lingering environment variables
|
||||
for env in os.environ.keys(): # noqa: SIM118
|
||||
for env in os.environ.keys():
|
||||
if env.startswith("PET_"):
|
||||
del os.environ[env]
|
||||
|
||||
|
||||
@ -39,7 +39,6 @@ from torch.nn.modules.loss import MSELoss
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcContinuousTest,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
check_leaked_tensors,
|
||||
@ -232,7 +231,6 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [_ScheduleForwardOnly])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_forward_only(self, ScheduleClass):
|
||||
mod, mod_ref, x, _, _ = setup_models_and_data(self.config)
|
||||
x_clone = x.clone()
|
||||
@ -276,7 +274,6 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
ScheduleInterleavedZeroBubble,
|
||||
],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_eval_inference_mode(self, ScheduleClass):
|
||||
num_microbatches = 4
|
||||
if ScheduleClass in [
|
||||
@ -354,7 +351,6 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
ScheduleInterleavedZeroBubble,
|
||||
],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_return_output(self, ScheduleClass):
|
||||
num_microbatches = 4
|
||||
if ScheduleClass in [
|
||||
@ -410,7 +406,6 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_multi_iter(self, ScheduleClass):
|
||||
mod, _, x, target, loss_fn = setup_models_and_data(self.config)
|
||||
chunks = 4
|
||||
@ -434,7 +429,6 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_kwargs_with_tracer(self, ScheduleClass):
|
||||
mod = ModelWithKwargs(d_hid, splits=self.world_size)
|
||||
mod.to(self.device)
|
||||
@ -487,7 +481,6 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_grad_with_tracer(self, ScheduleClass):
|
||||
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
|
||||
|
||||
@ -530,7 +523,6 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@parametrize("shape_inference", [True, False])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_grad_with_manual(self, ScheduleClass, shape_inference):
|
||||
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
|
||||
|
||||
@ -594,7 +586,6 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
ScheduleInterleavedZeroBubble,
|
||||
],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_grad_with_manual_interleaved(self, ScheduleClass):
|
||||
stages_per_rank = 2
|
||||
n_stages = stages_per_rank * self.world_size
|
||||
@ -659,7 +650,6 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass):
|
||||
stages_per_rank = 2
|
||||
n_stages = stages_per_rank * self.world_size
|
||||
@ -746,7 +736,6 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
"schedule_class",
|
||||
[ScheduleZBVZeroBubble, ScheduleDualPipeV],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_v_shape_schedules(self, schedule_class):
|
||||
n_stages = 8
|
||||
rank_stages = {0: [0, 7], 1: [1, 6], 2: [2, 5], 3: [3, 4]}
|
||||
@ -791,7 +780,6 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_custom_function_callback(self):
|
||||
"""Test the custom function callback functionality with _PipelineScheduleRuntime."""
|
||||
n_stages = 8
|
||||
@ -991,7 +979,6 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
"ScheduleClass",
|
||||
[ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_zero_bubble_with_model_kwargs(self, ScheduleClass):
|
||||
stages_per_rank = 2
|
||||
n_stages = stages_per_rank * self.world_size
|
||||
@ -1085,7 +1072,6 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
"schedule_class",
|
||||
[ScheduleVShaped, ScheduleUnbalanced],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_non_symmetric_stage_ids(self, schedule_class):
|
||||
n_stages = schedule_class.n_stages
|
||||
rank_stages = schedule_class.rank_stages
|
||||
@ -1135,7 +1121,6 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleWithReorderedB])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass):
|
||||
n_stages = 2
|
||||
stages_per_rank = 1
|
||||
@ -1196,7 +1181,6 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleWithW])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_schedule_with_native_zero_bubble(self, ScheduleClass):
|
||||
n_stages = ScheduleClass.n_stages
|
||||
num_microbatches = ScheduleClass.num_microbatches
|
||||
|
||||
@ -204,16 +204,14 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
||||
self.assertTrue(b_dt.grad is not None)
|
||||
self.assertTrue(x_dt.grad is None)
|
||||
|
||||
def _run_single_arg_fwd(
|
||||
self, model, arg, placements=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Given model and arg, runs fwd model local and distbuted given device_mesh"""
|
||||
device_mesh = self.build_device_mesh()
|
||||
model_copy = copy.deepcopy(model).to(device=self.device_type)
|
||||
dist_model = distribute_module(model, device_mesh, _conv_fn)
|
||||
arg_dt = DTensor.from_local(arg, device_mesh, placements)
|
||||
arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()])
|
||||
out_dt = dist_model(arg_dt.to(device=self.device_type))
|
||||
out = model_copy(arg_dt.full_tensor())
|
||||
out = model_copy(arg)
|
||||
return (out_dt.full_tensor(), out)
|
||||
|
||||
@with_comms
|
||||
@ -221,20 +219,22 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
||||
model = nn.Conv1d(64, 64, 3, padding=1)
|
||||
x = torch.randn(1, 64, 8, device=self.device_type)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x)
|
||||
self.assertEqual(out_dt, out)
|
||||
self.assertEqual(out_dt.shape, out.shape)
|
||||
|
||||
@with_comms
|
||||
def test_conv3d(self):
|
||||
model = nn.Conv3d(64, 64, 3, padding=1)
|
||||
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x, [Shard(0)])
|
||||
self.assertEqual(out_dt, out)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x)
|
||||
self.assertEqual(out_dt.shape, out.shape)
|
||||
|
||||
|
||||
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DistConvolutionOpsTest,
|
||||
# Send / recv ops are not supported
|
||||
skipped_tests=[
|
||||
"test_conv1d",
|
||||
"test_conv3d",
|
||||
"test_conv_backward_none_grad_inp",
|
||||
"test_depthwise_convolution",
|
||||
"test_downsampling_convolution",
|
||||
|
||||
@ -464,25 +464,6 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
||||
run(g, 64, 8)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_dtensor_requires_grad_recompile(self):
|
||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
@torch.compile(backend=cnt, fullgraph=True)
|
||||
def f(x):
|
||||
y = x * x
|
||||
return y.to_local()
|
||||
|
||||
full_x = torch.randn(8, 8, requires_grad=False)
|
||||
x = distribute_tensor(full_x, mesh, [Shard(0)])
|
||||
f(x)
|
||||
|
||||
full_x = torch.randn(8, 8, requires_grad=True)
|
||||
x = distribute_tensor(full_x, mesh, [Shard(0)])
|
||||
f(x)
|
||||
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_dtensor_attribute_access_on_intermediate(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
|
||||
@ -535,19 +535,6 @@ class DTensorExportTest(TestCase):
|
||||
|
||||
self.assertEqual(fn(z), gm(z)[0])
|
||||
|
||||
def test_dtensor_data_dependent_index(self):
|
||||
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return x[y]
|
||||
|
||||
x = torch.randn(10)
|
||||
y = torch.randint(1, (10,)).bool()
|
||||
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()])
|
||||
y_dt = distribute_tensor(y, device_mesh, placements=[Replicate()])
|
||||
_dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(DTensorExportTest)
|
||||
|
||||
|
||||
@ -26,7 +26,6 @@ from torch.distributed.tensor.parallel import (
|
||||
RowwiseParallel,
|
||||
SequenceParallel,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
@ -765,7 +764,6 @@ class DistMathOpsTest(DTensorTestBase):
|
||||
self.assertEqual(grad1_norm.device_mesh, mesh_y)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_foreach_add_different_mesh(self):
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(
|
||||
|
||||
@ -577,7 +577,7 @@ class DistTensorReplicateStrategyRegistrationTest(DTensorTestBase):
|
||||
self.assertEqual(
|
||||
comm_mode.get_comm_counts(),
|
||||
{
|
||||
torch.ops.c10d_functional.all_gather_into_tensor: self.world_size,
|
||||
torch.ops.c10d_functional.all_gather_into_tensor: 4,
|
||||
},
|
||||
)
|
||||
expected_cost = [
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
@ -21,8 +22,9 @@ from torch.distributed.tensor import (
|
||||
)
|
||||
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.placement_types import _StridedShard, MaskPartial
|
||||
from torch.distributed.tensor.placement_types import _StridedShard
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -33,11 +35,7 @@ from torch.testing._internal.common_utils import (
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
DTensorTestBase,
|
||||
generate_shard_orders,
|
||||
make_full_tensor,
|
||||
map_local_tensor_for_rank,
|
||||
patched_distribute_tensor as _distribute_tensor,
|
||||
redistribute,
|
||||
with_comms,
|
||||
)
|
||||
from torch.utils._debug_mode import DebugMode
|
||||
@ -787,6 +785,88 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
else:
|
||||
return ""
|
||||
|
||||
# TODO(zpcore): remove once the native redistribute supports shard_order arg
|
||||
def redistribute(
|
||||
self,
|
||||
dtensor_input,
|
||||
device_mesh,
|
||||
placements,
|
||||
shard_order,
|
||||
use_graph_based_transform=True,
|
||||
):
|
||||
"""
|
||||
wrapper function to support shard_order for redistribution
|
||||
This is a simpler version of Redistribute, only considers the forward.
|
||||
"""
|
||||
if placements is None:
|
||||
placements = self._shard_order_to_placement(shard_order, device_mesh)
|
||||
placements = tuple(placements)
|
||||
old_spec = dtensor_input._spec
|
||||
new_spec = copy.deepcopy(old_spec)
|
||||
new_spec.placements = placements
|
||||
if shard_order is not None:
|
||||
new_spec.shard_order = shard_order
|
||||
else:
|
||||
new_spec.shard_order = ()
|
||||
if old_spec == new_spec:
|
||||
return dtensor_input
|
||||
dtensor_input = DTensor.from_local(
|
||||
redistribute_local_tensor(
|
||||
dtensor_input.to_local(),
|
||||
old_spec,
|
||||
new_spec,
|
||||
use_graph_based_transform=use_graph_based_transform,
|
||||
),
|
||||
device_mesh,
|
||||
)
|
||||
dtensor_input._spec = copy.deepcopy(new_spec)
|
||||
return dtensor_input # returns DTensor
|
||||
|
||||
# TODO(zpcore): remove once the native distribute_tensor supports
|
||||
# shard_order arg
|
||||
def distribute_tensor(
|
||||
self,
|
||||
input_tensor,
|
||||
device_mesh,
|
||||
placements,
|
||||
shard_order,
|
||||
use_graph_based_transform=True,
|
||||
):
|
||||
"""wrapper function to support shard_order for tensor distribution"""
|
||||
if placements is None:
|
||||
placements = self._shard_order_to_placement(shard_order, device_mesh)
|
||||
placements = tuple(placements)
|
||||
tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
|
||||
# fix the shard order
|
||||
return self.redistribute(
|
||||
tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
|
||||
)
|
||||
|
||||
# TODO(zpcore): remove once the native redistribute supports shard_order arg
|
||||
def full_tensor(self, dtensor_input):
|
||||
"""wrapper function to support DTensor.full_tensor"""
|
||||
return self.redistribute(
|
||||
dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
|
||||
).to_local()
|
||||
|
||||
def _shard_order_to_placement(self, shard_order, mesh):
|
||||
"""convert shard_order to placement with only Replicate() and Shard()"""
|
||||
placements = [Replicate() for _ in range(mesh.ndim)]
|
||||
if shard_order is not None:
|
||||
for entry in shard_order:
|
||||
tensor_dim = entry.tensor_dim
|
||||
mesh_dims = entry.mesh_dims
|
||||
for mesh_dim in mesh_dims:
|
||||
placements[mesh_dim] = Shard(tensor_dim)
|
||||
return tuple(placements)
|
||||
|
||||
def _convert_shard_order_dict_to_ShardOrder(self, shard_order):
|
||||
"""Convert shard_order dict to ShardOrder"""
|
||||
return tuple(
|
||||
ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
|
||||
for tensor_dim, mesh_dims in shard_order.items()
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_ordered_redistribute(self):
|
||||
"""Test ordered redistribution with various sharding syntaxes"""
|
||||
@ -847,11 +927,13 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
for idx, ((src_placement, src_order), (dst_placement, dst_order)) in enumerate(
|
||||
sharding_src_dst_pairs_with_expected_trace
|
||||
):
|
||||
sharded_dt = _distribute_tensor(
|
||||
sharded_dt = self.distribute_tensor(
|
||||
input_data.clone(), mesh, src_placement, shard_order=src_order
|
||||
)
|
||||
with DebugMode(record_torchfunction=False) as debug_mode:
|
||||
sharded_dt = redistribute(sharded_dt, mesh, dst_placement, dst_order)
|
||||
sharded_dt = self.redistribute(
|
||||
sharded_dt, mesh, dst_placement, dst_order
|
||||
)
|
||||
trace_str = self._extract_redistribute_trace_from_debug_mode(
|
||||
debug_mode.debug_string()
|
||||
)
|
||||
@ -875,11 +957,49 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
trace_str,
|
||||
"""S(0)[0]S(0)[1]R->S(0)S(1)R->RS(1)R->RS(1)S(0)""",
|
||||
)
|
||||
expected_dt = _distribute_tensor(
|
||||
expected_dt = self.distribute_tensor(
|
||||
input_data.clone(), mesh, dst_placement, shard_order=dst_order
|
||||
)
|
||||
self.assertEqual(sharded_dt.to_local(), expected_dt.to_local())
|
||||
|
||||
def generate_shard_orders(self, mesh, tensor_rank):
|
||||
# Generate all possible sharding placement of tensor with rank
|
||||
# `tensor_rank` over mesh.
|
||||
def _split_list(lst: list, N: int):
|
||||
def compositions(n, k):
|
||||
if k == 1:
|
||||
yield [n]
|
||||
else:
|
||||
for i in range(1, n - k + 2):
|
||||
for tail in compositions(n - i, k - 1):
|
||||
yield [i] + tail
|
||||
|
||||
length = len(lst)
|
||||
for comp in compositions(length, N):
|
||||
result = []
|
||||
start = 0
|
||||
for size in comp:
|
||||
result.append(lst[start : start + size])
|
||||
start += size
|
||||
yield result
|
||||
|
||||
all_mesh = list(range(mesh.ndim))
|
||||
all_device_order = list(itertools.permutations(all_mesh))
|
||||
for device_order in all_device_order:
|
||||
# split on device orders, and assign each device order segment to a tensor dim
|
||||
for num_split in range(1, mesh.ndim + 1):
|
||||
for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
|
||||
for tensor_dims in itertools.combinations(
|
||||
range(tensor_rank), len(splitted_list)
|
||||
):
|
||||
shard_order = {}
|
||||
assert len(tensor_dims) == len(splitted_list)
|
||||
for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
|
||||
shard_order[tensor_dim] = device_order[
|
||||
mesh_dims[0] : mesh_dims[-1] + 1
|
||||
]
|
||||
yield self._convert_shard_order_dict_to_ShardOrder(shard_order)
|
||||
|
||||
@with_comms
|
||||
def test_generate_shard_orders(self):
|
||||
"""Check if `generate_shard_orders` generates unique sharding combinations"""
|
||||
@ -892,7 +1012,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
]
|
||||
for test_input in test_inputs:
|
||||
all_combinations = []
|
||||
for shard_order in generate_shard_orders(
|
||||
for shard_order in self.generate_shard_orders(
|
||||
test_input["mesh"], test_input["tensor_rank"]
|
||||
):
|
||||
all_combinations.append(shard_order) # noqa: PERF402
|
||||
@ -942,12 +1062,12 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
input_data = torch.randn(tensor_shape, device=self.device_type)
|
||||
tensor_rank = input_data.ndim
|
||||
with maybe_disable_local_tensor_mode():
|
||||
shard_orders = generate_shard_orders(mesh, tensor_rank)
|
||||
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
|
||||
for shard_order in shard_orders:
|
||||
sharded_dt = _distribute_tensor(
|
||||
sharded_dt = self.distribute_tensor(
|
||||
input_data.clone(), mesh, placements=None, shard_order=shard_order
|
||||
)
|
||||
self.assertEqual(make_full_tensor(sharded_dt), input_data)
|
||||
self.assertEqual(self.full_tensor(sharded_dt), input_data)
|
||||
|
||||
# 2. Verify the correctness of redistribution from DTensor to DTensor.
|
||||
# This test repeatedly redistributes a DTensor to various ordered
|
||||
@ -958,20 +1078,20 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
tensor_rank = input_data.ndim
|
||||
prev_sharded_dt = None
|
||||
with maybe_disable_local_tensor_mode():
|
||||
shard_orders = generate_shard_orders(mesh, tensor_rank)
|
||||
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
|
||||
for shard_order in shard_orders:
|
||||
if prev_sharded_dt is None:
|
||||
prev_sharded_dt = _distribute_tensor(
|
||||
prev_sharded_dt = self.distribute_tensor(
|
||||
input_data.clone(),
|
||||
mesh,
|
||||
placements=None,
|
||||
shard_order=shard_order,
|
||||
)
|
||||
else:
|
||||
sharded_dt = redistribute(
|
||||
sharded_dt = self.redistribute(
|
||||
prev_sharded_dt, mesh, placements=None, shard_order=shard_order
|
||||
)
|
||||
self.assertEqual(make_full_tensor(sharded_dt), input_data)
|
||||
self.assertEqual(self.full_tensor(sharded_dt), input_data)
|
||||
prev_sharded_dt = sharded_dt
|
||||
|
||||
@with_comms
|
||||
@ -1016,13 +1136,13 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
local_tensor = torch.randn(shape, device=self.device_type)
|
||||
full_tensor = DTensor.from_local(local_tensor, mesh, placements)
|
||||
with maybe_disable_local_tensor_mode():
|
||||
shard_orders = generate_shard_orders(mesh, len(shape))
|
||||
shard_orders = self.generate_shard_orders(mesh, len(shape))
|
||||
for shard_order in shard_orders:
|
||||
sharded_dt = redistribute(
|
||||
sharded_dt = self.redistribute(
|
||||
full_tensor, mesh, placements=None, shard_order=shard_order
|
||||
)
|
||||
self.assertEqual(
|
||||
make_full_tensor(sharded_dt), make_full_tensor(full_tensor)
|
||||
self.full_tensor(sharded_dt), self.full_tensor(full_tensor)
|
||||
)
|
||||
|
||||
@unittest.skip(
|
||||
@ -1032,20 +1152,24 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_ordered_redistribute_for_special_placement(self):
|
||||
"""Test ordered redistribution with special placement"""
|
||||
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||
|
||||
torch.manual_seed(21)
|
||||
mesh = init_device_mesh(self.device_type, (8,))
|
||||
input_data = torch.randn((8, 8), device=self.device_type)
|
||||
src_placement = [Shard(1)]
|
||||
tgt_placement = [
|
||||
(MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
|
||||
(_MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
|
||||
]
|
||||
sharded_dt = _distribute_tensor(
|
||||
sharded_dt = self.distribute_tensor(
|
||||
input_data.clone(),
|
||||
mesh,
|
||||
src_placement,
|
||||
shard_order=(ShardOrderEntry(tensor_dim=1, mesh_dims=(0,)),),
|
||||
)
|
||||
sharded_dt = redistribute(sharded_dt, mesh, tgt_placement, shard_order=None)
|
||||
sharded_dt = self.redistribute(
|
||||
sharded_dt, mesh, tgt_placement, shard_order=None
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_shard_order_same_data_as_strided_shard(self):
|
||||
@ -1055,7 +1179,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
|
||||
x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
|
||||
# specify right-to-left order use ordered shard
|
||||
x_ordered_dt = _distribute_tensor(
|
||||
x_ordered_dt = self.distribute_tensor(
|
||||
x,
|
||||
device_mesh,
|
||||
placements=[Shard(0), Shard(0)],
|
||||
|
||||
@ -34,10 +34,6 @@ from torch.distributed.tensor.placement_types import (
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
generate_shard_orders,
|
||||
LocalDTensorTestBase,
|
||||
patched_distribute_tensor as _distribute_tensor,
|
||||
shard_order_to_placement,
|
||||
with_comms,
|
||||
)
|
||||
|
||||
@ -778,63 +774,6 @@ class TestStridedSharding(DTensorTestBase):
|
||||
self.assertEqual(dtensor.full_tensor(), tensor)
|
||||
|
||||
|
||||
class Test_StridedShard_with_shard_order(LocalDTensorTestBase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 32
|
||||
|
||||
@with_comms
|
||||
def test_StridedShard_to_shard_order(self):
|
||||
with LocalTensorMode(ranks=self.world_size):
|
||||
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(2, 2, 2, 2, 2))
|
||||
shard_iter = generate_shard_orders(mesh, 3)
|
||||
# It takes ~4.8h to complete total 2520 shard order combinations here
|
||||
# using LocalTensor. So we only randomly pick 25 shard orders to test.
|
||||
all_shard_order = list(shard_iter)
|
||||
import random
|
||||
|
||||
random.seed(42)
|
||||
shard_order_choices = random.sample(
|
||||
all_shard_order, min(25, len(all_shard_order))
|
||||
)
|
||||
|
||||
x = torch.randn(32, 32, 32)
|
||||
for shard_order in shard_order_choices:
|
||||
a = _distribute_tensor(x, mesh, None, shard_order)
|
||||
|
||||
placement_without_stridedshard = shard_order_to_placement(
|
||||
shard_order, mesh
|
||||
)
|
||||
placements_with_stridedshard = (
|
||||
DTensorSpec._convert_shard_order_to_StridedShard(
|
||||
shard_order, placement_without_stridedshard, mesh
|
||||
)
|
||||
)
|
||||
b = distribute_tensor(x, mesh, placements_with_stridedshard)
|
||||
shard_order_from_stridedshard = (
|
||||
DTensorSpec._maybe_convert_StridedShard_to_shard_order(
|
||||
placements_with_stridedshard, mesh
|
||||
)
|
||||
)
|
||||
self.assertEqual(shard_order, shard_order_from_stridedshard)
|
||||
self.assertEqual(a.to_local(), b.to_local())
|
||||
|
||||
@with_comms
|
||||
def test_StridedShard_not_convertible_to_shard_order(self):
|
||||
with LocalTensorMode(ranks=self.world_size):
|
||||
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(4, 8))
|
||||
unconvertible_placements_list = [
|
||||
[_StridedShard(0, split_factor=2), _StridedShard(1, split_factor=2)],
|
||||
[_StridedShard(0, split_factor=2), Shard(1)],
|
||||
[_StridedShard(1, split_factor=16), Shard(1)],
|
||||
]
|
||||
for placements in unconvertible_placements_list:
|
||||
shard_order = DTensorSpec._maybe_convert_StridedShard_to_shard_order(
|
||||
tuple(placements), mesh
|
||||
)
|
||||
self.assertIsNone(shard_order)
|
||||
|
||||
|
||||
class Test2DStridedLocalShard(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
|
||||
@ -54,7 +54,6 @@ def apply_reordering_and_get_graph(graph, out_li) -> None:
|
||||
"max_compute_pre_fetch",
|
||||
"custom_runtime_estimation",
|
||||
"insert_overlap_deps",
|
||||
"collective_estimator",
|
||||
)
|
||||
for key in config_keys:
|
||||
if (val := getattr(dist_opts, key)) is not None:
|
||||
@ -944,50 +943,6 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
||||
correct = func(inputs_a, inputs_b, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_collective_benchmarking_with_real_pg(self):
|
||||
"""Test collective benchmarking with real process group (falls back on fake)."""
|
||||
|
||||
def func(a):
|
||||
# Test all three collective types with 8x8 (power of 2 size = 256 elements = 1024 bytes for fp32)
|
||||
ar = _functional_collectives.all_reduce(a, "sum", "0")
|
||||
ag = _functional_collectives.all_gather_tensor(
|
||||
a, 0, list(range(self.world_size))
|
||||
)
|
||||
rs = _functional_collectives.reduce_scatter_tensor(a, "sum", 0, "0")
|
||||
|
||||
b = torch.matmul(a, a)
|
||||
c = torch.matmul(ar, b)
|
||||
return c.sum() + ag.sum() + rs.sum()
|
||||
|
||||
patches = {
|
||||
**get_patches(),
|
||||
"aten_distributed_optimizations.collective_estimator": "benchmark",
|
||||
"aten_distributed_optimizations.custom_runtime_estimation": None, # Remove custom estimation so benchmarking happens
|
||||
}
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(8, 8, dtype=torch.float, device=device_type) + self.rank
|
||||
|
||||
with torch._inductor.config.patch(patches):
|
||||
compiled = torch.compile(func)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
|
||||
|
||||
# Verify all three collective types are present
|
||||
FileCheck().check("all_reduce").check("all_gather").check(
|
||||
"reduce_scatter"
|
||||
).run(aten_graph_str)
|
||||
|
||||
# Test passes if compilation succeeded with benchmarking enabled
|
||||
# Cache verification is tricky due to multiprocess test setup
|
||||
correct = func(inputs)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_multidtype_bucketing(self):
|
||||
|
||||
@ -485,7 +485,7 @@ elif TEST_XPU:
|
||||
def exit_if_lt_x_accelerators(x):
|
||||
if torch.accelerator.is_available():
|
||||
if torch.accelerator.device_count() < x:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
||||
sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code)
|
||||
|
||||
|
||||
def with_comms(func=None):
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
# flake8: noqa: B950
|
||||
# flake8: noqa: E731
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
@ -17,11 +15,7 @@ import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._dynamo.testing import (
|
||||
AotEagerAndRecordGraphs,
|
||||
CompileCounterWithBackend,
|
||||
normalize_gm,
|
||||
)
|
||||
from torch._dynamo.testing import CompileCounterWithBackend
|
||||
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
|
||||
@ -1655,43 +1649,6 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
self.assertEqual(opt_fn(x), fn(x))
|
||||
|
||||
def test_return_same_element_twice(self):
|
||||
def gn(x):
|
||||
y = torch.sin(x)
|
||||
return y, y
|
||||
|
||||
def fn(x):
|
||||
return torch.utils.checkpoint.checkpoint(gn, x, use_reentrant=True)
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
ref = fn(x)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref[0], res[0])
|
||||
self.assertEqual(ref[1], res[1])
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[4, 4]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = True); wrap_body_0 = l_x_ = None
|
||||
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]
|
||||
getitem_1: "f32[4, 4]" = tag_activation_checkpoint[1]; tag_activation_checkpoint = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[4, 4]"):
|
||||
y: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
|
||||
return (y, y)
|
||||
""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
|
||||
def test_nonlocal_mutation(self):
|
||||
counter = 0
|
||||
@ -1715,114 +1672,6 @@ class GraphModule(torch.nn.Module):
|
||||
# The mutation is not reapplied in the backward because the flag was on.
|
||||
self.assertEqual(counter, 1)
|
||||
|
||||
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
|
||||
def test_nonlocal_list_mutation(self):
|
||||
def gn(x, z):
|
||||
out = x.sin()
|
||||
z.append(out)
|
||||
return torch.cos(torch.sin(torch.matmul(x, x) @ x)), out
|
||||
|
||||
def fn(x):
|
||||
z = []
|
||||
|
||||
out1, out2 = torch.utils.checkpoint.checkpoint(
|
||||
gn,
|
||||
x,
|
||||
z,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
return out1, z[0]
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
ref = fn(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref[0], res[0])
|
||||
self.assertEqual(ref[1], res[1])
|
||||
|
||||
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
|
||||
def test_nonlocal_list_mutation_hidden(self):
|
||||
def gn(x, z):
|
||||
o = torch.matmul(x, x) @ x
|
||||
out = x.sin()
|
||||
z.append(out)
|
||||
return torch.cos(torch.sin(o)), torch.sin(x)
|
||||
|
||||
def fn(x):
|
||||
z = []
|
||||
|
||||
outs = torch.utils.checkpoint.checkpoint(
|
||||
gn,
|
||||
x,
|
||||
z,
|
||||
use_reentrant=False,
|
||||
)
|
||||
out1 = outs[0]
|
||||
# Check that the extra output pytree handling is done properly
|
||||
out2 = outs[-1]
|
||||
|
||||
return out1 + out2, z[0]
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
ref = fn(x)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref[0], res[0])
|
||||
self.assertEqual(ref[1], res[1])
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[4, 4]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = False); wrap_body_0 = l_x_ = None
|
||||
out1: "f32[4, 4]" = tag_activation_checkpoint[0]
|
||||
out2: "f32[4, 4]" = tag_activation_checkpoint[1]
|
||||
getitem_4: "f32[4, 4]" = tag_activation_checkpoint[4]; tag_activation_checkpoint = None
|
||||
|
||||
add: "f32[4, 4]" = out1 + out2; out1 = out2 = None
|
||||
return (add, getitem_4)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[4, 4]"):
|
||||
matmul: "f32[4, 4]" = torch.matmul(l_x_, l_x_)
|
||||
o: "f32[4, 4]" = matmul @ l_x_
|
||||
|
||||
out: "f32[4, 4]" = l_x_.sin()
|
||||
|
||||
sin_1: "f32[4, 4]" = torch.sin(o)
|
||||
child: "f32[4, 4]" = torch.cos(sin_1)
|
||||
child_1: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
|
||||
return (child, child_1, matmul, o, out, sin_1)
|
||||
""",
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[4, 4]"):
|
||||
mm: "f32[4, 4]" = torch.ops.aten.mm.default(primals_1, primals_1)
|
||||
mm_1: "f32[4, 4]" = torch.ops.aten.mm.default(mm, primals_1); mm = None
|
||||
|
||||
sin: "f32[4, 4]" = torch.ops.aten.sin.default(primals_1)
|
||||
|
||||
sin_1: "f32[4, 4]" = torch.ops.aten.sin.default(mm_1); mm_1 = None
|
||||
cos: "f32[4, 4]" = torch.ops.aten.cos.default(sin_1); sin_1 = None
|
||||
sin_2: "f32[4, 4]" = torch.ops.aten.sin.default(primals_1)
|
||||
|
||||
add: "f32[4, 4]" = torch.ops.aten.add.Tensor(cos, sin_2); cos = sin_2 = None
|
||||
return (add, sin, primals_1)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
devices = ["cuda", "hpu"]
|
||||
instantiate_device_type_tests(
|
||||
|
||||
@ -2109,89 +2109,6 @@ Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: '
|
||||
with self.assertRaises(Unsupported):
|
||||
outer_f2(inp)
|
||||
|
||||
def test_disable_recursive_flags(self):
|
||||
class SimpleLinear(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layer0 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, inp):
|
||||
return self.layer0(torch.sigmoid(inp))
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layer0 = SimpleLinear()
|
||||
self.layer1 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, inp):
|
||||
z = self.layer0(torch.sin(inp))
|
||||
return self.layer1(z)
|
||||
|
||||
for recursive_flag in [True, False]:
|
||||
model = SimpleModel()
|
||||
other_model = SimpleModel()
|
||||
|
||||
model.forward = torch._dynamo.disable(
|
||||
model.forward,
|
||||
recursive=recursive_flag,
|
||||
)
|
||||
self.assertEqual(
|
||||
torch._dynamo.is_dynamo_disable_recursive(model.forward),
|
||||
recursive_flag,
|
||||
)
|
||||
|
||||
other_model = torch._dynamo.disable(other_model, recursive=recursive_flag)
|
||||
self.assertEqual(
|
||||
torch._dynamo.is_dynamo_disable_recursive(
|
||||
other_model.forward
|
||||
if isinstance(other_model, torch.nn.Module)
|
||||
else other_model
|
||||
),
|
||||
recursive_flag,
|
||||
)
|
||||
|
||||
# check the model is compilable
|
||||
torch.compile(model)
|
||||
torch.compile(other_model)
|
||||
|
||||
def test_dynamo_disable_annotations(self):
|
||||
class SimpleModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_buffer("buffer", torch.rand(2, 2))
|
||||
|
||||
@torch._dynamo.disable()
|
||||
def f1(self, x) -> torch.Tensor:
|
||||
return x + self.buffer + 1
|
||||
|
||||
@torch._dynamo.disable()
|
||||
def f2(self, x) -> torch.Tensor:
|
||||
return x + self.buffer + 2
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
return self.f1(x) + self.f2(x)
|
||||
|
||||
model = SimpleModel()
|
||||
inp = torch.rand(2, 2)
|
||||
with torch.fx.traceback.preserve_node_meta():
|
||||
exported_model = torch.export.export(model, (inp,))
|
||||
graph = exported_model.graph_module.graph
|
||||
found_f1 = False
|
||||
found_f2 = False
|
||||
for node in graph.nodes:
|
||||
if "custom" in node.meta:
|
||||
if "_torchdynamo_disable_method" in node.meta["custom"]:
|
||||
if node.meta["custom"]["_torchdynamo_disable_method"] == "f1":
|
||||
found_f1 = True
|
||||
elif node.meta["custom"]["_torchdynamo_disable_method"] == "f2":
|
||||
found_f2 = True
|
||||
self.assertTrue(found_f1)
|
||||
self.assertTrue(found_f2)
|
||||
model.forward = torch._dynamo.disable(model.forward, recursive=False)
|
||||
with self.assertRaises(RuntimeError):
|
||||
exported_model = torch.export.export(model, (inp,))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -422,41 +422,34 @@ from user code:
|
||||
import optree
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn1(x):
|
||||
tree = {"a": x, "b": (x - 1, 2 * x)}
|
||||
sin, cos = optree.tree_transpose_map(
|
||||
lambda t: (torch.sin(t), torch.cos(t)),
|
||||
tree,
|
||||
def fn(x):
|
||||
d = {"a": 1}
|
||||
optree.tree_flatten_with_path(d)
|
||||
return torch.sin(x)
|
||||
|
||||
def post_munge(s):
|
||||
s = re.sub(
|
||||
r"optree\.\S*\.flatten_with_path",
|
||||
"optree.<path>.flatten_with_path",
|
||||
s,
|
||||
)
|
||||
return sin, cos
|
||||
|
||||
fn1(torch.randn(4))
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn2(x):
|
||||
spec = optree.treespec_deque([])
|
||||
return spec, x
|
||||
|
||||
fn2(torch.randn(4))
|
||||
self.assertGreaterEqual(len(counters["graph_break"]), 1)
|
||||
first_graph_break = next(iter(counters["graph_break"].keys()))
|
||||
|
||||
def post_munge(string):
|
||||
return re.sub(
|
||||
r"(optree\.|qualname: )\S*(\.make_from_collection)",
|
||||
r"\1<path>\2",
|
||||
string,
|
||||
r"qualname: \S*flatten_with_path",
|
||||
"qualname: <path>.flatten_with_path",
|
||||
s,
|
||||
)
|
||||
|
||||
fn(torch.randn(4))
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
first_graph_break = next(iter(counters["graph_break"].keys()))
|
||||
self.assertExpectedInline(
|
||||
post_munge(first_graph_break),
|
||||
"""\
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.make_from_collection.
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.flatten_with_path.
|
||||
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
|
||||
|
||||
Developer debug context: module: optree._C, qualname: <path>.make_from_collection, skip reason: <missing reason>
|
||||
Developer debug context: module: optree._C, qualname: <path>.flatten_with_path, skip reason: <missing reason>
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
|
||||
)
|
||||
@ -1050,7 +1043,7 @@ Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especiall
|
||||
msg = re.sub(r"line (\d+)", "line N", msg)
|
||||
msg = re.sub(
|
||||
r"""(?s)Traceback \(most recent call last\):.*
|
||||
File "exc.py", line N, in unimplemented
|
||||
File "exc.py", line N, in unimplemented_v2
|
||||
raise Unsupported\(msg\)""",
|
||||
"<Internal traceback>\n",
|
||||
msg,
|
||||
|
||||
@ -861,7 +861,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
|
||||
def test_logs_out(self):
|
||||
import tempfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=True) as tmp:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||
file_path = _as_posix_path(tmp.name)
|
||||
"""
|
||||
NamedTemporaryFile will include a file open operation.
|
||||
@ -888,6 +888,10 @@ fn(torch.randn(5))
|
||||
file_path, encoding="utf-8"
|
||||
) as fd: # encoding file to UTF-8 for Windows.
|
||||
lines = fd.read()
|
||||
fd.close()
|
||||
os.remove(
|
||||
file_path
|
||||
) # Delete temp file manually, due to setup NamedTemporaryFile as delete=False.
|
||||
orig_maxDiff = unittest.TestCase.maxDiff
|
||||
unittest.TestCase.maxDiff = None
|
||||
try:
|
||||
@ -984,7 +988,6 @@ exclusions = {
|
||||
"hierarchical_compile",
|
||||
"compute_dependencies",
|
||||
"annotation",
|
||||
"node_runtime_estimation",
|
||||
}
|
||||
for name in torch._logging._internal.log_registry.artifact_names:
|
||||
if name not in exclusions:
|
||||
|
||||
@ -742,14 +742,11 @@ class TestExport(TestCase):
|
||||
self.assertExpectedInline(
|
||||
str(custom_metadata),
|
||||
"""\
|
||||
('placeholder', 'x', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
|
||||
('placeholder', 'y', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
|
||||
('call_function', 'cat', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', 'item', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', 'ge_1', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', '_assert_scalar_default', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', 'mul', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('output', 'output', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})""",
|
||||
('call_function', 'cat', {'moo': 0})
|
||||
('call_function', 'item', {'moo': 0})
|
||||
('call_function', 'ge_1', {'moo': 0})
|
||||
('call_function', '_assert_scalar_default', {'moo': 0})
|
||||
('call_function', 'mul', {'moo': 0})""",
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@ -1224,14 +1221,8 @@ graph():
|
||||
%p_block_linear2_bias : [num_users=1] = placeholder[target=p_block_linear2_bias]
|
||||
%x : [num_users=1] = placeholder[target=x]
|
||||
%wrap_body0 : [num_users=1] = get_attr[target=wrap_body0]
|
||||
%tag_activation_checkpoint : [num_users=7] = call_function[target=torch.ops.higher_order.tag_activation_checkpoint](args = (%wrap_body0, %x, %p_block_linear1_weight, %p_block_linear1_bias, %p_block_linear2_weight, %p_block_linear2_bias), kwargs = {})
|
||||
%tag_activation_checkpoint : [num_users=1] = call_function[target=torch.ops.higher_order.tag_activation_checkpoint](args = (%wrap_body0, %x, %p_block_linear1_weight, %p_block_linear1_bias, %p_block_linear2_weight, %p_block_linear2_bias), kwargs = {})
|
||||
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 0), kwargs = {})
|
||||
%getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 1), kwargs = {})
|
||||
%getitem_2 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 2), kwargs = {})
|
||||
%getitem_3 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 3), kwargs = {})
|
||||
%getitem_4 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 4), kwargs = {})
|
||||
%getitem_5 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 5), kwargs = {})
|
||||
%getitem_6 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 6), kwargs = {})
|
||||
return (getitem,)""",
|
||||
)
|
||||
|
||||
@ -1240,14 +1231,14 @@ graph():
|
||||
"""\
|
||||
graph():
|
||||
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
|
||||
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
|
||||
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
|
||||
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
|
||||
%arg4_1 : [num_users=2] = placeholder[target=arg4_1]
|
||||
%linear : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%arg0_1, %arg1_1, %arg2_1), kwargs = {})
|
||||
%relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%linear,), kwargs = {})
|
||||
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
|
||||
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
|
||||
%arg3_1 : [num_users=1] = placeholder[target=arg3_1]
|
||||
%arg4_1 : [num_users=1] = placeholder[target=arg4_1]
|
||||
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%arg0_1, %arg1_1, %arg2_1), kwargs = {})
|
||||
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear,), kwargs = {})
|
||||
%linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %arg3_1, %arg4_1), kwargs = {})
|
||||
return (linear_1, arg1_1, arg2_1, linear, relu, arg3_1, arg4_1)""",
|
||||
return (linear_1,)""",
|
||||
)
|
||||
|
||||
stack = contextlib.ExitStack()
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
|
||||
import copy
|
||||
import pathlib
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@ -96,55 +97,55 @@ def run_with_nativert(ep):
|
||||
MODEL_NAME = "forward"
|
||||
|
||||
# TODO Does named tempfile have collision?
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f:
|
||||
torch.export.pt2_archive._package.package_pt2(
|
||||
f, exported_programs={MODEL_NAME: ep_infer}
|
||||
)
|
||||
filename = f.name
|
||||
|
||||
try:
|
||||
ep_args, ep_kwargs = ep_infer.example_inputs
|
||||
ep_args_copied, ep_kwargs_copied = (
|
||||
copy.deepcopy(ep_args),
|
||||
copy.deepcopy(ep_kwargs),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
try:
|
||||
ep_args, ep_kwargs = ep_infer.example_inputs
|
||||
ep_args_copied, ep_kwargs_copied = (
|
||||
copy.deepcopy(ep_args),
|
||||
copy.deepcopy(ep_kwargs),
|
||||
flat_expected = pytree.tree_leaves(
|
||||
ep_infer.module()(*ep_args_copied, **ep_kwargs_copied)
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
try:
|
||||
flat_expected = pytree.tree_leaves(
|
||||
ep_infer.module()(*ep_args_copied, **ep_kwargs_copied)
|
||||
)
|
||||
except Exception as e:
|
||||
raise unittest.case.SkipTest(str(e)) from e
|
||||
except Exception as e:
|
||||
raise unittest.case.SkipTest(str(e)) from e
|
||||
|
||||
model_runner = PyModelRunner(filename, MODEL_NAME)
|
||||
torch.manual_seed(0)
|
||||
if _is_supported_types((ep_args, ep_kwargs)):
|
||||
results = model_runner.run(*ep_args, **ep_kwargs)
|
||||
model_runner = PyModelRunner(filename, MODEL_NAME)
|
||||
torch.manual_seed(0)
|
||||
if _is_supported_types((ep_args, ep_kwargs)):
|
||||
results = model_runner.run(*ep_args, **ep_kwargs)
|
||||
else:
|
||||
results = model_runner.run_with_flat_inputs_and_outputs(
|
||||
*pytree.tree_leaves((ep_args, ep_kwargs))
|
||||
)
|
||||
flat_results = pytree.tree_leaves(results)
|
||||
assert len(flat_results) == len(flat_expected)
|
||||
for result, expected in zip(flat_results, flat_expected):
|
||||
assert type(result) is type(expected)
|
||||
if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor):
|
||||
assert result.shape == expected.shape
|
||||
assert result.dtype == expected.dtype
|
||||
assert result.device == expected.device
|
||||
torch.testing.assert_close(result, expected, equal_nan=True)
|
||||
else:
|
||||
results = model_runner.run_with_flat_inputs_and_outputs(
|
||||
*pytree.tree_leaves((ep_args, ep_kwargs))
|
||||
)
|
||||
flat_results = pytree.tree_leaves(results)
|
||||
assert len(flat_results) == len(flat_expected)
|
||||
for result, expected in zip(flat_results, flat_expected):
|
||||
assert type(result) is type(expected)
|
||||
if isinstance(result, torch.Tensor) and isinstance(
|
||||
expected, torch.Tensor
|
||||
):
|
||||
assert result.shape == expected.shape
|
||||
assert result.dtype == expected.dtype
|
||||
assert result.device == expected.device
|
||||
torch.testing.assert_close(result, expected, equal_nan=True)
|
||||
else:
|
||||
assert result == expected
|
||||
except RuntimeError as e:
|
||||
# User need to register pytree type on the cpp side, which
|
||||
# cannot be tested in python unittest.
|
||||
if "Unknown pytree node type" in str(e):
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
return ep
|
||||
assert result == expected
|
||||
except RuntimeError as e:
|
||||
# User need to register pytree type on the cpp side, which
|
||||
# cannot be tested in python unittest.
|
||||
if "Unknown pytree node type" in str(e):
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
finally:
|
||||
pathlib.Path(filename).unlink(missing_ok=True)
|
||||
return ep
|
||||
|
||||
|
||||
def mocked_nativert_export_strict(*args, **kwargs):
|
||||
@ -286,7 +287,7 @@ class TestNativeRT(TestCase):
|
||||
)
|
||||
|
||||
# package everything needed for the NativeRT to execute the AOTI delegate
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f:
|
||||
package_nativert_with_aoti_delegate(
|
||||
f,
|
||||
MODEL_NAME,
|
||||
@ -297,48 +298,50 @@ class TestNativeRT(TestCase):
|
||||
)
|
||||
filename = f.name
|
||||
|
||||
try:
|
||||
ep_args, ep_kwargs = aoti_delegate_ep.example_inputs
|
||||
ep_args_copied, ep_kwargs_copied = (
|
||||
copy.deepcopy(ep_args),
|
||||
copy.deepcopy(ep_kwargs),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
try:
|
||||
ep_args, ep_kwargs = aoti_delegate_ep.example_inputs
|
||||
ep_args_copied, ep_kwargs_copied = (
|
||||
copy.deepcopy(ep_args),
|
||||
copy.deepcopy(ep_kwargs),
|
||||
flat_expected = pytree.tree_leaves(
|
||||
aoti_delegate_ep.module()(*ep_args_copied, **ep_kwargs_copied)
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
try:
|
||||
flat_expected = pytree.tree_leaves(
|
||||
aoti_delegate_ep.module()(*ep_args_copied, **ep_kwargs_copied)
|
||||
)
|
||||
except Exception as e:
|
||||
raise unittest.case.SkipTest(str(e)) from e
|
||||
except Exception as e:
|
||||
raise unittest.case.SkipTest(str(e)) from e
|
||||
|
||||
model_runner = PyModelRunner(filename, f"{MODEL_NAME}-{BACKEND_ID}")
|
||||
torch.manual_seed(0)
|
||||
if _is_supported_types((ep_args, ep_kwargs)):
|
||||
results = model_runner.run(*ep_args, **ep_kwargs)
|
||||
model_runner = PyModelRunner(filename, f"{MODEL_NAME}-{BACKEND_ID}")
|
||||
torch.manual_seed(0)
|
||||
if _is_supported_types((ep_args, ep_kwargs)):
|
||||
results = model_runner.run(*ep_args, **ep_kwargs)
|
||||
else:
|
||||
results = model_runner.run_with_flat_inputs_and_outputs(
|
||||
*pytree.tree_leaves((ep_args, ep_kwargs))
|
||||
)
|
||||
flat_results = pytree.tree_leaves(results)
|
||||
assert len(flat_results) == len(flat_expected)
|
||||
for result, expected in zip(flat_results, flat_expected):
|
||||
assert type(result) is type(expected)
|
||||
if isinstance(result, torch.Tensor) and isinstance(
|
||||
expected, torch.Tensor
|
||||
):
|
||||
assert result.shape == expected.shape
|
||||
assert result.dtype == expected.dtype
|
||||
assert result.device == expected.device
|
||||
torch.testing.assert_close(result, expected, equal_nan=True)
|
||||
else:
|
||||
results = model_runner.run_with_flat_inputs_and_outputs(
|
||||
*pytree.tree_leaves((ep_args, ep_kwargs))
|
||||
)
|
||||
flat_results = pytree.tree_leaves(results)
|
||||
assert len(flat_results) == len(flat_expected)
|
||||
for result, expected in zip(flat_results, flat_expected):
|
||||
assert type(result) is type(expected)
|
||||
if isinstance(result, torch.Tensor) and isinstance(
|
||||
expected, torch.Tensor
|
||||
):
|
||||
assert result.shape == expected.shape
|
||||
assert result.dtype == expected.dtype
|
||||
assert result.device == expected.device
|
||||
torch.testing.assert_close(result, expected, equal_nan=True)
|
||||
else:
|
||||
assert result == expected
|
||||
except RuntimeError as e:
|
||||
# User need to register pytree type on the cpp side, which
|
||||
# cannot be tested in python unittest.
|
||||
if "Unknown pytree node type" in str(e):
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
assert result == expected
|
||||
except RuntimeError as e:
|
||||
# User need to register pytree type on the cpp side, which
|
||||
# cannot be tested in python unittest.
|
||||
if "Unknown pytree node type" in str(e):
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
finally:
|
||||
pathlib.Path(filename).unlink(missing_ok=True)
|
||||
|
||||
|
||||
if is_fbcode():
|
||||
|
||||
@ -4,7 +4,6 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
@ -40,56 +39,6 @@ class TestHopPrint(TestCase):
|
||||
|
||||
self.assertEqual(printed_output, "moo 1 2")
|
||||
|
||||
fx_f = make_fx(f)(x)
|
||||
new_inp = torch.randn(3, 3)
|
||||
|
||||
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
||||
fx_f(new_inp)
|
||||
ori_printed_output = mock_stdout.getvalue().strip()
|
||||
|
||||
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
||||
f(new_inp)
|
||||
fx_printed_output = mock_stdout.getvalue().strip()
|
||||
|
||||
self.assertEqual(ori_printed_output, fx_printed_output)
|
||||
|
||||
def test_print_with_proxy_graph(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
|
||||
torch._higher_order_ops.print("moo {x}", x=x)
|
||||
res = x + x
|
||||
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
|
||||
torch._higher_order_ops.print("yeehop {x}", x=x.shape[0])
|
||||
return (res,)
|
||||
|
||||
inputs = (torch.randn(3),)
|
||||
|
||||
# Without functionalization, print should just appear in the graph directly
|
||||
gm = make_fx(M(), tracing_mode="symbolic")(*inputs)
|
||||
|
||||
self.assertExpectedInline(
|
||||
str(gm.code).strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1):
|
||||
print_1 = torch.ops.higher_order.print('moo {x} {y}', x = 1, y = 2); print_1 = None
|
||||
print_2 = torch.ops.higher_order.print('moo {x}', x = arg0_1); print_2 = None
|
||||
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1)
|
||||
print_3 = torch.ops.higher_order.print('moo {x} {y}', x = 1, y = 2); print_3 = None
|
||||
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0); arg0_1 = None
|
||||
print_4 = torch.ops.higher_order.print('yeehop {x}', x = sym_size_int); sym_size_int = print_4 = None
|
||||
return (add,)""",
|
||||
)
|
||||
|
||||
new_inp = torch.randn(4)
|
||||
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
||||
gm(
|
||||
new_inp,
|
||||
)
|
||||
printed_output = mock_stdout.getvalue().strip()
|
||||
|
||||
self.assertEqual(printed_output, f"moo 1 2\nmoo {new_inp}\nmoo 1 2\nyeehop 4")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -206,10 +206,6 @@ class TestPyCodeCache(TestCase):
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
# XPU have extra lines, so get the last line, refer https://github.com/intel/torch-xpu-ops/issues/2261
|
||||
if torch.xpu.is_available():
|
||||
wrapper_path = wrapper_path.splitlines()[-1]
|
||||
hit = hit.splitlines()[-1]
|
||||
self.assertEqual(hit, "1")
|
||||
|
||||
with open(wrapper_path) as f:
|
||||
|
||||
@ -73,23 +73,6 @@ class TestCompileWorker(TestCase):
|
||||
finally:
|
||||
pool.shutdown()
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_quiesce_repeatedly(self):
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
a = pool.submit(operator.add, 100, 1)
|
||||
pool.quiesce()
|
||||
pool.wakeup()
|
||||
b = pool.submit(operator.sub, 100, 1)
|
||||
pool.quiesce()
|
||||
pool.quiesce()
|
||||
pool.wakeup()
|
||||
b = pool.submit(operator.sub, 100, 1)
|
||||
self.assertEqual(a.result(), 101)
|
||||
self.assertEqual(b.result(), 99)
|
||||
finally:
|
||||
pool.shutdown()
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_logging(self):
|
||||
os.environ["MAST_HPC_JOB_NAME"] = "test_job"
|
||||
|
||||
@ -5222,7 +5222,6 @@ xfail_by_backend = {
|
||||
"test_reentrant_with_callbacks_both_depths", # queue_callback
|
||||
"test_reentrant_with_callbacks_depth_0", # queue_callback
|
||||
"test_reentrant_with_callbacks_depth_1", # queue_callback
|
||||
"test_checkpoint_graph_execution_group", # Attempted to call function marked as skipped
|
||||
"test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook
|
||||
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
|
||||
"test_post_accumulate_grad_hook_ordering", # accuracy error
|
||||
|
||||
@ -3278,15 +3278,6 @@ class CPUReproTests(TestCase):
|
||||
metrics.reset()
|
||||
self.common(fn, (x,))
|
||||
|
||||
def test_softmax_with_zero_dim(self):
|
||||
def fn(x):
|
||||
x = torch.softmax(x, 0)
|
||||
return x
|
||||
|
||||
x = torch.rand([], dtype=torch.bfloat16)
|
||||
metrics.reset()
|
||||
self.common(fn, (x,))
|
||||
|
||||
@config.patch({"fx_graph_cache": False, "fx_graph_remote_cache": False})
|
||||
def test_local_buffer_in_outer_loop_fusion(self):
|
||||
def fn(x):
|
||||
|
||||
@ -148,24 +148,6 @@ class FxirTestCase(InductorTestCase):
|
||||
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
||||
self._compile_and_check(torch.add, args)
|
||||
|
||||
def test_device_type(self):
|
||||
"""
|
||||
Test that we allocate on a device type instead of a specific index.
|
||||
"""
|
||||
# Pass in a tensor on an indexed device.
|
||||
device_runtime = getattr(torch, self.device)
|
||||
indexed_device = torch.device(self.device, device_runtime.current_device())
|
||||
args = [torch.randn(8, device=indexed_device) for _ in range(2)]
|
||||
(gm,) = self._compile_and_check(torch.add, args)
|
||||
(empty_strided,) = gm.graph.find_nodes(
|
||||
op="call_function", target=torch.empty_strided
|
||||
)
|
||||
|
||||
# Check that the device of the output allocation is not indexed.
|
||||
output_device = torch.device(empty_strided.kwargs["device"])
|
||||
self.assertIs(output_device.index, None)
|
||||
self.assertEqual(output_device.type, indexed_device.type)
|
||||
|
||||
def test_multiple_kernels(self):
|
||||
def foo(x, y):
|
||||
return x.sum() + y.sum()
|
||||
|
||||
@ -3,8 +3,7 @@
|
||||
import functools
|
||||
import weakref
|
||||
from collections import Counter
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch._inductor.fx_passes.memory_estimator import (
|
||||
@ -29,7 +28,7 @@ def device_filter(device):
|
||||
|
||||
|
||||
class FakeTensorMemoryProfilerMode(TorchDispatchMode):
|
||||
def __init__(self, device_filter: Optional[Callable[[torch.device], bool]] = None):
|
||||
def __init__(self, device_filter: Optional[Callable[torch.device, bool]] = None):
|
||||
# counter of storage ids to live references
|
||||
self.storage_count: dict[int, int] = Counter()
|
||||
# live fake tensors
|
||||
|
||||
@ -52,12 +52,9 @@ def make_pallas(cls):
|
||||
return test_class
|
||||
|
||||
|
||||
class PallasTestsMixin:
|
||||
"""Basic tests for Pallas backend functionality (parameterized by DEVICE). Mixin only, not collected."""
|
||||
|
||||
def _compile(self, fn):
|
||||
key = "cuda_backend" if self.DEVICE == "cuda" else "cpu_backend"
|
||||
return torch.compile(fn, backend="inductor", options={key: "pallas"})
|
||||
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
|
||||
class PallasTests(TestCase):
|
||||
"""Basic tests for Pallas backend functionality."""
|
||||
|
||||
def test_simple_add(self):
|
||||
"""Test basic element-wise addition."""
|
||||
@ -65,10 +62,12 @@ class PallasTestsMixin:
|
||||
def fn(a, b):
|
||||
return a + b
|
||||
|
||||
compiled = self._compile(fn)
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
|
||||
a = torch.randn(1024, device=self.DEVICE)
|
||||
b = torch.randn(1024, device=self.DEVICE)
|
||||
a = torch.randn(1024, device="cuda")
|
||||
b = torch.randn(1024, device="cuda")
|
||||
result = compiled(a, b)
|
||||
expected = fn(a, b)
|
||||
self.assertEqual(result, expected)
|
||||
@ -79,10 +78,12 @@ class PallasTestsMixin:
|
||||
def fn(a, b):
|
||||
return a * b
|
||||
|
||||
compiled = self._compile(fn)
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
|
||||
a = torch.randn(1024, device=self.DEVICE)
|
||||
b = torch.randn(1024, device=self.DEVICE)
|
||||
a = torch.randn(1024, device="cuda")
|
||||
b = torch.randn(1024, device="cuda")
|
||||
result = compiled(a, b)
|
||||
expected = fn(a, b)
|
||||
self.assertEqual(result, expected)
|
||||
@ -93,9 +94,11 @@ class PallasTestsMixin:
|
||||
def fn(x):
|
||||
return torch.sin(x)
|
||||
|
||||
compiled = self._compile(fn)
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
|
||||
x = torch.randn(1024, device=self.DEVICE)
|
||||
x = torch.randn(1024, device="cuda")
|
||||
result = compiled(x)
|
||||
expected = fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
@ -106,10 +109,12 @@ class PallasTestsMixin:
|
||||
def fn(x, y):
|
||||
return x.sin() + y
|
||||
|
||||
compiled = self._compile(fn)
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
|
||||
x = torch.randn(1024, device=self.DEVICE)
|
||||
y = torch.randn(1024, device=self.DEVICE)
|
||||
x = torch.randn(1024, device="cuda")
|
||||
y = torch.randn(1024, device="cuda")
|
||||
result = compiled(x, y)
|
||||
expected = fn(x, y)
|
||||
self.assertEqual(result, expected)
|
||||
@ -120,9 +125,11 @@ class PallasTestsMixin:
|
||||
def fn(x):
|
||||
return torch.log(torch.exp(x))
|
||||
|
||||
compiled = self._compile(fn)
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
|
||||
x = torch.randn(1024, device=self.DEVICE)
|
||||
x = torch.randn(1024, device="cuda")
|
||||
result = compiled(x)
|
||||
expected = fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
@ -133,9 +140,11 @@ class PallasTestsMixin:
|
||||
def fn(x):
|
||||
return torch.sqrt(x)
|
||||
|
||||
compiled = self._compile(fn)
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
|
||||
x = torch.randn(1024, device=self.DEVICE).abs() # Ensure positive for sqrt
|
||||
x = torch.randn(1024, device="cuda").abs() # Ensure positive for sqrt
|
||||
result = compiled(x)
|
||||
expected = fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
@ -146,9 +155,11 @@ class PallasTestsMixin:
|
||||
def fn(x):
|
||||
return torch.tanh(x)
|
||||
|
||||
compiled = self._compile(fn)
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
|
||||
x = torch.randn(1024, device=self.DEVICE)
|
||||
x = torch.randn(1024, device="cuda")
|
||||
result = compiled(x)
|
||||
expected = fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
@ -159,9 +170,11 @@ class PallasTestsMixin:
|
||||
def fn(x):
|
||||
return torch.abs(-x)
|
||||
|
||||
compiled = self._compile(fn)
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
|
||||
x = torch.randn(1024, device=self.DEVICE)
|
||||
x = torch.randn(1024, device="cuda")
|
||||
result = compiled(x)
|
||||
expected = fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
@ -172,10 +185,12 @@ class PallasTestsMixin:
|
||||
def fn(a, b):
|
||||
return torch.maximum(a, b) + torch.minimum(a, b)
|
||||
|
||||
compiled = self._compile(fn)
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
|
||||
a = torch.randn(1024, device=self.DEVICE)
|
||||
b = torch.randn(1024, device=self.DEVICE)
|
||||
a = torch.randn(1024, device="cuda")
|
||||
b = torch.randn(1024, device="cuda")
|
||||
result = compiled(a, b)
|
||||
expected = fn(a, b)
|
||||
self.assertEqual(result, expected)
|
||||
@ -213,17 +228,15 @@ class PallasTestsMixin:
|
||||
|
||||
@torch.compile(
|
||||
backend="inductor",
|
||||
options={
|
||||
("cuda_backend" if self.DEVICE == "cuda" else "cpu_backend"): "pallas"
|
||||
},
|
||||
options={"cuda_backend": "pallas"},
|
||||
)
|
||||
def pallas_fn(a, b):
|
||||
return a.sin() + b.cos()
|
||||
|
||||
_, (code,) = run_and_get_code(
|
||||
pallas_fn,
|
||||
torch.randn(64, device=self.DEVICE),
|
||||
torch.randn(64, device=self.DEVICE),
|
||||
torch.randn(64, device="cuda"),
|
||||
torch.randn(64, device="cuda"),
|
||||
)
|
||||
# Verify Pallas-specific code generation
|
||||
self.assertIn("import jax", code)
|
||||
@ -236,10 +249,12 @@ class PallasTestsMixin:
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
compiled = self._compile(fn)
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
|
||||
x = torch.randn(32, 32, device=self.DEVICE)
|
||||
y = torch.randn(32, 32, device=self.DEVICE)
|
||||
x = torch.randn(32, 32, device="cuda")
|
||||
y = torch.randn(32, 32, device="cuda")
|
||||
result = compiled(x, y)
|
||||
expected = fn(x, y)
|
||||
self.assertEqual(result, expected)
|
||||
@ -250,10 +265,12 @@ class PallasTestsMixin:
|
||||
def fn(x):
|
||||
return x * 2.0
|
||||
|
||||
compiled = self._compile(fn)
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
|
||||
for shape in [(64,), (128,), (256,), (1024,)]:
|
||||
x = torch.randn(shape, device=self.DEVICE)
|
||||
x = torch.randn(shape, device="cuda")
|
||||
result = compiled(x)
|
||||
expected = fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
@ -265,10 +282,12 @@ class PallasTestsMixin:
|
||||
def contiguous_add(a, b):
|
||||
return a + b
|
||||
|
||||
compiled = self._compile(contiguous_add)
|
||||
compiled = torch.compile(
|
||||
contiguous_add, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
|
||||
a = torch.randn(1024, device=self.DEVICE)
|
||||
b = torch.randn(1024, device=self.DEVICE)
|
||||
a = torch.randn(1024, device="cuda")
|
||||
b = torch.randn(1024, device="cuda")
|
||||
result = compiled(a, b)
|
||||
expected = contiguous_add(a, b)
|
||||
self.assertEqual(result, expected)
|
||||
@ -277,9 +296,11 @@ class PallasTestsMixin:
|
||||
def contiguous_mul(x):
|
||||
return x * 2.0
|
||||
|
||||
compiled = self._compile(contiguous_mul)
|
||||
compiled = torch.compile(
|
||||
contiguous_mul, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
|
||||
x = torch.randn(128, 8, device=self.DEVICE)
|
||||
x = torch.randn(128, 8, device="cuda")
|
||||
result = compiled(x)
|
||||
expected = contiguous_mul(x)
|
||||
self.assertEqual(result, expected)
|
||||
@ -289,10 +310,12 @@ class PallasTestsMixin:
|
||||
def operate_on_tensor(x):
|
||||
return x.sin()
|
||||
|
||||
compiled = self._compile(operate_on_tensor)
|
||||
compiled = torch.compile(
|
||||
operate_on_tensor, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
|
||||
# Create a transposed (non-contiguous) view
|
||||
x = torch.randn(64, 32, device=self.DEVICE)
|
||||
x = torch.randn(64, 32, device="cuda")
|
||||
x_t = x.t() # Non-contiguous view
|
||||
self.assertFalse(x_t.is_contiguous())
|
||||
|
||||
@ -309,24 +332,13 @@ class PallasTestsMixin:
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
|
||||
class PallasTestsCUDA(PallasTestsMixin, TestCase):
|
||||
DEVICE = "cuda"
|
||||
|
||||
|
||||
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
|
||||
class PallasTestsCPU(PallasTestsMixin, TestCase):
|
||||
DEVICE = "cpu"
|
||||
|
||||
|
||||
# Create test variants using the main test suite
|
||||
# Note: Only enable GPU tests since Pallas primarily targets GPU
|
||||
if hasattr(sys.modules.get(__name__), "test_torchinductor") and HAS_PALLAS:
|
||||
if getattr(test_torchinductor, "HAS_GPU", False):
|
||||
# Uncomment these to run full test suite with Pallas backend
|
||||
# make_pallas(test_torchinductor.SweepInputsGPUTest)
|
||||
# make_pallas(test_torchinductor.GPUTests)
|
||||
pass
|
||||
if test_torchinductor.HAS_GPU and HAS_PALLAS:
|
||||
# Uncomment these to run full test suite with Pallas backend
|
||||
# make_pallas(test_torchinductor.SweepInputsGPUTest)
|
||||
# make_pallas(test_torchinductor.GPUTests)
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
if HAS_PALLAS:
|
||||
|
||||
@ -1217,43 +1217,6 @@ class TestPatternMatcher(TestCase):
|
||||
_, (code) = run_and_get_code(fn2, args[0], args[1], args[2])
|
||||
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
|
||||
|
||||
def test_addmm_alpha_beta_with_pointwise(self):
|
||||
# Test that addmm with alpha/beta != 1 is unfused correctly with pointwise ops
|
||||
# See https://github.com/pytorch/pytorch/issues/167313
|
||||
x = torch.rand(2, device=GPU_TYPE)
|
||||
a = torch.rand(2, 3, device=GPU_TYPE)
|
||||
b = torch.rand(3, 2, device=GPU_TYPE)
|
||||
|
||||
def f(x, a, b):
|
||||
return torch.nn.functional.relu(torch.addmm(x, a, b, alpha=0.8, beta=0.2))
|
||||
|
||||
fc = torch.compile(f)
|
||||
|
||||
expected = f(x, a, b)
|
||||
actual = fc(x, a, b)
|
||||
|
||||
# The compiled version should produce the same result as eager
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
# Verify that addmm is unfused (should not use extern_kernels.addmm)
|
||||
# The pattern should be replaced with beta * x + alpha * (a @ b)
|
||||
_, (code) = run_and_get_code(fc, x, a, b)
|
||||
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
|
||||
|
||||
# Test with alpha=1, beta=1 (default) - should also unfuse
|
||||
def f_default(x, a, b):
|
||||
return torch.nn.functional.relu(torch.addmm(x, a, b))
|
||||
|
||||
fc_default = torch.compile(f_default)
|
||||
expected_default = f_default(x, a, b)
|
||||
actual_default = fc_default(x, a, b)
|
||||
|
||||
torch.testing.assert_close(actual_default, expected_default)
|
||||
|
||||
# Should unfuse and not use extern_kernels.addmm
|
||||
_, (code) = run_and_get_code(fc_default, x, a, b)
|
||||
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
|
||||
|
||||
def test_serialized_patterns_up_to_date(self):
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._inductor.fx_passes import joint_graph
|
||||
|
||||
@ -7,7 +7,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
class TestUpgraderModelGeneration(TestCase):
|
||||
def test_all_modules(self):
|
||||
for a_module in ALL_MODULES:
|
||||
for a_module in ALL_MODULES.keys():
|
||||
module_name = type(a_module).__name__
|
||||
self.assertTrue(
|
||||
isinstance(a_module, torch.nn.Module),
|
||||
|
||||
@ -2979,7 +2979,7 @@ class TestScriptList(JitTestCase):
|
||||
self.col2 = "b"
|
||||
|
||||
def forward(self):
|
||||
if self.col1 in self.segments_groupby_col:
|
||||
if self.col1 in self.segments_groupby_col.keys():
|
||||
return 1
|
||||
else:
|
||||
return 2
|
||||
|
||||
@ -78,7 +78,7 @@ class TestModuleContainers(JitTestCase):
|
||||
x = mod(x)
|
||||
values.append(x)
|
||||
|
||||
for key in self.moduledict:
|
||||
for key in self.moduledict.keys():
|
||||
names.append(key)
|
||||
|
||||
return x, names
|
||||
@ -306,7 +306,7 @@ class TestModuleContainers(JitTestCase):
|
||||
|
||||
assert "submod" in self.moduledict, "__contains__ fails for ModuleDict"
|
||||
|
||||
for key in self.moduledict:
|
||||
for key in self.moduledict.keys():
|
||||
assert key == "submod", "keys() fails for ModuleDict"
|
||||
|
||||
for item in self.moduledict.items():
|
||||
|
||||
@ -276,7 +276,7 @@ class TestPDT(JitTestCase):
|
||||
def test_multiple_class_with_same_method(self):
|
||||
class PDTModelOne:
|
||||
def test_find(self, a, b):
|
||||
return b in a
|
||||
return b in a.keys()
|
||||
|
||||
class PDTModelTwo:
|
||||
def test_find(self, a, b):
|
||||
|
||||
@ -342,7 +342,7 @@ class TestTyping(JitTestCase):
|
||||
# type: (Dict[str, int]) -> Tuple[str, int]
|
||||
key_str = ""
|
||||
sum = 0
|
||||
for key in x:
|
||||
for key in x.keys():
|
||||
key_str += key
|
||||
for val in x.values():
|
||||
sum += val
|
||||
|
||||
@ -310,7 +310,7 @@ class TestLoadStateDict(NNTestCase):
|
||||
|
||||
# Make sure parameters and persistent buffers were assigned
|
||||
net_meta_state_dict = net_meta.state_dict(keep_vars=True)
|
||||
for key in state_dict:
|
||||
for key in state_dict.keys():
|
||||
if key in net_meta._parameters:
|
||||
if keep_vars and not swap:
|
||||
# state_dict[key] is an nn.Parameter
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user