mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Compare commits
98 Commits
Author | SHA1 | Date | |
---|---|---|---|
ee77ccbb6d | |||
f7f538566e | |||
432724b3e2 | |||
cc98c93bf3 | |||
eb8e8c1bcf | |||
f0a4ac3ee0 | |||
bd9766a36a | |||
80cca51adc | |||
9d45ee1d81 | |||
de394b672d | |||
92c6401bb9 | |||
b4f32dd292 | |||
3e451b4796 | |||
036a591556 | |||
86d9ee8dee | |||
aa44ffb4c9 | |||
162b054e39 | |||
7f044f7398 | |||
0c81d6ba4b | |||
d1752f2bf8 | |||
49fbeb8cc8 | |||
f0d3fc70b4 | |||
fb489555a9 | |||
b5144f1068 | |||
a5c08a6abd | |||
6cc759269f | |||
6742476ba3 | |||
067aee5f30 | |||
0a7f7e6d30 | |||
e9fc91cbca | |||
23df957e94 | |||
a7b161c08b | |||
6bae48c127 | |||
c248943743 | |||
e058a37fe4 | |||
aa7112a618 | |||
b728ffabc3 | |||
d67898a93b | |||
9a25673478 | |||
17613ad73c | |||
beaae6a2b6 | |||
328f49968c | |||
7f9096f868 | |||
7e94ee235f | |||
318fb8e8b9 | |||
43bb1b2356 | |||
87fbd27cc0 | |||
225c38b719 | |||
8074526e7f | |||
06a866de94 | |||
9b22a55499 | |||
f8d3eac4c3 | |||
68b4d22da7 | |||
66b73b0950 | |||
32eb3b8d7b | |||
5a2a34cd2d | |||
a8083e18e8 | |||
bc3fb36ed7 | |||
e0822f1089 | |||
cbfd4e05e9 | |||
b9a2c8ac5c | |||
6d7a73c0da | |||
15e4827617 | |||
024fa34700 | |||
95d2c7fc98 | |||
7ba2baee00 | |||
ccf3a6de3d | |||
1ba6fc4ca6 | |||
d4d4bf5686 | |||
cee965fae9 | |||
544c16cdbf | |||
494a5563b4 | |||
ba4c3a1c2c | |||
c556f9052f | |||
5c80dd3c1f | |||
2d8ee11139 | |||
ebc2519bec | |||
8c9e4b250d | |||
6a6f047fc6 | |||
deadc27c23 | |||
6276fda119 | |||
65ee8f2c23 | |||
6126cfab2c | |||
e2f6fed611 | |||
667deb92f7 | |||
0e88de5580 | |||
3c8ce2a57e | |||
f2080fb3f2 | |||
84afb7b0c1 | |||
b6e976ae2d | |||
8626a1cc81 | |||
831566ec90 | |||
f7b3b20457 | |||
8ce38cf27d | |||
a94f9c7246 | |||
2fc3bb8571 | |||
f694d4d872 | |||
d7b6d945eb |
@ -17,7 +17,7 @@ CONFIG_TREE_DATA = [
|
||||
X("py2"),
|
||||
XImportant("cmake"),
|
||||
]),
|
||||
(Ver("cuda", "9.1"), [XImportant("py2")]),
|
||||
(Ver("cuda", "10.1"), [XImportant("py3.5")]), # TensorRT 6 build
|
||||
(Ver("mkl"), [XImportant("py2")]),
|
||||
(Ver("gcc", "5"), [XImportant("onnx_py2")]),
|
||||
(Ver("clang", "3.8"), [X("py2")]),
|
||||
|
@ -14,7 +14,7 @@ from dataclasses import dataclass
|
||||
|
||||
DOCKER_IMAGE_PATH_BASE = "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/"
|
||||
|
||||
DOCKER_IMAGE_VERSION = 315
|
||||
DOCKER_IMAGE_VERSION = 324
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -8,7 +8,7 @@ CONFIG_TREE_DATA = [
|
||||
(None, [
|
||||
XImportant("2.7.9"),
|
||||
X("2.7"),
|
||||
X("3.5"),
|
||||
XImportant("3.5"), # Not run on all PRs, but should be included on [test all]
|
||||
X("nightly"),
|
||||
]),
|
||||
("gcc", [
|
||||
|
@ -304,20 +304,20 @@ jobs:
|
||||
set -e
|
||||
# Pull Docker image and run build
|
||||
echo "DOCKER_IMAGE: "${DOCKER_IMAGE}
|
||||
docker pull ${DOCKER_IMAGE} >/dev/null
|
||||
time docker pull ${DOCKER_IMAGE} >/dev/null
|
||||
export id=$(docker run -t -d -w /var/lib/jenkins ${DOCKER_IMAGE})
|
||||
|
||||
# TODO We may want to move the rebase logic to a separate step after checkout
|
||||
# Rebase to master only if in xenial_py3_6_gcc5_4 case
|
||||
if [[ "${CIRCLE_BRANCH}" != "master" && "${BUILD_ENVIRONMENT}" == *"gcc5"* ]]; then
|
||||
echo "Merge master branch into $CIRCLE_BRANCH before build in environment $BUILD_ENVIRONMENT"
|
||||
# Rebase to v1.3.0 only if in xenial_py3_6_gcc5_4 case
|
||||
if [[ "${CIRCLE_BRANCH}" != "v1.3.0" && "${BUILD_ENVIRONMENT}" == *"gcc5"* ]]; then
|
||||
echo "Merge v1.3.0 branch into $CIRCLE_BRANCH before build in environment $BUILD_ENVIRONMENT"
|
||||
set -x
|
||||
git config --global user.email "circleci.ossci@gmail.com"
|
||||
git config --global user.name "CircleCI"
|
||||
git config remote.origin.url https://github.com/pytorch/pytorch.git
|
||||
git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master
|
||||
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=50 --quiet
|
||||
export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/master`
|
||||
git config --add remote.origin.fetch +refs/heads/v1.3.0:refs/remotes/origin/v1.3.0
|
||||
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/v1.3.0:refs/remotes/origin/v1.3.0 --depth=50 --quiet
|
||||
export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/v1.3.0`
|
||||
echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET}
|
||||
export GIT_COMMIT=${CIRCLE_SHA1}
|
||||
echo "GIT_COMMIT: " ${GIT_COMMIT}
|
||||
@ -326,7 +326,7 @@ jobs:
|
||||
git merge --no-edit --no-ff ${GIT_MERGE_TARGET}
|
||||
set +x
|
||||
else
|
||||
echo "Do NOT merge master branch into $CIRCLE_BRANCH in environment $BUILD_ENVIRONMENT"
|
||||
echo "Do NOT merge v1.3.0 branch into $CIRCLE_BRANCH in environment $BUILD_ENVIRONMENT"
|
||||
fi
|
||||
|
||||
git submodule sync && git submodule update -q --init --recursive
|
||||
@ -363,7 +363,7 @@ jobs:
|
||||
export COMMIT_DOCKER_IMAGE=$output_image
|
||||
fi
|
||||
docker commit "$id" ${COMMIT_DOCKER_IMAGE}
|
||||
docker push ${COMMIT_DOCKER_IMAGE}
|
||||
time docker push ${COMMIT_DOCKER_IMAGE}
|
||||
fi
|
||||
|
||||
pytorch_linux_test:
|
||||
@ -391,7 +391,7 @@ jobs:
|
||||
export COMMIT_DOCKER_IMAGE=$output_image
|
||||
fi
|
||||
echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE}
|
||||
docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
if [ -n "${USE_CUDA_DOCKER_RUNTIME}" ]; then
|
||||
export id=$(docker run --runtime=nvidia -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE})
|
||||
else
|
||||
@ -445,7 +445,7 @@ jobs:
|
||||
chmod +x /home/circleci/project/ci_build_script.sh
|
||||
|
||||
echo "DOCKER_IMAGE: "${DOCKER_IMAGE}
|
||||
docker pull ${DOCKER_IMAGE} >/dev/null
|
||||
time docker pull ${DOCKER_IMAGE} >/dev/null
|
||||
export id=$(docker run -t -d -w /var/lib/jenkins ${DOCKER_IMAGE})
|
||||
docker cp /home/circleci/project/. $id:/var/lib/jenkins/workspace
|
||||
|
||||
@ -460,7 +460,7 @@ jobs:
|
||||
export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1}
|
||||
fi
|
||||
docker commit "$id" ${COMMIT_DOCKER_IMAGE}
|
||||
docker push ${COMMIT_DOCKER_IMAGE}
|
||||
time docker push ${COMMIT_DOCKER_IMAGE}
|
||||
fi
|
||||
|
||||
caffe2_linux_test:
|
||||
@ -515,7 +515,7 @@ jobs:
|
||||
export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1}
|
||||
fi
|
||||
echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE}
|
||||
docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
if [ -n "${USE_CUDA_DOCKER_RUNTIME}" ]; then
|
||||
export id=$(docker run --runtime=nvidia -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE})
|
||||
else
|
||||
@ -551,12 +551,7 @@ jobs:
|
||||
# Reinitialize path (see man page for path_helper(8))
|
||||
eval `/usr/libexec/path_helper -s`
|
||||
|
||||
# Use Homebrew Python if configured to do so
|
||||
if [ "${PYTHON_INSTALLATION}" == "homebrew" ]; then
|
||||
export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH
|
||||
fi
|
||||
|
||||
pip -q install numpy
|
||||
export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH
|
||||
|
||||
# Install Anaconda if we need to
|
||||
if [ -n "${CAFFE2_USE_ANACONDA}" ]; then
|
||||
@ -569,6 +564,8 @@ jobs:
|
||||
source ${TMPDIR}/anaconda/bin/activate
|
||||
fi
|
||||
|
||||
pip -q install numpy
|
||||
|
||||
# Install sccache
|
||||
sudo curl https://s3.amazonaws.com/ossci-macos/sccache --output /usr/local/bin/sccache
|
||||
sudo chmod +x /usr/local/bin/sccache
|
||||
@ -606,7 +603,6 @@ jobs:
|
||||
if which sccache > /dev/null; then
|
||||
sccache --show-stats
|
||||
fi
|
||||
|
||||
binary_linux_build:
|
||||
<<: *binary_linux_build_params
|
||||
steps:
|
||||
@ -923,7 +919,7 @@ jobs:
|
||||
set -e
|
||||
export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1}
|
||||
echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE}
|
||||
docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
export id=$(docker run --runtime=nvidia -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE})
|
||||
|
||||
docker cp $id:/var/lib/jenkins/workspace/env /home/circleci/project/env
|
||||
@ -957,7 +953,7 @@ jobs:
|
||||
set -ex
|
||||
export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1}
|
||||
echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE}
|
||||
docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
export id=$(docker run -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE})
|
||||
|
||||
# master branch docs push
|
||||
@ -965,11 +961,11 @@ jobs:
|
||||
export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/master master site") | docker exec -u jenkins -i "$id" bash) 2>&1'
|
||||
|
||||
# stable release docs push. Due to some circleci limitations, we keep
|
||||
# an eternal PR open for merging v1.2.0 -> master for this job.
|
||||
# XXX: The following code is only run on the v1.2.0 branch, which might
|
||||
# an eternal PR open for merging v1.3.0 -> master for this job.
|
||||
# XXX: The following code is only run on the v1.3.0 branch, which might
|
||||
# not be exactly the same as what you see here.
|
||||
elif [[ "${CIRCLE_BRANCH}" == "v1.2.0" ]]; then
|
||||
export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/stable 1.2.0 site dry_run") | docker exec -u jenkins -i "$id" bash) 2>&1'
|
||||
elif [[ "${CIRCLE_BRANCH}" == "v1.3.0" ]]; then
|
||||
export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/stable 1.3.0 site-v1.3.0 dry_run") | docker exec -u jenkins -i "$id" bash) 2>&1'
|
||||
|
||||
# For open PRs: Do a dry_run of the docs build, don't push build
|
||||
else
|
||||
@ -981,7 +977,7 @@ jobs:
|
||||
# Save the docs build so we can debug any problems
|
||||
export DEBUG_COMMIT_DOCKER_IMAGE=${COMMIT_DOCKER_IMAGE}-debug
|
||||
docker commit "$id" ${DEBUG_COMMIT_DOCKER_IMAGE}
|
||||
docker push ${DEBUG_COMMIT_DOCKER_IMAGE}
|
||||
time docker push ${DEBUG_COMMIT_DOCKER_IMAGE}
|
||||
|
||||
pytorch_cpp_doc_push:
|
||||
environment:
|
||||
@ -1002,7 +998,7 @@ jobs:
|
||||
set -ex
|
||||
export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1}
|
||||
echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE}
|
||||
docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
export id=$(docker run -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE})
|
||||
|
||||
# master branch docs push
|
||||
@ -1026,7 +1022,7 @@ jobs:
|
||||
# Save the docs build so we can debug any problems
|
||||
export DEBUG_COMMIT_DOCKER_IMAGE=${COMMIT_DOCKER_IMAGE}-debug
|
||||
docker commit "$id" ${DEBUG_COMMIT_DOCKER_IMAGE}
|
||||
docker push ${DEBUG_COMMIT_DOCKER_IMAGE}
|
||||
time docker push ${DEBUG_COMMIT_DOCKER_IMAGE}
|
||||
|
||||
pytorch_macos_10_13_py3_build:
|
||||
environment:
|
||||
@ -1171,14 +1167,14 @@ jobs:
|
||||
echo "docker_image_libtorch_android_arm_v8a: "${docker_image_libtorch_android_arm_v8a}
|
||||
|
||||
# x86_32
|
||||
docker pull ${docker_image_libtorch_android_x86_32} >/dev/null
|
||||
time docker pull ${docker_image_libtorch_android_x86_32} >/dev/null
|
||||
export id_x86_32=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_32})
|
||||
|
||||
export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_x86_32" bash) 2>&1'
|
||||
echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts
|
||||
|
||||
# arm-v7a
|
||||
docker pull ${docker_image_libtorch_android_arm_v7a} >/dev/null
|
||||
time docker pull ${docker_image_libtorch_android_arm_v7a} >/dev/null
|
||||
export id_arm_v7a=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_arm_v7a})
|
||||
|
||||
export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_arm_v7a" bash) 2>&1'
|
||||
@ -1188,7 +1184,7 @@ jobs:
|
||||
docker cp $id_arm_v7a:/var/lib/jenkins/workspace/build_android/install ~/workspace/build_android_install_arm_v7a
|
||||
|
||||
# x86_64
|
||||
docker pull ${docker_image_libtorch_android_x86_64} >/dev/null
|
||||
time docker pull ${docker_image_libtorch_android_x86_64} >/dev/null
|
||||
export id_x86_64=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_64})
|
||||
|
||||
export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_x86_64" bash) 2>&1'
|
||||
@ -1198,7 +1194,7 @@ jobs:
|
||||
docker cp $id_x86_64:/var/lib/jenkins/workspace/build_android/install ~/workspace/build_android_install_x86_64
|
||||
|
||||
# arm-v8a
|
||||
docker pull ${docker_image_libtorch_android_arm_v8a} >/dev/null
|
||||
time docker pull ${docker_image_libtorch_android_arm_v8a} >/dev/null
|
||||
export id_arm_v8a=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_arm_v8a})
|
||||
|
||||
export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_arm_v8a" bash) 2>&1'
|
||||
@ -1220,7 +1216,7 @@ jobs:
|
||||
|
||||
output_image=$docker_image_libtorch_android_x86_32-gradle
|
||||
docker commit "$id_x86_32" ${output_image}
|
||||
docker push ${output_image}
|
||||
time docker push ${output_image}
|
||||
- store_artifacts:
|
||||
path: ~/workspace/build_android_artifacts/artifacts.tgz
|
||||
destination: artifacts.tgz
|
||||
@ -1251,7 +1247,7 @@ jobs:
|
||||
echo "docker_image_libtorch_android_x86_32_gradle: "${docker_image_libtorch_android_x86_32_gradle}
|
||||
|
||||
# x86_32
|
||||
docker pull ${docker_image_libtorch_android_x86_32_gradle} >/dev/null
|
||||
time docker pull ${docker_image_libtorch_android_x86_32_gradle} >/dev/null
|
||||
export id_x86_32=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_32_gradle})
|
||||
|
||||
export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace" && echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export SONATYPE_NEXUS_USERNAME=${SONATYPE_NEXUS_USERNAME}" && echo "export SONATYPE_NEXUS_PASSWORD=${SONATYPE_NEXUS_PASSWORD}" && echo "export ANDROID_SIGN_KEY=${ANDROID_SIGN_KEY}" && echo "export ANDROID_SIGN_PASS=${ANDROID_SIGN_PASS}" && echo "sudo chown -R jenkins workspace && cd workspace && ./.circleci/scripts/publish_android_snapshot.sh") | docker exec -u jenkins -i "$id_x86_32" bash) 2>&1'
|
||||
@ -1259,7 +1255,7 @@ jobs:
|
||||
|
||||
output_image=${docker_image_libtorch_android_x86_32_gradle}-publish-snapshot
|
||||
docker commit "$id_x86_32" ${output_image}
|
||||
docker push ${output_image}
|
||||
time docker push ${output_image}
|
||||
|
||||
pytorch_android_gradle_build-x86_32:
|
||||
environment:
|
||||
@ -1291,7 +1287,7 @@ jobs:
|
||||
echo "docker_image_libtorch_android_x86_32: "${docker_image_libtorch_android_x86_32}
|
||||
|
||||
# x86
|
||||
docker pull ${docker_image_libtorch_android_x86_32} >/dev/null
|
||||
time docker pull ${docker_image_libtorch_android_x86_32} >/dev/null
|
||||
export id=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_32})
|
||||
|
||||
export COMMAND='((echo "source ./workspace/env" && echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "sudo chown -R jenkins workspace && cd workspace && ./.circleci/scripts/build_android_gradle.sh") | docker exec -u jenkins -i "$id" bash) 2>&1'
|
||||
@ -1302,7 +1298,7 @@ jobs:
|
||||
|
||||
output_image=${docker_image_libtorch_android_x86_32}-gradle
|
||||
docker commit "$id" ${output_image}
|
||||
docker push ${output_image}
|
||||
time docker push ${output_image}
|
||||
- store_artifacts:
|
||||
path: ~/workspace/build_android_x86_32_artifacts/artifacts.tgz
|
||||
destination: artifacts.tgz
|
||||
@ -1333,7 +1329,7 @@ jobs:
|
||||
export PATH="~/anaconda/bin:${PATH}"
|
||||
source ~/anaconda/bin/activate
|
||||
# Install dependencies
|
||||
conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests
|
||||
conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes
|
||||
# sync submodules
|
||||
cd ${PROJ_ROOT}
|
||||
git submodule sync
|
||||
@ -1518,11 +1514,6 @@ workflows:
|
||||
name: pytorch_linux_xenial_py3_5_build
|
||||
requires:
|
||||
- setup
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
build_environment: "pytorch-linux-xenial-py3.5-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.5:347"
|
||||
- pytorch_linux_test:
|
||||
@ -1530,11 +1521,6 @@ workflows:
|
||||
requires:
|
||||
- setup
|
||||
- pytorch_linux_xenial_py3_5_build
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
build_environment: "pytorch-linux-xenial-py3.5-test"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.5:347"
|
||||
resource_class: large
|
||||
@ -1944,7 +1930,7 @@ workflows:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-gcc4.8-ubuntu14.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:324"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_py2_gcc4_8_ubuntu14_04_test
|
||||
requires:
|
||||
@ -1956,7 +1942,7 @@ workflows:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-gcc4.8-ubuntu14.04-test"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:324"
|
||||
resource_class: large
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py2_cuda9_0_cudnn7_ubuntu16_04_build
|
||||
@ -1968,7 +1954,7 @@ workflows:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-cuda9.0-cudnn7-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:324"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_py2_cuda9_0_cudnn7_ubuntu16_04_test
|
||||
requires:
|
||||
@ -1981,14 +1967,14 @@ workflows:
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-cuda9.0-cudnn7-ubuntu16.04-test"
|
||||
use_cuda_docker_runtime: "1"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:324"
|
||||
resource_class: gpu.medium
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_cmake_cuda9_0_cudnn7_ubuntu16_04_build
|
||||
requires:
|
||||
- setup
|
||||
build_environment: "caffe2-cmake-cuda9.0-cudnn7-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:324"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_cmake_cuda9_0_cudnn7_ubuntu16_04_test
|
||||
requires:
|
||||
@ -1996,50 +1982,50 @@ workflows:
|
||||
- caffe2_cmake_cuda9_0_cudnn7_ubuntu16_04_build
|
||||
build_environment: "caffe2-cmake-cuda9.0-cudnn7-ubuntu16.04-test"
|
||||
use_cuda_docker_runtime: "1"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:324"
|
||||
resource_class: gpu.medium
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py2_cuda9_1_cudnn7_ubuntu16_04_build
|
||||
name: caffe2_py3_5_cuda10_1_cudnn7_ubuntu16_04_build
|
||||
requires:
|
||||
- setup
|
||||
build_environment: "caffe2-py2-cuda9.1-cudnn7-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.1-cudnn7-ubuntu16.04:315"
|
||||
build_environment: "caffe2-py3.5-cuda10.1-cudnn7-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.5-cuda10.1-cudnn7-ubuntu16.04:324"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_py2_cuda9_1_cudnn7_ubuntu16_04_test
|
||||
name: caffe2_py3_5_cuda10_1_cudnn7_ubuntu16_04_test
|
||||
requires:
|
||||
- setup
|
||||
- caffe2_py2_cuda9_1_cudnn7_ubuntu16_04_build
|
||||
build_environment: "caffe2-py2-cuda9.1-cudnn7-ubuntu16.04-test"
|
||||
- caffe2_py3_5_cuda10_1_cudnn7_ubuntu16_04_build
|
||||
build_environment: "caffe2-py3.5-cuda10.1-cudnn7-ubuntu16.04-test"
|
||||
use_cuda_docker_runtime: "1"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.1-cudnn7-ubuntu16.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.5-cuda10.1-cudnn7-ubuntu16.04:324"
|
||||
resource_class: gpu.medium
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py2_mkl_ubuntu16_04_build
|
||||
requires:
|
||||
- setup
|
||||
build_environment: "caffe2-py2-mkl-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:324"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_py2_mkl_ubuntu16_04_test
|
||||
requires:
|
||||
- setup
|
||||
- caffe2_py2_mkl_ubuntu16_04_build
|
||||
build_environment: "caffe2-py2-mkl-ubuntu16.04-test"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:324"
|
||||
resource_class: large
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_onnx_py2_gcc5_ubuntu16_04_build
|
||||
requires:
|
||||
- setup
|
||||
build_environment: "caffe2-onnx-py2-gcc5-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:324"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_onnx_py2_gcc5_ubuntu16_04_test
|
||||
requires:
|
||||
- setup
|
||||
- caffe2_onnx_py2_gcc5_ubuntu16_04_build
|
||||
build_environment: "caffe2-onnx-py2-gcc5-ubuntu16.04-test"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:324"
|
||||
resource_class: large
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py2_clang3_8_ubuntu16_04_build
|
||||
@ -2051,7 +2037,7 @@ workflows:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-clang3.8-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.8-ubuntu16.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.8-ubuntu16.04:324"
|
||||
build_only: "1"
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py2_clang3_9_ubuntu16_04_build
|
||||
@ -2063,35 +2049,35 @@ workflows:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-clang3.9-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.9-ubuntu16.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.9-ubuntu16.04:324"
|
||||
build_only: "1"
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py2_clang7_ubuntu16_04_build
|
||||
requires:
|
||||
- setup
|
||||
build_environment: "caffe2-py2-clang7-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang7-ubuntu16.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang7-ubuntu16.04:324"
|
||||
build_only: "1"
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_onnx_py3_6_clang7_ubuntu16_04_build
|
||||
requires:
|
||||
- setup
|
||||
build_environment: "caffe2-onnx-py3.6-clang7-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:324"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_onnx_py3_6_clang7_ubuntu16_04_test
|
||||
requires:
|
||||
- setup
|
||||
- caffe2_onnx_py3_6_clang7_ubuntu16_04_build
|
||||
build_environment: "caffe2-onnx-py3.6-clang7-ubuntu16.04-test"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:324"
|
||||
resource_class: large
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py2_android_ubuntu16_04_build
|
||||
requires:
|
||||
- setup
|
||||
build_environment: "caffe2-py2-android-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-android-ubuntu16.04:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-android-ubuntu16.04:324"
|
||||
build_only: "1"
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py2_cuda9_0_cudnn7_centos7_build
|
||||
@ -2103,7 +2089,7 @@ workflows:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-cuda9.0-cudnn7-centos7-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:324"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_py2_cuda9_0_cudnn7_centos7_test
|
||||
requires:
|
||||
@ -2116,7 +2102,7 @@ workflows:
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-cuda9.0-cudnn7-centos7-test"
|
||||
use_cuda_docker_runtime: "1"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:315"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:324"
|
||||
resource_class: gpu.medium
|
||||
- caffe2_macos_build:
|
||||
name: caffe2_py2_ios_macos10_13_build
|
||||
|
@ -13,7 +13,7 @@ chmod +x ~/Downloads/conda.sh
|
||||
export PATH="~/anaconda/bin:${PATH}"
|
||||
source ~/anaconda/bin/activate
|
||||
# Install dependencies
|
||||
conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests
|
||||
conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes
|
||||
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
|
||||
# sync submodules
|
||||
cd ${PROJ_ROOT}
|
||||
|
@ -22,7 +22,7 @@ default_set = set([
|
||||
# Caffe2 CPU
|
||||
'caffe2-py2-mkl-ubuntu16.04',
|
||||
# Caffe2 CUDA
|
||||
'caffe2-py2-cuda9.1-cudnn7-ubuntu16.04',
|
||||
'caffe2-py3.5-cuda10.1-cudnn7-ubuntu16.04',
|
||||
# Caffe2 ONNX
|
||||
'caffe2-onnx-py2-gcc5-ubuntu16.04',
|
||||
'caffe2-onnx-py3.6-clang7-ubuntu16.04',
|
||||
|
@ -18,7 +18,7 @@ if ! [ -e "$SCRIPT_DIR/COMMIT_MSG" ]; then
|
||||
exit 1
|
||||
fi
|
||||
if [ -n "${CIRCLE_PULL_REQUEST:-}" ]; then
|
||||
if [[ $CIRCLE_BRANCH != "ci-all/"* ]]; then
|
||||
if [[ $CIRCLE_BRANCH != "ci-all/"* ]] && [[ $CIRCLE_BRANCH != "nightly" ]] && [[ $CIRCLE_BRANCH != "postnightly" ]] ; then
|
||||
# Don't swallow "script doesn't exist
|
||||
[ -e "$SCRIPT_DIR/should_run_job.py" ]
|
||||
if ! python "$SCRIPT_DIR/should_run_job.py" "${BUILD_ENVIRONMENT:-}" < "$SCRIPT_DIR/COMMIT_MSG" ; then
|
||||
|
@ -40,7 +40,7 @@
|
||||
chmod +x /home/circleci/project/ci_build_script.sh
|
||||
|
||||
echo "DOCKER_IMAGE: "${DOCKER_IMAGE}
|
||||
docker pull ${DOCKER_IMAGE} >/dev/null
|
||||
time docker pull ${DOCKER_IMAGE} >/dev/null
|
||||
export id=$(docker run -t -d -w /var/lib/jenkins ${DOCKER_IMAGE})
|
||||
docker cp /home/circleci/project/. $id:/var/lib/jenkins/workspace
|
||||
|
||||
@ -55,7 +55,7 @@
|
||||
export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1}
|
||||
fi
|
||||
docker commit "$id" ${COMMIT_DOCKER_IMAGE}
|
||||
docker push ${COMMIT_DOCKER_IMAGE}
|
||||
time docker push ${COMMIT_DOCKER_IMAGE}
|
||||
fi
|
||||
|
||||
caffe2_linux_test:
|
||||
@ -110,7 +110,7 @@
|
||||
export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1}
|
||||
fi
|
||||
echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE}
|
||||
docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
if [ -n "${USE_CUDA_DOCKER_RUNTIME}" ]; then
|
||||
export id=$(docker run --runtime=nvidia -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE})
|
||||
else
|
||||
@ -146,12 +146,7 @@
|
||||
# Reinitialize path (see man page for path_helper(8))
|
||||
eval `/usr/libexec/path_helper -s`
|
||||
|
||||
# Use Homebrew Python if configured to do so
|
||||
if [ "${PYTHON_INSTALLATION}" == "homebrew" ]; then
|
||||
export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH
|
||||
fi
|
||||
|
||||
pip -q install numpy
|
||||
export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH
|
||||
|
||||
# Install Anaconda if we need to
|
||||
if [ -n "${CAFFE2_USE_ANACONDA}" ]; then
|
||||
@ -164,6 +159,8 @@
|
||||
source ${TMPDIR}/anaconda/bin/activate
|
||||
fi
|
||||
|
||||
pip -q install numpy
|
||||
|
||||
# Install sccache
|
||||
sudo curl https://s3.amazonaws.com/ossci-macos/sccache --output /usr/local/bin/sccache
|
||||
sudo chmod +x /usr/local/bin/sccache
|
||||
@ -201,4 +198,3 @@
|
||||
if which sccache > /dev/null; then
|
||||
sccache --show-stats
|
||||
fi
|
||||
|
||||
|
@ -19,7 +19,7 @@
|
||||
set -e
|
||||
export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1}
|
||||
echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE}
|
||||
docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
export id=$(docker run --runtime=nvidia -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE})
|
||||
|
||||
docker cp $id:/var/lib/jenkins/workspace/env /home/circleci/project/env
|
||||
@ -53,7 +53,7 @@
|
||||
set -ex
|
||||
export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1}
|
||||
echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE}
|
||||
docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
export id=$(docker run -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE})
|
||||
|
||||
# master branch docs push
|
||||
@ -61,11 +61,11 @@
|
||||
export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/master master site") | docker exec -u jenkins -i "$id" bash) 2>&1'
|
||||
|
||||
# stable release docs push. Due to some circleci limitations, we keep
|
||||
# an eternal PR open for merging v1.2.0 -> master for this job.
|
||||
# XXX: The following code is only run on the v1.2.0 branch, which might
|
||||
# an eternal PR open for merging v1.3.0 -> master for this job.
|
||||
# XXX: The following code is only run on the v1.3.0 branch, which might
|
||||
# not be exactly the same as what you see here.
|
||||
elif [[ "${CIRCLE_BRANCH}" == "v1.2.0" ]]; then
|
||||
export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/stable 1.2.0 site dry_run") | docker exec -u jenkins -i "$id" bash) 2>&1'
|
||||
elif [[ "${CIRCLE_BRANCH}" == "v1.3.0" ]]; then
|
||||
export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/stable 1.3.0 site-v1.3.0 dry_run") | docker exec -u jenkins -i "$id" bash) 2>&1'
|
||||
|
||||
# For open PRs: Do a dry_run of the docs build, don't push build
|
||||
else
|
||||
@ -77,7 +77,7 @@
|
||||
# Save the docs build so we can debug any problems
|
||||
export DEBUG_COMMIT_DOCKER_IMAGE=${COMMIT_DOCKER_IMAGE}-debug
|
||||
docker commit "$id" ${DEBUG_COMMIT_DOCKER_IMAGE}
|
||||
docker push ${DEBUG_COMMIT_DOCKER_IMAGE}
|
||||
time docker push ${DEBUG_COMMIT_DOCKER_IMAGE}
|
||||
|
||||
pytorch_cpp_doc_push:
|
||||
environment:
|
||||
@ -98,7 +98,7 @@
|
||||
set -ex
|
||||
export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1}
|
||||
echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE}
|
||||
docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
export id=$(docker run -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE})
|
||||
|
||||
# master branch docs push
|
||||
@ -122,7 +122,7 @@
|
||||
# Save the docs build so we can debug any problems
|
||||
export DEBUG_COMMIT_DOCKER_IMAGE=${COMMIT_DOCKER_IMAGE}-debug
|
||||
docker commit "$id" ${DEBUG_COMMIT_DOCKER_IMAGE}
|
||||
docker push ${DEBUG_COMMIT_DOCKER_IMAGE}
|
||||
time docker push ${DEBUG_COMMIT_DOCKER_IMAGE}
|
||||
|
||||
pytorch_macos_10_13_py3_build:
|
||||
environment:
|
||||
@ -267,14 +267,14 @@
|
||||
echo "docker_image_libtorch_android_arm_v8a: "${docker_image_libtorch_android_arm_v8a}
|
||||
|
||||
# x86_32
|
||||
docker pull ${docker_image_libtorch_android_x86_32} >/dev/null
|
||||
time docker pull ${docker_image_libtorch_android_x86_32} >/dev/null
|
||||
export id_x86_32=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_32})
|
||||
|
||||
export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_x86_32" bash) 2>&1'
|
||||
echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts
|
||||
|
||||
# arm-v7a
|
||||
docker pull ${docker_image_libtorch_android_arm_v7a} >/dev/null
|
||||
time docker pull ${docker_image_libtorch_android_arm_v7a} >/dev/null
|
||||
export id_arm_v7a=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_arm_v7a})
|
||||
|
||||
export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_arm_v7a" bash) 2>&1'
|
||||
@ -284,7 +284,7 @@
|
||||
docker cp $id_arm_v7a:/var/lib/jenkins/workspace/build_android/install ~/workspace/build_android_install_arm_v7a
|
||||
|
||||
# x86_64
|
||||
docker pull ${docker_image_libtorch_android_x86_64} >/dev/null
|
||||
time docker pull ${docker_image_libtorch_android_x86_64} >/dev/null
|
||||
export id_x86_64=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_64})
|
||||
|
||||
export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_x86_64" bash) 2>&1'
|
||||
@ -294,7 +294,7 @@
|
||||
docker cp $id_x86_64:/var/lib/jenkins/workspace/build_android/install ~/workspace/build_android_install_x86_64
|
||||
|
||||
# arm-v8a
|
||||
docker pull ${docker_image_libtorch_android_arm_v8a} >/dev/null
|
||||
time docker pull ${docker_image_libtorch_android_arm_v8a} >/dev/null
|
||||
export id_arm_v8a=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_arm_v8a})
|
||||
|
||||
export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_arm_v8a" bash) 2>&1'
|
||||
@ -316,7 +316,7 @@
|
||||
|
||||
output_image=$docker_image_libtorch_android_x86_32-gradle
|
||||
docker commit "$id_x86_32" ${output_image}
|
||||
docker push ${output_image}
|
||||
time docker push ${output_image}
|
||||
- store_artifacts:
|
||||
path: ~/workspace/build_android_artifacts/artifacts.tgz
|
||||
destination: artifacts.tgz
|
||||
@ -347,7 +347,7 @@
|
||||
echo "docker_image_libtorch_android_x86_32_gradle: "${docker_image_libtorch_android_x86_32_gradle}
|
||||
|
||||
# x86_32
|
||||
docker pull ${docker_image_libtorch_android_x86_32_gradle} >/dev/null
|
||||
time docker pull ${docker_image_libtorch_android_x86_32_gradle} >/dev/null
|
||||
export id_x86_32=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_32_gradle})
|
||||
|
||||
export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace" && echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export SONATYPE_NEXUS_USERNAME=${SONATYPE_NEXUS_USERNAME}" && echo "export SONATYPE_NEXUS_PASSWORD=${SONATYPE_NEXUS_PASSWORD}" && echo "export ANDROID_SIGN_KEY=${ANDROID_SIGN_KEY}" && echo "export ANDROID_SIGN_PASS=${ANDROID_SIGN_PASS}" && echo "sudo chown -R jenkins workspace && cd workspace && ./.circleci/scripts/publish_android_snapshot.sh") | docker exec -u jenkins -i "$id_x86_32" bash) 2>&1'
|
||||
@ -355,7 +355,7 @@
|
||||
|
||||
output_image=${docker_image_libtorch_android_x86_32_gradle}-publish-snapshot
|
||||
docker commit "$id_x86_32" ${output_image}
|
||||
docker push ${output_image}
|
||||
time docker push ${output_image}
|
||||
|
||||
pytorch_android_gradle_build-x86_32:
|
||||
environment:
|
||||
@ -387,7 +387,7 @@
|
||||
echo "docker_image_libtorch_android_x86_32: "${docker_image_libtorch_android_x86_32}
|
||||
|
||||
# x86
|
||||
docker pull ${docker_image_libtorch_android_x86_32} >/dev/null
|
||||
time docker pull ${docker_image_libtorch_android_x86_32} >/dev/null
|
||||
export id=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_32})
|
||||
|
||||
export COMMAND='((echo "source ./workspace/env" && echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "sudo chown -R jenkins workspace && cd workspace && ./.circleci/scripts/build_android_gradle.sh") | docker exec -u jenkins -i "$id" bash) 2>&1'
|
||||
@ -398,7 +398,7 @@
|
||||
|
||||
output_image=${docker_image_libtorch_android_x86_32}-gradle
|
||||
docker commit "$id" ${output_image}
|
||||
docker push ${output_image}
|
||||
time docker push ${output_image}
|
||||
- store_artifacts:
|
||||
path: ~/workspace/build_android_x86_32_artifacts/artifacts.tgz
|
||||
destination: artifacts.tgz
|
||||
@ -429,7 +429,7 @@
|
||||
export PATH="~/anaconda/bin:${PATH}"
|
||||
source ~/anaconda/bin/activate
|
||||
# Install dependencies
|
||||
conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests
|
||||
conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes
|
||||
# sync submodules
|
||||
cd ${PROJ_ROOT}
|
||||
git submodule sync
|
||||
|
@ -16,20 +16,20 @@ jobs:
|
||||
set -e
|
||||
# Pull Docker image and run build
|
||||
echo "DOCKER_IMAGE: "${DOCKER_IMAGE}
|
||||
docker pull ${DOCKER_IMAGE} >/dev/null
|
||||
time docker pull ${DOCKER_IMAGE} >/dev/null
|
||||
export id=$(docker run -t -d -w /var/lib/jenkins ${DOCKER_IMAGE})
|
||||
|
||||
# TODO We may want to move the rebase logic to a separate step after checkout
|
||||
# Rebase to master only if in xenial_py3_6_gcc5_4 case
|
||||
if [[ "${CIRCLE_BRANCH}" != "master" && "${BUILD_ENVIRONMENT}" == *"gcc5"* ]]; then
|
||||
echo "Merge master branch into $CIRCLE_BRANCH before build in environment $BUILD_ENVIRONMENT"
|
||||
# Rebase to v1.3.0 only if in xenial_py3_6_gcc5_4 case
|
||||
if [[ "${CIRCLE_BRANCH}" != "v1.3.0" && "${BUILD_ENVIRONMENT}" == *"gcc5"* ]]; then
|
||||
echo "Merge v1.3.0 branch into $CIRCLE_BRANCH before build in environment $BUILD_ENVIRONMENT"
|
||||
set -x
|
||||
git config --global user.email "circleci.ossci@gmail.com"
|
||||
git config --global user.name "CircleCI"
|
||||
git config remote.origin.url https://github.com/pytorch/pytorch.git
|
||||
git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master
|
||||
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=50 --quiet
|
||||
export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/master`
|
||||
git config --add remote.origin.fetch +refs/heads/v1.3.0:refs/remotes/origin/v1.3.0
|
||||
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/v1.3.0:refs/remotes/origin/v1.3.0 --depth=50 --quiet
|
||||
export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/v1.3.0`
|
||||
echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET}
|
||||
export GIT_COMMIT=${CIRCLE_SHA1}
|
||||
echo "GIT_COMMIT: " ${GIT_COMMIT}
|
||||
@ -38,7 +38,7 @@ jobs:
|
||||
git merge --no-edit --no-ff ${GIT_MERGE_TARGET}
|
||||
set +x
|
||||
else
|
||||
echo "Do NOT merge master branch into $CIRCLE_BRANCH in environment $BUILD_ENVIRONMENT"
|
||||
echo "Do NOT merge v1.3.0 branch into $CIRCLE_BRANCH in environment $BUILD_ENVIRONMENT"
|
||||
fi
|
||||
|
||||
git submodule sync && git submodule update -q --init --recursive
|
||||
@ -75,7 +75,7 @@ jobs:
|
||||
export COMMIT_DOCKER_IMAGE=$output_image
|
||||
fi
|
||||
docker commit "$id" ${COMMIT_DOCKER_IMAGE}
|
||||
docker push ${COMMIT_DOCKER_IMAGE}
|
||||
time docker push ${COMMIT_DOCKER_IMAGE}
|
||||
fi
|
||||
|
||||
pytorch_linux_test:
|
||||
@ -103,7 +103,7 @@ jobs:
|
||||
export COMMIT_DOCKER_IMAGE=$output_image
|
||||
fi
|
||||
echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE}
|
||||
docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
|
||||
if [ -n "${USE_CUDA_DOCKER_RUNTIME}" ]; then
|
||||
export id=$(docker run --runtime=nvidia -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE})
|
||||
else
|
||||
|
@ -1,6 +1,6 @@
|
||||
ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64
|
||||
|
||||
VERSION_NAME=0.0.7-SNAPSHOT
|
||||
VERSION_NAME=1.3.0
|
||||
GROUP=org.pytorch
|
||||
MAVEN_GROUP=org.pytorch
|
||||
POM_URL=https://github.com/pytorch/pytorch/tree/master/android
|
||||
|
@ -0,0 +1,30 @@
|
||||
package org.pytorch;
|
||||
|
||||
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||
import com.facebook.soloader.nativeloader.SystemDelegate;
|
||||
import org.junit.BeforeClass;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.Objects;
|
||||
|
||||
public class PytorchHostTests extends PytorchTestBase {
|
||||
@BeforeClass
|
||||
public static void setUpClass() {
|
||||
if (!NativeLoader.isInitialized()) {
|
||||
NativeLoader.init(new SystemDelegate());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String assetFilePath(String assetName) throws IOException {
|
||||
Path tempFile = Files.createTempFile("test", ".pt");
|
||||
try (InputStream resource = Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) {
|
||||
Files.copy(resource, tempFile, StandardCopyOption.REPLACE_EXISTING);
|
||||
}
|
||||
return tempFile.toAbsolutePath().toString();
|
||||
}
|
||||
}
|
@ -2,8 +2,6 @@ package org.pytorch;
|
||||
|
||||
import android.content.Context;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
import java.io.File;
|
||||
@ -11,294 +9,15 @@ import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import androidx.test.platform.app.InstrumentationRegistry;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class PytorchInstrumentedTests {
|
||||
public class PytorchInstrumentedTests extends PytorchTestBase {
|
||||
|
||||
private static final String TEST_MODULE_ASSET_NAME = "test.pt";
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
System.loadLibrary("pytorch");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testForwardNull() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final IValue input =
|
||||
IValue.tensor(Tensor.newInt8Tensor(new long[] {1}, Tensor.allocateByteBuffer(1)));
|
||||
assertTrue(input.isTensor());
|
||||
final IValue output = module.forward(input);
|
||||
assertTrue(output.isNull());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqBool() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
for (boolean value : new boolean[] {false, true}) {
|
||||
final IValue input = IValue.bool(value);
|
||||
assertTrue(input.isBool());
|
||||
assertTrue(value == input.getBool());
|
||||
final IValue output = module.runMethod("eqBool", input);
|
||||
assertTrue(output.isBool());
|
||||
assertTrue(value == output.getBool());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqInt() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
|
||||
final IValue input = IValue.long64(value);
|
||||
assertTrue(input.isLong());
|
||||
assertTrue(value == input.getLong());
|
||||
final IValue output = module.runMethod("eqInt", input);
|
||||
assertTrue(output.isLong());
|
||||
assertTrue(value == output.getLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqFloat() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
double[] values =
|
||||
new double[] {
|
||||
-Double.MAX_VALUE,
|
||||
Double.MAX_VALUE,
|
||||
-Double.MIN_VALUE,
|
||||
Double.MIN_VALUE,
|
||||
-Math.exp(1.d),
|
||||
-Math.sqrt(2.d),
|
||||
-3.1415f,
|
||||
3.1415f,
|
||||
-1,
|
||||
0,
|
||||
1,
|
||||
};
|
||||
for (double value : values) {
|
||||
final IValue input = IValue.double64(value);
|
||||
assertTrue(input.isDouble());
|
||||
assertTrue(value == input.getDouble());
|
||||
final IValue output = module.runMethod("eqFloat", input);
|
||||
assertTrue(output.isDouble());
|
||||
assertTrue(value == output.getDouble());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqTensor() throws IOException {
|
||||
final long[] inputTensorShape = new long[] {1, 3, 224, 224};
|
||||
final long numElements = Tensor.numel(inputTensorShape);
|
||||
final float[] inputTensorData = new float[(int) numElements];
|
||||
for (int i = 0; i < numElements; ++i) {
|
||||
inputTensorData[i] = i;
|
||||
}
|
||||
final Tensor inputTensor = Tensor.newFloat32Tensor(inputTensorShape, inputTensorData);
|
||||
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final IValue input = IValue.tensor(inputTensor);
|
||||
assertTrue(input.isTensor());
|
||||
assertTrue(inputTensor == input.getTensor());
|
||||
final IValue output = module.runMethod("eqTensor", input);
|
||||
assertTrue(output.isTensor());
|
||||
final Tensor outputTensor = output.getTensor();
|
||||
assertNotNull(outputTensor);
|
||||
assertArrayEquals(inputTensorShape, outputTensor.shape);
|
||||
float[] outputData = outputTensor.getDataAsFloatArray();
|
||||
for (int i = 0; i < numElements; i++) {
|
||||
assertTrue(inputTensorData[i] == outputData[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictIntKeyIntValue() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Map<Long, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put(Long.MIN_VALUE, IValue.long64(-Long.MIN_VALUE));
|
||||
inputMap.put(Long.MAX_VALUE, IValue.long64(-Long.MAX_VALUE));
|
||||
inputMap.put(0l, IValue.long64(0l));
|
||||
inputMap.put(1l, IValue.long64(-1l));
|
||||
inputMap.put(-1l, IValue.long64(1l));
|
||||
|
||||
final IValue input = IValue.dictLongKey(inputMap);
|
||||
assertTrue(input.isDictLongKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
|
||||
assertTrue(output.isDictLongKey());
|
||||
|
||||
final Map<Long, IValue> outputMap = output.getDictLongKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).getLong() == entry.getValue().getLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictStrKeyIntValue() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Map<String, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put("long_min_value", IValue.long64(Long.MIN_VALUE));
|
||||
inputMap.put("long_max_value", IValue.long64(Long.MAX_VALUE));
|
||||
inputMap.put("long_0", IValue.long64(0l));
|
||||
inputMap.put("long_1", IValue.long64(1l));
|
||||
inputMap.put("long_-1", IValue.long64(-1l));
|
||||
|
||||
final IValue input = IValue.dictStringKey(inputMap);
|
||||
assertTrue(input.isDictStringKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
|
||||
assertTrue(output.isDictStringKey());
|
||||
|
||||
final Map<String, IValue> outputMap = output.getDictStringKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).getLong() == entry.getValue().getLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testListIntSumReturnTuple() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
|
||||
for (int n : new int[] {0, 1, 128}) {
|
||||
long[] a = new long[n];
|
||||
long sum = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
a[i] = i;
|
||||
sum += a[i];
|
||||
}
|
||||
final IValue input = IValue.longList(a);
|
||||
assertTrue(input.isLongList());
|
||||
|
||||
final IValue output = module.runMethod("listIntSumReturnTuple", input);
|
||||
|
||||
assertTrue(output.isTuple());
|
||||
assertTrue(2 == output.getTuple().length);
|
||||
|
||||
IValue output0 = output.getTuple()[0];
|
||||
IValue output1 = output.getTuple()[1];
|
||||
|
||||
assertArrayEquals(a, output0.getLongList());
|
||||
assertTrue(sum == output1.getLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOptionalIntIsNone() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
|
||||
assertFalse(module.runMethod("optionalIntIsNone", IValue.long64(1l)).getBool());
|
||||
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).getBool());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIntEq0None() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
|
||||
assertTrue(module.runMethod("intEq0None", IValue.long64(0l)).isNull());
|
||||
assertTrue(module.runMethod("intEq0None", IValue.long64(1l)).getLong() == 1l);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testRunUndefinedMethod() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
module.runMethod("test_undefined_method_throws_exception");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorMethods() {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
int[] ints = new int[numel];
|
||||
float[] floats = new float[numel];
|
||||
|
||||
byte[] bytes = new byte[numel];
|
||||
for (int i = 0; i < numel; i++) {
|
||||
bytes[i] = (byte) ((i % 255) - 128);
|
||||
ints[i] = i;
|
||||
floats[i] = i / 1000.f;
|
||||
}
|
||||
|
||||
Tensor tensorBytes = Tensor.newInt8Tensor(shape, bytes);
|
||||
assertTrue(tensorBytes.dtype() == Tensor.DTYPE_INT8);
|
||||
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
|
||||
|
||||
Tensor tensorInts = Tensor.newInt32Tensor(shape, ints);
|
||||
assertTrue(tensorInts.dtype() == Tensor.DTYPE_INT32);
|
||||
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
|
||||
|
||||
Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats);
|
||||
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
|
||||
float[] floatsOut = tensorFloats.getDataAsFloatArray();
|
||||
assertTrue(floatsOut.length == numel);
|
||||
for (int i = 0; i < numel; i++) {
|
||||
assertTrue(floats[i] == floatsOut[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void testTensorIllegalStateOnWrongType() {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
float[] floats = new float[numel];
|
||||
Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats);
|
||||
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
|
||||
tensorFloats.getDataAsByteArray();
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testEqString() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.string(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.getString()));
|
||||
final IValue output = module.runMethod("eqStr", input);
|
||||
assertTrue(output.isString());
|
||||
assertTrue(value.equals(output.getString()));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStr3Concat() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.string(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.getString()));
|
||||
final IValue output = module.runMethod("str3Concat", input);
|
||||
assertTrue(output.isString());
|
||||
String expectedOutput = new StringBuilder().append(value).append(value).append(value).toString();
|
||||
assertTrue(expectedOutput.equals(output.getString()));
|
||||
}
|
||||
}
|
||||
|
||||
private static String assetFilePath(String assetName) throws IOException {
|
||||
@Override
|
||||
protected String assetFilePath(String assetName) throws IOException {
|
||||
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
|
||||
File file = new File(appContext.getFilesDir(), assetName);
|
||||
if (file.exists() && file.length() > 0) {
|
||||
|
@ -0,0 +1,290 @@
|
||||
package org.pytorch;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public abstract class PytorchTestBase {
|
||||
private static final String TEST_MODULE_ASSET_NAME = "test.pt";
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
System.loadLibrary("pytorch");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testForwardNull() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final IValue input =
|
||||
IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
|
||||
assertTrue(input.isTensor());
|
||||
final IValue output = module.forward(input);
|
||||
assertTrue(output.isNull());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqBool() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
for (boolean value : new boolean[] {false, true}) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isBool());
|
||||
assertTrue(value == input.toBool());
|
||||
final IValue output = module.runMethod("eqBool", input);
|
||||
assertTrue(output.isBool());
|
||||
assertTrue(value == output.toBool());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqInt() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isLong());
|
||||
assertTrue(value == input.toLong());
|
||||
final IValue output = module.runMethod("eqInt", input);
|
||||
assertTrue(output.isLong());
|
||||
assertTrue(value == output.toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqFloat() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
double[] values =
|
||||
new double[] {
|
||||
-Double.MAX_VALUE,
|
||||
Double.MAX_VALUE,
|
||||
-Double.MIN_VALUE,
|
||||
Double.MIN_VALUE,
|
||||
-Math.exp(1.d),
|
||||
-Math.sqrt(2.d),
|
||||
-3.1415f,
|
||||
3.1415f,
|
||||
-1,
|
||||
0,
|
||||
1,
|
||||
};
|
||||
for (double value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isDouble());
|
||||
assertTrue(value == input.toDouble());
|
||||
final IValue output = module.runMethod("eqFloat", input);
|
||||
assertTrue(output.isDouble());
|
||||
assertTrue(value == output.toDouble());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqTensor() throws IOException {
|
||||
final long[] inputTensorShape = new long[] {1, 3, 224, 224};
|
||||
final long numElements = Tensor.numel(inputTensorShape);
|
||||
final float[] inputTensorData = new float[(int) numElements];
|
||||
for (int i = 0; i < numElements; ++i) {
|
||||
inputTensorData[i] = i;
|
||||
}
|
||||
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
|
||||
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final IValue input = IValue.from(inputTensor);
|
||||
assertTrue(input.isTensor());
|
||||
assertTrue(inputTensor == input.toTensor());
|
||||
final IValue output = module.runMethod("eqTensor", input);
|
||||
assertTrue(output.isTensor());
|
||||
final Tensor outputTensor = output.toTensor();
|
||||
assertNotNull(outputTensor);
|
||||
assertArrayEquals(inputTensorShape, outputTensor.shape());
|
||||
float[] outputData = outputTensor.getDataAsFloatArray();
|
||||
for (int i = 0; i < numElements; i++) {
|
||||
assertTrue(inputTensorData[i] == outputData[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictIntKeyIntValue() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Map<Long, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
|
||||
inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE));
|
||||
inputMap.put(0l, IValue.from(0l));
|
||||
inputMap.put(1l, IValue.from(-1l));
|
||||
inputMap.put(-1l, IValue.from(1l));
|
||||
|
||||
final IValue input = IValue.dictLongKeyFrom(inputMap);
|
||||
assertTrue(input.isDictLongKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
|
||||
assertTrue(output.isDictLongKey());
|
||||
|
||||
final Map<Long, IValue> outputMap = output.toDictLongKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictStrKeyIntValue() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Map<String, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
|
||||
inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE));
|
||||
inputMap.put("long_0", IValue.from(0l));
|
||||
inputMap.put("long_1", IValue.from(1l));
|
||||
inputMap.put("long_-1", IValue.from(-1l));
|
||||
|
||||
final IValue input = IValue.dictStringKeyFrom(inputMap);
|
||||
assertTrue(input.isDictStringKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
|
||||
assertTrue(output.isDictStringKey());
|
||||
|
||||
final Map<String, IValue> outputMap = output.toDictStringKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testListIntSumReturnTuple() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
|
||||
for (int n : new int[] {0, 1, 128}) {
|
||||
long[] a = new long[n];
|
||||
long sum = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
a[i] = i;
|
||||
sum += a[i];
|
||||
}
|
||||
final IValue input = IValue.listFrom(a);
|
||||
assertTrue(input.isLongList());
|
||||
|
||||
final IValue output = module.runMethod("listIntSumReturnTuple", input);
|
||||
|
||||
assertTrue(output.isTuple());
|
||||
assertTrue(2 == output.toTuple().length);
|
||||
|
||||
IValue output0 = output.toTuple()[0];
|
||||
IValue output1 = output.toTuple()[1];
|
||||
|
||||
assertArrayEquals(a, output0.toLongList());
|
||||
assertTrue(sum == output1.toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOptionalIntIsNone() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
|
||||
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
|
||||
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIntEq0None() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testRunUndefinedMethod() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
module.runMethod("test_undefined_method_throws_exception");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorMethods() {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
int[] ints = new int[numel];
|
||||
float[] floats = new float[numel];
|
||||
|
||||
byte[] bytes = new byte[numel];
|
||||
for (int i = 0; i < numel; i++) {
|
||||
bytes[i] = (byte) ((i % 255) - 128);
|
||||
ints[i] = i;
|
||||
floats[i] = i / 1000.f;
|
||||
}
|
||||
|
||||
Tensor tensorBytes = Tensor.fromBlob(bytes, shape);
|
||||
assertTrue(tensorBytes.dtype() == DType.INT8);
|
||||
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
|
||||
|
||||
Tensor tensorInts = Tensor.fromBlob(ints, shape);
|
||||
assertTrue(tensorInts.dtype() == DType.INT32);
|
||||
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
|
||||
|
||||
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
|
||||
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
|
||||
float[] floatsOut = tensorFloats.getDataAsFloatArray();
|
||||
assertTrue(floatsOut.length == numel);
|
||||
for (int i = 0; i < numel; i++) {
|
||||
assertTrue(floats[i] == floatsOut[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void testTensorIllegalStateOnWrongType() {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
float[] floats = new float[numel];
|
||||
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
|
||||
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
|
||||
tensorFloats.getDataAsByteArray();
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testEqString() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.toStr()));
|
||||
final IValue output = module.runMethod("eqStr", input);
|
||||
assertTrue(output.isString());
|
||||
assertTrue(value.equals(output.toStr()));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStr3Concat() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.toStr()));
|
||||
final IValue output = module.runMethod("str3Concat", input);
|
||||
assertTrue(output.isString());
|
||||
String expectedOutput = new StringBuilder().append(value).append(value).append(value).toString();
|
||||
assertTrue(expectedOutput.equals(output.toStr()));
|
||||
}
|
||||
}
|
||||
|
||||
protected abstract String assetFilePath(String assetName) throws IOException;
|
||||
}
|
@ -10,6 +10,8 @@
|
||||
|
||||
namespace pytorch_jni {
|
||||
|
||||
// NOTE: Codes must be kept in sync with DType.java.
|
||||
// NOTE: Never serialize these, because they can change between releases.
|
||||
constexpr static int kTensorDTypeUInt8 = 1;
|
||||
constexpr static int kTensorDTypeInt8 = 2;
|
||||
constexpr static int kTensorDTypeInt32 = 3;
|
||||
@ -164,7 +166,7 @@ class JTensor : public facebook::jni::JavaClass<JTensor> {
|
||||
static at::Tensor newAtTensorFromJTensor(
|
||||
facebook::jni::alias_ref<JTensor> jtensor) {
|
||||
static const auto dtypeMethod =
|
||||
JTensor::javaClassStatic()->getMethod<jint()>("dtype");
|
||||
JTensor::javaClassStatic()->getMethod<jint()>("dtypeJniCode");
|
||||
jint jdtype = dtypeMethod(jtensor);
|
||||
|
||||
static const auto shapeField =
|
||||
@ -216,7 +218,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static auto jMethodTensor =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::local_ref<JTensor>)>("tensor");
|
||||
facebook::jni::local_ref<JTensor>)>("from");
|
||||
return jMethodTensor(
|
||||
JIValue::javaClassStatic(),
|
||||
JTensor::newJTensorFromAtTensor(ivalue.toTensor()));
|
||||
@ -224,26 +226,26 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static auto jMethodBool =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(jboolean)>(
|
||||
"bool");
|
||||
"from");
|
||||
return jMethodBool(JIValue::javaClassStatic(), ivalue.toBool());
|
||||
} else if (ivalue.isInt()) {
|
||||
static auto jMethodInt =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(jlong)>(
|
||||
"long64");
|
||||
"from");
|
||||
return jMethodInt(JIValue::javaClassStatic(), ivalue.toInt());
|
||||
} else if (ivalue.isDouble()) {
|
||||
static auto jMethodDouble =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(jdouble)>(
|
||||
"double64");
|
||||
"from");
|
||||
return jMethodDouble(JIValue::javaClassStatic(), ivalue.toDouble());
|
||||
} else if (ivalue.isString()) {
|
||||
static auto jMethodString =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JString::javaobject>)>("string");
|
||||
facebook::jni::JString::javaobject>)>("from");
|
||||
return jMethodString(
|
||||
JIValue::javaClassStatic(),
|
||||
facebook::jni::make_jstring(ivalue.toStringRef()));
|
||||
@ -253,7 +255,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject>)>("tuple");
|
||||
JIValue::javaobject>::javaobject>)>("tupleFrom");
|
||||
auto jElementsArray =
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::newArray(
|
||||
elementsVec.size());
|
||||
@ -267,7 +269,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static auto jMethodBoolListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<jbooleanArray>)>("boolList");
|
||||
facebook::jni::alias_ref<jbooleanArray>)>("listFrom");
|
||||
size_t n = list.size();
|
||||
auto jArray = facebook::jni::make_boolean_array(n);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
@ -281,7 +283,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static auto jMethodLongListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<jlongArray>)>("longList");
|
||||
facebook::jni::alias_ref<jlongArray>)>("listFrom");
|
||||
size_t n = list.size();
|
||||
auto jArray = facebook::jni::make_long_array(n);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
@ -295,7 +297,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static auto jMethoDoubleListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<jdoubleArray>)>("doubleList");
|
||||
facebook::jni::alias_ref<jdoubleArray>)>("listFrom");
|
||||
size_t n = list.size();
|
||||
auto jArray = facebook::jni::make_double_array(n);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
@ -310,7 +312,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JArrayClass<
|
||||
JTensor::javaobject>::javaobject>)>("tensorList");
|
||||
JTensor::javaobject>::javaobject>)>("listFrom");
|
||||
auto jArray = facebook::jni::JArrayClass<JTensor::javaobject>::newArray(
|
||||
list.size());
|
||||
auto index = 0;
|
||||
@ -324,7 +326,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject>)>("list");
|
||||
JIValue::javaobject>::javaobject>)>("listFrom");
|
||||
auto jArray = facebook::jni::JArrayClass<JIValue::javaobject>::newArray(
|
||||
list.size());
|
||||
auto index = 0;
|
||||
@ -351,7 +353,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JString::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||
"dictStringKey");
|
||||
"dictStringKeyFrom");
|
||||
|
||||
auto jmap = JHashMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
|
||||
@ -370,7 +372,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JLong::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||
"dictLongKey");
|
||||
"dictLongKeyFrom");
|
||||
auto jmap = JHashMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>::create();
|
||||
@ -404,32 +406,32 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static const auto jMethodGetTensor =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::alias_ref<JTensor::javaobject>()>(
|
||||
"getTensor");
|
||||
"toTensor");
|
||||
return JTensor::newAtTensorFromJTensor(jMethodGetTensor(jivalue));
|
||||
} else if (JIValue::kTypeCodeBool == typeCode) {
|
||||
static const auto jMethodGetBool =
|
||||
JIValue::javaClassStatic()->getMethod<jboolean()>("getBool");
|
||||
JIValue::javaClassStatic()->getMethod<jboolean()>("toBool");
|
||||
// explicit cast to bool as jboolean is defined as uint8_t, IValue ctor
|
||||
// for int will be called for jboolean
|
||||
bool b = jMethodGetBool(jivalue);
|
||||
return at::IValue{b};
|
||||
} else if (JIValue::kTypeCodeLong == typeCode) {
|
||||
static const auto jMethodGetLong =
|
||||
JIValue::javaClassStatic()->getMethod<jlong()>("getLong");
|
||||
JIValue::javaClassStatic()->getMethod<jlong()>("toLong");
|
||||
return at::IValue{jMethodGetLong(jivalue)};
|
||||
} else if (JIValue::kTypeCodeDouble == typeCode) {
|
||||
static const auto jMethodGetDouble =
|
||||
JIValue::javaClassStatic()->getMethod<jdouble()>("getDouble");
|
||||
JIValue::javaClassStatic()->getMethod<jdouble()>("toDouble");
|
||||
return at::IValue{jMethodGetDouble(jivalue)};
|
||||
} else if (JIValue::kTypeCodeString == typeCode) {
|
||||
static const auto jMethodGetString =
|
||||
JIValue::javaClassStatic()->getMethod<jstring()>("getString");
|
||||
JIValue::javaClassStatic()->getMethod<jstring()>("toStr");
|
||||
return at::IValue{jMethodGetString(jivalue)->toStdString()};
|
||||
} else if (JIValue::kTypeCodeTuple == typeCode) {
|
||||
static const auto jMethodGetTuple =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject()>("getTuple");
|
||||
JIValue::javaobject>::javaobject()>("toTuple");
|
||||
auto jarray = jMethodGetTuple(jivalue);
|
||||
size_t n = jarray->size();
|
||||
|
||||
@ -443,7 +445,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
return c10::ivalue::Tuple::create(std::move(elements));
|
||||
} else if (JIValue::kTypeCodeBoolList == typeCode) {
|
||||
static const auto jMethodGetBoolList =
|
||||
JIValue::javaClassStatic()->getMethod<jbooleanArray()>("getBoolList");
|
||||
JIValue::javaClassStatic()->getMethod<jbooleanArray()>("toBoolList");
|
||||
auto jArray = jMethodGetBoolList(jivalue);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
size_t n = jArrayPinned.size();
|
||||
@ -455,7 +457,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
return at::IValue{std::move(list)};
|
||||
} else if (JIValue::kTypeCodeLongList == typeCode) {
|
||||
static const auto jMethodGetLongList =
|
||||
JIValue::javaClassStatic()->getMethod<jlongArray()>("getLongList");
|
||||
JIValue::javaClassStatic()->getMethod<jlongArray()>("toLongList");
|
||||
auto jArray = jMethodGetLongList(jivalue);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
size_t n = jArrayPinned.size();
|
||||
@ -468,7 +470,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
} else if (JIValue::kTypeCodeDoubleList == typeCode) {
|
||||
static const auto jMethodGetDoubleList =
|
||||
JIValue::javaClassStatic()->getMethod<jdoubleArray()>(
|
||||
"getDoubleList");
|
||||
"toDoubleList");
|
||||
auto jArray = jMethodGetDoubleList(jivalue);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
size_t n = jArrayPinned.size();
|
||||
@ -482,7 +484,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static const auto jMethodGetTensorList =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JArrayClass<
|
||||
JTensor::javaobject>::javaobject()>("getTensorList");
|
||||
JTensor::javaobject>::javaobject()>("toTensorList");
|
||||
auto jArray = jMethodGetTensorList(jivalue);
|
||||
size_t n = jArray->size();
|
||||
c10::List<at::Tensor> list{};
|
||||
@ -495,7 +497,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static const auto jMethodGetList =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject()>("getList");
|
||||
JIValue::javaobject>::javaobject()>("toList");
|
||||
auto jarray = jMethodGetList(jivalue);
|
||||
size_t n = jarray->size();
|
||||
if (n == 0) {
|
||||
@ -518,7 +520,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static const auto jMethodGetDictStringKey =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JMap<jstring, JIValue::javaobject>::
|
||||
javaobject()>("getDictStringKey");
|
||||
javaobject()>("toDictStringKey");
|
||||
auto jmap = jMethodGetDictStringKey(jivalue);
|
||||
auto it = jmap->begin();
|
||||
if (it == jmap->end()) {
|
||||
@ -541,7 +543,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JMap<
|
||||
facebook::jni::JLong::javaobject,
|
||||
JIValue::javaobject>::javaobject()>("getDictLongKey");
|
||||
JIValue::javaobject>::javaobject()>("toDictLongKey");
|
||||
auto jmap = jMethodGetDictLongKey(jivalue);
|
||||
auto it = jmap->begin();
|
||||
if (it == jmap->end()) {
|
||||
|
29
android/pytorch_android/src/main/java/org/pytorch/DType.java
Normal file
29
android/pytorch_android/src/main/java/org/pytorch/DType.java
Normal file
@ -0,0 +1,29 @@
|
||||
package org.pytorch;
|
||||
|
||||
/**
|
||||
* Codes representing tensor data types.
|
||||
*/
|
||||
public enum DType {
|
||||
// NOTE: "jniCode" must be kept in sync with pytorch_jni.cpp.
|
||||
// NOTE: Never serialize "jniCode", because it can change between releases.
|
||||
|
||||
/** Code for dtype torch.uint8. {@link Tensor#dtype()} */
|
||||
UINT8(1),
|
||||
/** Code for dtype torch.int8. {@link Tensor#dtype()} */
|
||||
INT8(2),
|
||||
/** Code for dtype torch.int32. {@link Tensor#dtype()} */
|
||||
INT32(3),
|
||||
/** Code for dtype torch.float32. {@link Tensor#dtype()} */
|
||||
FLOAT32(4),
|
||||
/** Code for dtype torch.int64. {@link Tensor#dtype()} */
|
||||
INT64(5),
|
||||
/** Code for dtype torch.float64. {@link Tensor#dtype()} */
|
||||
FLOAT64(6),
|
||||
;
|
||||
|
||||
final int jniCode;
|
||||
|
||||
DType(int jniCode) {
|
||||
this.jniCode = jniCode;
|
||||
}
|
||||
}
|
@ -4,10 +4,21 @@ import java.util.Locale;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Java representation of a torchscript variable, which is implemented as tagged union that can be
|
||||
* one of the supported types: https://pytorch.org/docs/stable/jit.html#types.
|
||||
* Java representation of a TorchScript value, which is implemented as tagged union that can be
|
||||
* one of the supported types: https://pytorch.org/docs/stable/jit.html#types .
|
||||
* <p>
|
||||
* Calling getters for inappropriate types will throw IllegalStateException.
|
||||
* Calling {@code toX} methods for inappropriate types will throw {@link IllegalStateException}.
|
||||
* <p>
|
||||
* {@code IValue} objects are constructed with {@code IValue.from(value)},
|
||||
* {@code IValue.tupleFrom(value1, value2, ...)}, {@code IValue.listFrom(value1, value2, ...)},
|
||||
* or one of the {@code dict} methods, depending on the key type.
|
||||
* <p>
|
||||
* Data is retrieved from {@code IValue} objects with the {@code toX()} methods. Note that
|
||||
* {@code str}-type IValues must be extracted with {@link #toStr()},
|
||||
* rather than {@link #toString()}.
|
||||
* <p>
|
||||
* {@code IValue} objects may retain references to objects passed into their constructors,
|
||||
* and may return references to their internal state from {@code toX()}.
|
||||
*/
|
||||
public class IValue {
|
||||
private static final int TYPE_CODE_NULL = 1;
|
||||
@ -91,95 +102,98 @@ public class IValue {
|
||||
return TYPE_CODE_DICT_LONG_KEY == this.mTypeCode;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@code IValue} of type {@code Optional} that contains no value.
|
||||
*/
|
||||
public static IValue optionalNull() {
|
||||
return new IValue(TYPE_CODE_NULL);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript Tensor type.
|
||||
* Creates a new {@code IValue} of type {@code Tensor}.
|
||||
*/
|
||||
public static IValue tensor(Tensor tensor) {
|
||||
public static IValue from(Tensor tensor) {
|
||||
final IValue iv = new IValue(TYPE_CODE_TENSOR);
|
||||
iv.mData = tensor;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript bool type.
|
||||
* Creates a new {@code IValue} of type {@code bool}.
|
||||
*/
|
||||
public static IValue bool(boolean value) {
|
||||
public static IValue from(boolean value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_BOOL);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript int type.
|
||||
* Creates a new {@code IValue} of type {@code int}.
|
||||
*/
|
||||
public static IValue long64(long value) {
|
||||
public static IValue from(long value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_LONG);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript float type.
|
||||
* Creates a new {@code IValue} of type {@code float}.
|
||||
*/
|
||||
public static IValue double64(double value) {
|
||||
public static IValue from(double value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DOUBLE);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates new IValue instance of torchscript str type.
|
||||
* Creates a new {@code IValue} of type {@code str}.
|
||||
*/
|
||||
public static IValue string(String value) {
|
||||
public static IValue from(String value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_STRING);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript List[bool] type.
|
||||
* Creates a new {@code IValue} of type {@code List[bool]}.
|
||||
*/
|
||||
public static IValue boolList(boolean... list) {
|
||||
public static IValue listFrom(boolean... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_BOOL_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript List[int] type.
|
||||
* Creates a new {@code IValue} of type {@code List[int]}.
|
||||
*/
|
||||
public static IValue longList(long... list) {
|
||||
public static IValue listFrom(long... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_LONG_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript List[float] type.
|
||||
* Creates a new {@code IValue} of type {@code List[float]}.
|
||||
*/
|
||||
public static IValue doubleList(double... list) {
|
||||
public static IValue listFrom(double... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DOUBLE_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript List[Tensor] type.
|
||||
* Creates a new {@code IValue} of type {@code List[Tensor]}.
|
||||
*/
|
||||
public static IValue tensorList(Tensor... list) {
|
||||
public static IValue listFrom(Tensor... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_TENSOR_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript List[T] type. All elements must have the same type.
|
||||
* Creates a new {@code IValue} of type {@code List[T]}. All elements must have the same type.
|
||||
*/
|
||||
public static IValue list(IValue... array) {
|
||||
public static IValue listFrom(IValue... array) {
|
||||
final int size = array.length;
|
||||
if (size > 0) {
|
||||
final int typeCode0 = array[0].mTypeCode;
|
||||
@ -196,93 +210,93 @@ public class IValue {
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript Tuple[T0, T1, ...] type.
|
||||
* Creates a new {@code IValue} of type {@code Tuple[T0, T1, ...]}.
|
||||
*/
|
||||
public static IValue tuple(IValue... array) {
|
||||
public static IValue tupleFrom(IValue... array) {
|
||||
final IValue iv = new IValue(TYPE_CODE_TUPLE);
|
||||
iv.mData = array;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new IValue instance oftorchscript Dict[Str, V] type.
|
||||
* Creates a new {@code IValue} of type {@code Dict[str, V]}.
|
||||
*/
|
||||
public static IValue dictStringKey(Map<String, IValue> map) {
|
||||
public static IValue dictStringKeyFrom(Map<String, IValue> map) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DICT_STRING_KEY);
|
||||
iv.mData = map;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript Dict[int, V] type.
|
||||
* Creates a new {@code IValue} of type {@code Dict[int, V]}.
|
||||
*/
|
||||
public static IValue dictLongKey(Map<Long, IValue> map) {
|
||||
public static IValue dictLongKeyFrom(Map<Long, IValue> map) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DICT_LONG_KEY);
|
||||
iv.mData = map;
|
||||
return iv;
|
||||
}
|
||||
|
||||
public Tensor getTensor() {
|
||||
public Tensor toTensor() {
|
||||
preconditionType(TYPE_CODE_TENSOR, mTypeCode);
|
||||
return (Tensor) mData;
|
||||
}
|
||||
|
||||
public boolean getBool() {
|
||||
public boolean toBool() {
|
||||
preconditionType(TYPE_CODE_BOOL, mTypeCode);
|
||||
return (boolean) mData;
|
||||
}
|
||||
|
||||
public long getLong() {
|
||||
public long toLong() {
|
||||
preconditionType(TYPE_CODE_LONG, mTypeCode);
|
||||
return (long) mData;
|
||||
}
|
||||
|
||||
public double getDouble() {
|
||||
public double toDouble() {
|
||||
preconditionType(TYPE_CODE_DOUBLE, mTypeCode);
|
||||
return (double) mData;
|
||||
}
|
||||
|
||||
public String getString() {
|
||||
public String toStr() {
|
||||
preconditionType(TYPE_CODE_STRING, mTypeCode);
|
||||
return (String) mData;
|
||||
}
|
||||
|
||||
public boolean[] getBoolList() {
|
||||
public boolean[] toBoolList() {
|
||||
preconditionType(TYPE_CODE_BOOL_LIST, mTypeCode);
|
||||
return (boolean[]) mData;
|
||||
}
|
||||
|
||||
public long[] getLongList() {
|
||||
public long[] toLongList() {
|
||||
preconditionType(TYPE_CODE_LONG_LIST, mTypeCode);
|
||||
return (long[]) mData;
|
||||
}
|
||||
|
||||
public double[] getDoubleList() {
|
||||
public double[] toDoubleList() {
|
||||
preconditionType(TYPE_CODE_DOUBLE_LIST, mTypeCode);
|
||||
return (double[]) mData;
|
||||
}
|
||||
|
||||
public Tensor[] getTensorList() {
|
||||
public Tensor[] toTensorList() {
|
||||
preconditionType(TYPE_CODE_TENSOR_LIST, mTypeCode);
|
||||
return (Tensor[]) mData;
|
||||
}
|
||||
|
||||
public IValue[] getList() {
|
||||
public IValue[] toList() {
|
||||
preconditionType(TYPE_CODE_LIST, mTypeCode);
|
||||
return (IValue[]) mData;
|
||||
}
|
||||
|
||||
public IValue[] getTuple() {
|
||||
public IValue[] toTuple() {
|
||||
preconditionType(TYPE_CODE_TUPLE, mTypeCode);
|
||||
return (IValue[]) mData;
|
||||
}
|
||||
|
||||
public Map<String, IValue> getDictStringKey() {
|
||||
public Map<String, IValue> toDictStringKey() {
|
||||
preconditionType(TYPE_CODE_DICT_STRING_KEY, mTypeCode);
|
||||
return (Map<String, IValue>) mData;
|
||||
}
|
||||
|
||||
public Map<Long, IValue> getDictLongKey() {
|
||||
public Map<Long, IValue> toDictLongKey() {
|
||||
preconditionType(TYPE_CODE_DICT_LONG_KEY, mTypeCode);
|
||||
return (Map<Long, IValue>) mData;
|
||||
}
|
||||
|
@ -5,21 +5,20 @@ package org.pytorch;
|
||||
import com.facebook.jni.HybridData;
|
||||
|
||||
/**
|
||||
* Java holder for torch::jit::script::Module which owns it on jni side.
|
||||
* Java wrapper for torch::jit::script::Module.
|
||||
*/
|
||||
public class Module {
|
||||
|
||||
private NativePeer mNativePeer;
|
||||
|
||||
/**
|
||||
* Loads serialized torchscript module from the specified absolute path on the disk.
|
||||
* Loads a serialized TorchScript module from the specified path on the disk.
|
||||
*
|
||||
* @param modelAbsolutePath absolute path to file that contains the serialized torchscript module.
|
||||
* @return new {@link org.pytorch.Module} object which owns torch::jit::script::Module on jni
|
||||
* side.
|
||||
* @param modelPath path to file that contains the serialized TorchScript module.
|
||||
* @return new {@link org.pytorch.Module} object which owns torch::jit::script::Module.
|
||||
*/
|
||||
public static Module load(final String modelAbsolutePath) {
|
||||
return new Module(modelAbsolutePath);
|
||||
public static Module load(final String modelPath) {
|
||||
return new Module(modelPath);
|
||||
}
|
||||
|
||||
private Module(final String moduleAbsolutePath) {
|
||||
@ -27,26 +26,37 @@ public class Module {
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs 'forward' method of loaded torchscript module with specified arguments.
|
||||
* Runs the 'forward' method of this module with the specified arguments.
|
||||
*
|
||||
* @param inputs arguments for torchscript module 'forward' method.
|
||||
* @return result of torchscript module 'forward' method evaluation
|
||||
* @param inputs arguments for the TorchScript module's 'forward' method.
|
||||
* @return return value from the 'forward' method.
|
||||
*/
|
||||
public IValue forward(IValue... inputs) {
|
||||
return mNativePeer.forward(inputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs specified method of loaded torchscript module with specified arguments.
|
||||
* Runs the specified method of this module with the specified arguments.
|
||||
*
|
||||
* @param methodName torchscript module method to run
|
||||
* @param inputs arguments that will be specified to torchscript module method call
|
||||
* @return result of torchscript module specified method evaluation
|
||||
* @param methodName name of the TorchScript method to run.
|
||||
* @param inputs arguments that will be passed to TorchScript method.
|
||||
* @return return value from the method.
|
||||
*/
|
||||
public IValue runMethod(String methodName, IValue... inputs) {
|
||||
return mNativePeer.runMethod(methodName, inputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Explicitly destroys the native torch::jit::script::Module.
|
||||
* Calling this method is not required, as the native object will be destroyed
|
||||
* when this object is garbage-collected. However, the timing of garbage collection
|
||||
* is not guaranteed, so proactively calling {@code destroy} can free memory more quickly.
|
||||
* See {@link com.facebook.jni.HybridData#resetNative}.
|
||||
*/
|
||||
public void destroy() {
|
||||
mNativePeer.mHybridData.resetNative();
|
||||
}
|
||||
|
||||
private static class NativePeer {
|
||||
static {
|
||||
System.loadLibrary("pytorch");
|
||||
|
@ -11,24 +11,24 @@ import java.util.Arrays;
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
* Representation of Tensor. Tensor shape is stored in {@link Tensor#shape}, elements are stored as
|
||||
* {@link java.nio.DirectByteBuffer} of one of the supported types.
|
||||
* Representation of a Tensor. Behavior is similar to PyTorch's tensor objects.
|
||||
* <p>
|
||||
* Most tensors will be constructed as {@code Tensor.fromBlob(data, shape)},
|
||||
* where {@code data} can be an array or a direct {@link Buffer} (of the proper subclass).
|
||||
* Helper methods are provided to allocate buffers properly.
|
||||
* <p>
|
||||
* To access Tensor data, see {@link #dtype()}, {@link #shape()},
|
||||
* and various {@code getDataAs*} methods.
|
||||
* <p>
|
||||
* When constructing {@code Tensor} objects with {@code data} as an array,
|
||||
* it is not specified whether this data is is copied or retained as a reference
|
||||
* so it is recommended not to modify it after constructing. {@code data} passed as a
|
||||
* {@link Buffer} is not copied, so it can be modified between {@link Module} calls
|
||||
* to avoid reallocation. Data retrieved from {@code Tensor} objects may be copied or
|
||||
* may be a reference to the {@code Tensor}'s internal data buffer.
|
||||
* {@code shape} is always copied.
|
||||
*/
|
||||
public abstract class Tensor {
|
||||
|
||||
/** Code for dtype torch.uint8. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_UINT8 = 1;
|
||||
/** Code for dtype torch.int8. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_INT8 = 2;
|
||||
/** Code for dtype torch.int32. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_INT32 = 3;
|
||||
/** Code for dtype torch.float32. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_FLOAT32 = 4;
|
||||
/** Code for dtype torch.int64. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_INT64 = 5;
|
||||
/** Code for dtype torch.float64. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_FLOAT64 = 6;
|
||||
|
||||
private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null";
|
||||
private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null";
|
||||
private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null";
|
||||
@ -39,8 +39,7 @@ public abstract class Tensor {
|
||||
private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT =
|
||||
"Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)";
|
||||
|
||||
/** Shape of current tensor. */
|
||||
public final long[] shape;
|
||||
final long[] shape;
|
||||
|
||||
private static final int INT_SIZE_BYTES = 4;
|
||||
private static final int FLOAT_SIZE_BYTES = 4;
|
||||
@ -49,8 +48,8 @@ public abstract class Tensor {
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.ByteBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#newInt8Tensor(long[], ByteBuffer)}, {@link
|
||||
* Tensor#newUInt8Tensor(long[], ByteBuffer)}.
|
||||
* capacity that can be used in {@link Tensor#fromBlob(ByteBuffer, long[])},
|
||||
* {@link Tensor#fromBlobUnsigned(ByteBuffer, long[])}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
@ -58,6 +57,12 @@ public abstract class Tensor {
|
||||
return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder());
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.IntBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#fromBlob(IntBuffer, long[])}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
public static IntBuffer allocateIntBuffer(int numElements) {
|
||||
return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
@ -66,7 +71,7 @@ public abstract class Tensor {
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.FloatBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#newFloat32Tensor(long[], FloatBuffer)}.
|
||||
* capacity that can be used in {@link Tensor#fromBlob(FloatBuffer, long[])}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
@ -78,7 +83,7 @@ public abstract class Tensor {
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.LongBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#newInt64Tensor(long[], LongBuffer)}.
|
||||
* capacity that can be used in {@link Tensor#fromBlob(LongBuffer, long[])}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
@ -90,7 +95,7 @@ public abstract class Tensor {
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.DoubleBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#newFloat64Tensor(long[], DoubleBuffer)}.
|
||||
* capacity that can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
@ -104,10 +109,10 @@ public abstract class Tensor {
|
||||
* Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of
|
||||
* bytes.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newUInt8Tensor(long[] shape, byte[] data) {
|
||||
public static Tensor fromBlobUnsigned(byte[] data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -121,10 +126,10 @@ public abstract class Tensor {
|
||||
* Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of
|
||||
* bytes.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newInt8Tensor(long[] shape, byte[] data) {
|
||||
public static Tensor fromBlob(byte[] data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -138,10 +143,10 @@ public abstract class Tensor {
|
||||
* Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of
|
||||
* ints.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newInt32Tensor(long[] shape, int[] data) {
|
||||
public static Tensor fromBlob(int[] data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -155,10 +160,10 @@ public abstract class Tensor {
|
||||
* Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array
|
||||
* of floats.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newFloat32Tensor(long[] shape, float[] data) {
|
||||
public static Tensor fromBlob(float[] data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -172,10 +177,10 @@ public abstract class Tensor {
|
||||
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of
|
||||
* longs.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newInt64Tensor(long[] shape, long[] data) {
|
||||
public static Tensor fromBlob(long[] data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -192,7 +197,7 @@ public abstract class Tensor {
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
*/
|
||||
public static Tensor newFloat64Tensor(long[] shape, double[] data) {
|
||||
public static Tensor fromBlob(long[] shape, double[] data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -205,12 +210,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.uint8 with specified shape and data.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newUInt8Tensor(long[] shape, ByteBuffer data) {
|
||||
public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -225,12 +230,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.int8 with specified shape and data.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newInt8Tensor(long[] shape, ByteBuffer data) {
|
||||
public static Tensor fromBlob(ByteBuffer data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -245,12 +250,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.int32 with specified shape and data.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newInt32Tensor(long[] shape, IntBuffer data) {
|
||||
public static Tensor fromBlob(IntBuffer data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -265,12 +270,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.float32 with specified shape and data.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newFloat32Tensor(long[] shape, FloatBuffer data) {
|
||||
public static Tensor fromBlob(FloatBuffer data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -285,12 +290,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newInt64Tensor(long[] shape, LongBuffer data) {
|
||||
public static Tensor fromBlob(LongBuffer data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -305,12 +310,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.float64 with specified shape and data.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newFloat64Tensor(long[] shape, DoubleBuffer data) {
|
||||
public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -327,12 +332,14 @@ public abstract class Tensor {
|
||||
this.shape = Arrays.copyOf(shape, shape.length);
|
||||
}
|
||||
|
||||
/** Calculates number of elements in current tensor instance. */
|
||||
/** Returns the number of elements in this tensor. */
|
||||
public long numel() {
|
||||
return numel(this.shape);
|
||||
}
|
||||
|
||||
/** Calculates number of elements in tensor with specified shape. */
|
||||
/**
|
||||
* Calculates the number of elements in a tensor with the specified shape.
|
||||
*/
|
||||
public static long numel(long[] shape) {
|
||||
checkShape(shape);
|
||||
int result = 1;
|
||||
@ -343,14 +350,24 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns dtype of current tensor. Can be one of {@link Tensor#DTYPE_UINT8}, {@link
|
||||
* Tensor#DTYPE_INT8}, {@link Tensor#DTYPE_INT32},{@link Tensor#DTYPE_FLOAT32}, {@link
|
||||
* Tensor#DTYPE_INT64}, {@link Tensor#DTYPE_FLOAT64}.
|
||||
* Returns the shape of this tensor. (The array is a fresh copy.)
|
||||
*/
|
||||
public abstract int dtype();
|
||||
public long[] shape() {
|
||||
return Arrays.copyOf(shape, shape.length);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns newly allocated java byte array that contains a copy of tensor data.
|
||||
* @return data type of this tensor.
|
||||
*/
|
||||
public abstract DType dtype();
|
||||
|
||||
// Called from native
|
||||
int dtypeJniCode() {
|
||||
return dtype().jniCode;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a Java byte array that contains the tensor data. This may be a copy or reference.
|
||||
*
|
||||
* @throws IllegalStateException if it is called for a non-int8 tensor.
|
||||
*/
|
||||
@ -360,7 +377,7 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns newly allocated java byte array that contains a copy of tensor data.
|
||||
* @return a Java byte array that contains the tensor data. This may be a copy or reference.
|
||||
*
|
||||
* @throws IllegalStateException if it is called for a non-uint8 tensor.
|
||||
*/
|
||||
@ -370,7 +387,7 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns newly allocated java byte array that contains a copy of tensor data.
|
||||
* @return a Java int array that contains the tensor data. This may be a copy or reference.
|
||||
*
|
||||
* @throws IllegalStateException if it is called for a non-int32 tensor.
|
||||
*/
|
||||
@ -380,7 +397,7 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns newly allocated java byte array that contains a copy of tensor data.
|
||||
* @return a Java float array that contains the tensor data. This may be a copy or reference.
|
||||
*
|
||||
* @throws IllegalStateException if it is called for a non-float32 tensor.
|
||||
*/
|
||||
@ -390,7 +407,7 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns newly allocated java byte array that contains a copy of tensor data.
|
||||
* @return a Java long array that contains the tensor data. This may be a copy or reference.
|
||||
*
|
||||
* @throws IllegalStateException if it is called for a non-int64 tensor.
|
||||
*/
|
||||
@ -400,7 +417,7 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns newly allocated java byte array that contains a copy of tensor data.
|
||||
* @return a Java double array that contains the tensor data. This may be a copy or reference.
|
||||
*
|
||||
* @throws IllegalStateException if it is called for a non-float64 tensor.
|
||||
*/
|
||||
@ -423,8 +440,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_UINT8;
|
||||
public DType dtype() {
|
||||
return DType.UINT8;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -455,8 +472,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_INT8;
|
||||
public DType dtype() {
|
||||
return DType.INT8;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -487,8 +504,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_INT32;
|
||||
public DType dtype() {
|
||||
return DType.INT32;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -527,8 +544,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_FLOAT32;
|
||||
public DType dtype() {
|
||||
return DType.FLOAT32;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -551,8 +568,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_INT64;
|
||||
public DType dtype() {
|
||||
return DType.INT64;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -583,8 +600,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_FLOAT64;
|
||||
public DType dtype() {
|
||||
return DType.FLOAT64;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -634,17 +651,17 @@ public abstract class Tensor {
|
||||
|
||||
// Called from native
|
||||
private static Tensor nativeNewTensor(ByteBuffer data, long[] shape, int dtype) {
|
||||
if (DTYPE_FLOAT32 == dtype) {
|
||||
if (DType.FLOAT32.jniCode == dtype) {
|
||||
return new Tensor_float32(data.asFloatBuffer(), shape);
|
||||
} else if (DTYPE_INT32 == dtype) {
|
||||
} else if (DType.INT32.jniCode == dtype) {
|
||||
return new Tensor_int32(data.asIntBuffer(), shape);
|
||||
} else if (DTYPE_INT64 == dtype) {
|
||||
} else if (DType.INT64.jniCode == dtype) {
|
||||
return new Tensor_int64(data.asLongBuffer(), shape);
|
||||
} else if (DTYPE_FLOAT64 == dtype) {
|
||||
} else if (DType.FLOAT64.jniCode == dtype) {
|
||||
return new Tensor_float64(data.asDoubleBuffer(), shape);
|
||||
} else if (DTYPE_UINT8 == dtype) {
|
||||
} else if (DType.UINT8.jniCode == dtype) {
|
||||
return new Tensor_uint8(data, shape);
|
||||
} else if (DTYPE_INT8 == dtype) {
|
||||
} else if (DType.INT8.jniCode == dtype) {
|
||||
return new Tensor_int8(data, shape);
|
||||
}
|
||||
throw new IllegalArgumentException("Unknown Tensor dtype");
|
||||
|
@ -7,6 +7,7 @@ import android.media.Image;
|
||||
import org.pytorch.Tensor;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
@ -26,7 +27,7 @@ public final class TensorImageUtils {
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
*/
|
||||
public static Tensor bitmapToFloat32Tensor(
|
||||
final Bitmap bitmap, float[] normMeanRGB, float normStdRGB[]) {
|
||||
final Bitmap bitmap, final float[] normMeanRGB, final float normStdRGB[]) {
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
|
||||
@ -34,6 +35,52 @@ public final class TensorImageUtils {
|
||||
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), normMeanRGB, normStdRGB);
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes tensor content from specified {@link android.graphics.Bitmap},
|
||||
* normalized with specified in parameters mean and std to specified {@link java.nio.FloatBuffer}
|
||||
* with specified offset.
|
||||
*
|
||||
* @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data
|
||||
* @param x - x coordinate of top left corner of bitmap's area
|
||||
* @param y - y coordinate of top left corner of bitmap's area
|
||||
* @param width - width of bitmap's area
|
||||
* @param height - height of bitmap's area
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
*/
|
||||
public static void bitmapToFloatBuffer(
|
||||
final Bitmap bitmap,
|
||||
final int x,
|
||||
final int y,
|
||||
final int width,
|
||||
final int height,
|
||||
final float[] normMeanRGB,
|
||||
final float[] normStdRGB,
|
||||
final FloatBuffer outBuffer,
|
||||
final int outBufferOffset) {
|
||||
checkOutBufferCapacity(outBuffer, outBufferOffset, width, height);
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
|
||||
final int pixelsCount = height * width;
|
||||
final int[] pixels = new int[pixelsCount];
|
||||
bitmap.getPixels(pixels, 0, width, x, y, width, height);
|
||||
final int offset_g = pixelsCount;
|
||||
final int offset_b = 2 * pixelsCount;
|
||||
for (int i = 0; i < pixelsCount; i++) {
|
||||
final int c = pixels[i];
|
||||
float r = ((c >> 16) & 0xff) / 255.0f;
|
||||
float g = ((c >> 8) & 0xff) / 255.0f;
|
||||
float b = ((c) & 0xff) / 255.0f;
|
||||
float rF = (r - normMeanRGB[0]) / normStdRGB[0];
|
||||
float gF = (g - normMeanRGB[1]) / normStdRGB[1];
|
||||
float bF = (b - normMeanRGB[2]) / normStdRGB[2];
|
||||
outBuffer.put(outBufferOffset + i, rF);
|
||||
outBuffer.put(outBufferOffset + offset_g + i, gF);
|
||||
outBuffer.put(outBufferOffset + offset_b + i, bF);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates new {@link org.pytorch.Tensor} from specified area of {@link android.graphics.Bitmap},
|
||||
* normalized with specified in parameters mean and std.
|
||||
@ -57,22 +104,9 @@ public final class TensorImageUtils {
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
|
||||
final int pixelsCount = height * width;
|
||||
final int[] pixels = new int[pixelsCount];
|
||||
bitmap.getPixels(pixels, 0, width, x, y, width, height);
|
||||
final float[] floatArray = new float[3 * pixelsCount];
|
||||
final int offset_g = pixelsCount;
|
||||
final int offset_b = 2 * pixelsCount;
|
||||
for (int i = 0; i < pixelsCount; i++) {
|
||||
final int c = pixels[i];
|
||||
float r = ((c >> 16) & 0xff) / 255.0f;
|
||||
float g = ((c >> 8) & 0xff) / 255.0f;
|
||||
float b = ((c) & 0xff) / 255.0f;
|
||||
floatArray[i] = (r - normMeanRGB[0]) / normStdRGB[0];
|
||||
floatArray[offset_g + i] = (g - normMeanRGB[1]) / normStdRGB[1];
|
||||
floatArray[offset_b + i] = (b - normMeanRGB[2]) / normStdRGB[2];
|
||||
}
|
||||
return Tensor.newFloat32Tensor(new long[]{1, 3, height, width}, floatArray);
|
||||
final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * width * height);
|
||||
bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0);
|
||||
return Tensor.fromBlob(floatBuffer, new long[]{1, 3, height, width});
|
||||
}
|
||||
|
||||
/**
|
||||
@ -105,6 +139,52 @@ public final class TensorImageUtils {
|
||||
checkRotateCWDegrees(rotateCWDegrees);
|
||||
checkTensorSize(tensorWidth, tensorHeight);
|
||||
|
||||
final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * tensorWidth * tensorHeight);
|
||||
imageYUV420CenterCropToFloatBuffer(
|
||||
image,
|
||||
rotateCWDegrees,
|
||||
tensorWidth,
|
||||
tensorHeight,
|
||||
normMeanRGB, normStdRGB, floatBuffer, 0);
|
||||
return Tensor.fromBlob(floatBuffer, new long[]{1, 3, tensorHeight, tensorWidth});
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes tensor content from specified {@link android.media.Image}, doing optional rotation,
|
||||
* scaling (nearest) and center cropping to specified {@link java.nio.FloatBuffer} with specified offset.
|
||||
*
|
||||
* @param image {@link android.media.Image} as a source for Tensor data
|
||||
* @param rotateCWDegrees Clockwise angle through which the input image needs to be rotated to be
|
||||
* upright. Range of valid values: 0, 90, 180, 270
|
||||
* @param tensorWidth return tensor width, must be positive
|
||||
* @param tensorHeight return tensor height, must be positive
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param outBuffer Output buffer, where tensor content will be written
|
||||
* @param outBufferOffset Output buffer offset with which tensor content will be written
|
||||
*/
|
||||
public static void imageYUV420CenterCropToFloatBuffer(
|
||||
final Image image,
|
||||
int rotateCWDegrees,
|
||||
final int tensorWidth,
|
||||
final int tensorHeight,
|
||||
float[] normMeanRGB,
|
||||
float[] normStdRGB,
|
||||
final FloatBuffer outBuffer,
|
||||
final int outBufferOffset) {
|
||||
checkOutBufferCapacity(outBuffer, outBufferOffset, tensorWidth, tensorHeight);
|
||||
|
||||
if (image.getFormat() != ImageFormat.YUV_420_888) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
Locale.US, "Image format %d != ImageFormat.YUV_420_888", image.getFormat()));
|
||||
}
|
||||
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
checkRotateCWDegrees(rotateCWDegrees);
|
||||
checkTensorSize(tensorWidth, tensorHeight);
|
||||
|
||||
final int widthBeforeRotation = image.getWidth();
|
||||
final int heightBeforeRotation = image.getHeight();
|
||||
|
||||
@ -158,7 +238,6 @@ public final class TensorImageUtils {
|
||||
final int channelSize = tensorHeight * tensorWidth;
|
||||
final int tensorInputOffsetG = channelSize;
|
||||
final int tensorInputOffsetB = 2 * channelSize;
|
||||
final float[] floatArray = new float[3 * channelSize];
|
||||
for (int x = 0; x < tensorWidth; x++) {
|
||||
for (int y = 0; y < tensorHeight; y++) {
|
||||
|
||||
@ -198,13 +277,22 @@ public final class TensorImageUtils {
|
||||
int r = clamp((a0 + a1) >> 10, 0, 255);
|
||||
int g = clamp((a0 - a2 - a3) >> 10, 0, 255);
|
||||
int b = clamp((a0 + a4) >> 10, 0, 255);
|
||||
final int offset = y * tensorWidth + x;
|
||||
floatArray[offset] = ((r / 255.f) - normMeanRGB[0]) / normStdRGB[0];
|
||||
floatArray[tensorInputOffsetG + offset] = ((g / 255.f) - normMeanRGB[1]) / normStdRGB[1];
|
||||
floatArray[tensorInputOffsetB + offset] = ((b / 255.f) - normMeanRGB[2]) / normStdRGB[2];
|
||||
final int offset = outBufferOffset + y * tensorWidth + x;
|
||||
float rF = ((r / 255.f) - normMeanRGB[0]) / normStdRGB[0];
|
||||
float gF = ((g / 255.f) - normMeanRGB[1]) / normStdRGB[1];
|
||||
float bF = ((b / 255.f) - normMeanRGB[2]) / normStdRGB[2];
|
||||
|
||||
outBuffer.put(offset, rF);
|
||||
outBuffer.put(offset + tensorInputOffsetG, gF);
|
||||
outBuffer.put(offset + tensorInputOffsetB, bF);
|
||||
}
|
||||
}
|
||||
return Tensor.newFloat32Tensor(new long[]{1, 3, tensorHeight, tensorWidth}, floatArray);
|
||||
}
|
||||
|
||||
private static void checkOutBufferCapacity(FloatBuffer outBuffer, int outBufferOffset, int tensorWidth, int tensorHeight) {
|
||||
if (outBufferOffset + 3 * tensorWidth * tensorHeight > outBuffer.capacity()) {
|
||||
throw new IllegalStateException("Buffer underflow");
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkTensorSize(int tensorWidth, int tensorHeight) {
|
||||
|
@ -243,7 +243,20 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
|
||||
set(BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE)
|
||||
set(BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
|
||||
set(BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
|
||||
set(OLD_CMAKE_BUILD_TYPE ${CMAKE_BUILD_TYPE})
|
||||
if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND
|
||||
CMAKE_C_COMPILER_VERSION VERSION_GREATER 6.9 AND CMAKE_C_COMPILER_VERSION VERSION_LESS 8)
|
||||
set(GCC_7 True)
|
||||
else()
|
||||
set(GCC_7 False)
|
||||
endif()
|
||||
if(GCC_7)
|
||||
set(CMAKE_BUILD_TYPE Release) # Always build Sleef as a Release build to work around a gcc-7 bug
|
||||
endif()
|
||||
add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/sleef" ${CMAKE_BINARY_DIR}/sleef)
|
||||
if(GCC_7)
|
||||
set(CMAKE_BUILD_TYPE ${OLD_CMAKE_BUILD_TYPE})
|
||||
endif()
|
||||
set_property(TARGET sleef PROPERTY FOLDER "dependencies")
|
||||
list(APPEND ATen_THIRD_PARTY_INCLUDE ${CMAKE_BINARY_DIR}/include)
|
||||
link_directories(${CMAKE_BINARY_DIR}/sleef/lib)
|
||||
|
@ -77,10 +77,8 @@ TaskThreadPoolBase& _get_intraop_pool() {
|
||||
|
||||
#endif // C10_MOBILE
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace internal {
|
||||
|
||||
// Run lambda function `fn` over `task_id` in [0, `range`) with threadpool.
|
||||
// `fn` will be called with params: (thread_pool_task_id, task_id).
|
||||
void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) {
|
||||
#ifndef C10_MOBILE
|
||||
for (size_t i = 1; i < range; ++i) {
|
||||
@ -101,14 +99,63 @@ void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) {
|
||||
#endif // C10_MOBILE
|
||||
}
|
||||
|
||||
ParallelRegionGuard::ParallelRegionGuard(int64_t task_id) {
|
||||
_set_thread_num(task_id);
|
||||
_set_in_parallel_region(true);
|
||||
}
|
||||
// RAII guard helps to support in_parallel_region() and get_thread_num() API.
|
||||
struct ParallelRegionGuard {
|
||||
ParallelRegionGuard(int64_t task_id) {
|
||||
_set_thread_num(task_id);
|
||||
_set_in_parallel_region(true);
|
||||
}
|
||||
|
||||
ParallelRegionGuard::~ParallelRegionGuard() {
|
||||
_set_in_parallel_region(false);
|
||||
_unset_thread_num();
|
||||
~ParallelRegionGuard() {
|
||||
_set_in_parallel_region(false);
|
||||
_unset_thread_num();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace internal {
|
||||
|
||||
void _parallel_run(
|
||||
const int64_t begin,
|
||||
const int64_t end,
|
||||
const int64_t grain_size,
|
||||
const std::function<void(int64_t, int64_t, size_t)>& f) {
|
||||
size_t num_tasks, chunk_size;
|
||||
std::tie(num_tasks, chunk_size) =
|
||||
internal::calc_num_tasks_and_chunk_size(begin, end, grain_size);
|
||||
|
||||
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
|
||||
std::exception_ptr eptr;
|
||||
std::vector<std::shared_ptr<c10::ivalue::Future>> futures(num_tasks);
|
||||
for (size_t task_id = 0; task_id < num_tasks; ++task_id) {
|
||||
futures[task_id] = std::make_shared<c10::ivalue::Future>(c10::NoneType::get());
|
||||
}
|
||||
auto task = [f, &eptr, &err_flag, &futures, begin, end, chunk_size]
|
||||
(int /* unused */, size_t task_id) {
|
||||
int64_t local_start = begin + task_id * chunk_size;
|
||||
if (local_start < end) {
|
||||
int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
|
||||
try {
|
||||
ParallelRegionGuard guard(task_id);
|
||||
f(local_start, local_end, task_id);
|
||||
} catch (...) {
|
||||
if (!err_flag.test_and_set()) {
|
||||
eptr = std::current_exception();
|
||||
}
|
||||
}
|
||||
}
|
||||
futures[task_id]->markCompleted();
|
||||
};
|
||||
_run_with_pool(task, num_tasks);
|
||||
|
||||
// Wait for all tasks to finish.
|
||||
for (size_t task_id = 0; task_id < num_tasks; ++task_id) {
|
||||
futures[task_id]->wait();
|
||||
}
|
||||
if (eptr) {
|
||||
std::rethrow_exception(eptr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
@ -9,16 +9,6 @@
|
||||
namespace at {
|
||||
namespace internal {
|
||||
|
||||
// Run lambda function `fn` over `task_id` in [0, `range`) with threadpool.
|
||||
// `fn` will be called with params: (thread_pool_task_id, task_id).
|
||||
CAFFE2_API void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range);
|
||||
|
||||
// RAII guard helps to support in_parallel_region() and get_thread_num() API.
|
||||
struct CAFFE2_API ParallelRegionGuard {
|
||||
ParallelRegionGuard(int64_t task_id);
|
||||
~ParallelRegionGuard();
|
||||
};
|
||||
|
||||
inline std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size(
|
||||
int64_t begin, int64_t end, int64_t grain_size) {
|
||||
if ((end - begin) < grain_size) {
|
||||
@ -32,6 +22,12 @@ inline std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size(
|
||||
return std::make_tuple(num_tasks, chunk_size);
|
||||
}
|
||||
|
||||
CAFFE2_API void _parallel_run(
|
||||
const int64_t begin,
|
||||
const int64_t end,
|
||||
const int64_t grain_size,
|
||||
const std::function<void(int64_t, int64_t, size_t)>& f);
|
||||
|
||||
} // namespace internal
|
||||
|
||||
template <class F>
|
||||
@ -48,40 +44,14 @@ inline void parallel_for(
|
||||
f(begin, end);
|
||||
return;
|
||||
}
|
||||
size_t num_tasks, chunk_size;
|
||||
std::tie(num_tasks, chunk_size) =
|
||||
internal::calc_num_tasks_and_chunk_size(begin, end, grain_size);
|
||||
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
|
||||
std::exception_ptr eptr;
|
||||
std::vector<std::shared_ptr<c10::ivalue::Future>> futures(num_tasks);
|
||||
for (size_t task_id = 0; task_id < num_tasks; ++task_id) {
|
||||
futures[task_id] = std::make_shared<c10::ivalue::Future>(c10::NoneType::get());
|
||||
}
|
||||
auto task = [f, &eptr, &err_flag, &futures, begin, end, chunk_size]
|
||||
(int _, size_t task_id) {
|
||||
int64_t local_start = begin + task_id * chunk_size;
|
||||
if (local_start < end) {
|
||||
int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
|
||||
try {
|
||||
internal::ParallelRegionGuard guard(task_id);
|
||||
f(local_start, local_end);
|
||||
} catch (...) {
|
||||
if (!err_flag.test_and_set()) {
|
||||
eptr = std::current_exception();
|
||||
}
|
||||
internal::_parallel_run(
|
||||
begin,
|
||||
end,
|
||||
grain_size,
|
||||
[f](int64_t start, int64_t end, size_t /* unused */) {
|
||||
f(start, end);
|
||||
}
|
||||
}
|
||||
futures[task_id]->markCompleted();
|
||||
};
|
||||
internal::_run_with_pool(task, num_tasks);
|
||||
|
||||
// Wait for all tasks to finish.
|
||||
for (size_t task_id = 0; task_id < num_tasks; ++task_id) {
|
||||
futures[task_id]->wait();
|
||||
}
|
||||
if (eptr) {
|
||||
std::rethrow_exception(eptr);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
template <class scalar_t, class F, class SF>
|
||||
@ -104,39 +74,14 @@ inline scalar_t parallel_reduce(
|
||||
internal::calc_num_tasks_and_chunk_size(begin, end, grain_size);
|
||||
std::vector<scalar_t> results(num_tasks);
|
||||
scalar_t* results_data = results.data();
|
||||
|
||||
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
|
||||
std::exception_ptr eptr;
|
||||
std::vector<std::shared_ptr<c10::ivalue::Future>> futures(num_tasks);
|
||||
for (size_t task_id = 0; task_id < num_tasks; ++task_id) {
|
||||
futures[task_id] = std::make_shared<c10::ivalue::Future>(c10::NoneType::get());
|
||||
}
|
||||
auto task = [f, ident, results_data, &eptr, &err_flag, &futures, begin, end, chunk_size]
|
||||
(int _, size_t task_id) {
|
||||
int64_t local_start = begin + task_id * chunk_size;
|
||||
if (local_start < end) {
|
||||
int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
|
||||
try {
|
||||
internal::ParallelRegionGuard guard(task_id);
|
||||
results_data[task_id] = f(local_start, local_end, ident);
|
||||
} catch (...) {
|
||||
if (!err_flag.test_and_set()) {
|
||||
eptr = std::current_exception();
|
||||
}
|
||||
internal::_parallel_run(
|
||||
begin,
|
||||
end,
|
||||
grain_size,
|
||||
[f, ident, results_data](int64_t start, int64_t end, size_t task_id) {
|
||||
results_data[task_id] = f(start, end, ident);
|
||||
}
|
||||
}
|
||||
futures[task_id]->markCompleted();
|
||||
};
|
||||
internal::_run_with_pool(task, num_tasks);
|
||||
|
||||
// Wait for all tasks to finish.
|
||||
for (size_t task_id = 0; task_id < num_tasks; ++task_id) {
|
||||
futures[task_id]->wait();
|
||||
}
|
||||
if (eptr) {
|
||||
std::rethrow_exception(eptr);
|
||||
}
|
||||
|
||||
);
|
||||
scalar_t result = ident;
|
||||
for (auto partial_result : results) {
|
||||
result = sf(result, partial_result);
|
||||
|
@ -1218,6 +1218,8 @@ bool aten_op_is_not_moved_to_c10_yet(const c10::OperatorName& opName) {
|
||||
{"aten::quantize_per_channel", ""},
|
||||
{"aten::q_per_channel_axis", ""},
|
||||
{"aten::qscheme", ""},
|
||||
{"aten::fake_quantize_per_channel_affine", ""},
|
||||
{"aten::fake_quantize_per_channel_affine_backward", ""},
|
||||
{"aten::to", "dtype_layout"},
|
||||
{"aten::to", "device"},
|
||||
{"aten::to", "dtype"},
|
||||
|
@ -77,6 +77,7 @@ namespace c10 {
|
||||
_(prim, requires_grad) \
|
||||
_(prim, AutogradAdd) \
|
||||
_(prim, GradOf) \
|
||||
_(aten, grad) \
|
||||
_(aten, backward) \
|
||||
_(prim, Guard) \
|
||||
_(prim, BailOut) \
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
|
||||
#include <ATen/cudnn/cudnn-wrapper.h>
|
||||
#include <ATen/cudnn/Utils.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cuda/ATenCUDAGeneral.h>
|
||||
@ -190,6 +191,7 @@ struct AT_CUDA_API DropoutDescriptor
|
||||
AT_ASSERT(options.device().type() == kCUDA);
|
||||
AT_ASSERT(options.dtype() == kByte);
|
||||
state = at::empty({static_cast<int64_t>(state_size)}, options);
|
||||
setCuDNNStreamToCurrent();
|
||||
AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed));
|
||||
}
|
||||
|
||||
@ -200,6 +202,7 @@ struct AT_CUDA_API DropoutDescriptor
|
||||
void *state_ptr = state.data_ptr();
|
||||
size_t state_size = state.size(0);
|
||||
// NB: The seed doesn't actually matter, so we give a dummy value
|
||||
setCuDNNStreamToCurrent();
|
||||
AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 /* seed */));
|
||||
}
|
||||
|
||||
|
@ -31,7 +31,7 @@ Tensor cosine_embedding_loss(const Tensor& input1, const Tensor& input2, const T
|
||||
auto denom = (mag_square1 * mag_square2).sqrt_();
|
||||
auto cos = prod_sum / denom;
|
||||
|
||||
auto zeros = at::zeros_like(target);
|
||||
auto zeros = at::zeros_like(cos);
|
||||
auto pos = 1 - cos;
|
||||
auto neg = (cos - margin).clamp_min_(0);
|
||||
auto output_pos = at::where(target == 1, pos, zeros);
|
||||
@ -67,8 +67,8 @@ Tensor margin_ranking_loss(const Tensor& input1, const Tensor& input2, const Ten
|
||||
}
|
||||
|
||||
Tensor kl_div(const Tensor& input, const Tensor& target, int64_t reduction) {
|
||||
auto zeros = at::zeros_like(target);
|
||||
auto output_pos = target * (at::log(target) - input);
|
||||
auto zeros = at::zeros_like(output_pos);
|
||||
auto output = at::where(target > 0, output_pos, zeros);
|
||||
return apply_loss_reduction(output, reduction);
|
||||
}
|
||||
|
@ -55,53 +55,67 @@ bool _nnpack_available() {
|
||||
|
||||
#include "nnpack.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <ATen/Parallel.h>
|
||||
#include <thread>
|
||||
#include "caffe2/utils/threadpool/ThreadPoolMobile.h"
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
// Stolen from Caffe2
|
||||
static pthreadpool_t nnpack_threadpool_ = nullptr;
|
||||
static bool called_nnpack_threadpool_ = false;
|
||||
static bool init_nnpack() {
|
||||
static std::once_flag once_;
|
||||
static bool nnpack_successfully_initialized_ = false;
|
||||
|
||||
std::call_once(once_, []() {
|
||||
const nnp_status nnpack_status = nnp_initialize();
|
||||
nnpack_successfully_initialized_ = (nnp_status_success == nnpack_status);
|
||||
|
||||
pthreadpool_t nnpack_threadpool() {
|
||||
if (! called_nnpack_threadpool_) {
|
||||
called_nnpack_threadpool_ = true;
|
||||
enum nnp_status nnpack_status = nnp_initialize();
|
||||
if (nnpack_status != nnp_status_success) {
|
||||
if (nnpack_status == nnp_status_out_of_memory) {
|
||||
throw std::runtime_error("could not initialize NNPack (out of memory)");
|
||||
LOG(WARNING) << "Could not initialize NNPACK! Reason: Out of memory.";
|
||||
} else if (nnpack_status == nnp_status_unsupported_hardware) {
|
||||
throw std::runtime_error("could not initialize NNPack (unsupported hardware)");
|
||||
LOG(WARNING) << "Could not initialize NNPACK! Reason: Unsupported hardware.";
|
||||
} else {
|
||||
throw std::runtime_error("could not initialize NNPack (unknown error)");
|
||||
LOG(WARNING) << "Could not initialize NNPACK! Reason: Unknown error!";
|
||||
}
|
||||
}
|
||||
unsigned int threads;
|
||||
#ifdef INTRA_OP_PARALLEL
|
||||
threads = at::get_num_threads();
|
||||
});
|
||||
|
||||
return nnpack_successfully_initialized_;
|
||||
}
|
||||
|
||||
static pthreadpool_t nnpack_threadpool() {
|
||||
// Try initializing a threadpool for NNPACK's use. If we fail to
|
||||
// successfully initialize an implementation, return nullptr which will
|
||||
// instruct NNPACK to run single threaded.
|
||||
|
||||
#ifdef C10_MOBILE
|
||||
// If building for mobile, use Caffe 2's mobile-friendly threadpool.
|
||||
return caffe2::mobile_pthreadpool();
|
||||
#else
|
||||
threads = std::thread::hardware_concurrency();
|
||||
// Otherwise, try using pthreadpool if we manage to initialize it successfully.
|
||||
static pthreadpool_t nnpack_threadpool_ = nullptr;
|
||||
static bool called_nnpack_threadpool_ = false;
|
||||
|
||||
if (!called_nnpack_threadpool_) {
|
||||
called_nnpack_threadpool_ = true;
|
||||
|
||||
#ifdef INTRA_OP_PARALLEL
|
||||
const uint32_t threads = at::get_num_threads();
|
||||
#else
|
||||
const uint32_t threads = std::thread::hardware_concurrency();
|
||||
#endif
|
||||
|
||||
nnpack_threadpool_ = pthreadpool_create(threads);
|
||||
if (nnpack_threadpool_ == nullptr) {
|
||||
throw std::runtime_error("could not initialize NNPack's pthreadpool");
|
||||
if ( !nnpack_threadpool_ ) {
|
||||
LOG(WARNING) << "Failed to initialize pthreadpool! Running NNPACK in single-threaded mode.";
|
||||
}
|
||||
}
|
||||
|
||||
return nnpack_threadpool_;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool _nnpack_available() {
|
||||
if (! called_nnpack_threadpool_) {
|
||||
try {
|
||||
return nnpack_threadpool() != nullptr;
|
||||
} catch (std::runtime_error e) {
|
||||
}
|
||||
}
|
||||
return nnpack_threadpool() != nullptr;
|
||||
return init_nnpack();
|
||||
}
|
||||
|
||||
// Make thread_local for safety in cases where we have multiple threads running
|
||||
|
@ -152,7 +152,7 @@ Tensor align_to(const Tensor& tensor, DimnameList names) {
|
||||
const auto& dim = tensor_names[idx];
|
||||
TORCH_CHECK(dim.isBasic(),
|
||||
"align_to: All input dims must be named. Found unnamed dim at index ",
|
||||
dim, " of Tensor", tensor_names);
|
||||
idx, " of Tensor", tensor_names);
|
||||
auto it = std::find(names.begin(), names.end(), dim);
|
||||
TORCH_CHECK(it != names.end(),
|
||||
"align_to: Cannot find dim ", dim, " from Tensor", names,
|
||||
|
@ -122,6 +122,7 @@ std::vector<Tensor> where(const Tensor& condition) {
|
||||
}
|
||||
|
||||
Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& other) {
|
||||
TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
|
||||
Tensor ret = at::empty(self.sizes(), self.options());
|
||||
AT_DISPATCH_ALL_TYPES(ret.scalar_type(), "where_cpu", [&] {
|
||||
where_cpu<scalar_t>(ret, condition, self, other);
|
||||
|
@ -102,11 +102,13 @@ std::tuple<Device, ScalarType> TensorIterator::compute_common_type() {
|
||||
return compute_common_type_(operands_);
|
||||
}
|
||||
|
||||
static void validate_dtype(OperandInfo& op, ScalarType common_dtype, int ninputs) {
|
||||
static void validate_dtype(OperandInfo& op, ScalarType common_dtype, CommonDTypeStrategy strategy) {
|
||||
bool check_output = strategy == CommonDTypeStrategy::CHECK || strategy == CommonDTypeStrategy::PROMOTE;
|
||||
bool promoting = strategy == CommonDTypeStrategy::PROMOTE || (strategy == CommonDTypeStrategy::PROMOTE_INPUTS && !op.is_output);
|
||||
if (op.tensor.defined()) {
|
||||
// For binary_ops, we follow casting rules. For unary/nullary types
|
||||
// we require the type to match.
|
||||
if (op.is_output) {
|
||||
if (op.is_output && check_output) {
|
||||
if (!canCast(common_dtype, op.tensor.scalar_type()))
|
||||
{
|
||||
AT_ERROR("result type ", common_dtype,
|
||||
@ -114,7 +116,7 @@ static void validate_dtype(OperandInfo& op, ScalarType common_dtype, int ninputs
|
||||
op.tensor.scalar_type());
|
||||
}
|
||||
}
|
||||
if (ninputs < 2 && op.dtype != op.tensor.scalar_type()) {
|
||||
if (!promoting && op.dtype != op.tensor.scalar_type()) {
|
||||
AT_ERROR("expected dtype ", op.dtype, " but got dtype ", op.tensor.scalar_type());
|
||||
}
|
||||
}
|
||||
@ -131,14 +133,6 @@ static void maybe_promote_common_dtype(OperandInfo& op, ScalarType common_dtype)
|
||||
op.tensor =
|
||||
at::empty_like(op.tensor, op.tensor.options().dtype(common_dtype));
|
||||
}
|
||||
auto original_element_size = op.original_tensor.element_size();
|
||||
auto new_element_size = op.tensor.element_size();
|
||||
|
||||
// stride size (in bytes) can change if we change the dtype.
|
||||
for( size_t i=0; i < op.stride_bytes.size(); i++ ) {
|
||||
auto stride = op.stride_bytes[i] / original_element_size;
|
||||
op.stride_bytes[i] = stride * new_element_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -155,12 +149,12 @@ void TensorIterator::compute_types() {
|
||||
}
|
||||
}
|
||||
|
||||
if (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_INPUTS) {
|
||||
if (common_dtype_strategy_ == CommonDTypeStrategy::PROMOTE_INPUTS) {
|
||||
TORCH_CHECK(!missing_output_dtypes, "unable to compute and promote common dtype based only on inputs if there are missing dtypes for outputs");
|
||||
}
|
||||
|
||||
bool compute_common_dtype = (compute_common_dtype_strategy_ != CommonDTypeStrategy::COMPUTE_NONE);
|
||||
bool compute_common_dtype_only_for_inputs = (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_INPUTS);
|
||||
bool compute_common_dtype = (common_dtype_strategy_ != CommonDTypeStrategy::NONE);
|
||||
bool compute_common_dtype_only_for_inputs = (common_dtype_strategy_ == CommonDTypeStrategy::PROMOTE_INPUTS);
|
||||
|
||||
if (missing_dtypes || compute_common_dtype) {
|
||||
auto operands = compute_common_dtype_only_for_inputs ? at::ArrayRef<OperandInfo>(operands_).slice(noutputs()) : operands_;
|
||||
@ -198,24 +192,25 @@ void TensorIterator::compute_types() {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!compute_common_dtype_only_for_inputs) {
|
||||
validate_dtype(op, common_dtype, ninputs());
|
||||
}
|
||||
if (!compute_common_dtype_only_for_inputs || !op.is_output) {
|
||||
maybe_promote_common_dtype(op, common_dtype);
|
||||
}
|
||||
|
||||
if (op.tensor.defined() && op.device != op.tensor.device()) {
|
||||
if (op.is_output) {
|
||||
AT_ERROR("output with device ", op.tensor.device(),
|
||||
" doesn't match the desired device ", op.device);
|
||||
} else if (op.tensor.dim() == 0) {
|
||||
op.tensor = op.tensor.to(op.options());
|
||||
} else {
|
||||
AT_ERROR("expected device ", op.device,
|
||||
" but got device ", op.tensor.device());
|
||||
}
|
||||
for (auto &op : operands_) {
|
||||
validate_dtype(op, common_dtype, common_dtype_strategy_);
|
||||
if (compute_common_dtype && (!compute_common_dtype_only_for_inputs || !op.is_output)) {
|
||||
maybe_promote_common_dtype(op, common_dtype);
|
||||
}
|
||||
|
||||
if (op.tensor.defined() && op.device != op.tensor.device()) {
|
||||
if (op.is_output) {
|
||||
AT_ERROR("output with device ", op.tensor.device(),
|
||||
" doesn't match the desired device ", op.device);
|
||||
} else if (op.tensor.dim() == 0) {
|
||||
op.tensor = op.tensor.to(op.options());
|
||||
} else {
|
||||
AT_ERROR("expected device ", op.device,
|
||||
" but got device ", op.tensor.device());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -602,6 +597,7 @@ TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a,
|
||||
iter.add_input(a);
|
||||
iter.add_input(b);
|
||||
iter.allow_cpu_scalars_ = true;
|
||||
iter.promote_common_dtype();
|
||||
iter.build();
|
||||
return iter;
|
||||
}
|
||||
@ -827,12 +823,12 @@ void TensorIterator::build() {
|
||||
#endif
|
||||
// compute the broadcasted shape
|
||||
compute_shape();
|
||||
// compute the result dtype and device
|
||||
compute_types();
|
||||
// compute each tensor's stride after broadcasting
|
||||
compute_strides();
|
||||
// re-order dimensions to improve coalescing
|
||||
reorder_dimensions();
|
||||
// compute the result dtype and device
|
||||
compute_types();
|
||||
// allocate the output tensor if it's not provided
|
||||
allocate_outputs();
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
|
@ -126,9 +126,10 @@ struct CAFFE2_API OperandInfo {
|
||||
struct SplitUntil32Bit;
|
||||
|
||||
enum class CommonDTypeStrategy : uint8_t {
|
||||
COMPUTE_ALL = 0, // Compute common dtype based on inputs and outputs. Try to promote common dtype to inputs and outputs
|
||||
COMPUTE_INPUTS = 1, // Compute common dtype based only on inputs. Try to promote common dtype only to inputs
|
||||
COMPUTE_NONE = 2, // Do not compute and promote common dtype
|
||||
NONE, // Do not compute a common dtype
|
||||
CHECK, // Compute and validate a common dtype but don't promote.
|
||||
PROMOTE_INPUTS, // Promote common dtype but only validate inputs (comparison ops have boolean output)
|
||||
PROMOTE // Promote to common dtype.
|
||||
};
|
||||
|
||||
struct CAFFE2_API TensorIterator {
|
||||
@ -204,7 +205,7 @@ struct CAFFE2_API TensorIterator {
|
||||
}
|
||||
|
||||
void cast_outputs() {
|
||||
if (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_ALL) {
|
||||
if (common_dtype_strategy_ == CommonDTypeStrategy::PROMOTE) {
|
||||
for(int i=0; i < noutputs(); i++) {
|
||||
if (operands_[i].original_tensor.defined() && dtype(i) != operands_[i].original_tensor.scalar_type()) {
|
||||
operands_[i].original_tensor.copy_(operands_[i].tensor);
|
||||
@ -306,12 +307,16 @@ struct CAFFE2_API TensorIterator {
|
||||
operands_.emplace_back(input, device, dtype);
|
||||
}
|
||||
|
||||
void promote_common_dtype() {
|
||||
common_dtype_strategy_ = CommonDTypeStrategy::PROMOTE;
|
||||
}
|
||||
|
||||
void dont_compute_common_dtype() {
|
||||
compute_common_dtype_strategy_ = CommonDTypeStrategy::COMPUTE_NONE;
|
||||
common_dtype_strategy_ = CommonDTypeStrategy::NONE;
|
||||
}
|
||||
|
||||
void compute_common_dtype_only_for_inputs() {
|
||||
compute_common_dtype_strategy_ = CommonDTypeStrategy::COMPUTE_INPUTS;
|
||||
common_dtype_strategy_ = CommonDTypeStrategy::PROMOTE_INPUTS;
|
||||
}
|
||||
|
||||
void dont_resize_outputs() {
|
||||
@ -344,7 +349,7 @@ protected:
|
||||
#endif
|
||||
SmallVector<OperandInfo, 4> operands_;
|
||||
int num_outputs_ = 0;
|
||||
CommonDTypeStrategy compute_common_dtype_strategy_ = CommonDTypeStrategy::COMPUTE_ALL;
|
||||
CommonDTypeStrategy common_dtype_strategy_ = CommonDTypeStrategy::CHECK;
|
||||
bool has_coalesced_dimensions_ = false;
|
||||
bool accumulate_ = false;
|
||||
bool resize_outputs_ = true;
|
||||
|
@ -171,8 +171,9 @@ void avg_pool2d_out_cuda_template(
|
||||
|
||||
output.resize_({nbatch, nInputPlane, outputHeight, outputWidth});
|
||||
|
||||
const int count = safe_downcast<int, int64_t>(output.numel());
|
||||
const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
const int32_t count = safe_downcast<int32_t, int64_t>(output.numel());
|
||||
const uint32_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
const uint32_t num_blocks = cuda::ATenCeilDiv<uint32_t>(count, num_threads);
|
||||
|
||||
if (divisor_override.has_value()) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
|
||||
@ -184,7 +185,7 @@ void avg_pool2d_out_cuda_template(
|
||||
scalar_t *input_data = input.data_ptr<scalar_t>();
|
||||
|
||||
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t, false, true>
|
||||
<<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
input_data,
|
||||
nbatch,
|
||||
@ -209,7 +210,7 @@ void avg_pool2d_out_cuda_template(
|
||||
scalar_t *input_data = input.data_ptr<scalar_t>();
|
||||
|
||||
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t, true, false>
|
||||
<<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
input_data,
|
||||
nbatch,
|
||||
@ -233,7 +234,7 @@ void avg_pool2d_out_cuda_template(
|
||||
scalar_t *input_data = input.data_ptr<scalar_t>();
|
||||
|
||||
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t, false, false>
|
||||
<<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
input_data,
|
||||
nbatch,
|
||||
@ -249,10 +250,8 @@ void avg_pool2d_out_cuda_template(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TORCH_CHECK(cudaGetLastError() == cudaSuccess,
|
||||
"avg_pool2d_out_cuda_frame failed with error code ",
|
||||
cudaGetLastError());
|
||||
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
if (input.ndimension() == 3) {
|
||||
output.resize_({nInputPlane, outputHeight, outputWidth});
|
||||
@ -322,8 +321,9 @@ Tensor& avg_pool2d_backward_out_cuda_template(
|
||||
|
||||
gradInput.resize_as_(input);
|
||||
|
||||
const int count = safe_downcast<int, int64_t>(input.numel());
|
||||
const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
const int32_t count = safe_downcast<int32_t, int64_t>(input.numel());
|
||||
const uint32_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
const uint32_t num_blocks = cuda::ATenCeilDiv<uint32_t>(count, num_threads);
|
||||
|
||||
if (divisor_override.has_value()) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
|
||||
@ -335,7 +335,7 @@ Tensor& avg_pool2d_backward_out_cuda_template(
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
|
||||
avg_pool2d_backward_out_cuda_frame<scalar_t, accscalar_t, false, true>
|
||||
<<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
gradOutput_data,
|
||||
nbatch,
|
||||
@ -360,7 +360,7 @@ Tensor& avg_pool2d_backward_out_cuda_template(
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
|
||||
avg_pool2d_backward_out_cuda_frame<scalar_t, accscalar_t, true, false>
|
||||
<<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
gradOutput_data,
|
||||
nbatch,
|
||||
@ -384,7 +384,7 @@ Tensor& avg_pool2d_backward_out_cuda_template(
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
|
||||
avg_pool2d_backward_out_cuda_frame<scalar_t, accscalar_t, false, false>
|
||||
<<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
gradOutput_data,
|
||||
nbatch,
|
||||
@ -400,9 +400,7 @@ Tensor& avg_pool2d_backward_out_cuda_template(
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_CHECK(cudaGetLastError() == cudaSuccess,
|
||||
"avg_pool2d_backward_out_cuda failed with error code ",
|
||||
cudaGetLastError());
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
return gradInput;
|
||||
}
|
||||
|
@ -428,7 +428,7 @@ ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_da
|
||||
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
|
||||
int64_t batch_size, int64_t num_labels, int64_t BLANK, bool zero_infinity) {
|
||||
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
int64_t s = threadIdx.x + blockIdx.x * blockDim.y; // note, this directly indexes into targets, no targets prime!
|
||||
int64_t s = threadIdx.x + blockIdx.x * blockDim.x; // note, this directly indexes into targets, not targets prime!
|
||||
|
||||
if (b >= batch_size)
|
||||
return;
|
||||
|
@ -141,50 +141,80 @@ void or_kernel_cuda(TensorIterator& iter) {
|
||||
}), false);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename acc_t=scalar_t>
|
||||
void max_values_kernel_cuda_impl(TensorIterator& iter) {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter, func_wrapper<scalar_t> ([]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
|
||||
return (THCNumerics<scalar_t>::isnan(a) || a > b) ? a : b;
|
||||
}), at::numeric_limits<scalar_t>::lower_bound());
|
||||
iter, func_wrapper<acc_t> ([]GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
|
||||
return (THCNumerics<acc_t>::isnan(a) || a > b) ? a : b;
|
||||
}), at::numeric_limits<acc_t>::lower_bound());
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename acc_t=scalar_t>
|
||||
void min_values_kernel_cuda_impl(TensorIterator& iter) {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter, func_wrapper<scalar_t> ([]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
|
||||
return (THCNumerics<scalar_t>::isnan(a) || a < b) ? a : b;
|
||||
}), at::numeric_limits<scalar_t>::upper_bound());
|
||||
iter, func_wrapper<acc_t> ([]GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
|
||||
return (THCNumerics<acc_t>::isnan(a) || a < b) ? a : b;
|
||||
}), at::numeric_limits<acc_t>::upper_bound());
|
||||
}
|
||||
|
||||
void max_values_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(), "max_values_cuda", [&]() {
|
||||
max_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
if (iter.dtype(1) == kHalf) {
|
||||
max_values_kernel_cuda_impl<at::Half, float>(iter);
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(), "max_values_cuda", [&]() {
|
||||
max_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void min_values_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(), "min_values_cuda", [&]() {
|
||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
if (iter.dtype(1) == kHalf) {
|
||||
min_values_kernel_cuda_impl<at::Half, float>(iter);
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(), "min_values_cuda", [&]() {
|
||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename acc_t=scalar_t>
|
||||
void argmax_kernel_cuda_impl(TensorIterator& iter) {
|
||||
gpu_reduce_kernel<scalar_t, int64_t>(
|
||||
iter,
|
||||
ArgMaxOps<acc_t>{},
|
||||
thrust::pair<acc_t, int64_t>(at::numeric_limits<acc_t>::lower_bound(), 0));
|
||||
};
|
||||
|
||||
template <typename scalar_t, typename acc_t=scalar_t>
|
||||
void argmin_kernel_cuda_impl(TensorIterator& iter) {
|
||||
gpu_reduce_kernel<scalar_t, int64_t>(
|
||||
iter,
|
||||
ArgMinOps<acc_t>{},
|
||||
thrust::pair<acc_t, int64_t>(at::numeric_limits<acc_t>::upper_bound(), 0));
|
||||
};
|
||||
|
||||
void argmax_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmax_cuda", [&]() {
|
||||
gpu_reduce_kernel<scalar_t, int64_t>(
|
||||
iter,
|
||||
ArgMaxOps<scalar_t>{},
|
||||
thrust::pair<scalar_t, int64_t>(at::numeric_limits<scalar_t>::lower_bound(), 0));
|
||||
});
|
||||
if (iter.dtype(1) == kHalf) {
|
||||
// Instead of implementing is_nan and warp_shfl_down
|
||||
// we can convert halves to float and do all the operations in float
|
||||
argmax_kernel_cuda_impl<at::Half, float>(iter);
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmax_cuda", [&]() {
|
||||
argmax_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void argmin_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmin_cuda", [&]() {
|
||||
gpu_reduce_kernel<scalar_t, int64_t>(
|
||||
iter,
|
||||
ArgMinOps<scalar_t>{},
|
||||
thrust::pair<scalar_t, int64_t>(at::numeric_limits<scalar_t>::upper_bound(), 0));
|
||||
});
|
||||
if (iter.dtype(1) == kHalf) {
|
||||
// Instead of implementing is_nan and warp_shfl_down
|
||||
// we can convert halves to float and do all the operations in float
|
||||
argmin_kernel_cuda_impl<at::Half, float>(iter);
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmin_cuda", [&]() {
|
||||
argmin_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(std_var_stub, &std_var_kernel_cuda);
|
||||
|
@ -778,6 +778,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
|
||||
&reserve_size
|
||||
));
|
||||
reserve = at::empty(reserve_size, input.options().dtype(kByte));
|
||||
setCuDNNStreamToCurrent();
|
||||
AT_CUDNN_CHECK(cudnnRNNForwardTraining(
|
||||
handle,
|
||||
descs.rnn_desc.desc(),
|
||||
@ -794,6 +795,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
|
||||
));
|
||||
} else { // inference
|
||||
reserve = at::empty({0}, input.options().dtype(kByte));
|
||||
setCuDNNStreamToCurrent();
|
||||
AT_CUDNN_CHECK(cudnnRNNForwardInference(
|
||||
handle,
|
||||
descs.rnn_desc.desc(),
|
||||
@ -912,7 +914,7 @@ std::tuple<Tensor, Tensor, Tensor> _cudnn_rnn_backward_input(
|
||||
));
|
||||
// TODO: put this in the correct device???
|
||||
Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte));
|
||||
|
||||
setCuDNNStreamToCurrent();
|
||||
AT_CUDNN_CHECK(cudnnRNNBackwardData(
|
||||
handle,
|
||||
descs.rnn_desc.desc(),
|
||||
@ -1016,7 +1018,7 @@ std::vector<Tensor> _cudnn_rnn_backward_weight(
|
||||
&workspace_size
|
||||
));
|
||||
Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte));
|
||||
|
||||
setCuDNNStreamToCurrent();
|
||||
AT_CUDNN_CHECK(cudnnRNNBackwardWeights(
|
||||
handle,
|
||||
descs.rnn_desc.desc(),
|
||||
|
@ -66,7 +66,7 @@
|
||||
supports_named_tensor: True
|
||||
|
||||
- func: align_to(Tensor(a) self, DimnameList names) -> Tensor(a)
|
||||
variants: function, method
|
||||
variants: method
|
||||
supports_named_tensor: True
|
||||
|
||||
- func: align_as(Tensor self, Tensor other) -> Tensor
|
||||
@ -1397,7 +1397,8 @@
|
||||
use_c10_dispatcher: full
|
||||
variants: function
|
||||
device_guard: False
|
||||
|
||||
supports_named_tensor: True
|
||||
|
||||
- func: is_distributed(Tensor self) -> bool
|
||||
use_c10_dispatcher: full
|
||||
variants: function, method
|
||||
@ -3681,6 +3682,17 @@
|
||||
CPU: fake_quantize_per_tensor_affine_backward_cpu
|
||||
CUDA: fake_quantize_per_tensor_affine_backward_cuda
|
||||
|
||||
- func: fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: fake_quantize_per_channel_affine_cpu
|
||||
CUDA: fake_quantize_per_channel_affine_cuda
|
||||
|
||||
- func: fake_quantize_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: fake_quantize_per_channel_affine_backward_cpu
|
||||
CUDA: fake_quantize_per_channel_affine_backward_cuda
|
||||
# to(Device) must not exist because all constructors of Device also works for
|
||||
# TensorOptions. Otherwise, an ambiguity error is thrown.
|
||||
# See NOTE [ TensorOptions Constructors ].
|
||||
@ -4473,12 +4485,14 @@
|
||||
CUDA: legacy::cuda::_th_trace
|
||||
|
||||
- func: ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
supports_named_tensor: True
|
||||
dispatch:
|
||||
CPU: ne_out
|
||||
CUDA: ne_out
|
||||
QuantizedCPU: ne_out_quantized_cpu
|
||||
|
||||
- func: ne.Scalar(Tensor self, Scalar other) -> Tensor
|
||||
supports_named_tensor: True
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
@ -4487,12 +4501,14 @@
|
||||
QuantizedCPU: ne_quantized_cpu
|
||||
|
||||
- func: ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
supports_named_tensor: True
|
||||
dispatch:
|
||||
CPU: ne_out
|
||||
CUDA: ne_out
|
||||
QuantizedCPU: ne_out_quantized_cpu
|
||||
|
||||
- func: ne.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
supports_named_tensor: True
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
@ -4501,12 +4517,14 @@
|
||||
QuantizedCPU: ne_quantized_cpu
|
||||
|
||||
- func: eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
supports_named_tensor: True
|
||||
dispatch:
|
||||
CPU: eq_out
|
||||
CUDA: eq_out
|
||||
QuantizedCPU: eq_out_quantized_cpu
|
||||
|
||||
- func: eq.Scalar(Tensor self, Scalar other) -> Tensor
|
||||
supports_named_tensor: True
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
@ -4515,12 +4533,14 @@
|
||||
QuantizedCPU: eq_quantized_cpu
|
||||
|
||||
- func: eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
supports_named_tensor: True
|
||||
dispatch:
|
||||
CPU: eq_out
|
||||
CUDA: eq_out
|
||||
QuantizedCPU: eq_out_quantized_cpu
|
||||
|
||||
- func: eq.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
supports_named_tensor: True
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
@ -4529,12 +4549,14 @@
|
||||
QuantizedCPU: eq_quantized_cpu
|
||||
|
||||
- func: ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
supports_named_tensor: True
|
||||
dispatch:
|
||||
CPU: ge_out
|
||||
CUDA: ge_out
|
||||
QuantizedCPU: ge_out_quantized_cpu
|
||||
|
||||
- func: ge.Scalar(Tensor self, Scalar other) -> Tensor
|
||||
supports_named_tensor: True
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
@ -4543,12 +4565,14 @@
|
||||
QuantizedCPU: ge_quantized_cpu
|
||||
|
||||
- func: ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
supports_named_tensor: True
|
||||
dispatch:
|
||||
CPU: ge_out
|
||||
CUDA: ge_out
|
||||
QuantizedCPU: ge_out_quantized_cpu
|
||||
|
||||
- func: ge.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
supports_named_tensor: True
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
@ -4557,12 +4581,14 @@
|
||||
QuantizedCPU: ge_quantized_cpu
|
||||
|
||||
- func: le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
supports_named_tensor: True
|
||||
dispatch:
|
||||
CPU: le_out
|
||||
CUDA: le_out
|
||||
QuantizedCPU: le_out_quantized_cpu
|
||||
|
||||
- func: le.Scalar(Tensor self, Scalar other) -> Tensor
|
||||
supports_named_tensor: True
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
@ -4571,12 +4597,14 @@
|
||||
QuantizedCPU: le_quantized_cpu
|
||||
|
||||
- func: le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
supports_named_tensor: True
|
||||
dispatch:
|
||||
CPU: le_out
|
||||
CUDA: le_out
|
||||
QuantizedCPU: le_out_quantized_cpu
|
||||
|
||||
- func: le.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
supports_named_tensor: True
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
@ -4585,12 +4613,14 @@
|
||||
QuantizedCPU: le_quantized_cpu
|
||||
|
||||
- func: gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
supports_named_tensor: True
|
||||
dispatch:
|
||||
CPU: gt_out
|
||||
CUDA: gt_out
|
||||
QuantizedCPU: gt_out_quantized_cpu
|
||||
|
||||
- func: gt.Scalar(Tensor self, Scalar other) -> Tensor
|
||||
supports_named_tensor: True
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
@ -4599,12 +4629,14 @@
|
||||
QuantizedCPU: gt_quantized_cpu
|
||||
|
||||
- func: gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
supports_named_tensor: True
|
||||
dispatch:
|
||||
CPU: gt_out
|
||||
CUDA: gt_out
|
||||
QuantizedCPU: gt_out_quantized_cpu
|
||||
|
||||
- func: gt.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
supports_named_tensor: True
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
@ -4613,12 +4645,14 @@
|
||||
QuantizedCPU: gt_quantized_cpu
|
||||
|
||||
- func: lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
supports_named_tensor: True
|
||||
dispatch:
|
||||
CPU: lt_out
|
||||
CUDA: lt_out
|
||||
QuantizedCPU: lt_out_quantized_cpu
|
||||
|
||||
- func: lt.Scalar(Tensor self, Scalar other) -> Tensor
|
||||
supports_named_tensor: True
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
@ -4627,12 +4661,14 @@
|
||||
QuantizedCPU: lt_quantized_cpu
|
||||
|
||||
- func: lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
supports_named_tensor: True
|
||||
dispatch:
|
||||
CPU: lt_out
|
||||
CUDA: lt_out
|
||||
QuantizedCPU: lt_out_quantized_cpu
|
||||
|
||||
- func: lt.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
supports_named_tensor: True
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
|
60
aten/src/ATen/native/quantized/cpu/fake_quantize_core.cpp
Normal file
60
aten/src/ATen/native/quantized/cpu/fake_quantize_core.cpp
Normal file
@ -0,0 +1,60 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
|
||||
/* Core operations for fake-quantization shared between per tensor
|
||||
and per-channel fake quant */
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
/* Fake quantize a tensor, common block for per-channel & per-tensor fake quant
|
||||
Args:
|
||||
output: output tensor.
|
||||
input : input tensor.
|
||||
sc: scale to quantize the input tensor to
|
||||
zero_point: zero_point
|
||||
quant_min: minimum quantized value
|
||||
quant_max: maximum quantized value
|
||||
Returns:
|
||||
Fake quantized tensor (double dtype).
|
||||
*/
|
||||
void fake_quantize_slice(
|
||||
Tensor& output,
|
||||
const Tensor& input,
|
||||
float sc,
|
||||
int64_t z_point,
|
||||
int64_t quant_min,
|
||||
int64_t quant_max) {
|
||||
float inv_scale = 1.0f / sc;
|
||||
auto iter = TensorIterator::unary_op(output, input);
|
||||
cpu_kernel(iter, [&](float self) -> float {
|
||||
return (std::fmin(
|
||||
std::fmax(
|
||||
static_cast<int64_t>(
|
||||
std::nearbyint(self * inv_scale + z_point)),
|
||||
quant_min),
|
||||
quant_max) -
|
||||
z_point) *
|
||||
sc;
|
||||
});
|
||||
}
|
||||
|
||||
void fake_quantize_grad_slice(
|
||||
Tensor& input_grad,
|
||||
const Tensor& input,
|
||||
const Tensor& output_grad,
|
||||
float sc,
|
||||
int64_t z_point,
|
||||
int64_t quant_min,
|
||||
int64_t quant_max) {
|
||||
float inv_scale = 1.0f / sc;
|
||||
auto iter = TensorIterator::binary_op(input_grad, input, output_grad);
|
||||
cpu_kernel(iter, [&](float x, float dy) -> float {
|
||||
int64_t xq = static_cast<int64_t>(std::nearbyint(x * inv_scale + z_point));
|
||||
return dy * (xq >= quant_min && xq <= quant_max);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
27
aten/src/ATen/native/quantized/cpu/fake_quantize_core.h
Normal file
27
aten/src/ATen/native/quantized/cpu/fake_quantize_core.h
Normal file
@ -0,0 +1,27 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
|
||||
/* FakeQuantize Op for PerChannelAffine quantization scheme */
|
||||
namespace at {
|
||||
namespace native {
|
||||
void fake_quantize_slice(
|
||||
Tensor& output,
|
||||
const Tensor& input,
|
||||
float sc,
|
||||
int64_t z_point,
|
||||
int64_t quant_min,
|
||||
int64_t quant_max);
|
||||
|
||||
void fake_quantize_grad_slice(
|
||||
Tensor& input_grad,
|
||||
const Tensor& output_grad,
|
||||
const Tensor& input,
|
||||
float sc,
|
||||
int64_t z_point,
|
||||
int64_t quant_min,
|
||||
int64_t quant_max);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -0,0 +1,147 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/native/quantized/cpu/fake_quantize_core.h>
|
||||
|
||||
/* FakeQuantize Op for PerChannelAffine quantization scheme */
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
/* Per channel fake-quantizes the 'inputs' tensor.
|
||||
Args:
|
||||
X: Forward input tensor.
|
||||
dY: Backward input tensor (_backward op only).
|
||||
scale: scale of per channel affine quantization
|
||||
zero_point: zero_point of per channel affine quantization
|
||||
axis: int specifying the axis to be quantized
|
||||
quant_min: minimum quantized value
|
||||
quant_max: maximum quantized value
|
||||
Returns:
|
||||
Fake quantized tensor (double dtype).
|
||||
|
||||
*/
|
||||
|
||||
Tensor fake_quantize_per_channel_affine_cpu(
|
||||
const Tensor& self,
|
||||
const Tensor& scale,
|
||||
const Tensor& zero_point,
|
||||
int64_t axis,
|
||||
int64_t quant_min,
|
||||
int64_t quant_max) {
|
||||
// TODO: Use REGISTER_DISPATCH
|
||||
TORCH_CHECK(self.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(
|
||||
scale.dim() == 1, "scale should be a 1-D tensor");
|
||||
TORCH_CHECK(
|
||||
zero_point.dim() == 1, "zero point should be a 1-D tensor");
|
||||
TORCH_CHECK(
|
||||
scale.numel() == zero_point.numel(),
|
||||
"scale and zero-point need to have the same dimensions");
|
||||
TORCH_CHECK(
|
||||
scale.numel() == self.size(axis),
|
||||
"dimensions of scale and zero-point are not consistent with input tensor")
|
||||
|
||||
TORCH_CHECK(
|
||||
quant_min <= quant_max,
|
||||
"`quant_min` should be less than or \
|
||||
equal to `quant_max`.");
|
||||
|
||||
TORCH_CHECK(
|
||||
at::min(zero_point).item().toLong() >= quant_min &&
|
||||
at::max(zero_point).item().toLong() <= quant_max,
|
||||
"`zero_point` must be between `quant_min` and `quant_max`.");
|
||||
|
||||
TORCH_CHECK(
|
||||
axis >= 0 && axis <= self.dim(),
|
||||
"`axis` must be between 0 and number of dimensions of input");
|
||||
|
||||
auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve);
|
||||
for (int i = 0; i < self.size(axis); i++) {
|
||||
auto input_slice = self.slice(axis, i, i + 1);
|
||||
auto output_slice = Y.slice(axis, i, i + 1);
|
||||
float sc = scale[i].item().toFloat();
|
||||
int64_t z_point = zero_point[i].item().toLong();
|
||||
fake_quantize_slice(
|
||||
output_slice, input_slice, sc, z_point, quant_min, quant_max);
|
||||
}
|
||||
|
||||
return Y;
|
||||
}
|
||||
|
||||
/* Backward path for per-channel fake-quantization of the 'inputs' tensor.
|
||||
|
||||
Args:
|
||||
X: Forward input tensor.
|
||||
dY: Backward input tensor.
|
||||
scale: scale of per tensor affine quantization
|
||||
zero_point: zero_point of per tensor affine quantization
|
||||
axis: int ,the axis over which quantization parameters vary
|
||||
quant_min: int, minimum quantized value
|
||||
quant_max: int, maximum quantized value
|
||||
|
||||
Returns:
|
||||
Gradient for per channel fake quant (double dtype).
|
||||
|
||||
*/
|
||||
Tensor fake_quantize_per_channel_affine_backward_cpu(
|
||||
const Tensor& dY,
|
||||
const Tensor& X,
|
||||
const Tensor& scale,
|
||||
const Tensor& zero_point,
|
||||
int64_t axis,
|
||||
int64_t quant_min,
|
||||
int64_t quant_max) {
|
||||
TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(X.scalar_type() == ScalarType::Float);
|
||||
|
||||
TORCH_CHECK(X.sizes() == dY.sizes(), "`X` and `dY` are not the same size");
|
||||
TORCH_CHECK(
|
||||
quant_min <= quant_max,
|
||||
"`quant_min` should be less than or \
|
||||
equal to `quant_max`.");
|
||||
TORCH_CHECK(
|
||||
scale.dim() == 1, "scale should be a 1-D tensor");
|
||||
TORCH_CHECK(
|
||||
zero_point.dim() == 1, "zero point should be a 1-D tensor");
|
||||
TORCH_CHECK(
|
||||
scale.numel() == zero_point.numel(),
|
||||
"scale and zero-point need to have the same dimensions");
|
||||
TORCH_CHECK(
|
||||
scale.numel() == X.size(axis),
|
||||
"dimensions of scale and zero-point are not consistent with input tensor")
|
||||
|
||||
TORCH_CHECK(
|
||||
quant_min <= quant_max,
|
||||
"`quant_min` should be less than or \
|
||||
equal to `quant_max`.");
|
||||
|
||||
TORCH_CHECK(
|
||||
at::min(zero_point).item().toLong() >= quant_min &&
|
||||
at::max(zero_point).item().toLong() <= quant_max,
|
||||
"`zero_point` must be between `quant_min` and `quant_max`.");
|
||||
|
||||
TORCH_CHECK(
|
||||
axis >= 0 && axis <= X.dim(),
|
||||
"`axis` must be between 0 and number of dimensions of input");
|
||||
|
||||
if (X.numel() <= 0) {
|
||||
return X;
|
||||
}
|
||||
|
||||
auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
|
||||
|
||||
for (int i = 0; i < X.size(axis); i++) {
|
||||
auto X_slice = X.slice(axis, i, i + 1);
|
||||
auto dY_slice = dY.slice(axis, i, i + 1);
|
||||
auto dX_slice = dX.slice(axis, i, i + 1);
|
||||
float sc = scale[i].item().toFloat();
|
||||
int64_t z_point = zero_point[i].item().toLong();
|
||||
fake_quantize_grad_slice(dX_slice, X_slice, dY_slice,
|
||||
sc, z_point, quant_min, quant_max);
|
||||
}
|
||||
|
||||
return dX;
|
||||
}
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -2,6 +2,7 @@
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/native/quantized/cpu/fake_quantize_core.h>
|
||||
|
||||
/* FakeQuantize Op for PerTensorAffine quantization scheme */
|
||||
namespace at {
|
||||
@ -15,16 +16,9 @@ Args:
|
||||
zero_point: zero_point of per tensor affine quantization
|
||||
quant_min: minimum quantized value
|
||||
quant_max: maximum quantized value
|
||||
quant_delay: Count of global steps for which to delay the quantization.
|
||||
See note below.
|
||||
iter: The current quantization iteration used for `quant_delay`.
|
||||
Returns:
|
||||
Quantized tensor (double dtype).
|
||||
|
||||
Notes:
|
||||
- quant_delay might be set to non-zero to help weights stabilize in the
|
||||
beginning of the training.
|
||||
- quantization range [quant_min, quant_max]
|
||||
*/
|
||||
Tensor fake_quantize_per_tensor_affine_cpu(
|
||||
const Tensor& self,
|
||||
@ -41,20 +35,8 @@ Tensor fake_quantize_per_tensor_affine_cpu(
|
||||
zero_point >= quant_min && zero_point <= quant_max,
|
||||
"`zero_point` must be between `quant_min` and `quant_max`.");
|
||||
|
||||
auto Y = at::empty_like(self);
|
||||
float inv_scale = 1.0f / scale;
|
||||
auto iter = TensorIterator::unary_op(Y, self);
|
||||
cpu_kernel(iter, [&](float self) -> float {
|
||||
return (std::fmin(
|
||||
std::fmax(
|
||||
static_cast<int64_t>(
|
||||
std::nearbyint(self * inv_scale + zero_point)),
|
||||
quant_min),
|
||||
quant_max) -
|
||||
zero_point) *
|
||||
scale;
|
||||
});
|
||||
|
||||
auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve);
|
||||
fake_quantize_slice(Y, self, scale, zero_point, quant_min, quant_max);
|
||||
return Y;
|
||||
}
|
||||
|
||||
@ -99,14 +81,8 @@ Tensor fake_quantize_per_tensor_affine_backward_cpu(
|
||||
return X;
|
||||
}
|
||||
|
||||
Tensor dX = at::zeros_like(X);
|
||||
auto iter = TensorIterator::binary_op(dX, X, dY);
|
||||
float inv_scale = 1.0f / scale;
|
||||
cpu_kernel(iter, [&](float x, float dy) -> float {
|
||||
int64_t xq =
|
||||
static_cast<int64_t>(std::nearbyint(x * inv_scale + zero_point));
|
||||
return dy * (xq >= quant_min && xq <= quant_max);
|
||||
});
|
||||
auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
|
||||
fake_quantize_grad_slice(dX, X, dY, scale, zero_point, quant_min, quant_max);
|
||||
return dX;
|
||||
}
|
||||
} // namespace native
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/SmallVector.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||
@ -153,8 +154,6 @@ class QConv2dInt8 final : public c10::OperatorKernel {
|
||||
{pad_l, pad_t, pad_l, pad_t},
|
||||
{static_cast<int>(dilation[0]), static_cast<int>(dilation[1])});
|
||||
|
||||
fbgemm::DoNothing<> NoOpObj{};
|
||||
|
||||
float act_scale = act.q_scale();
|
||||
int32_t act_zero_point = act.q_zero_point();
|
||||
|
||||
@ -208,63 +207,69 @@ class QConv2dInt8 final : public c10::OperatorKernel {
|
||||
// TODO: add MemoryFormat::ChannelsLast here once perf is fixed
|
||||
Tensor output = _empty_affine_quantized(
|
||||
outShape, device(kCPU).dtype(kQUInt8), output_scale, output_zero_point);
|
||||
auto buffer = at::zeros_like(output, output.options().dtype(at::kInt));
|
||||
auto buffer = at::empty(output.sizes(), output.options().dtype(at::kInt));
|
||||
|
||||
if (pack_ptr.q_scheme == kPerTensorAffine) {
|
||||
fbgemm::ReQuantizeOutput<
|
||||
ReluFused,
|
||||
fbgemm::QuantizationGranularity::TENSOR,
|
||||
float>
|
||||
outputProcObj(
|
||||
NoOpObj,
|
||||
output_multiplier_float.data(),
|
||||
output_zero_point,
|
||||
act_zero_point,
|
||||
pack_ptr.w_zp.data(),
|
||||
nullptr, /* row offset buffer */
|
||||
col_offsets.data(),
|
||||
bias_ptr,
|
||||
K,
|
||||
groups,
|
||||
act_times_w_scale.data());
|
||||
fbgemm::fbgemmConv(
|
||||
conv_p,
|
||||
act_ptr,
|
||||
*packB,
|
||||
reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
|
||||
buffer.data_ptr<int32_t>(),
|
||||
outputProcObj,
|
||||
0 /* thread_id*/,
|
||||
1 /* num_threads */);
|
||||
int num_tasks = at::get_num_threads();
|
||||
at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
|
||||
fbgemm::DoNothing<> NoOpObj{};
|
||||
for (int task_id = begin; task_id < end; ++task_id) {
|
||||
if (pack_ptr.q_scheme == kPerTensorAffine) {
|
||||
fbgemm::ReQuantizeOutput<
|
||||
ReluFused,
|
||||
fbgemm::QuantizationGranularity::TENSOR,
|
||||
float>
|
||||
outputProcObj(
|
||||
NoOpObj,
|
||||
output_multiplier_float.data(),
|
||||
output_zero_point,
|
||||
act_zero_point,
|
||||
pack_ptr.w_zp.data(),
|
||||
nullptr, /* row offset buffer */
|
||||
col_offsets.data(),
|
||||
bias_ptr,
|
||||
K,
|
||||
groups,
|
||||
act_times_w_scale.data());
|
||||
fbgemm::fbgemmConv(
|
||||
conv_p,
|
||||
act_ptr,
|
||||
*packB,
|
||||
reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
|
||||
buffer.data_ptr<int32_t>(),
|
||||
outputProcObj,
|
||||
task_id /* thread_id*/,
|
||||
num_tasks /* num_threads */);
|
||||
|
||||
} else if (pack_ptr.q_scheme == kPerChannelAffine) {
|
||||
fbgemm::ReQuantizeOutput<
|
||||
ReluFused,
|
||||
fbgemm::QuantizationGranularity::OUT_CHANNEL,
|
||||
float>
|
||||
outputProcObj(
|
||||
NoOpObj,
|
||||
output_multiplier_float.data(),
|
||||
output_zero_point,
|
||||
act_zero_point,
|
||||
pack_ptr.w_zp.data(),
|
||||
nullptr, /* row offset buffer */
|
||||
col_offsets.data(),
|
||||
bias_ptr,
|
||||
K,
|
||||
groups,
|
||||
act_times_w_scale.data());
|
||||
} else if (pack_ptr.q_scheme == kPerChannelAffine) {
|
||||
fbgemm::ReQuantizeOutput<
|
||||
ReluFused,
|
||||
fbgemm::QuantizationGranularity::OUT_CHANNEL,
|
||||
float>
|
||||
outputProcObj(
|
||||
NoOpObj,
|
||||
output_multiplier_float.data(),
|
||||
output_zero_point,
|
||||
act_zero_point,
|
||||
pack_ptr.w_zp.data(),
|
||||
nullptr, /* row offset buffer */
|
||||
col_offsets.data(),
|
||||
bias_ptr,
|
||||
K,
|
||||
groups,
|
||||
act_times_w_scale.data());
|
||||
|
||||
fbgemm::fbgemmConv(
|
||||
conv_p,
|
||||
act_ptr,
|
||||
*packB,
|
||||
reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
|
||||
buffer.data_ptr<int32_t>(),
|
||||
outputProcObj,
|
||||
0 /* thread_id*/,
|
||||
1 /* num_threads */);
|
||||
}
|
||||
fbgemm::fbgemmConv(
|
||||
conv_p,
|
||||
act_ptr,
|
||||
*packB,
|
||||
reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
|
||||
buffer.data_ptr<int32_t>(),
|
||||
outputProcObj,
|
||||
task_id /* thread_id*/,
|
||||
num_tasks /* num_threads */);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// TODO: remove permute once MemoryLayout is added above
|
||||
return output.permute({0, 3, 1, 2});
|
||||
@ -355,7 +360,8 @@ class QConv2dInt8 final : public c10::OperatorKernel {
|
||||
kernel_zp,
|
||||
MemoryFormat::ChannelsLast);
|
||||
auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>();
|
||||
for (int i = 0; i < weight_contig.numel(); ++i) {
|
||||
auto wt_numel = weight_contig.numel();
|
||||
for (int i = 0; i < wt_numel; ++i) {
|
||||
qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128);
|
||||
}
|
||||
// Original bias was float, so we requantize it here.
|
||||
|
@ -218,7 +218,8 @@ class QConvPackWeightInt8 final : public c10::OperatorKernel {
|
||||
weight.q_scale(),
|
||||
weight_zp);
|
||||
auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>();
|
||||
for (int i = 0; i < weight_contig.numel(); ++i) {
|
||||
auto wt_numel = weight_contig.numel();
|
||||
for (int i = 0; i < wt_numel; ++i) {
|
||||
qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128);
|
||||
}
|
||||
// We set the pre-packed conv weights to nullptr below as we call pre-pack
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||
@ -80,35 +81,6 @@ class QLinearInt8 final : public torch::OperatorKernel {
|
||||
}
|
||||
int32_t output_zero_point_int32 = static_cast<int32_t>(output_zero_point);
|
||||
|
||||
// This operation does the following:
|
||||
// 1) Creates a "row buffer" vector with offset values that must be added
|
||||
// to the integer matrix multiplication operation to ensure correctness.
|
||||
// This "row buffer" is also called the row offset, and it is needed when
|
||||
// we use affine quantization for weights.
|
||||
// 2) Packs the resulting quantized matrix into vector-register and cache
|
||||
// friendly tiles.
|
||||
//
|
||||
// Note this is not executed eagerly, but rather within the fbgemmPacked
|
||||
// call below.
|
||||
fbgemm::PackAWithRowOffset<uint8_t> packA(
|
||||
/*trans=*/fbgemm::matrix_op_t::NoTranspose,
|
||||
/*nRow=*/M,
|
||||
/*nCol=*/K,
|
||||
/*smat=*/input_ptr,
|
||||
/*ld=*/K,
|
||||
/*pmat=*/nullptr); // Currently, packA manages ownership of `pmat`.
|
||||
// TODO: Consider a way to pre-allocate and reuse
|
||||
// pmat buffer.
|
||||
|
||||
// ReQuantizeOutput requires pointers to the zero point values,
|
||||
// since in the case of rowwise quantization these will be arrays rather
|
||||
// than scalars. But in this case, we're doing whole-tensor quantization so
|
||||
// we just pass a pointer to the scale values (and internally
|
||||
// ReQuantizeOutput won't index past 0.
|
||||
|
||||
// This is the end of the pipeline, pass the resulting matrix through.
|
||||
fbgemm::DoNothing<> doNothingObj{};
|
||||
|
||||
const float* bias_ptr = nullptr;
|
||||
at::Tensor bias;
|
||||
if (pack_ptr.bias.has_value()) {
|
||||
@ -134,79 +106,114 @@ class QLinearInt8 final : public torch::OperatorKernel {
|
||||
output_scale,
|
||||
output_zero_point);
|
||||
|
||||
auto buffer = at::zeros_like(output, output.options().dtype(at::kInt));
|
||||
auto buffer = at::empty(out_sizes, output.options().dtype(at::kInt));
|
||||
|
||||
if (pack_ptr.q_scheme == kPerTensorAffine) {
|
||||
// Process the per tensor quantization.
|
||||
//
|
||||
// After the uint8 * int8 matrix multiplication is performed, this
|
||||
// operation does:
|
||||
// 1) Add in row and column offsets to the rows and columns,
|
||||
// respectively.
|
||||
// 2) Add in the bias term.
|
||||
fbgemm::ReQuantizeOutput<
|
||||
ReluFused,
|
||||
fbgemm::QuantizationGranularity::TENSOR,
|
||||
float>
|
||||
outputProcObj(
|
||||
doNothingObj,
|
||||
output_multiplier_float.data(),
|
||||
output_zero_point_int32,
|
||||
input_zero_point_int32,
|
||||
pack_ptr.w_zp.data(),
|
||||
packA.getRowOffsetBuffer(),
|
||||
col_offsets.data(),
|
||||
bias_ptr,
|
||||
N, /* nCol */
|
||||
1 /* groups */,
|
||||
act_times_w_scale.data());
|
||||
int num_tasks = at::get_num_threads();
|
||||
at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
|
||||
for (int task_id = begin; task_id < end; ++task_id) {
|
||||
// This operation does the following:
|
||||
// 1) Creates a "row buffer" vector with offset values that must be
|
||||
// added to the integer matrix multiplication operation to ensure
|
||||
// correctness. This "row buffer" is also called the row offset, and
|
||||
// it is needed when we use affine quantization for weights.
|
||||
// 2) Packs the resulting quantized matrix into vector-register and
|
||||
// cache friendly tiles.
|
||||
//
|
||||
// Note this is not executed eagerly, but rather within the
|
||||
// fbgemmPacked call below.
|
||||
fbgemm::PackAWithRowOffset<uint8_t> packA(
|
||||
/*trans=*/fbgemm::matrix_op_t::NoTranspose,
|
||||
/*nRow=*/M,
|
||||
/*nCol=*/K,
|
||||
/*smat=*/input_ptr,
|
||||
/*ld=*/K,
|
||||
/*pmat=*/nullptr); // Currently, packA manages ownership of `pmat`.
|
||||
// TODO: Consider a way to pre-allocate and reuse
|
||||
// pmat buffer.
|
||||
|
||||
// Do the GEMM
|
||||
fbgemm::fbgemmPacked(
|
||||
/*packA=*/packA,
|
||||
/*packB=*/*packB,
|
||||
/*C=*/reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
|
||||
/*C_buffer=*/buffer.data_ptr<int32_t>(),
|
||||
/*ldc=*/N,
|
||||
/*outProcess=*/outputProcObj,
|
||||
/*thread_id=*/0,
|
||||
/*num_threads=*/1);
|
||||
} else if (pack_ptr.q_scheme == kPerChannelAffine) {
|
||||
// Process the per channel quantization.
|
||||
//
|
||||
// After the uint8 * int8 matrix multiplication is performed, this
|
||||
// operation does:
|
||||
// 1) Add in row and column offsets to the rows and columns,
|
||||
// respectively.
|
||||
// 2) Add in the bias term.
|
||||
fbgemm::ReQuantizeOutput<
|
||||
ReluFused,
|
||||
fbgemm::QuantizationGranularity::OUT_CHANNEL,
|
||||
float>
|
||||
outputProcObj(
|
||||
doNothingObj,
|
||||
output_multiplier_float.data(),
|
||||
output_zero_point_int32,
|
||||
input_zero_point_int32,
|
||||
pack_ptr.w_zp.data(),
|
||||
packA.getRowOffsetBuffer(),
|
||||
col_offsets.data(),
|
||||
bias_ptr,
|
||||
N, /*nCol=*/
|
||||
1, /* groups*/
|
||||
act_times_w_scale.data());
|
||||
// ReQuantizeOutput requires pointers to the zero point values,
|
||||
// since in the case of rowwise quantization these will be arrays rather
|
||||
// than scalars. But in this case, we're doing whole-tensor quantization
|
||||
// so we just pass a pointer to the scale values (and internally
|
||||
// ReQuantizeOutput won't index past 0.
|
||||
|
||||
// This is the end of the pipeline, pass the resulting matrix through.
|
||||
fbgemm::DoNothing<> doNothingObj{};
|
||||
|
||||
if (pack_ptr.q_scheme == kPerTensorAffine) {
|
||||
// Process the per tensor quantization.
|
||||
//
|
||||
// After the uint8 * int8 matrix multiplication is performed, this
|
||||
// operation does:
|
||||
// 1) Add in row and column offsets to the rows and columns,
|
||||
// respectively.
|
||||
// 2) Add in the bias term.
|
||||
fbgemm::ReQuantizeOutput<
|
||||
ReluFused,
|
||||
fbgemm::QuantizationGranularity::TENSOR,
|
||||
float>
|
||||
outputProcObj(
|
||||
doNothingObj,
|
||||
output_multiplier_float.data(),
|
||||
output_zero_point_int32,
|
||||
input_zero_point_int32,
|
||||
pack_ptr.w_zp.data(),
|
||||
packA.getRowOffsetBuffer(),
|
||||
col_offsets.data(),
|
||||
bias_ptr,
|
||||
N, /* nCol */
|
||||
1 /* groups */,
|
||||
act_times_w_scale.data());
|
||||
|
||||
// Do the GEMM
|
||||
fbgemm::fbgemmPacked(
|
||||
/*packA=*/packA,
|
||||
/*packB=*/*packB,
|
||||
/*C=*/reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
|
||||
/*C_buffer=*/buffer.data_ptr<int32_t>(),
|
||||
/*ldc=*/N,
|
||||
/*outProcess=*/outputProcObj,
|
||||
/*thread_id=*/task_id,
|
||||
/*num_threads=*/num_tasks);
|
||||
} else if (pack_ptr.q_scheme == kPerChannelAffine) {
|
||||
// Process the per channel quantization.
|
||||
//
|
||||
// After the uint8 * int8 matrix multiplication is performed, this
|
||||
// operation does:
|
||||
// 1) Add in row and column offsets to the rows and columns,
|
||||
// respectively.
|
||||
// 2) Add in the bias term.
|
||||
fbgemm::ReQuantizeOutput<
|
||||
ReluFused,
|
||||
fbgemm::QuantizationGranularity::OUT_CHANNEL,
|
||||
float>
|
||||
outputProcObj(
|
||||
doNothingObj,
|
||||
output_multiplier_float.data(),
|
||||
output_zero_point_int32,
|
||||
input_zero_point_int32,
|
||||
pack_ptr.w_zp.data(),
|
||||
packA.getRowOffsetBuffer(),
|
||||
col_offsets.data(),
|
||||
bias_ptr,
|
||||
N, /*nCol=*/
|
||||
1, /* groups*/
|
||||
act_times_w_scale.data());
|
||||
|
||||
// Do the GEMM
|
||||
fbgemm::fbgemmPacked(
|
||||
/*packA=*/packA,
|
||||
/*packB=*/*packB,
|
||||
/*C=*/reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
|
||||
/*C_buffer=*/buffer.data_ptr<int32_t>(),
|
||||
/*ldc=*/N,
|
||||
/*outProcess=*/outputProcObj,
|
||||
/*thread_id=*/task_id,
|
||||
/*num_threads=*/num_tasks);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Do the GEMM
|
||||
fbgemm::fbgemmPacked(
|
||||
/*packA=*/packA,
|
||||
/*packB=*/*packB,
|
||||
/*C=*/reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
|
||||
/*C_buffer=*/buffer.data_ptr<int32_t>(),
|
||||
/*ldc=*/N,
|
||||
/*outProcess=*/outputProcObj,
|
||||
/*thread_id=*/0,
|
||||
/*num_threads=*/1);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
#endif
|
||||
@ -242,7 +249,8 @@ class QLinearInt8 final : public torch::OperatorKernel {
|
||||
kernel_scale,
|
||||
kernel_zp);
|
||||
auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>();
|
||||
for (int i = 0; i < weight_contig.numel(); ++i) {
|
||||
auto wt_numel = weight_contig.numel();
|
||||
for (int i = 0; i < wt_numel; ++i) {
|
||||
qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128);
|
||||
}
|
||||
// Original bias was float, so we requantize it here.
|
||||
|
@ -127,8 +127,8 @@ class QLinearDynamicInt8 final : public torch::OperatorKernel {
|
||||
std::vector<int64_t> out_sizes = input.sizes().vec();
|
||||
out_sizes.back() = N;
|
||||
// Allocate output Tensor and a buffer for fbgemmPacked to use
|
||||
auto output = at::zeros(out_sizes, input.options().dtype(at::kFloat));
|
||||
auto buffer = at::zeros_like(output, output.options().dtype(at::kInt));
|
||||
auto output = at::empty(out_sizes, input.options().dtype(at::kFloat));
|
||||
auto buffer = at::empty_like(output, output.options().dtype(at::kInt));
|
||||
|
||||
if (pack_ptr.q_scheme == kPerTensorAffine) {
|
||||
// Process the per tensor quantization.
|
||||
|
@ -163,7 +163,8 @@ class QLinearPackWeightInt8 final : public c10::OperatorKernel {
|
||||
weight.q_scale(),
|
||||
weight_zp);
|
||||
auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>();
|
||||
for (int i = 0; i < weight_contig.numel(); ++i) {
|
||||
auto wt_numel = weight_contig.numel();
|
||||
for (int i = 0; i < wt_numel; ++i) {
|
||||
qnnp_w_data[i] = static_cast<c10::quint8>(inp_data[i] + 128);
|
||||
}
|
||||
initQNNPACK();
|
||||
|
60
aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
Normal file
60
aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
Normal file
@ -0,0 +1,60 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <cmath>
|
||||
|
||||
/* Fake quantize a tensor, common block for per-channel & per-tensor fake quant
|
||||
Args:
|
||||
output: output tensor.
|
||||
input : input tensor.
|
||||
sc: scale to quantize the input tensor to
|
||||
zero_point: zero_point
|
||||
quant_min: minimum quantized value
|
||||
quant_max: maximum quantized value
|
||||
Returns:
|
||||
Fake quantized tensor (double dtype).
|
||||
*/
|
||||
namespace at {
|
||||
namespace native {
|
||||
void fake_quantize_slice_cuda(
|
||||
Tensor& output,
|
||||
const Tensor& input,
|
||||
float scale,
|
||||
int64_t zero_point,
|
||||
int64_t quant_min,
|
||||
int64_t quant_max) {
|
||||
float inv_scale = 1.0f / scale;
|
||||
at::cuda::CUDA_tensor_apply2<float, float>(
|
||||
input, output, [=] __device__(const float& input_val, float& result_val) {
|
||||
result_val = (fminf(
|
||||
quant_max,
|
||||
fmaxf(
|
||||
quant_min,
|
||||
static_cast<int64_t>(std::nearbyint(
|
||||
input_val * inv_scale + zero_point)))) -
|
||||
zero_point) *
|
||||
scale;
|
||||
});
|
||||
}
|
||||
|
||||
void fake_quantize_grad_slice_cuda(
|
||||
Tensor& input_grad,
|
||||
const Tensor& input,
|
||||
const Tensor& output_grad,
|
||||
float scale,
|
||||
int64_t zero_point,
|
||||
int64_t quant_min,
|
||||
int64_t quant_max) {
|
||||
float inv_scale = 1.0f / scale;
|
||||
at::cuda::CUDA_tensor_apply3<float, float, float>(
|
||||
output_grad,
|
||||
input,
|
||||
input_grad,
|
||||
[=] __device__(const float& dy, const float& x, float& dx) {
|
||||
int64_t Xq = std::nearbyint(x * inv_scale + zero_point);
|
||||
dx = (Xq >= quant_min && Xq <= quant_max) * dy;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
28
aten/src/ATen/native/quantized/cuda/fake_quantize_core.h
Normal file
28
aten/src/ATen/native/quantized/cuda/fake_quantize_core.h
Normal file
@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <cmath>
|
||||
|
||||
/* FakeQuantize Op for PerChannelAffine quantization scheme */
|
||||
namespace at {
|
||||
namespace native {
|
||||
void fake_quantize_slice_cuda(
|
||||
Tensor& output,
|
||||
const Tensor& input,
|
||||
float sc,
|
||||
int64_t z_point,
|
||||
int64_t quant_min,
|
||||
int64_t quant_max);
|
||||
|
||||
void fake_quantize_grad_slice_cuda(
|
||||
Tensor& input_grad,
|
||||
const Tensor& output_grad,
|
||||
const Tensor& input,
|
||||
float sc,
|
||||
int64_t z_point,
|
||||
int64_t quant_min,
|
||||
int64_t quant_max);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -0,0 +1,144 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <cmath>
|
||||
#include <ATen/native/quantized/cuda/fake_quantize_core.h>
|
||||
|
||||
/* FakeQuantize Op for PerChannelAffine quantization scheme */
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
/* Per channel fake-quantizes the 'inputs' tensor.
|
||||
Args:
|
||||
self: Forward input tensor.
|
||||
scale: scale of per channel affine quantization
|
||||
zero_point: zero_point of per channel affine quantization
|
||||
axis: int specifying the axis to be quantized
|
||||
quant_min: minimum quantized value
|
||||
quant_max: maximum quantized value
|
||||
Returns:
|
||||
Fake quantized tensor (double dtype).
|
||||
|
||||
*/
|
||||
Tensor fake_quantize_per_channel_affine_cuda(
|
||||
const Tensor& self,
|
||||
const Tensor& scale,
|
||||
const Tensor& zero_point,
|
||||
int64_t axis,
|
||||
int64_t quant_min,
|
||||
int64_t quant_max) {
|
||||
TORCH_CHECK(self.is_cuda());
|
||||
TORCH_CHECK(self.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(
|
||||
scale.dim() == 1, "scale should be a 1-D tensor");
|
||||
TORCH_CHECK(
|
||||
zero_point.dim() == 1, "zero point should be a 1-D tensor");
|
||||
TORCH_CHECK(
|
||||
scale.numel() == zero_point.numel(),
|
||||
"scale and zero-point need to have the same dimensions");
|
||||
TORCH_CHECK(
|
||||
scale.numel() == self.size(axis),
|
||||
"dimensions of scale and zero-point are not consistent with input tensor")
|
||||
|
||||
TORCH_CHECK(
|
||||
quant_min <= quant_max,
|
||||
"`quant_min` should be less than or \
|
||||
equal to `quant_max`.");
|
||||
|
||||
TORCH_CHECK(
|
||||
at::min(zero_point).item().toLong() >= quant_min &&
|
||||
at::max(zero_point).item().toLong() <= quant_max,
|
||||
"`zero_point` must be between `quant_min` and `quant_max`.");
|
||||
|
||||
TORCH_CHECK(
|
||||
axis >= 0 && axis <= self.dim(),
|
||||
"`axis` must be between 0 and number of dimensions of input");
|
||||
|
||||
auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve);
|
||||
for (int i = 0; i < self.size(axis); i++) {
|
||||
auto X_slice = self.slice(axis, i, i + 1);
|
||||
auto Y_slice = Y.slice(axis, i, i + 1);
|
||||
float sc = scale[i].item().toFloat();
|
||||
int64_t zp = zero_point[i].item().toLong();
|
||||
fake_quantize_slice_cuda(Y_slice, X_slice, sc, zp, quant_min, quant_max);
|
||||
}
|
||||
return Y;
|
||||
}
|
||||
|
||||
/* Backward path for per-channel fake-quantization of the 'inputs' tensor.
|
||||
|
||||
Args:
|
||||
dY: Backward input tensor.
|
||||
X: Forward input tensor.
|
||||
scale: scale of per tensor affine quantization
|
||||
zero_point: zero_point of per tensor affine quantization
|
||||
axis: int ,the axis over which quantization parameters vary
|
||||
quant_min: int, minimum quantized value
|
||||
quant_max: int, maximum quantized value
|
||||
|
||||
Returns:
|
||||
Gradient for per channel fake quant (double dtype).
|
||||
|
||||
*/
|
||||
Tensor fake_quantize_per_channel_affine_backward_cuda(
|
||||
const Tensor& dY,
|
||||
const Tensor& X,
|
||||
const Tensor& scale,
|
||||
const Tensor& zero_point,
|
||||
int64_t axis,
|
||||
int64_t quant_min,
|
||||
int64_t quant_max) {
|
||||
TORCH_CHECK(dY.is_cuda());
|
||||
TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(X.scalar_type() == ScalarType::Float);
|
||||
|
||||
TORCH_CHECK(X.sizes() == dY.sizes(), "`X` and `dY` are not the same size");
|
||||
TORCH_CHECK(
|
||||
quant_min <= quant_max,
|
||||
"`quant_min` should be less than or \
|
||||
equal to `quant_max`.");
|
||||
TORCH_CHECK(
|
||||
scale.dim() == 1, "scale should be a 1-D tensor");
|
||||
TORCH_CHECK(
|
||||
zero_point.dim() == 1, "zero point should be a 1-D tensor");
|
||||
TORCH_CHECK(
|
||||
scale.numel() == zero_point.numel(),
|
||||
"scale and zero-point need to have the same dimensions");
|
||||
TORCH_CHECK(
|
||||
scale.numel() == X.size(axis),
|
||||
"dimensions of scale and zero-point are not consistent with input tensor")
|
||||
|
||||
TORCH_CHECK(
|
||||
quant_min <= quant_max,
|
||||
"`quant_min` should be less than or \
|
||||
equal to `quant_max`.");
|
||||
|
||||
TORCH_CHECK(
|
||||
at::min(zero_point).item().toLong() >= quant_min &&
|
||||
at::max(zero_point).item().toLong() <= quant_max,
|
||||
"`zero_point` must be between `quant_min` and `quant_max`.");
|
||||
|
||||
TORCH_CHECK(
|
||||
axis >= 0 && axis <= X.dim(),
|
||||
"`axis` must be between 0 and number of dimensions of input");
|
||||
|
||||
if (X.numel() <= 0) {
|
||||
return X;
|
||||
}
|
||||
|
||||
auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
|
||||
|
||||
for (int i = 0; i < X.size(axis); i++) {
|
||||
auto dY_slice = dY.slice(axis, i, i + 1);
|
||||
auto X_slice = X.slice(axis, i, i + 1);
|
||||
auto dX_slice = dX.slice(axis, i, i + 1);
|
||||
float sc = scale[i].item().toFloat();
|
||||
int64_t zp = zero_point[i].item().toLong();
|
||||
fake_quantize_grad_slice_cuda(
|
||||
dX_slice, X_slice, dY_slice, sc, zp, quant_min, quant_max);
|
||||
}
|
||||
return dX;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -2,6 +2,7 @@
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <cmath>
|
||||
#include <ATen/native/quantized/cuda/fake_quantize_core.h>
|
||||
|
||||
/* FakeQuantize Op for PerTensorAffine quantization scheme */
|
||||
namespace at {
|
||||
@ -9,21 +10,13 @@ namespace native {
|
||||
|
||||
/* Fake-quantizes the 'inputs' tensor.
|
||||
Args:
|
||||
X: Forward input tensor.
|
||||
self: Forward input tensor.
|
||||
scale: scale of per tensor affine quantization
|
||||
zero_point: zero_point of per tensor affine quantization
|
||||
quant_min: minimum quantized value
|
||||
quant_max: maximum quantized value
|
||||
quant_delay: Count of global steps for which to delay the quantization.
|
||||
See note below.
|
||||
iter: The current quantization iteration used for `quant_delay`.
|
||||
Returns:
|
||||
Quantized tensor (double dtype).
|
||||
|
||||
Notes:
|
||||
- quant_delay might be set to non-zero to help weights stabilize in the
|
||||
beginning of the training.
|
||||
- quantization range [quant_min, quant_max]
|
||||
*/
|
||||
Tensor fake_quantize_per_tensor_affine_cuda(
|
||||
const Tensor& self,
|
||||
@ -40,42 +33,22 @@ Tensor fake_quantize_per_tensor_affine_cuda(
|
||||
TORCH_CHECK(
|
||||
zero_point >= quant_min && zero_point <= quant_max,
|
||||
"`zero_point` must be between `quant_min` and `quant_max`.");
|
||||
auto Y = at::empty_like(self);
|
||||
|
||||
float inv_scale = 1.0f / scale;
|
||||
at::cuda::CUDA_tensor_apply2<float, float>(
|
||||
self, Y, [=] __device__(const float& input_val, float& result_val) {
|
||||
result_val = (fminf(
|
||||
quant_max,
|
||||
fmaxf(
|
||||
quant_min,
|
||||
static_cast<int64_t>(std::nearbyint(
|
||||
input_val * inv_scale + zero_point)))) -
|
||||
zero_point) *
|
||||
scale;
|
||||
});
|
||||
auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve);
|
||||
fake_quantize_slice_cuda(Y, self, scale, zero_point, quant_min, quant_max);
|
||||
return Y;
|
||||
}
|
||||
|
||||
/* Backward path to fake-quantize the 'inputs' tensor.
|
||||
|
||||
Args:
|
||||
X: Forward input tensor.
|
||||
dY: Backward input tensor.
|
||||
X: Forward input tensor.
|
||||
scale: scale of per tensor affine quantization
|
||||
zero_point: zero_point of per tensor affine quantization
|
||||
quant_min: minimum quantized value
|
||||
quant_max: maximum quantized value
|
||||
quant_delay: Count of global steps for which to delay the quantization.
|
||||
See note in forward.
|
||||
iter: The current quantization iteration used for `quant_delay`.
|
||||
Returns:
|
||||
Quantized tensor (double dtype).
|
||||
|
||||
Notes:
|
||||
- quant_delay might be set to non-zero to help weights stabilize in the
|
||||
beginning of the training.
|
||||
- quantization range [quant_min, quant_max]
|
||||
*/
|
||||
Tensor fake_quantize_per_tensor_affine_backward_cuda(
|
||||
const Tensor& dY,
|
||||
@ -100,14 +73,9 @@ Tensor fake_quantize_per_tensor_affine_backward_cuda(
|
||||
return X;
|
||||
}
|
||||
|
||||
auto dX = dY.clone();
|
||||
|
||||
float inv_scale = 1.0f / scale;
|
||||
at::cuda::CUDA_tensor_apply3<float, float, float>(
|
||||
dY, X, dX, [=] __device__(const float& dy, const float& x, float& dx) {
|
||||
int64_t Xq = std::nearbyint(x * inv_scale + zero_point);
|
||||
dx = (Xq >= quant_min && Xq <= quant_max) * dy;
|
||||
});
|
||||
auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
|
||||
fake_quantize_grad_slice_cuda(
|
||||
dX, X, dY, scale, zero_point, quant_min, quant_max);
|
||||
return dX;
|
||||
}
|
||||
|
||||
|
@ -297,7 +297,8 @@ Tensor quantize_tensor(Tensor rtensor, Tensor qtensor, double scale, int64_t zer
|
||||
}
|
||||
#endif
|
||||
auto qdata = qtensor.data_ptr<T>();
|
||||
for (int i = 0; i < rtensor.numel(); ++i) {
|
||||
auto numel = rtensor.numel();
|
||||
for (int i = 0; i < numel; ++i) {
|
||||
qdata[i] = quantize_val<T>(scale, zero_point, rdata[i]);
|
||||
}
|
||||
return qtensor;
|
||||
@ -318,7 +319,8 @@ Tensor dequantize_tensor(Tensor qtensor, Tensor rtensor, double scale, int64_t z
|
||||
checkZeroPoint<typename T::underlying>(fn_name, zero_point);
|
||||
const auto* qd = qtensor.data_ptr<T>();
|
||||
float* rd = rtensor.data_ptr<float>();
|
||||
for (auto i = 0; i < qtensor.numel(); ++i) {
|
||||
auto numel = qtensor.numel();
|
||||
for (auto i = 0; i < numel; ++i) {
|
||||
rd[i] = dequantize_val<T>(scale, zero_point, qd[i]);
|
||||
}
|
||||
return rtensor;
|
||||
|
@ -29,7 +29,8 @@ list(APPEND ATen_CPU_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/xla_tensor_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/tensor_iterator_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cpu_generator_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce_ops_test.cpp)
|
||||
|
||||
list(APPEND ATen_CUDA_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu
|
||||
|
24
aten/src/ATen/test/reduce_ops_test.cpp
Normal file
24
aten/src/ATen/test/reduce_ops_test.cpp
Normal file
@ -0,0 +1,24 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
using namespace at;
|
||||
|
||||
TEST(ReduceOpsTest, MaxValuesAndMinValues) {
|
||||
const int W = 10;
|
||||
const int H = 10;
|
||||
if (hasCUDA()) {
|
||||
for (const auto dtype : {kHalf, kFloat, kDouble, kShort, kInt, kLong}) {
|
||||
auto a = at::rand({H, W}, TensorOptions(kCUDA).dtype(at::kHalf));
|
||||
ASSERT_FLOAT_EQ(
|
||||
a.max_values(c10::IntArrayRef{0, 1}).item<double>(),
|
||||
a.max().item<double>()
|
||||
);
|
||||
ASSERT_FLOAT_EQ(
|
||||
a.min_values(c10::IntArrayRef{0, 1}).item<double>(),
|
||||
a.min().item<double>()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
@ -172,3 +172,12 @@ TEST(TensorIteratorTest, DoNotComputeCommonDTypeIfOutputIsUndefined) {
|
||||
iter.compute_common_dtype_only_for_inputs();
|
||||
ASSERT_ANY_THROW(iter.build());
|
||||
}
|
||||
|
||||
TEST(TensorIteratorTest, FailNonPromotingBinaryOp) {
|
||||
Tensor out;
|
||||
auto iter = at::TensorIterator();
|
||||
iter.add_output(out);
|
||||
iter.add_input(at::ones({1,1}, at::dtype(at::kDouble)));
|
||||
iter.add_input(at::ones({1,1}, at::dtype(at::kInt)));
|
||||
ASSERT_ANY_THROW(iter.build());
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
#define Real QInt32
|
||||
#define RealUnderlying Int
|
||||
#define THQUANTIZED
|
||||
#define THQINT32
|
||||
#define TH_REAL_IS_BYTE
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
@ -15,6 +16,7 @@
|
||||
#undef Real
|
||||
#undef RealUnderlying
|
||||
#undef TH_REAL_IS_BYTE
|
||||
#undef THQINT32
|
||||
#undef THQUANTIZED
|
||||
|
||||
#ifndef THGenerateManyTypes
|
||||
|
@ -7,6 +7,7 @@
|
||||
#define Real QInt8
|
||||
#define RealUnderlying Char
|
||||
#define THQUANTIZED
|
||||
#define THQINT8
|
||||
#define TH_REAL_IS_BYTE
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
@ -15,6 +16,7 @@
|
||||
#undef Real
|
||||
#undef RealUnderlying
|
||||
#undef TH_REAL_IS_BYTE
|
||||
#undef THQINT8
|
||||
#undef THQUANTIZED
|
||||
|
||||
#ifndef THGenerateManyTypes
|
||||
|
@ -7,6 +7,7 @@
|
||||
#define Real QUInt8
|
||||
#define RealUnderlying Byte
|
||||
#define THQUANTIZED
|
||||
#define THQUINT8
|
||||
#define TH_REAL_IS_BYTE
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
@ -15,6 +16,7 @@
|
||||
#undef Real
|
||||
#undef RealUnderlying
|
||||
#undef TH_REAL_IS_BYTE
|
||||
#undef THQUINT8
|
||||
#undef THQUANTIZED
|
||||
|
||||
#ifndef THGenerateManyTypes
|
||||
|
@ -35,6 +35,9 @@
|
||||
#define THLongStorage THStorage
|
||||
#define THBoolStorage THStorage
|
||||
#define THBFloat16Storage THStorage
|
||||
#define THQUInt8Storage THStorage
|
||||
#define THQInt8Storage THStorage
|
||||
#define THQInt32Storage THStorage
|
||||
|
||||
TH_API scalar_t* THStorage_(data)(const THStorage*);
|
||||
TH_API ptrdiff_t THStorage_(size)(const THStorage*);
|
||||
|
@ -35,5 +35,14 @@ IMPLEMENT_THStorage_COPY(Double)
|
||||
IMPLEMENT_THStorage_COPY(Half)
|
||||
IMPLEMENT_THStorage_COPY(Bool)
|
||||
IMPLEMENT_THStorage_COPY(BFloat16)
|
||||
#ifdef THQUINT8
|
||||
IMPLEMENT_THStorage_COPY(QUInt8)
|
||||
#endif
|
||||
#ifdef THQINT8
|
||||
IMPLEMENT_THStorage_COPY(QInt8)
|
||||
#endif
|
||||
#ifdef THQINT32
|
||||
IMPLEMENT_THStorage_COPY(QInt32)
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
@ -14,5 +14,14 @@ TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *s
|
||||
TH_API void THStorage_(copyHalf)(THStorage *storage, struct THHalfStorage *src);
|
||||
TH_API void THStorage_(copyBool)(THStorage *storage, struct THBoolStorage *src);
|
||||
TH_API void THStorage_(copyBFloat16)(THStorage *storage, struct THBFloat16Storage *src);
|
||||
#ifdef THQUINT8
|
||||
TH_API void THStorage_(copyQUInt8)(THStorage *storage, struct THQUInt8Storage *src);
|
||||
#endif
|
||||
#ifdef THQINT8
|
||||
TH_API void THStorage_(copyQInt8)(THStorage *storage, struct THQInt8Storage *src);
|
||||
#endif
|
||||
#ifdef THQINT32
|
||||
TH_API void THStorage_(copyQInt32)(THStorage *storage, struct THQInt32Storage *src);
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
#include <type_traits>
|
||||
#include <stdexcept>
|
||||
|
||||
#ifndef _MSC_VER
|
||||
#pragma GCC diagnostic push
|
||||
|
@ -117,9 +117,12 @@ if ((NOT TARGET protobuf::libprotobuf) AND (NOT TARGET protobuf::libprotobuf-lit
|
||||
# "Please set the proper paths so that I can find protobuf correctly.")
|
||||
endif()
|
||||
|
||||
# Protobuf generated files use <> as inclusion path, so maybe we should use
|
||||
# SYSTEM inclusion path. But we need these include dirs to be found before
|
||||
# other protobuf include dirs in Anaconda
|
||||
get_target_property(__tmp protobuf::libprotobuf INTERFACE_INCLUDE_DIRECTORIES)
|
||||
message(STATUS "Caffe2 protobuf include directory: " ${__tmp})
|
||||
include_directories(BEFORE SYSTEM ${__tmp})
|
||||
include_directories(BEFORE ${__tmp})
|
||||
|
||||
# If Protobuf_VERSION is known (true in most cases, false if we are building
|
||||
# local protobuf), then we will add a protobuf version check in
|
||||
|
@ -53,6 +53,8 @@ extensions = [
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinxcontrib.katex',
|
||||
'sphinx.ext.autosectionlabel',
|
||||
'javasphinx',
|
||||
]
|
||||
|
||||
# katex options
|
||||
|
@ -16,6 +16,7 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
|
||||
:caption: Notes
|
||||
|
||||
notes/*
|
||||
PyTorch on XLA Devices <http://pytorch.org/xla/>
|
||||
|
||||
.. toctree::
|
||||
:glob:
|
||||
@ -25,8 +26,8 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
|
||||
community/*
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Package Reference
|
||||
:maxdepth: 1
|
||||
:caption: Python API
|
||||
|
||||
torch
|
||||
nn
|
||||
@ -42,6 +43,7 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
|
||||
nn.init
|
||||
onnx
|
||||
optim
|
||||
quantization
|
||||
torch.random <random>
|
||||
sparse
|
||||
storage
|
||||
@ -53,6 +55,8 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
|
||||
torch.utils.model_zoo <model_zoo>
|
||||
torch.utils.tensorboard <tensorboard>
|
||||
type_info
|
||||
named_tensor
|
||||
name_inference
|
||||
torch.__config__ <__config__>
|
||||
|
||||
.. toctree::
|
||||
@ -62,9 +66,24 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
|
||||
|
||||
torchvision/index
|
||||
|
||||
* `torchaudio <https://pytorch.org/audio>`_
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: torchaudio Reference
|
||||
|
||||
* `torchtext <https://pytorch.org/text>`_
|
||||
torchaudio <https://pytorch.org/audio>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: torchtext Reference
|
||||
|
||||
torchtext <https://pytorch.org/text>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Other Languages
|
||||
|
||||
C++ API <https://pytorch.org/cppdocs/>
|
||||
packages
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
BIN
docs/source/math-quantizer-equation.png
Normal file
BIN
docs/source/math-quantizer-equation.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 9.7 KiB |
468
docs/source/name_inference.rst
Normal file
468
docs/source/name_inference.rst
Normal file
@ -0,0 +1,468 @@
|
||||
.. currentmodule:: torch
|
||||
|
||||
.. _name_inference_reference-doc:
|
||||
|
||||
Named Tensors operator coverage
|
||||
===============================
|
||||
|
||||
Please read :ref:`named_tensors-doc` first for an introduction to named tensors.
|
||||
|
||||
This document is a reference for *name inference*, a process that defines how
|
||||
named tensors:
|
||||
|
||||
1. use names to provide additional automatic runtime correctness checks
|
||||
2. propagate names from input tensors to output tensors
|
||||
|
||||
Below is a list of all operations that are supported with named tensors
|
||||
and their associated name inference rules.
|
||||
|
||||
If you don't see an operation listed here, but it would help your use case, please
|
||||
`search if an issue has already been filed <https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+named+tensor%22>`_ and if not, `file one <https://github.com/pytorch/pytorch/issues/new/choose>`_.
|
||||
|
||||
.. warning::
|
||||
The named tensor API is experimental and subject to change.
|
||||
|
||||
.. csv-table:: Supported Operations
|
||||
:header: API, Name inference rule
|
||||
:widths: 20, 20
|
||||
|
||||
":meth:`Tensor.abs`, :func:`torch.abs`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.abs_`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.acos`, :func:`torch.acos`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.acos_`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.add`, :func:`torch.add`",:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.add_`,:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.addmm`, :func:`torch.addmm`",:ref:`contracts_away_dims-doc`
|
||||
:meth:`Tensor.addmm_`,:ref:`contracts_away_dims-doc`
|
||||
":meth:`Tensor.addmv`, :func:`torch.addmv`",:ref:`contracts_away_dims-doc`
|
||||
:meth:`Tensor.addmv_`,:ref:`contracts_away_dims-doc`
|
||||
:meth:`Tensor.align_as`,See documentation
|
||||
:meth:`Tensor.align_to`,See documentation
|
||||
":meth:`Tensor.all`, :func:`torch.all`",None
|
||||
":meth:`Tensor.any`, :func:`torch.any`",None
|
||||
":meth:`Tensor.asin`, :func:`torch.asin`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.asin_`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.atan`, :func:`torch.atan`",:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.atan2`, :func:`torch.atan2`",:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.atan2_`,:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.atan_`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.bernoulli`, :func:`torch.bernoulli`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.bernoulli_`,None
|
||||
:meth:`Tensor.bfloat16`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.bitwise_not`, :func:`torch.bitwise_not`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.bitwise_not_`,None
|
||||
":meth:`Tensor.bmm`, :func:`torch.bmm`",:ref:`contracts_away_dims-doc`
|
||||
:meth:`Tensor.bool`,:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.byte`,:ref:`keeps_input_names-doc`
|
||||
:func:`torch.cat`,:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.cauchy_`,None
|
||||
":meth:`Tensor.ceil`, :func:`torch.ceil`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.ceil_`,None
|
||||
:meth:`Tensor.char`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.chunk`, :func:`torch.chunk`",:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.clamp`, :func:`torch.clamp`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.clamp_`,None
|
||||
:meth:`Tensor.copy_`,:ref:`out_function_semantics-doc`
|
||||
":meth:`Tensor.cos`, :func:`torch.cos`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.cos_`,None
|
||||
":meth:`Tensor.cosh`, :func:`torch.cosh`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.cosh_`,None
|
||||
:meth:`Tensor.cpu`,:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.cuda`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.cumprod`, :func:`torch.cumprod`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.cumsum`, :func:`torch.cumsum`",:ref:`removes_dimensions-doc`
|
||||
:meth:`Tensor.data_ptr`,None
|
||||
":meth:`Tensor.detach`, :func:`torch.detach`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.detach_`,None
|
||||
":attr:`Tensor.device`, :func:`torch.device`",None
|
||||
":meth:`Tensor.digamma`, :func:`torch.digamma`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.digamma_`,None
|
||||
:meth:`Tensor.dim`,None
|
||||
":meth:`Tensor.div`, :func:`torch.div`",:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.div_`,:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.dot`, :func:`torch.dot`",None
|
||||
:meth:`Tensor.double`,:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.element_size`,None
|
||||
:func:`torch.empty`,:ref:`factory-doc`
|
||||
:func:`torch.empty_like`,:ref:`factory-doc`
|
||||
":meth:`Tensor.eq`, :func:`torch.eq`",:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.erf`, :func:`torch.erf`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.erf_`,None
|
||||
":meth:`Tensor.erfc`, :func:`torch.erfc`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.erfc_`,None
|
||||
":meth:`Tensor.erfinv`, :func:`torch.erfinv`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.erfinv_`,None
|
||||
":meth:`Tensor.exp`, :func:`torch.exp`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.exp_`,None
|
||||
:meth:`Tensor.expand`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.expm1`, :func:`torch.expm1`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.expm1_`,None
|
||||
:meth:`Tensor.exponential_`,None
|
||||
:meth:`Tensor.fill_`,None
|
||||
":meth:`Tensor.flatten`, :func:`torch.flatten`",See documentation
|
||||
:meth:`Tensor.float`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.floor`, :func:`torch.floor`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.floor_`,None
|
||||
":meth:`Tensor.frac`, :func:`torch.frac`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.frac_`,None
|
||||
":meth:`Tensor.ge`, :func:`torch.ge`",:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.get_device`, :func:`torch.get_device`",None
|
||||
:attr:`Tensor.grad`,None
|
||||
":meth:`Tensor.gt`, :func:`torch.gt`",:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.half`,:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.has_names`,See documentation
|
||||
":meth:`Tensor.index_fill`, :func:`torch.index_fill`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.index_fill_`,None
|
||||
:meth:`Tensor.int`,:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.is_contiguous`,None
|
||||
:attr:`Tensor.is_cuda`,None
|
||||
":meth:`Tensor.is_floating_point`, :func:`torch.is_floating_point`",None
|
||||
:attr:`Tensor.is_leaf`,None
|
||||
:meth:`Tensor.is_pinned`,None
|
||||
:meth:`Tensor.is_shared`,None
|
||||
":meth:`Tensor.is_signed`, :func:`torch.is_signed`",None
|
||||
:attr:`Tensor.is_sparse`,None
|
||||
:func:`torch.is_tensor`,None
|
||||
:meth:`Tensor.item`,None
|
||||
":meth:`Tensor.kthvalue`, :func:`torch.kthvalue`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.le`, :func:`torch.le`",:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.log`, :func:`torch.log`",:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.log10`, :func:`torch.log10`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.log10_`,None
|
||||
":meth:`Tensor.log1p`, :func:`torch.log1p`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.log1p_`,None
|
||||
":meth:`Tensor.log2`, :func:`torch.log2`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.log2_`,None
|
||||
:meth:`Tensor.log_`,None
|
||||
:meth:`Tensor.log_normal_`,None
|
||||
":meth:`Tensor.logical_not`, :func:`torch.logical_not`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.logical_not_`,None
|
||||
":meth:`Tensor.logsumexp`, :func:`torch.logsumexp`",:ref:`removes_dimensions-doc`
|
||||
:meth:`Tensor.long`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.lt`, :func:`torch.lt`",:ref:`unifies_names_from_inputs-doc`
|
||||
:func:`torch.manual_seed`,None
|
||||
":meth:`Tensor.masked_fill`, :func:`torch.masked_fill`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.masked_fill_`,None
|
||||
":meth:`Tensor.masked_select`, :func:`torch.masked_select`",Aligns mask up to input and then unifies_names_from_input_tensors
|
||||
":meth:`Tensor.matmul`, :func:`torch.matmul`",:ref:`contracts_away_dims-doc`
|
||||
":meth:`Tensor.mean`, :func:`torch.mean`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.median`, :func:`torch.median`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.mm`, :func:`torch.mm`",:ref:`contracts_away_dims-doc`
|
||||
":meth:`Tensor.mode`, :func:`torch.mode`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.mul`, :func:`torch.mul`",:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.mul_`,:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.mv`, :func:`torch.mv`",:ref:`contracts_away_dims-doc`
|
||||
:attr:`Tensor.names`,See documentation
|
||||
":meth:`Tensor.narrow`, :func:`torch.narrow`",:ref:`keeps_input_names-doc`
|
||||
:attr:`Tensor.ndim`,None
|
||||
:meth:`Tensor.ndimension`,None
|
||||
":meth:`Tensor.ne`, :func:`torch.ne`",:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.neg`, :func:`torch.neg`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.neg_`,None
|
||||
:func:`torch.normal`,:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.normal_`,None
|
||||
":meth:`Tensor.numel`, :func:`torch.numel`",None
|
||||
:func:`torch.ones`,:ref:`factory-doc`
|
||||
":meth:`Tensor.pow`, :func:`torch.pow`",:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.pow_`,None
|
||||
":meth:`Tensor.prod`, :func:`torch.prod`",:ref:`removes_dimensions-doc`
|
||||
:func:`torch.rand`,:ref:`factory-doc`
|
||||
:func:`torch.rand`,:ref:`factory-doc`
|
||||
:func:`torch.randn`,:ref:`factory-doc`
|
||||
:func:`torch.randn`,:ref:`factory-doc`
|
||||
:meth:`Tensor.random_`,None
|
||||
":meth:`Tensor.reciprocal`, :func:`torch.reciprocal`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.reciprocal_`,None
|
||||
:meth:`Tensor.refine_names`,See documentation
|
||||
:meth:`Tensor.register_hook`,None
|
||||
:meth:`Tensor.rename`,See documentation
|
||||
:meth:`Tensor.rename_`,See documentation
|
||||
:attr:`Tensor.requires_grad`,None
|
||||
:meth:`Tensor.requires_grad_`,None
|
||||
:meth:`Tensor.resize_`,Only allow resizes that do not change shape
|
||||
:meth:`Tensor.resize_as_`,Only allow resizes that do not change shape
|
||||
":meth:`Tensor.round`, :func:`torch.round`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.round_`,None
|
||||
":meth:`Tensor.rsqrt`, :func:`torch.rsqrt`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.rsqrt_`,None
|
||||
":meth:`Tensor.select`, :func:`torch.select`",:ref:`removes_dimensions-doc`
|
||||
:meth:`Tensor.short`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.sigmoid`, :func:`torch.sigmoid`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.sigmoid_`,None
|
||||
":meth:`Tensor.sign`, :func:`torch.sign`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.sign_`,None
|
||||
":meth:`Tensor.sin`, :func:`torch.sin`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.sin_`,None
|
||||
":meth:`Tensor.sinh`, :func:`torch.sinh`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.sinh_`,None
|
||||
:meth:`Tensor.size`,None
|
||||
":meth:`Tensor.split`, :func:`torch.split`",:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.sqrt`, :func:`torch.sqrt`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.sqrt_`,None
|
||||
":meth:`Tensor.squeeze`, :func:`torch.squeeze`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.std`, :func:`torch.std`",:ref:`removes_dimensions-doc`
|
||||
:func:`torch.std_mean`,:ref:`removes_dimensions-doc`
|
||||
:meth:`Tensor.stride`,None
|
||||
":meth:`Tensor.sub`, :func:`torch.sub`",:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.sub_`,:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.sum`, :func:`torch.sum`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.tan`, :func:`torch.tan`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.tan_`,None
|
||||
":meth:`Tensor.tanh`, :func:`torch.tanh`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.tanh_`,None
|
||||
:func:`torch.tensor`,:ref:`factory-doc`
|
||||
:meth:`Tensor.to`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.topk`, :func:`torch.topk`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.transpose`, :func:`torch.transpose`",:ref:`permutes_dimensions-doc`
|
||||
":meth:`Tensor.trunc`, :func:`torch.trunc`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.trunc_`,None
|
||||
:meth:`Tensor.type`,None
|
||||
:meth:`Tensor.type_as`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.unbind`, :func:`torch.unbind`",:ref:`removes_dimensions-doc`
|
||||
:meth:`Tensor.unflatten`,See documentation
|
||||
:meth:`Tensor.uniform_`,None
|
||||
":meth:`Tensor.var`, :func:`torch.var`",:ref:`removes_dimensions-doc`
|
||||
:func:`torch.var_mean`,:ref:`removes_dimensions-doc`
|
||||
:meth:`Tensor.zero_`,None
|
||||
:func:`torch.zeros`,:ref:`factory-doc`
|
||||
|
||||
|
||||
.. _keeps_input_names-doc:
|
||||
|
||||
Keeps input names
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
All pointwise unary functions follow this rule as well as some other unary functions.
|
||||
|
||||
- Check names: None
|
||||
- Propagate names: input tensor's names are propagated to the output.
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(3, 3, names=('N', 'C'))
|
||||
>>> x.abs().names
|
||||
('N', 'C')
|
||||
|
||||
.. _removes_dimensions-doc:
|
||||
|
||||
Removes dimensions
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
All reduction ops like :meth:`~Tensor.sum` remove dimensions by reducing
|
||||
over the desired dimensions. Other operations like :meth:`~Tensor.select` and
|
||||
:meth:`~Tensor.squeeze` remove dimensions.
|
||||
|
||||
Wherever one can pass an integer dimension index to an operator, one can also pass
|
||||
a dimension name. Functions that take lists of dimension indices can also take in a
|
||||
list of dimension names.
|
||||
|
||||
- Check names: If :attr:`dim` or :attr:`dims` is passed in as a list of names,
|
||||
check that those names exist in :attr:`self`.
|
||||
- Propagate names: If the dimensions of the input tensor specified by :attr:`dim`
|
||||
or :attr:`dims` are not present in the output tensor, then the corresponding names
|
||||
of those dimensions do not appear in ``output.names``.
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W'))
|
||||
>>> x.squeeze('N').names
|
||||
('C', 'H', 'W')
|
||||
|
||||
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
|
||||
>>> x.sum(['N', 'C']).names
|
||||
('H', 'W')
|
||||
|
||||
# Reduction ops with keepdim=True don't actually remove dimensions.
|
||||
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
|
||||
>>> x.sum(['N', 'C'], keepdim=True).names
|
||||
('N', 'C', 'H', 'W')
|
||||
|
||||
|
||||
.. _unifies_names_from_inputs-doc:
|
||||
|
||||
Unifies names from inputs
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
All binary arithmetic ops follow this rule. Operations that broadcast still
|
||||
broadcast positionally from the right to preserve compatibility with unnamed
|
||||
tensors. To perform explicit broadcasting by names, use :meth:`Tensor.align_as`.
|
||||
|
||||
- Check names: All names must match positionally from the right. i.e., in
|
||||
``tensor + other``, ``match(tensor.names[i], other.names[i])`` must be true for all
|
||||
``i`` in ``(-min(tensor.dim(), other.dim()) + 1, -1]``.
|
||||
- Check names: Furthermore, all named dimensions must be aligned from the right.
|
||||
During matching, if we match a named dimension ``A`` with an unnamed dimension
|
||||
``None``, then ``A`` must not appear in the tensor with the unnamed dimension.
|
||||
- Propagate names: unify pairs of names from the right from both tensors to
|
||||
produce output names.
|
||||
|
||||
For example,
|
||||
|
||||
::
|
||||
|
||||
# tensor: Tensor[ N, None]
|
||||
# other: Tensor[None, C]
|
||||
>>> tensor = torch.randn(3, 3, names=('N', None))
|
||||
>>> other = torch.randn(3, 3, names=(None, 'C'))
|
||||
>>> (tensor + other).names
|
||||
('N', 'C')
|
||||
|
||||
Check names:
|
||||
|
||||
- ``match(tensor.names[-1], other.names[-1])`` is ``True``
|
||||
- ``match(tensor.names[-2], tensor.names[-2])`` is ``True``
|
||||
- Because we matched ``None`` in :attr:`tensor` with ``'C'``,
|
||||
check to make sure ``'C'`` doesn't exist in :attr:`tensor` (it does not).
|
||||
- Check to make sure ``'N'`` doesn't exists in :attr:`other` (it does not).
|
||||
|
||||
Finally, the output names are computed with
|
||||
``[unify('N', None), unify(None, 'C')] = ['N', 'C']``
|
||||
|
||||
More examples::
|
||||
|
||||
# Dimensions don't match from the right:
|
||||
# tensor: Tensor[N, C]
|
||||
# other: Tensor[ N]
|
||||
>>> tensor = torch.randn(3, 3, names=('N', 'C'))
|
||||
>>> other = torch.randn(3, names=('N',))
|
||||
>>> (tensor + other).names
|
||||
RuntimeError: Error when attempting to broadcast dims ['N', 'C'] and dims
|
||||
['N']: dim 'C' and dim 'N' are at the same position from the right but do
|
||||
not match.
|
||||
|
||||
# Dimensions aren't aligned when matching tensor.names[-1] and other.names[-1]:
|
||||
# tensor: Tensor[N, None]
|
||||
# other: Tensor[ N]
|
||||
>>> tensor = torch.randn(3, 3, names=('N', None))
|
||||
>>> other = torch.randn(3, names=('N',))
|
||||
>>> (tensor + other).names
|
||||
RuntimeError: Misaligned dims when attempting to broadcast dims ['N'] and
|
||||
dims ['N', None]: dim 'N' appears in a different position from the right
|
||||
across both lists.
|
||||
|
||||
.. note::
|
||||
|
||||
In both of the last examples, it is possible to align the tensors by names
|
||||
and then perform the addition. Use :meth:`Tensor.align_as` to align
|
||||
tensors by name or :meth:`Tensor.align_to` to align tensors to a custom
|
||||
dimension ordering.
|
||||
|
||||
.. _permutes_dimensions-doc:
|
||||
|
||||
Permutes dimensions
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Some operations, like :meth:`Tensor.t()`, permute the order of dimensions. Dimension names
|
||||
are attached to individual dimensions so they get permuted as well.
|
||||
|
||||
If the operator takes in positional index :attr:`dim`, it is also able to take a dimension
|
||||
name as :attr:`dim`.
|
||||
|
||||
- Check names: If :attr:`dim` is passed as a name, check that it exists in the tensor.
|
||||
- Propagate names: Permute dimension names in the same way as the dimensions that are
|
||||
being permuted.
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(3, 3, names=('N', 'C'))
|
||||
>>> x.transpose('N', 'C').names
|
||||
('C', 'N')
|
||||
|
||||
.. _contracts_away_dims-doc:
|
||||
|
||||
Contracts away dims
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Matrix multiply functions follow some variant of this. Let's go through
|
||||
:func:`torch.mm` first and then generalize the rule for batch matrix multiplication.
|
||||
|
||||
For ``torch.mm(tensor, other)``:
|
||||
|
||||
- Check names: None
|
||||
- Propagate names: result names are ``(tensor.names[-2], other.names[-1])``.
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(3, 3, names=('N', 'D'))
|
||||
>>> y = torch.randn(3, 3, names=('in', 'out'))
|
||||
>>> x.mm(y).names
|
||||
('N', 'out')
|
||||
|
||||
Inherently, a matrix multiplication performs a dot product over two dimensions,
|
||||
collapsing them. When two tensors are matrix-multipled, the contracted dimensions
|
||||
disappear and do not show up in the output tensor.
|
||||
|
||||
:func:`torch.mv`, :func:`torch.dot` work in a similar way: name inference does not
|
||||
check input names and removes the dimensions that are involved in the dot product:
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(3, 3, names=('N', 'D'))
|
||||
>>> y = torch.randn(3, names=('something',))
|
||||
>>> x.mv(y).names
|
||||
('N',)
|
||||
|
||||
Now, let's take a look at ``torch.matmul(tensor, other)``. Assume that ``tensor.dim() >= 2``
|
||||
and ``other.dim() >= 2``.
|
||||
|
||||
- Check names: Check that the batch dimensions of the inputs are aligned and broadcastable.
|
||||
See :ref:`unifies_names_from_inputs-doc` for what it means for the inputs to be aligned.
|
||||
- Propagate names: result names are obtained by unifying the batch dimensions and removing
|
||||
the contracted dimensions:
|
||||
``unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])``.
|
||||
|
||||
Examples::
|
||||
|
||||
# Batch matrix multiply of matrices Tensor['C', 'D'] and Tensor['E', 'F'].
|
||||
# 'A', 'B' are batch dimensions.
|
||||
>>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D))
|
||||
>>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F))
|
||||
>>> torch.matmul(x, y).names
|
||||
('A', 'B', 'C', 'F')
|
||||
|
||||
|
||||
Finally, there are fused ``add`` versions of many matmul functions. i.e., :func:`addmm`
|
||||
and :func:`addmv`. These are treated as composing name inference for i.e. :func:`mm` and
|
||||
name inference for :func:`add`.
|
||||
|
||||
.. _factory-doc:
|
||||
|
||||
Factory functions
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Factory functions now take a new :attr:`names` argument that associates a name
|
||||
with each dimension.
|
||||
|
||||
::
|
||||
|
||||
>>> torch.zeros(2, 3, names=('N', 'C'))
|
||||
tensor([[0., 0., 0.],
|
||||
[0., 0., 0.]], names=('N', 'C'))
|
||||
|
||||
.. _out_function_semantics-doc:
|
||||
|
||||
out function and in-place variants
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
A tensor specified as an ``out=`` tensor has the following behavior:
|
||||
|
||||
- If it has no named dimensions, then the names computed from the operation
|
||||
get propagated to it.
|
||||
- If it has any named dimensions, then the names computed from the operation
|
||||
must be exactly equal to the existing names. Otherwise, the operation errors.
|
||||
|
||||
All in-place methods modify inputs to have names equal to the computed names
|
||||
from name inference. For example,
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(3, 3)
|
||||
>>> y = torch.randn(3, 3, names=('N', 'C'))
|
||||
>>> x.names
|
||||
(None, None)
|
||||
|
||||
>>> x += y
|
||||
>>> x.names
|
||||
('N', 'C')
|
||||
|
319
docs/source/named_tensor.rst
Normal file
319
docs/source/named_tensor.rst
Normal file
@ -0,0 +1,319 @@
|
||||
.. currentmodule:: torch
|
||||
|
||||
.. _named_tensors-doc:
|
||||
|
||||
Named Tensors
|
||||
=============
|
||||
|
||||
Named Tensors aim to make tensors easier to use by allowing users to associate
|
||||
explicit names with tensor dimensions. In most cases, operations that take
|
||||
dimension parameters will accept dimension names, avoiding the need to track
|
||||
dimensions by position. In addition, named tensors use names to automatically
|
||||
check that APIs are being used correctly at runtime, providing extra safety.
|
||||
Names can also be used to rearrange dimensions, for example, to support
|
||||
"broadcasting by name" rather than "broadcasting by position".
|
||||
|
||||
.. warning::
|
||||
The named tensor API is experimental and subject to change.
|
||||
|
||||
Creating named tensors
|
||||
----------------------
|
||||
|
||||
Factory functions now take a new :attr:`names` argument that associates a name
|
||||
with each dimension.
|
||||
|
||||
::
|
||||
|
||||
>>> torch.zeros(2, 3, names=('N', 'C'))
|
||||
tensor([[0., 0., 0.],
|
||||
[0., 0., 0.]], names=('N', 'C'))
|
||||
|
||||
Named dimensions, like regular Tensor dimensions, are ordered.
|
||||
``tensor.names[i]`` is the name of dimension ``i`` of ``tensor``.
|
||||
|
||||
The following factory functions support named tensors:
|
||||
|
||||
- :func:`torch.empty`
|
||||
- :func:`torch.rand`
|
||||
- :func:`torch.randn`
|
||||
- :func:`torch.ones`
|
||||
- :func:`torch.tensor`
|
||||
- :func:`torch.zeros`
|
||||
|
||||
Named dimensions
|
||||
----------------
|
||||
|
||||
See :attr:`~Tensor.names` for restrictions on tensor names.
|
||||
|
||||
Use :attr:`~Tensor.names` to access the dimension names of a tensor and
|
||||
:meth:`~Tensor.rename` to rename named dimensions.
|
||||
|
||||
::
|
||||
|
||||
>>> imgs = torch.randn(1, 2, 2, 3 , names=('N', 'C', 'H', 'W'))
|
||||
>>> imgs.names
|
||||
('N', 'C', 'H', 'W')
|
||||
|
||||
>>> renamed_imgs = imgs.rename(H='height', W='width')
|
||||
>>> renamed_imgs.names
|
||||
('N', 'C', 'height', 'width)
|
||||
|
||||
|
||||
Named tensors can coexist with unnamed tensors; named tensors are instances of
|
||||
:class:`torch.Tensor`. Unnamed tensors have ``None``-named dimensions. Named
|
||||
tensors do not require all dimensions to be named.
|
||||
|
||||
::
|
||||
|
||||
>>> imgs = torch.randn(1, 2, 2, 3 , names=(None, 'C', 'H', 'W'))
|
||||
>>> imgs.names
|
||||
(None, 'C', 'H', 'W')
|
||||
|
||||
Name propagation semantics
|
||||
--------------------------
|
||||
|
||||
Named tensors use names to automatically check that APIs are being called
|
||||
correctly at runtime. This occurs in a process called *name inference*.
|
||||
More formally, name inference consists of the following two steps:
|
||||
|
||||
- **Check names**: an operator may perform automatic checks at runtime that
|
||||
check that certain dimension names must match.
|
||||
- **Propagate names**: name inference propagates names to output tensors.
|
||||
|
||||
All operations that support named tensors propagate names.
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(3, 3, names=('N', 'C'))
|
||||
>>> x.abs().names
|
||||
('N', 'C')
|
||||
|
||||
|
||||
.. _match_semantics-doc:
|
||||
|
||||
match semantics
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
Two names *match* if they are equal (string equality) or if at least one is ``None``.
|
||||
Nones are essentially a special "wildcard" name.
|
||||
|
||||
``unify(A, B)`` determines which of the names ``A`` and ``B`` to propagate to the outputs.
|
||||
It returns the more *specific* of the two names, if they match. If the names do not match,
|
||||
then it errors.
|
||||
|
||||
.. note::
|
||||
In practice, when working with named tensors, one should avoid having unnamed
|
||||
dimensions because their handling can be complicated. It is recommended to lift
|
||||
all unnamed dimensions to be named dimensions by using :meth:`~Tensor.refine_names`.
|
||||
|
||||
|
||||
Basic name inference rules
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Let's see how ``match`` and ``unify`` are used in name inference in the case of
|
||||
adding two one-dim tensors with no broadcasting.
|
||||
|
||||
::
|
||||
|
||||
x = torch.randn(3, names=('X',))
|
||||
y = torch.randn(3)
|
||||
z = torch.randn(3, names=('Z',))
|
||||
|
||||
**Check names**: check that the names of the two tensors *match*.
|
||||
|
||||
For the following examples:
|
||||
|
||||
::
|
||||
|
||||
>>> # x + y # match('X', None) is True
|
||||
>>> # x + z # match('X', 'Z') is False
|
||||
>>> # x + x # match('X', 'X') is True
|
||||
|
||||
>>> x + z
|
||||
Error when attempting to broadcast dims ['X'] and dims ['Z']: dim 'X' and dim 'Z' are at the same position from the right but do not match.
|
||||
|
||||
**Propagate names**: *unify* the names to select which one to propagate.
|
||||
In the case of ``x + y``, ``unify('X', None) = 'X'`` because ``'X'`` is more
|
||||
specific than ``None``.
|
||||
|
||||
::
|
||||
|
||||
>>> (x + y).names
|
||||
('X',)
|
||||
>>> (x + x).names
|
||||
('X',)
|
||||
|
||||
For a comprehensive list of name inference rules, see :ref:`name_inference_reference-doc`.
|
||||
Here are two common operations that may be useful to go over:
|
||||
|
||||
- Binary arithmetic ops: :ref:`unifies_names_from_inputs-doc`
|
||||
- Matrix multiplication ops: :ref:`contracts_away_dims-doc`
|
||||
|
||||
Explicit alignment by names
|
||||
---------------------------
|
||||
|
||||
Use :meth:`~Tensor.align_as` or :meth:`~Tensor.align_to` to align tensor dimensions
|
||||
by name to a specified ordering. This is useful for performing "broadcasting by names".
|
||||
|
||||
::
|
||||
|
||||
# This function is agnostic to the dimension ordering of `input`,
|
||||
# as long as it has a `C` dimension somewhere.
|
||||
def scale_channels(input, scale):
|
||||
scale = scale.refine_names('C')
|
||||
return input * scale.align_as(input)
|
||||
|
||||
>>> num_channels = 3
|
||||
>>> scale = torch.randn(num_channels, names='C')
|
||||
>>> imgs = torch.rand(3, 3, 3, num_channels, names=('N', 'H', 'W', 'C'))
|
||||
>>> more_imgs = torch.rand(3, num_channels, 3, 3, names=('N', 'C', 'H', 'W'))
|
||||
>>> videos = torch.randn(3, num_channels, 3, 3, 3, names=('N', 'C', 'H', 'W', 'D')
|
||||
|
||||
>>> scale_channels(imgs, scale)
|
||||
>>> scale_channels(more_imgs, scale)
|
||||
>>> scale_channels(videos, scale)
|
||||
|
||||
Manipulating dimensions
|
||||
-----------------------
|
||||
|
||||
Use :meth:`~Tensor.align_to` to permute large amounts of dimensions without
|
||||
mentioning all of them as in required by :meth:`~Tensor.permute`.
|
||||
|
||||
::
|
||||
|
||||
>>> tensor = torch.randn(2, 2, 2, 2, 2, 2)
|
||||
>>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F')
|
||||
|
||||
# Move the F (dim 5) and E dimension (dim 4) to the front while keeping
|
||||
# the rest in the same order
|
||||
>>> tensor.permute(5, 4, 0, 1, 2, 3)
|
||||
>>> named_tensor.align_to('F', 'E', ...) # Use '...' instead in Python 2
|
||||
|
||||
Use :meth:`~Tensor.flatten` and :meth:`~Tensor.unflatten` to flatten and unflatten
|
||||
dimensions, respectively. These methods are more verbose than :meth:`~Tensor.view`
|
||||
and :meth:`~Tensor.reshape`, but have more semantic meaning to someone reading the code.
|
||||
|
||||
::
|
||||
|
||||
>>> imgs = torch.randn(32, 3, 128, 128)
|
||||
>>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W')
|
||||
|
||||
>>> flat_imgs = imgs.view(32, -1)
|
||||
>>> named_flat_imgs = named_imgs.flatten(['C', 'H', 'W'], 'features')
|
||||
>>> named_flat_imgs.names
|
||||
('N', 'features')
|
||||
|
||||
>>> unflattened_imgs = imgs.view(32, 3, 128, 128)
|
||||
>>> unflattened_named_imgs = named_flat_imgs.unflatten(
|
||||
'features', [('C', 3), ('H', 128), ('W', 128)])
|
||||
|
||||
.. _named_tensors_autograd-doc:
|
||||
|
||||
Autograd support
|
||||
----------------
|
||||
|
||||
Autograd currently supports named tensors in a limited manner: autograd ignores
|
||||
names on all tensors. Gradient computation is still correct but we lose the
|
||||
safety that names give us.
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(3, names=('D',))
|
||||
>>> weight = torch.randn(3, names=('D',), requires_grad=True)
|
||||
>>> loss = (x - weight).abs()
|
||||
>>> grad_loss = torch.randn(3)
|
||||
>>> loss.backward(grad_loss)
|
||||
>>> weight.grad # Unnamed for now. Will be named in the future
|
||||
tensor([-1.8107, -0.6357, 0.0783])
|
||||
|
||||
>>> weight.grad.zero_()
|
||||
>>> grad_loss = grad_loss.refine_names('C')
|
||||
>>> loss = (x - weight).abs()
|
||||
# Ideally we'd check that the names of loss and grad_loss match but we don't yet.
|
||||
>>> loss.backward(grad_loss)
|
||||
>>> weight.grad
|
||||
tensor([-1.8107, -0.6357, 0.0783])
|
||||
|
||||
Currently supported operations and subsystems
|
||||
---------------------------------------------
|
||||
|
||||
Operators
|
||||
^^^^^^^^^
|
||||
|
||||
See :ref:`name_inference_reference-doc` for a full list of the supported torch and
|
||||
tensor operations. We do not yet support the following that is not covered by the link:
|
||||
|
||||
- indexing, advanced indexing.
|
||||
|
||||
For ``torch.nn.functional`` operators, we support the following:
|
||||
|
||||
- :func:`torch.nn.functional.relu`
|
||||
- :func:`torch.nn.functional.softmax`
|
||||
- :func:`torch.nn.functional.log_softmax`
|
||||
- :func:`torch.nn.functional.tanh`
|
||||
- :func:`torch.nn.functional.sigmoid`
|
||||
- :func:`torch.nn.functional.dropout`
|
||||
|
||||
Subsystems
|
||||
^^^^^^^^^^
|
||||
|
||||
Autograd is supported, see :ref:`named_tensors_autograd-doc`.
|
||||
Because gradients are currently unnamed, optimizers may work but are untested.
|
||||
|
||||
NN modules are currently unsupported. This can lead to the following when calling
|
||||
modules with named tensor inputs:
|
||||
|
||||
- NN module parameters are unnamed, so outputs may be partially named.
|
||||
- NN module forward passes have code that don't support named tensors and will
|
||||
error out appropriately.
|
||||
|
||||
We also do not support the following subsystems, though some may work out
|
||||
of the box:
|
||||
|
||||
- distributions
|
||||
- serialization (:func:`torch.load`, :func:`torch.save`)
|
||||
- multiprocessing
|
||||
- JIT
|
||||
- distributed
|
||||
- ONNX
|
||||
|
||||
If any of these would help your use case, please
|
||||
`search if an issue has already been filed <https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+named+tensor%22>`_
|
||||
and if not, `file one <https://github.com/pytorch/pytorch/issues/new/choose>`_.
|
||||
|
||||
Named tensor API reference
|
||||
--------------------------
|
||||
|
||||
In this section please find the documentation for named tensor specific APIs.
|
||||
For a comprehensive reference for how names are propagated through other PyTorch
|
||||
operators, see :ref:`name_inference_reference-doc`.
|
||||
|
||||
.. class:: Tensor()
|
||||
:noindex:
|
||||
|
||||
.. autoattribute:: names
|
||||
.. automethod:: rename
|
||||
.. automethod:: rename_
|
||||
.. automethod:: refine_names
|
||||
|
||||
.. automethod:: align_as
|
||||
.. automethod:: align_to
|
||||
|
||||
.. automethod:: unflatten
|
||||
.. py:method:: flatten(dims, out_dim) -> Tensor
|
||||
|
||||
Flattens :attr:`dims` into a single dimension with name :attr:`out_dim`.
|
||||
|
||||
All of `dims` must be consecutive in order in the :attr:`self` tensor,
|
||||
but not necessary contiguous in memory.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> imgs = torch.randn(32, 3, 128, 128, names=('N', 'C', 'H', 'W'))
|
||||
>>> flat_imgs = imgs.flatten(['C', 'H', 'W'], 'features')
|
||||
>>> flat_imgs.names, flat_imgs.shape
|
||||
(('N', 'features'), torch.Size([32, 49152]))
|
||||
|
||||
.. warning::
|
||||
The named tensor API is experimental and subject to change.
|
||||
|
@ -877,3 +877,10 @@ Utilities
|
||||
|
||||
.. autoclass:: Flatten
|
||||
:members:
|
||||
|
||||
|
||||
Quantized Functions
|
||||
--------------------
|
||||
|
||||
Quantization refers to techniques for performing computations and storing tensors at lower bitwidths than
|
||||
floating point precision. PyTorch supports both per tensor and per channel asymmetric linear quantization. To learn more how to use quantized functions in PyTorch, please refer to the :ref:`Quantization` documentation.
|
||||
|
67
docs/source/org/pytorch/DType.rst
Normal file
67
docs/source/org/pytorch/DType.rst
Normal file
@ -0,0 +1,67 @@
|
||||
DType
|
||||
=====
|
||||
|
||||
.. java:package:: org.pytorch
|
||||
:noindex:
|
||||
|
||||
.. java:type:: public enum DType
|
||||
|
||||
Codes representing tensor data types.
|
||||
|
||||
Enum Constants
|
||||
--------------
|
||||
FLOAT32
|
||||
^^^^^^^
|
||||
|
||||
.. java:field:: public static final DType FLOAT32
|
||||
:outertype: DType
|
||||
|
||||
Code for dtype torch.float32. \ :java:ref:`Tensor.dtype()`\
|
||||
|
||||
FLOAT64
|
||||
^^^^^^^
|
||||
|
||||
.. java:field:: public static final DType FLOAT64
|
||||
:outertype: DType
|
||||
|
||||
Code for dtype torch.float64. \ :java:ref:`Tensor.dtype()`\
|
||||
|
||||
INT32
|
||||
^^^^^
|
||||
|
||||
.. java:field:: public static final DType INT32
|
||||
:outertype: DType
|
||||
|
||||
Code for dtype torch.int32. \ :java:ref:`Tensor.dtype()`\
|
||||
|
||||
INT64
|
||||
^^^^^
|
||||
|
||||
.. java:field:: public static final DType INT64
|
||||
:outertype: DType
|
||||
|
||||
Code for dtype torch.int64. \ :java:ref:`Tensor.dtype()`\
|
||||
|
||||
INT8
|
||||
^^^^
|
||||
|
||||
.. java:field:: public static final DType INT8
|
||||
:outertype: DType
|
||||
|
||||
Code for dtype torch.int8. \ :java:ref:`Tensor.dtype()`\
|
||||
|
||||
UINT8
|
||||
^^^^^
|
||||
|
||||
.. java:field:: public static final DType UINT8
|
||||
:outertype: DType
|
||||
|
||||
Code for dtype torch.uint8. \ :java:ref:`Tensor.dtype()`\
|
||||
|
||||
Fields
|
||||
------
|
||||
jniCode
|
||||
^^^^^^^
|
||||
|
||||
.. java:field:: final int jniCode
|
||||
:outertype: DType
|
297
docs/source/org/pytorch/IValue.rst
Normal file
297
docs/source/org/pytorch/IValue.rst
Normal file
@ -0,0 +1,297 @@
|
||||
.. java:import:: java.util Locale
|
||||
|
||||
.. java:import:: java.util Map
|
||||
|
||||
IValue
|
||||
======
|
||||
|
||||
.. java:package:: org.pytorch
|
||||
:noindex:
|
||||
|
||||
.. java:type:: public class IValue
|
||||
|
||||
Java representation of a TorchScript value, which is implemented as tagged union that can be one of the supported types: https://pytorch.org/docs/stable/jit.html#types .
|
||||
|
||||
Calling \ ``toX``\ methods for inappropriate types will throw \ :java:ref:`IllegalStateException`\ .
|
||||
|
||||
\ ``IValue``\ objects are constructed with \ ``IValue.from(value)``\ , \ ``IValue.tupleFrom(value1, value2, ...)``\ , \ ``IValue.listFrom(value1, value2, ...)``\ , or one of the \ ``dict``\ methods, depending on the key type.
|
||||
|
||||
Data is retrieved from \ ``IValue``\ objects with the \ ``toX()``\ methods. Note that \ ``str``\ -type IValues must be extracted with \ :java:ref:`toStr()`\ , rather than \ :java:ref:`toString()`\ .
|
||||
|
||||
\ ``IValue``\ objects may retain references to objects passed into their constructors, and may return references to their internal state from \ ``toX()``\ .
|
||||
|
||||
Methods
|
||||
-------
|
||||
dictLongKeyFrom
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue dictLongKeyFrom(Map<Long, IValue> map)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``Dict[int, V]``\ .
|
||||
|
||||
dictStringKeyFrom
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue dictStringKeyFrom(Map<String, IValue> map)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``Dict[str, V]``\ .
|
||||
|
||||
from
|
||||
^^^^
|
||||
|
||||
.. java:method:: public static IValue from(Tensor tensor)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``Tensor``\ .
|
||||
|
||||
from
|
||||
^^^^
|
||||
|
||||
.. java:method:: public static IValue from(boolean value)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``bool``\ .
|
||||
|
||||
from
|
||||
^^^^
|
||||
|
||||
.. java:method:: public static IValue from(long value)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``int``\ .
|
||||
|
||||
from
|
||||
^^^^
|
||||
|
||||
.. java:method:: public static IValue from(double value)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``float``\ .
|
||||
|
||||
from
|
||||
^^^^
|
||||
|
||||
.. java:method:: public static IValue from(String value)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``str``\ .
|
||||
|
||||
isBool
|
||||
^^^^^^
|
||||
|
||||
.. java:method:: public boolean isBool()
|
||||
:outertype: IValue
|
||||
|
||||
isBoolList
|
||||
^^^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isBoolList()
|
||||
:outertype: IValue
|
||||
|
||||
isDictLongKey
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isDictLongKey()
|
||||
:outertype: IValue
|
||||
|
||||
isDictStringKey
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isDictStringKey()
|
||||
:outertype: IValue
|
||||
|
||||
isDouble
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isDouble()
|
||||
:outertype: IValue
|
||||
|
||||
isDoubleList
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isDoubleList()
|
||||
:outertype: IValue
|
||||
|
||||
isList
|
||||
^^^^^^
|
||||
|
||||
.. java:method:: public boolean isList()
|
||||
:outertype: IValue
|
||||
|
||||
isLong
|
||||
^^^^^^
|
||||
|
||||
.. java:method:: public boolean isLong()
|
||||
:outertype: IValue
|
||||
|
||||
isLongList
|
||||
^^^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isLongList()
|
||||
:outertype: IValue
|
||||
|
||||
isNull
|
||||
^^^^^^
|
||||
|
||||
.. java:method:: public boolean isNull()
|
||||
:outertype: IValue
|
||||
|
||||
isString
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isString()
|
||||
:outertype: IValue
|
||||
|
||||
isTensor
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isTensor()
|
||||
:outertype: IValue
|
||||
|
||||
isTensorList
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isTensorList()
|
||||
:outertype: IValue
|
||||
|
||||
isTuple
|
||||
^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isTuple()
|
||||
:outertype: IValue
|
||||
|
||||
listFrom
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue listFrom(boolean... list)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``List[bool]``\ .
|
||||
|
||||
listFrom
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue listFrom(long... list)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``List[int]``\ .
|
||||
|
||||
listFrom
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue listFrom(double... list)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``List[float]``\ .
|
||||
|
||||
listFrom
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue listFrom(Tensor... list)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``List[Tensor]``\ .
|
||||
|
||||
listFrom
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue listFrom(IValue... array)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``List[T]``\ . All elements must have the same type.
|
||||
|
||||
optionalNull
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue optionalNull()
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``Optional``\ that contains no value.
|
||||
|
||||
toBool
|
||||
^^^^^^
|
||||
|
||||
.. java:method:: public boolean toBool()
|
||||
:outertype: IValue
|
||||
|
||||
toBoolList
|
||||
^^^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean[] toBoolList()
|
||||
:outertype: IValue
|
||||
|
||||
toDictLongKey
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public Map<Long, IValue> toDictLongKey()
|
||||
:outertype: IValue
|
||||
|
||||
toDictStringKey
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public Map<String, IValue> toDictStringKey()
|
||||
:outertype: IValue
|
||||
|
||||
toDouble
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public double toDouble()
|
||||
:outertype: IValue
|
||||
|
||||
toDoubleList
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public double[] toDoubleList()
|
||||
:outertype: IValue
|
||||
|
||||
toList
|
||||
^^^^^^
|
||||
|
||||
.. java:method:: public IValue[] toList()
|
||||
:outertype: IValue
|
||||
|
||||
toLong
|
||||
^^^^^^
|
||||
|
||||
.. java:method:: public long toLong()
|
||||
:outertype: IValue
|
||||
|
||||
toLongList
|
||||
^^^^^^^^^^
|
||||
|
||||
.. java:method:: public long[] toLongList()
|
||||
:outertype: IValue
|
||||
|
||||
toStr
|
||||
^^^^^
|
||||
|
||||
.. java:method:: public String toStr()
|
||||
:outertype: IValue
|
||||
|
||||
toTensor
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public Tensor toTensor()
|
||||
:outertype: IValue
|
||||
|
||||
toTensorList
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public Tensor[] toTensorList()
|
||||
:outertype: IValue
|
||||
|
||||
toTuple
|
||||
^^^^^^^
|
||||
|
||||
.. java:method:: public IValue[] toTuple()
|
||||
:outertype: IValue
|
||||
|
||||
tupleFrom
|
||||
^^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue tupleFrom(IValue... array)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``Tuple[T0, T1, ...]``\ .
|
55
docs/source/org/pytorch/Module.rst
Normal file
55
docs/source/org/pytorch/Module.rst
Normal file
@ -0,0 +1,55 @@
|
||||
.. java:import:: com.facebook.jni HybridData
|
||||
|
||||
Module
|
||||
======
|
||||
|
||||
.. java:package:: org.pytorch
|
||||
:noindex:
|
||||
|
||||
.. java:type:: public class Module
|
||||
|
||||
Java wrapper for torch::jit::script::Module.
|
||||
|
||||
Methods
|
||||
-------
|
||||
destroy
|
||||
^^^^^^^
|
||||
|
||||
.. java:method:: public void destroy()
|
||||
:outertype: Module
|
||||
|
||||
Explicitly destroys the native torch::jit::script::Module. Calling this method is not required, as the native object will be destroyed when this object is garbage-collected. However, the timing of garbage collection is not guaranteed, so proactively calling \ ``destroy``\ can free memory more quickly. See \ :java:ref:`com.facebook.jni.HybridData.resetNative`\ .
|
||||
|
||||
forward
|
||||
^^^^^^^
|
||||
|
||||
.. java:method:: public IValue forward(IValue... inputs)
|
||||
:outertype: Module
|
||||
|
||||
Runs the 'forward' method of this module with the specified arguments.
|
||||
|
||||
:param inputs: arguments for the TorchScript module's 'forward' method.
|
||||
:return: return value from the 'forward' method.
|
||||
|
||||
load
|
||||
^^^^
|
||||
|
||||
.. java:method:: public static Module load(String modelPath)
|
||||
:outertype: Module
|
||||
|
||||
Loads a serialized TorchScript module from the specified path on the disk.
|
||||
|
||||
:param modelPath: path to file that contains the serialized TorchScript module.
|
||||
:return: new \ :java:ref:`org.pytorch.Module`\ object which owns torch::jit::script::Module.
|
||||
|
||||
runMethod
|
||||
^^^^^^^^^
|
||||
|
||||
.. java:method:: public IValue runMethod(String methodName, IValue... inputs)
|
||||
:outertype: Module
|
||||
|
||||
Runs the specified method of this module with the specified arguments.
|
||||
|
||||
:param methodName: name of the TorchScript method to run.
|
||||
:param inputs: arguments that will be passed to TorchScript method.
|
||||
:return: return value from the method.
|
315
docs/source/org/pytorch/Tensor.rst
Normal file
315
docs/source/org/pytorch/Tensor.rst
Normal file
@ -0,0 +1,315 @@
|
||||
.. java:import:: java.nio Buffer
|
||||
|
||||
.. java:import:: java.nio ByteBuffer
|
||||
|
||||
.. java:import:: java.nio ByteOrder
|
||||
|
||||
.. java:import:: java.nio DoubleBuffer
|
||||
|
||||
.. java:import:: java.nio FloatBuffer
|
||||
|
||||
.. java:import:: java.nio IntBuffer
|
||||
|
||||
.. java:import:: java.nio LongBuffer
|
||||
|
||||
.. java:import:: java.util Arrays
|
||||
|
||||
.. java:import:: java.util Locale
|
||||
|
||||
Tensor
|
||||
======
|
||||
|
||||
.. java:package:: org.pytorch
|
||||
:noindex:
|
||||
|
||||
.. java:type:: public abstract class Tensor
|
||||
|
||||
Representation of a Tensor. Behavior is similar to PyTorch's tensor objects.
|
||||
|
||||
Most tensors will be constructed as \ ``Tensor.fromBlob(data, shape)``\ , where \ ``data``\ can be an array or a direct \ :java:ref:`Buffer`\ (of the proper subclass). Helper methods are provided to allocate buffers properly.
|
||||
|
||||
To access Tensor data, see \ :java:ref:`dtype()`\ , \ :java:ref:`shape()`\ , and various \ ``getDataAs*``\ methods.
|
||||
|
||||
When constructing \ ``Tensor``\ objects with \ ``data``\ as an array, it is not specified whether this data is is copied or retained as a reference so it is recommended not to modify it after constructing. \ ``data``\ passed as a \ :java:ref:`Buffer`\ is not copied, so it can be modified between \ :java:ref:`Module`\ calls to avoid reallocation. Data retrieved from \ ``Tensor``\ objects may be copied or may be a reference to the \ ``Tensor``\ 's internal data buffer. \ ``shape``\ is always copied.
|
||||
|
||||
Methods
|
||||
-------
|
||||
allocateByteBuffer
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static ByteBuffer allocateByteBuffer(int numElements)
|
||||
:outertype: Tensor
|
||||
|
||||
Allocates a new direct \ :java:ref:`java.nio.ByteBuffer`\ with native byte order with specified capacity that can be used in \ :java:ref:`Tensor.fromBlob(ByteBuffer,long[])`\ , \ :java:ref:`Tensor.fromBlobUnsigned(ByteBuffer,long[])`\ .
|
||||
|
||||
:param numElements: capacity (number of elements) of result buffer.
|
||||
|
||||
allocateDoubleBuffer
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static DoubleBuffer allocateDoubleBuffer(int numElements)
|
||||
:outertype: Tensor
|
||||
|
||||
Allocates a new direct \ :java:ref:`java.nio.DoubleBuffer`\ with native byte order with specified capacity that can be used in \ :java:ref:`Tensor.fromBlob(DoubleBuffer,long[])`\ .
|
||||
|
||||
:param numElements: capacity (number of elements) of result buffer.
|
||||
|
||||
allocateFloatBuffer
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static FloatBuffer allocateFloatBuffer(int numElements)
|
||||
:outertype: Tensor
|
||||
|
||||
Allocates a new direct \ :java:ref:`java.nio.FloatBuffer`\ with native byte order with specified capacity that can be used in \ :java:ref:`Tensor.fromBlob(FloatBuffer,long[])`\ .
|
||||
|
||||
:param numElements: capacity (number of elements) of result buffer.
|
||||
|
||||
allocateIntBuffer
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static IntBuffer allocateIntBuffer(int numElements)
|
||||
:outertype: Tensor
|
||||
|
||||
Allocates a new direct \ :java:ref:`java.nio.IntBuffer`\ with native byte order with specified capacity that can be used in \ :java:ref:`Tensor.fromBlob(IntBuffer,long[])`\ .
|
||||
|
||||
:param numElements: capacity (number of elements) of result buffer.
|
||||
|
||||
allocateLongBuffer
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static LongBuffer allocateLongBuffer(int numElements)
|
||||
:outertype: Tensor
|
||||
|
||||
Allocates a new direct \ :java:ref:`java.nio.LongBuffer`\ with native byte order with specified capacity that can be used in \ :java:ref:`Tensor.fromBlob(LongBuffer,long[])`\ .
|
||||
|
||||
:param numElements: capacity (number of elements) of result buffer.
|
||||
|
||||
dtype
|
||||
^^^^^
|
||||
|
||||
.. java:method:: public abstract DType dtype()
|
||||
:outertype: Tensor
|
||||
|
||||
:return: data type of this tensor.
|
||||
|
||||
dtypeJniCode
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: int dtypeJniCode()
|
||||
:outertype: Tensor
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(byte[] data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of bytes.
|
||||
|
||||
:param data: Tensor elements
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(int[] data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of ints.
|
||||
|
||||
:param data: Tensor elements
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(float[] data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array of floats.
|
||||
|
||||
:param data: Tensor elements
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(long[] data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of longs.
|
||||
|
||||
:param data: Tensor elements
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(long[] shape, double[] data)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.float64 with specified shape and data as array of doubles.
|
||||
|
||||
:param shape: Tensor shape
|
||||
:param data: Tensor elements
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(ByteBuffer data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.int8 with specified shape and data.
|
||||
|
||||
:param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\ elements. The buffer is used directly without copying, and changes to its content will change the tensor.
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(IntBuffer data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.int32 with specified shape and data.
|
||||
|
||||
:param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\ elements. The buffer is used directly without copying, and changes to its content will change the tensor.
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(FloatBuffer data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.float32 with specified shape and data.
|
||||
|
||||
:param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\ elements. The buffer is used directly without copying, and changes to its content will change the tensor.
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(LongBuffer data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.int64 with specified shape and data.
|
||||
|
||||
:param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\ elements. The buffer is used directly without copying, and changes to its content will change the tensor.
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(DoubleBuffer data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.float64 with specified shape and data.
|
||||
|
||||
:param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\ elements. The buffer is used directly without copying, and changes to its content will change the tensor.
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlobUnsigned
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlobUnsigned(byte[] data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of bytes.
|
||||
|
||||
:param data: Tensor elements
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlobUnsigned
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.uint8 with specified shape and data.
|
||||
|
||||
:param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\ elements. The buffer is used directly without copying, and changes to its content will change the tensor.
|
||||
:param shape: Tensor shape
|
||||
|
||||
getDataAsByteArray
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public byte[] getDataAsByteArray()
|
||||
:outertype: Tensor
|
||||
|
||||
:throws IllegalStateException: if it is called for a non-int8 tensor.
|
||||
:return: a Java byte array that contains the tensor data. This may be a copy or reference.
|
||||
|
||||
getDataAsDoubleArray
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public double[] getDataAsDoubleArray()
|
||||
:outertype: Tensor
|
||||
|
||||
:throws IllegalStateException: if it is called for a non-float64 tensor.
|
||||
:return: a Java double array that contains the tensor data. This may be a copy or reference.
|
||||
|
||||
getDataAsFloatArray
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public float[] getDataAsFloatArray()
|
||||
:outertype: Tensor
|
||||
|
||||
:throws IllegalStateException: if it is called for a non-float32 tensor.
|
||||
:return: a Java float array that contains the tensor data. This may be a copy or reference.
|
||||
|
||||
getDataAsIntArray
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public int[] getDataAsIntArray()
|
||||
:outertype: Tensor
|
||||
|
||||
:throws IllegalStateException: if it is called for a non-int32 tensor.
|
||||
:return: a Java int array that contains the tensor data. This may be a copy or reference.
|
||||
|
||||
getDataAsLongArray
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public long[] getDataAsLongArray()
|
||||
:outertype: Tensor
|
||||
|
||||
:throws IllegalStateException: if it is called for a non-int64 tensor.
|
||||
:return: a Java long array that contains the tensor data. This may be a copy or reference.
|
||||
|
||||
getDataAsUnsignedByteArray
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public byte[] getDataAsUnsignedByteArray()
|
||||
:outertype: Tensor
|
||||
|
||||
:throws IllegalStateException: if it is called for a non-uint8 tensor.
|
||||
:return: a Java byte array that contains the tensor data. This may be a copy or reference.
|
||||
|
||||
getRawDataBuffer
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: Buffer getRawDataBuffer()
|
||||
:outertype: Tensor
|
||||
|
||||
numel
|
||||
^^^^^
|
||||
|
||||
.. java:method:: public long numel()
|
||||
:outertype: Tensor
|
||||
|
||||
Returns the number of elements in this tensor.
|
||||
|
||||
numel
|
||||
^^^^^
|
||||
|
||||
.. java:method:: public static long numel(long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Calculates the number of elements in a tensor with the specified shape.
|
||||
|
||||
shape
|
||||
^^^^^
|
||||
|
||||
.. java:method:: public long[] shape()
|
||||
:outertype: Tensor
|
||||
|
||||
Returns the shape of this tensor. (The array is a fresh copy.)
|
114
docs/source/org/pytorch/TensorImageUtils.rst
Normal file
114
docs/source/org/pytorch/TensorImageUtils.rst
Normal file
@ -0,0 +1,114 @@
|
||||
.. java:import:: android.graphics Bitmap
|
||||
|
||||
.. java:import:: android.graphics ImageFormat
|
||||
|
||||
.. java:import:: android.media Image
|
||||
|
||||
.. java:import:: org.pytorch Tensor
|
||||
|
||||
.. java:import:: java.nio ByteBuffer
|
||||
|
||||
.. java:import:: java.nio FloatBuffer
|
||||
|
||||
.. java:import:: java.util Locale
|
||||
|
||||
TensorImageUtils
|
||||
================
|
||||
|
||||
.. java:package:: org.pytorch.torchvision
|
||||
:noindex:
|
||||
|
||||
.. java:type:: public final class TensorImageUtils
|
||||
|
||||
Contains utility functions for \ :java:ref:`org.pytorch.Tensor`\ creation from \ :java:ref:`android.graphics.Bitmap`\ or \ :java:ref:`android.media.Image`\ source.
|
||||
|
||||
Fields
|
||||
------
|
||||
TORCHVISION_NORM_MEAN_RGB
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:field:: public static float[] TORCHVISION_NORM_MEAN_RGB
|
||||
:outertype: TensorImageUtils
|
||||
|
||||
TORCHVISION_NORM_STD_RGB
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:field:: public static float[] TORCHVISION_NORM_STD_RGB
|
||||
:outertype: TensorImageUtils
|
||||
|
||||
Methods
|
||||
-------
|
||||
bitmapToFloat32Tensor
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor bitmapToFloat32Tensor(Bitmap bitmap, float[] normMeanRGB, float[] normStdRGB)
|
||||
:outertype: TensorImageUtils
|
||||
|
||||
Creates new \ :java:ref:`org.pytorch.Tensor`\ from full \ :java:ref:`android.graphics.Bitmap`\ , normalized with specified in parameters mean and std.
|
||||
|
||||
:param normMeanRGB: means for RGB channels normalization, length must equal 3, RGB order
|
||||
:param normStdRGB: standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
|
||||
bitmapToFloat32Tensor
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor bitmapToFloat32Tensor(Bitmap bitmap, int x, int y, int width, int height, float[] normMeanRGB, float[] normStdRGB)
|
||||
:outertype: TensorImageUtils
|
||||
|
||||
Creates new \ :java:ref:`org.pytorch.Tensor`\ from specified area of \ :java:ref:`android.graphics.Bitmap`\ , normalized with specified in parameters mean and std.
|
||||
|
||||
:param bitmap: \ :java:ref:`android.graphics.Bitmap`\ as a source for Tensor data
|
||||
:param x: - x coordinate of top left corner of bitmap's area
|
||||
:param y: - y coordinate of top left corner of bitmap's area
|
||||
:param width: - width of bitmap's area
|
||||
:param height: - height of bitmap's area
|
||||
:param normMeanRGB: means for RGB channels normalization, length must equal 3, RGB order
|
||||
:param normStdRGB: standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
|
||||
bitmapToFloatBuffer
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static void bitmapToFloatBuffer(Bitmap bitmap, int x, int y, int width, int height, float[] normMeanRGB, float[] normStdRGB, FloatBuffer outBuffer, int outBufferOffset)
|
||||
:outertype: TensorImageUtils
|
||||
|
||||
Writes tensor content from specified \ :java:ref:`android.graphics.Bitmap`\ , normalized with specified in parameters mean and std to specified \ :java:ref:`java.nio.FloatBuffer`\ with specified offset.
|
||||
|
||||
:param bitmap: \ :java:ref:`android.graphics.Bitmap`\ as a source for Tensor data
|
||||
:param x: - x coordinate of top left corner of bitmap's area
|
||||
:param y: - y coordinate of top left corner of bitmap's area
|
||||
:param width: - width of bitmap's area
|
||||
:param height: - height of bitmap's area
|
||||
:param normMeanRGB: means for RGB channels normalization, length must equal 3, RGB order
|
||||
:param normStdRGB: standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
|
||||
imageYUV420CenterCropToFloat32Tensor
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor imageYUV420CenterCropToFloat32Tensor(Image image, int rotateCWDegrees, int tensorWidth, int tensorHeight, float[] normMeanRGB, float[] normStdRGB)
|
||||
:outertype: TensorImageUtils
|
||||
|
||||
Creates new \ :java:ref:`org.pytorch.Tensor`\ from specified area of \ :java:ref:`android.media.Image`\ , doing optional rotation, scaling (nearest) and center cropping.
|
||||
|
||||
:param image: \ :java:ref:`android.media.Image`\ as a source for Tensor data
|
||||
:param rotateCWDegrees: Clockwise angle through which the input image needs to be rotated to be upright. Range of valid values: 0, 90, 180, 270
|
||||
:param tensorWidth: return tensor width, must be positive
|
||||
:param tensorHeight: return tensor height, must be positive
|
||||
:param normMeanRGB: means for RGB channels normalization, length must equal 3, RGB order
|
||||
:param normStdRGB: standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
|
||||
imageYUV420CenterCropToFloatBuffer
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static void imageYUV420CenterCropToFloatBuffer(Image image, int rotateCWDegrees, int tensorWidth, int tensorHeight, float[] normMeanRGB, float[] normStdRGB, FloatBuffer outBuffer, int outBufferOffset)
|
||||
:outertype: TensorImageUtils
|
||||
|
||||
Writes tensor content from specified \ :java:ref:`android.media.Image`\ , doing optional rotation, scaling (nearest) and center cropping to specified \ :java:ref:`java.nio.FloatBuffer`\ with specified offset.
|
||||
|
||||
:param image: \ :java:ref:`android.media.Image`\ as a source for Tensor data
|
||||
:param rotateCWDegrees: Clockwise angle through which the input image needs to be rotated to be upright. Range of valid values: 0, 90, 180, 270
|
||||
:param tensorWidth: return tensor width, must be positive
|
||||
:param tensorHeight: return tensor height, must be positive
|
||||
:param normMeanRGB: means for RGB channels normalization, length must equal 3, RGB order
|
||||
:param normStdRGB: standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
:param outBuffer: Output buffer, where tensor content will be written
|
||||
:param outBufferOffset: Output buffer offset with which tensor content will be written
|
18
docs/source/org/pytorch/package-index.rst
Normal file
18
docs/source/org/pytorch/package-index.rst
Normal file
@ -0,0 +1,18 @@
|
||||
org.pytorch
|
||||
===========
|
||||
|
||||
.. java:package:: org.pytorch
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
DType
|
||||
IValue
|
||||
Module
|
||||
Tensor
|
||||
Tensor-Tensor_float32
|
||||
Tensor-Tensor_float64
|
||||
Tensor-Tensor_int32
|
||||
Tensor-Tensor_int64
|
||||
Tensor-Tensor_int8
|
||||
Tensor-Tensor_uint8
|
9
docs/source/org/pytorch/torchvision/package-index.rst
Normal file
9
docs/source/org/pytorch/torchvision/package-index.rst
Normal file
@ -0,0 +1,9 @@
|
||||
rg.pytorch.torchvision
|
||||
=======================
|
||||
|
||||
.. java:package:: org.pytorch.torchvision
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
TensorImageUtils
|
29
docs/source/packages.rst
Normal file
29
docs/source/packages.rst
Normal file
@ -0,0 +1,29 @@
|
||||
Javadoc
|
||||
=========
|
||||
|
||||
Main part of the java api includes 3 classes:
|
||||
::
|
||||
org.pytorch.Tensor
|
||||
org.pytorch.IValue
|
||||
org.pytorch.Module
|
||||
|
||||
If the reader is familiar with pytorch python api, we can think that
|
||||
``org.pytorch.Tensor`` represents ``torch.tensor``, ``org.pytorch.Module`` represents ``torch.Module``,
|
||||
while ``org.pytorch.IValue`` represents value of TorchScript variable, supporting all
|
||||
its `types <https://pytorch.org/docs/stable/jit.html#types>`_.
|
||||
|
||||
Additionally, there is the DType class which contains code representing tensor data types and
|
||||
TensorImageUtils which contains utility functions for \ :java:ref:`org.pytorch.Tensor`\ created from \ :java:ref:`android.graphics.Bitmap`\ or \ :java:ref:`android.media.Image`\ source.
|
||||
|
||||
You can find details to each of these classes linked below:
|
||||
|
||||
.. java:package:: org.pytorch
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
org/pytorch/DType
|
||||
org/pytorch/Tensor
|
||||
org/pytorch/IValue
|
||||
org/pytorch/Module
|
||||
org/pytorch/TensorImageUtils
|
583
docs/source/quantization.rst
Normal file
583
docs/source/quantization.rst
Normal file
@ -0,0 +1,583 @@
|
||||
Quantization
|
||||
===========================
|
||||
|
||||
Introduction to Quantization
|
||||
----------------------------
|
||||
|
||||
Quantization refers to techniques for performing computations and storing tensors at lower bitwidths than
|
||||
floating point precision. A quantized model executes some or all of the operations on tensors with
|
||||
integers rather than floating point values. This allows for a more
|
||||
compact model representation and the use of high performance vectorized
|
||||
operations on many hardware platforms. PyTorch supports INT8
|
||||
quantization compared to typical FP32 models allowing for a 4x reduction in the model size and
|
||||
a 4x reduction in memory bandwidth requirements. Hardware support for INT8 computations
|
||||
is typically 2 to 4 times faster compared to FP32 compute. Quantization is primarily a technique
|
||||
to speed up inference and only the forward pass is supported for quantized operators.
|
||||
|
||||
PyTorch supports multiple approaches to quantizing a deep learning model. In most cases the model is trained
|
||||
in FP32 and then the model is converted to INT8. In addition, PyTorch also supports quantization aware
|
||||
training, which models quantization errors in both the forward and backward passes using fake-quantization
|
||||
modules. Note that the entire computation is carried out in floating point. At the end of quantization aware
|
||||
training, PyTorch provides conversion functions to convert the trained model into lower precision.
|
||||
|
||||
At lower level, PyTorch provides a way to represent quantized tensors and
|
||||
perform operations with them. They can be used to directly construct models that
|
||||
perform all or part of the computation in lower precision. Higher-level APIs are
|
||||
provided that incorporate typical workflows of converting FP32 model to lower
|
||||
precision with minimal accuracy loss.
|
||||
|
||||
Today, PyTorch supports the following backends for running quantized operators efficiently:
|
||||
|
||||
* x86 CPUs with AVX2 support or higher (without AVX2 some operations have inefficient implementations)
|
||||
* ARM CPUs (typically found in mobile/embedded devices)
|
||||
|
||||
The corresponding implementation is chosen automatically based on the PyTorch build mode.
|
||||
|
||||
.. note::
|
||||
|
||||
PyTorch 1.3 doesn't provide quantized operator implementations on CUDA yet - this is direction of future work.
|
||||
Move the model to CPU in order to test the quantized functionality.
|
||||
|
||||
Quantization-aware training (through :class:`~torch.quantization.FakeQuantize`) supports both CPU and CUDA.
|
||||
|
||||
Quantized Tensors
|
||||
---------------------------------------
|
||||
|
||||
PyTorch supports both per tensor and per channel asymmetric linear
|
||||
quantization. Per tensor means that all the values within the tensor are
|
||||
scaled the same way. Per channel means that for each dimension, typically
|
||||
the channel dimension of a tensor, the values
|
||||
in the tensor are scaled and offset by a different value (effectively
|
||||
the scale and offset become vectors). This allows for lesser error in converting tensors
|
||||
to quantized values.
|
||||
|
||||
The mapping is performed by converting the floating point tensors using
|
||||
|
||||
.. image:: math-quantizer-equation.png
|
||||
:width: 40%
|
||||
|
||||
Note that, we ensure that zero in floating point is represented with no error after quantization,
|
||||
thereby ensuring that operations like padding do not cause additional quantization error.
|
||||
|
||||
In order to do quantization in PyTorch, we need to be able to represent
|
||||
quantized data in Tensors. A Quantized Tensor allows for storing
|
||||
quantized data (represented as int8/uint8/int32) along with quantization
|
||||
parameters like scale and zero\_point. Quantized Tensors allow for many
|
||||
useful operations making quantized arithmetic easy, in addition to
|
||||
allowing for serialization of data in a quantized format.
|
||||
|
||||
Operation coverage
|
||||
------------------
|
||||
|
||||
Quantized Tensors support a limited subset of data manipulation methods of the regular
|
||||
full-precision tensor. (see list below)
|
||||
|
||||
For NN operators included in PyTorch, we restrict support to:
|
||||
|
||||
1. 8 bit weights (data\_type = qint8)
|
||||
2. 8 bit activations (data\_type = quint8)
|
||||
|
||||
Note that operator implementations currently only
|
||||
support per channel quantization for weights of the **conv** and **linear**
|
||||
operators. Furthermore the minimum and the maximum of the input data is
|
||||
mapped linearly to the minimum and the maximum of the quantized data
|
||||
type such that zero is represented with no quantization error.
|
||||
|
||||
Additional data types and quantization schemes can be implemented through
|
||||
the `custom operator mechanism <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_.
|
||||
|
||||
Many operations for quantized tensors are available under the same API as full
|
||||
float version in ``torch`` or ``torch.nn``. Quantized version of NN modules that
|
||||
perform re-quantization are available in ``torch.nn.quantized``. Those
|
||||
operations explicitly take output quantization parameters (scale and zero\_point) in
|
||||
the operation signature.
|
||||
|
||||
In addition, we also support fused versions corresponding to common fusion patterns that impact quantization at:
|
||||
torch.nn.intrinsic.quantized.
|
||||
|
||||
For quantization aware training, we support modules prepared for quantization aware training at
|
||||
torch.nn.qat and torch.nn.intrinsic.qat
|
||||
|
||||
Current quantized operation list is sufficient to cover typical CNN and RNN
|
||||
models:
|
||||
|
||||
|
||||
Quantized ``torch.Tensor`` operations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Operations that are available from the ``torch`` namespace or as methods on Tensor for quantized tensors:
|
||||
|
||||
* :func:`~torch.quantize_per_tensor` - Convert float tensor to quantized tensor with per-tensor scale and zero point
|
||||
* :func:`~torch.quantize_per_channel` - Convert float tensor to quantized tensor with per-channel scale and zero point
|
||||
* View-based operations like :meth:`~torch.Tensor.view`, :meth:`~torch.Tensor.as_strided`, :meth:`~torch.Tensor.expand`, :meth:`~torch.Tensor.flatten`, :meth:`~torch.Tensor.slice`, python-style indexing, etc - work as on regular tensor (if quantization is not per-channel)
|
||||
* Comparators
|
||||
* :meth:`~torch.Tensor.ne` — Not equal
|
||||
* :meth:`~torch.Tensor.eq` — Equal
|
||||
* :meth:`~torch.Tensor.ge` — Greater or equal
|
||||
* :meth:`~torch.Tensor.le` — Less or equal
|
||||
* :meth:`~torch.Tensor.gt` — Greater
|
||||
* :meth:`~torch.Tensor.lt` — Less
|
||||
* :meth:`~torch.Tensor.copy_` — Copies src to self in-place
|
||||
* :meth:`~torch.Tensor.clone` — Returns a deep copy of the passed-in tensor
|
||||
* :meth:`~torch.Tensor.dequantize` — Convert quantized tensor to float tensor
|
||||
* :meth:`~torch.Tensor.equal` — Compares two tensors, returns true if quantization parameters and all integer elements are the same
|
||||
* :meth:`~torch.Tensor.int_repr` — Prints the underlying integer representation of the quantized tensor
|
||||
* :meth:`~torch.Tensor.max` — Returns the maximum value of the tensor (reduction only)
|
||||
* :meth:`~torch.Tensor.mean` — Mean function. Supported variants: reduction, dim, out
|
||||
* :meth:`~torch.Tensor.min` — Returns the minimum value of the tensor (reduction only)
|
||||
* :meth:`~torch.Tensor.q_scale` — Returns the scale of the per-tensor quantized tensor
|
||||
* :meth:`~torch.Tensor.q_zero_point` — Returns the zero_point of the per-tensor quantized zero point
|
||||
* :meth:`~torch.Tensor.q_per_channel_scales` — Returns the scales of the per-channel quantized tensor
|
||||
* :meth:`~torch.Tensor.q_per_channel_zero_points` — Returns the zero points of the per-channel quantized tensor
|
||||
* :meth:`~torch.Tensor.q_per_channel_axis` — Returns the channel axis of the per-channel quantized tensor
|
||||
* :meth:`~torch.Tensor.relu` — Rectified linear unit (copy)
|
||||
* :meth:`~torch.Tensor.relu_` — Rectified linear unit (inplace)
|
||||
* :meth:`~torch.Tensor.resize_` — In-place resize
|
||||
* :meth:`~torch.Tensor.sort` — Sorts the tensor
|
||||
* :meth:`~torch.Tensor.topk` — Returns k largest values of a tensor
|
||||
|
||||
|
||||
``torch.nn.intrinsic``
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Fused modules are provided for common patterns in CNNs. Combining several operations together (like convolution and relu) allows for better quantization accuracy
|
||||
|
||||
* ``torch.nn.intrinsic`` — float versions of the modules, can be swapped with quantized version 1 to 1
|
||||
* :class:`~torch.nn.intrinsic.ConvBn2d` — Conv2d + BatchNorm
|
||||
* :class:`~torch.nn.intrinsic.ConvBnReLU2d` — Conv2d + BatchNorm + ReLU
|
||||
* :class:`~torch.nn.intrinsic.ConvReLU2d` — Conv2d + Relu
|
||||
* :class:`~torch.nn.intrinsic.LinearReLU` — Linear + ReLU
|
||||
* ``torch.nn.intrinsic.qat`` — versions of layers for quantization-aware training
|
||||
* :class:`~torch.nn.intrinsic.qat.ConvBn2d` — Conv2d + BatchNorm
|
||||
* :class:`~torch.nn.intrinsic.qat.ConvBnReLU2d` — Conv2d + BatchNorm + ReLU
|
||||
* :class:`~torch.nn.intrinsic.qat.ConvReLU2d` — Conv2d + ReLU
|
||||
* :class:`~torch.nn.intrinsic.qat.LinearReLU` — Linear + ReLU
|
||||
* ``torch.nn.intrinsic.quantized`` — quantized version of fused layers for inference (no BatchNorm variants as it's usually folded into convolution for inference)
|
||||
* :class:`~torch.nn.intrinsic.quantized.LinearReLU` — Linear + ReLU
|
||||
* :class:`~torch.nn.intrinsic.quantized.ConvReLU2d` — 2D Convolution + ReLU
|
||||
|
||||
``torch.nn.qat``
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
Layers for the quantization-aware training
|
||||
|
||||
* :class:`~torch.nn.qat.Linear` — Linear (fully-connected) layer
|
||||
* :class:`~torch.nn.qat.Conv2d` — 2D convolution
|
||||
|
||||
``torch.quantization``
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
* Functions for quantization
|
||||
* :func:`~torch.quantization.add_observer_` — Adds observer for the leaf modules (if quantization configuration is provided)
|
||||
* :func:`~torch.quantization.add_quant_dequant`— Wraps the leaf child module using :class:`~torch.quantization.QuantWrapper`
|
||||
* :func:`~torch.quantization.convert` — Converts float module with observers into its quantized counterpart. Must have quantization configuration
|
||||
* :func:`~torch.quantization.get_observer_dict` — Traverses the module children and collects all observers into a ``dict``
|
||||
* :func:`~torch.quantization.prepare` — Prepares a copy of a model for quantization
|
||||
* :func:`~torch.quantization.prepare_qat` — Prepares a copy of a model for quantization aware training
|
||||
* :func:`~torch.quantization.propagate_qconfig_` — Propagates quantization configurations through the module hierarchy and assign them to each leaf module
|
||||
* :func:`~torch.quantization.quantize` — Converts a float module to quantized version
|
||||
* :func:`~torch.quantization.quantize_dynamic` — Converts a float module to dynamically quantized version
|
||||
* :func:`~torch.quantization.quantize_qat`— Converts a float module to quantized version used in quantization aware training
|
||||
* :func:`~torch.quantization.swap_module` — Swaps the module with its quantized counterpart (if quantizable and if it has an observer)
|
||||
* :func:`~torch.quantization.default_eval_fn` — Default evaluation function used by the :func:`torch.quantization.quantize`
|
||||
* :func:`~torch.quantization.fuse_modules`
|
||||
* :class:`~torch.quantization.FakeQuantize` — Module for simulating the quantization/dequantization at training time
|
||||
* Default Observers. The rest of observers are available from ``torch.quantization.observer``
|
||||
* :attr:`~torch.quantization.default_observer` — Same as ``MinMaxObserver.with_args(reduce_range=True)``
|
||||
* :attr:`~torch.quantization.default_weight_observer` — Same as ``MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)``
|
||||
* :class:`~torch.quantization.Observer` — Abstract base class for observers
|
||||
* Quantization configurations
|
||||
* :class:`~torch.quantization.QConfig` — Quantization configuration class
|
||||
* :attr:`~torch.quantization.default_qconfig` — Same as ``QConfig(activation=default_observer, weight=default_weight_observer)`` (See :class:`~torch.quantization.QConfig.QConfig`)
|
||||
* :attr:`~torch.quantization.default_qat_qconfig` — Same as ``QConfig(activation=default_fake_quant, weight=default_weight_fake_quant)`` (See :class:`~torch.quantization.QConfig.QConfig`)
|
||||
* :attr:`~torch.quantization.default_dynamic_qconfig` — Same as ``QConfigDynamic(weight=default_weight_observer)`` (See :class:`~torch.quantization.QConfig.QConfigDynamic`)
|
||||
* :attr:`~torch.quantization.float16_dynamic_qconfig` — Same as ``QConfigDynamic(weight=NoopObserver.with_args(dtype=torch.float16))`` (See :class:`~torch.quantization.QConfig.QConfigDynamic`)
|
||||
* Stubs
|
||||
* :class:`~torch.quantization.DeQuantStub` - placeholder module for dequantize() operation in float-valued models
|
||||
* :class:`~torch.quantization.QuantStub` - placeholder module for quantize() operation in float-valued models
|
||||
* :class:`~torch.quantization.QuantWrapper` — wraps the module to be quantized. Inserts the :class:`~torch.quantization.QuantStub` and :class:`~torch.quantization.DeQuantStub`
|
||||
|
||||
Observers for computing the quantization parameters
|
||||
|
||||
* :class:`~torch.quantization.MinMaxObserver` — Derives the quantization parameters from the running minimum and maximum of the observed tensor inputs (per tensor variant)
|
||||
* :class:`~torch.quantization.MovingAverageObserver` — Derives the quantization parameters from the running averages of the minimums and maximums of the observed tensor inputs (per tensor variant)
|
||||
* :class:`~torch.quantization.PerChannelMinMaxObserver`— Derives the quantization parameters from the running minimum and maximum of the observed tensor inputs (per channel variant)
|
||||
* :class:`~torch.quantization.MovingAveragePerChannelMinMaxObserver` — Derives the quantization parameters from the running averages of the minimums and maximums of the observed tensor inputs (per channel variant)
|
||||
* :class:`~torch.quantization.HistogramObserver` — Derives the quantization parameters by creating a histogram of running minimums and maximums.
|
||||
* Observers that do not compute the quantization parameters:
|
||||
* :class:`~torch.quantization.RecordingObserver` — Records all incoming tensors. Used for debugging only.
|
||||
* :class:`~torch.quantization.NoopObserver` — Pass-through observer. Used for situation when there are no quantization parameters (i.e. quantization to ``float16``)
|
||||
|
||||
``torch.nn.quantized``
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Quantized version of standard NN layers.
|
||||
|
||||
* :class:`~torch.nn.quantized.Quantize` — Quantization layer, used to automatically replace :class:`~torch.quantization.QuantStub`
|
||||
* :class:`~torch.nn.quantized.DeQuantize` — Dequantization layer, used to replace :class:`~torch.quantization.DeQuantStub`
|
||||
* :class:`~torch.nn.quantized.FloatFunctional` — Wrapper class to make stateless float operations stateful so that they can be replaced with quantized versions
|
||||
* :class:`~torch.nn.quantized.QFunctional` — Wrapper class for quantized versions of stateless operations like ```torch.add``
|
||||
* :class:`~torch.nn.quantized.Conv2d` — 2D convolution
|
||||
* :class:`~torch.nn.quantized.Linear` — Linear (fully-connected) layer
|
||||
* :class:`~torch.nn.MaxPool2d` — 2D max pooling
|
||||
* :class:`~torch.nn.quantized.ReLU` — Rectified linear unit
|
||||
* :class:`~torch.nn.quantized.ReLU6` — Rectified linear unit with cut-off at quantized representation of 6
|
||||
|
||||
``torch.nn.quantized.dynamic``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Layers used in dynamically quantized models (i.e. quantized only on weights)
|
||||
|
||||
* :class:`~torch.nn.quantized.dynamic.Linear` — Linear (fully-connected) layer
|
||||
* :class:`~torch.nn.quantized.dynamic.LSTM` — Long-Short Term Memory RNN module
|
||||
|
||||
``torch.nn.quantized.functional``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Functional versions of quantized NN layers (many of them accept explicit quantization output parameters)
|
||||
|
||||
* :func:`~torch.nn.quantized.functional.adaptive_avg_pool2d` — 2D adaptive average pooling
|
||||
* :func:`~torch.nn.quantized.functional.avg_pool2d` — 2D average pooling
|
||||
* :func:`~torch.nn.quantized.functional.conv2d` — 2D convolution
|
||||
* :func:`~torch.nn.quantized.functional.interpolate` — Down-/up- sampler
|
||||
* :func:`~torch.nn.quantized.functional.linear` — Linear (fully-connected) op
|
||||
* :func:`~torch.nn.quantized.functional.max_pool2d` — 2D max pooling
|
||||
* :func:`~torch.nn.quantized.functional.relu` — Rectified linear unit
|
||||
* :func:`~torch.nn.quantized.functional.upsample` — Upsampler. Will be deprecated in favor of :func:`~torch.nn.quantized.functional.interpolate`
|
||||
* :func:`~torch.nn.quantized.functional.upsample_bilinear` — Bilenear upsampler. Will be deprecated in favor of :func:`~torch.nn.quantized.functional.interpolate`
|
||||
* :func:`~torch.nn.quantized.functional.upsample_nearest` — Nearest neighbor upsampler. Will be deprecated in favor of :func:`~torch.nn.quantized.functional.interpolate`
|
||||
|
||||
Quantized dtypes and quantization schemes
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
* :attr:`torch.qscheme` — Type to describe the quantization scheme of a tensor. Supported types:
|
||||
* :attr:`torch.per_tensor_affine` — per tensor, asymmetric
|
||||
* :attr:`torch.per_channel_affine` — per channel, asymmetric
|
||||
* :attr:`torch.per_tensor_symmetric` — per tensor, symmetric
|
||||
* :attr:`torch.per_channel_symmetric` — per tensor, symmetric
|
||||
* ``torch.dtype`` — Type to describe the data. Supported types:
|
||||
* :attr:`torch.quint8` — 8-bit unsigned integer
|
||||
* :attr:`torch.qint8` — 8-bit signed integer
|
||||
* :attr:`torch.qint32` — 32-bit signed integer
|
||||
|
||||
|
||||
|
||||
Quantization Workflows
|
||||
----------------------
|
||||
|
||||
PyTorch provides three approaches to quantize models.
|
||||
|
||||
1. Post Training Dynamic Quantization: This is the simplest to apply form of
|
||||
quantization where the weights are quantized ahead of time but the
|
||||
activations are dynamically quantized during inference. This is used
|
||||
for situations where the model execution time is dominated by loading
|
||||
weights from memory rather than computing the matrix multiplications.
|
||||
This is true for for LSTM and Transformer type models with small
|
||||
batch size. Applying dynamic quantization to a whole model can be
|
||||
done with a single call to :func:`torch.quantization.quantize_dynamic()`.
|
||||
See the `quantization tutorials <https://pytorch.org/tutorials/#quantization-experimental>`_
|
||||
2. Post Training Static Quantization: This is the most commonly used form of
|
||||
quantization where the weights are quantized ahead of time and the
|
||||
scale factor and bias for the activation tensors is pre-computed
|
||||
based on observing the behavior of the model during a calibration
|
||||
process. Post Training Quantization is typically when both memory bandwidth and compute
|
||||
savings are important with CNNs being a typical use case.
|
||||
The general process for doing post training quantization is:
|
||||
|
||||
|
||||
|
||||
1. Prepare the model:
|
||||
a. Specify where the activations are quantized and dequantized explicitly by adding QuantStub and DeQuantStub modules.
|
||||
b. Ensure that modules are not reused.
|
||||
c. Convert any operations that require requantization into modules
|
||||
2. Fuse operations like conv + relu or conv+batchnorm + relu together to improve both model accuracy and performance.
|
||||
|
||||
3. Specify the configuration of the quantization methods \'97 such as
|
||||
selecting symmetric or asymmetric quantization and MinMax or
|
||||
L2Norm calibration techniques.
|
||||
4. Use the :func:`torch.quantization.prepare` to insert modules
|
||||
that will observe activation tensors during calibration
|
||||
5. Calibrate the model by running inference against a calibration
|
||||
dataset
|
||||
6. Finally, convert the model itself with the
|
||||
torch.quantization.convert() method. This does several things: it
|
||||
quantizes the weights, computes and stores the scale and bias
|
||||
value to be used each activation tensor, and replaces key
|
||||
operators quantized implementations.
|
||||
|
||||
See the `quantization tutorials <https://pytorch.org/tutorials/#quantization_experimental>`_
|
||||
|
||||
|
||||
3. Quantization Aware Training: In the rare cases where post training
|
||||
quantization does not provide adequate accuracy training can be done
|
||||
with simulated quantization using the :class:`torch.quantization.FakeQuantize`. Computations
|
||||
will take place in FP32 but with values clamped and rounded to
|
||||
simulate the effects of INT8 quantization. The sequence of steps is
|
||||
very similar.
|
||||
|
||||
|
||||
1. Steps (1) and (2) are identical.
|
||||
|
||||
3. Specify the configuration of the fake quantization methods \'97 such as
|
||||
selecting symmetric or asymmetric quantization and MinMax or Moving Average
|
||||
or L2Norm calibration techniques.
|
||||
4. Use the :func:`torch.quantization.prepare_qat` to insert modules
|
||||
that will simulate quantization during training.
|
||||
5. Train or fine tune the model.
|
||||
6. Identical to step (6) for post training quantization
|
||||
|
||||
See the `quantization tutorials <https://pytorch.org/tutorials/#quantization_experimental>`_
|
||||
|
||||
|
||||
While default implementations of observers to select the scale factor and bias
|
||||
based on observed tensor data are provided, developers can provide their own
|
||||
quantization functions. Quantization can be applied selectively to different
|
||||
parts of the model or configured differently for different parts of the model.
|
||||
|
||||
We also provide support for per channel quantization for **conv2d()**
|
||||
and **linear()**
|
||||
|
||||
Quantization workflows work by adding (e.g. adding observers as
|
||||
``.observer`` submodule) or replacing (e.g. converting ``nn.Conv2d`` to
|
||||
``nn.quantized.Conv2d``) submodules in the model's module hierarchy. It
|
||||
means that the model stays a regular ``nn.Module``-based instance throughout the
|
||||
process and thus can work with the rest of PyTorch APIs.
|
||||
|
||||
|
||||
Model Preparation for Quantization
|
||||
----------------------------------
|
||||
|
||||
It is necessary to currently make some modifications to the model definition
|
||||
prior to quantization. This is because currently quantization works on a module
|
||||
by module basis. Specifically, for all quantization techniques, the user needs to:
|
||||
|
||||
1. Convert any operations that require output requantization (and thus have additional parameters) from functionals to module form.
|
||||
2. Specify which parts of the model need to be quantized either by assigning ```.qconfig`` attributes on submodules or by specifying ``qconfig_dict``
|
||||
|
||||
For static quantization techniques which quantize activations, the user needs to do the following in addition:
|
||||
|
||||
1. Specify where activations are quantized and de-quantized. This is done using :class:`~torch.quantization.QuantStub` and :class:`~torch.quantization.DeQuantStub` modules.
|
||||
2. Use :class:`torch.nn.quantized.FloatFunctional` to wrap tensor operations that require special handling for quantization into modules. Examples
|
||||
are operations like ``add`` and ``cat`` which require special handling to determine output quantization parameters.
|
||||
3. Fuse modules: combine operations/modules into a single module to obtain higher accuracy and performance. This is done using the
|
||||
:func:`torch.quantization.fuse_modules` API, which takes in lists of modules to be fused. We currently support the following fusions:
|
||||
[Conv, Relu], [Conv, BatchNorm], [Conv, BatchNorm, Relu], [Linear, Relu]
|
||||
|
||||
|
||||
torch.quantization
|
||||
---------------------------
|
||||
.. automodule:: torch.quantization
|
||||
|
||||
This module implements the functions you call
|
||||
directly to convert your model from FP32 to quantized form. For
|
||||
example the :func:`~torch.quantization.prepare` is used in post training
|
||||
quantization to prepares your model for the calibration step and
|
||||
:func:`~torch.quantization.convert` actually converts the weights to int8 and
|
||||
replaces the operations with their quantized counterparts. There are
|
||||
other helper functions for things like quantizing the input to your
|
||||
model and performing critical fusions like conv+relu.
|
||||
|
||||
Top-level quantization APIs
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: quantize
|
||||
.. autofunction:: quantize_dynamic
|
||||
.. autofunction:: quantize_qat
|
||||
.. autofunction:: prepare
|
||||
.. autofunction:: prepare_qat
|
||||
.. autofunction:: convert
|
||||
.. autoclass:: QConfig
|
||||
.. autoclass:: QConfigDynamic
|
||||
.. autoattr:: default_qconfig
|
||||
|
||||
Preparing model for quantization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: fuse_modules
|
||||
.. autoclass:: QuantStub
|
||||
.. autoclass:: DeQuantStub
|
||||
.. autoclass:: QuantWrapper
|
||||
.. autofunction:: add_quant_dequant
|
||||
|
||||
Utility functions
|
||||
~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: add_observer_
|
||||
.. autofunction:: swap_module
|
||||
.. autofunction:: propagate_qconfig_
|
||||
.. autofunction:: default_eval_fn
|
||||
|
||||
Observers
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: Observer
|
||||
:members:
|
||||
.. autoclass:: MinMaxObserver
|
||||
.. autoclass:: MovingAverageObserver
|
||||
.. autoclass:: PerChannelMinMaxObserver
|
||||
.. autoclass:: MovingAveragePerChannelMinMaxObserver
|
||||
.. autoclass:: HistogramObserver
|
||||
.. autoclass:: FakeQuantize
|
||||
.. autoclass:: NoopObserver
|
||||
|
||||
Debugging utilities
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: get_observer_dict
|
||||
.. autoclass:: RecordingObserver
|
||||
|
||||
torch.nn.instrinsic
|
||||
--------------------------------
|
||||
|
||||
This module implements the combined (fused) modules conv + relu which can be then quantized.
|
||||
|
||||
.. automodule:: torch.nn.intrinsic
|
||||
|
||||
ConvBn2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ConvBn2d
|
||||
:members:
|
||||
|
||||
ConvBnReLU2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ConvBnReLU2d
|
||||
:members:
|
||||
|
||||
ConvReLU2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ConvReLU2d
|
||||
:members:
|
||||
|
||||
LinearReLU
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: LinearReLU
|
||||
:members:
|
||||
|
||||
torch.nn.instrinsic.qat
|
||||
--------------------------------
|
||||
|
||||
This module implements the versions of those fused operations needed for quantization aware training.
|
||||
|
||||
.. automodule:: torch.nn.intrinsic.qat
|
||||
|
||||
ConvBn2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ConvBn2d
|
||||
:members:
|
||||
|
||||
ConvBnReLU2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ConvBnReLU2d
|
||||
:members:
|
||||
|
||||
ConvReLU2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ConvReLU2d
|
||||
:members:
|
||||
|
||||
LinearReLU
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: LinearReLU
|
||||
:members:
|
||||
|
||||
torch.nn.intrinsic.quantized
|
||||
--------------------------------------
|
||||
|
||||
This module implements the quantized implementations of fused operations like conv + relu.
|
||||
|
||||
.. automodule:: torch.nn.intrinsic.quantized
|
||||
|
||||
ConvReLU2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ConvReLU2d
|
||||
:members:
|
||||
|
||||
LinearReLU
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: LinearReLU
|
||||
:members:
|
||||
|
||||
torch.nn.qat
|
||||
---------------------------
|
||||
|
||||
This module implements versions of the key nn modules **Conv2d()** and **Linear()** which
|
||||
run in FP32 but with rounding applied to simulate the effect of INT8 quantization.
|
||||
|
||||
.. automodule:: torch.nn.qat
|
||||
|
||||
Conv2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: Conv2d
|
||||
:members:
|
||||
|
||||
Linear
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: Linear
|
||||
:members:
|
||||
|
||||
|
||||
torch.nn.quantized
|
||||
----------------------------
|
||||
|
||||
This module implements the quantized versions of the nn layers such as **Conv2d** and **ReLU**.
|
||||
|
||||
Functional interface
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: torch.nn.quantized.functional
|
||||
|
||||
.. autofunction:: relu
|
||||
.. autofunction:: linear
|
||||
.. autofunction:: conv2d
|
||||
.. autofunction:: max_pool2d
|
||||
|
||||
.. automodule:: torch.nn.quantized
|
||||
|
||||
ReLU
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ReLU
|
||||
:members:
|
||||
|
||||
ReLU6
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ReLU6
|
||||
:members:
|
||||
|
||||
Conv2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: Conv2d
|
||||
:members:
|
||||
|
||||
FloatFunctional
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: FloatFunctional
|
||||
:members:
|
||||
|
||||
QFunctional
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: QFunctional
|
||||
:members:
|
||||
|
||||
Quantize
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: Quantize
|
||||
:members:
|
||||
|
||||
DeQuantize
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: DeQuantize
|
||||
:members:
|
||||
|
||||
Linear
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: Linear
|
||||
:members:
|
||||
|
||||
torch.nn.quantized.dynamic
|
||||
----------------------------
|
||||
|
||||
.. automodule:: torch.nn.quantized.dynamic
|
||||
|
||||
Linear
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: Linear
|
||||
:members:
|
||||
|
||||
LSTM
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: LSTM
|
||||
:members:
|
@ -91,5 +91,6 @@ Expected result:
|
||||
.. automethod:: add_pr_curve
|
||||
.. automethod:: add_custom_scalars
|
||||
.. automethod:: add_mesh
|
||||
.. automethod:: add_hparams
|
||||
.. automethod:: flush
|
||||
.. automethod:: close
|
||||
|
@ -304,6 +304,8 @@ view of a storage and defines numeric operations on it.
|
||||
.. automethod:: le_
|
||||
.. automethod:: lerp
|
||||
.. automethod:: lerp_
|
||||
.. automethod:: lgamma
|
||||
.. automethod:: lgamma_
|
||||
.. automethod:: log
|
||||
.. automethod:: log_
|
||||
.. automethod:: logdet
|
||||
@ -363,6 +365,8 @@ view of a storage and defines numeric operations on it.
|
||||
.. automethod:: permute
|
||||
.. automethod:: pin_memory
|
||||
.. automethod:: pinverse
|
||||
.. automethod:: polygamma
|
||||
.. automethod:: polygamma_
|
||||
.. automethod:: pow
|
||||
.. automethod:: pow_
|
||||
.. automethod:: prod
|
||||
|
@ -52,6 +52,8 @@ Creation Ops
|
||||
.. autofunction:: empty_strided
|
||||
.. autofunction:: full
|
||||
.. autofunction:: full_like
|
||||
.. autofunction:: quantize_per_tensor
|
||||
.. autofunction:: quantize_per_channel
|
||||
|
||||
Indexing, Slicing, Joining, Mutating Ops
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -207,6 +209,7 @@ Pointwise Ops
|
||||
.. autofunction:: fmod
|
||||
.. autofunction:: frac
|
||||
.. autofunction:: lerp
|
||||
.. autofunction:: lgamma
|
||||
.. autofunction:: log
|
||||
.. autofunction:: log10
|
||||
.. autofunction:: log1p
|
||||
@ -216,6 +219,7 @@ Pointwise Ops
|
||||
.. autofunction:: mul
|
||||
.. autofunction:: mvlgamma
|
||||
.. autofunction:: neg
|
||||
.. autofunction:: polygamma
|
||||
.. autofunction:: pow
|
||||
.. autofunction:: reciprocal
|
||||
.. autofunction:: remainder
|
||||
|
@ -17,7 +17,7 @@ For Objective-C developers, simply import the umbrella header
|
||||
#import <LibTorch/LibTorch.h>
|
||||
```
|
||||
|
||||
For Swift developers, you need to create an Objective-C class as a bridge to call the C++ APIs. We highly recommend you to follow the [Image Classification](https://github.com/pytorch/examples) demo where you can find out how C++, Objective-C and Swift work together.
|
||||
For Swift developers, you need to create an Objective-C class as a bridge to call the C++ APIs. We highly recommend you to follow the [Image Classification](https://github.com/pytorch/ios-demo-app/tree/master/PyTorchDemo) demo where you can find out how C++, Objective-C and Swift work together.
|
||||
|
||||
### Disable Bitcode
|
||||
|
||||
@ -25,4 +25,4 @@ Since PyTorch is not yet built with bitcode support, you need to disable bitcode
|
||||
|
||||
## LICENSE
|
||||
|
||||
PyTorch is BSD-style licensed, as found in the LICENSE file.
|
||||
PyTorch is BSD-style licensed, as found in the LICENSE file.
|
||||
|
10
setup.py
10
setup.py
@ -265,7 +265,11 @@ try:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# the list of runtime dependencies required by this built package
|
||||
install_requires = []
|
||||
|
||||
if os.getenv('PYTORCH_BUILD_VERSION'):
|
||||
install_requires += ['numpy']
|
||||
assert os.getenv('PYTORCH_BUILD_NUMBER') is not None
|
||||
build_number = int(os.getenv('PYTORCH_BUILD_NUMBER'))
|
||||
version = os.getenv('PYTORCH_BUILD_VERSION')
|
||||
@ -345,10 +349,8 @@ def build_deps():
|
||||
# Building dependent libraries
|
||||
################################################################################
|
||||
|
||||
# the list of runtime dependencies required by this built package
|
||||
install_requires = []
|
||||
|
||||
if sys.version_info <= (2, 7):
|
||||
if sys.version_info <= (3, 0):
|
||||
install_requires += ['future']
|
||||
|
||||
missing_pydep = '''
|
||||
@ -371,8 +373,6 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
||||
cmake_cache_vars = defaultdict(lambda: False, cmake.get_cmake_cache_variables())
|
||||
if cmake_cache_vars['USE_NUMPY']:
|
||||
report('-- Building with NumPy bindings')
|
||||
global install_requires
|
||||
install_requires += ['numpy']
|
||||
else:
|
||||
report('-- NumPy not found')
|
||||
if cmake_cache_vars['USE_CUDNN']:
|
||||
|
@ -14,8 +14,9 @@ import torch.nn.quantized as nnq
|
||||
import torch.nn.quantized.dynamic as nnqd
|
||||
from common_utils import TestCase
|
||||
from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \
|
||||
default_qconfig, QConfig, default_observer, default_weight_observer, \
|
||||
default_qat_qconfig, propagate_qconfig_, convert, DEFAULT_DYNAMIC_MODULE_MAPPING
|
||||
default_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
|
||||
propagate_qconfig_, convert
|
||||
from torch.quantization.default_mappings import DEFAULT_DYNAMIC_MODULE_MAPPING
|
||||
|
||||
def test_only_eval_fn(model, calib_data):
|
||||
r"""
|
||||
@ -214,7 +215,7 @@ class AnnotatedTwoLayerLinearModel(torch.nn.Module):
|
||||
super(AnnotatedTwoLayerLinearModel, self).__init__()
|
||||
self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
|
||||
self.fc2 = QuantWrapper(torch.nn.Linear(8, 5).to(dtype=torch.float))
|
||||
self.fc2.qconfig = default_qconfig
|
||||
self.fc2.qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
@ -252,7 +253,7 @@ class AnnotatedNestedModel(torch.nn.Module):
|
||||
self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
|
||||
self.fc3.qconfig = default_qconfig
|
||||
self.sub2.fc1 = QuantWrapper(self.sub2.fc1)
|
||||
self.sub2.fc1.qconfig = default_qconfig
|
||||
self.sub2.fc1.qconfig = default_per_channel_qconfig
|
||||
|
||||
def forward(self, x):
|
||||
x = self.sub1(x)
|
||||
@ -346,7 +347,7 @@ class QuantStubModel(torch.nn.Module):
|
||||
"""
|
||||
def __init__(self):
|
||||
super(QuantStubModel, self).__init__()
|
||||
self.qconfig = default_qconfig
|
||||
self.qconfig = torch.quantization.get_default_qconfig("qnnpack")
|
||||
self.quant = QuantStub()
|
||||
self.dequant = DeQuantStub()
|
||||
self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
||||
@ -361,7 +362,7 @@ class ManualLinearQATModel(torch.nn.Module):
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ManualLinearQATModel, self).__init__()
|
||||
self.qconfig = default_qat_qconfig
|
||||
self.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
|
||||
self.quant = QuantStub()
|
||||
self.dequant = DeQuantStub()
|
||||
self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
|
||||
@ -379,7 +380,7 @@ class ManualConvLinearQATModel(torch.nn.Module):
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ManualConvLinearQATModel, self).__init__()
|
||||
self.qconfig = default_qat_qconfig
|
||||
self.qconfig = torch.quantization.get_default_qat_qconfig("qnnpack")
|
||||
self.quant = QuantStub()
|
||||
self.dequant = DeQuantStub()
|
||||
self.conv = torch.nn.Conv2d(3, 1, kernel_size=3).to(dtype=torch.float)
|
||||
@ -444,6 +445,38 @@ class ModelForFusion(nn.Module):
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
def __init__(self):
|
||||
super(ConvBNReLU, self).__init__(
|
||||
nn.Conv2d(3, 3, 1, 1, bias=False),
|
||||
nn.BatchNorm2d(3),
|
||||
nn.ReLU(inplace=False)
|
||||
)
|
||||
|
||||
class ModelWithSequentialFusion(nn.Module):
|
||||
def __init__(self):
|
||||
super(ModelWithSequentialFusion, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 3, 1)
|
||||
self.relu1 = nn.ReLU(inplace=False)
|
||||
layers = []
|
||||
for i in range(3):
|
||||
layers.append(ConvBNReLU())
|
||||
self.features = nn.Sequential(*layers)
|
||||
head = [nn.Linear(300, 10), nn.ReLU(inplace=False)]
|
||||
self.classifier = nn.Sequential(*head)
|
||||
self.quant = QuantStub()
|
||||
self.dequant = DeQuantStub()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.quant(x)
|
||||
x = self.conv1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.features(x)
|
||||
x = torch.reshape(x, (-1, 3 * 10 * 10))
|
||||
x = self.classifier(x)
|
||||
x = self.dequant(x)
|
||||
return x
|
||||
|
||||
|
||||
class DummyObserver(torch.nn.Module):
|
||||
def calculate_qparams(self):
|
||||
@ -453,30 +486,27 @@ class DummyObserver(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class ModForWrapping(torch.nn.Module):
|
||||
def __init__(self, quantized=False):
|
||||
super(ModForWrapping, self).__init__()
|
||||
self.qconfig = default_qconfig
|
||||
if quantized:
|
||||
self.mycat = nnq.QFunctional()
|
||||
self.myadd = nnq.QFunctional()
|
||||
else:
|
||||
self.mycat = nnq.FloatFunctional()
|
||||
self.myadd = nnq.FloatFunctional()
|
||||
self.mycat.observer = DummyObserver()
|
||||
self.myadd.observer = DummyObserver()
|
||||
class ModelWithFunctionals(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(ModelWithFunctionals, self).__init__()
|
||||
self.mycat = nnq.FloatFunctional()
|
||||
self.myadd = nnq.FloatFunctional()
|
||||
self.myadd_relu = nnq.FloatFunctional()
|
||||
# Tracing doesnt work yet for c10 ops with scalar inputs
|
||||
# https://github.com/pytorch/pytorch/issues/27097
|
||||
# self.my_scalar_add = nnq.FloatFunctional()
|
||||
# self.my_scalar_mul = nnq.FloatFunctional()
|
||||
|
||||
def forward(self, x):
|
||||
y = self.mycat.cat([x, x, x])
|
||||
z = self.myadd.add(y, y)
|
||||
return z
|
||||
w = self.myadd_relu.add_relu(z, z)
|
||||
# Tracing doesnt work yet for c10 ops with scalar inputs
|
||||
# https://github.com/pytorch/pytorch/issues/27097
|
||||
# w = self.my_scalar_add.add_scalar(w, -0.5)
|
||||
# w = self.my_scalar_mul.mul_scalar(w, 0.5)
|
||||
return w
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod):
|
||||
new_mod = cls(quantized=True)
|
||||
new_mod.mycat = new_mod.mycat.from_float(mod.mycat)
|
||||
new_mod.myadd = new_mod.myadd.from_float(mod.myadd)
|
||||
return new_mod
|
||||
|
||||
class ResNetBase(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -532,3 +562,62 @@ class ModelMultipleOps(torch.nn.Module):
|
||||
out = out.view(-1, 3 * 2 * 2)
|
||||
out = self.fc(out)
|
||||
return out
|
||||
|
||||
# Model to ensure consistency of fake quant with true quant
|
||||
# Average pooling and mean operations are not modelled
|
||||
# accurately with fake-quant so this model does not
|
||||
# contain those operations
|
||||
class ModelMultipleOpsNoAvgPool(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(ModelMultipleOpsNoAvgPool, self).__init__()
|
||||
norm_layer = nn.BatchNorm2d
|
||||
inplanes = 3
|
||||
self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
|
||||
self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
|
||||
self.bn1 = norm_layer(inplanes)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.relu2 = nn.ReLU()
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
self.cat = nn.quantized.FloatFunctional()
|
||||
self.maxpool = nn.MaxPool2d((4, 4))
|
||||
self.fc = nn.Linear(12, 6)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu1(out)
|
||||
skip = self.conv2(x)
|
||||
out = self.skip_add.add(out, skip)
|
||||
out = self.relu2(out)
|
||||
out = self.maxpool(out)
|
||||
out = self.conv2(out)
|
||||
out = torch.nn.functional.max_pool2d(out, 2, 2)
|
||||
out = self.cat.cat([out, out])
|
||||
out = out.view(-1, 3 * 2 * 2)
|
||||
out = self.fc(out)
|
||||
return out
|
||||
|
||||
"""Model to make sure that the observers are not inserted into custom modules.
|
||||
"""
|
||||
class ModelWithNoQconfigPropagation(nn.Module):
|
||||
class ListOutModule(nn.Module):
|
||||
def __init__(self):
|
||||
super(ModelWithNoQconfigPropagation.ListOutModule, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
# returns a list of tensors, not supported by observers
|
||||
return [x]
|
||||
|
||||
def __init__(self):
|
||||
super(ModelWithNoQconfigPropagation, self).__init__()
|
||||
self.fc1 = nn.Linear(5, 5).to(dtype=torch.float)
|
||||
self.quant = QuantStub()
|
||||
self.dequant = DeQuantStub()
|
||||
self.no_quant_module = self.ListOutModule()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.quant(x)
|
||||
x = self.fc1(x)
|
||||
x = self.dequant(x)
|
||||
x = self.no_quant_module(x)
|
||||
return x
|
||||
|
@ -63,10 +63,36 @@ def _calculate_dynamic_qparams(X, dtype):
|
||||
zero_point = min(qmax, zero_point)
|
||||
return [float(scale), int(zero_point)]
|
||||
|
||||
def _calculate_dynamic_per_channel_qparams(X, dtype):
|
||||
"""Calculate the dynamic quantization parameters (scale, zero_point)
|
||||
according to the min and max element of the tensor"""
|
||||
if isinstance(X, torch.Tensor):
|
||||
X = X.numpy()
|
||||
qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max
|
||||
n_levels = qmax - qmin
|
||||
scale = np.zeros(X.shape[0], dtype=np.float64)
|
||||
zero_point = np.zeros(X.shape[0], dtype=np.int64)
|
||||
for i in range(zero_point.shape[0]):
|
||||
min_val = X.min()
|
||||
max_val = X.max()
|
||||
if min_val == max_val:
|
||||
scale[i] = 1.0
|
||||
zero_point[i] = 0
|
||||
else:
|
||||
max_val = max(max_val, 0.0)
|
||||
min_val = min(min_val, 0.0)
|
||||
scale[i] = (max_val - min_val) / n_levels
|
||||
scale[i] = max(scale[i], np.finfo(np.float32).eps)
|
||||
zero_point[i] = qmin - round(min_val / scale[i])
|
||||
zero_point[i] = max(qmin, zero_point[i])
|
||||
zero_point[i] = min(qmax, zero_point[i])
|
||||
|
||||
return scale, zero_point
|
||||
|
||||
@contextmanager
|
||||
def enable_mobile_quantized_engine():
|
||||
def override_quantized_engine(qengine):
|
||||
previous = torch.backends.quantized.engine
|
||||
torch.backends.quantized.engine = 'qnnpack'
|
||||
torch.backends.quantized.engine = qengine
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
@ -185,7 +185,6 @@ TEST_WITH_ASAN = os.getenv('PYTORCH_TEST_WITH_ASAN', '0') == '1'
|
||||
TEST_WITH_TSAN = os.getenv('PYTORCH_TEST_WITH_TSAN', '0') == '1'
|
||||
TEST_WITH_UBSAN = os.getenv('PYTORCH_TEST_WITH_UBSAN', '0') == '1'
|
||||
TEST_WITH_ROCM = os.getenv('PYTORCH_TEST_WITH_ROCM', '0') == '1'
|
||||
TEST_WITH_QNNPACK = os.getenv('PYTORCH_TEST_WITH_QNNPACK', '0') == '1'
|
||||
# Enables tests that are slow to run (disabled by default)
|
||||
TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1'
|
||||
|
||||
|
@ -8,7 +8,7 @@ from hypothesis import strategies as st
|
||||
from hypothesis.extra import numpy as stnp
|
||||
from hypothesis.searchstrategy import SearchStrategy
|
||||
|
||||
from common_quantized import _calculate_dynamic_qparams
|
||||
from common_quantized import _calculate_dynamic_qparams, _calculate_dynamic_per_channel_qparams
|
||||
|
||||
# Setup for the hypothesis tests.
|
||||
# The tuples are (torch_quantized_dtype, zero_point_enforce), where the last
|
||||
@ -182,6 +182,38 @@ def tensor(draw, shapes=None, elements=None, qparams=None):
|
||||
zp = enforced_zp
|
||||
return X, (scale, zp, qparams[2])
|
||||
|
||||
@st.composite
|
||||
def per_channel_tensor(draw, shapes=None, elements=None, qparams=None):
|
||||
if isinstance(shapes, SearchStrategy):
|
||||
_shape = draw(shapes)
|
||||
else:
|
||||
_shape = draw(st.sampled_from(shapes))
|
||||
if qparams is None:
|
||||
if elements is None:
|
||||
elements = st.floats(-1e6, 1e6, allow_nan=False)
|
||||
X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape))
|
||||
assume(not (np.isnan(X).any() or np.isinf(X).any()))
|
||||
return X, None
|
||||
qparams = draw(qparams)
|
||||
if elements is None:
|
||||
min_value, max_value = _get_valid_min_max(qparams)
|
||||
elements = st.floats(min_value, max_value, allow_infinity=False,
|
||||
allow_nan=False)
|
||||
X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape))
|
||||
# Recompute the scale and zero_points according to the X statistics.
|
||||
scale, zp = _calculate_dynamic_per_channel_qparams(X, qparams[2])
|
||||
enforced_zp = _ENFORCED_ZERO_POINT.get(qparams[2], None)
|
||||
if enforced_zp is not None:
|
||||
zp = enforced_zp
|
||||
# Permute to model quantization along an axis
|
||||
axis = int(np.random.randint(0, X.ndim, 1))
|
||||
permute_axes = np.arange(X.ndim)
|
||||
permute_axes[0] = axis
|
||||
permute_axes[axis] = 0
|
||||
X = np.transpose(X, permute_axes)
|
||||
|
||||
return X, (scale, zp, axis, qparams[2])
|
||||
|
||||
"""Strategy for generating test cases for tensors used in Conv2D.
|
||||
The resulting tensors is in float32 format.
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "0"
|
||||
|
@ -1,6 +1,6 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "0"
|
||||
|
@ -1,6 +1,6 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "0"
|
||||
|
@ -1,6 +1,6 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "0"
|
||||
|
@ -1,6 +1,6 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "0"
|
||||
|
@ -1,6 +1,6 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "0"
|
||||
|
@ -1,6 +1,6 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
output: "1"
|
||||
|
@ -1,6 +1,6 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "0"
|
||||
|
@ -1,6 +1,6 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "0"
|
||||
|
@ -1,6 +1,6 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "0"
|
||||
|
@ -1,6 +1,6 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
|
@ -1,6 +1,6 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "0"
|
||||
|
@ -1,6 +1,6 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "0"
|
||||
|
@ -1,6 +1,6 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
output: "3"
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user