mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 15:44:58 +08:00
Compare commits
13 Commits
codex-test
...
v1.11.0-rc
| Author | SHA1 | Date | |
|---|---|---|---|
| 503a0923d3 | |||
| 6641e9b75f | |||
| 4f9f0e7a13 | |||
| 0ea924fc98 | |||
| 5a78725c29 | |||
| f72151b900 | |||
| 8380187819 | |||
| 7cc129e60c | |||
| ff6c348762 | |||
| 03a283b2b1 | |||
| 614e765575 | |||
| 7b0e140ecc | |||
| 3fab33e1c9 |
@ -120,9 +120,9 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then
|
|||||||
# Install numba only on python-3.8 or below
|
# Install numba only on python-3.8 or below
|
||||||
# For numba issue see https://github.com/pytorch/pytorch/issues/51511
|
# For numba issue see https://github.com/pytorch/pytorch/issues/51511
|
||||||
if [[ $(python -c "import sys; print(int(sys.version_info < (3, 9)))") == "1" ]]; then
|
if [[ $(python -c "import sys; print(int(sys.version_info < (3, 9)))") == "1" ]]; then
|
||||||
as_jenkins pip install --progress-bar off numba==0.54.1 librosa>=0.6.2
|
as_jenkins pip install --progress-bar off numba==0.54.1 "librosa>=0.6.2,<0.9.0"
|
||||||
else
|
else
|
||||||
as_jenkins pip install --progress-bar off numba==0.49.0 librosa>=0.6.2
|
as_jenkins pip install --progress-bar off numba==0.49.0 "librosa>=0.6.2,<0.9.0"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Update scikit-learn to a python-3.8 compatible version
|
# Update scikit-learn to a python-3.8 compatible version
|
||||||
|
|||||||
@ -61,7 +61,7 @@ git --no-pager log --max-count 1
|
|||||||
popd
|
popd
|
||||||
|
|
||||||
# Clone the Builder master repo
|
# Clone the Builder master repo
|
||||||
retry git clone -q https://github.com/pytorch/builder.git "$BUILDER_ROOT"
|
retry git clone -q https://github.com/pytorch/builder.git -b release/1.11 "$BUILDER_ROOT"
|
||||||
pushd "$BUILDER_ROOT"
|
pushd "$BUILDER_ROOT"
|
||||||
echo "Using builder from "
|
echo "Using builder from "
|
||||||
git --no-pager log --max-count 1
|
git --no-pager log --max-count 1
|
||||||
|
|||||||
@ -33,7 +33,7 @@ fi
|
|||||||
cp ${PROJ_ROOT}/LICENSE ${ZIP_DIR}/
|
cp ${PROJ_ROOT}/LICENSE ${ZIP_DIR}/
|
||||||
# zip the library
|
# zip the library
|
||||||
export DATE="$(date -u +%Y%m%d)"
|
export DATE="$(date -u +%Y%m%d)"
|
||||||
export IOS_NIGHTLY_BUILD_VERSION="1.12.0.${DATE}"
|
export IOS_NIGHTLY_BUILD_VERSION="1.11.0.${DATE}"
|
||||||
if [ "${BUILD_LITE_INTERPRETER}" == "1" ]; then
|
if [ "${BUILD_LITE_INTERPRETER}" == "1" ]; then
|
||||||
# libtorch_lite_ios_nightly_1.11.0.20210810.zip
|
# libtorch_lite_ios_nightly_1.11.0.20210810.zip
|
||||||
ZIPFILE="libtorch_lite_ios_nightly_${IOS_NIGHTLY_BUILD_VERSION}.zip"
|
ZIPFILE="libtorch_lite_ios_nightly_${IOS_NIGHTLY_BUILD_VERSION}.zip"
|
||||||
|
|||||||
@ -5,7 +5,7 @@ export TZ=UTC
|
|||||||
tagged_version() {
|
tagged_version() {
|
||||||
# Grabs version from either the env variable CIRCLE_TAG
|
# Grabs version from either the env variable CIRCLE_TAG
|
||||||
# or the pytorch git described version
|
# or the pytorch git described version
|
||||||
if [[ "$OSTYPE" == "msys" ]]; then
|
if [[ "$OSTYPE" == "msys" && -z "${IS_GHA:-}" ]]; then
|
||||||
GIT_DIR="${workdir}/p/.git"
|
GIT_DIR="${workdir}/p/.git"
|
||||||
else
|
else
|
||||||
GIT_DIR="${workdir}/pytorch/.git"
|
GIT_DIR="${workdir}/pytorch/.git"
|
||||||
@ -13,6 +13,9 @@ tagged_version() {
|
|||||||
GIT_DESCRIBE="git --git-dir ${GIT_DIR} describe --tags --match v[0-9]*.[0-9]*.[0-9]*"
|
GIT_DESCRIBE="git --git-dir ${GIT_DIR} describe --tags --match v[0-9]*.[0-9]*.[0-9]*"
|
||||||
if [[ -n "${CIRCLE_TAG:-}" ]]; then
|
if [[ -n "${CIRCLE_TAG:-}" ]]; then
|
||||||
echo "${CIRCLE_TAG}"
|
echo "${CIRCLE_TAG}"
|
||||||
|
elif [[ ! -d "${GIT_DIR}" ]]; then
|
||||||
|
echo "Abort, abort! Git dir ${GIT_DIR} does not exists!"
|
||||||
|
kill $$
|
||||||
elif ${GIT_DESCRIBE} --exact >/dev/null; then
|
elif ${GIT_DESCRIBE} --exact >/dev/null; then
|
||||||
${GIT_DESCRIBE}
|
${GIT_DESCRIBE}
|
||||||
else
|
else
|
||||||
@ -58,7 +61,12 @@ if [[ -z ${IS_GHA:-} ]]; then
|
|||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
envfile=${BINARY_ENV_FILE:-/tmp/env}
|
envfile=${BINARY_ENV_FILE:-/tmp/env}
|
||||||
workdir="/pytorch"
|
if [[ -n "${PYTORCH_ROOT}" ]]; then
|
||||||
|
workdir=$(dirname "${PYTORCH_ROOT}")
|
||||||
|
else
|
||||||
|
# docker executor (binary builds)
|
||||||
|
workdir="/"
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ "$PACKAGE_TYPE" == 'libtorch' ]]; then
|
if [[ "$PACKAGE_TYPE" == 'libtorch' ]]; then
|
||||||
@ -94,7 +102,7 @@ PIP_UPLOAD_FOLDER='nightly/'
|
|||||||
# We put this here so that OVERRIDE_PACKAGE_VERSION below can read from it
|
# We put this here so that OVERRIDE_PACKAGE_VERSION below can read from it
|
||||||
export DATE="$(date -u +%Y%m%d)"
|
export DATE="$(date -u +%Y%m%d)"
|
||||||
#TODO: We should be pulling semver version from the base version.txt
|
#TODO: We should be pulling semver version from the base version.txt
|
||||||
BASE_BUILD_VERSION="1.12.0.dev$DATE"
|
BASE_BUILD_VERSION="1.11.0.dev$DATE"
|
||||||
# Change BASE_BUILD_VERSION to git tag when on a git tag
|
# Change BASE_BUILD_VERSION to git tag when on a git tag
|
||||||
# Use 'git -C' to make doubly sure we're in the correct directory for checking
|
# Use 'git -C' to make doubly sure we're in the correct directory for checking
|
||||||
# the git tag
|
# the git tag
|
||||||
@ -157,7 +165,7 @@ if [[ "${BUILD_FOR_SYSTEM:-}" == "windows" ]]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
export DATE="$DATE"
|
export DATE="$DATE"
|
||||||
export NIGHTLIES_DATE_PREAMBLE=1.12.0.dev
|
export NIGHTLIES_DATE_PREAMBLE=1.11.0.dev
|
||||||
export PYTORCH_BUILD_VERSION="$PYTORCH_BUILD_VERSION"
|
export PYTORCH_BUILD_VERSION="$PYTORCH_BUILD_VERSION"
|
||||||
export PYTORCH_BUILD_NUMBER="$PYTORCH_BUILD_NUMBER"
|
export PYTORCH_BUILD_NUMBER="$PYTORCH_BUILD_NUMBER"
|
||||||
export OVERRIDE_PACKAGE_VERSION="$PYTORCH_BUILD_VERSION"
|
export OVERRIDE_PACKAGE_VERSION="$PYTORCH_BUILD_VERSION"
|
||||||
|
|||||||
21
.github/templates/linux_ci_workflow.yml.j2
vendored
21
.github/templates/linux_ci_workflow.yml.j2
vendored
@ -11,8 +11,14 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
push:
|
push:
|
||||||
|
{%- if enable_doc_jobs and is_scheduled %}
|
||||||
|
tags:
|
||||||
|
# NOTE: Binary build pipelines should only get triggered on release candidate builds
|
||||||
|
# Release candidate tags look like: v1.11.0-rc1
|
||||||
|
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
|
||||||
|
{%- endif %}
|
||||||
{%- for label in ciflow_config.labels | sort %}
|
{%- for label in ciflow_config.labels | sort %}
|
||||||
{%- if loop.first %}
|
{%- if loop.first and not (enable_doc_jobs and is_scheduled) %}
|
||||||
tags:
|
tags:
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- if label != "ciflow/default" %}
|
{%- if label != "ciflow/default" %}
|
||||||
@ -364,7 +370,7 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
DOCKER_IMAGE: ${{ needs.build.outputs.docker_image }}
|
DOCKER_IMAGE: ${{ needs.build.outputs.docker_image }}
|
||||||
DOCS_TYPE: ${{ matrix.docs_type }}
|
DOCS_TYPE: ${{ matrix.docs_type }}
|
||||||
WITH_PUSH: ${{ github.event_name == 'schedule' }}
|
WITH_PUSH: ${{ github.event_name == 'schedule' || startsWith(github.event.ref, 'refs/tags/v') }}
|
||||||
steps:
|
steps:
|
||||||
!{{ common.setup_ec2_linux() }}
|
!{{ common.setup_ec2_linux() }}
|
||||||
!{{ common.checkout() }}
|
!{{ common.checkout() }}
|
||||||
@ -381,7 +387,7 @@ jobs:
|
|||||||
unzip -o artifacts.zip
|
unzip -o artifacts.zip
|
||||||
{%- if is_scheduled %}
|
{%- if is_scheduled %}
|
||||||
- name: Generate netrc (only for docs-push)
|
- name: Generate netrc (only for docs-push)
|
||||||
if: ${{ github.event_name == 'schedule' }}
|
if: ${{ github.event_name == 'schedule' || startsWith(github.event.ref, 'refs/tags/v') }}
|
||||||
env:
|
env:
|
||||||
GITHUB_PYTORCHBOT_TOKEN: ${{ secrets.GH_PYTORCHBOT_TOKEN }}
|
GITHUB_PYTORCHBOT_TOKEN: ${{ secrets.GH_PYTORCHBOT_TOKEN }}
|
||||||
run: |
|
run: |
|
||||||
@ -394,9 +400,12 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
set -ex
|
set -ex
|
||||||
time docker pull "${DOCKER_IMAGE}" > /dev/null
|
time docker pull "${DOCKER_IMAGE}" > /dev/null
|
||||||
echo "${GITHUB_REF}"
|
# Convert refs/tags/v1.12.0rc3 into 1.12
|
||||||
# TODO: Set it correctly when workflows are scheduled on tags
|
if [[ "${GITHUB_REF}" =~ ^refs/tags/v([0-9]+\.[0-9]+)\.* ]]; then
|
||||||
target="master"
|
target="${BASH_REMATCH[1]}"
|
||||||
|
else
|
||||||
|
target="master"
|
||||||
|
fi
|
||||||
# detached container should get cleaned up by teardown_ec2_linux
|
# detached container should get cleaned up by teardown_ec2_linux
|
||||||
container_name=$(docker run \
|
container_name=$(docker run \
|
||||||
-e BUILD_ENVIRONMENT \
|
-e BUILD_ENVIRONMENT \
|
||||||
|
|||||||
16
.github/workflows/generated-linux-docs-push.yml
generated
vendored
16
.github/workflows/generated-linux-docs-push.yml
generated
vendored
@ -6,6 +6,9 @@ name: linux-docs-push
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
|
# NOTE: Binary build pipelines should only get triggered on release candidate builds
|
||||||
|
# Release candidate tags look like: v1.11.0-rc1
|
||||||
|
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
|
||||||
- 'ciflow/all/*'
|
- 'ciflow/all/*'
|
||||||
- 'ciflow/cpu/*'
|
- 'ciflow/cpu/*'
|
||||||
- 'ciflow/linux/*'
|
- 'ciflow/linux/*'
|
||||||
@ -255,7 +258,7 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
DOCKER_IMAGE: ${{ needs.build.outputs.docker_image }}
|
DOCKER_IMAGE: ${{ needs.build.outputs.docker_image }}
|
||||||
DOCS_TYPE: ${{ matrix.docs_type }}
|
DOCS_TYPE: ${{ matrix.docs_type }}
|
||||||
WITH_PUSH: ${{ github.event_name == 'schedule' }}
|
WITH_PUSH: ${{ github.event_name == 'schedule' || startsWith(github.event.ref, 'refs/tags/v') }}
|
||||||
steps:
|
steps:
|
||||||
- name: Display EC2 information
|
- name: Display EC2 information
|
||||||
shell: bash
|
shell: bash
|
||||||
@ -324,7 +327,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
unzip -o artifacts.zip
|
unzip -o artifacts.zip
|
||||||
- name: Generate netrc (only for docs-push)
|
- name: Generate netrc (only for docs-push)
|
||||||
if: ${{ github.event_name == 'schedule' }}
|
if: ${{ github.event_name == 'schedule' || startsWith(github.event.ref, 'refs/tags/v') }}
|
||||||
env:
|
env:
|
||||||
GITHUB_PYTORCHBOT_TOKEN: ${{ secrets.GH_PYTORCHBOT_TOKEN }}
|
GITHUB_PYTORCHBOT_TOKEN: ${{ secrets.GH_PYTORCHBOT_TOKEN }}
|
||||||
run: |
|
run: |
|
||||||
@ -336,9 +339,12 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
set -ex
|
set -ex
|
||||||
time docker pull "${DOCKER_IMAGE}" > /dev/null
|
time docker pull "${DOCKER_IMAGE}" > /dev/null
|
||||||
echo "${GITHUB_REF}"
|
# Convert refs/tags/v1.12.0rc3 into 1.12
|
||||||
# TODO: Set it correctly when workflows are scheduled on tags
|
if [[ "${GITHUB_REF}" =~ ^refs/tags/v([0-9]+\.[0-9]+)\.* ]]; then
|
||||||
target="master"
|
target="${BASH_REMATCH[1]}"
|
||||||
|
else
|
||||||
|
target="master"
|
||||||
|
fi
|
||||||
# detached container should get cleaned up by teardown_ec2_linux
|
# detached container should get cleaned up by teardown_ec2_linux
|
||||||
container_name=$(docker run \
|
container_name=$(docker run \
|
||||||
-e BUILD_ENVIRONMENT \
|
-e BUILD_ENVIRONMENT \
|
||||||
|
|||||||
11
.github/workflows/generated-linux-docs.yml
generated
vendored
11
.github/workflows/generated-linux-docs.yml
generated
vendored
@ -258,7 +258,7 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
DOCKER_IMAGE: ${{ needs.build.outputs.docker_image }}
|
DOCKER_IMAGE: ${{ needs.build.outputs.docker_image }}
|
||||||
DOCS_TYPE: ${{ matrix.docs_type }}
|
DOCS_TYPE: ${{ matrix.docs_type }}
|
||||||
WITH_PUSH: ${{ github.event_name == 'schedule' }}
|
WITH_PUSH: ${{ github.event_name == 'schedule' || startsWith(github.event.ref, 'refs/tags/v') }}
|
||||||
steps:
|
steps:
|
||||||
- name: Display EC2 information
|
- name: Display EC2 information
|
||||||
shell: bash
|
shell: bash
|
||||||
@ -330,9 +330,12 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
set -ex
|
set -ex
|
||||||
time docker pull "${DOCKER_IMAGE}" > /dev/null
|
time docker pull "${DOCKER_IMAGE}" > /dev/null
|
||||||
echo "${GITHUB_REF}"
|
# Convert refs/tags/v1.12.0rc3 into 1.12
|
||||||
# TODO: Set it correctly when workflows are scheduled on tags
|
if [[ "${GITHUB_REF}" =~ ^refs/tags/v([0-9]+\.[0-9]+)\.* ]]; then
|
||||||
target="master"
|
target="${BASH_REMATCH[1]}"
|
||||||
|
else
|
||||||
|
target="master"
|
||||||
|
fi
|
||||||
# detached container should get cleaned up by teardown_ec2_linux
|
# detached container should get cleaned up by teardown_ec2_linux
|
||||||
container_name=$(docker run \
|
container_name=$(docker run \
|
||||||
-e BUILD_ENVIRONMENT \
|
-e BUILD_ENVIRONMENT \
|
||||||
|
|||||||
@ -100,6 +100,6 @@ function checkout_install_torchvision() {
|
|||||||
|
|
||||||
function clone_pytorch_xla() {
|
function clone_pytorch_xla() {
|
||||||
if [[ ! -d ./xla ]]; then
|
if [[ ! -d ./xla ]]; then
|
||||||
git clone --recursive https://github.com/pytorch/xla.git
|
git clone --recursive -b release/1.11 https://github.com/pytorch/xla.git
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,7 +7,7 @@ source "$(dirname "${BASH_SOURCE[0]}")/macos-common.sh"
|
|||||||
export PYTORCH_TEST_SKIP_NOARCH=1
|
export PYTORCH_TEST_SKIP_NOARCH=1
|
||||||
|
|
||||||
conda install -y six
|
conda install -y six
|
||||||
pip install -q hypothesis "expecttest==0.1.3" "librosa>=0.6.2" "numba<=0.49.1" psutil "scipy==1.6.3"
|
pip install -q hypothesis "expecttest==0.1.3" "librosa>=0.6.2,<0.9.0" "numba<=0.49.1" psutil "scipy==1.6.3"
|
||||||
|
|
||||||
# TODO move this to docker
|
# TODO move this to docker
|
||||||
pip install unittest-xml-reporting pytest
|
pip install unittest-xml-reporting pytest
|
||||||
|
|||||||
@ -445,7 +445,7 @@ test_forward_backward_compatibility() {
|
|||||||
python -m venv venv
|
python -m venv venv
|
||||||
# shellcheck disable=SC1091
|
# shellcheck disable=SC1091
|
||||||
. venv/bin/activate
|
. venv/bin/activate
|
||||||
pip_install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
pip_install --pre torch -f https://download.pytorch.org/whl/test/cpu/torch_test.html
|
||||||
pip show torch
|
pip show torch
|
||||||
python dump_all_function_schemas.py --filename nightly_schemas.txt
|
python dump_all_function_schemas.py --filename nightly_schemas.txt
|
||||||
deactivate
|
deactivate
|
||||||
|
|||||||
@ -34,7 +34,7 @@ popd
|
|||||||
|
|
||||||
:: The version is fixed to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136
|
:: The version is fixed to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136
|
||||||
=======
|
=======
|
||||||
pip install "ninja==1.10.0.post1" future "hypothesis==4.53.2" "expecttest==0.1.3" "librosa>=0.6.2" psutil pillow unittest-xml-reporting pytest
|
pip install "ninja==1.10.0.post1" future "hypothesis==4.53.2" "expecttest==0.1.3" "librosa>=0.6.2,<0.9.0" psutil pillow unittest-xml-reporting pytest
|
||||||
if errorlevel 1 exit /b
|
if errorlevel 1 exit /b
|
||||||
if not errorlevel 0 exit /b
|
if not errorlevel 0 exit /b
|
||||||
|
|
||||||
|
|||||||
@ -34,8 +34,8 @@ repositories {
|
|||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
...
|
...
|
||||||
implementation 'org.pytorch:pytorch_android:1.12.0-SNAPSHOT'
|
implementation 'org.pytorch:pytorch_android:1.11.0-SNAPSHOT'
|
||||||
implementation 'org.pytorch:pytorch_android_torchvision:1.12.0-SNAPSHOT'
|
implementation 'org.pytorch:pytorch_android_torchvision:1.11.0-SNAPSHOT'
|
||||||
...
|
...
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64
|
ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64
|
||||||
|
|
||||||
VERSION_NAME=1.12.0-SNAPSHOT
|
VERSION_NAME=1.11.0-SNAPSHOT
|
||||||
GROUP=org.pytorch
|
GROUP=org.pytorch
|
||||||
MAVEN_GROUP=org.pytorch
|
MAVEN_GROUP=org.pytorch
|
||||||
SONATYPE_STAGING_PROFILE=orgpytorch
|
SONATYPE_STAGING_PROFILE=orgpytorch
|
||||||
|
|||||||
@ -149,8 +149,8 @@ dependencies {
|
|||||||
//nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar')
|
//nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar')
|
||||||
//extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar')
|
//extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar')
|
||||||
|
|
||||||
nightlyImplementation 'org.pytorch:pytorch_android:1.12.0-SNAPSHOT'
|
nightlyImplementation 'org.pytorch:pytorch_android:1.11.0-SNAPSHOT'
|
||||||
nightlyImplementation 'org.pytorch:pytorch_android_torchvision:1.12.0-SNAPSHOT'
|
nightlyImplementation 'org.pytorch:pytorch_android_torchvision:1.11.0-SNAPSHOT'
|
||||||
|
|
||||||
aarImplementation(name:'pytorch_android', ext:'aar')
|
aarImplementation(name:'pytorch_android', ext:'aar')
|
||||||
aarImplementation(name:'pytorch_android_torchvision', ext:'aar')
|
aarImplementation(name:'pytorch_android_torchvision', ext:'aar')
|
||||||
|
|||||||
@ -69,8 +69,8 @@ inline DimVector computeStrideForViewAsComplex(IntArrayRef oldstride) {
|
|||||||
// and returns back a tensor with corresponding complex dtype
|
// and returns back a tensor with corresponding complex dtype
|
||||||
Tensor view_as_complex(const Tensor& self) {
|
Tensor view_as_complex(const Tensor& self) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf,
|
self.scalar_type() == kFloat || self.scalar_type() == kDouble,
|
||||||
"view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type());
|
"view_as_complex is only supported for float and double tensors, but got a tensor of scalar type: ", self.scalar_type());
|
||||||
|
|
||||||
auto old_sizes = self.sizes();
|
auto old_sizes = self.sizes();
|
||||||
TORCH_CHECK(old_sizes.size() != 0, "Input tensor must have one or more dimensions");
|
TORCH_CHECK(old_sizes.size() != 0, "Input tensor must have one or more dimensions");
|
||||||
|
|||||||
@ -314,6 +314,16 @@ static inline void singleCheckErrors(int64_t info, const c10::string_view name,
|
|||||||
batch_string = ": (Batch element " + std::to_string(batch_id) + ")";
|
batch_string = ": (Batch element " + std::to_string(batch_id) + ")";
|
||||||
}
|
}
|
||||||
if (info < 0) {
|
if (info < 0) {
|
||||||
|
// Reference LAPACK 3.10+ changed `info` behavior for inputs with non-finite values
|
||||||
|
// Previously, it would return `info` > 0, but now it returns `info` = -4
|
||||||
|
// OpenBLAS 0.3.15+ uses the Reference LAPACK 3.10+.
|
||||||
|
// MKL 2022.0+ uses the Reference LAPACK 3.10+.
|
||||||
|
// Older version of MKL and OpenBLAS follow the old behavior (return `info` > 0).
|
||||||
|
// Here we check for the case where `info` is -4 and raise an error
|
||||||
|
if (name.find("svd") != name.npos) {
|
||||||
|
TORCH_CHECK_LINALG(info != -4, name, batch_string,
|
||||||
|
": The algorithm failed to converge because the input matrix contained non-finite values.");
|
||||||
|
}
|
||||||
TORCH_INTERNAL_ASSERT(false, name, batch_string,
|
TORCH_INTERNAL_ASSERT(false, name, batch_string,
|
||||||
": Argument ", -info, " has illegal value. Most certainly there is a bug in the implementation calling the backend library.");
|
": Argument ", -info, " has illegal value. Most certainly there is a bug in the implementation calling the backend library.");
|
||||||
} else if (info > 0) {
|
} else if (info > 0) {
|
||||||
|
|||||||
@ -102,11 +102,7 @@ if(CUDA_FOUND)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Find cuDNN.
|
# Find cuDNN.
|
||||||
if(CAFFE2_STATIC_LINK_CUDA AND NOT USE_STATIC_CUDNN)
|
if(USE_STATIC_CUDNN)
|
||||||
message(WARNING "cuDNN will be linked statically because CAFFE2_STATIC_LINK_CUDA is ON. "
|
|
||||||
"Set USE_STATIC_CUDNN to ON to suppress this warning.")
|
|
||||||
endif()
|
|
||||||
if(CAFFE2_STATIC_LINK_CUDA OR USE_STATIC_CUDNN)
|
|
||||||
set(CUDNN_STATIC ON CACHE BOOL "")
|
set(CUDNN_STATIC ON CACHE BOOL "")
|
||||||
else()
|
else()
|
||||||
set(CUDNN_STATIC OFF CACHE BOOL "")
|
set(CUDNN_STATIC OFF CACHE BOOL "")
|
||||||
|
|||||||
@ -590,6 +590,7 @@ Torch functions specific to sparse Tensors
|
|||||||
sparse_csr_tensor
|
sparse_csr_tensor
|
||||||
sparse.sum
|
sparse.sum
|
||||||
sparse.addmm
|
sparse.addmm
|
||||||
|
sparse.sampled_addmm
|
||||||
sparse.mm
|
sparse.mm
|
||||||
sspaddmm
|
sspaddmm
|
||||||
hspmm
|
hspmm
|
||||||
|
|||||||
@ -21,7 +21,6 @@ Data type dtype
|
|||||||
64-bit floating point ``torch.float64`` or ``torch.double`` :class:`torch.DoubleTensor` :class:`torch.cuda.DoubleTensor`
|
64-bit floating point ``torch.float64`` or ``torch.double`` :class:`torch.DoubleTensor` :class:`torch.cuda.DoubleTensor`
|
||||||
16-bit floating point [1]_ ``torch.float16`` or ``torch.half`` :class:`torch.HalfTensor` :class:`torch.cuda.HalfTensor`
|
16-bit floating point [1]_ ``torch.float16`` or ``torch.half`` :class:`torch.HalfTensor` :class:`torch.cuda.HalfTensor`
|
||||||
16-bit floating point [2]_ ``torch.bfloat16`` :class:`torch.BFloat16Tensor` :class:`torch.cuda.BFloat16Tensor`
|
16-bit floating point [2]_ ``torch.bfloat16`` :class:`torch.BFloat16Tensor` :class:`torch.cuda.BFloat16Tensor`
|
||||||
32-bit complex ``torch.complex32``
|
|
||||||
64-bit complex ``torch.complex64``
|
64-bit complex ``torch.complex64``
|
||||||
128-bit complex ``torch.complex128`` or ``torch.cdouble``
|
128-bit complex ``torch.complex128`` or ``torch.cdouble``
|
||||||
8-bit integer (unsigned) ``torch.uint8`` :class:`torch.ByteTensor` :class:`torch.cuda.ByteTensor`
|
8-bit integer (unsigned) ``torch.uint8`` :class:`torch.ByteTensor` :class:`torch.cuda.ByteTensor`
|
||||||
|
|||||||
@ -610,5 +610,4 @@ Utilities
|
|||||||
get_deterministic_debug_mode
|
get_deterministic_debug_mode
|
||||||
set_warn_always
|
set_warn_always
|
||||||
is_warn_always_enabled
|
is_warn_always_enabled
|
||||||
vmap
|
|
||||||
_assert
|
_assert
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from typing import Tuple
|
|||||||
from unittest.case import expectedFailure
|
from unittest.case import expectedFailure
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import complex32, float32, float64, int32, int64
|
from torch import float32, int32, int64
|
||||||
from torch.jit._passes import _property_propagation
|
from torch.jit._passes import _property_propagation
|
||||||
from torch.testing._internal.common_methods_invocations import (
|
from torch.testing._internal.common_methods_invocations import (
|
||||||
SampleInput,
|
SampleInput,
|
||||||
@ -191,7 +191,6 @@ class TestDtypeAnalysis(TestDtypeBase):
|
|||||||
input_dtypes = [
|
input_dtypes = [
|
||||||
(float32,), # Simple Case
|
(float32,), # Simple Case
|
||||||
(int64,), # Test how some unary ops implicitly convert to float
|
(int64,), # Test how some unary ops implicitly convert to float
|
||||||
(complex32,), # Show we can handle complex vals as well
|
|
||||||
]
|
]
|
||||||
|
|
||||||
for fn, in_shapes, in_dtypes in product(functions, input_shapes, input_dtypes):
|
for fn, in_shapes, in_dtypes in product(functions, input_shapes, input_dtypes):
|
||||||
@ -220,7 +219,6 @@ class TestDtypeAnalysis(TestDtypeBase):
|
|||||||
(int32, int64), # Size Promotion (compliated case for 0dim tensors)
|
(int32, int64), # Size Promotion (compliated case for 0dim tensors)
|
||||||
(float32, int32), # type Promotion
|
(float32, int32), # type Promotion
|
||||||
(int64, float32), # Type promotion with size change
|
(int64, float32), # Type promotion with size change
|
||||||
(float64, complex32), # Show we can handle complex vals as well
|
|
||||||
]
|
]
|
||||||
|
|
||||||
for fn, in_shapes, in_dtypes in product(functions, input_shapes, input_dtypes):
|
for fn, in_shapes, in_dtypes in product(functions, input_shapes, input_dtypes):
|
||||||
|
|||||||
@ -13,13 +13,11 @@ def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
|
|||||||
|
|
||||||
for opset_version in opset_versions:
|
for opset_version in opset_versions:
|
||||||
self.opset_version = opset_version
|
self.opset_version = opset_version
|
||||||
|
self.onnx_shape_inference = True
|
||||||
run_model_test(self, model, False,
|
run_model_test(self, model, False,
|
||||||
input=inputs, rtol=rtol, atol=atol)
|
input=inputs, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
if self.is_script_test_enabled and opset_version > 11:
|
if self.is_script_test_enabled and opset_version > 11:
|
||||||
TestModels.onnx_shape_inference = True
|
|
||||||
|
|
||||||
outputs = model(inputs)
|
|
||||||
script_model = torch.jit.script(model)
|
script_model = torch.jit.script(model)
|
||||||
run_model_test(self, script_model, False,
|
run_model_test(self, script_model, False,
|
||||||
input=inputs, rtol=rtol, atol=atol)
|
input=inputs, rtol=rtol, atol=atol)
|
||||||
|
|||||||
@ -1835,7 +1835,6 @@ class TestTEFuser(JitTestCase):
|
|||||||
|
|
||||||
unsupported_dtypes = [
|
unsupported_dtypes = [
|
||||||
torch.uint8,
|
torch.uint8,
|
||||||
torch.complex32,
|
|
||||||
torch.complex64,
|
torch.complex64,
|
||||||
torch.complex128,
|
torch.complex128,
|
||||||
torch.qint8,
|
torch.qint8,
|
||||||
|
|||||||
@ -1121,8 +1121,6 @@ class TestLinalg(TestCase):
|
|||||||
return torch.float
|
return torch.float
|
||||||
elif dtype == torch.cdouble:
|
elif dtype == torch.cdouble:
|
||||||
return torch.double
|
return torch.double
|
||||||
elif dtype == torch.complex32:
|
|
||||||
return torch.float16
|
|
||||||
else:
|
else:
|
||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
@ -1685,11 +1683,10 @@ class TestLinalg(TestCase):
|
|||||||
@skipCUDAIfNoMagma
|
@skipCUDAIfNoMagma
|
||||||
@skipCPUIfNoLapack
|
@skipCPUIfNoLapack
|
||||||
def test_norm_extreme_values(self, device):
|
def test_norm_extreme_values(self, device):
|
||||||
if torch.device(device).type == 'cpu':
|
|
||||||
self.skipTest("Test broken on cpu (see gh-71645)")
|
|
||||||
|
|
||||||
vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf]
|
vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf]
|
||||||
matrix_ords = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf]
|
# matrix_ords 'nuc', 2, -2 are skipped currently
|
||||||
|
# See issue https://github.com/pytorch/pytorch/issues/71911
|
||||||
|
matrix_ords = ['fro', 1, inf, -1, -inf]
|
||||||
vectors = []
|
vectors = []
|
||||||
matrices = []
|
matrices = []
|
||||||
for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2):
|
for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2):
|
||||||
@ -1727,8 +1724,8 @@ class TestLinalg(TestCase):
|
|||||||
if is_broken_matrix_norm_case(ord, x):
|
if is_broken_matrix_norm_case(ord, x):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
result = torch.linalg.norm(x, ord=ord)
|
|
||||||
result_n = np.linalg.norm(x_n, ord=ord)
|
result_n = np.linalg.norm(x_n, ord=ord)
|
||||||
|
result = torch.linalg.norm(x, ord=ord)
|
||||||
self.assertEqual(result, result_n, msg=msg)
|
self.assertEqual(result, result_n, msg=msg)
|
||||||
|
|
||||||
# Test degenerate shape results match numpy for linalg.norm vector norms
|
# Test degenerate shape results match numpy for linalg.norm vector norms
|
||||||
@ -2651,6 +2648,28 @@ class TestLinalg(TestCase):
|
|||||||
result = torch.linalg.svd(a, full_matrices=False)
|
result = torch.linalg.svd(a, full_matrices=False)
|
||||||
self.assertEqual(result.S, S)
|
self.assertEqual(result.S, S)
|
||||||
|
|
||||||
|
# This test doesn't work with MAGMA backend https://github.com/pytorch/pytorch/issues/72106
|
||||||
|
@skipMeta
|
||||||
|
@skipCUDAIfRocm
|
||||||
|
@skipCUDAIfNoCusolver
|
||||||
|
@skipCPUIfNoLapack
|
||||||
|
@dtypes(*floating_and_complex_types())
|
||||||
|
def test_svd_nan_error(self, device, dtype):
|
||||||
|
for svd in [torch.svd, torch.linalg.svd]:
|
||||||
|
# if input contains NaN then an error is triggered for svd
|
||||||
|
# When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan.
|
||||||
|
# When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
|
||||||
|
error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|The algorithm failed to converge)'
|
||||||
|
a = torch.full((3, 3), float('nan'), dtype=dtype, device=device)
|
||||||
|
a[0] = float('nan')
|
||||||
|
with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg):
|
||||||
|
svd(a)
|
||||||
|
error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|\(Batch element 1\): The algorithm failed to converge)'
|
||||||
|
a = torch.randn(3, 33, 33, dtype=dtype, device=device)
|
||||||
|
a[1, 0, 0] = float('nan')
|
||||||
|
with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg):
|
||||||
|
svd(a)
|
||||||
|
|
||||||
def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype):
|
def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype):
|
||||||
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
||||||
|
|
||||||
|
|||||||
@ -2490,7 +2490,7 @@ class TestReductions(TestCase):
|
|||||||
return
|
return
|
||||||
self.fail("Failed to hit RuntimeError!")
|
self.fail("Failed to hit RuntimeError!")
|
||||||
|
|
||||||
exact_dtype = input.dtype not in (torch.bfloat16, torch.complex32, torch.complex64, torch.complex128)
|
exact_dtype = input.dtype not in (torch.bfloat16, torch.complex64, torch.complex128)
|
||||||
self.assertEqual(torch_result, numpy_result, exact_dtype=exact_dtype)
|
self.assertEqual(torch_result, numpy_result, exact_dtype=exact_dtype)
|
||||||
|
|
||||||
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
|
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
|
||||||
|
|||||||
@ -727,7 +727,6 @@ class TestFFT(TestCase):
|
|||||||
# Legacy fft tests
|
# Legacy fft tests
|
||||||
def _test_fft_ifft_rfft_irfft(self, device, dtype):
|
def _test_fft_ifft_rfft_irfft(self, device, dtype):
|
||||||
complex_dtype = {
|
complex_dtype = {
|
||||||
torch.float16: torch.complex32,
|
|
||||||
torch.float32: torch.complex64,
|
torch.float32: torch.complex64,
|
||||||
torch.float64: torch.complex128
|
torch.float64: torch.complex128
|
||||||
}[dtype]
|
}[dtype]
|
||||||
|
|||||||
@ -1630,6 +1630,10 @@ else:
|
|||||||
self.assertEqual(a_with_output.dtype, y.dtype)
|
self.assertEqual(a_with_output.dtype, y.dtype)
|
||||||
self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
|
self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
|
||||||
|
|
||||||
|
def test_unsupported_complex_type(self, device):
|
||||||
|
with self.assertRaisesRegex(AttributeError, r'module \'torch\' has no attribute \'complex32\''):
|
||||||
|
torch.tensor(1j, dtype=torch.complex32, device=device)
|
||||||
|
|
||||||
@dtypes(*get_all_fp_dtypes(include_half=False, include_bfloat16=False))
|
@dtypes(*get_all_fp_dtypes(include_half=False, include_bfloat16=False))
|
||||||
@dtypesIfCPU(*(get_all_fp_dtypes(include_half=False, include_bfloat16=True)))
|
@dtypesIfCPU(*(get_all_fp_dtypes(include_half=False, include_bfloat16=True)))
|
||||||
@dtypesIfCUDA(*(get_all_fp_dtypes(include_bfloat16=False)))
|
@dtypesIfCUDA(*(get_all_fp_dtypes(include_bfloat16=False)))
|
||||||
|
|||||||
@ -302,7 +302,7 @@ class TestViewOps(TestCase):
|
|||||||
self.assertEqual(res.shape, torch.Size([0]))
|
self.assertEqual(res.shape, torch.Size([0]))
|
||||||
|
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
@dtypes(*get_all_complex_dtypes(include_complex32=True))
|
@dtypes(*get_all_complex_dtypes())
|
||||||
def test_view_as_real(self, device, dtype):
|
def test_view_as_real(self, device, dtype):
|
||||||
def fn(contiguous_input=True):
|
def fn(contiguous_input=True):
|
||||||
t = torch.randn(3, 4, dtype=dtype, device=device)
|
t = torch.randn(3, 4, dtype=dtype, device=device)
|
||||||
@ -310,11 +310,7 @@ class TestViewOps(TestCase):
|
|||||||
res = torch.view_as_real(input)
|
res = torch.view_as_real(input)
|
||||||
self.assertEqual(res[:, :, 0], input.real)
|
self.assertEqual(res[:, :, 0], input.real)
|
||||||
self.assertEqual(res[:, :, 1], input.imag)
|
self.assertEqual(res[:, :, 1], input.imag)
|
||||||
# TODO: Add torch.ComplexHalfStorage
|
self.assertTrue(self.is_view_of(t, res))
|
||||||
if dtype != torch.complex32:
|
|
||||||
self.assertTrue(self.is_view_of(t, res))
|
|
||||||
else:
|
|
||||||
self.assertRaises(RuntimeError, lambda: self.is_view_of(t, res))
|
|
||||||
|
|
||||||
fn()
|
fn()
|
||||||
fn(contiguous_input=False)
|
fn(contiguous_input=False)
|
||||||
@ -322,21 +318,12 @@ class TestViewOps(TestCase):
|
|||||||
# tensor with zero elements
|
# tensor with zero elements
|
||||||
x = torch.tensor([], dtype=dtype, device=device)
|
x = torch.tensor([], dtype=dtype, device=device)
|
||||||
res = torch.view_as_real(x)
|
res = torch.view_as_real(x)
|
||||||
# TODO: Add torch.ComplexHalfStorage
|
self.assertTrue(self.is_view_of(x, res))
|
||||||
if dtype != torch.complex32:
|
|
||||||
self.assertTrue(self.is_view_of(x, res))
|
|
||||||
else:
|
|
||||||
self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res))
|
|
||||||
self.assertEqual(res.shape, torch.Size([0, 2]))
|
|
||||||
|
|
||||||
# tensor with zero dim
|
# tensor with zero dim
|
||||||
x = torch.tensor(2 + 3j, dtype=dtype, device=device)
|
x = torch.tensor(2 + 3j, dtype=dtype, device=device)
|
||||||
res = torch.view_as_real(x)
|
res = torch.view_as_real(x)
|
||||||
# TODO: Add torch.ComplexHalfStorage
|
self.assertTrue(self.is_view_of(x, res))
|
||||||
if dtype != torch.complex32:
|
|
||||||
self.assertTrue(self.is_view_of(x, res))
|
|
||||||
else:
|
|
||||||
self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res))
|
|
||||||
self.assertEqual(res.shape, torch.Size([2]))
|
self.assertEqual(res.shape, torch.Size([2]))
|
||||||
|
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
|
|||||||
@ -3,7 +3,8 @@
|
|||||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor, vmap
|
from torch import Tensor
|
||||||
|
from torch._vmap_internals import vmap
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import warnings
|
import warnings
|
||||||
|
|||||||
@ -577,7 +577,7 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
|
|||||||
for n in
|
for n in
|
||||||
['float32', 'float', 'float64', 'double', 'float16', 'bfloat16', 'half',
|
['float32', 'float', 'float64', 'double', 'float16', 'bfloat16', 'half',
|
||||||
'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long',
|
'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long',
|
||||||
'complex32', 'complex64', 'cfloat', 'complex128', 'cdouble',
|
'complex64', 'cfloat', 'complex128', 'cdouble',
|
||||||
'quint8', 'qint8', 'qint32', 'bool', 'quint4x2', 'quint2x4']]
|
'quint8', 'qint8', 'qint32', 'bool', 'quint4x2', 'quint2x4']]
|
||||||
|
|
||||||
# Generate __all__ directive
|
# Generate __all__ directive
|
||||||
|
|||||||
@ -856,8 +856,6 @@ del register_after_fork
|
|||||||
# torch.jit.script as a decorator, for instance):
|
# torch.jit.script as a decorator, for instance):
|
||||||
from ._lobpcg import lobpcg as lobpcg
|
from ._lobpcg import lobpcg as lobpcg
|
||||||
|
|
||||||
from ._vmap_internals import vmap as vmap
|
|
||||||
|
|
||||||
# These were previously defined in native_functions.yaml and appeared on the
|
# These were previously defined in native_functions.yaml and appeared on the
|
||||||
# `torch` namespace, but we moved them to c10 dispatch to facilitate custom
|
# `torch` namespace, but we moved them to c10 dispatch to facilitate custom
|
||||||
# class usage. We add these lines here to preserve backward compatibility.
|
# class usage. We add these lines here to preserve backward compatibility.
|
||||||
|
|||||||
@ -73,6 +73,8 @@ void initializeDtypes() {
|
|||||||
std::string primary_name, legacy_name;
|
std::string primary_name, legacy_name;
|
||||||
std::tie(primary_name, legacy_name) = getDtypeNames(scalarType);
|
std::tie(primary_name, legacy_name) = getDtypeNames(scalarType);
|
||||||
PyObject *dtype = THPDtype_New(scalarType, primary_name);
|
PyObject *dtype = THPDtype_New(scalarType, primary_name);
|
||||||
|
// disable complex32 dtype
|
||||||
|
if (primary_name == "complex32") continue;
|
||||||
torch::registerDtypeObject((THPDtype*)dtype, scalarType);
|
torch::registerDtypeObject((THPDtype*)dtype, scalarType);
|
||||||
Py_INCREF(dtype);
|
Py_INCREF(dtype);
|
||||||
if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) !=
|
if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) !=
|
||||||
|
|||||||
@ -944,7 +944,7 @@ scalar_type_to_pytorch_type = [
|
|||||||
torch.half, # 5
|
torch.half, # 5
|
||||||
torch.float, # 6
|
torch.float, # 6
|
||||||
torch.double, # 7
|
torch.double, # 7
|
||||||
torch.complex32, # 8
|
None, # 8
|
||||||
torch.complex64, # 9
|
torch.complex64, # 9
|
||||||
torch.complex128, # 10
|
torch.complex128, # 10
|
||||||
torch.bool, # 11
|
torch.bool, # 11
|
||||||
|
|||||||
@ -98,7 +98,9 @@ Performs a matrix multiplication of the dense matrices :attr:`mat1` and :attr:`m
|
|||||||
specified by the sparsity pattern of :attr:`input`. The matrix :attr:`input` is added to the final result.
|
specified by the sparsity pattern of :attr:`input`. The matrix :attr:`input` is added to the final result.
|
||||||
|
|
||||||
Mathematically this performs the following operation:
|
Mathematically this performs the following operation:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
\text{out} = \alpha\ (\text{mat1} \mathbin{@} \text{mat2})*\text{spy}(\text{input}) + \beta\ \text{input}
|
\text{out} = \alpha\ (\text{mat1} \mathbin{@} \text{mat2})*\text{spy}(\text{input}) + \beta\ \text{input}
|
||||||
|
|
||||||
where :math:`\text{spy}(\text{input})` is the sparsity pattern matrix of :attr:`input`, :attr:`alpha`
|
where :math:`\text{spy}(\text{input})` is the sparsity pattern matrix of :attr:`input`, :attr:`alpha`
|
||||||
|
|||||||
@ -45,7 +45,6 @@ _DTYPE_PRECISIONS = {
|
|||||||
torch.bfloat16: (0.016, 1e-5),
|
torch.bfloat16: (0.016, 1e-5),
|
||||||
torch.float32: (1.3e-6, 1e-5),
|
torch.float32: (1.3e-6, 1e-5),
|
||||||
torch.float64: (1e-7, 1e-7),
|
torch.float64: (1e-7, 1e-7),
|
||||||
torch.complex32: (0.001, 1e-5),
|
|
||||||
torch.complex64: (1.3e-6, 1e-5),
|
torch.complex64: (1.3e-6, 1e-5),
|
||||||
torch.complex128: (1e-7, 1e-7),
|
torch.complex128: (1e-7, 1e-7),
|
||||||
}
|
}
|
||||||
@ -1185,8 +1184,6 @@ def assert_close(
|
|||||||
+---------------------------+------------+----------+
|
+---------------------------+------------+----------+
|
||||||
| :attr:`~torch.float64` | ``1e-7`` | ``1e-7`` |
|
| :attr:`~torch.float64` | ``1e-7`` | ``1e-7`` |
|
||||||
+---------------------------+------------+----------+
|
+---------------------------+------------+----------+
|
||||||
| :attr:`~torch.complex32` | ``1e-3`` | ``1e-5`` |
|
|
||||||
+---------------------------+------------+----------+
|
|
||||||
| :attr:`~torch.complex64` | ``1.3e-6`` | ``1e-5`` |
|
| :attr:`~torch.complex64` | ``1.3e-6`` | ``1e-5`` |
|
||||||
+---------------------------+------------+----------+
|
+---------------------------+------------+----------+
|
||||||
| :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` |
|
| :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` |
|
||||||
|
|||||||
@ -8888,7 +8888,7 @@ op_db: List[OpInfo] = [
|
|||||||
test_conjugated_samples=False,
|
test_conjugated_samples=False,
|
||||||
),
|
),
|
||||||
OpInfo('view_as_complex',
|
OpInfo('view_as_complex',
|
||||||
dtypes=floating_types_and(torch.half),
|
dtypes=floating_types(),
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
|
|||||||
@ -112,22 +112,21 @@ def all_types_and_half():
|
|||||||
def get_all_dtypes(include_half=True,
|
def get_all_dtypes(include_half=True,
|
||||||
include_bfloat16=True,
|
include_bfloat16=True,
|
||||||
include_bool=True,
|
include_bool=True,
|
||||||
include_complex=True,
|
include_complex=True
|
||||||
include_complex32=False
|
|
||||||
) -> List[torch.dtype]:
|
) -> List[torch.dtype]:
|
||||||
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16)
|
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16)
|
||||||
if include_bool:
|
if include_bool:
|
||||||
dtypes.append(torch.bool)
|
dtypes.append(torch.bool)
|
||||||
if include_complex:
|
if include_complex:
|
||||||
dtypes += get_all_complex_dtypes(include_complex32)
|
dtypes += get_all_complex_dtypes()
|
||||||
return dtypes
|
return dtypes
|
||||||
|
|
||||||
def get_all_math_dtypes(device) -> List[torch.dtype]:
|
def get_all_math_dtypes(device) -> List[torch.dtype]:
|
||||||
return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'),
|
return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'),
|
||||||
include_bfloat16=False) + get_all_complex_dtypes()
|
include_bfloat16=False) + get_all_complex_dtypes()
|
||||||
|
|
||||||
def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
|
def get_all_complex_dtypes() -> List[torch.dtype]:
|
||||||
return [torch.complex32, torch.complex64, torch.complex128] if include_complex32 else [torch.complex64, torch.complex128]
|
return [torch.complex64, torch.complex128]
|
||||||
|
|
||||||
|
|
||||||
def get_all_int_dtypes() -> List[torch.dtype]:
|
def get_all_int_dtypes() -> List[torch.dtype]:
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
1.12.0a0
|
1.11.0a0
|
||||||
|
|||||||
Reference in New Issue
Block a user