mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-28 18:54:57 +08:00 
			
		
		
		
	Compare commits
	
		
			98 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| ee77ccbb6d | |||
| f7f538566e | |||
| 432724b3e2 | |||
| cc98c93bf3 | |||
| eb8e8c1bcf | |||
| f0a4ac3ee0 | |||
| bd9766a36a | |||
| 80cca51adc | |||
| 9d45ee1d81 | |||
| de394b672d | |||
| 92c6401bb9 | |||
| b4f32dd292 | |||
| 3e451b4796 | |||
| 036a591556 | |||
| 86d9ee8dee | |||
| aa44ffb4c9 | |||
| 162b054e39 | |||
| 7f044f7398 | |||
| 0c81d6ba4b | |||
| d1752f2bf8 | |||
| 49fbeb8cc8 | |||
| f0d3fc70b4 | |||
| fb489555a9 | |||
| b5144f1068 | |||
| a5c08a6abd | |||
| 6cc759269f | |||
| 6742476ba3 | |||
| 067aee5f30 | |||
| 0a7f7e6d30 | |||
| e9fc91cbca | |||
| 23df957e94 | |||
| a7b161c08b | |||
| 6bae48c127 | |||
| c248943743 | |||
| e058a37fe4 | |||
| aa7112a618 | |||
| b728ffabc3 | |||
| d67898a93b | |||
| 9a25673478 | |||
| 17613ad73c | |||
| beaae6a2b6 | |||
| 328f49968c | |||
| 7f9096f868 | |||
| 7e94ee235f | |||
| 318fb8e8b9 | |||
| 43bb1b2356 | |||
| 87fbd27cc0 | |||
| 225c38b719 | |||
| 8074526e7f | |||
| 06a866de94 | |||
| 9b22a55499 | |||
| f8d3eac4c3 | |||
| 68b4d22da7 | |||
| 66b73b0950 | |||
| 32eb3b8d7b | |||
| 5a2a34cd2d | |||
| a8083e18e8 | |||
| bc3fb36ed7 | |||
| e0822f1089 | |||
| cbfd4e05e9 | |||
| b9a2c8ac5c | |||
| 6d7a73c0da | |||
| 15e4827617 | |||
| 024fa34700 | |||
| 95d2c7fc98 | |||
| 7ba2baee00 | |||
| ccf3a6de3d | |||
| 1ba6fc4ca6 | |||
| d4d4bf5686 | |||
| cee965fae9 | |||
| 544c16cdbf | |||
| 494a5563b4 | |||
| ba4c3a1c2c | |||
| c556f9052f | |||
| 5c80dd3c1f | |||
| 2d8ee11139 | |||
| ebc2519bec | |||
| 8c9e4b250d | |||
| 6a6f047fc6 | |||
| deadc27c23 | |||
| 6276fda119 | |||
| 65ee8f2c23 | |||
| 6126cfab2c | |||
| e2f6fed611 | |||
| 667deb92f7 | |||
| 0e88de5580 | |||
| 3c8ce2a57e | |||
| f2080fb3f2 | |||
| 84afb7b0c1 | |||
| b6e976ae2d | |||
| 8626a1cc81 | |||
| 831566ec90 | |||
| f7b3b20457 | |||
| 8ce38cf27d | |||
| a94f9c7246 | |||
| 2fc3bb8571 | |||
| f694d4d872 | |||
| d7b6d945eb | 
| @ -17,7 +17,7 @@ CONFIG_TREE_DATA = [ | ||||
|             X("py2"), | ||||
|             XImportant("cmake"), | ||||
|         ]), | ||||
|         (Ver("cuda", "9.1"), [XImportant("py2")]), | ||||
|         (Ver("cuda", "10.1"), [XImportant("py3.5")]),  # TensorRT 6 build | ||||
|         (Ver("mkl"), [XImportant("py2")]), | ||||
|         (Ver("gcc", "5"), [XImportant("onnx_py2")]), | ||||
|         (Ver("clang", "3.8"), [X("py2")]), | ||||
|  | ||||
| @ -14,7 +14,7 @@ from dataclasses import dataclass | ||||
|  | ||||
| DOCKER_IMAGE_PATH_BASE = "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/" | ||||
|  | ||||
| DOCKER_IMAGE_VERSION = 315 | ||||
| DOCKER_IMAGE_VERSION = 324 | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
|  | ||||
| @ -8,7 +8,7 @@ CONFIG_TREE_DATA = [ | ||||
|         (None, [ | ||||
|             XImportant("2.7.9"), | ||||
|             X("2.7"), | ||||
|             X("3.5"), | ||||
|             XImportant("3.5"),  # Not run on all PRs, but should be included on [test all] | ||||
|             X("nightly"), | ||||
|         ]), | ||||
|         ("gcc", [ | ||||
|  | ||||
| @ -304,20 +304,20 @@ jobs: | ||||
|           set -e | ||||
|           # Pull Docker image and run build | ||||
|           echo "DOCKER_IMAGE: "${DOCKER_IMAGE} | ||||
|           docker pull ${DOCKER_IMAGE} >/dev/null | ||||
|           time docker pull ${DOCKER_IMAGE} >/dev/null | ||||
|           export id=$(docker run -t -d -w /var/lib/jenkins ${DOCKER_IMAGE}) | ||||
|  | ||||
|           # TODO We may want to move the rebase logic to a separate step after checkout | ||||
|           # Rebase to master only if in xenial_py3_6_gcc5_4 case | ||||
|           if [[ "${CIRCLE_BRANCH}" != "master" && "${BUILD_ENVIRONMENT}" == *"gcc5"* ]]; then | ||||
|             echo "Merge master branch into $CIRCLE_BRANCH before build in environment $BUILD_ENVIRONMENT" | ||||
|           # Rebase to v1.3.0 only if in xenial_py3_6_gcc5_4 case | ||||
|           if [[ "${CIRCLE_BRANCH}" != "v1.3.0" && "${BUILD_ENVIRONMENT}" == *"gcc5"* ]]; then | ||||
|             echo "Merge v1.3.0 branch into $CIRCLE_BRANCH before build in environment $BUILD_ENVIRONMENT" | ||||
|             set -x | ||||
|             git config --global user.email "circleci.ossci@gmail.com" | ||||
|             git config --global user.name "CircleCI" | ||||
|             git config remote.origin.url https://github.com/pytorch/pytorch.git | ||||
|             git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master | ||||
|             git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=50 --quiet | ||||
|             export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/master` | ||||
|             git config --add remote.origin.fetch +refs/heads/v1.3.0:refs/remotes/origin/v1.3.0 | ||||
|             git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/v1.3.0:refs/remotes/origin/v1.3.0 --depth=50 --quiet | ||||
|             export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/v1.3.0` | ||||
|             echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET} | ||||
|             export GIT_COMMIT=${CIRCLE_SHA1} | ||||
|             echo "GIT_COMMIT: " ${GIT_COMMIT} | ||||
| @ -326,7 +326,7 @@ jobs: | ||||
|             git merge --no-edit --no-ff ${GIT_MERGE_TARGET} | ||||
|             set +x | ||||
|           else | ||||
|             echo "Do NOT merge master branch into $CIRCLE_BRANCH in environment $BUILD_ENVIRONMENT" | ||||
|             echo "Do NOT merge v1.3.0 branch into $CIRCLE_BRANCH in environment $BUILD_ENVIRONMENT" | ||||
|           fi | ||||
|  | ||||
|           git submodule sync && git submodule update -q --init --recursive | ||||
| @ -363,7 +363,7 @@ jobs: | ||||
|               export COMMIT_DOCKER_IMAGE=$output_image | ||||
|             fi | ||||
|             docker commit "$id" ${COMMIT_DOCKER_IMAGE} | ||||
|             docker push ${COMMIT_DOCKER_IMAGE} | ||||
|             time docker push ${COMMIT_DOCKER_IMAGE} | ||||
|           fi | ||||
|  | ||||
|   pytorch_linux_test: | ||||
| @ -391,7 +391,7 @@ jobs: | ||||
|             export COMMIT_DOCKER_IMAGE=$output_image | ||||
|           fi | ||||
|           echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE} | ||||
|           docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           if [ -n "${USE_CUDA_DOCKER_RUNTIME}" ]; then | ||||
|             export id=$(docker run --runtime=nvidia -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) | ||||
|           else | ||||
| @ -445,7 +445,7 @@ jobs: | ||||
|           chmod +x /home/circleci/project/ci_build_script.sh | ||||
|  | ||||
|           echo "DOCKER_IMAGE: "${DOCKER_IMAGE} | ||||
|           docker pull ${DOCKER_IMAGE} >/dev/null | ||||
|           time docker pull ${DOCKER_IMAGE} >/dev/null | ||||
|           export id=$(docker run -t -d -w /var/lib/jenkins ${DOCKER_IMAGE}) | ||||
|           docker cp /home/circleci/project/. $id:/var/lib/jenkins/workspace | ||||
|  | ||||
| @ -460,7 +460,7 @@ jobs: | ||||
|               export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1} | ||||
|             fi | ||||
|             docker commit "$id" ${COMMIT_DOCKER_IMAGE} | ||||
|             docker push ${COMMIT_DOCKER_IMAGE} | ||||
|             time docker push ${COMMIT_DOCKER_IMAGE} | ||||
|           fi | ||||
|  | ||||
|   caffe2_linux_test: | ||||
| @ -515,7 +515,7 @@ jobs: | ||||
|             export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1} | ||||
|           fi | ||||
|           echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE} | ||||
|           docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           if [ -n "${USE_CUDA_DOCKER_RUNTIME}" ]; then | ||||
|             export id=$(docker run --runtime=nvidia -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) | ||||
|           else | ||||
| @ -551,12 +551,7 @@ jobs: | ||||
|             # Reinitialize path (see man page for path_helper(8)) | ||||
|             eval `/usr/libexec/path_helper -s` | ||||
|  | ||||
|             # Use Homebrew Python if configured to do so | ||||
|             if [ "${PYTHON_INSTALLATION}" == "homebrew" ]; then | ||||
|               export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH | ||||
|             fi | ||||
|  | ||||
|             pip -q install numpy | ||||
|             export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH | ||||
|  | ||||
|             # Install Anaconda if we need to | ||||
|             if [ -n "${CAFFE2_USE_ANACONDA}" ]; then | ||||
| @ -569,6 +564,8 @@ jobs: | ||||
|               source ${TMPDIR}/anaconda/bin/activate | ||||
|             fi | ||||
|  | ||||
|             pip -q install numpy | ||||
|  | ||||
|             # Install sccache | ||||
|             sudo curl https://s3.amazonaws.com/ossci-macos/sccache --output /usr/local/bin/sccache | ||||
|             sudo chmod +x /usr/local/bin/sccache | ||||
| @ -606,7 +603,6 @@ jobs: | ||||
|             if which sccache > /dev/null; then | ||||
|               sccache --show-stats | ||||
|             fi | ||||
|  | ||||
|   binary_linux_build: | ||||
|     <<: *binary_linux_build_params | ||||
|     steps: | ||||
| @ -923,7 +919,7 @@ jobs: | ||||
|           set -e | ||||
|           export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1} | ||||
|           echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE} | ||||
|           docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           export id=$(docker run --runtime=nvidia -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) | ||||
|  | ||||
|           docker cp $id:/var/lib/jenkins/workspace/env /home/circleci/project/env | ||||
| @ -957,7 +953,7 @@ jobs: | ||||
|           set -ex | ||||
|           export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1} | ||||
|           echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE} | ||||
|           docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           export id=$(docker run -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) | ||||
|  | ||||
|           # master branch docs push | ||||
| @ -965,11 +961,11 @@ jobs: | ||||
|             export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/master master site") | docker exec -u jenkins -i "$id" bash) 2>&1' | ||||
|  | ||||
|           # stable release docs push. Due to some circleci limitations, we keep | ||||
|           # an eternal PR open for merging v1.2.0 -> master for this job. | ||||
|           # XXX: The following code is only run on the v1.2.0 branch, which might | ||||
|           # an eternal PR open for merging v1.3.0 -> master for this job. | ||||
|           # XXX: The following code is only run on the v1.3.0 branch, which might | ||||
|           # not be exactly the same as what you see here. | ||||
|           elif [[ "${CIRCLE_BRANCH}" == "v1.2.0" ]]; then | ||||
|             export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/stable 1.2.0 site dry_run") | docker exec -u jenkins -i "$id" bash) 2>&1' | ||||
|           elif [[ "${CIRCLE_BRANCH}" == "v1.3.0" ]]; then | ||||
|             export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/stable 1.3.0 site-v1.3.0 dry_run") | docker exec -u jenkins -i "$id" bash) 2>&1' | ||||
|  | ||||
|           # For open PRs: Do a dry_run of the docs build, don't push build | ||||
|           else | ||||
| @ -981,7 +977,7 @@ jobs: | ||||
|           # Save the docs build so we can debug any problems | ||||
|           export DEBUG_COMMIT_DOCKER_IMAGE=${COMMIT_DOCKER_IMAGE}-debug | ||||
|           docker commit "$id" ${DEBUG_COMMIT_DOCKER_IMAGE} | ||||
|           docker push ${DEBUG_COMMIT_DOCKER_IMAGE} | ||||
|           time docker push ${DEBUG_COMMIT_DOCKER_IMAGE} | ||||
|  | ||||
|   pytorch_cpp_doc_push: | ||||
|     environment: | ||||
| @ -1002,7 +998,7 @@ jobs: | ||||
|           set -ex | ||||
|           export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1} | ||||
|           echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE} | ||||
|           docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           export id=$(docker run -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) | ||||
|  | ||||
|           # master branch docs push | ||||
| @ -1026,7 +1022,7 @@ jobs: | ||||
|           # Save the docs build so we can debug any problems | ||||
|           export DEBUG_COMMIT_DOCKER_IMAGE=${COMMIT_DOCKER_IMAGE}-debug | ||||
|           docker commit "$id" ${DEBUG_COMMIT_DOCKER_IMAGE} | ||||
|           docker push ${DEBUG_COMMIT_DOCKER_IMAGE} | ||||
|           time docker push ${DEBUG_COMMIT_DOCKER_IMAGE} | ||||
|  | ||||
|   pytorch_macos_10_13_py3_build: | ||||
|     environment: | ||||
| @ -1171,14 +1167,14 @@ jobs: | ||||
|           echo "docker_image_libtorch_android_arm_v8a: "${docker_image_libtorch_android_arm_v8a} | ||||
|  | ||||
|           # x86_32 | ||||
|           docker pull ${docker_image_libtorch_android_x86_32} >/dev/null | ||||
|           time docker pull ${docker_image_libtorch_android_x86_32} >/dev/null | ||||
|           export id_x86_32=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_32}) | ||||
|  | ||||
|           export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_x86_32" bash) 2>&1' | ||||
|           echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts | ||||
|  | ||||
|           # arm-v7a | ||||
|           docker pull ${docker_image_libtorch_android_arm_v7a} >/dev/null | ||||
|           time docker pull ${docker_image_libtorch_android_arm_v7a} >/dev/null | ||||
|           export id_arm_v7a=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_arm_v7a}) | ||||
|  | ||||
|           export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_arm_v7a" bash) 2>&1' | ||||
| @ -1188,7 +1184,7 @@ jobs: | ||||
|           docker cp $id_arm_v7a:/var/lib/jenkins/workspace/build_android/install ~/workspace/build_android_install_arm_v7a | ||||
|  | ||||
|           # x86_64 | ||||
|           docker pull ${docker_image_libtorch_android_x86_64} >/dev/null | ||||
|           time docker pull ${docker_image_libtorch_android_x86_64} >/dev/null | ||||
|           export id_x86_64=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_64}) | ||||
|  | ||||
|           export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_x86_64" bash) 2>&1' | ||||
| @ -1198,7 +1194,7 @@ jobs: | ||||
|           docker cp $id_x86_64:/var/lib/jenkins/workspace/build_android/install ~/workspace/build_android_install_x86_64 | ||||
|  | ||||
|           # arm-v8a | ||||
|           docker pull ${docker_image_libtorch_android_arm_v8a} >/dev/null | ||||
|           time docker pull ${docker_image_libtorch_android_arm_v8a} >/dev/null | ||||
|           export id_arm_v8a=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_arm_v8a}) | ||||
|  | ||||
|           export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_arm_v8a" bash) 2>&1' | ||||
| @ -1220,7 +1216,7 @@ jobs: | ||||
|  | ||||
|           output_image=$docker_image_libtorch_android_x86_32-gradle | ||||
|           docker commit "$id_x86_32" ${output_image} | ||||
|           docker push ${output_image} | ||||
|           time docker push ${output_image} | ||||
|     - store_artifacts: | ||||
|         path: ~/workspace/build_android_artifacts/artifacts.tgz | ||||
|         destination: artifacts.tgz | ||||
| @ -1251,7 +1247,7 @@ jobs: | ||||
|           echo "docker_image_libtorch_android_x86_32_gradle: "${docker_image_libtorch_android_x86_32_gradle} | ||||
|  | ||||
|           # x86_32 | ||||
|           docker pull ${docker_image_libtorch_android_x86_32_gradle} >/dev/null | ||||
|           time docker pull ${docker_image_libtorch_android_x86_32_gradle} >/dev/null | ||||
|           export id_x86_32=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_32_gradle}) | ||||
|  | ||||
|           export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace" && echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export SONATYPE_NEXUS_USERNAME=${SONATYPE_NEXUS_USERNAME}" && echo "export SONATYPE_NEXUS_PASSWORD=${SONATYPE_NEXUS_PASSWORD}" && echo "export ANDROID_SIGN_KEY=${ANDROID_SIGN_KEY}" && echo "export ANDROID_SIGN_PASS=${ANDROID_SIGN_PASS}" && echo "sudo chown -R jenkins workspace && cd workspace && ./.circleci/scripts/publish_android_snapshot.sh") | docker exec -u jenkins -i "$id_x86_32" bash) 2>&1' | ||||
| @ -1259,7 +1255,7 @@ jobs: | ||||
|  | ||||
|           output_image=${docker_image_libtorch_android_x86_32_gradle}-publish-snapshot | ||||
|           docker commit "$id_x86_32" ${output_image} | ||||
|           docker push ${output_image} | ||||
|           time docker push ${output_image} | ||||
|  | ||||
|   pytorch_android_gradle_build-x86_32: | ||||
|     environment: | ||||
| @ -1291,7 +1287,7 @@ jobs: | ||||
|           echo "docker_image_libtorch_android_x86_32: "${docker_image_libtorch_android_x86_32} | ||||
|  | ||||
|           # x86 | ||||
|           docker pull ${docker_image_libtorch_android_x86_32} >/dev/null | ||||
|           time docker pull ${docker_image_libtorch_android_x86_32} >/dev/null | ||||
|           export id=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_32}) | ||||
|  | ||||
|           export COMMAND='((echo "source ./workspace/env" && echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "sudo chown -R jenkins workspace && cd workspace && ./.circleci/scripts/build_android_gradle.sh") | docker exec -u jenkins -i "$id" bash) 2>&1' | ||||
| @ -1302,7 +1298,7 @@ jobs: | ||||
|  | ||||
|           output_image=${docker_image_libtorch_android_x86_32}-gradle | ||||
|           docker commit "$id" ${output_image} | ||||
|           docker push ${output_image} | ||||
|           time docker push ${output_image} | ||||
|     - store_artifacts: | ||||
|         path: ~/workspace/build_android_x86_32_artifacts/artifacts.tgz | ||||
|         destination: artifacts.tgz | ||||
| @ -1333,7 +1329,7 @@ jobs: | ||||
|             export PATH="~/anaconda/bin:${PATH}" | ||||
|             source ~/anaconda/bin/activate | ||||
|             # Install dependencies | ||||
|             conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests | ||||
|             conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes | ||||
|             # sync submodules | ||||
|             cd ${PROJ_ROOT} | ||||
|             git submodule sync | ||||
| @ -1518,11 +1514,6 @@ workflows: | ||||
|           name: pytorch_linux_xenial_py3_5_build | ||||
|           requires: | ||||
|             - setup | ||||
|           filters: | ||||
|             branches: | ||||
|               only: | ||||
|                 - master | ||||
|                 - /ci-all\/.*/ | ||||
|           build_environment: "pytorch-linux-xenial-py3.5-build" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.5:347" | ||||
|       - pytorch_linux_test: | ||||
| @ -1530,11 +1521,6 @@ workflows: | ||||
|           requires: | ||||
|             - setup | ||||
|             - pytorch_linux_xenial_py3_5_build | ||||
|           filters: | ||||
|             branches: | ||||
|               only: | ||||
|                 - master | ||||
|                 - /ci-all\/.*/ | ||||
|           build_environment: "pytorch-linux-xenial-py3.5-test" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.5:347" | ||||
|           resource_class: large | ||||
| @ -1944,7 +1930,7 @@ workflows: | ||||
|                 - master | ||||
|                 - /ci-all\/.*/ | ||||
|           build_environment: "caffe2-py2-gcc4.8-ubuntu14.04-build" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:324" | ||||
|       - caffe2_linux_test: | ||||
|           name: caffe2_py2_gcc4_8_ubuntu14_04_test | ||||
|           requires: | ||||
| @ -1956,7 +1942,7 @@ workflows: | ||||
|                 - master | ||||
|                 - /ci-all\/.*/ | ||||
|           build_environment: "caffe2-py2-gcc4.8-ubuntu14.04-test" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:324" | ||||
|           resource_class: large | ||||
|       - caffe2_linux_build: | ||||
|           name: caffe2_py2_cuda9_0_cudnn7_ubuntu16_04_build | ||||
| @ -1968,7 +1954,7 @@ workflows: | ||||
|                 - master | ||||
|                 - /ci-all\/.*/ | ||||
|           build_environment: "caffe2-py2-cuda9.0-cudnn7-ubuntu16.04-build" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:324" | ||||
|       - caffe2_linux_test: | ||||
|           name: caffe2_py2_cuda9_0_cudnn7_ubuntu16_04_test | ||||
|           requires: | ||||
| @ -1981,14 +1967,14 @@ workflows: | ||||
|                 - /ci-all\/.*/ | ||||
|           build_environment: "caffe2-py2-cuda9.0-cudnn7-ubuntu16.04-test" | ||||
|           use_cuda_docker_runtime: "1" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:324" | ||||
|           resource_class: gpu.medium | ||||
|       - caffe2_linux_build: | ||||
|           name: caffe2_cmake_cuda9_0_cudnn7_ubuntu16_04_build | ||||
|           requires: | ||||
|             - setup | ||||
|           build_environment: "caffe2-cmake-cuda9.0-cudnn7-ubuntu16.04-build" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:324" | ||||
|       - caffe2_linux_test: | ||||
|           name: caffe2_cmake_cuda9_0_cudnn7_ubuntu16_04_test | ||||
|           requires: | ||||
| @ -1996,50 +1982,50 @@ workflows: | ||||
|             - caffe2_cmake_cuda9_0_cudnn7_ubuntu16_04_build | ||||
|           build_environment: "caffe2-cmake-cuda9.0-cudnn7-ubuntu16.04-test" | ||||
|           use_cuda_docker_runtime: "1" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:324" | ||||
|           resource_class: gpu.medium | ||||
|       - caffe2_linux_build: | ||||
|           name: caffe2_py2_cuda9_1_cudnn7_ubuntu16_04_build | ||||
|           name: caffe2_py3_5_cuda10_1_cudnn7_ubuntu16_04_build | ||||
|           requires: | ||||
|             - setup | ||||
|           build_environment: "caffe2-py2-cuda9.1-cudnn7-ubuntu16.04-build" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.1-cudnn7-ubuntu16.04:315" | ||||
|           build_environment: "caffe2-py3.5-cuda10.1-cudnn7-ubuntu16.04-build" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.5-cuda10.1-cudnn7-ubuntu16.04:324" | ||||
|       - caffe2_linux_test: | ||||
|           name: caffe2_py2_cuda9_1_cudnn7_ubuntu16_04_test | ||||
|           name: caffe2_py3_5_cuda10_1_cudnn7_ubuntu16_04_test | ||||
|           requires: | ||||
|             - setup | ||||
|             - caffe2_py2_cuda9_1_cudnn7_ubuntu16_04_build | ||||
|           build_environment: "caffe2-py2-cuda9.1-cudnn7-ubuntu16.04-test" | ||||
|             - caffe2_py3_5_cuda10_1_cudnn7_ubuntu16_04_build | ||||
|           build_environment: "caffe2-py3.5-cuda10.1-cudnn7-ubuntu16.04-test" | ||||
|           use_cuda_docker_runtime: "1" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.1-cudnn7-ubuntu16.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.5-cuda10.1-cudnn7-ubuntu16.04:324" | ||||
|           resource_class: gpu.medium | ||||
|       - caffe2_linux_build: | ||||
|           name: caffe2_py2_mkl_ubuntu16_04_build | ||||
|           requires: | ||||
|             - setup | ||||
|           build_environment: "caffe2-py2-mkl-ubuntu16.04-build" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:324" | ||||
|       - caffe2_linux_test: | ||||
|           name: caffe2_py2_mkl_ubuntu16_04_test | ||||
|           requires: | ||||
|             - setup | ||||
|             - caffe2_py2_mkl_ubuntu16_04_build | ||||
|           build_environment: "caffe2-py2-mkl-ubuntu16.04-test" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:324" | ||||
|           resource_class: large | ||||
|       - caffe2_linux_build: | ||||
|           name: caffe2_onnx_py2_gcc5_ubuntu16_04_build | ||||
|           requires: | ||||
|             - setup | ||||
|           build_environment: "caffe2-onnx-py2-gcc5-ubuntu16.04-build" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:324" | ||||
|       - caffe2_linux_test: | ||||
|           name: caffe2_onnx_py2_gcc5_ubuntu16_04_test | ||||
|           requires: | ||||
|             - setup | ||||
|             - caffe2_onnx_py2_gcc5_ubuntu16_04_build | ||||
|           build_environment: "caffe2-onnx-py2-gcc5-ubuntu16.04-test" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:324" | ||||
|           resource_class: large | ||||
|       - caffe2_linux_build: | ||||
|           name: caffe2_py2_clang3_8_ubuntu16_04_build | ||||
| @ -2051,7 +2037,7 @@ workflows: | ||||
|                 - master | ||||
|                 - /ci-all\/.*/ | ||||
|           build_environment: "caffe2-py2-clang3.8-ubuntu16.04-build" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.8-ubuntu16.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.8-ubuntu16.04:324" | ||||
|           build_only: "1" | ||||
|       - caffe2_linux_build: | ||||
|           name: caffe2_py2_clang3_9_ubuntu16_04_build | ||||
| @ -2063,35 +2049,35 @@ workflows: | ||||
|                 - master | ||||
|                 - /ci-all\/.*/ | ||||
|           build_environment: "caffe2-py2-clang3.9-ubuntu16.04-build" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.9-ubuntu16.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.9-ubuntu16.04:324" | ||||
|           build_only: "1" | ||||
|       - caffe2_linux_build: | ||||
|           name: caffe2_py2_clang7_ubuntu16_04_build | ||||
|           requires: | ||||
|             - setup | ||||
|           build_environment: "caffe2-py2-clang7-ubuntu16.04-build" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang7-ubuntu16.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang7-ubuntu16.04:324" | ||||
|           build_only: "1" | ||||
|       - caffe2_linux_build: | ||||
|           name: caffe2_onnx_py3_6_clang7_ubuntu16_04_build | ||||
|           requires: | ||||
|             - setup | ||||
|           build_environment: "caffe2-onnx-py3.6-clang7-ubuntu16.04-build" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:324" | ||||
|       - caffe2_linux_test: | ||||
|           name: caffe2_onnx_py3_6_clang7_ubuntu16_04_test | ||||
|           requires: | ||||
|             - setup | ||||
|             - caffe2_onnx_py3_6_clang7_ubuntu16_04_build | ||||
|           build_environment: "caffe2-onnx-py3.6-clang7-ubuntu16.04-test" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:324" | ||||
|           resource_class: large | ||||
|       - caffe2_linux_build: | ||||
|           name: caffe2_py2_android_ubuntu16_04_build | ||||
|           requires: | ||||
|             - setup | ||||
|           build_environment: "caffe2-py2-android-ubuntu16.04-build" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-android-ubuntu16.04:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-android-ubuntu16.04:324" | ||||
|           build_only: "1" | ||||
|       - caffe2_linux_build: | ||||
|           name: caffe2_py2_cuda9_0_cudnn7_centos7_build | ||||
| @ -2103,7 +2089,7 @@ workflows: | ||||
|                 - master | ||||
|                 - /ci-all\/.*/ | ||||
|           build_environment: "caffe2-py2-cuda9.0-cudnn7-centos7-build" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:324" | ||||
|       - caffe2_linux_test: | ||||
|           name: caffe2_py2_cuda9_0_cudnn7_centos7_test | ||||
|           requires: | ||||
| @ -2116,7 +2102,7 @@ workflows: | ||||
|                 - /ci-all\/.*/ | ||||
|           build_environment: "caffe2-py2-cuda9.0-cudnn7-centos7-test" | ||||
|           use_cuda_docker_runtime: "1" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:315" | ||||
|           docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:324" | ||||
|           resource_class: gpu.medium | ||||
|       - caffe2_macos_build: | ||||
|           name: caffe2_py2_ios_macos10_13_build | ||||
|  | ||||
| @ -13,7 +13,7 @@ chmod +x ~/Downloads/conda.sh | ||||
| export PATH="~/anaconda/bin:${PATH}" | ||||
| source ~/anaconda/bin/activate | ||||
| # Install dependencies | ||||
| conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests | ||||
| conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes | ||||
| export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} | ||||
| # sync submodules | ||||
| cd ${PROJ_ROOT} | ||||
|  | ||||
| @ -22,7 +22,7 @@ default_set = set([ | ||||
|     # Caffe2 CPU | ||||
|     'caffe2-py2-mkl-ubuntu16.04', | ||||
|     # Caffe2 CUDA | ||||
|     'caffe2-py2-cuda9.1-cudnn7-ubuntu16.04', | ||||
|     'caffe2-py3.5-cuda10.1-cudnn7-ubuntu16.04', | ||||
|     # Caffe2 ONNX | ||||
|     'caffe2-onnx-py2-gcc5-ubuntu16.04', | ||||
|     'caffe2-onnx-py3.6-clang7-ubuntu16.04', | ||||
|  | ||||
| @ -18,7 +18,7 @@ if ! [ -e "$SCRIPT_DIR/COMMIT_MSG" ]; then | ||||
|   exit 1 | ||||
| fi | ||||
| if [ -n "${CIRCLE_PULL_REQUEST:-}" ]; then | ||||
|   if [[ $CIRCLE_BRANCH != "ci-all/"* ]]; then | ||||
|   if [[ $CIRCLE_BRANCH != "ci-all/"* ]] && [[ $CIRCLE_BRANCH != "nightly" ]] &&  [[ $CIRCLE_BRANCH != "postnightly" ]] ; then | ||||
|     # Don't swallow "script doesn't exist | ||||
|     [ -e "$SCRIPT_DIR/should_run_job.py"  ] | ||||
|     if ! python "$SCRIPT_DIR/should_run_job.py" "${BUILD_ENVIRONMENT:-}" < "$SCRIPT_DIR/COMMIT_MSG" ; then | ||||
|  | ||||
| @ -40,7 +40,7 @@ | ||||
|           chmod +x /home/circleci/project/ci_build_script.sh | ||||
|  | ||||
|           echo "DOCKER_IMAGE: "${DOCKER_IMAGE} | ||||
|           docker pull ${DOCKER_IMAGE} >/dev/null | ||||
|           time docker pull ${DOCKER_IMAGE} >/dev/null | ||||
|           export id=$(docker run -t -d -w /var/lib/jenkins ${DOCKER_IMAGE}) | ||||
|           docker cp /home/circleci/project/. $id:/var/lib/jenkins/workspace | ||||
|  | ||||
| @ -55,7 +55,7 @@ | ||||
|               export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1} | ||||
|             fi | ||||
|             docker commit "$id" ${COMMIT_DOCKER_IMAGE} | ||||
|             docker push ${COMMIT_DOCKER_IMAGE} | ||||
|             time docker push ${COMMIT_DOCKER_IMAGE} | ||||
|           fi | ||||
|  | ||||
|   caffe2_linux_test: | ||||
| @ -110,7 +110,7 @@ | ||||
|             export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1} | ||||
|           fi | ||||
|           echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE} | ||||
|           docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           if [ -n "${USE_CUDA_DOCKER_RUNTIME}" ]; then | ||||
|             export id=$(docker run --runtime=nvidia -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) | ||||
|           else | ||||
| @ -146,12 +146,7 @@ | ||||
|             # Reinitialize path (see man page for path_helper(8)) | ||||
|             eval `/usr/libexec/path_helper -s` | ||||
|  | ||||
|             # Use Homebrew Python if configured to do so | ||||
|             if [ "${PYTHON_INSTALLATION}" == "homebrew" ]; then | ||||
|               export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH | ||||
|             fi | ||||
|  | ||||
|             pip -q install numpy | ||||
|             export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH | ||||
|  | ||||
|             # Install Anaconda if we need to | ||||
|             if [ -n "${CAFFE2_USE_ANACONDA}" ]; then | ||||
| @ -164,6 +159,8 @@ | ||||
|               source ${TMPDIR}/anaconda/bin/activate | ||||
|             fi | ||||
|  | ||||
|             pip -q install numpy | ||||
|  | ||||
|             # Install sccache | ||||
|             sudo curl https://s3.amazonaws.com/ossci-macos/sccache --output /usr/local/bin/sccache | ||||
|             sudo chmod +x /usr/local/bin/sccache | ||||
| @ -201,4 +198,3 @@ | ||||
|             if which sccache > /dev/null; then | ||||
|               sccache --show-stats | ||||
|             fi | ||||
|  | ||||
|  | ||||
| @ -19,7 +19,7 @@ | ||||
|           set -e | ||||
|           export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1} | ||||
|           echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE} | ||||
|           docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           export id=$(docker run --runtime=nvidia -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) | ||||
|  | ||||
|           docker cp $id:/var/lib/jenkins/workspace/env /home/circleci/project/env | ||||
| @ -53,7 +53,7 @@ | ||||
|           set -ex | ||||
|           export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1} | ||||
|           echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE} | ||||
|           docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           export id=$(docker run -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) | ||||
|  | ||||
|           # master branch docs push | ||||
| @ -61,11 +61,11 @@ | ||||
|             export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/master master site") | docker exec -u jenkins -i "$id" bash) 2>&1' | ||||
|  | ||||
|           # stable release docs push. Due to some circleci limitations, we keep | ||||
|           # an eternal PR open for merging v1.2.0 -> master for this job. | ||||
|           # XXX: The following code is only run on the v1.2.0 branch, which might | ||||
|           # an eternal PR open for merging v1.3.0 -> master for this job. | ||||
|           # XXX: The following code is only run on the v1.3.0 branch, which might | ||||
|           # not be exactly the same as what you see here. | ||||
|           elif [[ "${CIRCLE_BRANCH}" == "v1.2.0" ]]; then | ||||
|             export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/stable 1.2.0 site dry_run") | docker exec -u jenkins -i "$id" bash) 2>&1' | ||||
|           elif [[ "${CIRCLE_BRANCH}" == "v1.3.0" ]]; then | ||||
|             export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export GITHUB_PYTORCHBOT_TOKEN=${GITHUB_PYTORCHBOT_TOKEN}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/stable 1.3.0 site-v1.3.0 dry_run") | docker exec -u jenkins -i "$id" bash) 2>&1' | ||||
|  | ||||
|           # For open PRs: Do a dry_run of the docs build, don't push build | ||||
|           else | ||||
| @ -77,7 +77,7 @@ | ||||
|           # Save the docs build so we can debug any problems | ||||
|           export DEBUG_COMMIT_DOCKER_IMAGE=${COMMIT_DOCKER_IMAGE}-debug | ||||
|           docker commit "$id" ${DEBUG_COMMIT_DOCKER_IMAGE} | ||||
|           docker push ${DEBUG_COMMIT_DOCKER_IMAGE} | ||||
|           time docker push ${DEBUG_COMMIT_DOCKER_IMAGE} | ||||
|  | ||||
|   pytorch_cpp_doc_push: | ||||
|     environment: | ||||
| @ -98,7 +98,7 @@ | ||||
|           set -ex | ||||
|           export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}-${CIRCLE_SHA1} | ||||
|           echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE} | ||||
|           docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           export id=$(docker run -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) | ||||
|  | ||||
|           # master branch docs push | ||||
| @ -122,7 +122,7 @@ | ||||
|           # Save the docs build so we can debug any problems | ||||
|           export DEBUG_COMMIT_DOCKER_IMAGE=${COMMIT_DOCKER_IMAGE}-debug | ||||
|           docker commit "$id" ${DEBUG_COMMIT_DOCKER_IMAGE} | ||||
|           docker push ${DEBUG_COMMIT_DOCKER_IMAGE} | ||||
|           time docker push ${DEBUG_COMMIT_DOCKER_IMAGE} | ||||
|  | ||||
|   pytorch_macos_10_13_py3_build: | ||||
|     environment: | ||||
| @ -267,14 +267,14 @@ | ||||
|           echo "docker_image_libtorch_android_arm_v8a: "${docker_image_libtorch_android_arm_v8a} | ||||
|  | ||||
|           # x86_32 | ||||
|           docker pull ${docker_image_libtorch_android_x86_32} >/dev/null | ||||
|           time docker pull ${docker_image_libtorch_android_x86_32} >/dev/null | ||||
|           export id_x86_32=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_32}) | ||||
|  | ||||
|           export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_x86_32" bash) 2>&1' | ||||
|           echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts | ||||
|  | ||||
|           # arm-v7a | ||||
|           docker pull ${docker_image_libtorch_android_arm_v7a} >/dev/null | ||||
|           time docker pull ${docker_image_libtorch_android_arm_v7a} >/dev/null | ||||
|           export id_arm_v7a=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_arm_v7a}) | ||||
|  | ||||
|           export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_arm_v7a" bash) 2>&1' | ||||
| @ -284,7 +284,7 @@ | ||||
|           docker cp $id_arm_v7a:/var/lib/jenkins/workspace/build_android/install ~/workspace/build_android_install_arm_v7a | ||||
|  | ||||
|           # x86_64 | ||||
|           docker pull ${docker_image_libtorch_android_x86_64} >/dev/null | ||||
|           time docker pull ${docker_image_libtorch_android_x86_64} >/dev/null | ||||
|           export id_x86_64=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_64}) | ||||
|  | ||||
|           export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_x86_64" bash) 2>&1' | ||||
| @ -294,7 +294,7 @@ | ||||
|           docker cp $id_x86_64:/var/lib/jenkins/workspace/build_android/install ~/workspace/build_android_install_x86_64 | ||||
|  | ||||
|           # arm-v8a | ||||
|           docker pull ${docker_image_libtorch_android_arm_v8a} >/dev/null | ||||
|           time docker pull ${docker_image_libtorch_android_arm_v8a} >/dev/null | ||||
|           export id_arm_v8a=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_arm_v8a}) | ||||
|  | ||||
|           export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace") | docker exec -u jenkins -i "$id_arm_v8a" bash) 2>&1' | ||||
| @ -316,7 +316,7 @@ | ||||
|  | ||||
|           output_image=$docker_image_libtorch_android_x86_32-gradle | ||||
|           docker commit "$id_x86_32" ${output_image} | ||||
|           docker push ${output_image} | ||||
|           time docker push ${output_image} | ||||
|     - store_artifacts: | ||||
|         path: ~/workspace/build_android_artifacts/artifacts.tgz | ||||
|         destination: artifacts.tgz | ||||
| @ -347,7 +347,7 @@ | ||||
|           echo "docker_image_libtorch_android_x86_32_gradle: "${docker_image_libtorch_android_x86_32_gradle} | ||||
|  | ||||
|           # x86_32 | ||||
|           docker pull ${docker_image_libtorch_android_x86_32_gradle} >/dev/null | ||||
|           time docker pull ${docker_image_libtorch_android_x86_32_gradle} >/dev/null | ||||
|           export id_x86_32=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_32_gradle}) | ||||
|  | ||||
|           export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace" && echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export SONATYPE_NEXUS_USERNAME=${SONATYPE_NEXUS_USERNAME}" && echo "export SONATYPE_NEXUS_PASSWORD=${SONATYPE_NEXUS_PASSWORD}" && echo "export ANDROID_SIGN_KEY=${ANDROID_SIGN_KEY}" && echo "export ANDROID_SIGN_PASS=${ANDROID_SIGN_PASS}" && echo "sudo chown -R jenkins workspace && cd workspace && ./.circleci/scripts/publish_android_snapshot.sh") | docker exec -u jenkins -i "$id_x86_32" bash) 2>&1' | ||||
| @ -355,7 +355,7 @@ | ||||
|  | ||||
|           output_image=${docker_image_libtorch_android_x86_32_gradle}-publish-snapshot | ||||
|           docker commit "$id_x86_32" ${output_image} | ||||
|           docker push ${output_image} | ||||
|           time docker push ${output_image} | ||||
|  | ||||
|   pytorch_android_gradle_build-x86_32: | ||||
|     environment: | ||||
| @ -387,7 +387,7 @@ | ||||
|           echo "docker_image_libtorch_android_x86_32: "${docker_image_libtorch_android_x86_32} | ||||
|  | ||||
|           # x86 | ||||
|           docker pull ${docker_image_libtorch_android_x86_32} >/dev/null | ||||
|           time docker pull ${docker_image_libtorch_android_x86_32} >/dev/null | ||||
|           export id=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_32}) | ||||
|  | ||||
|           export COMMAND='((echo "source ./workspace/env" && echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "sudo chown -R jenkins workspace && cd workspace && ./.circleci/scripts/build_android_gradle.sh") | docker exec -u jenkins -i "$id" bash) 2>&1' | ||||
| @ -398,7 +398,7 @@ | ||||
|  | ||||
|           output_image=${docker_image_libtorch_android_x86_32}-gradle | ||||
|           docker commit "$id" ${output_image} | ||||
|           docker push ${output_image} | ||||
|           time docker push ${output_image} | ||||
|     - store_artifacts: | ||||
|         path: ~/workspace/build_android_x86_32_artifacts/artifacts.tgz | ||||
|         destination: artifacts.tgz | ||||
| @ -429,7 +429,7 @@ | ||||
|             export PATH="~/anaconda/bin:${PATH}" | ||||
|             source ~/anaconda/bin/activate | ||||
|             # Install dependencies | ||||
|             conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests | ||||
|             conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes | ||||
|             # sync submodules | ||||
|             cd ${PROJ_ROOT} | ||||
|             git submodule sync | ||||
|  | ||||
| @ -16,20 +16,20 @@ jobs: | ||||
|           set -e | ||||
|           # Pull Docker image and run build | ||||
|           echo "DOCKER_IMAGE: "${DOCKER_IMAGE} | ||||
|           docker pull ${DOCKER_IMAGE} >/dev/null | ||||
|           time docker pull ${DOCKER_IMAGE} >/dev/null | ||||
|           export id=$(docker run -t -d -w /var/lib/jenkins ${DOCKER_IMAGE}) | ||||
|  | ||||
|           # TODO We may want to move the rebase logic to a separate step after checkout | ||||
|           # Rebase to master only if in xenial_py3_6_gcc5_4 case | ||||
|           if [[ "${CIRCLE_BRANCH}" != "master" && "${BUILD_ENVIRONMENT}" == *"gcc5"* ]]; then | ||||
|             echo "Merge master branch into $CIRCLE_BRANCH before build in environment $BUILD_ENVIRONMENT" | ||||
|           # Rebase to v1.3.0 only if in xenial_py3_6_gcc5_4 case | ||||
|           if [[ "${CIRCLE_BRANCH}" != "v1.3.0" && "${BUILD_ENVIRONMENT}" == *"gcc5"* ]]; then | ||||
|             echo "Merge v1.3.0 branch into $CIRCLE_BRANCH before build in environment $BUILD_ENVIRONMENT" | ||||
|             set -x | ||||
|             git config --global user.email "circleci.ossci@gmail.com" | ||||
|             git config --global user.name "CircleCI" | ||||
|             git config remote.origin.url https://github.com/pytorch/pytorch.git | ||||
|             git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master | ||||
|             git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=50 --quiet | ||||
|             export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/master` | ||||
|             git config --add remote.origin.fetch +refs/heads/v1.3.0:refs/remotes/origin/v1.3.0 | ||||
|             git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/v1.3.0:refs/remotes/origin/v1.3.0 --depth=50 --quiet | ||||
|             export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/v1.3.0` | ||||
|             echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET} | ||||
|             export GIT_COMMIT=${CIRCLE_SHA1} | ||||
|             echo "GIT_COMMIT: " ${GIT_COMMIT} | ||||
| @ -38,7 +38,7 @@ jobs: | ||||
|             git merge --no-edit --no-ff ${GIT_MERGE_TARGET} | ||||
|             set +x | ||||
|           else | ||||
|             echo "Do NOT merge master branch into $CIRCLE_BRANCH in environment $BUILD_ENVIRONMENT" | ||||
|             echo "Do NOT merge v1.3.0 branch into $CIRCLE_BRANCH in environment $BUILD_ENVIRONMENT" | ||||
|           fi | ||||
|  | ||||
|           git submodule sync && git submodule update -q --init --recursive | ||||
| @ -75,7 +75,7 @@ jobs: | ||||
|               export COMMIT_DOCKER_IMAGE=$output_image | ||||
|             fi | ||||
|             docker commit "$id" ${COMMIT_DOCKER_IMAGE} | ||||
|             docker push ${COMMIT_DOCKER_IMAGE} | ||||
|             time docker push ${COMMIT_DOCKER_IMAGE} | ||||
|           fi | ||||
|  | ||||
|   pytorch_linux_test: | ||||
| @ -103,7 +103,7 @@ jobs: | ||||
|             export COMMIT_DOCKER_IMAGE=$output_image | ||||
|           fi | ||||
|           echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE} | ||||
|           docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null | ||||
|           if [ -n "${USE_CUDA_DOCKER_RUNTIME}" ]; then | ||||
|             export id=$(docker run --runtime=nvidia -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) | ||||
|           else | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64 | ||||
|  | ||||
| VERSION_NAME=0.0.7-SNAPSHOT | ||||
| VERSION_NAME=1.3.0 | ||||
| GROUP=org.pytorch | ||||
| MAVEN_GROUP=org.pytorch | ||||
| POM_URL=https://github.com/pytorch/pytorch/tree/master/android | ||||
|  | ||||
| @ -0,0 +1,30 @@ | ||||
| package org.pytorch; | ||||
|  | ||||
| import com.facebook.soloader.nativeloader.NativeLoader; | ||||
| import com.facebook.soloader.nativeloader.SystemDelegate; | ||||
| import org.junit.BeforeClass; | ||||
|  | ||||
| import java.io.IOException; | ||||
| import java.io.InputStream; | ||||
| import java.nio.file.Files; | ||||
| import java.nio.file.Path; | ||||
| import java.nio.file.StandardCopyOption; | ||||
| import java.util.Objects; | ||||
|  | ||||
| public class PytorchHostTests extends PytorchTestBase { | ||||
|   @BeforeClass | ||||
|   public static void setUpClass() { | ||||
|     if (!NativeLoader.isInitialized()) { | ||||
|       NativeLoader.init(new SystemDelegate()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Override | ||||
|   protected String assetFilePath(String assetName) throws IOException { | ||||
|     Path tempFile = Files.createTempFile("test", ".pt"); | ||||
|     try (InputStream resource = Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) { | ||||
|       Files.copy(resource, tempFile, StandardCopyOption.REPLACE_EXISTING); | ||||
|     } | ||||
|     return tempFile.toAbsolutePath().toString(); | ||||
|   } | ||||
| } | ||||
| @ -2,8 +2,6 @@ package org.pytorch; | ||||
|  | ||||
| import android.content.Context; | ||||
|  | ||||
| import org.junit.Before; | ||||
| import org.junit.Test; | ||||
| import org.junit.runner.RunWith; | ||||
|  | ||||
| import java.io.File; | ||||
| @ -11,294 +9,15 @@ import java.io.FileOutputStream; | ||||
| import java.io.IOException; | ||||
| import java.io.InputStream; | ||||
| import java.io.OutputStream; | ||||
| import java.util.HashMap; | ||||
| import java.util.Map; | ||||
|  | ||||
| import androidx.test.ext.junit.runners.AndroidJUnit4; | ||||
| import androidx.test.platform.app.InstrumentationRegistry; | ||||
|  | ||||
| import static org.junit.Assert.assertArrayEquals; | ||||
| import static org.junit.Assert.assertFalse; | ||||
| import static org.junit.Assert.assertNotNull; | ||||
| import static org.junit.Assert.assertTrue; | ||||
|  | ||||
| @RunWith(AndroidJUnit4.class) | ||||
| public class PytorchInstrumentedTests { | ||||
| public class PytorchInstrumentedTests extends PytorchTestBase { | ||||
|  | ||||
|   private static final String TEST_MODULE_ASSET_NAME = "test.pt"; | ||||
|  | ||||
|   @Before | ||||
|   public void setUp() { | ||||
|     System.loadLibrary("pytorch"); | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testForwardNull() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     final IValue input = | ||||
|         IValue.tensor(Tensor.newInt8Tensor(new long[] {1}, Tensor.allocateByteBuffer(1))); | ||||
|     assertTrue(input.isTensor()); | ||||
|     final IValue output = module.forward(input); | ||||
|     assertTrue(output.isNull()); | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testEqBool() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     for (boolean value : new boolean[] {false, true}) { | ||||
|       final IValue input = IValue.bool(value); | ||||
|       assertTrue(input.isBool()); | ||||
|       assertTrue(value == input.getBool()); | ||||
|       final IValue output = module.runMethod("eqBool", input); | ||||
|       assertTrue(output.isBool()); | ||||
|       assertTrue(value == output.getBool()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testEqInt() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) { | ||||
|       final IValue input = IValue.long64(value); | ||||
|       assertTrue(input.isLong()); | ||||
|       assertTrue(value == input.getLong()); | ||||
|       final IValue output = module.runMethod("eqInt", input); | ||||
|       assertTrue(output.isLong()); | ||||
|       assertTrue(value == output.getLong()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testEqFloat() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     double[] values = | ||||
|         new double[] { | ||||
|           -Double.MAX_VALUE, | ||||
|           Double.MAX_VALUE, | ||||
|           -Double.MIN_VALUE, | ||||
|           Double.MIN_VALUE, | ||||
|           -Math.exp(1.d), | ||||
|           -Math.sqrt(2.d), | ||||
|           -3.1415f, | ||||
|           3.1415f, | ||||
|           -1, | ||||
|           0, | ||||
|           1, | ||||
|         }; | ||||
|     for (double value : values) { | ||||
|       final IValue input = IValue.double64(value); | ||||
|       assertTrue(input.isDouble()); | ||||
|       assertTrue(value == input.getDouble()); | ||||
|       final IValue output = module.runMethod("eqFloat", input); | ||||
|       assertTrue(output.isDouble()); | ||||
|       assertTrue(value == output.getDouble()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testEqTensor() throws IOException { | ||||
|     final long[] inputTensorShape = new long[] {1, 3, 224, 224}; | ||||
|     final long numElements = Tensor.numel(inputTensorShape); | ||||
|     final float[] inputTensorData = new float[(int) numElements]; | ||||
|     for (int i = 0; i < numElements; ++i) { | ||||
|       inputTensorData[i] = i; | ||||
|     } | ||||
|     final Tensor inputTensor = Tensor.newFloat32Tensor(inputTensorShape, inputTensorData); | ||||
|  | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     final IValue input = IValue.tensor(inputTensor); | ||||
|     assertTrue(input.isTensor()); | ||||
|     assertTrue(inputTensor == input.getTensor()); | ||||
|     final IValue output = module.runMethod("eqTensor", input); | ||||
|     assertTrue(output.isTensor()); | ||||
|     final Tensor outputTensor = output.getTensor(); | ||||
|     assertNotNull(outputTensor); | ||||
|     assertArrayEquals(inputTensorShape, outputTensor.shape); | ||||
|     float[] outputData = outputTensor.getDataAsFloatArray(); | ||||
|     for (int i = 0; i < numElements; i++) { | ||||
|       assertTrue(inputTensorData[i] == outputData[i]); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testEqDictIntKeyIntValue() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     final Map<Long, IValue> inputMap = new HashMap<>(); | ||||
|  | ||||
|     inputMap.put(Long.MIN_VALUE, IValue.long64(-Long.MIN_VALUE)); | ||||
|     inputMap.put(Long.MAX_VALUE, IValue.long64(-Long.MAX_VALUE)); | ||||
|     inputMap.put(0l, IValue.long64(0l)); | ||||
|     inputMap.put(1l, IValue.long64(-1l)); | ||||
|     inputMap.put(-1l, IValue.long64(1l)); | ||||
|  | ||||
|     final IValue input = IValue.dictLongKey(inputMap); | ||||
|     assertTrue(input.isDictLongKey()); | ||||
|  | ||||
|     final IValue output = module.runMethod("eqDictIntKeyIntValue", input); | ||||
|     assertTrue(output.isDictLongKey()); | ||||
|  | ||||
|     final Map<Long, IValue> outputMap = output.getDictLongKey(); | ||||
|     assertTrue(inputMap.size() == outputMap.size()); | ||||
|     for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) { | ||||
|       assertTrue(outputMap.get(entry.getKey()).getLong() == entry.getValue().getLong()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testEqDictStrKeyIntValue() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     final Map<String, IValue> inputMap = new HashMap<>(); | ||||
|  | ||||
|     inputMap.put("long_min_value", IValue.long64(Long.MIN_VALUE)); | ||||
|     inputMap.put("long_max_value", IValue.long64(Long.MAX_VALUE)); | ||||
|     inputMap.put("long_0", IValue.long64(0l)); | ||||
|     inputMap.put("long_1", IValue.long64(1l)); | ||||
|     inputMap.put("long_-1", IValue.long64(-1l)); | ||||
|  | ||||
|     final IValue input = IValue.dictStringKey(inputMap); | ||||
|     assertTrue(input.isDictStringKey()); | ||||
|  | ||||
|     final IValue output = module.runMethod("eqDictStrKeyIntValue", input); | ||||
|     assertTrue(output.isDictStringKey()); | ||||
|  | ||||
|     final Map<String, IValue> outputMap = output.getDictStringKey(); | ||||
|     assertTrue(inputMap.size() == outputMap.size()); | ||||
|     for (Map.Entry<String, IValue> entry : inputMap.entrySet()) { | ||||
|       assertTrue(outputMap.get(entry.getKey()).getLong() == entry.getValue().getLong()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testListIntSumReturnTuple() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|  | ||||
|     for (int n : new int[] {0, 1, 128}) { | ||||
|       long[] a = new long[n]; | ||||
|       long sum = 0; | ||||
|       for (int i = 0; i < n; i++) { | ||||
|         a[i] = i; | ||||
|         sum += a[i]; | ||||
|       } | ||||
|       final IValue input = IValue.longList(a); | ||||
|       assertTrue(input.isLongList()); | ||||
|  | ||||
|       final IValue output = module.runMethod("listIntSumReturnTuple", input); | ||||
|  | ||||
|       assertTrue(output.isTuple()); | ||||
|       assertTrue(2 == output.getTuple().length); | ||||
|  | ||||
|       IValue output0 = output.getTuple()[0]; | ||||
|       IValue output1 = output.getTuple()[1]; | ||||
|  | ||||
|       assertArrayEquals(a, output0.getLongList()); | ||||
|       assertTrue(sum == output1.getLong()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testOptionalIntIsNone() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|  | ||||
|     assertFalse(module.runMethod("optionalIntIsNone", IValue.long64(1l)).getBool()); | ||||
|     assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).getBool()); | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testIntEq0None() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|  | ||||
|     assertTrue(module.runMethod("intEq0None", IValue.long64(0l)).isNull()); | ||||
|     assertTrue(module.runMethod("intEq0None", IValue.long64(1l)).getLong() == 1l); | ||||
|   } | ||||
|  | ||||
|   @Test(expected = IllegalArgumentException.class) | ||||
|   public void testRunUndefinedMethod() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     module.runMethod("test_undefined_method_throws_exception"); | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testTensorMethods() { | ||||
|     long[] shape = new long[] {1, 3, 224, 224}; | ||||
|     final int numel = (int) Tensor.numel(shape); | ||||
|     int[] ints = new int[numel]; | ||||
|     float[] floats = new float[numel]; | ||||
|  | ||||
|     byte[] bytes = new byte[numel]; | ||||
|     for (int i = 0; i < numel; i++) { | ||||
|       bytes[i] = (byte) ((i % 255) - 128); | ||||
|       ints[i] = i; | ||||
|       floats[i] = i / 1000.f; | ||||
|     } | ||||
|  | ||||
|     Tensor tensorBytes = Tensor.newInt8Tensor(shape, bytes); | ||||
|     assertTrue(tensorBytes.dtype() == Tensor.DTYPE_INT8); | ||||
|     assertArrayEquals(bytes, tensorBytes.getDataAsByteArray()); | ||||
|  | ||||
|     Tensor tensorInts = Tensor.newInt32Tensor(shape, ints); | ||||
|     assertTrue(tensorInts.dtype() == Tensor.DTYPE_INT32); | ||||
|     assertArrayEquals(ints, tensorInts.getDataAsIntArray()); | ||||
|  | ||||
|     Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats); | ||||
|     assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32); | ||||
|     float[] floatsOut = tensorFloats.getDataAsFloatArray(); | ||||
|     assertTrue(floatsOut.length == numel); | ||||
|     for (int i = 0; i < numel; i++) { | ||||
|       assertTrue(floats[i] == floatsOut[i]); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test(expected = IllegalStateException.class) | ||||
|   public void testTensorIllegalStateOnWrongType() { | ||||
|     long[] shape = new long[] {1, 3, 224, 224}; | ||||
|     final int numel = (int) Tensor.numel(shape); | ||||
|     float[] floats = new float[numel]; | ||||
|     Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats); | ||||
|     assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32); | ||||
|     tensorFloats.getDataAsByteArray(); | ||||
|   } | ||||
|  | ||||
|  | ||||
|   @Test | ||||
|   public void testEqString() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     String[] values = | ||||
|         new String[] { | ||||
|             "smoketest", | ||||
|             "проверка не латинских символов", // not latin symbols check | ||||
|             "#@$!@#)($*!@#$)(!@*#$" | ||||
|         }; | ||||
|     for (String value : values) { | ||||
|       final IValue input = IValue.string(value); | ||||
|       assertTrue(input.isString()); | ||||
|       assertTrue(value.equals(input.getString())); | ||||
|       final IValue output = module.runMethod("eqStr", input); | ||||
|       assertTrue(output.isString()); | ||||
|       assertTrue(value.equals(output.getString())); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testStr3Concat() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     String[] values = | ||||
|         new String[] { | ||||
|             "smoketest", | ||||
|             "проверка не латинских символов", // not latin symbols check | ||||
|             "#@$!@#)($*!@#$)(!@*#$" | ||||
|         }; | ||||
|     for (String value : values) { | ||||
|       final IValue input = IValue.string(value); | ||||
|       assertTrue(input.isString()); | ||||
|       assertTrue(value.equals(input.getString())); | ||||
|       final IValue output = module.runMethod("str3Concat", input); | ||||
|       assertTrue(output.isString()); | ||||
|       String expectedOutput = new StringBuilder().append(value).append(value).append(value).toString(); | ||||
|       assertTrue(expectedOutput.equals(output.getString())); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   private static String assetFilePath(String assetName) throws IOException { | ||||
|   @Override | ||||
|   protected String assetFilePath(String assetName) throws IOException { | ||||
|     final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext(); | ||||
|     File file = new File(appContext.getFilesDir(), assetName); | ||||
|     if (file.exists() && file.length() > 0) { | ||||
|  | ||||
| @ -0,0 +1,290 @@ | ||||
| package org.pytorch; | ||||
|  | ||||
| import org.junit.Before; | ||||
| import org.junit.Test; | ||||
|  | ||||
| import java.io.IOException; | ||||
| import java.util.HashMap; | ||||
| import java.util.Map; | ||||
|  | ||||
| import static org.junit.Assert.assertArrayEquals; | ||||
| import static org.junit.Assert.assertFalse; | ||||
| import static org.junit.Assert.assertNotNull; | ||||
| import static org.junit.Assert.assertTrue; | ||||
|  | ||||
| public abstract class PytorchTestBase { | ||||
|   private static final String TEST_MODULE_ASSET_NAME = "test.pt"; | ||||
|  | ||||
|   @Before | ||||
|   public void setUp() { | ||||
|     System.loadLibrary("pytorch"); | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testForwardNull() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     final IValue input = | ||||
|         IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1})); | ||||
|     assertTrue(input.isTensor()); | ||||
|     final IValue output = module.forward(input); | ||||
|     assertTrue(output.isNull()); | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testEqBool() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     for (boolean value : new boolean[] {false, true}) { | ||||
|       final IValue input = IValue.from(value); | ||||
|       assertTrue(input.isBool()); | ||||
|       assertTrue(value == input.toBool()); | ||||
|       final IValue output = module.runMethod("eqBool", input); | ||||
|       assertTrue(output.isBool()); | ||||
|       assertTrue(value == output.toBool()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testEqInt() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) { | ||||
|       final IValue input = IValue.from(value); | ||||
|       assertTrue(input.isLong()); | ||||
|       assertTrue(value == input.toLong()); | ||||
|       final IValue output = module.runMethod("eqInt", input); | ||||
|       assertTrue(output.isLong()); | ||||
|       assertTrue(value == output.toLong()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testEqFloat() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     double[] values = | ||||
|         new double[] { | ||||
|             -Double.MAX_VALUE, | ||||
|             Double.MAX_VALUE, | ||||
|             -Double.MIN_VALUE, | ||||
|             Double.MIN_VALUE, | ||||
|             -Math.exp(1.d), | ||||
|             -Math.sqrt(2.d), | ||||
|             -3.1415f, | ||||
|             3.1415f, | ||||
|             -1, | ||||
|             0, | ||||
|             1, | ||||
|         }; | ||||
|     for (double value : values) { | ||||
|       final IValue input = IValue.from(value); | ||||
|       assertTrue(input.isDouble()); | ||||
|       assertTrue(value == input.toDouble()); | ||||
|       final IValue output = module.runMethod("eqFloat", input); | ||||
|       assertTrue(output.isDouble()); | ||||
|       assertTrue(value == output.toDouble()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testEqTensor() throws IOException { | ||||
|     final long[] inputTensorShape = new long[] {1, 3, 224, 224}; | ||||
|     final long numElements = Tensor.numel(inputTensorShape); | ||||
|     final float[] inputTensorData = new float[(int) numElements]; | ||||
|     for (int i = 0; i < numElements; ++i) { | ||||
|       inputTensorData[i] = i; | ||||
|     } | ||||
|     final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape); | ||||
|  | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     final IValue input = IValue.from(inputTensor); | ||||
|     assertTrue(input.isTensor()); | ||||
|     assertTrue(inputTensor == input.toTensor()); | ||||
|     final IValue output = module.runMethod("eqTensor", input); | ||||
|     assertTrue(output.isTensor()); | ||||
|     final Tensor outputTensor = output.toTensor(); | ||||
|     assertNotNull(outputTensor); | ||||
|     assertArrayEquals(inputTensorShape, outputTensor.shape()); | ||||
|     float[] outputData = outputTensor.getDataAsFloatArray(); | ||||
|     for (int i = 0; i < numElements; i++) { | ||||
|       assertTrue(inputTensorData[i] == outputData[i]); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testEqDictIntKeyIntValue() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     final Map<Long, IValue> inputMap = new HashMap<>(); | ||||
|  | ||||
|     inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE)); | ||||
|     inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE)); | ||||
|     inputMap.put(0l, IValue.from(0l)); | ||||
|     inputMap.put(1l, IValue.from(-1l)); | ||||
|     inputMap.put(-1l, IValue.from(1l)); | ||||
|  | ||||
|     final IValue input = IValue.dictLongKeyFrom(inputMap); | ||||
|     assertTrue(input.isDictLongKey()); | ||||
|  | ||||
|     final IValue output = module.runMethod("eqDictIntKeyIntValue", input); | ||||
|     assertTrue(output.isDictLongKey()); | ||||
|  | ||||
|     final Map<Long, IValue> outputMap = output.toDictLongKey(); | ||||
|     assertTrue(inputMap.size() == outputMap.size()); | ||||
|     for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) { | ||||
|       assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testEqDictStrKeyIntValue() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     final Map<String, IValue> inputMap = new HashMap<>(); | ||||
|  | ||||
|     inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE)); | ||||
|     inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE)); | ||||
|     inputMap.put("long_0", IValue.from(0l)); | ||||
|     inputMap.put("long_1", IValue.from(1l)); | ||||
|     inputMap.put("long_-1", IValue.from(-1l)); | ||||
|  | ||||
|     final IValue input = IValue.dictStringKeyFrom(inputMap); | ||||
|     assertTrue(input.isDictStringKey()); | ||||
|  | ||||
|     final IValue output = module.runMethod("eqDictStrKeyIntValue", input); | ||||
|     assertTrue(output.isDictStringKey()); | ||||
|  | ||||
|     final Map<String, IValue> outputMap = output.toDictStringKey(); | ||||
|     assertTrue(inputMap.size() == outputMap.size()); | ||||
|     for (Map.Entry<String, IValue> entry : inputMap.entrySet()) { | ||||
|       assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testListIntSumReturnTuple() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|  | ||||
|     for (int n : new int[] {0, 1, 128}) { | ||||
|       long[] a = new long[n]; | ||||
|       long sum = 0; | ||||
|       for (int i = 0; i < n; i++) { | ||||
|         a[i] = i; | ||||
|         sum += a[i]; | ||||
|       } | ||||
|       final IValue input = IValue.listFrom(a); | ||||
|       assertTrue(input.isLongList()); | ||||
|  | ||||
|       final IValue output = module.runMethod("listIntSumReturnTuple", input); | ||||
|  | ||||
|       assertTrue(output.isTuple()); | ||||
|       assertTrue(2 == output.toTuple().length); | ||||
|  | ||||
|       IValue output0 = output.toTuple()[0]; | ||||
|       IValue output1 = output.toTuple()[1]; | ||||
|  | ||||
|       assertArrayEquals(a, output0.toLongList()); | ||||
|       assertTrue(sum == output1.toLong()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testOptionalIntIsNone() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|  | ||||
|     assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool()); | ||||
|     assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool()); | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testIntEq0None() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|  | ||||
|     assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull()); | ||||
|     assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l); | ||||
|   } | ||||
|  | ||||
|   @Test(expected = IllegalArgumentException.class) | ||||
|   public void testRunUndefinedMethod() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     module.runMethod("test_undefined_method_throws_exception"); | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testTensorMethods() { | ||||
|     long[] shape = new long[] {1, 3, 224, 224}; | ||||
|     final int numel = (int) Tensor.numel(shape); | ||||
|     int[] ints = new int[numel]; | ||||
|     float[] floats = new float[numel]; | ||||
|  | ||||
|     byte[] bytes = new byte[numel]; | ||||
|     for (int i = 0; i < numel; i++) { | ||||
|       bytes[i] = (byte) ((i % 255) - 128); | ||||
|       ints[i] = i; | ||||
|       floats[i] = i / 1000.f; | ||||
|     } | ||||
|  | ||||
|     Tensor tensorBytes = Tensor.fromBlob(bytes, shape); | ||||
|     assertTrue(tensorBytes.dtype() == DType.INT8); | ||||
|     assertArrayEquals(bytes, tensorBytes.getDataAsByteArray()); | ||||
|  | ||||
|     Tensor tensorInts = Tensor.fromBlob(ints, shape); | ||||
|     assertTrue(tensorInts.dtype() == DType.INT32); | ||||
|     assertArrayEquals(ints, tensorInts.getDataAsIntArray()); | ||||
|  | ||||
|     Tensor tensorFloats = Tensor.fromBlob(floats, shape); | ||||
|     assertTrue(tensorFloats.dtype() == DType.FLOAT32); | ||||
|     float[] floatsOut = tensorFloats.getDataAsFloatArray(); | ||||
|     assertTrue(floatsOut.length == numel); | ||||
|     for (int i = 0; i < numel; i++) { | ||||
|       assertTrue(floats[i] == floatsOut[i]); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test(expected = IllegalStateException.class) | ||||
|   public void testTensorIllegalStateOnWrongType() { | ||||
|     long[] shape = new long[] {1, 3, 224, 224}; | ||||
|     final int numel = (int) Tensor.numel(shape); | ||||
|     float[] floats = new float[numel]; | ||||
|     Tensor tensorFloats = Tensor.fromBlob(floats, shape); | ||||
|     assertTrue(tensorFloats.dtype() == DType.FLOAT32); | ||||
|     tensorFloats.getDataAsByteArray(); | ||||
|   } | ||||
|  | ||||
|  | ||||
|   @Test | ||||
|   public void testEqString() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     String[] values = | ||||
|         new String[] { | ||||
|             "smoketest", | ||||
|             "проверка не латинских символов", // not latin symbols check | ||||
|             "#@$!@#)($*!@#$)(!@*#$" | ||||
|         }; | ||||
|     for (String value : values) { | ||||
|       final IValue input = IValue.from(value); | ||||
|       assertTrue(input.isString()); | ||||
|       assertTrue(value.equals(input.toStr())); | ||||
|       final IValue output = module.runMethod("eqStr", input); | ||||
|       assertTrue(output.isString()); | ||||
|       assertTrue(value.equals(output.toStr())); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   @Test | ||||
|   public void testStr3Concat() throws IOException { | ||||
|     final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||||
|     String[] values = | ||||
|         new String[] { | ||||
|             "smoketest", | ||||
|             "проверка не латинских символов", // not latin symbols check | ||||
|             "#@$!@#)($*!@#$)(!@*#$" | ||||
|         }; | ||||
|     for (String value : values) { | ||||
|       final IValue input = IValue.from(value); | ||||
|       assertTrue(input.isString()); | ||||
|       assertTrue(value.equals(input.toStr())); | ||||
|       final IValue output = module.runMethod("str3Concat", input); | ||||
|       assertTrue(output.isString()); | ||||
|       String expectedOutput = new StringBuilder().append(value).append(value).append(value).toString(); | ||||
|       assertTrue(expectedOutput.equals(output.toStr())); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   protected abstract String assetFilePath(String assetName) throws IOException; | ||||
| } | ||||
| @ -10,6 +10,8 @@ | ||||
|  | ||||
| namespace pytorch_jni { | ||||
|  | ||||
| // NOTE: Codes must be kept in sync with DType.java. | ||||
| // NOTE: Never serialize these, because they can change between releases. | ||||
| constexpr static int kTensorDTypeUInt8 = 1; | ||||
| constexpr static int kTensorDTypeInt8 = 2; | ||||
| constexpr static int kTensorDTypeInt32 = 3; | ||||
| @ -164,7 +166,7 @@ class JTensor : public facebook::jni::JavaClass<JTensor> { | ||||
|   static at::Tensor newAtTensorFromJTensor( | ||||
|       facebook::jni::alias_ref<JTensor> jtensor) { | ||||
|     static const auto dtypeMethod = | ||||
|         JTensor::javaClassStatic()->getMethod<jint()>("dtype"); | ||||
|         JTensor::javaClassStatic()->getMethod<jint()>("dtypeJniCode"); | ||||
|     jint jdtype = dtypeMethod(jtensor); | ||||
|  | ||||
|     static const auto shapeField = | ||||
| @ -216,7 +218,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|       static auto jMethodTensor = | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getStaticMethod<facebook::jni::local_ref<JIValue>( | ||||
|                   facebook::jni::local_ref<JTensor>)>("tensor"); | ||||
|                   facebook::jni::local_ref<JTensor>)>("from"); | ||||
|       return jMethodTensor( | ||||
|           JIValue::javaClassStatic(), | ||||
|           JTensor::newJTensorFromAtTensor(ivalue.toTensor())); | ||||
| @ -224,26 +226,26 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|       static auto jMethodBool = | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getStaticMethod<facebook::jni::local_ref<JIValue>(jboolean)>( | ||||
|                   "bool"); | ||||
|                   "from"); | ||||
|       return jMethodBool(JIValue::javaClassStatic(), ivalue.toBool()); | ||||
|     } else if (ivalue.isInt()) { | ||||
|       static auto jMethodInt = | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getStaticMethod<facebook::jni::local_ref<JIValue>(jlong)>( | ||||
|                   "long64"); | ||||
|                   "from"); | ||||
|       return jMethodInt(JIValue::javaClassStatic(), ivalue.toInt()); | ||||
|     } else if (ivalue.isDouble()) { | ||||
|       static auto jMethodDouble = | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getStaticMethod<facebook::jni::local_ref<JIValue>(jdouble)>( | ||||
|                   "double64"); | ||||
|                   "from"); | ||||
|       return jMethodDouble(JIValue::javaClassStatic(), ivalue.toDouble()); | ||||
|     } else if (ivalue.isString()) { | ||||
|       static auto jMethodString = | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getStaticMethod<facebook::jni::local_ref<JIValue>( | ||||
|                   facebook::jni::alias_ref< | ||||
|                       facebook::jni::JString::javaobject>)>("string"); | ||||
|                       facebook::jni::JString::javaobject>)>("from"); | ||||
|       return jMethodString( | ||||
|           JIValue::javaClassStatic(), | ||||
|           facebook::jni::make_jstring(ivalue.toStringRef())); | ||||
| @ -253,7 +255,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getStaticMethod<facebook::jni::local_ref<JIValue>( | ||||
|                   facebook::jni::alias_ref<facebook::jni::JArrayClass< | ||||
|                       JIValue::javaobject>::javaobject>)>("tuple"); | ||||
|                       JIValue::javaobject>::javaobject>)>("tupleFrom"); | ||||
|       auto jElementsArray = | ||||
|           facebook::jni::JArrayClass<JIValue::javaobject>::newArray( | ||||
|               elementsVec.size()); | ||||
| @ -267,7 +269,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|       static auto jMethodBoolListArr = | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getStaticMethod<facebook::jni::local_ref<JIValue>( | ||||
|                   facebook::jni::alias_ref<jbooleanArray>)>("boolList"); | ||||
|                   facebook::jni::alias_ref<jbooleanArray>)>("listFrom"); | ||||
|       size_t n = list.size(); | ||||
|       auto jArray = facebook::jni::make_boolean_array(n); | ||||
|       auto jArrayPinned = jArray->pin(); | ||||
| @ -281,7 +283,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|       static auto jMethodLongListArr = | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getStaticMethod<facebook::jni::local_ref<JIValue>( | ||||
|                   facebook::jni::alias_ref<jlongArray>)>("longList"); | ||||
|                   facebook::jni::alias_ref<jlongArray>)>("listFrom"); | ||||
|       size_t n = list.size(); | ||||
|       auto jArray = facebook::jni::make_long_array(n); | ||||
|       auto jArrayPinned = jArray->pin(); | ||||
| @ -295,7 +297,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|       static auto jMethoDoubleListArr = | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getStaticMethod<facebook::jni::local_ref<JIValue>( | ||||
|                   facebook::jni::alias_ref<jdoubleArray>)>("doubleList"); | ||||
|                   facebook::jni::alias_ref<jdoubleArray>)>("listFrom"); | ||||
|       size_t n = list.size(); | ||||
|       auto jArray = facebook::jni::make_double_array(n); | ||||
|       auto jArrayPinned = jArray->pin(); | ||||
| @ -310,7 +312,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getStaticMethod<facebook::jni::local_ref<JIValue>( | ||||
|                   facebook::jni::alias_ref<facebook::jni::JArrayClass< | ||||
|                       JTensor::javaobject>::javaobject>)>("tensorList"); | ||||
|                       JTensor::javaobject>::javaobject>)>("listFrom"); | ||||
|       auto jArray = facebook::jni::JArrayClass<JTensor::javaobject>::newArray( | ||||
|           list.size()); | ||||
|       auto index = 0; | ||||
| @ -324,7 +326,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getStaticMethod<facebook::jni::local_ref<JIValue>( | ||||
|                   facebook::jni::alias_ref<facebook::jni::JArrayClass< | ||||
|                       JIValue::javaobject>::javaobject>)>("list"); | ||||
|                       JIValue::javaobject>::javaobject>)>("listFrom"); | ||||
|       auto jArray = facebook::jni::JArrayClass<JIValue::javaobject>::newArray( | ||||
|           list.size()); | ||||
|       auto index = 0; | ||||
| @ -351,7 +353,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|                         facebook::jni::alias_ref< | ||||
|                             facebook::jni::JString::javaobject>, | ||||
|                         facebook::jni::alias_ref<JIValue::javaobject>>>)>( | ||||
|                     "dictStringKey"); | ||||
|                     "dictStringKeyFrom"); | ||||
|  | ||||
|         auto jmap = JHashMap< | ||||
|             facebook::jni::alias_ref<facebook::jni::JString::javaobject>, | ||||
| @ -370,7 +372,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|                         facebook::jni::alias_ref< | ||||
|                             facebook::jni::JLong::javaobject>, | ||||
|                         facebook::jni::alias_ref<JIValue::javaobject>>>)>( | ||||
|                     "dictLongKey"); | ||||
|                     "dictLongKeyFrom"); | ||||
|         auto jmap = JHashMap< | ||||
|             facebook::jni::alias_ref<facebook::jni::JLong::javaobject>, | ||||
|             facebook::jni::alias_ref<JIValue::javaobject>>::create(); | ||||
| @ -404,32 +406,32 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|       static const auto jMethodGetTensor = | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getMethod<facebook::jni::alias_ref<JTensor::javaobject>()>( | ||||
|                   "getTensor"); | ||||
|                   "toTensor"); | ||||
|       return JTensor::newAtTensorFromJTensor(jMethodGetTensor(jivalue)); | ||||
|     } else if (JIValue::kTypeCodeBool == typeCode) { | ||||
|       static const auto jMethodGetBool = | ||||
|           JIValue::javaClassStatic()->getMethod<jboolean()>("getBool"); | ||||
|           JIValue::javaClassStatic()->getMethod<jboolean()>("toBool"); | ||||
|       // explicit cast to bool as jboolean is defined as uint8_t, IValue ctor | ||||
|       // for int will be called for jboolean | ||||
|       bool b = jMethodGetBool(jivalue); | ||||
|       return at::IValue{b}; | ||||
|     } else if (JIValue::kTypeCodeLong == typeCode) { | ||||
|       static const auto jMethodGetLong = | ||||
|           JIValue::javaClassStatic()->getMethod<jlong()>("getLong"); | ||||
|           JIValue::javaClassStatic()->getMethod<jlong()>("toLong"); | ||||
|       return at::IValue{jMethodGetLong(jivalue)}; | ||||
|     } else if (JIValue::kTypeCodeDouble == typeCode) { | ||||
|       static const auto jMethodGetDouble = | ||||
|           JIValue::javaClassStatic()->getMethod<jdouble()>("getDouble"); | ||||
|           JIValue::javaClassStatic()->getMethod<jdouble()>("toDouble"); | ||||
|       return at::IValue{jMethodGetDouble(jivalue)}; | ||||
|     } else if (JIValue::kTypeCodeString == typeCode) { | ||||
|       static const auto jMethodGetString = | ||||
|           JIValue::javaClassStatic()->getMethod<jstring()>("getString"); | ||||
|           JIValue::javaClassStatic()->getMethod<jstring()>("toStr"); | ||||
|       return at::IValue{jMethodGetString(jivalue)->toStdString()}; | ||||
|     } else if (JIValue::kTypeCodeTuple == typeCode) { | ||||
|       static const auto jMethodGetTuple = | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getMethod<facebook::jni::JArrayClass< | ||||
|                   JIValue::javaobject>::javaobject()>("getTuple"); | ||||
|                   JIValue::javaobject>::javaobject()>("toTuple"); | ||||
|       auto jarray = jMethodGetTuple(jivalue); | ||||
|       size_t n = jarray->size(); | ||||
|  | ||||
| @ -443,7 +445,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|       return c10::ivalue::Tuple::create(std::move(elements)); | ||||
|     } else if (JIValue::kTypeCodeBoolList == typeCode) { | ||||
|       static const auto jMethodGetBoolList = | ||||
|           JIValue::javaClassStatic()->getMethod<jbooleanArray()>("getBoolList"); | ||||
|           JIValue::javaClassStatic()->getMethod<jbooleanArray()>("toBoolList"); | ||||
|       auto jArray = jMethodGetBoolList(jivalue); | ||||
|       auto jArrayPinned = jArray->pin(); | ||||
|       size_t n = jArrayPinned.size(); | ||||
| @ -455,7 +457,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|       return at::IValue{std::move(list)}; | ||||
|     } else if (JIValue::kTypeCodeLongList == typeCode) { | ||||
|       static const auto jMethodGetLongList = | ||||
|           JIValue::javaClassStatic()->getMethod<jlongArray()>("getLongList"); | ||||
|           JIValue::javaClassStatic()->getMethod<jlongArray()>("toLongList"); | ||||
|       auto jArray = jMethodGetLongList(jivalue); | ||||
|       auto jArrayPinned = jArray->pin(); | ||||
|       size_t n = jArrayPinned.size(); | ||||
| @ -468,7 +470,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|     } else if (JIValue::kTypeCodeDoubleList == typeCode) { | ||||
|       static const auto jMethodGetDoubleList = | ||||
|           JIValue::javaClassStatic()->getMethod<jdoubleArray()>( | ||||
|               "getDoubleList"); | ||||
|               "toDoubleList"); | ||||
|       auto jArray = jMethodGetDoubleList(jivalue); | ||||
|       auto jArrayPinned = jArray->pin(); | ||||
|       size_t n = jArrayPinned.size(); | ||||
| @ -482,7 +484,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|       static const auto jMethodGetTensorList = | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getMethod<facebook::jni::JArrayClass< | ||||
|                   JTensor::javaobject>::javaobject()>("getTensorList"); | ||||
|                   JTensor::javaobject>::javaobject()>("toTensorList"); | ||||
|       auto jArray = jMethodGetTensorList(jivalue); | ||||
|       size_t n = jArray->size(); | ||||
|       c10::List<at::Tensor> list{}; | ||||
| @ -495,7 +497,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|       static const auto jMethodGetList = | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getMethod<facebook::jni::JArrayClass< | ||||
|                   JIValue::javaobject>::javaobject()>("getList"); | ||||
|                   JIValue::javaobject>::javaobject()>("toList"); | ||||
|       auto jarray = jMethodGetList(jivalue); | ||||
|       size_t n = jarray->size(); | ||||
|       if (n == 0) { | ||||
| @ -518,7 +520,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|       static const auto jMethodGetDictStringKey = | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getMethod<facebook::jni::JMap<jstring, JIValue::javaobject>:: | ||||
|                               javaobject()>("getDictStringKey"); | ||||
|                               javaobject()>("toDictStringKey"); | ||||
|       auto jmap = jMethodGetDictStringKey(jivalue); | ||||
|       auto it = jmap->begin(); | ||||
|       if (it == jmap->end()) { | ||||
| @ -541,7 +543,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> { | ||||
|           JIValue::javaClassStatic() | ||||
|               ->getMethod<facebook::jni::JMap< | ||||
|                   facebook::jni::JLong::javaobject, | ||||
|                   JIValue::javaobject>::javaobject()>("getDictLongKey"); | ||||
|                   JIValue::javaobject>::javaobject()>("toDictLongKey"); | ||||
|       auto jmap = jMethodGetDictLongKey(jivalue); | ||||
|       auto it = jmap->begin(); | ||||
|       if (it == jmap->end()) { | ||||
|  | ||||
							
								
								
									
										29
									
								
								android/pytorch_android/src/main/java/org/pytorch/DType.java
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								android/pytorch_android/src/main/java/org/pytorch/DType.java
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,29 @@ | ||||
| package org.pytorch; | ||||
|  | ||||
| /** | ||||
|  * Codes representing tensor data types. | ||||
|  */ | ||||
| public enum DType { | ||||
|   // NOTE: "jniCode" must be kept in sync with pytorch_jni.cpp. | ||||
|   // NOTE: Never serialize "jniCode", because it can change between releases. | ||||
|  | ||||
|   /** Code for dtype torch.uint8. {@link Tensor#dtype()} */ | ||||
|   UINT8(1), | ||||
|   /** Code for dtype torch.int8. {@link Tensor#dtype()} */ | ||||
|   INT8(2), | ||||
|   /** Code for dtype torch.int32. {@link Tensor#dtype()} */ | ||||
|   INT32(3), | ||||
|   /** Code for dtype torch.float32. {@link Tensor#dtype()} */ | ||||
|   FLOAT32(4), | ||||
|   /** Code for dtype torch.int64. {@link Tensor#dtype()} */ | ||||
|   INT64(5), | ||||
|   /** Code for dtype torch.float64. {@link Tensor#dtype()} */ | ||||
|   FLOAT64(6), | ||||
|   ; | ||||
|  | ||||
|   final int jniCode; | ||||
|  | ||||
|   DType(int jniCode) { | ||||
|     this.jniCode = jniCode; | ||||
|   } | ||||
| } | ||||
| @ -4,10 +4,21 @@ import java.util.Locale; | ||||
| import java.util.Map; | ||||
|  | ||||
| /** | ||||
|  * Java representation of a torchscript variable, which is implemented as tagged union that can be | ||||
|  * one of the supported types: https://pytorch.org/docs/stable/jit.html#types. | ||||
|  * Java representation of a TorchScript value, which is implemented as tagged union that can be | ||||
|  * one of the supported types: https://pytorch.org/docs/stable/jit.html#types . | ||||
|  * <p> | ||||
|  * Calling getters for inappropriate types will throw IllegalStateException. | ||||
|  * Calling {@code toX} methods for inappropriate types will throw {@link IllegalStateException}. | ||||
|  * <p> | ||||
|  * {@code IValue} objects are constructed with {@code IValue.from(value)}, | ||||
|  * {@code IValue.tupleFrom(value1, value2, ...)}, {@code IValue.listFrom(value1, value2, ...)}, | ||||
|  * or one of the {@code dict} methods, depending on the key type. | ||||
|  * <p> | ||||
|  * Data is retrieved from {@code IValue} objects with the {@code toX()} methods.  Note that | ||||
|  * {@code str}-type IValues must be extracted with {@link #toStr()}, | ||||
|  * rather than {@link #toString()}. | ||||
|  * <p> | ||||
|  * {@code IValue} objects may retain references to objects passed into their constructors, | ||||
|  * and may return references to their internal state from {@code toX()}. | ||||
|  */ | ||||
| public class IValue { | ||||
|   private static final int TYPE_CODE_NULL = 1; | ||||
| @ -91,95 +102,98 @@ public class IValue { | ||||
|     return TYPE_CODE_DICT_LONG_KEY == this.mTypeCode; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Creates a new {@code IValue} of type {@code Optional} that contains no value. | ||||
|    */ | ||||
|   public static IValue optionalNull() { | ||||
|     return new IValue(TYPE_CODE_NULL); | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Creates a new IValue instance of torchscript Tensor type. | ||||
|    * Creates a new {@code IValue} of type {@code Tensor}. | ||||
|    */ | ||||
|   public static IValue tensor(Tensor tensor) { | ||||
|   public static IValue from(Tensor tensor) { | ||||
|     final IValue iv = new IValue(TYPE_CODE_TENSOR); | ||||
|     iv.mData = tensor; | ||||
|     return iv; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Creates a new IValue instance of torchscript bool type. | ||||
|    * Creates a new {@code IValue} of type {@code bool}. | ||||
|    */ | ||||
|   public static IValue bool(boolean value) { | ||||
|   public static IValue from(boolean value) { | ||||
|     final IValue iv = new IValue(TYPE_CODE_BOOL); | ||||
|     iv.mData = value; | ||||
|     return iv; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Creates a new IValue instance of torchscript int type. | ||||
|    * Creates a new {@code IValue} of type {@code int}. | ||||
|    */ | ||||
|   public static IValue long64(long value) { | ||||
|   public static IValue from(long value) { | ||||
|     final IValue iv = new IValue(TYPE_CODE_LONG); | ||||
|     iv.mData = value; | ||||
|     return iv; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Creates a new IValue instance of torchscript float type. | ||||
|    * Creates a new {@code IValue} of type {@code float}. | ||||
|    */ | ||||
|   public static IValue double64(double value) { | ||||
|   public static IValue from(double value) { | ||||
|     final IValue iv = new IValue(TYPE_CODE_DOUBLE); | ||||
|     iv.mData = value; | ||||
|     return iv; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Creates new IValue instance of torchscript str type. | ||||
|    * Creates a new {@code IValue} of type {@code str}. | ||||
|    */ | ||||
|   public static IValue string(String value) { | ||||
|   public static IValue from(String value) { | ||||
|     final IValue iv = new IValue(TYPE_CODE_STRING); | ||||
|     iv.mData = value; | ||||
|     return iv; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Creates a new IValue instance of torchscript List[bool] type. | ||||
|    * Creates a new {@code IValue} of type {@code List[bool]}. | ||||
|    */ | ||||
|   public static IValue boolList(boolean... list) { | ||||
|   public static IValue listFrom(boolean... list) { | ||||
|     final IValue iv = new IValue(TYPE_CODE_BOOL_LIST); | ||||
|     iv.mData = list; | ||||
|     return iv; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Creates a new IValue instance of torchscript List[int] type. | ||||
|    * Creates a new {@code IValue} of type {@code List[int]}. | ||||
|    */ | ||||
|   public static IValue longList(long... list) { | ||||
|   public static IValue listFrom(long... list) { | ||||
|     final IValue iv = new IValue(TYPE_CODE_LONG_LIST); | ||||
|     iv.mData = list; | ||||
|     return iv; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Creates a new IValue instance of torchscript List[float] type. | ||||
|    * Creates a new {@code IValue} of type {@code List[float]}. | ||||
|    */ | ||||
|   public static IValue doubleList(double... list) { | ||||
|   public static IValue listFrom(double... list) { | ||||
|     final IValue iv = new IValue(TYPE_CODE_DOUBLE_LIST); | ||||
|     iv.mData = list; | ||||
|     return iv; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Creates a new IValue instance of torchscript List[Tensor] type. | ||||
|    * Creates a new {@code IValue} of type {@code List[Tensor]}. | ||||
|    */ | ||||
|   public static IValue tensorList(Tensor... list) { | ||||
|   public static IValue listFrom(Tensor... list) { | ||||
|     final IValue iv = new IValue(TYPE_CODE_TENSOR_LIST); | ||||
|     iv.mData = list; | ||||
|     return iv; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Creates a new IValue instance of torchscript List[T] type. All elements must have the same type. | ||||
|    * Creates a new {@code IValue} of type {@code List[T]}.  All elements must have the same type. | ||||
|    */ | ||||
|   public static IValue list(IValue... array) { | ||||
|   public static IValue listFrom(IValue... array) { | ||||
|     final int size = array.length; | ||||
|     if (size > 0) { | ||||
|       final int typeCode0 = array[0].mTypeCode; | ||||
| @ -196,93 +210,93 @@ public class IValue { | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Creates a new IValue instance of torchscript Tuple[T0, T1, ...] type. | ||||
|    * Creates a new {@code IValue} of type {@code Tuple[T0, T1, ...]}. | ||||
|    */ | ||||
|   public static IValue tuple(IValue... array) { | ||||
|   public static IValue tupleFrom(IValue... array) { | ||||
|     final IValue iv = new IValue(TYPE_CODE_TUPLE); | ||||
|     iv.mData = array; | ||||
|     return iv; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Creates a new IValue instance oftorchscript Dict[Str, V] type. | ||||
|    * Creates a new {@code IValue} of type {@code Dict[str, V]}. | ||||
|    */ | ||||
|   public static IValue dictStringKey(Map<String, IValue> map) { | ||||
|   public static IValue dictStringKeyFrom(Map<String, IValue> map) { | ||||
|     final IValue iv = new IValue(TYPE_CODE_DICT_STRING_KEY); | ||||
|     iv.mData = map; | ||||
|     return iv; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Creates a new IValue instance of torchscript Dict[int, V] type. | ||||
|    * Creates a new {@code IValue} of type {@code Dict[int, V]}. | ||||
|    */ | ||||
|   public static IValue dictLongKey(Map<Long, IValue> map) { | ||||
|   public static IValue dictLongKeyFrom(Map<Long, IValue> map) { | ||||
|     final IValue iv = new IValue(TYPE_CODE_DICT_LONG_KEY); | ||||
|     iv.mData = map; | ||||
|     return iv; | ||||
|   } | ||||
|  | ||||
|   public Tensor getTensor() { | ||||
|   public Tensor toTensor() { | ||||
|     preconditionType(TYPE_CODE_TENSOR, mTypeCode); | ||||
|     return (Tensor) mData; | ||||
|   } | ||||
|  | ||||
|   public boolean getBool() { | ||||
|   public boolean toBool() { | ||||
|     preconditionType(TYPE_CODE_BOOL, mTypeCode); | ||||
|     return (boolean) mData; | ||||
|   } | ||||
|  | ||||
|   public long getLong() { | ||||
|   public long toLong() { | ||||
|     preconditionType(TYPE_CODE_LONG, mTypeCode); | ||||
|     return (long) mData; | ||||
|   } | ||||
|  | ||||
|   public double getDouble() { | ||||
|   public double toDouble() { | ||||
|     preconditionType(TYPE_CODE_DOUBLE, mTypeCode); | ||||
|     return (double) mData; | ||||
|   } | ||||
|  | ||||
|   public String getString() { | ||||
|   public String toStr() { | ||||
|     preconditionType(TYPE_CODE_STRING, mTypeCode); | ||||
|     return (String) mData; | ||||
|   } | ||||
|  | ||||
|   public boolean[] getBoolList() { | ||||
|   public boolean[] toBoolList() { | ||||
|     preconditionType(TYPE_CODE_BOOL_LIST, mTypeCode); | ||||
|     return (boolean[]) mData; | ||||
|   } | ||||
|  | ||||
|   public long[] getLongList() { | ||||
|   public long[] toLongList() { | ||||
|     preconditionType(TYPE_CODE_LONG_LIST, mTypeCode); | ||||
|     return (long[]) mData; | ||||
|   } | ||||
|  | ||||
|   public double[] getDoubleList() { | ||||
|   public double[] toDoubleList() { | ||||
|     preconditionType(TYPE_CODE_DOUBLE_LIST, mTypeCode); | ||||
|     return (double[]) mData; | ||||
|   } | ||||
|  | ||||
|   public Tensor[] getTensorList() { | ||||
|   public Tensor[] toTensorList() { | ||||
|     preconditionType(TYPE_CODE_TENSOR_LIST, mTypeCode); | ||||
|     return (Tensor[]) mData; | ||||
|   } | ||||
|  | ||||
|   public IValue[] getList() { | ||||
|   public IValue[] toList() { | ||||
|     preconditionType(TYPE_CODE_LIST, mTypeCode); | ||||
|     return (IValue[]) mData; | ||||
|   } | ||||
|  | ||||
|   public IValue[] getTuple() { | ||||
|   public IValue[] toTuple() { | ||||
|     preconditionType(TYPE_CODE_TUPLE, mTypeCode); | ||||
|     return (IValue[]) mData; | ||||
|   } | ||||
|  | ||||
|   public Map<String, IValue> getDictStringKey() { | ||||
|   public Map<String, IValue> toDictStringKey() { | ||||
|     preconditionType(TYPE_CODE_DICT_STRING_KEY, mTypeCode); | ||||
|     return (Map<String, IValue>) mData; | ||||
|   } | ||||
|  | ||||
|   public Map<Long, IValue> getDictLongKey() { | ||||
|   public Map<Long, IValue> toDictLongKey() { | ||||
|     preconditionType(TYPE_CODE_DICT_LONG_KEY, mTypeCode); | ||||
|     return (Map<Long, IValue>) mData; | ||||
|   } | ||||
|  | ||||
| @ -5,21 +5,20 @@ package org.pytorch; | ||||
| import com.facebook.jni.HybridData; | ||||
|  | ||||
| /** | ||||
|  * Java holder for torch::jit::script::Module which owns it on jni side. | ||||
|  * Java wrapper for torch::jit::script::Module. | ||||
|  */ | ||||
| public class Module { | ||||
|  | ||||
|   private NativePeer mNativePeer; | ||||
|  | ||||
|   /** | ||||
|    * Loads serialized torchscript module from the specified absolute path on the disk. | ||||
|    * Loads a serialized TorchScript module from the specified path on the disk. | ||||
|    * | ||||
|    * @param modelAbsolutePath absolute path to file that contains the serialized torchscript module. | ||||
|    * @return new {@link org.pytorch.Module} object which owns torch::jit::script::Module on jni | ||||
|    * side. | ||||
|    * @param modelPath path to file that contains the serialized TorchScript module. | ||||
|    * @return new {@link org.pytorch.Module} object which owns torch::jit::script::Module. | ||||
|    */ | ||||
|   public static Module load(final String modelAbsolutePath) { | ||||
|     return new Module(modelAbsolutePath); | ||||
|   public static Module load(final String modelPath) { | ||||
|     return new Module(modelPath); | ||||
|   } | ||||
|  | ||||
|   private Module(final String moduleAbsolutePath) { | ||||
| @ -27,26 +26,37 @@ public class Module { | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Runs 'forward' method of loaded torchscript module with specified arguments. | ||||
|    * Runs the 'forward' method of this module with the specified arguments. | ||||
|    * | ||||
|    * @param inputs arguments for torchscript module 'forward' method. | ||||
|    * @return result of torchscript module 'forward' method evaluation | ||||
|    * @param inputs arguments for the TorchScript module's 'forward' method. | ||||
|    * @return return value from the 'forward' method. | ||||
|    */ | ||||
|   public IValue forward(IValue... inputs) { | ||||
|     return mNativePeer.forward(inputs); | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Runs specified method of loaded torchscript module with specified arguments. | ||||
|    * Runs the specified method of this module with the specified arguments. | ||||
|    * | ||||
|    * @param methodName torchscript module method to run | ||||
|    * @param inputs     arguments that will be specified to torchscript module method call | ||||
|    * @return result of torchscript module specified method evaluation | ||||
|    * @param methodName name of the TorchScript method to run. | ||||
|    * @param inputs     arguments that will be passed to TorchScript method. | ||||
|    * @return return value from the method. | ||||
|    */ | ||||
|   public IValue runMethod(String methodName, IValue... inputs) { | ||||
|     return mNativePeer.runMethod(methodName, inputs); | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Explicitly destroys the native torch::jit::script::Module. | ||||
|    * Calling this method is not required, as the native object will be destroyed | ||||
|    * when this object is garbage-collected.  However, the timing of garbage collection | ||||
|    * is not guaranteed, so proactively calling {@code destroy} can free memory more quickly. | ||||
|    * See {@link com.facebook.jni.HybridData#resetNative}. | ||||
|    */ | ||||
|   public void destroy() { | ||||
|     mNativePeer.mHybridData.resetNative(); | ||||
|   } | ||||
|  | ||||
|   private static class NativePeer { | ||||
|     static { | ||||
|       System.loadLibrary("pytorch"); | ||||
|  | ||||
| @ -11,24 +11,24 @@ import java.util.Arrays; | ||||
| import java.util.Locale; | ||||
|  | ||||
| /** | ||||
|  * Representation of Tensor. Tensor shape is stored in {@link Tensor#shape}, elements are stored as | ||||
|  * {@link java.nio.DirectByteBuffer} of one of the supported types. | ||||
|  * Representation of a Tensor.  Behavior is similar to PyTorch's tensor objects. | ||||
|  * <p> | ||||
|  * Most tensors will be constructed as {@code Tensor.fromBlob(data, shape)}, | ||||
|  * where {@code data} can be an array or a direct {@link Buffer} (of the proper subclass). | ||||
|  * Helper methods are provided to allocate buffers properly. | ||||
|  * <p> | ||||
|  * To access Tensor data, see {@link #dtype()}, {@link #shape()}, | ||||
|  * and various {@code getDataAs*} methods. | ||||
|  * <p> | ||||
|  * When constructing {@code Tensor} objects with {@code data} as an array, | ||||
|  * it is not specified whether this data is is copied or retained as a reference | ||||
|  * so it is recommended not to modify it after constructing.  {@code data} passed as a | ||||
|  * {@link Buffer} is not copied, so it can be modified between {@link Module} calls | ||||
|  * to avoid reallocation.  Data retrieved from {@code Tensor} objects may be copied or | ||||
|  * may be a reference to the {@code Tensor}'s internal data buffer. | ||||
|  * {@code shape} is always copied. | ||||
|  */ | ||||
| public abstract class Tensor { | ||||
|  | ||||
|   /** Code for dtype torch.uint8. {@link Tensor#dtype()} */ | ||||
|   public static final int DTYPE_UINT8 = 1; | ||||
|   /** Code for dtype torch.int8. {@link Tensor#dtype()} */ | ||||
|   public static final int DTYPE_INT8 = 2; | ||||
|   /** Code for dtype torch.int32. {@link Tensor#dtype()} */ | ||||
|   public static final int DTYPE_INT32 = 3; | ||||
|   /** Code for dtype torch.float32. {@link Tensor#dtype()} */ | ||||
|   public static final int DTYPE_FLOAT32 = 4; | ||||
|   /** Code for dtype torch.int64. {@link Tensor#dtype()} */ | ||||
|   public static final int DTYPE_INT64 = 5; | ||||
|   /** Code for dtype torch.float64. {@link Tensor#dtype()} */ | ||||
|   public static final int DTYPE_FLOAT64 = 6; | ||||
|  | ||||
|   private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null"; | ||||
|   private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null"; | ||||
|   private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null"; | ||||
| @ -39,8 +39,7 @@ public abstract class Tensor { | ||||
|   private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT = | ||||
|       "Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)"; | ||||
|  | ||||
|   /** Shape of current tensor. */ | ||||
|   public final long[] shape; | ||||
|   final long[] shape; | ||||
|  | ||||
|   private static final int INT_SIZE_BYTES = 4; | ||||
|   private static final int FLOAT_SIZE_BYTES = 4; | ||||
| @ -49,8 +48,8 @@ public abstract class Tensor { | ||||
|  | ||||
|   /** | ||||
|    * Allocates a new direct {@link java.nio.ByteBuffer} with native byte order with specified | ||||
|    * capacity that can be used in {@link Tensor#newInt8Tensor(long[], ByteBuffer)}, {@link | ||||
|    * Tensor#newUInt8Tensor(long[], ByteBuffer)}. | ||||
|    * capacity that can be used in {@link Tensor#fromBlob(ByteBuffer, long[])}, | ||||
|    * {@link Tensor#fromBlobUnsigned(ByteBuffer, long[])}. | ||||
|    * | ||||
|    * @param numElements capacity (number of elements) of result buffer. | ||||
|    */ | ||||
| @ -58,6 +57,12 @@ public abstract class Tensor { | ||||
|     return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder()); | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Allocates a new direct {@link java.nio.IntBuffer} with native byte order with specified | ||||
|    * capacity that can be used in {@link Tensor#fromBlob(IntBuffer, long[])}. | ||||
|    * | ||||
|    * @param numElements capacity (number of elements) of result buffer. | ||||
|    */ | ||||
|   public static IntBuffer allocateIntBuffer(int numElements) { | ||||
|     return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES) | ||||
|         .order(ByteOrder.nativeOrder()) | ||||
| @ -66,7 +71,7 @@ public abstract class Tensor { | ||||
|  | ||||
|   /** | ||||
|    * Allocates a new direct {@link java.nio.FloatBuffer} with native byte order with specified | ||||
|    * capacity that can be used in {@link Tensor#newFloat32Tensor(long[], FloatBuffer)}. | ||||
|    * capacity that can be used in {@link Tensor#fromBlob(FloatBuffer, long[])}. | ||||
|    * | ||||
|    * @param numElements capacity (number of elements) of result buffer. | ||||
|    */ | ||||
| @ -78,7 +83,7 @@ public abstract class Tensor { | ||||
|  | ||||
|   /** | ||||
|    * Allocates a new direct {@link java.nio.LongBuffer} with native byte order with specified | ||||
|    * capacity that can be used in {@link Tensor#newInt64Tensor(long[], LongBuffer)}. | ||||
|    * capacity that can be used in {@link Tensor#fromBlob(LongBuffer, long[])}. | ||||
|    * | ||||
|    * @param numElements capacity (number of elements) of result buffer. | ||||
|    */ | ||||
| @ -90,7 +95,7 @@ public abstract class Tensor { | ||||
|  | ||||
|   /** | ||||
|    * Allocates a new direct {@link java.nio.DoubleBuffer} with native byte order with specified | ||||
|    * capacity that can be used in {@link Tensor#newFloat64Tensor(long[], DoubleBuffer)}. | ||||
|    * capacity that can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}. | ||||
|    * | ||||
|    * @param numElements capacity (number of elements) of result buffer. | ||||
|    */ | ||||
| @ -104,10 +109,10 @@ public abstract class Tensor { | ||||
|    * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of | ||||
|    * bytes. | ||||
|    * | ||||
|    * @param shape Tensor shape | ||||
|    * @param data Tensor elements | ||||
|    * @param shape Tensor shape | ||||
|    */ | ||||
|   public static Tensor newUInt8Tensor(long[] shape, byte[] data) { | ||||
|   public static Tensor fromBlobUnsigned(byte[] data, long[] shape) { | ||||
|     checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); | ||||
|     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); | ||||
|     checkShape(shape); | ||||
| @ -121,10 +126,10 @@ public abstract class Tensor { | ||||
|    * Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of | ||||
|    * bytes. | ||||
|    * | ||||
|    * @param shape Tensor shape | ||||
|    * @param data Tensor elements | ||||
|    * @param shape Tensor shape | ||||
|    */ | ||||
|   public static Tensor newInt8Tensor(long[] shape, byte[] data) { | ||||
|   public static Tensor fromBlob(byte[] data, long[] shape) { | ||||
|     checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); | ||||
|     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); | ||||
|     checkShape(shape); | ||||
| @ -138,10 +143,10 @@ public abstract class Tensor { | ||||
|    * Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of | ||||
|    * ints. | ||||
|    * | ||||
|    * @param shape Tensor shape | ||||
|    * @param data Tensor elements | ||||
|    * @param shape Tensor shape | ||||
|    */ | ||||
|   public static Tensor newInt32Tensor(long[] shape, int[] data) { | ||||
|   public static Tensor fromBlob(int[] data, long[] shape) { | ||||
|     checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); | ||||
|     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); | ||||
|     checkShape(shape); | ||||
| @ -155,10 +160,10 @@ public abstract class Tensor { | ||||
|    * Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array | ||||
|    * of floats. | ||||
|    * | ||||
|    * @param shape Tensor shape | ||||
|    * @param data Tensor elements | ||||
|    * @param shape Tensor shape | ||||
|    */ | ||||
|   public static Tensor newFloat32Tensor(long[] shape, float[] data) { | ||||
|   public static Tensor fromBlob(float[] data, long[] shape) { | ||||
|     checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); | ||||
|     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); | ||||
|     checkShape(shape); | ||||
| @ -172,10 +177,10 @@ public abstract class Tensor { | ||||
|    * Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of | ||||
|    * longs. | ||||
|    * | ||||
|    * @param shape Tensor shape | ||||
|    * @param data Tensor elements | ||||
|    * @param shape Tensor shape | ||||
|    */ | ||||
|   public static Tensor newInt64Tensor(long[] shape, long[] data) { | ||||
|   public static Tensor fromBlob(long[] data, long[] shape) { | ||||
|     checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); | ||||
|     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); | ||||
|     checkShape(shape); | ||||
| @ -192,7 +197,7 @@ public abstract class Tensor { | ||||
|    * @param shape Tensor shape | ||||
|    * @param data Tensor elements | ||||
|    */ | ||||
|   public static Tensor newFloat64Tensor(long[] shape, double[] data) { | ||||
|   public static Tensor fromBlob(long[] shape, double[] data) { | ||||
|     checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); | ||||
|     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); | ||||
|     checkShape(shape); | ||||
| @ -205,12 +210,12 @@ public abstract class Tensor { | ||||
|   /** | ||||
|    * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data. | ||||
|    * | ||||
|    * @param shape Tensor shape | ||||
|    * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} | ||||
|    * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} | ||||
|    *     elements. The buffer is used directly without copying, and changes to its content will | ||||
|    *     change the tensor. | ||||
|    * @param shape Tensor shape | ||||
|    */ | ||||
|   public static Tensor newUInt8Tensor(long[] shape, ByteBuffer data) { | ||||
|   public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) { | ||||
|     checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); | ||||
|     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); | ||||
|     checkShape(shape); | ||||
| @ -225,12 +230,12 @@ public abstract class Tensor { | ||||
|   /** | ||||
|    * Creates a new Tensor instance with dtype torch.int8 with specified shape and data. | ||||
|    * | ||||
|    * @param shape Tensor shape | ||||
|    * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} | ||||
|    * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} | ||||
|    *     elements. The buffer is used directly without copying, and changes to its content will | ||||
|    *     change the tensor. | ||||
|    * @param shape Tensor shape | ||||
|    */ | ||||
|   public static Tensor newInt8Tensor(long[] shape, ByteBuffer data) { | ||||
|   public static Tensor fromBlob(ByteBuffer data, long[] shape) { | ||||
|     checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); | ||||
|     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); | ||||
|     checkShape(shape); | ||||
| @ -245,12 +250,12 @@ public abstract class Tensor { | ||||
|   /** | ||||
|    * Creates a new Tensor instance with dtype torch.int32 with specified shape and data. | ||||
|    * | ||||
|    * @param shape Tensor shape | ||||
|    * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} | ||||
|    * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} | ||||
|    *     elements. The buffer is used directly without copying, and changes to its content will | ||||
|    *     change the tensor. | ||||
|    * @param shape Tensor shape | ||||
|    */ | ||||
|   public static Tensor newInt32Tensor(long[] shape, IntBuffer data) { | ||||
|   public static Tensor fromBlob(IntBuffer data, long[] shape) { | ||||
|     checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); | ||||
|     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); | ||||
|     checkShape(shape); | ||||
| @ -265,12 +270,12 @@ public abstract class Tensor { | ||||
|   /** | ||||
|    * Creates a new Tensor instance with dtype torch.float32 with specified shape and data. | ||||
|    * | ||||
|    * @param shape Tensor shape | ||||
|    * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} | ||||
|    * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} | ||||
|    *     elements. The buffer is used directly without copying, and changes to its content will | ||||
|    *     change the tensor. | ||||
|    * @param shape Tensor shape | ||||
|    */ | ||||
|   public static Tensor newFloat32Tensor(long[] shape, FloatBuffer data) { | ||||
|   public static Tensor fromBlob(FloatBuffer data, long[] shape) { | ||||
|     checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); | ||||
|     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); | ||||
|     checkShape(shape); | ||||
| @ -285,12 +290,12 @@ public abstract class Tensor { | ||||
|   /** | ||||
|    * Creates a new Tensor instance with dtype torch.int64 with specified shape and data. | ||||
|    * | ||||
|    * @param shape Tensor shape | ||||
|    * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} | ||||
|    * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} | ||||
|    *     elements. The buffer is used directly without copying, and changes to its content will | ||||
|    *     change the tensor. | ||||
|    * @param shape Tensor shape | ||||
|    */ | ||||
|   public static Tensor newInt64Tensor(long[] shape, LongBuffer data) { | ||||
|   public static Tensor fromBlob(LongBuffer data, long[] shape) { | ||||
|     checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); | ||||
|     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); | ||||
|     checkShape(shape); | ||||
| @ -305,12 +310,12 @@ public abstract class Tensor { | ||||
|   /** | ||||
|    * Creates a new Tensor instance with dtype torch.float64 with specified shape and data. | ||||
|    * | ||||
|    * @param shape Tensor shape | ||||
|    * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} | ||||
|    * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} | ||||
|    *     elements. The buffer is used directly without copying, and changes to its content will | ||||
|    *     change the tensor. | ||||
|    * @param shape Tensor shape | ||||
|    */ | ||||
|   public static Tensor newFloat64Tensor(long[] shape, DoubleBuffer data) { | ||||
|   public static Tensor fromBlob(DoubleBuffer data, long[] shape) { | ||||
|     checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); | ||||
|     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); | ||||
|     checkShape(shape); | ||||
| @ -327,12 +332,14 @@ public abstract class Tensor { | ||||
|     this.shape = Arrays.copyOf(shape, shape.length); | ||||
|   } | ||||
|  | ||||
|   /** Calculates number of elements in current tensor instance. */ | ||||
|   /** Returns the number of elements in this tensor. */ | ||||
|   public long numel() { | ||||
|     return numel(this.shape); | ||||
|   } | ||||
|  | ||||
|   /** Calculates number of elements in tensor with specified shape. */ | ||||
|   /** | ||||
|    * Calculates the number of elements in a tensor with the specified shape. | ||||
|    */ | ||||
|   public static long numel(long[] shape) { | ||||
|     checkShape(shape); | ||||
|     int result = 1; | ||||
| @ -343,14 +350,24 @@ public abstract class Tensor { | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Returns dtype of current tensor. Can be one of {@link Tensor#DTYPE_UINT8}, {@link | ||||
|    * Tensor#DTYPE_INT8}, {@link Tensor#DTYPE_INT32},{@link Tensor#DTYPE_FLOAT32}, {@link | ||||
|    * Tensor#DTYPE_INT64}, {@link Tensor#DTYPE_FLOAT64}. | ||||
|    * Returns the shape of this tensor.  (The array is a fresh copy.) | ||||
|    */ | ||||
|   public abstract int dtype(); | ||||
|   public long[] shape() { | ||||
|     return Arrays.copyOf(shape, shape.length); | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Returns newly allocated java byte array that contains a copy of tensor data. | ||||
|    * @return data type of this tensor. | ||||
|    */ | ||||
|   public abstract DType dtype(); | ||||
|  | ||||
|   // Called from native | ||||
|   int dtypeJniCode() { | ||||
|     return dtype().jniCode; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * @return a Java byte array that contains the tensor data.  This may be a copy or reference. | ||||
|    * | ||||
|    * @throws IllegalStateException if it is called for a non-int8 tensor. | ||||
|    */ | ||||
| @ -360,7 +377,7 @@ public abstract class Tensor { | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Returns newly allocated java byte array that contains a copy of tensor data. | ||||
|    * @return a Java byte array that contains the tensor data.  This may be a copy or reference. | ||||
|    * | ||||
|    * @throws IllegalStateException if it is called for a non-uint8 tensor. | ||||
|    */ | ||||
| @ -370,7 +387,7 @@ public abstract class Tensor { | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Returns newly allocated java byte array that contains a copy of tensor data. | ||||
|    * @return a Java int array that contains the tensor data.  This may be a copy or reference. | ||||
|    * | ||||
|    * @throws IllegalStateException if it is called for a non-int32 tensor. | ||||
|    */ | ||||
| @ -380,7 +397,7 @@ public abstract class Tensor { | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Returns newly allocated java byte array that contains a copy of tensor data. | ||||
|    * @return a Java float array that contains the tensor data.  This may be a copy or reference. | ||||
|    * | ||||
|    * @throws IllegalStateException if it is called for a non-float32 tensor. | ||||
|    */ | ||||
| @ -390,7 +407,7 @@ public abstract class Tensor { | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Returns newly allocated java byte array that contains a copy of tensor data. | ||||
|    * @return a Java long array that contains the tensor data.  This may be a copy or reference. | ||||
|    * | ||||
|    * @throws IllegalStateException if it is called for a non-int64 tensor. | ||||
|    */ | ||||
| @ -400,7 +417,7 @@ public abstract class Tensor { | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Returns newly allocated java byte array that contains a copy of tensor data. | ||||
|    * @return a Java double array that contains the tensor data.  This may be a copy or reference. | ||||
|    * | ||||
|    * @throws IllegalStateException if it is called for a non-float64 tensor. | ||||
|    */ | ||||
| @ -423,8 +440,8 @@ public abstract class Tensor { | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public int dtype() { | ||||
|       return DTYPE_UINT8; | ||||
|     public DType dtype() { | ||||
|       return DType.UINT8; | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
| @ -455,8 +472,8 @@ public abstract class Tensor { | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public int dtype() { | ||||
|       return DTYPE_INT8; | ||||
|     public DType dtype() { | ||||
|       return DType.INT8; | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
| @ -487,8 +504,8 @@ public abstract class Tensor { | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public int dtype() { | ||||
|       return DTYPE_INT32; | ||||
|     public DType dtype() { | ||||
|       return DType.INT32; | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
| @ -527,8 +544,8 @@ public abstract class Tensor { | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public int dtype() { | ||||
|       return DTYPE_FLOAT32; | ||||
|     public DType dtype() { | ||||
|       return DType.FLOAT32; | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
| @ -551,8 +568,8 @@ public abstract class Tensor { | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public int dtype() { | ||||
|       return DTYPE_INT64; | ||||
|     public DType dtype() { | ||||
|       return DType.INT64; | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
| @ -583,8 +600,8 @@ public abstract class Tensor { | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public int dtype() { | ||||
|       return DTYPE_FLOAT64; | ||||
|     public DType dtype() { | ||||
|       return DType.FLOAT64; | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
| @ -634,17 +651,17 @@ public abstract class Tensor { | ||||
|  | ||||
|   // Called from native | ||||
|   private static Tensor nativeNewTensor(ByteBuffer data, long[] shape, int dtype) { | ||||
|     if (DTYPE_FLOAT32 == dtype) { | ||||
|     if (DType.FLOAT32.jniCode == dtype) { | ||||
|       return new Tensor_float32(data.asFloatBuffer(), shape); | ||||
|     } else if (DTYPE_INT32 == dtype) { | ||||
|     } else if (DType.INT32.jniCode == dtype) { | ||||
|       return new Tensor_int32(data.asIntBuffer(), shape); | ||||
|     } else if (DTYPE_INT64 == dtype) { | ||||
|     } else if (DType.INT64.jniCode == dtype) { | ||||
|       return new Tensor_int64(data.asLongBuffer(), shape); | ||||
|     } else if (DTYPE_FLOAT64 == dtype) { | ||||
|     } else if (DType.FLOAT64.jniCode == dtype) { | ||||
|       return new Tensor_float64(data.asDoubleBuffer(), shape); | ||||
|     } else if (DTYPE_UINT8 == dtype) { | ||||
|     } else if (DType.UINT8.jniCode == dtype) { | ||||
|       return new Tensor_uint8(data, shape); | ||||
|     } else if (DTYPE_INT8 == dtype) { | ||||
|     } else if (DType.INT8.jniCode == dtype) { | ||||
|       return new Tensor_int8(data, shape); | ||||
|     } | ||||
|     throw new IllegalArgumentException("Unknown Tensor dtype"); | ||||
|  | ||||
| @ -7,6 +7,7 @@ import android.media.Image; | ||||
| import org.pytorch.Tensor; | ||||
|  | ||||
| import java.nio.ByteBuffer; | ||||
| import java.nio.FloatBuffer; | ||||
| import java.util.Locale; | ||||
|  | ||||
| /** | ||||
| @ -26,7 +27,7 @@ public final class TensorImageUtils { | ||||
|    * @param normStdRGB  standard deviation for RGB channels normalization, length must equal 3, RGB order | ||||
|    */ | ||||
|   public static Tensor bitmapToFloat32Tensor( | ||||
|       final Bitmap bitmap, float[] normMeanRGB, float normStdRGB[]) { | ||||
|       final Bitmap bitmap, final float[] normMeanRGB, final float normStdRGB[]) { | ||||
|     checkNormMeanArg(normMeanRGB); | ||||
|     checkNormStdArg(normStdRGB); | ||||
|  | ||||
| @ -34,6 +35,52 @@ public final class TensorImageUtils { | ||||
|         bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), normMeanRGB, normStdRGB); | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Writes tensor content from specified {@link android.graphics.Bitmap}, | ||||
|    * normalized with specified in parameters mean and std to specified {@link java.nio.FloatBuffer} | ||||
|    * with specified offset. | ||||
|    * | ||||
|    * @param bitmap      {@link android.graphics.Bitmap} as a source for Tensor data | ||||
|    * @param x           - x coordinate of top left corner of bitmap's area | ||||
|    * @param y           - y coordinate of top left corner of bitmap's area | ||||
|    * @param width       - width of bitmap's area | ||||
|    * @param height      - height of bitmap's area | ||||
|    * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order | ||||
|    * @param normStdRGB  standard deviation for RGB channels normalization, length must equal 3, RGB order | ||||
|    */ | ||||
|   public static void bitmapToFloatBuffer( | ||||
|       final Bitmap bitmap, | ||||
|       final int x, | ||||
|       final int y, | ||||
|       final int width, | ||||
|       final int height, | ||||
|       final float[] normMeanRGB, | ||||
|       final float[] normStdRGB, | ||||
|       final FloatBuffer outBuffer, | ||||
|       final int outBufferOffset) { | ||||
|     checkOutBufferCapacity(outBuffer, outBufferOffset, width, height); | ||||
|     checkNormMeanArg(normMeanRGB); | ||||
|     checkNormStdArg(normStdRGB); | ||||
|  | ||||
|     final int pixelsCount = height * width; | ||||
|     final int[] pixels = new int[pixelsCount]; | ||||
|     bitmap.getPixels(pixels, 0, width, x, y, width, height); | ||||
|     final int offset_g = pixelsCount; | ||||
|     final int offset_b = 2 * pixelsCount; | ||||
|     for (int i = 0; i < pixelsCount; i++) { | ||||
|       final int c = pixels[i]; | ||||
|       float r = ((c >> 16) & 0xff) / 255.0f; | ||||
|       float g = ((c >> 8) & 0xff) / 255.0f; | ||||
|       float b = ((c) & 0xff) / 255.0f; | ||||
|       float rF = (r - normMeanRGB[0]) / normStdRGB[0]; | ||||
|       float gF = (g - normMeanRGB[1]) / normStdRGB[1]; | ||||
|       float bF = (b - normMeanRGB[2]) / normStdRGB[2]; | ||||
|       outBuffer.put(outBufferOffset + i, rF); | ||||
|       outBuffer.put(outBufferOffset + offset_g + i, gF); | ||||
|       outBuffer.put(outBufferOffset + offset_b + i, bF); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Creates new {@link org.pytorch.Tensor} from specified area of {@link android.graphics.Bitmap}, | ||||
|    * normalized with specified in parameters mean and std. | ||||
| @ -57,22 +104,9 @@ public final class TensorImageUtils { | ||||
|     checkNormMeanArg(normMeanRGB); | ||||
|     checkNormStdArg(normStdRGB); | ||||
|  | ||||
|     final int pixelsCount = height * width; | ||||
|     final int[] pixels = new int[pixelsCount]; | ||||
|     bitmap.getPixels(pixels, 0, width, x, y, width, height); | ||||
|     final float[] floatArray = new float[3 * pixelsCount]; | ||||
|     final int offset_g = pixelsCount; | ||||
|     final int offset_b = 2 * pixelsCount; | ||||
|     for (int i = 0; i < pixelsCount; i++) { | ||||
|       final int c = pixels[i]; | ||||
|       float r = ((c >> 16) & 0xff) / 255.0f; | ||||
|       float g = ((c >> 8) & 0xff) / 255.0f; | ||||
|       float b = ((c) & 0xff) / 255.0f; | ||||
|       floatArray[i] = (r - normMeanRGB[0]) / normStdRGB[0]; | ||||
|       floatArray[offset_g + i] = (g - normMeanRGB[1]) / normStdRGB[1]; | ||||
|       floatArray[offset_b + i] = (b - normMeanRGB[2]) / normStdRGB[2]; | ||||
|     } | ||||
|     return Tensor.newFloat32Tensor(new long[]{1, 3, height, width}, floatArray); | ||||
|     final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * width * height); | ||||
|     bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0); | ||||
|     return Tensor.fromBlob(floatBuffer, new long[]{1, 3, height, width}); | ||||
|   } | ||||
|  | ||||
|   /** | ||||
| @ -105,6 +139,52 @@ public final class TensorImageUtils { | ||||
|     checkRotateCWDegrees(rotateCWDegrees); | ||||
|     checkTensorSize(tensorWidth, tensorHeight); | ||||
|  | ||||
|     final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * tensorWidth * tensorHeight); | ||||
|     imageYUV420CenterCropToFloatBuffer( | ||||
|         image, | ||||
|         rotateCWDegrees, | ||||
|         tensorWidth, | ||||
|         tensorHeight, | ||||
|         normMeanRGB, normStdRGB, floatBuffer, 0); | ||||
|     return Tensor.fromBlob(floatBuffer, new long[]{1, 3, tensorHeight, tensorWidth}); | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Writes tensor content from specified {@link android.media.Image}, doing optional rotation, | ||||
|    * scaling (nearest) and center cropping to specified {@link java.nio.FloatBuffer} with specified offset. | ||||
|    * | ||||
|    * @param image           {@link android.media.Image} as a source for Tensor data | ||||
|    * @param rotateCWDegrees Clockwise angle through which the input image needs to be rotated to be | ||||
|    *                        upright. Range of valid values: 0, 90, 180, 270 | ||||
|    * @param tensorWidth     return tensor width, must be positive | ||||
|    * @param tensorHeight    return tensor height, must be positive | ||||
|    * @param normMeanRGB     means for RGB channels normalization, length must equal 3, RGB order | ||||
|    * @param normStdRGB      standard deviation for RGB channels normalization, length must equal 3, RGB order | ||||
|    * @param outBuffer       Output buffer, where tensor content will be written | ||||
|    * @param outBufferOffset Output buffer offset with which tensor content will be written | ||||
|    */ | ||||
|   public static void imageYUV420CenterCropToFloatBuffer( | ||||
|       final Image image, | ||||
|       int rotateCWDegrees, | ||||
|       final int tensorWidth, | ||||
|       final int tensorHeight, | ||||
|       float[] normMeanRGB, | ||||
|       float[] normStdRGB, | ||||
|       final FloatBuffer outBuffer, | ||||
|       final int outBufferOffset) { | ||||
|     checkOutBufferCapacity(outBuffer, outBufferOffset, tensorWidth, tensorHeight); | ||||
|  | ||||
|     if (image.getFormat() != ImageFormat.YUV_420_888) { | ||||
|       throw new IllegalArgumentException( | ||||
|           String.format( | ||||
|               Locale.US, "Image format %d != ImageFormat.YUV_420_888", image.getFormat())); | ||||
|     } | ||||
|  | ||||
|     checkNormMeanArg(normMeanRGB); | ||||
|     checkNormStdArg(normStdRGB); | ||||
|     checkRotateCWDegrees(rotateCWDegrees); | ||||
|     checkTensorSize(tensorWidth, tensorHeight); | ||||
|  | ||||
|     final int widthBeforeRotation = image.getWidth(); | ||||
|     final int heightBeforeRotation = image.getHeight(); | ||||
|  | ||||
| @ -158,7 +238,6 @@ public final class TensorImageUtils { | ||||
|     final int channelSize = tensorHeight * tensorWidth; | ||||
|     final int tensorInputOffsetG = channelSize; | ||||
|     final int tensorInputOffsetB = 2 * channelSize; | ||||
|     final float[] floatArray = new float[3 * channelSize]; | ||||
|     for (int x = 0; x < tensorWidth; x++) { | ||||
|       for (int y = 0; y < tensorHeight; y++) { | ||||
|  | ||||
| @ -198,13 +277,22 @@ public final class TensorImageUtils { | ||||
|         int r = clamp((a0 + a1) >> 10, 0, 255); | ||||
|         int g = clamp((a0 - a2 - a3) >> 10, 0, 255); | ||||
|         int b = clamp((a0 + a4) >> 10, 0, 255); | ||||
|         final int offset = y * tensorWidth + x; | ||||
|         floatArray[offset] = ((r / 255.f) - normMeanRGB[0]) / normStdRGB[0]; | ||||
|         floatArray[tensorInputOffsetG + offset] = ((g / 255.f) - normMeanRGB[1]) / normStdRGB[1]; | ||||
|         floatArray[tensorInputOffsetB + offset] = ((b / 255.f) - normMeanRGB[2]) / normStdRGB[2]; | ||||
|         final int offset = outBufferOffset + y * tensorWidth + x; | ||||
|         float rF = ((r / 255.f) - normMeanRGB[0]) / normStdRGB[0]; | ||||
|         float gF = ((g / 255.f) - normMeanRGB[1]) / normStdRGB[1]; | ||||
|         float bF = ((b / 255.f) - normMeanRGB[2]) / normStdRGB[2]; | ||||
|  | ||||
|         outBuffer.put(offset, rF); | ||||
|         outBuffer.put(offset + tensorInputOffsetG, gF); | ||||
|         outBuffer.put(offset + tensorInputOffsetB, bF); | ||||
|       } | ||||
|     } | ||||
|     return Tensor.newFloat32Tensor(new long[]{1, 3, tensorHeight, tensorWidth}, floatArray); | ||||
|   } | ||||
|  | ||||
|   private static void checkOutBufferCapacity(FloatBuffer outBuffer, int outBufferOffset, int tensorWidth, int tensorHeight) { | ||||
|     if (outBufferOffset + 3 * tensorWidth * tensorHeight > outBuffer.capacity()) { | ||||
|       throw new IllegalStateException("Buffer underflow"); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   private static void checkTensorSize(int tensorWidth, int tensorHeight) { | ||||
|  | ||||
| @ -243,7 +243,20 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE) | ||||
|   set(BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE) | ||||
|   set(BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE) | ||||
|   set(BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE) | ||||
|   set(OLD_CMAKE_BUILD_TYPE ${CMAKE_BUILD_TYPE}) | ||||
|   if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND | ||||
|       CMAKE_C_COMPILER_VERSION VERSION_GREATER 6.9 AND CMAKE_C_COMPILER_VERSION VERSION_LESS 8) | ||||
|     set(GCC_7 True) | ||||
|   else() | ||||
|     set(GCC_7 False) | ||||
|   endif() | ||||
|   if(GCC_7) | ||||
|     set(CMAKE_BUILD_TYPE Release)  # Always build Sleef as a Release build to work around a gcc-7 bug | ||||
|   endif() | ||||
|   add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/sleef" ${CMAKE_BINARY_DIR}/sleef) | ||||
|   if(GCC_7) | ||||
|     set(CMAKE_BUILD_TYPE ${OLD_CMAKE_BUILD_TYPE}) | ||||
|   endif() | ||||
|   set_property(TARGET sleef PROPERTY FOLDER "dependencies") | ||||
|   list(APPEND ATen_THIRD_PARTY_INCLUDE ${CMAKE_BINARY_DIR}/include) | ||||
|   link_directories(${CMAKE_BINARY_DIR}/sleef/lib) | ||||
|  | ||||
| @ -77,10 +77,8 @@ TaskThreadPoolBase& _get_intraop_pool() { | ||||
|  | ||||
| #endif // C10_MOBILE | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| namespace internal { | ||||
|  | ||||
| // Run lambda function `fn` over `task_id` in [0, `range`) with threadpool. | ||||
| // `fn` will be called with params: (thread_pool_task_id, task_id). | ||||
| void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) { | ||||
| #ifndef C10_MOBILE | ||||
|   for (size_t i = 1; i < range; ++i) { | ||||
| @ -101,14 +99,63 @@ void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) { | ||||
| #endif // C10_MOBILE | ||||
| } | ||||
|  | ||||
| ParallelRegionGuard::ParallelRegionGuard(int64_t task_id) { | ||||
|   _set_thread_num(task_id); | ||||
|   _set_in_parallel_region(true); | ||||
| } | ||||
| // RAII guard helps to support in_parallel_region() and get_thread_num() API. | ||||
| struct ParallelRegionGuard { | ||||
|   ParallelRegionGuard(int64_t task_id) { | ||||
|     _set_thread_num(task_id); | ||||
|     _set_in_parallel_region(true); | ||||
|   } | ||||
|  | ||||
| ParallelRegionGuard::~ParallelRegionGuard() { | ||||
|   _set_in_parallel_region(false); | ||||
|   _unset_thread_num(); | ||||
|   ~ParallelRegionGuard() { | ||||
|     _set_in_parallel_region(false); | ||||
|     _unset_thread_num(); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| namespace internal { | ||||
|  | ||||
| void _parallel_run( | ||||
|   const int64_t begin, | ||||
|   const int64_t end, | ||||
|   const int64_t grain_size, | ||||
|   const std::function<void(int64_t, int64_t, size_t)>& f) { | ||||
|   size_t num_tasks, chunk_size; | ||||
|   std::tie(num_tasks, chunk_size) = | ||||
|       internal::calc_num_tasks_and_chunk_size(begin, end, grain_size); | ||||
|  | ||||
|   std::atomic_flag err_flag = ATOMIC_FLAG_INIT; | ||||
|   std::exception_ptr eptr; | ||||
|   std::vector<std::shared_ptr<c10::ivalue::Future>> futures(num_tasks); | ||||
|   for (size_t task_id = 0; task_id < num_tasks; ++task_id) { | ||||
|     futures[task_id] = std::make_shared<c10::ivalue::Future>(c10::NoneType::get()); | ||||
|   } | ||||
|   auto task = [f, &eptr, &err_flag, &futures, begin, end, chunk_size] | ||||
|       (int /* unused */, size_t task_id) { | ||||
|     int64_t local_start = begin + task_id * chunk_size; | ||||
|     if (local_start < end) { | ||||
|       int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start)); | ||||
|       try { | ||||
|         ParallelRegionGuard guard(task_id); | ||||
|         f(local_start, local_end, task_id); | ||||
|       } catch (...) { | ||||
|         if (!err_flag.test_and_set()) { | ||||
|           eptr = std::current_exception(); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     futures[task_id]->markCompleted(); | ||||
|   }; | ||||
|   _run_with_pool(task, num_tasks); | ||||
|  | ||||
|   // Wait for all tasks to finish. | ||||
|   for (size_t task_id = 0; task_id < num_tasks; ++task_id) { | ||||
|     futures[task_id]->wait(); | ||||
|   } | ||||
|   if (eptr) { | ||||
|     std::rethrow_exception(eptr); | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace internal | ||||
|  | ||||
| @ -9,16 +9,6 @@ | ||||
| namespace at { | ||||
| namespace internal { | ||||
|  | ||||
| // Run lambda function `fn` over `task_id` in [0, `range`) with threadpool. | ||||
| // `fn` will be called with params: (thread_pool_task_id, task_id). | ||||
| CAFFE2_API void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range); | ||||
|  | ||||
| // RAII guard helps to support in_parallel_region() and get_thread_num() API. | ||||
| struct CAFFE2_API ParallelRegionGuard { | ||||
|   ParallelRegionGuard(int64_t task_id); | ||||
|   ~ParallelRegionGuard(); | ||||
| }; | ||||
|  | ||||
| inline std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size( | ||||
|     int64_t begin, int64_t end, int64_t grain_size) { | ||||
|   if ((end - begin) < grain_size) { | ||||
| @ -32,6 +22,12 @@ inline std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size( | ||||
|   return std::make_tuple(num_tasks, chunk_size); | ||||
| } | ||||
|  | ||||
| CAFFE2_API void _parallel_run( | ||||
|   const int64_t begin, | ||||
|   const int64_t end, | ||||
|   const int64_t grain_size, | ||||
|   const std::function<void(int64_t, int64_t, size_t)>& f); | ||||
|  | ||||
| } // namespace internal | ||||
|  | ||||
| template <class F> | ||||
| @ -48,40 +44,14 @@ inline void parallel_for( | ||||
|     f(begin, end); | ||||
|     return; | ||||
|   } | ||||
|   size_t num_tasks, chunk_size; | ||||
|   std::tie(num_tasks, chunk_size) = | ||||
|       internal::calc_num_tasks_and_chunk_size(begin, end, grain_size); | ||||
|   std::atomic_flag err_flag = ATOMIC_FLAG_INIT; | ||||
|   std::exception_ptr eptr; | ||||
|   std::vector<std::shared_ptr<c10::ivalue::Future>> futures(num_tasks); | ||||
|   for (size_t task_id = 0; task_id < num_tasks; ++task_id) { | ||||
|     futures[task_id] = std::make_shared<c10::ivalue::Future>(c10::NoneType::get()); | ||||
|   } | ||||
|   auto task = [f, &eptr, &err_flag, &futures, begin, end, chunk_size] | ||||
|       (int _, size_t task_id) { | ||||
|     int64_t local_start = begin + task_id * chunk_size; | ||||
|     if (local_start < end) { | ||||
|       int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start)); | ||||
|       try { | ||||
|         internal::ParallelRegionGuard guard(task_id); | ||||
|         f(local_start, local_end); | ||||
|       } catch (...) { | ||||
|         if (!err_flag.test_and_set()) { | ||||
|           eptr = std::current_exception(); | ||||
|         } | ||||
|   internal::_parallel_run( | ||||
|       begin, | ||||
|       end, | ||||
|       grain_size, | ||||
|       [f](int64_t start, int64_t end, size_t /* unused */) { | ||||
|         f(start, end); | ||||
|       } | ||||
|     } | ||||
|     futures[task_id]->markCompleted(); | ||||
|   }; | ||||
|   internal::_run_with_pool(task, num_tasks); | ||||
|  | ||||
|   // Wait for all tasks to finish. | ||||
|   for (size_t task_id = 0; task_id < num_tasks; ++task_id) { | ||||
|     futures[task_id]->wait(); | ||||
|   } | ||||
|   if (eptr) { | ||||
|     std::rethrow_exception(eptr); | ||||
|   } | ||||
|   ); | ||||
| } | ||||
|  | ||||
| template <class scalar_t, class F, class SF> | ||||
| @ -104,39 +74,14 @@ inline scalar_t parallel_reduce( | ||||
|       internal::calc_num_tasks_and_chunk_size(begin, end, grain_size); | ||||
|   std::vector<scalar_t> results(num_tasks); | ||||
|   scalar_t* results_data = results.data(); | ||||
|  | ||||
|   std::atomic_flag err_flag = ATOMIC_FLAG_INIT; | ||||
|   std::exception_ptr eptr; | ||||
|   std::vector<std::shared_ptr<c10::ivalue::Future>> futures(num_tasks); | ||||
|   for (size_t task_id = 0; task_id < num_tasks; ++task_id) { | ||||
|     futures[task_id] = std::make_shared<c10::ivalue::Future>(c10::NoneType::get()); | ||||
|   } | ||||
|   auto task = [f, ident, results_data, &eptr, &err_flag, &futures, begin, end, chunk_size] | ||||
|       (int _, size_t task_id) { | ||||
|     int64_t local_start = begin + task_id * chunk_size; | ||||
|     if (local_start < end) { | ||||
|       int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start)); | ||||
|       try { | ||||
|         internal::ParallelRegionGuard guard(task_id); | ||||
|         results_data[task_id] = f(local_start, local_end, ident); | ||||
|       } catch (...) { | ||||
|         if (!err_flag.test_and_set()) { | ||||
|           eptr = std::current_exception(); | ||||
|         } | ||||
|   internal::_parallel_run( | ||||
|       begin, | ||||
|       end, | ||||
|       grain_size, | ||||
|       [f, ident, results_data](int64_t start, int64_t end, size_t task_id) { | ||||
|         results_data[task_id] = f(start, end, ident); | ||||
|       } | ||||
|     } | ||||
|     futures[task_id]->markCompleted(); | ||||
|   }; | ||||
|   internal::_run_with_pool(task, num_tasks); | ||||
|  | ||||
|   // Wait for all tasks to finish. | ||||
|   for (size_t task_id = 0; task_id < num_tasks; ++task_id) { | ||||
|     futures[task_id]->wait(); | ||||
|   } | ||||
|   if (eptr) { | ||||
|     std::rethrow_exception(eptr); | ||||
|   } | ||||
|  | ||||
|   ); | ||||
|   scalar_t result = ident; | ||||
|   for (auto partial_result : results) { | ||||
|     result = sf(result, partial_result); | ||||
|  | ||||
| @ -1218,6 +1218,8 @@ bool aten_op_is_not_moved_to_c10_yet(const c10::OperatorName& opName) { | ||||
|         {"aten::quantize_per_channel", ""}, | ||||
|         {"aten::q_per_channel_axis", ""}, | ||||
|         {"aten::qscheme", ""}, | ||||
|         {"aten::fake_quantize_per_channel_affine", ""}, | ||||
|         {"aten::fake_quantize_per_channel_affine_backward", ""}, | ||||
|         {"aten::to", "dtype_layout"}, | ||||
|         {"aten::to", "device"}, | ||||
|         {"aten::to", "dtype"}, | ||||
|  | ||||
| @ -77,6 +77,7 @@ namespace c10 { | ||||
|   _(prim, requires_grad)             \ | ||||
|   _(prim, AutogradAdd)               \ | ||||
|   _(prim, GradOf)                    \ | ||||
|   _(aten, grad)                      \ | ||||
|   _(aten, backward)                  \ | ||||
|   _(prim, Guard)                     \ | ||||
|   _(prim, BailOut)                   \ | ||||
|  | ||||
| @ -4,6 +4,7 @@ | ||||
| #include <ATen/cuda/Exceptions.h> | ||||
|  | ||||
| #include <ATen/cudnn/cudnn-wrapper.h> | ||||
| #include <ATen/cudnn/Utils.h> | ||||
| #include <ATen/ATen.h> | ||||
| #include <ATen/TensorUtils.h> | ||||
| #include <ATen/cuda/ATenCUDAGeneral.h> | ||||
| @ -190,6 +191,7 @@ struct AT_CUDA_API DropoutDescriptor | ||||
|     AT_ASSERT(options.device().type() == kCUDA); | ||||
|     AT_ASSERT(options.dtype() == kByte); | ||||
|     state = at::empty({static_cast<int64_t>(state_size)}, options); | ||||
|     setCuDNNStreamToCurrent(); | ||||
|     AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed)); | ||||
|   } | ||||
|  | ||||
| @ -200,6 +202,7 @@ struct AT_CUDA_API DropoutDescriptor | ||||
|     void *state_ptr = state.data_ptr(); | ||||
|     size_t state_size = state.size(0); | ||||
|     // NB: The seed doesn't actually matter, so we give a dummy value | ||||
|     setCuDNNStreamToCurrent(); | ||||
|     AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 /* seed */)); | ||||
|   } | ||||
|  | ||||
|  | ||||
| @ -31,7 +31,7 @@ Tensor cosine_embedding_loss(const Tensor& input1, const Tensor& input2, const T | ||||
|   auto denom = (mag_square1 * mag_square2).sqrt_(); | ||||
|   auto cos = prod_sum / denom; | ||||
|  | ||||
|   auto zeros = at::zeros_like(target); | ||||
|   auto zeros = at::zeros_like(cos); | ||||
|   auto pos = 1 - cos; | ||||
|   auto neg = (cos - margin).clamp_min_(0); | ||||
|   auto output_pos = at::where(target == 1, pos, zeros); | ||||
| @ -67,8 +67,8 @@ Tensor margin_ranking_loss(const Tensor& input1, const Tensor& input2, const Ten | ||||
| } | ||||
|  | ||||
| Tensor kl_div(const Tensor& input, const Tensor& target, int64_t reduction) { | ||||
|   auto zeros = at::zeros_like(target); | ||||
|   auto output_pos = target * (at::log(target) - input); | ||||
|   auto zeros = at::zeros_like(output_pos); | ||||
|   auto output = at::where(target > 0, output_pos, zeros); | ||||
|   return apply_loss_reduction(output, reduction); | ||||
| } | ||||
|  | ||||
| @ -55,53 +55,67 @@ bool _nnpack_available() { | ||||
|  | ||||
| #include "nnpack.h" | ||||
|  | ||||
| #include <stdlib.h> | ||||
|  | ||||
| #include <ATen/Parallel.h> | ||||
| #include <thread> | ||||
| #include "caffe2/utils/threadpool/ThreadPoolMobile.h" | ||||
|  | ||||
| namespace at { | ||||
| namespace native { | ||||
|  | ||||
| // Stolen from Caffe2 | ||||
| static pthreadpool_t nnpack_threadpool_ = nullptr; | ||||
| static bool called_nnpack_threadpool_ = false; | ||||
| static bool init_nnpack() { | ||||
|   static std::once_flag once_; | ||||
|   static bool nnpack_successfully_initialized_ = false; | ||||
|  | ||||
|   std::call_once(once_, []() { | ||||
|     const nnp_status nnpack_status = nnp_initialize(); | ||||
|     nnpack_successfully_initialized_ = (nnp_status_success == nnpack_status); | ||||
|  | ||||
| pthreadpool_t nnpack_threadpool() { | ||||
|   if (! called_nnpack_threadpool_) { | ||||
|     called_nnpack_threadpool_ = true; | ||||
|     enum nnp_status nnpack_status = nnp_initialize(); | ||||
|     if (nnpack_status != nnp_status_success) { | ||||
|       if (nnpack_status == nnp_status_out_of_memory) { | ||||
|         throw std::runtime_error("could not initialize NNPack (out of memory)"); | ||||
|         LOG(WARNING) << "Could not initialize NNPACK! Reason: Out of memory."; | ||||
|       } else if (nnpack_status == nnp_status_unsupported_hardware) { | ||||
|         throw std::runtime_error("could not initialize NNPack (unsupported hardware)"); | ||||
|         LOG(WARNING) << "Could not initialize NNPACK! Reason: Unsupported hardware."; | ||||
|       } else { | ||||
|         throw std::runtime_error("could not initialize NNPack (unknown error)"); | ||||
|         LOG(WARNING) << "Could not initialize NNPACK! Reason: Unknown error!"; | ||||
|       } | ||||
|     } | ||||
|     unsigned int threads; | ||||
| #ifdef INTRA_OP_PARALLEL | ||||
|     threads = at::get_num_threads(); | ||||
|   }); | ||||
|  | ||||
|   return nnpack_successfully_initialized_; | ||||
| } | ||||
|  | ||||
| static pthreadpool_t nnpack_threadpool() { | ||||
|   // Try initializing a threadpool for NNPACK's use.  If we fail to | ||||
|   // successfully initialize an implementation, return nullptr which will | ||||
|   // instruct NNPACK to run single threaded. | ||||
|  | ||||
| #ifdef C10_MOBILE | ||||
|   // If building for mobile, use Caffe 2's mobile-friendly threadpool. | ||||
|   return caffe2::mobile_pthreadpool(); | ||||
| #else | ||||
|     threads = std::thread::hardware_concurrency(); | ||||
|   // Otherwise, try using pthreadpool if we manage to initialize it successfully. | ||||
|   static pthreadpool_t nnpack_threadpool_ = nullptr; | ||||
|   static bool called_nnpack_threadpool_ = false; | ||||
|  | ||||
|   if (!called_nnpack_threadpool_) { | ||||
|     called_nnpack_threadpool_ = true; | ||||
|  | ||||
| #ifdef INTRA_OP_PARALLEL | ||||
|     const uint32_t threads = at::get_num_threads(); | ||||
| #else | ||||
|     const uint32_t threads = std::thread::hardware_concurrency(); | ||||
| #endif | ||||
|  | ||||
|     nnpack_threadpool_ = pthreadpool_create(threads); | ||||
|     if (nnpack_threadpool_ == nullptr) { | ||||
|       throw std::runtime_error("could not initialize NNPack's pthreadpool"); | ||||
|     if ( !nnpack_threadpool_ ) { | ||||
|       LOG(WARNING) << "Failed to initialize pthreadpool! Running NNPACK in single-threaded mode."; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   return nnpack_threadpool_; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| bool _nnpack_available() { | ||||
|   if (! called_nnpack_threadpool_) { | ||||
|     try { | ||||
|       return nnpack_threadpool() != nullptr; | ||||
|     } catch (std::runtime_error e) { | ||||
|     } | ||||
|   } | ||||
|   return nnpack_threadpool() != nullptr; | ||||
|   return init_nnpack(); | ||||
| } | ||||
|  | ||||
| // Make thread_local for safety in cases where we have multiple threads running | ||||
|  | ||||
| @ -152,7 +152,7 @@ Tensor align_to(const Tensor& tensor, DimnameList names) { | ||||
|     const auto& dim = tensor_names[idx]; | ||||
|     TORCH_CHECK(dim.isBasic(), | ||||
|         "align_to: All input dims must be named. Found unnamed dim at index ", | ||||
|         dim, " of Tensor", tensor_names); | ||||
|         idx, " of Tensor", tensor_names); | ||||
|     auto it = std::find(names.begin(), names.end(), dim); | ||||
|     TORCH_CHECK(it != names.end(), | ||||
|         "align_to: Cannot find dim ", dim, " from Tensor", names, | ||||
|  | ||||
| @ -122,6 +122,7 @@ std::vector<Tensor> where(const Tensor& condition) { | ||||
| } | ||||
|  | ||||
| Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& other) { | ||||
|   TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype()); | ||||
|   Tensor ret = at::empty(self.sizes(), self.options()); | ||||
|   AT_DISPATCH_ALL_TYPES(ret.scalar_type(), "where_cpu", [&] { | ||||
|     where_cpu<scalar_t>(ret, condition, self, other); | ||||
|  | ||||
| @ -102,11 +102,13 @@ std::tuple<Device, ScalarType> TensorIterator::compute_common_type() { | ||||
|   return compute_common_type_(operands_); | ||||
| } | ||||
|  | ||||
| static void validate_dtype(OperandInfo& op, ScalarType common_dtype, int ninputs) { | ||||
| static void validate_dtype(OperandInfo& op, ScalarType common_dtype, CommonDTypeStrategy strategy) { | ||||
|   bool check_output = strategy == CommonDTypeStrategy::CHECK || strategy == CommonDTypeStrategy::PROMOTE; | ||||
|   bool promoting = strategy == CommonDTypeStrategy::PROMOTE || (strategy == CommonDTypeStrategy::PROMOTE_INPUTS && !op.is_output); | ||||
|   if (op.tensor.defined()) { | ||||
|     // For binary_ops, we follow casting rules. For unary/nullary types | ||||
|     // we require the type to match. | ||||
|     if (op.is_output) { | ||||
|     if (op.is_output && check_output) { | ||||
|       if (!canCast(common_dtype, op.tensor.scalar_type())) | ||||
|       { | ||||
|         AT_ERROR("result type ", common_dtype, | ||||
| @ -114,7 +116,7 @@ static void validate_dtype(OperandInfo& op, ScalarType common_dtype, int ninputs | ||||
|           op.tensor.scalar_type()); | ||||
|       } | ||||
|     } | ||||
|     if (ninputs < 2 && op.dtype != op.tensor.scalar_type()) { | ||||
|     if (!promoting && op.dtype != op.tensor.scalar_type()) { | ||||
|       AT_ERROR("expected dtype ", op.dtype, " but got dtype ", op.tensor.scalar_type()); | ||||
|     } | ||||
|   } | ||||
| @ -131,14 +133,6 @@ static void maybe_promote_common_dtype(OperandInfo& op, ScalarType common_dtype) | ||||
|       op.tensor = | ||||
|           at::empty_like(op.tensor, op.tensor.options().dtype(common_dtype)); | ||||
|     } | ||||
|     auto original_element_size = op.original_tensor.element_size(); | ||||
|     auto new_element_size = op.tensor.element_size(); | ||||
|  | ||||
|     // stride size (in bytes) can change if we change the dtype. | ||||
|     for( size_t i=0; i < op.stride_bytes.size(); i++ ) { | ||||
|       auto stride = op.stride_bytes[i] / original_element_size; | ||||
|       op.stride_bytes[i] = stride * new_element_size; | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| @ -155,12 +149,12 @@ void TensorIterator::compute_types() { | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   if (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_INPUTS) { | ||||
|   if (common_dtype_strategy_ == CommonDTypeStrategy::PROMOTE_INPUTS) { | ||||
|     TORCH_CHECK(!missing_output_dtypes, "unable to compute and promote common dtype based only on inputs if there are missing dtypes for outputs"); | ||||
|   } | ||||
|  | ||||
|   bool compute_common_dtype = (compute_common_dtype_strategy_ != CommonDTypeStrategy::COMPUTE_NONE); | ||||
|   bool compute_common_dtype_only_for_inputs = (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_INPUTS); | ||||
|   bool compute_common_dtype = (common_dtype_strategy_ != CommonDTypeStrategy::NONE); | ||||
|   bool compute_common_dtype_only_for_inputs = (common_dtype_strategy_ == CommonDTypeStrategy::PROMOTE_INPUTS); | ||||
|  | ||||
|   if (missing_dtypes || compute_common_dtype) { | ||||
|     auto operands = compute_common_dtype_only_for_inputs ? at::ArrayRef<OperandInfo>(operands_).slice(noutputs()) : operands_; | ||||
| @ -198,24 +192,25 @@ void TensorIterator::compute_types() { | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|       if (!compute_common_dtype_only_for_inputs) { | ||||
|         validate_dtype(op, common_dtype, ninputs()); | ||||
|       } | ||||
|       if (!compute_common_dtype_only_for_inputs || !op.is_output) { | ||||
|         maybe_promote_common_dtype(op, common_dtype); | ||||
|       } | ||||
|  | ||||
|       if (op.tensor.defined() && op.device != op.tensor.device()) { | ||||
|         if (op.is_output) { | ||||
|           AT_ERROR("output with device ", op.tensor.device(), | ||||
|                    " doesn't match the desired device ", op.device); | ||||
|         } else if (op.tensor.dim() == 0) { | ||||
|           op.tensor = op.tensor.to(op.options()); | ||||
|         } else { | ||||
|           AT_ERROR("expected device ", op.device, | ||||
|                    " but got device ", op.tensor.device()); | ||||
|         } | ||||
|   for (auto &op : operands_) { | ||||
|     validate_dtype(op, common_dtype, common_dtype_strategy_); | ||||
|     if (compute_common_dtype && (!compute_common_dtype_only_for_inputs || !op.is_output)) { | ||||
|       maybe_promote_common_dtype(op, common_dtype); | ||||
|     } | ||||
|  | ||||
|     if (op.tensor.defined() && op.device != op.tensor.device()) { | ||||
|       if (op.is_output) { | ||||
|         AT_ERROR("output with device ", op.tensor.device(), | ||||
|                   " doesn't match the desired device ", op.device); | ||||
|       } else if (op.tensor.dim() == 0) { | ||||
|         op.tensor = op.tensor.to(op.options()); | ||||
|       } else { | ||||
|         AT_ERROR("expected device ", op.device, | ||||
|                   " but got device ", op.tensor.device()); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| @ -602,6 +597,7 @@ TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a, | ||||
|   iter.add_input(a); | ||||
|   iter.add_input(b); | ||||
|   iter.allow_cpu_scalars_ = true; | ||||
|   iter.promote_common_dtype(); | ||||
|   iter.build(); | ||||
|   return iter; | ||||
| } | ||||
| @ -827,12 +823,12 @@ void TensorIterator::build() { | ||||
| #endif | ||||
|   // compute the broadcasted shape | ||||
|   compute_shape(); | ||||
|   // compute the result dtype and device | ||||
|   compute_types(); | ||||
|   // compute each tensor's stride after broadcasting | ||||
|   compute_strides(); | ||||
|   // re-order dimensions to improve coalescing | ||||
|   reorder_dimensions(); | ||||
|   // compute the result dtype and device | ||||
|   compute_types(); | ||||
|   // allocate the output tensor if it's not provided | ||||
|   allocate_outputs(); | ||||
| #ifdef BUILD_NAMEDTENSOR | ||||
|  | ||||
| @ -126,9 +126,10 @@ struct CAFFE2_API OperandInfo { | ||||
| struct SplitUntil32Bit; | ||||
|  | ||||
| enum class CommonDTypeStrategy : uint8_t { | ||||
|   COMPUTE_ALL = 0, // Compute common dtype based on inputs and outputs. Try to promote common dtype to inputs and outputs | ||||
|   COMPUTE_INPUTS = 1, // Compute common dtype based only on inputs. Try to promote common dtype only to inputs | ||||
|   COMPUTE_NONE = 2, // Do not compute and promote common dtype | ||||
|   NONE, // Do not compute a common dtype | ||||
|   CHECK, // Compute and validate a common dtype but don't promote. | ||||
|   PROMOTE_INPUTS, // Promote common dtype but only validate inputs (comparison ops have boolean output) | ||||
|   PROMOTE // Promote to common dtype. | ||||
| }; | ||||
|  | ||||
| struct CAFFE2_API TensorIterator { | ||||
| @ -204,7 +205,7 @@ struct CAFFE2_API TensorIterator { | ||||
|   } | ||||
|  | ||||
|   void cast_outputs() { | ||||
|     if (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_ALL) { | ||||
|     if (common_dtype_strategy_ == CommonDTypeStrategy::PROMOTE) { | ||||
|       for(int i=0; i < noutputs(); i++) { | ||||
|         if (operands_[i].original_tensor.defined() && dtype(i) != operands_[i].original_tensor.scalar_type()) { | ||||
|           operands_[i].original_tensor.copy_(operands_[i].tensor); | ||||
| @ -306,12 +307,16 @@ struct CAFFE2_API TensorIterator { | ||||
|     operands_.emplace_back(input, device, dtype); | ||||
|   } | ||||
|  | ||||
|   void promote_common_dtype() { | ||||
|     common_dtype_strategy_ = CommonDTypeStrategy::PROMOTE; | ||||
|   } | ||||
|  | ||||
|   void dont_compute_common_dtype() { | ||||
|     compute_common_dtype_strategy_ = CommonDTypeStrategy::COMPUTE_NONE; | ||||
|     common_dtype_strategy_ = CommonDTypeStrategy::NONE; | ||||
|   } | ||||
|  | ||||
|   void compute_common_dtype_only_for_inputs() { | ||||
|     compute_common_dtype_strategy_ = CommonDTypeStrategy::COMPUTE_INPUTS; | ||||
|     common_dtype_strategy_ = CommonDTypeStrategy::PROMOTE_INPUTS; | ||||
|   } | ||||
|  | ||||
|   void dont_resize_outputs() { | ||||
| @ -344,7 +349,7 @@ protected: | ||||
| #endif | ||||
|   SmallVector<OperandInfo, 4> operands_; | ||||
|   int num_outputs_ = 0; | ||||
|   CommonDTypeStrategy compute_common_dtype_strategy_ = CommonDTypeStrategy::COMPUTE_ALL; | ||||
|   CommonDTypeStrategy common_dtype_strategy_ = CommonDTypeStrategy::CHECK; | ||||
|   bool has_coalesced_dimensions_ = false; | ||||
|   bool accumulate_ = false; | ||||
|   bool resize_outputs_ = true; | ||||
|  | ||||
| @ -171,8 +171,9 @@ void avg_pool2d_out_cuda_template( | ||||
|  | ||||
|   output.resize_({nbatch, nInputPlane, outputHeight, outputWidth}); | ||||
|  | ||||
|   const int count = safe_downcast<int, int64_t>(output.numel()); | ||||
|   const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); | ||||
|   const int32_t count = safe_downcast<int32_t, int64_t>(output.numel()); | ||||
|   const uint32_t  num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); | ||||
|   const uint32_t num_blocks = cuda::ATenCeilDiv<uint32_t>(count, num_threads); | ||||
|  | ||||
|   if (divisor_override.has_value()) { | ||||
|     AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), | ||||
| @ -184,7 +185,7 @@ void avg_pool2d_out_cuda_template( | ||||
|         scalar_t *input_data = input.data_ptr<scalar_t>(); | ||||
|  | ||||
|         avg_pool2d_out_cuda_frame<scalar_t, accscalar_t, false, true> | ||||
|             <<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( | ||||
|             <<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( | ||||
|             count, | ||||
|                 input_data, | ||||
|                 nbatch, | ||||
| @ -209,7 +210,7 @@ void avg_pool2d_out_cuda_template( | ||||
|           scalar_t *input_data = input.data_ptr<scalar_t>(); | ||||
|  | ||||
|           avg_pool2d_out_cuda_frame<scalar_t, accscalar_t, true, false> | ||||
|               <<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( | ||||
|               <<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( | ||||
|               count, | ||||
|                   input_data, | ||||
|                   nbatch, | ||||
| @ -233,7 +234,7 @@ void avg_pool2d_out_cuda_template( | ||||
|           scalar_t *input_data = input.data_ptr<scalar_t>(); | ||||
|  | ||||
|           avg_pool2d_out_cuda_frame<scalar_t, accscalar_t, false, false> | ||||
|               <<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( | ||||
|               <<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( | ||||
|               count, | ||||
|                   input_data, | ||||
|                   nbatch, | ||||
| @ -249,10 +250,8 @@ void avg_pool2d_out_cuda_template( | ||||
|     } | ||||
|   } | ||||
|  | ||||
|    | ||||
|   TORCH_CHECK(cudaGetLastError() == cudaSuccess, | ||||
|      "avg_pool2d_out_cuda_frame failed with error code ", | ||||
|      cudaGetLastError()); | ||||
|  | ||||
|   THCudaCheck(cudaGetLastError()); | ||||
|  | ||||
|   if (input.ndimension() == 3) { | ||||
|     output.resize_({nInputPlane, outputHeight, outputWidth}); | ||||
| @ -322,8 +321,9 @@ Tensor& avg_pool2d_backward_out_cuda_template( | ||||
|  | ||||
|   gradInput.resize_as_(input); | ||||
|  | ||||
|   const int count =  safe_downcast<int, int64_t>(input.numel()); | ||||
|   const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); | ||||
|   const int32_t count =  safe_downcast<int32_t, int64_t>(input.numel()); | ||||
|   const uint32_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); | ||||
|   const uint32_t num_blocks = cuda::ATenCeilDiv<uint32_t>(count, num_threads); | ||||
|  | ||||
|   if (divisor_override.has_value()) { | ||||
|     AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), | ||||
| @ -335,7 +335,7 @@ Tensor& avg_pool2d_backward_out_cuda_template( | ||||
|         scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>(); | ||||
|  | ||||
|         avg_pool2d_backward_out_cuda_frame<scalar_t, accscalar_t, false, true> | ||||
|             <<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( | ||||
|             <<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( | ||||
|             count, | ||||
|                 gradOutput_data, | ||||
|                 nbatch, | ||||
| @ -360,7 +360,7 @@ Tensor& avg_pool2d_backward_out_cuda_template( | ||||
|           scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>(); | ||||
|  | ||||
|           avg_pool2d_backward_out_cuda_frame<scalar_t, accscalar_t, true, false> | ||||
|             <<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( | ||||
|             <<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( | ||||
|                count, | ||||
|                gradOutput_data, | ||||
|                nbatch, | ||||
| @ -384,7 +384,7 @@ Tensor& avg_pool2d_backward_out_cuda_template( | ||||
|           scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>(); | ||||
|  | ||||
|           avg_pool2d_backward_out_cuda_frame<scalar_t, accscalar_t, false, false> | ||||
|             <<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( | ||||
|             <<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( | ||||
|                count, | ||||
|                gradOutput_data, | ||||
|                nbatch, | ||||
| @ -400,9 +400,7 @@ Tensor& avg_pool2d_backward_out_cuda_template( | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   TORCH_CHECK(cudaGetLastError() == cudaSuccess, | ||||
|     "avg_pool2d_backward_out_cuda failed with error code ", | ||||
|     cudaGetLastError()); | ||||
|   THCudaCheck(cudaGetLastError()); | ||||
|  | ||||
|   return gradInput; | ||||
| } | ||||
|  | ||||
| @ -428,7 +428,7 @@ ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_da | ||||
|                                                      const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, | ||||
|                                               int64_t batch_size, int64_t num_labels, int64_t BLANK, bool zero_infinity) { | ||||
|   int64_t b = threadIdx.y + blockIdx.y * blockDim.y; | ||||
|   int64_t s = threadIdx.x + blockIdx.x * blockDim.y; // note, this directly indexes into targets, no targets prime! | ||||
|   int64_t s = threadIdx.x + blockIdx.x * blockDim.x; // note, this directly indexes into targets, not targets prime! | ||||
|  | ||||
|   if (b >= batch_size) | ||||
|     return; | ||||
|  | ||||
| @ -141,50 +141,80 @@ void or_kernel_cuda(TensorIterator& iter) { | ||||
|     }), false); | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| template <typename scalar_t, typename acc_t=scalar_t> | ||||
| void max_values_kernel_cuda_impl(TensorIterator& iter) { | ||||
|   gpu_reduce_kernel<scalar_t, scalar_t>( | ||||
|     iter, func_wrapper<scalar_t> ([]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { | ||||
|       return (THCNumerics<scalar_t>::isnan(a) || a > b) ? a : b; | ||||
|     }), at::numeric_limits<scalar_t>::lower_bound()); | ||||
|     iter, func_wrapper<acc_t> ([]GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { | ||||
|       return (THCNumerics<acc_t>::isnan(a) || a > b) ? a : b; | ||||
|     }), at::numeric_limits<acc_t>::lower_bound()); | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| template <typename scalar_t, typename acc_t=scalar_t> | ||||
| void min_values_kernel_cuda_impl(TensorIterator& iter) { | ||||
|   gpu_reduce_kernel<scalar_t, scalar_t>( | ||||
|     iter, func_wrapper<scalar_t> ([]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { | ||||
|       return (THCNumerics<scalar_t>::isnan(a) || a < b) ? a : b; | ||||
|     }), at::numeric_limits<scalar_t>::upper_bound()); | ||||
|     iter, func_wrapper<acc_t> ([]GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { | ||||
|       return (THCNumerics<acc_t>::isnan(a) || a < b) ? a : b; | ||||
|     }), at::numeric_limits<acc_t>::upper_bound()); | ||||
| } | ||||
|  | ||||
| void max_values_kernel_cuda(TensorIterator& iter) { | ||||
|   AT_DISPATCH_ALL_TYPES(iter.dtype(), "max_values_cuda", [&]() { | ||||
|     max_values_kernel_cuda_impl<scalar_t>(iter); | ||||
|   }); | ||||
|   if (iter.dtype(1) == kHalf) { | ||||
|     max_values_kernel_cuda_impl<at::Half, float>(iter); | ||||
|   } else { | ||||
|     AT_DISPATCH_ALL_TYPES(iter.dtype(), "max_values_cuda", [&]() { | ||||
|       max_values_kernel_cuda_impl<scalar_t>(iter); | ||||
|     }); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void min_values_kernel_cuda(TensorIterator& iter) { | ||||
|   AT_DISPATCH_ALL_TYPES(iter.dtype(), "min_values_cuda", [&]() { | ||||
|     min_values_kernel_cuda_impl<scalar_t>(iter); | ||||
|   }); | ||||
|   if (iter.dtype(1) == kHalf) { | ||||
|     min_values_kernel_cuda_impl<at::Half, float>(iter); | ||||
|   } else { | ||||
|     AT_DISPATCH_ALL_TYPES(iter.dtype(), "min_values_cuda", [&]() { | ||||
|       min_values_kernel_cuda_impl<scalar_t>(iter); | ||||
|     }); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t, typename acc_t=scalar_t> | ||||
| void argmax_kernel_cuda_impl(TensorIterator& iter) { | ||||
|   gpu_reduce_kernel<scalar_t, int64_t>( | ||||
|     iter, | ||||
|     ArgMaxOps<acc_t>{}, | ||||
|     thrust::pair<acc_t, int64_t>(at::numeric_limits<acc_t>::lower_bound(), 0)); | ||||
| }; | ||||
|  | ||||
| template <typename scalar_t, typename acc_t=scalar_t> | ||||
| void argmin_kernel_cuda_impl(TensorIterator& iter) { | ||||
|   gpu_reduce_kernel<scalar_t, int64_t>( | ||||
|     iter, | ||||
|     ArgMinOps<acc_t>{}, | ||||
|     thrust::pair<acc_t, int64_t>(at::numeric_limits<acc_t>::upper_bound(), 0)); | ||||
| }; | ||||
|  | ||||
| void argmax_kernel_cuda(TensorIterator& iter) { | ||||
|   AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmax_cuda", [&]() { | ||||
|     gpu_reduce_kernel<scalar_t, int64_t>( | ||||
|       iter, | ||||
|       ArgMaxOps<scalar_t>{}, | ||||
|       thrust::pair<scalar_t, int64_t>(at::numeric_limits<scalar_t>::lower_bound(), 0)); | ||||
|   }); | ||||
|   if (iter.dtype(1) == kHalf) { | ||||
|     // Instead of implementing is_nan and warp_shfl_down | ||||
|     // we can convert halves to float and do all the operations in float | ||||
|     argmax_kernel_cuda_impl<at::Half, float>(iter); | ||||
|   } else { | ||||
|     AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmax_cuda", [&]() { | ||||
|       argmax_kernel_cuda_impl<scalar_t>(iter); | ||||
|     }); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void argmin_kernel_cuda(TensorIterator& iter) { | ||||
|   AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmin_cuda", [&]() { | ||||
|     gpu_reduce_kernel<scalar_t, int64_t>( | ||||
|       iter, | ||||
|       ArgMinOps<scalar_t>{}, | ||||
|       thrust::pair<scalar_t, int64_t>(at::numeric_limits<scalar_t>::upper_bound(), 0)); | ||||
|   }); | ||||
|   if (iter.dtype(1) == kHalf) { | ||||
|     // Instead of implementing is_nan and warp_shfl_down | ||||
|     // we can convert halves to float and do all the operations in float | ||||
|     argmin_kernel_cuda_impl<at::Half, float>(iter); | ||||
|   } else { | ||||
|     AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmin_cuda", [&]() { | ||||
|       argmin_kernel_cuda_impl<scalar_t>(iter); | ||||
|     }); | ||||
|   } | ||||
| } | ||||
|  | ||||
| REGISTER_DISPATCH(std_var_stub, &std_var_kernel_cuda); | ||||
|  | ||||
| @ -778,6 +778,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn( | ||||
|           &reserve_size | ||||
|           )); | ||||
|     reserve = at::empty(reserve_size, input.options().dtype(kByte)); | ||||
|     setCuDNNStreamToCurrent(); | ||||
|     AT_CUDNN_CHECK(cudnnRNNForwardTraining( | ||||
|           handle, | ||||
|           descs.rnn_desc.desc(), | ||||
| @ -794,6 +795,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn( | ||||
|           )); | ||||
|   } else { // inference | ||||
|     reserve = at::empty({0}, input.options().dtype(kByte)); | ||||
|     setCuDNNStreamToCurrent(); | ||||
|     AT_CUDNN_CHECK(cudnnRNNForwardInference( | ||||
|           handle, | ||||
|           descs.rnn_desc.desc(), | ||||
| @ -912,7 +914,7 @@ std::tuple<Tensor, Tensor, Tensor> _cudnn_rnn_backward_input( | ||||
|         )); | ||||
|   // TODO: put this in the correct device??? | ||||
|   Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte)); | ||||
|  | ||||
|   setCuDNNStreamToCurrent(); | ||||
|   AT_CUDNN_CHECK(cudnnRNNBackwardData( | ||||
|         handle, | ||||
|         descs.rnn_desc.desc(), | ||||
| @ -1016,7 +1018,7 @@ std::vector<Tensor> _cudnn_rnn_backward_weight( | ||||
|         &workspace_size | ||||
|         )); | ||||
|   Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte)); | ||||
|  | ||||
|   setCuDNNStreamToCurrent(); | ||||
|   AT_CUDNN_CHECK(cudnnRNNBackwardWeights( | ||||
|         handle, | ||||
|         descs.rnn_desc.desc(), | ||||
|  | ||||
| @ -66,7 +66,7 @@ | ||||
|   supports_named_tensor: True | ||||
|  | ||||
| - func: align_to(Tensor(a) self, DimnameList names) -> Tensor(a) | ||||
|   variants: function, method | ||||
|   variants: method | ||||
|   supports_named_tensor: True | ||||
|  | ||||
| - func: align_as(Tensor self, Tensor other) -> Tensor | ||||
| @ -1397,7 +1397,8 @@ | ||||
|   use_c10_dispatcher: full | ||||
|   variants: function | ||||
|   device_guard: False | ||||
|  | ||||
|   supports_named_tensor: True | ||||
|    | ||||
| - func: is_distributed(Tensor self) -> bool | ||||
|   use_c10_dispatcher: full | ||||
|   variants: function, method | ||||
| @ -3681,6 +3682,17 @@ | ||||
|     CPU: fake_quantize_per_tensor_affine_backward_cpu | ||||
|     CUDA: fake_quantize_per_tensor_affine_backward_cuda | ||||
|  | ||||
| - func: fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor | ||||
|   variants: function | ||||
|   dispatch: | ||||
|     CPU: fake_quantize_per_channel_affine_cpu | ||||
|     CUDA: fake_quantize_per_channel_affine_cuda | ||||
|  | ||||
| - func: fake_quantize_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor | ||||
|   variants: function | ||||
|   dispatch: | ||||
|     CPU: fake_quantize_per_channel_affine_backward_cpu | ||||
|     CUDA: fake_quantize_per_channel_affine_backward_cuda | ||||
| # to(Device) must not exist because all constructors of Device also works for | ||||
| # TensorOptions. Otherwise, an ambiguity error is thrown. | ||||
| # See NOTE [ TensorOptions Constructors ]. | ||||
| @ -4473,12 +4485,14 @@ | ||||
|     CUDA: legacy::cuda::_th_trace | ||||
|  | ||||
| - func: ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   supports_named_tensor: True | ||||
|   dispatch: | ||||
|     CPU: ne_out | ||||
|     CUDA: ne_out | ||||
|     QuantizedCPU: ne_out_quantized_cpu | ||||
|  | ||||
| - func: ne.Scalar(Tensor self, Scalar other) -> Tensor | ||||
|   supports_named_tensor: True | ||||
|   use_c10_dispatcher: full | ||||
|   variants: method, function | ||||
|   dispatch: | ||||
| @ -4487,12 +4501,14 @@ | ||||
|     QuantizedCPU: ne_quantized_cpu | ||||
|  | ||||
| - func: ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   supports_named_tensor: True | ||||
|   dispatch: | ||||
|     CPU: ne_out | ||||
|     CUDA: ne_out | ||||
|     QuantizedCPU: ne_out_quantized_cpu | ||||
|  | ||||
| - func: ne.Tensor(Tensor self, Tensor other) -> Tensor | ||||
|   supports_named_tensor: True | ||||
|   use_c10_dispatcher: full | ||||
|   variants: method, function | ||||
|   dispatch: | ||||
| @ -4501,12 +4517,14 @@ | ||||
|     QuantizedCPU: ne_quantized_cpu | ||||
|  | ||||
| - func: eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   supports_named_tensor: True | ||||
|   dispatch: | ||||
|     CPU: eq_out | ||||
|     CUDA: eq_out | ||||
|     QuantizedCPU: eq_out_quantized_cpu | ||||
|  | ||||
| - func: eq.Scalar(Tensor self, Scalar other) -> Tensor | ||||
|   supports_named_tensor: True | ||||
|   use_c10_dispatcher: full | ||||
|   variants: method, function | ||||
|   dispatch: | ||||
| @ -4515,12 +4533,14 @@ | ||||
|     QuantizedCPU: eq_quantized_cpu | ||||
|  | ||||
| - func: eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   supports_named_tensor: True | ||||
|   dispatch: | ||||
|     CPU: eq_out | ||||
|     CUDA: eq_out | ||||
|     QuantizedCPU: eq_out_quantized_cpu | ||||
|  | ||||
| - func: eq.Tensor(Tensor self, Tensor other) -> Tensor | ||||
|   supports_named_tensor: True | ||||
|   use_c10_dispatcher: full | ||||
|   variants: method, function | ||||
|   dispatch: | ||||
| @ -4529,12 +4549,14 @@ | ||||
|     QuantizedCPU: eq_quantized_cpu | ||||
|  | ||||
| - func: ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   supports_named_tensor: True | ||||
|   dispatch: | ||||
|     CPU: ge_out | ||||
|     CUDA: ge_out | ||||
|     QuantizedCPU: ge_out_quantized_cpu | ||||
|  | ||||
| - func: ge.Scalar(Tensor self, Scalar other) -> Tensor | ||||
|   supports_named_tensor: True | ||||
|   use_c10_dispatcher: full | ||||
|   variants: method, function | ||||
|   dispatch: | ||||
| @ -4543,12 +4565,14 @@ | ||||
|     QuantizedCPU: ge_quantized_cpu | ||||
|  | ||||
| - func: ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   supports_named_tensor: True | ||||
|   dispatch: | ||||
|     CPU: ge_out | ||||
|     CUDA: ge_out | ||||
|     QuantizedCPU: ge_out_quantized_cpu | ||||
|  | ||||
| - func: ge.Tensor(Tensor self, Tensor other) -> Tensor | ||||
|   supports_named_tensor: True | ||||
|   use_c10_dispatcher: full | ||||
|   variants: method, function | ||||
|   dispatch: | ||||
| @ -4557,12 +4581,14 @@ | ||||
|     QuantizedCPU: ge_quantized_cpu | ||||
|  | ||||
| - func: le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   supports_named_tensor: True | ||||
|   dispatch: | ||||
|     CPU: le_out | ||||
|     CUDA: le_out | ||||
|     QuantizedCPU: le_out_quantized_cpu | ||||
|  | ||||
| - func: le.Scalar(Tensor self, Scalar other) -> Tensor | ||||
|   supports_named_tensor: True | ||||
|   use_c10_dispatcher: full | ||||
|   variants: method, function | ||||
|   dispatch: | ||||
| @ -4571,12 +4597,14 @@ | ||||
|     QuantizedCPU: le_quantized_cpu | ||||
|  | ||||
| - func: le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   supports_named_tensor: True | ||||
|   dispatch: | ||||
|     CPU: le_out | ||||
|     CUDA: le_out | ||||
|     QuantizedCPU: le_out_quantized_cpu | ||||
|  | ||||
| - func: le.Tensor(Tensor self, Tensor other) -> Tensor | ||||
|   supports_named_tensor: True | ||||
|   use_c10_dispatcher: full | ||||
|   variants: method, function | ||||
|   dispatch: | ||||
| @ -4585,12 +4613,14 @@ | ||||
|     QuantizedCPU: le_quantized_cpu | ||||
|  | ||||
| - func: gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   supports_named_tensor: True | ||||
|   dispatch: | ||||
|     CPU: gt_out | ||||
|     CUDA: gt_out | ||||
|     QuantizedCPU: gt_out_quantized_cpu | ||||
|  | ||||
| - func: gt.Scalar(Tensor self, Scalar other) -> Tensor | ||||
|   supports_named_tensor: True | ||||
|   use_c10_dispatcher: full | ||||
|   variants: method, function | ||||
|   dispatch: | ||||
| @ -4599,12 +4629,14 @@ | ||||
|     QuantizedCPU: gt_quantized_cpu | ||||
|  | ||||
| - func: gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   supports_named_tensor: True | ||||
|   dispatch: | ||||
|     CPU: gt_out | ||||
|     CUDA: gt_out | ||||
|     QuantizedCPU: gt_out_quantized_cpu | ||||
|  | ||||
| - func: gt.Tensor(Tensor self, Tensor other) -> Tensor | ||||
|   supports_named_tensor: True | ||||
|   use_c10_dispatcher: full | ||||
|   variants: method, function | ||||
|   dispatch: | ||||
| @ -4613,12 +4645,14 @@ | ||||
|     QuantizedCPU: gt_quantized_cpu | ||||
|  | ||||
| - func: lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   supports_named_tensor: True | ||||
|   dispatch: | ||||
|     CPU: lt_out | ||||
|     CUDA: lt_out | ||||
|     QuantizedCPU: lt_out_quantized_cpu | ||||
|  | ||||
| - func: lt.Scalar(Tensor self, Scalar other) -> Tensor | ||||
|   supports_named_tensor: True | ||||
|   use_c10_dispatcher: full | ||||
|   variants: method, function | ||||
|   dispatch: | ||||
| @ -4627,12 +4661,14 @@ | ||||
|     QuantizedCPU: lt_quantized_cpu | ||||
|  | ||||
| - func: lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   supports_named_tensor: True | ||||
|   dispatch: | ||||
|     CPU: lt_out | ||||
|     CUDA: lt_out | ||||
|     QuantizedCPU: lt_out_quantized_cpu | ||||
|  | ||||
| - func: lt.Tensor(Tensor self, Tensor other) -> Tensor | ||||
|   supports_named_tensor: True | ||||
|   use_c10_dispatcher: full | ||||
|   variants: method, function | ||||
|   dispatch: | ||||
|  | ||||
							
								
								
									
										60
									
								
								aten/src/ATen/native/quantized/cpu/fake_quantize_core.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								aten/src/ATen/native/quantized/cpu/fake_quantize_core.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,60 @@ | ||||
| #include <ATen/ATen.h> | ||||
| #include <ATen/NativeFunctions.h> | ||||
| #include <ATen/native/TensorIterator.h> | ||||
| #include <ATen/native/cpu/Loops.h> | ||||
|  | ||||
| /* Core operations for fake-quantization shared between per tensor | ||||
|  and per-channel fake quant */ | ||||
| namespace at { | ||||
| namespace native { | ||||
|  | ||||
| /* Fake quantize a tensor, common block for per-channel & per-tensor fake quant | ||||
| Args: | ||||
|   output: output tensor. | ||||
|   input : input tensor. | ||||
|   sc:  scale to quantize the input tensor to | ||||
|   zero_point: zero_point | ||||
|   quant_min: minimum quantized value | ||||
|   quant_max: maximum quantized value | ||||
| Returns: | ||||
|   Fake quantized tensor (double dtype). | ||||
| */ | ||||
| void fake_quantize_slice( | ||||
|     Tensor& output, | ||||
|     const Tensor& input, | ||||
|     float sc, | ||||
|     int64_t z_point, | ||||
|     int64_t quant_min, | ||||
|     int64_t quant_max) { | ||||
|   float inv_scale = 1.0f / sc; | ||||
|   auto iter = TensorIterator::unary_op(output, input); | ||||
|   cpu_kernel(iter, [&](float self) -> float { | ||||
|     return (std::fmin( | ||||
|                 std::fmax( | ||||
|                     static_cast<int64_t>( | ||||
|                         std::nearbyint(self * inv_scale + z_point)), | ||||
|                     quant_min), | ||||
|                 quant_max) - | ||||
|             z_point) * | ||||
|         sc; | ||||
|   }); | ||||
| } | ||||
|  | ||||
| void fake_quantize_grad_slice( | ||||
|     Tensor& input_grad, | ||||
|     const Tensor& input, | ||||
|     const Tensor& output_grad, | ||||
|     float sc, | ||||
|     int64_t z_point, | ||||
|     int64_t quant_min, | ||||
|     int64_t quant_max) { | ||||
|   float inv_scale = 1.0f / sc; | ||||
|   auto iter = TensorIterator::binary_op(input_grad, input, output_grad); | ||||
|   cpu_kernel(iter, [&](float x, float dy) -> float { | ||||
|     int64_t xq = static_cast<int64_t>(std::nearbyint(x * inv_scale + z_point)); | ||||
|     return dy * (xq >= quant_min && xq <= quant_max); | ||||
|   }); | ||||
| } | ||||
|  | ||||
| } // namespace native | ||||
| } // namespace at | ||||
							
								
								
									
										27
									
								
								aten/src/ATen/native/quantized/cpu/fake_quantize_core.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								aten/src/ATen/native/quantized/cpu/fake_quantize_core.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,27 @@ | ||||
| #include <ATen/ATen.h> | ||||
| #include <ATen/NativeFunctions.h> | ||||
| #include <ATen/native/TensorIterator.h> | ||||
| #include <ATen/native/cpu/Loops.h> | ||||
|  | ||||
| /* FakeQuantize Op for PerChannelAffine quantization scheme */ | ||||
| namespace at { | ||||
| namespace native { | ||||
| void fake_quantize_slice( | ||||
|     Tensor& output, | ||||
|     const Tensor& input, | ||||
|     float sc, | ||||
|     int64_t z_point, | ||||
|     int64_t quant_min, | ||||
|     int64_t quant_max); | ||||
|  | ||||
| void fake_quantize_grad_slice( | ||||
|     Tensor& input_grad, | ||||
|     const Tensor& output_grad, | ||||
|     const Tensor& input, | ||||
|     float sc, | ||||
|     int64_t z_point, | ||||
|     int64_t quant_min, | ||||
|     int64_t quant_max); | ||||
|  | ||||
| } // namespace native | ||||
| } // namespace at | ||||
| @ -0,0 +1,147 @@ | ||||
| #include <ATen/ATen.h> | ||||
| #include <ATen/NativeFunctions.h> | ||||
| #include <ATen/native/TensorIterator.h> | ||||
| #include <ATen/native/cpu/Loops.h> | ||||
| #include <ATen/native/quantized/cpu/fake_quantize_core.h> | ||||
|  | ||||
| /* FakeQuantize Op for PerChannelAffine quantization scheme */ | ||||
| namespace at { | ||||
| namespace native { | ||||
|  | ||||
| /* Per channel fake-quantizes the 'inputs' tensor. | ||||
| Args: | ||||
|   X: Forward input tensor. | ||||
|   dY: Backward input tensor (_backward op only). | ||||
|   scale: scale of per channel affine quantization | ||||
|   zero_point: zero_point of per channel affine quantization | ||||
|   axis: int specifying the axis to be quantized | ||||
|   quant_min: minimum quantized value | ||||
|   quant_max: maximum quantized value | ||||
| Returns: | ||||
|   Fake quantized tensor (double dtype). | ||||
|  | ||||
| */ | ||||
|  | ||||
| Tensor fake_quantize_per_channel_affine_cpu( | ||||
|     const Tensor& self, | ||||
|     const Tensor& scale, | ||||
|     const Tensor& zero_point, | ||||
|     int64_t axis, | ||||
|     int64_t quant_min, | ||||
|     int64_t quant_max) { | ||||
|   // TODO: Use REGISTER_DISPATCH | ||||
|   TORCH_CHECK(self.scalar_type() == ScalarType::Float); | ||||
|   TORCH_CHECK( | ||||
|       scale.dim() == 1, "scale should be a 1-D tensor"); | ||||
|   TORCH_CHECK( | ||||
|       zero_point.dim() == 1, "zero point should be a 1-D tensor"); | ||||
|   TORCH_CHECK( | ||||
|       scale.numel() == zero_point.numel(), | ||||
|       "scale and zero-point need to have the same dimensions"); | ||||
|   TORCH_CHECK( | ||||
|       scale.numel() == self.size(axis), | ||||
|       "dimensions of scale and zero-point are not consistent with input tensor") | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       quant_min <= quant_max, | ||||
|       "`quant_min` should be less than or \ | ||||
|         equal to `quant_max`."); | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       at::min(zero_point).item().toLong() >= quant_min && | ||||
|           at::max(zero_point).item().toLong() <= quant_max, | ||||
|       "`zero_point` must be between `quant_min` and `quant_max`."); | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       axis >= 0 && axis <= self.dim(), | ||||
|       "`axis` must be between 0 and number of dimensions of input"); | ||||
|  | ||||
|   auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve); | ||||
|   for (int i = 0; i < self.size(axis); i++) { | ||||
|     auto input_slice = self.slice(axis, i, i + 1); | ||||
|     auto output_slice = Y.slice(axis, i, i + 1); | ||||
|     float sc = scale[i].item().toFloat(); | ||||
|     int64_t z_point = zero_point[i].item().toLong(); | ||||
|     fake_quantize_slice( | ||||
|         output_slice, input_slice, sc, z_point, quant_min, quant_max); | ||||
|   } | ||||
|  | ||||
|   return Y; | ||||
| } | ||||
|  | ||||
| /* Backward path for per-channel fake-quantization of the 'inputs' tensor. | ||||
|  | ||||
| Args: | ||||
|   X: Forward input tensor. | ||||
|   dY: Backward input tensor. | ||||
|   scale: scale of per tensor affine quantization | ||||
|   zero_point: zero_point of per tensor affine quantization | ||||
|   axis: int ,the axis over which quantization parameters vary | ||||
|   quant_min: int, minimum quantized value | ||||
|   quant_max: int, maximum quantized value | ||||
|  | ||||
| Returns: | ||||
|   Gradient for per channel fake quant (double dtype). | ||||
|  | ||||
| */ | ||||
| Tensor fake_quantize_per_channel_affine_backward_cpu( | ||||
|     const Tensor& dY, | ||||
|     const Tensor& X, | ||||
|     const Tensor& scale, | ||||
|     const Tensor& zero_point, | ||||
|     int64_t axis, | ||||
|     int64_t quant_min, | ||||
|     int64_t quant_max) { | ||||
|   TORCH_CHECK(dY.scalar_type() == ScalarType::Float); | ||||
|   TORCH_CHECK(X.scalar_type() == ScalarType::Float); | ||||
|  | ||||
|   TORCH_CHECK(X.sizes() == dY.sizes(), "`X` and `dY` are not the same size"); | ||||
|   TORCH_CHECK( | ||||
|       quant_min <= quant_max, | ||||
|       "`quant_min` should be less than or \ | ||||
|         equal to `quant_max`."); | ||||
|   TORCH_CHECK( | ||||
|       scale.dim() == 1, "scale should be a 1-D tensor"); | ||||
|   TORCH_CHECK( | ||||
|       zero_point.dim() == 1, "zero point should be a 1-D tensor"); | ||||
|   TORCH_CHECK( | ||||
|       scale.numel() == zero_point.numel(), | ||||
|       "scale and zero-point need to have the same dimensions"); | ||||
|   TORCH_CHECK( | ||||
|       scale.numel() == X.size(axis), | ||||
|       "dimensions of scale and zero-point are not consistent with input tensor") | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       quant_min <= quant_max, | ||||
|       "`quant_min` should be less than or \ | ||||
|         equal to `quant_max`."); | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       at::min(zero_point).item().toLong() >= quant_min && | ||||
|           at::max(zero_point).item().toLong() <= quant_max, | ||||
|       "`zero_point` must be between `quant_min` and `quant_max`."); | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       axis >= 0 && axis <= X.dim(), | ||||
|       "`axis` must be between 0 and number of dimensions of input"); | ||||
|  | ||||
|   if (X.numel() <= 0) { | ||||
|     return X; | ||||
|   } | ||||
|  | ||||
|   auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve); | ||||
|  | ||||
|   for (int i = 0; i < X.size(axis); i++) { | ||||
|     auto X_slice = X.slice(axis, i, i + 1); | ||||
|     auto dY_slice = dY.slice(axis, i, i + 1); | ||||
|     auto dX_slice = dX.slice(axis, i, i + 1); | ||||
|     float sc = scale[i].item().toFloat(); | ||||
|     int64_t z_point = zero_point[i].item().toLong(); | ||||
|     fake_quantize_grad_slice(dX_slice, X_slice, dY_slice, | ||||
|                              sc, z_point, quant_min, quant_max); | ||||
|   } | ||||
|  | ||||
|   return dX; | ||||
| } | ||||
| } // namespace native | ||||
| } // namespace at | ||||
| @ -2,6 +2,7 @@ | ||||
| #include <ATen/NativeFunctions.h> | ||||
| #include <ATen/native/TensorIterator.h> | ||||
| #include <ATen/native/cpu/Loops.h> | ||||
| #include <ATen/native/quantized/cpu/fake_quantize_core.h> | ||||
|  | ||||
| /* FakeQuantize Op for PerTensorAffine quantization scheme */ | ||||
| namespace at { | ||||
| @ -15,16 +16,9 @@ Args: | ||||
|   zero_point: zero_point of per tensor affine quantization | ||||
|   quant_min: minimum quantized value | ||||
|   quant_max: maximum quantized value | ||||
|   quant_delay: Count of global steps for which to delay the quantization. | ||||
|                See note below. | ||||
|   iter: The current quantization iteration used for `quant_delay`. | ||||
| Returns: | ||||
|   Quantized tensor (double dtype). | ||||
|  | ||||
| Notes: | ||||
|   - quant_delay might be set to non-zero to help weights stabilize in the | ||||
|     beginning of the training. | ||||
|   - quantization range [quant_min, quant_max] | ||||
| */ | ||||
| Tensor fake_quantize_per_tensor_affine_cpu( | ||||
|     const Tensor& self, | ||||
| @ -41,20 +35,8 @@ Tensor fake_quantize_per_tensor_affine_cpu( | ||||
|       zero_point >= quant_min && zero_point <= quant_max, | ||||
|       "`zero_point` must be between `quant_min` and `quant_max`."); | ||||
|  | ||||
|   auto Y = at::empty_like(self); | ||||
|   float inv_scale = 1.0f / scale; | ||||
|   auto iter = TensorIterator::unary_op(Y, self); | ||||
|   cpu_kernel(iter, [&](float self) -> float { | ||||
|     return (std::fmin( | ||||
|                 std::fmax( | ||||
|                     static_cast<int64_t>( | ||||
|                         std::nearbyint(self * inv_scale + zero_point)), | ||||
|                     quant_min), | ||||
|                 quant_max) - | ||||
|             zero_point) * | ||||
|         scale; | ||||
|   }); | ||||
|  | ||||
|   auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve); | ||||
|   fake_quantize_slice(Y, self, scale, zero_point, quant_min, quant_max); | ||||
|   return Y; | ||||
| } | ||||
|  | ||||
| @ -99,14 +81,8 @@ Tensor fake_quantize_per_tensor_affine_backward_cpu( | ||||
|     return X; | ||||
|   } | ||||
|  | ||||
|   Tensor dX = at::zeros_like(X); | ||||
|   auto iter = TensorIterator::binary_op(dX, X, dY); | ||||
|   float inv_scale = 1.0f / scale; | ||||
|   cpu_kernel(iter, [&](float x, float dy) -> float { | ||||
|     int64_t xq = | ||||
|         static_cast<int64_t>(std::nearbyint(x * inv_scale + zero_point)); | ||||
|     return dy * (xq >= quant_min && xq <= quant_max); | ||||
|   }); | ||||
|   auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve); | ||||
|   fake_quantize_grad_slice(dX, X, dY, scale, zero_point, quant_min, quant_max); | ||||
|   return dX; | ||||
| } | ||||
| } // namespace native | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| #include <ATen/ATen.h> | ||||
| #include <ATen/SmallVector.h> | ||||
| #include <ATen/Parallel.h> | ||||
| #include <ATen/core/op_registration/op_registration.h> | ||||
| #include <ATen/cpp_custom_type_hack.h> | ||||
| #include <ATen/native/quantized/cpu/fbgemm_utils.h> | ||||
| @ -153,8 +154,6 @@ class QConv2dInt8 final : public c10::OperatorKernel { | ||||
|         {pad_l, pad_t, pad_l, pad_t}, | ||||
|         {static_cast<int>(dilation[0]), static_cast<int>(dilation[1])}); | ||||
|  | ||||
|     fbgemm::DoNothing<> NoOpObj{}; | ||||
|  | ||||
|     float act_scale = act.q_scale(); | ||||
|     int32_t act_zero_point = act.q_zero_point(); | ||||
|  | ||||
| @ -208,63 +207,69 @@ class QConv2dInt8 final : public c10::OperatorKernel { | ||||
|     // TODO: add MemoryFormat::ChannelsLast here once perf is fixed | ||||
|     Tensor output = _empty_affine_quantized( | ||||
|         outShape, device(kCPU).dtype(kQUInt8), output_scale, output_zero_point); | ||||
|     auto buffer = at::zeros_like(output, output.options().dtype(at::kInt)); | ||||
|     auto buffer = at::empty(output.sizes(), output.options().dtype(at::kInt)); | ||||
|  | ||||
|     if (pack_ptr.q_scheme == kPerTensorAffine) { | ||||
|       fbgemm::ReQuantizeOutput< | ||||
|           ReluFused, | ||||
|           fbgemm::QuantizationGranularity::TENSOR, | ||||
|           float> | ||||
|           outputProcObj( | ||||
|               NoOpObj, | ||||
|               output_multiplier_float.data(), | ||||
|               output_zero_point, | ||||
|               act_zero_point, | ||||
|               pack_ptr.w_zp.data(), | ||||
|               nullptr, /* row offset buffer */ | ||||
|               col_offsets.data(), | ||||
|               bias_ptr, | ||||
|               K, | ||||
|               groups, | ||||
|               act_times_w_scale.data()); | ||||
|       fbgemm::fbgemmConv( | ||||
|           conv_p, | ||||
|           act_ptr, | ||||
|           *packB, | ||||
|           reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()), | ||||
|           buffer.data_ptr<int32_t>(), | ||||
|           outputProcObj, | ||||
|           0 /* thread_id*/, | ||||
|           1 /* num_threads */); | ||||
|     int num_tasks = at::get_num_threads(); | ||||
|     at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) { | ||||
|       fbgemm::DoNothing<> NoOpObj{}; | ||||
|       for (int task_id = begin; task_id < end; ++task_id) { | ||||
|         if (pack_ptr.q_scheme == kPerTensorAffine) { | ||||
|           fbgemm::ReQuantizeOutput< | ||||
|               ReluFused, | ||||
|               fbgemm::QuantizationGranularity::TENSOR, | ||||
|               float> | ||||
|               outputProcObj( | ||||
|                   NoOpObj, | ||||
|                   output_multiplier_float.data(), | ||||
|                   output_zero_point, | ||||
|                   act_zero_point, | ||||
|                   pack_ptr.w_zp.data(), | ||||
|                   nullptr, /* row offset buffer */ | ||||
|                   col_offsets.data(), | ||||
|                   bias_ptr, | ||||
|                   K, | ||||
|                   groups, | ||||
|                   act_times_w_scale.data()); | ||||
|           fbgemm::fbgemmConv( | ||||
|               conv_p, | ||||
|               act_ptr, | ||||
|               *packB, | ||||
|               reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()), | ||||
|               buffer.data_ptr<int32_t>(), | ||||
|               outputProcObj, | ||||
|               task_id /* thread_id*/, | ||||
|               num_tasks /* num_threads */); | ||||
|  | ||||
|     } else if (pack_ptr.q_scheme == kPerChannelAffine) { | ||||
|       fbgemm::ReQuantizeOutput< | ||||
|           ReluFused, | ||||
|           fbgemm::QuantizationGranularity::OUT_CHANNEL, | ||||
|           float> | ||||
|           outputProcObj( | ||||
|               NoOpObj, | ||||
|               output_multiplier_float.data(), | ||||
|               output_zero_point, | ||||
|               act_zero_point, | ||||
|               pack_ptr.w_zp.data(), | ||||
|               nullptr, /* row offset buffer */ | ||||
|               col_offsets.data(), | ||||
|               bias_ptr, | ||||
|               K, | ||||
|               groups, | ||||
|               act_times_w_scale.data()); | ||||
|         } else if (pack_ptr.q_scheme == kPerChannelAffine) { | ||||
|           fbgemm::ReQuantizeOutput< | ||||
|               ReluFused, | ||||
|               fbgemm::QuantizationGranularity::OUT_CHANNEL, | ||||
|               float> | ||||
|               outputProcObj( | ||||
|                   NoOpObj, | ||||
|                   output_multiplier_float.data(), | ||||
|                   output_zero_point, | ||||
|                   act_zero_point, | ||||
|                   pack_ptr.w_zp.data(), | ||||
|                   nullptr, /* row offset buffer */ | ||||
|                   col_offsets.data(), | ||||
|                   bias_ptr, | ||||
|                   K, | ||||
|                   groups, | ||||
|                   act_times_w_scale.data()); | ||||
|  | ||||
|       fbgemm::fbgemmConv( | ||||
|           conv_p, | ||||
|           act_ptr, | ||||
|           *packB, | ||||
|           reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()), | ||||
|           buffer.data_ptr<int32_t>(), | ||||
|           outputProcObj, | ||||
|           0 /* thread_id*/, | ||||
|           1 /* num_threads */); | ||||
|     } | ||||
|           fbgemm::fbgemmConv( | ||||
|               conv_p, | ||||
|               act_ptr, | ||||
|               *packB, | ||||
|               reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()), | ||||
|               buffer.data_ptr<int32_t>(), | ||||
|               outputProcObj, | ||||
|               task_id /* thread_id*/, | ||||
|               num_tasks /* num_threads */); | ||||
|         } | ||||
|       } | ||||
|     }); | ||||
|  | ||||
|     // TODO: remove permute once MemoryLayout is added above | ||||
|     return output.permute({0, 3, 1, 2}); | ||||
| @ -355,7 +360,8 @@ class QConv2dInt8 final : public c10::OperatorKernel { | ||||
|           kernel_zp, | ||||
|           MemoryFormat::ChannelsLast); | ||||
|       auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>(); | ||||
|       for (int i = 0; i < weight_contig.numel(); ++i) { | ||||
|       auto wt_numel = weight_contig.numel(); | ||||
|       for (int i = 0; i < wt_numel; ++i) { | ||||
|         qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128); | ||||
|       } | ||||
|       // Original bias was float, so we requantize it here. | ||||
|  | ||||
| @ -218,7 +218,8 @@ class QConvPackWeightInt8 final : public c10::OperatorKernel { | ||||
|         weight.q_scale(), | ||||
|         weight_zp); | ||||
|     auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>(); | ||||
|     for (int i = 0; i < weight_contig.numel(); ++i) { | ||||
|     auto wt_numel = weight_contig.numel(); | ||||
|     for (int i = 0; i < wt_numel; ++i) { | ||||
|       qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128); | ||||
|     } | ||||
|     // We set the pre-packed conv weights to nullptr below as we call pre-pack | ||||
|  | ||||
| @ -1,4 +1,5 @@ | ||||
| #include <ATen/ATen.h> | ||||
| #include <ATen/Parallel.h> | ||||
| #include <ATen/core/op_registration/op_registration.h> | ||||
| #include <ATen/cpp_custom_type_hack.h> | ||||
| #include <ATen/native/quantized/cpu/fbgemm_utils.h> | ||||
| @ -80,35 +81,6 @@ class QLinearInt8 final : public torch::OperatorKernel { | ||||
|     } | ||||
|     int32_t output_zero_point_int32 = static_cast<int32_t>(output_zero_point); | ||||
|  | ||||
|     // This operation does the following: | ||||
|     // 1) Creates a "row buffer" vector with offset values that must be added | ||||
|     //    to the integer matrix multiplication operation to ensure correctness. | ||||
|     //    This "row buffer" is also called the row offset, and it is needed when | ||||
|     //    we use affine quantization for weights. | ||||
|     // 2) Packs the resulting quantized matrix into vector-register and cache | ||||
|     //    friendly tiles. | ||||
|     // | ||||
|     //  Note this is not executed eagerly, but rather within the fbgemmPacked | ||||
|     //  call below. | ||||
|     fbgemm::PackAWithRowOffset<uint8_t> packA( | ||||
|         /*trans=*/fbgemm::matrix_op_t::NoTranspose, | ||||
|         /*nRow=*/M, | ||||
|         /*nCol=*/K, | ||||
|         /*smat=*/input_ptr, | ||||
|         /*ld=*/K, | ||||
|         /*pmat=*/nullptr); // Currently, packA manages ownership of `pmat`. | ||||
|                            // TODO: Consider a way to pre-allocate and reuse | ||||
|                            // pmat buffer. | ||||
|  | ||||
|     // ReQuantizeOutput requires pointers to the zero point values, | ||||
|     // since in the case of rowwise quantization these will be arrays rather | ||||
|     // than scalars. But in this case, we're doing whole-tensor quantization so | ||||
|     // we just pass a pointer to the scale values (and internally | ||||
|     // ReQuantizeOutput won't index past 0. | ||||
|  | ||||
|     // This is the end of the pipeline, pass the resulting matrix through. | ||||
|     fbgemm::DoNothing<> doNothingObj{}; | ||||
|  | ||||
|     const float* bias_ptr = nullptr; | ||||
|     at::Tensor bias; | ||||
|     if (pack_ptr.bias.has_value()) { | ||||
| @ -134,79 +106,114 @@ class QLinearInt8 final : public torch::OperatorKernel { | ||||
|         output_scale, | ||||
|         output_zero_point); | ||||
|  | ||||
|     auto buffer = at::zeros_like(output, output.options().dtype(at::kInt)); | ||||
|     auto buffer = at::empty(out_sizes, output.options().dtype(at::kInt)); | ||||
|  | ||||
|     if (pack_ptr.q_scheme == kPerTensorAffine) { | ||||
|       // Process the per tensor quantization. | ||||
|       // | ||||
|       // After the uint8 * int8 matrix multiplication is performed, this | ||||
|       // operation does: | ||||
|       //  1) Add in row and column offsets to the rows and columns, | ||||
|       //  respectively. | ||||
|       //  2) Add in the bias term. | ||||
|       fbgemm::ReQuantizeOutput< | ||||
|           ReluFused, | ||||
|           fbgemm::QuantizationGranularity::TENSOR, | ||||
|           float> | ||||
|           outputProcObj( | ||||
|               doNothingObj, | ||||
|               output_multiplier_float.data(), | ||||
|               output_zero_point_int32, | ||||
|               input_zero_point_int32, | ||||
|               pack_ptr.w_zp.data(), | ||||
|               packA.getRowOffsetBuffer(), | ||||
|               col_offsets.data(), | ||||
|               bias_ptr, | ||||
|               N, /* nCol */ | ||||
|               1 /* groups */, | ||||
|               act_times_w_scale.data()); | ||||
|     int num_tasks = at::get_num_threads(); | ||||
|     at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) { | ||||
|       for (int task_id = begin; task_id < end; ++task_id) { | ||||
|         // This operation does the following: | ||||
|         // 1) Creates a "row buffer" vector with offset values that must be | ||||
|         //    added to the integer matrix multiplication operation to ensure | ||||
|         //    correctness. This "row buffer" is also called the row offset, and | ||||
|         //    it is needed when we use affine quantization for weights. | ||||
|         // 2) Packs the resulting quantized matrix into vector-register and | ||||
|         //    cache friendly tiles. | ||||
|         // | ||||
|         //  Note this is not executed eagerly, but rather within the | ||||
|         //  fbgemmPacked call below. | ||||
|         fbgemm::PackAWithRowOffset<uint8_t> packA( | ||||
|             /*trans=*/fbgemm::matrix_op_t::NoTranspose, | ||||
|             /*nRow=*/M, | ||||
|             /*nCol=*/K, | ||||
|             /*smat=*/input_ptr, | ||||
|             /*ld=*/K, | ||||
|             /*pmat=*/nullptr); // Currently, packA manages ownership of `pmat`. | ||||
|                                // TODO: Consider a way to pre-allocate and reuse | ||||
|                                // pmat buffer. | ||||
|  | ||||
|       // Do the GEMM | ||||
|       fbgemm::fbgemmPacked( | ||||
|           /*packA=*/packA, | ||||
|           /*packB=*/*packB, | ||||
|           /*C=*/reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()), | ||||
|           /*C_buffer=*/buffer.data_ptr<int32_t>(), | ||||
|           /*ldc=*/N, | ||||
|           /*outProcess=*/outputProcObj, | ||||
|           /*thread_id=*/0, | ||||
|           /*num_threads=*/1); | ||||
|     } else if (pack_ptr.q_scheme == kPerChannelAffine) { | ||||
|       // Process the per channel quantization. | ||||
|       // | ||||
|       // After the uint8 * int8 matrix multiplication is performed, this | ||||
|       // operation does: | ||||
|       //  1) Add in row and column offsets to the rows and columns, | ||||
|       //  respectively. | ||||
|       //  2) Add in the bias term. | ||||
|       fbgemm::ReQuantizeOutput< | ||||
|           ReluFused, | ||||
|           fbgemm::QuantizationGranularity::OUT_CHANNEL, | ||||
|           float> | ||||
|           outputProcObj( | ||||
|               doNothingObj, | ||||
|               output_multiplier_float.data(), | ||||
|               output_zero_point_int32, | ||||
|               input_zero_point_int32, | ||||
|               pack_ptr.w_zp.data(), | ||||
|               packA.getRowOffsetBuffer(), | ||||
|               col_offsets.data(), | ||||
|               bias_ptr, | ||||
|               N, /*nCol=*/ | ||||
|               1, /* groups*/ | ||||
|               act_times_w_scale.data()); | ||||
|         // ReQuantizeOutput requires pointers to the zero point values, | ||||
|         // since in the case of rowwise quantization these will be arrays rather | ||||
|         // than scalars. But in this case, we're doing whole-tensor quantization | ||||
|         // so we just pass a pointer to the scale values (and internally | ||||
|         // ReQuantizeOutput won't index past 0. | ||||
|  | ||||
|         // This is the end of the pipeline, pass the resulting matrix through. | ||||
|         fbgemm::DoNothing<> doNothingObj{}; | ||||
|  | ||||
|         if (pack_ptr.q_scheme == kPerTensorAffine) { | ||||
|           // Process the per tensor quantization. | ||||
|           // | ||||
|           // After the uint8 * int8 matrix multiplication is performed, this | ||||
|           // operation does: | ||||
|           //  1) Add in row and column offsets to the rows and columns, | ||||
|           //  respectively. | ||||
|           //  2) Add in the bias term. | ||||
|           fbgemm::ReQuantizeOutput< | ||||
|               ReluFused, | ||||
|               fbgemm::QuantizationGranularity::TENSOR, | ||||
|               float> | ||||
|               outputProcObj( | ||||
|                   doNothingObj, | ||||
|                   output_multiplier_float.data(), | ||||
|                   output_zero_point_int32, | ||||
|                   input_zero_point_int32, | ||||
|                   pack_ptr.w_zp.data(), | ||||
|                   packA.getRowOffsetBuffer(), | ||||
|                   col_offsets.data(), | ||||
|                   bias_ptr, | ||||
|                   N, /* nCol */ | ||||
|                   1 /* groups */, | ||||
|                   act_times_w_scale.data()); | ||||
|  | ||||
|           // Do the GEMM | ||||
|           fbgemm::fbgemmPacked( | ||||
|               /*packA=*/packA, | ||||
|               /*packB=*/*packB, | ||||
|               /*C=*/reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()), | ||||
|               /*C_buffer=*/buffer.data_ptr<int32_t>(), | ||||
|               /*ldc=*/N, | ||||
|               /*outProcess=*/outputProcObj, | ||||
|               /*thread_id=*/task_id, | ||||
|               /*num_threads=*/num_tasks); | ||||
|         } else if (pack_ptr.q_scheme == kPerChannelAffine) { | ||||
|           // Process the per channel quantization. | ||||
|           // | ||||
|           // After the uint8 * int8 matrix multiplication is performed, this | ||||
|           // operation does: | ||||
|           //  1) Add in row and column offsets to the rows and columns, | ||||
|           //  respectively. | ||||
|           //  2) Add in the bias term. | ||||
|           fbgemm::ReQuantizeOutput< | ||||
|               ReluFused, | ||||
|               fbgemm::QuantizationGranularity::OUT_CHANNEL, | ||||
|               float> | ||||
|               outputProcObj( | ||||
|                   doNothingObj, | ||||
|                   output_multiplier_float.data(), | ||||
|                   output_zero_point_int32, | ||||
|                   input_zero_point_int32, | ||||
|                   pack_ptr.w_zp.data(), | ||||
|                   packA.getRowOffsetBuffer(), | ||||
|                   col_offsets.data(), | ||||
|                   bias_ptr, | ||||
|                   N, /*nCol=*/ | ||||
|                   1, /* groups*/ | ||||
|                   act_times_w_scale.data()); | ||||
|  | ||||
|           // Do the GEMM | ||||
|           fbgemm::fbgemmPacked( | ||||
|               /*packA=*/packA, | ||||
|               /*packB=*/*packB, | ||||
|               /*C=*/reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()), | ||||
|               /*C_buffer=*/buffer.data_ptr<int32_t>(), | ||||
|               /*ldc=*/N, | ||||
|               /*outProcess=*/outputProcObj, | ||||
|               /*thread_id=*/task_id, | ||||
|               /*num_threads=*/num_tasks); | ||||
|         } | ||||
|       } | ||||
|     }); | ||||
|  | ||||
|       // Do the GEMM | ||||
|       fbgemm::fbgemmPacked( | ||||
|           /*packA=*/packA, | ||||
|           /*packB=*/*packB, | ||||
|           /*C=*/reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()), | ||||
|           /*C_buffer=*/buffer.data_ptr<int32_t>(), | ||||
|           /*ldc=*/N, | ||||
|           /*outProcess=*/outputProcObj, | ||||
|           /*thread_id=*/0, | ||||
|           /*num_threads=*/1); | ||||
|     } | ||||
|     return output; | ||||
|   } | ||||
| #endif | ||||
| @ -242,7 +249,8 @@ class QLinearInt8 final : public torch::OperatorKernel { | ||||
|           kernel_scale, | ||||
|           kernel_zp); | ||||
|       auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>(); | ||||
|       for (int i = 0; i < weight_contig.numel(); ++i) { | ||||
|       auto wt_numel = weight_contig.numel(); | ||||
|       for (int i = 0; i < wt_numel; ++i) { | ||||
|         qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128); | ||||
|       } | ||||
|       // Original bias was float, so we requantize it here. | ||||
|  | ||||
| @ -127,8 +127,8 @@ class QLinearDynamicInt8 final : public torch::OperatorKernel { | ||||
|     std::vector<int64_t> out_sizes = input.sizes().vec(); | ||||
|     out_sizes.back() = N; | ||||
|     // Allocate output Tensor and a buffer for fbgemmPacked to use | ||||
|     auto output = at::zeros(out_sizes, input.options().dtype(at::kFloat)); | ||||
|     auto buffer = at::zeros_like(output, output.options().dtype(at::kInt)); | ||||
|     auto output = at::empty(out_sizes, input.options().dtype(at::kFloat)); | ||||
|     auto buffer = at::empty_like(output, output.options().dtype(at::kInt)); | ||||
|  | ||||
|     if (pack_ptr.q_scheme == kPerTensorAffine) { | ||||
|       // Process the per tensor quantization. | ||||
|  | ||||
| @ -163,7 +163,8 @@ class QLinearPackWeightInt8 final : public c10::OperatorKernel { | ||||
|         weight.q_scale(), | ||||
|         weight_zp); | ||||
|     auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>(); | ||||
|     for (int i = 0; i < weight_contig.numel(); ++i) { | ||||
|     auto wt_numel = weight_contig.numel(); | ||||
|     for (int i = 0; i < wt_numel; ++i) { | ||||
|       qnnp_w_data[i] = static_cast<c10::quint8>(inp_data[i] + 128); | ||||
|     } | ||||
|     initQNNPACK(); | ||||
|  | ||||
							
								
								
									
										60
									
								
								aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,60 @@ | ||||
| #include <ATen/ATen.h> | ||||
| #include <ATen/NativeFunctions.h> | ||||
| #include <ATen/cuda/CUDAApplyUtils.cuh> | ||||
| #include <cmath> | ||||
|  | ||||
| /* Fake quantize a tensor, common block for per-channel & per-tensor fake quant | ||||
| Args: | ||||
|   output: output tensor. | ||||
|   input : input tensor. | ||||
|   sc:  scale to quantize the input tensor to | ||||
|   zero_point: zero_point | ||||
|   quant_min: minimum quantized value | ||||
|   quant_max: maximum quantized value | ||||
| Returns: | ||||
|   Fake quantized tensor (double dtype). | ||||
| */ | ||||
| namespace at { | ||||
| namespace native { | ||||
| void fake_quantize_slice_cuda( | ||||
|     Tensor& output, | ||||
|     const Tensor& input, | ||||
|     float scale, | ||||
|     int64_t zero_point, | ||||
|     int64_t quant_min, | ||||
|     int64_t quant_max) { | ||||
|   float inv_scale = 1.0f / scale; | ||||
|   at::cuda::CUDA_tensor_apply2<float, float>( | ||||
|       input, output, [=] __device__(const float& input_val, float& result_val) { | ||||
|         result_val = (fminf( | ||||
|                           quant_max, | ||||
|                           fmaxf( | ||||
|                               quant_min, | ||||
|                               static_cast<int64_t>(std::nearbyint( | ||||
|                                   input_val * inv_scale + zero_point)))) - | ||||
|                       zero_point) * | ||||
|             scale; | ||||
|       }); | ||||
| } | ||||
|  | ||||
| void fake_quantize_grad_slice_cuda( | ||||
|     Tensor& input_grad, | ||||
|     const Tensor& input, | ||||
|     const Tensor& output_grad, | ||||
|     float scale, | ||||
|     int64_t zero_point, | ||||
|     int64_t quant_min, | ||||
|     int64_t quant_max) { | ||||
|   float inv_scale = 1.0f / scale; | ||||
|   at::cuda::CUDA_tensor_apply3<float, float, float>( | ||||
|       output_grad, | ||||
|       input, | ||||
|       input_grad, | ||||
|       [=] __device__(const float& dy, const float& x, float& dx) { | ||||
|         int64_t Xq = std::nearbyint(x * inv_scale + zero_point); | ||||
|         dx = (Xq >= quant_min && Xq <= quant_max) * dy; | ||||
|       }); | ||||
| } | ||||
|  | ||||
| } // namespace native | ||||
| } // namespace at | ||||
							
								
								
									
										28
									
								
								aten/src/ATen/native/quantized/cuda/fake_quantize_core.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								aten/src/ATen/native/quantized/cuda/fake_quantize_core.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,28 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/ATen.h> | ||||
| #include <ATen/NativeFunctions.h> | ||||
| #include <cmath> | ||||
|  | ||||
| /* FakeQuantize Op for PerChannelAffine quantization scheme */ | ||||
| namespace at { | ||||
| namespace native { | ||||
| void fake_quantize_slice_cuda( | ||||
|     Tensor& output, | ||||
|     const Tensor& input, | ||||
|     float sc, | ||||
|     int64_t z_point, | ||||
|     int64_t quant_min, | ||||
|     int64_t quant_max); | ||||
|  | ||||
| void fake_quantize_grad_slice_cuda( | ||||
|     Tensor& input_grad, | ||||
|     const Tensor& output_grad, | ||||
|     const Tensor& input, | ||||
|     float sc, | ||||
|     int64_t z_point, | ||||
|     int64_t quant_min, | ||||
|     int64_t quant_max); | ||||
|  | ||||
| } // namespace native | ||||
| } // namespace at | ||||
| @ -0,0 +1,144 @@ | ||||
| #include <ATen/ATen.h> | ||||
| #include <ATen/NativeFunctions.h> | ||||
| #include <ATen/cuda/CUDAApplyUtils.cuh> | ||||
| #include <cmath> | ||||
| #include <ATen/native/quantized/cuda/fake_quantize_core.h> | ||||
|  | ||||
| /* FakeQuantize Op for PerChannelAffine quantization scheme */ | ||||
| namespace at { | ||||
| namespace native { | ||||
|  | ||||
| /* Per channel fake-quantizes the 'inputs' tensor. | ||||
| Args: | ||||
|   self: Forward input tensor. | ||||
|   scale: scale of per channel affine quantization | ||||
|   zero_point: zero_point of per channel affine quantization | ||||
|   axis: int specifying the axis to be quantized | ||||
|   quant_min: minimum quantized value | ||||
|   quant_max: maximum quantized value | ||||
| Returns: | ||||
|   Fake quantized tensor (double dtype). | ||||
|  | ||||
| */ | ||||
| Tensor fake_quantize_per_channel_affine_cuda( | ||||
|     const Tensor& self, | ||||
|     const Tensor& scale, | ||||
|     const Tensor& zero_point, | ||||
|     int64_t axis, | ||||
|     int64_t quant_min, | ||||
|     int64_t quant_max) { | ||||
|   TORCH_CHECK(self.is_cuda()); | ||||
|   TORCH_CHECK(self.scalar_type() == ScalarType::Float); | ||||
|   TORCH_CHECK( | ||||
|       scale.dim() == 1, "scale should be a 1-D tensor"); | ||||
|   TORCH_CHECK( | ||||
|       zero_point.dim() == 1, "zero point should be a 1-D tensor"); | ||||
|   TORCH_CHECK( | ||||
|       scale.numel() == zero_point.numel(), | ||||
|       "scale and zero-point need to have the same dimensions"); | ||||
|   TORCH_CHECK( | ||||
|       scale.numel() == self.size(axis), | ||||
|       "dimensions of scale and zero-point are not consistent with input tensor") | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       quant_min <= quant_max, | ||||
|       "`quant_min` should be less than or \ | ||||
|         equal to `quant_max`."); | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       at::min(zero_point).item().toLong() >= quant_min && | ||||
|           at::max(zero_point).item().toLong() <= quant_max, | ||||
|       "`zero_point` must be between `quant_min` and `quant_max`."); | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       axis >= 0 && axis <= self.dim(), | ||||
|       "`axis` must be between 0 and number of dimensions of input"); | ||||
|  | ||||
|   auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve); | ||||
|   for (int i = 0; i < self.size(axis); i++) { | ||||
|     auto X_slice = self.slice(axis, i, i + 1); | ||||
|     auto Y_slice = Y.slice(axis, i, i + 1); | ||||
|     float sc = scale[i].item().toFloat(); | ||||
|     int64_t zp = zero_point[i].item().toLong(); | ||||
|     fake_quantize_slice_cuda(Y_slice, X_slice, sc, zp, quant_min, quant_max); | ||||
|   } | ||||
|   return Y; | ||||
| } | ||||
|  | ||||
| /* Backward path for per-channel fake-quantization of the 'inputs' tensor. | ||||
|  | ||||
| Args: | ||||
|   dY: Backward input tensor. | ||||
|   X: Forward input tensor. | ||||
|   scale: scale of per tensor affine quantization | ||||
|   zero_point: zero_point of per tensor affine quantization | ||||
|   axis: int ,the axis over which quantization parameters vary | ||||
|   quant_min: int, minimum quantized value | ||||
|   quant_max: int, maximum quantized value | ||||
|  | ||||
| Returns: | ||||
|   Gradient for per channel fake quant (double dtype). | ||||
|  | ||||
| */ | ||||
| Tensor fake_quantize_per_channel_affine_backward_cuda( | ||||
|     const Tensor& dY, | ||||
|     const Tensor& X, | ||||
|     const Tensor& scale, | ||||
|     const Tensor& zero_point, | ||||
|     int64_t axis, | ||||
|     int64_t quant_min, | ||||
|     int64_t quant_max) { | ||||
|   TORCH_CHECK(dY.is_cuda()); | ||||
|   TORCH_CHECK(dY.scalar_type() == ScalarType::Float); | ||||
|   TORCH_CHECK(X.scalar_type() == ScalarType::Float); | ||||
|  | ||||
|   TORCH_CHECK(X.sizes() == dY.sizes(), "`X` and `dY` are not the same size"); | ||||
|   TORCH_CHECK( | ||||
|       quant_min <= quant_max, | ||||
|       "`quant_min` should be less than or \ | ||||
|         equal to `quant_max`."); | ||||
|   TORCH_CHECK( | ||||
|       scale.dim() == 1, "scale should be a 1-D tensor"); | ||||
|   TORCH_CHECK( | ||||
|       zero_point.dim() == 1, "zero point should be a 1-D tensor"); | ||||
|   TORCH_CHECK( | ||||
|       scale.numel() == zero_point.numel(), | ||||
|       "scale and zero-point need to have the same dimensions"); | ||||
|   TORCH_CHECK( | ||||
|       scale.numel() == X.size(axis), | ||||
|       "dimensions of scale and zero-point are not consistent with input tensor") | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       quant_min <= quant_max, | ||||
|       "`quant_min` should be less than or \ | ||||
|         equal to `quant_max`."); | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       at::min(zero_point).item().toLong() >= quant_min && | ||||
|           at::max(zero_point).item().toLong() <= quant_max, | ||||
|       "`zero_point` must be between `quant_min` and `quant_max`."); | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       axis >= 0 && axis <= X.dim(), | ||||
|       "`axis` must be between 0 and number of dimensions of input"); | ||||
|  | ||||
|   if (X.numel() <= 0) { | ||||
|     return X; | ||||
|   } | ||||
|  | ||||
|   auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve); | ||||
|  | ||||
|   for (int i = 0; i < X.size(axis); i++) { | ||||
|     auto dY_slice = dY.slice(axis, i, i + 1); | ||||
|     auto X_slice = X.slice(axis, i, i + 1); | ||||
|     auto dX_slice = dX.slice(axis, i, i + 1); | ||||
|     float sc = scale[i].item().toFloat(); | ||||
|     int64_t zp = zero_point[i].item().toLong(); | ||||
|     fake_quantize_grad_slice_cuda( | ||||
|         dX_slice, X_slice, dY_slice, sc, zp, quant_min, quant_max); | ||||
|   } | ||||
|   return dX; | ||||
| } | ||||
|  | ||||
| } // namespace native | ||||
| } // namespace at | ||||
| @ -2,6 +2,7 @@ | ||||
| #include <ATen/NativeFunctions.h> | ||||
| #include <ATen/cuda/CUDAApplyUtils.cuh> | ||||
| #include <cmath> | ||||
| #include <ATen/native/quantized/cuda/fake_quantize_core.h> | ||||
|  | ||||
| /* FakeQuantize Op for PerTensorAffine quantization scheme */ | ||||
| namespace at { | ||||
| @ -9,21 +10,13 @@ namespace native { | ||||
|  | ||||
| /* Fake-quantizes the 'inputs' tensor. | ||||
| Args: | ||||
|   X: Forward input tensor. | ||||
|   self: Forward input tensor. | ||||
|   scale: scale of per tensor affine quantization | ||||
|   zero_point: zero_point of per tensor affine quantization | ||||
|   quant_min: minimum quantized value | ||||
|   quant_max: maximum quantized value | ||||
|   quant_delay: Count of global steps for which to delay the quantization. | ||||
|                See note below. | ||||
|   iter: The current quantization iteration used for `quant_delay`. | ||||
| Returns: | ||||
|   Quantized tensor (double dtype). | ||||
|  | ||||
| Notes: | ||||
|   - quant_delay might be set to non-zero to help weights stabilize in the | ||||
|     beginning of the training. | ||||
|   - quantization range [quant_min, quant_max] | ||||
| */ | ||||
| Tensor fake_quantize_per_tensor_affine_cuda( | ||||
|     const Tensor& self, | ||||
| @ -40,42 +33,22 @@ Tensor fake_quantize_per_tensor_affine_cuda( | ||||
|   TORCH_CHECK( | ||||
|       zero_point >= quant_min && zero_point <= quant_max, | ||||
|       "`zero_point` must be between `quant_min` and `quant_max`."); | ||||
|   auto Y = at::empty_like(self); | ||||
|  | ||||
|   float inv_scale = 1.0f / scale; | ||||
|   at::cuda::CUDA_tensor_apply2<float, float>( | ||||
|       self, Y, [=] __device__(const float& input_val, float& result_val) { | ||||
|         result_val = (fminf( | ||||
|                           quant_max, | ||||
|                           fmaxf( | ||||
|                               quant_min, | ||||
|                               static_cast<int64_t>(std::nearbyint( | ||||
|                                   input_val * inv_scale + zero_point)))) - | ||||
|                       zero_point) * | ||||
|             scale; | ||||
|       }); | ||||
|   auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve); | ||||
|   fake_quantize_slice_cuda(Y, self, scale, zero_point, quant_min, quant_max); | ||||
|   return Y; | ||||
| } | ||||
|  | ||||
| /* Backward path to fake-quantize the 'inputs' tensor. | ||||
|  | ||||
| Args: | ||||
|   X: Forward input tensor. | ||||
|   dY: Backward input tensor. | ||||
|   X: Forward input tensor. | ||||
|   scale: scale of per tensor affine quantization | ||||
|   zero_point: zero_point of per tensor affine quantization | ||||
|   quant_min: minimum quantized value | ||||
|   quant_max: maximum quantized value | ||||
|   quant_delay: Count of global steps for which to delay the quantization. | ||||
|                See note in forward. | ||||
|   iter: The current quantization iteration used for `quant_delay`. | ||||
| Returns: | ||||
|   Quantized tensor (double dtype). | ||||
|  | ||||
| Notes: | ||||
|   - quant_delay might be set to non-zero to help weights stabilize in the | ||||
|     beginning of the training. | ||||
|   - quantization range [quant_min, quant_max] | ||||
| */ | ||||
| Tensor fake_quantize_per_tensor_affine_backward_cuda( | ||||
|     const Tensor& dY, | ||||
| @ -100,14 +73,9 @@ Tensor fake_quantize_per_tensor_affine_backward_cuda( | ||||
|     return X; | ||||
|   } | ||||
|  | ||||
|   auto dX = dY.clone(); | ||||
|  | ||||
|   float inv_scale = 1.0f / scale; | ||||
|   at::cuda::CUDA_tensor_apply3<float, float, float>( | ||||
|       dY, X, dX, [=] __device__(const float& dy, const float& x, float& dx) { | ||||
|         int64_t Xq = std::nearbyint(x * inv_scale + zero_point); | ||||
|         dx = (Xq >= quant_min && Xq <= quant_max) * dy; | ||||
|       }); | ||||
|   auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve); | ||||
|   fake_quantize_grad_slice_cuda( | ||||
|       dX, X, dY, scale, zero_point, quant_min, quant_max); | ||||
|   return dX; | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -297,7 +297,8 @@ Tensor quantize_tensor(Tensor rtensor, Tensor qtensor, double scale, int64_t zer | ||||
|   } | ||||
| #endif | ||||
|   auto qdata = qtensor.data_ptr<T>(); | ||||
|   for (int i = 0; i < rtensor.numel(); ++i) { | ||||
|   auto numel = rtensor.numel(); | ||||
|   for (int i = 0; i < numel; ++i) { | ||||
|     qdata[i] = quantize_val<T>(scale, zero_point, rdata[i]); | ||||
|   } | ||||
|   return qtensor; | ||||
| @ -318,7 +319,8 @@ Tensor dequantize_tensor(Tensor qtensor, Tensor rtensor, double scale, int64_t z | ||||
|   checkZeroPoint<typename T::underlying>(fn_name, zero_point); | ||||
|   const auto* qd = qtensor.data_ptr<T>(); | ||||
|   float* rd = rtensor.data_ptr<float>(); | ||||
|   for (auto i = 0; i < qtensor.numel(); ++i) { | ||||
|   auto numel = qtensor.numel(); | ||||
|   for (auto i = 0; i < numel; ++i) { | ||||
|     rd[i] = dequantize_val<T>(scale, zero_point, qd[i]); | ||||
|   } | ||||
|   return rtensor; | ||||
|  | ||||
| @ -29,7 +29,8 @@ list(APPEND ATen_CPU_TEST_SRCS | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/xla_tensor_test.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/tensor_iterator_test.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/cpu_generator_test.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp) | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/reduce_ops_test.cpp) | ||||
|  | ||||
| list(APPEND ATen_CUDA_TEST_SRCS | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu | ||||
|  | ||||
							
								
								
									
										24
									
								
								aten/src/ATen/test/reduce_ops_test.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								aten/src/ATen/test/reduce_ops_test.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,24 @@ | ||||
| #include <gtest/gtest.h> | ||||
|  | ||||
| #include <torch/types.h> | ||||
| #include <torch/utils.h> | ||||
|  | ||||
| using namespace at; | ||||
|  | ||||
| TEST(ReduceOpsTest, MaxValuesAndMinValues) { | ||||
|   const int W = 10; | ||||
|   const int H = 10; | ||||
|   if (hasCUDA()) { | ||||
|     for (const auto dtype : {kHalf, kFloat, kDouble, kShort, kInt, kLong}) { | ||||
|       auto a = at::rand({H, W}, TensorOptions(kCUDA).dtype(at::kHalf)); | ||||
|       ASSERT_FLOAT_EQ( | ||||
|         a.max_values(c10::IntArrayRef{0, 1}).item<double>(), | ||||
|         a.max().item<double>() | ||||
|       ); | ||||
|       ASSERT_FLOAT_EQ( | ||||
|         a.min_values(c10::IntArrayRef{0, 1}).item<double>(), | ||||
|         a.min().item<double>() | ||||
|       ); | ||||
|     } | ||||
|   } | ||||
| } | ||||
| @ -172,3 +172,12 @@ TEST(TensorIteratorTest, DoNotComputeCommonDTypeIfOutputIsUndefined) { | ||||
|   iter.compute_common_dtype_only_for_inputs(); | ||||
|   ASSERT_ANY_THROW(iter.build()); | ||||
| } | ||||
|  | ||||
| TEST(TensorIteratorTest, FailNonPromotingBinaryOp) { | ||||
|   Tensor out; | ||||
|   auto iter = at::TensorIterator(); | ||||
|   iter.add_output(out); | ||||
|   iter.add_input(at::ones({1,1}, at::dtype(at::kDouble))); | ||||
|   iter.add_input(at::ones({1,1}, at::dtype(at::kInt))); | ||||
|   ASSERT_ANY_THROW(iter.build()); | ||||
| } | ||||
|  | ||||
| @ -7,6 +7,7 @@ | ||||
| #define Real QInt32 | ||||
| #define RealUnderlying Int | ||||
| #define THQUANTIZED | ||||
| #define THQINT32 | ||||
| #define TH_REAL_IS_BYTE | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| @ -15,6 +16,7 @@ | ||||
| #undef Real | ||||
| #undef RealUnderlying | ||||
| #undef TH_REAL_IS_BYTE | ||||
| #undef THQINT32 | ||||
| #undef THQUANTIZED | ||||
|  | ||||
| #ifndef THGenerateManyTypes | ||||
|  | ||||
| @ -7,6 +7,7 @@ | ||||
| #define Real QInt8 | ||||
| #define RealUnderlying Char | ||||
| #define THQUANTIZED | ||||
| #define THQINT8 | ||||
| #define TH_REAL_IS_BYTE | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| @ -15,6 +16,7 @@ | ||||
| #undef Real | ||||
| #undef RealUnderlying | ||||
| #undef TH_REAL_IS_BYTE | ||||
| #undef THQINT8 | ||||
| #undef THQUANTIZED | ||||
|  | ||||
| #ifndef THGenerateManyTypes | ||||
|  | ||||
| @ -7,6 +7,7 @@ | ||||
| #define Real QUInt8 | ||||
| #define RealUnderlying Byte | ||||
| #define THQUANTIZED | ||||
| #define THQUINT8 | ||||
| #define TH_REAL_IS_BYTE | ||||
| #line 1 TH_GENERIC_FILE | ||||
| #include TH_GENERIC_FILE | ||||
| @ -15,6 +16,7 @@ | ||||
| #undef Real | ||||
| #undef RealUnderlying | ||||
| #undef TH_REAL_IS_BYTE | ||||
| #undef THQUINT8 | ||||
| #undef THQUANTIZED | ||||
|  | ||||
| #ifndef THGenerateManyTypes | ||||
|  | ||||
| @ -35,6 +35,9 @@ | ||||
| #define THLongStorage THStorage | ||||
| #define THBoolStorage THStorage | ||||
| #define THBFloat16Storage THStorage | ||||
| #define THQUInt8Storage THStorage | ||||
| #define THQInt8Storage THStorage | ||||
| #define THQInt32Storage THStorage | ||||
|  | ||||
| TH_API scalar_t* THStorage_(data)(const THStorage*); | ||||
| TH_API ptrdiff_t THStorage_(size)(const THStorage*); | ||||
|  | ||||
| @ -35,5 +35,14 @@ IMPLEMENT_THStorage_COPY(Double) | ||||
| IMPLEMENT_THStorage_COPY(Half) | ||||
| IMPLEMENT_THStorage_COPY(Bool) | ||||
| IMPLEMENT_THStorage_COPY(BFloat16) | ||||
| #ifdef THQUINT8 | ||||
| IMPLEMENT_THStorage_COPY(QUInt8) | ||||
| #endif | ||||
| #ifdef THQINT8 | ||||
| IMPLEMENT_THStorage_COPY(QInt8) | ||||
| #endif | ||||
| #ifdef THQINT32 | ||||
| IMPLEMENT_THStorage_COPY(QInt32) | ||||
| #endif | ||||
|  | ||||
| #endif | ||||
|  | ||||
| @ -14,5 +14,14 @@ TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *s | ||||
| TH_API void THStorage_(copyHalf)(THStorage *storage, struct THHalfStorage *src); | ||||
| TH_API void THStorage_(copyBool)(THStorage *storage, struct THBoolStorage *src); | ||||
| TH_API void THStorage_(copyBFloat16)(THStorage *storage, struct THBFloat16Storage *src); | ||||
| #ifdef THQUINT8 | ||||
| TH_API void THStorage_(copyQUInt8)(THStorage *storage, struct THQUInt8Storage *src); | ||||
| #endif | ||||
| #ifdef THQINT8 | ||||
| TH_API void THStorage_(copyQInt8)(THStorage *storage, struct THQInt8Storage *src); | ||||
| #endif | ||||
| #ifdef THQINT32 | ||||
| TH_API void THStorage_(copyQInt32)(THStorage *storage, struct THQInt32Storage *src); | ||||
| #endif | ||||
|  | ||||
| #endif | ||||
|  | ||||
| @ -21,6 +21,7 @@ | ||||
| #include <iterator> | ||||
| #include <utility> | ||||
| #include <type_traits> | ||||
| #include <stdexcept> | ||||
|  | ||||
| #ifndef _MSC_VER | ||||
| #pragma GCC diagnostic push | ||||
|  | ||||
| @ -117,9 +117,12 @@ if ((NOT TARGET protobuf::libprotobuf) AND (NOT TARGET protobuf::libprotobuf-lit | ||||
|   #     "Please set the proper paths so that I can find protobuf correctly.") | ||||
| endif() | ||||
|  | ||||
| # Protobuf generated files use <> as inclusion path, so maybe we should use | ||||
| # SYSTEM inclusion path. But we need these include dirs to be found before | ||||
| # other protobuf include dirs in Anaconda | ||||
| get_target_property(__tmp protobuf::libprotobuf INTERFACE_INCLUDE_DIRECTORIES) | ||||
| message(STATUS "Caffe2 protobuf include directory: " ${__tmp}) | ||||
| include_directories(BEFORE SYSTEM ${__tmp}) | ||||
| include_directories(BEFORE ${__tmp}) | ||||
|  | ||||
| # If Protobuf_VERSION is known (true in most cases, false if we are building | ||||
| # local protobuf), then we will add a protobuf version check in | ||||
|  | ||||
| @ -53,6 +53,8 @@ extensions = [ | ||||
|     'sphinx.ext.napoleon', | ||||
|     'sphinx.ext.viewcode', | ||||
|     'sphinxcontrib.katex', | ||||
|     'sphinx.ext.autosectionlabel', | ||||
|     'javasphinx', | ||||
| ] | ||||
|  | ||||
| # katex options | ||||
|  | ||||
| @ -16,6 +16,7 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs. | ||||
|    :caption: Notes | ||||
|  | ||||
|    notes/* | ||||
|    PyTorch on XLA Devices <http://pytorch.org/xla/> | ||||
|  | ||||
| .. toctree:: | ||||
|   :glob: | ||||
| @ -25,8 +26,8 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs. | ||||
|   community/* | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 2 | ||||
|    :caption: Package Reference | ||||
|    :maxdepth: 1 | ||||
|    :caption: Python API | ||||
|  | ||||
|    torch | ||||
|    nn | ||||
| @ -42,6 +43,7 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs. | ||||
|    nn.init | ||||
|    onnx | ||||
|    optim | ||||
|    quantization | ||||
|    torch.random <random> | ||||
|    sparse | ||||
|    storage | ||||
| @ -53,6 +55,8 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs. | ||||
|    torch.utils.model_zoo <model_zoo> | ||||
|    torch.utils.tensorboard <tensorboard> | ||||
|    type_info | ||||
|    named_tensor | ||||
|    name_inference | ||||
|    torch.__config__ <__config__> | ||||
|  | ||||
| .. toctree:: | ||||
| @ -62,9 +66,24 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs. | ||||
|  | ||||
|    torchvision/index | ||||
|  | ||||
| * `torchaudio <https://pytorch.org/audio>`_ | ||||
| .. toctree:: | ||||
|    :maxdepth: 1 | ||||
|    :caption: torchaudio Reference | ||||
|  | ||||
| * `torchtext <https://pytorch.org/text>`_ | ||||
|    torchaudio <https://pytorch.org/audio> | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 1 | ||||
|    :caption: torchtext Reference | ||||
|  | ||||
|    torchtext <https://pytorch.org/text> | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 1 | ||||
|    :caption: Other Languages | ||||
|  | ||||
|    C++ API <https://pytorch.org/cppdocs/> | ||||
|    packages | ||||
|  | ||||
| Indices and tables | ||||
| ================== | ||||
|  | ||||
							
								
								
									
										
											BIN
										
									
								
								docs/source/math-quantizer-equation.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/source/math-quantizer-equation.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 9.7 KiB | 
							
								
								
									
										468
									
								
								docs/source/name_inference.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										468
									
								
								docs/source/name_inference.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,468 @@ | ||||
| .. currentmodule:: torch | ||||
|  | ||||
| .. _name_inference_reference-doc: | ||||
|  | ||||
| Named Tensors operator coverage | ||||
| =============================== | ||||
|  | ||||
| Please read :ref:`named_tensors-doc` first for an introduction to named tensors. | ||||
|  | ||||
| This document is a reference for *name inference*, a process that defines how | ||||
| named tensors: | ||||
|  | ||||
| 1. use names to provide additional automatic runtime correctness checks | ||||
| 2. propagate names from input tensors to output tensors | ||||
|  | ||||
| Below is a list of all operations that are supported with named tensors | ||||
| and their associated name inference rules. | ||||
|  | ||||
| If you don't see an operation listed here, but it would help your use case, please | ||||
| `search if an issue has already been filed <https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+named+tensor%22>`_ and if not, `file one <https://github.com/pytorch/pytorch/issues/new/choose>`_. | ||||
|  | ||||
| .. warning:: | ||||
|     The named tensor API is experimental and subject to change. | ||||
|  | ||||
| .. csv-table:: Supported Operations | ||||
|    :header: API, Name inference rule | ||||
|    :widths: 20, 20 | ||||
|  | ||||
|    ":meth:`Tensor.abs`, :func:`torch.abs`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.abs_`,:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.acos`, :func:`torch.acos`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.acos_`,:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.add`, :func:`torch.add`",:ref:`unifies_names_from_inputs-doc` | ||||
|    :meth:`Tensor.add_`,:ref:`unifies_names_from_inputs-doc` | ||||
|    ":meth:`Tensor.addmm`, :func:`torch.addmm`",:ref:`contracts_away_dims-doc` | ||||
|    :meth:`Tensor.addmm_`,:ref:`contracts_away_dims-doc` | ||||
|    ":meth:`Tensor.addmv`, :func:`torch.addmv`",:ref:`contracts_away_dims-doc` | ||||
|    :meth:`Tensor.addmv_`,:ref:`contracts_away_dims-doc` | ||||
|    :meth:`Tensor.align_as`,See documentation | ||||
|    :meth:`Tensor.align_to`,See documentation | ||||
|    ":meth:`Tensor.all`, :func:`torch.all`",None | ||||
|    ":meth:`Tensor.any`, :func:`torch.any`",None | ||||
|    ":meth:`Tensor.asin`, :func:`torch.asin`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.asin_`,:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.atan`, :func:`torch.atan`",:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.atan2`, :func:`torch.atan2`",:ref:`unifies_names_from_inputs-doc` | ||||
|    :meth:`Tensor.atan2_`,:ref:`unifies_names_from_inputs-doc` | ||||
|    :meth:`Tensor.atan_`,:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.bernoulli`, :func:`torch.bernoulli`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.bernoulli_`,None | ||||
|    :meth:`Tensor.bfloat16`,:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.bitwise_not`, :func:`torch.bitwise_not`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.bitwise_not_`,None | ||||
|    ":meth:`Tensor.bmm`, :func:`torch.bmm`",:ref:`contracts_away_dims-doc` | ||||
|    :meth:`Tensor.bool`,:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.byte`,:ref:`keeps_input_names-doc` | ||||
|    :func:`torch.cat`,:ref:`unifies_names_from_inputs-doc` | ||||
|    :meth:`Tensor.cauchy_`,None | ||||
|    ":meth:`Tensor.ceil`, :func:`torch.ceil`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.ceil_`,None | ||||
|    :meth:`Tensor.char`,:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.chunk`, :func:`torch.chunk`",:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.clamp`, :func:`torch.clamp`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.clamp_`,None | ||||
|    :meth:`Tensor.copy_`,:ref:`out_function_semantics-doc` | ||||
|    ":meth:`Tensor.cos`, :func:`torch.cos`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.cos_`,None | ||||
|    ":meth:`Tensor.cosh`, :func:`torch.cosh`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.cosh_`,None | ||||
|    :meth:`Tensor.cpu`,:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.cuda`,:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.cumprod`, :func:`torch.cumprod`",:ref:`removes_dimensions-doc` | ||||
|    ":meth:`Tensor.cumsum`, :func:`torch.cumsum`",:ref:`removes_dimensions-doc` | ||||
|    :meth:`Tensor.data_ptr`,None | ||||
|    ":meth:`Tensor.detach`, :func:`torch.detach`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.detach_`,None | ||||
|    ":attr:`Tensor.device`, :func:`torch.device`",None | ||||
|    ":meth:`Tensor.digamma`, :func:`torch.digamma`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.digamma_`,None | ||||
|    :meth:`Tensor.dim`,None | ||||
|    ":meth:`Tensor.div`, :func:`torch.div`",:ref:`unifies_names_from_inputs-doc` | ||||
|    :meth:`Tensor.div_`,:ref:`unifies_names_from_inputs-doc` | ||||
|    ":meth:`Tensor.dot`, :func:`torch.dot`",None | ||||
|    :meth:`Tensor.double`,:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.element_size`,None | ||||
|    :func:`torch.empty`,:ref:`factory-doc` | ||||
|    :func:`torch.empty_like`,:ref:`factory-doc` | ||||
|    ":meth:`Tensor.eq`, :func:`torch.eq`",:ref:`unifies_names_from_inputs-doc` | ||||
|    ":meth:`Tensor.erf`, :func:`torch.erf`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.erf_`,None | ||||
|    ":meth:`Tensor.erfc`, :func:`torch.erfc`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.erfc_`,None | ||||
|    ":meth:`Tensor.erfinv`, :func:`torch.erfinv`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.erfinv_`,None | ||||
|    ":meth:`Tensor.exp`, :func:`torch.exp`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.exp_`,None | ||||
|    :meth:`Tensor.expand`,:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.expm1`, :func:`torch.expm1`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.expm1_`,None | ||||
|    :meth:`Tensor.exponential_`,None | ||||
|    :meth:`Tensor.fill_`,None | ||||
|    ":meth:`Tensor.flatten`, :func:`torch.flatten`",See documentation | ||||
|    :meth:`Tensor.float`,:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.floor`, :func:`torch.floor`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.floor_`,None | ||||
|    ":meth:`Tensor.frac`, :func:`torch.frac`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.frac_`,None | ||||
|    ":meth:`Tensor.ge`, :func:`torch.ge`",:ref:`unifies_names_from_inputs-doc` | ||||
|    ":meth:`Tensor.get_device`, :func:`torch.get_device`",None | ||||
|    :attr:`Tensor.grad`,None | ||||
|    ":meth:`Tensor.gt`, :func:`torch.gt`",:ref:`unifies_names_from_inputs-doc` | ||||
|    :meth:`Tensor.half`,:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.has_names`,See documentation | ||||
|    ":meth:`Tensor.index_fill`, :func:`torch.index_fill`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.index_fill_`,None | ||||
|    :meth:`Tensor.int`,:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.is_contiguous`,None | ||||
|    :attr:`Tensor.is_cuda`,None | ||||
|    ":meth:`Tensor.is_floating_point`, :func:`torch.is_floating_point`",None | ||||
|    :attr:`Tensor.is_leaf`,None | ||||
|    :meth:`Tensor.is_pinned`,None | ||||
|    :meth:`Tensor.is_shared`,None | ||||
|    ":meth:`Tensor.is_signed`, :func:`torch.is_signed`",None | ||||
|    :attr:`Tensor.is_sparse`,None | ||||
|    :func:`torch.is_tensor`,None | ||||
|    :meth:`Tensor.item`,None | ||||
|    ":meth:`Tensor.kthvalue`, :func:`torch.kthvalue`",:ref:`removes_dimensions-doc` | ||||
|    ":meth:`Tensor.le`, :func:`torch.le`",:ref:`unifies_names_from_inputs-doc` | ||||
|    ":meth:`Tensor.log`, :func:`torch.log`",:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.log10`, :func:`torch.log10`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.log10_`,None | ||||
|    ":meth:`Tensor.log1p`, :func:`torch.log1p`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.log1p_`,None | ||||
|    ":meth:`Tensor.log2`, :func:`torch.log2`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.log2_`,None | ||||
|    :meth:`Tensor.log_`,None | ||||
|    :meth:`Tensor.log_normal_`,None | ||||
|    ":meth:`Tensor.logical_not`, :func:`torch.logical_not`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.logical_not_`,None | ||||
|    ":meth:`Tensor.logsumexp`, :func:`torch.logsumexp`",:ref:`removes_dimensions-doc` | ||||
|    :meth:`Tensor.long`,:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.lt`, :func:`torch.lt`",:ref:`unifies_names_from_inputs-doc` | ||||
|    :func:`torch.manual_seed`,None | ||||
|    ":meth:`Tensor.masked_fill`, :func:`torch.masked_fill`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.masked_fill_`,None | ||||
|    ":meth:`Tensor.masked_select`, :func:`torch.masked_select`",Aligns mask up to input and then unifies_names_from_input_tensors | ||||
|    ":meth:`Tensor.matmul`, :func:`torch.matmul`",:ref:`contracts_away_dims-doc` | ||||
|    ":meth:`Tensor.mean`, :func:`torch.mean`",:ref:`removes_dimensions-doc` | ||||
|    ":meth:`Tensor.median`, :func:`torch.median`",:ref:`removes_dimensions-doc` | ||||
|    ":meth:`Tensor.mm`, :func:`torch.mm`",:ref:`contracts_away_dims-doc` | ||||
|    ":meth:`Tensor.mode`, :func:`torch.mode`",:ref:`removes_dimensions-doc` | ||||
|    ":meth:`Tensor.mul`, :func:`torch.mul`",:ref:`unifies_names_from_inputs-doc` | ||||
|    :meth:`Tensor.mul_`,:ref:`unifies_names_from_inputs-doc` | ||||
|    ":meth:`Tensor.mv`, :func:`torch.mv`",:ref:`contracts_away_dims-doc` | ||||
|    :attr:`Tensor.names`,See documentation | ||||
|    ":meth:`Tensor.narrow`, :func:`torch.narrow`",:ref:`keeps_input_names-doc` | ||||
|    :attr:`Tensor.ndim`,None | ||||
|    :meth:`Tensor.ndimension`,None | ||||
|    ":meth:`Tensor.ne`, :func:`torch.ne`",:ref:`unifies_names_from_inputs-doc` | ||||
|    ":meth:`Tensor.neg`, :func:`torch.neg`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.neg_`,None | ||||
|    :func:`torch.normal`,:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.normal_`,None | ||||
|    ":meth:`Tensor.numel`, :func:`torch.numel`",None | ||||
|    :func:`torch.ones`,:ref:`factory-doc` | ||||
|    ":meth:`Tensor.pow`, :func:`torch.pow`",:ref:`unifies_names_from_inputs-doc` | ||||
|    :meth:`Tensor.pow_`,None | ||||
|    ":meth:`Tensor.prod`, :func:`torch.prod`",:ref:`removes_dimensions-doc` | ||||
|    :func:`torch.rand`,:ref:`factory-doc` | ||||
|    :func:`torch.rand`,:ref:`factory-doc` | ||||
|    :func:`torch.randn`,:ref:`factory-doc` | ||||
|    :func:`torch.randn`,:ref:`factory-doc` | ||||
|    :meth:`Tensor.random_`,None | ||||
|    ":meth:`Tensor.reciprocal`, :func:`torch.reciprocal`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.reciprocal_`,None | ||||
|    :meth:`Tensor.refine_names`,See documentation | ||||
|    :meth:`Tensor.register_hook`,None | ||||
|    :meth:`Tensor.rename`,See documentation | ||||
|    :meth:`Tensor.rename_`,See documentation | ||||
|    :attr:`Tensor.requires_grad`,None | ||||
|    :meth:`Tensor.requires_grad_`,None | ||||
|    :meth:`Tensor.resize_`,Only allow resizes that do not change shape | ||||
|    :meth:`Tensor.resize_as_`,Only allow resizes that do not change shape | ||||
|    ":meth:`Tensor.round`, :func:`torch.round`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.round_`,None | ||||
|    ":meth:`Tensor.rsqrt`, :func:`torch.rsqrt`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.rsqrt_`,None | ||||
|    ":meth:`Tensor.select`, :func:`torch.select`",:ref:`removes_dimensions-doc` | ||||
|    :meth:`Tensor.short`,:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.sigmoid`, :func:`torch.sigmoid`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.sigmoid_`,None | ||||
|    ":meth:`Tensor.sign`, :func:`torch.sign`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.sign_`,None | ||||
|    ":meth:`Tensor.sin`, :func:`torch.sin`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.sin_`,None | ||||
|    ":meth:`Tensor.sinh`, :func:`torch.sinh`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.sinh_`,None | ||||
|    :meth:`Tensor.size`,None | ||||
|    ":meth:`Tensor.split`, :func:`torch.split`",:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.sqrt`, :func:`torch.sqrt`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.sqrt_`,None | ||||
|    ":meth:`Tensor.squeeze`, :func:`torch.squeeze`",:ref:`removes_dimensions-doc` | ||||
|    ":meth:`Tensor.std`, :func:`torch.std`",:ref:`removes_dimensions-doc` | ||||
|    :func:`torch.std_mean`,:ref:`removes_dimensions-doc` | ||||
|    :meth:`Tensor.stride`,None | ||||
|    ":meth:`Tensor.sub`, :func:`torch.sub`",:ref:`unifies_names_from_inputs-doc` | ||||
|    :meth:`Tensor.sub_`,:ref:`unifies_names_from_inputs-doc` | ||||
|    ":meth:`Tensor.sum`, :func:`torch.sum`",:ref:`removes_dimensions-doc` | ||||
|    ":meth:`Tensor.tan`, :func:`torch.tan`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.tan_`,None | ||||
|    ":meth:`Tensor.tanh`, :func:`torch.tanh`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.tanh_`,None | ||||
|    :func:`torch.tensor`,:ref:`factory-doc` | ||||
|    :meth:`Tensor.to`,:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.topk`, :func:`torch.topk`",:ref:`removes_dimensions-doc` | ||||
|    ":meth:`Tensor.transpose`, :func:`torch.transpose`",:ref:`permutes_dimensions-doc` | ||||
|    ":meth:`Tensor.trunc`, :func:`torch.trunc`",:ref:`keeps_input_names-doc` | ||||
|    :meth:`Tensor.trunc_`,None | ||||
|    :meth:`Tensor.type`,None | ||||
|    :meth:`Tensor.type_as`,:ref:`keeps_input_names-doc` | ||||
|    ":meth:`Tensor.unbind`, :func:`torch.unbind`",:ref:`removes_dimensions-doc` | ||||
|    :meth:`Tensor.unflatten`,See documentation | ||||
|    :meth:`Tensor.uniform_`,None | ||||
|    ":meth:`Tensor.var`, :func:`torch.var`",:ref:`removes_dimensions-doc` | ||||
|    :func:`torch.var_mean`,:ref:`removes_dimensions-doc` | ||||
|    :meth:`Tensor.zero_`,None | ||||
|    :func:`torch.zeros`,:ref:`factory-doc` | ||||
|  | ||||
|  | ||||
| .. _keeps_input_names-doc: | ||||
|  | ||||
| Keeps input names | ||||
| ^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| All pointwise unary functions follow this rule as well as some other unary functions. | ||||
|  | ||||
| - Check names: None | ||||
| - Propagate names: input tensor's names are propagated to the output. | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> x = torch.randn(3, 3, names=('N', 'C')) | ||||
|     >>> x.abs().names | ||||
|     ('N', 'C') | ||||
|  | ||||
| .. _removes_dimensions-doc: | ||||
|  | ||||
| Removes dimensions | ||||
| ^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| All reduction ops like :meth:`~Tensor.sum` remove dimensions by reducing | ||||
| over the desired dimensions. Other operations like :meth:`~Tensor.select` and | ||||
| :meth:`~Tensor.squeeze` remove dimensions. | ||||
|  | ||||
| Wherever one can pass an integer dimension index to an operator, one can also pass | ||||
| a dimension name. Functions that take lists of dimension indices can also take in a | ||||
| list of dimension names. | ||||
|  | ||||
| - Check names: If :attr:`dim` or :attr:`dims` is passed in as a list of names, | ||||
|   check that those names exist in :attr:`self`. | ||||
| - Propagate names: If the dimensions of the input tensor specified by :attr:`dim` | ||||
|   or :attr:`dims` are not present in the output tensor, then the corresponding names | ||||
|   of those dimensions do not appear in ``output.names``. | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W')) | ||||
|     >>> x.squeeze('N').names | ||||
|     ('C', 'H', 'W') | ||||
|  | ||||
|     >>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W')) | ||||
|     >>> x.sum(['N', 'C']).names | ||||
|     ('H', 'W') | ||||
|  | ||||
|     # Reduction ops with keepdim=True don't actually remove dimensions. | ||||
|     >>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W')) | ||||
|     >>> x.sum(['N', 'C'], keepdim=True).names | ||||
|     ('N', 'C', 'H', 'W') | ||||
|  | ||||
|  | ||||
| .. _unifies_names_from_inputs-doc: | ||||
|  | ||||
| Unifies names from inputs | ||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| All binary arithmetic ops follow this rule. Operations that broadcast still | ||||
| broadcast positionally from the right to preserve compatibility with unnamed | ||||
| tensors. To perform explicit broadcasting by names, use :meth:`Tensor.align_as`. | ||||
|  | ||||
| - Check names: All names must match positionally from the right. i.e., in | ||||
|   ``tensor + other``, ``match(tensor.names[i], other.names[i])`` must be true for all | ||||
|   ``i`` in ``(-min(tensor.dim(), other.dim()) + 1, -1]``. | ||||
| - Check names: Furthermore, all named dimensions must be aligned from the right. | ||||
|   During matching, if we match a named dimension ``A`` with an unnamed dimension | ||||
|   ``None``, then ``A`` must not appear in the tensor with the unnamed dimension. | ||||
| - Propagate names: unify pairs of names from the right from both tensors to | ||||
|   produce output names. | ||||
|  | ||||
| For example, | ||||
|  | ||||
| :: | ||||
|  | ||||
|     # tensor: Tensor[   N, None] | ||||
|     # other:  Tensor[None,    C] | ||||
|     >>> tensor = torch.randn(3, 3, names=('N', None)) | ||||
|     >>> other = torch.randn(3, 3, names=(None, 'C')) | ||||
|     >>> (tensor + other).names | ||||
|     ('N', 'C') | ||||
|  | ||||
| Check names: | ||||
|  | ||||
| - ``match(tensor.names[-1], other.names[-1])`` is ``True`` | ||||
| - ``match(tensor.names[-2], tensor.names[-2])`` is ``True`` | ||||
| - Because we matched ``None`` in :attr:`tensor` with ``'C'``, | ||||
|   check to make sure ``'C'`` doesn't exist in :attr:`tensor` (it does not). | ||||
| - Check to make sure ``'N'`` doesn't exists in :attr:`other` (it does not). | ||||
|  | ||||
| Finally, the output names are computed with | ||||
| ``[unify('N', None), unify(None, 'C')] = ['N', 'C']`` | ||||
|  | ||||
| More examples:: | ||||
|  | ||||
|     # Dimensions don't match from the right: | ||||
|     # tensor: Tensor[N, C] | ||||
|     # other:  Tensor[   N] | ||||
|     >>> tensor = torch.randn(3, 3, names=('N', 'C')) | ||||
|     >>> other = torch.randn(3, names=('N',)) | ||||
|     >>> (tensor + other).names | ||||
|     RuntimeError: Error when attempting to broadcast dims ['N', 'C'] and dims | ||||
|     ['N']: dim 'C' and dim 'N' are at the same position from the right but do | ||||
|     not match. | ||||
|  | ||||
|     # Dimensions aren't aligned when matching tensor.names[-1] and other.names[-1]: | ||||
|     # tensor: Tensor[N, None] | ||||
|     # other:  Tensor[      N] | ||||
|     >>> tensor = torch.randn(3, 3, names=('N', None)) | ||||
|     >>> other = torch.randn(3, names=('N',)) | ||||
|     >>> (tensor + other).names | ||||
|     RuntimeError: Misaligned dims when attempting to broadcast dims ['N'] and | ||||
|     dims ['N', None]: dim 'N' appears in a different position from the right | ||||
|     across both lists. | ||||
|  | ||||
| .. note:: | ||||
|  | ||||
|     In both of the last examples, it is possible to align the tensors by names | ||||
|     and then perform the addition. Use :meth:`Tensor.align_as` to align | ||||
|     tensors by name or :meth:`Tensor.align_to` to align tensors to a custom | ||||
|     dimension ordering. | ||||
|  | ||||
| .. _permutes_dimensions-doc: | ||||
|  | ||||
| Permutes dimensions | ||||
| ^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| Some operations, like :meth:`Tensor.t()`, permute the order of dimensions. Dimension names | ||||
| are attached to individual dimensions so they get permuted as well. | ||||
|  | ||||
| If the operator takes in positional index :attr:`dim`, it is also able to take a dimension | ||||
| name as :attr:`dim`. | ||||
|  | ||||
| - Check names: If :attr:`dim` is passed as a name, check that it exists in the tensor. | ||||
| - Propagate names: Permute dimension names in the same way as the dimensions that are | ||||
|   being permuted. | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> x = torch.randn(3, 3, names=('N', 'C')) | ||||
|     >>> x.transpose('N', 'C').names | ||||
|     ('C', 'N') | ||||
|  | ||||
| .. _contracts_away_dims-doc: | ||||
|  | ||||
| Contracts away dims | ||||
| ^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| Matrix multiply functions follow some variant of this. Let's go through | ||||
| :func:`torch.mm` first and then generalize the rule for batch matrix multiplication. | ||||
|  | ||||
| For ``torch.mm(tensor, other)``: | ||||
|  | ||||
| - Check names: None | ||||
| - Propagate names: result names are ``(tensor.names[-2], other.names[-1])``. | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> x = torch.randn(3, 3, names=('N', 'D')) | ||||
|     >>> y = torch.randn(3, 3, names=('in', 'out')) | ||||
|     >>> x.mm(y).names | ||||
|     ('N', 'out') | ||||
|  | ||||
| Inherently, a matrix multiplication performs a dot product over two dimensions, | ||||
| collapsing them. When two tensors are matrix-multipled, the contracted dimensions | ||||
| disappear and do not show up in the output tensor. | ||||
|  | ||||
| :func:`torch.mv`, :func:`torch.dot` work in a similar way: name inference does not | ||||
| check input names and removes the dimensions that are involved in the dot product: | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> x = torch.randn(3, 3, names=('N', 'D')) | ||||
|     >>> y = torch.randn(3, names=('something',)) | ||||
|     >>> x.mv(y).names | ||||
|     ('N',) | ||||
|  | ||||
| Now, let's take a look at ``torch.matmul(tensor, other)``. Assume that ``tensor.dim() >= 2`` | ||||
| and ``other.dim() >= 2``. | ||||
|  | ||||
| - Check names: Check that the batch dimensions of the inputs are aligned and broadcastable. | ||||
|   See :ref:`unifies_names_from_inputs-doc` for what it means for the inputs to be aligned. | ||||
| - Propagate names: result names are obtained by unifying the batch dimensions and removing | ||||
|   the contracted dimensions: | ||||
|   ``unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])``. | ||||
|  | ||||
| Examples:: | ||||
|  | ||||
|     # Batch matrix multiply of matrices Tensor['C', 'D'] and Tensor['E', 'F']. | ||||
|     # 'A', 'B' are batch dimensions. | ||||
|     >>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D)) | ||||
|     >>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F)) | ||||
|     >>> torch.matmul(x, y).names | ||||
|     ('A', 'B', 'C', 'F') | ||||
|  | ||||
|  | ||||
| Finally, there are fused ``add`` versions of many matmul functions. i.e., :func:`addmm` | ||||
| and :func:`addmv`. These are treated as composing name inference for i.e. :func:`mm` and | ||||
| name inference for :func:`add`. | ||||
|  | ||||
| .. _factory-doc: | ||||
|  | ||||
| Factory functions | ||||
| ^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
|  | ||||
| Factory functions now take a new :attr:`names` argument that associates a name | ||||
| with each dimension. | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> torch.zeros(2, 3, names=('N', 'C')) | ||||
|     tensor([[0., 0., 0.], | ||||
|             [0., 0., 0.]], names=('N', 'C')) | ||||
|  | ||||
| .. _out_function_semantics-doc: | ||||
|  | ||||
| out function and in-place variants | ||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| A tensor specified as an ``out=`` tensor has the following behavior: | ||||
|  | ||||
| - If it has no named dimensions, then the names computed from the operation | ||||
|   get propagated to it. | ||||
| - If it has any named dimensions, then the names computed from the operation | ||||
|   must be exactly equal to the existing names. Otherwise, the operation errors. | ||||
|  | ||||
| All in-place methods modify inputs to have names equal to the computed names | ||||
| from name inference. For example, | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> x = torch.randn(3, 3) | ||||
|     >>> y = torch.randn(3, 3, names=('N', 'C')) | ||||
|     >>> x.names | ||||
|     (None, None) | ||||
|  | ||||
|     >>> x += y | ||||
|     >>> x.names | ||||
|     ('N', 'C') | ||||
|  | ||||
							
								
								
									
										319
									
								
								docs/source/named_tensor.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										319
									
								
								docs/source/named_tensor.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,319 @@ | ||||
| .. currentmodule:: torch | ||||
|  | ||||
| .. _named_tensors-doc: | ||||
|  | ||||
| Named Tensors | ||||
| ============= | ||||
|  | ||||
| Named Tensors aim to make tensors easier to use by allowing users to associate | ||||
| explicit names with tensor dimensions. In most cases, operations that take | ||||
| dimension parameters will accept dimension names, avoiding the need to track | ||||
| dimensions by position. In addition, named tensors use names to automatically | ||||
| check that APIs are being used correctly at runtime, providing extra safety. | ||||
| Names can also be used to rearrange dimensions, for example, to support | ||||
| "broadcasting by name" rather than "broadcasting by position". | ||||
|  | ||||
| .. warning:: | ||||
|     The named tensor API is experimental and subject to change. | ||||
|  | ||||
| Creating named tensors | ||||
| ---------------------- | ||||
|  | ||||
| Factory functions now take a new :attr:`names` argument that associates a name | ||||
| with each dimension. | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> torch.zeros(2, 3, names=('N', 'C')) | ||||
|     tensor([[0., 0., 0.], | ||||
|             [0., 0., 0.]], names=('N', 'C')) | ||||
|  | ||||
| Named dimensions, like regular Tensor dimensions, are ordered. | ||||
| ``tensor.names[i]`` is the name of dimension ``i`` of ``tensor``. | ||||
|  | ||||
| The following factory functions support named tensors: | ||||
|  | ||||
| - :func:`torch.empty` | ||||
| - :func:`torch.rand` | ||||
| - :func:`torch.randn` | ||||
| - :func:`torch.ones` | ||||
| - :func:`torch.tensor` | ||||
| - :func:`torch.zeros` | ||||
|  | ||||
| Named dimensions | ||||
| ---------------- | ||||
|  | ||||
| See :attr:`~Tensor.names` for restrictions on tensor names. | ||||
|  | ||||
| Use :attr:`~Tensor.names` to access the dimension names of a tensor and | ||||
| :meth:`~Tensor.rename` to rename named dimensions. | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> imgs = torch.randn(1, 2, 2, 3 , names=('N', 'C', 'H', 'W')) | ||||
|     >>> imgs.names | ||||
|     ('N', 'C', 'H', 'W') | ||||
|  | ||||
|     >>> renamed_imgs = imgs.rename(H='height', W='width') | ||||
|     >>> renamed_imgs.names | ||||
|     ('N', 'C', 'height', 'width) | ||||
|  | ||||
|  | ||||
| Named tensors can coexist with unnamed tensors; named tensors are instances of | ||||
| :class:`torch.Tensor`. Unnamed tensors have ``None``-named dimensions. Named | ||||
| tensors do not require all dimensions to be named. | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> imgs = torch.randn(1, 2, 2, 3 , names=(None, 'C', 'H', 'W')) | ||||
|     >>> imgs.names | ||||
|     (None, 'C', 'H', 'W') | ||||
|  | ||||
| Name propagation semantics | ||||
| -------------------------- | ||||
|  | ||||
| Named tensors use names to automatically check that APIs are being called | ||||
| correctly at runtime. This occurs in a process called *name inference*. | ||||
| More formally, name inference consists of the following two steps: | ||||
|  | ||||
| - **Check names**: an operator may perform automatic checks at runtime that | ||||
|   check that certain dimension names must match. | ||||
| - **Propagate names**: name inference propagates names to output tensors. | ||||
|  | ||||
| All operations that support named tensors propagate names. | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> x = torch.randn(3, 3, names=('N', 'C')) | ||||
|     >>> x.abs().names | ||||
|     ('N', 'C') | ||||
|  | ||||
|  | ||||
| .. _match_semantics-doc: | ||||
|  | ||||
| match semantics | ||||
| ^^^^^^^^^^^^^^^ | ||||
|  | ||||
| Two names *match* if they are equal (string equality) or if at least one is ``None``. | ||||
| Nones are essentially a special "wildcard" name. | ||||
|  | ||||
| ``unify(A, B)`` determines which of the names ``A`` and ``B`` to propagate to the outputs. | ||||
| It returns the more *specific* of the two names, if they match. If the names do not match, | ||||
| then it errors. | ||||
|  | ||||
| .. note:: | ||||
|     In practice, when working with named tensors, one should avoid having unnamed | ||||
|     dimensions because their handling can be complicated. It is recommended to lift | ||||
|     all unnamed dimensions to be named dimensions by using :meth:`~Tensor.refine_names`. | ||||
|  | ||||
|  | ||||
| Basic name inference rules | ||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| Let's see how ``match`` and ``unify`` are used in name inference in the case of | ||||
| adding two one-dim tensors with no broadcasting. | ||||
|  | ||||
| :: | ||||
|  | ||||
|     x = torch.randn(3, names=('X',)) | ||||
|     y = torch.randn(3) | ||||
|     z = torch.randn(3, names=('Z',)) | ||||
|  | ||||
| **Check names**: check that the names of the two tensors *match*. | ||||
|  | ||||
| For the following examples: | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> # x + y  # match('X', None) is True | ||||
|     >>> # x + z  # match('X', 'Z') is False | ||||
|     >>> # x + x  # match('X', 'X') is True | ||||
|  | ||||
|     >>> x + z | ||||
|     Error when attempting to broadcast dims ['X'] and dims ['Z']: dim 'X' and dim 'Z' are at the same position from the right but do not match. | ||||
|  | ||||
| **Propagate names**: *unify* the names to select which one to propagate. | ||||
| In the case of ``x + y``, ``unify('X', None) = 'X'`` because ``'X'`` is more | ||||
| specific than ``None``. | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> (x + y).names | ||||
|     ('X',) | ||||
|     >>> (x + x).names | ||||
|     ('X',) | ||||
|  | ||||
| For a comprehensive list of name inference rules, see :ref:`name_inference_reference-doc`. | ||||
| Here are two common operations that may be useful to go over: | ||||
|  | ||||
| - Binary arithmetic ops: :ref:`unifies_names_from_inputs-doc` | ||||
| - Matrix multiplication ops: :ref:`contracts_away_dims-doc` | ||||
|  | ||||
| Explicit alignment by names | ||||
| --------------------------- | ||||
|  | ||||
| Use :meth:`~Tensor.align_as` or :meth:`~Tensor.align_to` to align tensor dimensions | ||||
| by name to a specified ordering. This is useful for performing "broadcasting by names". | ||||
|  | ||||
| :: | ||||
|  | ||||
|     # This function is agnostic to the dimension ordering of `input`, | ||||
|     # as long as it has a `C` dimension somewhere. | ||||
|     def scale_channels(input, scale): | ||||
|         scale = scale.refine_names('C') | ||||
|         return input * scale.align_as(input) | ||||
|  | ||||
|     >>> num_channels = 3 | ||||
|     >>> scale = torch.randn(num_channels, names='C') | ||||
|     >>> imgs = torch.rand(3, 3, 3, num_channels, names=('N', 'H', 'W', 'C')) | ||||
|     >>> more_imgs = torch.rand(3, num_channels, 3, 3, names=('N', 'C', 'H', 'W')) | ||||
|     >>> videos = torch.randn(3, num_channels, 3, 3, 3, names=('N', 'C', 'H', 'W', 'D') | ||||
|  | ||||
|     >>> scale_channels(imgs, scale) | ||||
|     >>> scale_channels(more_imgs, scale) | ||||
|     >>> scale_channels(videos, scale) | ||||
|  | ||||
| Manipulating dimensions | ||||
| ----------------------- | ||||
|  | ||||
| Use :meth:`~Tensor.align_to` to permute large amounts of dimensions without | ||||
| mentioning all of them as in required by :meth:`~Tensor.permute`. | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> tensor = torch.randn(2, 2, 2, 2, 2, 2) | ||||
|     >>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F') | ||||
|  | ||||
|     # Move the F (dim 5) and E dimension (dim 4) to the front while keeping | ||||
|     # the rest in the same order | ||||
|     >>> tensor.permute(5, 4, 0, 1, 2, 3) | ||||
|     >>> named_tensor.align_to('F', 'E', ...)  # Use '...' instead in Python 2 | ||||
|  | ||||
| Use :meth:`~Tensor.flatten` and :meth:`~Tensor.unflatten` to flatten and unflatten | ||||
| dimensions, respectively. These methods are more verbose than :meth:`~Tensor.view` | ||||
| and :meth:`~Tensor.reshape`, but have more semantic meaning to someone reading the code. | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> imgs = torch.randn(32, 3, 128, 128) | ||||
|     >>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W') | ||||
|  | ||||
|     >>> flat_imgs = imgs.view(32, -1) | ||||
|     >>> named_flat_imgs = named_imgs.flatten(['C', 'H', 'W'], 'features') | ||||
|     >>> named_flat_imgs.names | ||||
|     ('N', 'features') | ||||
|  | ||||
|     >>> unflattened_imgs = imgs.view(32, 3, 128, 128) | ||||
|     >>> unflattened_named_imgs = named_flat_imgs.unflatten( | ||||
|             'features', [('C', 3), ('H', 128), ('W', 128)]) | ||||
|  | ||||
| .. _named_tensors_autograd-doc: | ||||
|  | ||||
| Autograd support | ||||
| ---------------- | ||||
|  | ||||
| Autograd currently supports named tensors in a limited manner: autograd ignores | ||||
| names on all tensors. Gradient computation is still correct but we lose the | ||||
| safety that names give us. | ||||
|  | ||||
| :: | ||||
|  | ||||
|     >>> x = torch.randn(3, names=('D',)) | ||||
|     >>> weight = torch.randn(3, names=('D',), requires_grad=True) | ||||
|     >>> loss = (x - weight).abs() | ||||
|     >>> grad_loss = torch.randn(3) | ||||
|     >>> loss.backward(grad_loss) | ||||
|     >>> weight.grad  # Unnamed for now. Will be named in the future | ||||
|     tensor([-1.8107, -0.6357,  0.0783]) | ||||
|  | ||||
|     >>> weight.grad.zero_() | ||||
|     >>> grad_loss = grad_loss.refine_names('C') | ||||
|     >>> loss = (x - weight).abs() | ||||
|     # Ideally we'd check that the names of loss and grad_loss match but we don't yet. | ||||
|     >>> loss.backward(grad_loss) | ||||
|     >>> weight.grad | ||||
|     tensor([-1.8107, -0.6357,  0.0783]) | ||||
|  | ||||
| Currently supported operations and subsystems | ||||
| --------------------------------------------- | ||||
|  | ||||
| Operators | ||||
| ^^^^^^^^^ | ||||
|  | ||||
| See :ref:`name_inference_reference-doc` for a full list of the supported torch and | ||||
| tensor operations. We do not yet support the following that is not covered by the link: | ||||
|  | ||||
| - indexing, advanced indexing. | ||||
|  | ||||
| For ``torch.nn.functional`` operators, we support the following: | ||||
|  | ||||
| - :func:`torch.nn.functional.relu` | ||||
| - :func:`torch.nn.functional.softmax` | ||||
| - :func:`torch.nn.functional.log_softmax` | ||||
| - :func:`torch.nn.functional.tanh` | ||||
| - :func:`torch.nn.functional.sigmoid` | ||||
| - :func:`torch.nn.functional.dropout` | ||||
|  | ||||
| Subsystems | ||||
| ^^^^^^^^^^ | ||||
|  | ||||
| Autograd is supported, see :ref:`named_tensors_autograd-doc`. | ||||
| Because gradients are currently unnamed, optimizers may work but are untested. | ||||
|  | ||||
| NN modules are currently unsupported. This can lead to the following when calling | ||||
| modules with named tensor inputs: | ||||
|  | ||||
| - NN module parameters are unnamed, so outputs may be partially named. | ||||
| - NN module forward passes have code that don't support named tensors and will | ||||
|   error out appropriately. | ||||
|  | ||||
| We also do not support the following subsystems, though some may work out | ||||
| of the box: | ||||
|  | ||||
| - distributions | ||||
| - serialization (:func:`torch.load`, :func:`torch.save`) | ||||
| - multiprocessing | ||||
| - JIT | ||||
| - distributed | ||||
| - ONNX | ||||
|  | ||||
| If any of these would help your use case, please | ||||
| `search if an issue has already been filed <https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+named+tensor%22>`_ | ||||
| and if not, `file one <https://github.com/pytorch/pytorch/issues/new/choose>`_. | ||||
|  | ||||
| Named tensor API reference | ||||
| -------------------------- | ||||
|  | ||||
| In this section please find the documentation for named tensor specific APIs. | ||||
| For a comprehensive reference for how names are propagated through other PyTorch | ||||
| operators, see :ref:`name_inference_reference-doc`. | ||||
|  | ||||
| .. class:: Tensor() | ||||
|    :noindex: | ||||
|  | ||||
|    .. autoattribute:: names | ||||
|    .. automethod:: rename | ||||
|    .. automethod:: rename_ | ||||
|    .. automethod:: refine_names | ||||
|  | ||||
|    .. automethod:: align_as | ||||
|    .. automethod:: align_to | ||||
|  | ||||
|    .. automethod:: unflatten | ||||
|    .. py:method:: flatten(dims, out_dim) -> Tensor | ||||
|  | ||||
|       Flattens :attr:`dims` into a single dimension with name :attr:`out_dim`. | ||||
|  | ||||
|       All of `dims` must be consecutive in order in the :attr:`self` tensor, | ||||
|       but not necessary contiguous in memory. | ||||
|  | ||||
|       Examples:: | ||||
|  | ||||
|           >>> imgs = torch.randn(32, 3, 128, 128, names=('N', 'C', 'H', 'W')) | ||||
|           >>> flat_imgs = imgs.flatten(['C', 'H', 'W'], 'features') | ||||
|           >>> flat_imgs.names, flat_imgs.shape | ||||
|           (('N', 'features'), torch.Size([32, 49152])) | ||||
|  | ||||
|       .. warning:: | ||||
|           The named tensor API is experimental and subject to change. | ||||
|  | ||||
| @ -877,3 +877,10 @@ Utilities | ||||
|  | ||||
| .. autoclass:: Flatten | ||||
|     :members: | ||||
|  | ||||
|  | ||||
| Quantized Functions | ||||
| -------------------- | ||||
|  | ||||
| Quantization refers to techniques for performing computations and storing tensors at lower bitwidths than | ||||
| floating point precision. PyTorch supports both per tensor and per channel asymmetric linear quantization. To learn more how to use quantized functions in PyTorch, please refer to the :ref:`Quantization` documentation. | ||||
|  | ||||
							
								
								
									
										67
									
								
								docs/source/org/pytorch/DType.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								docs/source/org/pytorch/DType.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,67 @@ | ||||
| DType | ||||
| ===== | ||||
|  | ||||
| .. java:package:: org.pytorch | ||||
|    :noindex: | ||||
|  | ||||
| .. java:type:: public enum DType | ||||
|  | ||||
|    Codes representing tensor data types. | ||||
|  | ||||
| Enum Constants | ||||
| -------------- | ||||
| FLOAT32 | ||||
| ^^^^^^^ | ||||
|  | ||||
| .. java:field:: public static final DType FLOAT32 | ||||
|    :outertype: DType | ||||
|  | ||||
|    Code for dtype torch.float32. \ :java:ref:`Tensor.dtype()`\ | ||||
|  | ||||
| FLOAT64 | ||||
| ^^^^^^^ | ||||
|  | ||||
| .. java:field:: public static final DType FLOAT64 | ||||
|    :outertype: DType | ||||
|  | ||||
|    Code for dtype torch.float64. \ :java:ref:`Tensor.dtype()`\ | ||||
|  | ||||
| INT32 | ||||
| ^^^^^ | ||||
|  | ||||
| .. java:field:: public static final DType INT32 | ||||
|    :outertype: DType | ||||
|  | ||||
|    Code for dtype torch.int32. \ :java:ref:`Tensor.dtype()`\ | ||||
|  | ||||
| INT64 | ||||
| ^^^^^ | ||||
|  | ||||
| .. java:field:: public static final DType INT64 | ||||
|    :outertype: DType | ||||
|  | ||||
|    Code for dtype torch.int64. \ :java:ref:`Tensor.dtype()`\ | ||||
|  | ||||
| INT8 | ||||
| ^^^^ | ||||
|  | ||||
| .. java:field:: public static final DType INT8 | ||||
|    :outertype: DType | ||||
|  | ||||
|    Code for dtype torch.int8. \ :java:ref:`Tensor.dtype()`\ | ||||
|  | ||||
| UINT8 | ||||
| ^^^^^ | ||||
|  | ||||
| .. java:field:: public static final DType UINT8 | ||||
|    :outertype: DType | ||||
|  | ||||
|    Code for dtype torch.uint8. \ :java:ref:`Tensor.dtype()`\ | ||||
|  | ||||
| Fields | ||||
| ------ | ||||
| jniCode | ||||
| ^^^^^^^ | ||||
|  | ||||
| .. java:field:: final int jniCode | ||||
|    :outertype: DType | ||||
							
								
								
									
										297
									
								
								docs/source/org/pytorch/IValue.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										297
									
								
								docs/source/org/pytorch/IValue.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,297 @@ | ||||
| .. java:import:: java.util Locale | ||||
|  | ||||
| .. java:import:: java.util Map | ||||
|  | ||||
| IValue | ||||
| ====== | ||||
|  | ||||
| .. java:package:: org.pytorch | ||||
|    :noindex: | ||||
|  | ||||
| .. java:type:: public class IValue | ||||
|  | ||||
|    Java representation of a TorchScript value, which is implemented as tagged union that can be one of the supported types: https://pytorch.org/docs/stable/jit.html#types . | ||||
|  | ||||
|    Calling \ ``toX``\  methods for inappropriate types will throw \ :java:ref:`IllegalStateException`\ . | ||||
|  | ||||
|    \ ``IValue``\  objects are constructed with \ ``IValue.from(value)``\ , \ ``IValue.tupleFrom(value1, value2, ...)``\ , \ ``IValue.listFrom(value1, value2, ...)``\ , or one of the \ ``dict``\  methods, depending on the key type. | ||||
|  | ||||
|    Data is retrieved from \ ``IValue``\  objects with the \ ``toX()``\  methods. Note that \ ``str``\ -type IValues must be extracted with \ :java:ref:`toStr()`\ , rather than \ :java:ref:`toString()`\ . | ||||
|  | ||||
|    \ ``IValue``\  objects may retain references to objects passed into their constructors, and may return references to their internal state from \ ``toX()``\ . | ||||
|  | ||||
| Methods | ||||
| ------- | ||||
| dictLongKeyFrom | ||||
| ^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static IValue dictLongKeyFrom(Map<Long, IValue> map) | ||||
|    :outertype: IValue | ||||
|  | ||||
|    Creates a new \ ``IValue``\  of type \ ``Dict[int, V]``\ . | ||||
|  | ||||
| dictStringKeyFrom | ||||
| ^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static IValue dictStringKeyFrom(Map<String, IValue> map) | ||||
|    :outertype: IValue | ||||
|  | ||||
|    Creates a new \ ``IValue``\  of type \ ``Dict[str, V]``\ . | ||||
|  | ||||
| from | ||||
| ^^^^ | ||||
|  | ||||
| .. java:method:: public static IValue from(Tensor tensor) | ||||
|    :outertype: IValue | ||||
|  | ||||
|    Creates a new \ ``IValue``\  of type \ ``Tensor``\ . | ||||
|  | ||||
| from | ||||
| ^^^^ | ||||
|  | ||||
| .. java:method:: public static IValue from(boolean value) | ||||
|    :outertype: IValue | ||||
|  | ||||
|    Creates a new \ ``IValue``\  of type \ ``bool``\ . | ||||
|  | ||||
| from | ||||
| ^^^^ | ||||
|  | ||||
| .. java:method:: public static IValue from(long value) | ||||
|    :outertype: IValue | ||||
|  | ||||
|    Creates a new \ ``IValue``\  of type \ ``int``\ . | ||||
|  | ||||
| from | ||||
| ^^^^ | ||||
|  | ||||
| .. java:method:: public static IValue from(double value) | ||||
|    :outertype: IValue | ||||
|  | ||||
|    Creates a new \ ``IValue``\  of type \ ``float``\ . | ||||
|  | ||||
| from | ||||
| ^^^^ | ||||
|  | ||||
| .. java:method:: public static IValue from(String value) | ||||
|    :outertype: IValue | ||||
|  | ||||
|    Creates a new \ ``IValue``\  of type \ ``str``\ . | ||||
|  | ||||
| isBool | ||||
| ^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean isBool() | ||||
|    :outertype: IValue | ||||
|  | ||||
| isBoolList | ||||
| ^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean isBoolList() | ||||
|    :outertype: IValue | ||||
|  | ||||
| isDictLongKey | ||||
| ^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean isDictLongKey() | ||||
|    :outertype: IValue | ||||
|  | ||||
| isDictStringKey | ||||
| ^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean isDictStringKey() | ||||
|    :outertype: IValue | ||||
|  | ||||
| isDouble | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean isDouble() | ||||
|    :outertype: IValue | ||||
|  | ||||
| isDoubleList | ||||
| ^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean isDoubleList() | ||||
|    :outertype: IValue | ||||
|  | ||||
| isList | ||||
| ^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean isList() | ||||
|    :outertype: IValue | ||||
|  | ||||
| isLong | ||||
| ^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean isLong() | ||||
|    :outertype: IValue | ||||
|  | ||||
| isLongList | ||||
| ^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean isLongList() | ||||
|    :outertype: IValue | ||||
|  | ||||
| isNull | ||||
| ^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean isNull() | ||||
|    :outertype: IValue | ||||
|  | ||||
| isString | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean isString() | ||||
|    :outertype: IValue | ||||
|  | ||||
| isTensor | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean isTensor() | ||||
|    :outertype: IValue | ||||
|  | ||||
| isTensorList | ||||
| ^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean isTensorList() | ||||
|    :outertype: IValue | ||||
|  | ||||
| isTuple | ||||
| ^^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean isTuple() | ||||
|    :outertype: IValue | ||||
|  | ||||
| listFrom | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static IValue listFrom(boolean... list) | ||||
|    :outertype: IValue | ||||
|  | ||||
|    Creates a new \ ``IValue``\  of type \ ``List[bool]``\ . | ||||
|  | ||||
| listFrom | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static IValue listFrom(long... list) | ||||
|    :outertype: IValue | ||||
|  | ||||
|    Creates a new \ ``IValue``\  of type \ ``List[int]``\ . | ||||
|  | ||||
| listFrom | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static IValue listFrom(double... list) | ||||
|    :outertype: IValue | ||||
|  | ||||
|    Creates a new \ ``IValue``\  of type \ ``List[float]``\ . | ||||
|  | ||||
| listFrom | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static IValue listFrom(Tensor... list) | ||||
|    :outertype: IValue | ||||
|  | ||||
|    Creates a new \ ``IValue``\  of type \ ``List[Tensor]``\ . | ||||
|  | ||||
| listFrom | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static IValue listFrom(IValue... array) | ||||
|    :outertype: IValue | ||||
|  | ||||
|    Creates a new \ ``IValue``\  of type \ ``List[T]``\ . All elements must have the same type. | ||||
|  | ||||
| optionalNull | ||||
| ^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static IValue optionalNull() | ||||
|    :outertype: IValue | ||||
|  | ||||
|    Creates a new \ ``IValue``\  of type \ ``Optional``\  that contains no value. | ||||
|  | ||||
| toBool | ||||
| ^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean toBool() | ||||
|    :outertype: IValue | ||||
|  | ||||
| toBoolList | ||||
| ^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public boolean[] toBoolList() | ||||
|    :outertype: IValue | ||||
|  | ||||
| toDictLongKey | ||||
| ^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public Map<Long, IValue> toDictLongKey() | ||||
|    :outertype: IValue | ||||
|  | ||||
| toDictStringKey | ||||
| ^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public Map<String, IValue> toDictStringKey() | ||||
|    :outertype: IValue | ||||
|  | ||||
| toDouble | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public double toDouble() | ||||
|    :outertype: IValue | ||||
|  | ||||
| toDoubleList | ||||
| ^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public double[] toDoubleList() | ||||
|    :outertype: IValue | ||||
|  | ||||
| toList | ||||
| ^^^^^^ | ||||
|  | ||||
| .. java:method:: public IValue[] toList() | ||||
|    :outertype: IValue | ||||
|  | ||||
| toLong | ||||
| ^^^^^^ | ||||
|  | ||||
| .. java:method:: public long toLong() | ||||
|    :outertype: IValue | ||||
|  | ||||
| toLongList | ||||
| ^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public long[] toLongList() | ||||
|    :outertype: IValue | ||||
|  | ||||
| toStr | ||||
| ^^^^^ | ||||
|  | ||||
| .. java:method:: public String toStr() | ||||
|    :outertype: IValue | ||||
|  | ||||
| toTensor | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public Tensor toTensor() | ||||
|    :outertype: IValue | ||||
|  | ||||
| toTensorList | ||||
| ^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public Tensor[] toTensorList() | ||||
|    :outertype: IValue | ||||
|  | ||||
| toTuple | ||||
| ^^^^^^^ | ||||
|  | ||||
| .. java:method:: public IValue[] toTuple() | ||||
|    :outertype: IValue | ||||
|  | ||||
| tupleFrom | ||||
| ^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static IValue tupleFrom(IValue... array) | ||||
|    :outertype: IValue | ||||
|  | ||||
|    Creates a new \ ``IValue``\  of type \ ``Tuple[T0, T1, ...]``\ . | ||||
							
								
								
									
										55
									
								
								docs/source/org/pytorch/Module.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								docs/source/org/pytorch/Module.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,55 @@ | ||||
| .. java:import:: com.facebook.jni HybridData | ||||
|  | ||||
| Module | ||||
| ====== | ||||
|  | ||||
| .. java:package:: org.pytorch | ||||
|    :noindex: | ||||
|  | ||||
| .. java:type:: public class Module | ||||
|  | ||||
|    Java wrapper for torch::jit::script::Module. | ||||
|  | ||||
| Methods | ||||
| ------- | ||||
| destroy | ||||
| ^^^^^^^ | ||||
|  | ||||
| .. java:method:: public void destroy() | ||||
|    :outertype: Module | ||||
|  | ||||
|    Explicitly destroys the native torch::jit::script::Module. Calling this method is not required, as the native object will be destroyed when this object is garbage-collected. However, the timing of garbage collection is not guaranteed, so proactively calling \ ``destroy``\  can free memory more quickly. See \ :java:ref:`com.facebook.jni.HybridData.resetNative`\ . | ||||
|  | ||||
| forward | ||||
| ^^^^^^^ | ||||
|  | ||||
| .. java:method:: public IValue forward(IValue... inputs) | ||||
|    :outertype: Module | ||||
|  | ||||
|    Runs the 'forward' method of this module with the specified arguments. | ||||
|  | ||||
|    :param inputs: arguments for the TorchScript module's 'forward' method. | ||||
|    :return: return value from the 'forward' method. | ||||
|  | ||||
| load | ||||
| ^^^^ | ||||
|  | ||||
| .. java:method:: public static Module load(String modelPath) | ||||
|    :outertype: Module | ||||
|  | ||||
|    Loads a serialized TorchScript module from the specified path on the disk. | ||||
|  | ||||
|    :param modelPath: path to file that contains the serialized TorchScript module. | ||||
|    :return: new \ :java:ref:`org.pytorch.Module`\  object which owns torch::jit::script::Module. | ||||
|  | ||||
| runMethod | ||||
| ^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public IValue runMethod(String methodName, IValue... inputs) | ||||
|    :outertype: Module | ||||
|  | ||||
|    Runs the specified method of this module with the specified arguments. | ||||
|  | ||||
|    :param methodName: name of the TorchScript method to run. | ||||
|    :param inputs: arguments that will be passed to TorchScript method. | ||||
|    :return: return value from the method. | ||||
							
								
								
									
										315
									
								
								docs/source/org/pytorch/Tensor.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										315
									
								
								docs/source/org/pytorch/Tensor.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,315 @@ | ||||
| .. java:import:: java.nio Buffer | ||||
|  | ||||
| .. java:import:: java.nio ByteBuffer | ||||
|  | ||||
| .. java:import:: java.nio ByteOrder | ||||
|  | ||||
| .. java:import:: java.nio DoubleBuffer | ||||
|  | ||||
| .. java:import:: java.nio FloatBuffer | ||||
|  | ||||
| .. java:import:: java.nio IntBuffer | ||||
|  | ||||
| .. java:import:: java.nio LongBuffer | ||||
|  | ||||
| .. java:import:: java.util Arrays | ||||
|  | ||||
| .. java:import:: java.util Locale | ||||
|  | ||||
| Tensor | ||||
| ====== | ||||
|  | ||||
| .. java:package:: org.pytorch | ||||
|    :noindex: | ||||
|  | ||||
| .. java:type:: public abstract class Tensor | ||||
|  | ||||
|    Representation of a Tensor. Behavior is similar to PyTorch's tensor objects. | ||||
|  | ||||
|    Most tensors will be constructed as \ ``Tensor.fromBlob(data, shape)``\ , where \ ``data``\  can be an array or a direct \ :java:ref:`Buffer`\  (of the proper subclass). Helper methods are provided to allocate buffers properly. | ||||
|  | ||||
|    To access Tensor data, see \ :java:ref:`dtype()`\ , \ :java:ref:`shape()`\ , and various \ ``getDataAs*``\  methods. | ||||
|  | ||||
|    When constructing \ ``Tensor``\  objects with \ ``data``\  as an array, it is not specified whether this data is is copied or retained as a reference so it is recommended not to modify it after constructing. \ ``data``\  passed as a \ :java:ref:`Buffer`\  is not copied, so it can be modified between \ :java:ref:`Module`\  calls to avoid reallocation. Data retrieved from \ ``Tensor``\  objects may be copied or may be a reference to the \ ``Tensor``\ 's internal data buffer. \ ``shape``\  is always copied. | ||||
|  | ||||
| Methods | ||||
| ------- | ||||
| allocateByteBuffer | ||||
| ^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static ByteBuffer allocateByteBuffer(int numElements) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Allocates a new direct \ :java:ref:`java.nio.ByteBuffer`\  with native byte order with specified capacity that can be used in \ :java:ref:`Tensor.fromBlob(ByteBuffer,long[])`\ , \ :java:ref:`Tensor.fromBlobUnsigned(ByteBuffer,long[])`\ . | ||||
|  | ||||
|    :param numElements: capacity (number of elements) of result buffer. | ||||
|  | ||||
| allocateDoubleBuffer | ||||
| ^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static DoubleBuffer allocateDoubleBuffer(int numElements) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Allocates a new direct \ :java:ref:`java.nio.DoubleBuffer`\  with native byte order with specified capacity that can be used in \ :java:ref:`Tensor.fromBlob(DoubleBuffer,long[])`\ . | ||||
|  | ||||
|    :param numElements: capacity (number of elements) of result buffer. | ||||
|  | ||||
| allocateFloatBuffer | ||||
| ^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static FloatBuffer allocateFloatBuffer(int numElements) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Allocates a new direct \ :java:ref:`java.nio.FloatBuffer`\  with native byte order with specified capacity that can be used in \ :java:ref:`Tensor.fromBlob(FloatBuffer,long[])`\ . | ||||
|  | ||||
|    :param numElements: capacity (number of elements) of result buffer. | ||||
|  | ||||
| allocateIntBuffer | ||||
| ^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static IntBuffer allocateIntBuffer(int numElements) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Allocates a new direct \ :java:ref:`java.nio.IntBuffer`\  with native byte order with specified capacity that can be used in \ :java:ref:`Tensor.fromBlob(IntBuffer,long[])`\ . | ||||
|  | ||||
|    :param numElements: capacity (number of elements) of result buffer. | ||||
|  | ||||
| allocateLongBuffer | ||||
| ^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static LongBuffer allocateLongBuffer(int numElements) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Allocates a new direct \ :java:ref:`java.nio.LongBuffer`\  with native byte order with specified capacity that can be used in \ :java:ref:`Tensor.fromBlob(LongBuffer,long[])`\ . | ||||
|  | ||||
|    :param numElements: capacity (number of elements) of result buffer. | ||||
|  | ||||
| dtype | ||||
| ^^^^^ | ||||
|  | ||||
| .. java:method:: public abstract DType dtype() | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    :return: data type of this tensor. | ||||
|  | ||||
| dtypeJniCode | ||||
| ^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method::  int dtypeJniCode() | ||||
|    :outertype: Tensor | ||||
|  | ||||
| fromBlob | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static Tensor fromBlob(byte[] data, long[] shape) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of bytes. | ||||
|  | ||||
|    :param data: Tensor elements | ||||
|    :param shape: Tensor shape | ||||
|  | ||||
| fromBlob | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static Tensor fromBlob(int[] data, long[] shape) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of ints. | ||||
|  | ||||
|    :param data: Tensor elements | ||||
|    :param shape: Tensor shape | ||||
|  | ||||
| fromBlob | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static Tensor fromBlob(float[] data, long[] shape) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array of floats. | ||||
|  | ||||
|    :param data: Tensor elements | ||||
|    :param shape: Tensor shape | ||||
|  | ||||
| fromBlob | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static Tensor fromBlob(long[] data, long[] shape) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of longs. | ||||
|  | ||||
|    :param data: Tensor elements | ||||
|    :param shape: Tensor shape | ||||
|  | ||||
| fromBlob | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static Tensor fromBlob(long[] shape, double[] data) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Creates a new Tensor instance with dtype torch.float64 with specified shape and data as array of doubles. | ||||
|  | ||||
|    :param shape: Tensor shape | ||||
|    :param data: Tensor elements | ||||
|  | ||||
| fromBlob | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static Tensor fromBlob(ByteBuffer data, long[] shape) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Creates a new Tensor instance with dtype torch.int8 with specified shape and data. | ||||
|  | ||||
|    :param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\  elements. The buffer is used directly without copying, and changes to its content will change the tensor. | ||||
|    :param shape: Tensor shape | ||||
|  | ||||
| fromBlob | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static Tensor fromBlob(IntBuffer data, long[] shape) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Creates a new Tensor instance with dtype torch.int32 with specified shape and data. | ||||
|  | ||||
|    :param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\  elements. The buffer is used directly without copying, and changes to its content will change the tensor. | ||||
|    :param shape: Tensor shape | ||||
|  | ||||
| fromBlob | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static Tensor fromBlob(FloatBuffer data, long[] shape) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Creates a new Tensor instance with dtype torch.float32 with specified shape and data. | ||||
|  | ||||
|    :param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\  elements. The buffer is used directly without copying, and changes to its content will change the tensor. | ||||
|    :param shape: Tensor shape | ||||
|  | ||||
| fromBlob | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static Tensor fromBlob(LongBuffer data, long[] shape) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Creates a new Tensor instance with dtype torch.int64 with specified shape and data. | ||||
|  | ||||
|    :param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\  elements. The buffer is used directly without copying, and changes to its content will change the tensor. | ||||
|    :param shape: Tensor shape | ||||
|  | ||||
| fromBlob | ||||
| ^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static Tensor fromBlob(DoubleBuffer data, long[] shape) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Creates a new Tensor instance with dtype torch.float64 with specified shape and data. | ||||
|  | ||||
|    :param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\  elements. The buffer is used directly without copying, and changes to its content will change the tensor. | ||||
|    :param shape: Tensor shape | ||||
|  | ||||
| fromBlobUnsigned | ||||
| ^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static Tensor fromBlobUnsigned(byte[] data, long[] shape) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of bytes. | ||||
|  | ||||
|    :param data: Tensor elements | ||||
|    :param shape: Tensor shape | ||||
|  | ||||
| fromBlobUnsigned | ||||
| ^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Creates a new Tensor instance with dtype torch.uint8 with specified shape and data. | ||||
|  | ||||
|    :param data: Direct buffer with native byte order that contains \ ``Tensor.numel(shape)``\  elements. The buffer is used directly without copying, and changes to its content will change the tensor. | ||||
|    :param shape: Tensor shape | ||||
|  | ||||
| getDataAsByteArray | ||||
| ^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public byte[] getDataAsByteArray() | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    :throws IllegalStateException: if it is called for a non-int8 tensor. | ||||
|    :return: a Java byte array that contains the tensor data. This may be a copy or reference. | ||||
|  | ||||
| getDataAsDoubleArray | ||||
| ^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public double[] getDataAsDoubleArray() | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    :throws IllegalStateException: if it is called for a non-float64 tensor. | ||||
|    :return: a Java double array that contains the tensor data. This may be a copy or reference. | ||||
|  | ||||
| getDataAsFloatArray | ||||
| ^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public float[] getDataAsFloatArray() | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    :throws IllegalStateException: if it is called for a non-float32 tensor. | ||||
|    :return: a Java float array that contains the tensor data. This may be a copy or reference. | ||||
|  | ||||
| getDataAsIntArray | ||||
| ^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public int[] getDataAsIntArray() | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    :throws IllegalStateException: if it is called for a non-int32 tensor. | ||||
|    :return: a Java int array that contains the tensor data. This may be a copy or reference. | ||||
|  | ||||
| getDataAsLongArray | ||||
| ^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public long[] getDataAsLongArray() | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    :throws IllegalStateException: if it is called for a non-int64 tensor. | ||||
|    :return: a Java long array that contains the tensor data. This may be a copy or reference. | ||||
|  | ||||
| getDataAsUnsignedByteArray | ||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public byte[] getDataAsUnsignedByteArray() | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    :throws IllegalStateException: if it is called for a non-uint8 tensor. | ||||
|    :return: a Java byte array that contains the tensor data. This may be a copy or reference. | ||||
|  | ||||
| getRawDataBuffer | ||||
| ^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method::  Buffer getRawDataBuffer() | ||||
|    :outertype: Tensor | ||||
|  | ||||
| numel | ||||
| ^^^^^ | ||||
|  | ||||
| .. java:method:: public long numel() | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Returns the number of elements in this tensor. | ||||
|  | ||||
| numel | ||||
| ^^^^^ | ||||
|  | ||||
| .. java:method:: public static long numel(long[] shape) | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Calculates the number of elements in a tensor with the specified shape. | ||||
|  | ||||
| shape | ||||
| ^^^^^ | ||||
|  | ||||
| .. java:method:: public long[] shape() | ||||
|    :outertype: Tensor | ||||
|  | ||||
|    Returns the shape of this tensor. (The array is a fresh copy.) | ||||
							
								
								
									
										114
									
								
								docs/source/org/pytorch/TensorImageUtils.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								docs/source/org/pytorch/TensorImageUtils.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,114 @@ | ||||
| .. java:import:: android.graphics Bitmap | ||||
|  | ||||
| .. java:import:: android.graphics ImageFormat | ||||
|  | ||||
| .. java:import:: android.media Image | ||||
|  | ||||
| .. java:import:: org.pytorch Tensor | ||||
|  | ||||
| .. java:import:: java.nio ByteBuffer | ||||
|  | ||||
| .. java:import:: java.nio FloatBuffer | ||||
|  | ||||
| .. java:import:: java.util Locale | ||||
|  | ||||
| TensorImageUtils | ||||
| ================ | ||||
|  | ||||
| .. java:package:: org.pytorch.torchvision | ||||
|    :noindex: | ||||
|  | ||||
| .. java:type:: public final class TensorImageUtils | ||||
|  | ||||
|    Contains utility functions for \ :java:ref:`org.pytorch.Tensor`\  creation from \ :java:ref:`android.graphics.Bitmap`\  or \ :java:ref:`android.media.Image`\  source. | ||||
|  | ||||
| Fields | ||||
| ------ | ||||
| TORCHVISION_NORM_MEAN_RGB | ||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:field:: public static float[] TORCHVISION_NORM_MEAN_RGB | ||||
|    :outertype: TensorImageUtils | ||||
|  | ||||
| TORCHVISION_NORM_STD_RGB | ||||
| ^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:field:: public static float[] TORCHVISION_NORM_STD_RGB | ||||
|    :outertype: TensorImageUtils | ||||
|  | ||||
| Methods | ||||
| ------- | ||||
| bitmapToFloat32Tensor | ||||
| ^^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static Tensor bitmapToFloat32Tensor(Bitmap bitmap, float[] normMeanRGB, float[] normStdRGB) | ||||
|    :outertype: TensorImageUtils | ||||
|  | ||||
|    Creates new \ :java:ref:`org.pytorch.Tensor`\  from full \ :java:ref:`android.graphics.Bitmap`\ , normalized with specified in parameters mean and std. | ||||
|  | ||||
|    :param normMeanRGB: means for RGB channels normalization, length must equal 3, RGB order | ||||
|    :param normStdRGB: standard deviation for RGB channels normalization, length must equal 3, RGB order | ||||
|  | ||||
| bitmapToFloat32Tensor | ||||
| ^^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static Tensor bitmapToFloat32Tensor(Bitmap bitmap, int x, int y, int width, int height, float[] normMeanRGB, float[] normStdRGB) | ||||
|    :outertype: TensorImageUtils | ||||
|  | ||||
|    Creates new \ :java:ref:`org.pytorch.Tensor`\  from specified area of \ :java:ref:`android.graphics.Bitmap`\ , normalized with specified in parameters mean and std. | ||||
|  | ||||
|    :param bitmap: \ :java:ref:`android.graphics.Bitmap`\  as a source for Tensor data | ||||
|    :param x: - x coordinate of top left corner of bitmap's area | ||||
|    :param y: - y coordinate of top left corner of bitmap's area | ||||
|    :param width: - width of bitmap's area | ||||
|    :param height: - height of bitmap's area | ||||
|    :param normMeanRGB: means for RGB channels normalization, length must equal 3, RGB order | ||||
|    :param normStdRGB: standard deviation for RGB channels normalization, length must equal 3, RGB order | ||||
|  | ||||
| bitmapToFloatBuffer | ||||
| ^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static void bitmapToFloatBuffer(Bitmap bitmap, int x, int y, int width, int height, float[] normMeanRGB, float[] normStdRGB, FloatBuffer outBuffer, int outBufferOffset) | ||||
|    :outertype: TensorImageUtils | ||||
|  | ||||
|    Writes tensor content from specified \ :java:ref:`android.graphics.Bitmap`\ , normalized with specified in parameters mean and std to specified \ :java:ref:`java.nio.FloatBuffer`\  with specified offset. | ||||
|  | ||||
|    :param bitmap: \ :java:ref:`android.graphics.Bitmap`\  as a source for Tensor data | ||||
|    :param x: - x coordinate of top left corner of bitmap's area | ||||
|    :param y: - y coordinate of top left corner of bitmap's area | ||||
|    :param width: - width of bitmap's area | ||||
|    :param height: - height of bitmap's area | ||||
|    :param normMeanRGB: means for RGB channels normalization, length must equal 3, RGB order | ||||
|    :param normStdRGB: standard deviation for RGB channels normalization, length must equal 3, RGB order | ||||
|  | ||||
| imageYUV420CenterCropToFloat32Tensor | ||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static Tensor imageYUV420CenterCropToFloat32Tensor(Image image, int rotateCWDegrees, int tensorWidth, int tensorHeight, float[] normMeanRGB, float[] normStdRGB) | ||||
|    :outertype: TensorImageUtils | ||||
|  | ||||
|    Creates new \ :java:ref:`org.pytorch.Tensor`\  from specified area of \ :java:ref:`android.media.Image`\ , doing optional rotation, scaling (nearest) and center cropping. | ||||
|  | ||||
|    :param image: \ :java:ref:`android.media.Image`\  as a source for Tensor data | ||||
|    :param rotateCWDegrees: Clockwise angle through which the input image needs to be rotated to be upright. Range of valid values: 0, 90, 180, 270 | ||||
|    :param tensorWidth: return tensor width, must be positive | ||||
|    :param tensorHeight: return tensor height, must be positive | ||||
|    :param normMeanRGB: means for RGB channels normalization, length must equal 3, RGB order | ||||
|    :param normStdRGB: standard deviation for RGB channels normalization, length must equal 3, RGB order | ||||
|  | ||||
| imageYUV420CenterCropToFloatBuffer | ||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. java:method:: public static void imageYUV420CenterCropToFloatBuffer(Image image, int rotateCWDegrees, int tensorWidth, int tensorHeight, float[] normMeanRGB, float[] normStdRGB, FloatBuffer outBuffer, int outBufferOffset) | ||||
|    :outertype: TensorImageUtils | ||||
|  | ||||
|    Writes tensor content from specified \ :java:ref:`android.media.Image`\ , doing optional rotation, scaling (nearest) and center cropping to specified \ :java:ref:`java.nio.FloatBuffer`\  with specified offset. | ||||
|  | ||||
|    :param image: \ :java:ref:`android.media.Image`\  as a source for Tensor data | ||||
|    :param rotateCWDegrees: Clockwise angle through which the input image needs to be rotated to be upright. Range of valid values: 0, 90, 180, 270 | ||||
|    :param tensorWidth: return tensor width, must be positive | ||||
|    :param tensorHeight: return tensor height, must be positive | ||||
|    :param normMeanRGB: means for RGB channels normalization, length must equal 3, RGB order | ||||
|    :param normStdRGB: standard deviation for RGB channels normalization, length must equal 3, RGB order | ||||
|    :param outBuffer: Output buffer, where tensor content will be written | ||||
|    :param outBufferOffset: Output buffer offset with which tensor content will be written | ||||
							
								
								
									
										18
									
								
								docs/source/org/pytorch/package-index.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								docs/source/org/pytorch/package-index.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,18 @@ | ||||
| org.pytorch | ||||
| =========== | ||||
|  | ||||
| .. java:package:: org.pytorch | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 1 | ||||
|  | ||||
|    DType | ||||
|    IValue | ||||
|    Module | ||||
|    Tensor | ||||
|    Tensor-Tensor_float32 | ||||
|    Tensor-Tensor_float64 | ||||
|    Tensor-Tensor_int32 | ||||
|    Tensor-Tensor_int64 | ||||
|    Tensor-Tensor_int8 | ||||
|    Tensor-Tensor_uint8 | ||||
							
								
								
									
										9
									
								
								docs/source/org/pytorch/torchvision/package-index.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								docs/source/org/pytorch/torchvision/package-index.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | ||||
| rg.pytorch.torchvision | ||||
| ======================= | ||||
|  | ||||
| .. java:package:: org.pytorch.torchvision | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 1 | ||||
|  | ||||
|    TensorImageUtils | ||||
							
								
								
									
										29
									
								
								docs/source/packages.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								docs/source/packages.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,29 @@ | ||||
| Javadoc | ||||
| ========= | ||||
|  | ||||
| Main part of the java api includes 3 classes: | ||||
| :: | ||||
|     org.pytorch.Tensor | ||||
|     org.pytorch.IValue | ||||
|     org.pytorch.Module | ||||
|  | ||||
| If the reader is familiar with pytorch python api, we can think that | ||||
| ``org.pytorch.Tensor`` represents ``torch.tensor``, ``org.pytorch.Module`` represents ``torch.Module``, | ||||
| while ``org.pytorch.IValue`` represents value of TorchScript variable, supporting all | ||||
| its `types <https://pytorch.org/docs/stable/jit.html#types>`_. | ||||
|  | ||||
| Additionally, there is the DType class which contains code representing tensor data types and  | ||||
| TensorImageUtils which contains utility functions for \ :java:ref:`org.pytorch.Tensor`\ created from \ :java:ref:`android.graphics.Bitmap`\  or \ :java:ref:`android.media.Image`\  source. | ||||
|  | ||||
| You can find details to each of these classes linked below: | ||||
|  | ||||
| .. java:package:: org.pytorch | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 1 | ||||
|      | ||||
|    org/pytorch/DType  | ||||
|    org/pytorch/Tensor | ||||
|    org/pytorch/IValue | ||||
|    org/pytorch/Module | ||||
|    org/pytorch/TensorImageUtils | ||||
							
								
								
									
										583
									
								
								docs/source/quantization.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										583
									
								
								docs/source/quantization.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,583 @@ | ||||
| Quantization | ||||
| =========================== | ||||
|  | ||||
| Introduction to Quantization | ||||
| ---------------------------- | ||||
|  | ||||
| Quantization refers to techniques for performing computations and storing tensors at lower bitwidths than | ||||
| floating point precision. A quantized model executes some or all of the operations on tensors with | ||||
| integers rather than floating point values. This allows for a more | ||||
| compact model representation and the use of high performance vectorized | ||||
| operations on many hardware platforms. PyTorch supports INT8 | ||||
| quantization compared to typical FP32 models allowing for a 4x reduction in the model size and | ||||
| a 4x reduction in memory bandwidth requirements.  Hardware support for  INT8 computations | ||||
| is typically 2 to 4 times faster compared to FP32 compute. Quantization is primarily a technique | ||||
| to speed up inference and only the forward pass is supported for quantized operators. | ||||
|  | ||||
| PyTorch supports multiple approaches to quantizing a deep learning model. In most cases the model is trained | ||||
| in FP32 and then the model is converted to INT8. In addition, PyTorch also supports quantization aware | ||||
| training, which models quantization errors in both the forward and backward passes using fake-quantization | ||||
| modules. Note that the entire computation is carried out in floating point. At the end of quantization aware | ||||
| training, PyTorch provides conversion functions to convert the trained model into lower precision. | ||||
|  | ||||
| At lower level, PyTorch provides a way to represent quantized tensors and | ||||
| perform operations with them. They can be used to directly construct models that | ||||
| perform all or part of the computation in lower precision. Higher-level APIs are | ||||
| provided that incorporate typical workflows of converting FP32 model to lower | ||||
| precision with minimal accuracy loss. | ||||
|  | ||||
| Today, PyTorch supports the following backends for running quantized operators efficiently: | ||||
|  | ||||
| * x86 CPUs with AVX2 support or higher (without AVX2 some operations have inefficient implementations) | ||||
| * ARM CPUs (typically found in mobile/embedded devices) | ||||
|  | ||||
| The corresponding implementation is chosen automatically based on the PyTorch build mode. | ||||
|  | ||||
| .. note:: | ||||
|  | ||||
|   PyTorch 1.3 doesn't provide quantized operator implementations on CUDA yet - this is direction of future work. | ||||
|   Move the model to CPU in order to test the quantized functionality. | ||||
|  | ||||
|   Quantization-aware training (through :class:`~torch.quantization.FakeQuantize`) supports both CPU and CUDA. | ||||
|  | ||||
| Quantized Tensors | ||||
| --------------------------------------- | ||||
|  | ||||
| PyTorch supports both per tensor and per channel asymmetric linear | ||||
| quantization. Per tensor means that all the values within the tensor are | ||||
| scaled the same way. Per channel means that for each dimension, typically | ||||
| the channel dimension of a tensor, the values | ||||
| in the tensor are scaled and offset by a different value (effectively | ||||
| the scale and offset become vectors). This allows for lesser error in converting tensors | ||||
| to quantized values. | ||||
|  | ||||
| The mapping is performed by converting the floating point tensors using | ||||
|  | ||||
| .. image:: math-quantizer-equation.png | ||||
|    :width: 40% | ||||
|  | ||||
| Note that, we ensure that zero in floating point is represented with no error after quantization, | ||||
| thereby ensuring that operations like padding do not cause additional quantization error. | ||||
|  | ||||
| In order to do quantization in PyTorch, we need to be able to represent | ||||
| quantized data in Tensors. A Quantized Tensor allows for storing | ||||
| quantized data (represented as int8/uint8/int32) along with quantization | ||||
| parameters like scale and zero\_point. Quantized Tensors allow for many | ||||
| useful operations making quantized arithmetic easy, in addition to | ||||
| allowing for serialization of data in a quantized format. | ||||
|  | ||||
| Operation coverage | ||||
| ------------------ | ||||
|  | ||||
| Quantized Tensors support a limited subset of data manipulation methods of the regular | ||||
| full-precision tensor. (see list below) | ||||
|  | ||||
| For NN operators included in PyTorch, we restrict support to: | ||||
|  | ||||
|    1. 8 bit weights (data\_type = qint8) | ||||
|    2. 8 bit activations (data\_type = quint8) | ||||
|  | ||||
| Note that operator implementations currently only | ||||
| support per channel quantization for weights of the **conv** and **linear** | ||||
| operators. Furthermore the minimum and the maximum of the input data is | ||||
| mapped linearly to the minimum and the maximum of the quantized data | ||||
| type such that zero is represented with no quantization error. | ||||
|  | ||||
| Additional data types and quantization schemes can be implemented through | ||||
| the `custom operator mechanism <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_. | ||||
|  | ||||
| Many operations for quantized tensors are available under the same API as full | ||||
| float version in ``torch`` or ``torch.nn``. Quantized version of NN modules that | ||||
| perform re-quantization are available in ``torch.nn.quantized``. Those | ||||
| operations explicitly take output quantization parameters (scale and zero\_point) in | ||||
| the operation signature. | ||||
|  | ||||
| In addition, we also support fused versions corresponding to common fusion patterns that impact quantization at: | ||||
| torch.nn.intrinsic.quantized. | ||||
|  | ||||
| For quantization aware training, we support modules prepared for quantization aware training at | ||||
| torch.nn.qat and torch.nn.intrinsic.qat | ||||
|  | ||||
| Current quantized operation list is sufficient to cover typical CNN and RNN | ||||
| models: | ||||
|  | ||||
|  | ||||
| Quantized ``torch.Tensor`` operations | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
| Operations that are available from the ``torch`` namespace or as methods on Tensor for quantized tensors: | ||||
|  | ||||
| * :func:`~torch.quantize_per_tensor` - Convert float tensor to quantized tensor with per-tensor scale and zero point | ||||
| * :func:`~torch.quantize_per_channel` - Convert float tensor to quantized tensor with per-channel scale and zero point | ||||
| * View-based operations like :meth:`~torch.Tensor.view`, :meth:`~torch.Tensor.as_strided`, :meth:`~torch.Tensor.expand`, :meth:`~torch.Tensor.flatten`, :meth:`~torch.Tensor.slice`, python-style indexing, etc - work as on regular tensor (if quantization is not per-channel) | ||||
| * Comparators | ||||
|     * :meth:`~torch.Tensor.ne` — Not equal | ||||
|     * :meth:`~torch.Tensor.eq` — Equal | ||||
|     * :meth:`~torch.Tensor.ge` — Greater or equal | ||||
|     * :meth:`~torch.Tensor.le` — Less or equal | ||||
|     * :meth:`~torch.Tensor.gt` — Greater | ||||
|     * :meth:`~torch.Tensor.lt` — Less | ||||
| * :meth:`~torch.Tensor.copy_` — Copies src to self in-place | ||||
| * :meth:`~torch.Tensor.clone` —  Returns a deep copy of the passed-in tensor | ||||
| * :meth:`~torch.Tensor.dequantize` — Convert quantized tensor to float tensor | ||||
| * :meth:`~torch.Tensor.equal` — Compares two tensors, returns true if quantization parameters and all integer elements are the same | ||||
| * :meth:`~torch.Tensor.int_repr` — Prints the underlying integer representation of the quantized tensor | ||||
| * :meth:`~torch.Tensor.max` — Returns the maximum value of the tensor (reduction only) | ||||
| * :meth:`~torch.Tensor.mean` — Mean function. Supported variants: reduction, dim, out | ||||
| * :meth:`~torch.Tensor.min` — Returns the minimum value of the tensor (reduction only) | ||||
| * :meth:`~torch.Tensor.q_scale` — Returns the scale of the per-tensor quantized tensor | ||||
| * :meth:`~torch.Tensor.q_zero_point` — Returns the zero_point of the per-tensor quantized zero point | ||||
| * :meth:`~torch.Tensor.q_per_channel_scales` — Returns the scales of the per-channel quantized tensor | ||||
| * :meth:`~torch.Tensor.q_per_channel_zero_points` — Returns the zero points of the per-channel quantized tensor | ||||
| * :meth:`~torch.Tensor.q_per_channel_axis` — Returns the channel axis of the per-channel quantized tensor | ||||
| * :meth:`~torch.Tensor.relu` — Rectified linear unit (copy) | ||||
| * :meth:`~torch.Tensor.relu_` — Rectified linear unit (inplace) | ||||
| * :meth:`~torch.Tensor.resize_` — In-place resize | ||||
| * :meth:`~torch.Tensor.sort` — Sorts the tensor | ||||
| * :meth:`~torch.Tensor.topk` — Returns k largest values of a tensor | ||||
|  | ||||
|  | ||||
| ``torch.nn.intrinsic`` | ||||
| ~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
| Fused modules are provided for common patterns in CNNs. Combining several operations together (like convolution and relu) allows for better quantization accuracy | ||||
|  | ||||
| * ``torch.nn.intrinsic`` — float versions of the modules, can be swapped with quantized version 1 to 1 | ||||
|     * :class:`~torch.nn.intrinsic.ConvBn2d` — Conv2d + BatchNorm | ||||
|     * :class:`~torch.nn.intrinsic.ConvBnReLU2d` — Conv2d + BatchNorm + ReLU | ||||
|     * :class:`~torch.nn.intrinsic.ConvReLU2d` — Conv2d + Relu | ||||
|     * :class:`~torch.nn.intrinsic.LinearReLU` — Linear + ReLU | ||||
| * ``torch.nn.intrinsic.qat`` — versions of layers for quantization-aware training | ||||
|     * :class:`~torch.nn.intrinsic.qat.ConvBn2d` — Conv2d + BatchNorm | ||||
|     * :class:`~torch.nn.intrinsic.qat.ConvBnReLU2d` — Conv2d + BatchNorm + ReLU | ||||
|     * :class:`~torch.nn.intrinsic.qat.ConvReLU2d` — Conv2d + ReLU | ||||
|     * :class:`~torch.nn.intrinsic.qat.LinearReLU` — Linear + ReLU | ||||
| * ``torch.nn.intrinsic.quantized`` — quantized version of fused layers for inference (no BatchNorm variants as it's usually folded into convolution for inference) | ||||
|     * :class:`~torch.nn.intrinsic.quantized.LinearReLU` — Linear + ReLU | ||||
|     * :class:`~torch.nn.intrinsic.quantized.ConvReLU2d` — 2D Convolution + ReLU | ||||
|  | ||||
| ``torch.nn.qat`` | ||||
| ~~~~~~~~~~~~~~~~ | ||||
|  | ||||
| Layers for the quantization-aware training | ||||
|  | ||||
| * :class:`~torch.nn.qat.Linear` — Linear (fully-connected) layer | ||||
| * :class:`~torch.nn.qat.Conv2d` — 2D convolution | ||||
|  | ||||
| ``torch.quantization`` | ||||
| ~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
| * Functions for quantization | ||||
|     * :func:`~torch.quantization.add_observer_` — Adds observer for the leaf modules (if quantization configuration is provided) | ||||
|     * :func:`~torch.quantization.add_quant_dequant`— Wraps the leaf child module using :class:`~torch.quantization.QuantWrapper` | ||||
|     * :func:`~torch.quantization.convert` — Converts float module with observers into its quantized counterpart. Must have quantization configuration | ||||
|     * :func:`~torch.quantization.get_observer_dict` — Traverses the module children and collects all observers into a ``dict`` | ||||
|     * :func:`~torch.quantization.prepare` — Prepares a copy of a model for quantization | ||||
|     * :func:`~torch.quantization.prepare_qat` — Prepares a copy of a model for quantization aware training | ||||
|     * :func:`~torch.quantization.propagate_qconfig_` — Propagates quantization configurations through the module hierarchy and assign them to each leaf module | ||||
|     * :func:`~torch.quantization.quantize` — Converts a float module to quantized version | ||||
|     * :func:`~torch.quantization.quantize_dynamic` — Converts a float module to dynamically quantized version | ||||
|     * :func:`~torch.quantization.quantize_qat`— Converts a float module to quantized version used in quantization aware training | ||||
|     * :func:`~torch.quantization.swap_module` — Swaps the module with its quantized counterpart (if quantizable and if it has an observer) | ||||
| * :func:`~torch.quantization.default_eval_fn` — Default evaluation function used by the :func:`torch.quantization.quantize` | ||||
| * :func:`~torch.quantization.fuse_modules` | ||||
| * :class:`~torch.quantization.FakeQuantize` — Module for simulating the quantization/dequantization at training time | ||||
| * Default Observers. The rest of observers are available from ``torch.quantization.observer`` | ||||
|     * :attr:`~torch.quantization.default_observer` — Same as ``MinMaxObserver.with_args(reduce_range=True)`` | ||||
|     * :attr:`~torch.quantization.default_weight_observer` — Same as ``MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)`` | ||||
|     * :class:`~torch.quantization.Observer` — Abstract base class for observers | ||||
| * Quantization configurations | ||||
|     * :class:`~torch.quantization.QConfig` — Quantization configuration class | ||||
|     * :attr:`~torch.quantization.default_qconfig` — Same as ``QConfig(activation=default_observer, weight=default_weight_observer)`` (See :class:`~torch.quantization.QConfig.QConfig`) | ||||
|     * :attr:`~torch.quantization.default_qat_qconfig` — Same as ``QConfig(activation=default_fake_quant, weight=default_weight_fake_quant)`` (See :class:`~torch.quantization.QConfig.QConfig`) | ||||
|     * :attr:`~torch.quantization.default_dynamic_qconfig` — Same as ``QConfigDynamic(weight=default_weight_observer)`` (See :class:`~torch.quantization.QConfig.QConfigDynamic`) | ||||
|     * :attr:`~torch.quantization.float16_dynamic_qconfig` — Same as ``QConfigDynamic(weight=NoopObserver.with_args(dtype=torch.float16))`` (See :class:`~torch.quantization.QConfig.QConfigDynamic`) | ||||
| * Stubs | ||||
|     * :class:`~torch.quantization.DeQuantStub` - placeholder module for dequantize() operation in float-valued models | ||||
|     * :class:`~torch.quantization.QuantStub` - placeholder module for quantize() operation in float-valued models | ||||
|     * :class:`~torch.quantization.QuantWrapper` — wraps the module to be quantized. Inserts the :class:`~torch.quantization.QuantStub` and :class:`~torch.quantization.DeQuantStub` | ||||
|  | ||||
| Observers for computing the quantization parameters | ||||
|  | ||||
| * :class:`~torch.quantization.MinMaxObserver` — Derives the quantization parameters from the running minimum and maximum of the observed tensor inputs (per tensor variant) | ||||
| * :class:`~torch.quantization.MovingAverageObserver` — Derives the quantization parameters from the running averages of the minimums and maximums of the observed tensor inputs (per tensor variant) | ||||
| * :class:`~torch.quantization.PerChannelMinMaxObserver`— Derives the quantization parameters from the running minimum and maximum of the observed tensor inputs (per channel variant) | ||||
| * :class:`~torch.quantization.MovingAveragePerChannelMinMaxObserver` — Derives the quantization parameters from the running averages of the minimums and maximums of the observed tensor inputs (per channel variant) | ||||
| * :class:`~torch.quantization.HistogramObserver` — Derives the quantization parameters by creating a histogram of running minimums and maximums. | ||||
| * Observers that do not compute the quantization parameters: | ||||
|     * :class:`~torch.quantization.RecordingObserver` — Records all incoming tensors. Used for debugging only. | ||||
|     * :class:`~torch.quantization.NoopObserver` — Pass-through observer. Used for situation when there are no quantization parameters (i.e. quantization to ``float16``) | ||||
|  | ||||
| ``torch.nn.quantized`` | ||||
| ~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
| Quantized version of standard NN layers. | ||||
|  | ||||
| * :class:`~torch.nn.quantized.Quantize` — Quantization layer, used to automatically replace :class:`~torch.quantization.QuantStub` | ||||
| * :class:`~torch.nn.quantized.DeQuantize` — Dequantization layer, used to replace :class:`~torch.quantization.DeQuantStub` | ||||
| * :class:`~torch.nn.quantized.FloatFunctional` — Wrapper class to make stateless float operations stateful so that they can be replaced with quantized versions | ||||
| * :class:`~torch.nn.quantized.QFunctional` — Wrapper class for quantized versions of stateless operations like ```torch.add`` | ||||
| * :class:`~torch.nn.quantized.Conv2d` — 2D convolution | ||||
| * :class:`~torch.nn.quantized.Linear` — Linear (fully-connected) layer | ||||
| * :class:`~torch.nn.MaxPool2d` — 2D max pooling | ||||
| * :class:`~torch.nn.quantized.ReLU` — Rectified linear unit | ||||
| * :class:`~torch.nn.quantized.ReLU6` — Rectified linear unit with cut-off at quantized representation of 6 | ||||
|  | ||||
| ``torch.nn.quantized.dynamic`` | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
| Layers used in dynamically quantized models (i.e. quantized only on weights) | ||||
|  | ||||
| * :class:`~torch.nn.quantized.dynamic.Linear` — Linear (fully-connected) layer | ||||
| * :class:`~torch.nn.quantized.dynamic.LSTM` — Long-Short Term Memory RNN module | ||||
|  | ||||
| ``torch.nn.quantized.functional`` | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
| Functional versions of quantized NN layers (many of them accept explicit quantization output parameters) | ||||
|  | ||||
| * :func:`~torch.nn.quantized.functional.adaptive_avg_pool2d` — 2D adaptive average pooling | ||||
| * :func:`~torch.nn.quantized.functional.avg_pool2d` — 2D average pooling | ||||
| * :func:`~torch.nn.quantized.functional.conv2d` — 2D convolution | ||||
| * :func:`~torch.nn.quantized.functional.interpolate` — Down-/up- sampler | ||||
| * :func:`~torch.nn.quantized.functional.linear` — Linear (fully-connected) op | ||||
| * :func:`~torch.nn.quantized.functional.max_pool2d` — 2D max pooling | ||||
| * :func:`~torch.nn.quantized.functional.relu` — Rectified linear unit | ||||
| * :func:`~torch.nn.quantized.functional.upsample` — Upsampler. Will be deprecated in favor of :func:`~torch.nn.quantized.functional.interpolate` | ||||
| * :func:`~torch.nn.quantized.functional.upsample_bilinear` — Bilenear upsampler. Will be deprecated in favor of :func:`~torch.nn.quantized.functional.interpolate` | ||||
| * :func:`~torch.nn.quantized.functional.upsample_nearest` — Nearest neighbor upsampler. Will be deprecated in favor of :func:`~torch.nn.quantized.functional.interpolate` | ||||
|  | ||||
| Quantized dtypes and quantization schemes | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
| * :attr:`torch.qscheme` — Type to describe the quantization scheme of a tensor. Supported types: | ||||
|     * :attr:`torch.per_tensor_affine` — per tensor, asymmetric | ||||
|     * :attr:`torch.per_channel_affine` — per channel, asymmetric | ||||
|     * :attr:`torch.per_tensor_symmetric` — per tensor, symmetric | ||||
|     * :attr:`torch.per_channel_symmetric` — per tensor, symmetric | ||||
| * ``torch.dtype`` — Type to describe the data. Supported types: | ||||
|     * :attr:`torch.quint8` — 8-bit unsigned integer | ||||
|     * :attr:`torch.qint8` — 8-bit signed integer | ||||
|     * :attr:`torch.qint32` — 32-bit signed integer | ||||
|  | ||||
|  | ||||
|  | ||||
| Quantization Workflows | ||||
| ---------------------- | ||||
|  | ||||
| PyTorch provides three approaches to quantize models. | ||||
|  | ||||
| 1. Post Training Dynamic Quantization: This is the simplest to apply form of | ||||
|    quantization where the weights are quantized ahead of time but the | ||||
|    activations are dynamically quantized  during inference. This is used | ||||
|    for situations where the model execution time is dominated by loading | ||||
|    weights from memory rather than computing the matrix multiplications. | ||||
|    This is true for for LSTM and Transformer type models with small | ||||
|    batch size. Applying dynamic quantization to a whole model can be | ||||
|    done with a single call to :func:`torch.quantization.quantize_dynamic()`. | ||||
|    See the `quantization tutorials <https://pytorch.org/tutorials/#quantization-experimental>`_ | ||||
| 2. Post Training Static Quantization: This is the most commonly used form of | ||||
|    quantization where the weights are quantized ahead of time and the | ||||
|    scale factor and bias for the activation tensors is pre-computed | ||||
|    based on observing the behavior of the model during a calibration | ||||
|    process. Post Training Quantization is typically when both memory bandwidth and compute | ||||
|    savings are important with CNNs being a typical use case. | ||||
|    The general process for doing post training quantization is: | ||||
|  | ||||
|  | ||||
|  | ||||
|    1. Prepare the model: | ||||
|       a. Specify where the activations are quantized and dequantized explicitly by adding QuantStub and DeQuantStub modules. | ||||
|       b. Ensure that modules are not reused. | ||||
|       c. Convert any operations that require requantization into modules | ||||
|    2. Fuse operations like conv + relu or conv+batchnorm + relu together to improve both model accuracy and performance. | ||||
|  | ||||
|    3. Specify the configuration of the quantization methods \'97 such as | ||||
|       selecting symmetric or asymmetric quantization and MinMax or | ||||
|       L2Norm calibration techniques. | ||||
|    4. Use the :func:`torch.quantization.prepare` to insert modules | ||||
|       that will observe activation tensors during calibration | ||||
|    5. Calibrate the model by running inference against a calibration | ||||
|       dataset | ||||
|    6. Finally, convert the model itself with the | ||||
|       torch.quantization.convert() method. This does several things: it | ||||
|       quantizes the weights, computes and stores the scale and bias | ||||
|       value to be used each activation tensor, and replaces key | ||||
|       operators quantized implementations. | ||||
|  | ||||
|    See the `quantization tutorials <https://pytorch.org/tutorials/#quantization_experimental>`_ | ||||
|  | ||||
|  | ||||
| 3. Quantization Aware Training: In the rare cases where post training | ||||
|    quantization does not provide adequate accuracy training can be done | ||||
|    with simulated quantization using the :class:`torch.quantization.FakeQuantize`. Computations | ||||
|    will take place in FP32 but with values clamped and rounded to | ||||
|    simulate the effects of INT8 quantization. The sequence of steps is | ||||
|    very similar. | ||||
|  | ||||
|  | ||||
|    1. Steps (1) and (2) are identical. | ||||
|  | ||||
|    3. Specify the configuration of the fake quantization methods \'97 such as | ||||
|       selecting symmetric or asymmetric quantization and MinMax or Moving Average | ||||
|       or L2Norm calibration techniques. | ||||
|    4. Use the :func:`torch.quantization.prepare_qat` to insert modules | ||||
|       that will simulate quantization during training. | ||||
|    5. Train or fine tune the model. | ||||
|    6. Identical to step (6) for post training quantization | ||||
|  | ||||
|    See the `quantization tutorials <https://pytorch.org/tutorials/#quantization_experimental>`_ | ||||
|  | ||||
|  | ||||
| While default implementations of observers to select the scale factor and bias | ||||
| based on observed tensor data are provided, developers can provide their own | ||||
| quantization functions. Quantization can be applied selectively to different | ||||
| parts of the model or configured differently for different parts of the model. | ||||
|  | ||||
| We also provide support for per channel quantization for **conv2d()** | ||||
| and **linear()** | ||||
|  | ||||
| Quantization workflows work by adding (e.g. adding observers as | ||||
| ``.observer`` submodule) or replacing (e.g. converting ``nn.Conv2d`` to | ||||
| ``nn.quantized.Conv2d``) submodules in the model's module hierarchy. It | ||||
| means that the model stays a regular ``nn.Module``-based instance throughout the | ||||
| process and thus can work with the rest of PyTorch APIs. | ||||
|  | ||||
|  | ||||
| Model Preparation for Quantization | ||||
| ---------------------------------- | ||||
|  | ||||
| It is necessary to currently make some modifications to the model definition | ||||
| prior to quantization. This is because currently quantization works on a module | ||||
| by module basis. Specifically, for all quantization techniques, the user needs to: | ||||
|  | ||||
| 1. Convert any operations that require output requantization (and thus have additional parameters) from functionals to module form. | ||||
| 2. Specify which parts of the model need to be quantized either by assigning ```.qconfig`` attributes on submodules or by specifying ``qconfig_dict`` | ||||
|  | ||||
| For static quantization techniques which quantize activations, the user needs to do the following in addition: | ||||
|  | ||||
| 1. Specify where activations are quantized and de-quantized. This is done using :class:`~torch.quantization.QuantStub` and :class:`~torch.quantization.DeQuantStub` modules. | ||||
| 2. Use :class:`torch.nn.quantized.FloatFunctional` to wrap tensor operations that require special handling for quantization into modules. Examples | ||||
|    are operations like ``add`` and ``cat`` which require special handling to determine output quantization parameters. | ||||
| 3. Fuse modules: combine operations/modules into a single module to obtain higher accuracy and performance. This is done using the | ||||
|    :func:`torch.quantization.fuse_modules` API, which takes in lists of modules to be fused. We currently support the following fusions: | ||||
|    [Conv, Relu], [Conv, BatchNorm], [Conv, BatchNorm, Relu], [Linear, Relu] | ||||
|  | ||||
|  | ||||
| torch.quantization | ||||
| --------------------------- | ||||
| .. automodule:: torch.quantization | ||||
|  | ||||
| This module implements the functions you call | ||||
| directly to convert your model from FP32 to quantized form. For | ||||
| example the :func:`~torch.quantization.prepare` is used in post training | ||||
| quantization to prepares your model for the calibration step and | ||||
| :func:`~torch.quantization.convert` actually converts the weights to int8 and | ||||
| replaces the operations with their quantized counterparts. There are | ||||
| other helper functions for things like quantizing the input to your | ||||
| model and performing critical fusions like conv+relu. | ||||
|  | ||||
| Top-level quantization APIs | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| .. autofunction:: quantize | ||||
| .. autofunction:: quantize_dynamic | ||||
| .. autofunction:: quantize_qat | ||||
| .. autofunction:: prepare | ||||
| .. autofunction:: prepare_qat | ||||
| .. autofunction:: convert | ||||
| .. autoclass:: QConfig | ||||
| .. autoclass:: QConfigDynamic | ||||
| .. autoattr:: default_qconfig | ||||
|  | ||||
| Preparing model for quantization | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| .. autofunction:: fuse_modules | ||||
| .. autoclass:: QuantStub | ||||
| .. autoclass:: DeQuantStub | ||||
| .. autoclass:: QuantWrapper | ||||
| .. autofunction:: add_quant_dequant | ||||
|  | ||||
| Utility functions | ||||
| ~~~~~~~~~~~~~~~~~ | ||||
| .. autofunction:: add_observer_ | ||||
| .. autofunction:: swap_module | ||||
| .. autofunction:: propagate_qconfig_ | ||||
| .. autofunction:: default_eval_fn | ||||
|  | ||||
| Observers | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: Observer | ||||
|     :members: | ||||
| .. autoclass:: MinMaxObserver | ||||
| .. autoclass:: MovingAverageObserver | ||||
| .. autoclass:: PerChannelMinMaxObserver | ||||
| .. autoclass:: MovingAveragePerChannelMinMaxObserver | ||||
| .. autoclass:: HistogramObserver | ||||
| .. autoclass:: FakeQuantize | ||||
| .. autoclass:: NoopObserver | ||||
|  | ||||
| Debugging utilities | ||||
| ~~~~~~~~~~~~~~~~~~~ | ||||
| .. autofunction:: get_observer_dict | ||||
| .. autoclass:: RecordingObserver | ||||
|  | ||||
| torch.nn.instrinsic | ||||
| -------------------------------- | ||||
|  | ||||
| This module implements the combined (fused) modules conv + relu which can be then quantized. | ||||
|  | ||||
| .. automodule:: torch.nn.intrinsic | ||||
|  | ||||
| ConvBn2d | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: ConvBn2d | ||||
|     :members: | ||||
|  | ||||
| ConvBnReLU2d | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: ConvBnReLU2d | ||||
|     :members: | ||||
|  | ||||
| ConvReLU2d | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: ConvReLU2d | ||||
|     :members: | ||||
|  | ||||
| LinearReLU | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: LinearReLU | ||||
|     :members: | ||||
|  | ||||
| torch.nn.instrinsic.qat | ||||
| -------------------------------- | ||||
|  | ||||
| This module implements the versions of those fused operations needed for quantization aware training. | ||||
|  | ||||
| .. automodule:: torch.nn.intrinsic.qat | ||||
|  | ||||
| ConvBn2d | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: ConvBn2d | ||||
|     :members: | ||||
|  | ||||
| ConvBnReLU2d | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: ConvBnReLU2d | ||||
|     :members: | ||||
|  | ||||
| ConvReLU2d | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: ConvReLU2d | ||||
|     :members: | ||||
|  | ||||
| LinearReLU | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: LinearReLU | ||||
|     :members: | ||||
|  | ||||
| torch.nn.intrinsic.quantized | ||||
| -------------------------------------- | ||||
|  | ||||
| This module implements the quantized implementations of fused operations like conv + relu. | ||||
|  | ||||
| .. automodule:: torch.nn.intrinsic.quantized | ||||
|  | ||||
| ConvReLU2d | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: ConvReLU2d | ||||
|     :members: | ||||
|  | ||||
| LinearReLU | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: LinearReLU | ||||
|     :members: | ||||
|  | ||||
| torch.nn.qat | ||||
| --------------------------- | ||||
|  | ||||
| This module implements versions of the key nn modules **Conv2d()** and **Linear()** which | ||||
| run in FP32 but with rounding applied to simulate the effect of INT8 quantization. | ||||
|  | ||||
| .. automodule:: torch.nn.qat | ||||
|  | ||||
| Conv2d | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: Conv2d | ||||
|     :members: | ||||
|  | ||||
| Linear | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: Linear | ||||
|     :members: | ||||
|  | ||||
|  | ||||
| torch.nn.quantized | ||||
| ---------------------------- | ||||
|  | ||||
| This module implements the quantized versions of the nn layers such as **Conv2d** and **ReLU**. | ||||
|  | ||||
| Functional interface | ||||
| ~~~~~~~~~~~~~~~~~~~~ | ||||
| .. automodule:: torch.nn.quantized.functional | ||||
|  | ||||
| .. autofunction:: relu | ||||
| .. autofunction:: linear | ||||
| .. autofunction:: conv2d | ||||
| .. autofunction:: max_pool2d | ||||
|  | ||||
| .. automodule:: torch.nn.quantized | ||||
|  | ||||
| ReLU | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: ReLU | ||||
|     :members: | ||||
|  | ||||
| ReLU6 | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: ReLU6 | ||||
|     :members: | ||||
|  | ||||
| Conv2d | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: Conv2d | ||||
|     :members: | ||||
|  | ||||
| FloatFunctional | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: FloatFunctional | ||||
|     :members: | ||||
|  | ||||
| QFunctional | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: QFunctional | ||||
|     :members: | ||||
|  | ||||
| Quantize | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: Quantize | ||||
|     :members: | ||||
|  | ||||
| DeQuantize | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: DeQuantize | ||||
|     :members: | ||||
|  | ||||
| Linear | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: Linear | ||||
|     :members: | ||||
|  | ||||
| torch.nn.quantized.dynamic | ||||
| ---------------------------- | ||||
|  | ||||
| .. automodule:: torch.nn.quantized.dynamic | ||||
|  | ||||
| Linear | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: Linear | ||||
|     :members: | ||||
|  | ||||
| LSTM | ||||
| ~~~~~~~~~~~~~~~ | ||||
| .. autoclass:: LSTM | ||||
|     :members: | ||||
| @ -91,5 +91,6 @@ Expected result: | ||||
|    .. automethod:: add_pr_curve | ||||
|    .. automethod:: add_custom_scalars | ||||
|    .. automethod:: add_mesh | ||||
|    .. automethod:: add_hparams | ||||
|    .. automethod:: flush | ||||
|    .. automethod:: close | ||||
|  | ||||
| @ -304,6 +304,8 @@ view of a storage and defines numeric operations on it. | ||||
|    .. automethod:: le_ | ||||
|    .. automethod:: lerp | ||||
|    .. automethod:: lerp_ | ||||
|    .. automethod:: lgamma | ||||
|    .. automethod:: lgamma_ | ||||
|    .. automethod:: log | ||||
|    .. automethod:: log_ | ||||
|    .. automethod:: logdet | ||||
| @ -363,6 +365,8 @@ view of a storage and defines numeric operations on it. | ||||
|    .. automethod:: permute | ||||
|    .. automethod:: pin_memory | ||||
|    .. automethod:: pinverse | ||||
|    .. automethod:: polygamma | ||||
|    .. automethod:: polygamma_ | ||||
|    .. automethod:: pow | ||||
|    .. automethod:: pow_ | ||||
|    .. automethod:: prod | ||||
|  | ||||
| @ -52,6 +52,8 @@ Creation Ops | ||||
| .. autofunction:: empty_strided | ||||
| .. autofunction:: full | ||||
| .. autofunction:: full_like | ||||
| .. autofunction:: quantize_per_tensor | ||||
| .. autofunction:: quantize_per_channel | ||||
|  | ||||
| Indexing, Slicing, Joining, Mutating Ops | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| @ -207,6 +209,7 @@ Pointwise Ops | ||||
| .. autofunction:: fmod | ||||
| .. autofunction:: frac | ||||
| .. autofunction:: lerp | ||||
| .. autofunction:: lgamma | ||||
| .. autofunction:: log | ||||
| .. autofunction:: log10 | ||||
| .. autofunction:: log1p | ||||
| @ -216,6 +219,7 @@ Pointwise Ops | ||||
| .. autofunction:: mul | ||||
| .. autofunction:: mvlgamma | ||||
| .. autofunction:: neg | ||||
| .. autofunction:: polygamma | ||||
| .. autofunction:: pow | ||||
| .. autofunction:: reciprocal | ||||
| .. autofunction:: remainder | ||||
|  | ||||
| @ -17,7 +17,7 @@ For Objective-C developers, simply import the umbrella header | ||||
| #import <LibTorch/LibTorch.h> | ||||
| ``` | ||||
|  | ||||
| For Swift developers, you need to create an Objective-C class as a bridge to call the C++ APIs. We highly recommend you to follow the [Image Classification](https://github.com/pytorch/examples) demo where you can find out how C++, Objective-C and Swift work together.  | ||||
| For Swift developers, you need to create an Objective-C class as a bridge to call the C++ APIs. We highly recommend you to follow the [Image Classification](https://github.com/pytorch/ios-demo-app/tree/master/PyTorchDemo) demo where you can find out how C++, Objective-C and Swift work together.  | ||||
|  | ||||
| ### Disable Bitcode  | ||||
|  | ||||
| @ -25,4 +25,4 @@ Since PyTorch is not yet built with bitcode support, you need to disable bitcode | ||||
|  | ||||
| ## LICENSE | ||||
|  | ||||
| PyTorch is BSD-style licensed, as found in the LICENSE file. | ||||
| PyTorch is BSD-style licensed, as found in the LICENSE file. | ||||
|  | ||||
							
								
								
									
										10
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								setup.py
									
									
									
									
									
								
							| @ -265,7 +265,11 @@ try: | ||||
| except Exception: | ||||
|     pass | ||||
|  | ||||
| # the list of runtime dependencies required by this built package | ||||
| install_requires = [] | ||||
|  | ||||
| if os.getenv('PYTORCH_BUILD_VERSION'): | ||||
|     install_requires += ['numpy'] | ||||
|     assert os.getenv('PYTORCH_BUILD_NUMBER') is not None | ||||
|     build_number = int(os.getenv('PYTORCH_BUILD_NUMBER')) | ||||
|     version = os.getenv('PYTORCH_BUILD_VERSION') | ||||
| @ -345,10 +349,8 @@ def build_deps(): | ||||
| # Building dependent libraries | ||||
| ################################################################################ | ||||
|  | ||||
| # the list of runtime dependencies required by this built package | ||||
| install_requires = [] | ||||
|  | ||||
| if sys.version_info <= (2, 7): | ||||
| if sys.version_info <= (3, 0): | ||||
|     install_requires += ['future'] | ||||
|  | ||||
| missing_pydep = ''' | ||||
| @ -371,8 +373,6 @@ class build_ext(setuptools.command.build_ext.build_ext): | ||||
|         cmake_cache_vars = defaultdict(lambda: False, cmake.get_cmake_cache_variables()) | ||||
|         if cmake_cache_vars['USE_NUMPY']: | ||||
|             report('-- Building with NumPy bindings') | ||||
|             global install_requires | ||||
|             install_requires += ['numpy'] | ||||
|         else: | ||||
|             report('-- NumPy not found') | ||||
|         if cmake_cache_vars['USE_CUDNN']: | ||||
|  | ||||
| @ -14,8 +14,9 @@ import torch.nn.quantized as nnq | ||||
| import torch.nn.quantized.dynamic as nnqd | ||||
| from common_utils import TestCase | ||||
| from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \ | ||||
|     default_qconfig, QConfig, default_observer, default_weight_observer, \ | ||||
|     default_qat_qconfig, propagate_qconfig_, convert, DEFAULT_DYNAMIC_MODULE_MAPPING | ||||
|     default_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \ | ||||
|     propagate_qconfig_, convert | ||||
| from torch.quantization.default_mappings import DEFAULT_DYNAMIC_MODULE_MAPPING | ||||
|  | ||||
| def test_only_eval_fn(model, calib_data): | ||||
|     r""" | ||||
| @ -214,7 +215,7 @@ class AnnotatedTwoLayerLinearModel(torch.nn.Module): | ||||
|         super(AnnotatedTwoLayerLinearModel, self).__init__() | ||||
|         self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) | ||||
|         self.fc2 = QuantWrapper(torch.nn.Linear(8, 5).to(dtype=torch.float)) | ||||
|         self.fc2.qconfig = default_qconfig | ||||
|         self.fc2.qconfig = torch.quantization.get_default_qconfig("fbgemm") | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.fc1(x) | ||||
| @ -252,7 +253,7 @@ class AnnotatedNestedModel(torch.nn.Module): | ||||
|         self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) | ||||
|         self.fc3.qconfig = default_qconfig | ||||
|         self.sub2.fc1 = QuantWrapper(self.sub2.fc1) | ||||
|         self.sub2.fc1.qconfig = default_qconfig | ||||
|         self.sub2.fc1.qconfig = default_per_channel_qconfig | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.sub1(x) | ||||
| @ -346,7 +347,7 @@ class QuantStubModel(torch.nn.Module): | ||||
|     """ | ||||
|     def __init__(self): | ||||
|         super(QuantStubModel, self).__init__() | ||||
|         self.qconfig = default_qconfig | ||||
|         self.qconfig = torch.quantization.get_default_qconfig("qnnpack") | ||||
|         self.quant = QuantStub() | ||||
|         self.dequant = DeQuantStub() | ||||
|         self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) | ||||
| @ -361,7 +362,7 @@ class ManualLinearQATModel(torch.nn.Module): | ||||
|     """ | ||||
|     def __init__(self): | ||||
|         super(ManualLinearQATModel, self).__init__() | ||||
|         self.qconfig = default_qat_qconfig | ||||
|         self.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm") | ||||
|         self.quant = QuantStub() | ||||
|         self.dequant = DeQuantStub() | ||||
|         self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float) | ||||
| @ -379,7 +380,7 @@ class ManualConvLinearQATModel(torch.nn.Module): | ||||
|     """ | ||||
|     def __init__(self): | ||||
|         super(ManualConvLinearQATModel, self).__init__() | ||||
|         self.qconfig = default_qat_qconfig | ||||
|         self.qconfig = torch.quantization.get_default_qat_qconfig("qnnpack") | ||||
|         self.quant = QuantStub() | ||||
|         self.dequant = DeQuantStub() | ||||
|         self.conv = torch.nn.Conv2d(3, 1, kernel_size=3).to(dtype=torch.float) | ||||
| @ -444,6 +445,38 @@ class ModelForFusion(nn.Module): | ||||
|         x = self.fc(x) | ||||
|         return x | ||||
|  | ||||
| class ConvBNReLU(nn.Sequential): | ||||
|     def __init__(self): | ||||
|         super(ConvBNReLU, self).__init__( | ||||
|             nn.Conv2d(3, 3, 1, 1, bias=False), | ||||
|             nn.BatchNorm2d(3), | ||||
|             nn.ReLU(inplace=False) | ||||
|         ) | ||||
|  | ||||
| class ModelWithSequentialFusion(nn.Module): | ||||
|     def __init__(self): | ||||
|         super(ModelWithSequentialFusion, self).__init__() | ||||
|         self.conv1 = nn.Conv2d(3, 3, 1) | ||||
|         self.relu1 = nn.ReLU(inplace=False) | ||||
|         layers = [] | ||||
|         for i in range(3): | ||||
|             layers.append(ConvBNReLU()) | ||||
|         self.features = nn.Sequential(*layers) | ||||
|         head = [nn.Linear(300, 10), nn.ReLU(inplace=False)] | ||||
|         self.classifier = nn.Sequential(*head) | ||||
|         self.quant = QuantStub() | ||||
|         self.dequant = DeQuantStub() | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.quant(x) | ||||
|         x = self.conv1(x) | ||||
|         x = self.relu1(x) | ||||
|         x = self.features(x) | ||||
|         x = torch.reshape(x, (-1, 3 * 10 * 10)) | ||||
|         x = self.classifier(x) | ||||
|         x = self.dequant(x) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class DummyObserver(torch.nn.Module): | ||||
|     def calculate_qparams(self): | ||||
| @ -453,30 +486,27 @@ class DummyObserver(torch.nn.Module): | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class ModForWrapping(torch.nn.Module): | ||||
|     def __init__(self, quantized=False): | ||||
|         super(ModForWrapping, self).__init__() | ||||
|         self.qconfig = default_qconfig | ||||
|         if quantized: | ||||
|             self.mycat = nnq.QFunctional() | ||||
|             self.myadd = nnq.QFunctional() | ||||
|         else: | ||||
|             self.mycat = nnq.FloatFunctional() | ||||
|             self.myadd = nnq.FloatFunctional() | ||||
|             self.mycat.observer = DummyObserver() | ||||
|             self.myadd.observer = DummyObserver() | ||||
| class ModelWithFunctionals(torch.nn.Module): | ||||
|     def __init__(self): | ||||
|         super(ModelWithFunctionals, self).__init__() | ||||
|         self.mycat = nnq.FloatFunctional() | ||||
|         self.myadd = nnq.FloatFunctional() | ||||
|         self.myadd_relu = nnq.FloatFunctional() | ||||
|         # Tracing doesnt work yet for c10 ops with scalar inputs | ||||
|         # https://github.com/pytorch/pytorch/issues/27097 | ||||
|         # self.my_scalar_add = nnq.FloatFunctional() | ||||
|         # self.my_scalar_mul = nnq.FloatFunctional() | ||||
|  | ||||
|     def forward(self, x): | ||||
|         y = self.mycat.cat([x, x, x]) | ||||
|         z = self.myadd.add(y, y) | ||||
|         return z | ||||
|         w = self.myadd_relu.add_relu(z, z) | ||||
|         # Tracing doesnt work yet for c10 ops with scalar inputs | ||||
|         # https://github.com/pytorch/pytorch/issues/27097 | ||||
|         # w = self.my_scalar_add.add_scalar(w, -0.5) | ||||
|         # w = self.my_scalar_mul.mul_scalar(w, 0.5) | ||||
|         return w | ||||
|  | ||||
|     @classmethod | ||||
|     def from_float(cls, mod): | ||||
|         new_mod = cls(quantized=True) | ||||
|         new_mod.mycat = new_mod.mycat.from_float(mod.mycat) | ||||
|         new_mod.myadd = new_mod.myadd.from_float(mod.myadd) | ||||
|         return new_mod | ||||
|  | ||||
| class ResNetBase(torch.nn.Module): | ||||
|     def __init__(self): | ||||
| @ -532,3 +562,62 @@ class ModelMultipleOps(torch.nn.Module): | ||||
|         out = out.view(-1, 3 * 2 * 2) | ||||
|         out = self.fc(out) | ||||
|         return out | ||||
|  | ||||
| # Model to ensure consistency of fake quant with true quant | ||||
| # Average pooling and mean operations are not modelled | ||||
| # accurately with fake-quant so this model does not | ||||
| # contain those operations | ||||
| class ModelMultipleOpsNoAvgPool(torch.nn.Module): | ||||
|     def __init__(self): | ||||
|         super(ModelMultipleOpsNoAvgPool, self).__init__() | ||||
|         norm_layer = nn.BatchNorm2d | ||||
|         inplanes = 3 | ||||
|         self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) | ||||
|         self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) | ||||
|         self.bn1 = norm_layer(inplanes) | ||||
|         self.relu1 = nn.ReLU() | ||||
|         self.relu2 = nn.ReLU() | ||||
|         self.skip_add = nn.quantized.FloatFunctional() | ||||
|         self.cat = nn.quantized.FloatFunctional() | ||||
|         self.maxpool = nn.MaxPool2d((4, 4)) | ||||
|         self.fc = nn.Linear(12, 6) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv1(x) | ||||
|         out = self.bn1(out) | ||||
|         out = self.relu1(out) | ||||
|         skip = self.conv2(x) | ||||
|         out = self.skip_add.add(out, skip) | ||||
|         out = self.relu2(out) | ||||
|         out = self.maxpool(out) | ||||
|         out = self.conv2(out) | ||||
|         out = torch.nn.functional.max_pool2d(out, 2, 2) | ||||
|         out = self.cat.cat([out, out]) | ||||
|         out = out.view(-1, 3 * 2 * 2) | ||||
|         out = self.fc(out) | ||||
|         return out | ||||
|  | ||||
| """Model to make sure that the observers are not inserted into custom modules. | ||||
| """ | ||||
| class ModelWithNoQconfigPropagation(nn.Module): | ||||
|     class ListOutModule(nn.Module): | ||||
|         def __init__(self): | ||||
|             super(ModelWithNoQconfigPropagation.ListOutModule, self).__init__() | ||||
|  | ||||
|         def forward(self, x): | ||||
|             # returns a list of tensors, not supported by observers | ||||
|             return [x] | ||||
|  | ||||
|     def __init__(self): | ||||
|         super(ModelWithNoQconfigPropagation, self).__init__() | ||||
|         self.fc1 = nn.Linear(5, 5).to(dtype=torch.float) | ||||
|         self.quant = QuantStub() | ||||
|         self.dequant = DeQuantStub() | ||||
|         self.no_quant_module = self.ListOutModule() | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.quant(x) | ||||
|         x = self.fc1(x) | ||||
|         x = self.dequant(x) | ||||
|         x = self.no_quant_module(x) | ||||
|         return x | ||||
|  | ||||
| @ -63,10 +63,36 @@ def _calculate_dynamic_qparams(X, dtype): | ||||
|         zero_point = min(qmax, zero_point) | ||||
|     return [float(scale), int(zero_point)] | ||||
|  | ||||
| def _calculate_dynamic_per_channel_qparams(X, dtype): | ||||
|     """Calculate the dynamic quantization parameters (scale, zero_point) | ||||
|     according to the min and max element of the tensor""" | ||||
|     if isinstance(X, torch.Tensor): | ||||
|         X = X.numpy() | ||||
|     qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max | ||||
|     n_levels = qmax - qmin | ||||
|     scale = np.zeros(X.shape[0], dtype=np.float64) | ||||
|     zero_point = np.zeros(X.shape[0], dtype=np.int64) | ||||
|     for i in range(zero_point.shape[0]): | ||||
|         min_val = X.min() | ||||
|         max_val = X.max() | ||||
|         if min_val == max_val: | ||||
|             scale[i] = 1.0 | ||||
|             zero_point[i] = 0 | ||||
|         else: | ||||
|             max_val = max(max_val, 0.0) | ||||
|             min_val = min(min_val, 0.0) | ||||
|             scale[i] = (max_val - min_val) / n_levels | ||||
|             scale[i] = max(scale[i], np.finfo(np.float32).eps) | ||||
|             zero_point[i] = qmin - round(min_val / scale[i]) | ||||
|             zero_point[i] = max(qmin, zero_point[i]) | ||||
|             zero_point[i] = min(qmax, zero_point[i]) | ||||
|  | ||||
|     return scale, zero_point | ||||
|  | ||||
| @contextmanager | ||||
| def enable_mobile_quantized_engine(): | ||||
| def override_quantized_engine(qengine): | ||||
|     previous = torch.backends.quantized.engine | ||||
|     torch.backends.quantized.engine = 'qnnpack' | ||||
|     torch.backends.quantized.engine = qengine | ||||
|     try: | ||||
|         yield | ||||
|     finally: | ||||
|  | ||||
| @ -185,7 +185,6 @@ TEST_WITH_ASAN = os.getenv('PYTORCH_TEST_WITH_ASAN', '0') == '1' | ||||
| TEST_WITH_TSAN = os.getenv('PYTORCH_TEST_WITH_TSAN', '0') == '1' | ||||
| TEST_WITH_UBSAN = os.getenv('PYTORCH_TEST_WITH_UBSAN', '0') == '1' | ||||
| TEST_WITH_ROCM = os.getenv('PYTORCH_TEST_WITH_ROCM', '0') == '1' | ||||
| TEST_WITH_QNNPACK = os.getenv('PYTORCH_TEST_WITH_QNNPACK', '0') == '1' | ||||
| # Enables tests that are slow to run (disabled by default) | ||||
| TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1' | ||||
|  | ||||
|  | ||||
| @ -8,7 +8,7 @@ from hypothesis import strategies as st | ||||
| from hypothesis.extra import numpy as stnp | ||||
| from hypothesis.searchstrategy import SearchStrategy | ||||
|  | ||||
| from common_quantized import _calculate_dynamic_qparams | ||||
| from common_quantized import _calculate_dynamic_qparams, _calculate_dynamic_per_channel_qparams | ||||
|  | ||||
| # Setup for the hypothesis tests. | ||||
| # The tuples are (torch_quantized_dtype, zero_point_enforce), where the last | ||||
| @ -182,6 +182,38 @@ def tensor(draw, shapes=None, elements=None, qparams=None): | ||||
|         zp = enforced_zp | ||||
|     return X, (scale, zp, qparams[2]) | ||||
|  | ||||
| @st.composite | ||||
| def per_channel_tensor(draw, shapes=None, elements=None, qparams=None): | ||||
|     if isinstance(shapes, SearchStrategy): | ||||
|         _shape = draw(shapes) | ||||
|     else: | ||||
|         _shape = draw(st.sampled_from(shapes)) | ||||
|     if qparams is None: | ||||
|         if elements is None: | ||||
|             elements = st.floats(-1e6, 1e6, allow_nan=False) | ||||
|         X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape)) | ||||
|         assume(not (np.isnan(X).any() or np.isinf(X).any())) | ||||
|         return X, None | ||||
|     qparams = draw(qparams) | ||||
|     if elements is None: | ||||
|         min_value, max_value = _get_valid_min_max(qparams) | ||||
|         elements = st.floats(min_value, max_value, allow_infinity=False, | ||||
|                              allow_nan=False) | ||||
|     X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape)) | ||||
|     # Recompute the scale and zero_points according to the X statistics. | ||||
|     scale, zp = _calculate_dynamic_per_channel_qparams(X, qparams[2]) | ||||
|     enforced_zp = _ENFORCED_ZERO_POINT.get(qparams[2], None) | ||||
|     if enforced_zp is not None: | ||||
|         zp = enforced_zp | ||||
|     # Permute to model quantization along an axis | ||||
|     axis = int(np.random.randint(0, X.ndim, 1)) | ||||
|     permute_axes = np.arange(X.ndim) | ||||
|     permute_axes[0] = axis | ||||
|     permute_axes[axis] = 0 | ||||
|     X = np.transpose(X, permute_axes) | ||||
|  | ||||
|     return X, (scale, zp, axis, qparams[2]) | ||||
|  | ||||
| """Strategy for generating test cases for tensors used in Conv2D. | ||||
| The resulting tensors is in float32 format. | ||||
|  | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| ir_version: 4 | ||||
| producer_name: "pytorch" | ||||
| producer_version: "1.2" | ||||
| producer_version: "1.3" | ||||
| graph { | ||||
|   node { | ||||
|     input: "0" | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| ir_version: 4 | ||||
| producer_name: "pytorch" | ||||
| producer_version: "1.2" | ||||
| producer_version: "1.3" | ||||
| graph { | ||||
|   node { | ||||
|     input: "0" | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| ir_version: 4 | ||||
| producer_name: "pytorch" | ||||
| producer_version: "1.2" | ||||
| producer_version: "1.3" | ||||
| graph { | ||||
|   node { | ||||
|     input: "0" | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| ir_version: 4 | ||||
| producer_name: "pytorch" | ||||
| producer_version: "1.2" | ||||
| producer_version: "1.3" | ||||
| graph { | ||||
|   node { | ||||
|     input: "0" | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| ir_version: 4 | ||||
| producer_name: "pytorch" | ||||
| producer_version: "1.2" | ||||
| producer_version: "1.3" | ||||
| graph { | ||||
|   node { | ||||
|     input: "0" | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| ir_version: 4 | ||||
| producer_name: "pytorch" | ||||
| producer_version: "1.2" | ||||
| producer_version: "1.3" | ||||
| graph { | ||||
|   node { | ||||
|     input: "0" | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| ir_version: 4 | ||||
| producer_name: "pytorch" | ||||
| producer_version: "1.2" | ||||
| producer_version: "1.3" | ||||
| graph { | ||||
|   node { | ||||
|     output: "1" | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| ir_version: 4 | ||||
| producer_name: "pytorch" | ||||
| producer_version: "1.2" | ||||
| producer_version: "1.3" | ||||
| graph { | ||||
|   node { | ||||
|     input: "0" | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| ir_version: 4 | ||||
| producer_name: "pytorch" | ||||
| producer_version: "1.2" | ||||
| producer_version: "1.3" | ||||
| graph { | ||||
|   node { | ||||
|     input: "0" | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| ir_version: 4 | ||||
| producer_name: "pytorch" | ||||
| producer_version: "1.2" | ||||
| producer_version: "1.3" | ||||
| graph { | ||||
|   node { | ||||
|     input: "0" | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| ir_version: 4 | ||||
| producer_name: "pytorch" | ||||
| producer_version: "1.2" | ||||
| producer_version: "1.3" | ||||
| graph { | ||||
|   node { | ||||
|     input: "x" | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| ir_version: 4 | ||||
| producer_name: "pytorch" | ||||
| producer_version: "1.2" | ||||
| producer_version: "1.3" | ||||
| graph { | ||||
|   node { | ||||
|     input: "0" | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| ir_version: 4 | ||||
| producer_name: "pytorch" | ||||
| producer_version: "1.2" | ||||
| producer_version: "1.3" | ||||
| graph { | ||||
|   node { | ||||
|     input: "0" | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| ir_version: 4 | ||||
| producer_name: "pytorch" | ||||
| producer_version: "1.2" | ||||
| producer_version: "1.3" | ||||
| graph { | ||||
|   node { | ||||
|     output: "3" | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	