mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 14:59:34 +08:00
Compare commits
125 Commits
Author | SHA1 | Date | |
---|---|---|---|
2e62318243 | |||
6e2a8d0625 | |||
75d8dab9be | |||
d6336b0d8b | |||
1c5683a80f | |||
082f195ac9 | |||
7bd7a3d806 | |||
493c900810 | |||
7b2e8c323c | |||
111da77912 | |||
882b2abb80 | |||
1bc7ea17b2 | |||
214a676f6b | |||
c2223df578 | |||
7e95b89980 | |||
8fbefa06f6 | |||
5b5f398dd4 | |||
17b1faa2bf | |||
fe4170bda8 | |||
fc249c7924 | |||
a423817055 | |||
1affa7c32c | |||
27dc595215 | |||
4fb442985a | |||
529c2dfaa9 | |||
803f7bfaac | |||
e33ec3942e | |||
293e35a87c | |||
162ef02db6 | |||
5e776d8a45 | |||
2491dd50ee | |||
d93fc64776 | |||
3eefc5457f | |||
eb5040c205 | |||
21c229f4e1 | |||
bb51980766 | |||
a19b135fab | |||
5835460309 | |||
c6ff830c2f | |||
515e3b85da | |||
5f36ce6792 | |||
33db4e02cb | |||
eeaef217b3 | |||
c864454a8f | |||
5005f7bce7 | |||
f4f6d8dda5 | |||
631e2ee7a4 | |||
5cac738713 | |||
b45f1b9601 | |||
2dff0b6f6a | |||
becf080e4a | |||
8a38a53e4d | |||
00e588290b | |||
0b79f77a4d | |||
bb8983e936 | |||
4abfb5493e | |||
d125a83f98 | |||
7bbb2df6d9 | |||
b805b5dab8 | |||
38001fec1b | |||
8e4031ddcf | |||
e09c97b113 | |||
98c02e6df3 | |||
14d29aeece | |||
f742ceaa46 | |||
6a09676442 | |||
1d2d59dd79 | |||
46539eee03 | |||
ec07d144ba | |||
48cd66cd2b | |||
4086f7432b | |||
18eea8269a | |||
5ccecabb04 | |||
3099732017 | |||
85abfc1992 | |||
dddae3f854 | |||
6a4ca9abec | |||
209dc4c4ba | |||
ea414e4990 | |||
1d4d6b6f0f | |||
9e3ba35500 | |||
bdcaf6334b | |||
ec7913afbd | |||
990f4ca76d | |||
bf7ebc5a53 | |||
5e7549c3b5 | |||
fb8fd6dc73 | |||
3e20d9c0dc | |||
47cd15c643 | |||
27d4b34ea6 | |||
9159a601ca | |||
e75a49341b | |||
585a5975fb | |||
4a8d899e18 | |||
73347e0ae5 | |||
b16358b251 | |||
e08338738c | |||
e367f605cd | |||
a252aee8c2 | |||
d91e490a9f | |||
3d41ad507a | |||
1afe3fc01e | |||
1eaa9f89fb | |||
275e0c1c8f | |||
d5298b6e66 | |||
32b0e8c980 | |||
84ee8ace12 | |||
7dc7075795 | |||
7c465aae32 | |||
2bcbc15023 | |||
a1513dced3 | |||
1a3997e0b8 | |||
4d7bec5f3e | |||
ee2c79d699 | |||
9e1cf99a16 | |||
2ccbdb79c8 | |||
a31fd5ea68 | |||
7d58060f49 | |||
18dd541b59 | |||
a173bea425 | |||
9f9ba3a900 | |||
310ebb45a8 | |||
8858f42aa4 | |||
0cd188035a | |||
cca3a369f1 |
@ -14,7 +14,7 @@ from dataclasses import dataclass
|
||||
|
||||
DOCKER_IMAGE_PATH_BASE = "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/"
|
||||
|
||||
DOCKER_IMAGE_VERSION = 324
|
||||
DOCKER_IMAGE_VERSION = 315
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -308,16 +308,16 @@ jobs:
|
||||
export id=$(docker run -t -d -w /var/lib/jenkins ${DOCKER_IMAGE})
|
||||
|
||||
# TODO We may want to move the rebase logic to a separate step after checkout
|
||||
# Rebase to v1.3.0 only if in xenial_py3_6_gcc5_4 case
|
||||
if [[ "${CIRCLE_BRANCH}" != "v1.3.0" && "${BUILD_ENVIRONMENT}" == *"gcc5"* ]]; then
|
||||
echo "Merge v1.3.0 branch into $CIRCLE_BRANCH before build in environment $BUILD_ENVIRONMENT"
|
||||
# Rebase to master only if in xenial_py3_6_gcc5_4 case
|
||||
if [[ "${CIRCLE_BRANCH}" != "master" && "${BUILD_ENVIRONMENT}" == *"gcc5"* ]]; then
|
||||
echo "Merge master branch into $CIRCLE_BRANCH before build in environment $BUILD_ENVIRONMENT"
|
||||
set -x
|
||||
git config --global user.email "circleci.ossci@gmail.com"
|
||||
git config --global user.name "CircleCI"
|
||||
git config remote.origin.url https://github.com/pytorch/pytorch.git
|
||||
git config --add remote.origin.fetch +refs/heads/v1.3.0:refs/remotes/origin/v1.3.0
|
||||
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/v1.3.0:refs/remotes/origin/v1.3.0 --depth=50 --quiet
|
||||
export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/v1.3.0`
|
||||
git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master
|
||||
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=50 --quiet
|
||||
export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/master`
|
||||
echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET}
|
||||
export GIT_COMMIT=${CIRCLE_SHA1}
|
||||
echo "GIT_COMMIT: " ${GIT_COMMIT}
|
||||
@ -326,7 +326,7 @@ jobs:
|
||||
git merge --no-edit --no-ff ${GIT_MERGE_TARGET}
|
||||
set +x
|
||||
else
|
||||
echo "Do NOT merge v1.3.0 branch into $CIRCLE_BRANCH in environment $BUILD_ENVIRONMENT"
|
||||
echo "Do NOT merge master branch into $CIRCLE_BRANCH in environment $BUILD_ENVIRONMENT"
|
||||
fi
|
||||
|
||||
git submodule sync && git submodule update -q --init --recursive
|
||||
@ -551,7 +551,12 @@ jobs:
|
||||
# Reinitialize path (see man page for path_helper(8))
|
||||
eval `/usr/libexec/path_helper -s`
|
||||
|
||||
export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH
|
||||
# Use Homebrew Python if configured to do so
|
||||
if [ "${PYTHON_INSTALLATION}" == "homebrew" ]; then
|
||||
export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH
|
||||
fi
|
||||
|
||||
pip -q install numpy
|
||||
|
||||
# Install Anaconda if we need to
|
||||
if [ -n "${CAFFE2_USE_ANACONDA}" ]; then
|
||||
@ -564,8 +569,6 @@ jobs:
|
||||
source ${TMPDIR}/anaconda/bin/activate
|
||||
fi
|
||||
|
||||
pip -q install numpy
|
||||
|
||||
# Install sccache
|
||||
sudo curl https://s3.amazonaws.com/ossci-macos/sccache --output /usr/local/bin/sccache
|
||||
sudo chmod +x /usr/local/bin/sccache
|
||||
@ -603,6 +606,7 @@ jobs:
|
||||
if which sccache > /dev/null; then
|
||||
sccache --show-stats
|
||||
fi
|
||||
|
||||
binary_linux_build:
|
||||
<<: *binary_linux_build_params
|
||||
steps:
|
||||
@ -961,11 +965,11 @@ jobs:
|
||||
export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/master master site") | docker exec -u jenkins -i "$id" bash) 2>&1'
|
||||
|
||||
# stable release docs push. Due to some circleci limitations, we keep
|
||||
# an eternal PR open for merging v1.3.0 -> master for this job.
|
||||
# XXX: The following code is only run on the v1.3.0 branch, which might
|
||||
# an eternal PR open for merging v1.2.0 -> master for this job.
|
||||
# XXX: The following code is only run on the v1.2.0 branch, which might
|
||||
# not be exactly the same as what you see here.
|
||||
elif [[ "${CIRCLE_BRANCH}" == "v1.3.0" ]]; then
|
||||
export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/stable 1.3.0 site-v1.3.0 dry_run") | docker exec -u jenkins -i "$id" bash) 2>&1'
|
||||
elif [[ "${CIRCLE_BRANCH}" == "v1.2.0" ]]; then
|
||||
export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/stable 1.2.0 site dry_run") | docker exec -u jenkins -i "$id" bash) 2>&1'
|
||||
|
||||
# For open PRs: Do a dry_run of the docs build, don't push build
|
||||
else
|
||||
@ -1930,7 +1934,7 @@ workflows:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-gcc4.8-ubuntu14.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:315"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_py2_gcc4_8_ubuntu14_04_test
|
||||
requires:
|
||||
@ -1942,7 +1946,7 @@ workflows:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-gcc4.8-ubuntu14.04-test"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:315"
|
||||
resource_class: large
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py2_cuda9_0_cudnn7_ubuntu16_04_build
|
||||
@ -1954,7 +1958,7 @@ workflows:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-cuda9.0-cudnn7-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_py2_cuda9_0_cudnn7_ubuntu16_04_test
|
||||
requires:
|
||||
@ -1967,14 +1971,14 @@ workflows:
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-cuda9.0-cudnn7-ubuntu16.04-test"
|
||||
use_cuda_docker_runtime: "1"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315"
|
||||
resource_class: gpu.medium
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_cmake_cuda9_0_cudnn7_ubuntu16_04_build
|
||||
requires:
|
||||
- setup
|
||||
build_environment: "caffe2-cmake-cuda9.0-cudnn7-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_cmake_cuda9_0_cudnn7_ubuntu16_04_test
|
||||
requires:
|
||||
@ -1982,14 +1986,14 @@ workflows:
|
||||
- caffe2_cmake_cuda9_0_cudnn7_ubuntu16_04_build
|
||||
build_environment: "caffe2-cmake-cuda9.0-cudnn7-ubuntu16.04-test"
|
||||
use_cuda_docker_runtime: "1"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315"
|
||||
resource_class: gpu.medium
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py3_5_cuda10_1_cudnn7_ubuntu16_04_build
|
||||
requires:
|
||||
- setup
|
||||
build_environment: "caffe2-py3.5-cuda10.1-cudnn7-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.5-cuda10.1-cudnn7-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.5-cuda10.1-cudnn7-ubuntu16.04:315"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_py3_5_cuda10_1_cudnn7_ubuntu16_04_test
|
||||
requires:
|
||||
@ -1997,35 +2001,35 @@ workflows:
|
||||
- caffe2_py3_5_cuda10_1_cudnn7_ubuntu16_04_build
|
||||
build_environment: "caffe2-py3.5-cuda10.1-cudnn7-ubuntu16.04-test"
|
||||
use_cuda_docker_runtime: "1"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.5-cuda10.1-cudnn7-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.5-cuda10.1-cudnn7-ubuntu16.04:315"
|
||||
resource_class: gpu.medium
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py2_mkl_ubuntu16_04_build
|
||||
requires:
|
||||
- setup
|
||||
build_environment: "caffe2-py2-mkl-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:315"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_py2_mkl_ubuntu16_04_test
|
||||
requires:
|
||||
- setup
|
||||
- caffe2_py2_mkl_ubuntu16_04_build
|
||||
build_environment: "caffe2-py2-mkl-ubuntu16.04-test"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:315"
|
||||
resource_class: large
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_onnx_py2_gcc5_ubuntu16_04_build
|
||||
requires:
|
||||
- setup
|
||||
build_environment: "caffe2-onnx-py2-gcc5-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:315"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_onnx_py2_gcc5_ubuntu16_04_test
|
||||
requires:
|
||||
- setup
|
||||
- caffe2_onnx_py2_gcc5_ubuntu16_04_build
|
||||
build_environment: "caffe2-onnx-py2-gcc5-ubuntu16.04-test"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:315"
|
||||
resource_class: large
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py2_clang3_8_ubuntu16_04_build
|
||||
@ -2037,7 +2041,7 @@ workflows:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-clang3.8-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.8-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.8-ubuntu16.04:315"
|
||||
build_only: "1"
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py2_clang3_9_ubuntu16_04_build
|
||||
@ -2049,35 +2053,35 @@ workflows:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-clang3.9-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.9-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.9-ubuntu16.04:315"
|
||||
build_only: "1"
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py2_clang7_ubuntu16_04_build
|
||||
requires:
|
||||
- setup
|
||||
build_environment: "caffe2-py2-clang7-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang7-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang7-ubuntu16.04:315"
|
||||
build_only: "1"
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_onnx_py3_6_clang7_ubuntu16_04_build
|
||||
requires:
|
||||
- setup
|
||||
build_environment: "caffe2-onnx-py3.6-clang7-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:315"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_onnx_py3_6_clang7_ubuntu16_04_test
|
||||
requires:
|
||||
- setup
|
||||
- caffe2_onnx_py3_6_clang7_ubuntu16_04_build
|
||||
build_environment: "caffe2-onnx-py3.6-clang7-ubuntu16.04-test"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:315"
|
||||
resource_class: large
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py2_android_ubuntu16_04_build
|
||||
requires:
|
||||
- setup
|
||||
build_environment: "caffe2-py2-android-ubuntu16.04-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-android-ubuntu16.04:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-android-ubuntu16.04:315"
|
||||
build_only: "1"
|
||||
- caffe2_linux_build:
|
||||
name: caffe2_py2_cuda9_0_cudnn7_centos7_build
|
||||
@ -2089,7 +2093,7 @@ workflows:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-cuda9.0-cudnn7-centos7-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:315"
|
||||
- caffe2_linux_test:
|
||||
name: caffe2_py2_cuda9_0_cudnn7_centos7_test
|
||||
requires:
|
||||
@ -2102,7 +2106,7 @@ workflows:
|
||||
- /ci-all\/.*/
|
||||
build_environment: "caffe2-py2-cuda9.0-cudnn7-centos7-test"
|
||||
use_cuda_docker_runtime: "1"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:324"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:315"
|
||||
resource_class: gpu.medium
|
||||
- caffe2_macos_build:
|
||||
name: caffe2_py2_ios_macos10_13_build
|
||||
|
@ -146,7 +146,12 @@
|
||||
# Reinitialize path (see man page for path_helper(8))
|
||||
eval `/usr/libexec/path_helper -s`
|
||||
|
||||
export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH
|
||||
# Use Homebrew Python if configured to do so
|
||||
if [ "${PYTHON_INSTALLATION}" == "homebrew" ]; then
|
||||
export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH
|
||||
fi
|
||||
|
||||
pip -q install numpy
|
||||
|
||||
# Install Anaconda if we need to
|
||||
if [ -n "${CAFFE2_USE_ANACONDA}" ]; then
|
||||
@ -159,8 +164,6 @@
|
||||
source ${TMPDIR}/anaconda/bin/activate
|
||||
fi
|
||||
|
||||
pip -q install numpy
|
||||
|
||||
# Install sccache
|
||||
sudo curl https://s3.amazonaws.com/ossci-macos/sccache --output /usr/local/bin/sccache
|
||||
sudo chmod +x /usr/local/bin/sccache
|
||||
@ -198,3 +201,4 @@
|
||||
if which sccache > /dev/null; then
|
||||
sccache --show-stats
|
||||
fi
|
||||
|
||||
|
@ -61,11 +61,11 @@
|
||||
export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/master master site") | docker exec -u jenkins -i "$id" bash) 2>&1'
|
||||
|
||||
# stable release docs push. Due to some circleci limitations, we keep
|
||||
# an eternal PR open for merging v1.3.0 -> master for this job.
|
||||
# XXX: The following code is only run on the v1.3.0 branch, which might
|
||||
# an eternal PR open for merging v1.2.0 -> master for this job.
|
||||
# XXX: The following code is only run on the v1.2.0 branch, which might
|
||||
# not be exactly the same as what you see here.
|
||||
elif [[ "${CIRCLE_BRANCH}" == "v1.3.0" ]]; then
|
||||
export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/stable 1.3.0 site-v1.3.0 dry_run") | docker exec -u jenkins -i "$id" bash) 2>&1'
|
||||
elif [[ "${CIRCLE_BRANCH}" == "v1.2.0" ]]; then
|
||||
export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/stable 1.2.0 site dry_run") | docker exec -u jenkins -i "$id" bash) 2>&1'
|
||||
|
||||
# For open PRs: Do a dry_run of the docs build, don't push build
|
||||
else
|
||||
|
@ -20,16 +20,16 @@ jobs:
|
||||
export id=$(docker run -t -d -w /var/lib/jenkins ${DOCKER_IMAGE})
|
||||
|
||||
# TODO We may want to move the rebase logic to a separate step after checkout
|
||||
# Rebase to v1.3.0 only if in xenial_py3_6_gcc5_4 case
|
||||
if [[ "${CIRCLE_BRANCH}" != "v1.3.0" && "${BUILD_ENVIRONMENT}" == *"gcc5"* ]]; then
|
||||
echo "Merge v1.3.0 branch into $CIRCLE_BRANCH before build in environment $BUILD_ENVIRONMENT"
|
||||
# Rebase to master only if in xenial_py3_6_gcc5_4 case
|
||||
if [[ "${CIRCLE_BRANCH}" != "master" && "${BUILD_ENVIRONMENT}" == *"gcc5"* ]]; then
|
||||
echo "Merge master branch into $CIRCLE_BRANCH before build in environment $BUILD_ENVIRONMENT"
|
||||
set -x
|
||||
git config --global user.email "circleci.ossci@gmail.com"
|
||||
git config --global user.name "CircleCI"
|
||||
git config remote.origin.url https://github.com/pytorch/pytorch.git
|
||||
git config --add remote.origin.fetch +refs/heads/v1.3.0:refs/remotes/origin/v1.3.0
|
||||
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/v1.3.0:refs/remotes/origin/v1.3.0 --depth=50 --quiet
|
||||
export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/v1.3.0`
|
||||
git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master
|
||||
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=50 --quiet
|
||||
export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/master`
|
||||
echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET}
|
||||
export GIT_COMMIT=${CIRCLE_SHA1}
|
||||
echo "GIT_COMMIT: " ${GIT_COMMIT}
|
||||
@ -38,7 +38,7 @@ jobs:
|
||||
git merge --no-edit --no-ff ${GIT_MERGE_TARGET}
|
||||
set +x
|
||||
else
|
||||
echo "Do NOT merge v1.3.0 branch into $CIRCLE_BRANCH in environment $BUILD_ENVIRONMENT"
|
||||
echo "Do NOT merge master branch into $CIRCLE_BRANCH in environment $BUILD_ENVIRONMENT"
|
||||
fi
|
||||
|
||||
git submodule sync && git submodule update -q --init --recursive
|
||||
|
@ -168,6 +168,9 @@ else
|
||||
python setup.py install
|
||||
fi
|
||||
|
||||
# TODO: I'm not sure why, but somehow we lose verbose commands
|
||||
set -x
|
||||
|
||||
if which sccache > /dev/null; then
|
||||
echo 'PyTorch Build Statistics'
|
||||
sccache --show-stats
|
||||
@ -205,19 +208,20 @@ if [[ "$BUILD_TEST_LIBTORCH" == "1" ]]; then
|
||||
pushd ../cpp-build/caffe2
|
||||
WERROR=1 VERBOSE=1 DEBUG=1 python $BUILD_LIBTORCH_PY
|
||||
popd
|
||||
|
||||
# Build custom operator tests.
|
||||
CUSTOM_OP_BUILD="$PWD/../custom-op-build"
|
||||
CUSTOM_OP_TEST="$PWD/test/custom_operator"
|
||||
SITE_PACKAGES="$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')"
|
||||
mkdir "$CUSTOM_OP_BUILD"
|
||||
pushd "$CUSTOM_OP_BUILD"
|
||||
CMAKE_PREFIX_PATH="$SITE_PACKAGES/torch" cmake "$CUSTOM_OP_TEST"
|
||||
make VERBOSE=1
|
||||
popd
|
||||
assert_git_not_dirty
|
||||
fi
|
||||
|
||||
# Build custom operator tests.
|
||||
CUSTOM_OP_BUILD="$PWD/../custom-op-build"
|
||||
CUSTOM_OP_TEST="$PWD/test/custom_operator"
|
||||
python --version
|
||||
SITE_PACKAGES="$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')"
|
||||
mkdir "$CUSTOM_OP_BUILD"
|
||||
pushd "$CUSTOM_OP_BUILD"
|
||||
cmake "$CUSTOM_OP_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch" -DPYTHON_EXECUTABLE="$(which python)"
|
||||
make VERBOSE=1
|
||||
popd
|
||||
assert_git_not_dirty
|
||||
|
||||
# Test XLA build
|
||||
if [[ "${BUILD_ENVIRONMENT}" == *xla* ]]; then
|
||||
# TODO: Move this to Dockerfile.
|
||||
|
@ -138,24 +138,23 @@ test_torchvision() {
|
||||
}
|
||||
|
||||
test_libtorch() {
|
||||
if [[ "$BUILD_TEST_LIBTORCH" == "1" ]]; then
|
||||
if [[ "$BUILD_ENVIRONMENT" != *rocm* ]]; then
|
||||
echo "Testing libtorch"
|
||||
python test/cpp/jit/tests_setup.py setup
|
||||
CPP_BUILD="$PWD/../cpp-build"
|
||||
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then
|
||||
"$CPP_BUILD"/caffe2/build/bin/test_jit
|
||||
build/bin/test_jit
|
||||
else
|
||||
"$CPP_BUILD"/caffe2/build/bin/test_jit "[cpu]"
|
||||
build/bin/test_jit "[cpu]"
|
||||
fi
|
||||
python test/cpp/jit/tests_setup.py shutdown
|
||||
python tools/download_mnist.py --quiet -d test/cpp/api/mnist
|
||||
OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" "$CPP_BUILD"/caffe2/build/bin/test_api
|
||||
OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" build/bin/test_api
|
||||
assert_git_not_dirty
|
||||
fi
|
||||
}
|
||||
|
||||
test_custom_script_ops() {
|
||||
if [[ "$BUILD_TEST_LIBTORCH" == "1" ]]; then
|
||||
if [[ "$BUILD_ENVIRONMENT" != *rocm* ]] && [[ "$BUILD_ENVIRONMENT" != *asan* ]] ; then
|
||||
echo "Testing custom script operators"
|
||||
CUSTOM_OP_BUILD="$PWD/../custom-op-build"
|
||||
pushd test/custom_operator
|
||||
|
@ -1,6 +1,6 @@
|
||||
ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64
|
||||
|
||||
VERSION_NAME=1.3.0
|
||||
VERSION_NAME=0.0.7-SNAPSHOT
|
||||
GROUP=org.pytorch
|
||||
MAVEN_GROUP=org.pytorch
|
||||
POM_URL=https://github.com/pytorch/pytorch/tree/master/android
|
||||
|
@ -1,30 +0,0 @@
|
||||
package org.pytorch;
|
||||
|
||||
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||
import com.facebook.soloader.nativeloader.SystemDelegate;
|
||||
import org.junit.BeforeClass;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.Objects;
|
||||
|
||||
public class PytorchHostTests extends PytorchTestBase {
|
||||
@BeforeClass
|
||||
public static void setUpClass() {
|
||||
if (!NativeLoader.isInitialized()) {
|
||||
NativeLoader.init(new SystemDelegate());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String assetFilePath(String assetName) throws IOException {
|
||||
Path tempFile = Files.createTempFile("test", ".pt");
|
||||
try (InputStream resource = Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) {
|
||||
Files.copy(resource, tempFile, StandardCopyOption.REPLACE_EXISTING);
|
||||
}
|
||||
return tempFile.toAbsolutePath().toString();
|
||||
}
|
||||
}
|
@ -2,6 +2,8 @@ package org.pytorch;
|
||||
|
||||
import android.content.Context;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
import java.io.File;
|
||||
@ -9,15 +11,294 @@ import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import androidx.test.platform.app.InstrumentationRegistry;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class PytorchInstrumentedTests extends PytorchTestBase {
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
@Override
|
||||
protected String assetFilePath(String assetName) throws IOException {
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class PytorchInstrumentedTests {
|
||||
|
||||
private static final String TEST_MODULE_ASSET_NAME = "test.pt";
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
System.loadLibrary("pytorch");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testForwardNull() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final IValue input =
|
||||
IValue.tensor(Tensor.newInt8Tensor(new long[] {1}, Tensor.allocateByteBuffer(1)));
|
||||
assertTrue(input.isTensor());
|
||||
final IValue output = module.forward(input);
|
||||
assertTrue(output.isNull());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqBool() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
for (boolean value : new boolean[] {false, true}) {
|
||||
final IValue input = IValue.bool(value);
|
||||
assertTrue(input.isBool());
|
||||
assertTrue(value == input.getBool());
|
||||
final IValue output = module.runMethod("eqBool", input);
|
||||
assertTrue(output.isBool());
|
||||
assertTrue(value == output.getBool());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqInt() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
|
||||
final IValue input = IValue.long64(value);
|
||||
assertTrue(input.isLong());
|
||||
assertTrue(value == input.getLong());
|
||||
final IValue output = module.runMethod("eqInt", input);
|
||||
assertTrue(output.isLong());
|
||||
assertTrue(value == output.getLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqFloat() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
double[] values =
|
||||
new double[] {
|
||||
-Double.MAX_VALUE,
|
||||
Double.MAX_VALUE,
|
||||
-Double.MIN_VALUE,
|
||||
Double.MIN_VALUE,
|
||||
-Math.exp(1.d),
|
||||
-Math.sqrt(2.d),
|
||||
-3.1415f,
|
||||
3.1415f,
|
||||
-1,
|
||||
0,
|
||||
1,
|
||||
};
|
||||
for (double value : values) {
|
||||
final IValue input = IValue.double64(value);
|
||||
assertTrue(input.isDouble());
|
||||
assertTrue(value == input.getDouble());
|
||||
final IValue output = module.runMethod("eqFloat", input);
|
||||
assertTrue(output.isDouble());
|
||||
assertTrue(value == output.getDouble());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqTensor() throws IOException {
|
||||
final long[] inputTensorShape = new long[] {1, 3, 224, 224};
|
||||
final long numElements = Tensor.numel(inputTensorShape);
|
||||
final float[] inputTensorData = new float[(int) numElements];
|
||||
for (int i = 0; i < numElements; ++i) {
|
||||
inputTensorData[i] = i;
|
||||
}
|
||||
final Tensor inputTensor = Tensor.newFloat32Tensor(inputTensorShape, inputTensorData);
|
||||
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final IValue input = IValue.tensor(inputTensor);
|
||||
assertTrue(input.isTensor());
|
||||
assertTrue(inputTensor == input.getTensor());
|
||||
final IValue output = module.runMethod("eqTensor", input);
|
||||
assertTrue(output.isTensor());
|
||||
final Tensor outputTensor = output.getTensor();
|
||||
assertNotNull(outputTensor);
|
||||
assertArrayEquals(inputTensorShape, outputTensor.shape);
|
||||
float[] outputData = outputTensor.getDataAsFloatArray();
|
||||
for (int i = 0; i < numElements; i++) {
|
||||
assertTrue(inputTensorData[i] == outputData[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictIntKeyIntValue() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Map<Long, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put(Long.MIN_VALUE, IValue.long64(-Long.MIN_VALUE));
|
||||
inputMap.put(Long.MAX_VALUE, IValue.long64(-Long.MAX_VALUE));
|
||||
inputMap.put(0l, IValue.long64(0l));
|
||||
inputMap.put(1l, IValue.long64(-1l));
|
||||
inputMap.put(-1l, IValue.long64(1l));
|
||||
|
||||
final IValue input = IValue.dictLongKey(inputMap);
|
||||
assertTrue(input.isDictLongKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
|
||||
assertTrue(output.isDictLongKey());
|
||||
|
||||
final Map<Long, IValue> outputMap = output.getDictLongKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).getLong() == entry.getValue().getLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictStrKeyIntValue() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Map<String, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put("long_min_value", IValue.long64(Long.MIN_VALUE));
|
||||
inputMap.put("long_max_value", IValue.long64(Long.MAX_VALUE));
|
||||
inputMap.put("long_0", IValue.long64(0l));
|
||||
inputMap.put("long_1", IValue.long64(1l));
|
||||
inputMap.put("long_-1", IValue.long64(-1l));
|
||||
|
||||
final IValue input = IValue.dictStringKey(inputMap);
|
||||
assertTrue(input.isDictStringKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
|
||||
assertTrue(output.isDictStringKey());
|
||||
|
||||
final Map<String, IValue> outputMap = output.getDictStringKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).getLong() == entry.getValue().getLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testListIntSumReturnTuple() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
|
||||
for (int n : new int[] {0, 1, 128}) {
|
||||
long[] a = new long[n];
|
||||
long sum = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
a[i] = i;
|
||||
sum += a[i];
|
||||
}
|
||||
final IValue input = IValue.longList(a);
|
||||
assertTrue(input.isLongList());
|
||||
|
||||
final IValue output = module.runMethod("listIntSumReturnTuple", input);
|
||||
|
||||
assertTrue(output.isTuple());
|
||||
assertTrue(2 == output.getTuple().length);
|
||||
|
||||
IValue output0 = output.getTuple()[0];
|
||||
IValue output1 = output.getTuple()[1];
|
||||
|
||||
assertArrayEquals(a, output0.getLongList());
|
||||
assertTrue(sum == output1.getLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOptionalIntIsNone() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
|
||||
assertFalse(module.runMethod("optionalIntIsNone", IValue.long64(1l)).getBool());
|
||||
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).getBool());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIntEq0None() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
|
||||
assertTrue(module.runMethod("intEq0None", IValue.long64(0l)).isNull());
|
||||
assertTrue(module.runMethod("intEq0None", IValue.long64(1l)).getLong() == 1l);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testRunUndefinedMethod() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
module.runMethod("test_undefined_method_throws_exception");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorMethods() {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
int[] ints = new int[numel];
|
||||
float[] floats = new float[numel];
|
||||
|
||||
byte[] bytes = new byte[numel];
|
||||
for (int i = 0; i < numel; i++) {
|
||||
bytes[i] = (byte) ((i % 255) - 128);
|
||||
ints[i] = i;
|
||||
floats[i] = i / 1000.f;
|
||||
}
|
||||
|
||||
Tensor tensorBytes = Tensor.newInt8Tensor(shape, bytes);
|
||||
assertTrue(tensorBytes.dtype() == Tensor.DTYPE_INT8);
|
||||
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
|
||||
|
||||
Tensor tensorInts = Tensor.newInt32Tensor(shape, ints);
|
||||
assertTrue(tensorInts.dtype() == Tensor.DTYPE_INT32);
|
||||
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
|
||||
|
||||
Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats);
|
||||
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
|
||||
float[] floatsOut = tensorFloats.getDataAsFloatArray();
|
||||
assertTrue(floatsOut.length == numel);
|
||||
for (int i = 0; i < numel; i++) {
|
||||
assertTrue(floats[i] == floatsOut[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void testTensorIllegalStateOnWrongType() {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
float[] floats = new float[numel];
|
||||
Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats);
|
||||
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
|
||||
tensorFloats.getDataAsByteArray();
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testEqString() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.string(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.getString()));
|
||||
final IValue output = module.runMethod("eqStr", input);
|
||||
assertTrue(output.isString());
|
||||
assertTrue(value.equals(output.getString()));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStr3Concat() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.string(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.getString()));
|
||||
final IValue output = module.runMethod("str3Concat", input);
|
||||
assertTrue(output.isString());
|
||||
String expectedOutput = new StringBuilder().append(value).append(value).append(value).toString();
|
||||
assertTrue(expectedOutput.equals(output.getString()));
|
||||
}
|
||||
}
|
||||
|
||||
private static String assetFilePath(String assetName) throws IOException {
|
||||
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
|
||||
File file = new File(appContext.getFilesDir(), assetName);
|
||||
if (file.exists() && file.length() > 0) {
|
||||
|
@ -1,290 +0,0 @@
|
||||
package org.pytorch;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public abstract class PytorchTestBase {
|
||||
private static final String TEST_MODULE_ASSET_NAME = "test.pt";
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
System.loadLibrary("pytorch");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testForwardNull() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final IValue input =
|
||||
IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
|
||||
assertTrue(input.isTensor());
|
||||
final IValue output = module.forward(input);
|
||||
assertTrue(output.isNull());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqBool() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
for (boolean value : new boolean[] {false, true}) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isBool());
|
||||
assertTrue(value == input.toBool());
|
||||
final IValue output = module.runMethod("eqBool", input);
|
||||
assertTrue(output.isBool());
|
||||
assertTrue(value == output.toBool());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqInt() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isLong());
|
||||
assertTrue(value == input.toLong());
|
||||
final IValue output = module.runMethod("eqInt", input);
|
||||
assertTrue(output.isLong());
|
||||
assertTrue(value == output.toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqFloat() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
double[] values =
|
||||
new double[] {
|
||||
-Double.MAX_VALUE,
|
||||
Double.MAX_VALUE,
|
||||
-Double.MIN_VALUE,
|
||||
Double.MIN_VALUE,
|
||||
-Math.exp(1.d),
|
||||
-Math.sqrt(2.d),
|
||||
-3.1415f,
|
||||
3.1415f,
|
||||
-1,
|
||||
0,
|
||||
1,
|
||||
};
|
||||
for (double value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isDouble());
|
||||
assertTrue(value == input.toDouble());
|
||||
final IValue output = module.runMethod("eqFloat", input);
|
||||
assertTrue(output.isDouble());
|
||||
assertTrue(value == output.toDouble());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqTensor() throws IOException {
|
||||
final long[] inputTensorShape = new long[] {1, 3, 224, 224};
|
||||
final long numElements = Tensor.numel(inputTensorShape);
|
||||
final float[] inputTensorData = new float[(int) numElements];
|
||||
for (int i = 0; i < numElements; ++i) {
|
||||
inputTensorData[i] = i;
|
||||
}
|
||||
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
|
||||
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final IValue input = IValue.from(inputTensor);
|
||||
assertTrue(input.isTensor());
|
||||
assertTrue(inputTensor == input.toTensor());
|
||||
final IValue output = module.runMethod("eqTensor", input);
|
||||
assertTrue(output.isTensor());
|
||||
final Tensor outputTensor = output.toTensor();
|
||||
assertNotNull(outputTensor);
|
||||
assertArrayEquals(inputTensorShape, outputTensor.shape());
|
||||
float[] outputData = outputTensor.getDataAsFloatArray();
|
||||
for (int i = 0; i < numElements; i++) {
|
||||
assertTrue(inputTensorData[i] == outputData[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictIntKeyIntValue() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Map<Long, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
|
||||
inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE));
|
||||
inputMap.put(0l, IValue.from(0l));
|
||||
inputMap.put(1l, IValue.from(-1l));
|
||||
inputMap.put(-1l, IValue.from(1l));
|
||||
|
||||
final IValue input = IValue.dictLongKeyFrom(inputMap);
|
||||
assertTrue(input.isDictLongKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
|
||||
assertTrue(output.isDictLongKey());
|
||||
|
||||
final Map<Long, IValue> outputMap = output.toDictLongKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictStrKeyIntValue() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Map<String, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
|
||||
inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE));
|
||||
inputMap.put("long_0", IValue.from(0l));
|
||||
inputMap.put("long_1", IValue.from(1l));
|
||||
inputMap.put("long_-1", IValue.from(-1l));
|
||||
|
||||
final IValue input = IValue.dictStringKeyFrom(inputMap);
|
||||
assertTrue(input.isDictStringKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
|
||||
assertTrue(output.isDictStringKey());
|
||||
|
||||
final Map<String, IValue> outputMap = output.toDictStringKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testListIntSumReturnTuple() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
|
||||
for (int n : new int[] {0, 1, 128}) {
|
||||
long[] a = new long[n];
|
||||
long sum = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
a[i] = i;
|
||||
sum += a[i];
|
||||
}
|
||||
final IValue input = IValue.listFrom(a);
|
||||
assertTrue(input.isLongList());
|
||||
|
||||
final IValue output = module.runMethod("listIntSumReturnTuple", input);
|
||||
|
||||
assertTrue(output.isTuple());
|
||||
assertTrue(2 == output.toTuple().length);
|
||||
|
||||
IValue output0 = output.toTuple()[0];
|
||||
IValue output1 = output.toTuple()[1];
|
||||
|
||||
assertArrayEquals(a, output0.toLongList());
|
||||
assertTrue(sum == output1.toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOptionalIntIsNone() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
|
||||
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
|
||||
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIntEq0None() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testRunUndefinedMethod() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
module.runMethod("test_undefined_method_throws_exception");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorMethods() {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
int[] ints = new int[numel];
|
||||
float[] floats = new float[numel];
|
||||
|
||||
byte[] bytes = new byte[numel];
|
||||
for (int i = 0; i < numel; i++) {
|
||||
bytes[i] = (byte) ((i % 255) - 128);
|
||||
ints[i] = i;
|
||||
floats[i] = i / 1000.f;
|
||||
}
|
||||
|
||||
Tensor tensorBytes = Tensor.fromBlob(bytes, shape);
|
||||
assertTrue(tensorBytes.dtype() == DType.INT8);
|
||||
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
|
||||
|
||||
Tensor tensorInts = Tensor.fromBlob(ints, shape);
|
||||
assertTrue(tensorInts.dtype() == DType.INT32);
|
||||
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
|
||||
|
||||
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
|
||||
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
|
||||
float[] floatsOut = tensorFloats.getDataAsFloatArray();
|
||||
assertTrue(floatsOut.length == numel);
|
||||
for (int i = 0; i < numel; i++) {
|
||||
assertTrue(floats[i] == floatsOut[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void testTensorIllegalStateOnWrongType() {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
float[] floats = new float[numel];
|
||||
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
|
||||
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
|
||||
tensorFloats.getDataAsByteArray();
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testEqString() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.toStr()));
|
||||
final IValue output = module.runMethod("eqStr", input);
|
||||
assertTrue(output.isString());
|
||||
assertTrue(value.equals(output.toStr()));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStr3Concat() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.toStr()));
|
||||
final IValue output = module.runMethod("str3Concat", input);
|
||||
assertTrue(output.isString());
|
||||
String expectedOutput = new StringBuilder().append(value).append(value).append(value).toString();
|
||||
assertTrue(expectedOutput.equals(output.toStr()));
|
||||
}
|
||||
}
|
||||
|
||||
protected abstract String assetFilePath(String assetName) throws IOException;
|
||||
}
|
@ -10,8 +10,6 @@
|
||||
|
||||
namespace pytorch_jni {
|
||||
|
||||
// NOTE: Codes must be kept in sync with DType.java.
|
||||
// NOTE: Never serialize these, because they can change between releases.
|
||||
constexpr static int kTensorDTypeUInt8 = 1;
|
||||
constexpr static int kTensorDTypeInt8 = 2;
|
||||
constexpr static int kTensorDTypeInt32 = 3;
|
||||
@ -166,7 +164,7 @@ class JTensor : public facebook::jni::JavaClass<JTensor> {
|
||||
static at::Tensor newAtTensorFromJTensor(
|
||||
facebook::jni::alias_ref<JTensor> jtensor) {
|
||||
static const auto dtypeMethod =
|
||||
JTensor::javaClassStatic()->getMethod<jint()>("dtypeJniCode");
|
||||
JTensor::javaClassStatic()->getMethod<jint()>("dtype");
|
||||
jint jdtype = dtypeMethod(jtensor);
|
||||
|
||||
static const auto shapeField =
|
||||
@ -218,7 +216,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static auto jMethodTensor =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::local_ref<JTensor>)>("from");
|
||||
facebook::jni::local_ref<JTensor>)>("tensor");
|
||||
return jMethodTensor(
|
||||
JIValue::javaClassStatic(),
|
||||
JTensor::newJTensorFromAtTensor(ivalue.toTensor()));
|
||||
@ -226,26 +224,26 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static auto jMethodBool =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(jboolean)>(
|
||||
"from");
|
||||
"bool");
|
||||
return jMethodBool(JIValue::javaClassStatic(), ivalue.toBool());
|
||||
} else if (ivalue.isInt()) {
|
||||
static auto jMethodInt =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(jlong)>(
|
||||
"from");
|
||||
"long64");
|
||||
return jMethodInt(JIValue::javaClassStatic(), ivalue.toInt());
|
||||
} else if (ivalue.isDouble()) {
|
||||
static auto jMethodDouble =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(jdouble)>(
|
||||
"from");
|
||||
"double64");
|
||||
return jMethodDouble(JIValue::javaClassStatic(), ivalue.toDouble());
|
||||
} else if (ivalue.isString()) {
|
||||
static auto jMethodString =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JString::javaobject>)>("from");
|
||||
facebook::jni::JString::javaobject>)>("string");
|
||||
return jMethodString(
|
||||
JIValue::javaClassStatic(),
|
||||
facebook::jni::make_jstring(ivalue.toStringRef()));
|
||||
@ -255,7 +253,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject>)>("tupleFrom");
|
||||
JIValue::javaobject>::javaobject>)>("tuple");
|
||||
auto jElementsArray =
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::newArray(
|
||||
elementsVec.size());
|
||||
@ -269,7 +267,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static auto jMethodBoolListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<jbooleanArray>)>("listFrom");
|
||||
facebook::jni::alias_ref<jbooleanArray>)>("boolList");
|
||||
size_t n = list.size();
|
||||
auto jArray = facebook::jni::make_boolean_array(n);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
@ -283,7 +281,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static auto jMethodLongListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<jlongArray>)>("listFrom");
|
||||
facebook::jni::alias_ref<jlongArray>)>("longList");
|
||||
size_t n = list.size();
|
||||
auto jArray = facebook::jni::make_long_array(n);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
@ -297,7 +295,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static auto jMethoDoubleListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<jdoubleArray>)>("listFrom");
|
||||
facebook::jni::alias_ref<jdoubleArray>)>("doubleList");
|
||||
size_t n = list.size();
|
||||
auto jArray = facebook::jni::make_double_array(n);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
@ -312,7 +310,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JArrayClass<
|
||||
JTensor::javaobject>::javaobject>)>("listFrom");
|
||||
JTensor::javaobject>::javaobject>)>("tensorList");
|
||||
auto jArray = facebook::jni::JArrayClass<JTensor::javaobject>::newArray(
|
||||
list.size());
|
||||
auto index = 0;
|
||||
@ -326,7 +324,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject>)>("listFrom");
|
||||
JIValue::javaobject>::javaobject>)>("list");
|
||||
auto jArray = facebook::jni::JArrayClass<JIValue::javaobject>::newArray(
|
||||
list.size());
|
||||
auto index = 0;
|
||||
@ -353,7 +351,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JString::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||
"dictStringKeyFrom");
|
||||
"dictStringKey");
|
||||
|
||||
auto jmap = JHashMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
|
||||
@ -372,7 +370,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JLong::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||
"dictLongKeyFrom");
|
||||
"dictLongKey");
|
||||
auto jmap = JHashMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>::create();
|
||||
@ -406,32 +404,32 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static const auto jMethodGetTensor =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::alias_ref<JTensor::javaobject>()>(
|
||||
"toTensor");
|
||||
"getTensor");
|
||||
return JTensor::newAtTensorFromJTensor(jMethodGetTensor(jivalue));
|
||||
} else if (JIValue::kTypeCodeBool == typeCode) {
|
||||
static const auto jMethodGetBool =
|
||||
JIValue::javaClassStatic()->getMethod<jboolean()>("toBool");
|
||||
JIValue::javaClassStatic()->getMethod<jboolean()>("getBool");
|
||||
// explicit cast to bool as jboolean is defined as uint8_t, IValue ctor
|
||||
// for int will be called for jboolean
|
||||
bool b = jMethodGetBool(jivalue);
|
||||
return at::IValue{b};
|
||||
} else if (JIValue::kTypeCodeLong == typeCode) {
|
||||
static const auto jMethodGetLong =
|
||||
JIValue::javaClassStatic()->getMethod<jlong()>("toLong");
|
||||
JIValue::javaClassStatic()->getMethod<jlong()>("getLong");
|
||||
return at::IValue{jMethodGetLong(jivalue)};
|
||||
} else if (JIValue::kTypeCodeDouble == typeCode) {
|
||||
static const auto jMethodGetDouble =
|
||||
JIValue::javaClassStatic()->getMethod<jdouble()>("toDouble");
|
||||
JIValue::javaClassStatic()->getMethod<jdouble()>("getDouble");
|
||||
return at::IValue{jMethodGetDouble(jivalue)};
|
||||
} else if (JIValue::kTypeCodeString == typeCode) {
|
||||
static const auto jMethodGetString =
|
||||
JIValue::javaClassStatic()->getMethod<jstring()>("toStr");
|
||||
JIValue::javaClassStatic()->getMethod<jstring()>("getString");
|
||||
return at::IValue{jMethodGetString(jivalue)->toStdString()};
|
||||
} else if (JIValue::kTypeCodeTuple == typeCode) {
|
||||
static const auto jMethodGetTuple =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject()>("toTuple");
|
||||
JIValue::javaobject>::javaobject()>("getTuple");
|
||||
auto jarray = jMethodGetTuple(jivalue);
|
||||
size_t n = jarray->size();
|
||||
|
||||
@ -445,7 +443,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
return c10::ivalue::Tuple::create(std::move(elements));
|
||||
} else if (JIValue::kTypeCodeBoolList == typeCode) {
|
||||
static const auto jMethodGetBoolList =
|
||||
JIValue::javaClassStatic()->getMethod<jbooleanArray()>("toBoolList");
|
||||
JIValue::javaClassStatic()->getMethod<jbooleanArray()>("getBoolList");
|
||||
auto jArray = jMethodGetBoolList(jivalue);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
size_t n = jArrayPinned.size();
|
||||
@ -457,7 +455,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
return at::IValue{std::move(list)};
|
||||
} else if (JIValue::kTypeCodeLongList == typeCode) {
|
||||
static const auto jMethodGetLongList =
|
||||
JIValue::javaClassStatic()->getMethod<jlongArray()>("toLongList");
|
||||
JIValue::javaClassStatic()->getMethod<jlongArray()>("getLongList");
|
||||
auto jArray = jMethodGetLongList(jivalue);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
size_t n = jArrayPinned.size();
|
||||
@ -470,7 +468,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
} else if (JIValue::kTypeCodeDoubleList == typeCode) {
|
||||
static const auto jMethodGetDoubleList =
|
||||
JIValue::javaClassStatic()->getMethod<jdoubleArray()>(
|
||||
"toDoubleList");
|
||||
"getDoubleList");
|
||||
auto jArray = jMethodGetDoubleList(jivalue);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
size_t n = jArrayPinned.size();
|
||||
@ -484,7 +482,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static const auto jMethodGetTensorList =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JArrayClass<
|
||||
JTensor::javaobject>::javaobject()>("toTensorList");
|
||||
JTensor::javaobject>::javaobject()>("getTensorList");
|
||||
auto jArray = jMethodGetTensorList(jivalue);
|
||||
size_t n = jArray->size();
|
||||
c10::List<at::Tensor> list{};
|
||||
@ -497,7 +495,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static const auto jMethodGetList =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject()>("toList");
|
||||
JIValue::javaobject>::javaobject()>("getList");
|
||||
auto jarray = jMethodGetList(jivalue);
|
||||
size_t n = jarray->size();
|
||||
if (n == 0) {
|
||||
@ -520,7 +518,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static const auto jMethodGetDictStringKey =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JMap<jstring, JIValue::javaobject>::
|
||||
javaobject()>("toDictStringKey");
|
||||
javaobject()>("getDictStringKey");
|
||||
auto jmap = jMethodGetDictStringKey(jivalue);
|
||||
auto it = jmap->begin();
|
||||
if (it == jmap->end()) {
|
||||
@ -543,7 +541,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JMap<
|
||||
facebook::jni::JLong::javaobject,
|
||||
JIValue::javaobject>::javaobject()>("toDictLongKey");
|
||||
JIValue::javaobject>::javaobject()>("getDictLongKey");
|
||||
auto jmap = jMethodGetDictLongKey(jivalue);
|
||||
auto it = jmap->begin();
|
||||
if (it == jmap->end()) {
|
||||
|
@ -1,29 +0,0 @@
|
||||
package org.pytorch;
|
||||
|
||||
/**
|
||||
* Codes representing tensor data types.
|
||||
*/
|
||||
public enum DType {
|
||||
// NOTE: "jniCode" must be kept in sync with pytorch_jni.cpp.
|
||||
// NOTE: Never serialize "jniCode", because it can change between releases.
|
||||
|
||||
/** Code for dtype torch.uint8. {@link Tensor#dtype()} */
|
||||
UINT8(1),
|
||||
/** Code for dtype torch.int8. {@link Tensor#dtype()} */
|
||||
INT8(2),
|
||||
/** Code for dtype torch.int32. {@link Tensor#dtype()} */
|
||||
INT32(3),
|
||||
/** Code for dtype torch.float32. {@link Tensor#dtype()} */
|
||||
FLOAT32(4),
|
||||
/** Code for dtype torch.int64. {@link Tensor#dtype()} */
|
||||
INT64(5),
|
||||
/** Code for dtype torch.float64. {@link Tensor#dtype()} */
|
||||
FLOAT64(6),
|
||||
;
|
||||
|
||||
final int jniCode;
|
||||
|
||||
DType(int jniCode) {
|
||||
this.jniCode = jniCode;
|
||||
}
|
||||
}
|
@ -4,21 +4,10 @@ import java.util.Locale;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Java representation of a TorchScript value, which is implemented as tagged union that can be
|
||||
* one of the supported types: https://pytorch.org/docs/stable/jit.html#types .
|
||||
* Java representation of a torchscript variable, which is implemented as tagged union that can be
|
||||
* one of the supported types: https://pytorch.org/docs/stable/jit.html#types.
|
||||
* <p>
|
||||
* Calling {@code toX} methods for inappropriate types will throw {@link IllegalStateException}.
|
||||
* <p>
|
||||
* {@code IValue} objects are constructed with {@code IValue.from(value)},
|
||||
* {@code IValue.tupleFrom(value1, value2, ...)}, {@code IValue.listFrom(value1, value2, ...)},
|
||||
* or one of the {@code dict} methods, depending on the key type.
|
||||
* <p>
|
||||
* Data is retrieved from {@code IValue} objects with the {@code toX()} methods. Note that
|
||||
* {@code str}-type IValues must be extracted with {@link #toStr()},
|
||||
* rather than {@link #toString()}.
|
||||
* <p>
|
||||
* {@code IValue} objects may retain references to objects passed into their constructors,
|
||||
* and may return references to their internal state from {@code toX()}.
|
||||
* Calling getters for inappropriate types will throw IllegalStateException.
|
||||
*/
|
||||
public class IValue {
|
||||
private static final int TYPE_CODE_NULL = 1;
|
||||
@ -102,98 +91,95 @@ public class IValue {
|
||||
return TYPE_CODE_DICT_LONG_KEY == this.mTypeCode;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@code IValue} of type {@code Optional} that contains no value.
|
||||
*/
|
||||
public static IValue optionalNull() {
|
||||
return new IValue(TYPE_CODE_NULL);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@code IValue} of type {@code Tensor}.
|
||||
* Creates a new IValue instance of torchscript Tensor type.
|
||||
*/
|
||||
public static IValue from(Tensor tensor) {
|
||||
public static IValue tensor(Tensor tensor) {
|
||||
final IValue iv = new IValue(TYPE_CODE_TENSOR);
|
||||
iv.mData = tensor;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@code IValue} of type {@code bool}.
|
||||
* Creates a new IValue instance of torchscript bool type.
|
||||
*/
|
||||
public static IValue from(boolean value) {
|
||||
public static IValue bool(boolean value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_BOOL);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@code IValue} of type {@code int}.
|
||||
* Creates a new IValue instance of torchscript int type.
|
||||
*/
|
||||
public static IValue from(long value) {
|
||||
public static IValue long64(long value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_LONG);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@code IValue} of type {@code float}.
|
||||
* Creates a new IValue instance of torchscript float type.
|
||||
*/
|
||||
public static IValue from(double value) {
|
||||
public static IValue double64(double value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DOUBLE);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@code IValue} of type {@code str}.
|
||||
* Creates new IValue instance of torchscript str type.
|
||||
*/
|
||||
public static IValue from(String value) {
|
||||
public static IValue string(String value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_STRING);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@code IValue} of type {@code List[bool]}.
|
||||
* Creates a new IValue instance of torchscript List[bool] type.
|
||||
*/
|
||||
public static IValue listFrom(boolean... list) {
|
||||
public static IValue boolList(boolean... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_BOOL_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@code IValue} of type {@code List[int]}.
|
||||
* Creates a new IValue instance of torchscript List[int] type.
|
||||
*/
|
||||
public static IValue listFrom(long... list) {
|
||||
public static IValue longList(long... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_LONG_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@code IValue} of type {@code List[float]}.
|
||||
* Creates a new IValue instance of torchscript List[float] type.
|
||||
*/
|
||||
public static IValue listFrom(double... list) {
|
||||
public static IValue doubleList(double... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DOUBLE_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@code IValue} of type {@code List[Tensor]}.
|
||||
* Creates a new IValue instance of torchscript List[Tensor] type.
|
||||
*/
|
||||
public static IValue listFrom(Tensor... list) {
|
||||
public static IValue tensorList(Tensor... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_TENSOR_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@code IValue} of type {@code List[T]}. All elements must have the same type.
|
||||
* Creates a new IValue instance of torchscript List[T] type. All elements must have the same type.
|
||||
*/
|
||||
public static IValue listFrom(IValue... array) {
|
||||
public static IValue list(IValue... array) {
|
||||
final int size = array.length;
|
||||
if (size > 0) {
|
||||
final int typeCode0 = array[0].mTypeCode;
|
||||
@ -210,93 +196,93 @@ public class IValue {
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@code IValue} of type {@code Tuple[T0, T1, ...]}.
|
||||
* Creates a new IValue instance of torchscript Tuple[T0, T1, ...] type.
|
||||
*/
|
||||
public static IValue tupleFrom(IValue... array) {
|
||||
public static IValue tuple(IValue... array) {
|
||||
final IValue iv = new IValue(TYPE_CODE_TUPLE);
|
||||
iv.mData = array;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@code IValue} of type {@code Dict[str, V]}.
|
||||
* Creates a new IValue instance oftorchscript Dict[Str, V] type.
|
||||
*/
|
||||
public static IValue dictStringKeyFrom(Map<String, IValue> map) {
|
||||
public static IValue dictStringKey(Map<String, IValue> map) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DICT_STRING_KEY);
|
||||
iv.mData = map;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@code IValue} of type {@code Dict[int, V]}.
|
||||
* Creates a new IValue instance of torchscript Dict[int, V] type.
|
||||
*/
|
||||
public static IValue dictLongKeyFrom(Map<Long, IValue> map) {
|
||||
public static IValue dictLongKey(Map<Long, IValue> map) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DICT_LONG_KEY);
|
||||
iv.mData = map;
|
||||
return iv;
|
||||
}
|
||||
|
||||
public Tensor toTensor() {
|
||||
public Tensor getTensor() {
|
||||
preconditionType(TYPE_CODE_TENSOR, mTypeCode);
|
||||
return (Tensor) mData;
|
||||
}
|
||||
|
||||
public boolean toBool() {
|
||||
public boolean getBool() {
|
||||
preconditionType(TYPE_CODE_BOOL, mTypeCode);
|
||||
return (boolean) mData;
|
||||
}
|
||||
|
||||
public long toLong() {
|
||||
public long getLong() {
|
||||
preconditionType(TYPE_CODE_LONG, mTypeCode);
|
||||
return (long) mData;
|
||||
}
|
||||
|
||||
public double toDouble() {
|
||||
public double getDouble() {
|
||||
preconditionType(TYPE_CODE_DOUBLE, mTypeCode);
|
||||
return (double) mData;
|
||||
}
|
||||
|
||||
public String toStr() {
|
||||
public String getString() {
|
||||
preconditionType(TYPE_CODE_STRING, mTypeCode);
|
||||
return (String) mData;
|
||||
}
|
||||
|
||||
public boolean[] toBoolList() {
|
||||
public boolean[] getBoolList() {
|
||||
preconditionType(TYPE_CODE_BOOL_LIST, mTypeCode);
|
||||
return (boolean[]) mData;
|
||||
}
|
||||
|
||||
public long[] toLongList() {
|
||||
public long[] getLongList() {
|
||||
preconditionType(TYPE_CODE_LONG_LIST, mTypeCode);
|
||||
return (long[]) mData;
|
||||
}
|
||||
|
||||
public double[] toDoubleList() {
|
||||
public double[] getDoubleList() {
|
||||
preconditionType(TYPE_CODE_DOUBLE_LIST, mTypeCode);
|
||||
return (double[]) mData;
|
||||
}
|
||||
|
||||
public Tensor[] toTensorList() {
|
||||
public Tensor[] getTensorList() {
|
||||
preconditionType(TYPE_CODE_TENSOR_LIST, mTypeCode);
|
||||
return (Tensor[]) mData;
|
||||
}
|
||||
|
||||
public IValue[] toList() {
|
||||
public IValue[] getList() {
|
||||
preconditionType(TYPE_CODE_LIST, mTypeCode);
|
||||
return (IValue[]) mData;
|
||||
}
|
||||
|
||||
public IValue[] toTuple() {
|
||||
public IValue[] getTuple() {
|
||||
preconditionType(TYPE_CODE_TUPLE, mTypeCode);
|
||||
return (IValue[]) mData;
|
||||
}
|
||||
|
||||
public Map<String, IValue> toDictStringKey() {
|
||||
public Map<String, IValue> getDictStringKey() {
|
||||
preconditionType(TYPE_CODE_DICT_STRING_KEY, mTypeCode);
|
||||
return (Map<String, IValue>) mData;
|
||||
}
|
||||
|
||||
public Map<Long, IValue> toDictLongKey() {
|
||||
public Map<Long, IValue> getDictLongKey() {
|
||||
preconditionType(TYPE_CODE_DICT_LONG_KEY, mTypeCode);
|
||||
return (Map<Long, IValue>) mData;
|
||||
}
|
||||
|
@ -5,20 +5,21 @@ package org.pytorch;
|
||||
import com.facebook.jni.HybridData;
|
||||
|
||||
/**
|
||||
* Java wrapper for torch::jit::script::Module.
|
||||
* Java holder for torch::jit::script::Module which owns it on jni side.
|
||||
*/
|
||||
public class Module {
|
||||
|
||||
private NativePeer mNativePeer;
|
||||
|
||||
/**
|
||||
* Loads a serialized TorchScript module from the specified path on the disk.
|
||||
* Loads serialized torchscript module from the specified absolute path on the disk.
|
||||
*
|
||||
* @param modelPath path to file that contains the serialized TorchScript module.
|
||||
* @return new {@link org.pytorch.Module} object which owns torch::jit::script::Module.
|
||||
* @param modelAbsolutePath absolute path to file that contains the serialized torchscript module.
|
||||
* @return new {@link org.pytorch.Module} object which owns torch::jit::script::Module on jni
|
||||
* side.
|
||||
*/
|
||||
public static Module load(final String modelPath) {
|
||||
return new Module(modelPath);
|
||||
public static Module load(final String modelAbsolutePath) {
|
||||
return new Module(modelAbsolutePath);
|
||||
}
|
||||
|
||||
private Module(final String moduleAbsolutePath) {
|
||||
@ -26,32 +27,35 @@ public class Module {
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the 'forward' method of this module with the specified arguments.
|
||||
* Runs 'forward' method of loaded torchscript module with specified arguments.
|
||||
*
|
||||
* @param inputs arguments for the TorchScript module's 'forward' method.
|
||||
* @return return value from the 'forward' method.
|
||||
* @param inputs arguments for torchscript module 'forward' method.
|
||||
* @return result of torchscript module 'forward' method evaluation
|
||||
*/
|
||||
public IValue forward(IValue... inputs) {
|
||||
return mNativePeer.forward(inputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the specified method of this module with the specified arguments.
|
||||
* Runs specified method of loaded torchscript module with specified arguments.
|
||||
*
|
||||
* @param methodName name of the TorchScript method to run.
|
||||
* @param inputs arguments that will be passed to TorchScript method.
|
||||
* @return return value from the method.
|
||||
* @param methodName torchscript module method to run
|
||||
* @param inputs arguments that will be specified to torchscript module method call
|
||||
* @return result of torchscript module specified method evaluation
|
||||
*/
|
||||
public IValue runMethod(String methodName, IValue... inputs) {
|
||||
return mNativePeer.runMethod(methodName, inputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Explicitly destroys the native torch::jit::script::Module.
|
||||
* Calling this method is not required, as the native object will be destroyed
|
||||
* when this object is garbage-collected. However, the timing of garbage collection
|
||||
* is not guaranteed, so proactively calling {@code destroy} can free memory more quickly.
|
||||
* See {@link com.facebook.jni.HybridData#resetNative}.
|
||||
* Explicitly destructs native part. Current instance can not be used after this call. This
|
||||
* method may be called multiple times safely. As fbjni library destructs native part
|
||||
* automatically when current instance will be
|
||||
* collected by Java GC, the instance will not leak if this method is not called,
|
||||
* but timing of deletion and the thread will be at the whim of the Java GC.
|
||||
* If you want to control the thread and timing of the destructor, you should call this method
|
||||
* explicitly.
|
||||
* {@link com.facebook.jni.HybridData#resetNative}
|
||||
*/
|
||||
public void destroy() {
|
||||
mNativePeer.mHybridData.resetNative();
|
||||
|
@ -11,24 +11,24 @@ import java.util.Arrays;
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
* Representation of a Tensor. Behavior is similar to PyTorch's tensor objects.
|
||||
* <p>
|
||||
* Most tensors will be constructed as {@code Tensor.fromBlob(data, shape)},
|
||||
* where {@code data} can be an array or a direct {@link Buffer} (of the proper subclass).
|
||||
* Helper methods are provided to allocate buffers properly.
|
||||
* <p>
|
||||
* To access Tensor data, see {@link #dtype()}, {@link #shape()},
|
||||
* and various {@code getDataAs*} methods.
|
||||
* <p>
|
||||
* When constructing {@code Tensor} objects with {@code data} as an array,
|
||||
* it is not specified whether this data is is copied or retained as a reference
|
||||
* so it is recommended not to modify it after constructing. {@code data} passed as a
|
||||
* {@link Buffer} is not copied, so it can be modified between {@link Module} calls
|
||||
* to avoid reallocation. Data retrieved from {@code Tensor} objects may be copied or
|
||||
* may be a reference to the {@code Tensor}'s internal data buffer.
|
||||
* {@code shape} is always copied.
|
||||
* Representation of Tensor. Tensor shape is stored in {@link Tensor#shape}, elements are stored as
|
||||
* {@link java.nio.DirectByteBuffer} of one of the supported types.
|
||||
*/
|
||||
public abstract class Tensor {
|
||||
|
||||
/** Code for dtype torch.uint8. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_UINT8 = 1;
|
||||
/** Code for dtype torch.int8. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_INT8 = 2;
|
||||
/** Code for dtype torch.int32. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_INT32 = 3;
|
||||
/** Code for dtype torch.float32. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_FLOAT32 = 4;
|
||||
/** Code for dtype torch.int64. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_INT64 = 5;
|
||||
/** Code for dtype torch.float64. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_FLOAT64 = 6;
|
||||
|
||||
private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null";
|
||||
private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null";
|
||||
private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null";
|
||||
@ -39,7 +39,8 @@ public abstract class Tensor {
|
||||
private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT =
|
||||
"Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)";
|
||||
|
||||
final long[] shape;
|
||||
/** Shape of current tensor. */
|
||||
public final long[] shape;
|
||||
|
||||
private static final int INT_SIZE_BYTES = 4;
|
||||
private static final int FLOAT_SIZE_BYTES = 4;
|
||||
@ -48,8 +49,8 @@ public abstract class Tensor {
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.ByteBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#fromBlob(ByteBuffer, long[])},
|
||||
* {@link Tensor#fromBlobUnsigned(ByteBuffer, long[])}.
|
||||
* capacity that can be used in {@link Tensor#newInt8Tensor(long[], ByteBuffer)}, {@link
|
||||
* Tensor#newUInt8Tensor(long[], ByteBuffer)}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
@ -57,12 +58,6 @@ public abstract class Tensor {
|
||||
return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder());
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.IntBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#fromBlob(IntBuffer, long[])}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
public static IntBuffer allocateIntBuffer(int numElements) {
|
||||
return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
@ -71,7 +66,7 @@ public abstract class Tensor {
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.FloatBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#fromBlob(FloatBuffer, long[])}.
|
||||
* capacity that can be used in {@link Tensor#newFloat32Tensor(long[], FloatBuffer)}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
@ -83,7 +78,7 @@ public abstract class Tensor {
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.LongBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#fromBlob(LongBuffer, long[])}.
|
||||
* capacity that can be used in {@link Tensor#newInt64Tensor(long[], LongBuffer)}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
@ -95,7 +90,7 @@ public abstract class Tensor {
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.DoubleBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}.
|
||||
* capacity that can be used in {@link Tensor#newFloat64Tensor(long[], DoubleBuffer)}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
@ -109,10 +104,10 @@ public abstract class Tensor {
|
||||
* Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of
|
||||
* bytes.
|
||||
*
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
*/
|
||||
public static Tensor fromBlobUnsigned(byte[] data, long[] shape) {
|
||||
public static Tensor newUInt8Tensor(long[] shape, byte[] data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -126,10 +121,10 @@ public abstract class Tensor {
|
||||
* Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of
|
||||
* bytes.
|
||||
*
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
*/
|
||||
public static Tensor fromBlob(byte[] data, long[] shape) {
|
||||
public static Tensor newInt8Tensor(long[] shape, byte[] data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -143,10 +138,10 @@ public abstract class Tensor {
|
||||
* Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of
|
||||
* ints.
|
||||
*
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
*/
|
||||
public static Tensor fromBlob(int[] data, long[] shape) {
|
||||
public static Tensor newInt32Tensor(long[] shape, int[] data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -160,10 +155,10 @@ public abstract class Tensor {
|
||||
* Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array
|
||||
* of floats.
|
||||
*
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
*/
|
||||
public static Tensor fromBlob(float[] data, long[] shape) {
|
||||
public static Tensor newFloat32Tensor(long[] shape, float[] data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -177,10 +172,10 @@ public abstract class Tensor {
|
||||
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of
|
||||
* longs.
|
||||
*
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
*/
|
||||
public static Tensor fromBlob(long[] data, long[] shape) {
|
||||
public static Tensor newInt64Tensor(long[] shape, long[] data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -197,7 +192,7 @@ public abstract class Tensor {
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
*/
|
||||
public static Tensor fromBlob(long[] shape, double[] data) {
|
||||
public static Tensor newFloat64Tensor(long[] shape, double[] data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -210,12 +205,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.uint8 with specified shape and data.
|
||||
*
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) {
|
||||
public static Tensor newUInt8Tensor(long[] shape, ByteBuffer data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -230,12 +225,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.int8 with specified shape and data.
|
||||
*
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlob(ByteBuffer data, long[] shape) {
|
||||
public static Tensor newInt8Tensor(long[] shape, ByteBuffer data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -250,12 +245,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.int32 with specified shape and data.
|
||||
*
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlob(IntBuffer data, long[] shape) {
|
||||
public static Tensor newInt32Tensor(long[] shape, IntBuffer data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -270,12 +265,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.float32 with specified shape and data.
|
||||
*
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlob(FloatBuffer data, long[] shape) {
|
||||
public static Tensor newFloat32Tensor(long[] shape, FloatBuffer data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -290,12 +285,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data.
|
||||
*
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlob(LongBuffer data, long[] shape) {
|
||||
public static Tensor newInt64Tensor(long[] shape, LongBuffer data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -310,12 +305,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.float64 with specified shape and data.
|
||||
*
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
|
||||
public static Tensor newFloat64Tensor(long[] shape, DoubleBuffer data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -332,14 +327,12 @@ public abstract class Tensor {
|
||||
this.shape = Arrays.copyOf(shape, shape.length);
|
||||
}
|
||||
|
||||
/** Returns the number of elements in this tensor. */
|
||||
/** Calculates number of elements in current tensor instance. */
|
||||
public long numel() {
|
||||
return numel(this.shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the number of elements in a tensor with the specified shape.
|
||||
*/
|
||||
/** Calculates number of elements in tensor with specified shape. */
|
||||
public static long numel(long[] shape) {
|
||||
checkShape(shape);
|
||||
int result = 1;
|
||||
@ -350,24 +343,14 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the shape of this tensor. (The array is a fresh copy.)
|
||||
* Returns dtype of current tensor. Can be one of {@link Tensor#DTYPE_UINT8}, {@link
|
||||
* Tensor#DTYPE_INT8}, {@link Tensor#DTYPE_INT32},{@link Tensor#DTYPE_FLOAT32}, {@link
|
||||
* Tensor#DTYPE_INT64}, {@link Tensor#DTYPE_FLOAT64}.
|
||||
*/
|
||||
public long[] shape() {
|
||||
return Arrays.copyOf(shape, shape.length);
|
||||
}
|
||||
public abstract int dtype();
|
||||
|
||||
/**
|
||||
* @return data type of this tensor.
|
||||
*/
|
||||
public abstract DType dtype();
|
||||
|
||||
// Called from native
|
||||
int dtypeJniCode() {
|
||||
return dtype().jniCode;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a Java byte array that contains the tensor data. This may be a copy or reference.
|
||||
* Returns newly allocated java byte array that contains a copy of tensor data.
|
||||
*
|
||||
* @throws IllegalStateException if it is called for a non-int8 tensor.
|
||||
*/
|
||||
@ -377,7 +360,7 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a Java byte array that contains the tensor data. This may be a copy or reference.
|
||||
* Returns newly allocated java byte array that contains a copy of tensor data.
|
||||
*
|
||||
* @throws IllegalStateException if it is called for a non-uint8 tensor.
|
||||
*/
|
||||
@ -387,7 +370,7 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a Java int array that contains the tensor data. This may be a copy or reference.
|
||||
* Returns newly allocated java byte array that contains a copy of tensor data.
|
||||
*
|
||||
* @throws IllegalStateException if it is called for a non-int32 tensor.
|
||||
*/
|
||||
@ -397,7 +380,7 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a Java float array that contains the tensor data. This may be a copy or reference.
|
||||
* Returns newly allocated java byte array that contains a copy of tensor data.
|
||||
*
|
||||
* @throws IllegalStateException if it is called for a non-float32 tensor.
|
||||
*/
|
||||
@ -407,7 +390,7 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a Java long array that contains the tensor data. This may be a copy or reference.
|
||||
* Returns newly allocated java byte array that contains a copy of tensor data.
|
||||
*
|
||||
* @throws IllegalStateException if it is called for a non-int64 tensor.
|
||||
*/
|
||||
@ -417,7 +400,7 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a Java double array that contains the tensor data. This may be a copy or reference.
|
||||
* Returns newly allocated java byte array that contains a copy of tensor data.
|
||||
*
|
||||
* @throws IllegalStateException if it is called for a non-float64 tensor.
|
||||
*/
|
||||
@ -440,8 +423,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public DType dtype() {
|
||||
return DType.UINT8;
|
||||
public int dtype() {
|
||||
return DTYPE_UINT8;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -472,8 +455,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public DType dtype() {
|
||||
return DType.INT8;
|
||||
public int dtype() {
|
||||
return DTYPE_INT8;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -504,8 +487,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public DType dtype() {
|
||||
return DType.INT32;
|
||||
public int dtype() {
|
||||
return DTYPE_INT32;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -544,8 +527,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public DType dtype() {
|
||||
return DType.FLOAT32;
|
||||
public int dtype() {
|
||||
return DTYPE_FLOAT32;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -568,8 +551,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public DType dtype() {
|
||||
return DType.INT64;
|
||||
public int dtype() {
|
||||
return DTYPE_INT64;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -600,8 +583,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public DType dtype() {
|
||||
return DType.FLOAT64;
|
||||
public int dtype() {
|
||||
return DTYPE_FLOAT64;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -651,17 +634,17 @@ public abstract class Tensor {
|
||||
|
||||
// Called from native
|
||||
private static Tensor nativeNewTensor(ByteBuffer data, long[] shape, int dtype) {
|
||||
if (DType.FLOAT32.jniCode == dtype) {
|
||||
if (DTYPE_FLOAT32 == dtype) {
|
||||
return new Tensor_float32(data.asFloatBuffer(), shape);
|
||||
} else if (DType.INT32.jniCode == dtype) {
|
||||
} else if (DTYPE_INT32 == dtype) {
|
||||
return new Tensor_int32(data.asIntBuffer(), shape);
|
||||
} else if (DType.INT64.jniCode == dtype) {
|
||||
} else if (DTYPE_INT64 == dtype) {
|
||||
return new Tensor_int64(data.asLongBuffer(), shape);
|
||||
} else if (DType.FLOAT64.jniCode == dtype) {
|
||||
} else if (DTYPE_FLOAT64 == dtype) {
|
||||
return new Tensor_float64(data.asDoubleBuffer(), shape);
|
||||
} else if (DType.UINT8.jniCode == dtype) {
|
||||
} else if (DTYPE_UINT8 == dtype) {
|
||||
return new Tensor_uint8(data, shape);
|
||||
} else if (DType.INT8.jniCode == dtype) {
|
||||
} else if (DTYPE_INT8 == dtype) {
|
||||
return new Tensor_int8(data, shape);
|
||||
}
|
||||
throw new IllegalArgumentException("Unknown Tensor dtype");
|
||||
|
@ -7,7 +7,6 @@ import android.media.Image;
|
||||
import org.pytorch.Tensor;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
@ -27,7 +26,7 @@ public final class TensorImageUtils {
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
*/
|
||||
public static Tensor bitmapToFloat32Tensor(
|
||||
final Bitmap bitmap, final float[] normMeanRGB, final float normStdRGB[]) {
|
||||
final Bitmap bitmap, float[] normMeanRGB, float normStdRGB[]) {
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
|
||||
@ -35,52 +34,6 @@ public final class TensorImageUtils {
|
||||
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), normMeanRGB, normStdRGB);
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes tensor content from specified {@link android.graphics.Bitmap},
|
||||
* normalized with specified in parameters mean and std to specified {@link java.nio.FloatBuffer}
|
||||
* with specified offset.
|
||||
*
|
||||
* @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data
|
||||
* @param x - x coordinate of top left corner of bitmap's area
|
||||
* @param y - y coordinate of top left corner of bitmap's area
|
||||
* @param width - width of bitmap's area
|
||||
* @param height - height of bitmap's area
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
*/
|
||||
public static void bitmapToFloatBuffer(
|
||||
final Bitmap bitmap,
|
||||
final int x,
|
||||
final int y,
|
||||
final int width,
|
||||
final int height,
|
||||
final float[] normMeanRGB,
|
||||
final float[] normStdRGB,
|
||||
final FloatBuffer outBuffer,
|
||||
final int outBufferOffset) {
|
||||
checkOutBufferCapacity(outBuffer, outBufferOffset, width, height);
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
|
||||
final int pixelsCount = height * width;
|
||||
final int[] pixels = new int[pixelsCount];
|
||||
bitmap.getPixels(pixels, 0, width, x, y, width, height);
|
||||
final int offset_g = pixelsCount;
|
||||
final int offset_b = 2 * pixelsCount;
|
||||
for (int i = 0; i < pixelsCount; i++) {
|
||||
final int c = pixels[i];
|
||||
float r = ((c >> 16) & 0xff) / 255.0f;
|
||||
float g = ((c >> 8) & 0xff) / 255.0f;
|
||||
float b = ((c) & 0xff) / 255.0f;
|
||||
float rF = (r - normMeanRGB[0]) / normStdRGB[0];
|
||||
float gF = (g - normMeanRGB[1]) / normStdRGB[1];
|
||||
float bF = (b - normMeanRGB[2]) / normStdRGB[2];
|
||||
outBuffer.put(outBufferOffset + i, rF);
|
||||
outBuffer.put(outBufferOffset + offset_g + i, gF);
|
||||
outBuffer.put(outBufferOffset + offset_b + i, bF);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates new {@link org.pytorch.Tensor} from specified area of {@link android.graphics.Bitmap},
|
||||
* normalized with specified in parameters mean and std.
|
||||
@ -104,9 +57,22 @@ public final class TensorImageUtils {
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
|
||||
final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * width * height);
|
||||
bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0);
|
||||
return Tensor.fromBlob(floatBuffer, new long[]{1, 3, height, width});
|
||||
final int pixelsCount = height * width;
|
||||
final int[] pixels = new int[pixelsCount];
|
||||
bitmap.getPixels(pixels, 0, width, x, y, width, height);
|
||||
final float[] floatArray = new float[3 * pixelsCount];
|
||||
final int offset_g = pixelsCount;
|
||||
final int offset_b = 2 * pixelsCount;
|
||||
for (int i = 0; i < pixelsCount; i++) {
|
||||
final int c = pixels[i];
|
||||
float r = ((c >> 16) & 0xff) / 255.0f;
|
||||
float g = ((c >> 8) & 0xff) / 255.0f;
|
||||
float b = ((c) & 0xff) / 255.0f;
|
||||
floatArray[i] = (r - normMeanRGB[0]) / normStdRGB[0];
|
||||
floatArray[offset_g + i] = (g - normMeanRGB[1]) / normStdRGB[1];
|
||||
floatArray[offset_b + i] = (b - normMeanRGB[2]) / normStdRGB[2];
|
||||
}
|
||||
return Tensor.newFloat32Tensor(new long[]{1, 3, height, width}, floatArray);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -139,52 +105,6 @@ public final class TensorImageUtils {
|
||||
checkRotateCWDegrees(rotateCWDegrees);
|
||||
checkTensorSize(tensorWidth, tensorHeight);
|
||||
|
||||
final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * tensorWidth * tensorHeight);
|
||||
imageYUV420CenterCropToFloatBuffer(
|
||||
image,
|
||||
rotateCWDegrees,
|
||||
tensorWidth,
|
||||
tensorHeight,
|
||||
normMeanRGB, normStdRGB, floatBuffer, 0);
|
||||
return Tensor.fromBlob(floatBuffer, new long[]{1, 3, tensorHeight, tensorWidth});
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes tensor content from specified {@link android.media.Image}, doing optional rotation,
|
||||
* scaling (nearest) and center cropping to specified {@link java.nio.FloatBuffer} with specified offset.
|
||||
*
|
||||
* @param image {@link android.media.Image} as a source for Tensor data
|
||||
* @param rotateCWDegrees Clockwise angle through which the input image needs to be rotated to be
|
||||
* upright. Range of valid values: 0, 90, 180, 270
|
||||
* @param tensorWidth return tensor width, must be positive
|
||||
* @param tensorHeight return tensor height, must be positive
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param outBuffer Output buffer, where tensor content will be written
|
||||
* @param outBufferOffset Output buffer offset with which tensor content will be written
|
||||
*/
|
||||
public static void imageYUV420CenterCropToFloatBuffer(
|
||||
final Image image,
|
||||
int rotateCWDegrees,
|
||||
final int tensorWidth,
|
||||
final int tensorHeight,
|
||||
float[] normMeanRGB,
|
||||
float[] normStdRGB,
|
||||
final FloatBuffer outBuffer,
|
||||
final int outBufferOffset) {
|
||||
checkOutBufferCapacity(outBuffer, outBufferOffset, tensorWidth, tensorHeight);
|
||||
|
||||
if (image.getFormat() != ImageFormat.YUV_420_888) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
Locale.US, "Image format %d != ImageFormat.YUV_420_888", image.getFormat()));
|
||||
}
|
||||
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
checkRotateCWDegrees(rotateCWDegrees);
|
||||
checkTensorSize(tensorWidth, tensorHeight);
|
||||
|
||||
final int widthBeforeRotation = image.getWidth();
|
||||
final int heightBeforeRotation = image.getHeight();
|
||||
|
||||
@ -238,6 +158,7 @@ public final class TensorImageUtils {
|
||||
final int channelSize = tensorHeight * tensorWidth;
|
||||
final int tensorInputOffsetG = channelSize;
|
||||
final int tensorInputOffsetB = 2 * channelSize;
|
||||
final float[] floatArray = new float[3 * channelSize];
|
||||
for (int x = 0; x < tensorWidth; x++) {
|
||||
for (int y = 0; y < tensorHeight; y++) {
|
||||
|
||||
@ -277,22 +198,13 @@ public final class TensorImageUtils {
|
||||
int r = clamp((a0 + a1) >> 10, 0, 255);
|
||||
int g = clamp((a0 - a2 - a3) >> 10, 0, 255);
|
||||
int b = clamp((a0 + a4) >> 10, 0, 255);
|
||||
final int offset = outBufferOffset + y * tensorWidth + x;
|
||||
float rF = ((r / 255.f) - normMeanRGB[0]) / normStdRGB[0];
|
||||
float gF = ((g / 255.f) - normMeanRGB[1]) / normStdRGB[1];
|
||||
float bF = ((b / 255.f) - normMeanRGB[2]) / normStdRGB[2];
|
||||
|
||||
outBuffer.put(offset, rF);
|
||||
outBuffer.put(offset + tensorInputOffsetG, gF);
|
||||
outBuffer.put(offset + tensorInputOffsetB, bF);
|
||||
final int offset = y * tensorWidth + x;
|
||||
floatArray[offset] = ((r / 255.f) - normMeanRGB[0]) / normStdRGB[0];
|
||||
floatArray[tensorInputOffsetG + offset] = ((g / 255.f) - normMeanRGB[1]) / normStdRGB[1];
|
||||
floatArray[tensorInputOffsetB + offset] = ((b / 255.f) - normMeanRGB[2]) / normStdRGB[2];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkOutBufferCapacity(FloatBuffer outBuffer, int outBufferOffset, int tensorWidth, int tensorHeight) {
|
||||
if (outBufferOffset + 3 * tensorWidth * tensorHeight > outBuffer.capacity()) {
|
||||
throw new IllegalStateException("Buffer underflow");
|
||||
}
|
||||
return Tensor.newFloat32Tensor(new long[]{1, 3, tensorHeight, tensorWidth}, floatArray);
|
||||
}
|
||||
|
||||
private static void checkTensorSize(int tensorWidth, int tensorHeight) {
|
||||
|
83
android/run_tests.sh
Executable file
83
android/run_tests.sh
Executable file
@ -0,0 +1,83 @@
|
||||
#!/bin/bash
|
||||
set -eux
|
||||
|
||||
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
|
||||
PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android
|
||||
|
||||
echo "ANDROID_HOME:$ANDROID_HOME"
|
||||
if [ ! -z "$ANDROID_HOME" ]; then
|
||||
echo "ANDROID_HOME not set; please set it to Android sdk directory"
|
||||
fi
|
||||
|
||||
if [ ! -d $ANDROID_HOME ]; then
|
||||
echo "ANDROID_HOME not a directory; did you install it under $ANDROID_HOME?"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "ANDROID_NDK:$ANDROID_NDK"
|
||||
if [ ! -z "$ANDROID_NDK" ]; then
|
||||
echo "ANDROID_NDK not set; please set it to Android sdk directory"
|
||||
fi
|
||||
|
||||
if [ ! -d $ANDROID_NDK ]; then
|
||||
echo "ANDROID_NDK not a directory; did you install it under $ANDROID_NDK?"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
GRADLE_PATH=gradle
|
||||
GRADLE_NOT_FOUND_MSG="Unable to find gradle, please add it to PATH or set GRADLE_HOME"
|
||||
|
||||
if [ ! -x "$(command -v gradle)" ]; then
|
||||
if [ -z "$GRADLE_HOME" ]; then
|
||||
echo GRADLE_NOT_FOUND_MSG
|
||||
exit 1
|
||||
fi
|
||||
GRADLE_PATH=$GRADLE_HOME/bin/gradle
|
||||
if [ ! -f "$GRADLE_PATH" ]; then
|
||||
echo GRADLE_NOT_FOUND_MSG
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
echo "GRADLE_PATH:$GRADLE_PATH"
|
||||
|
||||
|
||||
# Run android instrumented tests on x86 emulator
|
||||
|
||||
ADB_PATH=$ANDROID_HOME/platform-tools/adb
|
||||
|
||||
echo "Expecting running emulator"
|
||||
$ADB_PATH devices
|
||||
|
||||
DEVICES_COUNT=$($ADB_PATH devices | awk 'NF' | wc -l)
|
||||
echo "DEVICES_COUNT:$DEVICES_COUNT"
|
||||
|
||||
if [ "$DEVICES_COUNT" -eq 1 ]; then
|
||||
echo "Unable to found connected android emulators"
|
||||
cat <<- EOF
|
||||
To start android emulator:
|
||||
1. Install android sdkmanager packages
|
||||
$ANDROID_HOME/tools/bin/sdkmanager "system-images;android-25;google_apis;x86"
|
||||
|
||||
to specify proxy add params: --proxy=http --proxy_host=fwdproxy --proxy_port=8080
|
||||
|
||||
2. Create android virtual device
|
||||
$ANDROID_HOME/tools/bin/avdmanager create avd --name "x86_android25" --package "system-images;android-25;google_apis;x86"
|
||||
|
||||
3. Start emulator in headless mode without audio
|
||||
$ANDROID_HOME/tools/emulator -avd x86_android25 -no-audio -no-window
|
||||
|
||||
4. Check that emulator is running
|
||||
$ANDROID_HOME/platform-tools/adb devices
|
||||
|
||||
If everything is ok the output will be:
|
||||
|
||||
List of devices attached
|
||||
emulator-5554 device
|
||||
EOF
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Waiting for emulator boot completed"
|
||||
$ADB_PATH wait-for-device shell 'while [[ -z $(getprop sys.boot_completed) ]]; do sleep 1; done;'
|
||||
|
||||
$GRADLE_PATH -PABI_FILTERS=x86 -p $PYTORCH_ANDROID_DIR connectedAndroidTest
|
@ -55,11 +55,15 @@ add_subdirectory(src/THNN)
|
||||
IF(USE_ROCM)
|
||||
include(LoadHIP)
|
||||
if (NOT PYTORCH_FOUND_HIP)
|
||||
MESSAGE(FATAL_ERROR
|
||||
"Could not find HIP installation")
|
||||
set(USE_ROCM OFF)
|
||||
endif()
|
||||
ENDIF()
|
||||
|
||||
# Both CUDA and ROCM are enabled and found. Report an error.
|
||||
if(USE_CUDA AND USE_ROCM)
|
||||
message(FATAL_ERROR "Both CUDA and ROCm are enabled and found. PyTorch can only be built with either of them. Please turn one off by using either USE_CUDA=OFF or USE_ROCM=OFF.")
|
||||
endif()
|
||||
|
||||
IF(MSVC)
|
||||
# we want to respect the standard, and we are bored of those **** .
|
||||
ADD_DEFINITIONS(-D_CRT_SECURE_NO_DEPRECATE=1)
|
||||
|
@ -762,20 +762,6 @@
|
||||
output: True
|
||||
- THTensor* self
|
||||
]]
|
||||
[[
|
||||
name: _th_log10
|
||||
cname: log10
|
||||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CUDA
|
||||
variants: function
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- THTensor* self
|
||||
]]
|
||||
[[
|
||||
name: _th_log1p
|
||||
cname: log1p
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <ATen/core/OpsAlreadyMovedToC10.h>
|
||||
#include <ATen/core/Variadic.h>
|
||||
#include <ATen/core/TensorBody.h>
|
||||
#include <ATen/core/EnableNamedTensor.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
@ -64,12 +65,46 @@ namespace detail {
|
||||
}
|
||||
};
|
||||
|
||||
// NB: take by const reference (Don't do universal forwarding here! You
|
||||
// don't want to move into this function!)
|
||||
template <typename... Args>
|
||||
TensorTypeSet multi_dispatch_tensor_type_set(const Args&... args) {
|
||||
return MultiDispatchTensorTypeSet().apply(args...).ts;
|
||||
}
|
||||
}
|
||||
|
||||
using FallbackBoxedFunction = void(const char* schema, torch::jit::Stack*);
|
||||
|
||||
// Assume T is decayed
|
||||
template <typename T>
|
||||
using not_ok_to_box =
|
||||
c10::guts::disjunction<
|
||||
c10::guts::negation<
|
||||
c10::guts::disjunction<
|
||||
std::is_constructible<IValue, T>,
|
||||
// TensorOptions are not directly constructible into IValue,
|
||||
// but torch::jit::push knows how to handle them
|
||||
std::is_same<TensorOptions, T>
|
||||
>>
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
,
|
||||
// some constructors are templated (and therefore pass
|
||||
// is_constructible), but do not actually work with all
|
||||
// template arguments, so we must blacklist them explicitly
|
||||
// TODO: The correct fix is to sfinae based on is_constructible of T
|
||||
std::is_same<optional<ArrayRef<Dimname>>, T>
|
||||
#endif
|
||||
>;
|
||||
|
||||
template <class Result, class... Args>
|
||||
using supports_boxed_fallback =
|
||||
c10::guts::negation<c10::guts::disjunction<
|
||||
std::is_lvalue_reference<Result>,
|
||||
not_ok_to_box<Result>,
|
||||
std::is_same<IntArrayRef, Result>,
|
||||
not_ok_to_box<guts::decay_t<Args>>...
|
||||
>>;
|
||||
|
||||
// ATenOpTable stores the implementations for each backend, in addition to
|
||||
// an implementation for variables.
|
||||
class CAFFE2_API ATenOpTable {
|
||||
@ -77,33 +112,12 @@ class CAFFE2_API ATenOpTable {
|
||||
ATenOpTable(std::string schema)
|
||||
: schema_(std::move(schema)) {}
|
||||
|
||||
// NB: No universal forwarding
|
||||
template<class Result, class... Args>
|
||||
Result callUnboxed(Args... args) const {
|
||||
using FuncType = Result(Args...);
|
||||
TensorTypeSet ts = detail::multi_dispatch_tensor_type_set(args...);
|
||||
TensorTypeId tid = impl::dispatchTypeId(ts);
|
||||
|
||||
// You might think we can eliminate the second branch by maintaining a
|
||||
// bitmask of registered operator keys, so we don't select dispatch ids
|
||||
// which don't have implementations here. But the net effect is that if you
|
||||
// get a Variable CPUTensor, if there is no variable registration, you'll
|
||||
// fall back to the CPU implementation. Is this what you want? Unlikely...
|
||||
|
||||
auto* unboxed_fn = reinterpret_cast<FuncType*>(function_table_[static_cast<int64_t>(tid)]);
|
||||
if (C10_LIKELY(unboxed_fn != nullptr)) {
|
||||
return (*unboxed_fn)(args...);
|
||||
}
|
||||
|
||||
auto* unboxed_fallback_fn = reinterpret_cast<FuncType*>(function_table_[static_cast<int64_t>(TensorTypeId::UndefinedTensorId)]);
|
||||
if (C10_LIKELY(unboxed_fallback_fn != nullptr)) {
|
||||
return (*unboxed_fallback_fn)(args...);
|
||||
}
|
||||
|
||||
reportError(tid);
|
||||
TORCH_INTERNAL_ASSERT(0);
|
||||
}
|
||||
Result callUnboxed(Args... args) const;
|
||||
|
||||
private:
|
||||
|
||||
void registerOp(TensorTypeId tid, void* fn) {
|
||||
TORCH_CHECK(function_table_[static_cast<int64_t>(tid)] == nullptr,
|
||||
"Attempting to register function for schema ", schema_,
|
||||
@ -132,6 +146,12 @@ class CAFFE2_API ATenDispatch {
|
||||
return *this;
|
||||
}
|
||||
|
||||
ATenDispatch& registerFallbackBoxedOp(TensorTypeId id, FallbackBoxedFunction* fn) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
boxed_fallback_table_[static_cast<size_t>(id)] = fn;
|
||||
return *this;
|
||||
}
|
||||
|
||||
const ATenOpTable* getOpTable(const char* schema) const {
|
||||
auto iter = op_tables_.find(schema);
|
||||
TORCH_CHECK(iter != op_tables_.end(),
|
||||
@ -139,11 +159,91 @@ class CAFFE2_API ATenDispatch {
|
||||
return &iter->second;
|
||||
}
|
||||
|
||||
FallbackBoxedFunction* getFallbackBoxedOp(TensorTypeId tid) const {
|
||||
return boxed_fallback_table_[static_cast<size_t>(tid)];
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, ATenOpTable> op_tables_;
|
||||
FallbackBoxedFunction* boxed_fallback_table_[static_cast<int64_t>(TensorTypeId::NumTensorIds)] = {nullptr};
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
CAFFE2_API ATenDispatch& globalATenDispatch();
|
||||
|
||||
template<class Result, class... Args>
|
||||
Result callBoxedFallback(const char* schema, FallbackBoxedFunction* boxed_fallback_fn, Args&&... args,
|
||||
// NB: enable_if must occur in function parameter, because MSVC
|
||||
// doesn't like it when it's a template argument next to
|
||||
// a parameter pack
|
||||
typename c10::guts::enable_if_t<
|
||||
!supports_boxed_fallback<Result, Args...>::value,
|
||||
std::nullptr_t
|
||||
> = nullptr) {
|
||||
// This is dead because we test the SFINAE condition before calling
|
||||
// boxed_fallback_fn. A more functional way of writing this with
|
||||
// optional<Result> return value works poorly when void is involved.
|
||||
TORCH_INTERNAL_ASSERT(0);
|
||||
}
|
||||
|
||||
template<
|
||||
class Result, class... Args>
|
||||
Result callBoxedFallback(const char* schema, FallbackBoxedFunction* boxed_fallback_fn, Args&&... args,
|
||||
typename c10::guts::enable_if_t<
|
||||
supports_boxed_fallback<Result, Args...>::value,
|
||||
std::nullptr_t
|
||||
> = nullptr) {
|
||||
torch::jit::Stack stack;
|
||||
torch::jit::push(stack, std::forward<Args>(args)...);
|
||||
boxed_fallback_fn(schema, &stack);
|
||||
TORCH_INTERNAL_ASSERT(stack.size() == 1);
|
||||
return torch::jit::pop(stack).to<Result>();
|
||||
}
|
||||
|
||||
// NB: No universal forwarding
|
||||
template<class Result, class... Args>
|
||||
Result ATenOpTable::callUnboxed(Args... args) const {
|
||||
using FuncType = Result(Args...);
|
||||
// NB: No universal forwarding (takes const& only)
|
||||
TensorTypeSet ts = detail::multi_dispatch_tensor_type_set(args...);
|
||||
TensorTypeId tid = impl::dispatchTypeId(ts);
|
||||
|
||||
// You might think we can eliminate the second branch by maintaining a
|
||||
// bitmask of registered operator keys, so we don't select dispatch ids
|
||||
// which don't have implementations here. But the net effect is that if you
|
||||
// get a Variable CPUTensor, if there is no variable registration, you'll
|
||||
// fall back to the CPU implementation. Is this what you want? Unlikely...
|
||||
|
||||
auto* unboxed_fn = reinterpret_cast<FuncType*>(function_table_[static_cast<int64_t>(tid)]);
|
||||
if (C10_LIKELY(unboxed_fn != nullptr)) {
|
||||
return (*unboxed_fn)(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
// The supports_boxed_fallback condition test, and the SFINAE on
|
||||
// callBoxedFallback, do the same thing. But we perform this (compile-time)
|
||||
// test twice so that we can take advantage of the fact that return
|
||||
// func_returns_void() is OK. If we eliminated this condition in exchange
|
||||
// for having callBoxedFallback return an optional, we can't conveniently
|
||||
// handle the Result=void case anymore.
|
||||
//
|
||||
// (The SFINAE in callBoxedFallback, of course, is necessary to
|
||||
// prevent us from attempting to typecheck code that won't typecheck.)
|
||||
auto* boxed_fallback_fn = globalATenDispatch().getFallbackBoxedOp(tid);
|
||||
if (C10_UNLIKELY(boxed_fallback_fn)) {
|
||||
if (supports_boxed_fallback<Result, Args...>::value) {
|
||||
return callBoxedFallback<Result, Args...>(schema_.c_str(), boxed_fallback_fn, std::forward<Args>(args)...);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(0, schema_, " does not support boxed fallback, but boxed fallback for ", tid, " was available");
|
||||
}
|
||||
}
|
||||
|
||||
auto* unboxed_fallback_fn = reinterpret_cast<FuncType*>(function_table_[static_cast<int64_t>(TensorTypeId::UndefinedTensorId)]);
|
||||
if (C10_LIKELY(unboxed_fallback_fn != nullptr)) {
|
||||
return (*unboxed_fallback_fn)(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
reportError(tid);
|
||||
TORCH_INTERNAL_ASSERT(0);
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
|
@ -419,7 +419,6 @@ bool aten_op_is_already_moved_to_c10(const c10::OperatorName& opName) {
|
||||
{"aten::frobenius_norm", "dim"},
|
||||
{"aten::nuclear_norm", ""},
|
||||
{"aten::nuclear_norm", "dim"},
|
||||
{"aten::clone", ""},
|
||||
{"aten::resize_as_", ""},
|
||||
{"aten::pow", "Tensor_Scalar"},
|
||||
{"aten::zero_", ""},
|
||||
@ -1201,6 +1200,7 @@ bool aten_op_is_not_moved_to_c10_yet(const c10::OperatorName& opName) {
|
||||
{"aten::frobenius_norm", "out"},
|
||||
{"aten::nuclear_norm", "out"},
|
||||
{"aten::nuclear_norm", "dim_out"},
|
||||
{"aten::clone", ""},
|
||||
{"aten::pow", "Tensor_Scalar_out"},
|
||||
{"aten::sub", "out"},
|
||||
{"aten::addmm", "out"},
|
||||
|
@ -186,10 +186,15 @@ class CAFFE2_API Tensor {
|
||||
int64_t ndimension() const {
|
||||
return dim();
|
||||
}
|
||||
|
||||
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
|
||||
return impl_->is_contiguous(memory_format);
|
||||
}
|
||||
|
||||
bool is_non_overlapping_and_dense() const {
|
||||
return impl_->is_non_overlapping_and_dense();
|
||||
}
|
||||
|
||||
at::MemoryFormat suggest_memory_format() const {
|
||||
if (impl_->is_strides_like_channels_last()) {
|
||||
return at::MemoryFormat::ChannelsLast;
|
||||
@ -734,7 +739,7 @@ class CAFFE2_API Tensor {
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
Tensor norm(c10::optional<Scalar> p, DimnameList dim, bool keepdim=false) const;
|
||||
#endif
|
||||
Tensor clone() const;
|
||||
Tensor clone(c10::optional<MemoryFormat> memory_format=c10::nullopt) const;
|
||||
Tensor & resize_as_(const Tensor & the_template) const;
|
||||
Tensor pow(Scalar exponent) const;
|
||||
Tensor & zero_() const;
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/interned_strings.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
@ -322,15 +323,7 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
|
||||
if (arg.default_value()) {
|
||||
out << "=";
|
||||
if (arg.type()->kind() == c10::TypeKind::StringType) {
|
||||
// TODO prettify the result, such as using \n to represent \012
|
||||
out << "\'";
|
||||
std::ios_base::fmtflags flags(out.flags());
|
||||
for (unsigned char c : arg.default_value().value().toStringRef()) {
|
||||
out << "\\" << std::oct << std::setfill('0') << std::setw(3)
|
||||
<< static_cast<uint64_t>(c);
|
||||
}
|
||||
out.flags(flags);
|
||||
out << "\'";
|
||||
printQuotedString(out, arg.default_value().value().toStringRef());
|
||||
} else {
|
||||
out << arg.default_value().value();
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ namespace c10 {
|
||||
_(prim, IgnoredPythonOp) \
|
||||
_(prim, Reverse) \
|
||||
_(prim, Return) \
|
||||
_(prim, ReturnStmt) \
|
||||
_(prim, ReturnStmt) \
|
||||
_(prim, BreakStmt) \
|
||||
_(prim, ContinueStmt) \
|
||||
_(prim, Store) \
|
||||
@ -93,6 +93,7 @@ namespace c10 {
|
||||
_(prim, enumerate) \
|
||||
_(prim, range) \
|
||||
_(prim, rangelist) \
|
||||
_(prim, isinstance) \
|
||||
_(aten, _grad_sum_to_size) \
|
||||
_(aten, _size_if_not_equal) \
|
||||
_(aten, _ncf_unsqueeze) \
|
||||
@ -221,7 +222,9 @@ namespace c10 {
|
||||
_(attr, beg) \
|
||||
_(attr, idx) \
|
||||
_(attr, split) \
|
||||
_(attr, slot)
|
||||
_(attr, slot) \
|
||||
_(attr, kinds) \
|
||||
_(attr, types)
|
||||
#else
|
||||
#define FORALL_NS_SYMBOLS(_) \
|
||||
_(namespaces, prim) \
|
||||
|
@ -1,8 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/blob.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <ATen/core/TensorBody.h>
|
||||
#include <ATen/core/blob.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
namespace torch {
|
||||
@ -172,6 +173,16 @@ struct CAFFE2_API IValue final {
|
||||
|
||||
// Tuple
|
||||
IValue(c10::intrusive_ptr<ivalue::Tuple> v);
|
||||
|
||||
template <
|
||||
typename... Args,
|
||||
c10::guts::enable_if_t<
|
||||
!c10::guts::disjunction<
|
||||
std::is_lvalue_reference<Args>...,
|
||||
c10::guts::negation<std::is_constructible<IValue, Args>>...>::
|
||||
value,
|
||||
std::nullptr_t> = nullptr>
|
||||
IValue(const std::tuple<Args...>& t);
|
||||
bool isTuple() const { return Tag::Tuple == tag; }
|
||||
c10::intrusive_ptr<ivalue::Tuple> toTuple() &&;
|
||||
c10::intrusive_ptr<ivalue::Tuple> toTuple() const &;
|
||||
|
@ -150,6 +150,11 @@ struct CAFFE2_API Tuple : c10::intrusive_ptr_target {
|
||||
return c10::make_intrusive<Tuple>(std::move(elements_));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
static c10::intrusive_ptr<Tuple> create(Args... elements_) {
|
||||
return c10::make_intrusive<Tuple>(std::vector<IValue>{IValue(elements_)...});
|
||||
}
|
||||
|
||||
const std::vector<IValue>& elements() const & {
|
||||
return elements_;
|
||||
}
|
||||
@ -522,6 +527,30 @@ c10::optional<T> generic_to(
|
||||
return std::move(ivalue).to<T>();
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
template <typename Tuple, std::size_t... I>
|
||||
Tuple generic_to_tuple_impl(
|
||||
const std::vector<IValue>& t,
|
||||
c10::guts::index_sequence<I...>) {
|
||||
return std::make_tuple(
|
||||
t[I].to<typename std::tuple_element<I, Tuple>::type>()...);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename... Args,
|
||||
typename Indices = c10::guts::make_index_sequence<sizeof...(Args)>,
|
||||
c10::guts::enable_if_t<
|
||||
!c10::guts::disjunction<
|
||||
std::is_lvalue_reference<Args>...,
|
||||
c10::guts::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||
std::nullptr_t> = nullptr>
|
||||
std::tuple<Args...> generic_to(IValue ivalue, _fake_type<std::tuple<Args...>>) {
|
||||
auto vals = ivalue.toTuple()->elements();
|
||||
TORCH_CHECK(vals.size() == sizeof...(Args));
|
||||
return detail::generic_to_tuple_impl<std::tuple<Args...>>(vals, Indices{});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T IValue::to() && {
|
||||
return generic_to(std::move(*this), _fake_type<T>{});
|
||||
@ -609,6 +638,17 @@ inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v)
|
||||
: tag(Tag::Tuple), is_intrusive_ptr(true) {
|
||||
payload.as_intrusive_ptr = v.release();
|
||||
}
|
||||
template <
|
||||
typename... Args,
|
||||
c10::guts::enable_if_t<
|
||||
!c10::guts::disjunction<
|
||||
std::is_lvalue_reference<Args>...,
|
||||
c10::guts::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||
std::nullptr_t>>
|
||||
inline IValue::IValue(const std::tuple<Args...>& t)
|
||||
: IValue(
|
||||
std::move(c10::guts::apply(c10::ivalue::Tuple::create<Args...>, t))) {
|
||||
}
|
||||
inline IValue::IValue(c10::List<int64_t> v)
|
||||
: tag(Tag::IntList), is_intrusive_ptr(true) {
|
||||
payload.as_intrusive_ptr = v.impl_.release();
|
||||
|
@ -1458,6 +1458,28 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
TypePtr type,
|
||||
bool is_parameter = false);
|
||||
|
||||
// Add attribute \p NAME if it doesn't exist or verify that it has a
|
||||
// compatible type otherwise.
|
||||
size_t addOrCheckAttribute(
|
||||
const std::string& name,
|
||||
TypePtr ty,
|
||||
bool is_parameter = false) {
|
||||
auto slot_idx = findAttributeSlot(name);
|
||||
if (!slot_idx) {
|
||||
return addAttribute(name, ty, is_parameter);
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
is_parameter == this->is_parameter(*slot_idx),
|
||||
"Parameter field mismatch for the field '",
|
||||
name,
|
||||
"'");
|
||||
TypePtr atype = getAttribute(*slot_idx);
|
||||
TORCH_CHECK(ty->isSubtypeOf(atype));
|
||||
return *slot_idx;
|
||||
}
|
||||
|
||||
|
||||
at::ArrayRef<std::string> attributeNames() const {
|
||||
return attributeNames_;
|
||||
}
|
||||
|
@ -77,13 +77,24 @@ static inline void pop(Stack& stack, Types&... args) {
|
||||
(void)result;
|
||||
drop(stack, N);
|
||||
}
|
||||
template <typename Type>
|
||||
static inline void push_one(Stack& stack, Type&& arg) {
|
||||
stack.emplace_back(std::forward<Type>(arg));
|
||||
}
|
||||
|
||||
static inline void push_one(Stack& stack, c10::TensorOptions options) {
|
||||
stack.emplace_back(c10::typeMetaToScalarType(options.dtype()));
|
||||
stack.emplace_back(options.layout());
|
||||
stack.emplace_back(options.device());
|
||||
stack.emplace_back(options.pinned_memory());
|
||||
}
|
||||
|
||||
template <typename... Types>
|
||||
static inline void push(Stack& stack, Types&&... args) {
|
||||
(void)std::initializer_list<int>{(stack.emplace_back(std::forward<Types>(args)), 0)...};
|
||||
(void)std::initializer_list<int>{(push_one(stack, args), 0)...};
|
||||
}
|
||||
template <class T>
|
||||
static inline void push_list_elements(Stack& stack, const c10::List<T>& elements) {
|
||||
stack.reserve(stack.size() + elements.size());
|
||||
for (T elem : elements) {
|
||||
stack.push_back(std::move(elem));
|
||||
}
|
||||
|
@ -239,10 +239,17 @@ static inline ${return_type} ${api_name}(${formals}) {
|
||||
|
||||
# In order to rely on the linker to strip unused ops, it requires us to dispatch statically
|
||||
# in Functions.h and TensorMethods.h.
|
||||
#
|
||||
# NB: The default body also needs to apply a variable guard, as in some
|
||||
# situations what we think is a default body actually does have an
|
||||
# explicit derivative, and thereby would have gotten unwrapped by
|
||||
# the time you get to the implementation.
|
||||
STATIC_DISPATCH_FUNCTION_DEFAULT_BODY = CodeTemplate("""\
|
||||
at::AutoNonVariableTypeMode _var_guard(true);
|
||||
${return_call} TypeDefault::${native_type_method_dispatch}(${native_arguments});
|
||||
""")
|
||||
STATIC_DISPATCH_FUNCTION_SWITCH_BODY = CodeTemplate("""\
|
||||
at::AutoNonVariableTypeMode _var_guard(true);
|
||||
switch(tensorTypeIdToBackend(impl::dispatchTypeId(${type_set}))) {
|
||||
${static_dispatch_function_switches}
|
||||
default:
|
||||
|
@ -31,7 +31,7 @@ Tensor cosine_embedding_loss(const Tensor& input1, const Tensor& input2, const T
|
||||
auto denom = (mag_square1 * mag_square2).sqrt_();
|
||||
auto cos = prod_sum / denom;
|
||||
|
||||
auto zeros = at::zeros_like(cos);
|
||||
auto zeros = at::zeros_like(target);
|
||||
auto pos = 1 - cos;
|
||||
auto neg = (cos - margin).clamp_min_(0);
|
||||
auto output_pos = at::where(target == 1, pos, zeros);
|
||||
@ -67,8 +67,8 @@ Tensor margin_ranking_loss(const Tensor& input1, const Tensor& input2, const Ten
|
||||
}
|
||||
|
||||
Tensor kl_div(const Tensor& input, const Tensor& target, int64_t reduction) {
|
||||
auto zeros = at::zeros_like(target);
|
||||
auto output_pos = target * (at::log(target) - input);
|
||||
auto zeros = at::zeros_like(output_pos);
|
||||
auto output = at::where(target > 0, output_pos, zeros);
|
||||
return apply_loss_reduction(output, reduction);
|
||||
}
|
||||
|
@ -55,67 +55,53 @@ bool _nnpack_available() {
|
||||
|
||||
#include "nnpack.h"
|
||||
|
||||
#include "caffe2/utils/threadpool/ThreadPoolMobile.h"
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <ATen/Parallel.h>
|
||||
#include <thread>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
static bool init_nnpack() {
|
||||
static std::once_flag once_;
|
||||
static bool nnpack_successfully_initialized_ = false;
|
||||
|
||||
std::call_once(once_, []() {
|
||||
const nnp_status nnpack_status = nnp_initialize();
|
||||
nnpack_successfully_initialized_ = (nnp_status_success == nnpack_status);
|
||||
// Stolen from Caffe2
|
||||
static pthreadpool_t nnpack_threadpool_ = nullptr;
|
||||
static bool called_nnpack_threadpool_ = false;
|
||||
|
||||
pthreadpool_t nnpack_threadpool() {
|
||||
if (! called_nnpack_threadpool_) {
|
||||
called_nnpack_threadpool_ = true;
|
||||
enum nnp_status nnpack_status = nnp_initialize();
|
||||
if (nnpack_status != nnp_status_success) {
|
||||
if (nnpack_status == nnp_status_out_of_memory) {
|
||||
LOG(WARNING) << "Could not initialize NNPACK! Reason: Out of memory.";
|
||||
throw std::runtime_error("could not initialize NNPack (out of memory)");
|
||||
} else if (nnpack_status == nnp_status_unsupported_hardware) {
|
||||
LOG(WARNING) << "Could not initialize NNPACK! Reason: Unsupported hardware.";
|
||||
throw std::runtime_error("could not initialize NNPack (unsupported hardware)");
|
||||
} else {
|
||||
LOG(WARNING) << "Could not initialize NNPACK! Reason: Unknown error!";
|
||||
throw std::runtime_error("could not initialize NNPack (unknown error)");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return nnpack_successfully_initialized_;
|
||||
}
|
||||
|
||||
static pthreadpool_t nnpack_threadpool() {
|
||||
// Try initializing a threadpool for NNPACK's use. If we fail to
|
||||
// successfully initialize an implementation, return nullptr which will
|
||||
// instruct NNPACK to run single threaded.
|
||||
|
||||
#ifdef C10_MOBILE
|
||||
// If building for mobile, use Caffe 2's mobile-friendly threadpool.
|
||||
return caffe2::mobile_pthreadpool();
|
||||
#else
|
||||
// Otherwise, try using pthreadpool if we manage to initialize it successfully.
|
||||
static pthreadpool_t nnpack_threadpool_ = nullptr;
|
||||
static bool called_nnpack_threadpool_ = false;
|
||||
|
||||
if (!called_nnpack_threadpool_) {
|
||||
called_nnpack_threadpool_ = true;
|
||||
|
||||
unsigned int threads;
|
||||
#ifdef INTRA_OP_PARALLEL
|
||||
const uint32_t threads = at::get_num_threads();
|
||||
threads = at::get_num_threads();
|
||||
#else
|
||||
const uint32_t threads = std::thread::hardware_concurrency();
|
||||
threads = std::thread::hardware_concurrency();
|
||||
#endif
|
||||
|
||||
nnpack_threadpool_ = pthreadpool_create(threads);
|
||||
if ( !nnpack_threadpool_ ) {
|
||||
LOG(WARNING) << "Failed to initialize pthreadpool! Running NNPACK in single-threaded mode.";
|
||||
if (nnpack_threadpool_ == nullptr) {
|
||||
throw std::runtime_error("could not initialize NNPack's pthreadpool");
|
||||
}
|
||||
}
|
||||
|
||||
return nnpack_threadpool_;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool _nnpack_available() {
|
||||
return init_nnpack();
|
||||
if (! called_nnpack_threadpool_) {
|
||||
try {
|
||||
return nnpack_threadpool() != nullptr;
|
||||
} catch (std::runtime_error e) {
|
||||
}
|
||||
}
|
||||
return nnpack_threadpool() != nullptr;
|
||||
}
|
||||
|
||||
// Make thread_local for safety in cases where we have multiple threads running
|
||||
|
@ -122,7 +122,6 @@ std::vector<Tensor> where(const Tensor& condition) {
|
||||
}
|
||||
|
||||
Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& other) {
|
||||
TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
|
||||
Tensor ret = at::empty(self.sizes(), self.options());
|
||||
AT_DISPATCH_ALL_TYPES(ret.scalar_type(), "where_cpu", [&] {
|
||||
where_cpu<scalar_t>(ret, condition, self, other);
|
||||
|
@ -867,8 +867,20 @@ Tensor from_file(std::string filename, c10::optional<bool> shared, c10::optional
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ clone ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Tensor clone(const Tensor& src) {
|
||||
auto self = at::empty_like(src);
|
||||
Tensor clone(const Tensor& src, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
auto memory_format =
|
||||
optional_memory_format.value_or(MemoryFormat::Contiguous);
|
||||
if (memory_format == MemoryFormat::Preserve) {
|
||||
if (src.is_non_overlapping_and_dense()) {
|
||||
// Copy all strides
|
||||
auto self = at::empty_strided(src.sizes(), src.strides(), src.options());
|
||||
self.copy_(src);
|
||||
return self;
|
||||
} else {
|
||||
memory_format = src.suggest_memory_format();
|
||||
}
|
||||
}
|
||||
auto self = at::empty_like(src, src.options(), memory_format);
|
||||
self.copy_(src);
|
||||
return self;
|
||||
}
|
||||
|
@ -102,13 +102,11 @@ std::tuple<Device, ScalarType> TensorIterator::compute_common_type() {
|
||||
return compute_common_type_(operands_);
|
||||
}
|
||||
|
||||
static void validate_dtype(OperandInfo& op, ScalarType common_dtype, CommonDTypeStrategy strategy) {
|
||||
bool check_output = strategy == CommonDTypeStrategy::CHECK || strategy == CommonDTypeStrategy::PROMOTE;
|
||||
bool promoting = strategy == CommonDTypeStrategy::PROMOTE || (strategy == CommonDTypeStrategy::PROMOTE_INPUTS && !op.is_output);
|
||||
static void validate_dtype(OperandInfo& op, ScalarType common_dtype, int ninputs) {
|
||||
if (op.tensor.defined()) {
|
||||
// For binary_ops, we follow casting rules. For unary/nullary types
|
||||
// we require the type to match.
|
||||
if (op.is_output && check_output) {
|
||||
if (op.is_output) {
|
||||
if (!canCast(common_dtype, op.tensor.scalar_type()))
|
||||
{
|
||||
AT_ERROR("result type ", common_dtype,
|
||||
@ -116,7 +114,7 @@ static void validate_dtype(OperandInfo& op, ScalarType common_dtype, CommonDType
|
||||
op.tensor.scalar_type());
|
||||
}
|
||||
}
|
||||
if (!promoting && op.dtype != op.tensor.scalar_type()) {
|
||||
if (ninputs < 2 && op.dtype != op.tensor.scalar_type()) {
|
||||
AT_ERROR("expected dtype ", op.dtype, " but got dtype ", op.tensor.scalar_type());
|
||||
}
|
||||
}
|
||||
@ -133,6 +131,14 @@ static void maybe_promote_common_dtype(OperandInfo& op, ScalarType common_dtype)
|
||||
op.tensor =
|
||||
at::empty_like(op.tensor, op.tensor.options().dtype(common_dtype));
|
||||
}
|
||||
auto original_element_size = op.original_tensor.element_size();
|
||||
auto new_element_size = op.tensor.element_size();
|
||||
|
||||
// stride size (in bytes) can change if we change the dtype.
|
||||
for( size_t i=0; i < op.stride_bytes.size(); i++ ) {
|
||||
auto stride = op.stride_bytes[i] / original_element_size;
|
||||
op.stride_bytes[i] = stride * new_element_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -149,12 +155,12 @@ void TensorIterator::compute_types() {
|
||||
}
|
||||
}
|
||||
|
||||
if (common_dtype_strategy_ == CommonDTypeStrategy::PROMOTE_INPUTS) {
|
||||
if (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_INPUTS) {
|
||||
TORCH_CHECK(!missing_output_dtypes, "unable to compute and promote common dtype based only on inputs if there are missing dtypes for outputs");
|
||||
}
|
||||
|
||||
bool compute_common_dtype = (common_dtype_strategy_ != CommonDTypeStrategy::NONE);
|
||||
bool compute_common_dtype_only_for_inputs = (common_dtype_strategy_ == CommonDTypeStrategy::PROMOTE_INPUTS);
|
||||
bool compute_common_dtype = (compute_common_dtype_strategy_ != CommonDTypeStrategy::COMPUTE_NONE);
|
||||
bool compute_common_dtype_only_for_inputs = (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_INPUTS);
|
||||
|
||||
if (missing_dtypes || compute_common_dtype) {
|
||||
auto operands = compute_common_dtype_only_for_inputs ? at::ArrayRef<OperandInfo>(operands_).slice(noutputs()) : operands_;
|
||||
@ -192,25 +198,24 @@ void TensorIterator::compute_types() {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!compute_common_dtype_only_for_inputs) {
|
||||
validate_dtype(op, common_dtype, ninputs());
|
||||
}
|
||||
if (!compute_common_dtype_only_for_inputs || !op.is_output) {
|
||||
maybe_promote_common_dtype(op, common_dtype);
|
||||
}
|
||||
|
||||
for (auto &op : operands_) {
|
||||
validate_dtype(op, common_dtype, common_dtype_strategy_);
|
||||
if (compute_common_dtype && (!compute_common_dtype_only_for_inputs || !op.is_output)) {
|
||||
maybe_promote_common_dtype(op, common_dtype);
|
||||
}
|
||||
|
||||
if (op.tensor.defined() && op.device != op.tensor.device()) {
|
||||
if (op.is_output) {
|
||||
AT_ERROR("output with device ", op.tensor.device(),
|
||||
" doesn't match the desired device ", op.device);
|
||||
} else if (op.tensor.dim() == 0) {
|
||||
op.tensor = op.tensor.to(op.options());
|
||||
} else {
|
||||
AT_ERROR("expected device ", op.device,
|
||||
" but got device ", op.tensor.device());
|
||||
if (op.tensor.defined() && op.device != op.tensor.device()) {
|
||||
if (op.is_output) {
|
||||
AT_ERROR("output with device ", op.tensor.device(),
|
||||
" doesn't match the desired device ", op.device);
|
||||
} else if (op.tensor.dim() == 0) {
|
||||
op.tensor = op.tensor.to(op.options());
|
||||
} else {
|
||||
AT_ERROR("expected device ", op.device,
|
||||
" but got device ", op.tensor.device());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -597,7 +602,6 @@ TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a,
|
||||
iter.add_input(a);
|
||||
iter.add_input(b);
|
||||
iter.allow_cpu_scalars_ = true;
|
||||
iter.promote_common_dtype();
|
||||
iter.build();
|
||||
return iter;
|
||||
}
|
||||
@ -823,12 +827,12 @@ void TensorIterator::build() {
|
||||
#endif
|
||||
// compute the broadcasted shape
|
||||
compute_shape();
|
||||
// compute the result dtype and device
|
||||
compute_types();
|
||||
// compute each tensor's stride after broadcasting
|
||||
compute_strides();
|
||||
// re-order dimensions to improve coalescing
|
||||
reorder_dimensions();
|
||||
// compute the result dtype and device
|
||||
compute_types();
|
||||
// allocate the output tensor if it's not provided
|
||||
allocate_outputs();
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
|
@ -126,10 +126,9 @@ struct CAFFE2_API OperandInfo {
|
||||
struct SplitUntil32Bit;
|
||||
|
||||
enum class CommonDTypeStrategy : uint8_t {
|
||||
NONE, // Do not compute a common dtype
|
||||
CHECK, // Compute and validate a common dtype but don't promote.
|
||||
PROMOTE_INPUTS, // Promote common dtype but only validate inputs (comparison ops have boolean output)
|
||||
PROMOTE // Promote to common dtype.
|
||||
COMPUTE_ALL = 0, // Compute common dtype based on inputs and outputs. Try to promote common dtype to inputs and outputs
|
||||
COMPUTE_INPUTS = 1, // Compute common dtype based only on inputs. Try to promote common dtype only to inputs
|
||||
COMPUTE_NONE = 2, // Do not compute and promote common dtype
|
||||
};
|
||||
|
||||
struct CAFFE2_API TensorIterator {
|
||||
@ -205,7 +204,7 @@ struct CAFFE2_API TensorIterator {
|
||||
}
|
||||
|
||||
void cast_outputs() {
|
||||
if (common_dtype_strategy_ == CommonDTypeStrategy::PROMOTE) {
|
||||
if (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_ALL) {
|
||||
for(int i=0; i < noutputs(); i++) {
|
||||
if (operands_[i].original_tensor.defined() && dtype(i) != operands_[i].original_tensor.scalar_type()) {
|
||||
operands_[i].original_tensor.copy_(operands_[i].tensor);
|
||||
@ -307,16 +306,12 @@ struct CAFFE2_API TensorIterator {
|
||||
operands_.emplace_back(input, device, dtype);
|
||||
}
|
||||
|
||||
void promote_common_dtype() {
|
||||
common_dtype_strategy_ = CommonDTypeStrategy::PROMOTE;
|
||||
}
|
||||
|
||||
void dont_compute_common_dtype() {
|
||||
common_dtype_strategy_ = CommonDTypeStrategy::NONE;
|
||||
compute_common_dtype_strategy_ = CommonDTypeStrategy::COMPUTE_NONE;
|
||||
}
|
||||
|
||||
void compute_common_dtype_only_for_inputs() {
|
||||
common_dtype_strategy_ = CommonDTypeStrategy::PROMOTE_INPUTS;
|
||||
compute_common_dtype_strategy_ = CommonDTypeStrategy::COMPUTE_INPUTS;
|
||||
}
|
||||
|
||||
void dont_resize_outputs() {
|
||||
@ -349,7 +344,7 @@ protected:
|
||||
#endif
|
||||
SmallVector<OperandInfo, 4> operands_;
|
||||
int num_outputs_ = 0;
|
||||
CommonDTypeStrategy common_dtype_strategy_ = CommonDTypeStrategy::CHECK;
|
||||
CommonDTypeStrategy compute_common_dtype_strategy_ = CommonDTypeStrategy::COMPUTE_ALL;
|
||||
bool has_coalesced_dimensions_ = false;
|
||||
bool accumulate_ = false;
|
||||
bool resize_outputs_ = true;
|
||||
|
@ -77,6 +77,10 @@ Tensor& log_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(r
|
||||
Tensor log(const Tensor& self) { return unary_op_impl(self, at::log_out); }
|
||||
Tensor& log_(Tensor& self) { return unary_op_impl_(self, at::log_out); }
|
||||
|
||||
Tensor& log10_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log10_stub); }
|
||||
Tensor log10(const Tensor& self) { return unary_op_impl(self, at::log10_out); }
|
||||
Tensor& log10_(Tensor& self) { return unary_op_impl_(self, at::log10_out); }
|
||||
|
||||
Tensor& round_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, round_stub); }
|
||||
Tensor round(const Tensor& self) { return unary_op_impl(self, at::round_out); }
|
||||
Tensor& round_(Tensor& self) { return unary_op_impl_(self, at::round_out); }
|
||||
@ -269,7 +273,6 @@ IMPLEMENT_UNARY_OP_VEC(erfc)
|
||||
IMPLEMENT_UNARY_OP_VEC_CUDA(erfinv)
|
||||
IMPLEMENT_UNARY_OP_VEC(exp)
|
||||
IMPLEMENT_UNARY_OP_VEC(frac)
|
||||
IMPLEMENT_UNARY_OP_VEC(log10)
|
||||
IMPLEMENT_UNARY_OP_VEC(log1p)
|
||||
IMPLEMENT_UNARY_OP_VEC(log2)
|
||||
IMPLEMENT_UNARY_OP_VEC(reciprocal)
|
||||
|
@ -5,34 +5,58 @@ namespace at { namespace native { namespace {
|
||||
// n: number of function arguments (arity)
|
||||
// traits: function_traits (see FunctionTraits.h)
|
||||
// s: index of scalar argument or -1
|
||||
template <int n, typename traits, int s=-1>
|
||||
template <int n, int stride_index, typename traits, int s=-1>
|
||||
struct IsContiguous {
|
||||
static bool eval(const int64_t* strides) {
|
||||
using type = typename traits::template arg<n - 1>::type;
|
||||
return strides[n] == (s == n ? 0 : sizeof(type)) &&
|
||||
IsContiguous<n - 1, traits, s>::eval(strides);
|
||||
return strides[stride_index] == (s == n ? 0 : sizeof(type)) &&
|
||||
IsContiguous<n - 1, stride_index - 1, traits, s>::eval(strides);
|
||||
}
|
||||
};
|
||||
|
||||
// will be called when there is an output exists
|
||||
template <typename traits, int s>
|
||||
struct IsContiguous<0, traits, s> {
|
||||
struct IsContiguous<0, 0, traits, s> {
|
||||
static bool eval(const int64_t* strides) {
|
||||
return strides[0] == sizeof(typename traits::result_type);
|
||||
}
|
||||
};
|
||||
|
||||
// will be called when there is no output
|
||||
template <typename traits, int s>
|
||||
struct IsContiguous<0, -1, traits, s> {
|
||||
static bool eval(const int64_t* strides) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// output and all inputs are contiguous
|
||||
template <typename traits>
|
||||
template <typename traits,
|
||||
typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr>
|
||||
static inline bool is_contiguous(const int64_t* strides) {
|
||||
return IsContiguous<traits::arity, traits>::eval(strides);
|
||||
return IsContiguous<traits::arity, traits::arity - 1, traits>::eval(strides);
|
||||
}
|
||||
|
||||
template <typename traits,
|
||||
typename std::enable_if<not std::is_void<typename traits::result_type>::value>::type* = nullptr>
|
||||
static inline bool is_contiguous(const int64_t* strides) {
|
||||
return IsContiguous<traits::arity, traits::arity, traits>::eval(strides);
|
||||
}
|
||||
|
||||
// input at `s` is scalar (stride 0); output and other inputs are contiguous
|
||||
// NB: output is typically at strides[0] so first input corresponds to s=1
|
||||
template <typename traits, int s>
|
||||
template <typename traits, int s,
|
||||
typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr>
|
||||
static inline bool is_contiguous_scalar(const int64_t* strides) {
|
||||
static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
|
||||
return IsContiguous<traits::arity, traits, s>::eval(strides);
|
||||
return IsContiguous<traits::arity, traits::arity - 1, traits, s>::eval(strides);
|
||||
}
|
||||
|
||||
template <typename traits, int s,
|
||||
typename std::enable_if<not std::is_void<typename traits::result_type>::value>::type* = nullptr>
|
||||
static inline bool is_contiguous_scalar(const int64_t* strides) {
|
||||
static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
|
||||
return IsContiguous<traits::arity, traits::arity, traits, s>::eval(strides);
|
||||
}
|
||||
|
||||
}}}
|
||||
|
@ -80,13 +80,40 @@ dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& o
|
||||
return dereference_vec_impl<traits>(data, opt_scalar, S, i, Indices{});
|
||||
}
|
||||
|
||||
template <typename func_t,
|
||||
typename std::enable_if<not std::is_void<typename function_traits<func_t>::result_type>::value>::type* = nullptr>
|
||||
static inline void
|
||||
execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t op) {
|
||||
using traits = function_traits<func_t>;
|
||||
using result_type = typename traits::result_type;
|
||||
for (; i < n; i++) {
|
||||
result_type* out_ptr = (result_type*)(data[0] + i * strides[0]);
|
||||
*out_ptr = c10::guts::apply(op, dereference<traits>(
|
||||
&data[1],
|
||||
&strides[1],
|
||||
i));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename func_t,
|
||||
typename std::enable_if<std::is_void<typename function_traits<func_t>::result_type>::value>::type* = nullptr>
|
||||
static inline void
|
||||
execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t op) {
|
||||
using traits = function_traits<func_t>;
|
||||
for (; i < n; i++) {
|
||||
c10::guts::apply(op, dereference<traits>(
|
||||
&data[0],
|
||||
&strides[0],
|
||||
i));
|
||||
}
|
||||
}
|
||||
|
||||
// Basic loop operation (one output, N inputs). May be auto-vectorized
|
||||
// by the compiler. Supports inputs and outputs of different types.
|
||||
template <typename func_t>
|
||||
static inline void
|
||||
basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t op) {
|
||||
using traits = function_traits<func_t>;
|
||||
using result_type = typename traits::result_type;
|
||||
constexpr int ntensors = traits::arity + 1;
|
||||
|
||||
// Copying strides to temporary array helps auto vectorization in older GCC
|
||||
@ -96,13 +123,7 @@ basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_
|
||||
strides[arg] = strides_[arg];
|
||||
}
|
||||
|
||||
for (; i < n; i++) {
|
||||
result_type* out_ptr = (result_type*)(data[0] + i * strides[0]);
|
||||
*out_ptr = c10::guts::apply(op, dereference<traits>(
|
||||
&data[1],
|
||||
&strides[1],
|
||||
i));
|
||||
}
|
||||
execute_op(data, strides, i, n, op);
|
||||
}
|
||||
|
||||
// Explicitly vectorized loop implementation. All inputs and outputs must be
|
||||
@ -205,7 +226,8 @@ void cpu_kernel_vec(TensorIterator& iter, func_t op, vec_func_t vop) {
|
||||
template <typename func_t>
|
||||
void cpu_serial_kernel(TensorIterator& iter, func_t op) {
|
||||
using traits = function_traits<func_t>;
|
||||
TORCH_INTERNAL_ASSERT(iter.ntensors() >= traits::arity + 1);
|
||||
TORCH_INTERNAL_ASSERT((std::is_void<typename traits::result_type>::value &&
|
||||
iter.noutputs() == 0 && iter.ntensors() == traits::arity) || (iter.ntensors() >= traits::arity + 1));
|
||||
|
||||
iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) {
|
||||
if (is_contiguous<traits>(strides)) {
|
||||
@ -217,6 +239,7 @@ void cpu_serial_kernel(TensorIterator& iter, func_t op) {
|
||||
});
|
||||
}
|
||||
}, {0, iter.numel()});
|
||||
iter.cast_outputs();
|
||||
}
|
||||
|
||||
}}} // namespace at::native::<anonymous>
|
||||
|
@ -171,9 +171,8 @@ void avg_pool2d_out_cuda_template(
|
||||
|
||||
output.resize_({nbatch, nInputPlane, outputHeight, outputWidth});
|
||||
|
||||
const int32_t count = safe_downcast<int32_t, int64_t>(output.numel());
|
||||
const uint32_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
const uint32_t num_blocks = cuda::ATenCeilDiv<uint32_t>(count, num_threads);
|
||||
const int count = safe_downcast<int, int64_t>(output.numel());
|
||||
const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
|
||||
if (divisor_override.has_value()) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
|
||||
@ -185,7 +184,7 @@ void avg_pool2d_out_cuda_template(
|
||||
scalar_t *input_data = input.data_ptr<scalar_t>();
|
||||
|
||||
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t, false, true>
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
<<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
input_data,
|
||||
nbatch,
|
||||
@ -210,7 +209,7 @@ void avg_pool2d_out_cuda_template(
|
||||
scalar_t *input_data = input.data_ptr<scalar_t>();
|
||||
|
||||
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t, true, false>
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
<<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
input_data,
|
||||
nbatch,
|
||||
@ -234,7 +233,7 @@ void avg_pool2d_out_cuda_template(
|
||||
scalar_t *input_data = input.data_ptr<scalar_t>();
|
||||
|
||||
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t, false, false>
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
<<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
input_data,
|
||||
nbatch,
|
||||
@ -250,8 +249,10 @@ void avg_pool2d_out_cuda_template(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
TORCH_CHECK(cudaGetLastError() == cudaSuccess,
|
||||
"avg_pool2d_out_cuda_frame failed with error code ",
|
||||
cudaGetLastError());
|
||||
|
||||
if (input.ndimension() == 3) {
|
||||
output.resize_({nInputPlane, outputHeight, outputWidth});
|
||||
@ -321,9 +322,8 @@ Tensor& avg_pool2d_backward_out_cuda_template(
|
||||
|
||||
gradInput.resize_as_(input);
|
||||
|
||||
const int32_t count = safe_downcast<int32_t, int64_t>(input.numel());
|
||||
const uint32_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
const uint32_t num_blocks = cuda::ATenCeilDiv<uint32_t>(count, num_threads);
|
||||
const int count = safe_downcast<int, int64_t>(input.numel());
|
||||
const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
|
||||
if (divisor_override.has_value()) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
|
||||
@ -335,7 +335,7 @@ Tensor& avg_pool2d_backward_out_cuda_template(
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
|
||||
avg_pool2d_backward_out_cuda_frame<scalar_t, accscalar_t, false, true>
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
<<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
gradOutput_data,
|
||||
nbatch,
|
||||
@ -360,7 +360,7 @@ Tensor& avg_pool2d_backward_out_cuda_template(
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
|
||||
avg_pool2d_backward_out_cuda_frame<scalar_t, accscalar_t, true, false>
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
<<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
gradOutput_data,
|
||||
nbatch,
|
||||
@ -384,7 +384,7 @@ Tensor& avg_pool2d_backward_out_cuda_template(
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
|
||||
avg_pool2d_backward_out_cuda_frame<scalar_t, accscalar_t, false, false>
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
<<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
gradOutput_data,
|
||||
nbatch,
|
||||
@ -400,7 +400,9 @@ Tensor& avg_pool2d_backward_out_cuda_template(
|
||||
}
|
||||
}
|
||||
|
||||
THCudaCheck(cudaGetLastError());
|
||||
TORCH_CHECK(cudaGetLastError() == cudaSuccess,
|
||||
"avg_pool2d_backward_out_cuda failed with error code ",
|
||||
cudaGetLastError());
|
||||
|
||||
return gradInput;
|
||||
}
|
||||
|
@ -74,7 +74,6 @@ IMPLEMENT_UNARY_OP_PREQUEL(erf)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(erfc)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(exp)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(frac)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(log10)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(log1p)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(log2)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(reciprocal)
|
||||
|
@ -428,7 +428,7 @@ ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_da
|
||||
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
|
||||
int64_t batch_size, int64_t num_labels, int64_t BLANK, bool zero_infinity) {
|
||||
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
int64_t s = threadIdx.x + blockIdx.x * blockDim.x; // note, this directly indexes into targets, not targets prime!
|
||||
int64_t s = threadIdx.x + blockIdx.x * blockDim.y; // note, this directly indexes into targets, no targets prime!
|
||||
|
||||
if (b >= batch_size)
|
||||
return;
|
||||
|
@ -141,80 +141,50 @@ void or_kernel_cuda(TensorIterator& iter) {
|
||||
}), false);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename acc_t=scalar_t>
|
||||
template <typename scalar_t>
|
||||
void max_values_kernel_cuda_impl(TensorIterator& iter) {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter, func_wrapper<acc_t> ([]GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
|
||||
return (THCNumerics<acc_t>::isnan(a) || a > b) ? a : b;
|
||||
}), at::numeric_limits<acc_t>::lower_bound());
|
||||
iter, func_wrapper<scalar_t> ([]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
|
||||
return (THCNumerics<scalar_t>::isnan(a) || a > b) ? a : b;
|
||||
}), at::numeric_limits<scalar_t>::lower_bound());
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename acc_t=scalar_t>
|
||||
template <typename scalar_t>
|
||||
void min_values_kernel_cuda_impl(TensorIterator& iter) {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter, func_wrapper<acc_t> ([]GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
|
||||
return (THCNumerics<acc_t>::isnan(a) || a < b) ? a : b;
|
||||
}), at::numeric_limits<acc_t>::upper_bound());
|
||||
iter, func_wrapper<scalar_t> ([]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
|
||||
return (THCNumerics<scalar_t>::isnan(a) || a < b) ? a : b;
|
||||
}), at::numeric_limits<scalar_t>::upper_bound());
|
||||
}
|
||||
|
||||
void max_values_kernel_cuda(TensorIterator& iter) {
|
||||
if (iter.dtype(1) == kHalf) {
|
||||
max_values_kernel_cuda_impl<at::Half, float>(iter);
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(), "max_values_cuda", [&]() {
|
||||
max_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(), "max_values_cuda", [&]() {
|
||||
max_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}
|
||||
|
||||
void min_values_kernel_cuda(TensorIterator& iter) {
|
||||
if (iter.dtype(1) == kHalf) {
|
||||
min_values_kernel_cuda_impl<at::Half, float>(iter);
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(), "min_values_cuda", [&]() {
|
||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(), "min_values_cuda", [&]() {
|
||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename acc_t=scalar_t>
|
||||
void argmax_kernel_cuda_impl(TensorIterator& iter) {
|
||||
gpu_reduce_kernel<scalar_t, int64_t>(
|
||||
iter,
|
||||
ArgMaxOps<acc_t>{},
|
||||
thrust::pair<acc_t, int64_t>(at::numeric_limits<acc_t>::lower_bound(), 0));
|
||||
};
|
||||
|
||||
template <typename scalar_t, typename acc_t=scalar_t>
|
||||
void argmin_kernel_cuda_impl(TensorIterator& iter) {
|
||||
gpu_reduce_kernel<scalar_t, int64_t>(
|
||||
iter,
|
||||
ArgMinOps<acc_t>{},
|
||||
thrust::pair<acc_t, int64_t>(at::numeric_limits<acc_t>::upper_bound(), 0));
|
||||
};
|
||||
|
||||
void argmax_kernel_cuda(TensorIterator& iter) {
|
||||
if (iter.dtype(1) == kHalf) {
|
||||
// Instead of implementing is_nan and warp_shfl_down
|
||||
// we can convert halves to float and do all the operations in float
|
||||
argmax_kernel_cuda_impl<at::Half, float>(iter);
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmax_cuda", [&]() {
|
||||
argmax_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmax_cuda", [&]() {
|
||||
gpu_reduce_kernel<scalar_t, int64_t>(
|
||||
iter,
|
||||
ArgMaxOps<scalar_t>{},
|
||||
thrust::pair<scalar_t, int64_t>(at::numeric_limits<scalar_t>::lower_bound(), 0));
|
||||
});
|
||||
}
|
||||
|
||||
void argmin_kernel_cuda(TensorIterator& iter) {
|
||||
if (iter.dtype(1) == kHalf) {
|
||||
// Instead of implementing is_nan and warp_shfl_down
|
||||
// we can convert halves to float and do all the operations in float
|
||||
argmin_kernel_cuda_impl<at::Half, float>(iter);
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmin_cuda", [&]() {
|
||||
argmin_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmin_cuda", [&]() {
|
||||
gpu_reduce_kernel<scalar_t, int64_t>(
|
||||
iter,
|
||||
ArgMinOps<scalar_t>{},
|
||||
thrust::pair<scalar_t, int64_t>(at::numeric_limits<scalar_t>::upper_bound(), 0));
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(std_var_stub, &std_var_kernel_cuda);
|
||||
|
@ -65,6 +65,14 @@ void log_kernel_cuda(TensorIterator& iter) {
|
||||
});
|
||||
}
|
||||
|
||||
void log10_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "log10_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return ::log10(a);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void neg_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Half, iter.dtype(), "neg_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
@ -180,6 +188,7 @@ REGISTER_DISPATCH(ceil_stub, &ceil_kernel_cuda);
|
||||
REGISTER_DISPATCH(expm1_stub, &expm1_kernel_cuda);
|
||||
REGISTER_DISPATCH(floor_stub, &floor_kernel_cuda);
|
||||
REGISTER_DISPATCH(log_stub, &log_kernel_cuda);
|
||||
REGISTER_DISPATCH(log10_stub, &log10_kernel_cuda);
|
||||
REGISTER_DISPATCH(neg_stub, &neg_kernel_cuda);
|
||||
REGISTER_DISPATCH(round_stub, &round_kernel_cuda);
|
||||
REGISTER_DISPATCH(rsqrt_stub, &rsqrt_kernel_cuda);
|
||||
|
@ -16,7 +16,7 @@ Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) {
|
||||
AT_ERROR("mkldnn_reshape: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
Tensor mkldnn_clone(const Tensor& self) {
|
||||
Tensor mkldnn_clone(const Tensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
AT_ERROR("mkldnn_clone: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
@ -54,7 +54,11 @@ Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) {
|
||||
return new_with_itensor_mkldnn(std::move(y), self.options());
|
||||
}
|
||||
|
||||
Tensor mkldnn_clone(const Tensor& self) {
|
||||
Tensor mkldnn_clone(const Tensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(
|
||||
!optional_memory_format.has_value(),
|
||||
"unsupported memory format option ",
|
||||
optional_memory_format.value());
|
||||
ideep::tensor& src = itensor_from_mkldnn(self);
|
||||
ideep::tensor dst;
|
||||
ideep::direct_copy::compute<AllocForMKLDNN>(src, dst);
|
||||
|
@ -66,7 +66,7 @@
|
||||
supports_named_tensor: True
|
||||
|
||||
- func: align_to(Tensor(a) self, DimnameList names) -> Tensor(a)
|
||||
variants: method
|
||||
variants: function, method
|
||||
supports_named_tensor: True
|
||||
|
||||
- func: align_as(Tensor self, Tensor other) -> Tensor
|
||||
@ -1539,15 +1539,12 @@
|
||||
use_c10_dispatcher: unboxed_only
|
||||
supports_named_tensor: True
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU: _log10__cpu
|
||||
CUDA: _log10__cuda
|
||||
|
||||
- func: log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
supports_named_tensor: True
|
||||
dispatch:
|
||||
CPU: _log10_out_cpu
|
||||
CUDA: _log10_out_cuda
|
||||
CPU: log10_out
|
||||
CUDA: log10_out
|
||||
|
||||
- func: log1p(Tensor self) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
@ -3108,8 +3105,7 @@
|
||||
- func: nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
|
||||
- func: clone(Tensor self) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
- func: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU: clone
|
||||
|
@ -155,13 +155,29 @@ Tensor& set_quantizer_(Tensor& self, ConstQuantizerPtr quantizer) {
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor quantized_clone(const Tensor& self) {
|
||||
Tensor quantized_clone(const Tensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
// TODO: add per channel support
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
self.qscheme() == at::kPerTensorAffine,
|
||||
"clone for quantized Tensor only works for PerTensorAffine scheme right now");
|
||||
|
||||
auto memory_format =
|
||||
optional_memory_format.value_or(MemoryFormat::Contiguous);
|
||||
|
||||
// TODO: To support all features of MemoryFormat::Preserve we need to add
|
||||
// _empty_affine_quantized_strided function and use it similarly to
|
||||
// Tensor clone(const Tensor& src, c10::optional<c10::MemoryFormat> optional_memory_format)
|
||||
// if (self.is_non_overlapping_and_dense()) -> _empty_affine_quantized_strided
|
||||
if (memory_format == MemoryFormat::Preserve) {
|
||||
memory_format = self.suggest_memory_format();
|
||||
}
|
||||
|
||||
Tensor dst = at::_empty_affine_quantized(
|
||||
self.sizes(), self.options(), self.q_scale(), self.q_zero_point());
|
||||
self.sizes(),
|
||||
self.options(),
|
||||
self.q_scale(),
|
||||
self.q_zero_point(),
|
||||
memory_format);
|
||||
|
||||
at::native::copy_(dst, self, false);
|
||||
|
||||
|
@ -255,7 +255,11 @@ Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, A
|
||||
|
||||
// NB: Deleted newWithSizeNd variants
|
||||
|
||||
SparseTensor clone_sparse(const SparseTensor& self) {
|
||||
SparseTensor clone_sparse(const SparseTensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(
|
||||
!optional_memory_format.has_value(),
|
||||
"unsupported memory format option ",
|
||||
optional_memory_format.value());
|
||||
SparseTensor other = new_with_dims_sparse(self.sparse_dim(), self.dense_dim(), self.sizes(), self.options());
|
||||
copy_into_sparse(other, self._indices(), self._values(), true);
|
||||
return other._coalesced_(self.is_coalesced());
|
||||
|
@ -186,10 +186,15 @@ class CAFFE2_API Tensor {
|
||||
int64_t ndimension() const {
|
||||
return dim();
|
||||
}
|
||||
|
||||
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
|
||||
return impl_->is_contiguous(memory_format);
|
||||
}
|
||||
|
||||
bool is_non_overlapping_and_dense() const {
|
||||
return impl_->is_non_overlapping_and_dense();
|
||||
}
|
||||
|
||||
at::MemoryFormat suggest_memory_format() const {
|
||||
if (impl_->is_strides_like_channels_last()) {
|
||||
return at::MemoryFormat::ChannelsLast;
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/core/NamedTensor.h>
|
||||
#include <ATen/core/EnableNamedTensor.h>
|
||||
#include <ATen/core/LegacyTypeDispatch.h>
|
||||
|
||||
#ifdef USE_STATIC_DISPATCH
|
||||
#include <ATen/TypeDefault.h>
|
||||
|
@ -29,8 +29,7 @@ list(APPEND ATen_CPU_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/xla_tensor_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/tensor_iterator_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cpu_generator_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce_ops_test.cpp)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp)
|
||||
|
||||
list(APPEND ATen_CUDA_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu
|
||||
|
@ -1,24 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
using namespace at;
|
||||
|
||||
TEST(ReduceOpsTest, MaxValuesAndMinValues) {
|
||||
const int W = 10;
|
||||
const int H = 10;
|
||||
if (hasCUDA()) {
|
||||
for (const auto dtype : {kHalf, kFloat, kDouble, kShort, kInt, kLong}) {
|
||||
auto a = at::rand({H, W}, TensorOptions(kCUDA).dtype(at::kHalf));
|
||||
ASSERT_FLOAT_EQ(
|
||||
a.max_values(c10::IntArrayRef{0, 1}).item<double>(),
|
||||
a.max().item<double>()
|
||||
);
|
||||
ASSERT_FLOAT_EQ(
|
||||
a.min_values(c10::IntArrayRef{0, 1}).item<double>(),
|
||||
a.min().item<double>()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
@ -61,17 +61,41 @@ TEST(TensorIteratorTest, SerialLoopUnary_##name) { \
|
||||
ASSERT_ANY_THROW(out.equal(expected)); \
|
||||
}
|
||||
|
||||
#define NO_OUTPUT_UNARY_TEST_ITER_FOR_TYPE(ctype,name) \
|
||||
TEST(TensorIteratorTest, SerialLoopUnaryNoOutput_##name) { \
|
||||
auto in = random_tensor_for_type(k##name); \
|
||||
auto iter = at::TensorIterator(); \
|
||||
iter.add_input(in); \
|
||||
iter.build(); \
|
||||
int64_t acc = 0; \
|
||||
at::native::cpu_serial_kernel(iter, [&](ctype a) -> void { acc++; }); \
|
||||
EXPECT_TRUE(acc == in.numel()); \
|
||||
}
|
||||
|
||||
#define BINARY_TEST_ITER_FOR_TYPE(ctype,name) \
|
||||
TEST(TensorIteratorTest, SerialLoopBinary_##name) { \
|
||||
Tensor out; \
|
||||
auto in1 = random_tensor_for_type(k##name); \
|
||||
auto in2 = random_tensor_for_type(k##name); \
|
||||
auto expected = in1.add(in2); \
|
||||
auto iter = TensorIterator::binary_op(out, in1, in2); \
|
||||
auto iter = TensorIterator::binary_op(out, in1, in2); \
|
||||
at::native::cpu_serial_kernel(iter, [=](ctype a, ctype b) -> int { return a + b; }); \
|
||||
ASSERT_ANY_THROW(out.equal(expected)); \
|
||||
}
|
||||
|
||||
#define NO_OUTPUT_BINARY_TEST_ITER_FOR_TYPE(ctype,name) \
|
||||
TEST(TensorIteratorTest, SerialLoopBinaryNoOutput_##name) { \
|
||||
auto in1 = random_tensor_for_type(k##name); \
|
||||
auto in2 = random_tensor_for_type(k##name); \
|
||||
auto iter = at::TensorIterator(); \
|
||||
iter.add_input(in1); \
|
||||
iter.add_input(in2); \
|
||||
iter.build(); \
|
||||
int64_t acc = 0; \
|
||||
at::native::cpu_serial_kernel(iter, [&](ctype a, ctype b) -> void { acc++; }); \
|
||||
EXPECT_TRUE(acc == in1.numel()); \
|
||||
}
|
||||
|
||||
#define POINTWISE_TEST_ITER_FOR_TYPE(ctype,name) \
|
||||
TEST(TensorIteratorTest, SerialLoopPointwise_##name) { \
|
||||
Tensor out; \
|
||||
@ -89,6 +113,21 @@ TEST(TensorIteratorTest, SerialLoopPointwise_##name) {
|
||||
ASSERT_ANY_THROW(out.equal(expected)); \
|
||||
}
|
||||
|
||||
#define NO_OUTPUT_POINTWISE_TEST_ITER_FOR_TYPE(ctype,name) \
|
||||
TEST(TensorIteratorTest, SerialLoopPoinwiseNoOutput_##name) { \
|
||||
auto in1 = random_tensor_for_type(k##name); \
|
||||
auto in2 = random_tensor_for_type(k##name); \
|
||||
auto in3 = random_tensor_for_type(k##name); \
|
||||
auto iter = at::TensorIterator(); \
|
||||
iter.add_input(in1); \
|
||||
iter.add_input(in2); \
|
||||
iter.add_input(in3); \
|
||||
iter.build(); \
|
||||
int64_t acc = 0; \
|
||||
at::native::cpu_serial_kernel(iter, [&](ctype a, ctype b, ctype c) -> void { acc++; }); \
|
||||
EXPECT_TRUE(acc == in1.numel()); \
|
||||
}
|
||||
|
||||
// The alternative way to calculate a < b is (b - a).clamp(0).toBool()
|
||||
// To prevent an overflow in subtraction (b - a) for unsigned types(unit, bool)
|
||||
// we will convert in to int first
|
||||
@ -112,6 +151,9 @@ TEST(TensorIteratorTest, ComparisonLoopBinary_##name) {
|
||||
AT_FORALL_SCALAR_TYPES(UNARY_TEST_ITER_FOR_TYPE)
|
||||
AT_FORALL_SCALAR_TYPES(BINARY_TEST_ITER_FOR_TYPE)
|
||||
AT_FORALL_SCALAR_TYPES(POINTWISE_TEST_ITER_FOR_TYPE)
|
||||
AT_FORALL_SCALAR_TYPES(NO_OUTPUT_UNARY_TEST_ITER_FOR_TYPE)
|
||||
AT_FORALL_SCALAR_TYPES(NO_OUTPUT_BINARY_TEST_ITER_FOR_TYPE)
|
||||
AT_FORALL_SCALAR_TYPES(NO_OUTPUT_POINTWISE_TEST_ITER_FOR_TYPE)
|
||||
AT_FORALL_SCALAR_TYPES_AND(Bool, COMPARISON_TEST_ITER_FOR_TYPE)
|
||||
|
||||
TEST(TensorIteratorTest, SerialLoopSingleThread) {
|
||||
@ -173,11 +215,3 @@ TEST(TensorIteratorTest, DoNotComputeCommonDTypeIfOutputIsUndefined) {
|
||||
ASSERT_ANY_THROW(iter.build());
|
||||
}
|
||||
|
||||
TEST(TensorIteratorTest, FailNonPromotingBinaryOp) {
|
||||
Tensor out;
|
||||
auto iter = at::TensorIterator();
|
||||
iter.add_output(out);
|
||||
iter.add_input(at::ones({1,1}, at::dtype(at::kDouble)));
|
||||
iter.add_input(at::ones({1,1}, at::dtype(at::kInt)));
|
||||
ASSERT_ANY_THROW(iter.build());
|
||||
}
|
||||
|
@ -188,11 +188,6 @@ void THFree(void *ptr)
|
||||
c10::free_cpu(ptr);
|
||||
}
|
||||
|
||||
double THLog10(const double x)
|
||||
{
|
||||
return log10(x);
|
||||
}
|
||||
|
||||
double THLog1p(const double x)
|
||||
{
|
||||
#if (defined(_MSC_VER) || defined(__MINGW32__))
|
||||
|
@ -7,7 +7,6 @@
|
||||
#define Real QInt32
|
||||
#define RealUnderlying Int
|
||||
#define THQUANTIZED
|
||||
#define THQINT32
|
||||
#define TH_REAL_IS_BYTE
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
@ -16,7 +15,6 @@
|
||||
#undef Real
|
||||
#undef RealUnderlying
|
||||
#undef TH_REAL_IS_BYTE
|
||||
#undef THQINT32
|
||||
#undef THQUANTIZED
|
||||
|
||||
#ifndef THGenerateManyTypes
|
||||
|
@ -7,7 +7,6 @@
|
||||
#define Real QInt8
|
||||
#define RealUnderlying Char
|
||||
#define THQUANTIZED
|
||||
#define THQINT8
|
||||
#define TH_REAL_IS_BYTE
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
@ -16,7 +15,6 @@
|
||||
#undef Real
|
||||
#undef RealUnderlying
|
||||
#undef TH_REAL_IS_BYTE
|
||||
#undef THQINT8
|
||||
#undef THQUANTIZED
|
||||
|
||||
#ifndef THGenerateManyTypes
|
||||
|
@ -7,7 +7,6 @@
|
||||
#define Real QUInt8
|
||||
#define RealUnderlying Byte
|
||||
#define THQUANTIZED
|
||||
#define THQUINT8
|
||||
#define TH_REAL_IS_BYTE
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
@ -16,7 +15,6 @@
|
||||
#undef Real
|
||||
#undef RealUnderlying
|
||||
#undef TH_REAL_IS_BYTE
|
||||
#undef THQUINT8
|
||||
#undef THQUANTIZED
|
||||
|
||||
#ifndef THGenerateManyTypes
|
||||
|
@ -35,9 +35,6 @@
|
||||
#define THLongStorage THStorage
|
||||
#define THBoolStorage THStorage
|
||||
#define THBFloat16Storage THStorage
|
||||
#define THQUInt8Storage THStorage
|
||||
#define THQInt8Storage THStorage
|
||||
#define THQInt32Storage THStorage
|
||||
|
||||
TH_API scalar_t* THStorage_(data)(const THStorage*);
|
||||
TH_API ptrdiff_t THStorage_(size)(const THStorage*);
|
||||
|
@ -35,14 +35,5 @@ IMPLEMENT_THStorage_COPY(Double)
|
||||
IMPLEMENT_THStorage_COPY(Half)
|
||||
IMPLEMENT_THStorage_COPY(Bool)
|
||||
IMPLEMENT_THStorage_COPY(BFloat16)
|
||||
#ifdef THQUINT8
|
||||
IMPLEMENT_THStorage_COPY(QUInt8)
|
||||
#endif
|
||||
#ifdef THQINT8
|
||||
IMPLEMENT_THStorage_COPY(QInt8)
|
||||
#endif
|
||||
#ifdef THQINT32
|
||||
IMPLEMENT_THStorage_COPY(QInt32)
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
@ -14,14 +14,5 @@ TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *s
|
||||
TH_API void THStorage_(copyHalf)(THStorage *storage, struct THHalfStorage *src);
|
||||
TH_API void THStorage_(copyBool)(THStorage *storage, struct THBoolStorage *src);
|
||||
TH_API void THStorage_(copyBFloat16)(THStorage *storage, struct THBFloat16Storage *src);
|
||||
#ifdef THQUINT8
|
||||
TH_API void THStorage_(copyQUInt8)(THStorage *storage, struct THQUInt8Storage *src);
|
||||
#endif
|
||||
#ifdef THQINT8
|
||||
TH_API void THStorage_(copyQInt8)(THStorage *storage, struct THQInt8Storage *src);
|
||||
#endif
|
||||
#ifdef THQINT32
|
||||
TH_API void THStorage_(copyQInt32)(THStorage *storage, struct THQInt32Storage *src);
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
@ -151,7 +151,6 @@ TH_API void THTensor_(abs)(THTensor *r_, THTensor *t);
|
||||
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
|
||||
|
||||
TH_API void THTensor_(sigmoid)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(log10)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(log1p)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(log2)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(exp)(THTensor *r_, THTensor *t);
|
||||
|
@ -31,7 +31,6 @@ TH_API void THVector_(abs)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
/* floating point only now */
|
||||
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
|
||||
|
||||
TH_API void THVector_(log10)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
TH_API void THVector_(log1p)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
TH_API void THVector_(log2)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
TH_API void THVector_(sigmoid)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
@ -49,12 +48,7 @@ TH_API void THVector_(atan)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
TH_API void THVector_(tanh)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
TH_API void THVector_(pow)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n);
|
||||
TH_API void THVector_(sqrt)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
TH_API void THVector_(rsqrt)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
TH_API void THVector_(ceil)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
TH_API void THVector_(floor)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
TH_API void THVector_(round)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
TH_API void THVector_(abs)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
TH_API void THVector_(trunc)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
TH_API void THVector_(frac)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
TH_API void THVector_(cinv)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
|
||||
|
@ -243,7 +243,6 @@ VECTOR_IMPLEMENT_FUNCTION(abs,)
|
||||
#define TH_MATH_NAME(fn) fn
|
||||
#endif
|
||||
|
||||
VECTOR_IMPLEMENT_FUNCTION(log10,TH_MATH_NAME(log10))
|
||||
VECTOR_IMPLEMENT_FUNCTION(log1p,TH_MATH_NAME(log1p))
|
||||
VECTOR_IMPLEMENT_FUNCTION(log2,TH_MATH_NAME(log2))
|
||||
VECTOR_IMPLEMENT_FUNCTION(sigmoid_DEFAULT,TH_MATH_NAME(TH_sigmoid))
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
#include <TH/THGeneral.h>
|
||||
#include <TH/THAllocator.h>
|
||||
#undef log10
|
||||
#undef log1p
|
||||
#undef log2
|
||||
|
||||
|
@ -204,7 +204,6 @@ struct THCNumerics<at::Half> {
|
||||
|
||||
static inline __host__ __device__ at::Half exp(at::Half a) { return std::exp(a); }
|
||||
static inline __host__ __device__ at::Half exp10(at::Half a) { return ::exp10(a); }
|
||||
static inline __host__ __device__ at::Half log10(at::Half a) { return ::log10(a); }
|
||||
static inline __host__ __device__ at::Half log1p(at::Half a) { return ::log1p(a); }
|
||||
static inline __host__ __device__ at::Half log2(at::Half a) { return ::log2(a); }
|
||||
static inline __host__ __device__ at::Half cos(at::Half a) { return ::cos(a); }
|
||||
@ -279,7 +278,6 @@ struct THCNumerics<float> {
|
||||
|
||||
static inline __host__ __device__ float exp (float a) { return expf(a); }
|
||||
static inline __host__ __device__ float exp10(float a) { return exp10f(a); }
|
||||
static inline __host__ __device__ float log10(float a) { return log10f(a); }
|
||||
static inline __host__ __device__ float log1p(float a) { return log1pf(a); }
|
||||
static inline __host__ __device__ float log2 (float a) { return log2f(a); }
|
||||
static inline __host__ __device__ float cos (float a) { return cosf(a); }
|
||||
@ -329,7 +327,6 @@ struct THCNumerics<double> {
|
||||
|
||||
static inline __host__ __device__ double exp (double a) { return ::exp(a); }
|
||||
static inline __host__ __device__ double exp10(double a) { return ::exp10(a); }
|
||||
static inline __host__ __device__ double log10(double a) { return ::log10(a); }
|
||||
static inline __host__ __device__ double log1p(double a) { return ::log1p(a); }
|
||||
static inline __host__ __device__ double log2 (double a) { return ::log2(a); }
|
||||
static inline __host__ __device__ double cos (double a) { return ::cos(a); }
|
||||
|
@ -198,7 +198,6 @@ static void propagate_names_if_named_tensor_enabled(THCTensor* result, THCTensor
|
||||
|
||||
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
|
||||
|
||||
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(log10, THCNumerics<scalar_t>::log10, Real)
|
||||
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(log1p, THCNumerics<scalar_t>::log1p, Real)
|
||||
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( log2, THCNumerics<scalar_t>::log2, Real)
|
||||
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( exp, THCNumerics<scalar_t>::exp, Real)
|
||||
|
@ -16,7 +16,6 @@ THC_API void THCTensor_(cminValue)(THCState *state, THCTensor *self, THCTensor *
|
||||
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
|
||||
|
||||
THC_API void THCTensor_(sigmoid)(THCState *state, THCTensor *self, THCTensor *src);
|
||||
THC_API void THCTensor_(log10)(THCState *state, THCTensor *self, THCTensor *src);
|
||||
THC_API void THCTensor_(log1p)(THCState *state, THCTensor *self, THCTensor *src);
|
||||
THC_API void THCTensor_(log2)(THCState *state, THCTensor *self, THCTensor *src);
|
||||
THC_API void THCTensor_(exp)(THCState *state, THCTensor *self, THCTensor *src);
|
||||
|
@ -39,10 +39,6 @@ inline __host__ __device__ THHalf exp(THHalf a) {
|
||||
return THCNumerics<THHalf>::exp(a);
|
||||
}
|
||||
|
||||
inline __host__ __device__ THHalf log10(THHalf a) {
|
||||
return THCNumerics<THHalf>::log10(a);
|
||||
}
|
||||
|
||||
inline __host__ __device__ THHalf log1p(THHalf a) {
|
||||
return THCNumerics<THHalf>::log1p(a);
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ from __future__ import unicode_literals
|
||||
import argparse
|
||||
|
||||
from caffe2.python import workspace
|
||||
import torch
|
||||
|
||||
import benchmark_core
|
||||
import benchmark_utils
|
||||
@ -127,7 +128,18 @@ def main():
|
||||
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
|
||||
workspace.ClearGlobalNetObserver()
|
||||
if args.omp_num_threads:
|
||||
# benchmark_utils.set_omp_threads sets the env variable OMP_NUM_THREADS
|
||||
# which doesn't have any impact as C2 init logic has already been called
|
||||
# before setting the env var.
|
||||
|
||||
# In general, OMP_NUM_THREADS (and other OMP env variables) needs to be set
|
||||
# before the program is started.
|
||||
# From Chapter 4 in OMP standard: https://www.openmp.org/wp-content/uploads/openmp-4.5.pdf
|
||||
# "Modifications to the environment variables after the program has started,
|
||||
# even if modified by the program itself, are ignored by the OpenMP implementation"
|
||||
benchmark_utils.set_omp_threads(args.omp_num_threads)
|
||||
if benchmark_utils.is_pytorch_enabled(args.framework):
|
||||
torch.set_num_threads(args.omp_num_threads)
|
||||
if args.mkl_num_threads:
|
||||
benchmark_utils.set_mkl_threads(args.mkl_num_threads)
|
||||
|
||||
|
@ -130,6 +130,35 @@ bool TensorImpl::compute_strides_like_channels_last() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TensorImpl::compute_non_overlapping_and_dense() const {
|
||||
if (dim() == 1) {
|
||||
return size(0) < 2 || stride(0) == 1;
|
||||
}
|
||||
SmallVector<int64_t,5> perm;
|
||||
perm.resize(dim());
|
||||
for (int64_t i = 0; i < dim(); i ++) {
|
||||
perm[i] = i;
|
||||
}
|
||||
// Sort by strides, leaving 0 and 1 sized dims at the end of the array
|
||||
std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
|
||||
if (sizes_[a] < 2) {
|
||||
return false;
|
||||
}
|
||||
return strides_[a] < strides_[b];
|
||||
});
|
||||
auto require_stride = 1;
|
||||
for (int64_t i = 0; i < dim(); i ++) {
|
||||
if (sizes_[perm[i]] < 2) {
|
||||
return true;
|
||||
}
|
||||
if (strides_[perm[i]] != require_stride) {
|
||||
return false;
|
||||
}
|
||||
require_stride *= sizes_[perm[i]];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void TensorImpl::release_resources() {
|
||||
autograd_meta_.reset();
|
||||
if (storage_) {
|
||||
|
@ -1390,6 +1390,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
is_contiguous_ = false;
|
||||
is_channels_last_contiguous_ = false;
|
||||
is_channels_last_ = false;
|
||||
is_non_overlapping_and_dense_ = false;
|
||||
switch (memory_format) {
|
||||
case MemoryFormat::Contiguous: {
|
||||
strides_.resize(sizes_.size(), 0);
|
||||
@ -1401,6 +1402,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
}
|
||||
}
|
||||
is_contiguous_ = true;
|
||||
is_non_overlapping_and_dense_ = true;
|
||||
return;
|
||||
}
|
||||
case MemoryFormat::ChannelsLast: {
|
||||
@ -1410,6 +1412,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
set_sizes_and_strides(sizes(), get_channels_last_strides(sizes()));
|
||||
is_channels_last_contiguous_ = true;
|
||||
is_channels_last_ = true;
|
||||
is_non_overlapping_and_dense_ = true;
|
||||
return;
|
||||
}
|
||||
case MemoryFormat::Preserve:
|
||||
@ -1421,6 +1424,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return is_channels_last_;
|
||||
}
|
||||
|
||||
bool is_non_overlapping_and_dense() const {
|
||||
return is_non_overlapping_and_dense_;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
// The Caffe2 Resize() method supports being called both as Resize({2,2}) as
|
||||
@ -1501,6 +1508,8 @@ private:
|
||||
|
||||
bool compute_strides_like_channels_last() const;
|
||||
|
||||
bool compute_non_overlapping_and_dense() const;
|
||||
|
||||
protected:
|
||||
/**
|
||||
* Recompute the cached numel of a tensor. Call this if you modify sizes.
|
||||
@ -1517,6 +1526,7 @@ protected:
|
||||
is_contiguous_ = compute_contiguous();
|
||||
is_channels_last_contiguous_ = compute_channels_last_contiguous();
|
||||
is_channels_last_ = is_channels_last_contiguous_ || compute_strides_like_channels_last();
|
||||
is_non_overlapping_and_dense_ = is_contiguous_ || is_channels_last_contiguous_ || compute_non_overlapping_and_dense();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1546,6 +1556,9 @@ protected:
|
||||
dest_impl->type_set_ = dest_impl->type_set_.remove(TensorTypeId::VariableTensorId);
|
||||
}
|
||||
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
|
||||
dest_impl->is_channels_last_contiguous_ = src_impl->is_channels_last_contiguous_;
|
||||
dest_impl->is_channels_last_ = src_impl->is_channels_last_;
|
||||
dest_impl->is_non_overlapping_and_dense_ = src_impl->is_non_overlapping_and_dense_;
|
||||
dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_;
|
||||
dest_impl->reserved_ = src_impl->reserved_;
|
||||
dest_impl->set_version_counter(version_counter);
|
||||
@ -1641,6 +1654,11 @@ protected:
|
||||
// contiguous memory block.
|
||||
bool is_channels_last_contiguous_ = false;
|
||||
|
||||
// Dense tensor is the tensor that store values in a contiguous block of memory.
|
||||
// Non-overlapping tensor is the tensor in which elements occupy individual
|
||||
// non-repetitive memory.
|
||||
bool is_non_overlapping_and_dense_ = false;
|
||||
|
||||
bool is_wrapped_number_ = false;
|
||||
|
||||
// NOTE [ Metadata Change for a Detached Tensor ]
|
||||
|
@ -74,6 +74,65 @@ struct C10_API SourceLocation {
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const SourceLocation& loc);
|
||||
|
||||
// unix isprint but insensitive to locale
|
||||
inline static bool isPrint(char s) {
|
||||
return s > 0x1f && s < 0x7f;
|
||||
}
|
||||
|
||||
inline void printQuotedString(std::ostream& stmt, const std::string& str) {
|
||||
stmt << "\"";
|
||||
for (auto s : str) {
|
||||
switch (s) {
|
||||
case '\\':
|
||||
stmt << "\\\\";
|
||||
break;
|
||||
case '\'':
|
||||
stmt << "\\'";
|
||||
break;
|
||||
case '\"':
|
||||
stmt << "\\\"";
|
||||
break;
|
||||
case '\a':
|
||||
stmt << "\\a";
|
||||
break;
|
||||
case '\b':
|
||||
stmt << "\\b";
|
||||
break;
|
||||
case '\f':
|
||||
stmt << "\\f";
|
||||
break;
|
||||
case '\n':
|
||||
stmt << "\\n";
|
||||
break;
|
||||
case '\r':
|
||||
stmt << "\\r";
|
||||
break;
|
||||
case '\t':
|
||||
stmt << "\\t";
|
||||
break;
|
||||
case '\v':
|
||||
stmt << "\\v";
|
||||
break;
|
||||
default:
|
||||
if (isPrint(s)) {
|
||||
stmt << s;
|
||||
} else {
|
||||
// C++ io has stateful formatting settings. Messing with
|
||||
// them is probably worse than doing this manually.
|
||||
char buf[4] = "000";
|
||||
buf[2] += s % 8;
|
||||
s /= 8;
|
||||
buf[1] += s % 8;
|
||||
s /= 8;
|
||||
buf[0] += s;
|
||||
stmt << "\\" << buf;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
stmt << "\"";
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#endif // C10_UTIL_STRINGUTIL_H_
|
||||
|
@ -21,7 +21,6 @@
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
#include <type_traits>
|
||||
#include <stdexcept>
|
||||
|
||||
#ifndef _MSC_VER
|
||||
#pragma GCC diagnostic push
|
||||
|
@ -481,18 +481,27 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
if (NOT INTERN_BUILD_MOBILE)
|
||||
list(APPEND TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/context/dist_autograd_container.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/context/dist_autograd_context.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/functions/recvrpc_backward.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/functions/sendrpc_backward.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/utils.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/future_message.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/message.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_udf_call.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_udf_resp.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_with_autograd.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_call.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_remote_call.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_rref_proto.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_ret.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_resp.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/utils.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/export.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/import_legacy.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/netdef_converter.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/byte_order.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -538,6 +547,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/init.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/module.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/activation.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/batchnorm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/conv.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/dropout.cpp
|
||||
@ -550,6 +560,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/rnn.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/container/functional.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/container/named_any.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/activation.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/batchnorm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/conv.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/dropout.cpp
|
||||
|
@ -40,8 +40,14 @@ class LayerModelHelper(model_helper.ModelHelper):
|
||||
"""
|
||||
|
||||
def __init__(self, name, input_feature_schema, trainer_extra_schema,
|
||||
keep_blobs=False):
|
||||
keep_blobs=False,
|
||||
use_attribution=True):
|
||||
''' TODO(amalevich): more documnetation on input args
|
||||
|
||||
use_attribution:
|
||||
if True, will generate the atrribution net for feature importance
|
||||
calculation; Need to turn it to false when FC is quantized as FP16
|
||||
This attribute access will be consistent with MTML model.
|
||||
'''
|
||||
|
||||
super(LayerModelHelper, self).__init__(name=name)
|
||||
@ -92,6 +98,7 @@ class LayerModelHelper(model_helper.ModelHelper):
|
||||
# TODO(xlwang): it's hack!
|
||||
self.ad_hoc_diagnose_blobs_and_operations = []
|
||||
self.ad_hoc_plot_blobs = []
|
||||
self.use_attribution = use_attribution
|
||||
|
||||
def clear_output_schema(self):
|
||||
self._output_schema = None
|
||||
|
@ -5,6 +5,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from caffe2.python.helpers.arg_scope import get_current_scope
|
||||
from caffe2.python import schema
|
||||
from caffe2.python.layers.layers import ModelLayer
|
||||
from caffe2.python.layers.sampling_trainable_mixin import SamplingTrainableMixin
|
||||
@ -12,6 +13,14 @@ import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_fc_predictor_version(fc_version):
|
||||
assert fc_version in ["fp32", "fp16"], (
|
||||
"Only support fp32 and fp16 for the fully connected layer "
|
||||
"in the predictor net, the provided FC precision is {}".format(fc_version)
|
||||
)
|
||||
return fc_version
|
||||
|
||||
|
||||
class FC(SamplingTrainableMixin, ModelLayer):
|
||||
|
||||
def __init__(self, model, input_record, output_dims, weight_init=None,
|
||||
@ -119,18 +128,47 @@ class FC(SamplingTrainableMixin, ModelLayer):
|
||||
|
||||
return output_dim_vec
|
||||
|
||||
def _add_ops(self, net, params):
|
||||
def _insert_fc_ops(self, net, params, outputs, version):
|
||||
"""
|
||||
Args:
|
||||
net: the caffe2 net to insert operator
|
||||
params: weight and bias for FC
|
||||
outputs: the output blobs
|
||||
version: support fp32 and fp16 for now.
|
||||
"""
|
||||
if version == "fp32":
|
||||
return net.FC(
|
||||
self.input_record.field_blobs() + params,
|
||||
outputs,
|
||||
axis=self.axis,
|
||||
**self.kwargs
|
||||
)
|
||||
elif version == "fp16":
|
||||
return net.FbFCPacked(
|
||||
self.input_record.field_blobs() + params,
|
||||
outputs,
|
||||
axis=self.axis,
|
||||
**self.kwargs
|
||||
)
|
||||
else:
|
||||
raise Exception("unsupported FC type version {}".format(version))
|
||||
|
||||
def _add_ops(self, net, params, version):
|
||||
"""
|
||||
Args:
|
||||
params : the weight and bias,
|
||||
passed by either add_ops or add_train_ops function
|
||||
version : fp16 or fp32, might support in8 in the future.
|
||||
"""
|
||||
if self.clip_args is not None:
|
||||
clipped_params = [net.NextScopedBlob(
|
||||
'clipped_%s' % str(p)) for p in params]
|
||||
for p, cp in zip(params, clipped_params):
|
||||
net.Clip([p], [cp], **self.clip_args)
|
||||
|
||||
params = clipped_params
|
||||
|
||||
if self.output_dim_vec is None or len(self.output_dim_vec) == 1:
|
||||
net.FC(self.input_record.field_blobs() + params,
|
||||
self.output_schema.field_blobs(), axis=self.axis, **self.kwargs)
|
||||
self._insert_fc_ops(net, params, self.output_schema.field_blobs(), version)
|
||||
else:
|
||||
w_vec = params[:int(len(params) / 2)]
|
||||
b_vec = params[int(len(params) / 2):]
|
||||
@ -142,15 +180,33 @@ class FC(SamplingTrainableMixin, ModelLayer):
|
||||
for i in range(len(self.output_dim_vec)):
|
||||
output_blob = net.NextScopedBlob(
|
||||
'output_sub_{}'.format(i))
|
||||
output_blob_vec.append(
|
||||
net.FC(self.input_record.field_blobs() +
|
||||
[w_vec[i], b_vec[i]],
|
||||
[output_blob], axis=self.axis, **self.kwargs))
|
||||
|
||||
insert_ret = self._insert_fc_ops(
|
||||
net, [w_vec[i], b_vec[i]], [output_blob], version
|
||||
)
|
||||
output_blob_vec.append(insert_ret)
|
||||
net.Concat(output_blob_vec,
|
||||
self.output_schema.field_blobs() +
|
||||
[self.output_schema.field_blobs()[0] + "_concat_dims"])
|
||||
|
||||
def add_ops(self, net):
|
||||
"""Both the predict net and the eval net will call this function
|
||||
"""
|
||||
version_info = get_current_scope().get(
|
||||
get_fc_predictor_version.__name__, {'fc_version': 'fp32'}
|
||||
)
|
||||
predictor_fc_fp_version = version_info['fc_version']
|
||||
self._add_ops(net, self.param_blobs, predictor_fc_fp_version)
|
||||
|
||||
def add_train_ops(self, net):
|
||||
# use the train_param_blobs to be consistent with the SamplingTrain unittest
|
||||
self._add_ops(net, self.train_param_blobs, "fp32")
|
||||
|
||||
def get_fp16_compatible_parameters(self):
|
||||
if self.output_dim_vec is None or len(self.output_dim_vec) == 1:
|
||||
return [self.w]
|
||||
else:
|
||||
return self.w_vec
|
||||
|
||||
@property
|
||||
def param_blobs(self):
|
||||
if self.output_dim_vec is None or len(self.output_dim_vec) == 1:
|
||||
|
@ -42,6 +42,9 @@ class Regularizer(object):
|
||||
def _run_after_optimizer(self, net, param_init_net, param, grad):
|
||||
return None
|
||||
|
||||
def _feature_grouping(self, param, net):
|
||||
return None
|
||||
|
||||
def _ensure_clipped(
|
||||
self,
|
||||
net,
|
||||
@ -84,6 +87,52 @@ class L1Norm(Regularizer):
|
||||
net.Scale([output_blob], [output_blob], scale=self.reg_lambda)
|
||||
return output_blob
|
||||
|
||||
class FCInputLpNorm(Regularizer):
|
||||
def __init__(self, reg_lambda, p_value=0.5):
|
||||
super(FCInputLpNorm, self).__init__()
|
||||
assert reg_lambda >= 0, "factor ahead of regularization should be 0 or positive"
|
||||
assert p_value >= 0, "p_value factor should be 0 or positive"
|
||||
self.p_value = p_value
|
||||
self.reg_lambda = reg_lambda
|
||||
|
||||
def _feature_grouping(self, param, net):
|
||||
# Possible alternative grouping method via summing over absolute values
|
||||
# Compute l2norm over feature weights
|
||||
# pow( sum_i { pow(theda_i, 2) } , 0.5)
|
||||
param_mul = net.Mul([param, param], [net.NextScopedBlob("param_mul")])
|
||||
param_reduced = net.ReduceFrontSum(
|
||||
[param_mul], [net.NextScopedBlob("param_reduced")]
|
||||
)
|
||||
grouped_feature_weight_vec = net.Pow(
|
||||
[param_reduced],
|
||||
[net.NextScopedBlob("grouped_feature_weight_vec")],
|
||||
exponent=0.5,
|
||||
)
|
||||
|
||||
return grouped_feature_weight_vec
|
||||
|
||||
def _run_on_loss(self, net, param_init_net, param, grad=None):
|
||||
# TODO: the second dim (num of input nodes) of param is after feature preproc,
|
||||
# and does not correspond to the original num of dense features.
|
||||
# In the future, will want to create a util to reduce the input dim of param to
|
||||
# match the num of dense features.
|
||||
|
||||
output_blob = net.NextScopedBlob(param + "_dense_feature_regularization")
|
||||
grouped_feature_weight_vec = self._feature_grouping(param, net)
|
||||
|
||||
# Compute Lpnorm over l2norm:
|
||||
# pow( sum_i { pow(theda_i, p) } , 1/p)
|
||||
lp_vec_raised = net.Pow(
|
||||
[grouped_feature_weight_vec], [net.NextScopedBlob("lp_vec_raised")], exponent=self.p_value
|
||||
)
|
||||
lp_vec_summed = net.ReduceFrontSum(
|
||||
[lp_vec_raised], [net.NextScopedBlob("lp_vec_summed")]
|
||||
)
|
||||
lp_vec = net.Pow(
|
||||
[lp_vec_summed], [net.NextScopedBlob("lp_vec")], exponent=(1 / self.p_value)
|
||||
)
|
||||
net.Scale([lp_vec], [output_blob], scale=self.reg_lambda)
|
||||
return output_blob
|
||||
|
||||
class L1NormTrimmed(Regularizer):
|
||||
"""
|
||||
|
@ -1,65 +0,0 @@
|
||||
#include "caffe2/quantization/server/resize_nearest_3d_dnnlowp_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <typename T>
|
||||
bool ResizeNearest3DDNNLowPOp<T>::RunOnDevice() {
|
||||
using namespace dnnlowp;
|
||||
|
||||
this->ParseDNNLowPOperatorArguments_();
|
||||
|
||||
// Choose quantization params
|
||||
in_qparams_[0] = GetInputTensorQuantizationParamsOf(this, 0, qfactory_.get());
|
||||
|
||||
const auto& X = InputTensorCPU_(0);
|
||||
auto* Y = OutputTensorCPU_(0);
|
||||
|
||||
CAFFE_ENFORCE_EQ(X.ndim(), 5);
|
||||
const int N = X.dim32(0);
|
||||
// input frames
|
||||
const int IF = X.dim32(1);
|
||||
const int IH = X.dim32(2);
|
||||
const int IW = X.dim32(3);
|
||||
const int C = X.dim32(4);
|
||||
const int OF = IF * temporal_scale_;
|
||||
const int OH = IH * height_scale_;
|
||||
const int OW = IW * width_scale_;
|
||||
|
||||
vector<int> buffer_shape{N, OF, OH, OW, C};
|
||||
Y->Resize(buffer_shape);
|
||||
const T* X_data = X.template data<T>();
|
||||
T* Y_data = Y->template mutable_data<T>();
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int t = 0; t < OF; ++t) {
|
||||
const int in_f = std::min((int)(t / temporal_scale_), (IF - 1));
|
||||
for (int y = 0; y < OH; ++y) {
|
||||
const int in_y = std::min((int)(y / height_scale_), (IH - 1));
|
||||
for (int x = 0; x < OW; ++x) {
|
||||
const int in_x = std::min((int)(x / width_scale_), (IW - 1));
|
||||
std::memcpy(
|
||||
&Y_data[((((n * OF) + t) * OH + y) * OW + x) * C],
|
||||
&X_data[((((n * IF) + in_f) * IH + in_y) * IW + in_x) * C],
|
||||
C * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Even if there is a pre-chosen quantization parameters for the output,
|
||||
// it is ignored because resize nearest output quantization should be same
|
||||
// as the input.
|
||||
PropagateOutputTensorQuantizationParams(this, 0, in_qparams_[0]);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
REGISTER_CPU_OPERATOR_WITH_ENGINE(
|
||||
Int8ResizeNearest3D,
|
||||
DNNLOWP,
|
||||
ResizeNearest3DDNNLowPOp<uint8_t>);
|
||||
|
||||
} // namespace caffe2
|
@ -1,41 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "caffe2/quantization/server/dnnlowp_op.h"
|
||||
#include "vision/video_modeling/experimental/abc/ops/resize_3d_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
using ResizeNearest3DFP32Op = ResizeNearest3DOp<float, CPUContext>;
|
||||
|
||||
template <typename T>
|
||||
class ResizeNearest3DDNNLowPOp final
|
||||
: public DNNLowPOp<T, ResizeNearest3DFP32Op> {
|
||||
public:
|
||||
USE_OPERATOR_FUNCTIONS(CPUContext);
|
||||
USE_DNNLOWP_OPERATOR_BASE_FUNCTIONS(T, ResizeNearest3DFP32Op);
|
||||
|
||||
ResizeNearest3DDNNLowPOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: BaseType(operator_def, ws),
|
||||
temporal_scale_(
|
||||
this->template GetSingleArgument<float>("temporal_scale", 1)),
|
||||
width_scale_(this->template GetSingleArgument<float>("width_scale", 1)),
|
||||
height_scale_(
|
||||
this->template GetSingleArgument<float>("height_scale", 1)) {
|
||||
CAFFE_ENFORCE_GT(temporal_scale_, 0);
|
||||
CAFFE_ENFORCE_GT(width_scale_, 0);
|
||||
CAFFE_ENFORCE_GT(height_scale_, 0);
|
||||
|
||||
const auto& order = StringToStorageOrder(
|
||||
this->template GetSingleArgument<std::string>("order", "NHWC"));
|
||||
CAFFE_ENFORCE_EQ(order, StorageOrder::NHWC);
|
||||
}
|
||||
|
||||
bool RunOnDevice() override;
|
||||
|
||||
private:
|
||||
float temporal_scale_;
|
||||
float width_scale_;
|
||||
float height_scale_;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
@ -1,62 +0,0 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
import hypothesis.strategies as st
|
||||
import numpy as np
|
||||
from caffe2.python import core, dyndep, workspace
|
||||
from hypothesis import given
|
||||
|
||||
|
||||
dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops")
|
||||
workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"])
|
||||
|
||||
|
||||
class DNNLowPResizeNearest3DOpTest(hu.HypothesisTestCase):
|
||||
@given(
|
||||
N=st.integers(1, 3),
|
||||
T=st.integers(1, 16),
|
||||
H=st.integers(10, 300),
|
||||
W=st.integers(10, 300),
|
||||
C=st.integers(1, 32),
|
||||
scale_t=st.floats(0.25, 4.0) | st.just(2.0),
|
||||
scale_w=st.floats(0.25, 4.0) | st.just(2.0),
|
||||
scale_h=st.floats(0.25, 4.0) | st.just(2.0),
|
||||
**hu.gcs_cpu_only
|
||||
)
|
||||
def test_resize_nearest(self, N, T, H, W, C, scale_t, scale_w, scale_h, gc, dc):
|
||||
X = np.round(np.random.rand(N, T, H, W, C) * 255).astype(np.float32)
|
||||
|
||||
quantize = core.CreateOperator("Quantize", ["X"], ["X_q"], engine="DNNLOWP")
|
||||
resize_nearest = core.CreateOperator(
|
||||
"Int8ResizeNearest3D",
|
||||
["X_q"],
|
||||
["Y_q"],
|
||||
temporal_scale=scale_t,
|
||||
width_scale=scale_w,
|
||||
height_scale=scale_h,
|
||||
engine="DNNLOWP",
|
||||
)
|
||||
|
||||
net = core.Net("test_net")
|
||||
net.Proto().op.extend([quantize, resize_nearest])
|
||||
|
||||
workspace.FeedBlob("X", X)
|
||||
workspace.RunNetOnce(net)
|
||||
X_q = workspace.FetchInt8Blob("X_q").data
|
||||
Y_q = workspace.FetchInt8Blob("Y_q").data
|
||||
|
||||
def resize_nearest_ref(X):
|
||||
outT = np.int32(T * scale_t)
|
||||
outH = np.int32(H * scale_h)
|
||||
outW = np.int32(W * scale_w)
|
||||
outT_idxs, outH_idxs, outW_idxs = np.meshgrid(
|
||||
np.arange(outT), np.arange(outH), np.arange(outW), indexing="ij"
|
||||
)
|
||||
inT_idxs = np.minimum(outT_idxs / scale_t, T - 1).astype(np.int32)
|
||||
inH_idxs = np.minimum(outH_idxs / scale_h, H - 1).astype(np.int32)
|
||||
inW_idxs = np.minimum(outW_idxs / scale_w, W - 1).astype(np.int32)
|
||||
Y = X[:, inT_idxs, inH_idxs, inW_idxs, :]
|
||||
return Y
|
||||
|
||||
Y_q_ref = resize_nearest_ref(X_q)
|
||||
np.testing.assert_allclose(Y_q, Y_q_ref)
|
@ -131,10 +131,15 @@ void PyTorchStreamReader::init() {
|
||||
". Your PyTorch installation may be too old.");
|
||||
}
|
||||
|
||||
void PyTorchStreamReader::valid(const char* what) {
|
||||
void PyTorchStreamReader::valid(const char* what, const char* info) {
|
||||
auto err = mz_zip_get_last_error(ar_.get());
|
||||
if (err != MZ_ZIP_NO_ERROR) {
|
||||
CAFFE_THROW("PytorchStreamReader failed ", what, ": ", mz_zip_get_error_string(err));
|
||||
CAFFE_THROW(
|
||||
"PytorchStreamReader failed ",
|
||||
what,
|
||||
info,
|
||||
": ",
|
||||
mz_zip_get_error_string(err));
|
||||
}
|
||||
}
|
||||
|
||||
@ -173,7 +178,7 @@ bool PyTorchStreamReader::hasRecord(const std::string& name) {
|
||||
if (!result) {
|
||||
ar_->m_last_error = MZ_ZIP_NO_ERROR;
|
||||
}
|
||||
valid("attempting to locate file");
|
||||
valid("attempting to locate file ", name.c_str());
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -184,7 +189,7 @@ size_t PyTorchStreamReader::getRecordID(const std::string& name) {
|
||||
if (ar_->m_last_error == MZ_ZIP_FILE_NOT_FOUND) {
|
||||
CAFFE_THROW("file not found: ", ss.str());
|
||||
}
|
||||
valid("locating file");
|
||||
valid("locating file ", name.c_str());
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -193,10 +198,10 @@ std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string
|
||||
size_t key = getRecordID(name);
|
||||
mz_zip_archive_file_stat stat;
|
||||
mz_zip_reader_file_stat(ar_.get(), key, &stat);
|
||||
valid("retrieving file meta-data");
|
||||
valid("retrieving file meta-data for ", name.c_str());
|
||||
void * ptr = malloc(stat.m_uncomp_size);
|
||||
mz_zip_reader_extract_to_mem(ar_.get(), key, ptr, stat.m_uncomp_size, 0);
|
||||
valid("reading file");
|
||||
valid("reading file ", name.c_str());
|
||||
|
||||
at::DataPtr retval(ptr, ptr, free, at::kCPU);
|
||||
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
|
||||
@ -209,7 +214,7 @@ static int64_t read_le_16(uint8_t* buf) {
|
||||
size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
|
||||
mz_zip_archive_file_stat stat;
|
||||
mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat);
|
||||
valid("retriving file meta-data");
|
||||
valid("retrieving file meta-data for ", name.c_str());
|
||||
uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
|
||||
in_->read(
|
||||
stat.m_local_header_ofs,
|
||||
@ -224,7 +229,7 @@ size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
|
||||
|
||||
PyTorchStreamReader::~PyTorchStreamReader() {
|
||||
mz_zip_reader_end(ar_.get());
|
||||
valid("closing reader");
|
||||
valid("closing reader for archive ", archive_name_.c_str());
|
||||
}
|
||||
|
||||
size_t ostream_write_func(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n) {
|
||||
@ -260,14 +265,14 @@ PyTorchStreamWriter::PyTorchStreamWriter(
|
||||
if (!out_) {
|
||||
file_stream_.open(file_name, std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
|
||||
out_ = &file_stream_;
|
||||
valid("opening archive");
|
||||
valid("opening archive ", file_name.c_str());
|
||||
}
|
||||
|
||||
ar_->m_pIO_opaque = this;
|
||||
ar_->m_pWrite = ostream_write_func;
|
||||
|
||||
mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
|
||||
valid("initializing archive");
|
||||
valid("initializing archive ", file_name.c_str());
|
||||
|
||||
std::stringstream version;
|
||||
version << kMaxSupportedFileFormatVersion << "\n";
|
||||
@ -296,7 +301,7 @@ void PyTorchStreamWriter::writeRecord(const std::string& name, const void* data,
|
||||
padding.size(),
|
||||
nullptr,
|
||||
0);
|
||||
valid("writing file");
|
||||
valid("writing file ", name.c_str());
|
||||
}
|
||||
|
||||
void PyTorchStreamWriter::writeEndOfFile() {
|
||||
@ -304,19 +309,23 @@ void PyTorchStreamWriter::writeEndOfFile() {
|
||||
finalized_ = true;
|
||||
mz_zip_writer_finalize_archive(ar_.get());
|
||||
mz_zip_writer_end(ar_.get());
|
||||
valid("writing central directory");
|
||||
valid("writing central directory for archive ", archive_name_.c_str());
|
||||
if (file_stream_.is_open())
|
||||
file_stream_.close();
|
||||
}
|
||||
|
||||
|
||||
void PyTorchStreamWriter::valid(const char* what) {
|
||||
void PyTorchStreamWriter::valid(const char* what, const char* info) {
|
||||
auto err = mz_zip_get_last_error(ar_.get());
|
||||
if (err != MZ_ZIP_NO_ERROR) {
|
||||
CAFFE_THROW("PytorchStreamWriter failed ", what, ": ", mz_zip_get_error_string(err));
|
||||
CAFFE_THROW(
|
||||
"PytorchStreamWriter failed ",
|
||||
what,
|
||||
info,
|
||||
": ",
|
||||
mz_zip_get_error_string(err));
|
||||
}
|
||||
if (!*out_) {
|
||||
CAFFE_THROW("PytorchStreamWriter failed ", what, ".");
|
||||
CAFFE_THROW("PytorchStreamWriter failed ", what, info, ".");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -111,7 +111,7 @@ class CAFFE2_API PyTorchStreamReader final {
|
||||
private:
|
||||
void init();
|
||||
size_t read(uint64_t pos, char* buf, size_t n);
|
||||
void valid(const char* what);
|
||||
void valid(const char* what, const char* info = "");
|
||||
size_t getRecordID(const std::string& name);
|
||||
|
||||
friend size_t
|
||||
@ -141,14 +141,18 @@ class CAFFE2_API PyTorchStreamWriter final {
|
||||
~PyTorchStreamWriter();
|
||||
|
||||
private:
|
||||
void valid(const char* what);
|
||||
size_t current_pos_ = 0;
|
||||
std::unique_ptr<mz_zip_archive> ar_;
|
||||
std::string archive_name_;
|
||||
std::ostream* out_;
|
||||
std::ofstream file_stream_;
|
||||
bool finalized_ = false;
|
||||
friend size_t ostream_write_func(void *pOpaque, uint64_t file_ofs, const void *pBuf, size_t n);
|
||||
void valid(const char* what, const char* info = "");
|
||||
size_t current_pos_ = 0;
|
||||
std::unique_ptr<mz_zip_archive> ar_;
|
||||
std::string archive_name_;
|
||||
std::ostream* out_;
|
||||
std::ofstream file_stream_;
|
||||
bool finalized_ = false;
|
||||
friend size_t ostream_write_func(
|
||||
void* pOpaque,
|
||||
uint64_t file_ofs,
|
||||
const void* pBuf,
|
||||
size_t n);
|
||||
};
|
||||
|
||||
} // namespace serialize
|
||||
|
@ -117,12 +117,9 @@ if ((NOT TARGET protobuf::libprotobuf) AND (NOT TARGET protobuf::libprotobuf-lit
|
||||
# "Please set the proper paths so that I can find protobuf correctly.")
|
||||
endif()
|
||||
|
||||
# Protobuf generated files use <> as inclusion path, so maybe we should use
|
||||
# SYSTEM inclusion path. But we need these include dirs to be found before
|
||||
# other protobuf include dirs in Anaconda
|
||||
get_target_property(__tmp protobuf::libprotobuf INTERFACE_INCLUDE_DIRECTORIES)
|
||||
message(STATUS "Caffe2 protobuf include directory: " ${__tmp})
|
||||
include_directories(BEFORE ${__tmp})
|
||||
include_directories(BEFORE SYSTEM ${__tmp})
|
||||
|
||||
# If Protobuf_VERSION is known (true in most cases, false if we are building
|
||||
# local protobuf), then we will add a protobuf version check in
|
||||
|
@ -53,8 +53,6 @@ extensions = [
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinxcontrib.katex',
|
||||
'sphinx.ext.autosectionlabel',
|
||||
'javasphinx',
|
||||
]
|
||||
|
||||
# katex options
|
||||
|
@ -16,7 +16,6 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
|
||||
:caption: Notes
|
||||
|
||||
notes/*
|
||||
PyTorch on XLA Devices <http://pytorch.org/xla/>
|
||||
|
||||
.. toctree::
|
||||
:glob:
|
||||
@ -26,8 +25,8 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
|
||||
community/*
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Python API
|
||||
:maxdepth: 2
|
||||
:caption: Package Reference
|
||||
|
||||
torch
|
||||
nn
|
||||
@ -43,7 +42,6 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
|
||||
nn.init
|
||||
onnx
|
||||
optim
|
||||
quantization
|
||||
torch.random <random>
|
||||
sparse
|
||||
storage
|
||||
@ -55,8 +53,6 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
|
||||
torch.utils.model_zoo <model_zoo>
|
||||
torch.utils.tensorboard <tensorboard>
|
||||
type_info
|
||||
named_tensor
|
||||
name_inference
|
||||
torch.__config__ <__config__>
|
||||
|
||||
.. toctree::
|
||||
@ -66,24 +62,9 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
|
||||
|
||||
torchvision/index
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: torchaudio Reference
|
||||
* `torchaudio <https://pytorch.org/audio>`_
|
||||
|
||||
torchaudio <https://pytorch.org/audio>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: torchtext Reference
|
||||
|
||||
torchtext <https://pytorch.org/text>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Other Languages
|
||||
|
||||
C++ API <https://pytorch.org/cppdocs/>
|
||||
packages
|
||||
* `torchtext <https://pytorch.org/text>`_
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 9.7 KiB |
@ -1,468 +0,0 @@
|
||||
.. currentmodule:: torch
|
||||
|
||||
.. _name_inference_reference-doc:
|
||||
|
||||
Named Tensors operator coverage
|
||||
===============================
|
||||
|
||||
Please read :ref:`named_tensors-doc` first for an introduction to named tensors.
|
||||
|
||||
This document is a reference for *name inference*, a process that defines how
|
||||
named tensors:
|
||||
|
||||
1. use names to provide additional automatic runtime correctness checks
|
||||
2. propagate names from input tensors to output tensors
|
||||
|
||||
Below is a list of all operations that are supported with named tensors
|
||||
and their associated name inference rules.
|
||||
|
||||
If you don't see an operation listed here, but it would help your use case, please
|
||||
`search if an issue has already been filed <https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+named+tensor%22>`_ and if not, `file one <https://github.com/pytorch/pytorch/issues/new/choose>`_.
|
||||
|
||||
.. warning::
|
||||
The named tensor API is experimental and subject to change.
|
||||
|
||||
.. csv-table:: Supported Operations
|
||||
:header: API, Name inference rule
|
||||
:widths: 20, 20
|
||||
|
||||
":meth:`Tensor.abs`, :func:`torch.abs`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.abs_`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.acos`, :func:`torch.acos`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.acos_`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.add`, :func:`torch.add`",:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.add_`,:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.addmm`, :func:`torch.addmm`",:ref:`contracts_away_dims-doc`
|
||||
:meth:`Tensor.addmm_`,:ref:`contracts_away_dims-doc`
|
||||
":meth:`Tensor.addmv`, :func:`torch.addmv`",:ref:`contracts_away_dims-doc`
|
||||
:meth:`Tensor.addmv_`,:ref:`contracts_away_dims-doc`
|
||||
:meth:`Tensor.align_as`,See documentation
|
||||
:meth:`Tensor.align_to`,See documentation
|
||||
":meth:`Tensor.all`, :func:`torch.all`",None
|
||||
":meth:`Tensor.any`, :func:`torch.any`",None
|
||||
":meth:`Tensor.asin`, :func:`torch.asin`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.asin_`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.atan`, :func:`torch.atan`",:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.atan2`, :func:`torch.atan2`",:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.atan2_`,:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.atan_`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.bernoulli`, :func:`torch.bernoulli`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.bernoulli_`,None
|
||||
:meth:`Tensor.bfloat16`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.bitwise_not`, :func:`torch.bitwise_not`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.bitwise_not_`,None
|
||||
":meth:`Tensor.bmm`, :func:`torch.bmm`",:ref:`contracts_away_dims-doc`
|
||||
:meth:`Tensor.bool`,:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.byte`,:ref:`keeps_input_names-doc`
|
||||
:func:`torch.cat`,:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.cauchy_`,None
|
||||
":meth:`Tensor.ceil`, :func:`torch.ceil`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.ceil_`,None
|
||||
:meth:`Tensor.char`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.chunk`, :func:`torch.chunk`",:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.clamp`, :func:`torch.clamp`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.clamp_`,None
|
||||
:meth:`Tensor.copy_`,:ref:`out_function_semantics-doc`
|
||||
":meth:`Tensor.cos`, :func:`torch.cos`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.cos_`,None
|
||||
":meth:`Tensor.cosh`, :func:`torch.cosh`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.cosh_`,None
|
||||
:meth:`Tensor.cpu`,:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.cuda`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.cumprod`, :func:`torch.cumprod`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.cumsum`, :func:`torch.cumsum`",:ref:`removes_dimensions-doc`
|
||||
:meth:`Tensor.data_ptr`,None
|
||||
":meth:`Tensor.detach`, :func:`torch.detach`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.detach_`,None
|
||||
":attr:`Tensor.device`, :func:`torch.device`",None
|
||||
":meth:`Tensor.digamma`, :func:`torch.digamma`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.digamma_`,None
|
||||
:meth:`Tensor.dim`,None
|
||||
":meth:`Tensor.div`, :func:`torch.div`",:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.div_`,:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.dot`, :func:`torch.dot`",None
|
||||
:meth:`Tensor.double`,:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.element_size`,None
|
||||
:func:`torch.empty`,:ref:`factory-doc`
|
||||
:func:`torch.empty_like`,:ref:`factory-doc`
|
||||
":meth:`Tensor.eq`, :func:`torch.eq`",:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.erf`, :func:`torch.erf`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.erf_`,None
|
||||
":meth:`Tensor.erfc`, :func:`torch.erfc`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.erfc_`,None
|
||||
":meth:`Tensor.erfinv`, :func:`torch.erfinv`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.erfinv_`,None
|
||||
":meth:`Tensor.exp`, :func:`torch.exp`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.exp_`,None
|
||||
:meth:`Tensor.expand`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.expm1`, :func:`torch.expm1`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.expm1_`,None
|
||||
:meth:`Tensor.exponential_`,None
|
||||
:meth:`Tensor.fill_`,None
|
||||
":meth:`Tensor.flatten`, :func:`torch.flatten`",See documentation
|
||||
:meth:`Tensor.float`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.floor`, :func:`torch.floor`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.floor_`,None
|
||||
":meth:`Tensor.frac`, :func:`torch.frac`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.frac_`,None
|
||||
":meth:`Tensor.ge`, :func:`torch.ge`",:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.get_device`, :func:`torch.get_device`",None
|
||||
:attr:`Tensor.grad`,None
|
||||
":meth:`Tensor.gt`, :func:`torch.gt`",:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.half`,:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.has_names`,See documentation
|
||||
":meth:`Tensor.index_fill`, :func:`torch.index_fill`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.index_fill_`,None
|
||||
:meth:`Tensor.int`,:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.is_contiguous`,None
|
||||
:attr:`Tensor.is_cuda`,None
|
||||
":meth:`Tensor.is_floating_point`, :func:`torch.is_floating_point`",None
|
||||
:attr:`Tensor.is_leaf`,None
|
||||
:meth:`Tensor.is_pinned`,None
|
||||
:meth:`Tensor.is_shared`,None
|
||||
":meth:`Tensor.is_signed`, :func:`torch.is_signed`",None
|
||||
:attr:`Tensor.is_sparse`,None
|
||||
:func:`torch.is_tensor`,None
|
||||
:meth:`Tensor.item`,None
|
||||
":meth:`Tensor.kthvalue`, :func:`torch.kthvalue`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.le`, :func:`torch.le`",:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.log`, :func:`torch.log`",:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.log10`, :func:`torch.log10`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.log10_`,None
|
||||
":meth:`Tensor.log1p`, :func:`torch.log1p`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.log1p_`,None
|
||||
":meth:`Tensor.log2`, :func:`torch.log2`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.log2_`,None
|
||||
:meth:`Tensor.log_`,None
|
||||
:meth:`Tensor.log_normal_`,None
|
||||
":meth:`Tensor.logical_not`, :func:`torch.logical_not`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.logical_not_`,None
|
||||
":meth:`Tensor.logsumexp`, :func:`torch.logsumexp`",:ref:`removes_dimensions-doc`
|
||||
:meth:`Tensor.long`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.lt`, :func:`torch.lt`",:ref:`unifies_names_from_inputs-doc`
|
||||
:func:`torch.manual_seed`,None
|
||||
":meth:`Tensor.masked_fill`, :func:`torch.masked_fill`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.masked_fill_`,None
|
||||
":meth:`Tensor.masked_select`, :func:`torch.masked_select`",Aligns mask up to input and then unifies_names_from_input_tensors
|
||||
":meth:`Tensor.matmul`, :func:`torch.matmul`",:ref:`contracts_away_dims-doc`
|
||||
":meth:`Tensor.mean`, :func:`torch.mean`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.median`, :func:`torch.median`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.mm`, :func:`torch.mm`",:ref:`contracts_away_dims-doc`
|
||||
":meth:`Tensor.mode`, :func:`torch.mode`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.mul`, :func:`torch.mul`",:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.mul_`,:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.mv`, :func:`torch.mv`",:ref:`contracts_away_dims-doc`
|
||||
:attr:`Tensor.names`,See documentation
|
||||
":meth:`Tensor.narrow`, :func:`torch.narrow`",:ref:`keeps_input_names-doc`
|
||||
:attr:`Tensor.ndim`,None
|
||||
:meth:`Tensor.ndimension`,None
|
||||
":meth:`Tensor.ne`, :func:`torch.ne`",:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.neg`, :func:`torch.neg`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.neg_`,None
|
||||
:func:`torch.normal`,:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.normal_`,None
|
||||
":meth:`Tensor.numel`, :func:`torch.numel`",None
|
||||
:func:`torch.ones`,:ref:`factory-doc`
|
||||
":meth:`Tensor.pow`, :func:`torch.pow`",:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.pow_`,None
|
||||
":meth:`Tensor.prod`, :func:`torch.prod`",:ref:`removes_dimensions-doc`
|
||||
:func:`torch.rand`,:ref:`factory-doc`
|
||||
:func:`torch.rand`,:ref:`factory-doc`
|
||||
:func:`torch.randn`,:ref:`factory-doc`
|
||||
:func:`torch.randn`,:ref:`factory-doc`
|
||||
:meth:`Tensor.random_`,None
|
||||
":meth:`Tensor.reciprocal`, :func:`torch.reciprocal`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.reciprocal_`,None
|
||||
:meth:`Tensor.refine_names`,See documentation
|
||||
:meth:`Tensor.register_hook`,None
|
||||
:meth:`Tensor.rename`,See documentation
|
||||
:meth:`Tensor.rename_`,See documentation
|
||||
:attr:`Tensor.requires_grad`,None
|
||||
:meth:`Tensor.requires_grad_`,None
|
||||
:meth:`Tensor.resize_`,Only allow resizes that do not change shape
|
||||
:meth:`Tensor.resize_as_`,Only allow resizes that do not change shape
|
||||
":meth:`Tensor.round`, :func:`torch.round`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.round_`,None
|
||||
":meth:`Tensor.rsqrt`, :func:`torch.rsqrt`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.rsqrt_`,None
|
||||
":meth:`Tensor.select`, :func:`torch.select`",:ref:`removes_dimensions-doc`
|
||||
:meth:`Tensor.short`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.sigmoid`, :func:`torch.sigmoid`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.sigmoid_`,None
|
||||
":meth:`Tensor.sign`, :func:`torch.sign`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.sign_`,None
|
||||
":meth:`Tensor.sin`, :func:`torch.sin`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.sin_`,None
|
||||
":meth:`Tensor.sinh`, :func:`torch.sinh`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.sinh_`,None
|
||||
:meth:`Tensor.size`,None
|
||||
":meth:`Tensor.split`, :func:`torch.split`",:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.sqrt`, :func:`torch.sqrt`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.sqrt_`,None
|
||||
":meth:`Tensor.squeeze`, :func:`torch.squeeze`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.std`, :func:`torch.std`",:ref:`removes_dimensions-doc`
|
||||
:func:`torch.std_mean`,:ref:`removes_dimensions-doc`
|
||||
:meth:`Tensor.stride`,None
|
||||
":meth:`Tensor.sub`, :func:`torch.sub`",:ref:`unifies_names_from_inputs-doc`
|
||||
:meth:`Tensor.sub_`,:ref:`unifies_names_from_inputs-doc`
|
||||
":meth:`Tensor.sum`, :func:`torch.sum`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.tan`, :func:`torch.tan`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.tan_`,None
|
||||
":meth:`Tensor.tanh`, :func:`torch.tanh`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.tanh_`,None
|
||||
:func:`torch.tensor`,:ref:`factory-doc`
|
||||
:meth:`Tensor.to`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.topk`, :func:`torch.topk`",:ref:`removes_dimensions-doc`
|
||||
":meth:`Tensor.transpose`, :func:`torch.transpose`",:ref:`permutes_dimensions-doc`
|
||||
":meth:`Tensor.trunc`, :func:`torch.trunc`",:ref:`keeps_input_names-doc`
|
||||
:meth:`Tensor.trunc_`,None
|
||||
:meth:`Tensor.type`,None
|
||||
:meth:`Tensor.type_as`,:ref:`keeps_input_names-doc`
|
||||
":meth:`Tensor.unbind`, :func:`torch.unbind`",:ref:`removes_dimensions-doc`
|
||||
:meth:`Tensor.unflatten`,See documentation
|
||||
:meth:`Tensor.uniform_`,None
|
||||
":meth:`Tensor.var`, :func:`torch.var`",:ref:`removes_dimensions-doc`
|
||||
:func:`torch.var_mean`,:ref:`removes_dimensions-doc`
|
||||
:meth:`Tensor.zero_`,None
|
||||
:func:`torch.zeros`,:ref:`factory-doc`
|
||||
|
||||
|
||||
.. _keeps_input_names-doc:
|
||||
|
||||
Keeps input names
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
All pointwise unary functions follow this rule as well as some other unary functions.
|
||||
|
||||
- Check names: None
|
||||
- Propagate names: input tensor's names are propagated to the output.
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(3, 3, names=('N', 'C'))
|
||||
>>> x.abs().names
|
||||
('N', 'C')
|
||||
|
||||
.. _removes_dimensions-doc:
|
||||
|
||||
Removes dimensions
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
All reduction ops like :meth:`~Tensor.sum` remove dimensions by reducing
|
||||
over the desired dimensions. Other operations like :meth:`~Tensor.select` and
|
||||
:meth:`~Tensor.squeeze` remove dimensions.
|
||||
|
||||
Wherever one can pass an integer dimension index to an operator, one can also pass
|
||||
a dimension name. Functions that take lists of dimension indices can also take in a
|
||||
list of dimension names.
|
||||
|
||||
- Check names: If :attr:`dim` or :attr:`dims` is passed in as a list of names,
|
||||
check that those names exist in :attr:`self`.
|
||||
- Propagate names: If the dimensions of the input tensor specified by :attr:`dim`
|
||||
or :attr:`dims` are not present in the output tensor, then the corresponding names
|
||||
of those dimensions do not appear in ``output.names``.
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W'))
|
||||
>>> x.squeeze('N').names
|
||||
('C', 'H', 'W')
|
||||
|
||||
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
|
||||
>>> x.sum(['N', 'C']).names
|
||||
('H', 'W')
|
||||
|
||||
# Reduction ops with keepdim=True don't actually remove dimensions.
|
||||
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
|
||||
>>> x.sum(['N', 'C'], keepdim=True).names
|
||||
('N', 'C', 'H', 'W')
|
||||
|
||||
|
||||
.. _unifies_names_from_inputs-doc:
|
||||
|
||||
Unifies names from inputs
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
All binary arithmetic ops follow this rule. Operations that broadcast still
|
||||
broadcast positionally from the right to preserve compatibility with unnamed
|
||||
tensors. To perform explicit broadcasting by names, use :meth:`Tensor.align_as`.
|
||||
|
||||
- Check names: All names must match positionally from the right. i.e., in
|
||||
``tensor + other``, ``match(tensor.names[i], other.names[i])`` must be true for all
|
||||
``i`` in ``(-min(tensor.dim(), other.dim()) + 1, -1]``.
|
||||
- Check names: Furthermore, all named dimensions must be aligned from the right.
|
||||
During matching, if we match a named dimension ``A`` with an unnamed dimension
|
||||
``None``, then ``A`` must not appear in the tensor with the unnamed dimension.
|
||||
- Propagate names: unify pairs of names from the right from both tensors to
|
||||
produce output names.
|
||||
|
||||
For example,
|
||||
|
||||
::
|
||||
|
||||
# tensor: Tensor[ N, None]
|
||||
# other: Tensor[None, C]
|
||||
>>> tensor = torch.randn(3, 3, names=('N', None))
|
||||
>>> other = torch.randn(3, 3, names=(None, 'C'))
|
||||
>>> (tensor + other).names
|
||||
('N', 'C')
|
||||
|
||||
Check names:
|
||||
|
||||
- ``match(tensor.names[-1], other.names[-1])`` is ``True``
|
||||
- ``match(tensor.names[-2], tensor.names[-2])`` is ``True``
|
||||
- Because we matched ``None`` in :attr:`tensor` with ``'C'``,
|
||||
check to make sure ``'C'`` doesn't exist in :attr:`tensor` (it does not).
|
||||
- Check to make sure ``'N'`` doesn't exists in :attr:`other` (it does not).
|
||||
|
||||
Finally, the output names are computed with
|
||||
``[unify('N', None), unify(None, 'C')] = ['N', 'C']``
|
||||
|
||||
More examples::
|
||||
|
||||
# Dimensions don't match from the right:
|
||||
# tensor: Tensor[N, C]
|
||||
# other: Tensor[ N]
|
||||
>>> tensor = torch.randn(3, 3, names=('N', 'C'))
|
||||
>>> other = torch.randn(3, names=('N',))
|
||||
>>> (tensor + other).names
|
||||
RuntimeError: Error when attempting to broadcast dims ['N', 'C'] and dims
|
||||
['N']: dim 'C' and dim 'N' are at the same position from the right but do
|
||||
not match.
|
||||
|
||||
# Dimensions aren't aligned when matching tensor.names[-1] and other.names[-1]:
|
||||
# tensor: Tensor[N, None]
|
||||
# other: Tensor[ N]
|
||||
>>> tensor = torch.randn(3, 3, names=('N', None))
|
||||
>>> other = torch.randn(3, names=('N',))
|
||||
>>> (tensor + other).names
|
||||
RuntimeError: Misaligned dims when attempting to broadcast dims ['N'] and
|
||||
dims ['N', None]: dim 'N' appears in a different position from the right
|
||||
across both lists.
|
||||
|
||||
.. note::
|
||||
|
||||
In both of the last examples, it is possible to align the tensors by names
|
||||
and then perform the addition. Use :meth:`Tensor.align_as` to align
|
||||
tensors by name or :meth:`Tensor.align_to` to align tensors to a custom
|
||||
dimension ordering.
|
||||
|
||||
.. _permutes_dimensions-doc:
|
||||
|
||||
Permutes dimensions
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Some operations, like :meth:`Tensor.t()`, permute the order of dimensions. Dimension names
|
||||
are attached to individual dimensions so they get permuted as well.
|
||||
|
||||
If the operator takes in positional index :attr:`dim`, it is also able to take a dimension
|
||||
name as :attr:`dim`.
|
||||
|
||||
- Check names: If :attr:`dim` is passed as a name, check that it exists in the tensor.
|
||||
- Propagate names: Permute dimension names in the same way as the dimensions that are
|
||||
being permuted.
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(3, 3, names=('N', 'C'))
|
||||
>>> x.transpose('N', 'C').names
|
||||
('C', 'N')
|
||||
|
||||
.. _contracts_away_dims-doc:
|
||||
|
||||
Contracts away dims
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Matrix multiply functions follow some variant of this. Let's go through
|
||||
:func:`torch.mm` first and then generalize the rule for batch matrix multiplication.
|
||||
|
||||
For ``torch.mm(tensor, other)``:
|
||||
|
||||
- Check names: None
|
||||
- Propagate names: result names are ``(tensor.names[-2], other.names[-1])``.
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(3, 3, names=('N', 'D'))
|
||||
>>> y = torch.randn(3, 3, names=('in', 'out'))
|
||||
>>> x.mm(y).names
|
||||
('N', 'out')
|
||||
|
||||
Inherently, a matrix multiplication performs a dot product over two dimensions,
|
||||
collapsing them. When two tensors are matrix-multipled, the contracted dimensions
|
||||
disappear and do not show up in the output tensor.
|
||||
|
||||
:func:`torch.mv`, :func:`torch.dot` work in a similar way: name inference does not
|
||||
check input names and removes the dimensions that are involved in the dot product:
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(3, 3, names=('N', 'D'))
|
||||
>>> y = torch.randn(3, names=('something',))
|
||||
>>> x.mv(y).names
|
||||
('N',)
|
||||
|
||||
Now, let's take a look at ``torch.matmul(tensor, other)``. Assume that ``tensor.dim() >= 2``
|
||||
and ``other.dim() >= 2``.
|
||||
|
||||
- Check names: Check that the batch dimensions of the inputs are aligned and broadcastable.
|
||||
See :ref:`unifies_names_from_inputs-doc` for what it means for the inputs to be aligned.
|
||||
- Propagate names: result names are obtained by unifying the batch dimensions and removing
|
||||
the contracted dimensions:
|
||||
``unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])``.
|
||||
|
||||
Examples::
|
||||
|
||||
# Batch matrix multiply of matrices Tensor['C', 'D'] and Tensor['E', 'F'].
|
||||
# 'A', 'B' are batch dimensions.
|
||||
>>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D))
|
||||
>>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F))
|
||||
>>> torch.matmul(x, y).names
|
||||
('A', 'B', 'C', 'F')
|
||||
|
||||
|
||||
Finally, there are fused ``add`` versions of many matmul functions. i.e., :func:`addmm`
|
||||
and :func:`addmv`. These are treated as composing name inference for i.e. :func:`mm` and
|
||||
name inference for :func:`add`.
|
||||
|
||||
.. _factory-doc:
|
||||
|
||||
Factory functions
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Factory functions now take a new :attr:`names` argument that associates a name
|
||||
with each dimension.
|
||||
|
||||
::
|
||||
|
||||
>>> torch.zeros(2, 3, names=('N', 'C'))
|
||||
tensor([[0., 0., 0.],
|
||||
[0., 0., 0.]], names=('N', 'C'))
|
||||
|
||||
.. _out_function_semantics-doc:
|
||||
|
||||
out function and in-place variants
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
A tensor specified as an ``out=`` tensor has the following behavior:
|
||||
|
||||
- If it has no named dimensions, then the names computed from the operation
|
||||
get propagated to it.
|
||||
- If it has any named dimensions, then the names computed from the operation
|
||||
must be exactly equal to the existing names. Otherwise, the operation errors.
|
||||
|
||||
All in-place methods modify inputs to have names equal to the computed names
|
||||
from name inference. For example,
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(3, 3)
|
||||
>>> y = torch.randn(3, 3, names=('N', 'C'))
|
||||
>>> x.names
|
||||
(None, None)
|
||||
|
||||
>>> x += y
|
||||
>>> x.names
|
||||
('N', 'C')
|
||||
|
@ -1,319 +0,0 @@
|
||||
.. currentmodule:: torch
|
||||
|
||||
.. _named_tensors-doc:
|
||||
|
||||
Named Tensors
|
||||
=============
|
||||
|
||||
Named Tensors aim to make tensors easier to use by allowing users to associate
|
||||
explicit names with tensor dimensions. In most cases, operations that take
|
||||
dimension parameters will accept dimension names, avoiding the need to track
|
||||
dimensions by position. In addition, named tensors use names to automatically
|
||||
check that APIs are being used correctly at runtime, providing extra safety.
|
||||
Names can also be used to rearrange dimensions, for example, to support
|
||||
"broadcasting by name" rather than "broadcasting by position".
|
||||
|
||||
.. warning::
|
||||
The named tensor API is experimental and subject to change.
|
||||
|
||||
Creating named tensors
|
||||
----------------------
|
||||
|
||||
Factory functions now take a new :attr:`names` argument that associates a name
|
||||
with each dimension.
|
||||
|
||||
::
|
||||
|
||||
>>> torch.zeros(2, 3, names=('N', 'C'))
|
||||
tensor([[0., 0., 0.],
|
||||
[0., 0., 0.]], names=('N', 'C'))
|
||||
|
||||
Named dimensions, like regular Tensor dimensions, are ordered.
|
||||
``tensor.names[i]`` is the name of dimension ``i`` of ``tensor``.
|
||||
|
||||
The following factory functions support named tensors:
|
||||
|
||||
- :func:`torch.empty`
|
||||
- :func:`torch.rand`
|
||||
- :func:`torch.randn`
|
||||
- :func:`torch.ones`
|
||||
- :func:`torch.tensor`
|
||||
- :func:`torch.zeros`
|
||||
|
||||
Named dimensions
|
||||
----------------
|
||||
|
||||
See :attr:`~Tensor.names` for restrictions on tensor names.
|
||||
|
||||
Use :attr:`~Tensor.names` to access the dimension names of a tensor and
|
||||
:meth:`~Tensor.rename` to rename named dimensions.
|
||||
|
||||
::
|
||||
|
||||
>>> imgs = torch.randn(1, 2, 2, 3 , names=('N', 'C', 'H', 'W'))
|
||||
>>> imgs.names
|
||||
('N', 'C', 'H', 'W')
|
||||
|
||||
>>> renamed_imgs = imgs.rename(H='height', W='width')
|
||||
>>> renamed_imgs.names
|
||||
('N', 'C', 'height', 'width)
|
||||
|
||||
|
||||
Named tensors can coexist with unnamed tensors; named tensors are instances of
|
||||
:class:`torch.Tensor`. Unnamed tensors have ``None``-named dimensions. Named
|
||||
tensors do not require all dimensions to be named.
|
||||
|
||||
::
|
||||
|
||||
>>> imgs = torch.randn(1, 2, 2, 3 , names=(None, 'C', 'H', 'W'))
|
||||
>>> imgs.names
|
||||
(None, 'C', 'H', 'W')
|
||||
|
||||
Name propagation semantics
|
||||
--------------------------
|
||||
|
||||
Named tensors use names to automatically check that APIs are being called
|
||||
correctly at runtime. This occurs in a process called *name inference*.
|
||||
More formally, name inference consists of the following two steps:
|
||||
|
||||
- **Check names**: an operator may perform automatic checks at runtime that
|
||||
check that certain dimension names must match.
|
||||
- **Propagate names**: name inference propagates names to output tensors.
|
||||
|
||||
All operations that support named tensors propagate names.
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(3, 3, names=('N', 'C'))
|
||||
>>> x.abs().names
|
||||
('N', 'C')
|
||||
|
||||
|
||||
.. _match_semantics-doc:
|
||||
|
||||
match semantics
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
Two names *match* if they are equal (string equality) or if at least one is ``None``.
|
||||
Nones are essentially a special "wildcard" name.
|
||||
|
||||
``unify(A, B)`` determines which of the names ``A`` and ``B`` to propagate to the outputs.
|
||||
It returns the more *specific* of the two names, if they match. If the names do not match,
|
||||
then it errors.
|
||||
|
||||
.. note::
|
||||
In practice, when working with named tensors, one should avoid having unnamed
|
||||
dimensions because their handling can be complicated. It is recommended to lift
|
||||
all unnamed dimensions to be named dimensions by using :meth:`~Tensor.refine_names`.
|
||||
|
||||
|
||||
Basic name inference rules
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Let's see how ``match`` and ``unify`` are used in name inference in the case of
|
||||
adding two one-dim tensors with no broadcasting.
|
||||
|
||||
::
|
||||
|
||||
x = torch.randn(3, names=('X',))
|
||||
y = torch.randn(3)
|
||||
z = torch.randn(3, names=('Z',))
|
||||
|
||||
**Check names**: check that the names of the two tensors *match*.
|
||||
|
||||
For the following examples:
|
||||
|
||||
::
|
||||
|
||||
>>> # x + y # match('X', None) is True
|
||||
>>> # x + z # match('X', 'Z') is False
|
||||
>>> # x + x # match('X', 'X') is True
|
||||
|
||||
>>> x + z
|
||||
Error when attempting to broadcast dims ['X'] and dims ['Z']: dim 'X' and dim 'Z' are at the same position from the right but do not match.
|
||||
|
||||
**Propagate names**: *unify* the names to select which one to propagate.
|
||||
In the case of ``x + y``, ``unify('X', None) = 'X'`` because ``'X'`` is more
|
||||
specific than ``None``.
|
||||
|
||||
::
|
||||
|
||||
>>> (x + y).names
|
||||
('X',)
|
||||
>>> (x + x).names
|
||||
('X',)
|
||||
|
||||
For a comprehensive list of name inference rules, see :ref:`name_inference_reference-doc`.
|
||||
Here are two common operations that may be useful to go over:
|
||||
|
||||
- Binary arithmetic ops: :ref:`unifies_names_from_inputs-doc`
|
||||
- Matrix multiplication ops: :ref:`contracts_away_dims-doc`
|
||||
|
||||
Explicit alignment by names
|
||||
---------------------------
|
||||
|
||||
Use :meth:`~Tensor.align_as` or :meth:`~Tensor.align_to` to align tensor dimensions
|
||||
by name to a specified ordering. This is useful for performing "broadcasting by names".
|
||||
|
||||
::
|
||||
|
||||
# This function is agnostic to the dimension ordering of `input`,
|
||||
# as long as it has a `C` dimension somewhere.
|
||||
def scale_channels(input, scale):
|
||||
scale = scale.refine_names('C')
|
||||
return input * scale.align_as(input)
|
||||
|
||||
>>> num_channels = 3
|
||||
>>> scale = torch.randn(num_channels, names='C')
|
||||
>>> imgs = torch.rand(3, 3, 3, num_channels, names=('N', 'H', 'W', 'C'))
|
||||
>>> more_imgs = torch.rand(3, num_channels, 3, 3, names=('N', 'C', 'H', 'W'))
|
||||
>>> videos = torch.randn(3, num_channels, 3, 3, 3, names=('N', 'C', 'H', 'W', 'D')
|
||||
|
||||
>>> scale_channels(imgs, scale)
|
||||
>>> scale_channels(more_imgs, scale)
|
||||
>>> scale_channels(videos, scale)
|
||||
|
||||
Manipulating dimensions
|
||||
-----------------------
|
||||
|
||||
Use :meth:`~Tensor.align_to` to permute large amounts of dimensions without
|
||||
mentioning all of them as in required by :meth:`~Tensor.permute`.
|
||||
|
||||
::
|
||||
|
||||
>>> tensor = torch.randn(2, 2, 2, 2, 2, 2)
|
||||
>>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F')
|
||||
|
||||
# Move the F (dim 5) and E dimension (dim 4) to the front while keeping
|
||||
# the rest in the same order
|
||||
>>> tensor.permute(5, 4, 0, 1, 2, 3)
|
||||
>>> named_tensor.align_to('F', 'E', ...) # Use '...' instead in Python 2
|
||||
|
||||
Use :meth:`~Tensor.flatten` and :meth:`~Tensor.unflatten` to flatten and unflatten
|
||||
dimensions, respectively. These methods are more verbose than :meth:`~Tensor.view`
|
||||
and :meth:`~Tensor.reshape`, but have more semantic meaning to someone reading the code.
|
||||
|
||||
::
|
||||
|
||||
>>> imgs = torch.randn(32, 3, 128, 128)
|
||||
>>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W')
|
||||
|
||||
>>> flat_imgs = imgs.view(32, -1)
|
||||
>>> named_flat_imgs = named_imgs.flatten(['C', 'H', 'W'], 'features')
|
||||
>>> named_flat_imgs.names
|
||||
('N', 'features')
|
||||
|
||||
>>> unflattened_imgs = imgs.view(32, 3, 128, 128)
|
||||
>>> unflattened_named_imgs = named_flat_imgs.unflatten(
|
||||
'features', [('C', 3), ('H', 128), ('W', 128)])
|
||||
|
||||
.. _named_tensors_autograd-doc:
|
||||
|
||||
Autograd support
|
||||
----------------
|
||||
|
||||
Autograd currently supports named tensors in a limited manner: autograd ignores
|
||||
names on all tensors. Gradient computation is still correct but we lose the
|
||||
safety that names give us.
|
||||
|
||||
::
|
||||
|
||||
>>> x = torch.randn(3, names=('D',))
|
||||
>>> weight = torch.randn(3, names=('D',), requires_grad=True)
|
||||
>>> loss = (x - weight).abs()
|
||||
>>> grad_loss = torch.randn(3)
|
||||
>>> loss.backward(grad_loss)
|
||||
>>> weight.grad # Unnamed for now. Will be named in the future
|
||||
tensor([-1.8107, -0.6357, 0.0783])
|
||||
|
||||
>>> weight.grad.zero_()
|
||||
>>> grad_loss = grad_loss.refine_names('C')
|
||||
>>> loss = (x - weight).abs()
|
||||
# Ideally we'd check that the names of loss and grad_loss match but we don't yet.
|
||||
>>> loss.backward(grad_loss)
|
||||
>>> weight.grad
|
||||
tensor([-1.8107, -0.6357, 0.0783])
|
||||
|
||||
Currently supported operations and subsystems
|
||||
---------------------------------------------
|
||||
|
||||
Operators
|
||||
^^^^^^^^^
|
||||
|
||||
See :ref:`name_inference_reference-doc` for a full list of the supported torch and
|
||||
tensor operations. We do not yet support the following that is not covered by the link:
|
||||
|
||||
- indexing, advanced indexing.
|
||||
|
||||
For ``torch.nn.functional`` operators, we support the following:
|
||||
|
||||
- :func:`torch.nn.functional.relu`
|
||||
- :func:`torch.nn.functional.softmax`
|
||||
- :func:`torch.nn.functional.log_softmax`
|
||||
- :func:`torch.nn.functional.tanh`
|
||||
- :func:`torch.nn.functional.sigmoid`
|
||||
- :func:`torch.nn.functional.dropout`
|
||||
|
||||
Subsystems
|
||||
^^^^^^^^^^
|
||||
|
||||
Autograd is supported, see :ref:`named_tensors_autograd-doc`.
|
||||
Because gradients are currently unnamed, optimizers may work but are untested.
|
||||
|
||||
NN modules are currently unsupported. This can lead to the following when calling
|
||||
modules with named tensor inputs:
|
||||
|
||||
- NN module parameters are unnamed, so outputs may be partially named.
|
||||
- NN module forward passes have code that don't support named tensors and will
|
||||
error out appropriately.
|
||||
|
||||
We also do not support the following subsystems, though some may work out
|
||||
of the box:
|
||||
|
||||
- distributions
|
||||
- serialization (:func:`torch.load`, :func:`torch.save`)
|
||||
- multiprocessing
|
||||
- JIT
|
||||
- distributed
|
||||
- ONNX
|
||||
|
||||
If any of these would help your use case, please
|
||||
`search if an issue has already been filed <https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+named+tensor%22>`_
|
||||
and if not, `file one <https://github.com/pytorch/pytorch/issues/new/choose>`_.
|
||||
|
||||
Named tensor API reference
|
||||
--------------------------
|
||||
|
||||
In this section please find the documentation for named tensor specific APIs.
|
||||
For a comprehensive reference for how names are propagated through other PyTorch
|
||||
operators, see :ref:`name_inference_reference-doc`.
|
||||
|
||||
.. class:: Tensor()
|
||||
:noindex:
|
||||
|
||||
.. autoattribute:: names
|
||||
.. automethod:: rename
|
||||
.. automethod:: rename_
|
||||
.. automethod:: refine_names
|
||||
|
||||
.. automethod:: align_as
|
||||
.. automethod:: align_to
|
||||
|
||||
.. automethod:: unflatten
|
||||
.. py:method:: flatten(dims, out_dim) -> Tensor
|
||||
|
||||
Flattens :attr:`dims` into a single dimension with name :attr:`out_dim`.
|
||||
|
||||
All of `dims` must be consecutive in order in the :attr:`self` tensor,
|
||||
but not necessary contiguous in memory.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> imgs = torch.randn(32, 3, 128, 128, names=('N', 'C', 'H', 'W'))
|
||||
>>> flat_imgs = imgs.flatten(['C', 'H', 'W'], 'features')
|
||||
>>> flat_imgs.names, flat_imgs.shape
|
||||
(('N', 'features'), torch.Size([32, 49152]))
|
||||
|
||||
.. warning::
|
||||
The named tensor API is experimental and subject to change.
|
||||
|
@ -877,10 +877,3 @@ Utilities
|
||||
|
||||
.. autoclass:: Flatten
|
||||
:members:
|
||||
|
||||
|
||||
Quantized Functions
|
||||
--------------------
|
||||
|
||||
Quantization refers to techniques for performing computations and storing tensors at lower bitwidths than
|
||||
floating point precision. PyTorch supports both per tensor and per channel asymmetric linear quantization. To learn more how to use quantized functions in PyTorch, please refer to the :ref:`Quantization` documentation.
|
||||
|
@ -1,67 +0,0 @@
|
||||
DType
|
||||
=====
|
||||
|
||||
.. java:package:: org.pytorch
|
||||
:noindex:
|
||||
|
||||
.. java:type:: public enum DType
|
||||
|
||||
Codes representing tensor data types.
|
||||
|
||||
Enum Constants
|
||||
--------------
|
||||
FLOAT32
|
||||
^^^^^^^
|
||||
|
||||
.. java:field:: public static final DType FLOAT32
|
||||
:outertype: DType
|
||||
|
||||
Code for dtype torch.float32. \ :java:ref:`Tensor.dtype()`\
|
||||
|
||||
FLOAT64
|
||||
^^^^^^^
|
||||
|
||||
.. java:field:: public static final DType FLOAT64
|
||||
:outertype: DType
|
||||
|
||||
Code for dtype torch.float64. \ :java:ref:`Tensor.dtype()`\
|
||||
|
||||
INT32
|
||||
^^^^^
|
||||
|
||||
.. java:field:: public static final DType INT32
|
||||
:outertype: DType
|
||||
|
||||
Code for dtype torch.int32. \ :java:ref:`Tensor.dtype()`\
|
||||
|
||||
INT64
|
||||
^^^^^
|
||||
|
||||
.. java:field:: public static final DType INT64
|
||||
:outertype: DType
|
||||
|
||||
Code for dtype torch.int64. \ :java:ref:`Tensor.dtype()`\
|
||||
|
||||
INT8
|
||||
^^^^
|
||||
|
||||
.. java:field:: public static final DType INT8
|
||||
:outertype: DType
|
||||
|
||||
Code for dtype torch.int8. \ :java:ref:`Tensor.dtype()`\
|
||||
|
||||
UINT8
|
||||
^^^^^
|
||||
|
||||
.. java:field:: public static final DType UINT8
|
||||
:outertype: DType
|
||||
|
||||
Code for dtype torch.uint8. \ :java:ref:`Tensor.dtype()`\
|
||||
|
||||
Fields
|
||||
------
|
||||
jniCode
|
||||
^^^^^^^
|
||||
|
||||
.. java:field:: final int jniCode
|
||||
:outertype: DType
|
@ -1,297 +0,0 @@
|
||||
.. java:import:: java.util Locale
|
||||
|
||||
.. java:import:: java.util Map
|
||||
|
||||
IValue
|
||||
======
|
||||
|
||||
.. java:package:: org.pytorch
|
||||
:noindex:
|
||||
|
||||
.. java:type:: public class IValue
|
||||
|
||||
Java representation of a TorchScript value, which is implemented as tagged union that can be one of the supported types: https://pytorch.org/docs/stable/jit.html#types .
|
||||
|
||||
Calling \ ``toX``\ methods for inappropriate types will throw \ :java:ref:`IllegalStateException`\ .
|
||||
|
||||
\ ``IValue``\ objects are constructed with \ ``IValue.from(value)``\ , \ ``IValue.tupleFrom(value1, value2, ...)``\ , \ ``IValue.listFrom(value1, value2, ...)``\ , or one of the \ ``dict``\ methods, depending on the key type.
|
||||
|
||||
Data is retrieved from \ ``IValue``\ objects with the \ ``toX()``\ methods. Note that \ ``str``\ -type IValues must be extracted with \ :java:ref:`toStr()`\ , rather than \ :java:ref:`toString()`\ .
|
||||
|
||||
\ ``IValue``\ objects may retain references to objects passed into their constructors, and may return references to their internal state from \ ``toX()``\ .
|
||||
|
||||
Methods
|
||||
-------
|
||||
dictLongKeyFrom
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue dictLongKeyFrom(Map<Long, IValue> map)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``Dict[int, V]``\ .
|
||||
|
||||
dictStringKeyFrom
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue dictStringKeyFrom(Map<String, IValue> map)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``Dict[str, V]``\ .
|
||||
|
||||
from
|
||||
^^^^
|
||||
|
||||
.. java:method:: public static IValue from(Tensor tensor)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``Tensor``\ .
|
||||
|
||||
from
|
||||
^^^^
|
||||
|
||||
.. java:method:: public static IValue from(boolean value)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``bool``\ .
|
||||
|
||||
from
|
||||
^^^^
|
||||
|
||||
.. java:method:: public static IValue from(long value)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``int``\ .
|
||||
|
||||
from
|
||||
^^^^
|
||||
|
||||
.. java:method:: public static IValue from(double value)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``float``\ .
|
||||
|
||||
from
|
||||
^^^^
|
||||
|
||||
.. java:method:: public static IValue from(String value)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``str``\ .
|
||||
|
||||
isBool
|
||||
^^^^^^
|
||||
|
||||
.. java:method:: public boolean isBool()
|
||||
:outertype: IValue
|
||||
|
||||
isBoolList
|
||||
^^^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isBoolList()
|
||||
:outertype: IValue
|
||||
|
||||
isDictLongKey
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isDictLongKey()
|
||||
:outertype: IValue
|
||||
|
||||
isDictStringKey
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isDictStringKey()
|
||||
:outertype: IValue
|
||||
|
||||
isDouble
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isDouble()
|
||||
:outertype: IValue
|
||||
|
||||
isDoubleList
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isDoubleList()
|
||||
:outertype: IValue
|
||||
|
||||
isList
|
||||
^^^^^^
|
||||
|
||||
.. java:method:: public boolean isList()
|
||||
:outertype: IValue
|
||||
|
||||
isLong
|
||||
^^^^^^
|
||||
|
||||
.. java:method:: public boolean isLong()
|
||||
:outertype: IValue
|
||||
|
||||
isLongList
|
||||
^^^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isLongList()
|
||||
:outertype: IValue
|
||||
|
||||
isNull
|
||||
^^^^^^
|
||||
|
||||
.. java:method:: public boolean isNull()
|
||||
:outertype: IValue
|
||||
|
||||
isString
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isString()
|
||||
:outertype: IValue
|
||||
|
||||
isTensor
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isTensor()
|
||||
:outertype: IValue
|
||||
|
||||
isTensorList
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isTensorList()
|
||||
:outertype: IValue
|
||||
|
||||
isTuple
|
||||
^^^^^^^
|
||||
|
||||
.. java:method:: public boolean isTuple()
|
||||
:outertype: IValue
|
||||
|
||||
listFrom
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue listFrom(boolean... list)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``List[bool]``\ .
|
||||
|
||||
listFrom
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue listFrom(long... list)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``List[int]``\ .
|
||||
|
||||
listFrom
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue listFrom(double... list)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``List[float]``\ .
|
||||
|
||||
listFrom
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue listFrom(Tensor... list)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``List[Tensor]``\ .
|
||||
|
||||
listFrom
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue listFrom(IValue... array)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``List[T]``\ . All elements must have the same type.
|
||||
|
||||
optionalNull
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue optionalNull()
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``Optional``\ that contains no value.
|
||||
|
||||
toBool
|
||||
^^^^^^
|
||||
|
||||
.. java:method:: public boolean toBool()
|
||||
:outertype: IValue
|
||||
|
||||
toBoolList
|
||||
^^^^^^^^^^
|
||||
|
||||
.. java:method:: public boolean[] toBoolList()
|
||||
:outertype: IValue
|
||||
|
||||
toDictLongKey
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public Map<Long, IValue> toDictLongKey()
|
||||
:outertype: IValue
|
||||
|
||||
toDictStringKey
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public Map<String, IValue> toDictStringKey()
|
||||
:outertype: IValue
|
||||
|
||||
toDouble
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public double toDouble()
|
||||
:outertype: IValue
|
||||
|
||||
toDoubleList
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public double[] toDoubleList()
|
||||
:outertype: IValue
|
||||
|
||||
toList
|
||||
^^^^^^
|
||||
|
||||
.. java:method:: public IValue[] toList()
|
||||
:outertype: IValue
|
||||
|
||||
toLong
|
||||
^^^^^^
|
||||
|
||||
.. java:method:: public long toLong()
|
||||
:outertype: IValue
|
||||
|
||||
toLongList
|
||||
^^^^^^^^^^
|
||||
|
||||
.. java:method:: public long[] toLongList()
|
||||
:outertype: IValue
|
||||
|
||||
toStr
|
||||
^^^^^
|
||||
|
||||
.. java:method:: public String toStr()
|
||||
:outertype: IValue
|
||||
|
||||
toTensor
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public Tensor toTensor()
|
||||
:outertype: IValue
|
||||
|
||||
toTensorList
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public Tensor[] toTensorList()
|
||||
:outertype: IValue
|
||||
|
||||
toTuple
|
||||
^^^^^^^
|
||||
|
||||
.. java:method:: public IValue[] toTuple()
|
||||
:outertype: IValue
|
||||
|
||||
tupleFrom
|
||||
^^^^^^^^^
|
||||
|
||||
.. java:method:: public static IValue tupleFrom(IValue... array)
|
||||
:outertype: IValue
|
||||
|
||||
Creates a new \ ``IValue``\ of type \ ``Tuple[T0, T1, ...]``\ .
|
@ -1,55 +0,0 @@
|
||||
.. java:import:: com.facebook.jni HybridData
|
||||
|
||||
Module
|
||||
======
|
||||
|
||||
.. java:package:: org.pytorch
|
||||
:noindex:
|
||||
|
||||
.. java:type:: public class Module
|
||||
|
||||
Java wrapper for torch::jit::script::Module.
|
||||
|
||||
Methods
|
||||
-------
|
||||
destroy
|
||||
^^^^^^^
|
||||
|
||||
.. java:method:: public void destroy()
|
||||
:outertype: Module
|
||||
|
||||
Explicitly destroys the native torch::jit::script::Module. Calling this method is not required, as the native object will be destroyed when this object is garbage-collected. However, the timing of garbage collection is not guaranteed, so proactively calling \ ``destroy``\ can free memory more quickly. See \ :java:ref:`com.facebook.jni.HybridData.resetNative`\ .
|
||||
|
||||
forward
|
||||
^^^^^^^
|
||||
|
||||
.. java:method:: public IValue forward(IValue... inputs)
|
||||
:outertype: Module
|
||||
|
||||
Runs the 'forward' method of this module with the specified arguments.
|
||||
|
||||
:param inputs: arguments for the TorchScript module's 'forward' method.
|
||||
:return: return value from the 'forward' method.
|
||||
|
||||
load
|
||||
^^^^
|
||||
|
||||
.. java:method:: public static Module load(String modelPath)
|
||||
:outertype: Module
|
||||
|
||||
Loads a serialized TorchScript module from the specified path on the disk.
|
||||
|
||||
:param modelPath: path to file that contains the serialized TorchScript module.
|
||||
:return: new \ :java:ref:`org.pytorch.Module`\ object which owns torch::jit::script::Module.
|
||||
|
||||
runMethod
|
||||
^^^^^^^^^
|
||||
|
||||
.. java:method:: public IValue runMethod(String methodName, IValue... inputs)
|
||||
:outertype: Module
|
||||
|
||||
Runs the specified method of this module with the specified arguments.
|
||||
|
||||
:param methodName: name of the TorchScript method to run.
|
||||
:param inputs: arguments that will be passed to TorchScript method.
|
||||
:return: return value from the method.
|
@ -1,315 +0,0 @@
|
||||
.. java:import:: java.nio Buffer
|
||||
|
||||
.. java:import:: java.nio ByteBuffer
|
||||
|
||||
.. java:import:: java.nio ByteOrder
|
||||
|
||||
.. java:import:: java.nio DoubleBuffer
|
||||
|
||||
.. java:import:: java.nio FloatBuffer
|
||||
|
||||
.. java:import:: java.nio IntBuffer
|
||||
|
||||
.. java:import:: java.nio LongBuffer
|
||||
|
||||
.. java:import:: java.util Arrays
|
||||
|
||||
.. java:import:: java.util Locale
|
||||
|
||||
Tensor
|
||||
======
|
||||
|
||||
.. java:package:: org.pytorch
|
||||
:noindex:
|
||||
|
||||
.. java:type:: public abstract class Tensor
|
||||
|
||||
Representation of a Tensor. Behavior is similar to PyTorch's tensor objects.
|
||||
|
||||
Most tensors will be constructed as \ ``Tensor.fromBlob(data, shape)``\ , where \ ``data``\ can be an array or a direct \ :java:ref:`Buffer`\ (of the proper subclass). Helper methods are provided to allocate buffers properly.
|
||||
|
||||
To access Tensor data, see \ :java:ref:`dtype()`\ , \ :java:ref:`shape()`\ , and various \ ``getDataAs*``\ methods.
|
||||
|
||||
When constructing \ ``Tensor``\ objects with \ ``data``\ as an array, it is not specified whether this data is is copied or retained as a reference so it is recommended not to modify it after constructing. \ ``data``\ passed as a \ :java:ref:`Buffer`\ is not copied, so it can be modified between \ :java:ref:`Module`\ calls to avoid reallocation. Data retrieved from \ ``Tensor``\ objects may be copied or may be a reference to the \ ``Tensor``\ 's internal data buffer. \ ``shape``\ is always copied.
|
||||
|
||||
Methods
|
||||
-------
|
||||
allocateByteBuffer
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static ByteBuffer allocateByteBuffer(int numElements)
|
||||
:outertype: Tensor
|
||||
|
||||
Allocates a new direct \ :java:ref:`java.nio.ByteBuffer`\ with native byte order with specified capacity that can be used in \ :java:ref:`Tensor.fromBlob(ByteBuffer,long[])`\ , \ :java:ref:`Tensor.fromBlobUnsigned(ByteBuffer,long[])`\ .
|
||||
|
||||
:param numElements: capacity (number of elements) of result buffer.
|
||||
|
||||
allocateDoubleBuffer
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static DoubleBuffer allocateDoubleBuffer(int numElements)
|
||||
:outertype: Tensor
|
||||
|
||||
Allocates a new direct \ :java:ref:`java.nio.DoubleBuffer`\ with native byte order with specified capacity that can be used in \ :java:ref:`Tensor.fromBlob(DoubleBuffer,long[])`\ .
|
||||
|
||||
:param numElements: capacity (number of elements) of result buffer.
|
||||
|
||||
allocateFloatBuffer
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static FloatBuffer allocateFloatBuffer(int numElements)
|
||||
:outertype: Tensor
|
||||
|
||||
Allocates a new direct \ :java:ref:`java.nio.FloatBuffer`\ with native byte order with specified capacity that can be used in \ :java:ref:`Tensor.fromBlob(FloatBuffer,long[])`\ .
|
||||
|
||||
:param numElements: capacity (number of elements) of result buffer.
|
||||
|
||||
allocateIntBuffer
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static IntBuffer allocateIntBuffer(int numElements)
|
||||
:outertype: Tensor
|
||||
|
||||
Allocates a new direct \ :java:ref:`java.nio.IntBuffer`\ with native byte order with specified capacity that can be used in \ :java:ref:`Tensor.fromBlob(IntBuffer,long[])`\ .
|
||||
|
||||
:param numElements: capacity (number of elements) of result buffer.
|
||||
|
||||
allocateLongBuffer
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static LongBuffer allocateLongBuffer(int numElements)
|
||||
:outertype: Tensor
|
||||
|
||||
Allocates a new direct \ :java:ref:`java.nio.LongBuffer`\ with native byte order with specified capacity that can be used in \ :java:ref:`Tensor.fromBlob(LongBuffer,long[])`\ .
|
||||
|
||||
:param numElements: capacity (number of elements) of result buffer.
|
||||
|
||||
dtype
|
||||
^^^^^
|
||||
|
||||
.. java:method:: public abstract DType dtype()
|
||||
:outertype: Tensor
|
||||
|
||||
:return: data type of this tensor.
|
||||
|
||||
dtypeJniCode
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: int dtypeJniCode()
|
||||
:outertype: Tensor
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(byte[] data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of bytes.
|
||||
|
||||
:param data: Tensor elements
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(int[] data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of ints.
|
||||
|
||||
:param data: Tensor elements
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(float[] data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array of floats.
|
||||
|
||||
:param data: Tensor elements
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(long[] data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of longs.
|
||||
|
||||
:param data: Tensor elements
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(long[] shape, double[] data)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.float64 with specified shape and data as array of doubles.
|
||||
|
||||
:param shape: Tensor shape
|
||||
:param data: Tensor elements
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(ByteBuffer data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.int8 with specified shape and data.
|
||||
|
||||
:param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\ elements. The buffer is used directly without copying, and changes to its content will change the tensor.
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(IntBuffer data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.int32 with specified shape and data.
|
||||
|
||||
:param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\ elements. The buffer is used directly without copying, and changes to its content will change the tensor.
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(FloatBuffer data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.float32 with specified shape and data.
|
||||
|
||||
:param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\ elements. The buffer is used directly without copying, and changes to its content will change the tensor.
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(LongBuffer data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.int64 with specified shape and data.
|
||||
|
||||
:param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\ elements. The buffer is used directly without copying, and changes to its content will change the tensor.
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlob
|
||||
^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlob(DoubleBuffer data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.float64 with specified shape and data.
|
||||
|
||||
:param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\ elements. The buffer is used directly without copying, and changes to its content will change the tensor.
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlobUnsigned
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlobUnsigned(byte[] data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of bytes.
|
||||
|
||||
:param data: Tensor elements
|
||||
:param shape: Tensor shape
|
||||
|
||||
fromBlobUnsigned
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Creates a new Tensor instance with dtype torch.uint8 with specified shape and data.
|
||||
|
||||
:param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\ elements. The buffer is used directly without copying, and changes to its content will change the tensor.
|
||||
:param shape: Tensor shape
|
||||
|
||||
getDataAsByteArray
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public byte[] getDataAsByteArray()
|
||||
:outertype: Tensor
|
||||
|
||||
:throws IllegalStateException: if it is called for a non-int8 tensor.
|
||||
:return: a Java byte array that contains the tensor data. This may be a copy or reference.
|
||||
|
||||
getDataAsDoubleArray
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public double[] getDataAsDoubleArray()
|
||||
:outertype: Tensor
|
||||
|
||||
:throws IllegalStateException: if it is called for a non-float64 tensor.
|
||||
:return: a Java double array that contains the tensor data. This may be a copy or reference.
|
||||
|
||||
getDataAsFloatArray
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public float[] getDataAsFloatArray()
|
||||
:outertype: Tensor
|
||||
|
||||
:throws IllegalStateException: if it is called for a non-float32 tensor.
|
||||
:return: a Java float array that contains the tensor data. This may be a copy or reference.
|
||||
|
||||
getDataAsIntArray
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public int[] getDataAsIntArray()
|
||||
:outertype: Tensor
|
||||
|
||||
:throws IllegalStateException: if it is called for a non-int32 tensor.
|
||||
:return: a Java int array that contains the tensor data. This may be a copy or reference.
|
||||
|
||||
getDataAsLongArray
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public long[] getDataAsLongArray()
|
||||
:outertype: Tensor
|
||||
|
||||
:throws IllegalStateException: if it is called for a non-int64 tensor.
|
||||
:return: a Java long array that contains the tensor data. This may be a copy or reference.
|
||||
|
||||
getDataAsUnsignedByteArray
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public byte[] getDataAsUnsignedByteArray()
|
||||
:outertype: Tensor
|
||||
|
||||
:throws IllegalStateException: if it is called for a non-uint8 tensor.
|
||||
:return: a Java byte array that contains the tensor data. This may be a copy or reference.
|
||||
|
||||
getRawDataBuffer
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: Buffer getRawDataBuffer()
|
||||
:outertype: Tensor
|
||||
|
||||
numel
|
||||
^^^^^
|
||||
|
||||
.. java:method:: public long numel()
|
||||
:outertype: Tensor
|
||||
|
||||
Returns the number of elements in this tensor.
|
||||
|
||||
numel
|
||||
^^^^^
|
||||
|
||||
.. java:method:: public static long numel(long[] shape)
|
||||
:outertype: Tensor
|
||||
|
||||
Calculates the number of elements in a tensor with the specified shape.
|
||||
|
||||
shape
|
||||
^^^^^
|
||||
|
||||
.. java:method:: public long[] shape()
|
||||
:outertype: Tensor
|
||||
|
||||
Returns the shape of this tensor. (The array is a fresh copy.)
|
@ -1,114 +0,0 @@
|
||||
.. java:import:: android.graphics Bitmap
|
||||
|
||||
.. java:import:: android.graphics ImageFormat
|
||||
|
||||
.. java:import:: android.media Image
|
||||
|
||||
.. java:import:: org.pytorch Tensor
|
||||
|
||||
.. java:import:: java.nio ByteBuffer
|
||||
|
||||
.. java:import:: java.nio FloatBuffer
|
||||
|
||||
.. java:import:: java.util Locale
|
||||
|
||||
TensorImageUtils
|
||||
================
|
||||
|
||||
.. java:package:: org.pytorch.torchvision
|
||||
:noindex:
|
||||
|
||||
.. java:type:: public final class TensorImageUtils
|
||||
|
||||
Contains utility functions for \ :java:ref:`org.pytorch.Tensor`\ creation from \ :java:ref:`android.graphics.Bitmap`\ or \ :java:ref:`android.media.Image`\ source.
|
||||
|
||||
Fields
|
||||
------
|
||||
TORCHVISION_NORM_MEAN_RGB
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:field:: public static float[] TORCHVISION_NORM_MEAN_RGB
|
||||
:outertype: TensorImageUtils
|
||||
|
||||
TORCHVISION_NORM_STD_RGB
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:field:: public static float[] TORCHVISION_NORM_STD_RGB
|
||||
:outertype: TensorImageUtils
|
||||
|
||||
Methods
|
||||
-------
|
||||
bitmapToFloat32Tensor
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor bitmapToFloat32Tensor(Bitmap bitmap, float[] normMeanRGB, float[] normStdRGB)
|
||||
:outertype: TensorImageUtils
|
||||
|
||||
Creates new \ :java:ref:`org.pytorch.Tensor`\ from full \ :java:ref:`android.graphics.Bitmap`\ , normalized with specified in parameters mean and std.
|
||||
|
||||
:param normMeanRGB: means for RGB channels normalization, length must equal 3, RGB order
|
||||
:param normStdRGB: standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
|
||||
bitmapToFloat32Tensor
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor bitmapToFloat32Tensor(Bitmap bitmap, int x, int y, int width, int height, float[] normMeanRGB, float[] normStdRGB)
|
||||
:outertype: TensorImageUtils
|
||||
|
||||
Creates new \ :java:ref:`org.pytorch.Tensor`\ from specified area of \ :java:ref:`android.graphics.Bitmap`\ , normalized with specified in parameters mean and std.
|
||||
|
||||
:param bitmap: \ :java:ref:`android.graphics.Bitmap`\ as a source for Tensor data
|
||||
:param x: - x coordinate of top left corner of bitmap's area
|
||||
:param y: - y coordinate of top left corner of bitmap's area
|
||||
:param width: - width of bitmap's area
|
||||
:param height: - height of bitmap's area
|
||||
:param normMeanRGB: means for RGB channels normalization, length must equal 3, RGB order
|
||||
:param normStdRGB: standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
|
||||
bitmapToFloatBuffer
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static void bitmapToFloatBuffer(Bitmap bitmap, int x, int y, int width, int height, float[] normMeanRGB, float[] normStdRGB, FloatBuffer outBuffer, int outBufferOffset)
|
||||
:outertype: TensorImageUtils
|
||||
|
||||
Writes tensor content from specified \ :java:ref:`android.graphics.Bitmap`\ , normalized with specified in parameters mean and std to specified \ :java:ref:`java.nio.FloatBuffer`\ with specified offset.
|
||||
|
||||
:param bitmap: \ :java:ref:`android.graphics.Bitmap`\ as a source for Tensor data
|
||||
:param x: - x coordinate of top left corner of bitmap's area
|
||||
:param y: - y coordinate of top left corner of bitmap's area
|
||||
:param width: - width of bitmap's area
|
||||
:param height: - height of bitmap's area
|
||||
:param normMeanRGB: means for RGB channels normalization, length must equal 3, RGB order
|
||||
:param normStdRGB: standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
|
||||
imageYUV420CenterCropToFloat32Tensor
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static Tensor imageYUV420CenterCropToFloat32Tensor(Image image, int rotateCWDegrees, int tensorWidth, int tensorHeight, float[] normMeanRGB, float[] normStdRGB)
|
||||
:outertype: TensorImageUtils
|
||||
|
||||
Creates new \ :java:ref:`org.pytorch.Tensor`\ from specified area of \ :java:ref:`android.media.Image`\ , doing optional rotation, scaling (nearest) and center cropping.
|
||||
|
||||
:param image: \ :java:ref:`android.media.Image`\ as a source for Tensor data
|
||||
:param rotateCWDegrees: Clockwise angle through which the input image needs to be rotated to be upright. Range of valid values: 0, 90, 180, 270
|
||||
:param tensorWidth: return tensor width, must be positive
|
||||
:param tensorHeight: return tensor height, must be positive
|
||||
:param normMeanRGB: means for RGB channels normalization, length must equal 3, RGB order
|
||||
:param normStdRGB: standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
|
||||
imageYUV420CenterCropToFloatBuffer
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. java:method:: public static void imageYUV420CenterCropToFloatBuffer(Image image, int rotateCWDegrees, int tensorWidth, int tensorHeight, float[] normMeanRGB, float[] normStdRGB, FloatBuffer outBuffer, int outBufferOffset)
|
||||
:outertype: TensorImageUtils
|
||||
|
||||
Writes tensor content from specified \ :java:ref:`android.media.Image`\ , doing optional rotation, scaling (nearest) and center cropping to specified \ :java:ref:`java.nio.FloatBuffer`\ with specified offset.
|
||||
|
||||
:param image: \ :java:ref:`android.media.Image`\ as a source for Tensor data
|
||||
:param rotateCWDegrees: Clockwise angle through which the input image needs to be rotated to be upright. Range of valid values: 0, 90, 180, 270
|
||||
:param tensorWidth: return tensor width, must be positive
|
||||
:param tensorHeight: return tensor height, must be positive
|
||||
:param normMeanRGB: means for RGB channels normalization, length must equal 3, RGB order
|
||||
:param normStdRGB: standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
:param outBuffer: Output buffer, where tensor content will be written
|
||||
:param outBufferOffset: Output buffer offset with which tensor content will be written
|
@ -1,18 +0,0 @@
|
||||
org.pytorch
|
||||
===========
|
||||
|
||||
.. java:package:: org.pytorch
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
DType
|
||||
IValue
|
||||
Module
|
||||
Tensor
|
||||
Tensor-Tensor_float32
|
||||
Tensor-Tensor_float64
|
||||
Tensor-Tensor_int32
|
||||
Tensor-Tensor_int64
|
||||
Tensor-Tensor_int8
|
||||
Tensor-Tensor_uint8
|
@ -1,9 +0,0 @@
|
||||
rg.pytorch.torchvision
|
||||
=======================
|
||||
|
||||
.. java:package:: org.pytorch.torchvision
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
TensorImageUtils
|
@ -1,29 +0,0 @@
|
||||
Javadoc
|
||||
=========
|
||||
|
||||
Main part of the java api includes 3 classes:
|
||||
::
|
||||
org.pytorch.Tensor
|
||||
org.pytorch.IValue
|
||||
org.pytorch.Module
|
||||
|
||||
If the reader is familiar with pytorch python api, we can think that
|
||||
``org.pytorch.Tensor`` represents ``torch.tensor``, ``org.pytorch.Module`` represents ``torch.Module``,
|
||||
while ``org.pytorch.IValue`` represents value of TorchScript variable, supporting all
|
||||
its `types <https://pytorch.org/docs/stable/jit.html#types>`_.
|
||||
|
||||
Additionally, there is the DType class which contains code representing tensor data types and
|
||||
TensorImageUtils which contains utility functions for \ :java:ref:`org.pytorch.Tensor`\ created from \ :java:ref:`android.graphics.Bitmap`\ or \ :java:ref:`android.media.Image`\ source.
|
||||
|
||||
You can find details to each of these classes linked below:
|
||||
|
||||
.. java:package:: org.pytorch
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
org/pytorch/DType
|
||||
org/pytorch/Tensor
|
||||
org/pytorch/IValue
|
||||
org/pytorch/Module
|
||||
org/pytorch/TensorImageUtils
|
@ -1,583 +0,0 @@
|
||||
Quantization
|
||||
===========================
|
||||
|
||||
Introduction to Quantization
|
||||
----------------------------
|
||||
|
||||
Quantization refers to techniques for performing computations and storing tensors at lower bitwidths than
|
||||
floating point precision. A quantized model executes some or all of the operations on tensors with
|
||||
integers rather than floating point values. This allows for a more
|
||||
compact model representation and the use of high performance vectorized
|
||||
operations on many hardware platforms. PyTorch supports INT8
|
||||
quantization compared to typical FP32 models allowing for a 4x reduction in the model size and
|
||||
a 4x reduction in memory bandwidth requirements. Hardware support for INT8 computations
|
||||
is typically 2 to 4 times faster compared to FP32 compute. Quantization is primarily a technique
|
||||
to speed up inference and only the forward pass is supported for quantized operators.
|
||||
|
||||
PyTorch supports multiple approaches to quantizing a deep learning model. In most cases the model is trained
|
||||
in FP32 and then the model is converted to INT8. In addition, PyTorch also supports quantization aware
|
||||
training, which models quantization errors in both the forward and backward passes using fake-quantization
|
||||
modules. Note that the entire computation is carried out in floating point. At the end of quantization aware
|
||||
training, PyTorch provides conversion functions to convert the trained model into lower precision.
|
||||
|
||||
At lower level, PyTorch provides a way to represent quantized tensors and
|
||||
perform operations with them. They can be used to directly construct models that
|
||||
perform all or part of the computation in lower precision. Higher-level APIs are
|
||||
provided that incorporate typical workflows of converting FP32 model to lower
|
||||
precision with minimal accuracy loss.
|
||||
|
||||
Today, PyTorch supports the following backends for running quantized operators efficiently:
|
||||
|
||||
* x86 CPUs with AVX2 support or higher (without AVX2 some operations have inefficient implementations)
|
||||
* ARM CPUs (typically found in mobile/embedded devices)
|
||||
|
||||
The corresponding implementation is chosen automatically based on the PyTorch build mode.
|
||||
|
||||
.. note::
|
||||
|
||||
PyTorch 1.3 doesn't provide quantized operator implementations on CUDA yet - this is direction of future work.
|
||||
Move the model to CPU in order to test the quantized functionality.
|
||||
|
||||
Quantization-aware training (through :class:`~torch.quantization.FakeQuantize`) supports both CPU and CUDA.
|
||||
|
||||
Quantized Tensors
|
||||
---------------------------------------
|
||||
|
||||
PyTorch supports both per tensor and per channel asymmetric linear
|
||||
quantization. Per tensor means that all the values within the tensor are
|
||||
scaled the same way. Per channel means that for each dimension, typically
|
||||
the channel dimension of a tensor, the values
|
||||
in the tensor are scaled and offset by a different value (effectively
|
||||
the scale and offset become vectors). This allows for lesser error in converting tensors
|
||||
to quantized values.
|
||||
|
||||
The mapping is performed by converting the floating point tensors using
|
||||
|
||||
.. image:: math-quantizer-equation.png
|
||||
:width: 40%
|
||||
|
||||
Note that, we ensure that zero in floating point is represented with no error after quantization,
|
||||
thereby ensuring that operations like padding do not cause additional quantization error.
|
||||
|
||||
In order to do quantization in PyTorch, we need to be able to represent
|
||||
quantized data in Tensors. A Quantized Tensor allows for storing
|
||||
quantized data (represented as int8/uint8/int32) along with quantization
|
||||
parameters like scale and zero\_point. Quantized Tensors allow for many
|
||||
useful operations making quantized arithmetic easy, in addition to
|
||||
allowing for serialization of data in a quantized format.
|
||||
|
||||
Operation coverage
|
||||
------------------
|
||||
|
||||
Quantized Tensors support a limited subset of data manipulation methods of the regular
|
||||
full-precision tensor. (see list below)
|
||||
|
||||
For NN operators included in PyTorch, we restrict support to:
|
||||
|
||||
1. 8 bit weights (data\_type = qint8)
|
||||
2. 8 bit activations (data\_type = quint8)
|
||||
|
||||
Note that operator implementations currently only
|
||||
support per channel quantization for weights of the **conv** and **linear**
|
||||
operators. Furthermore the minimum and the maximum of the input data is
|
||||
mapped linearly to the minimum and the maximum of the quantized data
|
||||
type such that zero is represented with no quantization error.
|
||||
|
||||
Additional data types and quantization schemes can be implemented through
|
||||
the `custom operator mechanism <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_.
|
||||
|
||||
Many operations for quantized tensors are available under the same API as full
|
||||
float version in ``torch`` or ``torch.nn``. Quantized version of NN modules that
|
||||
perform re-quantization are available in ``torch.nn.quantized``. Those
|
||||
operations explicitly take output quantization parameters (scale and zero\_point) in
|
||||
the operation signature.
|
||||
|
||||
In addition, we also support fused versions corresponding to common fusion patterns that impact quantization at:
|
||||
torch.nn.intrinsic.quantized.
|
||||
|
||||
For quantization aware training, we support modules prepared for quantization aware training at
|
||||
torch.nn.qat and torch.nn.intrinsic.qat
|
||||
|
||||
Current quantized operation list is sufficient to cover typical CNN and RNN
|
||||
models:
|
||||
|
||||
|
||||
Quantized ``torch.Tensor`` operations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Operations that are available from the ``torch`` namespace or as methods on Tensor for quantized tensors:
|
||||
|
||||
* :func:`~torch.quantize_per_tensor` - Convert float tensor to quantized tensor with per-tensor scale and zero point
|
||||
* :func:`~torch.quantize_per_channel` - Convert float tensor to quantized tensor with per-channel scale and zero point
|
||||
* View-based operations like :meth:`~torch.Tensor.view`, :meth:`~torch.Tensor.as_strided`, :meth:`~torch.Tensor.expand`, :meth:`~torch.Tensor.flatten`, :meth:`~torch.Tensor.slice`, python-style indexing, etc - work as on regular tensor (if quantization is not per-channel)
|
||||
* Comparators
|
||||
* :meth:`~torch.Tensor.ne` — Not equal
|
||||
* :meth:`~torch.Tensor.eq` — Equal
|
||||
* :meth:`~torch.Tensor.ge` — Greater or equal
|
||||
* :meth:`~torch.Tensor.le` — Less or equal
|
||||
* :meth:`~torch.Tensor.gt` — Greater
|
||||
* :meth:`~torch.Tensor.lt` — Less
|
||||
* :meth:`~torch.Tensor.copy_` — Copies src to self in-place
|
||||
* :meth:`~torch.Tensor.clone` — Returns a deep copy of the passed-in tensor
|
||||
* :meth:`~torch.Tensor.dequantize` — Convert quantized tensor to float tensor
|
||||
* :meth:`~torch.Tensor.equal` — Compares two tensors, returns true if quantization parameters and all integer elements are the same
|
||||
* :meth:`~torch.Tensor.int_repr` — Prints the underlying integer representation of the quantized tensor
|
||||
* :meth:`~torch.Tensor.max` — Returns the maximum value of the tensor (reduction only)
|
||||
* :meth:`~torch.Tensor.mean` — Mean function. Supported variants: reduction, dim, out
|
||||
* :meth:`~torch.Tensor.min` — Returns the minimum value of the tensor (reduction only)
|
||||
* :meth:`~torch.Tensor.q_scale` — Returns the scale of the per-tensor quantized tensor
|
||||
* :meth:`~torch.Tensor.q_zero_point` — Returns the zero_point of the per-tensor quantized zero point
|
||||
* :meth:`~torch.Tensor.q_per_channel_scales` — Returns the scales of the per-channel quantized tensor
|
||||
* :meth:`~torch.Tensor.q_per_channel_zero_points` — Returns the zero points of the per-channel quantized tensor
|
||||
* :meth:`~torch.Tensor.q_per_channel_axis` — Returns the channel axis of the per-channel quantized tensor
|
||||
* :meth:`~torch.Tensor.relu` — Rectified linear unit (copy)
|
||||
* :meth:`~torch.Tensor.relu_` — Rectified linear unit (inplace)
|
||||
* :meth:`~torch.Tensor.resize_` — In-place resize
|
||||
* :meth:`~torch.Tensor.sort` — Sorts the tensor
|
||||
* :meth:`~torch.Tensor.topk` — Returns k largest values of a tensor
|
||||
|
||||
|
||||
``torch.nn.intrinsic``
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Fused modules are provided for common patterns in CNNs. Combining several operations together (like convolution and relu) allows for better quantization accuracy
|
||||
|
||||
* ``torch.nn.intrinsic`` — float versions of the modules, can be swapped with quantized version 1 to 1
|
||||
* :class:`~torch.nn.intrinsic.ConvBn2d` — Conv2d + BatchNorm
|
||||
* :class:`~torch.nn.intrinsic.ConvBnReLU2d` — Conv2d + BatchNorm + ReLU
|
||||
* :class:`~torch.nn.intrinsic.ConvReLU2d` — Conv2d + Relu
|
||||
* :class:`~torch.nn.intrinsic.LinearReLU` — Linear + ReLU
|
||||
* ``torch.nn.intrinsic.qat`` — versions of layers for quantization-aware training
|
||||
* :class:`~torch.nn.intrinsic.qat.ConvBn2d` — Conv2d + BatchNorm
|
||||
* :class:`~torch.nn.intrinsic.qat.ConvBnReLU2d` — Conv2d + BatchNorm + ReLU
|
||||
* :class:`~torch.nn.intrinsic.qat.ConvReLU2d` — Conv2d + ReLU
|
||||
* :class:`~torch.nn.intrinsic.qat.LinearReLU` — Linear + ReLU
|
||||
* ``torch.nn.intrinsic.quantized`` — quantized version of fused layers for inference (no BatchNorm variants as it's usually folded into convolution for inference)
|
||||
* :class:`~torch.nn.intrinsic.quantized.LinearReLU` — Linear + ReLU
|
||||
* :class:`~torch.nn.intrinsic.quantized.ConvReLU2d` — 2D Convolution + ReLU
|
||||
|
||||
``torch.nn.qat``
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
Layers for the quantization-aware training
|
||||
|
||||
* :class:`~torch.nn.qat.Linear` — Linear (fully-connected) layer
|
||||
* :class:`~torch.nn.qat.Conv2d` — 2D convolution
|
||||
|
||||
``torch.quantization``
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
* Functions for quantization
|
||||
* :func:`~torch.quantization.add_observer_` — Adds observer for the leaf modules (if quantization configuration is provided)
|
||||
* :func:`~torch.quantization.add_quant_dequant`— Wraps the leaf child module using :class:`~torch.quantization.QuantWrapper`
|
||||
* :func:`~torch.quantization.convert` — Converts float module with observers into its quantized counterpart. Must have quantization configuration
|
||||
* :func:`~torch.quantization.get_observer_dict` — Traverses the module children and collects all observers into a ``dict``
|
||||
* :func:`~torch.quantization.prepare` — Prepares a copy of a model for quantization
|
||||
* :func:`~torch.quantization.prepare_qat` — Prepares a copy of a model for quantization aware training
|
||||
* :func:`~torch.quantization.propagate_qconfig_` — Propagates quantization configurations through the module hierarchy and assign them to each leaf module
|
||||
* :func:`~torch.quantization.quantize` — Converts a float module to quantized version
|
||||
* :func:`~torch.quantization.quantize_dynamic` — Converts a float module to dynamically quantized version
|
||||
* :func:`~torch.quantization.quantize_qat`— Converts a float module to quantized version used in quantization aware training
|
||||
* :func:`~torch.quantization.swap_module` — Swaps the module with its quantized counterpart (if quantizable and if it has an observer)
|
||||
* :func:`~torch.quantization.default_eval_fn` — Default evaluation function used by the :func:`torch.quantization.quantize`
|
||||
* :func:`~torch.quantization.fuse_modules`
|
||||
* :class:`~torch.quantization.FakeQuantize` — Module for simulating the quantization/dequantization at training time
|
||||
* Default Observers. The rest of observers are available from ``torch.quantization.observer``
|
||||
* :attr:`~torch.quantization.default_observer` — Same as ``MinMaxObserver.with_args(reduce_range=True)``
|
||||
* :attr:`~torch.quantization.default_weight_observer` — Same as ``MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)``
|
||||
* :class:`~torch.quantization.Observer` — Abstract base class for observers
|
||||
* Quantization configurations
|
||||
* :class:`~torch.quantization.QConfig` — Quantization configuration class
|
||||
* :attr:`~torch.quantization.default_qconfig` — Same as ``QConfig(activation=default_observer, weight=default_weight_observer)`` (See :class:`~torch.quantization.QConfig.QConfig`)
|
||||
* :attr:`~torch.quantization.default_qat_qconfig` — Same as ``QConfig(activation=default_fake_quant, weight=default_weight_fake_quant)`` (See :class:`~torch.quantization.QConfig.QConfig`)
|
||||
* :attr:`~torch.quantization.default_dynamic_qconfig` — Same as ``QConfigDynamic(weight=default_weight_observer)`` (See :class:`~torch.quantization.QConfig.QConfigDynamic`)
|
||||
* :attr:`~torch.quantization.float16_dynamic_qconfig` — Same as ``QConfigDynamic(weight=NoopObserver.with_args(dtype=torch.float16))`` (See :class:`~torch.quantization.QConfig.QConfigDynamic`)
|
||||
* Stubs
|
||||
* :class:`~torch.quantization.DeQuantStub` - placeholder module for dequantize() operation in float-valued models
|
||||
* :class:`~torch.quantization.QuantStub` - placeholder module for quantize() operation in float-valued models
|
||||
* :class:`~torch.quantization.QuantWrapper` — wraps the module to be quantized. Inserts the :class:`~torch.quantization.QuantStub` and :class:`~torch.quantization.DeQuantStub`
|
||||
|
||||
Observers for computing the quantization parameters
|
||||
|
||||
* :class:`~torch.quantization.MinMaxObserver` — Derives the quantization parameters from the running minimum and maximum of the observed tensor inputs (per tensor variant)
|
||||
* :class:`~torch.quantization.MovingAverageObserver` — Derives the quantization parameters from the running averages of the minimums and maximums of the observed tensor inputs (per tensor variant)
|
||||
* :class:`~torch.quantization.PerChannelMinMaxObserver`— Derives the quantization parameters from the running minimum and maximum of the observed tensor inputs (per channel variant)
|
||||
* :class:`~torch.quantization.MovingAveragePerChannelMinMaxObserver` — Derives the quantization parameters from the running averages of the minimums and maximums of the observed tensor inputs (per channel variant)
|
||||
* :class:`~torch.quantization.HistogramObserver` — Derives the quantization parameters by creating a histogram of running minimums and maximums.
|
||||
* Observers that do not compute the quantization parameters:
|
||||
* :class:`~torch.quantization.RecordingObserver` — Records all incoming tensors. Used for debugging only.
|
||||
* :class:`~torch.quantization.NoopObserver` — Pass-through observer. Used for situation when there are no quantization parameters (i.e. quantization to ``float16``)
|
||||
|
||||
``torch.nn.quantized``
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Quantized version of standard NN layers.
|
||||
|
||||
* :class:`~torch.nn.quantized.Quantize` — Quantization layer, used to automatically replace :class:`~torch.quantization.QuantStub`
|
||||
* :class:`~torch.nn.quantized.DeQuantize` — Dequantization layer, used to replace :class:`~torch.quantization.DeQuantStub`
|
||||
* :class:`~torch.nn.quantized.FloatFunctional` — Wrapper class to make stateless float operations stateful so that they can be replaced with quantized versions
|
||||
* :class:`~torch.nn.quantized.QFunctional` — Wrapper class for quantized versions of stateless operations like ```torch.add``
|
||||
* :class:`~torch.nn.quantized.Conv2d` — 2D convolution
|
||||
* :class:`~torch.nn.quantized.Linear` — Linear (fully-connected) layer
|
||||
* :class:`~torch.nn.MaxPool2d` — 2D max pooling
|
||||
* :class:`~torch.nn.quantized.ReLU` — Rectified linear unit
|
||||
* :class:`~torch.nn.quantized.ReLU6` — Rectified linear unit with cut-off at quantized representation of 6
|
||||
|
||||
``torch.nn.quantized.dynamic``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Layers used in dynamically quantized models (i.e. quantized only on weights)
|
||||
|
||||
* :class:`~torch.nn.quantized.dynamic.Linear` — Linear (fully-connected) layer
|
||||
* :class:`~torch.nn.quantized.dynamic.LSTM` — Long-Short Term Memory RNN module
|
||||
|
||||
``torch.nn.quantized.functional``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Functional versions of quantized NN layers (many of them accept explicit quantization output parameters)
|
||||
|
||||
* :func:`~torch.nn.quantized.functional.adaptive_avg_pool2d` — 2D adaptive average pooling
|
||||
* :func:`~torch.nn.quantized.functional.avg_pool2d` — 2D average pooling
|
||||
* :func:`~torch.nn.quantized.functional.conv2d` — 2D convolution
|
||||
* :func:`~torch.nn.quantized.functional.interpolate` — Down-/up- sampler
|
||||
* :func:`~torch.nn.quantized.functional.linear` — Linear (fully-connected) op
|
||||
* :func:`~torch.nn.quantized.functional.max_pool2d` — 2D max pooling
|
||||
* :func:`~torch.nn.quantized.functional.relu` — Rectified linear unit
|
||||
* :func:`~torch.nn.quantized.functional.upsample` — Upsampler. Will be deprecated in favor of :func:`~torch.nn.quantized.functional.interpolate`
|
||||
* :func:`~torch.nn.quantized.functional.upsample_bilinear` — Bilenear upsampler. Will be deprecated in favor of :func:`~torch.nn.quantized.functional.interpolate`
|
||||
* :func:`~torch.nn.quantized.functional.upsample_nearest` — Nearest neighbor upsampler. Will be deprecated in favor of :func:`~torch.nn.quantized.functional.interpolate`
|
||||
|
||||
Quantized dtypes and quantization schemes
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
* :attr:`torch.qscheme` — Type to describe the quantization scheme of a tensor. Supported types:
|
||||
* :attr:`torch.per_tensor_affine` — per tensor, asymmetric
|
||||
* :attr:`torch.per_channel_affine` — per channel, asymmetric
|
||||
* :attr:`torch.per_tensor_symmetric` — per tensor, symmetric
|
||||
* :attr:`torch.per_channel_symmetric` — per tensor, symmetric
|
||||
* ``torch.dtype`` — Type to describe the data. Supported types:
|
||||
* :attr:`torch.quint8` — 8-bit unsigned integer
|
||||
* :attr:`torch.qint8` — 8-bit signed integer
|
||||
* :attr:`torch.qint32` — 32-bit signed integer
|
||||
|
||||
|
||||
|
||||
Quantization Workflows
|
||||
----------------------
|
||||
|
||||
PyTorch provides three approaches to quantize models.
|
||||
|
||||
1. Post Training Dynamic Quantization: This is the simplest to apply form of
|
||||
quantization where the weights are quantized ahead of time but the
|
||||
activations are dynamically quantized during inference. This is used
|
||||
for situations where the model execution time is dominated by loading
|
||||
weights from memory rather than computing the matrix multiplications.
|
||||
This is true for for LSTM and Transformer type models with small
|
||||
batch size. Applying dynamic quantization to a whole model can be
|
||||
done with a single call to :func:`torch.quantization.quantize_dynamic()`.
|
||||
See the `quantization tutorials <https://pytorch.org/tutorials/#quantization-experimental>`_
|
||||
2. Post Training Static Quantization: This is the most commonly used form of
|
||||
quantization where the weights are quantized ahead of time and the
|
||||
scale factor and bias for the activation tensors is pre-computed
|
||||
based on observing the behavior of the model during a calibration
|
||||
process. Post Training Quantization is typically when both memory bandwidth and compute
|
||||
savings are important with CNNs being a typical use case.
|
||||
The general process for doing post training quantization is:
|
||||
|
||||
|
||||
|
||||
1. Prepare the model:
|
||||
a. Specify where the activations are quantized and dequantized explicitly by adding QuantStub and DeQuantStub modules.
|
||||
b. Ensure that modules are not reused.
|
||||
c. Convert any operations that require requantization into modules
|
||||
2. Fuse operations like conv + relu or conv+batchnorm + relu together to improve both model accuracy and performance.
|
||||
|
||||
3. Specify the configuration of the quantization methods \'97 such as
|
||||
selecting symmetric or asymmetric quantization and MinMax or
|
||||
L2Norm calibration techniques.
|
||||
4. Use the :func:`torch.quantization.prepare` to insert modules
|
||||
that will observe activation tensors during calibration
|
||||
5. Calibrate the model by running inference against a calibration
|
||||
dataset
|
||||
6. Finally, convert the model itself with the
|
||||
torch.quantization.convert() method. This does several things: it
|
||||
quantizes the weights, computes and stores the scale and bias
|
||||
value to be used each activation tensor, and replaces key
|
||||
operators quantized implementations.
|
||||
|
||||
See the `quantization tutorials <https://pytorch.org/tutorials/#quantization_experimental>`_
|
||||
|
||||
|
||||
3. Quantization Aware Training: In the rare cases where post training
|
||||
quantization does not provide adequate accuracy training can be done
|
||||
with simulated quantization using the :class:`torch.quantization.FakeQuantize`. Computations
|
||||
will take place in FP32 but with values clamped and rounded to
|
||||
simulate the effects of INT8 quantization. The sequence of steps is
|
||||
very similar.
|
||||
|
||||
|
||||
1. Steps (1) and (2) are identical.
|
||||
|
||||
3. Specify the configuration of the fake quantization methods \'97 such as
|
||||
selecting symmetric or asymmetric quantization and MinMax or Moving Average
|
||||
or L2Norm calibration techniques.
|
||||
4. Use the :func:`torch.quantization.prepare_qat` to insert modules
|
||||
that will simulate quantization during training.
|
||||
5. Train or fine tune the model.
|
||||
6. Identical to step (6) for post training quantization
|
||||
|
||||
See the `quantization tutorials <https://pytorch.org/tutorials/#quantization_experimental>`_
|
||||
|
||||
|
||||
While default implementations of observers to select the scale factor and bias
|
||||
based on observed tensor data are provided, developers can provide their own
|
||||
quantization functions. Quantization can be applied selectively to different
|
||||
parts of the model or configured differently for different parts of the model.
|
||||
|
||||
We also provide support for per channel quantization for **conv2d()**
|
||||
and **linear()**
|
||||
|
||||
Quantization workflows work by adding (e.g. adding observers as
|
||||
``.observer`` submodule) or replacing (e.g. converting ``nn.Conv2d`` to
|
||||
``nn.quantized.Conv2d``) submodules in the model's module hierarchy. It
|
||||
means that the model stays a regular ``nn.Module``-based instance throughout the
|
||||
process and thus can work with the rest of PyTorch APIs.
|
||||
|
||||
|
||||
Model Preparation for Quantization
|
||||
----------------------------------
|
||||
|
||||
It is necessary to currently make some modifications to the model definition
|
||||
prior to quantization. This is because currently quantization works on a module
|
||||
by module basis. Specifically, for all quantization techniques, the user needs to:
|
||||
|
||||
1. Convert any operations that require output requantization (and thus have additional parameters) from functionals to module form.
|
||||
2. Specify which parts of the model need to be quantized either by assigning ```.qconfig`` attributes on submodules or by specifying ``qconfig_dict``
|
||||
|
||||
For static quantization techniques which quantize activations, the user needs to do the following in addition:
|
||||
|
||||
1. Specify where activations are quantized and de-quantized. This is done using :class:`~torch.quantization.QuantStub` and :class:`~torch.quantization.DeQuantStub` modules.
|
||||
2. Use :class:`torch.nn.quantized.FloatFunctional` to wrap tensor operations that require special handling for quantization into modules. Examples
|
||||
are operations like ``add`` and ``cat`` which require special handling to determine output quantization parameters.
|
||||
3. Fuse modules: combine operations/modules into a single module to obtain higher accuracy and performance. This is done using the
|
||||
:func:`torch.quantization.fuse_modules` API, which takes in lists of modules to be fused. We currently support the following fusions:
|
||||
[Conv, Relu], [Conv, BatchNorm], [Conv, BatchNorm, Relu], [Linear, Relu]
|
||||
|
||||
|
||||
torch.quantization
|
||||
---------------------------
|
||||
.. automodule:: torch.quantization
|
||||
|
||||
This module implements the functions you call
|
||||
directly to convert your model from FP32 to quantized form. For
|
||||
example the :func:`~torch.quantization.prepare` is used in post training
|
||||
quantization to prepares your model for the calibration step and
|
||||
:func:`~torch.quantization.convert` actually converts the weights to int8 and
|
||||
replaces the operations with their quantized counterparts. There are
|
||||
other helper functions for things like quantizing the input to your
|
||||
model and performing critical fusions like conv+relu.
|
||||
|
||||
Top-level quantization APIs
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: quantize
|
||||
.. autofunction:: quantize_dynamic
|
||||
.. autofunction:: quantize_qat
|
||||
.. autofunction:: prepare
|
||||
.. autofunction:: prepare_qat
|
||||
.. autofunction:: convert
|
||||
.. autoclass:: QConfig
|
||||
.. autoclass:: QConfigDynamic
|
||||
.. autoattr:: default_qconfig
|
||||
|
||||
Preparing model for quantization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: fuse_modules
|
||||
.. autoclass:: QuantStub
|
||||
.. autoclass:: DeQuantStub
|
||||
.. autoclass:: QuantWrapper
|
||||
.. autofunction:: add_quant_dequant
|
||||
|
||||
Utility functions
|
||||
~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: add_observer_
|
||||
.. autofunction:: swap_module
|
||||
.. autofunction:: propagate_qconfig_
|
||||
.. autofunction:: default_eval_fn
|
||||
|
||||
Observers
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: Observer
|
||||
:members:
|
||||
.. autoclass:: MinMaxObserver
|
||||
.. autoclass:: MovingAverageObserver
|
||||
.. autoclass:: PerChannelMinMaxObserver
|
||||
.. autoclass:: MovingAveragePerChannelMinMaxObserver
|
||||
.. autoclass:: HistogramObserver
|
||||
.. autoclass:: FakeQuantize
|
||||
.. autoclass:: NoopObserver
|
||||
|
||||
Debugging utilities
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: get_observer_dict
|
||||
.. autoclass:: RecordingObserver
|
||||
|
||||
torch.nn.instrinsic
|
||||
--------------------------------
|
||||
|
||||
This module implements the combined (fused) modules conv + relu which can be then quantized.
|
||||
|
||||
.. automodule:: torch.nn.intrinsic
|
||||
|
||||
ConvBn2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ConvBn2d
|
||||
:members:
|
||||
|
||||
ConvBnReLU2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ConvBnReLU2d
|
||||
:members:
|
||||
|
||||
ConvReLU2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ConvReLU2d
|
||||
:members:
|
||||
|
||||
LinearReLU
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: LinearReLU
|
||||
:members:
|
||||
|
||||
torch.nn.instrinsic.qat
|
||||
--------------------------------
|
||||
|
||||
This module implements the versions of those fused operations needed for quantization aware training.
|
||||
|
||||
.. automodule:: torch.nn.intrinsic.qat
|
||||
|
||||
ConvBn2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ConvBn2d
|
||||
:members:
|
||||
|
||||
ConvBnReLU2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ConvBnReLU2d
|
||||
:members:
|
||||
|
||||
ConvReLU2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ConvReLU2d
|
||||
:members:
|
||||
|
||||
LinearReLU
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: LinearReLU
|
||||
:members:
|
||||
|
||||
torch.nn.intrinsic.quantized
|
||||
--------------------------------------
|
||||
|
||||
This module implements the quantized implementations of fused operations like conv + relu.
|
||||
|
||||
.. automodule:: torch.nn.intrinsic.quantized
|
||||
|
||||
ConvReLU2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ConvReLU2d
|
||||
:members:
|
||||
|
||||
LinearReLU
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: LinearReLU
|
||||
:members:
|
||||
|
||||
torch.nn.qat
|
||||
---------------------------
|
||||
|
||||
This module implements versions of the key nn modules **Conv2d()** and **Linear()** which
|
||||
run in FP32 but with rounding applied to simulate the effect of INT8 quantization.
|
||||
|
||||
.. automodule:: torch.nn.qat
|
||||
|
||||
Conv2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: Conv2d
|
||||
:members:
|
||||
|
||||
Linear
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: Linear
|
||||
:members:
|
||||
|
||||
|
||||
torch.nn.quantized
|
||||
----------------------------
|
||||
|
||||
This module implements the quantized versions of the nn layers such as **Conv2d** and **ReLU**.
|
||||
|
||||
Functional interface
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: torch.nn.quantized.functional
|
||||
|
||||
.. autofunction:: relu
|
||||
.. autofunction:: linear
|
||||
.. autofunction:: conv2d
|
||||
.. autofunction:: max_pool2d
|
||||
|
||||
.. automodule:: torch.nn.quantized
|
||||
|
||||
ReLU
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ReLU
|
||||
:members:
|
||||
|
||||
ReLU6
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: ReLU6
|
||||
:members:
|
||||
|
||||
Conv2d
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: Conv2d
|
||||
:members:
|
||||
|
||||
FloatFunctional
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: FloatFunctional
|
||||
:members:
|
||||
|
||||
QFunctional
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: QFunctional
|
||||
:members:
|
||||
|
||||
Quantize
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: Quantize
|
||||
:members:
|
||||
|
||||
DeQuantize
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: DeQuantize
|
||||
:members:
|
||||
|
||||
Linear
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: Linear
|
||||
:members:
|
||||
|
||||
torch.nn.quantized.dynamic
|
||||
----------------------------
|
||||
|
||||
.. automodule:: torch.nn.quantized.dynamic
|
||||
|
||||
Linear
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: Linear
|
||||
:members:
|
||||
|
||||
LSTM
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autoclass:: LSTM
|
||||
:members:
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user