mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-30 03:34:56 +08:00
Compare commits
126 Commits
dev/joona/
...
pool-separ
| Author | SHA1 | Date | |
|---|---|---|---|
| 2ce56de80e | |||
| 34cd5614c5 | |||
| 15d7f6ac2b | |||
| 7b80b3fd13 | |||
| b0b1902739 | |||
| 1d29dc5d9c | |||
| fe518636a6 | |||
| 765dd32545 | |||
| b9ca9918ba | |||
| 003540fcb6 | |||
| a7f788143e | |||
| befb5bd52a | |||
| 3aa84775e7 | |||
| a060f3d272 | |||
| 2ce0b66db8 | |||
| f66a159db5 | |||
| 5a0ca65555 | |||
| 053025494f | |||
| 5964cb5eb1 | |||
| 4759922c5e | |||
| ca96d55322 | |||
| 5c6830ced0 | |||
| 574f4c507a | |||
| 5926b7a38f | |||
| fe51ce62ca | |||
| 481c345f49 | |||
| cf7021a0ee | |||
| 477f13c3fb | |||
| 6592086ac3 | |||
| 8ca985b365 | |||
| 9ccd601a14 | |||
| 3443627e07 | |||
| 86c6f71ddb | |||
| 4d073af58c | |||
| 741539a790 | |||
| a9adc9a9b6 | |||
| 658d17dfb5 | |||
| 3fe42d4d5d | |||
| d965fa2c4b | |||
| 1503b3f897 | |||
| 1a722f62c2 | |||
| 7e16cb99b6 | |||
| 459ce6c12a | |||
| 7ed377f577 | |||
| 56e1c236bf | |||
| d1f1ff8610 | |||
| 1b4749f748 | |||
| 5aea57d653 | |||
| 9d3b6ee4c1 | |||
| d1dd2c1fc8 | |||
| ab757dcddc | |||
| 754b758ea1 | |||
| 2489b6470b | |||
| cb5f31a4a1 | |||
| e7a40fb301 | |||
| ea17cd067d | |||
| b972435158 | |||
| e4adf5df39 | |||
| b8fad785d5 | |||
| 90deff6d59 | |||
| a2e2f908fd | |||
| ae0e8f0c73 | |||
| b03e4f53d2 | |||
| f7ecc091a0 | |||
| 064f4c18f9 | |||
| a4c828199e | |||
| b22f01fcb9 | |||
| 00e5cb3db3 | |||
| 480ae2dab8 | |||
| cfee9046b6 | |||
| 236b08cbf8 | |||
| 2327c9eedc | |||
| db26aeaec2 | |||
| 0cb48633d9 | |||
| fa8543454a | |||
| 4f4ecc583e | |||
| 7482eb217c | |||
| e5e06d9cab | |||
| 22b124335e | |||
| f7a5aa1d8d | |||
| 129a2976a8 | |||
| 6e107899da | |||
| b297e01f4b | |||
| 4863e5c843 | |||
| 71027b13b2 | |||
| 004dad48f7 | |||
| 55784be01b | |||
| a762dd1f67 | |||
| 181bfabb9e | |||
| 9839ec1383 | |||
| b47be23461 | |||
| 910d2f96af | |||
| 0ca91af6b8 | |||
| ebd3268538 | |||
| b992a665d1 | |||
| d5ddc5ab20 | |||
| 3e8bda4ad5 | |||
| 756fd80734 | |||
| d2f6c6df1d | |||
| 014726d9d3 | |||
| 881a598a1e | |||
| eaf2dee10e | |||
| 725bbb6b5f | |||
| f5e0806f34 | |||
| 82dc3457e0 | |||
| 781ba0ac9d | |||
| c2bc7e2827 | |||
| 72fee137dd | |||
| e0dece510b | |||
| 7412b33e91 | |||
| f2e8e41855 | |||
| f887bfffda | |||
| 03d01860fd | |||
| 1bd6bc7190 | |||
| f363a3f51a | |||
| c92ea3bc98 | |||
| 5e6e52e7c9 | |||
| dda2c7c8fc | |||
| 6a28cc826f | |||
| a54bf43baa | |||
| 534b66fe30 | |||
| bf0fe4f828 | |||
| 8749fe8439 | |||
| 47d6feff7c | |||
| 6ef1cbc191 | |||
| 533fc58453 |
@ -251,14 +251,6 @@ case "$tag" in
|
||||
UCC_COMMIT=${_UCC_COMMIT}
|
||||
INDUCTOR_BENCHMARKS=yes
|
||||
;;
|
||||
pytorch-linux-jammy-xpu-2024.0-py3)
|
||||
ANACONDA_PYTHON_VERSION=3.9
|
||||
GCC_VERSION=11
|
||||
VISION=yes
|
||||
XPU_VERSION=0.5
|
||||
NINJA_VERSION=1.9.0
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-xpu-2025.0-py3)
|
||||
ANACONDA_PYTHON_VERSION=3.9
|
||||
GCC_VERSION=11
|
||||
@ -267,6 +259,14 @@ case "$tag" in
|
||||
NINJA_VERSION=1.9.0
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-xpu-2025.1-py3)
|
||||
ANACONDA_PYTHON_VERSION=3.9
|
||||
GCC_VERSION=11
|
||||
VISION=yes
|
||||
XPU_VERSION=2025.1
|
||||
NINJA_VERSION=1.9.0
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks)
|
||||
ANACONDA_PYTHON_VERSION=3.9
|
||||
GCC_VERSION=11
|
||||
|
||||
@ -26,7 +26,7 @@ function install_ubuntu() {
|
||||
wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
||||
| gpg --dearmor > /usr/share/keyrings/oneapi-archive-keyring.gpg.gpg
|
||||
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg.gpg] \
|
||||
https://apt.repos.intel.com/${XPU_REPO_NAME} all main" \
|
||||
https://apt.repos.intel.com/oneapi all main" \
|
||||
| tee /etc/apt/sources.list.d/oneAPI.list
|
||||
|
||||
# Update the packages list and repository index
|
||||
@ -74,7 +74,7 @@ function install_rhel() {
|
||||
tee > /etc/yum.repos.d/oneAPI.repo << EOF
|
||||
[oneAPI]
|
||||
name=Intel for Pytorch GPU dev repository
|
||||
baseurl=https://yum.repos.intel.com/${XPU_REPO_NAME}
|
||||
baseurl=https://yum.repos.intel.com/oneapi
|
||||
enabled=1
|
||||
gpgcheck=1
|
||||
repo_gpgcheck=1
|
||||
@ -118,7 +118,7 @@ function install_sles() {
|
||||
https://repositories.intel.com/gpu/sles/${VERSION_SP}${XPU_DRIVER_VERSION}/unified/intel-gpu-${VERSION_SP}.repo
|
||||
rpm --import https://repositories.intel.com/gpu/intel-graphics.key
|
||||
# To add the online network network package repository for the Intel Support Packages
|
||||
zypper addrepo https://yum.repos.intel.com/${XPU_REPO_NAME} oneAPI
|
||||
zypper addrepo https://yum.repos.intel.com/oneapi oneAPI
|
||||
rpm --import https://yum.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB
|
||||
|
||||
# The xpu-smi packages
|
||||
@ -141,10 +141,10 @@ if [[ "${XPU_DRIVER_TYPE,,}" == "rolling" ]]; then
|
||||
XPU_DRIVER_VERSION=""
|
||||
fi
|
||||
|
||||
XPU_REPO_NAME="intel-for-pytorch-gpu-dev"
|
||||
XPU_PACKAGES="intel-for-pytorch-gpu-dev-0.5 intel-pti-dev-0.9"
|
||||
if [[ "$XPU_VERSION" == "2025.0" ]]; then
|
||||
XPU_REPO_NAME="oneapi"
|
||||
# Default use Intel® oneAPI Deep Learning Essentials 2025.0
|
||||
if [[ "$XPU_VERSION" == "2025.1" ]]; then
|
||||
XPU_PACKAGES="intel-deep-learning-essentials-2025.1"
|
||||
else
|
||||
XPU_PACKAGES="intel-deep-learning-essentials-2025.0"
|
||||
fi
|
||||
|
||||
|
||||
@ -174,6 +174,6 @@ ENV XPU_DRIVER_TYPE ROLLING
|
||||
RUN python3 -m pip install --upgrade pip && \
|
||||
python3 -mpip install cmake==3.28.4
|
||||
ADD ./common/install_xpu.sh install_xpu.sh
|
||||
ENV XPU_VERSION 2025.0
|
||||
ENV XPU_VERSION 2025.1
|
||||
RUN bash ./install_xpu.sh && rm install_xpu.sh
|
||||
RUN pushd /opt/_internal && tar -xJf static-libs-for-embedding-only.tar.xz && popd
|
||||
|
||||
@ -20,7 +20,11 @@ fi
|
||||
source /opt/intel/oneapi/compiler/latest/env/vars.sh
|
||||
source /opt/intel/oneapi/pti/latest/env/vars.sh
|
||||
source /opt/intel/oneapi/umf/latest/env/vars.sh
|
||||
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
||||
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
||||
export USE_STATIC_MKL=1
|
||||
export USE_ONEMKL=1
|
||||
export USE_XCCL=1
|
||||
|
||||
WHEELHOUSE_DIR="wheelhousexpu"
|
||||
LIBTORCH_HOUSE_DIR="libtorch_housexpu"
|
||||
|
||||
@ -99,30 +99,6 @@ if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
|
||||
export ACL_ROOT_DIR=/ComputeLibrary
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *libtorch* ]]; then
|
||||
POSSIBLE_JAVA_HOMES=()
|
||||
POSSIBLE_JAVA_HOMES+=(/usr/local)
|
||||
POSSIBLE_JAVA_HOMES+=(/usr/lib/jvm/java-8-openjdk-amd64)
|
||||
POSSIBLE_JAVA_HOMES+=(/Library/Java/JavaVirtualMachines/*.jdk/Contents/Home)
|
||||
# Add the Windows-specific JNI
|
||||
POSSIBLE_JAVA_HOMES+=("$PWD/.circleci/windows-jni/")
|
||||
for JH in "${POSSIBLE_JAVA_HOMES[@]}" ; do
|
||||
if [[ -e "$JH/include/jni.h" ]] ; then
|
||||
# Skip if we're not on Windows but haven't found a JAVA_HOME
|
||||
if [[ "$JH" == "$PWD/.circleci/windows-jni/" && "$OSTYPE" != "msys" ]] ; then
|
||||
break
|
||||
fi
|
||||
echo "Found jni.h under $JH"
|
||||
export JAVA_HOME="$JH"
|
||||
export BUILD_JNI=ON
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ -z "$JAVA_HOME" ]; then
|
||||
echo "Did not find jni.h"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Use special scripts for Android builds
|
||||
if [[ "${BUILD_ENVIRONMENT}" == *-android* ]]; then
|
||||
export ANDROID_NDK=/opt/ndk
|
||||
|
||||
@ -232,7 +232,8 @@ test_torchbench_smoketest() {
|
||||
mkdir -p "$TEST_REPORTS_DIR"
|
||||
|
||||
local device=mps
|
||||
local models=(hf_T5 llama BERT_pytorch dcgan hf_GPT2 yolov3 resnet152 sam pytorch_unet stable_diffusion_text_encoder moco speech_transformer)
|
||||
local models=(hf_T5 llama BERT_pytorch dcgan hf_GPT2 yolov3 resnet152 sam pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo)
|
||||
local hf_models=(GoogleFnet YituTechConvBert)
|
||||
|
||||
for backend in eager inductor; do
|
||||
|
||||
@ -253,6 +254,16 @@ test_torchbench_smoketest() {
|
||||
--output "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype}_inference_${device}_accuracy.csv" || true
|
||||
fi
|
||||
done
|
||||
for model in "${hf_models[@]}"; do
|
||||
if [ "$backend" == "inductor" ]; then
|
||||
PYTHONPATH="$(pwd)"/torchbench python benchmarks/dynamo/huggingface.py \
|
||||
--performance --only "$model" --backend "$backend" --inference --devices "$device" "$dtype_arg" \
|
||||
--output "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype}_inference_${device}_performance.csv" || true
|
||||
PYTHONPATH="$(pwd)"/torchbench python benchmarks/dynamo/huggingface.py \
|
||||
--accuracy --only "$model" --backend "$backend" --inference --devices "$device" "$dtype_arg" \
|
||||
--output "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype}_inference_${device}_accuracy.csv" || true
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
for dtype in notset amp; do
|
||||
|
||||
@ -37,6 +37,11 @@ call %INSTALLER_DIR%\activate_miniconda3.bat
|
||||
if errorlevel 1 goto fail
|
||||
if not errorlevel 0 goto fail
|
||||
|
||||
:: Update CMake
|
||||
call choco upgrade -y cmake --no-progress --installargs 'ADD_CMAKE_TO_PATH=System' --apply-install-arguments-to-dependencies --version=3.27.9
|
||||
if errorlevel 1 goto fail
|
||||
if not errorlevel 0 goto fail
|
||||
|
||||
call pip install mkl-include==2021.4.0 mkl-devel==2021.4.0
|
||||
if errorlevel 1 goto fail
|
||||
if not errorlevel 0 goto fail
|
||||
@ -88,7 +93,7 @@ set PATH=%CUDA_PATH%\bin;%CUDA_PATH%\libnvvp;%PATH%
|
||||
:cuda_build_end
|
||||
|
||||
set DISTUTILS_USE_SDK=1
|
||||
set PATH=%TMP_DIR_WIN%\bin;%PATH%
|
||||
set PATH=%TMP_DIR_WIN%\bin;C:\Program Files\CMake\bin;%PATH%
|
||||
|
||||
:: The latest Windows CUDA test is running on AWS G5 runner with A10G GPU
|
||||
if "%TORCH_CUDA_ARCH_LIST%" == "" set TORCH_CUDA_ARCH_LIST=8.6
|
||||
|
||||
@ -10,53 +10,23 @@ if not "%CUDA_VERSION%" == "xpu" (
|
||||
set SRC_DIR=%NIGHTLIES_PYTORCH_ROOT%
|
||||
if not exist "%SRC_DIR%\temp_build" mkdir "%SRC_DIR%\temp_build"
|
||||
|
||||
set XPU_INSTALL_MODE=%~1
|
||||
if "%XPU_INSTALL_MODE%"=="" goto xpu_bundle_install_start
|
||||
if "%XPU_INSTALL_MODE%"=="bundle" goto xpu_bundle_install_start
|
||||
if "%XPU_INSTALL_MODE%"=="driver" goto xpu_driver_install_start
|
||||
if "%XPU_INSTALL_MODE%"=="all" goto xpu_driver_install_start
|
||||
|
||||
:arg_error
|
||||
|
||||
echo Illegal XPU installation mode. The value can be "bundle"/"driver"/"all"
|
||||
echo If keep the value as space, will use default "bundle" mode
|
||||
exit /b 1
|
||||
|
||||
:xpu_driver_install_start
|
||||
:: TODO Need more testing for driver installation
|
||||
set XPU_DRIVER_LINK=https://downloadmirror.intel.com/830975/gfx_win_101.5972.exe
|
||||
curl -o xpu_driver.exe --retry 3 --retry-all-errors -k %XPU_DRIVER_LINK%
|
||||
echo "XPU Driver installing..."
|
||||
start /wait "Intel XPU Driver Installer" "xpu_driver.exe"
|
||||
if errorlevel 1 exit /b 1
|
||||
del xpu_driver.exe
|
||||
if "%XPU_INSTALL_MODE%"=="driver" goto xpu_install_end
|
||||
|
||||
:xpu_bundle_install_start
|
||||
|
||||
set XPU_BUNDLE_PARENT_DIR=C:\Program Files (x86)\Intel\oneAPI
|
||||
set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/9d1a91e2-e8b8-40a5-8c7f-5db768a6a60c/w_intel-for-pytorch-gpu-dev_p_0.5.3.37_offline.exe
|
||||
set XPU_BUNDLE_PRODUCT_NAME=intel.oneapi.win.intel-for-pytorch-gpu-dev.product
|
||||
set XPU_BUNDLE_VERSION=0.5.3+31
|
||||
set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/9d6d6c17-ca2d-4735-9331-99447e4a1280/intel-deep-learning-essentials-2025.0.1.28_offline.exe
|
||||
set XPU_BUNDLE_PRODUCT_NAME=intel.oneapi.win.deep-learning-essentials.product
|
||||
set XPU_BUNDLE_VERSION=2025.0.1+20
|
||||
set XPU_BUNDLE_INSTALLED=0
|
||||
set XPU_BUNDLE_UNINSTALL=0
|
||||
set XPU_EXTRA_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/9d1a91e2-e8b8-40a5-8c7f-5db768a6a60c/w_intel-pti-dev_p_0.9.0.37_offline.exe
|
||||
set XPU_EXTRA_PRODUCT_NAME=intel.oneapi.win.intel-pti-dev.product
|
||||
set XPU_EXTRA_VERSION=0.9.0+36
|
||||
set XPU_EXTRA_URL=NULL
|
||||
set XPU_EXTRA_PRODUCT_NAME=intel.oneapi.win.compiler.product
|
||||
set XPU_EXTRA_VERSION=2025.0.1+1226
|
||||
set XPU_EXTRA_INSTALLED=0
|
||||
set XPU_EXTRA_UNINSTALL=0
|
||||
|
||||
if not [%XPU_VERSION%]==[] if [%XPU_VERSION%]==[2025.0] (
|
||||
set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/9d6d6c17-ca2d-4735-9331-99447e4a1280/intel-deep-learning-essentials-2025.0.1.28_offline.exe
|
||||
set XPU_BUNDLE_PRODUCT_NAME=intel.oneapi.win.deep-learning-essentials.product
|
||||
set XPU_BUNDLE_VERSION=2025.0.1+20
|
||||
set XPU_BUNDLE_INSTALLED=0
|
||||
set XPU_BUNDLE_UNINSTALL=0
|
||||
set XPU_EXTRA_URL=NULL
|
||||
set XPU_EXTRA_PRODUCT_NAME=intel.oneapi.win.compiler.product
|
||||
set XPU_EXTRA_VERSION=2025.0.1+1226
|
||||
set XPU_EXTRA_INSTALLED=0
|
||||
set XPU_EXTRA_UNINSTALL=0
|
||||
if not [%XPU_VERSION%]==[] if [%XPU_VERSION%]==[2025.1] (
|
||||
set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/1a9fff3d-04c2-4d77-8861-3d86c774b66f/intel-deep-learning-essentials-2025.1.1.26_offline.exe
|
||||
set XPU_BUNDLE_VERSION=2025.1.1+23
|
||||
)
|
||||
|
||||
:: Check if XPU bundle is target version or already installed
|
||||
|
||||
@ -26,6 +26,7 @@ set VS2022INSTALLDIR=%VS15INSTALLDIR%
|
||||
set XPU_BUNDLE_ROOT=%ProgramFiles(x86)%\Intel\oneAPI
|
||||
call "%XPU_BUNDLE_ROOT%\compiler\latest\env\vars.bat"
|
||||
call "%XPU_BUNDLE_ROOT%\ocloc\latest\env\vars.bat"
|
||||
set USE_ONEMKL=1
|
||||
IF ERRORLEVEL 1 goto :eof
|
||||
|
||||
if exist "%NIGHTLIES_PYTORCH_ROOT%" cd %NIGHTLIES_PYTORCH_ROOT%\..
|
||||
|
||||
@ -15,7 +15,7 @@ fi
|
||||
if [[ "$DESIRED_CUDA" == 'xpu' ]]; then
|
||||
export VC_YEAR=2022
|
||||
export USE_SCCACHE=0
|
||||
export XPU_VERSION=2025.0
|
||||
export XPU_VERSION=2025.1
|
||||
export XPU_ENABLE_KINETO=1
|
||||
fi
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ export VC_YEAR=2019
|
||||
|
||||
if [[ "$DESIRED_CUDA" == 'xpu' ]]; then
|
||||
export VC_YEAR=2022
|
||||
export XPU_VERSION=2025.0
|
||||
export XPU_VERSION=2025.1
|
||||
fi
|
||||
|
||||
pushd "$PYTORCH_ROOT/.ci/pytorch/"
|
||||
|
||||
@ -44,7 +44,7 @@ runs:
|
||||
retry_wait_seconds: 30
|
||||
command: |
|
||||
set -eu
|
||||
python3 -m pip install python-dateutil==2.8.2 boto3==1.35.42 pandas==2.1.3
|
||||
python3 -m pip install python-dateutil==2.8.2 boto3==1.35.42 pandas==2.1.3 dataclasses_json==0.6.7
|
||||
- name: Upload utilizatoin stats to s3
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
1
.github/pytorch-probot.yml
vendored
1
.github/pytorch-probot.yml
vendored
@ -25,7 +25,6 @@ ciflow_push_tags:
|
||||
- ciflow/unstable
|
||||
- ciflow/xpu
|
||||
- ciflow/torchbench
|
||||
- ciflow/autoformat
|
||||
- ciflow/op-benchmark
|
||||
- ciflow/pull
|
||||
- ciflow/h100
|
||||
|
||||
31
.github/scripts/generate_binary_build_matrix.py
vendored
31
.github/scripts/generate_binary_build_matrix.py
vendored
@ -88,17 +88,26 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
|
||||
"nvidia-cufile-cu12==1.13.0.11; platform_system == 'Linux' and platform_machine == 'x86_64'"
|
||||
),
|
||||
"xpu": (
|
||||
"intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | "
|
||||
"intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | "
|
||||
"intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | "
|
||||
"intel-sycl-rt==2025.0.4; platform_system == 'Linux' | "
|
||||
"intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | "
|
||||
"intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | "
|
||||
"intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | "
|
||||
"intel-sycl-rt==2025.0.5; platform_system == 'Windows' | "
|
||||
"tcmlib==1.2.0 | "
|
||||
"umf==0.9.1 | "
|
||||
"intel-pti==0.10.1"
|
||||
"intel-cmplr-lib-rt==2025.1.1 | "
|
||||
"intel-cmplr-lib-ur==2025.1.1 | "
|
||||
"intel-cmplr-lic-rt==2025.1.1 | "
|
||||
"intel-sycl-rt==2025.1.1 | "
|
||||
"oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||
"oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||
"impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||
"onemkl-sycl-blas==2025.1.0 | "
|
||||
"onemkl-sycl-dft==2025.1.0 | "
|
||||
"onemkl-sycl-lapack==2025.1.0 | "
|
||||
"onemkl-sycl-rng==2025.1.0 | "
|
||||
"onemkl-sycl-sparse==2025.1.0 | "
|
||||
"dpcpp-cpp-rt==2025.1.1 | "
|
||||
"intel-opencl-rt==2025.1.1 | "
|
||||
"mkl==2025.1.0 | "
|
||||
"intel-openmp==2025.1.1 | "
|
||||
"tbb==2022.1.0 | "
|
||||
"tcmlib==1.3.0 | "
|
||||
"umf==0.10.0 | "
|
||||
"intel-pti==0.12.0"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
15
.github/scripts/trymerge.py
vendored
15
.github/scripts/trymerge.py
vendored
@ -1938,6 +1938,7 @@ def get_ghstack_dependent_prs(
|
||||
|
||||
def do_revert_prs(
|
||||
repo: GitRepo,
|
||||
original_pr: GitHubPR,
|
||||
shas_and_prs: list[tuple[str, GitHubPR]],
|
||||
*,
|
||||
author_login: str,
|
||||
@ -1959,9 +1960,16 @@ def do_revert_prs(
|
||||
|
||||
# Comment/reopen PRs
|
||||
for commit_sha, pr in shas_and_prs:
|
||||
revert_message = (
|
||||
f"@{pr.get_pr_creator_login()} your PR has been successfully reverted."
|
||||
)
|
||||
revert_message = ""
|
||||
if pr.pr_num == original_pr.pr_num:
|
||||
revert_message += (
|
||||
f"@{pr.get_pr_creator_login()} your PR has been successfully reverted."
|
||||
)
|
||||
else:
|
||||
revert_message += (
|
||||
f"@{pr.get_pr_creator_login()} your PR has been reverted as part of the stack under "
|
||||
f"#{original_pr.pr_num}.\n"
|
||||
)
|
||||
if (
|
||||
pr.has_internal_changes()
|
||||
and not pr.has_no_connected_diff()
|
||||
@ -2013,6 +2021,7 @@ def try_revert(
|
||||
|
||||
do_revert_prs(
|
||||
repo,
|
||||
pr,
|
||||
shas_and_prs,
|
||||
author_login=author_login,
|
||||
extra_msg=extra_msg,
|
||||
|
||||
2
.github/workflows/docker-builds.yml
vendored
2
.github/workflows/docker-builds.yml
vendored
@ -67,8 +67,8 @@ jobs:
|
||||
pytorch-linux-jammy-py3.9-gcc11,
|
||||
pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3.12-halide,
|
||||
pytorch-linux-jammy-xpu-2024.0-py3,
|
||||
pytorch-linux-jammy-xpu-2025.0-py3,
|
||||
pytorch-linux-jammy-xpu-2025.1-py3,
|
||||
pytorch-linux-jammy-py3-clang15-asan,
|
||||
pytorch-linux-jammy-py3-clang18-asan,
|
||||
pytorch-linux-focal-py3-clang10-onnx,
|
||||
|
||||
12
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
12
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
@ -565,7 +565,7 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build_name: manywheel-py3_9-xpu
|
||||
build_environment: linux-binary-manywheel
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
manywheel-py3_9-xpu-test: # Testing
|
||||
@ -1178,7 +1178,7 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build_name: manywheel-py3_10-xpu
|
||||
build_environment: linux-binary-manywheel
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
manywheel-py3_10-xpu-test: # Testing
|
||||
@ -1859,7 +1859,7 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build_name: manywheel-py3_11-xpu
|
||||
build_environment: linux-binary-manywheel
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
manywheel-py3_11-xpu-test: # Testing
|
||||
@ -2472,7 +2472,7 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build_name: manywheel-py3_12-xpu
|
||||
build_environment: linux-binary-manywheel
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
manywheel-py3_12-xpu-test: # Testing
|
||||
@ -3085,7 +3085,7 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build_name: manywheel-py3_13-xpu
|
||||
build_environment: linux-binary-manywheel
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
manywheel-py3_13-xpu-test: # Testing
|
||||
@ -3698,7 +3698,7 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build_name: manywheel-py3_13t-xpu
|
||||
build_environment: linux-binary-manywheel
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
manywheel-py3_13t-xpu-test: # Testing
|
||||
|
||||
12
.github/workflows/generated-windows-binary-wheel-nightly.yml
generated
vendored
12
.github/workflows/generated-windows-binary-wheel-nightly.yml
generated
vendored
@ -1004,7 +1004,7 @@ jobs:
|
||||
GPU_ARCH_TYPE: xpu
|
||||
SKIP_ALL_TESTS: 1
|
||||
DESIRED_PYTHON: "3.9"
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
|
||||
steps:
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
@ -2189,7 +2189,7 @@ jobs:
|
||||
GPU_ARCH_TYPE: xpu
|
||||
SKIP_ALL_TESTS: 1
|
||||
DESIRED_PYTHON: "3.10"
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
|
||||
steps:
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
@ -3374,7 +3374,7 @@ jobs:
|
||||
GPU_ARCH_TYPE: xpu
|
||||
SKIP_ALL_TESTS: 1
|
||||
DESIRED_PYTHON: "3.11"
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
|
||||
steps:
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
@ -4559,7 +4559,7 @@ jobs:
|
||||
GPU_ARCH_TYPE: xpu
|
||||
SKIP_ALL_TESTS: 1
|
||||
DESIRED_PYTHON: "3.12"
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
|
||||
steps:
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
@ -5744,7 +5744,7 @@ jobs:
|
||||
GPU_ARCH_TYPE: xpu
|
||||
SKIP_ALL_TESTS: 1
|
||||
DESIRED_PYTHON: "3.13"
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
|
||||
steps:
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
@ -6929,7 +6929,7 @@ jobs:
|
||||
GPU_ARCH_TYPE: xpu
|
||||
SKIP_ALL_TESTS: 1
|
||||
DESIRED_PYTHON: "3.13t"
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
|
||||
steps:
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
|
||||
2
.github/workflows/inductor-perf-compare.yml
vendored
2
.github/workflows/inductor-perf-compare.yml
vendored
@ -52,7 +52,7 @@ jobs:
|
||||
docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build.outputs.test-matrix }}
|
||||
# disable monitor in perf tests for more investigation
|
||||
disable-monitor: true
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
|
||||
@ -129,7 +129,9 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-jammy-aarch64-py3_10-inductor-build.outputs.test-matrix }}
|
||||
timeout-minutes: 720
|
||||
# disable monitor in perf tests for more investigation
|
||||
disable-monitor: true
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
|
||||
|
||||
|
||||
@ -122,7 +122,7 @@ jobs:
|
||||
test-matrix: ${{ needs.build.outputs.test-matrix }}
|
||||
timeout-minutes: 720
|
||||
# disable monitor in perf tests, next step is to enable it
|
||||
disable-monitor: true
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
@ -139,7 +139,7 @@ jobs:
|
||||
test-matrix: ${{ needs.build.outputs.test-matrix }}
|
||||
timeout-minutes: 1440
|
||||
# disable monitor in perf tests, next step is to enable it
|
||||
disable-monitor: true
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
@ -156,7 +156,7 @@ jobs:
|
||||
test-matrix: ${{ needs.build.outputs.test-matrix }}
|
||||
timeout-minutes: 720
|
||||
# disable monitor in perf tests for more investigation
|
||||
disable-monitor: true
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
|
||||
@ -103,7 +103,7 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }}
|
||||
timeout-minutes: 720
|
||||
# disable monitor in perf tests
|
||||
disable-monitor: true
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
@ -121,7 +121,7 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }}
|
||||
timeout-minutes: 720
|
||||
# disable monitor in perf tests
|
||||
disable-monitor: true
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
|
||||
@ -123,8 +123,7 @@ jobs:
|
||||
docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build.outputs.test-matrix }}
|
||||
timeout-minutes: 720
|
||||
# disable monitor in perf tests, next step is to enable it
|
||||
disable-monitor: true
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
@ -141,7 +140,7 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build.outputs.test-matrix }}
|
||||
timeout-minutes: 1440
|
||||
# disable monitor in perf tests, next step is to enable it
|
||||
disable-monitor: true
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
@ -157,8 +156,7 @@ jobs:
|
||||
docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build.outputs.test-matrix }}
|
||||
timeout-minutes: 720
|
||||
# disable monitor in perf tests, next step is to enable it
|
||||
disable-monitor: true
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
|
||||
4
.github/workflows/inductor-periodic.yml
vendored
4
.github/workflows/inductor-periodic.yml
vendored
@ -133,10 +133,6 @@ jobs:
|
||||
build-environment: linux-focal-cuda12.6-py3.10-gcc9-sm80
|
||||
docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build-gcp.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build-gcp.outputs.test-matrix }}
|
||||
# disable monitor in perf tests, next step is to enable it
|
||||
disable-monitor: true
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cpu-py3_9-gcc11-periodic-dynamo-benchmarks-build:
|
||||
|
||||
11
.github/workflows/lint-autoformat.yml
vendored
11
.github/workflows/lint-autoformat.yml
vendored
@ -1,10 +1,8 @@
|
||||
name: Apply lint suggestions
|
||||
|
||||
on:
|
||||
|
||||
push:
|
||||
tags:
|
||||
- ciflow/autoformat/*
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, labeled, unlabeled]
|
||||
|
||||
jobs:
|
||||
lintrunner-autoformat:
|
||||
@ -12,7 +10,7 @@ jobs:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
runs-on: lf.linux.2xlarge
|
||||
if: ${{ github.repository_owner == 'pytorch' && github.event.pull_request.user.login != 'ezyang' && github.event.pull_request.user.login != 'malfet' && !startsWith(github.head_ref, 'export-') }}
|
||||
if: ${{ github.repository_owner == 'pytorch' && contains(github.event.pull_request.labels.*.name, 'autoformat') }}
|
||||
steps:
|
||||
- name: Checkout pytorch
|
||||
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
|
||||
@ -21,12 +19,11 @@ jobs:
|
||||
fetch-depth: 0
|
||||
- name: Run lintrunner (nonretryable)
|
||||
continue-on-error: true
|
||||
# we can't run all files here because only changes around where the diff are shown in the PR UI
|
||||
run: |
|
||||
set -ex
|
||||
python3 -m venv /tmp/venv
|
||||
source /tmp/venv/bin/activate
|
||||
export ADDITIONAL_LINTRUNNER_ARGS="format"
|
||||
export ADDITIONAL_LINTRUNNER_ARGS="format --all-files"
|
||||
bash .github/scripts/lintrunner.sh
|
||||
- name: Check for changes
|
||||
id: git-check
|
||||
|
||||
10
.github/workflows/pull.yml
vendored
10
.github/workflows/pull.yml
vendored
@ -532,15 +532,15 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-xpu-2025_0-py3_9-build:
|
||||
name: linux-jammy-xpu-2025.0-py3.9
|
||||
linux-jammy-xpu-2025_1-py3_9-build:
|
||||
name: linux-jammy-xpu-2025.1-py3.9
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
sync-tag: linux-xpu-2025-0-build
|
||||
sync-tag: linux-xpu-2025-1-build
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-jammy-xpu-2025.0-py3.9
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-2025.0-py3
|
||||
build-environment: linux-jammy-xpu-2025.1-py3.9
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-2025.1-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 4, runner: "linux.idc.xpu" },
|
||||
|
||||
1
.github/workflows/trunk.yml
vendored
1
.github/workflows/trunk.yml
vendored
@ -138,6 +138,7 @@ jobs:
|
||||
build-environment: win-vs2022-cpu-py3
|
||||
cuda-version: cpu
|
||||
test-matrix: ${{ needs.win-vs2022-cpu-py3-build.outputs.test-matrix }}
|
||||
disable-monitor: false
|
||||
secrets: inherit
|
||||
|
||||
win-vs2022-cuda12_6-py3-build:
|
||||
|
||||
20
.github/workflows/upload-test-stats.yml
vendored
20
.github/workflows/upload-test-stats.yml
vendored
@ -2,7 +2,25 @@ name: Upload test stats
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: [pull, trunk, periodic, periodic-rocm-mi300, inductor, unstable, slow, unstable-periodic, inductor-periodic, rocm, rocm-mi300, inductor-micro-benchmark, inductor-micro-benchmark-x86, inductor-cu124, inductor-rocm, inductor-rocm-mi300, mac-mps]
|
||||
workflows:
|
||||
- pull
|
||||
- trunk
|
||||
- periodic
|
||||
- periodic-rocm-mi300
|
||||
- inductor
|
||||
- unstable
|
||||
- slow
|
||||
- unstable-periodic
|
||||
- inductor-periodic
|
||||
- rocm
|
||||
- rocm-mi300
|
||||
- inductor-micro-benchmark
|
||||
- inductor-micro-benchmark-x86
|
||||
- inductor-cu124
|
||||
- inductor-rocm
|
||||
- inductor-rocm-mi300
|
||||
- mac-mps
|
||||
- linux-aarch64
|
||||
types:
|
||||
- completed
|
||||
|
||||
|
||||
45
.github/workflows/xpu.yml
vendored
45
.github/workflows/xpu.yml
vendored
@ -43,17 +43,38 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-xpu-2025_0-py3_9-test:
|
||||
name: linux-jammy-xpu-2025.0-py3.9
|
||||
linux-jammy-xpu-2025_1-py3_9-build:
|
||||
name: linux-jammy-xpu-2025.1-py3.9
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
sync-tag: linux-xpu-2025-1-build
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-jammy-xpu-2025.1-py3.9
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-2025.1-py3
|
||||
runner: linux.12xlarge
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-xpu-2025_1-py3_9-test:
|
||||
name: linux-jammy-xpu-2025.1-py3.9
|
||||
uses: ./.github/workflows/_xpu-test.yml
|
||||
needs: linux-jammy-xpu-2025_0-py3_9-build
|
||||
needs: linux-jammy-xpu-2025_1-py3_9-build
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
with:
|
||||
build-environment: linux-jammy-xpu-2025.0-py3.9
|
||||
docker-image: ${{ needs.linux-jammy-xpu-2025_0-py3_9-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-xpu-2025_0-py3_9-build.outputs.test-matrix }}
|
||||
build-environment: linux-jammy-xpu-2025.1-py3.9
|
||||
docker-image: ${{ needs.linux-jammy-xpu-2025_1-py3_9-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-xpu-2025_1-py3_9-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
windows-xpu-2025_0-build:
|
||||
@ -67,3 +88,15 @@ jobs:
|
||||
xpu-version: '2025.0'
|
||||
vc-year: '2022'
|
||||
secrets: inherit
|
||||
|
||||
windows-xpu-2025_1-build:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: win-vs2022-xpu-2025_1-py3
|
||||
uses: ./.github/workflows/_win-build.yml
|
||||
with:
|
||||
build-environment: win-vs2022-xpu-py3
|
||||
cuda-version: cpu
|
||||
use-xpu: true
|
||||
xpu-version: '2025.1'
|
||||
vc-year: '2022'
|
||||
secrets: inherit
|
||||
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@ -212,15 +212,6 @@ docs/source/scripts/lr_scheduler_images/
|
||||
# Compiled MATLAB
|
||||
*.mex*
|
||||
|
||||
# IPython notebook checkpoints
|
||||
.ipynb_checkpoints
|
||||
|
||||
# Editor temporaries
|
||||
*.swn
|
||||
*.swo
|
||||
*.swp
|
||||
*~
|
||||
|
||||
# NFS handle files
|
||||
**/.nfs*
|
||||
|
||||
|
||||
@ -1719,3 +1719,15 @@ include_patterns = [
|
||||
'torch/_dynamo/**',
|
||||
]
|
||||
is_formatter = false
|
||||
|
||||
[[linter]]
|
||||
code = 'TEST_DEVICE_BIAS'
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/test_device_bias_linter.py',
|
||||
'--',
|
||||
'@{{PATHSFILE}}',
|
||||
]
|
||||
include_patterns = [
|
||||
'test/**/test_*.py',
|
||||
]
|
||||
|
||||
@ -222,7 +222,6 @@ option(
|
||||
BUILD_MOBILE_TEST
|
||||
"Build C++ test binaries for mobile (ARM) targets(need gtest and gbenchmark)"
|
||||
OFF)
|
||||
option(BUILD_JNI "Build JNI bindings" OFF)
|
||||
option(BUILD_MOBILE_AUTOGRAD
|
||||
"Build autograd function in mobile build (in development)" OFF)
|
||||
cmake_dependent_option(INSTALL_TEST "Install test binaries if BUILD_TEST is on"
|
||||
@ -259,6 +258,8 @@ option(USE_NATIVE_ARCH "Use -march=native" OFF)
|
||||
cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF)
|
||||
cmake_dependent_option(USE_NCCL "Use NCCL" ON
|
||||
"USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF)
|
||||
cmake_dependent_option(USE_XCCL "Use XCCL" OFF
|
||||
"USE_XPU;UNIX;NOT APPLE" OFF)
|
||||
cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF)
|
||||
cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF)
|
||||
cmake_dependent_option(USE_SYSTEM_NCCL "Use system-wide NCCL" OFF "USE_NCCL"
|
||||
@ -338,6 +339,8 @@ cmake_dependent_option(
|
||||
USE_C10D_GLOO "USE C10D GLOO" ON "USE_DISTRIBUTED;USE_GLOO" OFF)
|
||||
cmake_dependent_option(
|
||||
USE_C10D_NCCL "USE C10D NCCL" ON "USE_DISTRIBUTED;USE_NCCL" OFF)
|
||||
cmake_dependent_option(
|
||||
USE_C10D_XCCL "USE C10D XCCL" ON "USE_DISTRIBUTED;USE_XCCL" OFF)
|
||||
cmake_dependent_option(
|
||||
USE_C10D_MPI "USE C10D MPI" ON "USE_DISTRIBUTED;USE_MPI" OFF)
|
||||
cmake_dependent_option(
|
||||
@ -1346,16 +1349,6 @@ if(BUILD_BINARY)
|
||||
add_subdirectory(binaries)
|
||||
endif()
|
||||
|
||||
# ---[ JNI
|
||||
if(BUILD_JNI)
|
||||
if(NOT MSVC)
|
||||
string(APPEND CMAKE_CXX_FLAGS " -Wno-unused-variable")
|
||||
endif()
|
||||
set(BUILD_LIBTORCH_WITH_JNI 1)
|
||||
set(FBJNI_SKIP_TESTS 1)
|
||||
add_subdirectory(android/pytorch_android)
|
||||
endif()
|
||||
|
||||
include(cmake/Summary.cmake)
|
||||
caffe2_print_configuration_summary()
|
||||
|
||||
|
||||
@ -1,240 +1,5 @@
|
||||
# Android
|
||||
# Android Demo App
|
||||
|
||||
## Demo applications and tutorials
|
||||
Please refer to [pytorch-labs/executorch-examples](https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch).
|
||||
|
||||
Demo applications with code walk-through can be find in [this github repo](https://github.com/pytorch/android-demo-app).
|
||||
|
||||
## Publishing
|
||||
|
||||
##### Release
|
||||
Release artifacts are published to jcenter:
|
||||
|
||||
```groovy
|
||||
repositories {
|
||||
jcenter()
|
||||
}
|
||||
|
||||
# lite interpreter build
|
||||
dependencies {
|
||||
implementation 'org.pytorch:pytorch_android_lite:1.10.0'
|
||||
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.10.0'
|
||||
}
|
||||
|
||||
# full jit build
|
||||
dependencies {
|
||||
implementation 'org.pytorch:pytorch_android:1.10.0'
|
||||
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
|
||||
}
|
||||
```
|
||||
|
||||
##### Nightly
|
||||
|
||||
Nightly(snapshots) builds are published every night from `master` branch to [nexus sonatype snapshots repository](https://oss.sonatype.org/#nexus-search;quick~pytorch_android)
|
||||
|
||||
To use them repository must be specified explicitly:
|
||||
```groovy
|
||||
repositories {
|
||||
maven {
|
||||
url "https://oss.sonatype.org/content/repositories/snapshots"
|
||||
}
|
||||
}
|
||||
|
||||
# lite interpreter build
|
||||
dependencies {
|
||||
...
|
||||
implementation 'org.pytorch:pytorch_android_lite:1.12.0-SNAPSHOT'
|
||||
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.12.0-SNAPSHOT'
|
||||
...
|
||||
}
|
||||
|
||||
# full jit build
|
||||
dependencies {
|
||||
...
|
||||
implementation 'org.pytorch:pytorch_android:1.12.0-SNAPSHOT'
|
||||
implementation 'org.pytorch:pytorch_android_torchvision:1.12.0-SNAPSHOT'
|
||||
...
|
||||
}
|
||||
```
|
||||
The current nightly(snapshots) version is the value of `VERSION_NAME` in `gradle.properties` in current folder, at this moment it is `1.8.0-SNAPSHOT`.
|
||||
|
||||
## Building PyTorch Android from Source
|
||||
|
||||
In some cases you might want to use a local build of pytorch android, for example you may build custom libtorch binary with another set of operators or to make local changes.
|
||||
|
||||
For this you can use `./scripts/build_pytorch_android.sh` script.
|
||||
```bash
|
||||
git clone https://github.com/pytorch/pytorch.git
|
||||
cd pytorch
|
||||
git submodule update --init --recursive
|
||||
bash ./scripts/build_pytorch_android.sh
|
||||
```
|
||||
|
||||
The workflow contains several steps:
|
||||
|
||||
1\. Build libtorch for android for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64)
|
||||
|
||||
2\. Create symbolic links to the results of those builds:
|
||||
`android/pytorch_android/src/main/jniLibs/${abi}` to the directory with output libraries
|
||||
`android/pytorch_android/src/main/cpp/libtorch_include/${abi}` to the directory with headers. These directories are used to build `libpytorch.so` library that will be loaded on android device.
|
||||
|
||||
3\. And finally run `gradle` in `android/pytorch_android` directory with task `assembleRelease`
|
||||
|
||||
Script requires that Android SDK, Android NDK and gradle are installed.
|
||||
They are specified as environment variables:
|
||||
|
||||
`ANDROID_HOME` - path to [Android SDK](https://developer.android.com/studio/command-line/sdkmanager.html)
|
||||
|
||||
`ANDROID_NDK` - path to [Android NDK](https://developer.android.com/studio/projects/install-ndk). It's recommended to use NDK 21.x.
|
||||
|
||||
`GRADLE_HOME` - path to [gradle](https://gradle.org/releases/)
|
||||
|
||||
|
||||
After successful build you should see the result as aar file:
|
||||
|
||||
```bash
|
||||
$ find pytorch_android/build/ -type f -name *aar
|
||||
pytorch_android/build/outputs/aar/pytorch_android.aar
|
||||
pytorch_android_torchvision/build/outputs/aar/pytorch_android.aar
|
||||
```
|
||||
|
||||
It can be used directly in android projects, as a gradle dependency:
|
||||
```groovy
|
||||
allprojects {
|
||||
repositories {
|
||||
flatDir {
|
||||
dirs 'libs'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation(name:'pytorch_android', ext:'aar')
|
||||
implementation(name:'pytorch_android_torchvision', ext:'aar')
|
||||
...
|
||||
implementation 'com.facebook.soloader:nativeloader:0.10.5'
|
||||
implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
|
||||
}
|
||||
```
|
||||
We also have to add all transitive dependencies of our aars.
|
||||
As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/build.gradle#L76-L77) on `'com.facebook.soloader:nativeloader:0.10.5'` and `'com.facebook.fbjni:fbjni-java-only:0.2.2'`, we need to add them.
|
||||
(In case of using maven dependencies they are added automatically from `pom.xml`).
|
||||
|
||||
You can check out [test app example](https://github.com/pytorch/pytorch/blob/master/android/test_app/app/build.gradle) that uses aars directly.
|
||||
|
||||
## Linking to prebuilt libtorch library from gradle dependency
|
||||
|
||||
In some cases, you may want to use libtorch from your android native build.
|
||||
You can do it without building libtorch android, using native libraries from PyTorch android gradle dependency.
|
||||
For that, you will need to add the next lines to your gradle build.
|
||||
```groovy
|
||||
android {
|
||||
...
|
||||
configurations {
|
||||
extractForNativeBuild
|
||||
}
|
||||
...
|
||||
compileOptions {
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
arguments "-DANDROID_STL=c++_shared"
|
||||
}
|
||||
}
|
||||
}
|
||||
...
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path "CMakeLists.txt"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
extractForNativeBuild('org.pytorch:pytorch_android:1.10.0')
|
||||
}
|
||||
|
||||
task extractAARForNativeBuild {
|
||||
doLast {
|
||||
configurations.extractForNativeBuild.files.each {
|
||||
def file = it.absoluteFile
|
||||
copy {
|
||||
from zipTree(file)
|
||||
into "$buildDir/$file.name"
|
||||
include "headers/**"
|
||||
include "jni/**"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tasks.whenTaskAdded { task ->
|
||||
if (task.name.contains('externalNativeBuild')) {
|
||||
task.dependsOn(extractAARForNativeBuild)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
pytorch_android aar contains headers to link in `headers` folder and native libraries in `jni/$ANDROID_ABI/`.
|
||||
As PyTorch native libraries use `ANDROID_STL` - we should use `ANDROID_STL=c++_shared` to have only one loaded binary of STL.
|
||||
|
||||
The added task will unpack them to gradle build directory.
|
||||
|
||||
In your native build you can link to them adding these lines to your CMakeLists.txt:
|
||||
|
||||
|
||||
```cmake
|
||||
# Relative path of gradle build directory to CMakeLists.txt
|
||||
set(build_DIR ${CMAKE_SOURCE_DIR}/build)
|
||||
|
||||
file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
|
||||
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")
|
||||
|
||||
set(BUILD_SUBDIR ${ANDROID_ABI})
|
||||
target_include_directories(${PROJECT_NAME} PRIVATE
|
||||
${PYTORCH_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
find_library(PYTORCH_LIBRARY pytorch_jni
|
||||
PATHS ${PYTORCH_LINK_DIRS}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
find_library(FBJNI_LIBRARY fbjni
|
||||
PATHS ${PYTORCH_LINK_DIRS}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
target_link_libraries(${PROJECT_NAME}
|
||||
${PYTORCH_LIBRARY})
|
||||
${FBJNI_LIBRARY})
|
||||
|
||||
```
|
||||
If your CMakeLists.txt file is located in the same directory as your build.gradle, `set(build_DIR ${CMAKE_SOURCE_DIR}/build)` should work for you. But if you have another location of it, you may need to change it.
|
||||
|
||||
After that, you can use libtorch C++ API from your native code.
|
||||
```cpp
|
||||
#include <string>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <torch/script.h>
|
||||
namespace pytorch_testapp_jni {
|
||||
namespace {
|
||||
struct JITCallGuard {
|
||||
c10::InferenceMode guard;
|
||||
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
|
||||
};
|
||||
}
|
||||
|
||||
void loadAndForwardModel(const std::string& modelPath) {
|
||||
JITCallGuard guard;
|
||||
torch::jit::Module module = torch::jit::load(modelPath);
|
||||
module.eval();
|
||||
torch::Tensor t = torch::randn({1, 3, 224, 224});
|
||||
c10::IValue t_out = module.forward({t});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
To load torchscript model for mobile we need some special setup which is placed in `struct JITCallGuard` in this example. It may change in future, you can track the latest changes keeping an eye in our [pytorch android jni code]([https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp#L28)
|
||||
|
||||
[Example of linking to libtorch from aar](https://github.com/pytorch/pytorch/tree/master/android/test_app)
|
||||
|
||||
## PyTorch Android API Javadoc
|
||||
|
||||
You can find more details about the PyTorch Android API in the [Javadoc](https://pytorch.org/javadoc/).
|
||||
Please join our [Discord](https://discord.com/channels/1334270993966825602/1349854760299270284) for any questions.
|
||||
|
||||
@ -1,40 +0,0 @@
|
||||
allprojects {
|
||||
buildscript {
|
||||
ext {
|
||||
minSdkVersion = 21
|
||||
targetSdkVersion = 28
|
||||
compileSdkVersion = 28
|
||||
buildToolsVersion = '28.0.3'
|
||||
|
||||
coreVersion = "1.2.0"
|
||||
extJUnitVersion = "1.1.1"
|
||||
runnerVersion = "1.2.0"
|
||||
rulesVersion = "1.2.0"
|
||||
junitVersion = "4.12"
|
||||
|
||||
fbjniJavaOnlyVersion = "0.2.2"
|
||||
soLoaderNativeLoaderVersion = "0.10.5"
|
||||
}
|
||||
|
||||
repositories {
|
||||
google()
|
||||
mavenLocal()
|
||||
mavenCentral()
|
||||
jcenter()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
classpath 'com.android.tools.build:gradle:4.1.2'
|
||||
classpath 'com.vanniktech:gradle-maven-publish-plugin:0.14.2'
|
||||
}
|
||||
}
|
||||
|
||||
repositories {
|
||||
google()
|
||||
jcenter()
|
||||
}
|
||||
}
|
||||
|
||||
ext.deps = [
|
||||
jsr305: 'com.google.code.findbugs:jsr305:3.0.1',
|
||||
]
|
||||
@ -1,30 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -eux
|
||||
|
||||
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
|
||||
PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android
|
||||
|
||||
echo "PYTORCH_DIR:$PYTORCH_DIR"
|
||||
|
||||
source "$PYTORCH_ANDROID_DIR/common.sh"
|
||||
|
||||
check_android_sdk
|
||||
check_gradle
|
||||
parse_abis_list "$@"
|
||||
build_android
|
||||
|
||||
# To set proxy for gradle add following lines to ./gradle/gradle.properties:
|
||||
# systemProp.http.proxyHost=...
|
||||
# systemProp.http.proxyPort=8080
|
||||
# systemProp.https.proxyHost=...
|
||||
# systemProp.https.proxyPort=8080
|
||||
|
||||
if [ "$CUSTOM_ABIS_LIST" = true ]; then
|
||||
NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug
|
||||
else
|
||||
NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug
|
||||
fi
|
||||
|
||||
find $PYTORCH_ANDROID_DIR -type f -name *apk
|
||||
|
||||
find $PYTORCH_ANDROID_DIR -type f -name *apk | xargs echo "To install apk run: $ANDROID_HOME/platform-tools/adb install -r "
|
||||
@ -1,32 +0,0 @@
|
||||
#!/bin/bash
|
||||
###############################################################################
|
||||
# This script tests the custom selective build flow for PyTorch Android, which
|
||||
# optimizes library size by only including ops used by a specific model.
|
||||
###############################################################################
|
||||
|
||||
set -eux
|
||||
|
||||
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
|
||||
PYTORCH_ANDROID_DIR="${PYTORCH_DIR}/android"
|
||||
BUILD_ROOT="${PYTORCH_DIR}/build_pytorch_android_custom"
|
||||
|
||||
source "${PYTORCH_ANDROID_DIR}/common.sh"
|
||||
|
||||
prepare_model_and_dump_root_ops() {
|
||||
cd "${BUILD_ROOT}"
|
||||
MODEL="${BUILD_ROOT}/MobileNetV2.pt"
|
||||
ROOT_OPS="${BUILD_ROOT}/MobileNetV2.yaml"
|
||||
python "${PYTORCH_ANDROID_DIR}/test_app/make_assets_custom.py"
|
||||
cp "${MODEL}" "${PYTORCH_ANDROID_DIR}/test_app/app/src/main/assets/mobilenet2.pt"
|
||||
}
|
||||
|
||||
# Start building
|
||||
mkdir -p "${BUILD_ROOT}"
|
||||
check_android_sdk
|
||||
check_gradle
|
||||
parse_abis_list "$@"
|
||||
prepare_model_and_dump_root_ops
|
||||
SELECTED_OP_LIST="${ROOT_OPS}" build_android
|
||||
|
||||
# TODO: change this to build test_app instead
|
||||
$GRADLE_PATH -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean assembleRelease
|
||||
@ -1,74 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -eux
|
||||
|
||||
##############################################################################
|
||||
# Common util functions for Android build scripts.
|
||||
##############################################################################
|
||||
|
||||
if [ -z "$PYTORCH_DIR" ]; then
|
||||
echo "PYTORCH_DIR not set!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
retry () {
|
||||
"$@" || (sleep 10 && "$@") || (sleep 20 && "$@") || (sleep 40 && "$@")
|
||||
}
|
||||
|
||||
check_android_sdk() {
|
||||
if [ -z "$ANDROID_HOME" ]; then
|
||||
echo "ANDROID_HOME not set; please set it to Android sdk directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d "$ANDROID_HOME" ]; then
|
||||
echo "ANDROID_HOME not a directory; did you install it under $ANDROID_HOME?"
|
||||
exit 1
|
||||
fi
|
||||
echo "ANDROID_HOME:$ANDROID_HOME"
|
||||
}
|
||||
|
||||
check_gradle() {
|
||||
GRADLE_PATH=$PYTORCH_DIR/android/gradlew
|
||||
echo "GRADLE_PATH:$GRADLE_PATH"
|
||||
}
|
||||
|
||||
parse_abis_list() {
|
||||
# sync with https://github.com/pytorch/pytorch/blob/0ca0e02685a9d033ac4f04e2fa5c8ba6dbc5ae50/android/gradle.properties#L1
|
||||
ABIS_LIST="armeabi-v7a,arm64-v8a,x86,x86_64"
|
||||
CUSTOM_ABIS_LIST=false
|
||||
if [ $# -gt 0 ]; then
|
||||
ABIS_LIST=$1
|
||||
CUSTOM_ABIS_LIST=true
|
||||
fi
|
||||
|
||||
echo "ABIS_LIST:$ABIS_LIST"
|
||||
echo "CUSTOM_ABIS_LIST:$CUSTOM_ABIS_LIST"
|
||||
}
|
||||
|
||||
build_android() {
|
||||
PYTORCH_ANDROID_DIR="$PYTORCH_DIR/android"
|
||||
BUILD_ROOT="${BUILD_ROOT:-$PYTORCH_DIR}"
|
||||
echo "BUILD_ROOT:$BUILD_ROOT"
|
||||
|
||||
LIB_DIR="$PYTORCH_ANDROID_DIR/pytorch_android/src/main/jniLibs"
|
||||
INCLUDE_DIR="$PYTORCH_ANDROID_DIR/pytorch_android/src/main/cpp/libtorch_include"
|
||||
|
||||
# These directories only contain symbolic links.
|
||||
rm -rf "$LIB_DIR" && mkdir -p "$LIB_DIR"
|
||||
rm -rf "$INCLUDE_DIR" && mkdir -p "$INCLUDE_DIR"
|
||||
|
||||
for abi in $(echo "$ABIS_LIST" | tr ',' '\n')
|
||||
do
|
||||
echo "abi:$abi"
|
||||
ANDROID_BUILD_ROOT="$BUILD_ROOT/build_android_$abi"
|
||||
ANDROID_ABI="$abi" \
|
||||
BUILD_ROOT="$ANDROID_BUILD_ROOT" \
|
||||
"$PYTORCH_DIR/scripts/build_android.sh" \
|
||||
-DANDROID_CCACHE="$(which ccache)" \
|
||||
-DUSE_LITE_INTERPRETER_PROFILER="OFF"
|
||||
|
||||
echo "$abi build output lib,include at $ANDROID_BUILD_ROOT/install"
|
||||
ln -s "$ANDROID_BUILD_ROOT/install/lib" "$LIB_DIR/$abi"
|
||||
ln -s "$ANDROID_BUILD_ROOT/install/include" "$INCLUDE_DIR/$abi"
|
||||
done
|
||||
}
|
||||
@ -1,26 +0,0 @@
|
||||
ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64
|
||||
|
||||
VERSION_NAME=2.2.0-SNAPSHOT
|
||||
GROUP=org.pytorch
|
||||
MAVEN_GROUP=org.pytorch
|
||||
SONATYPE_STAGING_PROFILE=orgpytorch
|
||||
POM_URL=https://github.com/pytorch/pytorch/tree/master/android
|
||||
POM_SCM_URL=https://github.com/pytorch/pytorch.git
|
||||
POM_SCM_CONNECTION=scm:git:https://github.com/pytorch/pytorch
|
||||
POM_SCM_DEV_CONNECTION=scm:git:git@github.com:pytorch/pytorch.git
|
||||
POM_LICENSE_NAME=BSD 3-Clause
|
||||
POM_LICENSE_URL=https://github.com/pytorch/pytorch/blob/master/LICENSE
|
||||
POM_ISSUES_URL=https://github.com/pytorch/pytorch/issues
|
||||
POM_LICENSE_DIST=repo
|
||||
POM_DEVELOPER_ID=pytorch
|
||||
POM_DEVELOPER_NAME=pytorch
|
||||
|
||||
# Gradle internals
|
||||
org.gradle.internal.repository.max.retries=1
|
||||
org.gradle.jvmargs=-XX:MaxMetaspaceSize=1024m
|
||||
|
||||
android.useAndroidX=true
|
||||
android.enableJetifier=true
|
||||
|
||||
nativeLibsDoNotStrip=false
|
||||
testAppAllVariantsEnabled=false
|
||||
@ -1,11 +0,0 @@
|
||||
afterEvaluate { project ->
|
||||
if (POM_PACKAGING == 'aar') {
|
||||
task headersJar(type: Jar) {
|
||||
archiveClassifier.set('headers')
|
||||
from("$rootDir/cxx/") {
|
||||
include '**/*.h'
|
||||
}
|
||||
}
|
||||
artifacts.add('archives', headersJar)
|
||||
}
|
||||
}
|
||||
@ -1,3 +0,0 @@
|
||||
apply from: rootProject.file('gradle/android_tasks.gradle')
|
||||
|
||||
apply plugin: 'com.vanniktech.maven.publish'
|
||||
BIN
android/gradle/wrapper/gradle-wrapper.jar
vendored
BIN
android/gradle/wrapper/gradle-wrapper.jar
vendored
Binary file not shown.
@ -1,5 +0,0 @@
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
172
android/gradlew
vendored
172
android/gradlew
vendored
@ -1,172 +0,0 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
##############################################################################
|
||||
##
|
||||
## Gradle start up script for UN*X
|
||||
##
|
||||
##############################################################################
|
||||
|
||||
# Attempt to set APP_HOME
|
||||
# Resolve links: $0 may be a link
|
||||
PRG="$0"
|
||||
# Need this for relative symlinks.
|
||||
while [ -h "$PRG" ] ; do
|
||||
ls=`ls -ld "$PRG"`
|
||||
link=`expr "$ls" : '.*-> \(.*\)$'`
|
||||
if expr "$link" : '/.*' > /dev/null; then
|
||||
PRG="$link"
|
||||
else
|
||||
PRG=`dirname "$PRG"`"/$link"
|
||||
fi
|
||||
done
|
||||
SAVED="`pwd`"
|
||||
cd "`dirname \"$PRG\"`/" >/dev/null
|
||||
APP_HOME="`pwd -P`"
|
||||
cd "$SAVED" >/dev/null
|
||||
|
||||
APP_NAME="Gradle"
|
||||
APP_BASE_NAME=`basename "$0"`
|
||||
|
||||
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
DEFAULT_JVM_OPTS=""
|
||||
|
||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||
MAX_FD="maximum"
|
||||
|
||||
warn () {
|
||||
echo "$*"
|
||||
}
|
||||
|
||||
die () {
|
||||
echo
|
||||
echo "$*"
|
||||
echo
|
||||
exit 1
|
||||
}
|
||||
|
||||
# OS specific support (must be 'true' or 'false').
|
||||
cygwin=false
|
||||
msys=false
|
||||
darwin=false
|
||||
nonstop=false
|
||||
case "`uname`" in
|
||||
CYGWIN* )
|
||||
cygwin=true
|
||||
;;
|
||||
Darwin* )
|
||||
darwin=true
|
||||
;;
|
||||
MINGW* )
|
||||
msys=true
|
||||
;;
|
||||
NONSTOP* )
|
||||
nonstop=true
|
||||
;;
|
||||
esac
|
||||
|
||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
||||
|
||||
# Determine the Java command to use to start the JVM.
|
||||
if [ -n "$JAVA_HOME" ] ; then
|
||||
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
|
||||
# IBM's JDK on AIX uses strange locations for the executables
|
||||
JAVACMD="$JAVA_HOME/jre/sh/java"
|
||||
else
|
||||
JAVACMD="$JAVA_HOME/bin/java"
|
||||
fi
|
||||
if [ ! -x "$JAVACMD" ] ; then
|
||||
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
else
|
||||
JAVACMD="java"
|
||||
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
|
||||
# Increase the maximum file descriptors if we can.
|
||||
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
|
||||
MAX_FD_LIMIT=`ulimit -H -n`
|
||||
if [ $? -eq 0 ] ; then
|
||||
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
|
||||
MAX_FD="$MAX_FD_LIMIT"
|
||||
fi
|
||||
ulimit -n $MAX_FD
|
||||
if [ $? -ne 0 ] ; then
|
||||
warn "Could not set maximum file descriptor limit: $MAX_FD"
|
||||
fi
|
||||
else
|
||||
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
|
||||
fi
|
||||
fi
|
||||
|
||||
# For Darwin, add options to specify how the application appears in the dock
|
||||
if $darwin; then
|
||||
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
|
||||
fi
|
||||
|
||||
# For Cygwin, switch paths to Windows format before running java
|
||||
if $cygwin ; then
|
||||
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
|
||||
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
|
||||
JAVACMD=`cygpath --unix "$JAVACMD"`
|
||||
|
||||
# We build the pattern for arguments to be converted via cygpath
|
||||
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
|
||||
SEP=""
|
||||
for dir in $ROOTDIRSRAW ; do
|
||||
ROOTDIRS="$ROOTDIRS$SEP$dir"
|
||||
SEP="|"
|
||||
done
|
||||
OURCYGPATTERN="(^($ROOTDIRS))"
|
||||
# Add a user-defined pattern to the cygpath arguments
|
||||
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
|
||||
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
|
||||
fi
|
||||
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
||||
i=0
|
||||
for arg in "$@" ; do
|
||||
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
|
||||
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
|
||||
|
||||
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
|
||||
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
|
||||
else
|
||||
eval `echo args$i`="\"$arg\""
|
||||
fi
|
||||
i=$((i+1))
|
||||
done
|
||||
case $i in
|
||||
(0) set -- ;;
|
||||
(1) set -- "$args0" ;;
|
||||
(2) set -- "$args0" "$args1" ;;
|
||||
(3) set -- "$args0" "$args1" "$args2" ;;
|
||||
(4) set -- "$args0" "$args1" "$args2" "$args3" ;;
|
||||
(5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
|
||||
(6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
|
||||
(7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
|
||||
(8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
|
||||
(9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
|
||||
esac
|
||||
fi
|
||||
|
||||
# Escape application args
|
||||
save () {
|
||||
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
|
||||
echo " "
|
||||
}
|
||||
APP_ARGS=$(save "$@")
|
||||
|
||||
# Collect all arguments for the java command, following the shell quoting and substitution rules
|
||||
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
|
||||
|
||||
# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
|
||||
if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then
|
||||
cd "$(dirname "$0")"
|
||||
fi
|
||||
|
||||
exec "$JAVACMD" "$@"
|
||||
84
android/gradlew.bat
vendored
84
android/gradlew.bat
vendored
@ -1,84 +0,0 @@
|
||||
@if "%DEBUG%" == "" @echo off
|
||||
@rem ##########################################################################
|
||||
@rem
|
||||
@rem Gradle startup script for Windows
|
||||
@rem
|
||||
@rem ##########################################################################
|
||||
|
||||
@rem Set local scope for the variables with windows NT shell
|
||||
if "%OS%"=="Windows_NT" setlocal
|
||||
|
||||
set DIRNAME=%~dp0
|
||||
if "%DIRNAME%" == "" set DIRNAME=.
|
||||
set APP_BASE_NAME=%~n0
|
||||
set APP_HOME=%DIRNAME%
|
||||
|
||||
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
set DEFAULT_JVM_OPTS=
|
||||
|
||||
@rem Find java.exe
|
||||
if defined JAVA_HOME goto findJavaFromJavaHome
|
||||
|
||||
set JAVA_EXE=java.exe
|
||||
%JAVA_EXE% -version >NUL 2>&1
|
||||
if "%ERRORLEVEL%" == "0" goto init
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
|
||||
goto fail
|
||||
|
||||
:findJavaFromJavaHome
|
||||
set JAVA_HOME=%JAVA_HOME:"=%
|
||||
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
|
||||
|
||||
if exist "%JAVA_EXE%" goto init
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
|
||||
goto fail
|
||||
|
||||
:init
|
||||
@rem Get command-line arguments, handling Windows variants
|
||||
|
||||
if not "%OS%" == "Windows_NT" goto win9xME_args
|
||||
|
||||
:win9xME_args
|
||||
@rem Slurp the command line arguments.
|
||||
set CMD_LINE_ARGS=
|
||||
set _SKIP=2
|
||||
|
||||
:win9xME_args_slurp
|
||||
if "x%~1" == "x" goto execute
|
||||
|
||||
set CMD_LINE_ARGS=%*
|
||||
|
||||
:execute
|
||||
@rem Setup the command line
|
||||
|
||||
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
|
||||
|
||||
@rem Execute Gradle
|
||||
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
|
||||
|
||||
:end
|
||||
@rem End local scope for the variables with windows NT shell
|
||||
if "%ERRORLEVEL%"=="0" goto mainEnd
|
||||
|
||||
:fail
|
||||
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
|
||||
rem the _cmd.exe /c_ return code!
|
||||
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
|
||||
exit /b 1
|
||||
|
||||
:mainEnd
|
||||
if "%OS%"=="Windows_NT" endlocal
|
||||
|
||||
:omega
|
||||
@ -1,184 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.5)
|
||||
option(BUILD_LITE_INTERPRETER "Master flag to build pytorch_jni_lite" ON)
|
||||
message(
|
||||
STATUS
|
||||
"BUILD_LITE_INTERPRETER (pytorch_jni_lite): ${BUILD_LITE_INTERPRETER}")
|
||||
|
||||
if(BUILD_LITE_INTERPRETER)
|
||||
project(pytorch_jni_lite CXX)
|
||||
set(PYTORCH_JNI_TARGET pytorch_jni_lite)
|
||||
else()
|
||||
project(pytorch_jni CXX)
|
||||
set(PYTORCH_JNI_TARGET pytorch_jni)
|
||||
endif()
|
||||
|
||||
include(GNUInstallDirs)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.")
|
||||
set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||
message(STATUS "ANDROID_STL:${ANDROID_STL}")
|
||||
|
||||
set(TRACE_ENABLED OFF)
|
||||
if(DEFINED ENV{TRACE_ENABLED})
|
||||
if($ENV{TRACE_ENABLED} STREQUAL "1")
|
||||
message(STATUS "TRACE_ENABLED ON")
|
||||
set(TRACE_ENABLED ON)
|
||||
endif()
|
||||
endif()
|
||||
if(NOT TRACE_ENABLED)
|
||||
message(STATUS "TRACE_ENABLED OFF")
|
||||
endif()
|
||||
|
||||
set(USE_VULKAN OFF)
|
||||
|
||||
set(pytorch_android_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
|
||||
|
||||
if(ANDROID_ABI)
|
||||
set(USE_VULKAN ON)
|
||||
set(libtorch_include_DIR ${pytorch_android_DIR}/libtorch_include/${ANDROID_ABI})
|
||||
set(BUILD_SUBDIR ${ANDROID_ABI})
|
||||
elseif(BUILD_LIBTORCH_WITH_JNI)
|
||||
# Don't need LIBTORCH_HOME if we're building from within PyTorch.
|
||||
else()
|
||||
# Building against a pre-built libtorch.
|
||||
if(NOT LIBTORCH_HOME)
|
||||
message(FATAL_ERROR
|
||||
"pytorch_android requires LIBTORCH_HOME to be defined for non-Android builds.")
|
||||
endif()
|
||||
set(libtorch_include_DIR ${LIBTORCH_HOME}/include)
|
||||
link_directories(${LIBTORCH_HOME}/lib)
|
||||
set(BUILD_SUBDIR host)
|
||||
endif()
|
||||
|
||||
message(STATUS "libtorch dir:${libtorch_DIR}")
|
||||
|
||||
configure_file(
|
||||
${pytorch_android_DIR}/cmake_macros.h.in
|
||||
${pytorch_android_DIR}/cmake_macros.h)
|
||||
|
||||
|
||||
if(BUILD_LITE_INTERPRETER)
|
||||
file(GLOB pytorch_android_SOURCES
|
||||
${pytorch_android_DIR}/pytorch_jni_lite.cpp
|
||||
${pytorch_android_DIR}/pytorch_jni_common.cpp
|
||||
${pytorch_android_DIR}/pytorch_jni_common.h
|
||||
)
|
||||
else()
|
||||
file(GLOB pytorch_android_SOURCES
|
||||
${pytorch_android_DIR}/pytorch_jni_jit.cpp
|
||||
${pytorch_android_DIR}/pytorch_jni_common.cpp
|
||||
${pytorch_android_DIR}/pytorch_jni_common.h
|
||||
)
|
||||
endif()
|
||||
add_library(${PYTORCH_JNI_TARGET} SHARED ${pytorch_android_SOURCES})
|
||||
|
||||
if(APPLE)
|
||||
# Need to add rpath so dlopen can find dependencies.
|
||||
add_custom_command(TARGET pytorch_jni
|
||||
POST_BUILD COMMAND
|
||||
${CMAKE_INSTALL_NAME_TOOL} -add_rpath "@loader_path"
|
||||
$<TARGET_FILE:pytorch_jni>)
|
||||
endif()
|
||||
|
||||
target_compile_options(${PYTORCH_JNI_TARGET} PRIVATE
|
||||
-fexceptions
|
||||
)
|
||||
target_include_directories(${PYTORCH_JNI_TARGET} BEFORE
|
||||
PUBLIC $<BUILD_INTERFACE:${libtorch_include_DIR}>)
|
||||
|
||||
set(fbjni_DIR ${CMAKE_CURRENT_LIST_DIR}/../libs/fbjni/)
|
||||
set(fbjni_BUILD_DIR ${CMAKE_BINARY_DIR}/fbjni/${BUILD_SUBDIR})
|
||||
|
||||
add_subdirectory(${fbjni_DIR} ${fbjni_BUILD_DIR})
|
||||
|
||||
# ---[ Vulkan deps
|
||||
if(USE_VULKAN)
|
||||
set(Vulkan_LIBS)
|
||||
set(Vulkan_INCLUDES)
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/../../cmake/VulkanDependencies.cmake)
|
||||
endif()
|
||||
|
||||
if(ANDROID_ABI)
|
||||
function(import_static_lib name)
|
||||
add_library(${name} STATIC IMPORTED)
|
||||
set_property(
|
||||
TARGET ${name}
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/${name}.a)
|
||||
endfunction(import_static_lib)
|
||||
|
||||
import_static_lib(libtorch)
|
||||
import_static_lib(libtorch_cpu)
|
||||
import_static_lib(libc10)
|
||||
import_static_lib(libnnpack)
|
||||
import_static_lib(libXNNPACK)
|
||||
import_static_lib(libmicrokernels-prod)
|
||||
import_static_lib(libpytorch_qnnpack)
|
||||
import_static_lib(libpthreadpool)
|
||||
import_static_lib(libeigen_blas)
|
||||
import_static_lib(libcpuinfo)
|
||||
import_static_lib(libclog)
|
||||
|
||||
# Link most things statically on Android.
|
||||
set(pytorch_jni_LIBS
|
||||
fbjni
|
||||
-Wl,--gc-sections
|
||||
-Wl,--whole-archive
|
||||
libtorch
|
||||
libtorch_cpu
|
||||
-Wl,--no-whole-archive
|
||||
libc10
|
||||
libnnpack
|
||||
libXNNPACK
|
||||
libmicrokernels-prod
|
||||
libpytorch_qnnpack
|
||||
libpthreadpool
|
||||
libeigen_blas
|
||||
libcpuinfo
|
||||
libclog
|
||||
)
|
||||
else()
|
||||
# Prefer dynamic linking on the host
|
||||
set(pytorch_jni_LIBS
|
||||
fbjni
|
||||
torch
|
||||
torch_cpu
|
||||
c10
|
||||
cpuinfo
|
||||
)
|
||||
|
||||
if(USE_NNPACK)
|
||||
list(APPEND pytorch_jni_LIBS nnpack)
|
||||
endif()
|
||||
|
||||
if(USE_XNNPACK)
|
||||
list(APPEND pytorch_jni_LIBS XNNPACK)
|
||||
list(APPEND pytorch_jni_LIBS microkernels-prod)
|
||||
endif()
|
||||
|
||||
if(USE_SYSTEM_PTHREADPOOL)
|
||||
list(APPEND pytorch_jni_LIBS pthreadpool)
|
||||
endif()
|
||||
|
||||
if(USE_PYTORCH_QNNPACK)
|
||||
list(APPEND pytorch_jni_LIBS pytorch_qnnpack)
|
||||
list(APPEND pytorch_jni_LIBS clog)
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
if(USE_VULKAN)
|
||||
list(APPEND pytorch_jni_LIBS ${Vulkan_LIBS})
|
||||
endif()
|
||||
|
||||
target_link_libraries(${PYTORCH_JNI_TARGET} ${pytorch_jni_LIBS})
|
||||
|
||||
install(TARGETS ${PYTORCH_JNI_TARGET}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) #For windows
|
||||
|
||||
if(MSVC)
|
||||
install(FILES $<TARGET_PDB_FILE:pytorch_jni> DESTINATION ${CMAKE_INSTALL_LIBDIR} OPTIONAL)
|
||||
install(TARGETS ${PYTORCH_JNI_TARGET} DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
||||
endif()
|
||||
@ -1,163 +0,0 @@
|
||||
apply plugin: 'com.android.library'
|
||||
apply plugin: 'maven'
|
||||
|
||||
android {
|
||||
compileSdkVersion rootProject.compileSdkVersion
|
||||
buildToolsVersion rootProject.buildToolsVersion
|
||||
|
||||
defaultConfig {
|
||||
minSdkVersion rootProject.minSdkVersion
|
||||
targetSdkVersion rootProject.targetSdkVersion
|
||||
versionCode 0
|
||||
versionName "0.1"
|
||||
|
||||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||
ndk {
|
||||
abiFilters ABI_FILTERS.split(",")
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
if(System.env.BUILD_LITE_INTERPRETER == '0') {
|
||||
arguments "-DANDROID_STL=c++_shared", "-DBUILD_LITE_INTERPRETER=OFF", "-DUSE_LITE_INTERPRETER_PROFILER=OFF"
|
||||
} else {
|
||||
arguments "-DANDROID_STL=c++_shared", "-DUSE_LITE_INTERPRETER_PROFILER=OFF"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
buildTypes {
|
||||
debug {
|
||||
minifyEnabled false
|
||||
debuggable true
|
||||
}
|
||||
release {
|
||||
minifyEnabled false
|
||||
}
|
||||
}
|
||||
sourceSets {
|
||||
main {
|
||||
java {
|
||||
if(System.env.BUILD_LITE_INTERPRETER == '0') {
|
||||
println 'Build pytorch_jni'
|
||||
exclude 'org/pytorch/LiteModuleLoader.java'
|
||||
exclude 'org/pytorch/LiteNativePeer.java'
|
||||
} else {
|
||||
println 'Build pytorch_jni_lite'
|
||||
}
|
||||
}
|
||||
jniLibs.srcDirs = ['src/main/jniLibs']
|
||||
manifest.srcFile 'src/main/AndroidManifest.xml'
|
||||
}
|
||||
androidTest {
|
||||
java {
|
||||
if(System.env.BUILD_LITE_INTERPRETER == '0') {
|
||||
println 'Build test for full jit (pytorch_jni)'
|
||||
exclude 'org/pytorch/PytorchHostTests.java'
|
||||
exclude 'org/pytorch/PytorchLiteInstrumentedTests.java'
|
||||
exclude 'org/pytorch/suite/PytorchLiteInstrumentedTestSuite.java'
|
||||
} else {
|
||||
println 'Build test for lite interpreter (pytorch_jni_lite)'
|
||||
exclude 'org/pytorch/PytorchHostTests.java'
|
||||
exclude 'org/pytorch/PytorchInstrumentedTests.java'
|
||||
exclude 'org/pytorch/suite/PytorchInstrumentedTestSuite.java'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path "CMakeLists.txt"
|
||||
}
|
||||
}
|
||||
|
||||
packagingOptions {
|
||||
if (nativeLibsDoNotStrip.toBoolean()) {
|
||||
doNotStrip "**/*.so"
|
||||
logger.warn('WARNING: nativeLibsDoNotStrip==true; debug symbols included')
|
||||
}
|
||||
}
|
||||
|
||||
useLibrary 'android.test.runner'
|
||||
useLibrary 'android.test.base'
|
||||
useLibrary 'android.test.mock'
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'com.facebook.fbjni:fbjni-java-only:' + rootProject.fbjniJavaOnlyVersion
|
||||
implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion
|
||||
|
||||
testImplementation 'junit:junit:' + rootProject.junitVersion
|
||||
testImplementation 'androidx.test:core:' + rootProject.coreVersion
|
||||
|
||||
androidTestImplementation 'junit:junit:' + rootProject.junitVersion
|
||||
androidTestImplementation 'androidx.test:core:' + rootProject.coreVersion
|
||||
androidTestImplementation 'androidx.test.ext:junit:' + rootProject.extJUnitVersion
|
||||
androidTestImplementation 'androidx.test:rules:' + rootProject.rulesVersion
|
||||
androidTestImplementation 'androidx.test:runner:' + rootProject.runnerVersion
|
||||
}
|
||||
|
||||
apply from: rootProject.file('gradle/release.gradle')
|
||||
|
||||
task sourcesJar(type: Jar) {
|
||||
from android.sourceSets.main.java.srcDirs
|
||||
classifier = 'sources'
|
||||
}
|
||||
|
||||
def getLibtorchHeadersDir() {
|
||||
def abi = ABI_FILTERS.split(",")[0]
|
||||
return "$rootDir/pytorch_android/src/main/cpp/libtorch_include/$abi"
|
||||
}
|
||||
|
||||
afterEvaluate {
|
||||
if (POM_PACKAGING == 'aar') {
|
||||
android.libraryVariants.all { variant ->
|
||||
variant.outputs.each { output ->
|
||||
File f = output.outputFile
|
||||
if (f.name.endsWith(".aar")) {
|
||||
output.assemble.finalizedBy addFolderToAarTask(
|
||||
"addHeadersToAar" + variant.name,
|
||||
f.path,
|
||||
getLibtorchHeadersDir(),
|
||||
"headers")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tasks.whenTaskAdded { task ->
|
||||
if (task.name.startsWith("bundle") && task.name.endsWith("Aar")) {
|
||||
doLast {
|
||||
addFolderToAar("addHeadersTo" + task.name, task.archivePath, getLibtorchHeadersDir(), 'headers')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def addFolderToAarTask(taskName, aarPath, folderPath, folderPathInAar) {
|
||||
return tasks.register(taskName) {
|
||||
doLast {
|
||||
addFolderToAar(taskName, aarPath, folderPath, folderPathInAar)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def addFolderToAar(taskName, aarPath, folderPath, folderPathInAar) {
|
||||
def tmpDir = file("${buildDir}/${taskName}")
|
||||
tmpDir.mkdir()
|
||||
def tmpDirFolder = file("${tmpDir.path}/${folderPathInAar}")
|
||||
tmpDirFolder.mkdir()
|
||||
copy {
|
||||
from zipTree(aarPath)
|
||||
into tmpDir
|
||||
}
|
||||
copy {
|
||||
from fileTree(folderPath)
|
||||
into tmpDirFolder
|
||||
}
|
||||
ant.zip(destfile: aarPath) {
|
||||
fileset(dir: tmpDir.path)
|
||||
}
|
||||
delete tmpDir
|
||||
}
|
||||
|
||||
artifacts.add('archives', sourcesJar)
|
||||
@ -1,20 +0,0 @@
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/jit.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
std::string input_file_path{argv[1]};
|
||||
std::string output_file_path{argv[2]};
|
||||
|
||||
std::ifstream ifs(input_file_path);
|
||||
std::stringstream buffer;
|
||||
buffer << ifs.rdbuf();
|
||||
torch::jit::Module m("TestModule");
|
||||
|
||||
m.define(buffer.str());
|
||||
m.save(output_file_path);
|
||||
}
|
||||
@ -1,151 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
OUTPUT_DIR = "src/androidTest/assets/"
|
||||
|
||||
|
||||
def scriptAndSave(module, fileName):
|
||||
print("-" * 80)
|
||||
script_module = torch.jit.script(module)
|
||||
print(script_module.graph)
|
||||
outputFileName = OUTPUT_DIR + fileName
|
||||
# note that the lite interpreter model can also be used in full JIT
|
||||
script_module._save_for_lite_interpreter(outputFileName)
|
||||
print("Saved to " + outputFileName)
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
class Test(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, input):
|
||||
return None
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqBool(self, input: bool) -> bool:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqInt(self, input: int) -> int:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqFloat(self, input: float) -> float:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqStr(self, input: str) -> str:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqTensor(self, input: Tensor) -> Tensor:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqDictStrKeyIntValue(self, input: dict[str, int]) -> dict[str, int]:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqDictIntKeyIntValue(self, input: dict[int, int]) -> dict[int, int]:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqDictFloatKeyIntValue(self, input: dict[float, int]) -> dict[float, int]:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def listIntSumReturnTuple(self, input: list[int]) -> tuple[list[int], int]:
|
||||
sum = 0
|
||||
for x in input:
|
||||
sum += x
|
||||
return (input, sum)
|
||||
|
||||
@torch.jit.script_method
|
||||
def listBoolConjunction(self, input: list[bool]) -> bool:
|
||||
res = True
|
||||
for x in input:
|
||||
res = res and x
|
||||
return res
|
||||
|
||||
@torch.jit.script_method
|
||||
def listBoolDisjunction(self, input: list[bool]) -> bool:
|
||||
res = False
|
||||
for x in input:
|
||||
res = res or x
|
||||
return res
|
||||
|
||||
@torch.jit.script_method
|
||||
def tupleIntSumReturnTuple(
|
||||
self, input: tuple[int, int, int]
|
||||
) -> tuple[tuple[int, int, int], int]:
|
||||
sum = 0
|
||||
for x in input:
|
||||
sum += x
|
||||
return (input, sum)
|
||||
|
||||
@torch.jit.script_method
|
||||
def optionalIntIsNone(self, input: Optional[int]) -> bool:
|
||||
return input is None
|
||||
|
||||
@torch.jit.script_method
|
||||
def intEq0None(self, input: int) -> Optional[int]:
|
||||
if input == 0:
|
||||
return None
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def str3Concat(self, input: str) -> str:
|
||||
return input + input + input
|
||||
|
||||
@torch.jit.script_method
|
||||
def newEmptyShapeWithItem(self, input):
|
||||
return torch.tensor([int(input.item())])[0]
|
||||
|
||||
@torch.jit.script_method
|
||||
def testAliasWithOffset(self) -> list[Tensor]:
|
||||
x = torch.tensor([100, 200])
|
||||
a = [x[0], x[1]]
|
||||
return a
|
||||
|
||||
@torch.jit.script_method
|
||||
def testNonContiguous(self):
|
||||
x = torch.tensor([100, 200, 300])[::2]
|
||||
assert not x.is_contiguous()
|
||||
assert x[0] == 100
|
||||
assert x[1] == 300
|
||||
return x
|
||||
|
||||
@torch.jit.script_method
|
||||
def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||
r = torch.nn.functional.conv2d(x, w)
|
||||
if toChannelsLast:
|
||||
r = r.contiguous(memory_format=torch.channels_last)
|
||||
else:
|
||||
r = r.contiguous()
|
||||
return r
|
||||
|
||||
@torch.jit.script_method
|
||||
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||
r = torch.nn.functional.conv3d(x, w)
|
||||
if toChannelsLast:
|
||||
r = r.contiguous(memory_format=torch.channels_last_3d)
|
||||
else:
|
||||
r = r.contiguous()
|
||||
return r
|
||||
|
||||
@torch.jit.script_method
|
||||
def contiguous(self, x: Tensor) -> Tensor:
|
||||
return x.contiguous()
|
||||
|
||||
@torch.jit.script_method
|
||||
def contiguousChannelsLast(self, x: Tensor) -> Tensor:
|
||||
return x.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
@torch.jit.script_method
|
||||
def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:
|
||||
return x.contiguous(memory_format=torch.channels_last_3d)
|
||||
|
||||
|
||||
scriptAndSave(Test(), "test.pt")
|
||||
@ -1,4 +0,0 @@
|
||||
POM_NAME=pytorch_android_lite pytorch android api
|
||||
POM_DESCRIPTION=pytorch_android_lite pytorch android api
|
||||
POM_ARTIFACT_ID=pytorch_android_lite
|
||||
POM_PACKAGING=aar
|
||||
@ -1,42 +0,0 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
//
|
||||
// This source code is licensed under the Apache-2 license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
plugins {
|
||||
id 'java-library'
|
||||
}
|
||||
|
||||
repositories {
|
||||
mavenLocal()
|
||||
jcenter()
|
||||
}
|
||||
|
||||
sourceSets {
|
||||
main {
|
||||
java {
|
||||
srcDir '../src/main/java'
|
||||
exclude 'org/pytorch/PyTorchAndroid.java'
|
||||
exclude 'org/pytorch/LitePyTorchAndroid.java'
|
||||
exclude 'org/pytorch/LiteModuleLoader.java'
|
||||
exclude 'org/pytorch/LiteNativePeer.java'
|
||||
}
|
||||
}
|
||||
test {
|
||||
java {
|
||||
srcDir '../src/androidTest/java'
|
||||
exclude '**/PytorchInstrumented*'
|
||||
exclude '**/PytorchLiteInstrumented*'
|
||||
}
|
||||
resources.srcDirs = ["../src/androidTest/assets"]
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
compileOnly 'com.google.code.findbugs:jsr305:3.0.1'
|
||||
implementation 'com.facebook.soloader:nativeloader:0.10.1'
|
||||
implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
|
||||
testImplementation 'junit:junit:4.12'
|
||||
}
|
||||
|
||||
apply from: rootProject.file('gradle/release.gradle')
|
||||
@ -1,4 +0,0 @@
|
||||
POM_NAME=pytorch_java_only pytorch java api
|
||||
POM_DESCRIPTION=pytorch_java_only pytorch java api
|
||||
POM_ARTIFACT_ID=pytorch_java_only
|
||||
POM_PACKAGING=jar
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,21 +0,0 @@
|
||||
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
//
|
||||
// This source code is licensed under the BSD-style license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include "caffe2/android/pytorch_android/src/main/cpp/pytorch_jni_common.h"
|
||||
|
||||
using namespace ::testing;
|
||||
|
||||
TEST(pytorch_jni_common_test, newJIValueFromAtIValue) {
|
||||
auto dict = c10::impl::GenericDict(
|
||||
c10::dynT<c10::IntType>(), c10::dynT<c10::StringType>());
|
||||
auto dictCallback = [](auto&&) {
|
||||
return facebook::jni::local_ref<pytorch_jni::JIValue>{};
|
||||
};
|
||||
EXPECT_NO_THROW(pytorch_jni::JIValue::newJIValueFromAtIValue(
|
||||
dict, dictCallback, dictCallback));
|
||||
}
|
||||
@ -1,25 +0,0 @@
|
||||
package org.pytorch;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.Objects;
|
||||
|
||||
public class PytorchHostTests extends PytorchTestBase {
|
||||
|
||||
@Override
|
||||
protected Module loadModel(String path) throws IOException {
|
||||
return Module.load(assetFilePath(path));
|
||||
}
|
||||
|
||||
private String assetFilePath(String assetName) throws IOException {
|
||||
Path tempFile = Files.createTempFile("test", ".pt");
|
||||
try (InputStream resource =
|
||||
Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) {
|
||||
Files.copy(resource, tempFile, StandardCopyOption.REPLACE_EXISTING);
|
||||
}
|
||||
return tempFile.toAbsolutePath().toString();
|
||||
}
|
||||
}
|
||||
@ -1,42 +0,0 @@
|
||||
package org.pytorch;
|
||||
|
||||
import android.content.Context;
|
||||
import androidx.test.InstrumentationRegistry;
|
||||
import androidx.test.runner.AndroidJUnit4;
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class PytorchInstrumentedTests extends PytorchTestBase {
|
||||
|
||||
@Override
|
||||
protected Module loadModel(String path) throws IOException {
|
||||
return Module.load(assetFilePath(path));
|
||||
}
|
||||
|
||||
private String assetFilePath(String assetName) throws IOException {
|
||||
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
|
||||
File file = new File(appContext.getFilesDir(), assetName);
|
||||
if (file.exists() && file.length() > 0) {
|
||||
return file.getAbsolutePath();
|
||||
}
|
||||
|
||||
try (InputStream is = appContext.getAssets().open(assetName)) {
|
||||
try (OutputStream os = new FileOutputStream(file)) {
|
||||
byte[] buffer = new byte[4 * 1024];
|
||||
int read;
|
||||
while ((read = is.read(buffer)) != -1) {
|
||||
os.write(buffer, 0, read);
|
||||
}
|
||||
os.flush();
|
||||
}
|
||||
return file.getAbsolutePath();
|
||||
} catch (IOException e) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,42 +0,0 @@
|
||||
package org.pytorch;
|
||||
|
||||
import android.content.Context;
|
||||
import androidx.test.InstrumentationRegistry;
|
||||
import androidx.test.runner.AndroidJUnit4;
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class PytorchLiteInstrumentedTests extends PytorchTestBase {
|
||||
|
||||
@Override
|
||||
protected Module loadModel(String path) throws IOException {
|
||||
return LiteModuleLoader.load(assetFilePath(path));
|
||||
}
|
||||
|
||||
private String assetFilePath(String assetName) throws IOException {
|
||||
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
|
||||
File file = new File(appContext.getFilesDir(), assetName);
|
||||
if (file.exists() && file.length() > 0) {
|
||||
return file.getAbsolutePath();
|
||||
}
|
||||
|
||||
try (InputStream is = appContext.getAssets().open(assetName)) {
|
||||
try (OutputStream os = new FileOutputStream(file)) {
|
||||
byte[] buffer = new byte[4 * 1024];
|
||||
int read;
|
||||
while ((read = is.read(buffer)) != -1) {
|
||||
os.write(buffer, 0, read);
|
||||
}
|
||||
os.flush();
|
||||
}
|
||||
return file.getAbsolutePath();
|
||||
} catch (IOException e) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,694 +0,0 @@
|
||||
package org.pytorch;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.junit.Test;
|
||||
import org.junit.Ignore;
|
||||
|
||||
public abstract class PytorchTestBase {
|
||||
private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl";
|
||||
|
||||
@Test
|
||||
public void testForwardNull() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
|
||||
assertTrue(input.isTensor());
|
||||
final IValue output = module.forward(input);
|
||||
assertTrue(output.isNull());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqBool() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
for (boolean value : new boolean[] {false, true}) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isBool());
|
||||
assertTrue(value == input.toBool());
|
||||
final IValue output = module.runMethod("eqBool", input);
|
||||
assertTrue(output.isBool());
|
||||
assertTrue(value == output.toBool());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqInt() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isLong());
|
||||
assertTrue(value == input.toLong());
|
||||
final IValue output = module.runMethod("eqInt", input);
|
||||
assertTrue(output.isLong());
|
||||
assertTrue(value == output.toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqFloat() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
double[] values =
|
||||
new double[] {
|
||||
-Double.MAX_VALUE,
|
||||
Double.MAX_VALUE,
|
||||
-Double.MIN_VALUE,
|
||||
Double.MIN_VALUE,
|
||||
-Math.exp(1.d),
|
||||
-Math.sqrt(2.d),
|
||||
-3.1415f,
|
||||
3.1415f,
|
||||
-1,
|
||||
0,
|
||||
1,
|
||||
};
|
||||
for (double value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isDouble());
|
||||
assertTrue(value == input.toDouble());
|
||||
final IValue output = module.runMethod("eqFloat", input);
|
||||
assertTrue(output.isDouble());
|
||||
assertTrue(value == output.toDouble());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqTensor() throws IOException {
|
||||
final long[] inputTensorShape = new long[] {1, 3, 224, 224};
|
||||
final long numElements = Tensor.numel(inputTensorShape);
|
||||
final float[] inputTensorData = new float[(int) numElements];
|
||||
for (int i = 0; i < numElements; ++i) {
|
||||
inputTensorData[i] = i;
|
||||
}
|
||||
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
|
||||
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue input = IValue.from(inputTensor);
|
||||
assertTrue(input.isTensor());
|
||||
assertTrue(inputTensor == input.toTensor());
|
||||
final IValue output = module.runMethod("eqTensor", input);
|
||||
assertTrue(output.isTensor());
|
||||
final Tensor outputTensor = output.toTensor();
|
||||
assertNotNull(outputTensor);
|
||||
assertArrayEquals(inputTensorShape, outputTensor.shape());
|
||||
float[] outputData = outputTensor.getDataAsFloatArray();
|
||||
for (int i = 0; i < numElements; i++) {
|
||||
assertTrue(inputTensorData[i] == outputData[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictIntKeyIntValue() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final Map<Long, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
|
||||
inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE));
|
||||
inputMap.put(0l, IValue.from(0l));
|
||||
inputMap.put(1l, IValue.from(-1l));
|
||||
inputMap.put(-1l, IValue.from(1l));
|
||||
|
||||
final IValue input = IValue.dictLongKeyFrom(inputMap);
|
||||
assertTrue(input.isDictLongKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
|
||||
assertTrue(output.isDictLongKey());
|
||||
|
||||
final Map<Long, IValue> outputMap = output.toDictLongKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictStrKeyIntValue() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final Map<String, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
|
||||
inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE));
|
||||
inputMap.put("long_0", IValue.from(0l));
|
||||
inputMap.put("long_1", IValue.from(1l));
|
||||
inputMap.put("long_-1", IValue.from(-1l));
|
||||
|
||||
final IValue input = IValue.dictStringKeyFrom(inputMap);
|
||||
assertTrue(input.isDictStringKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
|
||||
assertTrue(output.isDictStringKey());
|
||||
|
||||
final Map<String, IValue> outputMap = output.toDictStringKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testListIntSumReturnTuple() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
for (int n : new int[] {0, 1, 128}) {
|
||||
long[] a = new long[n];
|
||||
long sum = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
a[i] = i;
|
||||
sum += a[i];
|
||||
}
|
||||
final IValue input = IValue.listFrom(a);
|
||||
assertTrue(input.isLongList());
|
||||
|
||||
final IValue output = module.runMethod("listIntSumReturnTuple", input);
|
||||
|
||||
assertTrue(output.isTuple());
|
||||
assertTrue(2 == output.toTuple().length);
|
||||
|
||||
IValue output0 = output.toTuple()[0];
|
||||
IValue output1 = output.toTuple()[1];
|
||||
|
||||
assertArrayEquals(a, output0.toLongList());
|
||||
assertTrue(sum == output1.toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOptionalIntIsNone() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
|
||||
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIntEq0None() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testRunUndefinedMethod() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
module.runMethod("test_undefined_method_throws_exception");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorMethods() {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
int[] ints = new int[numel];
|
||||
float[] floats = new float[numel];
|
||||
|
||||
byte[] bytes = new byte[numel];
|
||||
for (int i = 0; i < numel; i++) {
|
||||
bytes[i] = (byte) ((i % 255) - 128);
|
||||
ints[i] = i;
|
||||
floats[i] = i / 1000.f;
|
||||
}
|
||||
|
||||
Tensor tensorBytes = Tensor.fromBlob(bytes, shape);
|
||||
assertTrue(tensorBytes.dtype() == DType.INT8);
|
||||
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
|
||||
|
||||
Tensor tensorInts = Tensor.fromBlob(ints, shape);
|
||||
assertTrue(tensorInts.dtype() == DType.INT32);
|
||||
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
|
||||
|
||||
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
|
||||
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
|
||||
float[] floatsOut = tensorFloats.getDataAsFloatArray();
|
||||
assertTrue(floatsOut.length == numel);
|
||||
for (int i = 0; i < numel; i++) {
|
||||
assertTrue(floats[i] == floatsOut[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void testTensorIllegalStateOnWrongType() {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
float[] floats = new float[numel];
|
||||
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
|
||||
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
|
||||
tensorFloats.getDataAsByteArray();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqString() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.toStr()));
|
||||
final IValue output = module.runMethod("eqStr", input);
|
||||
assertTrue(output.isString());
|
||||
assertTrue(value.equals(output.toStr()));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStr3Concat() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.toStr()));
|
||||
final IValue output = module.runMethod("str3Concat", input);
|
||||
assertTrue(output.isString());
|
||||
String expectedOutput =
|
||||
new StringBuilder().append(value).append(value).append(value).toString();
|
||||
assertTrue(expectedOutput.equals(output.toStr()));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyShape() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final long someNumber = 43;
|
||||
final IValue input = IValue.from(Tensor.fromBlob(new long[] {someNumber}, new long[] {}));
|
||||
final IValue output = module.runMethod("newEmptyShapeWithItem", input);
|
||||
assertTrue(output.isTensor());
|
||||
Tensor value = output.toTensor();
|
||||
assertArrayEquals(new long[] {}, value.shape());
|
||||
assertArrayEquals(new long[] {someNumber}, value.getDataAsLongArray());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAliasWithOffset() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue output = module.runMethod("testAliasWithOffset");
|
||||
assertTrue(output.isTensorList());
|
||||
Tensor[] tensors = output.toTensorList();
|
||||
assertEquals(100, tensors[0].getDataAsLongArray()[0]);
|
||||
assertEquals(200, tensors[1].getDataAsLongArray()[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNonContiguous() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue output = module.runMethod("testNonContiguous");
|
||||
assertTrue(output.isTensor());
|
||||
Tensor value = output.toTensor();
|
||||
assertArrayEquals(new long[] {2}, value.shape());
|
||||
assertArrayEquals(new long[] {100, 300}, value.getDataAsLongArray());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChannelsLast() throws IOException {
|
||||
long[] inputShape = new long[] {1, 3, 2, 2};
|
||||
long[] data = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
|
||||
Tensor inputNHWC = Tensor.fromBlob(data, inputShape, MemoryFormat.CHANNELS_LAST);
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue outputNCHW = module.runMethod("contiguous", IValue.from(inputNHWC));
|
||||
assertIValueTensor(
|
||||
outputNCHW,
|
||||
MemoryFormat.CONTIGUOUS,
|
||||
new long[] {1, 3, 2, 2},
|
||||
new long[] {1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104});
|
||||
final IValue outputNHWC = module.runMethod("contiguousChannelsLast", IValue.from(inputNHWC));
|
||||
assertIValueTensor(outputNHWC, MemoryFormat.CHANNELS_LAST, inputShape, data);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChannelsLast3d() throws IOException {
|
||||
long[] shape = new long[] {1, 2, 2, 2, 2};
|
||||
long[] dataNCHWD = new long[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
|
||||
long[] dataNHWDC = new long[] {1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16};
|
||||
|
||||
Tensor inputNHWDC = Tensor.fromBlob(dataNHWDC, shape, MemoryFormat.CHANNELS_LAST_3D);
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue outputNCHWD = module.runMethod("contiguous", IValue.from(inputNHWDC));
|
||||
assertIValueTensor(outputNCHWD, MemoryFormat.CONTIGUOUS, shape, dataNCHWD);
|
||||
|
||||
Tensor inputNCHWD = Tensor.fromBlob(dataNCHWD, shape, MemoryFormat.CONTIGUOUS);
|
||||
final IValue outputNHWDC =
|
||||
module.runMethod("contiguousChannelsLast3d", IValue.from(inputNCHWD));
|
||||
assertIValueTensor(outputNHWDC, MemoryFormat.CHANNELS_LAST_3D, shape, dataNHWDC);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChannelsLastConv2d() throws IOException {
|
||||
long[] inputShape = new long[] {1, 3, 2, 2};
|
||||
long[] dataNCHW = new long[] {
|
||||
111, 112,
|
||||
121, 122,
|
||||
|
||||
211, 212,
|
||||
221, 222,
|
||||
|
||||
311, 312,
|
||||
321, 322};
|
||||
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataNHWC = new long[] {
|
||||
111, 211, 311, 112, 212, 312,
|
||||
|
||||
121, 221, 321, 122, 222, 322};
|
||||
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
|
||||
long[] weightShape = new long[] {3, 3, 1, 1};
|
||||
long[] dataWeightOIHW = new long[] {
|
||||
2, 0, 0,
|
||||
0, 1, 0,
|
||||
0, 0, -1};
|
||||
Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataWeightOHWI = new long[] {
|
||||
2, 0, 0,
|
||||
0, 1, 0,
|
||||
0, 0, -1};
|
||||
|
||||
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
|
||||
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
final IValue outputNCHW =
|
||||
module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false));
|
||||
assertIValueTensor(
|
||||
outputNCHW,
|
||||
MemoryFormat.CONTIGUOUS,
|
||||
new long[] {1, 3, 2, 2},
|
||||
new long[] {
|
||||
2*111, 2*112,
|
||||
2*121, 2*122,
|
||||
|
||||
211, 212,
|
||||
221, 222,
|
||||
|
||||
-311, -312,
|
||||
-321, -322});
|
||||
|
||||
final IValue outputNHWC =
|
||||
module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true));
|
||||
assertIValueTensor(
|
||||
outputNHWC,
|
||||
MemoryFormat.CHANNELS_LAST,
|
||||
new long[] {1, 3, 2, 2},
|
||||
new long[] {
|
||||
2*111, 211, -311, 2*112, 212, -312,
|
||||
2*121, 221, -321, 2*122, 222, -322});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChannelsLastConv3d() throws IOException {
|
||||
long[] inputShape = new long[] {1, 3, 2, 2, 2};
|
||||
long[] dataNCDHW = new long[] {
|
||||
1111, 1112,
|
||||
1121, 1122,
|
||||
1211, 1212,
|
||||
1221, 1222,
|
||||
|
||||
2111, 2112,
|
||||
2121, 2122,
|
||||
2211, 2212,
|
||||
2221, 2222,
|
||||
|
||||
3111, 3112,
|
||||
3121, 3122,
|
||||
3211, 3212,
|
||||
3221, 3222};
|
||||
Tensor inputNCDHW = Tensor.fromBlob(dataNCDHW, inputShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataNDHWC = new long[] {
|
||||
1111, 2111, 3111,
|
||||
1112, 2112, 3112,
|
||||
|
||||
1121, 2121, 3121,
|
||||
1122, 2122, 3122,
|
||||
|
||||
1211, 2211, 3211,
|
||||
1212, 2212, 3212,
|
||||
|
||||
1221, 2221, 3221,
|
||||
1222, 2222, 3222};
|
||||
|
||||
Tensor inputNDHWC = Tensor.fromBlob(dataNDHWC, inputShape, MemoryFormat.CHANNELS_LAST_3D);
|
||||
|
||||
long[] weightShape = new long[] {3, 3, 1, 1, 1};
|
||||
long[] dataWeightOIDHW = new long[] {
|
||||
2, 0, 0,
|
||||
0, 1, 0,
|
||||
0, 0, -1,
|
||||
};
|
||||
Tensor wNCDHW = Tensor.fromBlob(dataWeightOIDHW, weightShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataWeightODHWI = new long[] {
|
||||
2, 0, 0,
|
||||
0, 1, 0,
|
||||
0, 0, -1,
|
||||
};
|
||||
Tensor wNDHWC = Tensor.fromBlob(dataWeightODHWI, weightShape, MemoryFormat.CHANNELS_LAST_3D);
|
||||
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
final IValue outputNCDHW =
|
||||
module.runMethod("conv3d", IValue.from(inputNCDHW), IValue.from(wNCDHW), IValue.from(false));
|
||||
assertIValueTensor(
|
||||
outputNCDHW,
|
||||
MemoryFormat.CONTIGUOUS,
|
||||
new long[] {1, 3, 2, 2, 2},
|
||||
new long[] {
|
||||
2*1111, 2*1112, 2*1121, 2*1122,
|
||||
2*1211, 2*1212, 2*1221, 2*1222,
|
||||
|
||||
2111, 2112, 2121, 2122,
|
||||
2211, 2212, 2221, 2222,
|
||||
|
||||
-3111, -3112, -3121, -3122,
|
||||
-3211, -3212, -3221, -3222});
|
||||
|
||||
final IValue outputNDHWC =
|
||||
module.runMethod("conv3d", IValue.from(inputNDHWC), IValue.from(wNDHWC), IValue.from(true));
|
||||
assertIValueTensor(
|
||||
outputNDHWC,
|
||||
MemoryFormat.CHANNELS_LAST_3D,
|
||||
new long[] {1, 3, 2, 2, 2},
|
||||
new long[] {
|
||||
2*1111, 2111, -3111, 2*1112, 2112, -3112,
|
||||
2*1121, 2121, -3121, 2*1122, 2122, -3122,
|
||||
|
||||
2*1211, 2211, -3211, 2*1212, 2212, -3212,
|
||||
2*1221, 2221, -3221, 2*1222, 2222, -3222});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMobileNetV2() throws IOException {
|
||||
try {
|
||||
final Module module = loadModel("mobilenet_v2.ptl");
|
||||
final IValue inputs = module.runMethod("get_all_bundled_inputs");
|
||||
assertTrue(inputs.isList());
|
||||
final IValue input = inputs.toList()[0];
|
||||
assertTrue(input.isTuple());
|
||||
module.forward(input.toTuple()[0]);
|
||||
assertTrue(true);
|
||||
} catch (Exception ex) {
|
||||
assertTrue("failed to run MobileNetV2 " + ex.getMessage(), false);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPointwiseOps() throws IOException {
|
||||
runModel("pointwise_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReductionOps() throws IOException {
|
||||
runModel("reduction_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComparisonOps() throws IOException {
|
||||
runModel("comparison_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOtherMathOps() throws IOException {
|
||||
runModel("other_math_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testSpectralOps() throws IOException {
|
||||
// NB: This model fails without lite interpreter. The error is as follows:
|
||||
// RuntimeError: stft requires the return_complex parameter be given for real inputs
|
||||
runModel("spectral_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBlasLapackOps() throws IOException {
|
||||
runModel("blas_lapack_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSamplingOps() throws IOException {
|
||||
runModel("sampling_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorOps() throws IOException {
|
||||
runModel("tensor_general_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorCreationOps() throws IOException {
|
||||
runModel("tensor_creation_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorIndexingOps() throws IOException {
|
||||
runModel("tensor_indexing_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorTypingOps() throws IOException {
|
||||
runModel("tensor_typing_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorViewOps() throws IOException {
|
||||
runModel("tensor_view_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConvolutionOps() throws IOException {
|
||||
runModel("convolution_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPoolingOps() throws IOException {
|
||||
runModel("pooling_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPaddingOps() throws IOException {
|
||||
runModel("padding_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testActivationOps() throws IOException {
|
||||
runModel("activation_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNormalizationOps() throws IOException {
|
||||
runModel("normalization_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRecurrentOps() throws IOException {
|
||||
runModel("recurrent_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTransformerOps() throws IOException {
|
||||
runModel("transformer_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLinearOps() throws IOException {
|
||||
runModel("linear_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDropoutOps() throws IOException {
|
||||
runModel("dropout_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSparseOps() throws IOException {
|
||||
runModel("sparse_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDistanceFunctionOps() throws IOException {
|
||||
runModel("distance_function_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLossFunctionOps() throws IOException {
|
||||
runModel("loss_function_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVisionFunctionOps() throws IOException {
|
||||
runModel("vision_function_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShuffleOps() throws IOException {
|
||||
runModel("shuffle_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNNUtilsOps() throws IOException {
|
||||
runModel("nn_utils_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testQuantOps() throws IOException {
|
||||
runModel("general_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDynamicQuantOps() throws IOException {
|
||||
runModel("dynamic_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStaticQuantOps() throws IOException {
|
||||
runModel("static_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFusedQuantOps() throws IOException {
|
||||
runModel("fused_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTorchScriptBuiltinQuantOps() throws IOException {
|
||||
runModel("torchscript_builtin_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTorchScriptCollectionQuantOps() throws IOException {
|
||||
runModel("torchscript_collection_ops");
|
||||
}
|
||||
|
||||
static void assertIValueTensor(
|
||||
final IValue ivalue,
|
||||
final MemoryFormat memoryFormat,
|
||||
final long[] expectedShape,
|
||||
final long[] expectedData) {
|
||||
assertTrue(ivalue.isTensor());
|
||||
Tensor t = ivalue.toTensor();
|
||||
assertEquals(memoryFormat, t.memoryFormat());
|
||||
assertArrayEquals(expectedShape, t.shape());
|
||||
assertArrayEquals(expectedData, t.getDataAsLongArray());
|
||||
}
|
||||
|
||||
void runModel(final String name) throws IOException {
|
||||
final Module storage_module = loadModel(name + ".ptl");
|
||||
storage_module.forward();
|
||||
|
||||
// TODO enable this once the on-the-fly script is ready
|
||||
// final Module on_the_fly_module = loadModel(name + "_temp.ptl");
|
||||
// on_the_fly_module.forward();
|
||||
assertTrue(true);
|
||||
}
|
||||
|
||||
protected abstract Module loadModel(String assetName) throws IOException;
|
||||
}
|
||||
@ -1,9 +0,0 @@
|
||||
package org.pytorch.suite;
|
||||
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Suite;
|
||||
import org.pytorch.PytorchInstrumentedTests;
|
||||
|
||||
@RunWith(Suite.class)
|
||||
@Suite.SuiteClasses({PytorchInstrumentedTests.class})
|
||||
public class PytorchInstrumentedTestSuite {}
|
||||
@ -1,9 +0,0 @@
|
||||
package org.pytorch.suite;
|
||||
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Suite;
|
||||
import org.pytorch.PytorchLiteInstrumentedTests;
|
||||
|
||||
@RunWith(Suite.class)
|
||||
@Suite.SuiteClasses({PytorchLiteInstrumentedTests.class})
|
||||
public class PytorchLiteInstrumentedTestSuite {}
|
||||
@ -1 +0,0 @@
|
||||
<manifest package="org.pytorch" />
|
||||
@ -1,3 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
/* #undef TRACE_ENABLED */
|
||||
@ -1,3 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#cmakedefine TRACE_ENABLED
|
||||
@ -1,697 +0,0 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <fbjni/ByteBuffer.h>
|
||||
#include <fbjni/fbjni.h>
|
||||
|
||||
#include "pytorch_jni_common.h"
|
||||
#if defined(__ANDROID__)
|
||||
#ifndef USE_PTHREADPOOL
|
||||
#define USE_PTHREADPOOL
|
||||
#endif /* USE_PTHREADPOOL */
|
||||
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
||||
#endif
|
||||
|
||||
namespace pytorch_jni {
|
||||
|
||||
c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode) {
|
||||
if (deviceJniCode == kDeviceCPU) {
|
||||
return at::kCPU;
|
||||
} else if (deviceJniCode == kDeviceVulkan) {
|
||||
return at::kVulkan;
|
||||
}
|
||||
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException, "Unknown device");
|
||||
}
|
||||
|
||||
bool Trace::is_initialized_ = false;
|
||||
|
||||
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||
Trace::fp_ATrace_beginSection Trace::ATrace_beginSection;
|
||||
Trace::fp_ATrace_endSection Trace::ATrace_endSection;
|
||||
#endif
|
||||
|
||||
void Trace::init() {
|
||||
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||
void* lib = dlopen("libandroid.so", RTLD_NOW || RTLD_LOCAL);
|
||||
if (lib != NULL) {
|
||||
Trace::ATrace_beginSection = reinterpret_cast<fp_ATrace_beginSection>(
|
||||
dlsym(lib, "ATrace_beginSection"));
|
||||
Trace::ATrace_endSection =
|
||||
reinterpret_cast<fp_ATrace_endSection>(dlsym(lib, "ATrace_endSection"));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// NOTE: Codes must be kept in sync with DType.java.
|
||||
// NOTE: Never serialize these, because they can change between releases.
|
||||
constexpr static int kTensorDTypeUInt8 = 1;
|
||||
constexpr static int kTensorDTypeInt8 = 2;
|
||||
constexpr static int kTensorDTypeInt32 = 3;
|
||||
constexpr static int kTensorDTypeFloat32 = 4;
|
||||
constexpr static int kTensorDTypeInt64 = 5;
|
||||
constexpr static int kTensorDTypeFloat64 = 6;
|
||||
|
||||
constexpr static int kTensorMemoryFormatContiguous = 1;
|
||||
constexpr static int kTensorMemoryFormatChannelsLast = 2;
|
||||
constexpr static int kTensorMemoryFormatChannelsLast3d = 3;
|
||||
|
||||
template <typename K = jobject, typename V = jobject>
|
||||
struct JHashMap
|
||||
: facebook::jni::JavaClass<JHashMap<K, V>, facebook::jni::JMap<K, V>> {
|
||||
constexpr static auto kJavaDescriptor = "Ljava/util/HashMap;";
|
||||
|
||||
using Super =
|
||||
facebook::jni::JavaClass<JHashMap<K, V>, facebook::jni::JMap<K, V>>;
|
||||
|
||||
static facebook::jni::local_ref<JHashMap<K, V>> create() {
|
||||
return Super::newInstance();
|
||||
}
|
||||
|
||||
void put(
|
||||
facebook::jni::alias_ref<facebook::jni::JObject::javaobject> key,
|
||||
facebook::jni::alias_ref<facebook::jni::JObject::javaobject> value) {
|
||||
static auto putMethod =
|
||||
Super::javaClassStatic()
|
||||
->template getMethod<facebook::jni::alias_ref<
|
||||
facebook::jni::JObject::javaobject>(
|
||||
facebook::jni::alias_ref<facebook::jni::JObject::javaobject>,
|
||||
facebook::jni::alias_ref<facebook::jni::JObject::javaobject>)>(
|
||||
"put");
|
||||
putMethod(Super::self(), key, value);
|
||||
}
|
||||
};
|
||||
|
||||
static at::Tensor newAtTensor(
|
||||
facebook::jni::alias_ref<facebook::jni::JBuffer> jbuffer,
|
||||
facebook::jni::alias_ref<jlongArray> jshape,
|
||||
jint jdtype,
|
||||
jint jmemoryFormat) {
|
||||
const auto rank = jshape->size();
|
||||
const auto shapeArr = jshape->getRegion(0, rank);
|
||||
std::vector<int64_t> shapeVec{};
|
||||
shapeVec.reserve(rank);
|
||||
auto numel = 1;
|
||||
for (const auto i : c10::irange(rank)) {
|
||||
shapeVec.push_back(shapeArr[i]);
|
||||
numel *= shapeArr[i];
|
||||
}
|
||||
JNIEnv* jni = facebook::jni::Environment::current();
|
||||
caffe2::TypeMeta typeMeta{};
|
||||
int dataElementSizeBytes = 0;
|
||||
if (kTensorDTypeFloat32 == jdtype) {
|
||||
dataElementSizeBytes = 4;
|
||||
typeMeta = caffe2::TypeMeta::Make<float>();
|
||||
} else if (kTensorDTypeInt32 == jdtype) {
|
||||
dataElementSizeBytes = 4;
|
||||
typeMeta = caffe2::TypeMeta::Make<int32_t>();
|
||||
} else if (kTensorDTypeInt8 == jdtype) {
|
||||
dataElementSizeBytes = 1;
|
||||
typeMeta = caffe2::TypeMeta::Make<int8_t>();
|
||||
} else if (kTensorDTypeUInt8 == jdtype) {
|
||||
dataElementSizeBytes = 1;
|
||||
typeMeta = caffe2::TypeMeta::Make<uint8_t>();
|
||||
} else if (kTensorDTypeFloat64 == jdtype) {
|
||||
dataElementSizeBytes = 8;
|
||||
typeMeta = caffe2::TypeMeta::Make<double>();
|
||||
} else if (kTensorDTypeInt64 == jdtype) {
|
||||
dataElementSizeBytes = 8;
|
||||
typeMeta = caffe2::TypeMeta::Make<int64_t>();
|
||||
} else {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Unknown Tensor jdtype %d",
|
||||
jdtype);
|
||||
}
|
||||
const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
|
||||
if (dataCapacity != numel) {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Tensor dimensions(elements number:%d, element byte size:%d, total "
|
||||
"bytes:%d) inconsistent with buffer capacity(%d)",
|
||||
numel,
|
||||
dataElementSizeBytes,
|
||||
numel * dataElementSizeBytes,
|
||||
dataCapacity);
|
||||
}
|
||||
|
||||
if (jmemoryFormat == kTensorMemoryFormatChannelsLast) {
|
||||
auto sizes = torch::IntArrayRef(shapeVec);
|
||||
return torch::from_blob(
|
||||
jni->GetDirectBufferAddress(jbuffer.get()),
|
||||
sizes,
|
||||
torch::IntArrayRef(c10::get_channels_last_strides_2d(sizes)),
|
||||
at::TensorOptions(typeMeta).memory_format(
|
||||
at::MemoryFormat::ChannelsLast));
|
||||
} else if (jmemoryFormat == kTensorMemoryFormatChannelsLast3d) {
|
||||
auto sizes = torch::IntArrayRef(shapeVec);
|
||||
return torch::from_blob(
|
||||
jni->GetDirectBufferAddress(jbuffer.get()),
|
||||
sizes,
|
||||
torch::IntArrayRef(c10::get_channels_last_strides_3d(sizes)),
|
||||
at::TensorOptions(typeMeta).memory_format(
|
||||
at::MemoryFormat::ChannelsLast3d));
|
||||
}
|
||||
return torch::from_blob(
|
||||
jni->GetDirectBufferAddress(jbuffer.get()),
|
||||
torch::IntArrayRef(shapeVec),
|
||||
at::TensorOptions(typeMeta));
|
||||
}
|
||||
|
||||
class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
|
||||
public:
|
||||
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/Tensor;";
|
||||
|
||||
explicit TensorHybrid(at::Tensor tensor) : tensor_(tensor) {}
|
||||
|
||||
static facebook::jni::local_ref<TensorHybrid::jhybriddata> initHybrid(
|
||||
facebook::jni::alias_ref<TensorHybrid::javaobject> jTensorThis) {
|
||||
static auto cls = TensorHybrid::javaClassStatic();
|
||||
static const auto jMethodDTypeCode = cls->getMethod<jint()>("dtypeJniCode");
|
||||
static const auto jMethodMemoryFormatCode =
|
||||
cls->getMethod<jint()>("memoryFormatJniCode");
|
||||
static const auto jFieldShape = cls->getField<jlongArray>("shape");
|
||||
static const auto jMethodGetDataBuffer = cls->getMethod<
|
||||
facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
|
||||
"getRawDataBuffer");
|
||||
|
||||
at::Tensor tensor = newAtTensor(
|
||||
jMethodGetDataBuffer(jTensorThis),
|
||||
jTensorThis->getFieldValue(jFieldShape),
|
||||
jMethodDTypeCode(jTensorThis),
|
||||
jMethodMemoryFormatCode(jTensorThis));
|
||||
return makeCxxInstance(std::move(tensor));
|
||||
}
|
||||
|
||||
static facebook::jni::local_ref<TensorHybrid::javaobject>
|
||||
newJTensorFromAtTensor(const at::Tensor& input_tensor) {
|
||||
// Java wrapper currently only supports contiguous tensors.
|
||||
|
||||
int jmemoryFormat = 0;
|
||||
at::Tensor tensor{};
|
||||
if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast)) {
|
||||
tensor = input_tensor;
|
||||
jmemoryFormat = kTensorMemoryFormatChannelsLast;
|
||||
} else if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
|
||||
tensor = input_tensor;
|
||||
jmemoryFormat = kTensorMemoryFormatChannelsLast3d;
|
||||
} else {
|
||||
tensor = input_tensor.contiguous();
|
||||
jmemoryFormat = kTensorMemoryFormatContiguous;
|
||||
}
|
||||
|
||||
const auto scalarType = tensor.scalar_type();
|
||||
int jdtype = 0;
|
||||
if (at::kFloat == scalarType) {
|
||||
jdtype = kTensorDTypeFloat32;
|
||||
} else if (at::kInt == scalarType) {
|
||||
jdtype = kTensorDTypeInt32;
|
||||
} else if (at::kByte == scalarType) {
|
||||
jdtype = kTensorDTypeUInt8;
|
||||
} else if (at::kChar == scalarType) {
|
||||
jdtype = kTensorDTypeInt8;
|
||||
} else if (at::kLong == scalarType) {
|
||||
jdtype = kTensorDTypeInt64;
|
||||
} else if (at::kDouble == scalarType) {
|
||||
jdtype = kTensorDTypeFloat64;
|
||||
} else {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"at::Tensor scalar type %s is not supported on java side",
|
||||
c10::toString(scalarType));
|
||||
}
|
||||
|
||||
const auto& tensorShape = tensor.sizes();
|
||||
std::vector<jlong> tensorShapeVec;
|
||||
for (const auto& s : tensorShape) {
|
||||
tensorShapeVec.push_back(s);
|
||||
}
|
||||
facebook::jni::local_ref<jlongArray> jTensorShape =
|
||||
facebook::jni::make_long_array(tensorShapeVec.size());
|
||||
jTensorShape->setRegion(0, tensorShapeVec.size(), tensorShapeVec.data());
|
||||
|
||||
static auto cls = TensorHybrid::javaClassStatic();
|
||||
facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
|
||||
facebook::jni::JByteBuffer::wrapBytes(
|
||||
(uint8_t*)tensor.data_ptr(), tensor.nbytes());
|
||||
jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
|
||||
|
||||
static const auto jMethodNewTensor =
|
||||
cls->getStaticMethod<facebook::jni::local_ref<TensorHybrid::javaobject>(
|
||||
facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
|
||||
facebook::jni::alias_ref<jlongArray>,
|
||||
jint,
|
||||
jint,
|
||||
facebook::jni::alias_ref<jhybriddata>)>("nativeNewTensor");
|
||||
return jMethodNewTensor(
|
||||
cls,
|
||||
jTensorBuffer,
|
||||
jTensorShape,
|
||||
jdtype,
|
||||
jmemoryFormat,
|
||||
makeCxxInstance(tensor));
|
||||
}
|
||||
|
||||
static at::Tensor newAtTensorFromJTensor(
|
||||
facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor) {
|
||||
static auto cls = TensorHybrid::javaClassStatic();
|
||||
static const auto dtypeMethod = cls->getMethod<jint()>("dtypeJniCode");
|
||||
jint jdtype = dtypeMethod(jtensor);
|
||||
|
||||
static const auto memoryFormatMethod =
|
||||
cls->getMethod<jint()>("memoryFormatJniCode");
|
||||
jint jmemoryFormat = memoryFormatMethod(jtensor);
|
||||
|
||||
static const auto shapeField = cls->getField<jlongArray>("shape");
|
||||
auto jshape = jtensor->getFieldValue(shapeField);
|
||||
|
||||
static auto dataBufferMethod = cls->getMethod<
|
||||
facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
|
||||
"getRawDataBuffer");
|
||||
facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
|
||||
dataBufferMethod(jtensor);
|
||||
return newAtTensor(jbuffer, jshape, jdtype, jmemoryFormat);
|
||||
}
|
||||
|
||||
at::Tensor tensor() const {
|
||||
return tensor_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend HybridBase;
|
||||
at::Tensor tensor_;
|
||||
};
|
||||
|
||||
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromStringDict(
|
||||
c10::Dict<c10::IValue, c10::IValue> dict) {
|
||||
static auto jMethodDictStringKey =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||
"dictStringKeyFrom");
|
||||
|
||||
auto jmap = JHashMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>::create();
|
||||
for (auto& pair : dict) {
|
||||
jmap->put(
|
||||
facebook::jni::make_jstring(pair.key().toStringRef()),
|
||||
JIValue::newJIValueFromAtIValue(pair.value()));
|
||||
}
|
||||
return jMethodDictStringKey(JIValue::javaClassStatic(), jmap);
|
||||
}
|
||||
|
||||
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromIntDict(
|
||||
c10::Dict<c10::IValue, c10::IValue> dict) {
|
||||
static auto jMethodDictLongKey =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||
"dictLongKeyFrom");
|
||||
auto jmap = JHashMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>::create();
|
||||
for (auto& pair : dict) {
|
||||
jmap->put(
|
||||
facebook::jni::JLong::valueOf(pair.key().toInt()),
|
||||
JIValue::newJIValueFromAtIValue(pair.value()));
|
||||
}
|
||||
return jMethodDictLongKey(JIValue::javaClassStatic(), jmap);
|
||||
}
|
||||
|
||||
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
|
||||
const at::IValue& ivalue,
|
||||
DictCallback stringDictCallback,
|
||||
DictCallback intDictCallback) {
|
||||
Trace _s{"jni::JIValue::newJIValueFromAtIValue"};
|
||||
if (ivalue.isNone()) {
|
||||
static auto jMethodOptionalNull =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>()>(
|
||||
"optionalNull");
|
||||
return jMethodOptionalNull(JIValue::javaClassStatic());
|
||||
} else if (ivalue.isTensor()) {
|
||||
static auto jMethodTensor =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::local_ref<TensorHybrid::javaobject>)>("from");
|
||||
const auto& tensor = ivalue.toTensor();
|
||||
return jMethodTensor(
|
||||
JIValue::javaClassStatic(),
|
||||
TensorHybrid::newJTensorFromAtTensor(tensor));
|
||||
} else if (ivalue.isBool()) {
|
||||
static auto jMethodBool =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(jboolean)>(
|
||||
"from");
|
||||
return jMethodBool(JIValue::javaClassStatic(), ivalue.toBool());
|
||||
} else if (ivalue.isInt()) {
|
||||
static auto jMethodInt =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(jlong)>("from");
|
||||
return jMethodInt(JIValue::javaClassStatic(), ivalue.toInt());
|
||||
} else if (ivalue.isDouble()) {
|
||||
static auto jMethodDouble =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(jdouble)>(
|
||||
"from");
|
||||
return jMethodDouble(JIValue::javaClassStatic(), ivalue.toDouble());
|
||||
} else if (ivalue.isString()) {
|
||||
static auto jMethodString =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject>)>(
|
||||
"from");
|
||||
return jMethodString(
|
||||
JIValue::javaClassStatic(),
|
||||
facebook::jni::make_jstring(ivalue.toStringRef()));
|
||||
} else if (ivalue.isTuple()) {
|
||||
auto elementsVec = ivalue.toTupleRef().elements();
|
||||
static auto jMethodTupleArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject>)>("tupleFrom");
|
||||
auto jElementsArray =
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::newArray(
|
||||
elementsVec.size());
|
||||
auto index = 0;
|
||||
for (const auto& e : elementsVec) {
|
||||
(*jElementsArray)[index++] = JIValue::newJIValueFromAtIValue(e);
|
||||
}
|
||||
return jMethodTupleArr(JIValue::javaClassStatic(), jElementsArray);
|
||||
} else if (ivalue.isBoolList()) {
|
||||
auto list = ivalue.toBoolList();
|
||||
static auto jMethodBoolListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<jbooleanArray>)>("listFrom");
|
||||
size_t n = list.size();
|
||||
auto jArray = facebook::jni::make_boolean_array(n);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
auto index = 0;
|
||||
for (const auto& e : list) {
|
||||
jArrayPinned[index++] = e;
|
||||
}
|
||||
return jMethodBoolListArr(JIValue::javaClassStatic(), jArray);
|
||||
} else if (ivalue.isIntList()) {
|
||||
auto list = ivalue.toIntList();
|
||||
static auto jMethodLongListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<jlongArray>)>("listFrom");
|
||||
size_t n = list.size();
|
||||
auto jArray = facebook::jni::make_long_array(n);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
auto index = 0;
|
||||
for (const auto& e : list) {
|
||||
jArrayPinned[index++] = e;
|
||||
}
|
||||
return jMethodLongListArr(JIValue::javaClassStatic(), jArray);
|
||||
} else if (ivalue.isDoubleList()) {
|
||||
auto list = ivalue.toDoubleList();
|
||||
static auto jMethoDoubleListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<jdoubleArray>)>("listFrom");
|
||||
size_t n = list.size();
|
||||
auto jArray = facebook::jni::make_double_array(n);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
auto index = 0;
|
||||
for (const auto& e : list) {
|
||||
jArrayPinned[index++] = e;
|
||||
}
|
||||
return jMethoDoubleListArr(JIValue::javaClassStatic(), jArray);
|
||||
} else if (ivalue.isTensorList()) {
|
||||
auto list = ivalue.toTensorList();
|
||||
static auto jMethodTensorListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JArrayClass<
|
||||
TensorHybrid::javaobject>::javaobject>)>("listFrom");
|
||||
auto jArray =
|
||||
facebook::jni::JArrayClass<TensorHybrid::javaobject>::newArray(
|
||||
list.size());
|
||||
auto index = 0;
|
||||
for (const auto& e : list) {
|
||||
(*jArray)[index++] = TensorHybrid::newJTensorFromAtTensor(e);
|
||||
}
|
||||
return jMethodTensorListArr(JIValue::javaClassStatic(), jArray);
|
||||
} else if (ivalue.isList()) {
|
||||
auto list = ivalue.toList();
|
||||
static auto jMethodListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject>)>("listFrom");
|
||||
auto jArray =
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::newArray(list.size());
|
||||
auto index = 0;
|
||||
for (const auto& e : list) {
|
||||
(*jArray)[index++] = JIValue::newJIValueFromAtIValue(e);
|
||||
}
|
||||
return jMethodListArr(JIValue::javaClassStatic(), jArray);
|
||||
} else if (ivalue.isGenericDict()) {
|
||||
auto dict = ivalue.toGenericDict();
|
||||
const auto keyType = dict.keyType();
|
||||
|
||||
if (!keyType) {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Unknown IValue-Dict key type");
|
||||
}
|
||||
|
||||
if (*keyType == *c10::StringType::get()) {
|
||||
return stringDictCallback(std::move(dict));
|
||||
} else if (*keyType == *c10::IntType::get()) {
|
||||
return intDictCallback(std::move(dict));
|
||||
}
|
||||
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Unsupported IValue-Dict key type: %s",
|
||||
keyType->str().c_str());
|
||||
}
|
||||
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Unsupported IValue type %s",
|
||||
ivalue.tagKind().c_str());
|
||||
}
|
||||
|
||||
at::IValue JIValue::JIValueToAtIValue(
|
||||
facebook::jni::alias_ref<JIValue> jivalue) {
|
||||
Trace _s{"jni::JIValue::JIValueToAtIValue"};
|
||||
static const auto typeCodeField =
|
||||
JIValue::javaClassStatic()->getField<jint>("mTypeCode");
|
||||
const auto typeCode = jivalue->getFieldValue(typeCodeField);
|
||||
if (JIValue::kTypeCodeNull == typeCode) {
|
||||
return at::IValue{};
|
||||
} else if (JIValue::kTypeCodeTensor == typeCode) {
|
||||
static const auto jMethodGetTensor =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::alias_ref<TensorHybrid::javaobject>()>(
|
||||
"toTensor");
|
||||
return TensorHybrid::newAtTensorFromJTensor(jMethodGetTensor(jivalue));
|
||||
} else if (JIValue::kTypeCodeBool == typeCode) {
|
||||
static const auto jMethodGetBool =
|
||||
JIValue::javaClassStatic()->getMethod<jboolean()>("toBool");
|
||||
// explicit cast to bool as jboolean is defined as uint8_t, IValue ctor
|
||||
// for int will be called for jboolean
|
||||
bool b = jMethodGetBool(jivalue);
|
||||
return at::IValue{b};
|
||||
} else if (JIValue::kTypeCodeLong == typeCode) {
|
||||
static const auto jMethodGetLong =
|
||||
JIValue::javaClassStatic()->getMethod<jlong()>("toLong");
|
||||
return at::IValue{(int64_t)jMethodGetLong(jivalue)};
|
||||
} else if (JIValue::kTypeCodeDouble == typeCode) {
|
||||
static const auto jMethodGetDouble =
|
||||
JIValue::javaClassStatic()->getMethod<jdouble()>("toDouble");
|
||||
return at::IValue{jMethodGetDouble(jivalue)};
|
||||
} else if (JIValue::kTypeCodeString == typeCode) {
|
||||
static const auto jMethodGetString =
|
||||
JIValue::javaClassStatic()->getMethod<jstring()>("toStr");
|
||||
return at::IValue{jMethodGetString(jivalue)->toStdString()};
|
||||
} else if (JIValue::kTypeCodeTuple == typeCode) {
|
||||
static const auto jMethodGetTuple =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject()>(
|
||||
"toTuple");
|
||||
auto jarray = jMethodGetTuple(jivalue);
|
||||
size_t n = jarray->size();
|
||||
|
||||
std::vector<at::IValue> elements;
|
||||
elements.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
auto jivalue_element = jarray->getElement(i);
|
||||
auto element = JIValue::JIValueToAtIValue(jivalue_element);
|
||||
elements.push_back(std::move(element));
|
||||
}
|
||||
return c10::ivalue::Tuple::create(std::move(elements));
|
||||
} else if (JIValue::kTypeCodeBoolList == typeCode) {
|
||||
static const auto jMethodGetBoolList =
|
||||
JIValue::javaClassStatic()->getMethod<jbooleanArray()>("toBoolList");
|
||||
auto jArray = jMethodGetBoolList(jivalue);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
size_t n = jArrayPinned.size();
|
||||
c10::List<bool> list{};
|
||||
list.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
list.push_back(jArrayPinned[i]);
|
||||
}
|
||||
return at::IValue{std::move(list)};
|
||||
} else if (JIValue::kTypeCodeLongList == typeCode) {
|
||||
static const auto jMethodGetLongList =
|
||||
JIValue::javaClassStatic()->getMethod<jlongArray()>("toLongList");
|
||||
auto jArray = jMethodGetLongList(jivalue);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
size_t n = jArrayPinned.size();
|
||||
c10::List<int64_t> list{};
|
||||
list.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
list.push_back(jArrayPinned[i]);
|
||||
}
|
||||
return at::IValue{std::move(list)};
|
||||
} else if (JIValue::kTypeCodeDoubleList == typeCode) {
|
||||
static const auto jMethodGetDoubleList =
|
||||
JIValue::javaClassStatic()->getMethod<jdoubleArray()>("toDoubleList");
|
||||
auto jArray = jMethodGetDoubleList(jivalue);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
size_t n = jArrayPinned.size();
|
||||
c10::List<double> list{};
|
||||
list.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
list.push_back(jArrayPinned[i]);
|
||||
}
|
||||
return at::IValue{std::move(list)};
|
||||
} else if (JIValue::kTypeCodeTensorList == typeCode) {
|
||||
static const auto jMethodGetTensorList =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JArrayClass<
|
||||
TensorHybrid::javaobject>::javaobject()>("toTensorList");
|
||||
auto jArray = jMethodGetTensorList(jivalue);
|
||||
size_t n = jArray->size();
|
||||
c10::List<at::Tensor> list{};
|
||||
list.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
list.push_back(
|
||||
TensorHybrid::newAtTensorFromJTensor(jArray->getElement(i)));
|
||||
}
|
||||
return at::IValue{std::move(list)};
|
||||
} else if (JIValue::kTypeCodeList == typeCode) {
|
||||
static const auto jMethodGetList =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject()>(
|
||||
"toList");
|
||||
auto jarray = jMethodGetList(jivalue);
|
||||
size_t n = jarray->size();
|
||||
if (n == 0) {
|
||||
return at::IValue{c10::impl::GenericList(c10::TensorType::get())};
|
||||
}
|
||||
|
||||
auto jivalue_first_element = jarray->getElement(0);
|
||||
auto first_element = JIValue::JIValueToAtIValue(jivalue_first_element);
|
||||
c10::impl::GenericList list{c10::unshapedType(first_element.type())};
|
||||
list.reserve(n);
|
||||
list.push_back(first_element);
|
||||
for (const auto i : c10::irange(1, n)) {
|
||||
auto jivalue_element = jarray->getElement(i);
|
||||
auto element = JIValue::JIValueToAtIValue(jivalue_element);
|
||||
list.push_back(element);
|
||||
}
|
||||
return at::IValue{list};
|
||||
} else if (JIValue::kTypeCodeDictStringKey == typeCode) {
|
||||
static const auto jMethodGetDictStringKey =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JMap<jstring, JIValue::javaobject>::
|
||||
javaobject()>("toDictStringKey");
|
||||
auto jmap = jMethodGetDictStringKey(jivalue);
|
||||
auto it = jmap->begin();
|
||||
if (it == jmap->end()) {
|
||||
return at::IValue{c10::impl::GenericDict(
|
||||
c10::StringType::get(), c10::TensorType::get())};
|
||||
}
|
||||
|
||||
auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
|
||||
c10::impl::GenericDict dict{
|
||||
c10::StringType::get(), c10::unshapedType(firstEntryValue.type())};
|
||||
dict.insert(it->first->toStdString(), firstEntryValue);
|
||||
it++;
|
||||
for (; it != jmap->end(); it++) {
|
||||
dict.insert(
|
||||
it->first->toStdString(), JIValue::JIValueToAtIValue(it->second));
|
||||
}
|
||||
return at::IValue{dict};
|
||||
} else if (JIValue::kTypeCodeDictLongKey == typeCode) {
|
||||
static const auto jMethodGetDictLongKey =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JMap<
|
||||
facebook::jni::JLong::javaobject,
|
||||
JIValue::javaobject>::javaobject()>("toDictLongKey");
|
||||
auto jmap = jMethodGetDictLongKey(jivalue);
|
||||
auto it = jmap->begin();
|
||||
if (it == jmap->end()) {
|
||||
return at::IValue{
|
||||
c10::impl::GenericDict(c10::IntType::get(), c10::TensorType::get())};
|
||||
}
|
||||
|
||||
auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
|
||||
c10::impl::GenericDict dict{
|
||||
c10::IntType::get(), c10::unshapedType(firstEntryValue.type())};
|
||||
dict.insert((int64_t)it->first->longValue(), firstEntryValue);
|
||||
it++;
|
||||
for (; it != jmap->end(); it++) {
|
||||
dict.insert(
|
||||
(int64_t)it->first->longValue(),
|
||||
JIValue::JIValueToAtIValue(it->second));
|
||||
}
|
||||
return at::IValue{dict};
|
||||
}
|
||||
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Unknown IValue typeCode %d",
|
||||
typeCode);
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
|
||||
public:
|
||||
constexpr static auto kJavaDescriptor = "Lorg/pytorch/PyTorchAndroid;";
|
||||
|
||||
static void registerNatives() {
|
||||
javaClassStatic()->registerNatives({
|
||||
makeNativeMethod(
|
||||
"nativeSetNumThreads", PyTorchAndroidJni::setNumThreads),
|
||||
});
|
||||
}
|
||||
|
||||
static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
|
||||
caffe2::pthreadpool()->set_thread_count(numThreads);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
void common_registerNatives() {
|
||||
static const int once = []() {
|
||||
#if defined(__ANDROID__)
|
||||
pytorch_jni::PyTorchAndroidJni::registerNatives();
|
||||
#endif
|
||||
return 0;
|
||||
}();
|
||||
((void)once);
|
||||
}
|
||||
|
||||
} // namespace pytorch_jni
|
||||
@ -1,137 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/FunctionRef.h>
|
||||
#include <fbjni/fbjni.h>
|
||||
#include <torch/csrc/api/include/torch/types.h>
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
|
||||
#include "cmake_macros.h"
|
||||
|
||||
#ifdef __ANDROID__
|
||||
#include <android/log.h>
|
||||
#define ALOGI(...) \
|
||||
__android_log_print(ANDROID_LOG_INFO, "pytorch-jni", __VA_ARGS__)
|
||||
#define ALOGE(...) \
|
||||
__android_log_print(ANDROID_LOG_ERROR, "pytorch-jni", __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||
#include <android/trace.h>
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
|
||||
namespace pytorch_jni {
|
||||
|
||||
constexpr static int kDeviceCPU = 1;
|
||||
constexpr static int kDeviceVulkan = 2;
|
||||
|
||||
c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode);
|
||||
|
||||
class Trace {
|
||||
public:
|
||||
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||
typedef void* (*fp_ATrace_beginSection)(const char* sectionName);
|
||||
typedef void* (*fp_ATrace_endSection)(void);
|
||||
|
||||
static fp_ATrace_beginSection ATrace_beginSection;
|
||||
static fp_ATrace_endSection ATrace_endSection;
|
||||
#endif
|
||||
|
||||
static void ensureInit() {
|
||||
if (!Trace::is_initialized_) {
|
||||
init();
|
||||
Trace::is_initialized_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
static void beginSection(const char* name) {
|
||||
Trace::ensureInit();
|
||||
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||
ATrace_beginSection(name);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void endSection() {
|
||||
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||
ATrace_endSection();
|
||||
#endif
|
||||
}
|
||||
|
||||
Trace(const char* name) {
|
||||
ensureInit();
|
||||
beginSection(name);
|
||||
}
|
||||
|
||||
~Trace() {
|
||||
endSection();
|
||||
}
|
||||
|
||||
private:
|
||||
static void init();
|
||||
static bool is_initialized_;
|
||||
};
|
||||
|
||||
class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
|
||||
public:
|
||||
explicit MemoryReadAdapter(const void* data, off_t size)
|
||||
: data_(data), size_(size){};
|
||||
|
||||
size_t size() const override {
|
||||
return size_;
|
||||
}
|
||||
|
||||
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
|
||||
const override {
|
||||
memcpy(buf, (int8_t*)(data_) + pos, n);
|
||||
return n;
|
||||
}
|
||||
|
||||
~MemoryReadAdapter() {}
|
||||
|
||||
private:
|
||||
const void* data_;
|
||||
off_t size_;
|
||||
};
|
||||
|
||||
class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
using DictCallback = c10::function_ref<facebook::jni::local_ref<JIValue>(
|
||||
c10::Dict<c10::IValue, c10::IValue>)>;
|
||||
|
||||
public:
|
||||
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;";
|
||||
|
||||
constexpr static int kTypeCodeNull = 1;
|
||||
|
||||
constexpr static int kTypeCodeTensor = 2;
|
||||
constexpr static int kTypeCodeBool = 3;
|
||||
constexpr static int kTypeCodeLong = 4;
|
||||
constexpr static int kTypeCodeDouble = 5;
|
||||
constexpr static int kTypeCodeString = 6;
|
||||
|
||||
constexpr static int kTypeCodeTuple = 7;
|
||||
constexpr static int kTypeCodeBoolList = 8;
|
||||
constexpr static int kTypeCodeLongList = 9;
|
||||
constexpr static int kTypeCodeDoubleList = 10;
|
||||
constexpr static int kTypeCodeTensorList = 11;
|
||||
constexpr static int kTypeCodeList = 12;
|
||||
|
||||
constexpr static int kTypeCodeDictStringKey = 13;
|
||||
constexpr static int kTypeCodeDictLongKey = 14;
|
||||
|
||||
static facebook::jni::local_ref<JIValue> newJIValueFromAtIValue(
|
||||
const at::IValue& ivalue,
|
||||
DictCallback stringDictCallback = newJIValueFromStringDict,
|
||||
DictCallback intDictCallback = newJIValueFromIntDict);
|
||||
|
||||
static at::IValue JIValueToAtIValue(
|
||||
facebook::jni::alias_ref<JIValue> jivalue);
|
||||
|
||||
private:
|
||||
static facebook::jni::local_ref<JIValue> newJIValueFromStringDict(
|
||||
c10::Dict<c10::IValue, c10::IValue>);
|
||||
static facebook::jni::local_ref<JIValue> newJIValueFromIntDict(
|
||||
c10::Dict<c10::IValue, c10::IValue>);
|
||||
};
|
||||
|
||||
void common_registerNatives();
|
||||
} // namespace pytorch_jni
|
||||
@ -1,245 +0,0 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <fbjni/ByteBuffer.h>
|
||||
#include <fbjni/fbjni.h>
|
||||
|
||||
#include <ATen/record_function.h>
|
||||
#include <torch/csrc/jit/runtime/print_handler.h>
|
||||
#include <torch/script.h>
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
|
||||
#include "pytorch_jni_common.h"
|
||||
|
||||
#ifdef __ANDROID__
|
||||
#include <android/asset_manager.h>
|
||||
#include <android/asset_manager_jni.h>
|
||||
#include <android/log.h>
|
||||
#endif
|
||||
|
||||
namespace pytorch_jni {
|
||||
|
||||
namespace {
|
||||
|
||||
struct JITCallGuard {
|
||||
// Inference only workload.
|
||||
c10::InferenceMode guard;
|
||||
// Disable graph optimizer to ensure list of unused ops are not changed for
|
||||
// custom mobile build.
|
||||
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
private:
|
||||
friend HybridBase;
|
||||
torch::jit::Module module_;
|
||||
c10::DeviceType deviceType_;
|
||||
|
||||
public:
|
||||
constexpr static auto kJavaDescriptor = "Lorg/pytorch/NativePeer;";
|
||||
|
||||
static facebook::jni::local_ref<jhybriddata> initHybrid(
|
||||
facebook::jni::alias_ref<jclass>,
|
||||
facebook::jni::alias_ref<jstring> modelPath,
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
|
||||
extraFiles,
|
||||
jint device) {
|
||||
return makeCxxInstance(modelPath, extraFiles, device);
|
||||
}
|
||||
|
||||
#ifdef __ANDROID__
|
||||
static facebook::jni::local_ref<jhybriddata> initHybridAndroidAsset(
|
||||
facebook::jni::alias_ref<jclass>,
|
||||
facebook::jni::alias_ref<jstring> assetName,
|
||||
facebook::jni::alias_ref<jobject> assetManager,
|
||||
jint device) {
|
||||
return makeCxxInstance(assetName, assetManager, device);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef TRACE_ENABLED
|
||||
static std::unique_ptr<at::ObserverContext> onFunctionEnter(
|
||||
const at::RecordFunction& fn) {
|
||||
Trace::beginSection(fn.name().str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static void onFunctionExit(const at::RecordFunction&, at::ObserverContext*) {
|
||||
Trace::endSection();
|
||||
}
|
||||
#endif
|
||||
|
||||
static void preModuleLoadSetupOnce() {
|
||||
auto qengines = at::globalContext().supportedQEngines();
|
||||
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
|
||||
qengines.end()) {
|
||||
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
||||
}
|
||||
|
||||
#ifdef __ANDROID__
|
||||
torch::jit::setPrintHandler([](const std::string& s) {
|
||||
__android_log_print(ANDROID_LOG_DEBUG, "pytorch-print", "%s", s.c_str());
|
||||
});
|
||||
#endif
|
||||
|
||||
#ifdef TRACE_ENABLED
|
||||
at::addGlobalCallback(
|
||||
at::RecordFunctionCallback(&onFunctionEnter, &onFunctionExit)
|
||||
.scopes({RecordScope::FUNCTION, RecordScope::USER_SCOPE}));
|
||||
#endif
|
||||
}
|
||||
|
||||
void preModuleLoadSetup() {
|
||||
static const int once = []() {
|
||||
preModuleLoadSetupOnce();
|
||||
return 0;
|
||||
}();
|
||||
((void)once);
|
||||
}
|
||||
|
||||
PytorchJni(
|
||||
facebook::jni::alias_ref<jstring> modelPath,
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
|
||||
extraFiles,
|
||||
jint device) {
|
||||
preModuleLoadSetup();
|
||||
JITCallGuard guard;
|
||||
std::unordered_map<std::string, std::string> extra_files;
|
||||
const auto has_extra = extraFiles && extraFiles->size() > 0;
|
||||
if (has_extra) {
|
||||
for (const auto& e : *extraFiles) {
|
||||
extra_files[e.first->toStdString()] = "";
|
||||
}
|
||||
}
|
||||
deviceType_ = deviceJniCodeToDeviceType(device);
|
||||
module_ = torch::jit::load(
|
||||
std::move(modelPath->toStdString()), std::nullopt, extra_files);
|
||||
if (has_extra) {
|
||||
static auto putMethod =
|
||||
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>::
|
||||
javaClassStatic()
|
||||
->template getMethod<facebook::jni::alias_ref<jobject>(
|
||||
facebook::jni::alias_ref<jobject>,
|
||||
facebook::jni::alias_ref<jobject>)>("put");
|
||||
for (const auto& ef : extra_files) {
|
||||
putMethod(
|
||||
extraFiles,
|
||||
facebook::jni::make_jstring(ef.first),
|
||||
facebook::jni::make_jstring(ef.second));
|
||||
}
|
||||
}
|
||||
|
||||
module_.eval();
|
||||
}
|
||||
|
||||
#ifdef __ANDROID__
|
||||
PytorchJni(
|
||||
facebook::jni::alias_ref<jstring> assetName,
|
||||
facebook::jni::alias_ref<jobject> assetManager,
|
||||
jint device) {
|
||||
preModuleLoadSetup();
|
||||
JNIEnv* env = facebook::jni::Environment::current();
|
||||
AAssetManager* mgr = AAssetManager_fromJava(env, assetManager.get());
|
||||
if (!mgr) {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Unable to get asset manager");
|
||||
}
|
||||
AAsset* asset = AAssetManager_open(
|
||||
mgr, assetName->toStdString().c_str(), AASSET_MODE_BUFFER);
|
||||
if (!asset) {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Failed to open asset '%s'",
|
||||
assetName->toStdString().c_str());
|
||||
}
|
||||
auto assetBuffer = AAsset_getBuffer(asset);
|
||||
if (!assetBuffer) {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Could not get buffer for asset '%s'",
|
||||
assetName->toStdString().c_str());
|
||||
}
|
||||
JITCallGuard guard;
|
||||
module_ = torch::jit::load(std::make_unique<MemoryReadAdapter>(
|
||||
assetBuffer, AAsset_getLength(asset)));
|
||||
AAsset_close(asset);
|
||||
module_.eval();
|
||||
deviceType_ = deviceJniCodeToDeviceType(device);
|
||||
}
|
||||
#endif
|
||||
|
||||
static void registerNatives() {
|
||||
registerHybrid({
|
||||
makeNativeMethod("initHybrid", PytorchJni::initHybrid),
|
||||
#ifdef __ANDROID__
|
||||
makeNativeMethod(
|
||||
"initHybridAndroidAsset", PytorchJni::initHybridAndroidAsset),
|
||||
#endif
|
||||
makeNativeMethod("forward", PytorchJni::forward),
|
||||
makeNativeMethod("runMethod", PytorchJni::runMethod),
|
||||
});
|
||||
}
|
||||
|
||||
facebook::jni::local_ref<JIValue> forward(
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
|
||||
jinputs) {
|
||||
Trace _s{"jni::Module::forward"};
|
||||
std::vector<at::IValue> inputs{};
|
||||
size_t n = jinputs->size();
|
||||
inputs.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
auto output = [&]() {
|
||||
JITCallGuard guard;
|
||||
return module_.forward(std::move(inputs));
|
||||
}();
|
||||
return JIValue::newJIValueFromAtIValue(output);
|
||||
}
|
||||
|
||||
facebook::jni::local_ref<JIValue> runMethod(
|
||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject> jmethodName,
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
|
||||
jinputs) {
|
||||
std::string methodName = jmethodName->toStdString();
|
||||
|
||||
std::vector<at::IValue> inputs{};
|
||||
size_t n = jinputs->size();
|
||||
inputs.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
if (auto method = module_.find_method(methodName)) {
|
||||
auto output = [&]() {
|
||||
JITCallGuard guard;
|
||||
return (*method)(std::move(inputs));
|
||||
}();
|
||||
return JIValue::newJIValueFromAtIValue(output);
|
||||
}
|
||||
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Undefined method %s",
|
||||
methodName.c_str());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace pytorch_jni
|
||||
|
||||
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
|
||||
return facebook::jni::initialize(vm, [] {
|
||||
pytorch_jni::common_registerNatives();
|
||||
pytorch_jni::PytorchJni::registerNatives();
|
||||
});
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user