Compare commits

..

1 Commits

Author SHA1 Message Date
854b40f81f hipsparselt support in cuda_to_hip_mappings.py 2025-11-07 16:39:12 +00:00
290 changed files with 3392 additions and 5822 deletions

View File

@ -36,7 +36,11 @@ case ${DOCKER_TAG_PREFIX} in
;;
rocm*)
BASE_TARGET=rocm
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
# add gfx950, gfx115x conditionally starting in ROCm 7.0
if [[ "$ROCM_VERSION" == *"7.0"* ]]; then
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
fi
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
;;
*)

View File

@ -260,12 +260,6 @@ case "$tag" in
HALIDE=yes
TRITON=yes
;;
pytorch-linux-jammy-cuda12.8-py3.12-pallas)
CUDA_VERSION=12.8.1
ANACONDA_PYTHON_VERSION=3.12
GCC_VERSION=11
PALLAS=yes
;;
pytorch-linux-jammy-py3.12-triton-cpu)
CUDA_VERSION=12.6
ANACONDA_PYTHON_VERSION=3.12
@ -387,7 +381,6 @@ docker build \
--build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \
--build-arg "EXECUTORCH=${EXECUTORCH}" \
--build-arg "HALIDE=${HALIDE}" \
--build-arg "PALLAS=${PALLAS}" \
--build-arg "XPU_VERSION=${XPU_VERSION}" \
--build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \
--build-arg "ACL=${ACL:-}" \

View File

@ -1 +0,0 @@
0.8.0

View File

@ -1,40 +0,0 @@
#!/bin/bash
set -ex
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
# Get the pinned JAX version (same for all CUDA versions)
JAX_VERSION=$(get_pinned_commit /ci_commit_pins/jax)
function install_jax_12() {
echo "Installing JAX ${JAX_VERSION} with CUDA 12 support"
pip_install "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Verify installation
python -c "import jax" # check for errors
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 12"
}
function install_jax_13() {
echo "Installing JAX ${JAX_VERSION} with CUDA 13 support"
pip_install "jax[cuda13]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Verify installation
python -c "import jax" # check for errors
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 13"
}
# idiomatic parameter and option handling in sh
while test $# -gt 0
do
case "$1" in
12.4|12.6|12.6.*|12.8|12.8.*|12.9|12.9.*) install_jax_12;
;;
13.0|13.0.*) install_jax_13;
;;
*) echo "bad argument $1"; exit 1
;;
esac
shift
done

View File

@ -49,7 +49,11 @@ case ${DOCKER_TAG_PREFIX} in
fi
BASE_TARGET=rocm
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
# add gfx950, gfx115x conditionally starting in ROCm 7.0
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
fi
DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}"
;;
*)

View File

@ -87,7 +87,11 @@ case ${image} in
MANY_LINUX_VERSION="2_28"
DEVTOOLSET_VERSION="11"
GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
# add gfx950, gfx115x conditionally starting in ROCm 7.0
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
fi
DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}"
;;
manylinux2_28-builder:xpu)

View File

@ -143,15 +143,6 @@ COPY ci_commit_pins/halide.txt halide.txt
RUN if [ -n "${HALIDE}" ]; then bash ./install_halide.sh; fi
RUN rm install_halide.sh common_utils.sh halide.txt
ARG PALLAS
ARG CUDA_VERSION
# Install JAX with CUDA support (for Pallas)
COPY ./common/install_jax.sh install_jax.sh
COPY ./common/common_utils.sh common_utils.sh
COPY ./ci_commit_pins/jax.txt /ci_commit_pins/jax.txt
RUN if [ -n "${PALLAS}" ]; then bash ./install_jax.sh ${CUDA_VERSION}; fi
RUN rm -f install_jax.sh common_utils.sh /ci_commit_pins/jax.txt
ARG ONNX
# Install ONNX dependencies
COPY ./common/install_onnx.sh ./common/common_utils.sh ./

View File

@ -8,11 +8,9 @@ from abc import ABC, abstractmethod
try:
from collections.abc import Callable # Python 3.11+
from typing import Any, Required, TypedDict
from typing import Any, Callable, Required, TypedDict # Python 3.11+
except ImportError:
from collections.abc import Callable
from typing import Any, TypedDict
from typing import Any, Callable, TypedDict
from typing_extensions import Required # Fallback for Python <3.11

View File

@ -168,16 +168,14 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
# shellcheck disable=SC1091
source /opt/intel/oneapi/compiler/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/umf/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/ccl/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/mpi/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/pti/latest/env/vars.sh
# Enable XCCL build
export USE_XCCL=1
export USE_MPI=0
# XPU kineto feature dependencies are not fully ready, disable kineto build as temp WA
export USE_KINETO=0
export TORCH_XPU_ARCH_LIST=pvc
fi

View File

@ -208,8 +208,6 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
source /opt/intel/oneapi/ccl/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/mpi/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/pti/latest/env/vars.sh
# Check XPU status before testing
timeout 30 xpu-smi discovery || true
fi
@ -826,11 +824,6 @@ test_inductor_halide() {
assert_git_not_dirty
}
test_inductor_pallas() {
python test/run_test.py --include inductor/test_pallas.py --verbose
assert_git_not_dirty
}
test_inductor_triton_cpu() {
python test/run_test.py --include inductor/test_triton_cpu_backend.py inductor/test_torchinductor_strided_blocks.py --verbose
assert_git_not_dirty
@ -1731,8 +1724,6 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
test_inductor_distributed
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
test_inductor_halide
elif [[ "${TEST_CONFIG}" == *inductor-pallas* ]]; then
test_inductor_pallas
elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
test_inductor_triton_cpu
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then

View File

@ -1 +1 @@
ccb801b88af136454798b945175c4c87e636ac33
ca2212438fdd8ce29b66999ed70ed54b0f9372d1

View File

@ -1 +1 @@
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9

9
.github/labeler.yml vendored
View File

@ -138,8 +138,7 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- aten/src/ATen/native/cuda/Blas.cpp
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
@ -149,8 +148,7 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- aten/src/ATen/native/cuda/Blas.cpp
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
@ -160,8 +158,7 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- aten/src/ATen/native/cuda/Blas.cpp
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
- third_party/fbgemm

View File

@ -10,4 +10,3 @@
pathFilter:
- 'torch/csrc/inductor/aoti_torch/c/*'
- 'torch/csrc/inductor/aoti_torch/generated/*'
- 'torch/csrc/stable/c/*'

View File

@ -1,11 +1,10 @@
# Delete old branches
import os
import re
from collections.abc import Callable
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import Any
from typing import Any, Callable
from github_utils import gh_fetch_json_dict, gh_graphql
from gitutils import GitRepo

View File

@ -8,11 +8,10 @@ import re
import subprocess
import sys
import warnings
from collections.abc import Callable
from enum import Enum
from functools import cache
from logging import info
from typing import Any, Optional
from typing import Any, Callable, Optional
from urllib.request import Request, urlopen
import yaml

View File

@ -11,8 +11,7 @@ import sys
import time
import urllib
import urllib.parse
from collections.abc import Callable
from typing import Any, Optional
from typing import Any, Callable, Optional
from urllib.request import Request, urlopen

View File

@ -3,9 +3,8 @@
import json
import os
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, cast, Optional, Union
from typing import Any, Callable, cast, Optional, Union
from urllib.error import HTTPError
from urllib.parse import quote
from urllib.request import Request, urlopen

View File

@ -4,10 +4,10 @@ import os
import re
import tempfile
from collections import defaultdict
from collections.abc import Callable, Iterator
from collections.abc import Iterator
from datetime import datetime
from functools import wraps
from typing import Any, cast, Optional, TypeVar, Union
from typing import Any, Callable, cast, Optional, TypeVar, Union
T = TypeVar("T")

View File

@ -17,12 +17,12 @@ import re
import time
import urllib.parse
from collections import defaultdict
from collections.abc import Callable, Iterable
from collections.abc import Iterable
from dataclasses import dataclass
from functools import cache
from pathlib import Path
from re import Pattern
from typing import Any, cast, NamedTuple, Optional
from typing import Any, Callable, cast, NamedTuple, Optional
from warnings import warn
import yaml

View File

@ -67,7 +67,6 @@ jobs:
pytorch-linux-jammy-py3.10-gcc11,
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
pytorch-linux-jammy-py3.12-halide,
pytorch-linux-jammy-cuda12.8-py3.12-pallas,
pytorch-linux-jammy-xpu-n-1-py3,
pytorch-linux-noble-xpu-n-py3,
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,

View File

@ -81,32 +81,6 @@ jobs:
test-matrix: ${{ needs.inductor-halide-build.outputs.test-matrix }}
secrets: inherit
inductor-pallas-build:
name: inductor-pallas-build
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
build-environment: linux-jammy-cuda12.8-py3.12-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-py3.12-pallas
cuda-arch-list: '8.9'
runner: linux.8xlarge.memory
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: |
{ include: [
{ config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu" },
]}
secrets: inherit
inductor-pallas-test:
name: inductor-pallas-test
uses: ./.github/workflows/_linux-test.yml
needs: inductor-pallas-build
with:
build-environment: linux-jammy-py3.12-gcc11
docker-image: ${{ needs.inductor-pallas-build.outputs.docker-image }}
test-matrix: ${{ needs.inductor-pallas-build.outputs.test-matrix }}
secrets: inherit
inductor-triton-cpu-build:
name: inductor-triton-cpu-build
uses: ./.github/workflows/_linux-build.yml

View File

@ -1402,7 +1402,7 @@ init_command = [
'--dry-run={{DRYRUN}}',
'usort==1.0.8.post1',
'isort==6.0.1',
'ruff==0.14.4', # sync with RUFF
'ruff==0.13.1', # sync with RUFF
]
is_formatter = true
@ -1537,7 +1537,7 @@ init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'ruff==0.14.4', # sync with PYFMT
'ruff==0.13.1', # sync with PYFMT
]
is_formatter = true

View File

@ -210,12 +210,8 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
/test/inductor/test_flex_attention.py @drisspg
/test/inductor/test_flex_decoding.py @drisspg
# Low Precision & Grouped GEMMs
# Low Precision GEMMs
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
/aten/src/ATen/native/cuda/GroupedBlas.cpp @drisspg @slayton58
/aten/src/ATen/native/cuda/ScaledBlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
/aten/src/ATen/cuda/CUDAScaledBlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDAScaledBlas.h @drisspg @slayton58
/test/test_scaled_matmul_cuda.py @drisspg @slayton58

View File

@ -94,11 +94,6 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
}
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
}
} // namespace at::accelerator
namespace at {

View File

@ -226,8 +226,8 @@ template <
typename B = HostBlock<S>>
struct CachingHostAllocatorImpl {
virtual ~CachingHostAllocatorImpl() {
if (active_) {
active_ = false;
active_ = false;
if (pinned_use_background_threads()) {
getBackgroundThreadPool()->waitWorkComplete();
}
}
@ -260,7 +260,6 @@ struct CachingHostAllocatorImpl {
if (pinned_use_background_threads()) {
// Launch the background thread and process events in a loop.
static bool background_thread_flag [[maybe_unused]] = [this] {
active_ = true;
getBackgroundThreadPool()->run([&]() {
while (active_) {
process_events();
@ -684,9 +683,9 @@ struct CachingHostAllocatorImpl {
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
std::deque<std::pair<E, B*>> events_; // event queue paired with block
// Indicates whether the event-processing thread pool is active.
// Indicates whether the object is active.
// Set to false in the destructor to signal background threads to stop.
std::atomic<bool> active_{false};
std::atomic<bool> active_{true};
protected:
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
};

View File

@ -157,8 +157,6 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
DispatchKey::Negative,
DispatchKey::Conjugate,
DispatchKey::XLA,
DispatchKey::XPU,
DispatchKey::HPU,
DispatchKey::CUDA,
DispatchKey::CPU,
DispatchKey::PrivateUse1,

View File

@ -141,9 +141,6 @@ static Tensor& addmv_out_mps_impl(const Tensor& self,
};
MPSStream* stream = at::mps::getCurrentMPSStream();
if (result.numel() == 0) {
return result;
}
Tensor matMulVec = at::mm(mat, vec.unsqueeze(1)).squeeze(1);
@autoreleasepool {

View File

@ -2803,7 +2803,7 @@
- func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA, MPS, MTIA: floor_divide_out
CPU, CUDA, MPS: floor_divide_out
SparseCPU, SparseCUDA, SparseMPS: floor_divide_out_sparse_zerodim
- func: floor_divide.Scalar(Tensor self, Scalar other) -> Tensor
@ -4292,7 +4292,6 @@
dispatch:
SparseCPU: sparse_sparse_matmul_cpu
SparseCUDA: sparse_sparse_matmul_cuda
SparseMPS: sparse_sparse_matmul_mps
autogen: _sparse_sparse_matmul.out
- func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
@ -4384,7 +4383,7 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: mv
SparseCPU, SparseCUDA, SparseMPS: mv_sparse
SparseCPU, SparseCUDA: mv_sparse
- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
@ -9833,7 +9832,7 @@
structured_delegate: erfinv.out
variants: method, function
dispatch:
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse
SparseCPU, SparseCUDA: erfinv_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr
tags: pointwise
@ -9842,7 +9841,7 @@
structured_delegate: erfinv.out
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_
SparseCPU, SparseCUDA: erfinv_sparse_
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_
tags: pointwise
@ -9852,7 +9851,7 @@
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: erfinv_out
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_out
SparseCPU, SparseCUDA: erfinv_sparse_out
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_out
tags: pointwise

View File

@ -10,10 +10,6 @@
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_coalesce_native.h>
#include <ATen/ops/repeat_interleave_native.h>
#include <ATen/ops/cumsum.h>
#include <ATen/ops/_sparse_sparse_matmul_native.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
#include <ATen/ops/cat.h>
#include <ATen/ops/add_native.h>
@ -892,114 +888,5 @@ static void sparse_mask_intersection_out_mps_kernel(
/*coalesce_mask=*/false);
}
Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
TORCH_CHECK(mat1_.is_sparse() && mat2_.is_sparse(),
"sparse_sparse_matmul_mps: both inputs must be sparse COO tensors");
TORCH_CHECK(mat1_.is_mps() && mat2_.is_mps(),
"sparse_sparse_matmul_mps: both inputs must be on MPS device");
TORCH_CHECK(mat1_.dim() == 2 && mat2_.dim() == 2,
"sparse_sparse_matmul_mps: both inputs must be 2D matrices");
TORCH_CHECK(mat1_.dense_dim() == 0 && mat2_.dense_dim() == 0,
"sparse_sparse_matmul_mps: only scalar values supported (dense_dim == 0)");
TORCH_CHECK(mat1_.size(1) == mat2_.size(0),
"mat1 and mat2 shapes cannot be multiplied (", mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
"sparse_sparse_matmul_mps: mat1 dtype ", mat1_.scalar_type(),
" does not match mat2 dtype ", mat2_.scalar_type());
const auto device = mat1_.device();
auto A = mat1_.coalesce();
auto B = mat2_.coalesce();
const auto I = A.size(0);
const auto K = A.size(1);
const auto N = B.size(1);
const auto nnzA = A._nnz();
const auto nnzB = B._nnz();
// Early empty result, return an empty, coalesced tensor
if (I == 0 || N == 0 || K == 0 || nnzA == 0 || nnzB == 0) {
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
out._coalesced_(true);
return out;
}
const auto computeDtype = at::result_type(mat1_, mat2_);
auto A_idx = A._indices().contiguous();
auto A_val = A._values().to(computeDtype).contiguous();
auto A_i = A_idx.select(0, 0).contiguous();
auto A_k = A_idx.select(0, 1).contiguous();
auto B_idx = B._indices().contiguous();
auto B_val = B._values().to(computeDtype).contiguous();
auto B_k = B_idx.select(0, 0).contiguous();
auto B_j = B_idx.select(0, 1).contiguous();
// csr-style row pointers for B by k (the shared dimension)
Tensor row_ptr_B;
{
auto batch_ptr = at::tensor({0LL, nnzB}, at::device(device).dtype(at::kLong));
row_ptr_B = at::empty({K + 1}, at::device(device).dtype(at::kLong));
build_row_ptr_per_batch_mps(B_k, batch_ptr, /*B=*/1, /*I=*/K, row_ptr_B);
}
auto row_ptr_B_lo = row_ptr_B.narrow(0, 0, K);
auto row_ptr_B_hi = row_ptr_B.narrow(0, 1, K);
auto deg_B = row_ptr_B_hi.sub(row_ptr_B_lo);
auto counts = deg_B.index_select(0, A_k);
const int64_t P = counts.sum().item<int64_t>();
if (P == 0) {
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
out._coalesced_(true);
return out;
}
auto group_ids = repeat_interleave_mps(counts);
// exclusive cumsum of counts
auto offsets = cumsum(counts, /*dim=*/0).sub(counts);
auto offsets_gather = offsets.index_select(0, group_ids);
auto within = at::arange(P, at::device(device).dtype(at::kLong)).sub(offsets_gather);
// Map each output element to its source B row and position
auto k_per_out = A_k.index_select(0, group_ids);
auto start_in_B = row_ptr_B.index_select(0, k_per_out);
auto seg_index = start_in_B.add(within);
// Assemble candidate coo pairs and values
auto i_out = A_i.index_select(0, group_ids).contiguous();
auto j_out = B_j.index_select(0, seg_index).contiguous();
auto vA_out = A_val.index_select(0, group_ids).contiguous();
auto vB_out = B_val.index_select(0, seg_index).contiguous();
auto v_out = vA_out.mul(vB_out);
// build (2, P) indices
auto out_indices = at::empty({2, P}, at::device(device).dtype(at::kLong)).contiguous();
out_indices.select(0, 0).copy_(i_out);
out_indices.select(0, 1).copy_(j_out);
auto result = _sparse_coo_tensor_unsafe(
out_indices, v_out, {I, N}, mat1_.options().dtype(computeDtype));
result = result.coalesce();
if (result.scalar_type() != mat1_.scalar_type()) {
auto cast_vals = result._values().to(mat1_.scalar_type());
auto out = _sparse_coo_tensor_unsafe(result._indices(), cast_vals, {I, N}, mat1_.options());
out._coalesced_(true);
return out;
}
return result;
}
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
} // namespace at::native

View File

@ -52,18 +52,19 @@ def test_sparse_coo_and_csr(m, n, k, nnz, test_count):
start.record()
coo.matmul(mat)
stop.record()
times.append(start.elapsed_time(stop))
coo_mean_time = sum(times) / len(times)
coo_mean_time = sum(times) / len(times)
times = []
for _ in range(test_count):
start.record()
csr.matmul(mat)
stop.record()
times.append(start.elapsed_time(stop))
times = []
for _ in range(test_count):
start.record()
csr.matmul(mat)
stop.record()
times.append(start.elapsed_time(stop))
csr_mean_time = sum(times) / len(times)
csr_mean_time = sum(times) / len(times)
return coo_mean_time, csr_mean_time

View File

@ -1,8 +1,6 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/macros/Export.h>
#include <optional>
namespace c10 {
@ -17,8 +15,7 @@ struct C10_API AutogradState {
bool inference_mode,
bool fw_grad_mode,
bool multithreading_enabled)
: graph_exec_group_(std::nullopt),
grad_mode_(grad_mode),
: grad_mode_(grad_mode),
inference_mode_(inference_mode),
fw_grad_mode_(fw_grad_mode),
multithreading_enabled_(multithreading_enabled),
@ -44,10 +41,6 @@ struct C10_API AutogradState {
view_replay_enabled_ = view_replay_enabled;
}
void set_graph_exec_group(std::optional<SafePyObject> group) {
graph_exec_group_ = std::move(group);
}
bool get_grad_mode() const {
return grad_mode_;
}
@ -68,12 +61,7 @@ struct C10_API AutogradState {
return view_replay_enabled_;
}
const std::optional<SafePyObject>& get_graph_exec_group() const {
return graph_exec_group_;
}
private:
std::optional<SafePyObject> graph_exec_group_;
bool grad_mode_ : 1;
bool inference_mode_ : 1;
bool fw_grad_mode_ : 1;

View File

@ -96,10 +96,6 @@ struct C10_API DeviceAllocator : public c10::Allocator {
// Resets peak memory usage statistics for the specified device
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
// Return the free memory size and total memory size in bytes for the
// specified device.
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) = 0;
};
// This function is used to get the DeviceAllocator for a specific device type

View File

@ -345,13 +345,6 @@ class CUDAAllocator : public DeviceAllocator {
c10::DeviceIndex device,
std::shared_ptr<AllocatorState> pps) = 0;
virtual std::string name() = 0;
std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) override {
c10::DeviceGuard device_guard({at::kCUDA, device});
size_t free = 0;
size_t total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&free, &total));
return {free, total};
}
};
// Allocator object, statically initialized

View File

@ -66,15 +66,6 @@ def define_targets(rules):
],
)
rules.cc_test(
name = "util/nofatal_test",
srcs = ["util/nofatal_test.cpp"],
deps = [
"//c10/util:base",
"@com_google_googletest//:gtest_main",
],
)
rules.cc_test(
name = "util/ssize_test",
srcs = ["util/ssize_test.cpp"],

View File

@ -1,53 +0,0 @@
#include <gtest/gtest.h>
#include <c10/util/Exception.h>
#include <c10/util/Logging.h>
namespace {
template <typename T>
inline void expectThrowsEq(T&& fn, const char* expected_msg) {
try {
std::forward<T>(fn)();
} catch (const c10::Error& e) {
EXPECT_TRUE(
std::string(e.what_without_backtrace()).find(expected_msg) !=
std::string::npos);
return;
}
ADD_FAILURE() << "Expected to throw exception with message \"" << expected_msg
<< "\" but didn't throw";
}
} // namespace
TEST(NofatalTest, TorchCheckComparisons) {
// quick make sure that no-op works as expected
TORCH_CHECK_EQ(1, 1) << "i am a silly message " << 1;
expectThrowsEq(
[]() { TORCH_CHECK_EQ(1, 2) << "i am a silly message " << 1; },
"Check failed: 1 == 2 (1 vs. 2). i am a silly message 1");
expectThrowsEq(
[]() { TORCH_CHECK_NE(2, 2); }, "Check failed: 2 != 2 (2 vs. 2).");
expectThrowsEq(
[]() { TORCH_CHECK_LT(2, 2); }, "Check failed: 2 < 2 (2 vs. 2).");
expectThrowsEq(
[]() { TORCH_CHECK_LE(3, 2); }, "Check failed: 3 <= 2 (3 vs. 2).");
expectThrowsEq(
[]() { TORCH_CHECK_GT(2, 2); }, "Check failed: 2 > 2 (2 vs. 2).");
expectThrowsEq(
[]() { TORCH_CHECK_GE(2, 3); }, "Check failed: 2 >= 3 (2 vs. 3).");
expectThrowsEq(
[]() {
void* p = nullptr;
TORCH_CHECK_NOTNULL(p);
},
"Check failed: 'p' must be non NULL.");
#if GTEST_HAS_DEATH_TEST
#ifndef NDEBUG
// if dbg build, DCHECK should result in deth
EXPECT_DEATH(TORCH_DCHECK_EQ(1, 2), "Check failed");
#else
TORCH_DCHECK_EQ(1, 2); // no-op
#endif
#endif // GTEST_HAS_DEATH_TEST
}

View File

@ -702,98 +702,6 @@ namespace c10::detail {
#define TORCH_CHECK_ARG(cond, argN, ...) \
TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__)
#ifndef FATAL_IF
#ifdef C10_USE_GLOG
#define FATAL_IF(condition) \
condition ? (void)0 \
: ::c10::LoggerVoidify() & \
::c10::MessageLogger(__FILE__, __LINE__, ::google::GLOG_FATAL) \
.stream()
#else
#define FATAL_IF(condition) \
condition ? (void)0 \
: ::c10::LoggerVoidify() & \
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL).stream()
#endif
#endif
#ifndef NON_FATAL_IF
#ifdef C10_USE_GLOG
#define NON_FATAL_IF(condition) \
condition ? (void)0 \
: ::c10::LoggerVoidify() & \
::c10::MessageLogger( \
__FILE__, __LINE__, ::google::GLOG_FATAL, false) \
.stream()
#else
#define NON_FATAL_IF(condition) \
condition ? (void)0 \
: ::c10::LoggerVoidify() & \
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL, false) \
.stream()
#endif
#endif
// Binary comparison check macros
#define TORCH_CHECK_OP(val1, val2, op) \
NON_FATAL_IF(((val1)op(val2))) \
<< "Check failed: " #val1 " " #op " " #val2 " (" << (val1) << " vs. " \
<< (val2) << "). "
#define TORCH_DCHECK_OP(val1, val2, op) \
FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \
<< (val1) << " vs. " << (val2) << "). "
#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
// Debug versions of TORCH_CHECK_OP macros
#ifndef NDEBUG
#define TORCH_DCHECK_EQ(val1, val2) TORCH_DCHECK_OP(val1, val2, ==)
#define TORCH_DCHECK_NE(val1, val2) TORCH_DCHECK_OP(val1, val2, !=)
#define TORCH_DCHECK_LE(val1, val2) TORCH_DCHECK_OP(val1, val2, <=)
#define TORCH_DCHECK_LT(val1, val2) TORCH_DCHECK_OP(val1, val2, <)
#define TORCH_DCHECK_GE(val1, val2) TORCH_DCHECK_OP(val1, val2, >=)
#define TORCH_DCHECK_GT(val1, val2) TORCH_DCHECK_OP(val1, val2, >)
#else // !NDEBUG
// Optimized versions - generate no code
#define TORCH_DCHECK_EQ(val1, val2) \
while (false) \
TORCH_DCHECK_OP(val1, val2, ==)
#define TORCH_DCHECK_NE(val1, val2) \
while (false) \
TORCH_DCHECK_OP(val1, val2, !=)
#define TORCH_DCHECK_LE(val1, val2) \
while (false) \
TORCH_DCHECK_OP(val1, val2, <=)
#define TORCH_DCHECK_LT(val1, val2) \
while (false) \
TORCH_DCHECK_OP(val1, val2, <)
#define TORCH_DCHECK_GE(val1, val2) \
while (false) \
TORCH_DCHECK_OP(val1, val2, >=)
#define TORCH_DCHECK_GT(val1, val2) \
while (false) \
TORCH_DCHECK_OP(val1, val2, >)
#endif // NDEBUG
// Null pointer check macro
#define TORCH_CHECK_NOTNULL(val) \
::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), false)
#ifndef NDEBUG
#define TORCH_DCHECK_NOTNULL(val) \
::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), true)
#else // !NDEBUG
#define TORCH_DCHECK_NOTNULL(val) \
while (false) \
TORCH_CHECK_NOTNULL(val)
#endif // NDEBUG
// ----------------------------------------------------------------------------
// Deprecated macros
// ----------------------------------------------------------------------------

