mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
2020-06-16 nightly release (bcb44796ba00f9ac5f22e33f1de3a24b277c4c3a)
This commit is contained in:
@ -37,19 +37,6 @@ CONFIG_TREE_DATA = [
|
||||
]),
|
||||
]),
|
||||
]),
|
||||
("android", [
|
||||
("r19c", [
|
||||
("3.6", [
|
||||
("android_abi", [XImportant("x86_32")]),
|
||||
("android_abi", [X("x86_64")]),
|
||||
("android_abi", [X("arm-v7a")]),
|
||||
("android_abi", [X("arm-v8a")]),
|
||||
("vulkan", [
|
||||
("android_abi", [XImportant("x86_32")]),
|
||||
]),
|
||||
])
|
||||
]),
|
||||
]),
|
||||
]),
|
||||
("bionic", [
|
||||
("clang", [
|
||||
@ -137,8 +124,6 @@ class ExperimentalFeatureConfigNode(TreeConfigNode):
|
||||
"libtorch": LibTorchConfigNode,
|
||||
"important": ImportantConfigNode,
|
||||
"build_only": BuildOnlyConfigNode,
|
||||
"android_abi": AndroidAbiConfigNode,
|
||||
"vulkan": VulkanConfigNode,
|
||||
"cuda_gcc_override": CudaGccOverrideConfigNode
|
||||
}
|
||||
return next_nodes[experimental_feature]
|
||||
@ -188,24 +173,6 @@ class LibTorchConfigNode(TreeConfigNode):
|
||||
return ImportantConfigNode
|
||||
|
||||
|
||||
class AndroidAbiConfigNode(TreeConfigNode):
|
||||
|
||||
def init2(self, node_name):
|
||||
self.props["android_abi"] = node_name
|
||||
|
||||
def child_constructor(self):
|
||||
return ImportantConfigNode
|
||||
|
||||
class VulkanConfigNode(TreeConfigNode):
|
||||
def modify_label(self, label):
|
||||
return "Vulkan=" + str(label)
|
||||
|
||||
def init2(self, node_name):
|
||||
self.props["vulkan"] = node_name
|
||||
|
||||
def child_constructor(self):
|
||||
return AndroidAbiConfigNode
|
||||
|
||||
class CudaGccOverrideConfigNode(TreeConfigNode):
|
||||
def init2(self, node_name):
|
||||
self.props["cuda_gcc_override"] = node_name
|
||||
|
||||
92
.circleci/cimodel/data/simple/android_definitions.py
Normal file
92
.circleci/cimodel/data/simple/android_definitions.py
Normal file
@ -0,0 +1,92 @@
|
||||
import cimodel.data.simple.util.branch_filters
|
||||
from cimodel.data.simple.util.docker_constants import DOCKER_IMAGE_NDK
|
||||
|
||||
|
||||
class AndroidJob:
|
||||
def __init__(self,
|
||||
variant,
|
||||
template_name,
|
||||
is_master_only=True):
|
||||
|
||||
self.variant = variant
|
||||
self.template_name = template_name
|
||||
self.is_master_only = is_master_only
|
||||
|
||||
def gen_tree(self):
|
||||
|
||||
base_name_parts = [
|
||||
"pytorch",
|
||||
"linux",
|
||||
"xenial",
|
||||
"py3",
|
||||
"clang5",
|
||||
"android",
|
||||
"ndk",
|
||||
"r19c",
|
||||
] + self.variant + [
|
||||
"build",
|
||||
]
|
||||
|
||||
full_job_name = "_".join(base_name_parts)
|
||||
build_env_name = "-".join(base_name_parts)
|
||||
|
||||
props_dict = {
|
||||
"name": full_job_name,
|
||||
"build_environment": "\"{}\"".format(build_env_name),
|
||||
"docker_image": "\"{}\"".format(DOCKER_IMAGE_NDK),
|
||||
}
|
||||
|
||||
if self.is_master_only:
|
||||
props_dict["filters"] = cimodel.data.simple.util.branch_filters.gen_filter_dict()
|
||||
|
||||
return [{self.template_name: props_dict}]
|
||||
|
||||
|
||||
class AndroidGradleJob:
|
||||
def __init__(self,
|
||||
job_name,
|
||||
template_name,
|
||||
dependencies,
|
||||
is_master_only=True):
|
||||
|
||||
self.job_name = job_name
|
||||
self.template_name = template_name
|
||||
self.dependencies = dependencies
|
||||
self.is_master_only = is_master_only
|
||||
|
||||
def gen_tree(self):
|
||||
|
||||
props_dict = {
|
||||
"name": self.job_name,
|
||||
"requires": self.dependencies,
|
||||
}
|
||||
|
||||
if self.is_master_only:
|
||||
props_dict["filters"] = cimodel.data.simple.util.branch_filters.gen_filter_dict()
|
||||
|
||||
return [{self.template_name: props_dict}]
|
||||
|
||||
|
||||
WORKFLOW_DATA = [
|
||||
AndroidJob(["x86_32"], "pytorch_linux_build", is_master_only=False),
|
||||
AndroidJob(["x86_64"], "pytorch_linux_build"),
|
||||
AndroidJob(["arm", "v7a"], "pytorch_linux_build"),
|
||||
AndroidJob(["arm", "v8a"], "pytorch_linux_build"),
|
||||
AndroidJob(["vulkan", "x86_32"], "pytorch_linux_build", is_master_only=False),
|
||||
AndroidGradleJob(
|
||||
"pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-x86_32",
|
||||
"pytorch_android_gradle_build-x86_32",
|
||||
["pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build"],
|
||||
is_master_only=False),
|
||||
AndroidGradleJob(
|
||||
"pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build",
|
||||
"pytorch_android_gradle_build",
|
||||
["pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build",
|
||||
"pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_64_build",
|
||||
"pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v7a_build",
|
||||
"pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v8a_build"]),
|
||||
]
|
||||
|
||||
|
||||
def get_workflow_jobs():
|
||||
return [item.gen_tree() for item in WORKFLOW_DATA]
|
||||
@ -1,37 +0,0 @@
|
||||
import cimodel.data.simple.util.branch_filters
|
||||
|
||||
|
||||
class AndroidJob:
|
||||
def __init__(self, job_name, template_name, dependencies):
|
||||
self.job_name = job_name
|
||||
self.template_name = template_name
|
||||
self.dependencies = dependencies
|
||||
|
||||
def gen_tree(self):
|
||||
|
||||
props_dict = {
|
||||
"filters": cimodel.data.simple.util.branch_filters.gen_filter_dict(),
|
||||
"name": self.job_name,
|
||||
"requires": self.dependencies,
|
||||
}
|
||||
|
||||
return [{self.template_name: props_dict}]
|
||||
|
||||
|
||||
WORKFLOW_DATA = [
|
||||
AndroidJob(
|
||||
"pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-x86_32",
|
||||
"pytorch_android_gradle_build-x86_32",
|
||||
["pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build"]),
|
||||
AndroidJob(
|
||||
"pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build",
|
||||
"pytorch_android_gradle_build",
|
||||
["pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build",
|
||||
"pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_64_build",
|
||||
"pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v7a_build",
|
||||
"pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v8a_build"]),
|
||||
]
|
||||
|
||||
|
||||
def get_workflow_jobs():
|
||||
return [item.gen_tree() for item in WORKFLOW_DATA]
|
||||
@ -6840,44 +6840,6 @@ workflows:
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7:209062ef-ab58-422a-b295-36c4eed6e906"
|
||||
use_cuda_docker_runtime: "1"
|
||||
resource_class: gpu.medium
|
||||
- pytorch_linux_build:
|
||||
name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build
|
||||
build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_32-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:209062ef-ab58-422a-b295-36c4eed6e906"
|
||||
- pytorch_linux_build:
|
||||
name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_64_build
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
- /release\/.*/
|
||||
build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_64-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:209062ef-ab58-422a-b295-36c4eed6e906"
|
||||
- pytorch_linux_build:
|
||||
name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v7a_build
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
- /release\/.*/
|
||||
build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v7a-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:209062ef-ab58-422a-b295-36c4eed6e906"
|
||||
- pytorch_linux_build:
|
||||
name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v8a_build
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
- /release\/.*/
|
||||
build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v8a-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:209062ef-ab58-422a-b295-36c4eed6e906"
|
||||
- pytorch_linux_build:
|
||||
name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_vulkan_x86_32_build
|
||||
build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-vulkan-x86_32-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:209062ef-ab58-422a-b295-36c4eed6e906"
|
||||
- pytorch_linux_build:
|
||||
name: pytorch_linux_bionic_py3_6_clang9_build
|
||||
build_environment: "pytorch-linux-bionic-py3.6-clang9-build"
|
||||
@ -6917,13 +6879,45 @@ workflows:
|
||||
name: pytorch_macos_10_13_py3_test
|
||||
requires:
|
||||
- pytorch_macos_10_13_py3_build
|
||||
- pytorch_android_gradle_build-x86_32:
|
||||
- pytorch_linux_build:
|
||||
build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_32-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:209062ef-ab58-422a-b295-36c4eed6e906"
|
||||
name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build
|
||||
- pytorch_linux_build:
|
||||
build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_64-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:209062ef-ab58-422a-b295-36c4eed6e906"
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
- /release\/.*/
|
||||
name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_64_build
|
||||
- pytorch_linux_build:
|
||||
build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v7a-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:209062ef-ab58-422a-b295-36c4eed6e906"
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
- /release\/.*/
|
||||
name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v7a_build
|
||||
- pytorch_linux_build:
|
||||
build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v8a-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:209062ef-ab58-422a-b295-36c4eed6e906"
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
- /release\/.*/
|
||||
name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v8a_build
|
||||
- pytorch_linux_build:
|
||||
build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-vulkan-x86_32-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:209062ef-ab58-422a-b295-36c4eed6e906"
|
||||
name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_vulkan_x86_32_build
|
||||
- pytorch_android_gradle_build-x86_32:
|
||||
name: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-x86_32
|
||||
requires:
|
||||
- pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build
|
||||
|
||||
@ -20,7 +20,7 @@ import cimodel.data.simple.bazel_definitions
|
||||
import cimodel.data.simple.ge_config_tests
|
||||
import cimodel.data.simple.binary_smoketest
|
||||
import cimodel.data.simple.ios_definitions
|
||||
import cimodel.data.simple.android_gradle
|
||||
import cimodel.data.simple.android_definitions
|
||||
import cimodel.data.simple.nightly_android
|
||||
import cimodel.data.simple.nightly_ios
|
||||
import cimodel.lib.miniutils as miniutils
|
||||
@ -83,7 +83,7 @@ def gen_build_workflows_tree():
|
||||
build_workflows_functions = [
|
||||
pytorch_build_definitions.get_workflow_jobs,
|
||||
cimodel.data.simple.macos_definitions.get_workflow_jobs,
|
||||
cimodel.data.simple.android_gradle.get_workflow_jobs,
|
||||
cimodel.data.simple.android_definitions.get_workflow_jobs,
|
||||
cimodel.data.simple.ios_definitions.get_workflow_jobs,
|
||||
cimodel.data.simple.mobile_definitions.get_workflow_jobs,
|
||||
cimodel.data.simple.ge_config_tests.get_workflow_jobs,
|
||||
|
||||
@ -6,69 +6,6 @@ TEST_DIR="$ROOT_DIR/caffe2_tests"
|
||||
gtest_reports_dir="${TEST_DIR}/cpp"
|
||||
pytest_reports_dir="${TEST_DIR}/python"
|
||||
|
||||
# This is needed to work around ROCm using old docker images until
|
||||
# the transition to new images is complete.
|
||||
# TODO: Remove once ROCm CI is using new images.
|
||||
if [[ $BUILD_ENVIRONMENT == py3.6-devtoolset7-rocmrpm-centos* ]]; then
|
||||
# This file is sourced multiple times, only install conda the first time.
|
||||
# We must install conda where we have write access.
|
||||
CONDA_DIR="$ROOT_DIR/conda"
|
||||
if [[ ! -d $CONDA_DIR ]]; then
|
||||
ANACONDA_PYTHON_VERSION=3.6
|
||||
BASE_URL="https://repo.anaconda.com/miniconda"
|
||||
CONDA_FILE="Miniconda3-latest-Linux-x86_64.sh"
|
||||
mkdir $CONDA_DIR
|
||||
pushd /tmp
|
||||
wget -q "${BASE_URL}/${CONDA_FILE}"
|
||||
chmod +x "${CONDA_FILE}"
|
||||
./"${CONDA_FILE}" -b -f -p "$CONDA_DIR"
|
||||
popd
|
||||
export PATH="$CONDA_DIR/bin:$PATH"
|
||||
# Ensure we run conda in a directory that jenkins has write access to
|
||||
pushd $CONDA_DIR
|
||||
# Track latest conda update
|
||||
conda update -n base conda
|
||||
# Install correct Python version
|
||||
conda install python="$ANACONDA_PYTHON_VERSION"
|
||||
|
||||
conda_install() {
|
||||
# Ensure that the install command don't upgrade/downgrade Python
|
||||
# This should be called as
|
||||
# conda_install pkg1 pkg2 ... [-c channel]
|
||||
conda install -q -y python="$ANACONDA_PYTHON_VERSION" $*
|
||||
}
|
||||
|
||||
# Install PyTorch conda deps, as per https://github.com/pytorch/pytorch README
|
||||
conda_install numpy pyyaml mkl mkl-include setuptools cffi typing future six
|
||||
|
||||
# TODO: This isn't working atm
|
||||
conda_install nnpack -c killeent
|
||||
|
||||
# Install some other packages
|
||||
|
||||
# Need networkx 2.0 because bellmand_ford was moved in 2.1 . Scikit-image by
|
||||
# defaults installs the most recent networkx version, so we install this lower
|
||||
# version explicitly before scikit-image pulls it in as a dependency
|
||||
pip install networkx==2.0
|
||||
|
||||
# TODO: Why is scipy pinned
|
||||
# numba & llvmlite is pinned because of https://github.com/numba/numba/issues/4368
|
||||
# scikit-learn is pinned because of
|
||||
# https://github.com/scikit-learn/scikit-learn/issues/14485 (affects gcc 5.5
|
||||
# only)
|
||||
pip install --progress-bar off pytest scipy==1.1.0 scikit-learn==0.20.3 scikit-image librosa>=0.6.2 psutil numba==0.46.0 llvmlite==0.30.0
|
||||
|
||||
# click - onnx
|
||||
# hypothesis - tests
|
||||
# jupyter - for tutorials
|
||||
pip install --progress-bar off click hypothesis jupyter protobuf tabulate virtualenv mock typing-extensions
|
||||
|
||||
popd
|
||||
else
|
||||
export PATH="$CONDA_DIR/bin:$PATH"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Figure out which Python to use
|
||||
PYTHON="$(which python)"
|
||||
if [[ "${BUILD_ENVIRONMENT}" =~ py((2|3)\.?[0-9]?\.?[0-9]?) ]]; then
|
||||
|
||||
@ -11,6 +11,15 @@ if [[ "${BUILD_ENVIRONMENT}" == *-rocm* ]]; then
|
||||
# temporary to locate some kernel issues on the CI nodes
|
||||
export HSAKMT_DEBUG_LEVEL=4
|
||||
fi
|
||||
# These additional packages are needed for circleci ROCm builds.
|
||||
if [[ $BUILD_ENVIRONMENT == pytorch-linux-xenial-rocm* ]]; then
|
||||
# Need networkx 2.0 because bellmand_ford was moved in 2.1 . Scikit-image by
|
||||
# defaults installs the most recent networkx version, so we install this lower
|
||||
# version explicitly before scikit-image pulls it in as a dependency
|
||||
pip install networkx==2.0
|
||||
# click - onnx
|
||||
pip install --progress-bar off click protobuf tabulate virtualenv mock typing-extensions
|
||||
fi
|
||||
|
||||
# Find where cpp tests and Caffe2 itself are installed
|
||||
if [[ "$BUILD_ENVIRONMENT" == *cmake* ]]; then
|
||||
@ -71,16 +80,24 @@ if [[ "$BUILD_ENVIRONMENT" == *cmake* ]]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# If pip is installed as root, we must use sudo.
|
||||
# CircleCI docker images could install conda as jenkins user, or use the OS's python package.
|
||||
PIP=$(which pip)
|
||||
PIP_USER=$(stat --format '%U' $PIP)
|
||||
if [[ "$PIP_USER" = root ]]; then
|
||||
MAYBE_SUDO=sudo
|
||||
fi
|
||||
|
||||
# if [[ "$BUILD_ENVIRONMENT" == *ubuntu14.04* ]]; then
|
||||
# Hotfix, use hypothesis 3.44.6 on Ubuntu 14.04
|
||||
# See comments on
|
||||
# https://github.com/HypothesisWorks/hypothesis-python/commit/eadd62e467d6cee6216e71b391951ec25b4f5830
|
||||
sudo pip -q uninstall -y hypothesis
|
||||
$MAYBE_SUDO pip -q uninstall -y hypothesis
|
||||
# "pip install hypothesis==3.44.6" from official server is unreliable on
|
||||
# CircleCI, so we host a copy on S3 instead
|
||||
sudo pip -q install attrs==18.1.0 -f https://s3.amazonaws.com/ossci-linux/wheels/attrs-18.1.0-py2.py3-none-any.whl
|
||||
sudo pip -q install coverage==4.5.1 -f https://s3.amazonaws.com/ossci-linux/wheels/coverage-4.5.1-cp36-cp36m-macosx_10_12_x86_64.whl
|
||||
sudo pip -q install hypothesis==3.44.6 -f https://s3.amazonaws.com/ossci-linux/wheels/hypothesis-3.44.6-py3-none-any.whl
|
||||
$MAYBE_SUDO pip -q install attrs==18.1.0 -f https://s3.amazonaws.com/ossci-linux/wheels/attrs-18.1.0-py2.py3-none-any.whl
|
||||
$MAYBE_SUDO pip -q install coverage==4.5.1 -f https://s3.amazonaws.com/ossci-linux/wheels/coverage-4.5.1-cp36-cp36m-macosx_10_12_x86_64.whl
|
||||
$MAYBE_SUDO pip -q install hypothesis==3.44.6 -f https://s3.amazonaws.com/ossci-linux/wheels/hypothesis-3.44.6-py3-none-any.whl
|
||||
# else
|
||||
# pip install --user --no-cache-dir hypothesis==3.59.0
|
||||
# fi
|
||||
|
||||
@ -326,6 +326,7 @@ filegroup(
|
||||
"aten/src/TH/THStorageFunctions.cpp",
|
||||
"aten/src/TH/THTensor.cpp",
|
||||
"aten/src/TH/THTensorEvenMoreMath.cpp",
|
||||
"aten/src/TH/THTensorFill.cpp",
|
||||
"aten/src/TH/THTensorLapack.cpp",
|
||||
"aten/src/TH/THTensorMath.cpp",
|
||||
"aten/src/TH/THTensorMoreMath.cpp",
|
||||
|
||||
@ -2,83 +2,16 @@
|
||||
set -eux
|
||||
|
||||
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
|
||||
|
||||
PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android
|
||||
WORK_DIR=$PYTORCH_DIR
|
||||
|
||||
echo "PYTORCH_DIR:$PYTORCH_DIR"
|
||||
echo "WORK_DIR:$WORK_DIR"
|
||||
|
||||
echo "ANDROID_HOME:$ANDROID_HOME"
|
||||
if [ ! -z "$ANDROID_HOME" ]; then
|
||||
echo "ANDROID_HOME not set; please set it to Android sdk directory"
|
||||
fi
|
||||
source "$PYTORCH_ANDROID_DIR/common.sh"
|
||||
|
||||
if [ ! -d $ANDROID_HOME ]; then
|
||||
echo "ANDROID_HOME not a directory; did you install it under $ANDROID_HOME?"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
GRADLE_PATH=gradle
|
||||
GRADLE_NOT_FOUND_MSG="Unable to find gradle, please add it to PATH or set GRADLE_HOME"
|
||||
|
||||
if [ ! -x "$(command -v gradle)" ]; then
|
||||
if [ -z "$GRADLE_HOME" ]; then
|
||||
echo GRADLE_NOT_FOUND_MSG
|
||||
exit 1
|
||||
fi
|
||||
GRADLE_PATH=$GRADLE_HOME/bin/gradle
|
||||
if [ ! -f "$GRADLE_PATH" ]; then
|
||||
echo GRADLE_NOT_FOUND_MSG
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
echo "GRADLE_PATH:$GRADLE_PATH"
|
||||
|
||||
ABIS_LIST="armeabi-v7a,arm64-v8a,x86,x86_64"
|
||||
CUSTOM_ABIS_LIST=false
|
||||
if [ $# -gt 0 ]; then
|
||||
ABIS_LIST=$1
|
||||
CUSTOM_ABIS_LIST=true
|
||||
fi
|
||||
|
||||
echo "ABIS_LIST:$ABIS_LIST"
|
||||
|
||||
LIB_DIR=$PYTORCH_ANDROID_DIR/pytorch_android/src/main/jniLibs
|
||||
INCLUDE_DIR=$PYTORCH_ANDROID_DIR/pytorch_android/src/main/cpp/libtorch_include
|
||||
mkdir -p $LIB_DIR
|
||||
rm -f $LIB_DIR/*
|
||||
mkdir -p $INCLUDE_DIR
|
||||
|
||||
for abi in $(echo $ABIS_LIST | tr ',' '\n')
|
||||
do
|
||||
echo "abi:$abi"
|
||||
|
||||
OUT_DIR=$WORK_DIR/build_android_$abi
|
||||
|
||||
rm -rf $OUT_DIR
|
||||
mkdir -p $OUT_DIR
|
||||
|
||||
pushd $PYTORCH_DIR
|
||||
python $PYTORCH_DIR/setup.py clean
|
||||
|
||||
ANDROID_ABI=$abi VERBOSE=1 ANDROID_DEBUG_SYMBOLS=1 $PYTORCH_DIR/scripts/build_android.sh -DANDROID_CCACHE=$(which ccache)
|
||||
|
||||
cp -R $PYTORCH_DIR/build_android/install/lib $OUT_DIR/
|
||||
cp -R $PYTORCH_DIR/build_android/install/include $OUT_DIR/
|
||||
|
||||
echo "$abi build output lib,include copied to $OUT_DIR"
|
||||
|
||||
LIB_LINK_PATH=$LIB_DIR/$abi
|
||||
INCLUDE_LINK_PATH=$INCLUDE_DIR/$abi
|
||||
|
||||
rm -f $LIB_LINK_PATH
|
||||
rm -f $INCLUDE_LINK_PATH
|
||||
|
||||
ln -s $OUT_DIR/lib $LIB_LINK_PATH
|
||||
ln -s $OUT_DIR/include $INCLUDE_LINK_PATH
|
||||
|
||||
done
|
||||
check_android_sdk
|
||||
check_gradle
|
||||
parse_abis_list "$@"
|
||||
build_android
|
||||
|
||||
# To set proxy for gradle add following lines to ./gradle/gradle.properties:
|
||||
# systemProp.http.proxyHost=...
|
||||
@ -95,5 +28,3 @@ fi
|
||||
find $PYTORCH_ANDROID_DIR -type f -name *apk
|
||||
|
||||
find $PYTORCH_ANDROID_DIR -type f -name *apk | xargs echo "To install apk run: $ANDROID_HOME/platform-tools/adb install -r "
|
||||
|
||||
popd
|
||||
|
||||
81
android/common.sh
Normal file
81
android/common.sh
Normal file
@ -0,0 +1,81 @@
|
||||
#!/bin/bash
|
||||
set -eux
|
||||
|
||||
##############################################################################
|
||||
# Common util functions for Android build scripts.
|
||||
##############################################################################
|
||||
|
||||
if [ -z "$PYTORCH_DIR" ]; then
|
||||
echo "PYTORCH_DIR not set!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
check_android_sdk() {
|
||||
if [ -z "$ANDROID_HOME" ]; then
|
||||
echo "ANDROID_HOME not set; please set it to Android sdk directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d "$ANDROID_HOME" ]; then
|
||||
echo "ANDROID_HOME not a directory; did you install it under $ANDROID_HOME?"
|
||||
exit 1
|
||||
fi
|
||||
echo "ANDROID_HOME:$ANDROID_HOME"
|
||||
}
|
||||
|
||||
check_gradle() {
|
||||
GRADLE_PATH=gradle
|
||||
GRADLE_NOT_FOUND_MSG="Unable to find gradle, please add it to PATH or set GRADLE_HOME"
|
||||
|
||||
if [ ! -x "$(command -v gradle)" ]; then
|
||||
if [ -z "$GRADLE_HOME" ]; then
|
||||
echo "$GRADLE_NOT_FOUND_MSG"
|
||||
exit 1
|
||||
fi
|
||||
GRADLE_PATH=$GRADLE_HOME/bin/gradle
|
||||
if [ ! -f "$GRADLE_PATH" ]; then
|
||||
echo "$GRADLE_NOT_FOUND_MSG"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
echo "GRADLE_PATH:$GRADLE_PATH"
|
||||
}
|
||||
|
||||
parse_abis_list() {
|
||||
ABIS_LIST="armeabi-v7a,arm64-v8a,x86,x86_64"
|
||||
CUSTOM_ABIS_LIST=false
|
||||
if [ $# -gt 0 ]; then
|
||||
ABIS_LIST=$1
|
||||
CUSTOM_ABIS_LIST=true
|
||||
fi
|
||||
|
||||
echo "ABIS_LIST:$ABIS_LIST"
|
||||
echo "CUSTOM_ABIS_LIST:$CUSTOM_ABIS_LIST"
|
||||
}
|
||||
|
||||
build_android() {
|
||||
PYTORCH_ANDROID_DIR="$PYTORCH_DIR/android"
|
||||
BUILD_ROOT="${BUILD_ROOT:-$PYTORCH_DIR}"
|
||||
echo "BUILD_ROOT:$BUILD_ROOT"
|
||||
|
||||
LIB_DIR="$PYTORCH_ANDROID_DIR/pytorch_android/src/main/jniLibs"
|
||||
INCLUDE_DIR="$PYTORCH_ANDROID_DIR/pytorch_android/src/main/cpp/libtorch_include"
|
||||
|
||||
# These directories only contain symbolic links.
|
||||
rm -rf "$LIB_DIR" && mkdir -p "$LIB_DIR"
|
||||
rm -rf "$INCLUDE_DIR" && mkdir -p "$INCLUDE_DIR"
|
||||
|
||||
for abi in $(echo "$ABIS_LIST" | tr ',' '\n')
|
||||
do
|
||||
echo "abi:$abi"
|
||||
ANDROID_BUILD_ROOT="$BUILD_ROOT/build_android_$abi"
|
||||
ANDROID_ABI="$abi" \
|
||||
BUILD_ROOT="$ANDROID_BUILD_ROOT" \
|
||||
"$PYTORCH_DIR/scripts/build_android.sh" \
|
||||
-DANDROID_CCACHE="$(which ccache)"
|
||||
|
||||
echo "$abi build output lib,include at $ANDROID_BUILD_ROOT/install"
|
||||
ln -s "$ANDROID_BUILD_ROOT/install/lib" "$LIB_DIR/$abi"
|
||||
ln -s "$ANDROID_BUILD_ROOT/install/include" "$INCLUDE_DIR/$abi"
|
||||
done
|
||||
}
|
||||
@ -1,51 +0,0 @@
|
||||
apply plugin: 'com.android.library'
|
||||
apply plugin: 'maven'
|
||||
|
||||
android {
|
||||
compileSdkVersion rootProject.compileSdkVersion
|
||||
buildToolsVersion rootProject.buildToolsVersion
|
||||
|
||||
defaultConfig {
|
||||
minSdkVersion rootProject.minSdkVersion
|
||||
targetSdkVersion rootProject.targetSdkVersion
|
||||
|
||||
sourceSets {
|
||||
main {
|
||||
manifest.srcFile '../fbjni/java/com/facebook/jni/AndroidManifest.xml'
|
||||
java {
|
||||
srcDir '../fbjni/java'
|
||||
}
|
||||
}
|
||||
}
|
||||
ndk {
|
||||
abiFilters ABI_FILTERS.split(",")
|
||||
}
|
||||
}
|
||||
buildTypes {
|
||||
debug {
|
||||
minifyEnabled false
|
||||
}
|
||||
release {
|
||||
minifyEnabled false
|
||||
}
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path "../fbjni/CMakeLists.txt"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
compileOnly 'com.google.code.findbugs:jsr305:3.0.1'
|
||||
implementation 'com.facebook.soloader:nativeloader:0.8.0'
|
||||
}
|
||||
|
||||
apply from: rootProject.file('gradle/release.gradle')
|
||||
|
||||
task sourcesJar(type: Jar) {
|
||||
from android.sourceSets.main.java.srcDirs
|
||||
classifier = 'sources'
|
||||
}
|
||||
|
||||
artifacts.add('archives', sourcesJar)
|
||||
@ -1,4 +0,0 @@
|
||||
POM_NAME=pytorch_android_fbjni
|
||||
POM_DESCRIPTION=pytorch_android_fbjni
|
||||
POM_ARTIFACT_ID=pytorch_android_fbjni
|
||||
POM_PACKAGING=aar
|
||||
@ -4,42 +4,10 @@ set -eux
|
||||
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
|
||||
PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android
|
||||
|
||||
echo "ANDROID_HOME:$ANDROID_HOME"
|
||||
if [ ! -z "$ANDROID_HOME" ]; then
|
||||
echo "ANDROID_HOME not set; please set it to Android sdk directory"
|
||||
fi
|
||||
|
||||
if [ ! -d $ANDROID_HOME ]; then
|
||||
echo "ANDROID_HOME not a directory; did you install it under $ANDROID_HOME?"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "ANDROID_NDK:$ANDROID_NDK"
|
||||
if [ ! -z "$ANDROID_NDK" ]; then
|
||||
echo "ANDROID_NDK not set; please set it to Android sdk directory"
|
||||
fi
|
||||
|
||||
if [ ! -d $ANDROID_NDK ]; then
|
||||
echo "ANDROID_NDK not a directory; did you install it under $ANDROID_NDK?"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
GRADLE_PATH=gradle
|
||||
GRADLE_NOT_FOUND_MSG="Unable to find gradle, please add it to PATH or set GRADLE_HOME"
|
||||
|
||||
if [ ! -x "$(command -v gradle)" ]; then
|
||||
if [ -z "$GRADLE_HOME" ]; then
|
||||
echo GRADLE_NOT_FOUND_MSG
|
||||
exit 1
|
||||
fi
|
||||
GRADLE_PATH=$GRADLE_HOME/bin/gradle
|
||||
if [ ! -f "$GRADLE_PATH" ]; then
|
||||
echo GRADLE_NOT_FOUND_MSG
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
echo "GRADLE_PATH:$GRADLE_PATH"
|
||||
source "$PYTORCH_ANDROID_DIR/common.sh"
|
||||
|
||||
check_android_sdk
|
||||
check_gradle
|
||||
|
||||
# Run android instrumented tests on x86 emulator
|
||||
|
||||
@ -70,7 +38,7 @@ cat <<- EOF
|
||||
$ANDROID_HOME/platform-tools/adb devices
|
||||
|
||||
If everything is ok the output will be:
|
||||
|
||||
|
||||
List of devices attached
|
||||
emulator-5554 device
|
||||
EOF
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
include ':app', ':pytorch_android', ':fbjni', ':pytorch_android_torchvision', ':pytorch_host', ':test_app'
|
||||
include ':app', ':pytorch_android', ':pytorch_android_torchvision', ':pytorch_host', ':test_app'
|
||||
|
||||
project(':fbjni').projectDir = file('libs/fbjni_local')
|
||||
project(':pytorch_android_torchvision').projectDir = file('pytorch_android_torchvision')
|
||||
|
||||
project(':pytorch_host').projectDir = file('pytorch_android/host')
|
||||
|
||||
@ -341,21 +341,6 @@ apply_op(int64_t numel, int64_t offset, const Op& op, Args... iters) {
|
||||
b_val[0] * c_val[0]; };
|
||||
*/
|
||||
|
||||
template <typename scalar1, typename Op>
|
||||
inline void CPU_tensor_apply1(Tensor tensor1, const Op op) {
|
||||
if (!_apply_preamble({tensor1}))
|
||||
return;
|
||||
if (tensor1.ndimension() < 8) {
|
||||
apply_op(
|
||||
tensor1.numel(),
|
||||
0,
|
||||
op,
|
||||
strided_tensor_iter_fixed<scalar1, 8>(tensor1, true));
|
||||
} else {
|
||||
apply_op(tensor1.numel(), 0, op, strided_tensor_iter<scalar1>(tensor1));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar1, typename scalar2, typename Op>
|
||||
inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
|
||||
if (!_apply_preamble({tensor1, tensor2}))
|
||||
|
||||
@ -26,8 +26,12 @@ inline void parallel_for(
|
||||
#ifdef _OPENMP
|
||||
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
|
||||
std::exception_ptr eptr;
|
||||
// Work around memory leak when using 1 thread in nested "omp parallel"
|
||||
// caused by some buggy OpenMP versions and the fact that omp_in_parallel()
|
||||
// returns false when omp_get_max_threads() == 1 inside nested "omp parallel"
|
||||
// See issue gh-32284
|
||||
|
||||
#pragma omp parallel if (!omp_in_parallel() && ((end - begin) > grain_size))
|
||||
#pragma omp parallel if (omp_get_max_threads() > 1 && !omp_in_parallel() && ((end - begin) > grain_size))
|
||||
{
|
||||
// choose number of tasks based on grain size and number of threads
|
||||
// can't use num_threads clause due to bugs in GOMP's thread pool (See #32008)
|
||||
|
||||
@ -126,6 +126,67 @@ struct uniform_real_distribution {
|
||||
T to_;
|
||||
};
|
||||
|
||||
// The SFINAE checks introduced in #39816 looks overcomplicated and must revisited
|
||||
// https://github.com/pytorch/pytorch/issues/40052
|
||||
#define DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(member) \
|
||||
template <typename T> \
|
||||
struct has_member_##member \
|
||||
{ \
|
||||
typedef char yes; \
|
||||
typedef long no; \
|
||||
template <typename U> static yes test(decltype(&U::member)); \
|
||||
template <typename U> static no test(...); \
|
||||
static constexpr bool value = sizeof(test<T>(0)) == sizeof(yes); \
|
||||
}
|
||||
|
||||
DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_double_normal_sample);
|
||||
DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_double_normal_sample);
|
||||
DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_float_normal_sample);
|
||||
DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_float_normal_sample);
|
||||
|
||||
#define DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(TYPE) \
|
||||
\
|
||||
template <typename RNG, typename ret_type, \
|
||||
typename std::enable_if_t<( \
|
||||
has_member_next_##TYPE##_normal_sample<RNG>::value && \
|
||||
has_member_set_next_##TYPE##_normal_sample<RNG>::value \
|
||||
), int> = 0> \
|
||||
C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* generator, ret_type* ret) { \
|
||||
if (generator->next_##TYPE##_normal_sample()) { \
|
||||
*ret = *(generator->next_##TYPE##_normal_sample()); \
|
||||
generator->set_next_##TYPE##_normal_sample(c10::optional<TYPE>()); \
|
||||
return true; \
|
||||
} \
|
||||
return false; \
|
||||
} \
|
||||
\
|
||||
template <typename RNG, typename ret_type, \
|
||||
typename std::enable_if_t<( \
|
||||
!has_member_next_##TYPE##_normal_sample<RNG>::value || \
|
||||
!has_member_set_next_##TYPE##_normal_sample<RNG>::value \
|
||||
), int> = 0> \
|
||||
C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* generator, ret_type* ret) { \
|
||||
return false; \
|
||||
} \
|
||||
\
|
||||
template <typename RNG, typename ret_type, \
|
||||
typename std::enable_if_t<( \
|
||||
has_member_set_next_##TYPE##_normal_sample<RNG>::value \
|
||||
), int> = 0> \
|
||||
C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* generator, ret_type cache) { \
|
||||
generator->set_next_##TYPE##_normal_sample(cache); \
|
||||
} \
|
||||
\
|
||||
template <typename RNG, typename ret_type, \
|
||||
typename std::enable_if_t<( \
|
||||
!has_member_set_next_##TYPE##_normal_sample<RNG>::value \
|
||||
), int> = 0> \
|
||||
C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* generator, ret_type cache) { \
|
||||
}
|
||||
|
||||
DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(double);
|
||||
DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(float);
|
||||
|
||||
/**
|
||||
* Samples a normal distribution using the Box-Muller method
|
||||
* Takes mean and standard deviation as inputs
|
||||
@ -144,41 +205,29 @@ struct normal_distribution {
|
||||
template <typename RNG>
|
||||
C10_HOST_DEVICE inline dist_acctype<T> operator()(RNG generator){
|
||||
dist_acctype<T> ret;
|
||||
#if !defined(__CUDACC__) && !defined(__HIPCC__)
|
||||
// return cached values if available
|
||||
if (std::is_same<T, double>::value) {
|
||||
if (generator->next_double_normal_sample()) {
|
||||
ret = *(generator->next_double_normal_sample()) * stdv + mean;
|
||||
// reset c10::optional to null
|
||||
generator->set_next_double_normal_sample(c10::optional<double>());
|
||||
return ret;
|
||||
if (maybe_get_next_double_normal_sample(generator, &ret)) {
|
||||
return transformation::normal(ret, mean, stdv);
|
||||
}
|
||||
} else {
|
||||
if (generator->next_float_normal_sample()) {
|
||||
ret = *(generator->next_float_normal_sample()) * stdv + mean;
|
||||
// reset c10::optional to null
|
||||
generator->set_next_float_normal_sample(c10::optional<float>());
|
||||
return ret;
|
||||
if (maybe_get_next_float_normal_sample(generator, &ret)) {
|
||||
return transformation::normal(ret, mean, stdv);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
// otherwise generate new normal values
|
||||
uniform_real_distribution<T> uniform(0.0, 1.0);
|
||||
const dist_acctype<T> u1 = uniform(generator);
|
||||
const dist_acctype<T> u2 = uniform(generator);
|
||||
const dist_acctype<T> r = ::sqrt(static_cast<T>(-2.0) * ::log(static_cast<T>(1.0)-u2));
|
||||
const dist_acctype<T> theta = static_cast<T>(2.0) * static_cast<T>(M_PI) * u1;
|
||||
#if !defined(__CUDACC__) && !defined(__HIPCC__)
|
||||
if (std::is_same<T, double>::value) {
|
||||
dist_acctype<double> cache = r * ::sin(theta);
|
||||
generator->set_next_double_normal_sample(c10::optional<double>(cache));
|
||||
maybe_set_next_double_normal_sample(generator, r * ::sin(theta));
|
||||
} else {
|
||||
dist_acctype<float> cache = r * ::sin(theta);
|
||||
generator->set_next_float_normal_sample(c10::optional<float>(cache));
|
||||
maybe_set_next_float_normal_sample(generator, r * ::sin(theta));
|
||||
}
|
||||
#endif
|
||||
ret = transformation::normal(r * ::cos(theta), mean, stdv);
|
||||
return ret;
|
||||
ret = r * ::cos(theta);
|
||||
return transformation::normal(ret, mean, stdv);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@ -51,7 +51,11 @@ using supported_primitive_arg_types = guts::typelist::typelist<
|
||||
/* everything is ok, this is a primitive type */
|
||||
}, /* else */ [] {
|
||||
auto tmap = c10::getCustomClassTypeMap();
|
||||
TORCH_CHECK(c10::isCustomClassRegistered<T>(), "Tried to use undefined class as input argument");
|
||||
TORCH_CHECK(
|
||||
c10::isCustomClassRegistered<T>(),
|
||||
"Tried to use undefined class ",
|
||||
c10::util::get_fully_qualified_type_name<T>(),
|
||||
" as input argument");
|
||||
});
|
||||
}
|
||||
};
|
||||
@ -140,7 +144,7 @@ using supported_primitive_arg_types = guts::typelist::typelist<
|
||||
/* everything is ok, this is a primitive type */
|
||||
}, /* else */ [] {
|
||||
auto tmap = getCustomClassTypeMap();
|
||||
TORCH_CHECK(c10::isCustomClassRegistered<T>(), "Tried to use undefined class as output");
|
||||
TORCH_CHECK(c10::isCustomClassRegistered<T>(), "Tried to use undefined class ", c10::util::get_fully_qualified_type_name<T>(), " as output");
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@ -318,6 +318,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
|
||||
const IValue& constValue() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
AT_ASSERT(completed());
|
||||
AT_ASSERT(!error_);
|
||||
return value_;
|
||||
}
|
||||
|
||||
@ -366,6 +367,11 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
|
||||
return completed_;
|
||||
}
|
||||
|
||||
bool hasValue() const {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
return completed_ && !error_;
|
||||
}
|
||||
|
||||
bool hasError() const {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
return error_ ? true : false;
|
||||
|
||||
@ -25,6 +25,19 @@ class Context;
|
||||
// NB: Class must live in `at` due to limitations of Registry.h.
|
||||
namespace at {
|
||||
|
||||
#ifdef _MSC_VER
|
||||
constexpr const char* CUDA_HELP =
|
||||
"PyTorch splits its backend into two shared libraries: a CPU library "
|
||||
"and a CUDA library; this error has occurred because you are trying "
|
||||
"to use some CUDA functionality, but the CUDA library has not been "
|
||||
"loaded by the dynamic linker for some reason. The CUDA library MUST "
|
||||
"be loaded, EVEN IF you don't directly use any symbols from the CUDA library! "
|
||||
"One common culprit is a lack of -INCLUDE:?warp_size@cuda@at@@YAHXZ "
|
||||
"in your link arguments; many dynamic linkers will delete dynamic library "
|
||||
"dependencies if you don't depend on any of their symbols. You can check "
|
||||
"if this has occurred by using link on your binary to see if there is a "
|
||||
"dependency on *_cuda.dll library.";
|
||||
#else
|
||||
constexpr const char* CUDA_HELP =
|
||||
"PyTorch splits its backend into two shared libraries: a CPU library "
|
||||
"and a CUDA library; this error has occurred because you are trying "
|
||||
@ -36,6 +49,7 @@ constexpr const char* CUDA_HELP =
|
||||
"depend on any of their symbols. You can check if this has occurred by "
|
||||
"using ldd on your binary to see if there is a dependency on *_cuda.so "
|
||||
"library.";
|
||||
#endif
|
||||
|
||||
// The CUDAHooksInterface is an omnibus interface for any CUDA functionality
|
||||
// which we may want to call into from CPU code (and thus must be dynamically
|
||||
|
||||
@ -314,21 +314,23 @@ void bernoulli_kernel(Tensor& self, const Tensor& p_, RNG generator) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(generator->mutex_);
|
||||
using self_t = scalar_t;
|
||||
auto p = std::get<0>(expand_inplace(self, p_.to(kCPU)));
|
||||
auto iter = TensorIterator();
|
||||
iter.add_output(self);
|
||||
iter.add_input(p);
|
||||
iter.check_all_same_dtype(false);
|
||||
iter.build();
|
||||
if (p_.scalar_type() == kDouble) {
|
||||
auto p = std::get<0>(expand_inplace(self, p_.to(kCPU)));
|
||||
CPU_tensor_apply2<self_t, double>(
|
||||
self, p, [generator](self_t& ret_val, double& p_val) {
|
||||
at::bernoulli_distribution<double> bernoulli(p_val);
|
||||
ret_val = static_cast<self_t>(bernoulli(generator));
|
||||
});
|
||||
cpu_serial_kernel(iter, [&](const double p_val) -> self_t {
|
||||
at::bernoulli_distribution<double> bernoulli(p_val);
|
||||
return static_cast<self_t>(bernoulli(generator));
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES(p_.scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
|
||||
auto p = std::get<0>(expand_inplace(self, p_.to(kCPU)));
|
||||
using p_t = scalar_t;
|
||||
CPU_tensor_apply2<self_t, p_t>(
|
||||
self, p, [generator](self_t& ret_val, p_t& p_val) {
|
||||
at::bernoulli_distribution<float> bernoulli(p_val);
|
||||
ret_val = static_cast<self_t>(bernoulli(generator));
|
||||
cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
|
||||
at::bernoulli_distribution<float> bernoulli(p_val);
|
||||
return static_cast<self_t>(bernoulli(generator));
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -340,11 +342,11 @@ void bernoulli_kernel(Tensor& self, double p, RNG generator) {
|
||||
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(generator->mutex_);
|
||||
CPU_tensor_apply1<scalar_t>(
|
||||
self, [generator, p](scalar_t& ret_val) {
|
||||
at::bernoulli_distribution<double> bernoulli(p);
|
||||
ret_val = static_cast<scalar_t>(bernoulli(generator));
|
||||
});
|
||||
auto iter = TensorIterator::nullary_op(self);
|
||||
cpu_serial_kernel(iter, [p, generator]() -> scalar_t {
|
||||
at::bernoulli_distribution<double> bernoulli(p);
|
||||
return static_cast<scalar_t>(bernoulli(generator));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -76,6 +76,7 @@ static Tensor _mkldnn_pool2d(
|
||||
bool ceil_mode,
|
||||
ideep::algorithm algo) {
|
||||
auto kernel_size_vec = expand_param_if_needed(kernel_size, "kernel_size", 2);
|
||||
if (stride.empty()) stride = kernel_size;
|
||||
auto stride_vec = expand_param_if_needed(stride, "stride", 2);
|
||||
auto padding_vec = expand_param_if_needed(padding, "padding", 2);
|
||||
auto padding_vec_l = padding_vec;
|
||||
|
||||
@ -69,8 +69,8 @@ Tensor q_batch_norm2d_impl(
|
||||
TORCH_CHECK(weight.numel() == C, "Expect weight size to match C");
|
||||
TORCH_CHECK(bias.numel() == C, "Expect weight size to match C");
|
||||
|
||||
const float* weight_data = weight.template data<float>();
|
||||
const float* bias_data = bias.template data<float>();
|
||||
const float* weight_data = weight.template data_ptr<float>();
|
||||
const float* bias_data = bias.template data_ptr<float>();
|
||||
|
||||
TORCH_CHECK(mean.numel() == C, "Mean size must match channel dimension");
|
||||
TORCH_CHECK(var.numel() == C, "Variance size must match channel dimension");
|
||||
@ -80,8 +80,8 @@ Tensor q_batch_norm2d_impl(
|
||||
float* alpha_data = alpha.data_ptr<float>();
|
||||
float* beta_data = beta.data_ptr<float>();
|
||||
|
||||
const float* mean_data = mean.template data<float>();
|
||||
const float* var_data = var.template data<float>();
|
||||
const float* mean_data = mean.template data_ptr<float>();
|
||||
const float* var_data = var.template data_ptr<float>();
|
||||
|
||||
auto oSizes = qx.sizes();
|
||||
auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast);
|
||||
@ -165,8 +165,8 @@ Tensor q_batch_norm3d_impl(
|
||||
TORCH_CHECK(weight.numel() == C, "Expect weight size to match C");
|
||||
TORCH_CHECK(bias.numel() == C, "Expect weight size to match C");
|
||||
|
||||
const float* weight_data = weight.template data<float>();
|
||||
const float* bias_data = bias.template data<float>();
|
||||
const float* weight_data = weight.template data_ptr<float>();
|
||||
const float* bias_data = bias.template data_ptr<float>();
|
||||
|
||||
TORCH_CHECK(mean.numel() == C, "Mean size must match channel dimension");
|
||||
TORCH_CHECK(var.numel() == C, "Variance size must match channel dimension");
|
||||
@ -176,8 +176,8 @@ Tensor q_batch_norm3d_impl(
|
||||
float* alpha_data = alpha.data_ptr<float>();
|
||||
float* beta_data = beta.data_ptr<float>();
|
||||
|
||||
const float* mean_data = mean.template data<float>();
|
||||
const float* var_data = var.template data<float>();
|
||||
const float* mean_data = mean.template data_ptr<float>();
|
||||
const float* var_data = var.template data_ptr<float>();
|
||||
|
||||
auto oSizes = qx.sizes();
|
||||
auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast3d);
|
||||
@ -230,6 +230,30 @@ Tensor q_batch_norm3d_impl(
|
||||
return qy;
|
||||
}
|
||||
|
||||
template <bool ReluFused>
|
||||
Tensor q_batch_norm_impl(
|
||||
Tensor qx,
|
||||
c10::optional<Tensor> mb_weight,
|
||||
c10::optional<Tensor> mb_bias,
|
||||
Tensor mean,
|
||||
Tensor var,
|
||||
double eps,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) {
|
||||
Tensor qy;
|
||||
int64_t dim = qx.dim();
|
||||
if (dim == 4) {
|
||||
qy = q_batch_norm2d_impl<ReluFused>(
|
||||
qx, mb_weight, mb_bias, mean, var, eps, output_scale, output_zero_point);
|
||||
} else if (dim == 5) {
|
||||
qy = q_batch_norm3d_impl<ReluFused>(
|
||||
qx, mb_weight, mb_bias, mean, var, eps, output_scale, output_zero_point);
|
||||
} else {
|
||||
TORCH_CHECK(false, "quantized::batch_norm only support 4d or 5d inputs.");
|
||||
}
|
||||
return qy;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor quantized_batch_norm(
|
||||
@ -252,6 +276,8 @@ Tensor quantized_batch_norm(
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
|
||||
m.impl("batch_norm", q_batch_norm_impl<false>);
|
||||
m.impl("batch_norm_relu", q_batch_norm_impl<true>);
|
||||
m.impl("batch_norm2d", q_batch_norm2d_impl<false>);
|
||||
m.impl("batch_norm2d_relu", q_batch_norm2d_impl<true>);
|
||||
m.impl("batch_norm3d", q_batch_norm3d_impl<false>);
|
||||
|
||||
@ -30,6 +30,14 @@ TORCH_LIBRARY(quantized, m) {
|
||||
m.def("add_scalar.Tensor(Tensor qa, Tensor b) -> Tensor qc");
|
||||
m.def("add_scalar_relu.Tensor(Tensor qa, Tensor b) -> Tensor qc");
|
||||
m.def("add_scalar_relu_out.Tensor(Tensor qa, Tensor b, Tensor(a!) out) -> Tensor(a!) out");
|
||||
// This is needed for graph mode quantization, when we fuse
|
||||
// dequant - aten::batch_norm - quant into quantized::batch_norm
|
||||
// and dimension is unknown given only the aten op call
|
||||
// quantized::batch_norm supports both 2d and 3d batch norm right now
|
||||
// it should also support 1d batch_norm after quantized::batch_norm1d is
|
||||
// implemented
|
||||
m.def("batch_norm(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
|
||||
m.def("batch_norm_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
|
||||
m.def("batch_norm2d(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
|
||||
m.def("batch_norm2d_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
|
||||
m.def("batch_norm3d(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
|
||||
|
||||
@ -15,6 +15,7 @@ set(ATen_TH_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/THStorageFunctions.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/THTensor.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/THTensorRandom.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/THTensorFill.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/THTensorMath.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/THTensorMoreMath.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/THTensorEvenMoreMath.cpp
|
||||
@ -98,6 +99,8 @@ install(FILES
|
||||
generic/THTensor.cpp
|
||||
generic/THTensor.h
|
||||
generic/THTensor.hpp
|
||||
generic/THTensorFill.cpp
|
||||
generic/THTensorFill.h
|
||||
generic/THTensorLapack.cpp
|
||||
generic/THTensorLapack.h
|
||||
generic/THTensorMath.cpp
|
||||
|
||||
@ -42,6 +42,19 @@
|
||||
#include <TH/generic/THTensorMath.h>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
|
||||
/* fill and zero*/
|
||||
#include <TH/generic/THTensorFill.h>
|
||||
#include <TH/THGenerateAllTypes.h>
|
||||
|
||||
#include <TH/generic/THTensorFill.h>
|
||||
#include <TH/THGenerateHalfType.h>
|
||||
|
||||
#include <TH/generic/THTensorFill.h>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
||||
#include <TH/generic/THTensorFill.h>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
|
||||
/* lapack support */
|
||||
#include <TH/generic/THTensorLapack.h>
|
||||
#include <TH/THGenerateFloatTypes.h>
|
||||
|
||||
14
aten/src/TH/THTensorFill.cpp
Normal file
14
aten/src/TH/THTensorFill.cpp
Normal file
@ -0,0 +1,14 @@
|
||||
#include <TH/THTensor.hpp>
|
||||
#include <TH/THVector.h>
|
||||
|
||||
#include <TH/generic/THTensorFill.cpp>
|
||||
#include <TH/THGenerateAllTypes.h>
|
||||
|
||||
#include <TH/generic/THTensorFill.cpp>
|
||||
#include <TH/THGenerateHalfType.h>
|
||||
|
||||
#include <TH/generic/THTensorFill.cpp>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
||||
#include <TH/generic/THTensorFill.cpp>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
@ -332,7 +332,7 @@ void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, scalar
|
||||
{
|
||||
tSlice = THTensor_(new)();
|
||||
THTensor_(select)(tSlice, tensor,dim,index_data[i]);
|
||||
THTensor_wrap(tSlice).fill_(val);
|
||||
THTensor_(fill)(tSlice, val);
|
||||
c10::raw::intrusive_ptr::decref(tSlice);
|
||||
}
|
||||
else
|
||||
|
||||
30
aten/src/TH/generic/THTensorFill.cpp
Normal file
30
aten/src/TH/generic/THTensorFill.cpp
Normal file
@ -0,0 +1,30 @@
|
||||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "TH/generic/THTensorFill.cpp"
|
||||
#else
|
||||
|
||||
#include <TH/generic/THTensorApply.hpp>
|
||||
|
||||
void THTensor_(fill)(THTensor *r_, scalar_t value)
|
||||
{
|
||||
if (THTensor_(isContiguous)(r_) || THTensor_(isTransposed)(r_)) {
|
||||
TH_TENSOR_APPLY_CONTIG(scalar_t, r_, THVector_(fill)(r__data, value, r__len););
|
||||
} else {
|
||||
TH_TENSOR_APPLY(scalar_t, r_,
|
||||
if (r__stride == 1) {
|
||||
THVector_(fill)(r__data, value, r__size);
|
||||
r__i = r__size;
|
||||
r__data += r__stride * r__size;
|
||||
break;
|
||||
} else {
|
||||
*r__data = value;
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void THTensor_(zero)(THTensor *r_)
|
||||
{
|
||||
THTensor_(fill)(r_, 0);
|
||||
}
|
||||
|
||||
#endif
|
||||
8
aten/src/TH/generic/THTensorFill.h
Normal file
8
aten/src/TH/generic/THTensorFill.h
Normal file
@ -0,0 +1,8 @@
|
||||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "TH/generic/THTensorFill.h"
|
||||
#else
|
||||
|
||||
TH_API void THTensor_(fill)(THTensor *r_, scalar_t value);
|
||||
TH_API void THTensor_(zero)(THTensor *r_);
|
||||
|
||||
#endif
|
||||
@ -211,7 +211,7 @@ void THTensor_(addr)(THTensor *r_, THTensor *t, THTensor *vec1, THTensor *vec2,
|
||||
}
|
||||
|
||||
if(beta == 0) {
|
||||
THTensor_wrap(r_).zero_();
|
||||
THTensor_(zero)(r_);
|
||||
}
|
||||
else if(beta != 1)
|
||||
THTensor_(mul)(r_, r_, beta);
|
||||
|
||||
@ -734,7 +734,7 @@ void THTensor_(histc)(THTensor *hist, THTensor *tensor, int64_t nbins, scalar_t
|
||||
scalar_t *h_data;
|
||||
|
||||
THTensor_(resize1d)(hist, nbins);
|
||||
THTensor_wrap(hist).zero_();
|
||||
THTensor_(zero)(hist);
|
||||
minval = minvalue;
|
||||
maxval = maxvalue;
|
||||
if (minval == maxval)
|
||||
|
||||
@ -18,7 +18,6 @@ const int kSmallTensorSize = 1;
|
||||
const float kSampingProb = 0.1;
|
||||
}
|
||||
|
||||
using namespace torch::autograd;
|
||||
|
||||
void setupCallbacks() {
|
||||
// non-sampled callback
|
||||
|
||||
@ -21,10 +21,20 @@ C10_DEFINE_TEST(TestExponential, IPi) {
|
||||
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> e_i_pi = ::exp(c10::complex<float>(0, float(PI)));
|
||||
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
|
||||
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> e_i_pi = std::exp(c10::complex<double>(0, PI));
|
||||
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
|
||||
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> e_i_pi = ::exp(c10::complex<double>(0, PI));
|
||||
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
|
||||
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestExponential, EulerFormula) {
|
||||
@ -38,6 +48,14 @@ C10_DEFINE_TEST(TestExponential, EulerFormula) {
|
||||
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> e = ::exp(x);
|
||||
float expected_real = ::exp(x.real()) * ::cos(x.imag());
|
||||
float expected_imag = ::exp(x.real()) * ::sin(x.imag());
|
||||
C10_ASSERT_NEAR(e.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> e = std::exp(x);
|
||||
float expected_real = std::exp(x.real()) * std::cos(x.imag());
|
||||
@ -45,6 +63,14 @@ C10_DEFINE_TEST(TestExponential, EulerFormula) {
|
||||
C10_ASSERT_NEAR(e.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> e = ::exp(x);
|
||||
float expected_real = ::exp(x.real()) * ::cos(x.imag());
|
||||
float expected_imag = ::exp(x.real()) * ::sin(x.imag());
|
||||
C10_ASSERT_NEAR(e.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestLog, Definition) {
|
||||
@ -58,6 +84,14 @@ C10_DEFINE_TEST(TestLog, Definition) {
|
||||
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(1.2, 3.4);
|
||||
c10::complex<float> l = ::log(x);
|
||||
float expected_real = ::log(std::abs(x));
|
||||
float expected_imag = std::arg(x);
|
||||
C10_ASSERT_NEAR(l.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(1.2, 3.4);
|
||||
c10::complex<double> l = std::log(x);
|
||||
float expected_real = std::log(std::abs(x));
|
||||
@ -65,6 +99,14 @@ C10_DEFINE_TEST(TestLog, Definition) {
|
||||
C10_ASSERT_NEAR(l.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(1.2, 3.4);
|
||||
c10::complex<double> l = ::log(x);
|
||||
float expected_real = ::log(std::abs(x));
|
||||
float expected_imag = std::arg(x);
|
||||
C10_ASSERT_NEAR(l.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestLog10, Rev) {
|
||||
@ -76,11 +118,23 @@ C10_DEFINE_TEST(TestLog10, Rev) {
|
||||
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> l = ::log10(::pow(float(10), x));
|
||||
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> l = std::log10(std::pow(double(10), x));
|
||||
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> l = ::log10(::pow(double(10), x));
|
||||
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
|
||||
}
|
||||
}
|
||||
|
||||
// Power functions
|
||||
@ -95,12 +149,26 @@ C10_DEFINE_TEST(TestPowSqrt, Equal) {
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::pow(x, float(0.5));
|
||||
c10::complex<float> z = ::sqrt(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::pow(x, double(0.5));
|
||||
c10::complex<double> z = std::sqrt(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::pow(x, double(0.5));
|
||||
c10::complex<double> z = ::sqrt(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestPow, Square) {
|
||||
@ -113,12 +181,26 @@ C10_DEFINE_TEST(TestPow, Square) {
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::pow(x, float(2));
|
||||
c10::complex<float> z = x * x;
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::pow(x, double(2));
|
||||
c10::complex<double> z = x * x;
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::pow(x, double(2));
|
||||
c10::complex<double> z = x * x;
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
}
|
||||
|
||||
// Trigonometric functions and hyperbolic functions
|
||||
@ -136,6 +218,14 @@ C10_DEFINE_TEST(TestSinCosSinhCosh, Identity) {
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::sin(x);
|
||||
float expected_real = ::sin(x.real()) * ::cosh(x.imag());
|
||||
float expected_imag = ::cos(x.real()) * ::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = std::cos(x);
|
||||
float expected_real = std::cos(x.real()) * std::cosh(x.imag());
|
||||
float expected_imag = - std::sin(x.real()) * std::sinh(x.imag());
|
||||
@ -143,6 +233,14 @@ C10_DEFINE_TEST(TestSinCosSinhCosh, Identity) {
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::cos(x);
|
||||
float expected_real = ::cos(x.real()) * ::cosh(x.imag());
|
||||
float expected_imag = - ::sin(x.real()) * ::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::sin(x);
|
||||
float expected_real = std::sin(x.real()) * std::cosh(x.imag());
|
||||
@ -152,12 +250,28 @@ C10_DEFINE_TEST(TestSinCosSinhCosh, Identity) {
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::sin(x);
|
||||
float expected_real = ::sin(x.real()) * ::cosh(x.imag());
|
||||
float expected_imag = ::cos(x.real()) * ::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::cos(x);
|
||||
float expected_real = std::cos(x.real()) * std::cosh(x.imag());
|
||||
float expected_imag = - std::sin(x.real()) * std::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::cos(x);
|
||||
float expected_real = ::cos(x.real()) * ::cosh(x.imag());
|
||||
float expected_imag = - ::sin(x.real()) * ::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestTan, Identity) {
|
||||
@ -170,12 +284,26 @@ C10_DEFINE_TEST(TestTan, Identity) {
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::tan(x);
|
||||
c10::complex<float> z = ::sin(x) / ::cos(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::tan(x);
|
||||
c10::complex<double> z = std::sin(x) / std::cos(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::tan(x);
|
||||
c10::complex<double> z = ::sin(x) / ::cos(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestTanh, Identity) {
|
||||
@ -188,12 +316,26 @@ C10_DEFINE_TEST(TestTanh, Identity) {
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::tanh(x);
|
||||
c10::complex<float> z = ::sinh(x) / ::cosh(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::tanh(x);
|
||||
c10::complex<double> z = std::sinh(x) / std::cosh(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::tanh(x);
|
||||
c10::complex<double> z = ::sinh(x) / ::cosh(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
}
|
||||
|
||||
// Rev trigonometric functions
|
||||
@ -218,6 +360,21 @@ C10_DEFINE_TEST(TestRevTrigonometric, Rev) {
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.5, 0.6);
|
||||
c10::complex<float> s = ::sin(x);
|
||||
c10::complex<float> ss = ::asin(s);
|
||||
c10::complex<float> c = ::cos(x);
|
||||
c10::complex<float> cc = ::acos(c);
|
||||
c10::complex<float> t = ::tan(x);
|
||||
c10::complex<float> tt = ::atan(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.5, 0.6);
|
||||
c10::complex<double> s = std::sin(x);
|
||||
c10::complex<double> ss = std::asin(s);
|
||||
@ -232,6 +389,21 @@ C10_DEFINE_TEST(TestRevTrigonometric, Rev) {
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.5, 0.6);
|
||||
c10::complex<double> s = ::sin(x);
|
||||
c10::complex<double> ss = ::asin(s);
|
||||
c10::complex<double> c = ::cos(x);
|
||||
c10::complex<double> cc = ::acos(c);
|
||||
c10::complex<double> t = ::tan(x);
|
||||
c10::complex<double> tt = ::atan(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
}
|
||||
|
||||
// Rev hyperbolic functions
|
||||
@ -256,6 +428,21 @@ C10_DEFINE_TEST(TestRevHyperbolic, Rev) {
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.5, 0.6);
|
||||
c10::complex<float> s = ::sinh(x);
|
||||
c10::complex<float> ss = ::asinh(s);
|
||||
c10::complex<float> c = ::cosh(x);
|
||||
c10::complex<float> cc = ::acosh(c);
|
||||
c10::complex<float> t = ::tanh(x);
|
||||
c10::complex<float> tt = ::atanh(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.5, 0.6);
|
||||
c10::complex<double> s = std::sinh(x);
|
||||
c10::complex<double> ss = std::asinh(s);
|
||||
@ -270,4 +457,19 @@ C10_DEFINE_TEST(TestRevHyperbolic, Rev) {
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.5, 0.6);
|
||||
c10::complex<double> s = ::sinh(x);
|
||||
c10::complex<double> ss = ::asinh(s);
|
||||
c10::complex<double> c = ::cosh(x);
|
||||
c10::complex<double> cc = ::acosh(c);
|
||||
c10::complex<double> t = ::tanh(x);
|
||||
c10::complex<double> tt = ::atanh(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
#endif
|
||||
|
||||
|
||||
namespace std {
|
||||
namespace c10_complex_math {
|
||||
|
||||
// Exponential functions
|
||||
|
||||
@ -213,4 +213,44 @@ C10_HOST_DEVICE inline c10::complex<T> atanh(const c10::complex<T> &x) {
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace std
|
||||
} // namespace c10_complex_math
|
||||
|
||||
using c10_complex_math::exp;
|
||||
using c10_complex_math::log;
|
||||
using c10_complex_math::log10;
|
||||
using c10_complex_math::sqrt;
|
||||
using c10_complex_math::pow;
|
||||
using c10_complex_math::sin;
|
||||
using c10_complex_math::cos;
|
||||
using c10_complex_math::tan;
|
||||
using c10_complex_math::asin;
|
||||
using c10_complex_math::acos;
|
||||
using c10_complex_math::atan;
|
||||
using c10_complex_math::sinh;
|
||||
using c10_complex_math::cosh;
|
||||
using c10_complex_math::tanh;
|
||||
using c10_complex_math::asinh;
|
||||
using c10_complex_math::acosh;
|
||||
using c10_complex_math::atanh;
|
||||
|
||||
namespace std {
|
||||
|
||||
using c10_complex_math::exp;
|
||||
using c10_complex_math::log;
|
||||
using c10_complex_math::log10;
|
||||
using c10_complex_math::sqrt;
|
||||
using c10_complex_math::pow;
|
||||
using c10_complex_math::sin;
|
||||
using c10_complex_math::cos;
|
||||
using c10_complex_math::tan;
|
||||
using c10_complex_math::asin;
|
||||
using c10_complex_math::acos;
|
||||
using c10_complex_math::atan;
|
||||
using c10_complex_math::sinh;
|
||||
using c10_complex_math::cosh;
|
||||
using c10_complex_math::tanh;
|
||||
using c10_complex_math::asinh;
|
||||
using c10_complex_math::acosh;
|
||||
using c10_complex_math::atanh;
|
||||
|
||||
}
|
||||
|
||||
@ -40,7 +40,7 @@ REGISTER_CPU_OPERATOR(RangeFill, RangeFillOp<float, CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(LengthsRangeFill, LengthsRangeFillOp<CPUContext>);
|
||||
|
||||
OPERATOR_SCHEMA(ConstantFill)
|
||||
.NumInputs(0, 1)
|
||||
.NumInputs(0, 2)
|
||||
.NumOutputs(1)
|
||||
.AllowInplace({{0, 0}})
|
||||
.TensorInferenceFunction(FillerTensorInference<>)
|
||||
@ -65,6 +65,8 @@ provided, a shape argument should not be set)
|
||||
containing the desired output shape (the dimensions specified in `extra_shape`
|
||||
will also be appended)
|
||||
|
||||
- If a second input V is passed, fill the output with the first element of V
|
||||
|
||||
When specifying `dtype` argument, use the integer keys from the *DataType* enum
|
||||
in TensorProto:
|
||||
|
||||
|
||||
@ -77,9 +77,14 @@ class FillerOp : public Operator<Context> {
|
||||
"data type int64_t");
|
||||
CAFFE_ENFORCE(input.numel() > 0);
|
||||
auto* shape_data = input.template data<int64_t>();
|
||||
std::unique_ptr<int64_t[]> shape_data_copy = std::make_unique<int64_t[]>(input.dim32(0));
|
||||
context_.template CopyToCPU<int64_t>(input.dim32(0), shape_data, shape_data_copy.get());
|
||||
shape.insert(shape.end(), shape_data_copy.get(), shape_data_copy.get() + input.dim32(0));
|
||||
std::unique_ptr<int64_t[]> shape_data_copy =
|
||||
std::make_unique<int64_t[]>(input.dim32(0));
|
||||
context_.template CopyToCPU<int64_t>(
|
||||
input.dim32(0), shape_data, shape_data_copy.get());
|
||||
shape.insert(
|
||||
shape.end(),
|
||||
shape_data_copy.get(),
|
||||
shape_data_copy.get() + input.dim32(0));
|
||||
}
|
||||
} else {
|
||||
auto& input = Input(0);
|
||||
@ -295,6 +300,15 @@ class ConstantFillOp final : public FillerOp<Context> {
|
||||
template <typename T>
|
||||
bool FillWithType(Tensor* output) {
|
||||
T value = this->template GetSingleArgument<T>("value", 0);
|
||||
if (InputSize() == 2) {
|
||||
auto& value_vec = Input(1);
|
||||
if (value_vec) {
|
||||
CAFFE_ENFORCE_EQ(
|
||||
value_vec.size(), 1, "value vector must have 1 element");
|
||||
value = value_vec.template data<T>()[0];
|
||||
}
|
||||
}
|
||||
|
||||
auto* data = output->template mutable_data<T>();
|
||||
if (output->numel()) {
|
||||
math::Set<T, Context>(output->numel(), value, data, &context_);
|
||||
@ -303,6 +317,8 @@ class ConstantFillOp final : public FillerOp<Context> {
|
||||
}
|
||||
|
||||
bool FillWithString(Tensor* output) {
|
||||
CAFFE_ENFORCE_LT(
|
||||
InputSize(), 2, "constant fill string from tensor is not supported");
|
||||
auto value = this->template GetSingleArgument<std::string>("value", "");
|
||||
auto* data = output->template mutable_data<std::string>();
|
||||
for (int i = 0; i < output->numel(); ++i) {
|
||||
|
||||
@ -1906,6 +1906,34 @@ class TestOperators(hu.HypothesisTestCase):
|
||||
out, = self.assertReferenceChecks(gc, op, inputs, ref)
|
||||
self.assertEqual(dtype, out.dtype)
|
||||
|
||||
@given(data=_dtypes(dtypes=[np.int32, np.int64, np.float32, np.bool]).
|
||||
flatmap(lambda dtype: hu.tensor(
|
||||
min_dim=1, dtype=dtype, elements=hu.elements_of_type(dtype))),
|
||||
**hu.gcs)
|
||||
def test_constant_fill_from_tensor(self, data, gc, dc):
|
||||
dtype = data.dtype.type
|
||||
if data.dtype == np.dtype(np.bool):
|
||||
dtype = np.bool
|
||||
|
||||
value = np.array([data.item(0)], dtype=dtype)
|
||||
inputs = [data, value]
|
||||
enum_type = _NUMPY_TYPE_TO_ENUM[dtype]
|
||||
|
||||
op = core.CreateOperator(
|
||||
'ConstantFill',
|
||||
["X", "V"],
|
||||
["Y"],
|
||||
dtype=enum_type,
|
||||
)
|
||||
|
||||
def ref(x, v):
|
||||
outputs = np.full(shape=data.shape, fill_value=value[0], dtype=dtype)
|
||||
return [outputs]
|
||||
|
||||
self.assertDeviceChecks(dc, op, inputs, [0])
|
||||
out, = self.assertReferenceChecks(gc, op, inputs, ref)
|
||||
self.assertEqual(dtype, out.dtype)
|
||||
|
||||
@given(t=st.integers(1, 5),
|
||||
n=st.integers(1, 5),
|
||||
d=st.integers(1, 5))
|
||||
|
||||
@ -1,6 +1,17 @@
|
||||
#!/bin/bash
|
||||
set -eux
|
||||
|
||||
##############################################################################
|
||||
# Master script to build PyTorch Android library with Java bindings.
|
||||
##############################################################################
|
||||
# Example usage:
|
||||
# - Build default AARs:
|
||||
# scipts/build_pytorch_androis.sh
|
||||
#
|
||||
# - Build for specific ABI(s):
|
||||
# scipts/build_pytorch_androis.sh armeabi-v7a
|
||||
# scipts/build_pytorch_androis.sh arm64-v8a,x86,x86_64
|
||||
#
|
||||
# Script's workflow:
|
||||
# 1. Builds libtorch for android for specified android abisi (by default for all 4).
|
||||
# Custom list of android abis can be specified as a bash argument as comma separated list.
|
||||
@ -12,82 +23,16 @@ set -eux
|
||||
# gradle assembleRelease
|
||||
|
||||
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
|
||||
|
||||
PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android
|
||||
WORK_DIR=$PYTORCH_DIR
|
||||
|
||||
echo "PYTORCH_DIR:$PYTORCH_DIR"
|
||||
echo "WORK_DIR:$WORK_DIR"
|
||||
|
||||
echo "ANDROID_HOME:$ANDROID_HOME"
|
||||
if [ -z "$ANDROID_HOME" ]; then
|
||||
echo "ANDROID_HOME not set; please set it to Android sdk directory"
|
||||
fi
|
||||
source "$PYTORCH_ANDROID_DIR/common.sh"
|
||||
|
||||
if [ ! -d $ANDROID_HOME ]; then
|
||||
echo "ANDROID_HOME not a directory; did you install it under $ANDROID_HOME?"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
GRADLE_PATH=gradle
|
||||
GRADLE_NOT_FOUND_MSG="Unable to find gradle, please add it to PATH or set GRADLE_HOME"
|
||||
|
||||
if [ ! -x "$(command -v gradle)" ]; then
|
||||
if [ -z "$GRADLE_HOME" ]; then
|
||||
echo GRADLE_NOT_FOUND_MSG
|
||||
exit 1
|
||||
fi
|
||||
GRADLE_PATH=$GRADLE_HOME/bin/gradle
|
||||
if [ ! -f "$GRADLE_PATH" ]; then
|
||||
echo GRADLE_NOT_FOUND_MSG
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
echo "GRADLE_PATH:$GRADLE_PATH"
|
||||
|
||||
ABIS_LIST="armeabi-v7a,arm64-v8a,x86,x86_64"
|
||||
CUSTOM_ABIS_LIST=false
|
||||
if [ $# -gt 0 ]; then
|
||||
ABIS_LIST=$1
|
||||
CUSTOM_ABIS_LIST=true
|
||||
fi
|
||||
|
||||
echo "ABIS_LIST:$ABIS_LIST"
|
||||
|
||||
LIB_DIR=$PYTORCH_ANDROID_DIR/pytorch_android/src/main/jniLibs
|
||||
INCLUDE_DIR=$PYTORCH_ANDROID_DIR/pytorch_android/src/main/cpp/libtorch_include
|
||||
mkdir -p $LIB_DIR
|
||||
mkdir -p $INCLUDE_DIR
|
||||
|
||||
for abi in $(echo $ABIS_LIST | tr ',' '\n')
|
||||
do
|
||||
echo "abi:$abi"
|
||||
|
||||
OUT_DIR=$WORK_DIR/build_android_$abi
|
||||
|
||||
rm -rf $OUT_DIR
|
||||
mkdir -p $OUT_DIR
|
||||
|
||||
pushd $PYTORCH_DIR
|
||||
python $PYTORCH_DIR/setup.py clean
|
||||
|
||||
ANDROID_ABI=$abi $PYTORCH_DIR/scripts/build_android.sh -DANDROID_CCACHE=$(which ccache)
|
||||
|
||||
cp -R $PYTORCH_DIR/build_android/install/lib $OUT_DIR/
|
||||
cp -R $PYTORCH_DIR/build_android/install/include $OUT_DIR/
|
||||
|
||||
echo "$abi build output lib,include copied to $OUT_DIR"
|
||||
|
||||
LIB_LINK_PATH=$LIB_DIR/$abi
|
||||
INCLUDE_LINK_PATH=$INCLUDE_DIR/$abi
|
||||
|
||||
rm -f $LIB_LINK_PATH
|
||||
rm -f $INCLUDE_LINK_PATH
|
||||
|
||||
ln -s $OUT_DIR/lib $LIB_LINK_PATH
|
||||
ln -s $OUT_DIR/include $INCLUDE_LINK_PATH
|
||||
|
||||
done
|
||||
check_android_sdk
|
||||
check_gradle
|
||||
parse_abis_list "$@"
|
||||
build_android
|
||||
|
||||
# To set proxy for gradle add following lines to ./gradle/gradle.properties:
|
||||
# systemProp.http.proxyHost=...
|
||||
@ -102,4 +47,3 @@ else
|
||||
fi
|
||||
|
||||
find $PYTORCH_ANDROID_DIR -type f -name *aar | xargs ls -lah
|
||||
popd
|
||||
|
||||
@ -113,6 +113,25 @@ graph(%0 : Tensor,
|
||||
%10 : int = prim::Constant[value=1]()
|
||||
%11 : Tensor = aten::add(%5, %3, %10)
|
||||
return (%11)
|
||||
)IR");
|
||||
}
|
||||
{
|
||||
checkRoundtrip(R"IR(
|
||||
graph(%0 : Tensor,
|
||||
%1 : Tensor,
|
||||
%2 : Tensor):
|
||||
%3 : int = prim::Constant[value=-1]()
|
||||
%4 : Tensor = aten::add(%0, %1, %3)
|
||||
%5 : Tensor = prim::If(%2)
|
||||
block0():
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%7 : Tensor = aten::add(%1, %3, %6)
|
||||
%8 : int = prim::Constant[value=1]()
|
||||
%9 : Tensor = aten::add(%7, %3, %8)
|
||||
-> (%9)
|
||||
%10 : int = prim::Constant[value=-987]()
|
||||
%11 : Tensor = aten::add(%5, %3, %10)
|
||||
return (%11)
|
||||
)IR");
|
||||
}
|
||||
{
|
||||
|
||||
@ -1881,11 +1881,13 @@ void testFutures() {
|
||||
{
|
||||
auto f1 = c10::make_intrusive<Future>(IntType::get());
|
||||
ASSERT_FALSE(f1->completed());
|
||||
ASSERT_FALSE(f1->hasValue());
|
||||
int32_t sat1 = 0;
|
||||
int32_t sat2 = 0;
|
||||
f1->addCallback([&]() { ++sat1; });
|
||||
f1->markCompleted(43);
|
||||
ASSERT_TRUE(f1->completed());
|
||||
ASSERT_TRUE(f1->hasValue());
|
||||
ASSERT_FALSE(f1->hasError());
|
||||
ASSERT_EQ(sat1, 1);
|
||||
ASSERT_EQ(f1->constValue().toInt(), 43);
|
||||
@ -1905,17 +1907,13 @@ void testFutures() {
|
||||
ASSERT_EQ(sat1, 1);
|
||||
ASSERT_TRUE(f1->completed());
|
||||
ASSERT_TRUE(f1->hasError());
|
||||
ASSERT_FALSE(f1->hasValue());
|
||||
try {
|
||||
(void)f1->value();
|
||||
ASSERT_TRUE(false); // Supposed to throw.
|
||||
} catch (const std::exception& e) {
|
||||
ASSERT_TRUE(strcmp(e.what(), "Failed") == 0);
|
||||
}
|
||||
try {
|
||||
(void)f1->constValue();
|
||||
} catch (const std::exception& e) {
|
||||
ASSERT_TRUE(false); // Not supposed to throw.
|
||||
}
|
||||
f1->addCallback([&]() { ++sat2; });
|
||||
ASSERT_EQ(sat1, 1);
|
||||
ASSERT_EQ(sat2, 1);
|
||||
|
||||
0
test/distributed/nn/api/__init__.py
Normal file
0
test/distributed/nn/api/__init__.py
Normal file
21
test/distributed/nn/api/test_remote_module_spawn.py
Normal file
21
test/distributed/nn/api/test_remote_module_spawn.py
Normal file
@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python3
|
||||
import unittest
|
||||
|
||||
from torch.testing._internal.common_distributed import MultiProcessTestCase
|
||||
from torch.testing._internal.common_utils import TEST_WITH_ASAN, run_tests
|
||||
from torch.testing._internal.distributed.nn.api.remote_module_test import (
|
||||
RemoteModuleTest,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues"
|
||||
)
|
||||
class RemoteModuleTestWithSpawn(MultiProcessTestCase, RemoteModuleTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
0
test/distributed/nn/jit/__init__.py
Normal file
0
test/distributed/nn/jit/__init__.py
Normal file
66
test/distributed/nn/jit/test_instantiator.py
Normal file
66
test/distributed/nn/jit/test_instantiator.py
Normal file
@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python3
|
||||
import unittest
|
||||
import pathlib
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.distributed.nn.jit import instantiator
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
@torch.jit.interface
|
||||
class MyModuleInterface:
|
||||
def forward(
|
||||
self, tensor: Tensor, number: int, word: str = "default"
|
||||
) -> Tuple[Tensor, int, str]:
|
||||
pass
|
||||
|
||||
|
||||
class MyModule(nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
def create_module():
|
||||
return MyModule()
|
||||
|
||||
|
||||
class TestInstantiator(unittest.TestCase):
|
||||
def test_get_arg_return_types_from_interface(self):
|
||||
(
|
||||
args_str,
|
||||
arg_types_str,
|
||||
return_type_str,
|
||||
) = instantiator.get_arg_return_types_from_interface(MyModuleInterface)
|
||||
self.assertEqual(args_str, "tensor, number, word")
|
||||
self.assertEqual(arg_types_str, "tensor: Tensor, number: int, word: str")
|
||||
self.assertEqual(return_type_str, "Tuple[Tensor, int, str]")
|
||||
|
||||
def test_instantiate_scripted_remote_module_template(self):
|
||||
generated_module = instantiator.instantiate_scriptable_remote_module_template(
|
||||
MyModuleInterface
|
||||
)
|
||||
self.assertTrue(hasattr(generated_module, "_remote_forward"))
|
||||
self.assertTrue(hasattr(generated_module, "_generated_methods"))
|
||||
|
||||
dir_path = pathlib.Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH)
|
||||
num_files_before_cleanup = len(list(dir_path.iterdir()))
|
||||
instantiator.cleanup_generated_modules()
|
||||
num_files_after_cleanup = len(list(dir_path.iterdir()))
|
||||
# This test assumes single RPC worker group in the same time.
|
||||
self.assertGreater(num_files_before_cleanup, num_files_after_cleanup)
|
||||
|
||||
def test_instantiate_non_scripted_remote_module_template(self):
|
||||
generated_module = instantiator.instantiate_non_scriptable_remote_module_template()
|
||||
self.assertTrue(hasattr(generated_module, "_remote_forward"))
|
||||
self.assertTrue(hasattr(generated_module, "_generated_methods"))
|
||||
|
||||
dir_path = pathlib.Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH)
|
||||
num_files_before_cleanup = len(list(dir_path.iterdir()))
|
||||
instantiator.cleanup_generated_modules()
|
||||
num_files_after_cleanup = len(list(dir_path.iterdir()))
|
||||
# This test assumes single RPC worker group in the same time.
|
||||
self.assertGreater(num_files_before_cleanup, num_files_after_cleanup)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -2184,7 +2184,7 @@ class TestQuantizeScriptPTSQOps(QuantizationTestCase):
|
||||
FileCheck().check_count("aten::dequantize", 1, exactly=True) \
|
||||
.run(m.graph)
|
||||
|
||||
def test_quantize_general_shape_ops(self):
|
||||
def test_general_shape_ops(self):
|
||||
""" A test that checks dequantize will be swapped for
|
||||
all supported general shape ops like aten::flatten
|
||||
without actually checking for execution of these ops
|
||||
@ -2210,8 +2210,14 @@ class TestQuantizeScriptPTSQOps(QuantizationTestCase):
|
||||
x = x.reshape([-1])
|
||||
x = x.resize_(1, 1, x.numel())
|
||||
x = x.view(-1)
|
||||
# prim::ListConstruct
|
||||
xs = [x, x]
|
||||
y, x = xs
|
||||
# prim::ListUnpack
|
||||
x, y = xs
|
||||
# prim::TupleConstruct
|
||||
xs = (x, x)
|
||||
# prim::TupleUnpack
|
||||
x, y = xs
|
||||
x = x.transpose(1, 2)
|
||||
x = x.contiguous()
|
||||
x, y = torch.chunk(x, 2)
|
||||
@ -2251,7 +2257,7 @@ class TestQuantizeScriptPTSQOps(QuantizationTestCase):
|
||||
.check("aten::dequantize") \
|
||||
.run(m.graph)
|
||||
|
||||
def test_quantize_general_value_ops(self):
|
||||
def test_general_value_ops(self):
|
||||
""" A test that checks correct patterns are produced for
|
||||
all supported general value ops like aten::avg_pool2d \
|
||||
without actually checking for execution of these ops
|
||||
|
||||
@ -68,6 +68,9 @@ TESTS = [
|
||||
'test_overrides',
|
||||
'test_jit_fuser_te',
|
||||
'test_tensorexpr',
|
||||
'test_openmp',
|
||||
'distributed/nn/jit/test_instantiator',
|
||||
'distributed/nn/api/test_remote_module_spawn',
|
||||
'distributed/rpc/faulty_agent/test_dist_autograd_spawn',
|
||||
'distributed/rpc/faulty_agent/test_rpc_spawn',
|
||||
'distributed/rpc/jit/test_dist_autograd_spawn',
|
||||
@ -86,6 +89,8 @@ TESTS = [
|
||||
]
|
||||
|
||||
WINDOWS_BLACKLIST = [
|
||||
'distributed/nn/jit/test_instantiator',
|
||||
'distributed/nn/api/test_remote_module_spawn',
|
||||
'distributed/rpc/faulty_agent/test_dist_autograd_spawn',
|
||||
'distributed/rpc/faulty_agent/test_rpc_spawn',
|
||||
'distributed/rpc/jit/test_dist_autograd_spawn',
|
||||
@ -101,6 +106,8 @@ WINDOWS_BLACKLIST = [
|
||||
]
|
||||
|
||||
ROCM_BLACKLIST = [
|
||||
'distributed/nn/jit/test_instantiator',
|
||||
'distributed/nn/api/test_remote_module_spawn',
|
||||
'distributed/rpc/faulty_agent/test_dist_autograd_spawn',
|
||||
'distributed/rpc/faulty_agent/test_rpc_spawn',
|
||||
'distributed/rpc/jit/test_dist_autograd_spawn',
|
||||
@ -147,6 +154,8 @@ SLOW_TESTS = [
|
||||
'test_jit',
|
||||
'test_jit_profiling',
|
||||
'test_torch',
|
||||
'distributed/nn/jit/test_instantiator',
|
||||
'distributed/nn/api/test_remote_module_spawn',
|
||||
'distributed/test_distributed',
|
||||
'distributed/rpc/tensorpipe/test_dist_autograd_spawn',
|
||||
'distributed/rpc/tensorpipe/test_dist_optimizer_spawn',
|
||||
|
||||
@ -1214,6 +1214,15 @@ except RuntimeError as e:
|
||||
seeds.add(batch[0])
|
||||
self.assertEqual(len(seeds), num_workers)
|
||||
|
||||
def test_worker_seed_reproducibility(self):
|
||||
def get_dataloader():
|
||||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=torch.Generator().manual_seed(42))
|
||||
|
||||
num_workers = 6
|
||||
batch_size = 1
|
||||
dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
|
||||
self.assertEqual(set(int(batch) for batch in get_dataloader()), set(int(batch) for batch in get_dataloader()))
|
||||
|
||||
def test_worker_init_fn(self):
|
||||
dataset = SeedDataset(4)
|
||||
dataloader = DataLoader(dataset, batch_size=2, num_workers=2,
|
||||
@ -1241,6 +1250,13 @@ except RuntimeError as e:
|
||||
def test_shuffle_batch(self):
|
||||
self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True))
|
||||
|
||||
def test_shuffle_reproducibility(self):
|
||||
for fn in (
|
||||
lambda: DataLoader(self.dataset, shuffle=True, num_workers=0, generator=torch.Generator().manual_seed(42)),
|
||||
lambda: DataLoader(self.dataset, shuffle=True, num_workers=2, generator=torch.Generator().manual_seed(42)),
|
||||
):
|
||||
self.assertEqual(list(fn()), list(fn()))
|
||||
|
||||
def test_sequential_workers(self):
|
||||
self._test_sequential(DataLoader(self.dataset, num_workers=4))
|
||||
|
||||
@ -1253,7 +1269,7 @@ except RuntimeError as e:
|
||||
def test_shuffle_batch_workers(self):
|
||||
self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4))
|
||||
|
||||
def test_RandomSampler(self):
|
||||
def test_random_sampler(self):
|
||||
|
||||
from collections import Counter
|
||||
from torch.utils.data import RandomSampler
|
||||
@ -1330,6 +1346,20 @@ except RuntimeError as e:
|
||||
|
||||
self.assertEqual(scanned_data.size(), scanned_data.unique().size())
|
||||
|
||||
def test_sampler_reproducibility(self):
|
||||
from torch.utils.data import RandomSampler, WeightedRandomSampler, SubsetRandomSampler
|
||||
|
||||
weights = [0.1, 0.9, 0.4, 0.7, 3.0, 0.6]
|
||||
for fn in (
|
||||
lambda: RandomSampler(self.dataset, num_samples=5, replacement=True, generator=torch.Generator().manual_seed(42)),
|
||||
lambda: RandomSampler(self.dataset, replacement=False, generator=torch.Generator().manual_seed(42)),
|
||||
lambda: WeightedRandomSampler(weights, num_samples=5, replacement=True, generator=torch.Generator().manual_seed(42)),
|
||||
lambda: WeightedRandomSampler(weights, num_samples=5, replacement=False, generator=torch.Generator().manual_seed(42)),
|
||||
lambda: SubsetRandomSampler(range(10), generator=torch.Generator().manual_seed(42)),
|
||||
):
|
||||
self.assertEqual(list(fn()), list(fn()))
|
||||
|
||||
|
||||
def _test_sampler(self, **kwargs):
|
||||
indices = range(2, 12) # using a regular iterable
|
||||
dl = DataLoader(self.dataset, sampler=indices, batch_size=2, **kwargs)
|
||||
|
||||
@ -4378,6 +4378,7 @@ def foo(x):
|
||||
torch._jit_internal.is_optional(ann)
|
||||
|
||||
def test_interpreter_fuzz(self):
|
||||
import builtins
|
||||
# This test generates random tree-like programs to fuzz test
|
||||
# that the interpreter does not have a bug in its stack manipulation
|
||||
# code. An assert in that code ensures individual operators are
|
||||
@ -4427,7 +4428,7 @@ def foo(x):
|
||||
for i in range(100):
|
||||
g = {'torch': torch}
|
||||
code = gen_code()
|
||||
torch._six.exec_(code, g, None)
|
||||
builtins.exec(code, g, None)
|
||||
cu = torch.jit.CompilationUnit(code)
|
||||
with freeze_rng_state():
|
||||
o1 = g['f']()
|
||||
@ -10882,37 +10883,34 @@ a")
|
||||
m_c = m.copy()
|
||||
|
||||
def test_script_forward_method_replacement(self):
|
||||
# We want to support the use case of having a different `forward` method
|
||||
# when we are not in scripting mode.
|
||||
# We want to support the use case of attaching a different `forward` method
|
||||
class LowLevelModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(LowLevelModule, self).__init__()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
# Generic forward dispatch, when we are not in scripting mode
|
||||
assert not torch.jit.is_scripting()
|
||||
return self.forward_pytorch(*args, **kwargs) * 2
|
||||
def forward(self, input: torch.Tensor):
|
||||
# Generic forward dispatch
|
||||
return self.forward_pytorch(input) * 2
|
||||
|
||||
class TestModule(LowLevelModule):
|
||||
def __init__(self):
|
||||
super(TestModule, self).__init__()
|
||||
# Replace the forward method if we know we are not in scripting mode
|
||||
if not torch.jit.is_scripting():
|
||||
self.forward = types.MethodType(LowLevelModule.forward, self)
|
||||
# Replace the forward method
|
||||
self.forward = types.MethodType(LowLevelModule.forward, self)
|
||||
|
||||
def forward_pytorch(self, input: torch.Tensor):
|
||||
return torch.tensor(123)
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
# Scripting should only use this forward method
|
||||
assert torch.jit.is_scripting()
|
||||
# Should not use this forward method
|
||||
raise AssertionError("This method should not be used")
|
||||
return self.forward_pytorch(input)
|
||||
|
||||
m = TestModule()
|
||||
self.assertEqual(m(torch.tensor(1)), torch.tensor(246))
|
||||
|
||||
m_scripted = torch.jit.script(m)
|
||||
self.assertEqual(m_scripted(torch.tensor(1)), torch.tensor(123))
|
||||
self.assertEqual(m_scripted(torch.tensor(1)), torch.tensor(246))
|
||||
|
||||
# Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch.
|
||||
@suppress_warnings
|
||||
|
||||
@ -11,6 +11,7 @@ except ImportError:
|
||||
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.jit
|
||||
import torch.backends.mkldnn
|
||||
from torch.utils import mkldnn as mkldnn_utils
|
||||
@ -204,6 +205,29 @@ class TestMkldnn(TestCase):
|
||||
max_pool2d(x),
|
||||
max_pool2d(x.to_mkldnn()).to_dense())
|
||||
|
||||
def test_max_pool2d_stride_none(self):
|
||||
N = torch.randint(3, 10, (1,)).item()
|
||||
C = torch.randint(3, 10, (1,)).item()
|
||||
|
||||
for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]:
|
||||
x = torch.randn(N, C, H, W, dtype=torch.float32) * 10
|
||||
for ceil_mode in [False, True]:
|
||||
y1 = F.max_pool2d(
|
||||
x,
|
||||
kernel_size=3 if not ceil_mode else 7,
|
||||
stride=None,
|
||||
padding=1,
|
||||
ceil_mode=ceil_mode)
|
||||
|
||||
y2 = F.max_pool2d(
|
||||
x.to_mkldnn(),
|
||||
kernel_size=3 if not ceil_mode else 7,
|
||||
stride=None,
|
||||
padding=1,
|
||||
ceil_mode=ceil_mode)
|
||||
|
||||
self.assertEqual(y1, y2.to_dense())
|
||||
|
||||
def test_avg_pool2d(self):
|
||||
N = torch.randint(3, 10, (1,)).item()
|
||||
C = torch.randint(3, 10, (1,)).item()
|
||||
@ -220,6 +244,27 @@ class TestMkldnn(TestCase):
|
||||
avg_pool2d(x),
|
||||
avg_pool2d(x.to_mkldnn()).to_dense())
|
||||
|
||||
def test_avg_pool2d_stride_none(self):
|
||||
N = torch.randint(3, 10, (1,)).item()
|
||||
C = torch.randint(3, 10, (1,)).item()
|
||||
x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
|
||||
|
||||
for count_include_pad in [True, False]:
|
||||
y1 = F.avg_pool2d(
|
||||
x,
|
||||
kernel_size=3,
|
||||
stride=None,
|
||||
padding=1,
|
||||
count_include_pad=count_include_pad)
|
||||
y2 = F.avg_pool2d(
|
||||
x.to_mkldnn(),
|
||||
kernel_size=3,
|
||||
stride=None,
|
||||
padding=1,
|
||||
count_include_pad=count_include_pad)
|
||||
|
||||
self.assertEqual(y1, y2.to_dense())
|
||||
|
||||
def test_adaptive_avg_pool2d(self):
|
||||
N = torch.randint(3, 10, (1,)).item()
|
||||
C = torch.randint(3, 10, (1,)).item()
|
||||
|
||||
69
test/test_openmp.py
Normal file
69
test/test_openmp.py
Normal file
@ -0,0 +1,69 @@
|
||||
import collections
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, TEST_WITH_ASAN)
|
||||
|
||||
try:
|
||||
import psutil
|
||||
HAS_PSUTIL = True
|
||||
except ImportError:
|
||||
HAS_PSUTIL = False
|
||||
|
||||
device = torch.device('cpu')
|
||||
|
||||
|
||||
class Network(torch.nn.Module):
|
||||
maxp1 = torch.nn.MaxPool2d(1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.maxp1(x)
|
||||
|
||||
|
||||
@unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run")
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN")
|
||||
class TestOpenMP_ParallelFor(TestCase):
|
||||
batch = 20
|
||||
channels = 1
|
||||
side_dim = 80
|
||||
x = torch.randn([batch, channels, side_dim, side_dim], device=device)
|
||||
model = Network()
|
||||
ncores = min(5, psutil.cpu_count(logical=False))
|
||||
|
||||
def func(self, runs):
|
||||
p = psutil.Process()
|
||||
# warm up for 5 runs, then things should be stable for the last 5
|
||||
last_rss = collections.deque(maxlen=5)
|
||||
for n in range(10):
|
||||
for i in range(runs):
|
||||
self.model(self.x)
|
||||
last_rss.append(p.memory_info().rss)
|
||||
return last_rss
|
||||
|
||||
def func_rss(self, runs):
|
||||
last_rss = self.func(runs)
|
||||
# Do a least-mean-squares fit of last_rss to a line
|
||||
poly = np.polynomial.Polynomial.fit(
|
||||
range(len(last_rss)), np.array(last_rss), 1)
|
||||
coefs = poly.convert().coef
|
||||
# The coefs are (b, m) for the line y = m * x + b that fits the data.
|
||||
# If m == 0 it will not be present. Assert it is missing or < 1000.
|
||||
self.assertTrue(len(coefs) < 2 or coefs[1] < 1000,
|
||||
msg='memory did not stabilize, {}'.format(str(list(last_rss))))
|
||||
|
||||
def test_one_thread(self):
|
||||
"""Make sure there is no memory leak with one thread: issue gh-32284
|
||||
"""
|
||||
torch.set_num_threads(1)
|
||||
self.func_rss(300)
|
||||
|
||||
def test_n_threads(self):
|
||||
"""Make sure there is no memory leak with many threads
|
||||
"""
|
||||
torch.set_num_threads(self.ncores)
|
||||
self.func_rss(300)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
@ -10953,18 +10953,20 @@ class TestTorchDeviceType(TestCase):
|
||||
t.bernoulli_(0.5)
|
||||
self.assertTrue(isBinary(t))
|
||||
|
||||
p = torch.rand(10, dtype=torch.float32, device=device).expand(10, 10)
|
||||
t.fill_(2)
|
||||
t.bernoulli_(p)
|
||||
self.assertTrue(isBinary(t))
|
||||
for p_dtype in torch.testing.get_all_fp_dtypes(include_half=device.startswith('cuda'),
|
||||
include_bfloat16=False):
|
||||
p = torch.rand(10, dtype=p_dtype, device=device).expand(10, 10)
|
||||
t.fill_(2)
|
||||
t.bernoulli_(p)
|
||||
self.assertTrue(isBinary(t))
|
||||
|
||||
t.fill_(2)
|
||||
torch.bernoulli(torch.rand_like(t, dtype=torch.float32), out=t)
|
||||
self.assertTrue(isBinary(t))
|
||||
t.fill_(2)
|
||||
torch.bernoulli(torch.rand_like(t, dtype=p_dtype), out=t)
|
||||
self.assertTrue(isBinary(t))
|
||||
|
||||
t.fill_(2)
|
||||
t.bernoulli_(torch.rand_like(t, dtype=torch.float32))
|
||||
self.assertTrue(isBinary(t))
|
||||
t.fill_(2)
|
||||
t.bernoulli_(torch.rand_like(t, dtype=p_dtype))
|
||||
self.assertTrue(isBinary(t))
|
||||
|
||||
@slowTest
|
||||
@dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False)))
|
||||
|
||||
@ -53,10 +53,7 @@ __version_info__ = tuple(int(segment) for segment in __version__.split("."))
|
||||
import sys
|
||||
import os
|
||||
|
||||
PY3 = sys.version_info[0] == 3
|
||||
|
||||
if PY3:
|
||||
unicode = str
|
||||
unicode = str
|
||||
|
||||
if sys.platform.startswith('java'):
|
||||
import platform
|
||||
@ -498,10 +495,7 @@ def _get_win_folder_from_registry(csidl_name):
|
||||
registry for this guarantees us the correct answer for all CSIDL_*
|
||||
names.
|
||||
"""
|
||||
if PY3:
|
||||
import winreg as _winreg
|
||||
else:
|
||||
import _winreg
|
||||
import winreg as _winreg
|
||||
|
||||
shell_folder_name = {
|
||||
"CSIDL_APPDATA": "AppData",
|
||||
|
||||
126
torch/_six.py
126
torch/_six.py
@ -18,55 +18,24 @@
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import itertools
|
||||
import builtins
|
||||
import collections.abc
|
||||
import io
|
||||
import math
|
||||
import sys
|
||||
import types
|
||||
import inspect
|
||||
import queue # noqa: F401
|
||||
|
||||
|
||||
PY2 = sys.version_info[0] == 2
|
||||
inf = math.inf
|
||||
nan = math.nan
|
||||
string_classes = (str, bytes)
|
||||
int_classes = int
|
||||
FileNotFoundError = builtins.FileNotFoundError
|
||||
StringIO = io.StringIO
|
||||
container_abcs = collections.abc
|
||||
PY3 = sys.version_info[0] == 3
|
||||
PY37 = sys.version_info[0] == 3 and sys.version_info[1] == 7
|
||||
|
||||
|
||||
if PY2:
|
||||
import __builtin__ as builtins
|
||||
elif PY3:
|
||||
import builtins
|
||||
|
||||
|
||||
if PY2:
|
||||
inf = float('inf')
|
||||
nan = float('nan')
|
||||
else:
|
||||
import math
|
||||
inf = math.inf
|
||||
nan = math.nan
|
||||
|
||||
if PY2:
|
||||
string_classes = basestring
|
||||
else:
|
||||
string_classes = (str, bytes)
|
||||
|
||||
|
||||
if PY2:
|
||||
int_classes = (int, long)
|
||||
else:
|
||||
int_classes = int
|
||||
|
||||
|
||||
if PY2:
|
||||
FileNotFoundError = IOError
|
||||
else:
|
||||
FileNotFoundError = builtins.FileNotFoundError
|
||||
|
||||
|
||||
if PY2:
|
||||
import Queue as queue # noqa: F401
|
||||
else:
|
||||
import queue # noqa: F401
|
||||
|
||||
|
||||
def with_metaclass(meta, *bases):
|
||||
"""Create a base class with a metaclass."""
|
||||
# This requires a bit of explanation: the basic idea is to make a dummy
|
||||
@ -79,74 +48,16 @@ def with_metaclass(meta, *bases):
|
||||
return type.__new__(metaclass, 'temporary_class', (), {})
|
||||
|
||||
|
||||
# A portable way of referring to the generator version of map
|
||||
# in both Python 2 and Python 3.
|
||||
if hasattr(itertools, 'imap'):
|
||||
imap = itertools.imap # type: ignore
|
||||
else:
|
||||
imap = map # type: ignore
|
||||
|
||||
|
||||
if PY3:
|
||||
import builtins
|
||||
# See https://github.com/PyCQA/flake8-bugbear/issues/64
|
||||
exec_ = getattr(builtins, "exec") # noqa: B009
|
||||
else:
|
||||
def exec_(_code_, _globs_=None, _locs_=None):
|
||||
"""Execute code in a namespace."""
|
||||
if _globs_ is None:
|
||||
frame = sys._getframe(1)
|
||||
_globs_ = frame.f_globals
|
||||
if _locs_ is None:
|
||||
_locs_ = frame.f_locals
|
||||
del frame
|
||||
elif _locs_ is None:
|
||||
_locs_ = _globs_
|
||||
exec("""exec _code_ in _globs_, _locs_""")
|
||||
|
||||
|
||||
if sys.version_info[:2] == (3, 2):
|
||||
exec_("""def raise_from(value, from_value):
|
||||
try:
|
||||
if from_value is None:
|
||||
raise value
|
||||
raise value from from_value
|
||||
finally:
|
||||
value = None
|
||||
""")
|
||||
elif sys.version_info[:2] > (3, 2):
|
||||
exec_("""def raise_from(value, from_value):
|
||||
def raise_from(value, from_value):
|
||||
try:
|
||||
raise value from from_value
|
||||
finally:
|
||||
value = None
|
||||
""")
|
||||
else:
|
||||
def raise_from(value, from_value):
|
||||
raise value
|
||||
|
||||
if PY2:
|
||||
import collections
|
||||
container_abcs = collections
|
||||
elif PY3:
|
||||
import collections.abc
|
||||
container_abcs = collections.abc
|
||||
|
||||
# Gets a function from the name of a method on a type
|
||||
if PY2:
|
||||
def get_function_from_type(cls, name):
|
||||
method = getattr(cls, name, None)
|
||||
return getattr(method, "__func__", None)
|
||||
elif PY3:
|
||||
def get_function_from_type(cls, name):
|
||||
return getattr(cls, name, None)
|
||||
|
||||
if PY2:
|
||||
import StringIO
|
||||
StringIO = StringIO.StringIO
|
||||
elif PY3:
|
||||
import io
|
||||
StringIO = io.StringIO
|
||||
def get_function_from_type(cls, name):
|
||||
return getattr(cls, name, None)
|
||||
|
||||
|
||||
# The codes below is not copied from the six package, so the copyright
|
||||
@ -164,9 +75,4 @@ def istuple(obj):
|
||||
return isinstance(obj, tuple) or t.__module__ == 'torch.return_types'
|
||||
|
||||
def bind_method(fn, obj, obj_type):
|
||||
if PY2:
|
||||
if inspect.ismethod(fn):
|
||||
fn = fn.__func__
|
||||
return types.MethodType(fn, obj, obj_type)
|
||||
else:
|
||||
return types.MethodType(fn, obj)
|
||||
return types.MethodType(fn, obj)
|
||||
|
||||
@ -95,7 +95,12 @@ PyObject* rpc_init(PyObject* /* unused */) {
|
||||
// triggered in this context because it conflicts with the struct
|
||||
// torch::hash, so we need to use the qualified name
|
||||
// py::detail::hash, which unfortunately is in a detail namespace.
|
||||
.def(py::detail::hash(py::self));
|
||||
.def(py::detail::hash(py::self)) // NOLINT
|
||||
.def("__repr__", [](const WorkerInfo& workerInfo) {
|
||||
std::ostringstream os;
|
||||
os << workerInfo;
|
||||
return os.str();
|
||||
});
|
||||
|
||||
auto rpcAgent =
|
||||
shared_ptr_class_<RpcAgent>(module, "RpcAgent")
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
#include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <torch/csrc/distributed/rpc/request_callback_impl.h>
|
||||
#include <torch/csrc/distributed/rpc/utils.h>
|
||||
@ -27,7 +29,14 @@ const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
|
||||
const std::string kRpcTimeoutErrorStr =
|
||||
"RPC ran for more than set timeout ({} ms) and will now be marked with an error";
|
||||
|
||||
std::string guessAddress(tensorpipe::transport::uv::Context& uvContext) {
|
||||
} // namespace
|
||||
|
||||
C10_DEFINE_REGISTRY(TensorPipeTransportRegistry, TransportRegistration);
|
||||
|
||||
C10_DEFINE_REGISTRY(TensorPipeChannelRegistry, ChannelRegistration);
|
||||
|
||||
std::string TensorPipeAgent::guessUvAddress(
|
||||
tensorpipe::transport::uv::Context& uvContext) {
|
||||
tensorpipe::Error error;
|
||||
std::string uvAddress;
|
||||
char* ifnameEnv = std::getenv(kSocketIfnameEnvVar.c_str());
|
||||
@ -50,6 +59,63 @@ std::string guessAddress(tensorpipe::transport::uv::Context& uvContext) {
|
||||
return uvAddress;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::unique_ptr<TransportRegistration> makeUvTransport() {
|
||||
auto context = std::make_shared<tensorpipe::transport::uv::Context>();
|
||||
std::string address = TensorPipeAgent::guessUvAddress(*context);
|
||||
return std::make_unique<TransportRegistration>(
|
||||
TransportRegistration{std::move(context), 1, std::move(address)});
|
||||
}
|
||||
|
||||
// The UV transport is implemented using standard TCP connections. It leverages
|
||||
// libuv (https://github.com/libuv/libuv) in order to be cross-platform.
|
||||
C10_REGISTER_CREATOR(TensorPipeTransportRegistry, uv, makeUvTransport);
|
||||
|
||||
#ifdef TP_ENABLE_SHM
|
||||
|
||||
std::unique_ptr<TransportRegistration> makeShmTransport() {
|
||||
auto context = std::make_shared<tensorpipe::transport::shm::Context>();
|
||||
std::string address = TensorPipeAgent::createUniqueShmAddr();
|
||||
return std::make_unique<TransportRegistration>(
|
||||
TransportRegistration{std::move(context), 0, std::move(address)});
|
||||
}
|
||||
|
||||
// The SHM implements connections using ringbuffers residing in anonymous shared
|
||||
// memory (plus UNIX domain sockets to bootstrap the connection and exchange
|
||||
// file descriptors). It is Linux-only due to some advanced features (O_TMPFILE,
|
||||
// eventfd, ...).
|
||||
C10_REGISTER_CREATOR(TensorPipeTransportRegistry, shm, makeShmTransport);
|
||||
|
||||
#endif
|
||||
|
||||
std::unique_ptr<ChannelRegistration> makeBasicChannel() {
|
||||
auto context = std::make_shared<tensorpipe::channel::basic::Context>();
|
||||
return std::make_unique<ChannelRegistration>(
|
||||
ChannelRegistration{std::move(context), 1});
|
||||
}
|
||||
|
||||
// The basic channel is just a straightforward adapter wrapper that allows any
|
||||
// transport to be used as a channel.
|
||||
C10_REGISTER_CREATOR(TensorPipeChannelRegistry, basic, makeBasicChannel);
|
||||
|
||||
#ifdef TP_ENABLE_CMA
|
||||
|
||||
std::unique_ptr<ChannelRegistration> makeCmaChannel() {
|
||||
auto context = std::make_shared<tensorpipe::channel::cma::Context>();
|
||||
return std::make_unique<ChannelRegistration>(
|
||||
ChannelRegistration{std::move(context), 0});
|
||||
}
|
||||
|
||||
// The CMA channel uses the Linux cross-memory attach syscalls (process_vm_readv
|
||||
// and _writev), which allow one process to access the private memory of another
|
||||
// process (as long as they belong to the same user and other security
|
||||
// constraints are satisfied). It does, more or less, what GDB does when it's
|
||||
// attached to a running process.
|
||||
C10_REGISTER_CREATOR(TensorPipeChannelRegistry, cma, makeCmaChannel);
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
|
||||
////////////////////////// MetricsTracker /////////////////////////////////
|
||||
@ -133,36 +199,40 @@ TensorPipeAgent::~TensorPipeAgent() {
|
||||
|
||||
void TensorPipeAgent::startImpl() {
|
||||
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is starting";
|
||||
auto uvContext = std::make_shared<tensorpipe::transport::uv::Context>();
|
||||
|
||||
std::string uvAddress = guessAddress(*uvContext);
|
||||
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is using address "
|
||||
<< uvAddress;
|
||||
std::vector<std::string> addresses;
|
||||
int lowestPriority = std::numeric_limits<int>::max();
|
||||
std::string lowestPriorityTransport;
|
||||
|
||||
context_->registerTransport(1, "tcp", std::move(uvContext));
|
||||
#ifdef TP_ENABLE_SHM
|
||||
context_->registerTransport(
|
||||
0, "shm", std::make_shared<tensorpipe::transport::shm::Context>());
|
||||
#endif
|
||||
context_->registerChannel(
|
||||
1, "basic", std::make_shared<tensorpipe::channel::basic::Context>());
|
||||
#ifdef TP_ENABLE_CMA
|
||||
context_->registerChannel(
|
||||
0, "cma", std::make_shared<tensorpipe::channel::cma::Context>());
|
||||
#endif
|
||||
for (auto& key : TensorPipeTransportRegistry()->Keys()) {
|
||||
std::unique_ptr<TransportRegistration> reg =
|
||||
TensorPipeTransportRegistry()->Create(key);
|
||||
if (reg->priority < lowestPriority) {
|
||||
lowestPriority = reg->priority;
|
||||
lowestPriorityTransport = key;
|
||||
}
|
||||
addresses.push_back(c10::str(key, "://", reg->address));
|
||||
context_->registerTransport(
|
||||
reg->priority, std::move(key), std::move(reg->transport));
|
||||
}
|
||||
|
||||
std::vector<std::string> addresses = {"tcp://" + uvAddress};
|
||||
#ifdef TP_ENABLE_SHM
|
||||
addresses.push_back(createUniqueShmAddr());
|
||||
#endif
|
||||
for (auto& key : TensorPipeChannelRegistry()->Keys()) {
|
||||
std::unique_ptr<ChannelRegistration> reg =
|
||||
TensorPipeChannelRegistry()->Create(key);
|
||||
context_->registerChannel(
|
||||
reg->priority, std::move(key), std::move(reg->channel));
|
||||
}
|
||||
|
||||
listener_ = context_->listen(addresses);
|
||||
|
||||
// Store our own url.
|
||||
const auto address = listener_->url("tcp");
|
||||
const auto address = listener_->url(lowestPriorityTransport);
|
||||
const std::vector<uint8_t> selfAddrData(address.begin(), address.end());
|
||||
nameToAddressStore_.set(workerInfo_.name_, selfAddrData);
|
||||
|
||||
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is using address "
|
||||
<< address;
|
||||
|
||||
for (const auto& p : workerNameToInfo_) {
|
||||
const auto& name = p.first;
|
||||
auto nodeAddrData = nameToAddressStore_.get(name);
|
||||
@ -664,10 +734,6 @@ void TensorPipeAgent::shutdownImpl() {
|
||||
// FIXME Isn't it too verbose for a library to print logs in normal operation?
|
||||
LOG(INFO) << "RPC agent for " << workerInfo_.name_ << " is shutting down";
|
||||
|
||||
threadPool_.waitWorkComplete();
|
||||
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
||||
<< " done waiting for thread pool to complete work";
|
||||
|
||||
// Join the Timeout Thread
|
||||
timeoutThreadCV_.notify_one();
|
||||
if (timeoutThread_.joinable()) {
|
||||
@ -681,6 +747,17 @@ void TensorPipeAgent::shutdownImpl() {
|
||||
context_->join();
|
||||
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
||||
<< " done waiting for TensorPipe context to join";
|
||||
|
||||
// NOTE: We need to call waitWorkComplete in the end after we have shutdown
|
||||
// all listeners for Tensorpipe. This is to drain any already accepted work
|
||||
// in the ThreadPool. If this is done before we shutdown the listeners,
|
||||
// additional work could be added after this call and before we shutdown
|
||||
// listeners. This work would continue executing in the threadpool and might
|
||||
// cause issues during shutdown of the system.
|
||||
threadPool_.waitWorkComplete();
|
||||
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
||||
<< " done waiting for thread pool to complete work";
|
||||
|
||||
}
|
||||
|
||||
const WorkerInfo& TensorPipeAgent::getWorkerInfo(
|
||||
|
||||
@ -37,6 +37,21 @@ struct AggregatedNetworkData {
|
||||
uint64_t totalErrors{0};
|
||||
};
|
||||
|
||||
struct TransportRegistration {
|
||||
std::shared_ptr<tensorpipe::transport::Context> transport;
|
||||
int64_t priority;
|
||||
std::string address;
|
||||
};
|
||||
|
||||
C10_DECLARE_REGISTRY(TensorPipeTransportRegistry, TransportRegistration);
|
||||
|
||||
struct ChannelRegistration {
|
||||
std::shared_ptr<tensorpipe::channel::Context> channel;
|
||||
int64_t priority;
|
||||
};
|
||||
|
||||
C10_DECLARE_REGISTRY(TensorPipeChannelRegistry, ChannelRegistration);
|
||||
|
||||
// TensorPipeAgent leverages TensorPipe (https://github.com/pytorch/tensorpipe)
|
||||
// to transparently move tensors and payloads through the fastest available
|
||||
// transport or channel. It acts like a hybrid RPC transport, providing shared
|
||||
@ -84,16 +99,19 @@ class TensorPipeAgent : public RpcAgent {
|
||||
// Returns NetworkSourceInfo struct
|
||||
NetworkSourceInfo getNetworkSourceInfo();
|
||||
|
||||
static std::string guessUvAddress(
|
||||
tensorpipe::transport::uv::Context& uvContext);
|
||||
|
||||
#ifdef TP_ENABLE_SHM
|
||||
static std::string createUniqueShmAddr();
|
||||
#endif
|
||||
|
||||
private:
|
||||
// Populates workerIdToInfo_ and workerNameToInfo_ using addressStore_
|
||||
void collectNames();
|
||||
|
||||
const std::string& findWorkerURL(const WorkerInfo& worker) const;
|
||||
|
||||
#ifdef TP_ENABLE_SHM
|
||||
std::string createUniqueShmAddr();
|
||||
#endif
|
||||
|
||||
// TensorPipe read function that could be used to read response messages
|
||||
// by client, and read request messages by server.
|
||||
void pipeRead(
|
||||
|
||||
@ -149,7 +149,10 @@ ParsedLiteral IRParser::parseScalarLiteral(Node* n) {
|
||||
case '-':
|
||||
str = "-";
|
||||
L.next();
|
||||
L.expect(TK_NUMBER);
|
||||
if (L.cur().kind != TK_NUMBER) {
|
||||
throw ErrorReport(token.range)
|
||||
<< "Expected a number after '-' but got:" << token.text();
|
||||
}
|
||||
// Fallthrough
|
||||
case TK_NUMBER:
|
||||
str += L.cur().text();
|
||||
|
||||
@ -295,9 +295,10 @@ std::vector<Value*> getPassThroughInputs(Value* v) {
|
||||
inputs.push_back(output);
|
||||
}
|
||||
return inputs;
|
||||
} else if (n->kind() == prim::ListUnpack) {
|
||||
} else if (n->kind() == prim::ListUnpack || n->kind() == prim::TupleUnpack) {
|
||||
return {n->input(0)};
|
||||
} else if (n->kind() == prim::ListConstruct) {
|
||||
} else if (
|
||||
n->kind() == prim::ListConstruct || n->kind() == prim::TupleConstruct) {
|
||||
std::vector<Value*> inputs;
|
||||
for (auto* v : n->inputs()) {
|
||||
inputs.push_back(v);
|
||||
|
||||
@ -837,11 +837,17 @@ void initJITBindings(PyObject* module) {
|
||||
[](Argument& self) -> py::object {
|
||||
return (self.N()) ? py::cast(*self.N()) : py::none();
|
||||
})
|
||||
.def_property_readonly("default_value", [](Argument& self) -> py::object {
|
||||
if (!self.default_value())
|
||||
return py::none();
|
||||
IValue v = *self.default_value();
|
||||
return toPyObject(std::move(v));
|
||||
.def_property_readonly(
|
||||
"default_value",
|
||||
[](Argument& self) -> py::object {
|
||||
if (!self.default_value()) {
|
||||
return py::none();
|
||||
}
|
||||
IValue v = *self.default_value();
|
||||
return toPyObject(std::move(v));
|
||||
})
|
||||
.def("has_default_value", [](Argument& self) -> py::bool_ {
|
||||
return self.default_value().has_value();
|
||||
});
|
||||
m.def("_jit_get_all_schemas", []() {
|
||||
const std::vector<std::shared_ptr<Operator>>& operations =
|
||||
@ -956,6 +962,22 @@ void initJITBindings(PyObject* module) {
|
||||
return fut->wait();
|
||||
});
|
||||
|
||||
m.def(
|
||||
"_collect_all",
|
||||
[](const std::vector<std::shared_ptr<jit::PythonFutureWrapper>>& futures)
|
||||
-> std::shared_ptr<jit::PythonFutureWrapper> {
|
||||
auto typePtr =
|
||||
futures.empty() ? AnyType::get() : futures[0]->fut->elementType();
|
||||
c10::List<c10::intrusive_ptr<c10::ivalue::Future>> asList(
|
||||
c10::FutureType::create(typePtr));
|
||||
asList.reserve(futures.size());
|
||||
for (const auto& f : futures) {
|
||||
asList.push_back(f->fut);
|
||||
}
|
||||
return std::make_shared<jit::PythonFutureWrapper>(
|
||||
c10::collectAll(asList));
|
||||
});
|
||||
|
||||
m.def("_jit_assert_is_instance", [](py::object obj, TypePtr type) {
|
||||
toIValue(obj, type);
|
||||
});
|
||||
|
||||
@ -780,6 +780,12 @@ void initPythonIRBindings(PyObject* module_) {
|
||||
return get_python_cu()->get_interface(
|
||||
c10::QualifiedName(qualified_name));
|
||||
}))
|
||||
.def(
|
||||
"getMethod",
|
||||
[](InterfaceType& self, const std::string& name) {
|
||||
return self.getMethod(name);
|
||||
},
|
||||
py::return_value_policy::reference)
|
||||
.def("getMethodNames", [](InterfaceType& self) {
|
||||
std::vector<std::string> names;
|
||||
for (const FunctionSchema& fn : self.methods()) {
|
||||
|
||||
@ -1063,7 +1063,11 @@ void initJitScriptBindings(PyObject* module) {
|
||||
const std::string& src,
|
||||
ResolutionCallback rcb) {
|
||||
cu.define(c10::nullopt, src, pythonResolver(rcb), nullptr);
|
||||
});
|
||||
})
|
||||
.def(
|
||||
"get_interface",
|
||||
[](const std::shared_ptr<CompilationUnit>& self,
|
||||
const std::string& name) { return self->get_interface(name); });
|
||||
|
||||
py::class_<StrongFunctionPtr>(m, "ScriptFunction", py::dynamic_attr())
|
||||
.def(
|
||||
|
||||
1
torch/distributed/nn/__init__.py
Normal file
1
torch/distributed/nn/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .api.remote_module import RemoteModule
|
||||
0
torch/distributed/nn/api/__init__.py
Normal file
0
torch/distributed/nn/api/__init__.py
Normal file
215
torch/distributed/nn/api/remote_module.py
Normal file
215
torch/distributed/nn/api/remote_module.py
Normal file
@ -0,0 +1,215 @@
|
||||
#!/usr/bin/python3
|
||||
import types
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch import nn
|
||||
from torch.distributed.nn.jit import instantiator
|
||||
|
||||
|
||||
_NON_SCRIPTABLE_REMOTE_MODULE_MODULE = (
|
||||
instantiator.instantiate_non_scriptable_remote_module_template()
|
||||
)
|
||||
|
||||
|
||||
# RPC handler.
|
||||
def _instantiate_template(module_interface_cls):
|
||||
instantiator.instantiate_scriptable_remote_module_template(module_interface_cls)
|
||||
|
||||
|
||||
def _create_module(module_cls, args, kwargs, module_interface_cls=None):
|
||||
module = module_cls(*args, **kwargs)
|
||||
if not isinstance(module, nn.Module):
|
||||
raise ValueError(
|
||||
"Expect `module_cls(*args, **kwargs)` returns an instancee of <class nn.Module>, "
|
||||
f"but it returns an instance of {type(module)}."
|
||||
)
|
||||
if module_interface_cls is not None:
|
||||
module = torch.jit.script(module)
|
||||
return rpc.RRef(module, module_interface_cls)
|
||||
|
||||
|
||||
class _RemoteModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
on: str,
|
||||
module_cls: nn.Module,
|
||||
args: Tuple = None,
|
||||
kwargs: Dict[str, Any] = None,
|
||||
_module_interface_cls: Any = None,
|
||||
):
|
||||
"""
|
||||
A RemoteModule instance can only be created after RPC initialization.
|
||||
It creates a user-specified module on a specified remote node.
|
||||
It behaves like a regular ``nn.Module`` except that the ``forward`` method is
|
||||
executed on the remote node.
|
||||
It takes care of autograd recording to ensure the backward pass propogates
|
||||
gradients back to the corresponding remote module.
|
||||
|
||||
The arguments of ``forward_async`` and ``forward`` are the same as
|
||||
the ``forward`` method of the module returned by the ``module_cls``.
|
||||
|
||||
For example, if ``module_cls`` returns an instace of ``nn.Linear``,
|
||||
that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``,
|
||||
the generated ``RemoteModule`` will have 2 methods in signature of
|
||||
``def forward(input: Tensor) -> Tensor:`` and
|
||||
``def forward_async(input: Tensor) -> Future[Tensor]:``.
|
||||
|
||||
Arguments:
|
||||
on (str or WorkerInfo): id or name of the destination worker.
|
||||
module_cls (nn.Module): For example,
|
||||
>>> class MyModule(nn.Module):
|
||||
>>> def forward(input):
|
||||
>>> return input + 1
|
||||
>>>
|
||||
>>> module_cls = MyModule
|
||||
args (Sequence, optional): args to be passed to ``module_cls``.
|
||||
kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
|
||||
_module_interface_cls (type, optional): The TorchScript interface type for the module
|
||||
to be created. The type object should be decorated by @torch.jit.interface.
|
||||
If not provided, the generated RemoteModule is not torchscript-able.
|
||||
Warning, this is an experimental API and susceptible to frequent changes.
|
||||
|
||||
Returns:
|
||||
A remote module instance which wraps the :class:`~nn.Module` created by the
|
||||
user-provided ``module_cls``, it has a blocking ``forward`` method and an
|
||||
asynchronous ``forward_async`` method that returns a future of the ``forward`` call
|
||||
on the user-provided module on the remote side.
|
||||
|
||||
Example::
|
||||
Run the following code in two different processes:
|
||||
|
||||
>>> # On worker 0:
|
||||
>>> import torch
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> from torch import nn, Tensor
|
||||
>>> from torch.distributed.nn.api.remote_module import RemoteModule
|
||||
>>>
|
||||
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||||
>>> remote_linear_module = RemoteModule(
|
||||
>>> "worker1", nn.Linear, args=(20, 30),
|
||||
>>> )
|
||||
>>> input = torch.randn(128, 20)
|
||||
>>> ret_fut = remote_linear_module.forward_async(input)
|
||||
>>> ret = ret_fut.wait()
|
||||
>>> rpc.shutdown()
|
||||
|
||||
>>> # On worker 1:
|
||||
>>> import torch
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>>
|
||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||
>>> rpc.shutdown()
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Sanity checks.
|
||||
assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC."
|
||||
|
||||
# Default arguments preperation.
|
||||
args = args if args is not None else ()
|
||||
kwargs = kwargs if kwargs is not None else {}
|
||||
|
||||
if _module_interface_cls is not None:
|
||||
# Users reply on this field to know if this generated RemoteModule is TorchScript-able.
|
||||
self.is_scriptable = True
|
||||
|
||||
# Instantiate template on remote side.
|
||||
fut = rpc.rpc_async(on, _instantiate_template, (_module_interface_cls,))
|
||||
|
||||
# Instantiate template on local side.
|
||||
generated_module = instantiator.instantiate_scriptable_remote_module_template(
|
||||
_module_interface_cls
|
||||
)
|
||||
generated_methods = generated_module._generated_methods
|
||||
|
||||
# Create the module on the remote side.
|
||||
fut.wait() # Ensure remote_module_cls is available on remote side.
|
||||
else:
|
||||
self.is_scriptable = False
|
||||
generated_methods = _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
|
||||
|
||||
# Create the module on the remote side.
|
||||
self.module_rref = rpc.rpc_sync(
|
||||
on,
|
||||
_create_module,
|
||||
(module_cls, args, kwargs, _module_interface_cls),
|
||||
)
|
||||
|
||||
# Install generated methods.
|
||||
for method in generated_methods:
|
||||
method_name = method.__name__
|
||||
method = torch.jit.export(method)
|
||||
setattr(self, method_name, types.MethodType(method, self))
|
||||
|
||||
|
||||
class RemoteModule(_RemoteModule):
|
||||
"""
|
||||
A RemoteModule instance can only be created after RPC initialization.
|
||||
It creates a user-specified module on a specified remote node.
|
||||
It behaves like a regular ``nn.Module`` except that the ``forward`` method is
|
||||
executed on the remote node.
|
||||
It takes care of autograd recording to ensure the backward pass propogates
|
||||
gradients back to the corresponding remote module.
|
||||
|
||||
The arguments of ``forward_async`` and ``forward`` are the same as
|
||||
the ``forward`` method of the module returned by the ``module_cls``.
|
||||
|
||||
For example, if ``module_cls`` returns an instace of ``nn.Linear``,
|
||||
that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``,
|
||||
the generated ``RemoteModule`` will have 2 methods in signature of
|
||||
``def forward(input: Tensor) -> Tensor:`` and
|
||||
``def forward_async(input: Tensor) -> Future[Tensor]:``.
|
||||
|
||||
Arguments:
|
||||
to (str or WorkerInfo): id or name of the destination worker.
|
||||
module_cls (nn.Module): For example,
|
||||
>>> class MyModule(nn.Module):
|
||||
>>> def forward(input):
|
||||
>>> return input + 1
|
||||
>>>
|
||||
>>> module_cls = MyModule
|
||||
args (Sequence, optional): args to be passed to ``module_cls``.
|
||||
kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
|
||||
|
||||
Returns:
|
||||
A remote module instance which wraps the :class:`~nn.Module` created by the
|
||||
user-provided ``module_cls``, it has a blocking ``forward`` method and an
|
||||
asynchronous ``forward_async`` method that returns a future of the ``forward`` call
|
||||
on the user-provided module on the remote side.
|
||||
|
||||
Example::
|
||||
Run the following code in two different processes:
|
||||
|
||||
>>> # On worker 0:
|
||||
>>> import torch
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> from torch import nn, Tensor
|
||||
>>> from torch.distributed.nn.api.remote_module import RemoteModule
|
||||
>>>
|
||||
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||||
>>> remote_linear_module = RemoteModule(
|
||||
>>> "worker1", nn.Linear, args=(20, 30),
|
||||
>>> )
|
||||
>>> input = torch.randn(128, 20)
|
||||
>>> ret_fut = remote_linear_module.forward_async(input)
|
||||
>>> ret = ret_fut.wait()
|
||||
>>> rpc.shutdown()
|
||||
|
||||
>>> # On worker 1:
|
||||
>>> import torch
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>>
|
||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||
>>> rpc.shutdown()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on: str,
|
||||
module_cls: nn.Module,
|
||||
args: Tuple = None,
|
||||
kwargs: Dict[str, Any] = None,
|
||||
):
|
||||
super().__init__(on, module_cls, args, kwargs)
|
||||
0
torch/distributed/nn/jit/__init__.py
Normal file
0
torch/distributed/nn/jit/__init__.py
Normal file
160
torch/distributed/nn/jit/instantiator.py
Normal file
160
torch/distributed/nn/jit/instantiator.py
Normal file
@ -0,0 +1,160 @@
|
||||
#!/usr/bin/python3
|
||||
import atexit
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
from torch.distributed.nn.jit.templates.instantiated import (
|
||||
INSTANTIATED_TEMPLATE_DIR_PATH,
|
||||
)
|
||||
from torch.distributed.nn.jit.templates.remote_module_template import (
|
||||
REMOTE_MODULE_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_FILE_PREFIX = "_remote_module_"
|
||||
|
||||
|
||||
def get_arg_return_types_from_interface(module_interface):
|
||||
assert getattr(
|
||||
module_interface, "__torch_script_interface__", False
|
||||
), "Expect a TorchScript class interface decorated by @torch.jit.interface."
|
||||
qualified_name = torch.jit._qualified_name(module_interface)
|
||||
cu = torch.jit._python_cu
|
||||
module_interface_c = cu.get_interface(qualified_name)
|
||||
assert (
|
||||
"forward" in module_interface_c.getMethodNames()
|
||||
), "Expect forward in interface methods, while it has {}".format(
|
||||
module_interface_c.getMethodNames()
|
||||
)
|
||||
method_schema = module_interface_c.getMethod("forward")
|
||||
|
||||
arg_str_list = []
|
||||
arg_type_str_list = []
|
||||
for argument in method_schema.arguments:
|
||||
arg_str_list.append(argument.name)
|
||||
|
||||
if argument.has_default_value():
|
||||
default_value_str = " = {}".format(argument.default)
|
||||
else:
|
||||
default_value_str = ""
|
||||
arg_type_str = "{name}: {type}{default_value}".format(
|
||||
name=argument.name, type=argument.type, default_value=default_value_str
|
||||
)
|
||||
arg_type_str_list.append(arg_type_str)
|
||||
|
||||
arg_str_list = arg_str_list[1:] # Remove "self".
|
||||
args_str = ", ".join(arg_str_list)
|
||||
|
||||
arg_type_str_list = arg_type_str_list[1:] # Remove "self".
|
||||
arg_types_str = ", ".join(arg_type_str_list)
|
||||
|
||||
assert len(method_schema.returns) == 1
|
||||
argument = method_schema.returns[0]
|
||||
return_type_str = str(argument.type)
|
||||
|
||||
return args_str, arg_types_str, return_type_str
|
||||
|
||||
|
||||
@atexit.register
|
||||
def cleanup_generated_modules():
|
||||
generated_module_paths = pathlib.Path(INSTANTIATED_TEMPLATE_DIR_PATH).glob(
|
||||
f"{_FILE_PREFIX}*.py"
|
||||
)
|
||||
for file_path in generated_module_paths:
|
||||
try:
|
||||
print(f"Removing {file_path}")
|
||||
assert file_path.is_file(), f"Expect {file_path} to be a file"
|
||||
file_path.unlink()
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to remove {file_path}:\n{traceback.format_exc()}")
|
||||
|
||||
|
||||
def _write(out_path, text):
|
||||
try:
|
||||
with open(out_path, "r") as f:
|
||||
old_text = f.read()
|
||||
except IOError:
|
||||
old_text = None
|
||||
if old_text != text:
|
||||
with open(out_path, "w") as f:
|
||||
print("Writing {}".format(out_path))
|
||||
f.write(text)
|
||||
else:
|
||||
print("Skipped writing {}".format(out_path))
|
||||
|
||||
|
||||
def _do_instantiate_remote_module_template(generated_module_name, str_dict):
|
||||
generated_code_text = REMOTE_MODULE_TEMPLATE.format(**str_dict)
|
||||
out_path = os.path.join(
|
||||
INSTANTIATED_TEMPLATE_DIR_PATH, f"{generated_module_name}.py"
|
||||
)
|
||||
_write(out_path, generated_code_text)
|
||||
|
||||
# From importlib doc,
|
||||
# > If you are dynamically importing a module that was created since
|
||||
# the interpreter began execution (e.g., created a Python source file),
|
||||
# you may need to call invalidate_caches() in order for the new module
|
||||
# to be noticed by the import system.
|
||||
importlib.invalidate_caches()
|
||||
generated_module = importlib.import_module(
|
||||
f"torch.distributed.nn.jit.templates.instantiated.{generated_module_name}"
|
||||
)
|
||||
return generated_module
|
||||
|
||||
|
||||
def instantiate_scriptable_remote_module_template(module_interface_cls):
|
||||
if not getattr(module_interface_cls, "__torch_script_interface__", False):
|
||||
raise ValueError(
|
||||
f"module_interface_cls {module_interface_cls} must be a type object decorated by "
|
||||
"@torch.jit.interface"
|
||||
)
|
||||
|
||||
# Generate the template instance name.
|
||||
module_interface_cls_name = torch.jit._qualified_name(module_interface_cls).replace(
|
||||
".", "_"
|
||||
)
|
||||
generated_module_name = f"{_FILE_PREFIX}{module_interface_cls_name}"
|
||||
|
||||
# Generate type annotation strs.
|
||||
assign_module_interface_cls_str = (
|
||||
f"from {module_interface_cls.__module__} import "
|
||||
f"{module_interface_cls.__name__} as module_interface_cls"
|
||||
)
|
||||
args_str, arg_types_str, return_type_str = get_arg_return_types_from_interface(
|
||||
module_interface_cls
|
||||
)
|
||||
kwargs_str = ""
|
||||
arrow_and_return_type_str = f" -> {return_type_str}"
|
||||
arrow_and_future_return_type_str = f" -> Future[{return_type_str}]"
|
||||
|
||||
str_dict = dict(
|
||||
assign_module_interface_cls=assign_module_interface_cls_str,
|
||||
arg_types=arg_types_str,
|
||||
arrow_and_return_type=arrow_and_return_type_str,
|
||||
arrow_and_future_return_type=arrow_and_future_return_type_str,
|
||||
args=args_str,
|
||||
kwargs=kwargs_str,
|
||||
jit_script_decorator="@torch.jit.script",
|
||||
)
|
||||
return _do_instantiate_remote_module_template(generated_module_name, str_dict)
|
||||
|
||||
|
||||
def instantiate_non_scriptable_remote_module_template():
|
||||
generated_module_name = f"{_FILE_PREFIX}non_sriptable"
|
||||
str_dict = dict(
|
||||
assign_module_interface_cls="module_interface_cls = None",
|
||||
args="*args",
|
||||
kwargs="**kwargs",
|
||||
arg_types="*args, **kwargs",
|
||||
arrow_and_return_type="",
|
||||
arrow_and_future_return_type="",
|
||||
jit_script_decorator="",
|
||||
)
|
||||
return _do_instantiate_remote_module_template(generated_module_name, str_dict)
|
||||
0
torch/distributed/nn/jit/templates/__init__.py
Normal file
0
torch/distributed/nn/jit/templates/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
import os
|
||||
|
||||
|
||||
INSTANTIATED_TEMPLATE_DIR_PATH = os.path.dirname(__file__)
|
||||
47
torch/distributed/nn/jit/templates/remote_module_template.py
Normal file
47
torch/distributed/nn/jit/templates/remote_module_template.py
Normal file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
REMOTE_MODULE_TEMPLATE = """from typing import *
|
||||
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch import Tensor
|
||||
from torch._jit_internal import Future, RRef
|
||||
|
||||
|
||||
{assign_module_interface_cls}
|
||||
|
||||
|
||||
{jit_script_decorator}
|
||||
def _remote_forward(module_rref: RRef[module_interface_cls], {arg_types}){arrow_and_return_type}: # noqa
|
||||
module = module_rref.local_value()
|
||||
return module.forward({args}, {kwargs})
|
||||
|
||||
|
||||
def forward_async(self, {arg_types}){arrow_and_future_return_type}: # noqa
|
||||
args = (self.module_rref, {args})
|
||||
kwargs = {{{kwargs}}}
|
||||
return rpc.rpc_async(
|
||||
self.module_rref.owner(),
|
||||
_remote_forward,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
|
||||
def forward(self, {arg_types}){arrow_and_return_type}: # noqa
|
||||
args = (self.module_rref, {args})
|
||||
kwargs = {{{kwargs}}}
|
||||
ret_fut = rpc.rpc_async(
|
||||
self.module_rref.owner(),
|
||||
_remote_forward,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
return ret_fut.wait()
|
||||
|
||||
|
||||
_generated_methods = [
|
||||
forward_async,
|
||||
forward,
|
||||
]
|
||||
"""
|
||||
@ -1,24 +1,28 @@
|
||||
import collections
|
||||
import copyreg
|
||||
from enum import Enum
|
||||
import io
|
||||
import pickle
|
||||
import threading
|
||||
import traceback
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from . import _get_current_rpc_agent
|
||||
|
||||
|
||||
# Thread local tensor tables to store tensors while pickling torch.Tensor
|
||||
# objects
|
||||
_thread_local_tensor_tables = threading.local()
|
||||
|
||||
|
||||
class RPCExecMode(Enum):
|
||||
SYNC = "sync"
|
||||
ASYNC = "async"
|
||||
REMOTE = "remote"
|
||||
|
||||
|
||||
class _InternalRPCPickler:
|
||||
r"""
|
||||
This class provides serialize() and deserialize() interfaces to serialize
|
||||
@ -51,7 +55,7 @@ class _InternalRPCPickler:
|
||||
|
||||
def _rref_reducer(self, rref):
|
||||
rref_fork_data = rref._serialize()
|
||||
return (_InternalRPCPickler._rref_receiver, (rref_fork_data, ))
|
||||
return (_InternalRPCPickler._rref_receiver, (rref_fork_data,))
|
||||
|
||||
def serialize(self, obj):
|
||||
r"""
|
||||
@ -107,9 +111,12 @@ class _InternalRPCPickler:
|
||||
except AttributeError as e:
|
||||
# Occurs when function is not found on module/class during
|
||||
# unpickling.
|
||||
except_str = str(e) + """ Default RPC pickler does not serialize
|
||||
except_str = (
|
||||
str(e)
|
||||
+ """ Default RPC pickler does not serialize
|
||||
function code. Ensure that UDFs are defined on both caller and
|
||||
callee modules."""
|
||||
)
|
||||
ret = AttributeError(except_str)
|
||||
|
||||
# restore _thread_local_tensor_tables.recv_tables if return
|
||||
@ -148,7 +155,10 @@ def _run_function(python_udf):
|
||||
result = python_udf.func(*python_udf.args, **python_udf.kwargs)
|
||||
except Exception as e:
|
||||
# except str = exception info + traceback string
|
||||
except_str = "{}\n{}".format(repr(e), traceback.format_exc())
|
||||
except_str = (
|
||||
f"On {_get_current_rpc_agent().get_worker_info()}:\n"
|
||||
f"{repr(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
result = RemoteException(except_str, type(e))
|
||||
return result
|
||||
|
||||
@ -157,7 +167,10 @@ def _handle_exception(result):
|
||||
if isinstance(result, RemoteException):
|
||||
raise result.exception_type(result.msg)
|
||||
|
||||
def _build_rpc_profiling_key(exec_type, func_name, current_worker_name, dst_worker_name):
|
||||
|
||||
def _build_rpc_profiling_key(
|
||||
exec_type, func_name, current_worker_name, dst_worker_name
|
||||
):
|
||||
"""
|
||||
Builds the key that RPC calls are profiled with using the autograd profiler.
|
||||
This will be the name of the corresponding Event recorded in the profiler.
|
||||
|
||||
@ -94,3 +94,30 @@ class Future(torch._C.Future):
|
||||
>>> t.join()
|
||||
"""
|
||||
super(Future, self).set_result(result)
|
||||
|
||||
|
||||
def collect_all(futures):
|
||||
r"""
|
||||
Collects the Futures into a single combined Future that is completed
|
||||
when all of the sub-futures are completed.
|
||||
|
||||
Arguments:
|
||||
futures: a list of Futures
|
||||
|
||||
Returns:
|
||||
Returns a Future object to a list of the passed in Futures.
|
||||
"""
|
||||
return torch._C._collect_all(futures)
|
||||
|
||||
def wait_all(futures):
|
||||
r"""
|
||||
Waits for all provided futures to be complete, and returns
|
||||
the list of completed values.
|
||||
|
||||
Arguments:
|
||||
futures: a list of Futures
|
||||
|
||||
Returns:
|
||||
A list of the completed Future results
|
||||
"""
|
||||
return [fut.wait() for fut in torch._C._collect_all(futures).wait()]
|
||||
|
||||
@ -35,7 +35,7 @@ def make_stub(func, name):
|
||||
return ScriptMethodStub(rcb, ast, func)
|
||||
|
||||
def make_stub_from_method(nn_module, method_name):
|
||||
func = get_function_from_type(type(nn_module), method_name)
|
||||
func = getattr(nn_module, method_name)
|
||||
if isinstance(func, ScriptMethodStub):
|
||||
return func
|
||||
# Make sure the name present in the resulting AST will match the name
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import sys
|
||||
import ast
|
||||
import inspect
|
||||
import re
|
||||
@ -17,9 +16,6 @@ from torch._utils_internal import get_source_lines_and_file
|
||||
from typing import Callable
|
||||
|
||||
|
||||
PY35 = sys.version_info >= (3, 5)
|
||||
|
||||
|
||||
class Module(object):
|
||||
def __init__(self, name, members):
|
||||
self.name = name
|
||||
@ -56,18 +52,15 @@ class EvalEnv(object):
|
||||
return getattr(builtins, name, None)
|
||||
|
||||
def get_signature(fn, rcb, loc, is_method):
|
||||
# Python 3.5 adds support for the nice annotation syntax, so try that first.
|
||||
signature = None
|
||||
if PY35:
|
||||
signature = try_real_annotations(fn, loc)
|
||||
if signature is not None and is_method:
|
||||
# If this is a method, then the signaure will include a type for
|
||||
# `self`, but type comments do not contain a `self`. So strip it
|
||||
# away here so everything is consistent (`inspect.ismethod` does
|
||||
# not work here since `fn` is unbound at this point)
|
||||
param_types, return_type = signature
|
||||
param_types = param_types[1:]
|
||||
signature = (param_types, return_type)
|
||||
signature = try_real_annotations(fn, loc)
|
||||
if signature is not None and is_method:
|
||||
# If this is a method, then the signature will include a type for
|
||||
# `self`, but type comments do not contain a `self`. So strip it
|
||||
# away here so everything is consistent (`inspect.ismethod` does
|
||||
# not work here since `fn` is unbound at this point)
|
||||
param_types, return_type = signature
|
||||
param_types = param_types[1:]
|
||||
signature = (param_types, return_type)
|
||||
|
||||
if signature is None:
|
||||
type_line, source = None, None
|
||||
|
||||
@ -6,7 +6,6 @@ from collections import OrderedDict
|
||||
import torch.utils.hooks as hooks
|
||||
import warnings
|
||||
import weakref
|
||||
from torch._six import imap
|
||||
from torch._C import _add_docstr
|
||||
from numbers import Number
|
||||
import functools
|
||||
@ -445,12 +444,6 @@ class Tensor(torch._C._TensorBase):
|
||||
return self.shape[0]
|
||||
|
||||
def __iter__(self):
|
||||
# NB: we use 'imap' and not 'map' here, so that in Python 2 we get a
|
||||
# generator and don't eagerly perform all the indexes. This could
|
||||
# save us work, and also helps keep trace ordering deterministic
|
||||
# (e.g., if you zip(*hiddens), the eager map will force all the
|
||||
# indexes of hiddens[0] before hiddens[1], while the generator
|
||||
# map will interleave them.)
|
||||
if self.dim() == 0:
|
||||
raise TypeError('iteration over a 0-d tensor')
|
||||
if torch._C._get_tracing_state():
|
||||
@ -458,7 +451,7 @@ class Tensor(torch._C._TensorBase):
|
||||
'Passing a tensor of different shape won\'t change the number of '
|
||||
'iterations executed (and might lead to errors or silently give '
|
||||
'incorrect results).', category=RuntimeWarning)
|
||||
return iter(imap(lambda i: self[i], range(self.size(0))))
|
||||
return iter(map(lambda i: self[i], range(self.size(0))))
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
@ -212,10 +212,7 @@ def repeat_test_for_types(dtypes):
|
||||
@wraps(f)
|
||||
def call_helper(self, *args):
|
||||
for dtype in dtypes:
|
||||
if PY34:
|
||||
with TestCase.subTest(self, dtype=dtype):
|
||||
f(self, *args, dtype=dtype)
|
||||
else:
|
||||
with TestCase.subTest(self, dtype=dtype):
|
||||
f(self, *args, dtype=dtype)
|
||||
|
||||
return call_helper
|
||||
@ -224,8 +221,6 @@ def repeat_test_for_types(dtypes):
|
||||
# Environment variable `IS_PYTORCH_CI` is set in `.jenkins/common.sh`.
|
||||
IS_PYTORCH_CI = bool(os.environ.get('IS_PYTORCH_CI'))
|
||||
|
||||
PY3 = sys.version_info > (3, 0)
|
||||
PY34 = sys.version_info >= (3, 4)
|
||||
|
||||
def discover_test_cases_recursively(suite_or_case):
|
||||
if isinstance(suite_or_case, unittest.TestCase):
|
||||
@ -242,7 +237,6 @@ def chunk_list(lst, nchunks):
|
||||
return [lst[i::nchunks] for i in range(nchunks)]
|
||||
|
||||
|
||||
|
||||
def run_tests(argv=UNITTEST_ARGS):
|
||||
if TEST_DISCOVER:
|
||||
suite = unittest.TestLoader().loadTestsFromModule(__main__)
|
||||
@ -319,15 +313,10 @@ def _check_module_exists(name):
|
||||
our tests, e.g., setting multiprocessing start method when imported
|
||||
(see librosa/#747, torchvision/#544).
|
||||
"""
|
||||
if not PY34: # Python [3, 3.4)
|
||||
import importlib
|
||||
loader = importlib.find_loader(name)
|
||||
return loader is not None
|
||||
else: # Python >= 3.4
|
||||
import importlib
|
||||
import importlib.util
|
||||
spec = importlib.util.find_spec(name)
|
||||
return spec is not None
|
||||
import importlib
|
||||
import importlib.util
|
||||
spec = importlib.util.find_spec(name)
|
||||
return spec is not None
|
||||
|
||||
TEST_NUMPY = _check_module_exists('numpy')
|
||||
TEST_SCIPY = _check_module_exists('scipy')
|
||||
|
||||
0
torch/testing/_internal/distributed/nn/__init__.py
Normal file
0
torch/testing/_internal/distributed/nn/__init__.py
Normal file
199
torch/testing/_internal/distributed/nn/api/remote_module_test.py
Normal file
199
torch/testing/_internal/distributed/nn/api/remote_module_test.py
Normal file
@ -0,0 +1,199 @@
|
||||
#!/usr/bin/python3
|
||||
import enum
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.testing._internal.dist_utils as dist_utils
|
||||
from torch import Tensor, nn
|
||||
from torch._jit_internal import Future
|
||||
from torch.distributed.nn.api.remote_module import (
|
||||
_RemoteModule,
|
||||
)
|
||||
from torch.distributed.nn import RemoteModule
|
||||
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
|
||||
RpcAgentTestFixture,
|
||||
)
|
||||
|
||||
|
||||
class ModuleCreationMode(enum.Enum):
|
||||
MODULE_CTOR_WITH_INTERFACE = "module_ctor_with_interface"
|
||||
MODULE_CTOR = "module_ctor"
|
||||
|
||||
|
||||
@torch.jit.interface
|
||||
class MyModuleInterface:
|
||||
def forward(
|
||||
self, tensor: Tensor, number: int, word: str = "default"
|
||||
) -> Tuple[str, int, Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
@torch.jit.interface
|
||||
class RemoteMyModuleInterface:
|
||||
def forward(
|
||||
self, tensor: Tensor, number: int, word: str = "default"
|
||||
) -> Tuple[str, int, Tensor]:
|
||||
pass
|
||||
|
||||
def forward_async(
|
||||
self, tensor: Tensor, number: int, word: str = "default"
|
||||
) -> Future[Tuple[str, int, Tensor]]:
|
||||
pass
|
||||
|
||||
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self, first_arg, first_kwarg=-1):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, tensor: Tensor, number: int, word: str = "default"
|
||||
) -> Tuple[str, int, Tensor]:
|
||||
return word, number, tensor
|
||||
|
||||
|
||||
class BadModule:
|
||||
def __init__(self, first_arg, first_kwarg=-1):
|
||||
pass
|
||||
|
||||
|
||||
def create_scripted_module(first_arg, first_kwarg=-1):
|
||||
module = MyModule(first_arg, first_kwarg=first_kwarg)
|
||||
scripted_module = torch.jit.script(module)
|
||||
return scripted_module
|
||||
|
||||
|
||||
class RemoteModuleTest(RpcAgentTestFixture):
|
||||
@property
|
||||
def world_size(self): # Override setting in RpcAgentTestFixture
|
||||
return 2
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._fork_processes()
|
||||
|
||||
@staticmethod
|
||||
def _create_remote_module_iter(dst_worker_name, modes=None):
|
||||
if modes is None:
|
||||
modes = ModuleCreationMode.__members__.values()
|
||||
|
||||
args = (1,)
|
||||
kwargs = dict(first_kwarg=2)
|
||||
|
||||
if ModuleCreationMode.MODULE_CTOR in modes:
|
||||
remote_module = RemoteModule(
|
||||
dst_worker_name, MyModule, args, kwargs
|
||||
)
|
||||
yield remote_module
|
||||
|
||||
if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes:
|
||||
remote_module = _RemoteModule(
|
||||
dst_worker_name,
|
||||
create_scripted_module,
|
||||
args,
|
||||
kwargs,
|
||||
_module_interface_cls=MyModuleInterface,
|
||||
)
|
||||
scripted_remote_module = torch.jit.script(remote_module)
|
||||
yield scripted_remote_module
|
||||
|
||||
@dist_utils.dist_init
|
||||
def test_bad_module(self):
|
||||
if self.rank != 0:
|
||||
return
|
||||
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
|
||||
args = (1,)
|
||||
kwargs = dict(first_kwarg=2)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instancee of <class nn.Module>,",
|
||||
):
|
||||
RemoteModule(dst_worker_name, BadModule, args, kwargs)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instancee of <class nn.Module>,",
|
||||
):
|
||||
RemoteModule(dst_worker_name, BadModule, args, kwargs)
|
||||
|
||||
@dist_utils.dist_init
|
||||
def test_forward_async(self):
|
||||
if self.rank != 0:
|
||||
return
|
||||
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
|
||||
args = (torch.ones(1), 2, "3")
|
||||
for remote_module in self._create_remote_module_iter(dst_worker_name):
|
||||
ret_fut = remote_module.forward_async(*args)
|
||||
ret = ret_fut.wait()
|
||||
self.assertEqual(ret, tuple(reversed(args)))
|
||||
|
||||
@dist_utils.dist_init
|
||||
def test_forward_async_script(self):
|
||||
if self.rank != 0:
|
||||
return
|
||||
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
|
||||
|
||||
scripted_remote_module = next(
|
||||
self._create_remote_module_iter(
|
||||
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
|
||||
)
|
||||
)
|
||||
|
||||
@torch.jit.script
|
||||
def run_forward_async(scripted_remote_module: RemoteMyModuleInterface):
|
||||
ret_fut = scripted_remote_module.forward_async(torch.ones(1), 2, "3")
|
||||
ret = ret_fut.wait()
|
||||
return ret
|
||||
|
||||
ret = run_forward_async(scripted_remote_module)
|
||||
|
||||
self.assertEqual(ret, ("3", 2, torch.ones(1)))
|
||||
|
||||
@dist_utils.dist_init
|
||||
def test_forward_sync(self):
|
||||
if self.rank != 0:
|
||||
return
|
||||
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
|
||||
args = (torch.ones(1), 2, "3")
|
||||
for remote_module in self._create_remote_module_iter(dst_worker_name):
|
||||
ret = remote_module.forward(*args)
|
||||
self.assertEqual(ret, tuple(reversed(args)))
|
||||
|
||||
@dist_utils.dist_init
|
||||
def test_forward_sync_script(self):
|
||||
if self.rank != 0:
|
||||
return
|
||||
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
|
||||
|
||||
scripted_remote_module = next(
|
||||
self._create_remote_module_iter(
|
||||
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
|
||||
)
|
||||
)
|
||||
|
||||
@torch.jit.script
|
||||
def run_forward(scripted_remote_module: MyModuleInterface):
|
||||
ret = scripted_remote_module.forward(torch.ones(1), 2, "3")
|
||||
return ret
|
||||
|
||||
ret = run_forward(scripted_remote_module)
|
||||
|
||||
self.assertEqual(ret, ("3", 2, torch.ones(1)))
|
||||
|
||||
@dist_utils.dist_init
|
||||
def test_forward_with_kwargs(self):
|
||||
if self.rank != 0:
|
||||
return
|
||||
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
|
||||
args = (torch.ones(1), 2)
|
||||
kwargs = dict(word="3")
|
||||
# Only test Python nn.Module, because script module methods don't support taking kwargs.
|
||||
for remote_module in self._create_remote_module_iter(
|
||||
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
|
||||
):
|
||||
ret_fut = remote_module.forward_async(*args, **kwargs)
|
||||
ret = ret_fut.wait()
|
||||
self.assertEqual(ret, tuple(reversed(args + ("3",))))
|
||||
|
||||
ret = remote_module.forward(*args, **kwargs)
|
||||
self.assertEqual(ret, tuple(reversed(args + ("3",))))
|
||||
@ -1045,8 +1045,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
||||
_set_rpc_done(None, 0)
|
||||
|
||||
# wait until all trainers are done
|
||||
for fut in futures:
|
||||
fut.wait()
|
||||
torch.futures.wait_all(futures)
|
||||
|
||||
@dist_init
|
||||
def test_trainer_ps(self):
|
||||
|
||||
@ -822,8 +822,7 @@ class RpcTest(RpcAgentTestFixture):
|
||||
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
|
||||
futs.append(fut)
|
||||
|
||||
for fut in futs:
|
||||
fut.wait()
|
||||
for fut in torch.futures.collect_all(futs).wait():
|
||||
self.assertEqual(fut.wait(), 0)
|
||||
|
||||
# Phase 2: Only worker2 has workload.
|
||||
@ -835,9 +834,8 @@ class RpcTest(RpcAgentTestFixture):
|
||||
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
|
||||
futs.append(fut)
|
||||
|
||||
for fut in futs:
|
||||
fut.wait()
|
||||
self.assertEqual(fut.wait(), 0)
|
||||
for val in torch.futures.wait_all(futs):
|
||||
self.assertEqual(val, 0)
|
||||
|
||||
def test_wait_all_workers(self):
|
||||
rpc.init_rpc(
|
||||
@ -1242,9 +1240,9 @@ class RpcTest(RpcAgentTestFixture):
|
||||
futs.append(fut)
|
||||
|
||||
j = 0
|
||||
for fut in futs:
|
||||
for val in torch.futures.wait_all(futs):
|
||||
self.assertEqual(
|
||||
fut.wait(), my_tensor_function(torch.ones(j, j), torch.ones(j, j))
|
||||
val, my_tensor_function(torch.ones(j, j), torch.ones(j, j))
|
||||
)
|
||||
j += 1
|
||||
|
||||
@ -1310,8 +1308,8 @@ class RpcTest(RpcAgentTestFixture):
|
||||
fut = rpc.rpc_async(worker_name(dst_rank), f, args=args)
|
||||
futs.append(fut)
|
||||
|
||||
for fut in futs:
|
||||
self.assertEqual(fut.wait(), 0)
|
||||
for val in torch.futures.wait_all(futs):
|
||||
self.assertEqual(val, 0)
|
||||
tok = time.time()
|
||||
print(
|
||||
"Rank {} finished testing {} times in {} seconds.".format(
|
||||
|
||||
@ -122,7 +122,8 @@ class DataLoader(object):
|
||||
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
|
||||
batch_sampler=None, num_workers=0, collate_fn=None,
|
||||
pin_memory=False, drop_last=False, timeout=0,
|
||||
worker_init_fn=None, multiprocessing_context=None):
|
||||
worker_init_fn=None, multiprocessing_context=None,
|
||||
generator=None):
|
||||
torch._C._log_api_usage_once("python.data_loader")
|
||||
|
||||
if num_workers < 0:
|
||||
@ -211,7 +212,7 @@ class DataLoader(object):
|
||||
sampler = _InfiniteConstantSampler()
|
||||
else: # map-style
|
||||
if shuffle:
|
||||
sampler = RandomSampler(dataset)
|
||||
sampler = RandomSampler(dataset, generator=generator)
|
||||
else:
|
||||
sampler = SequentialSampler(dataset)
|
||||
|
||||
@ -223,6 +224,7 @@ class DataLoader(object):
|
||||
self.drop_last = drop_last
|
||||
self.sampler = sampler
|
||||
self.batch_sampler = batch_sampler
|
||||
self.generator = generator
|
||||
|
||||
if collate_fn is None:
|
||||
if self._auto_collation:
|
||||
@ -336,7 +338,7 @@ class _BaseDataLoaderIter(object):
|
||||
self._timeout = loader.timeout
|
||||
self._collate_fn = loader.collate_fn
|
||||
self._sampler_iter = iter(self._index_sampler)
|
||||
self._base_seed = torch.empty((), dtype=torch.int64).random_().item()
|
||||
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
|
||||
self._num_yielded = 0
|
||||
|
||||
def __iter__(self):
|
||||
@ -1061,7 +1063,13 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||
if self._workers_status[worker_id]:
|
||||
self._shutdown_worker(worker_id)
|
||||
for w in self._workers:
|
||||
w.join()
|
||||
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
|
||||
if w.is_alive():
|
||||
# Existing mechanisms try to make the workers exit
|
||||
# peacefully, but in case that we unfortunately reach
|
||||
# here, which we shouldn't, (e.g., pytorch/pytorch#39570),
|
||||
# we kill the worker.
|
||||
w.terminate()
|
||||
for q in self._index_queues:
|
||||
q.cancel_join_thread()
|
||||
q.close()
|
||||
|
||||
@ -74,12 +74,14 @@ class RandomSampler(Sampler):
|
||||
replacement (bool): samples are drawn with replacement if ``True``, default=``False``
|
||||
num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
|
||||
is supposed to be specified only when `replacement` is ``True``.
|
||||
generator (Generator): Generator used in sampling.
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, replacement=False, num_samples=None):
|
||||
def __init__(self, data_source, replacement=False, num_samples=None, generator=None):
|
||||
self.data_source = data_source
|
||||
self.replacement = replacement
|
||||
self._num_samples = num_samples
|
||||
self.generator = generator
|
||||
|
||||
if not isinstance(self.replacement, bool):
|
||||
raise TypeError("replacement should be a boolean value, but got "
|
||||
@ -103,8 +105,9 @@ class RandomSampler(Sampler):
|
||||
def __iter__(self):
|
||||
n = len(self.data_source)
|
||||
if self.replacement:
|
||||
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
|
||||
return iter(torch.randperm(n).tolist())
|
||||
rand_tensor = torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=self.generator)
|
||||
return iter(rand_tensor.tolist())
|
||||
return iter(torch.randperm(n, generator=self.generator).tolist())
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
@ -115,13 +118,15 @@ class SubsetRandomSampler(Sampler):
|
||||
|
||||
Arguments:
|
||||
indices (sequence): a sequence of indices
|
||||
generator (Generator): Generator used in sampling.
|
||||
"""
|
||||
|
||||
def __init__(self, indices):
|
||||
def __init__(self, indices, generator=None):
|
||||
self.indices = indices
|
||||
self.generator = generator
|
||||
|
||||
def __iter__(self):
|
||||
return (self.indices[i] for i in torch.randperm(len(self.indices)))
|
||||
return (self.indices[i] for i in torch.randperm(len(self.indices), generator=self.generator))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.indices)
|
||||
@ -136,6 +141,7 @@ class WeightedRandomSampler(Sampler):
|
||||
replacement (bool): if ``True``, samples are drawn with replacement.
|
||||
If not, they are drawn without replacement, which means that when a
|
||||
sample index is drawn for a row, it cannot be drawn again for that row.
|
||||
generator (Generator): Generator used in sampling.
|
||||
|
||||
Example:
|
||||
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
|
||||
@ -144,7 +150,7 @@ class WeightedRandomSampler(Sampler):
|
||||
[0, 1, 4, 3, 2]
|
||||
"""
|
||||
|
||||
def __init__(self, weights, num_samples, replacement=True):
|
||||
def __init__(self, weights, num_samples, replacement=True, generator=None):
|
||||
if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
|
||||
num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
@ -155,9 +161,11 @@ class WeightedRandomSampler(Sampler):
|
||||
self.weights = torch.as_tensor(weights, dtype=torch.double)
|
||||
self.num_samples = num_samples
|
||||
self.replacement = replacement
|
||||
self.generator = generator
|
||||
|
||||
def __iter__(self):
|
||||
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
|
||||
rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
|
||||
return iter(rand_tensor.tolist())
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
Reference in New Issue
Block a user