View File

@ -291,32 +291,6 @@ namespace c10 {
using fLB::FLAGS_logtostderr;
using fLI::FLAGS_minloglevel;
using fLI::FLAGS_v;
MessageLogger::MessageLogger(
const char* file,
int line,
int severity,
bool exit_on_fatal)
: stream_(), severity_(severity), exit_on_fatal_(exit_on_fatal) {}
MessageLogger::~MessageLogger() noexcept(false) {
if (severity_ == ::google::GLOG_FATAL) {
DealWithFatal();
}
}
std::stringstream& MessageLogger::stream() {
return stream_;
}
void MessageLogger::DealWithFatal() {
if (exit_on_fatal_) {
LOG(FATAL) << stream_.str();
} else {
throw c10::Error(stream_.str(), nullptr, nullptr);
}
}
} // namespace c10
C10_DEFINE_int(
@ -438,16 +412,17 @@ void ShowLogInfoToStderr() {
FLAGS_caffe2_log_level = GLOG_INFO;
}
MessageLogger::MessageLogger(
const char* file,
int line,
int severity,
bool exit_on_fatal)
: severity_(severity), exit_on_fatal_(exit_on_fatal) {
MessageLogger::MessageLogger(const char* file, int line, int severity)
: severity_(severity) {
if (severity_ < FLAGS_caffe2_log_level) {
// Nothing needs to be logged.
return;
}
#ifdef ANDROID
tag_ = "native";
#else // !ANDROID
tag_ = "";
#endif // ANDROID
time_t rawtime = 0;
time(&rawtime);
@ -483,7 +458,7 @@ MessageLogger::MessageLogger(
}
// Output the contents of the stream to the proper channel on destruction.
MessageLogger::~MessageLogger() noexcept(false) {
MessageLogger::~MessageLogger() {
if (severity_ < FLAGS_caffe2_log_level) {
// Nothing needs to be logged.
return;
@ -523,18 +498,6 @@ MessageLogger::~MessageLogger() noexcept(false) {
}
}
std::stringstream& MessageLogger::stream() {
return stream_;
}
void MessageLogger::DealWithFatal() {
if (exit_on_fatal_) {
abort();
} else {
throw c10::Error(stream_.str(), nullptr, nullptr);
}
}
} // namespace c10
#endif // !C10_USE_GLOG

View File

@ -1,74 +0,0 @@
#ifndef C10_UTIL_LOGGING_COMMON_H_
#define C10_UTIL_LOGGING_COMMON_H_
#include <c10/macros/Export.h>
#include <sstream>
namespace c10 {
// MessageLogger that throws exceptions instead of aborting (glog version)
// or logs and may abort (non-glog version).
class C10_API MessageLogger {
public:
MessageLogger(
const char* file,
int line,
int severity,
bool exit_on_fatal = true);
~MessageLogger() noexcept(false);
// Return the stream associated with the logger object.
std::stringstream& stream();
private:
// When there is a fatal log, and fatal == true, we abort
// otherwise, we throw.
void DealWithFatal();
#if defined(ANDROID) && !defined(C10_USE_GLOG)
const char* tag_{"native"};
#endif
std::stringstream stream_;
int severity_;
bool exit_on_fatal_;
};
// This class is used to explicitly ignore values in the conditional
// logging macros. This avoids compiler warnings like "value computed
// is not used" and "statement has no effect".
class C10_API LoggerVoidify {
public:
LoggerVoidify() = default;
// This has to be an operator with a precedence lower than << but
// higher than ?:
void operator&(const std::ostream& s [[maybe_unused]]) {}
};
// Forward declarations for CheckNotNull functions
template <typename T>
T& CheckNotNullCommon(
const char* file,
int line,
const char* names,
T& t,
bool fatal = true);
template <typename T>
T* CheckNotNull(
const char* file,
int line,
const char* names,
T* t,
bool fatal = true);
template <typename T>
T& CheckNotNull(
const char* file,
int line,
const char* names,
T& t,
bool fatal = true);
} // namespace c10
#endif // C10_UTIL_LOGGING_COMMON_H_

View File

@ -47,53 +47,57 @@ INSTANTIATE_FOR_CONTAINER(set)
#endif
#include <c10/util/logging_common.h>
#include <glog/logging.h>
namespace c10 {
// Additional macros on top of glog
#define TORCH_CHECK_EQ(val1, val2) CHECK_EQ(val1, val2)
#define TORCH_CHECK_NE(val1, val2) CHECK_NE(val1, val2)
#define TORCH_CHECK_LE(val1, val2) CHECK_LE(val1, val2)
#define TORCH_CHECK_LT(val1, val2) CHECK_LT(val1, val2)
#define TORCH_CHECK_GE(val1, val2) CHECK_GE(val1, val2)
#define TORCH_CHECK_GT(val1, val2) CHECK_GT(val1, val2)
[[noreturn]] void ThrowEnforceNotMet(
const char* file,
const int line,
const char* condition,
const std::string& msg,
const void* caller);
#ifndef NDEBUG
#define TORCH_DCHECK_EQ(val1, val2) DCHECK_EQ(val1, val2)
#define TORCH_DCHECK_NE(val1, val2) DCHECK_NE(val1, val2)
#define TORCH_DCHECK_LE(val1, val2) DCHECK_LE(val1, val2)
#define TORCH_DCHECK_LT(val1, val2) DCHECK_LT(val1, val2)
#define TORCH_DCHECK_GE(val1, val2) DCHECK_GE(val1, val2)
#define TORCH_DCHECK_GT(val1, val2) DCHECK_GT(val1, val2)
#else // !NDEBUG
// These versions generate no code in optimized mode.
#define TORCH_DCHECK_EQ(val1, val2) \
while (false) \
DCHECK_EQ(val1, val2)
#define TORCH_DCHECK_NE(val1, val2) \
while (false) \
DCHECK_NE(val1, val2)
#define TORCH_DCHECK_LE(val1, val2) \
while (false) \
DCHECK_LE(val1, val2)
#define TORCH_DCHECK_LT(val1, val2) \
while (false) \
DCHECK_LT(val1, val2)
#define TORCH_DCHECK_GE(val1, val2) \
while (false) \
DCHECK_GE(val1, val2)
#define TORCH_DCHECK_GT(val1, val2) \
while (false) \
DCHECK_GT(val1, val2)
#endif // NDEBUG
template <typename T>
T& CheckNotNullCommon(
const char* file,
int line,
const char* names,
T& t,
bool fatal) {
if (t == nullptr) {
MessageLogger(file, line, ::google::GLOG_FATAL, fatal).stream()
<< "Check failed: '" << names << "' must be non NULL. ";
}
return t;
}
// Check that a pointer is not null.
#define TORCH_CHECK_NOTNULL(val) CHECK_NOTNULL(val)
template <typename T>
T* CheckNotNull(
const char* file,
int line,
const char* names,
T* t,
bool fatal) {
return CheckNotNullCommon(file, line, names, t, fatal);
}
template <typename T>
T& CheckNotNull(
const char* file,
int line,
const char* names,
T& t,
bool fatal) {
return CheckNotNullCommon(file, line, names, t, fatal);
}
} // namespace c10
#ifndef NDEBUG
// Debug only version of TORCH_CHECK_NOTNULL
#define TORCH_DCHECK_NOTNULL(val) DCHECK_NOTNULL(val)
#else // !NDEBUG
// Optimized version - generates no code.
#define TORCH_DCHECK_NOTNULL(val) \
while (false) \
DCHECK_NOTNULL(val)
#endif // NDEBUG
// Log with source location information override (to be used in generic
// warning/error handlers implemented as functions, not macros)

View File

@ -13,7 +13,6 @@
#include <vector>
#include <c10/util/Flags.h>
#include <c10/util/logging_common.h>
const char CAFFE2_SEVERITY_PREFIX[] = "FEWIV";
@ -25,40 +24,61 @@ const int GLOG_ERROR = 2;
const int GLOG_WARNING = 1;
const int GLOG_INFO = 0;
class C10_API MessageLogger {
public:
MessageLogger(const char* file, int line, int severity);
~MessageLogger();
// Return the stream associated with the logger object.
std::stringstream& stream() {
return stream_;
}
private:
// When there is a fatal log, we simply abort.
void DealWithFatal() {
abort();
}
const char* tag_;
std::stringstream stream_;
int severity_;
};
// This class is used to explicitly ignore values in the conditional
// logging macros. This avoids compiler warnings like "value computed
// is not used" and "statement has no effect".
class C10_API LoggerVoidify {
public:
LoggerVoidify() = default;
// This has to be an operator with a precedence lower than << but
// higher than ?:
void operator&(const std::ostream& s [[maybe_unused]]) {}
};
// Log a message and terminate.
template <class T>
void LogMessageFatal(const char* file, int line, const T& message) {
MessageLogger(file, line, GLOG_FATAL).stream() << message;
}
// Helpers for TORCH_CHECK_NOTNULL(). Two are necessary to support both raw
// pointers and smart pointers.
template <typename T>
T& CheckNotNullCommon(
const char* file,
int line,
const char* names,
T& t,
bool fatal) {
T& CheckNotNullCommon(const char* file, int line, const char* names, T& t) {
if (t == nullptr) {
MessageLogger(file, line, GLOG_FATAL, fatal).stream()
<< "Check failed: '" << names << "' must be non NULL. ";
LogMessageFatal(file, line, std::string(names));
}
return t;
}
template <typename T>
T* CheckNotNull(
const char* file,
int line,
const char* names,
T* t,
bool fatal) {
return CheckNotNullCommon(file, line, names, t, fatal);
T* CheckNotNull(const char* file, int line, const char* names, T* t) {
return CheckNotNullCommon(file, line, names, t);
}
template <typename T>
T& CheckNotNull(
const char* file,
int line,
const char* names,
T& t,
bool fatal) {
return CheckNotNullCommon(file, line, names, t, fatal);
T& CheckNotNull(const char* file, int line, const char* names, T& t) {
return CheckNotNullCommon(file, line, names, t);
}
} // namespace c10
@ -116,6 +136,65 @@ static_assert(
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream()
#endif // NDEBUG
#define TORCH_CHECK_OP(val1, val2, op) \
FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \
<< (val1) << " vs. " << (val2) << ") "
// TORCH_CHECK_OP macro definitions
#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
#ifndef NDEBUG
// Debug only versions of TORCH_CHECK_OP macros.
#define TORCH_DCHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
#define TORCH_DCHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
#define TORCH_DCHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
#define TORCH_DCHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
#define TORCH_DCHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
#define TORCH_DCHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
#else // !NDEBUG
// These versions generate no code in optimized mode.
#define TORCH_DCHECK_EQ(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, ==)
#define TORCH_DCHECK_NE(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, !=)
#define TORCH_DCHECK_LE(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, <=)
#define TORCH_DCHECK_LT(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, <)
#define TORCH_DCHECK_GE(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, >=)
#define TORCH_DCHECK_GT(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, >)
#endif // NDEBUG
// Check that a pointer is not null.
#define TORCH_CHECK_NOTNULL(val) \
::c10::CheckNotNull( \
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
#ifndef NDEBUG
// Debug only version of TORCH_CHECK_NOTNULL
#define TORCH_DCHECK_NOTNULL(val) \
::c10::CheckNotNull( \
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
#else // !NDEBUG
// Optimized version - generates no code.
#define TORCH_DCHECK_NOTNULL(val) \
while (false) \
TORCH_CHECK_NOTNULL(val)
#endif // NDEBUG
// ---------------------- Support for std objects --------------------------
// These are adapted from glog to support a limited set of logging capability
// for STL objects.

View File

@ -926,14 +926,15 @@ class DeviceCachingAllocator {
(release_cached_blocks() && alloc_block(params, true));
}
if (!block_found) {
const auto& raw_device = c10::xpu::get_raw_device(device);
const auto device_total =
raw_device.get_info<sycl::info::device::global_mem_size>();
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device);
auto device_total = device_prop.global_mem_size;
// Estimate the available device memory when the SYCL runtime does not
// support the corresponding aspect (ext_intel_free_memory).
size_t device_free = device_total -
size_t device_free = device_prop.global_mem_size -
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current;
auto& raw_device = c10::xpu::get_raw_device(device);
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
// affected devices.
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
@ -1051,37 +1052,21 @@ class DeviceCachingAllocator {
}
}
std::pair<size_t, size_t> getMemoryInfo() {
const auto& device = c10::xpu::get_raw_device(device_index);
const size_t total = device.get_info<sycl::info::device::global_mem_size>();
TORCH_CHECK(
device.has(sycl::aspect::ext_intel_free_memory),
"The device (",
device.get_info<sycl::info::device::name>(),
") doesn't support querying the available free memory. ",
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
"to help us prioritize its implementation.");
const size_t free =
device.get_info<sycl::ext::intel::info::device::free_memory>();
return {free, total};
}
double getMemoryFraction() {
if (!set_fraction) {
return 1.0;
}
const auto device_total =
xpu::get_raw_device(device_index)
.get_info<sycl::info::device::global_mem_size>();
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
return static_cast<double>(allowed_memory_maximum) /
static_cast<double>(device_total);
static_cast<double>(device_prop.global_mem_size);
}
void setMemoryFraction(double fraction) {
const auto device_total =
xpu::get_raw_device(device_index)
.get_info<sycl::info::device::global_mem_size>();
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
auto device_total = device_prop.global_mem_size;
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
set_fraction = true;
}
@ -1255,11 +1240,6 @@ class XPUAllocator : public DeviceAllocator {
c10::xpu::get_raw_device(dev_to_access));
}
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
assertValidDevice(device);
return device_allocators[device]->getMemoryInfo();
}
double getMemoryFraction(DeviceIndex device) {
assertValidDevice(device);
return device_allocators[device]->getMemoryFraction();

View File

@ -40,7 +40,6 @@
:nosignatures:
empty_cache
get_memory_info
max_memory_allocated
max_memory_reserved
memory_allocated

View File

@ -382,6 +382,20 @@ coverage_ignore_functions = [
# torch.ao.quantization.backend_config.tensorrt
"get_tensorrt_backend_config",
"get_tensorrt_backend_config_dict",
# torch.ao.quantization.backend_config.utils
"entry_to_pretty_str",
"get_fused_module_classes",
"get_fuser_method_mapping",
"get_fusion_pattern_to_extra_inputs_getter",
"get_fusion_pattern_to_root_node_getter",
"get_module_to_qat_module",
"get_pattern_to_dtype_configs",
"get_pattern_to_input_type_to_index",
"get_qat_module_classes",
"get_root_module_to_quantized_reference_module",
"pattern_to_human_readable",
"remove_boolean_dispatch_from_name",
# torch.ao.quantization.backend_config.x86
"get_x86_backend_config",
# torch.ao.quantization.fuse_modules
"fuse_known_modules",
@ -412,6 +426,25 @@ coverage_ignore_functions = [
"insert_observers_for_model",
"prepare",
"propagate_dtypes_for_known_nodes",
# torch.ao.quantization.fx.utils
"all_node_args_except_first",
"all_node_args_have_no_tensors",
"assert_and_get_unique_device",
"collect_producer_nodes",
"create_getattr_from_value",
"create_node_from_old_node_preserve_meta",
"get_custom_module_class_keys",
"get_linear_prepack_op_for_dtype",
"get_new_attr_name_with_prefix",
"get_non_observable_arg_indexes_and_types",
"get_qconv_prepack_op",
"get_skipped_module_name_and_classes",
"graph_module_from_producer_nodes",
"maybe_get_next_module",
"node_arg_is_bias",
"node_arg_is_weight",
"return_arg_list",
# torch.ao.quantization.pt2e.graph_utils
"bfs_trace_with_node_process",
"find_sequential_partitions",
"get_equivalent_types",
@ -827,10 +860,80 @@ coverage_ignore_functions = [
"get_latency_of_one_partition",
"get_latency_of_partitioned_graph",
"get_partition_to_latency_mapping",
# torch.fx.experimental.proxy_tensor
"decompose",
"disable_autocast_cache",
"disable_proxy_modes_tracing",
"dispatch_trace",
"extract_val",
"fake_signature",
"fetch_sym_proxy",
"fetch_object_proxy",
"get_innermost_proxy_mode",
"get_isolated_graphmodule",
"get_proxy_slot",
"get_torch_dispatch_modes",
"has_proxy_slot",
"is_sym_node",
"maybe_handle_decomp",
"proxy_call",
"set_meta",
"set_original_aten_op",
"set_proxy_slot",
"snapshot_fake",
"thunkify",
"track_tensor",
"track_tensor_tree",
"wrap_key",
"wrapper_and_args_for_make_fx",
# torch.fx.experimental.recording
"record_shapeenv_event",
"replay_shape_env_events",
"shape_env_check_state_equal",
# torch.fx.experimental.sym_node
"ceil_impl",
"floor_ceil_helper",
"floor_impl",
"method_to_operator",
"sympy_is_channels_last_contiguous_2d",
"sympy_is_channels_last_contiguous_3d",
"sympy_is_channels_last_strides_2d",
"sympy_is_channels_last_strides_3d",
"sympy_is_channels_last_strides_generic",
"sympy_is_contiguous",
"sympy_is_contiguous_generic",
"to_node",
"wrap_node",
"sym_sqrt",
# torch.fx.experimental.symbolic_shapes
"bind_symbols",
"cast_symbool_to_symint_guardless",
"create_contiguous",
"error",
"eval_guards",
"eval_is_non_overlapping_and_dense",
"expect_true",
"find_symbol_binding_fx_nodes",
"free_symbols",
"free_unbacked_symbols",
"fx_placeholder_targets",
"fx_placeholder_vals",
"guard_bool",
"guard_float",
"guard_int",
"guard_scalar",
"has_hint",
"has_symbolic_sizes_strides",
"is_channels_last_contiguous_2d",
"is_channels_last_contiguous_3d",
"is_channels_last_strides_2d",
"is_channels_last_strides_3d",
"is_contiguous",
"is_non_overlapping_and_dense_indicator",
"is_nested_int",
"is_symbol_binding_fx_node",
"is_symbolic",
# torch.fx.experimental.unification.core
"reify",
# torch.fx.experimental.unification.match
"edge",
@ -868,6 +971,24 @@ coverage_ignore_functions = [
"reverse_dict",
# torch.fx.experimental.unification.multipledispatch.variadic
"isvariadic",
# torch.fx.experimental.unification.unification_tools
"assoc",
"assoc_in",
"dissoc",
"first",
"get_in",
"getter",
"groupby",
"itemfilter",
"itemmap",
"keyfilter",
"keymap",
"merge",
"merge_with",
"update_in",
"valfilter",
"valmap",
# torch.fx.experimental.unification.utils
"freeze",
"hashable",
"raises",
@ -1308,8 +1429,319 @@ coverage_ignore_functions = [
# torch.onnx.symbolic_opset7
"max",
"min",
# torch.onnx.symbolic_opset8
"addmm",
"bmm",
"empty",
"empty_like",
"flatten",
"full",
"full_like",
"gt",
"lt",
"matmul",
"mm",
"ones",
"ones_like",
"prelu",
"repeat",
"zeros",
"zeros_like",
# torch.onnx.symbolic_opset9
"abs",
"acos",
"adaptive_avg_pool1d",
"adaptive_avg_pool2d",
"adaptive_avg_pool3d",
"adaptive_max_pool1d",
"adaptive_max_pool2d",
"adaptive_max_pool3d",
"add",
"addcmul",
"addmm",
"alias",
"amax",
"amin",
"aminmax",
"arange",
"argmax",
"argmin",
"as_strided",
"as_tensor",
"asin",
"atan",
"atan2",
"avg_pool1d",
"avg_pool2d",
"avg_pool3d",
"baddbmm",
"batch_norm",
"bernoulli",
"bitwise_not",
"bitwise_or",
"bmm",
"broadcast_tensors",
"broadcast_to",
"bucketize",
"cat",
"cdist",
"ceil",
"clamp",
"clamp_max",
"clamp_min",
"clone",
"constant_pad_nd",
"contiguous",
"conv1d",
"conv2d",
"conv3d",
"conv_tbc",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
"convert_element_type",
"convolution",
"cos",
"cosine_similarity",
"cross",
"cumsum",
"detach",
"dim",
"div",
"dot",
"dropout",
"elu",
"embedding",
"embedding_bag",
"empty",
"empty_like",
"eq",
"erf",
"exp",
"expand",
"expand_as",
"eye",
"fill",
"flatten",
"floor",
"floor_divide",
"floordiv",
"frobenius_norm",
"full",
"full_like",
"gather",
"ge",
"gelu",
"get_pool_ceil_padding",
"glu",
"group_norm",
"gru",
"gt",
"hann_window",
"hardshrink",
"hardsigmoid",
"hardswish",
"hardtanh",
"index",
"index_add",
"index_copy",
"index_fill",
"index_put",
"index_select",
"instance_norm",
"is_floating_point",
"is_pinned",
"isnan",
"item",
"kl_div",
"layer_norm",
"le",
"leaky_relu",
"lerp",
"lift",
"linalg_cross",
"linalg_matrix_norm",
"linalg_norm",
"linalg_vector_norm",
"linear",
"linspace",
"log",
"log10",
"log1p",
"log2",
"log_sigmoid",
"log_softmax",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"logit",
"logsumexp",
"lstm",
"lstm_cell",
"lt",
"masked_fill",
"masked_fill_",
"matmul",
"max",
"max_pool1d",
"max_pool1d_with_indices",
"max_pool2d",
"max_pool2d_with_indices",
"max_pool3d",
"max_pool3d_with_indices",
"maximum",
"meshgrid",
"min",
"minimum",
"mish",
"mm",
"movedim",
"mse_loss",
"mul",
"multinomial",
"mv",
"narrow",
"native_layer_norm",
"ne",
"neg",
"new_empty",
"new_full",
"new_ones",
"new_zeros",
"nonzero",
"nonzero_numpy",
"noop_complex_operators",
"norm",
"numel",
"numpy_T",
"one_hot",
"ones",
"ones_like",
"onnx_placeholder",
"overload_by_arg_count",
"pad",
"pairwise_distance",
"permute",
"pixel_shuffle",
"pixel_unshuffle",
"pow",
"prelu",
"prim_constant",
"prim_constant_chunk",
"prim_constant_split",
"prim_data",
"prim_device",
"prim_dtype",
"prim_if",
"prim_layout",
"prim_list_construct",
"prim_list_unpack",
"prim_loop",
"prim_max",
"prim_min",
"prim_shape",
"prim_tolist",
"prim_tuple_construct",
"prim_type",
"prim_unchecked_cast",
"prim_uninitialized",
"rand",
"rand_like",
"randint",
"randint_like",
"randn",
"randn_like",
"reciprocal",
"reflection_pad",
"relu",
"relu6",
"remainder",
"repeat",
"repeat_interleave",
"replication_pad",
"reshape",
"reshape_as",
"rnn_relu",
"rnn_tanh",
"roll",
"rrelu",
"rsqrt",
"rsub",
"scalar_tensor",
"scatter",
"scatter_add",
"select",
"selu",
"sigmoid",
"sign",
"silu",
"sin",
"size",
"slice",
"softmax",
"softplus",
"softshrink",
"sort",
"split",
"split_with_sizes",
"sqrt",
"square",
"squeeze",
"stack",
"std",
"std_mean",
"sub",
"t",
"take",
"tan",
"tanh",
"tanhshrink",
"tensor",
"threshold",
"to",
"topk",
"transpose",
"true_divide",
"type_as",
"unbind",
"unfold",
"unsafe_chunk",
"unsafe_split",
"unsafe_split_with_sizes",
"unsqueeze",
"unsupported_complex_operators",
"unused",
"upsample_bilinear2d",
"upsample_linear1d",
"upsample_nearest1d",
"upsample_nearest2d",
"upsample_nearest3d",
"upsample_trilinear3d",
"var",
"var_mean",
"view",
"view_as",
"where",
"wrap_logical_op_with_cast_to",
"wrap_logical_op_with_negation",
"zero",
"zeros",
"zeros_like",
# torch.onnx.utils
"disable_apex_o2_state_dict_hook",
"export",
"export_to_pretty_string",
"exporter_context",
"is_in_onnx_export",
"model_signature",
"register_custom_op_symbolic",
"select_model_mode_for_export",
"setup_onnx_logging",
"unconvertible_ops",
"unpack_quantized_tensor",
"warn_on_static_input_change",
# torch.onnx.verification
"check_export_model_diff",
"verify",
"verify_aten_graph",
@ -1400,6 +1832,32 @@ coverage_ignore_functions = [
"noop_context_fn",
"set_checkpoint_early_stop",
"set_device_states",
# torch.utils.collect_env
"check_release_file",
"get_cachingallocator_config",
"get_clang_version",
"get_cmake_version",
"get_conda_packages",
"get_cpu_info",
"get_cuda_module_loading_config",
"get_cudnn_version",
"get_env_info",
"get_gcc_version",
"get_gpu_info",
"get_libc_version",
"get_lsb_version",
"get_mac_version",
"get_nvidia_driver_version",
"get_nvidia_smi",
"get_os",
"get_pip_packages",
"get_platform",
"get_pretty_env_info",
"get_python_platform",
"get_running_cuda_version",
"get_windows_version",
"is_xnnpack_available",
"pretty_str",
# torch.utils.cpp_backtrace
"get_cpp_backtrace",
# torch.utils.cpp_extension
@ -1463,6 +1921,52 @@ coverage_ignore_functions = [
"apply_shuffle_seed",
"apply_shuffle_settings",
"get_all_graph_pipes",
# torch.utils.flop_counter
"addmm_flop",
"baddbmm_flop",
"bmm_flop",
"conv_backward_flop",
"conv_flop",
"conv_flop_count",
"convert_num_with_suffix",
"get_shape",
"get_suffix_str",
"mm_flop",
"normalize_tuple",
"register_flop_formula",
"sdpa_backward_flop",
"sdpa_backward_flop_count",
"sdpa_flop",
"sdpa_flop_count",
"shape_wrapper",
"transpose_shape",
# torch.utils.hipify.hipify_python
"add_dim3",
"compute_stats",
"extract_arguments",
"file_add_header",
"file_specific_replacement",
"find_bracket_group",
"find_closure_group",
"find_parentheses_group",
"fix_static_global_kernels",
"get_hip_file_path",
"hip_header_magic",
"hipify",
"is_caffe2_gpu_file",
"is_cusparse_file",
"is_out_of_place",
"is_pytorch_file",
"is_special_file",
"match_extensions",
"matched_files_iter",
"openf",
"preprocess_file_and_save_result",
"preprocessor",
"processKernelLaunches",
"replace_extern_shared",
"replace_math_functions",
"str2bool",
# torch.utils.hooks
"unserializable_hook",
"warn_if_has_hooks",

View File

@ -12,37 +12,6 @@ These APIs are experimental and subject to change without notice.
.. autoclass:: torch.fx.experimental.sym_node.DynamicInt
```
## torch.fx.experimental.sym_node
```{eval-rst}
.. currentmodule:: torch.fx.experimental.sym_node
```
```{eval-rst}
.. automodule:: torch.fx.experimental.sym_node
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
is_channels_last_contiguous_2d
is_channels_last_contiguous_3d
is_channels_last_strides_2d
is_channels_last_strides_3d
is_contiguous
is_non_overlapping_and_dense_indicator
method_to_operator
sympy_is_channels_last_contiguous_2d
sympy_is_channels_last_contiguous_3d
sympy_is_channels_last_strides_2d
sympy_is_channels_last_strides_3d
sympy_is_channels_last_strides_generic
sympy_is_contiguous
sympy_is_contiguous_generic
```
## torch.fx.experimental.symbolic_shapes
```{eval-rst}
@ -100,25 +69,6 @@ These APIs are experimental and subject to change without notice.
rebind_unbacked
resolve_unbacked_bindings
is_accessor_node
cast_symbool_to_symint_guardless
create_contiguous
error
eval_guards
eval_is_non_overlapping_and_dense
find_symbol_binding_fx_nodes
free_symbols
free_unbacked_symbols
fx_placeholder_targets
fx_placeholder_vals
guard_bool
guard_float
guard_int
guard_scalar
has_hint
has_symbolic_sizes_strides
is_nested_int
is_symbol_binding_fx_node
is_symbolic
```
## torch.fx.experimental.proxy_tensor
@ -141,46 +91,4 @@ These APIs are experimental and subject to change without notice.
get_proxy_mode
maybe_enable_thunkify
maybe_disable_thunkify
decompose
disable_autocast_cache
disable_proxy_modes_tracing
extract_val
fake_signature
fetch_object_proxy
fetch_sym_proxy
has_proxy_slot
is_sym_node
maybe_handle_decomp
proxy_call
set_meta
set_original_aten_op
set_proxy_slot
snapshot_fake
```
## torch.fx.experimental.unification.unification_tools
```{eval-rst}
.. currentmodule:: torch.fx.experimental.unification.unification_tools
```
```{eval-rst}
.. automodule:: torch.fx.experimental.unification.unification_tools
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
assoc
assoc_in
dissoc
first
keyfilter
keymap
merge
merge_with
update_in
valfilter
valmap

View File

@ -1134,6 +1134,7 @@ The set of leaf modules can be customized by overriding
.. py:module:: torch.fx.experimental.refinement_types
.. py:module:: torch.fx.experimental.rewriter
.. py:module:: torch.fx.experimental.schema_type_annotation
.. py:module:: torch.fx.experimental.sym_node
.. py:module:: torch.fx.experimental.unification.core
.. py:module:: torch.fx.experimental.unification.dispatch
.. py:module:: torch.fx.experimental.unification.match
@ -1143,6 +1144,7 @@ The set of leaf modules can be customized by overriding
.. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher
.. py:module:: torch.fx.experimental.unification.multipledispatch.utils
.. py:module:: torch.fx.experimental.unification.multipledispatch.variadic
.. py:module:: torch.fx.experimental.unification.unification_tools
.. py:module:: torch.fx.experimental.unification.utils
.. py:module:: torch.fx.experimental.unification.variable
.. py:module:: torch.fx.experimental.unify_refinements

View File

@ -134,23 +134,6 @@ Quantization to work with this as well.
ObservationType
```
## torch.ao.quantization.backend_config.utils
```{eval-rst}
.. currentmodule:: torch.ao.quantization.backend_config.utils
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
entry_to_pretty_str
pattern_to_human_readable
remove_boolean_dispatch_from_name
```
## torch.ao.quantization.fx.custom_config
This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization
@ -171,30 +154,6 @@ This module contains a few CustomConfig classes that's used in both eager mode a
StandaloneModuleConfigEntry
```
## torch.ao.quantization.fx.utils
```{eval-rst}
.. currentmodule:: torch.ao.quantization.fx.utils
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
all_node_args_except_first
all_node_args_have_no_tensors
collect_producer_nodes
create_getattr_from_value
create_node_from_old_node_preserve_meta
graph_module_from_producer_nodes
maybe_get_next_module
node_arg_is_bias
node_arg_is_weight
return_arg_list
```
## torch.ao.quantization.quantizer
```{eval-rst}

View File

@ -19,91 +19,6 @@
swap_tensors
```
# torch.utils.collect_env
```{eval-rst}
.. automodule:: torch.utils.collect_env
```
```{eval-rst}
.. currentmodule:: torch.utils.collect_env
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
check_release_file
is_xnnpack_available
pretty_str
```
# torch.utils.flop_counter
```{eval-rst}
.. automodule:: torch.utils.flop_counter
```
```{eval-rst}
.. currentmodule:: torch.utils.flop_counter
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
baddbmm_flop
bmm_flop
conv_backward_flop
conv_flop
conv_flop_count
register_flop_formula
sdpa_backward_flop
sdpa_backward_flop_count
sdpa_flop
sdpa_flop_count
shape_wrapper
```
# torch.utils.hipify.hipify_python
```{eval-rst}
.. automodule:: torch.utils.hipify.hipify_python
```
```{eval-rst}
.. currentmodule:: torch.utils.hipify.hipify_python
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
compute_stats
extract_arguments
file_add_header
file_specific_replacement
find_bracket_group
find_closure_group
find_parentheses_group
fix_static_global_kernels
hip_header_magic
hipify
is_caffe2_gpu_file
is_cusparse_file
is_out_of_place
is_pytorch_file
is_special_file
openf
preprocess_file_and_save_result
preprocessor
processKernelLaunches
replace_extern_shared
replace_math_functions
str2bool
```
<!-- This module needs to be documented. Adding here in the meantime
for tracking purposes -->
```{eval-rst}
@ -128,6 +43,7 @@ for tracking purposes -->
.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface
.. py:module:: torch.utils.bundled_inputs
.. py:module:: torch.utils.checkpoint
.. py:module:: torch.utils.collect_env
.. py:module:: torch.utils.cpp_backtrace
.. py:module:: torch.utils.cpp_extension
.. py:module:: torch.utils.data.backward_compatibility
@ -164,8 +80,10 @@ for tracking purposes -->
.. py:module:: torch.utils.data.sampler
.. py:module:: torch.utils.dlpack
.. py:module:: torch.utils.file_baton
.. py:module:: torch.utils.flop_counter
.. py:module:: torch.utils.hipify.constants
.. py:module:: torch.utils.hipify.cuda_to_hip_mappings
.. py:module:: torch.utils.hipify.hipify_python
.. py:module:: torch.utils.hipify.version
.. py:module:: torch.utils.hooks
.. py:module:: torch.utils.jit.log_extract

View File

@ -172,9 +172,9 @@ ignore = [
"SIM102", "SIM103", "SIM112", # flake8-simplify code styles
"SIM105", # these ignores are from flake8-simplify. please fix or ignore with commented reason
"SIM108", # SIM108 ignored because we prefer if-else-block instead of ternary expression
"SIM110", # Checks for for loops that can be replaced with a builtin function, like any or all.
"SIM110",
"SIM114", # Combine `if` branches using logical `or` operator
"SIM115", # Checks for cases where files are opened without using a context manager.
"SIM115",
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements
"SIM117",
"SIM118",
@ -184,6 +184,7 @@ ignore = [
"TC006",
# TODO: Remove Python-3.10 specific suppressions
"B905",
"UP035",
]
select = [
"B",

View File

@ -1646,7 +1646,8 @@ def main() -> None:
mirror_files_into_torchgen()
if RUN_BUILD_DEPS:
build_deps()
mirror_inductor_external_kernels()
mirror_inductor_external_kernels()
(
ext_modules,

View File

@ -208,7 +208,7 @@ class _BaseDataSparsiferTestCase(TestCase):
assert len(sparsifier1.data_groups) == len(sparsifier2.data_groups)
state1 = state_dict1["state"]
for name in state1:
for name in state1.keys():
# compare mask
assert name in sparsifier2.state
assert "mask" in sparsifier2.state[name]

View File

@ -119,7 +119,7 @@ class TestBaseSparsifier(TestCase):
for idx in range(len(sparsifier0.groups)):
mg0 = sparsifier0.groups[idx]
mg1 = sparsifier1.groups[idx]
for key in mg0:
for key in mg0.keys():
assert key in mg1
if key == "module":
# We cannot compare modules as they are different

View File

@ -67,13 +67,13 @@ Tensor sgd_out_of_place(
void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = sgd_out_of_place(
torch::stable::detail::to<Tensor>(stack[0]),
torch::stable::detail::to<Tensor>(stack[1]),
float(torch::stable::detail::to<double>(stack[2])),
torch::stable::detail::to<double>(stack[3]),
torch::stable::detail::to<bool>(stack[4]));
to<Tensor>(stack[0]),
to<Tensor>(stack[1]),
float(to<double>(stack[2])),
to<double>(stack[3]),
to<bool>(stack[4]));
stack[0] = torch::stable::detail::from(res);
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
@ -89,8 +89,8 @@ Tensor identity(Tensor t) {
}
void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = identity(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
Tensor res = identity(to<Tensor>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -108,14 +108,14 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
Tensor my_abs(Tensor t) {
const auto num_args = 1;
StableIValue stack[num_args];
stack[0] = torch::stable::detail::from(t);
stack[0] = from(t);
aoti_torch_call_dispatcher("aten::abs", "", stack);
return torch::stable::detail::to<Tensor>(stack[0]);
return to<Tensor>(stack[0]);
}
void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor tensor_res = my_abs(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(tensor_res);
Tensor tensor_res = my_abs(to<Tensor>(stack[0]));
stack[0] = from(tensor_res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -132,21 +132,21 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
auto mf = aoti_torch_memory_format_contiguous_format();
stack[0] = torch::stable::detail::from(t);
stack[1] = torch::stable::detail::from(std::optional(t.scalar_type())); // dtype
stack[2] = torch::stable::detail::from(std::nullopt); // layout
stack[3] = torch::stable::detail::from(std::optional(device)); // device
stack[4] = torch::stable::detail::from(std::optional(false)); // pin_memory
stack[5] = torch::stable::detail::from(std::optional(mf)); // memory_format
stack[0] = from(t);
stack[1] = from(std::optional(t.scalar_type())); // dtype
stack[2] = from(std::nullopt); // layout
stack[3] = from(std::optional(device)); // device
stack[4] = from(std::optional(false)); // pin_memory
stack[5] = from(std::optional(mf)); // memory_format
aoti_torch_call_dispatcher("aten::ones_like", "", stack);
return torch::stable::detail::to<Tensor>(stack[0]);
return to<Tensor>(stack[0]);
}
void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = my_ones_like(torch::stable::detail::to<Tensor>(stack[0]), stack[1]);
stack[0] = torch::stable::detail::from(res);
Tensor res = my_ones_like(to<Tensor>(stack[0]), stack[1]);
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -159,28 +159,28 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) {
StableIValue stack_exp[1];
stack_exp[0] = torch::stable::detail::from(t1);
stack_exp[0] = from(t1);
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
StableIValue stack_neg[1];
stack_neg[0] = torch::stable::detail::from(t2);
stack_neg[0] = from(t2);
aoti_torch_call_dispatcher("aten::neg", "", stack_neg);
StableIValue stack_is_leaf[1];
stack_is_leaf[0] = torch::stable::detail::from(t3);
stack_is_leaf[0] = from(t3);
aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf);
return std::make_tuple(
torch::stable::detail::to<Tensor>(stack_exp[0]),
torch::stable::detail::to<Tensor>(stack_neg[0]),
torch::stable::detail::to<bool>(stack_is_leaf[0]));
to<Tensor>(stack_exp[0]),
to<Tensor>(stack_neg[0]),
to<bool>(stack_is_leaf[0]));
}
void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto tuple = exp_neg_is_leaf(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<Tensor>(stack[2]));
stack[0] = torch::stable::detail::from(std::get<0>(tuple));
stack[1] = torch::stable::detail::from(std::get<1>(tuple));
stack[2] = torch::stable::detail::from(std::get<2>(tuple));
auto tuple = exp_neg_is_leaf(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<Tensor>(stack[2]));
stack[0] = from(std::get<0>(tuple));
stack[1] = from(std::get<1>(tuple));
stack[2] = from(std::get<2>(tuple));
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -193,15 +193,15 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
Tensor neg_exp(Tensor t) {
StableIValue stack[1];
stack[0] = torch::stable::detail::from(t);
stack[0] = from(t);
aoti_torch_call_dispatcher("aten::exp", "", stack);
aoti_torch_call_dispatcher("aten::neg", "", stack);
return torch::stable::detail::to<Tensor>(stack[0]);
return to<Tensor>(stack[0]);
}
void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
Tensor res = neg_exp(to<Tensor>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -214,10 +214,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
Tensor divide_neg_exp(Tensor t) {
StableIValue stack_neg[1];
stack_neg[0] = torch::stable::detail::from(t);
stack_neg[0] = from(t);
StableIValue stack_exp[1];
stack_exp[0] = torch::stable::detail::from(t);
stack_exp[0] = from(t);
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
aoti_torch_call_dispatcher("aten::neg", "", stack_neg);
@ -225,12 +225,12 @@ Tensor divide_neg_exp(Tensor t) {
stack_div[0] = stack_neg[0];
stack_div[1] = stack_exp[0];
aoti_torch_call_dispatcher("aten::divide", "Tensor", stack_div);
return torch::stable::detail::to<Tensor>(stack_div[0]);
return to<Tensor>(stack_div[0]);
}
void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = divide_neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
Tensor res = divide_neg_exp(to<Tensor>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -246,8 +246,8 @@ bool is_contiguous(Tensor t) {
}
void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
bool res = is_contiguous(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
bool res = is_contiguous(to<Tensor>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -263,9 +263,9 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
}
void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_transpose(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<int64_t>(stack[1]), torch::stable::detail::to<int64_t>(stack[2]));
auto res = my_transpose(to<Tensor>(stack[0]), to<int64_t>(stack[1]), to<int64_t>(stack[2]));
stack[0] = torch::stable::detail::from(res);
stack[0] = from(res);
}
Tensor my_empty_like(Tensor t) {
@ -273,8 +273,8 @@ Tensor my_empty_like(Tensor t) {
}
void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_empty_like(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_empty_like(to<Tensor>(stack[0]));
stack[0] = from(res);
}
bool my_is_cpu(Tensor t) {
@ -283,8 +283,8 @@ bool my_is_cpu(Tensor t) {
void boxed_my_is_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_is_cpu(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_is_cpu(to<Tensor>(stack[0]));
stack[0] = from(res);
}
Tensor fill_infinity(Tensor t) {
@ -296,8 +296,8 @@ void boxed_fill_infinity(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
auto res = fill_infinity(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = fill_infinity(to<Tensor>(stack[0]));
stack[0] = from(res);
}
Tensor my_pad(Tensor t) {
@ -310,8 +310,8 @@ void boxed_my_pad(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
auto res = my_pad(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_pad(to<Tensor>(stack[0]));
stack[0] = from(res);
}
Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) {
@ -323,11 +323,11 @@ void boxed_my_narrow(
uint64_t num_args,
uint64_t num_outputs) {
auto res = my_narrow(
torch::stable::detail::to<Tensor>(stack[0]),
torch::stable::detail::to<int64_t>(stack[1]),
torch::stable::detail::to<int64_t>(stack[2]),
torch::stable::detail::to<int64_t>(stack[3]));
stack[0] = torch::stable::detail::from(res);
to<Tensor>(stack[0]),
to<int64_t>(stack[1]),
to<int64_t>(stack[2]),
to<int64_t>(stack[3]));
stack[0] = from(res);
}
Tensor my_new_empty_dtype_variant(Tensor t) {
@ -342,8 +342,8 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
}
void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_new_empty_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_new_empty_dtype_variant(to<Tensor>(stack[0]));
stack[0] = from(res);
}
Tensor my_new_zeros_dtype_variant(Tensor t) {
@ -352,8 +352,8 @@ Tensor my_new_zeros_dtype_variant(Tensor t) {
}
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_new_zeros_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_new_zeros_dtype_variant(to<Tensor>(stack[0]));
stack[0] = from(res);
}
Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
@ -361,8 +361,8 @@ Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
}
void boxed_my_copy_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor tensor_res = my_copy_(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<bool>(stack[2]));
stack[0] = torch::stable::detail::from(tensor_res);
Tensor tensor_res = my_copy_(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<bool>(stack[2]));
stack[0] = from(tensor_res);
}
Tensor my_clone(Tensor t) {
@ -370,8 +370,8 @@ Tensor my_clone(Tensor t) {
}
void boxed_my_clone(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor tensor_res = my_clone(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(tensor_res);
Tensor tensor_res = my_clone(to<Tensor>(stack[0]));
stack[0] = from(tensor_res);
}
@ -408,8 +408,8 @@ Tensor my_zero_(Tensor t) {
}
void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_zero_(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_zero_(to<Tensor>(stack[0]));
stack[0] = from(res);
}
Tensor my_amax(Tensor t) {
@ -417,8 +417,8 @@ Tensor my_amax(Tensor t) {
}
void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_amax(to<Tensor>(stack[0]));
stack[0] = from(res);
}
Tensor my_amax_vec(Tensor t) {
@ -426,8 +426,8 @@ Tensor my_amax_vec(Tensor t) {
}
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax_vec(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_amax_vec(to<Tensor>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -464,8 +464,8 @@ void boxed_test_default_constructor(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
bool res = test_default_constructor(torch::stable::detail::to<bool>(stack[0]));
stack[0] = torch::stable::detail::from(res);
bool res = test_default_constructor(to<bool>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -478,56 +478,6 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_amax_vec", &boxed_my_amax_vec);
}
std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data());
return torch::stable::detail::to<std::vector<Tensor>>(stack[0]);
}
void boxed_my__foreach_mul(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
// Why is the following NOT torch::stable::detail::to<HeaderOnlyArrayRef<Tensor>>(stack[0])? Because calling `to`
// on a StableIValue means that the result is owning its underlying data now! HeaderOnlyArrayRef
// is not owning, so it cannot safely steward the result of the torch::stable::detail::to<>.
auto res = my__foreach_mul(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
stack[0] = torch::stable::detail::from(res);
}
void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data());
}
void boxed_my__foreach_mul_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
my__foreach_mul_(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
}
std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
// This function tests that my__foreach_mul can take in std::initializer_lists
// in addition to std::vectors.
Tensor t1_1 = my_clone(t1);
Tensor t1_2 = my_clone(t1);
Tensor t2_1 = my_clone(t2);
Tensor t2_2 = my_clone(t2);
return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2});
}
void boxed_make_tensor_clones_and_call_foreach(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = make_tensor_clones_and_call_foreach(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]");
m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()");
m.def("make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my__foreach_mul", &boxed_my__foreach_mul);
m.impl("my__foreach_mul_", &boxed_my__foreach_mul_);
m.impl("make_tensor_clones_and_call_foreach", &boxed_make_tensor_clones_and_call_foreach);
}
// Test functions for torch::stable::accelerator APIs
#ifdef LAE_USE_CUDA
@ -550,8 +500,8 @@ void boxed_test_device_guard(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int res = test_device_guard(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
stack[0] = torch::stable::detail::from(res);
int res = test_device_guard(static_cast<int64_t>(to<int64_t>(stack[0])));
stack[0] = from(res);
}
int64_t test_device_guard_set_index() {
@ -570,7 +520,7 @@ void boxed_test_device_guard_set_index(
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_device_guard_set_index();
stack[0] = torch::stable::detail::from(res);
stack[0] = from(res);
}
int64_t test_stream(int32_t device_index) {
@ -586,8 +536,8 @@ void boxed_test_stream(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_stream(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
stack[0] = torch::stable::detail::from(res);
int64_t res = test_stream(static_cast<int64_t>(to<int64_t>(stack[0])));
stack[0] = from(res);
}
int64_t test_get_current_device_index() {
@ -599,7 +549,7 @@ void boxed_test_get_current_device_index(
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_get_current_device_index();
stack[0] = torch::stable::detail::from(res);
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -615,5 +565,4 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_stream", &boxed_test_stream);
m.impl("test_get_current_device_index", &boxed_test_get_current_device_index);
}
#endif // LAE_USE_CUDA

View File

@ -333,45 +333,3 @@ def my_new_zeros_dtype_variant(t) -> Tensor:
Returns: New zeros tensor
"""
return torch.ops.libtorch_agnostic.my_new_zeros_dtype_variant.default(t)
def my__foreach_mul_(tensors, others) -> ():
"""
Updates tensors to be the result of pointwise multiplying with others.
Args:
tensors: list of tensors
others: list of tensors (with the same corresponding shapes as tensors)
Returns: nothing, tensors is updated in place.
"""
torch.ops.libtorch_agnostic.my__foreach_mul_.default(tensors, others)
def my__foreach_mul(tensors, others) -> list[Tensor]:
"""
Returns a list of tensors that are the results of pointwise multiplying
tensors and others.
Args:
tensors: list of tensors
others: list of tensors (with the same corresponding shapes as tensors)
Returns: list of multiplied tensors
"""
return torch.ops.libtorch_agnostic.my__foreach_mul.default(tensors, others)
def make_tensor_clones_and_call_foreach(t1, t2) -> list[Tensor]:
"""
Returns a list of 2 tensors corresponding to the square of the inputs.
Args:
t1: Tensor
t2: Tensor
Returns: list of [t1^2, t2^2]
"""
return torch.ops.libtorch_agnostic.make_tensor_clones_and_call_foreach.default(
t1, t2
)

View File

@ -367,57 +367,6 @@ if not IS_WINDOWS:
self.assertNotEqual(result.data_ptr(), expected.data_ptr())
self.assertEqual(result.stride(), expected.stride())
def test_my__foreach_mul_(self, device):
import libtorch_agnostic
N = 5
tensors = [torch.rand(32, 16, device=device) for _ in range(N)]
tensors_c = [t.clone() for t in tensors]
others = [torch.rand(32, 16, device=device) for _ in range(N)]
libtorch_agnostic.ops.my__foreach_mul_(tensors, others)
expected_values = torch._foreach_mul(tensors_c, others)
for tensor_t, expected_t in zip(tensors, expected_values):
self.assertEqual(tensor_t, expected_t)
def test_my__foreach_mul(self, device):
import libtorch_agnostic
N = 5
tensors = [torch.rand(32, 16, device=device) for _ in range(N)]
others = [torch.rand(32, 16, device=device) for _ in range(N)]
result = libtorch_agnostic.ops.my__foreach_mul(tensors, others)
expected = torch._foreach_mul(tensors, others)
for result_t, expected_t in zip(result, expected):
self.assertEqual(result_t, expected_t)
def _make_cuda_tensors(prior_mem):
cuda_res = libtorch_agnostic.ops.my__foreach_mul(tensors, others)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
expected = torch._foreach_mul(tensors, others)
for result_t, expected_t in zip(cuda_res, expected):
self.assertEqual(result_t, expected_t)
if tensors[0].is_cuda:
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_make_cuda_tensors(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
def test_make_tensor_clones_and_call_foreach(self, device):
import libtorch_agnostic
t1 = torch.rand(2, 5, device=device)
t2 = torch.rand(3, 4, device=device)
result = libtorch_agnostic.ops.make_tensor_clones_and_call_foreach(t1, t2)
self.assertEqual(result[0], t1 * t1)
self.assertEqual(result[1], t2 * t2)
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: unknown"]
import os
import tempfile
from backend import get_custom_backend_library_path, Model, to_custom_backend
@ -40,11 +41,14 @@ class TestCustomBackend(TestCase):
self.test_execute()
# Save and load.
with tempfile.NamedTemporaryFile() as f:
f = tempfile.NamedTemporaryFile(delete=False)
try:
f.close()
torch.jit.save(self.model, f.name)
loaded = torch.jit.load(f.name)
self.model = loaded
finally:
os.unlink(f.name)
self.model = loaded
# Test execution again.
self.test_execute()

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: unknown"]
import os.path
import sys
import tempfile
import unittest
@ -143,13 +144,16 @@ def forward(self, arg0_1):
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# close the file after creation and try to remove it manually.
with tempfile.NamedTemporaryFile() as file:
file = tempfile.NamedTemporaryFile(delete=False)
try:
file.close()
model.save(file.name)
loaded = torch.jit.load(file.name)
finally:
os.unlink(file.name)
output = loaded.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
output = loaded.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
if __name__ == "__main__":

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: fsdp"]
import functools
import os
import unittest
import unittest.mock
import torch.distributed as dist
from torch._dynamo.test_case import run_tests
@ -37,9 +37,9 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import fully_shard
logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
logger = logging.getLogger("torch.distributed._composable.fsdp")
logger.setLevel(logging.DEBUG)
device = '{device_type.type}'
device = {device_type.type}
torch.manual_seed(0)
model = nn.Sequential(*[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)])
for layer in model:

View File

@ -76,7 +76,7 @@ class ReplicateTest(MultiProcessTestCase):
store=dist.FileStore(self.file_name, self.world_size),
)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_replicate_transformer(self):
"""
This tests that replicate works on a transformer model with fully_shard and replicate layers
@ -126,7 +126,7 @@ class ReplicateTest(MultiProcessTestCase):
for parameter in layer.parameters():
self.assertEqual(parameter.placements, (Shard(dim=0),))
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_replicate_transformer_managed_modules(self):
"""
This tests that replicate managed modules works properly. In this test we use a Transformer Module with 3 layers,
@ -178,7 +178,7 @@ class ReplicateTest(MultiProcessTestCase):
replicate_model = replicate(replicate_model)
self.assertEqual(len(_get_managed_modules((replicate_model,))), 21)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_replicate_tp_device_mesh(self):
"""
This tests that a user can pass in a device mesh to replicate a module
@ -206,7 +206,7 @@ class ReplicateTest(MultiProcessTestCase):
self.assertEqual(parameter.device_mesh.shape, (2,))
self.assertEqual(parameter.placements, (Replicate(),))
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_train_replicate_fsdp(self):
"""
Tests that replicate_model has the same behavior as original model when training
@ -253,7 +253,7 @@ class ReplicateTest(MultiProcessTestCase):
self.assertEqual(replicate_loss, loss)
check_sharded_parity(self, model, replicate_model)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_train_parity_2d_mlp(self):
"""
Verifies when a device mesh is passed in, the model has the same behavior as the original model when training

View File

@ -80,7 +80,7 @@ class TestSACILP(TestCase):
# postprocessing due to the fact that for ModTracker, the post backward hook
# is not being called for modules whose inputs don't require gradients
# TODO: fix this in ModTracker and ensure it does not lead to any perf regression
if _ModState.POST_BW not in mod_stats.snapshots:
if _ModState.POST_BW not in mod_stats.snapshots.keys():
mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append(
copy.deepcopy(last_snapshot)
)

View File

@ -16,7 +16,7 @@ from torch.distributed.argparse_util import check_env, env
class ArgParseUtilTest(unittest.TestCase):
def setUp(self):
# remove any lingering environment variables
for e in os.environ.keys(): # noqa: SIM118
for e in os.environ.keys():
if e.startswith("PET_"):
del os.environ[e]

View File

@ -207,7 +207,7 @@ class TestDefaultStager(TestCase):
for i, result in enumerate(staged_results):
self.assertIsInstance(result, dict)
# Verify the result contains the expected keys
for key in state_dicts[i]:
for key in state_dicts[i].keys():
self.assertIn(key, result)
stager.close()

View File

@ -299,7 +299,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_dtensor_checkpoint_with_uneven_shards(self) -> None:
"""
Saving a dtensor with uneven shards.
@ -436,7 +436,6 @@ class TestCheckpointableReshard(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_uneven_reshard_with_checkpointable_api(self) -> None:
"""
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.
@ -499,7 +498,6 @@ class TestCheckpointableReshard(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_uneven_reshard_with_dtensor_shards_wrapper_api(self) -> None:
"""
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.

View File

@ -60,7 +60,7 @@ class TestSingleRankSaveLoad(TestCase):
self.assertEqual(
sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys())
)
for key in state_dict_to_save:
for key in state_dict_to_save.keys():
self.assertTrue(
torch.equal(state_dict_to_save[key], state_dict_loaded[key])
)
@ -89,7 +89,7 @@ class TestSingleRankSaveLoad(TestCase):
self.assertEqual(
sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys())
)
for key in state_dict_to_save:
for key in state_dict_to_save.keys():
self.assertTrue(
torch.equal(state_dict_to_save[key], state_dict_to_load[key])
)
@ -116,7 +116,7 @@ class TestSingleRankSaveLoad(TestCase):
self.assertEqual(
sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys())
)
for key in state_dict_to_save:
for key in state_dict_to_save.keys():
self.assertTrue(
torch.equal(state_dict_to_save[key], state_dict_loaded[key])
)
@ -156,7 +156,7 @@ class TestSingleRankSaveLoad(TestCase):
self.assertEqual(
sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys())
)
for key in state_dict_to_save:
for key in state_dict_to_save.keys():
self.assertTrue(
torch.equal(state_dict_to_save[key], state_dict_to_load[key])
)

View File

@ -18,7 +18,6 @@ from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
from torch.distributed.checkpoint.api import CheckpointException
from torch.distributed.checkpoint.default_planner import (
_create_default_local_metadata,
_validate_global_plan,
create_default_global_save_plan,
create_default_local_load_plan,
create_default_local_save_plan,
@ -29,7 +28,6 @@ from torch.distributed.checkpoint.filesystem import CURRENT_DCP_VERSION
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
MetadataIndex,
TensorProperties,
TensorStorageMetadata,
@ -562,32 +560,6 @@ class TestPlannerHelpers(TestCase):
self.assertTrue(_compare_save_plans(plan2, plan2))
class TestValidateGlobalPlan(TestCase):
def _make_metadata(self, chunks, size):
storage = TensorStorageMetadata(
properties=TensorProperties(dtype=torch.float32),
size=torch.Size(size),
chunks=chunks,
)
return Metadata(state_dict_metadata={"param": storage})
def test_non_overlapping_chunks(self):
chunks = [
ChunkStorageMetadata(offsets=torch.Size([i]), sizes=torch.Size([1]))
for i in range(4)
]
metadata = self._make_metadata(chunks, [4])
self.assertTrue(_validate_global_plan([SavePlan([])], metadata))
def test_detect_overlapping_chunks(self):
chunks = [
ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([2])),
ChunkStorageMetadata(offsets=torch.Size([1]), sizes=torch.Size([2])),
]
metadata = self._make_metadata(chunks, [4])
self.assertFalse(_validate_global_plan([SavePlan([])], metadata))
class TestLoadPlanner(TestCase):
@with_temp_dir
def test_strict(self):

View File

@ -769,7 +769,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
model_state_dict3 = copy.deepcopy(model_state_dict3)
self.assertEqual(len(model_state_dict2), 2)
self.assertEqual(len(model_state_dict3), 2)
for key in model_state_dict3:
for key in model_state_dict3.keys():
full_fqn = f"l.{key}"
value1 = model_state_dict1[full_fqn]
value2 = model_state_dict2[full_fqn]
@ -886,7 +886,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self.assertEqual(cpu_model_value, meta_model_value)
@with_comms
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_setting_meta_device_model_broadcasting_and_memory(self) -> None:
# This test verifies that we can set model state dict by a meta device model
# With the correlated changes in state_dict, meta device model should be accepted

View File

@ -587,7 +587,9 @@ class TestFSDPStateDict(FSDPTest):
model, cpu_offload.offload_params, fp16
)
ignore_keys = [k for k in fsdp_state_dict if NON_ROOT_FSDP_PREFIX in k]
ignore_keys = [
k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k
]
self._validate_state_dict_contents(
model,
@ -908,7 +910,7 @@ class TestFSDPStateDict(FSDPTest):
with sd_mgr:
fsdp_state_dict = model.state_dict()
ignore_keys = [k for k in fsdp_state_dict if NON_ROOT_FSDP_PREFIX in k]
ignore_keys = [k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k]
self._validate_state_dict_contents(
model,
fsdp_state_dict,
@ -957,7 +959,9 @@ class TestFSDPStateDict(FSDPTest):
# Full name of linear_skip param tensors in SkipModel, as would be
# stored in checkpoint.
linear_skip_tensor_names = [
k for k in dict(module.named_parameters()) if LINEAR_SKIP in k
k
for k in dict(module.named_parameters()).keys()
if LINEAR_SKIP in k
]
# skip SkipModule
linear_skip = getattr(module, LINEAR_SKIP)

View File

@ -137,7 +137,7 @@ class ElasticLaunchTest(unittest.TestCase):
self.test_dir = tempfile.mkdtemp()
# remove any lingering environment variables.
for env in os.environ.keys(): # noqa: SIM118
for env in os.environ.keys():
if env.startswith("PET_"):
del os.environ[env]

View File

@ -69,7 +69,7 @@ class ElasticLaunchTest(TestCase):
self.test_dir = tempfile.mkdtemp()
# remove any lingering environment variables
for env in os.environ.keys(): # noqa: SIM118
for env in os.environ.keys():
if env.startswith("PET_"):
del os.environ[env]

View File

@ -39,7 +39,6 @@ from torch.nn.modules.loss import MSELoss
from torch.testing._internal.common_distributed import (
MultiProcContinuousTest,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
check_leaked_tensors,
@ -232,7 +231,6 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [_ScheduleForwardOnly])
@skip_if_lt_x_gpu(4)
def test_forward_only(self, ScheduleClass):
mod, mod_ref, x, _, _ = setup_models_and_data(self.config)
x_clone = x.clone()
@ -276,7 +274,6 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
@skip_if_lt_x_gpu(4)
def test_eval_inference_mode(self, ScheduleClass):
num_microbatches = 4
if ScheduleClass in [
@ -354,7 +351,6 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
@skip_if_lt_x_gpu(4)
def test_return_output(self, ScheduleClass):
num_microbatches = 4
if ScheduleClass in [
@ -410,7 +406,6 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@skip_if_lt_x_gpu(4)
def test_multi_iter(self, ScheduleClass):
mod, _, x, target, loss_fn = setup_models_and_data(self.config)
chunks = 4
@ -434,7 +429,6 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@skip_if_lt_x_gpu(4)
def test_kwargs_with_tracer(self, ScheduleClass):
mod = ModelWithKwargs(d_hid, splits=self.world_size)
mod.to(self.device)
@ -487,7 +481,6 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@skip_if_lt_x_gpu(4)
def test_grad_with_tracer(self, ScheduleClass):
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
@ -530,7 +523,6 @@ class ScheduleTest(MultiProcContinuousTest):
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@parametrize("shape_inference", [True, False])
@skip_if_lt_x_gpu(4)
def test_grad_with_manual(self, ScheduleClass, shape_inference):
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
@ -594,7 +586,6 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
@skip_if_lt_x_gpu(4)
def test_grad_with_manual_interleaved(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -659,7 +650,6 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble])
@skip_if_lt_x_gpu(4)
def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -746,7 +736,6 @@ class ScheduleTest(MultiProcContinuousTest):
"schedule_class",
[ScheduleZBVZeroBubble, ScheduleDualPipeV],
)
@skip_if_lt_x_gpu(4)
def test_v_shape_schedules(self, schedule_class):
n_stages = 8
rank_stages = {0: [0, 7], 1: [1, 6], 2: [2, 5], 3: [3, 4]}
@ -791,7 +780,6 @@ class ScheduleTest(MultiProcContinuousTest):
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@skip_if_lt_x_gpu(4)
def test_custom_function_callback(self):
"""Test the custom function callback functionality with _PipelineScheduleRuntime."""
n_stages = 8
@ -991,7 +979,6 @@ class ScheduleTest(MultiProcContinuousTest):
"ScheduleClass",
[ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B],
)
@skip_if_lt_x_gpu(4)
def test_zero_bubble_with_model_kwargs(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -1085,7 +1072,6 @@ class CustomSchedulesTest(MultiProcContinuousTest):
"schedule_class",
[ScheduleVShaped, ScheduleUnbalanced],
)
@skip_if_lt_x_gpu(4)
def test_non_symmetric_stage_ids(self, schedule_class):
n_stages = schedule_class.n_stages
rank_stages = schedule_class.rank_stages
@ -1135,7 +1121,6 @@ class CustomSchedulesTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleWithReorderedB])
@skip_if_lt_x_gpu(4)
def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass):
n_stages = 2
stages_per_rank = 1
@ -1196,7 +1181,6 @@ class CustomSchedulesTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleWithW])
@skip_if_lt_x_gpu(4)
def test_schedule_with_native_zero_bubble(self, ScheduleClass):
n_stages = ScheduleClass.n_stages
num_microbatches = ScheduleClass.num_microbatches

View File

@ -204,16 +204,14 @@ class DistConvolutionOpsTest(DTensorTestBase):
self.assertTrue(b_dt.grad is not None)
self.assertTrue(x_dt.grad is None)
def _run_single_arg_fwd(
self, model, arg, placements=None
) -> tuple[torch.Tensor, torch.Tensor]:
def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]:
"""Given model and arg, runs fwd model local and distbuted given device_mesh"""
device_mesh = self.build_device_mesh()
model_copy = copy.deepcopy(model).to(device=self.device_type)
dist_model = distribute_module(model, device_mesh, _conv_fn)
arg_dt = DTensor.from_local(arg, device_mesh, placements)
arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()])
out_dt = dist_model(arg_dt.to(device=self.device_type))
out = model_copy(arg_dt.full_tensor())
out = model_copy(arg)
return (out_dt.full_tensor(), out)
@with_comms
@ -221,20 +219,22 @@ class DistConvolutionOpsTest(DTensorTestBase):
model = nn.Conv1d(64, 64, 3, padding=1)
x = torch.randn(1, 64, 8, device=self.device_type)
out_dt, out = self._run_single_arg_fwd(model, x)
self.assertEqual(out_dt, out)
self.assertEqual(out_dt.shape, out.shape)
@with_comms
def test_conv3d(self):
model = nn.Conv3d(64, 64, 3, padding=1)
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
out_dt, out = self._run_single_arg_fwd(model, x, [Shard(0)])
self.assertEqual(out_dt, out)
out_dt, out = self._run_single_arg_fwd(model, x)
self.assertEqual(out_dt.shape, out.shape)
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
DistConvolutionOpsTest,
# Send / recv ops are not supported
skipped_tests=[
"test_conv1d",
"test_conv3d",
"test_conv_backward_none_grad_inp",
"test_depthwise_convolution",
"test_downsampling_convolution",

View File

@ -464,25 +464,6 @@ def forward(self, b_parametrizations_buffer_original0, x):
run(g, 64, 8)
self.assertEqual(cnt.frame_count, 2)
def test_dtensor_requires_grad_recompile(self):
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
y = x * x
return y.to_local()
full_x = torch.randn(8, 8, requires_grad=False)
x = distribute_tensor(full_x, mesh, [Shard(0)])
f(x)
full_x = torch.randn(8, 8, requires_grad=True)
x = distribute_tensor(full_x, mesh, [Shard(0)])
f(x)
self.assertEqual(cnt.frame_count, 2)
def test_dtensor_attribute_access_on_intermediate(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

View File

@ -535,19 +535,6 @@ class DTensorExportTest(TestCase):
self.assertEqual(fn(z), gm(z)[0])
def test_dtensor_data_dependent_index(self):
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
class Foo(torch.nn.Module):
def forward(self, x, y):
return x[y]
x = torch.randn(10)
y = torch.randint(1, (10,)).bool()
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()])
y_dt = distribute_tensor(y, device_mesh, placements=[Replicate()])
_dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
instantiate_parametrized_tests(DTensorExportTest)

View File

@ -26,7 +26,6 @@ from torch.distributed.tensor.parallel import (
RowwiseParallel,
SequenceParallel,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
@ -765,7 +764,6 @@ class DistMathOpsTest(DTensorTestBase):
self.assertEqual(grad1_norm.device_mesh, mesh_y)
@with_comms
@skip_if_lt_x_gpu(4)
def test_foreach_add_different_mesh(self):
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(

View File

@ -577,7 +577,7 @@ class DistTensorReplicateStrategyRegistrationTest(DTensorTestBase):
self.assertEqual(
comm_mode.get_comm_counts(),
{
torch.ops.c10d_functional.all_gather_into_tensor: self.world_size,
torch.ops.c10d_functional.all_gather_into_tensor: 4,
},
)
expected_cost = [

View File

@ -2,6 +2,7 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import copy
import itertools
import unittest
@ -21,8 +22,9 @@ from torch.distributed.tensor import (
)
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import _StridedShard, MaskPartial
from torch.distributed.tensor.placement_types import _StridedShard
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -33,11 +35,7 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorTestBase,
generate_shard_orders,
make_full_tensor,
map_local_tensor_for_rank,
patched_distribute_tensor as _distribute_tensor,
redistribute,
with_comms,
)
from torch.utils._debug_mode import DebugMode
@ -787,6 +785,88 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
else:
return ""
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def redistribute(
self,
dtensor_input,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""
wrapper function to support shard_order for redistribution
This is a simpler version of Redistribute, only considers the forward.
"""
if placements is None:
placements = self._shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
old_spec = dtensor_input._spec
new_spec = copy.deepcopy(old_spec)
new_spec.placements = placements
if shard_order is not None:
new_spec.shard_order = shard_order
else:
new_spec.shard_order = ()
if old_spec == new_spec:
return dtensor_input
dtensor_input = DTensor.from_local(
redistribute_local_tensor(
dtensor_input.to_local(),
old_spec,
new_spec,
use_graph_based_transform=use_graph_based_transform,
),
device_mesh,
)
dtensor_input._spec = copy.deepcopy(new_spec)
return dtensor_input # returns DTensor
# TODO(zpcore): remove once the native distribute_tensor supports
# shard_order arg
def distribute_tensor(
self,
input_tensor,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""wrapper function to support shard_order for tensor distribution"""
if placements is None:
placements = self._shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
# fix the shard order
return self.redistribute(
tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
)
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def full_tensor(self, dtensor_input):
"""wrapper function to support DTensor.full_tensor"""
return self.redistribute(
dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
).to_local()
def _shard_order_to_placement(self, shard_order, mesh):
"""convert shard_order to placement with only Replicate() and Shard()"""
placements = [Replicate() for _ in range(mesh.ndim)]
if shard_order is not None:
for entry in shard_order:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
for mesh_dim in mesh_dims:
placements[mesh_dim] = Shard(tensor_dim)
return tuple(placements)
def _convert_shard_order_dict_to_ShardOrder(self, shard_order):
"""Convert shard_order dict to ShardOrder"""
return tuple(
ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
for tensor_dim, mesh_dims in shard_order.items()
)
@with_comms
def test_ordered_redistribute(self):
"""Test ordered redistribution with various sharding syntaxes"""
@ -847,11 +927,13 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
for idx, ((src_placement, src_order), (dst_placement, dst_order)) in enumerate(
sharding_src_dst_pairs_with_expected_trace
):
sharded_dt = _distribute_tensor(
sharded_dt = self.distribute_tensor(
input_data.clone(), mesh, src_placement, shard_order=src_order
)
with DebugMode(record_torchfunction=False) as debug_mode:
sharded_dt = redistribute(sharded_dt, mesh, dst_placement, dst_order)
sharded_dt = self.redistribute(
sharded_dt, mesh, dst_placement, dst_order
)
trace_str = self._extract_redistribute_trace_from_debug_mode(
debug_mode.debug_string()
)
@ -875,11 +957,49 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
trace_str,
"""S(0)[0]S(0)[1]R->S(0)S(1)R->RS(1)R->RS(1)S(0)""",
)
expected_dt = _distribute_tensor(
expected_dt = self.distribute_tensor(
input_data.clone(), mesh, dst_placement, shard_order=dst_order
)
self.assertEqual(sharded_dt.to_local(), expected_dt.to_local())
def generate_shard_orders(self, mesh, tensor_rank):
# Generate all possible sharding placement of tensor with rank
# `tensor_rank` over mesh.
def _split_list(lst: list, N: int):
def compositions(n, k):
if k == 1:
yield [n]
else:
for i in range(1, n - k + 2):
for tail in compositions(n - i, k - 1):
yield [i] + tail
length = len(lst)
for comp in compositions(length, N):
result = []
start = 0
for size in comp:
result.append(lst[start : start + size])
start += size
yield result
all_mesh = list(range(mesh.ndim))
all_device_order = list(itertools.permutations(all_mesh))
for device_order in all_device_order:
# split on device orders, and assign each device order segment to a tensor dim
for num_split in range(1, mesh.ndim + 1):
for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
for tensor_dims in itertools.combinations(
range(tensor_rank), len(splitted_list)
):
shard_order = {}
assert len(tensor_dims) == len(splitted_list)
for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
shard_order[tensor_dim] = device_order[
mesh_dims[0] : mesh_dims[-1] + 1
]
yield self._convert_shard_order_dict_to_ShardOrder(shard_order)
@with_comms
def test_generate_shard_orders(self):
"""Check if `generate_shard_orders` generates unique sharding combinations"""
@ -892,7 +1012,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
]
for test_input in test_inputs:
all_combinations = []
for shard_order in generate_shard_orders(
for shard_order in self.generate_shard_orders(
test_input["mesh"], test_input["tensor_rank"]
):
all_combinations.append(shard_order) # noqa: PERF402
@ -942,12 +1062,12 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
input_data = torch.randn(tensor_shape, device=self.device_type)
tensor_rank = input_data.ndim
with maybe_disable_local_tensor_mode():
shard_orders = generate_shard_orders(mesh, tensor_rank)
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
for shard_order in shard_orders:
sharded_dt = _distribute_tensor(
sharded_dt = self.distribute_tensor(
input_data.clone(), mesh, placements=None, shard_order=shard_order
)
self.assertEqual(make_full_tensor(sharded_dt), input_data)
self.assertEqual(self.full_tensor(sharded_dt), input_data)
# 2. Verify the correctness of redistribution from DTensor to DTensor.
# This test repeatedly redistributes a DTensor to various ordered
@ -958,20 +1078,20 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
tensor_rank = input_data.ndim
prev_sharded_dt = None
with maybe_disable_local_tensor_mode():
shard_orders = generate_shard_orders(mesh, tensor_rank)
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
for shard_order in shard_orders:
if prev_sharded_dt is None:
prev_sharded_dt = _distribute_tensor(
prev_sharded_dt = self.distribute_tensor(
input_data.clone(),
mesh,
placements=None,
shard_order=shard_order,
)
else:
sharded_dt = redistribute(
sharded_dt = self.redistribute(
prev_sharded_dt, mesh, placements=None, shard_order=shard_order
)
self.assertEqual(make_full_tensor(sharded_dt), input_data)
self.assertEqual(self.full_tensor(sharded_dt), input_data)
prev_sharded_dt = sharded_dt
@with_comms
@ -1016,13 +1136,13 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
local_tensor = torch.randn(shape, device=self.device_type)
full_tensor = DTensor.from_local(local_tensor, mesh, placements)
with maybe_disable_local_tensor_mode():
shard_orders = generate_shard_orders(mesh, len(shape))
shard_orders = self.generate_shard_orders(mesh, len(shape))
for shard_order in shard_orders:
sharded_dt = redistribute(
sharded_dt = self.redistribute(
full_tensor, mesh, placements=None, shard_order=shard_order
)
self.assertEqual(
make_full_tensor(sharded_dt), make_full_tensor(full_tensor)
self.full_tensor(sharded_dt), self.full_tensor(full_tensor)
)
@unittest.skip(
@ -1032,20 +1152,24 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
@with_comms
def test_ordered_redistribute_for_special_placement(self):
"""Test ordered redistribution with special placement"""
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
torch.manual_seed(21)
mesh = init_device_mesh(self.device_type, (8,))
input_data = torch.randn((8, 8), device=self.device_type)
src_placement = [Shard(1)]
tgt_placement = [
(MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
(_MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
]
sharded_dt = _distribute_tensor(
sharded_dt = self.distribute_tensor(
input_data.clone(),
mesh,
src_placement,
shard_order=(ShardOrderEntry(tensor_dim=1, mesh_dims=(0,)),),
)
sharded_dt = redistribute(sharded_dt, mesh, tgt_placement, shard_order=None)
sharded_dt = self.redistribute(
sharded_dt, mesh, tgt_placement, shard_order=None
)
@with_comms
def test_shard_order_same_data_as_strided_shard(self):
@ -1055,7 +1179,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
# specify right-to-left order use ordered shard
x_ordered_dt = _distribute_tensor(
x_ordered_dt = self.distribute_tensor(
x,
device_mesh,
placements=[Shard(0), Shard(0)],

View File

@ -34,10 +34,6 @@ from torch.distributed.tensor.placement_types import (
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
generate_shard_orders,
LocalDTensorTestBase,
patched_distribute_tensor as _distribute_tensor,
shard_order_to_placement,
with_comms,
)
@ -778,63 +774,6 @@ class TestStridedSharding(DTensorTestBase):
self.assertEqual(dtensor.full_tensor(), tensor)
class Test_StridedShard_with_shard_order(LocalDTensorTestBase):
@property
def world_size(self) -> int:
return 32
@with_comms
def test_StridedShard_to_shard_order(self):
with LocalTensorMode(ranks=self.world_size):
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(2, 2, 2, 2, 2))
shard_iter = generate_shard_orders(mesh, 3)
# It takes ~4.8h to complete total 2520 shard order combinations here
# using LocalTensor. So we only randomly pick 25 shard orders to test.
all_shard_order = list(shard_iter)
import random
random.seed(42)
shard_order_choices = random.sample(
all_shard_order, min(25, len(all_shard_order))
)
x = torch.randn(32, 32, 32)
for shard_order in shard_order_choices:
a = _distribute_tensor(x, mesh, None, shard_order)
placement_without_stridedshard = shard_order_to_placement(
shard_order, mesh
)
placements_with_stridedshard = (
DTensorSpec._convert_shard_order_to_StridedShard(
shard_order, placement_without_stridedshard, mesh
)
)
b = distribute_tensor(x, mesh, placements_with_stridedshard)
shard_order_from_stridedshard = (
DTensorSpec._maybe_convert_StridedShard_to_shard_order(
placements_with_stridedshard, mesh
)
)
self.assertEqual(shard_order, shard_order_from_stridedshard)
self.assertEqual(a.to_local(), b.to_local())
@with_comms
def test_StridedShard_not_convertible_to_shard_order(self):
with LocalTensorMode(ranks=self.world_size):
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(4, 8))
unconvertible_placements_list = [
[_StridedShard(0, split_factor=2), _StridedShard(1, split_factor=2)],
[_StridedShard(0, split_factor=2), Shard(1)],
[_StridedShard(1, split_factor=16), Shard(1)],
]
for placements in unconvertible_placements_list:
shard_order = DTensorSpec._maybe_convert_StridedShard_to_shard_order(
tuple(placements), mesh
)
self.assertIsNone(shard_order)
class Test2DStridedLocalShard(DTensorTestBase):
@property
def world_size(self):

View File

@ -54,7 +54,6 @@ def apply_reordering_and_get_graph(graph, out_li) -> None:
"max_compute_pre_fetch",
"custom_runtime_estimation",
"insert_overlap_deps",
"collective_estimator",
)
for key in config_keys:
if (val := getattr(dist_opts, key)) is not None:
@ -944,50 +943,6 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
correct = func(inputs_a, inputs_b, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_collective_benchmarking_with_real_pg(self):
"""Test collective benchmarking with real process group (falls back on fake)."""
def func(a):
# Test all three collective types with 8x8 (power of 2 size = 256 elements = 1024 bytes for fp32)
ar = _functional_collectives.all_reduce(a, "sum", "0")
ag = _functional_collectives.all_gather_tensor(
a, 0, list(range(self.world_size))
)
rs = _functional_collectives.reduce_scatter_tensor(a, "sum", 0, "0")
b = torch.matmul(a, a)
c = torch.matmul(ar, b)
return c.sum() + ag.sum() + rs.sum()
patches = {
**get_patches(),
"aten_distributed_optimizations.collective_estimator": "benchmark",
"aten_distributed_optimizations.custom_runtime_estimation": None, # Remove custom estimation so benchmarking happens
}
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(8, 8, dtype=torch.float, device=device_type) + self.rank
with torch._inductor.config.patch(patches):
compiled = torch.compile(func)
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
# Verify all three collective types are present
FileCheck().check("all_reduce").check("all_gather").check(
"reduce_scatter"
).run(aten_graph_str)
# Test passes if compilation succeeded with benchmarking enabled
# Cache verification is tricky due to multiprocess test setup
correct = func(inputs)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_multidtype_bucketing(self):

View File

@ -485,7 +485,7 @@ elif TEST_XPU:
def exit_if_lt_x_accelerators(x):
if torch.accelerator.is_available():
if torch.accelerator.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code)
def with_comms(func=None):

View File

@ -1,6 +1,4 @@
# Owner(s): ["module: dynamo"]
# flake8: noqa: B950
# flake8: noqa: E731
import contextlib
import copy
import functools
@ -17,11 +15,7 @@ import torch.nn as nn
import torch.utils.checkpoint
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
CompileCounterWithBackend,
normalize_gm,
)
from torch._dynamo.testing import CompileCounterWithBackend
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
@ -1655,43 +1649,6 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
self.assertEqual(opt_fn(x), fn(x))
def test_return_same_element_twice(self):
def gn(x):
y = torch.sin(x)
return y, y
def fn(x):
return torch.utils.checkpoint.checkpoint(gn, x, use_reentrant=True)
x = torch.randn(4, 4, requires_grad=True)
ref = fn(x)
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref[0], res[0])
self.assertEqual(ref[1], res[1])
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[4, 4]"):
l_x_ = L_x_
wrap_body_0 = self.wrap_body_0
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = True); wrap_body_0 = l_x_ = None
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]
getitem_1: "f32[4, 4]" = tag_activation_checkpoint[1]; tag_activation_checkpoint = None
return (getitem, getitem_1)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[4, 4]"):
y: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (y, y)
""",
)
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
def test_nonlocal_mutation(self):
counter = 0
@ -1715,114 +1672,6 @@ class GraphModule(torch.nn.Module):
# The mutation is not reapplied in the backward because the flag was on.
self.assertEqual(counter, 1)
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
def test_nonlocal_list_mutation(self):
def gn(x, z):
out = x.sin()
z.append(out)
return torch.cos(torch.sin(torch.matmul(x, x) @ x)), out
def fn(x):
z = []
out1, out2 = torch.utils.checkpoint.checkpoint(
gn,
x,
z,
use_reentrant=False,
)
return out1, z[0]
x = torch.randn(4, 4, requires_grad=True)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref[0], res[0])
self.assertEqual(ref[1], res[1])
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
def test_nonlocal_list_mutation_hidden(self):
def gn(x, z):
o = torch.matmul(x, x) @ x
out = x.sin()
z.append(out)
return torch.cos(torch.sin(o)), torch.sin(x)
def fn(x):
z = []
outs = torch.utils.checkpoint.checkpoint(
gn,
x,
z,
use_reentrant=False,
)
out1 = outs[0]
# Check that the extra output pytree handling is done properly
out2 = outs[-1]
return out1 + out2, z[0]
x = torch.randn(4, 4, requires_grad=True)
ref = fn(x)
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref[0], res[0])
self.assertEqual(ref[1], res[1])
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[4, 4]"):
l_x_ = L_x_
wrap_body_0 = self.wrap_body_0
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = False); wrap_body_0 = l_x_ = None
out1: "f32[4, 4]" = tag_activation_checkpoint[0]
out2: "f32[4, 4]" = tag_activation_checkpoint[1]
getitem_4: "f32[4, 4]" = tag_activation_checkpoint[4]; tag_activation_checkpoint = None
add: "f32[4, 4]" = out1 + out2; out1 = out2 = None
return (add, getitem_4)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[4, 4]"):
matmul: "f32[4, 4]" = torch.matmul(l_x_, l_x_)
o: "f32[4, 4]" = matmul @ l_x_
out: "f32[4, 4]" = l_x_.sin()
sin_1: "f32[4, 4]" = torch.sin(o)
child: "f32[4, 4]" = torch.cos(sin_1)
child_1: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (child, child_1, matmul, o, out, sin_1)
""",
)
self.assertExpectedInline(
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4]"):
mm: "f32[4, 4]" = torch.ops.aten.mm.default(primals_1, primals_1)
mm_1: "f32[4, 4]" = torch.ops.aten.mm.default(mm, primals_1); mm = None
sin: "f32[4, 4]" = torch.ops.aten.sin.default(primals_1)
sin_1: "f32[4, 4]" = torch.ops.aten.sin.default(mm_1); mm_1 = None
cos: "f32[4, 4]" = torch.ops.aten.cos.default(sin_1); sin_1 = None
sin_2: "f32[4, 4]" = torch.ops.aten.sin.default(primals_1)
add: "f32[4, 4]" = torch.ops.aten.add.Tensor(cos, sin_2); cos = sin_2 = None
return (add, sin, primals_1)
""",
)
devices = ["cuda", "hpu"]
instantiate_device_type_tests(

View File

@ -2109,89 +2109,6 @@ Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: '
with self.assertRaises(Unsupported):
outer_f2(inp)
def test_disable_recursive_flags(self):
class SimpleLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer0 = torch.nn.Linear(4, 4)
def forward(self, inp):
return self.layer0(torch.sigmoid(inp))
class SimpleModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer0 = SimpleLinear()
self.layer1 = torch.nn.Linear(4, 4)
def forward(self, inp):
z = self.layer0(torch.sin(inp))
return self.layer1(z)
for recursive_flag in [True, False]:
model = SimpleModel()
other_model = SimpleModel()
model.forward = torch._dynamo.disable(
model.forward,
recursive=recursive_flag,
)
self.assertEqual(
torch._dynamo.is_dynamo_disable_recursive(model.forward),
recursive_flag,
)
other_model = torch._dynamo.disable(other_model, recursive=recursive_flag)
self.assertEqual(
torch._dynamo.is_dynamo_disable_recursive(
other_model.forward
if isinstance(other_model, torch.nn.Module)
else other_model
),
recursive_flag,
)
# check the model is compilable
torch.compile(model)
torch.compile(other_model)
def test_dynamo_disable_annotations(self):
class SimpleModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("buffer", torch.rand(2, 2))
@torch._dynamo.disable()
def f1(self, x) -> torch.Tensor:
return x + self.buffer + 1
@torch._dynamo.disable()
def f2(self, x) -> torch.Tensor:
return x + self.buffer + 2
def forward(self, x) -> torch.Tensor:
return self.f1(x) + self.f2(x)
model = SimpleModel()
inp = torch.rand(2, 2)
with torch.fx.traceback.preserve_node_meta():
exported_model = torch.export.export(model, (inp,))
graph = exported_model.graph_module.graph
found_f1 = False
found_f2 = False
for node in graph.nodes:
if "custom" in node.meta:
if "_torchdynamo_disable_method" in node.meta["custom"]:
if node.meta["custom"]["_torchdynamo_disable_method"] == "f1":
found_f1 = True
elif node.meta["custom"]["_torchdynamo_disable_method"] == "f2":
found_f2 = True
self.assertTrue(found_f1)
self.assertTrue(found_f2)
model.forward = torch._dynamo.disable(model.forward, recursive=False)
with self.assertRaises(RuntimeError):
exported_model = torch.export.export(model, (inp,))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -422,41 +422,34 @@ from user code:
import optree
@torch.compile(backend="eager")
def fn1(x):
tree = {"a": x, "b": (x - 1, 2 * x)}
sin, cos = optree.tree_transpose_map(
lambda t: (torch.sin(t), torch.cos(t)),
tree,
def fn(x):
d = {"a": 1}
optree.tree_flatten_with_path(d)
return torch.sin(x)
def post_munge(s):
s = re.sub(
r"optree\.\S*\.flatten_with_path",
"optree.<path>.flatten_with_path",
s,
)
return sin, cos
fn1(torch.randn(4))
self.assertEqual(len(counters["graph_break"]), 0)
@torch.compile(backend="eager")
def fn2(x):
spec = optree.treespec_deque([])
return spec, x
fn2(torch.randn(4))
self.assertGreaterEqual(len(counters["graph_break"]), 1)
first_graph_break = next(iter(counters["graph_break"].keys()))
def post_munge(string):
return re.sub(
r"(optree\.|qualname: )\S*(\.make_from_collection)",
r"\1<path>\2",
string,
r"qualname: \S*flatten_with_path",
"qualname: <path>.flatten_with_path",
s,
)
fn(torch.randn(4))
self.assertEqual(len(counters["graph_break"]), 1)
first_graph_break = next(iter(counters["graph_break"].keys()))
self.assertExpectedInline(
post_munge(first_graph_break),
"""\
Attempted to call function marked as skipped
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.make_from_collection.
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.flatten_with_path.
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
Developer debug context: module: optree._C, qualname: <path>.make_from_collection, skip reason: <missing reason>
Developer debug context: module: optree._C, qualname: <path>.flatten_with_path, skip reason: <missing reason>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
)
@ -1050,7 +1043,7 @@ Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especiall
msg = re.sub(r"line (\d+)", "line N", msg)
msg = re.sub(
r"""(?s)Traceback \(most recent call last\):.*
File "exc.py", line N, in unimplemented
File "exc.py", line N, in unimplemented_v2
raise Unsupported\(msg\)""",
"<Internal traceback>\n",
msg,

View File

@ -861,7 +861,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
def test_logs_out(self):
import tempfile
with tempfile.NamedTemporaryFile(delete=True) as tmp:
with tempfile.NamedTemporaryFile(delete=False) as tmp:
file_path = _as_posix_path(tmp.name)
"""
NamedTemporaryFile will include a file open operation.
@ -888,6 +888,10 @@ fn(torch.randn(5))
file_path, encoding="utf-8"
) as fd: # encoding file to UTF-8 for Windows.
lines = fd.read()
fd.close()
os.remove(
file_path
) # Delete temp file manually, due to setup NamedTemporaryFile as delete=False.
orig_maxDiff = unittest.TestCase.maxDiff
unittest.TestCase.maxDiff = None
try:
@ -984,7 +988,6 @@ exclusions = {
"hierarchical_compile",
"compute_dependencies",
"annotation",
"node_runtime_estimation",
}
for name in torch._logging._internal.log_registry.artifact_names:
if name not in exclusions:

View File

@ -742,14 +742,11 @@ class TestExport(TestCase):
self.assertExpectedInline(
str(custom_metadata),
"""\
('placeholder', 'x', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
('placeholder', 'y', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
('call_function', 'cat', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'item', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'ge_1', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', '_assert_scalar_default', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'mul', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('output', 'output', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})""",
('call_function', 'cat', {'moo': 0})
('call_function', 'item', {'moo': 0})
('call_function', 'ge_1', {'moo': 0})
('call_function', '_assert_scalar_default', {'moo': 0})
('call_function', 'mul', {'moo': 0})""",
)
@requires_gpu
@ -1224,14 +1221,8 @@ graph():
%p_block_linear2_bias : [num_users=1] = placeholder[target=p_block_linear2_bias]
%x : [num_users=1] = placeholder[target=x]
%wrap_body0 : [num_users=1] = get_attr[target=wrap_body0]
%tag_activation_checkpoint : [num_users=7] = call_function[target=torch.ops.higher_order.tag_activation_checkpoint](args = (%wrap_body0, %x, %p_block_linear1_weight, %p_block_linear1_bias, %p_block_linear2_weight, %p_block_linear2_bias), kwargs = {})
%tag_activation_checkpoint : [num_users=1] = call_function[target=torch.ops.higher_order.tag_activation_checkpoint](args = (%wrap_body0, %x, %p_block_linear1_weight, %p_block_linear1_bias, %p_block_linear2_weight, %p_block_linear2_bias), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 0), kwargs = {})
%getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 1), kwargs = {})
%getitem_2 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 2), kwargs = {})
%getitem_3 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 3), kwargs = {})
%getitem_4 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 4), kwargs = {})
%getitem_5 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 5), kwargs = {})
%getitem_6 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 6), kwargs = {})
return (getitem,)""",
)
@ -1240,14 +1231,14 @@ graph():
"""\
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%arg4_1 : [num_users=2] = placeholder[target=arg4_1]
%linear : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%arg0_1, %arg1_1, %arg2_1), kwargs = {})
%relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%linear,), kwargs = {})
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%arg3_1 : [num_users=1] = placeholder[target=arg3_1]
%arg4_1 : [num_users=1] = placeholder[target=arg4_1]
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%arg0_1, %arg1_1, %arg2_1), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear,), kwargs = {})
%linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %arg3_1, %arg4_1), kwargs = {})
return (linear_1, arg1_1, arg2_1, linear, relu, arg3_1, arg4_1)""",
return (linear_1,)""",
)
stack = contextlib.ExitStack()

View File

@ -2,6 +2,7 @@
import copy
import pathlib
import tempfile
import unittest
@ -96,55 +97,55 @@ def run_with_nativert(ep):
MODEL_NAME = "forward"
# TODO Does named tempfile have collision?
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f:
torch.export.pt2_archive._package.package_pt2(
f, exported_programs={MODEL_NAME: ep_infer}
)
filename = f.name
try:
ep_args, ep_kwargs = ep_infer.example_inputs
ep_args_copied, ep_kwargs_copied = (
copy.deepcopy(ep_args),
copy.deepcopy(ep_kwargs),
)
torch.manual_seed(0)
try:
ep_args, ep_kwargs = ep_infer.example_inputs
ep_args_copied, ep_kwargs_copied = (
copy.deepcopy(ep_args),
copy.deepcopy(ep_kwargs),
flat_expected = pytree.tree_leaves(
ep_infer.module()(*ep_args_copied, **ep_kwargs_copied)
)
torch.manual_seed(0)
try:
flat_expected = pytree.tree_leaves(
ep_infer.module()(*ep_args_copied, **ep_kwargs_copied)
)
except Exception as e:
raise unittest.case.SkipTest(str(e)) from e
except Exception as e:
raise unittest.case.SkipTest(str(e)) from e
model_runner = PyModelRunner(filename, MODEL_NAME)
torch.manual_seed(0)
if _is_supported_types((ep_args, ep_kwargs)):
results = model_runner.run(*ep_args, **ep_kwargs)
model_runner = PyModelRunner(filename, MODEL_NAME)
torch.manual_seed(0)
if _is_supported_types((ep_args, ep_kwargs)):
results = model_runner.run(*ep_args, **ep_kwargs)
else:
results = model_runner.run_with_flat_inputs_and_outputs(
*pytree.tree_leaves((ep_args, ep_kwargs))
)
flat_results = pytree.tree_leaves(results)
assert len(flat_results) == len(flat_expected)
for result, expected in zip(flat_results, flat_expected):
assert type(result) is type(expected)
if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor):
assert result.shape == expected.shape
assert result.dtype == expected.dtype
assert result.device == expected.device
torch.testing.assert_close(result, expected, equal_nan=True)
else:
results = model_runner.run_with_flat_inputs_and_outputs(
*pytree.tree_leaves((ep_args, ep_kwargs))
)
flat_results = pytree.tree_leaves(results)
assert len(flat_results) == len(flat_expected)
for result, expected in zip(flat_results, flat_expected):
assert type(result) is type(expected)
if isinstance(result, torch.Tensor) and isinstance(
expected, torch.Tensor
):
assert result.shape == expected.shape
assert result.dtype == expected.dtype
assert result.device == expected.device
torch.testing.assert_close(result, expected, equal_nan=True)
else:
assert result == expected
except RuntimeError as e:
# User need to register pytree type on the cpp side, which
# cannot be tested in python unittest.
if "Unknown pytree node type" in str(e):
pass
else:
raise e
return ep
assert result == expected
except RuntimeError as e:
# User need to register pytree type on the cpp side, which
# cannot be tested in python unittest.
if "Unknown pytree node type" in str(e):
pass
else:
raise e
finally:
pathlib.Path(filename).unlink(missing_ok=True)
return ep
def mocked_nativert_export_strict(*args, **kwargs):
@ -286,7 +287,7 @@ class TestNativeRT(TestCase):
)
# package everything needed for the NativeRT to execute the AOTI delegate
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f:
package_nativert_with_aoti_delegate(
f,
MODEL_NAME,
@ -297,48 +298,50 @@ class TestNativeRT(TestCase):
)
filename = f.name
try:
ep_args, ep_kwargs = aoti_delegate_ep.example_inputs
ep_args_copied, ep_kwargs_copied = (
copy.deepcopy(ep_args),
copy.deepcopy(ep_kwargs),
)
torch.manual_seed(0)
try:
ep_args, ep_kwargs = aoti_delegate_ep.example_inputs
ep_args_copied, ep_kwargs_copied = (
copy.deepcopy(ep_args),
copy.deepcopy(ep_kwargs),
flat_expected = pytree.tree_leaves(
aoti_delegate_ep.module()(*ep_args_copied, **ep_kwargs_copied)
)
torch.manual_seed(0)
try:
flat_expected = pytree.tree_leaves(
aoti_delegate_ep.module()(*ep_args_copied, **ep_kwargs_copied)
)
except Exception as e:
raise unittest.case.SkipTest(str(e)) from e
except Exception as e:
raise unittest.case.SkipTest(str(e)) from e
model_runner = PyModelRunner(filename, f"{MODEL_NAME}-{BACKEND_ID}")
torch.manual_seed(0)
if _is_supported_types((ep_args, ep_kwargs)):
results = model_runner.run(*ep_args, **ep_kwargs)
model_runner = PyModelRunner(filename, f"{MODEL_NAME}-{BACKEND_ID}")
torch.manual_seed(0)
if _is_supported_types((ep_args, ep_kwargs)):
results = model_runner.run(*ep_args, **ep_kwargs)
else:
results = model_runner.run_with_flat_inputs_and_outputs(
*pytree.tree_leaves((ep_args, ep_kwargs))
)
flat_results = pytree.tree_leaves(results)
assert len(flat_results) == len(flat_expected)
for result, expected in zip(flat_results, flat_expected):
assert type(result) is type(expected)
if isinstance(result, torch.Tensor) and isinstance(
expected, torch.Tensor
):
assert result.shape == expected.shape
assert result.dtype == expected.dtype
assert result.device == expected.device
torch.testing.assert_close(result, expected, equal_nan=True)
else:
results = model_runner.run_with_flat_inputs_and_outputs(
*pytree.tree_leaves((ep_args, ep_kwargs))
)
flat_results = pytree.tree_leaves(results)
assert len(flat_results) == len(flat_expected)
for result, expected in zip(flat_results, flat_expected):
assert type(result) is type(expected)
if isinstance(result, torch.Tensor) and isinstance(
expected, torch.Tensor
):
assert result.shape == expected.shape
assert result.dtype == expected.dtype
assert result.device == expected.device
torch.testing.assert_close(result, expected, equal_nan=True)
else:
assert result == expected
except RuntimeError as e:
# User need to register pytree type on the cpp side, which
# cannot be tested in python unittest.
if "Unknown pytree node type" in str(e):
pass
else:
raise e
assert result == expected
except RuntimeError as e:
# User need to register pytree type on the cpp side, which
# cannot be tested in python unittest.
if "Unknown pytree node type" in str(e):
pass
else:
raise e
finally:
pathlib.Path(filename).unlink(missing_ok=True)
if is_fbcode():

View File

@ -4,7 +4,6 @@ from unittest.mock import patch
import torch
from torch._dynamo.utils import counters
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase
@ -40,56 +39,6 @@ class TestHopPrint(TestCase):
self.assertEqual(printed_output, "moo 1 2")
fx_f = make_fx(f)(x)
new_inp = torch.randn(3, 3)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
fx_f(new_inp)
ori_printed_output = mock_stdout.getvalue().strip()
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
f(new_inp)
fx_printed_output = mock_stdout.getvalue().strip()
self.assertEqual(ori_printed_output, fx_printed_output)
def test_print_with_proxy_graph(self):
class M(torch.nn.Module):
def forward(self, x):
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
torch._higher_order_ops.print("moo {x}", x=x)
res = x + x
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
torch._higher_order_ops.print("yeehop {x}", x=x.shape[0])
return (res,)
inputs = (torch.randn(3),)
# Without functionalization, print should just appear in the graph directly
gm = make_fx(M(), tracing_mode="symbolic")(*inputs)
self.assertExpectedInline(
str(gm.code).strip(),
"""\
def forward(self, arg0_1):
print_1 = torch.ops.higher_order.print('moo {x} {y}', x = 1, y = 2); print_1 = None
print_2 = torch.ops.higher_order.print('moo {x}', x = arg0_1); print_2 = None
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1)
print_3 = torch.ops.higher_order.print('moo {x} {y}', x = 1, y = 2); print_3 = None
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0); arg0_1 = None
print_4 = torch.ops.higher_order.print('yeehop {x}', x = sym_size_int); sym_size_int = print_4 = None
return (add,)""",
)
new_inp = torch.randn(4)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
gm(
new_inp,
)
printed_output = mock_stdout.getvalue().strip()
self.assertEqual(printed_output, f"moo 1 2\nmoo {new_inp}\nmoo 1 2\nyeehop 4")
if __name__ == "__main__":
run_tests()

View File

@ -206,10 +206,6 @@ class TestPyCodeCache(TestCase):
.decode()
.strip()
)
# XPU have extra lines, so get the last line, refer https://github.com/intel/torch-xpu-ops/issues/2261
if torch.xpu.is_available():
wrapper_path = wrapper_path.splitlines()[-1]
hit = hit.splitlines()[-1]
self.assertEqual(hit, "1")
with open(wrapper_path) as f:

View File

@ -73,23 +73,6 @@ class TestCompileWorker(TestCase):
finally:
pool.shutdown()
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_quiesce_repeatedly(self):
pool = SubprocPool(2)
try:
a = pool.submit(operator.add, 100, 1)
pool.quiesce()
pool.wakeup()
b = pool.submit(operator.sub, 100, 1)
pool.quiesce()
pool.quiesce()
pool.wakeup()
b = pool.submit(operator.sub, 100, 1)
self.assertEqual(a.result(), 101)
self.assertEqual(b.result(), 99)
finally:
pool.shutdown()
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_logging(self):
os.environ["MAST_HPC_JOB_NAME"] = "test_job"

View File

@ -5222,7 +5222,6 @@ xfail_by_backend = {
"test_reentrant_with_callbacks_both_depths", # queue_callback
"test_reentrant_with_callbacks_depth_0", # queue_callback
"test_reentrant_with_callbacks_depth_1", # queue_callback
"test_checkpoint_graph_execution_group", # Attempted to call function marked as skipped
"test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
"test_post_accumulate_grad_hook_ordering", # accuracy error

View File

@ -3278,15 +3278,6 @@ class CPUReproTests(TestCase):
metrics.reset()
self.common(fn, (x,))
def test_softmax_with_zero_dim(self):
def fn(x):
x = torch.softmax(x, 0)
return x
x = torch.rand([], dtype=torch.bfloat16)
metrics.reset()
self.common(fn, (x,))
@config.patch({"fx_graph_cache": False, "fx_graph_remote_cache": False})
def test_local_buffer_in_outer_loop_fusion(self):
def fn(x):

View File

@ -148,24 +148,6 @@ class FxirTestCase(InductorTestCase):
args = [torch.randn(8, device=self.device) for _ in range(2)]
self._compile_and_check(torch.add, args)
def test_device_type(self):
"""
Test that we allocate on a device type instead of a specific index.
"""
# Pass in a tensor on an indexed device.
device_runtime = getattr(torch, self.device)
indexed_device = torch.device(self.device, device_runtime.current_device())
args = [torch.randn(8, device=indexed_device) for _ in range(2)]
(gm,) = self._compile_and_check(torch.add, args)
(empty_strided,) = gm.graph.find_nodes(
op="call_function", target=torch.empty_strided
)
# Check that the device of the output allocation is not indexed.
output_device = torch.device(empty_strided.kwargs["device"])
self.assertIs(output_device.index, None)
self.assertEqual(output_device.type, indexed_device.type)
def test_multiple_kernels(self):
def foo(x, y):
return x.sum() + y.sum()

View File

@ -3,8 +3,7 @@
import functools
import weakref
from collections import Counter
from collections.abc import Callable
from typing import Optional
from typing import Callable, Optional
import torch
from torch._inductor.fx_passes.memory_estimator import (
@ -29,7 +28,7 @@ def device_filter(device):
class FakeTensorMemoryProfilerMode(TorchDispatchMode):
def __init__(self, device_filter: Optional[Callable[[torch.device], bool]] = None):
def __init__(self, device_filter: Optional[Callable[torch.device, bool]] = None):
# counter of storage ids to live references
self.storage_count: dict[int, int] = Counter()
# live fake tensors

View File

@ -52,12 +52,9 @@ def make_pallas(cls):
return test_class
class PallasTestsMixin:
"""Basic tests for Pallas backend functionality (parameterized by DEVICE). Mixin only, not collected."""
def _compile(self, fn):
key = "cuda_backend" if self.DEVICE == "cuda" else "cpu_backend"
return torch.compile(fn, backend="inductor", options={key: "pallas"})
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
class PallasTests(TestCase):
"""Basic tests for Pallas backend functionality."""
def test_simple_add(self):
"""Test basic element-wise addition."""
@ -65,10 +62,12 @@ class PallasTestsMixin:
def fn(a, b):
return a + b
compiled = self._compile(fn)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
a = torch.randn(1024, device=self.DEVICE)
b = torch.randn(1024, device=self.DEVICE)
a = torch.randn(1024, device="cuda")
b = torch.randn(1024, device="cuda")
result = compiled(a, b)
expected = fn(a, b)
self.assertEqual(result, expected)
@ -79,10 +78,12 @@ class PallasTestsMixin:
def fn(a, b):
return a * b
compiled = self._compile(fn)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
a = torch.randn(1024, device=self.DEVICE)
b = torch.randn(1024, device=self.DEVICE)
a = torch.randn(1024, device="cuda")
b = torch.randn(1024, device="cuda")
result = compiled(a, b)
expected = fn(a, b)
self.assertEqual(result, expected)
@ -93,9 +94,11 @@ class PallasTestsMixin:
def fn(x):
return torch.sin(x)
compiled = self._compile(fn)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
x = torch.randn(1024, device=self.DEVICE)
x = torch.randn(1024, device="cuda")
result = compiled(x)
expected = fn(x)
self.assertEqual(result, expected)
@ -106,10 +109,12 @@ class PallasTestsMixin:
def fn(x, y):
return x.sin() + y
compiled = self._compile(fn)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
x = torch.randn(1024, device=self.DEVICE)
y = torch.randn(1024, device=self.DEVICE)
x = torch.randn(1024, device="cuda")
y = torch.randn(1024, device="cuda")
result = compiled(x, y)
expected = fn(x, y)
self.assertEqual(result, expected)
@ -120,9 +125,11 @@ class PallasTestsMixin:
def fn(x):
return torch.log(torch.exp(x))
compiled = self._compile(fn)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
x = torch.randn(1024, device=self.DEVICE)
x = torch.randn(1024, device="cuda")
result = compiled(x)
expected = fn(x)
self.assertEqual(result, expected)
@ -133,9 +140,11 @@ class PallasTestsMixin:
def fn(x):
return torch.sqrt(x)
compiled = self._compile(fn)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
x = torch.randn(1024, device=self.DEVICE).abs() # Ensure positive for sqrt
x = torch.randn(1024, device="cuda").abs() # Ensure positive for sqrt
result = compiled(x)
expected = fn(x)
self.assertEqual(result, expected)
@ -146,9 +155,11 @@ class PallasTestsMixin:
def fn(x):
return torch.tanh(x)
compiled = self._compile(fn)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
x = torch.randn(1024, device=self.DEVICE)
x = torch.randn(1024, device="cuda")
result = compiled(x)
expected = fn(x)
self.assertEqual(result, expected)
@ -159,9 +170,11 @@ class PallasTestsMixin:
def fn(x):
return torch.abs(-x)
compiled = self._compile(fn)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
x = torch.randn(1024, device=self.DEVICE)
x = torch.randn(1024, device="cuda")
result = compiled(x)
expected = fn(x)
self.assertEqual(result, expected)
@ -172,10 +185,12 @@ class PallasTestsMixin:
def fn(a, b):
return torch.maximum(a, b) + torch.minimum(a, b)
compiled = self._compile(fn)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
a = torch.randn(1024, device=self.DEVICE)
b = torch.randn(1024, device=self.DEVICE)
a = torch.randn(1024, device="cuda")
b = torch.randn(1024, device="cuda")
result = compiled(a, b)
expected = fn(a, b)
self.assertEqual(result, expected)
@ -213,17 +228,15 @@ class PallasTestsMixin:
@torch.compile(
backend="inductor",
options={
("cuda_backend" if self.DEVICE == "cuda" else "cpu_backend"): "pallas"
},
options={"cuda_backend": "pallas"},
)
def pallas_fn(a, b):
return a.sin() + b.cos()
_, (code,) = run_and_get_code(
pallas_fn,
torch.randn(64, device=self.DEVICE),
torch.randn(64, device=self.DEVICE),
torch.randn(64, device="cuda"),
torch.randn(64, device="cuda"),
)
# Verify Pallas-specific code generation
self.assertIn("import jax", code)
@ -236,10 +249,12 @@ class PallasTestsMixin:
def fn(x, y):
return x + y
compiled = self._compile(fn)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
x = torch.randn(32, 32, device=self.DEVICE)
y = torch.randn(32, 32, device=self.DEVICE)
x = torch.randn(32, 32, device="cuda")
y = torch.randn(32, 32, device="cuda")
result = compiled(x, y)
expected = fn(x, y)
self.assertEqual(result, expected)
@ -250,10 +265,12 @@ class PallasTestsMixin:
def fn(x):
return x * 2.0
compiled = self._compile(fn)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
for shape in [(64,), (128,), (256,), (1024,)]:
x = torch.randn(shape, device=self.DEVICE)
x = torch.randn(shape, device="cuda")
result = compiled(x)
expected = fn(x)
self.assertEqual(result, expected)
@ -265,10 +282,12 @@ class PallasTestsMixin:
def contiguous_add(a, b):
return a + b
compiled = self._compile(contiguous_add)
compiled = torch.compile(
contiguous_add, backend="inductor", options={"cuda_backend": "pallas"}
)
a = torch.randn(1024, device=self.DEVICE)
b = torch.randn(1024, device=self.DEVICE)
a = torch.randn(1024, device="cuda")
b = torch.randn(1024, device="cuda")
result = compiled(a, b)
expected = contiguous_add(a, b)
self.assertEqual(result, expected)
@ -277,9 +296,11 @@ class PallasTestsMixin:
def contiguous_mul(x):
return x * 2.0
compiled = self._compile(contiguous_mul)
compiled = torch.compile(
contiguous_mul, backend="inductor", options={"cuda_backend": "pallas"}
)
x = torch.randn(128, 8, device=self.DEVICE)
x = torch.randn(128, 8, device="cuda")
result = compiled(x)
expected = contiguous_mul(x)
self.assertEqual(result, expected)
@ -289,10 +310,12 @@ class PallasTestsMixin:
def operate_on_tensor(x):
return x.sin()
compiled = self._compile(operate_on_tensor)
compiled = torch.compile(
operate_on_tensor, backend="inductor", options={"cuda_backend": "pallas"}
)
# Create a transposed (non-contiguous) view
x = torch.randn(64, 32, device=self.DEVICE)
x = torch.randn(64, 32, device="cuda")
x_t = x.t() # Non-contiguous view
self.assertFalse(x_t.is_contiguous())
@ -309,24 +332,13 @@ class PallasTestsMixin:
self.assertEqual(result, expected)
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
class PallasTestsCUDA(PallasTestsMixin, TestCase):
DEVICE = "cuda"
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
class PallasTestsCPU(PallasTestsMixin, TestCase):
DEVICE = "cpu"
# Create test variants using the main test suite
# Note: Only enable GPU tests since Pallas primarily targets GPU
if hasattr(sys.modules.get(__name__), "test_torchinductor") and HAS_PALLAS:
if getattr(test_torchinductor, "HAS_GPU", False):
# Uncomment these to run full test suite with Pallas backend
# make_pallas(test_torchinductor.SweepInputsGPUTest)
# make_pallas(test_torchinductor.GPUTests)
pass
if test_torchinductor.HAS_GPU and HAS_PALLAS:
# Uncomment these to run full test suite with Pallas backend
# make_pallas(test_torchinductor.SweepInputsGPUTest)
# make_pallas(test_torchinductor.GPUTests)
pass
if __name__ == "__main__":
if HAS_PALLAS:

View File

@ -1217,43 +1217,6 @@ class TestPatternMatcher(TestCase):
_, (code) = run_and_get_code(fn2, args[0], args[1], args[2])
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
def test_addmm_alpha_beta_with_pointwise(self):
# Test that addmm with alpha/beta != 1 is unfused correctly with pointwise ops
# See https://github.com/pytorch/pytorch/issues/167313
x = torch.rand(2, device=GPU_TYPE)
a = torch.rand(2, 3, device=GPU_TYPE)
b = torch.rand(3, 2, device=GPU_TYPE)
def f(x, a, b):
return torch.nn.functional.relu(torch.addmm(x, a, b, alpha=0.8, beta=0.2))
fc = torch.compile(f)
expected = f(x, a, b)
actual = fc(x, a, b)
# The compiled version should produce the same result as eager
torch.testing.assert_close(actual, expected)
# Verify that addmm is unfused (should not use extern_kernels.addmm)
# The pattern should be replaced with beta * x + alpha * (a @ b)
_, (code) = run_and_get_code(fc, x, a, b)
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
# Test with alpha=1, beta=1 (default) - should also unfuse
def f_default(x, a, b):
return torch.nn.functional.relu(torch.addmm(x, a, b))
fc_default = torch.compile(f_default)
expected_default = f_default(x, a, b)
actual_default = fc_default(x, a, b)
torch.testing.assert_close(actual_default, expected_default)
# Should unfuse and not use extern_kernels.addmm
_, (code) = run_and_get_code(fc_default, x, a, b)
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
def test_serialized_patterns_up_to_date(self):
import torch.utils._pytree as pytree
from torch._inductor.fx_passes import joint_graph

View File

@ -7,7 +7,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase
class TestUpgraderModelGeneration(TestCase):
def test_all_modules(self):
for a_module in ALL_MODULES:
for a_module in ALL_MODULES.keys():
module_name = type(a_module).__name__
self.assertTrue(
isinstance(a_module, torch.nn.Module),

View File

@ -2979,7 +2979,7 @@ class TestScriptList(JitTestCase):
self.col2 = "b"
def forward(self):
if self.col1 in self.segments_groupby_col:
if self.col1 in self.segments_groupby_col.keys():
return 1
else:
return 2

View File

@ -78,7 +78,7 @@ class TestModuleContainers(JitTestCase):
x = mod(x)
values.append(x)
for key in self.moduledict:
for key in self.moduledict.keys():
names.append(key)
return x, names
@ -306,7 +306,7 @@ class TestModuleContainers(JitTestCase):
assert "submod" in self.moduledict, "__contains__ fails for ModuleDict"
for key in self.moduledict:
for key in self.moduledict.keys():
assert key == "submod", "keys() fails for ModuleDict"
for item in self.moduledict.items():

View File

@ -276,7 +276,7 @@ class TestPDT(JitTestCase):
def test_multiple_class_with_same_method(self):
class PDTModelOne:
def test_find(self, a, b):
return b in a
return b in a.keys()
class PDTModelTwo:
def test_find(self, a, b):

View File

@ -342,7 +342,7 @@ class TestTyping(JitTestCase):
# type: (Dict[str, int]) -> Tuple[str, int]
key_str = ""
sum = 0
for key in x:
for key in x.keys():
key_str += key
for val in x.values():
sum += val

View File

@ -310,7 +310,7 @@ class TestLoadStateDict(NNTestCase):
# Make sure parameters and persistent buffers were assigned
net_meta_state_dict = net_meta.state_dict(keep_vars=True)
for key in state_dict:
for key in state_dict.keys():
if key in net_meta._parameters:
if keep_vars and not swap:
# state_dict[key] is an nn.Parameter

Some files were not shown because too many files have changed in this diff Show More