mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 16:44:54 +08:00
Compare commits
156 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c54597e0b2 | |||
| 833b8cbc7a | |||
| 75aeb16e05 | |||
| 10bb6bb9b8 | |||
| 3c9ef69c37 | |||
| dee987d6ee | |||
| 138f254ec1 | |||
| c7c8aaa7f0 | |||
| d0db624e02 | |||
| e3e7b76310 | |||
| dad02bceb9 | |||
| b195285879 | |||
| 8f3da5b51d | |||
| 825e919eb8 | |||
| acb0ce8885 | |||
| 72089c9c36 | |||
| cf2f158fec | |||
| 6470b5bd21 | |||
| 44196955e2 | |||
| f08ec1394d | |||
| f8fb25e0a2 | |||
| 6a0c66752f | |||
| a1bd4efb08 | |||
| b43ce05268 | |||
| 80e56cfda9 | |||
| 24701fc5a7 | |||
| f78a266d99 | |||
| f096fb6859 | |||
| a3e11d606b | |||
| 79232c24e2 | |||
| 15d9d499ab | |||
| 962084c8e8 | |||
| 7518b1eefb | |||
| 8215d7a4ba | |||
| 5aaa220d84 | |||
| 12c16ab9bc | |||
| 76520512e7 | |||
| 66de965882 | |||
| 10d32fb0b7 | |||
| e72c9b6e4a | |||
| ac1f68127a | |||
| 60d1852c7b | |||
| d53eb521fc | |||
| 9808932f10 | |||
| ea876eb6d5 | |||
| 0a45864866 | |||
| 2560b39796 | |||
| 21afa4c88b | |||
| 9fc3c5e4d2 | |||
| 3e3501c98d | |||
| 5e6fcd02b5 | |||
| d46ebcfadf | |||
| 41480c8cf2 | |||
| 236890d902 | |||
| 55632d81d2 | |||
| 0b276d622e | |||
| c81491b37d | |||
| 42e189425f | |||
| 3cfa0d7199 | |||
| 7c9e088661 | |||
| e78aa4bb84 | |||
| f8e94d0d8b | |||
| ebe6f40fce | |||
| 5fb37efb46 | |||
| 4f47855873 | |||
| 52ae6f682f | |||
| c35f58f97b | |||
| 659b2f3154 | |||
| 5ea05cfb96 | |||
| dc9a5b7d2f | |||
| f7ab5a128a | |||
| 368cbe615d | |||
| d4c9a3782b | |||
| 172dca5e8b | |||
| 818bf0c408 | |||
| 03dcf8a83b | |||
| 604f607fd1 | |||
| 956d946c25 | |||
| 970caaa621 | |||
| 00a5980cdf | |||
| e24eee04f0 | |||
| f1b3af4ee2 | |||
| fb2d28f477 | |||
| 3a704ff725 | |||
| 0180e638e5 | |||
| 95c6ae04fb | |||
| 27c4c6e0af | |||
| da17414b3f | |||
| be2b27a747 | |||
| aec2c8f752 | |||
| 13e34b4679 | |||
| 57373c7c29 | |||
| 79f5bf84e5 | |||
| 3ed720079e | |||
| e7c1e6a8e3 | |||
| f1d0d73ed7 | |||
| 9c411513bf | |||
| ce78bc898b | |||
| 887002e932 | |||
| 31dea5ff23 | |||
| ec4602a973 | |||
| a38749d15f | |||
| 6ee77b4edd | |||
| 343d65db91 | |||
| a90913105c | |||
| 9368596059 | |||
| 80ed795ff1 | |||
| a2938e3d11 | |||
| 2ad967dbe4 | |||
| 7415c090ac | |||
| a1fa995044 | |||
| 3c2ecc6b15 | |||
| fa1516d319 | |||
| 5e26f49db4 | |||
| 7694f65120 | |||
| b5ebf68df1 | |||
| aa46055274 | |||
| 2cad802b68 | |||
| 2d01f384f1 | |||
| f8d4f980b3 | |||
| 4f5a6c366e | |||
| ecfcf39f30 | |||
| 3975a2676e | |||
| 138ee75a3b | |||
| 0048f228cb | |||
| 2748b920ab | |||
| a92a2312d4 | |||
| 945ce5cdb0 | |||
| b39de2cbbe | |||
| 49a555e0f5 | |||
| ce13900148 | |||
| 4c77ad6ee4 | |||
| 0bc4246425 | |||
| c45ff2efe6 | |||
| 99b520cc5d | |||
| e05607aee1 | |||
| a360ba1734 | |||
| c661b963b9 | |||
| e374dc1696 | |||
| 116e0c7f38 | |||
| 45596d5289 | |||
| 342e7b873d | |||
| 00410c4496 | |||
| 8b9276bbee | |||
| 3238786ea1 | |||
| 07ebbcbcb3 | |||
| ca555abcf9 | |||
| 63893c3fa2 | |||
| f8ae34706e | |||
| 7179002bfb | |||
| 43b5be1d78 | |||
| b5f6fdb814 | |||
| a69d819901 | |||
| fef2b1526d | |||
| 3719994c96 | |||
| 4461ae8090 |
3
.gitignore
vendored
3
.gitignore
vendored
@ -15,6 +15,9 @@ torch/csrc/nn/THNN.cwrap
|
||||
torch/csrc/nn/THNN.cpp
|
||||
torch/csrc/nn/THCUNN.cwrap
|
||||
torch/csrc/nn/THCUNN.cpp
|
||||
torch/csrc/nn/THNN_generic.cwrap
|
||||
torch/csrc/nn/THNN_generic.cpp
|
||||
torch/csrc/nn/THNN_generic.h
|
||||
docs/src/**/*
|
||||
test/data/legacy_modules.t7
|
||||
test/htmlcov
|
||||
|
||||
23
.travis.yml
23
.travis.yml
@ -4,16 +4,25 @@ python:
|
||||
- 2.7.8
|
||||
- 2.7
|
||||
- 3.5
|
||||
- 3.6
|
||||
- nightly
|
||||
|
||||
cache:
|
||||
- ccache
|
||||
- directories:
|
||||
- $HOME/.ccache
|
||||
|
||||
install:
|
||||
- export CC="gcc-4.8"
|
||||
- export CXX="g++-4.8"
|
||||
- unset CCACHE_DISABLE
|
||||
- export CCACHE_DIR=$HOME/.ccache
|
||||
- export CC="ccache gcc-4.8"
|
||||
- export CXX="ccache g++-4.8"
|
||||
- ccache --show-stats
|
||||
- travis_retry pip install -r requirements.txt
|
||||
- travis_retry pip install .
|
||||
- python setup.py install
|
||||
|
||||
script:
|
||||
- ./test/run_test.sh
|
||||
- OMP_NUM_THREADS=2 ./test/run_test.sh
|
||||
|
||||
addons:
|
||||
apt:
|
||||
@ -30,3 +39,9 @@ sudo: false
|
||||
|
||||
matrix:
|
||||
fast_finish: true
|
||||
include:
|
||||
env: LINT_CHECK
|
||||
python: "2.7"
|
||||
addons: true
|
||||
install: pip install pep8
|
||||
script: pep8
|
||||
|
||||
33
Dockerfile
Normal file
33
Dockerfile
Normal file
@ -0,0 +1,33 @@
|
||||
FROM nvidia/cuda:8.0-cudnn5-devel-ubuntu14.04
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
cmake \
|
||||
git \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libjpeg-dev \
|
||||
libpng-dev &&\
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-4.2.12-Linux-x86_64.sh && \
|
||||
chmod +x ~/miniconda.sh && \
|
||||
~/miniconda.sh -b -p /opt/conda && \
|
||||
rm ~/miniconda.sh && \
|
||||
/opt/conda/bin/conda install conda-build && \
|
||||
/opt/conda/bin/conda create -y --name pytorch-py35 python=3.5.2 numpy scipy ipython mkl&& \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
ENV PATH /opt/conda/envs/pytorch-py35/bin:$PATH
|
||||
RUN conda install --name pytorch-py35 -c soumith magma-cuda80
|
||||
# This must be done before pip so that requirements.txt is available
|
||||
WORKDIR /opt/pytorch
|
||||
COPY . .
|
||||
|
||||
RUN cat requirements.txt | xargs -n1 pip install --no-cache-dir && \
|
||||
TORCH_CUDA_ARCH_LIST="3.5 5.2 6.0 6.1+PTX" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
|
||||
CMAKE_LIBRARY_PATH=/opt/conda/envs/pytorch-py35/lib \
|
||||
CMAKE_INCLUDE_PATH=/opt/conda/envs/pytorch-py35/include \
|
||||
pip install -v .
|
||||
|
||||
WORKDIR /workspace
|
||||
RUN chmod -R a+w /workspace
|
||||
49
README.md
49
README.md
@ -14,17 +14,17 @@ We are in an early-release Beta. Expect some adventures and rough edges.
|
||||
- [Installation](#installation)
|
||||
- [Binaries](#binaries)
|
||||
- [From source](#from-source)
|
||||
- [Docker image](#docker-image)
|
||||
- [Getting Started](#getting-started)
|
||||
- [Communication](#communication)
|
||||
- [Releases and Contributing](#releases-and-contributing)
|
||||
- [The Team](#the-team)
|
||||
|
||||
| Python | **`Linux CPU`** | **`Linux GPU`** |
|
||||
|--------|--------------------|------------------|
|
||||
| 2.7.8 | [](https://travis-ci.com/apaszke/pytorch) | |
|
||||
| 2.7 | [](https://travis-ci.com/apaszke/pytorch) | [](https://build.pytorch.org/job/pytorch-master-py2) |
|
||||
| 3.5 | [](https://travis-ci.com/apaszke/pytorch) | [](https://build.pytorch.org/job/pytorch-master-py3) |
|
||||
| Nightly| [](https://travis-ci.com/apaszke/pytorch) | |
|
||||
| System | Python | Status |
|
||||
| --- | --- | --- |
|
||||
| Linux CPU | 2.7.8, 2.7, 3.5, nightly | [](https://travis-ci.org/pytorch/pytorch) |
|
||||
| Linux GPU | 2.7 | [](https://build.pytorch.org/job/pytorch-master-py2) |
|
||||
| Linux GPU | 3.5 | [](https://build.pytorch.org/job/pytorch-master-py3) |
|
||||
|
||||
## More about PyTorch
|
||||
|
||||
@ -101,7 +101,7 @@ We hope you never spend hours debugging your code because of bad stack traces or
|
||||
|
||||
PyTorch has minimal framework overhead. We integrate acceleration libraries
|
||||
such as Intel MKL and NVIDIA (CuDNN, NCCL) to maximize speed.
|
||||
At the core, it's CPU and GPU Tensor and Neural Network backends
|
||||
At the core, its CPU and GPU Tensor and Neural Network backends
|
||||
(TH, THC, THNN, THCUNN) are written as independent libraries with a C99 API.
|
||||
They are mature and have been tested for years.
|
||||
|
||||
@ -135,24 +135,36 @@ conda install pytorch torchvision -c soumith
|
||||
|
||||
### From source
|
||||
|
||||
Instructions for an Anaconda environment.
|
||||
If you are installing from source, we highly recommend installing an [Anaconda](https://www.continuum.io/downloads) environment.
|
||||
You will get a high-quality BLAS library (MKL) and you get a controlled compiler version regardless of your Linux distro.
|
||||
|
||||
Once you have [anaconda](https://www.continuum.io/downloads) installed, here are the instructions.
|
||||
|
||||
If you want to compile with CUDA support, install
|
||||
- [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 7.5 or above
|
||||
- [NVIDIA CuDNN](https://developer.nvidia.com/cudnn) v5.x
|
||||
|
||||
If you want to disable CUDA support, export environment variable `NO_CUDA=1`.
|
||||
|
||||
#### Install optional dependencies
|
||||
|
||||
On Linux
|
||||
```bash
|
||||
export CMAKE_PREFIX_PATH=[anaconda root directory]
|
||||
|
||||
# Install basic dependencies
|
||||
conda install numpy mkl setuptools cmake gcc cffi
|
||||
|
||||
# On Linux, add LAPACK support for the GPU
|
||||
# Add LAPACK support for the GPU
|
||||
conda install -c soumith magma-cuda75 # or magma-cuda80 if CUDA 8.0
|
||||
```
|
||||
|
||||
On OSX
|
||||
```bash
|
||||
export CMAKE_PREFIX_PATH=[anaconda root directory]
|
||||
conda install numpy setuptools cmake cffi
|
||||
```
|
||||
|
||||
#### Install PyTorch
|
||||
```bash
|
||||
export MACOSX_DEPLOYMENT_TARGET=10.9 # if OSX
|
||||
@ -160,6 +172,25 @@ pip install -r requirements.txt
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
### Docker image
|
||||
|
||||
Dockerfiles are supplied to build images with cuda support and cudnn v5 and cudnn v6 RC. Build them as usual
|
||||
```
|
||||
docker build . -t pytorch-cudnnv5
|
||||
```
|
||||
or
|
||||
```
|
||||
docker build . -t pytorch-cudnnv6 -f tools/docker/Dockerfile-v6
|
||||
```
|
||||
and run them with nvidia-docker:
|
||||
```
|
||||
nvidia-docker run --rm -ti --ipc=host pytorch-cudnnv5
|
||||
```
|
||||
Please note that pytorch uses shared memory to share data between processes, so if torch multiprocessing is used (e.g.
|
||||
for multithreaded data loaders) the default shared memory segment size that container runs with is not enough, and you
|
||||
should increase shared memory size either with --ipc=host or --shm-size command line options to nvidia-docker run.
|
||||
|
||||
|
||||
## Getting Started
|
||||
|
||||
Three pointers to get you started:
|
||||
|
||||
@ -201,12 +201,13 @@ from docutils import nodes
|
||||
from sphinx.util.docfields import TypedField
|
||||
from sphinx import addnodes
|
||||
|
||||
|
||||
def patched_make_field(self, types, domain, items):
|
||||
# type: (List, unicode, Tuple) -> nodes.field
|
||||
def handle_item(fieldarg, content):
|
||||
par = nodes.paragraph()
|
||||
par += addnodes.literal_strong('', fieldarg) # Patch: this line added
|
||||
#par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
|
||||
# par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
|
||||
# addnodes.literal_strong))
|
||||
if fieldarg in types:
|
||||
par += nodes.Text(' (')
|
||||
|
||||
@ -7,6 +7,12 @@ torch.nn
|
||||
.. automodule:: torch.nn
|
||||
.. currentmodule:: torch.nn
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
.. autoclass:: Parameter
|
||||
:members:
|
||||
|
||||
Containers
|
||||
----------------------------------
|
||||
|
||||
@ -362,6 +368,12 @@ Loss functions
|
||||
.. autoclass:: NLLLoss
|
||||
:members:
|
||||
|
||||
:hidden:`NLLLoss2d`
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: NLLLoss2d
|
||||
:members:
|
||||
|
||||
:hidden:`KLDivLoss`
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -432,6 +444,19 @@ Vision layers
|
||||
.. autoclass:: PixelShuffle
|
||||
:members:
|
||||
|
||||
:hidden:`UpsamplingNearest2d`
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: UpsamplingNearest2d
|
||||
:members:
|
||||
|
||||
:hidden:`UpsamplingBilinear2d`
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: UpsamplingBilinear2d
|
||||
:members:
|
||||
|
||||
|
||||
Multi-GPU layers
|
||||
----------------
|
||||
|
||||
|
||||
@ -29,12 +29,15 @@ Below you can find a small example showcasing this::
|
||||
b = torch.FloatTensor(1).cuda()
|
||||
# a.get_device() == b.get_device() == 1
|
||||
|
||||
c = a + b
|
||||
# c.get_device() == 1
|
||||
|
||||
z = x + y
|
||||
# z.get_device() == 1
|
||||
# z.get_device() == 0
|
||||
|
||||
# even within a context, you can give a GPU id to the .cuda call
|
||||
c = torch.randn(2).cuda(2)
|
||||
# c.get_device() == 2
|
||||
d = torch.randn(2).cuda(2)
|
||||
# d.get_device() == 2
|
||||
|
||||
Best practices
|
||||
--------------
|
||||
|
||||
@ -144,9 +144,9 @@ This is how a ``Linear`` module can be implemented::
|
||||
if bias is not None:
|
||||
self.bias.data.uniform_(-0.1, 0.1)
|
||||
|
||||
def forward(self, input):
|
||||
# See the autograd section for explanation of what happens here.
|
||||
return Linear()(input, self.weight, self.bias)
|
||||
def forward(self, input):
|
||||
# See the autograd section for explanation of what happens here.
|
||||
return Linear()(input, self.weight, self.bias)
|
||||
|
||||
|
||||
Writing custom C extensions
|
||||
|
||||
@ -106,6 +106,8 @@ Algorithms
|
||||
:members:
|
||||
.. autoclass:: ASGD
|
||||
:members:
|
||||
.. autoclass:: LBFGS
|
||||
:members:
|
||||
.. autoclass:: RMSprop
|
||||
:members:
|
||||
.. autoclass:: Rprop
|
||||
|
||||
@ -14,8 +14,8 @@ Data type CPU tensor GPU tensor
|
||||
32-bit floating point :class:`torch.FloatTensor` :class:`torch.cuda.FloatTensor`
|
||||
64-bit floating point :class:`torch.DoubleTensor` :class:`torch.cuda.DoubleTensor`
|
||||
16-bit floating point N/A :class:`torch.cuda.HalfTensor`
|
||||
8-bit integer (signed) :class:`torch.ByteTensor` :class:`torch.cuda.ByteTensor`
|
||||
8-bit integer (unsigned) :class:`torch.CharTensor` :class:`torch.cuda.CharTensor`
|
||||
8-bit integer (unsigned) :class:`torch.ByteTensor` :class:`torch.cuda.ByteTensor`
|
||||
8-bit integer (signed) :class:`torch.CharTensor` :class:`torch.cuda.CharTensor`
|
||||
16-bit integer (signed) :class:`torch.ShortTensor` :class:`torch.cuda.ShortTensor`
|
||||
32-bit integer (signed) :class:`torch.IntTensor` :class:`torch.cuda.IntTensor`
|
||||
64-bit integer (signed) :class:`torch.LongTensor` :class:`torch.cuda.LongTensor`
|
||||
@ -251,7 +251,6 @@ view of a storage and defines numeric operations on it.
|
||||
.. automethod:: scatter_
|
||||
.. automethod:: select
|
||||
.. automethod:: set_
|
||||
.. automethod:: set_index
|
||||
.. automethod:: share_memory_
|
||||
.. automethod:: short
|
||||
.. automethod:: sigmoid
|
||||
|
||||
@ -37,6 +37,7 @@ Indexing, Slicing, Joining, Mutating Ops
|
||||
.. autofunction:: stack
|
||||
.. autofunction:: t
|
||||
.. autofunction:: transpose
|
||||
.. autofunction:: unbind
|
||||
|
||||
|
||||
Random sampling
|
||||
|
||||
8
setup.cfg
Normal file
8
setup.cfg
Normal file
@ -0,0 +1,8 @@
|
||||
[pep8]
|
||||
max-line-length = 120
|
||||
ignore = E402,E721,E731,W503
|
||||
exclude = docs/src
|
||||
|
||||
[flake8]
|
||||
max-line-length = 120
|
||||
ignore = E305,E402,E721,E731,F401,F403,F405,F811,F812,F821,F841
|
||||
182
setup.py
182
setup.py
@ -1,6 +1,7 @@
|
||||
from setuptools import setup, Extension, distutils, Command, find_packages
|
||||
import setuptools.command.build_ext
|
||||
import setuptools.command.install
|
||||
import distutils.unixccompiler
|
||||
import distutils.command.build
|
||||
import distutils.command.clean
|
||||
import platform
|
||||
@ -13,18 +14,26 @@ from tools.setup_helpers.env import check_env_flag
|
||||
from tools.setup_helpers.cuda import WITH_CUDA, CUDA_HOME
|
||||
from tools.setup_helpers.cudnn import WITH_CUDNN, CUDNN_LIB_DIR, CUDNN_INCLUDE_DIR
|
||||
DEBUG = check_env_flag('DEBUG')
|
||||
WITH_DISTRIBUTED = check_env_flag('WITH_DISTRIBUTED')
|
||||
WITH_DISTRIBUTED_MW = WITH_DISTRIBUTED and check_env_flag('WITH_DISTRIBUTED_MW')
|
||||
|
||||
################################################################################
|
||||
# Monkey-patch setuptools to compile in parallel
|
||||
################################################################################
|
||||
original_link = distutils.unixccompiler.UnixCCompiler.link
|
||||
|
||||
def parallelCCompile(self, sources, output_dir=None, macros=None, include_dirs=None, debug=0, extra_preargs=None, extra_postargs=None, depends=None):
|
||||
|
||||
def parallelCCompile(self, sources, output_dir=None, macros=None,
|
||||
include_dirs=None, debug=0, extra_preargs=None,
|
||||
extra_postargs=None, depends=None):
|
||||
# those lines are copied from distutils.ccompiler.CCompiler directly
|
||||
macros, objects, extra_postargs, pp_opts, build = self._setup_compile(output_dir, macros, include_dirs, sources, depends, extra_postargs)
|
||||
macros, objects, extra_postargs, pp_opts, build = self._setup_compile(
|
||||
output_dir, macros, include_dirs, sources, depends, extra_postargs)
|
||||
cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
|
||||
|
||||
# compile using a thread pool
|
||||
import multiprocessing.pool
|
||||
|
||||
def _single_compile(obj):
|
||||
src, ext = build[obj]
|
||||
self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
|
||||
@ -33,12 +42,23 @@ def parallelCCompile(self, sources, output_dir=None, macros=None, include_dirs=N
|
||||
|
||||
return objects
|
||||
|
||||
|
||||
def patched_link(self, *args, **kwargs):
|
||||
_cxx = self.compiler_cxx
|
||||
self.compiler_cxx = None
|
||||
result = original_link(self, *args, **kwargs)
|
||||
self.compiler_cxx = _cxx
|
||||
return result
|
||||
|
||||
|
||||
distutils.ccompiler.CCompiler.compile = parallelCCompile
|
||||
distutils.unixccompiler.UnixCCompiler.link = patched_link
|
||||
|
||||
################################################################################
|
||||
# Custom build commands
|
||||
################################################################################
|
||||
|
||||
|
||||
class build_deps(Command):
|
||||
user_options = []
|
||||
|
||||
@ -53,6 +73,8 @@ class build_deps(Command):
|
||||
build_all_cmd = ['bash', 'torch/lib/build_all.sh']
|
||||
if WITH_CUDA:
|
||||
build_all_cmd += ['--with-cuda']
|
||||
if WITH_DISTRIBUTED:
|
||||
build_all_cmd += ['--with-distributed']
|
||||
if subprocess.call(build_all_cmd) != 0:
|
||||
sys.exit(1)
|
||||
generate_nn_wrappers()
|
||||
@ -73,6 +95,7 @@ class build_module(Command):
|
||||
|
||||
|
||||
class build_ext(setuptools.command.build_ext.build_ext):
|
||||
|
||||
def run(self):
|
||||
# Print build options
|
||||
if WITH_NUMPY:
|
||||
@ -116,6 +139,7 @@ class build(distutils.command.build.build):
|
||||
|
||||
|
||||
class install(setuptools.command.install.install):
|
||||
|
||||
def run(self):
|
||||
if not self.skip_build:
|
||||
self.run_command('build_deps')
|
||||
@ -123,6 +147,7 @@ class install(setuptools.command.install.install):
|
||||
|
||||
|
||||
class clean(distutils.command.clean.clean):
|
||||
|
||||
def run(self):
|
||||
import glob
|
||||
with open('.gitignore', 'r') as f:
|
||||
@ -138,7 +163,6 @@ class clean(distutils.command.clean.clean):
|
||||
distutils.command.clean.clean.run(self)
|
||||
|
||||
|
||||
|
||||
################################################################################
|
||||
# Configure compile flags
|
||||
################################################################################
|
||||
@ -161,31 +185,35 @@ include_dirs += [
|
||||
tmp_install_path + "/include",
|
||||
tmp_install_path + "/include/TH",
|
||||
tmp_install_path + "/include/THPP",
|
||||
tmp_install_path + "/include/THNN",
|
||||
]
|
||||
|
||||
extra_link_args.append('-L' + lib_path)
|
||||
|
||||
# we specify exact lib names to avoid conflict with lua-torch installs
|
||||
TH_LIB = os.path.join(lib_path, 'libTH.so.1')
|
||||
THS_LIB = os.path.join(lib_path, 'libTHS.so.1')
|
||||
THC_LIB = os.path.join(lib_path, 'libTHC.so.1')
|
||||
THCS_LIB = os.path.join(lib_path, 'libTHCS.so.1')
|
||||
THNN_LIB = os.path.join(lib_path, 'libTHNN.so.1')
|
||||
TH_LIB = os.path.join(lib_path, 'libTH.so.1')
|
||||
THS_LIB = os.path.join(lib_path, 'libTHS.so.1')
|
||||
THC_LIB = os.path.join(lib_path, 'libTHC.so.1')
|
||||
THCS_LIB = os.path.join(lib_path, 'libTHCS.so.1')
|
||||
THNN_LIB = os.path.join(lib_path, 'libTHNN.so.1')
|
||||
THCUNN_LIB = os.path.join(lib_path, 'libTHCUNN.so.1')
|
||||
THPP_LIB = os.path.join(lib_path, 'libTHPP.so.1')
|
||||
THPP_LIB = os.path.join(lib_path, 'libTHPP.so.1')
|
||||
THD_LIB = os.path.join(lib_path, 'libTHD.so.1')
|
||||
if platform.system() == 'Darwin':
|
||||
TH_LIB = os.path.join(lib_path, 'libTH.1.dylib')
|
||||
THS_LIB = os.path.join(lib_path, 'libTHS.1.dylib')
|
||||
THC_LIB = os.path.join(lib_path, 'libTHC.1.dylib')
|
||||
THCS_LIB = os.path.join(lib_path, 'libTHCS.1.dylib')
|
||||
THNN_LIB = os.path.join(lib_path, 'libTHNN.1.dylib')
|
||||
TH_LIB = os.path.join(lib_path, 'libTH.1.dylib')
|
||||
THS_LIB = os.path.join(lib_path, 'libTHS.1.dylib')
|
||||
THC_LIB = os.path.join(lib_path, 'libTHC.1.dylib')
|
||||
THCS_LIB = os.path.join(lib_path, 'libTHCS.1.dylib')
|
||||
THNN_LIB = os.path.join(lib_path, 'libTHNN.1.dylib')
|
||||
THCUNN_LIB = os.path.join(lib_path, 'libTHCUNN.1.dylib')
|
||||
THPP_LIB = os.path.join(lib_path, 'libTHPP.1.dylib')
|
||||
THPP_LIB = os.path.join(lib_path, 'libTHPP.1.dylib')
|
||||
THD_LIB = os.path.join(lib_path, 'libTHD.1.dylib')
|
||||
|
||||
main_compile_args = ['-D_THP_CORE']
|
||||
main_libraries = ['shm']
|
||||
main_link_args = [TH_LIB, THS_LIB, THPP_LIB]
|
||||
main_link_args = [TH_LIB, THS_LIB, THPP_LIB, THNN_LIB]
|
||||
main_sources = [
|
||||
"torch/csrc/PtrWrapper.cpp",
|
||||
"torch/csrc/Module.cpp",
|
||||
"torch/csrc/Generator.cpp",
|
||||
"torch/csrc/Size.cpp",
|
||||
@ -200,6 +228,7 @@ main_sources = [
|
||||
"torch/csrc/autograd/variable.cpp",
|
||||
"torch/csrc/autograd/function.cpp",
|
||||
"torch/csrc/autograd/engine.cpp",
|
||||
"torch/csrc/nn/THNN_generic.cpp",
|
||||
]
|
||||
|
||||
try:
|
||||
@ -210,6 +239,20 @@ try:
|
||||
except ImportError:
|
||||
WITH_NUMPY = False
|
||||
|
||||
if WITH_DISTRIBUTED:
|
||||
extra_compile_args += ['-DWITH_DISTRIBUTED']
|
||||
main_sources += [
|
||||
"torch/csrc/distributed/Module.cpp",
|
||||
"torch/csrc/distributed/utils.cpp",
|
||||
]
|
||||
if WITH_DISTRIBUTED_MW:
|
||||
main_sources += [
|
||||
"torch/csrc/distributed/Tensor.cpp",
|
||||
"torch/csrc/distributed/Storage.cpp",
|
||||
]
|
||||
include_dirs += [tmp_install_path + "/include/THD"]
|
||||
main_link_args += [THD_LIB]
|
||||
|
||||
if WITH_CUDA:
|
||||
cuda_lib_dirs = ['lib64', 'lib']
|
||||
cuda_include_path = os.path.join(CUDA_HOME, 'include')
|
||||
@ -218,11 +261,12 @@ if WITH_CUDA:
|
||||
if os.path.exists(cuda_lib_path):
|
||||
break
|
||||
include_dirs.append(cuda_include_path)
|
||||
include_dirs.append(tmp_install_path + "/include/THCUNN")
|
||||
extra_link_args.append('-L' + cuda_lib_path)
|
||||
extra_link_args.append('-Wl,-rpath,' + cuda_lib_path)
|
||||
extra_compile_args += ['-DWITH_CUDA']
|
||||
extra_compile_args += ['-DCUDA_LIB_PATH=' + cuda_lib_path]
|
||||
main_link_args += [THC_LIB, THCS_LIB]
|
||||
main_link_args += [THC_LIB, THCS_LIB, THCUNN_LIB]
|
||||
main_sources += [
|
||||
"torch/csrc/cuda/Module.cpp",
|
||||
"torch/csrc/cuda/Storage.cpp",
|
||||
@ -238,13 +282,11 @@ if WITH_CUDNN:
|
||||
include_dirs.append(CUDNN_INCLUDE_DIR)
|
||||
extra_link_args.append('-L' + CUDNN_LIB_DIR)
|
||||
main_sources += [
|
||||
"torch/csrc/cudnn/Module.cpp",
|
||||
"torch/csrc/cudnn/BatchNorm.cpp",
|
||||
"torch/csrc/cudnn/Conv.cpp",
|
||||
"torch/csrc/cudnn/cuDNN.cpp",
|
||||
"torch/csrc/cudnn/Types.cpp",
|
||||
"torch/csrc/cudnn/Handles.cpp",
|
||||
"torch/csrc/cudnn/CppWrapper.cpp",
|
||||
]
|
||||
extra_compile_args += ['-DWITH_CUDNN']
|
||||
|
||||
@ -267,70 +309,70 @@ extensions = []
|
||||
packages = find_packages(exclude=('tools.*',))
|
||||
|
||||
C = Extension("torch._C",
|
||||
libraries=main_libraries,
|
||||
sources=main_sources,
|
||||
language='c++',
|
||||
extra_compile_args=main_compile_args + extra_compile_args,
|
||||
include_dirs=include_dirs,
|
||||
extra_link_args=extra_link_args + main_link_args + [make_relative_rpath('lib')],
|
||||
)
|
||||
libraries=main_libraries,
|
||||
sources=main_sources,
|
||||
language='c++',
|
||||
extra_compile_args=main_compile_args + extra_compile_args,
|
||||
include_dirs=include_dirs,
|
||||
extra_link_args=extra_link_args + main_link_args + [make_relative_rpath('lib')],
|
||||
)
|
||||
extensions.append(C)
|
||||
|
||||
DL = Extension("torch._dl",
|
||||
sources=["torch/csrc/dl.c"],
|
||||
language='c',
|
||||
)
|
||||
sources=["torch/csrc/dl.c"],
|
||||
language='c',
|
||||
)
|
||||
extensions.append(DL)
|
||||
|
||||
THNN = Extension("torch._thnn._THNN",
|
||||
sources=['torch/csrc/nn/THNN.cpp'],
|
||||
language='c++',
|
||||
extra_compile_args=extra_compile_args,
|
||||
include_dirs=include_dirs,
|
||||
extra_link_args=extra_link_args + [
|
||||
TH_LIB,
|
||||
THNN_LIB,
|
||||
make_relative_rpath('../lib'),
|
||||
]
|
||||
)
|
||||
sources=['torch/csrc/nn/THNN.cpp'],
|
||||
language='c++',
|
||||
extra_compile_args=extra_compile_args,
|
||||
include_dirs=include_dirs,
|
||||
extra_link_args=extra_link_args + [
|
||||
TH_LIB,
|
||||
THNN_LIB,
|
||||
make_relative_rpath('../lib'),
|
||||
]
|
||||
)
|
||||
extensions.append(THNN)
|
||||
|
||||
if WITH_CUDA:
|
||||
THCUNN = Extension("torch._thnn._THCUNN",
|
||||
sources=['torch/csrc/nn/THCUNN.cpp'],
|
||||
language='c++',
|
||||
extra_compile_args=extra_compile_args,
|
||||
include_dirs=include_dirs,
|
||||
extra_link_args=extra_link_args + [
|
||||
TH_LIB,
|
||||
THC_LIB,
|
||||
THCUNN_LIB,
|
||||
make_relative_rpath('../lib'),
|
||||
]
|
||||
)
|
||||
sources=['torch/csrc/nn/THCUNN.cpp'],
|
||||
language='c++',
|
||||
extra_compile_args=extra_compile_args,
|
||||
include_dirs=include_dirs,
|
||||
extra_link_args=extra_link_args + [
|
||||
TH_LIB,
|
||||
THC_LIB,
|
||||
THCUNN_LIB,
|
||||
make_relative_rpath('../lib'),
|
||||
]
|
||||
)
|
||||
extensions.append(THCUNN)
|
||||
|
||||
version="0.1"
|
||||
version = "0.1"
|
||||
if os.getenv('PYTORCH_BUILD_VERSION'):
|
||||
version = os.getenv('PYTORCH_BUILD_VERSION') \
|
||||
+ '_' + os.getenv('PYTORCH_BUILD_NUMBER')
|
||||
+ '_' + os.getenv('PYTORCH_BUILD_NUMBER')
|
||||
|
||||
setup(name="torch", version=version,
|
||||
ext_modules=extensions,
|
||||
cmdclass = {
|
||||
'build': build,
|
||||
'build_ext': build_ext,
|
||||
'build_deps': build_deps,
|
||||
'build_module': build_module,
|
||||
'install': install,
|
||||
'clean': clean,
|
||||
},
|
||||
packages=packages,
|
||||
package_data={'torch': [
|
||||
'lib/*.so*', 'lib/*.dylib*',
|
||||
'lib/torch_shm_manager',
|
||||
'lib/*.h',
|
||||
'lib/include/TH/*.h', 'lib/include/TH/generic/*.h',
|
||||
'lib/include/THC/*.h', 'lib/include/THC/generic/*.h']},
|
||||
install_requires=['pyyaml'],
|
||||
)
|
||||
ext_modules=extensions,
|
||||
cmdclass={
|
||||
'build': build,
|
||||
'build_ext': build_ext,
|
||||
'build_deps': build_deps,
|
||||
'build_module': build_module,
|
||||
'install': install,
|
||||
'clean': clean,
|
||||
},
|
||||
packages=packages,
|
||||
package_data={'torch': [
|
||||
'lib/*.so*', 'lib/*.dylib*',
|
||||
'lib/torch_shm_manager',
|
||||
'lib/*.h',
|
||||
'lib/include/TH/*.h', 'lib/include/TH/generic/*.h',
|
||||
'lib/include/THC/*.h', 'lib/include/THC/generic/*.h']},
|
||||
install_requires=['pyyaml'],
|
||||
)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import sys
|
||||
import argparse
|
||||
import unittest
|
||||
import contextlib
|
||||
from itertools import product
|
||||
@ -9,9 +11,17 @@ from torch.autograd import Variable, Function
|
||||
|
||||
|
||||
torch.set_default_tensor_type('torch.DoubleTensor')
|
||||
torch.manual_seed(123)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(123)
|
||||
|
||||
|
||||
def run_tests():
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument('--seed', type=int, default=123)
|
||||
args, remaining = parser.parse_known_args()
|
||||
torch.manual_seed(args.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
remaining = [sys.argv[0]] + remaining
|
||||
unittest.main(argv=remaining)
|
||||
|
||||
|
||||
TEST_NUMPY = True
|
||||
@ -20,6 +30,7 @@ try:
|
||||
except ImportError:
|
||||
TEST_NUMPY = False
|
||||
|
||||
|
||||
def get_cpu_type(t):
|
||||
assert t.__module__ == 'torch.cuda'
|
||||
return getattr(torch, t.__class__.__name__)
|
||||
@ -146,7 +157,7 @@ def make_jacobian(input, num_out):
|
||||
return torch.zeros(input.nelement(), num_out)
|
||||
else:
|
||||
return type(input)(filter(lambda x: x is not None,
|
||||
(make_jacobian(elem, num_out) for elem in input)))
|
||||
(make_jacobian(elem, num_out) for elem in input)))
|
||||
|
||||
|
||||
def iter_tensors(x, only_requiring_grad=False):
|
||||
@ -197,7 +208,7 @@ def get_numerical_jacobian(fn, input, target):
|
||||
outb.copy_(fn(input))
|
||||
flat_tensor[i] = orig
|
||||
|
||||
outb.add_(-1,outa).div_(2*perturbation)
|
||||
outb.add_(-1, outa).div_(2 * perturbation)
|
||||
d_tensor[i] = outb
|
||||
|
||||
return jacobian
|
||||
|
||||
@ -18,6 +18,7 @@ else:
|
||||
TEST_CUDA = torch.cuda.is_available()
|
||||
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
|
||||
TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.cuda.FloatTensor(1))
|
||||
TEST_CUDNN_VERSION = TEST_CUDNN and torch.backends.cudnn.version()
|
||||
PRECISION = 1e-5
|
||||
|
||||
module_tests = [
|
||||
@ -25,14 +26,14 @@ module_tests = [
|
||||
module_name='Linear',
|
||||
constructor_args=(10, 8),
|
||||
input_size=(4, 10),
|
||||
reference_fn=lambda i,p: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)
|
||||
reference_fn=lambda i, p: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)
|
||||
),
|
||||
dict(
|
||||
module_name='Linear',
|
||||
constructor_args=(10, 8, False),
|
||||
input_size=(4, 10),
|
||||
desc='no_bias',
|
||||
reference_fn=lambda i,p: torch.mm(i, p[0].t())
|
||||
reference_fn=lambda i, p: torch.mm(i, p[0].t())
|
||||
),
|
||||
dict(
|
||||
module_name='Threshold',
|
||||
@ -72,7 +73,7 @@ module_tests = [
|
||||
dict(
|
||||
module_name='Hardtanh',
|
||||
input_size=(3, 2, 5),
|
||||
reference_fn=lambda i,_: i.clamp(-1, 1)
|
||||
reference_fn=lambda i, _: i.clamp(-1, 1)
|
||||
),
|
||||
dict(
|
||||
module_name='Sigmoid',
|
||||
@ -85,17 +86,23 @@ module_tests = [
|
||||
dict(
|
||||
module_name='Softmax',
|
||||
input_size=(10, 20),
|
||||
reference_fn=lambda i,_: torch.exp(i).div(torch.exp(i).sum(1).expand(10, 20))
|
||||
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1).expand(10, 20))
|
||||
),
|
||||
dict(
|
||||
module_name='Softmax2d',
|
||||
input_size=(1, 3, 10, 20),
|
||||
reference_fn=lambda i,_: torch.exp(i).div(torch.exp(i).sum(1).expand_as(i))
|
||||
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1).expand_as(i))
|
||||
),
|
||||
dict(
|
||||
module_name='LogSoftmax',
|
||||
input_size=(10, 20),
|
||||
reference_fn=lambda i,_: torch.exp(i).div_(torch.exp(i).sum(1).expand(10, 20)).log_()
|
||||
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1).expand(10, 20)).log_()
|
||||
),
|
||||
dict(
|
||||
module_name='LogSoftmax',
|
||||
input_size=(1, 3, 10, 20),
|
||||
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1).expand_as(i)).log_(),
|
||||
desc='multiparam'
|
||||
),
|
||||
dict(
|
||||
module_name='ELU',
|
||||
@ -124,18 +131,18 @@ module_tests = [
|
||||
dict(
|
||||
module_name='LogSigmoid',
|
||||
input_size=(2, 3, 4),
|
||||
reference_fn=lambda i,_: i.sigmoid().log()
|
||||
reference_fn=lambda i, _: i.sigmoid().log()
|
||||
),
|
||||
dict(
|
||||
module_name='Softplus',
|
||||
input_size=(10, 20),
|
||||
reference_fn=lambda i,_: torch.log(1 + torch.exp(i))
|
||||
reference_fn=lambda i, _: torch.log(1 + torch.exp(i))
|
||||
),
|
||||
dict(
|
||||
module_name='Softplus',
|
||||
constructor_args=(2,),
|
||||
input_size=(10, 20),
|
||||
reference_fn=lambda i,_: 1. / 2. * torch.log(1 + torch.exp(2 * i)),
|
||||
reference_fn=lambda i, _: 1. / 2. * torch.log(1 + torch.exp(2 * i)),
|
||||
desc='beta'
|
||||
),
|
||||
dict(
|
||||
@ -166,7 +173,7 @@ module_tests = [
|
||||
dict(
|
||||
module_name='Softsign',
|
||||
input_size=(3, 2, 5),
|
||||
reference_fn=lambda i,_: i.div(1 + torch.abs(i))
|
||||
reference_fn=lambda i, _: i.div(1 + torch.abs(i))
|
||||
),
|
||||
dict(
|
||||
module_name='Softmin',
|
||||
@ -181,11 +188,11 @@ module_tests = [
|
||||
|
||||
criterion_tests = [
|
||||
dict(module_name='L1Loss',
|
||||
input_size=(2, 3, 4),
|
||||
target=torch.randn(2, 3, 4),
|
||||
reference_fn=lambda i,t,_: 1./i.numel() * \
|
||||
sum((a-b).abs().sum() for a,b in zip(i, t))
|
||||
),
|
||||
input_size=(2, 3, 4),
|
||||
target=torch.randn(2, 3, 4),
|
||||
reference_fn=lambda i, t, _: 1. / i.numel() *
|
||||
sum((a - b).abs().sum() for a, b in zip(i, t))
|
||||
),
|
||||
dict(
|
||||
module_name='NLLLoss',
|
||||
input=torch.rand(15, 10).log(),
|
||||
@ -207,7 +214,7 @@ criterion_tests = [
|
||||
module_name='MSELoss',
|
||||
input=torch.randn(2, 3, 4, 5),
|
||||
target=torch.randn(2, 3, 4, 5),
|
||||
reference_fn=lambda i,t,_: (i-t).abs().pow(2).sum() / i.numel()
|
||||
reference_fn=lambda i, t, _: (i - t).abs().pow(2).sum() / i.numel()
|
||||
),
|
||||
dict(
|
||||
module_name='BCELoss',
|
||||
@ -364,9 +371,9 @@ class NNTestCase(TestCase):
|
||||
|
||||
if jacobian_input:
|
||||
for jacobian_x, d_x in zip(flat_jacobian_input, iter_tensors(d_input)):
|
||||
jacobian_x[:,i] = d_x
|
||||
jacobian_x[:, i] = d_x
|
||||
if jacobian_parameters:
|
||||
jacobian_param[:,i] = torch.cat(self._flatten_tensors(d_param), 0)
|
||||
jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)
|
||||
|
||||
res = tuple()
|
||||
if jacobian_input:
|
||||
@ -427,7 +434,7 @@ class NNTestCase(TestCase):
|
||||
fx1 = self._forward_criterion(criterion, input, target)
|
||||
x[i] = original - eps
|
||||
fx2 = self._forward_criterion(criterion, input, target)
|
||||
deriv = (fx1 - fx2) / (2.*eps)
|
||||
deriv = (fx1 - fx2) / (2. * eps)
|
||||
d_x[i] = deriv
|
||||
x[i] = original
|
||||
|
||||
@ -441,8 +448,9 @@ class NNTestCase(TestCase):
|
||||
|
||||
|
||||
class TestBase(object):
|
||||
|
||||
def __init__(self, constructor, constructor_args=tuple(), input_size=None,
|
||||
input=None, desc='', reference_fn=None, fullname=None, **kwargs):
|
||||
input=None, desc='', reference_fn=None, fullname=None, **kwargs):
|
||||
if input_size is None and input is None:
|
||||
raise RuntimeError("Specify either an input tensor, or it's size!")
|
||||
self.constructor = constructor
|
||||
@ -490,6 +498,7 @@ class TestBase(object):
|
||||
|
||||
|
||||
class ModuleTest(TestBase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ModuleTest, self).__init__(*args, **kwargs)
|
||||
self.jacobian_input = kwargs.get('jacobian_input', True)
|
||||
@ -562,6 +571,7 @@ class ModuleTest(TestBase):
|
||||
|
||||
|
||||
class CriterionTest(TestBase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CriterionTest, self).__init__(*args, **kwargs)
|
||||
self.target = self._get_target(kwargs['target'])
|
||||
@ -584,7 +594,7 @@ class CriterionTest(TestBase):
|
||||
if isinstance(target, Variable):
|
||||
target = target.data
|
||||
expected_out = self.reference_fn(deepcopy(self._unpack_input(input)),
|
||||
deepcopy(target), module)
|
||||
deepcopy(target), module)
|
||||
test_case.assertEqual(out, expected_out)
|
||||
|
||||
test_case.check_criterion_jacobian(module, input, self.target)
|
||||
@ -607,10 +617,10 @@ class CriterionTest(TestBase):
|
||||
|
||||
cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target)
|
||||
gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target)
|
||||
test_case.assertEqual(cpu_output, gpu_output, 2e-4)
|
||||
test_case.assertEqual(cpu_output, gpu_output, 4e-4)
|
||||
|
||||
cpu_gradInput = test_case._backward_criterion(cpu_module, cpu_input, cpu_target)
|
||||
gpu_gradInput = test_case._backward_criterion(gpu_module, gpu_input, gpu_target)
|
||||
test_case.assertEqual(cpu_gradInput, gpu_gradInput, 2e-4)
|
||||
test_case.assertEqual(cpu_gradInput, gpu_gradInput, 4e-4)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
@ -2,6 +2,7 @@ import torch.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.linear = nn.Linear(10, 20)
|
||||
|
||||
@ -2,6 +2,7 @@ import torch.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.linear = nn.Linear(10, 20)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
|
||||
|
||||
def check_error(desc, fn, *required_substrings):
|
||||
try:
|
||||
fn()
|
||||
@ -16,54 +17,55 @@ def check_error(desc, fn, *required_substrings):
|
||||
assert False, "given function ({}) didn't raise an error".format(desc)
|
||||
|
||||
check_error(
|
||||
'Wrong argument types',
|
||||
lambda: torch.FloatStorage(object()),
|
||||
'object')
|
||||
'Wrong argument types',
|
||||
lambda: torch.FloatStorage(object()),
|
||||
'object')
|
||||
|
||||
check_error('Unknown keyword argument',
|
||||
lambda: torch.FloatStorage(content=1234.),
|
||||
'keyword')
|
||||
lambda: torch.FloatStorage(content=1234.),
|
||||
'keyword')
|
||||
|
||||
check_error('Invalid types inside a sequence',
|
||||
lambda: torch.FloatStorage(['a', 'b']),
|
||||
'list', 'str')
|
||||
lambda: torch.FloatStorage(['a', 'b']),
|
||||
'list', 'str')
|
||||
|
||||
check_error('Invalid size type',
|
||||
lambda: torch.FloatStorage(1.5),
|
||||
'float')
|
||||
lambda: torch.FloatStorage(1.5),
|
||||
'float')
|
||||
|
||||
check_error('Invalid offset',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(2), 4),
|
||||
'2', '4')
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(2), 4),
|
||||
'2', '4')
|
||||
|
||||
check_error('Negative offset',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(2), -1),
|
||||
'2', '-1')
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(2), -1),
|
||||
'2', '-1')
|
||||
|
||||
check_error('Invalid size',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, 5),
|
||||
'2', '1', '5')
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, 5),
|
||||
'2', '1', '5')
|
||||
|
||||
check_error('Negative size',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, -5),
|
||||
'2', '1', '-5')
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, -5),
|
||||
'2', '1', '-5')
|
||||
|
||||
check_error('Invalid index type',
|
||||
lambda: torch.FloatStorage(10)['first item'],
|
||||
'str')
|
||||
lambda: torch.FloatStorage(10)['first item'],
|
||||
'str')
|
||||
|
||||
|
||||
def assign():
|
||||
torch.FloatStorage(10)[1:-1] = '1'
|
||||
check_error('Invalid value type',
|
||||
assign,
|
||||
'str')
|
||||
assign,
|
||||
'str')
|
||||
|
||||
check_error('resize_ with invalid type',
|
||||
lambda: torch.FloatStorage(10).resize_(1.5),
|
||||
'float')
|
||||
lambda: torch.FloatStorage(10).resize_(1.5),
|
||||
'float')
|
||||
|
||||
check_error('fill_ with invalid type',
|
||||
lambda: torch.IntStorage(10).fill_('asdf'),
|
||||
'str')
|
||||
lambda: torch.IntStorage(10).fill_('asdf'),
|
||||
'str')
|
||||
|
||||
# TODO: frombuffer
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
# th test.lua > lua.out
|
||||
th test.lua > lua.out
|
||||
python3 test.py > python.out
|
||||
|
||||
diff lua.out python.out >/dev/null 2>&1
|
||||
|
||||
5060
test/optim/lua.out
5060
test/optim/lua.out
File diff suppressed because it is too large
Load Diff
@ -1,39 +0,0 @@
|
||||
assert(arg[1])
|
||||
funcs = {
|
||||
'resizeAs', 'add', 'zero', 'mul', 'div', 'abs',
|
||||
'addcmul', 'addcdiv', 'copy', 'sqrt', 'fill',
|
||||
{'cmul', 'mul'},
|
||||
{'cdiv', 'div'},
|
||||
}
|
||||
for _, val in pairs(funcs) do
|
||||
local name, newname
|
||||
if type(val) == 'table' then
|
||||
name = val[1]
|
||||
newname = val[2]
|
||||
else
|
||||
name = val
|
||||
newname = val .. '_'
|
||||
end
|
||||
|
||||
command = "sed -i -r "
|
||||
.. "'/torch\\." .. name .. "\\(/b; " -- short-circuits
|
||||
.. "s/([a-zA-Z]*)\\." .. name .. "\\(" -- substitution
|
||||
.. "/"
|
||||
.. "\\1\\." .. newname .. "\\(/g' " .. arg[1]
|
||||
print(command)
|
||||
os.execute(command)
|
||||
command = "sed -i 's/math\\." .. newname
|
||||
.. "/math\\." .. name .. "/' " .. arg[1]
|
||||
print(command)
|
||||
os.execute(command)
|
||||
end
|
||||
|
||||
funcs = {
|
||||
{'torch\.cmul', 'torch\.mul'},
|
||||
{'torch\.cdiv', 'torch\.div'},
|
||||
}
|
||||
for _, val in pairs(funcs) do
|
||||
command = "sed -i 's/" .. val[1] .. "/" .. val[2] .. "/' " .. arg[1]
|
||||
print(command)
|
||||
os.execute(command)
|
||||
end
|
||||
33
test/optim/test.lua
Normal file
33
test/optim/test.lua
Normal file
@ -0,0 +1,33 @@
|
||||
local cjson = require 'cjson'
|
||||
require 'optim'
|
||||
|
||||
function rosenbrock(t)
|
||||
x, y = t[1], t[2]
|
||||
return (1 - x) ^ 2 + 100 * (y - x^2)^2
|
||||
end
|
||||
|
||||
function drosenbrock(t)
|
||||
x, y = t[1], t[2]
|
||||
return torch.DoubleTensor({-400 * x * (y - x^2) - 2 * (1 - x), 200 * x * (y - x^2)})
|
||||
end
|
||||
|
||||
local fd = io.open('tests.json', 'r')
|
||||
local tests = cjson.decode(fd:read('*a'))
|
||||
fd:close()
|
||||
|
||||
for i, test in ipairs(tests) do
|
||||
print(test.algorithm)
|
||||
algorithm = optim[test.algorithm]
|
||||
for i, config in ipairs(test.config) do
|
||||
print('================================================================================')
|
||||
params = torch.DoubleTensor({1.5, 1.5})
|
||||
for i = 1, 100 do
|
||||
function closure(x)
|
||||
return rosenbrock(x), drosenbrock(x)
|
||||
end
|
||||
algorithm(closure, params, config)
|
||||
print(string.format('%.8f\t%.8f', params[1], params[2]))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@ -3,13 +3,15 @@ import torch
|
||||
import torch.legacy.optim as optim
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
def rosenbrock(tensor):
|
||||
x, y = tensor
|
||||
return (1 - x)**2 + 100 * (y - x**2)**2
|
||||
return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2
|
||||
|
||||
|
||||
def drosenbrock(tensor):
|
||||
x, y = tensor
|
||||
return torch.DoubleTensor((-400 * x * (y - x**2) - 2 * (1 - x), 200 * x * (y - x**2)))
|
||||
return torch.DoubleTensor((-400 * x * (y - x ** 2) - 2 * (1 - x), 200 * x * (y - x ** 2)))
|
||||
|
||||
algorithms = {
|
||||
'adadelta': optim.adadelta,
|
||||
@ -22,6 +24,7 @@ algorithms = {
|
||||
'rmsprop': optim.rmsprop,
|
||||
'rprop': optim.rprop,
|
||||
'sgd': optim.sgd,
|
||||
'lbfgs': optim.lbfgs,
|
||||
}
|
||||
|
||||
with open('tests.json', 'r') as f:
|
||||
@ -35,4 +38,4 @@ for test in tests:
|
||||
params = torch.DoubleTensor((1.5, 1.5))
|
||||
for i in range(100):
|
||||
algorithm(lambda x: (rosenbrock(x), drosenbrock(x)), params, config)
|
||||
print('{:.12f}\t{:.12f}\t'.format(params[0], params[1]))
|
||||
print('{:.8f}\t{:.8f}\t'.format(params[0], params[1]))
|
||||
|
||||
@ -98,5 +98,12 @@
|
||||
{"learningRate": 1e-4, "nesterov": true, "momentum": 0.95, "dampening": 0},
|
||||
{"weightDecay": 0.2}
|
||||
]
|
||||
},
|
||||
{
|
||||
"algorithm": "lbfgs",
|
||||
"config": [
|
||||
{},
|
||||
{"learningRate": 1e-1}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@ -2,8 +2,17 @@
|
||||
set -e
|
||||
|
||||
PYCMD=${PYCMD:="python"}
|
||||
if [ "$1" == "coverage" ];
|
||||
then
|
||||
COVERAGE=0
|
||||
while [[ "$#" -gt 0 ]]; do
|
||||
case "$1" in
|
||||
-p|--python) PYCMD=$2; shift 2 ;;
|
||||
-c|--coverage) COVERAGE=1; shift 2 ;;
|
||||
--) shift; break ;;
|
||||
*) echo "Invalid argument: $1!" ; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ $COVERAGE -eq 1 ]]; then
|
||||
coverage erase
|
||||
PYCMD="coverage run --parallel-mode --source torch "
|
||||
echo "coverage flag found. Setting python command to: \"$PYCMD\""
|
||||
@ -12,39 +21,66 @@ fi
|
||||
pushd "$(dirname "$0")"
|
||||
|
||||
echo "Running torch tests"
|
||||
$PYCMD test_torch.py
|
||||
$PYCMD test_torch.py $@
|
||||
|
||||
echo "Running autograd tests"
|
||||
$PYCMD test_autograd.py
|
||||
$PYCMD test_autograd.py $@
|
||||
|
||||
echo "Running sparse tests"
|
||||
$PYCMD test_sparse.py
|
||||
$PYCMD test_sparse.py $@
|
||||
|
||||
echo "Running nn tests"
|
||||
$PYCMD test_nn.py
|
||||
$PYCMD test_nn.py $@
|
||||
|
||||
echo "Running legacy nn tests"
|
||||
$PYCMD test_legacy_nn.py
|
||||
$PYCMD test_legacy_nn.py $@
|
||||
|
||||
echo "Running optim tests"
|
||||
$PYCMD test_optim.py
|
||||
$PYCMD test_optim.py $@
|
||||
|
||||
echo "Running multiprocessing tests"
|
||||
$PYCMD test_multiprocessing.py
|
||||
MULTIPROCESSING_METHOD=spawn $PYCMD test_multiprocessing.py
|
||||
MULTIPROCESSING_METHOD=forkserver $PYCMD test_multiprocessing.py
|
||||
$PYCMD test_multiprocessing.py $@
|
||||
MULTIPROCESSING_METHOD=spawn $PYCMD test_multiprocessing.py $@
|
||||
MULTIPROCESSING_METHOD=forkserver $PYCMD test_multiprocessing.py $@
|
||||
|
||||
echo "Running util tests"
|
||||
$PYCMD test_utils.py
|
||||
$PYCMD test_utils.py $@
|
||||
|
||||
echo "Running dataloader tests"
|
||||
$PYCMD test_dataloader.py
|
||||
$PYCMD test_dataloader.py $@
|
||||
|
||||
echo "Running cuda tests"
|
||||
$PYCMD test_cuda.py
|
||||
$PYCMD test_cuda.py $@
|
||||
|
||||
echo "Running NCCL tests"
|
||||
$PYCMD test_nccl.py
|
||||
$PYCMD test_nccl.py $@
|
||||
|
||||
################################################################################
|
||||
if [[ "$TEST_DISTRIBUTED" -eq 1 ]]; then
|
||||
distributed_set_up() {
|
||||
export TEMP_DIR="$(mktemp -d)"
|
||||
rm -rf "$TEMP_DIR/"*
|
||||
mkdir "$TEMP_DIR/barrier"
|
||||
mkdir "$TEMP_DIR/test_dir"
|
||||
}
|
||||
|
||||
distributed_tear_down() {
|
||||
rm -rf "$TEMP_DIR"
|
||||
}
|
||||
|
||||
trap distributed_tear_down EXIT SIGHUP SIGINT SIGTERM
|
||||
|
||||
echo "Running distributed tests for the TCP backend"
|
||||
distributed_set_up
|
||||
BACKEND=tcp WORLD_SIZE=3 $PYCMD ./test_distributed.py
|
||||
distributed_tear_down
|
||||
|
||||
echo "Running distributed tests for the MPI backend"
|
||||
distributed_set_up
|
||||
BACKEND=mpi mpiexec -n 3 $PYCMD ./test_distributed.py
|
||||
distributed_tear_down
|
||||
fi
|
||||
################################################################################
|
||||
|
||||
if [ "$1" == "coverage" ];
|
||||
then
|
||||
|
||||
@ -7,7 +7,8 @@ import unittest
|
||||
from copy import deepcopy
|
||||
from collections import OrderedDict
|
||||
|
||||
from common import make_jacobian, TestCase, iter_tensors, get_numerical_jacobian
|
||||
from common import make_jacobian, TestCase, iter_tensors, \
|
||||
get_numerical_jacobian, run_tests
|
||||
from torch.autograd._functions import *
|
||||
from torch.autograd import Variable, Function
|
||||
|
||||
@ -45,7 +46,7 @@ def get_analytical_jacobian(input, output):
|
||||
zero_gradients(input)
|
||||
output.backward(grad_output, retain_variables=True)
|
||||
for jacobian_x, d_x in zip(jacobian, iter_gradients(input)):
|
||||
jacobian_x[:,i] = d_x
|
||||
jacobian_x[:, i] = d_x
|
||||
|
||||
return jacobian
|
||||
|
||||
@ -67,6 +68,7 @@ class TestAutograd(TestCase):
|
||||
y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
|
||||
|
||||
counter = [0]
|
||||
|
||||
def bw_hook(inc, grad):
|
||||
self.assertIsInstance(grad, Variable)
|
||||
counter[0] += inc
|
||||
@ -102,6 +104,7 @@ class TestAutograd(TestCase):
|
||||
# WARNING: this is a test for autograd internals.
|
||||
# You should never have to use such things in your code.
|
||||
class NoneGradientFunction(Function):
|
||||
|
||||
def forward(self, x, y):
|
||||
assert self.needs_input_grad[0]
|
||||
assert not self.needs_input_grad[1]
|
||||
@ -113,6 +116,7 @@ class TestAutograd(TestCase):
|
||||
fn = NoneGradientFunction()
|
||||
fn._backward_hooks = OrderedDict()
|
||||
was_called = [False]
|
||||
|
||||
def hook(grad_input, grad_output):
|
||||
self.assertIsInstance(grad_input, tuple)
|
||||
self.assertIsInstance(grad_output, tuple)
|
||||
@ -142,7 +146,7 @@ class TestAutograd(TestCase):
|
||||
v.backward(grad_output)
|
||||
self.assertEqual(v.grad.data, grad_output)
|
||||
|
||||
a = x + (y * z) + 4 * z**2 * x / y
|
||||
a = x + (y * z) + 4 * z ** 2 * x / y
|
||||
a.backward(grad_output)
|
||||
x_grad = 4 * z_t.pow(2) / y_t + 1
|
||||
y_grad = z_t - 4 * x_t * z_t.pow(2) / y_t.pow(2)
|
||||
@ -237,6 +241,7 @@ class TestAutograd(TestCase):
|
||||
self.assertFalse(a.requires_grad)
|
||||
b = a + z
|
||||
self.assertTrue(b.requires_grad)
|
||||
|
||||
def error():
|
||||
raise RuntimeError
|
||||
# Make sure backward isn't called on these
|
||||
@ -374,6 +379,7 @@ class TestAutograd(TestCase):
|
||||
segfault.
|
||||
"""
|
||||
class CollectOnDelete(Function):
|
||||
|
||||
def __del__(self):
|
||||
gc.collect()
|
||||
|
||||
@ -381,7 +387,7 @@ class TestAutograd(TestCase):
|
||||
Variable(torch.randn(10, 10), creator=CollectOnDelete())
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available() or torch.cuda.device_count() < 2,
|
||||
"CUDA not available or <2 GPUs detected")
|
||||
"CUDA not available or <2 GPUs detected")
|
||||
def test_unused_output_gpu(self):
|
||||
from torch.nn.parallel._functions import Broadcast
|
||||
x = Variable(torch.randn(5, 5).float().cuda(), requires_grad=True)
|
||||
@ -431,6 +437,7 @@ class TestAutograd(TestCase):
|
||||
|
||||
def test_return_leaf(self):
|
||||
class Identity(Function):
|
||||
|
||||
def forward(self, a, b):
|
||||
return a, a + b
|
||||
|
||||
@ -438,6 +445,7 @@ class TestAutograd(TestCase):
|
||||
return grad_a + grad_b, grad_b
|
||||
|
||||
class Inplace(InplaceFunction):
|
||||
|
||||
def forward(self, a, b):
|
||||
self.mark_dirty(a)
|
||||
return a.add_(b), b + 2
|
||||
@ -459,6 +467,7 @@ class TestAutograd(TestCase):
|
||||
|
||||
def test_return_leaf_inplace(self):
|
||||
class Inplace(InplaceFunction):
|
||||
|
||||
def forward(self, a, b):
|
||||
self.mark_dirty(a)
|
||||
return a.add_(b), b + 2
|
||||
@ -491,51 +500,51 @@ class TestAutograd(TestCase):
|
||||
self.assertEqual(z.grad.data, torch.ones(5) * 2)
|
||||
|
||||
def test_backward_copy(self):
|
||||
# This tests checks backward engine for a very subtle bug that appreared
|
||||
# in one of the initial versions of autograd. Gradients tensors were
|
||||
# simply stored in lists while the function waited for all its gradients
|
||||
# to be computed. However, sometimes an output was used multiple times,
|
||||
# so the gradients needed to be summed. Engine used to keep a need_copy
|
||||
# set of tensors that will need a clone upon next addition and removed
|
||||
# them from the set as soon as the clone was performed. However, this
|
||||
# could lead to incorrect results if the same gradient tensor was
|
||||
# buffered in three places in the graph:
|
||||
# 1. When accumulating gradients in one of these places it was cloned
|
||||
# and removed from need_copy set.
|
||||
# 2. When accumulating in second place, it wasn't in the need_copy set,
|
||||
# so the gradients were simply accumulated in-place (which already
|
||||
# modified the grad in 3rd place)
|
||||
# 3. When accumulating in the third place, it wasn't in the need_copy set
|
||||
# as well, so the incoming gradient was summed in-place, yielding
|
||||
# incorrect results in all functions, except the first one.
|
||||
x = Variable(torch.ones(5, 5), requires_grad=True)
|
||||
y = Variable(torch.ones(5, 5), requires_grad=True)
|
||||
# Simulate that we're in the middle of the graph
|
||||
a = x + 2
|
||||
b = y + 2
|
||||
c = x + 2
|
||||
# This op will just return grad_output two times in backward
|
||||
add1 = a + b
|
||||
add2 = add1 + c
|
||||
# Simulate a long branch, so grad_output will get buffered.
|
||||
for i in range(4):
|
||||
a = a * 2
|
||||
b = b * 2
|
||||
c = c * 2
|
||||
branch = a + b + c
|
||||
out = add2 + branch
|
||||
# expected gradients are:
|
||||
# for x: 34 (16 from final a, 16 from final c, 2 from add2)
|
||||
# for y: 17 (16 from final b, 1 from add2)
|
||||
grad_output = torch.ones(5, 5)
|
||||
out.backward(grad_output)
|
||||
self.assertEqual(x.grad.data, torch.ones(5, 5) * 34)
|
||||
self.assertEqual(y.grad.data, torch.ones(5, 5) * 17)
|
||||
# This tests checks backward engine for a very subtle bug that appreared
|
||||
# in one of the initial versions of autograd. Gradients tensors were
|
||||
# simply stored in lists while the function waited for all its gradients
|
||||
# to be computed. However, sometimes an output was used multiple times,
|
||||
# so the gradients needed to be summed. Engine used to keep a need_copy
|
||||
# set of tensors that will need a clone upon next addition and removed
|
||||
# them from the set as soon as the clone was performed. However, this
|
||||
# could lead to incorrect results if the same gradient tensor was
|
||||
# buffered in three places in the graph:
|
||||
# 1. When accumulating gradients in one of these places it was cloned
|
||||
# and removed from need_copy set.
|
||||
# 2. When accumulating in second place, it wasn't in the need_copy set,
|
||||
# so the gradients were simply accumulated in-place (which already
|
||||
# modified the grad in 3rd place)
|
||||
# 3. When accumulating in the third place, it wasn't in the need_copy set
|
||||
# as well, so the incoming gradient was summed in-place, yielding
|
||||
# incorrect results in all functions, except the first one.
|
||||
x = Variable(torch.ones(5, 5), requires_grad=True)
|
||||
y = Variable(torch.ones(5, 5), requires_grad=True)
|
||||
# Simulate that we're in the middle of the graph
|
||||
a = x + 2
|
||||
b = y + 2
|
||||
c = x + 2
|
||||
# This op will just return grad_output two times in backward
|
||||
add1 = a + b
|
||||
add2 = add1 + c
|
||||
# Simulate a long branch, so grad_output will get buffered.
|
||||
for i in range(4):
|
||||
a = a * 2
|
||||
b = b * 2
|
||||
c = c * 2
|
||||
branch = a + b + c
|
||||
out = add2 + branch
|
||||
# expected gradients are:
|
||||
# for x: 34 (16 from final a, 16 from final c, 2 from add2)
|
||||
# for y: 17 (16 from final b, 1 from add2)
|
||||
grad_output = torch.ones(5, 5)
|
||||
out.backward(grad_output)
|
||||
self.assertEqual(x.grad.data, torch.ones(5, 5) * 34)
|
||||
self.assertEqual(y.grad.data, torch.ones(5, 5) * 17)
|
||||
|
||||
def test_functional_blas(self):
|
||||
def compare(fn, *args):
|
||||
unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg
|
||||
for arg in args)
|
||||
for arg in args)
|
||||
self.assertEqual(fn(*args).data, fn(*unpacked_args))
|
||||
|
||||
def test_blas_add(fn, x, y, z):
|
||||
@ -548,27 +557,29 @@ class TestAutograd(TestCase):
|
||||
compare(fn, x, y)
|
||||
|
||||
test_blas(torch.mm, Variable(torch.randn(2, 10)),
|
||||
Variable(torch.randn(10, 4)))
|
||||
Variable(torch.randn(10, 4)))
|
||||
test_blas_add(torch.addmm, Variable(torch.randn(2, 4)),
|
||||
Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
|
||||
Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
|
||||
test_blas(torch.bmm, Variable(torch.randn(4, 2, 10)),
|
||||
Variable(torch.randn(4, 10, 4)))
|
||||
Variable(torch.randn(4, 10, 4)))
|
||||
test_blas_add(torch.addbmm, Variable(torch.randn(2, 4)),
|
||||
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
|
||||
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
|
||||
test_blas_add(torch.baddbmm, Variable(torch.randn(4, 2, 4)),
|
||||
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
|
||||
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
|
||||
test_blas(torch.mv, Variable(torch.randn(2, 10)),
|
||||
Variable(torch.randn(10)))
|
||||
Variable(torch.randn(10)))
|
||||
test_blas_add(torch.addmv, Variable(torch.randn(2)),
|
||||
Variable(torch.randn(2, 10)), Variable(torch.randn(10)))
|
||||
Variable(torch.randn(2, 10)), Variable(torch.randn(10)))
|
||||
test_blas(torch.ger, Variable(torch.randn(5)),
|
||||
Variable(torch.randn(6)))
|
||||
Variable(torch.randn(6)))
|
||||
test_blas_add(torch.addr, Variable(torch.randn(5, 6)),
|
||||
Variable(torch.randn(5)), Variable(torch.randn(6)))
|
||||
Variable(torch.randn(5)), Variable(torch.randn(6)))
|
||||
|
||||
def test_save_none_for_backward(self):
|
||||
test_case = self
|
||||
|
||||
class MyFn(Function):
|
||||
|
||||
def forward(self, input):
|
||||
self.save_for_backward(None, input, None)
|
||||
return input * input
|
||||
@ -586,6 +597,7 @@ class TestAutograd(TestCase):
|
||||
|
||||
def test_too_many_grads(self):
|
||||
class MyFn(Function):
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
@ -674,6 +686,7 @@ class TestAutograd(TestCase):
|
||||
|
||||
def test_dep_nograd(self):
|
||||
class F1(Function):
|
||||
|
||||
def forward(self, input):
|
||||
out = torch.randn(input.size())
|
||||
self.mark_non_differentiable(out)
|
||||
@ -683,6 +696,7 @@ class TestAutograd(TestCase):
|
||||
return grad_output
|
||||
|
||||
class F2(Function):
|
||||
|
||||
def forward(self, input, ignored):
|
||||
return input
|
||||
|
||||
@ -705,6 +719,7 @@ def index_variable(shape, max_indices):
|
||||
index = torch.rand(*shape).mul_(max_indices).floor_().long()
|
||||
return Variable(index, requires_grad=False)
|
||||
|
||||
|
||||
def gather_variable(shape, index_dim, max_indices):
|
||||
assert len(shape) == 2
|
||||
assert index_dim < 2
|
||||
@ -712,7 +727,7 @@ def gather_variable(shape, index_dim, max_indices):
|
||||
index = torch.LongTensor(*shape)
|
||||
for i in range(shape[index_dim]):
|
||||
index.select(index_dim, i).copy_(
|
||||
torch.randperm(max_indices)[:shape[batch_dim]])
|
||||
torch.randperm(max_indices)[:shape[batch_dim]])
|
||||
return Variable(index, requires_grad=False)
|
||||
|
||||
|
||||
@ -720,215 +735,235 @@ L = 20
|
||||
M = 10
|
||||
S = 5
|
||||
function_tests = [
|
||||
(Add, (), ((M, M), (M, M)) ),
|
||||
(Sub, (), ((M, M), (M, M)) ),
|
||||
(Mul, (), ((M, M), (M, M)) ),
|
||||
(Div, (), ((M, M), torch.rand(M, M) + 5e-2) ),
|
||||
(Pow, (), (torch.rand(M, M) + 1e-3, torch.rand(M, M) + 0.1)),
|
||||
(AddConstant, (3.14,), ((L, L),) ),
|
||||
(SubConstant, (3.14,), ((L, L),) ),
|
||||
(SubConstant, (3.14, True), ((L, L),), 'from_tensor' ),
|
||||
(MulConstant, (3.14,), ((L, L),) ),
|
||||
(DivConstant, (3.14, True), (torch.rand(L, L) + 1e-1,), 'by_tensor' ),
|
||||
(PowConstant, (3.14,), (torch.rand(L, L),) ),
|
||||
(PowConstant, (3.14, True), (torch.rand(L, L),), 'tensor_power' ),
|
||||
(Transpose, (0, 1), (torch.rand(L, L),) ),
|
||||
(Transpose, (2, 0), (torch.rand(S, S, S),), '3d' ),
|
||||
(Permute, ((0, 4, 3, 5, 1, 2),), ((1, 2, 3, 4, 5, 6),) ),
|
||||
(Index, ((1, 2),), (torch.rand(S, S, S),) ),
|
||||
(Index, (slice(0, 3),), (torch.rand(S, S, S),), 'slice' ),
|
||||
(Index, ((slice(0, 3), 1),),(torch.rand(S, S, S),), 'slice_index' ),
|
||||
(View, (S*S, S), (torch.rand(S, S, S),) ),
|
||||
(Expand, ((S, 5, S, 5),), ((S, 1, S, 1),) ),
|
||||
(Exp, (), (torch.rand(S, S, S),) ),
|
||||
(Log, (), (torch.rand(S, S, S) + 1e-2,) ),
|
||||
(Log1p, (), (torch.rand(S, S, S),) ),
|
||||
(Tanh, (), ((S, S, S),) ),
|
||||
(Sigmoid, (), ((S, S, S),) ),
|
||||
(Sinh, (), ((S, S, S),) ),
|
||||
(Cosh, (), ((S, S, S),) ),
|
||||
(Abs, (), ((S, S, S),) ),
|
||||
(Clamp, (0, 1), ((S, S, S),) ),
|
||||
(Sqrt, (), (torch.rand(S, S, S) + 1e-4,) ),
|
||||
(Sin, (), ((S, S, S),) ),
|
||||
(Cos, (), ((S, S, S),) ),
|
||||
(Tan, (), (torch.randn(S, S, S).clamp(-1, 1),) ),
|
||||
(Asin, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),) ),
|
||||
(Acos, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),) ),
|
||||
(Atan, (), ((S, S, S),) ),
|
||||
(Reciprocal, (), (torch.rand(S, S, S) + 0.1,) ),
|
||||
(Cmax, (), ((S, S, S), (S, S, S)) ),
|
||||
(Cmin, (), ((S, S, S), (S, S, S)) ),
|
||||
(Round, (), ((S, S, S),) ),
|
||||
(Sign, (), ((S, S, S),) ),
|
||||
(Trunc, (), ((S, S, S),) ),
|
||||
(Floor, (), ((S, S, S),) ),
|
||||
(Ceil, (), ((S, S, S),) ),
|
||||
(Frac, (), ((S, S, S),) ),
|
||||
(Fmod, (1.5,), ((S, S, S),) ),
|
||||
(Lerp, (0.2,), ((S, S, S), (S, S, S)) ),
|
||||
(Rsqrt, (), (torch.rand(S, S, S) + 1e-2,) ),
|
||||
(Remainder, (1.5,), ((S, S, S),) ),
|
||||
(CmaxConstant, (0.5,), ((S, S, S),) ),
|
||||
(CminConstant, (0.5,), ((S, S, S),) ),
|
||||
(Mean, (), ((S, S, S),) ),
|
||||
(Mean, (1,), ((S, S, S),), 'dim' ),
|
||||
(Sum, (), ((S, S, S),) ),
|
||||
(Sum, (1,), ((S, S, S),), 'dim' ),
|
||||
(Prod, (), ((S, S, S),) ),
|
||||
(Prod, (1,), ((S, S, S),), 'dim' ),
|
||||
(Addmm, (), ((S, M), (S, S), (S, M)), ),
|
||||
(Addmm, (0.1, 1), ((S, M), (S, S), (S, M)), 'coef' ),
|
||||
(Addbmm, (), ((S, M), (S, S, S), (S, S, M)), ),
|
||||
(Addbmm, (0.1, 0.4), ((S, M), (S, S, S), (S, S, M)), 'coef' ),
|
||||
(Baddbmm, (), ((S, S, M), (S, S, S), (S, S, M)), ),
|
||||
(Baddbmm, (0.1, 0.4), ((S, S, M), (S, S, S), (S, S, M)), 'coef' ),
|
||||
(Addmv, (), ((S,), (S, M), (M,)), ),
|
||||
(Addmv, (0.1, 0.4), ((S,), (S, M), (M,)), 'coef' ),
|
||||
(Addr, (), ((S, M), (S,), (M,)), ),
|
||||
(Addr, (0.1, 0.4), ((S, M), (S,), (M,)), 'coef' ),
|
||||
(Dot, (), ((L,), (L,)), ),
|
||||
(Max, (), ((S, S, S),), ),
|
||||
(Min, (), ((S, S, S),), ),
|
||||
(Max, (0,), ((S, S, S),), 'dim' ),
|
||||
(Min, (0,), ((S, S, S),), 'dim' ),
|
||||
(Mode, (0,), ((S, S, S),), ),
|
||||
(Kthvalue, (2, 0), ((S, S, S),), ),
|
||||
(Median, (0,), ((S, S, S),), ),
|
||||
(Norm, (1.5,), (torch.rand(S, S, S),), '1_5' ),
|
||||
(Norm, (), ((S, S, S),), '2' ),
|
||||
(Norm, (3,), ((S, S, S),), '3' ),
|
||||
(Norm, (1.5, 0), (torch.rand(S, S, S),), '1_5_dim' ),
|
||||
(Norm, (2, 0), ((S, S, S),), '2_dim' ),
|
||||
(Norm, (3, 0), ((S, S, S),), '3_dim' ),
|
||||
(Addcmul, (), ((S, S), (S, S), (S, S)) ),
|
||||
(Addcmul, (0.6,), ((S, S), (S, S), (S, S)), 'scale' ),
|
||||
(Addcdiv, (), ((S, S), (S, S), torch.rand(S, S) + 1e-2) ),
|
||||
(Addcdiv, (0.6,), ((S, S), (S, S), torch.rand(S, S) + 1e-2), 'scale'),
|
||||
(IndexAdd, (0,), ((S, S), index_variable(2, S), (2, S)) ),
|
||||
(Add, (), ((M, M), (M, M))),
|
||||
(Sub, (), ((M, M), (M, M))),
|
||||
(Mul, (), ((M, M), (M, M))),
|
||||
(Div, (), ((M, M), torch.rand(M, M) + 5e-2)),
|
||||
(Pow, (), (torch.rand(M, M) + 1e-3, torch.rand(M, M) + 0.1)),
|
||||
(AddConstant, (3.14,), ((L, L),)),
|
||||
(SubConstant, (3.14,), ((L, L),)),
|
||||
(SubConstant, (3.14, True), ((L, L),), 'from_tensor'),
|
||||
(MulConstant, (3.14,), ((L, L),)),
|
||||
(DivConstant, (3.14, True), (torch.rand(L, L) + 1e-1,), 'by_tensor'),
|
||||
(PowConstant, (3.14,), (torch.rand(L, L),)),
|
||||
(PowConstant, (3.14, True), (torch.rand(L, L),), 'tensor_power'),
|
||||
(Transpose, (0, 1), (torch.rand(L, L),)),
|
||||
(Transpose, (2, 0), (torch.rand(S, S, S),), '3d'),
|
||||
(Permute, ((0, 4, 3, 5, 1, 2),), ((1, 2, 3, 4, 5, 6),)),
|
||||
(Index, ((1, 2),), (torch.rand(S, S, S),)),
|
||||
(Index, (slice(0, 3),), (torch.rand(S, S, S),), 'slice'),
|
||||
(Index, ((slice(0, 3), 1),), (torch.rand(S, S, S),), 'slice_index'),
|
||||
(View, (S * S, S), (torch.rand(S, S, S),)),
|
||||
(Expand, ((S, 5, S, 5),), ((S, 1, S, 1),)),
|
||||
(Exp, (), (torch.rand(S, S, S),)),
|
||||
(Log, (), (torch.rand(S, S, S) + 1e-2,)),
|
||||
(Log1p, (), (torch.rand(S, S, S),)),
|
||||
(Tanh, (), ((S, S, S),)),
|
||||
(Sigmoid, (), ((S, S, S),)),
|
||||
(Sinh, (), ((S, S, S),)),
|
||||
(Cosh, (), ((S, S, S),)),
|
||||
(Abs, (), ((S, S, S),)),
|
||||
(Clamp, (0, 1), ((S, S, S),)),
|
||||
(Sqrt, (), (torch.rand(S, S, S) + 5e-4,)),
|
||||
(Sin, (), ((S, S, S),)),
|
||||
(Cos, (), ((S, S, S),)),
|
||||
(Tan, (), (torch.randn(S, S, S).clamp(-1, 1),)),
|
||||
(Asin, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),)),
|
||||
(Acos, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),)),
|
||||
(Atan, (), ((S, S, S),)),
|
||||
(Reciprocal, (), (torch.rand(S, S, S) + 0.1,)),
|
||||
(Cmax, (), ((S, S, S), (S, S, S))),
|
||||
(Cmin, (), ((S, S, S), (S, S, S))),
|
||||
(Round, (), ((S, S, S),)),
|
||||
(Sign, (), ((S, S, S),)),
|
||||
(Trunc, (), ((S, S, S),)),
|
||||
(Floor, (), ((S, S, S),)),
|
||||
(Ceil, (), ((S, S, S),)),
|
||||
(Frac, (), ((S, S, S),)),
|
||||
(Fmod, (1.5,), ((S, S, S),)),
|
||||
(Lerp, (0.2,), ((S, S, S), (S, S, S))),
|
||||
(Rsqrt, (), (torch.rand(S, S, S) + 1e-2,)),
|
||||
(Remainder, (1.5,), ((S, S, S),)),
|
||||
(CmaxConstant, (0.5,), ((S, S, S),)),
|
||||
(CminConstant, (0.5,), ((S, S, S),)),
|
||||
(Mean, (), ((S, S, S),)),
|
||||
(Mean, (1,), ((S, S, S),), 'dim'),
|
||||
(Sum, (), ((S, S, S),)),
|
||||
(Sum, (1,), ((S, S, S),), 'dim'),
|
||||
(Prod, (), ((S, S, S),)),
|
||||
(Prod, (1,), ((S, S, S),), 'dim'),
|
||||
(Addmm, (), ((S, M), (S, S), (S, M)),),
|
||||
(Addmm, (0.1, 1), ((S, M), (S, S), (S, M)), 'coef'),
|
||||
(Addbmm, (), ((S, M), (S, S, S), (S, S, M)),),
|
||||
(Addbmm, (0.1, 0.4), ((S, M), (S, S, S), (S, S, M)), 'coef'),
|
||||
(Baddbmm, (), ((S, S, M), (S, S, S), (S, S, M)),),
|
||||
(Baddbmm, (0.1, 0.4), ((S, S, M), (S, S, S), (S, S, M)), 'coef'),
|
||||
(Addmv, (), ((S,), (S, M), (M,)),),
|
||||
(Addmv, (0.1, 0.4), ((S,), (S, M), (M,)), 'coef'),
|
||||
(Addr, (), ((S, M), (S,), (M,)),),
|
||||
(Addr, (0.1, 0.4), ((S, M), (S,), (M,)), 'coef'),
|
||||
(Dot, (), ((L,), (L,)),),
|
||||
(Max, (), ((S, S, S),),),
|
||||
(Repeat, (torch.Size([2, 3, 1, 4]),), ((S, S, S, S),)),
|
||||
(Min, (), ((S, S, S),),),
|
||||
(Max, (0,), ((S, S, S),), 'dim'),
|
||||
(Min, (0,), ((S, S, S),), 'dim'),
|
||||
(Mode, (0,), ((S, S, S),),),
|
||||
(Kthvalue, (2, 0), ((S, S, S),),),
|
||||
(Median, (0,), ((S, S, S),),),
|
||||
(Norm, (1.5,), (torch.rand(S, S, S),), '1_5'),
|
||||
(Norm, (), ((S, S, S),), '2'),
|
||||
(Norm, (3,), ((S, S, S),), '3'),
|
||||
(Norm, (1.5, 0), (torch.rand(S, S, S),), '1_5_dim'),
|
||||
(Norm, (2, 0), ((S, S, S),), '2_dim'),
|
||||
(Norm, (3, 0), ((S, S, S),), '3_dim'),
|
||||
(Addcmul, (), ((S, S), (S, S), (S, S))),
|
||||
(Addcmul, (0.6,), ((S, S), (S, S), (S, S)), 'scale'),
|
||||
(Addcdiv, (), ((S, S), (S, S), torch.rand(S, S) + 1e-2)),
|
||||
(Addcdiv, (0.6,), ((S, S), (S, S), torch.rand(S, S) + 1e-2), 'scale'),
|
||||
(IndexAdd, (0,), ((S, S), index_variable(2, S), (2, S))),
|
||||
# (IndexCopy, (0,), ((S, S), index_variable(2, S), (2, S)) ),
|
||||
(IndexFill, (0, 2), ((S, S), index_variable(2, S)) ),
|
||||
(IndexSelect, (0,), ((S, S), index_variable(2, S)) ),
|
||||
(Gather, (0,), ((M, S), gather_variable((S, S), 1, M)) ),
|
||||
(Gather, (1,), ((M, S), gather_variable((M, S//2), 0, S)), 'dim1'),
|
||||
(Scatter, (0,), ((M, S), gather_variable((S, S), 1, M), (S, S))),
|
||||
(Scatter, (1,), ((M, S), gather_variable((M, S//2), 0, S), (M, S//2)), 'dim1'),
|
||||
(Concat, (0,), ((1, S, S), (2, S, S), (3, S, S)) ),
|
||||
(Resize, (S*S, S), ((S, S, S),) ),
|
||||
(Diag, (), ((S, S),), '2d' ),
|
||||
(Diag, (), ((S,),), '1d' ),
|
||||
(Tril, (), ((S, S),) ),
|
||||
(Tril, (2,), ((S, S),), 'idx' ),
|
||||
(Triu, (), ((S, S),) ),
|
||||
(Triu, (2,), ((S, S),), 'idx' ),
|
||||
(Clone, (), ((S, M, S),) ),
|
||||
(Squeeze, (), ((S, 1, M, 1),) ),
|
||||
(Squeeze, (1,), ((S, 1, M, 1),), 'dim' ),
|
||||
(Unsqueeze, (0,), ((S, M, S),), '0' ),
|
||||
(Unsqueeze, (1,), ((S, M, S),), '1' ),
|
||||
(IndexFill, (0, 2), ((S, S), index_variable(2, S))),
|
||||
(IndexSelect, (0,), ((S, S), index_variable(2, S))),
|
||||
(Gather, (0,), ((M, S), gather_variable((S, S), 1, M))),
|
||||
(Gather, (1,), ((M, S), gather_variable((M, S // 2), 0, S)), 'dim1'),
|
||||
(Scatter, (0,), ((M, S), gather_variable((S, S), 1, M), (S, S))),
|
||||
(Scatter, (1,), ((M, S), gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1'),
|
||||
(Concat, (0,), ((1, S, S), (2, S, S), (3, S, S))),
|
||||
(Resize, (S * S, S), ((S, S, S),)),
|
||||
(Diag, (), ((S, S),), '2d'),
|
||||
(Diag, (), ((S,),), '1d'),
|
||||
(Tril, (), ((S, S),)),
|
||||
(Tril, (2,), ((S, S),), 'idx'),
|
||||
(Triu, (), ((S, S),)),
|
||||
(Triu, (2,), ((S, S),), 'idx'),
|
||||
(Clone, (), ((S, M, S),)),
|
||||
(Squeeze, (), ((S, 1, M, 1),)),
|
||||
(Squeeze, (1,), ((S, 1, M, 1),), 'dim'),
|
||||
(Unsqueeze, (0,), ((S, M, S),), '0'),
|
||||
(Unsqueeze, (1,), ((S, M, S),), '1'),
|
||||
# (MaskedCopy, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False), (S, S),)),
|
||||
(MaskedFill, (10,), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
|
||||
(MaskedSelect, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
|
||||
(Sort, (), ((S, M, S),) ),
|
||||
(Sort, (1,), ((S, M, S),), 'dim' ),
|
||||
(Sort, (1, True), ((S, M, S),), 'dim_desc' ),
|
||||
(Topk, (3,), ((S, M, S),) ),
|
||||
(Topk, (3, 1), ((S, M, S),), 'dim' ),
|
||||
(Topk, (3, 1, True), ((S, M, S),), 'dim_desc' ),
|
||||
(Topk, (3, 1, True, True), ((S, M, S),), 'dim_desc_sort' ),
|
||||
(MaskedFill, (10,), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
|
||||
(MaskedSelect, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
|
||||
(Sort, (), ((S, M, S),)),
|
||||
(Sort, (1,), ((S, M, S),), 'dim'),
|
||||
(Sort, (1, True), ((S, M, S),), 'dim_desc'),
|
||||
(Topk, (3,), ((S, M, S),)),
|
||||
(Topk, (3, 1), ((S, M, S),), 'dim'),
|
||||
(Topk, (3, 1, True), ((S, M, S),), 'dim_desc'),
|
||||
(Topk, (3, 1, True, True), ((S, M, S),), 'dim_desc_sort'),
|
||||
]
|
||||
|
||||
|
||||
method_tests = [
|
||||
('add', (S, S, S), ((S, S, S),) ),
|
||||
('add', (S, S, S), (3.14,), 'constant' ),
|
||||
('sub', (S, S, S), ((S, S, S),) ),
|
||||
('sub', (S, S, S), (3.14,), 'constant' ),
|
||||
('mul', (S, S, S), ((S, S, S),) ),
|
||||
('mul', (S, S, S), (3.14,), 'constant' ),
|
||||
('div', (S, S, S), ((S, S, S),) ),
|
||||
('div', (S, S, S), (3.14,), 'constant' ),
|
||||
('pow', (S, S, S), ((S, S, S),) ),
|
||||
('pow', (S, S, S), (3.14,), 'constant' ),
|
||||
('transpose', (1, 2, 3), (1, 2) ),
|
||||
('t', (1, 2), () ),
|
||||
('view', (S, S, S), (S*S, S), ),
|
||||
('view_as', (S, S, S), ((S*S, S),) ),
|
||||
('expand', (S, 1, S), (S, S, S) ),
|
||||
('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size' ),
|
||||
('exp', (S, S, S), () ),
|
||||
('log', (S, S, S), () ),
|
||||
('log1p', (S, S, S), () ),
|
||||
('tanh', (S, S, S), () ),
|
||||
('sigmoid', (S, S, S), () ),
|
||||
('sinh', (S, S, S), () ),
|
||||
('cosh', (S, S, S), () ),
|
||||
('abs', (S, S, S), () ),
|
||||
('clamp', (S, S, S), (0, 1) ),
|
||||
('sqrt', (S, S, S), () ),
|
||||
('sin', (S, S, S), () ),
|
||||
('cos', (S, S, S), () ),
|
||||
('tan', (S, S, S), () ),
|
||||
('asin', (S, S, S), () ),
|
||||
('acos', (S, S, S), () ),
|
||||
('atan', (S, S, S), () ),
|
||||
('reciprocal', (S, S, S), () ),
|
||||
('round', (S, S, S), () ),
|
||||
('sign', (S, S, S), () ),
|
||||
('trunc', (S, S, S), () ),
|
||||
('floor', (S, S, S), () ),
|
||||
('ceil', (S, S, S), () ),
|
||||
('rsqrt', (S, S, S), () ),
|
||||
('fmod', (S, S, S), (1.5,) ),
|
||||
('remainder', (S, S, S), (1.5,) ),
|
||||
('lerp', (S, S, S), ((S, S, S), 0.4) ),
|
||||
('max', (S, S, S), () ),
|
||||
('max', (S, S, S), ((S, S, S),), 'elementwise' ),
|
||||
('min', (S, S, S), () ),
|
||||
('min', (S, S, S), ((S, S, S),), 'elementwise' ),
|
||||
('mean', (S, S, S), () ),
|
||||
('mean', (S, S, S), (1,), 'dim' ),
|
||||
('sum', (S, S, S), () ),
|
||||
('sum', (S, S, S), (1,), 'dim' ),
|
||||
('prod', (S, S, S), () ),
|
||||
('prod', (S, S, S), (1,), 'dim' ),
|
||||
('addmm', (S, M), ((S, S), (S, M)), ),
|
||||
('addmm', (S, M), (0.2, 0.6, (S, S), (S, M)), 'coef' ),
|
||||
('addbmm', (S, M), ((S, S, S), (S, S, M)), ),
|
||||
('addbmm', (S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef' ),
|
||||
('baddbmm', (S, S, M), ((S, S, S), (S, S, M)), ),
|
||||
('baddbmm', (S, S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef' ),
|
||||
('addmv', (S,), ((S, M), (M,)), ),
|
||||
('addmv', (S,), (0.2, 0.6, (S, M), (M,)), 'coef' ),
|
||||
('addr', (S, M), ((S,), (M,)), ),
|
||||
('addr', (S, M), (0.2, 0.6, (S,), (M,)), 'coef' ),
|
||||
('dot', (L,), ((L,),), ),
|
||||
('addcmul', (S, S), ((S, S), (S, S)) ),
|
||||
('addcmul', (S, S), (0.5, (S, S), (S, S)), 'scale' ),
|
||||
('addcdiv', (S, S), ((S, S), (S, S)) ),
|
||||
('addcdiv', (S, S), (0.5, (S, S), (S, S)), 'scale' ),
|
||||
('norm', (S, S, S), (2,) ),
|
||||
('norm', (S, S, S), (2, 1), 'dim' ),
|
||||
('dist', (S, S, S), ((S, S, S),) ),
|
||||
('dist', (S, S, S), ((S, S, S), 4), '4' ),
|
||||
('index_select', (S, S, S), (0, index_variable(2, S)) ),
|
||||
('diag', (M, M), (), '2d' ),
|
||||
('diag', (M,), (), '1d' ),
|
||||
('tril', (M, M), () ),
|
||||
('triu', (M, M), () ),
|
||||
('clone', (S, M, S), () ),
|
||||
('permute', (1, 2, 3, 4), (0, 2, 3, 1) ),
|
||||
('select', (S, S, S), (1, 2) ),
|
||||
('narrow', (S, S, S), (1, 2, 2) ),
|
||||
('squeeze', (S, 1, S, 1), () ),
|
||||
('squeeze', (S, 1, S, 1), (1,), '1_dim' ),
|
||||
('squeeze', (S, 1, S, 1), (2,), 'not_1_dim' ),
|
||||
('unsqueeze', (S, S, S), (0,), 'first' ),
|
||||
('unsqueeze', (S, S, S), (1,), 'middle' ),
|
||||
('unsqueeze', (S, S, S), (3,), 'last' ),
|
||||
('masked_select', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False),) ),
|
||||
('masked_fill_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), 10) ),
|
||||
('masked_copy_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), (M, M)) ),
|
||||
('add', (S, S, S), ((S, S, S),)),
|
||||
('add', (S, S, S), (3.14,), 'constant'),
|
||||
('sub', (S, S, S), ((S, S, S),)),
|
||||
('sub', (S, S, S), (3.14,), 'constant'),
|
||||
('mul', (S, S, S), ((S, S, S),)),
|
||||
('mul', (S, S, S), (3.14,), 'constant'),
|
||||
('div', (S, S, S), ((S, S, S),)),
|
||||
('div', (S, S, S), (3.14,), 'constant'),
|
||||
('pow', (S, S, S), ((S, S, S),)),
|
||||
('pow', (S, S, S), (3.14,), 'constant'),
|
||||
('transpose', (1, 2, 3), (1, 2)),
|
||||
('t', (1, 2), ()),
|
||||
('view', (S, S, S), (S * S, S),),
|
||||
('view_as', (S, S, S), ((S * S, S),)),
|
||||
('expand', (S, 1, S), (S, S, S)),
|
||||
('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size'),
|
||||
('exp', (S, S, S), ()),
|
||||
('log', (S, S, S), ()),
|
||||
('log1p', (S, S, S), ()),
|
||||
('tanh', (S, S, S), ()),
|
||||
('sigmoid', (S, S, S), ()),
|
||||
('sinh', (S, S, S), ()),
|
||||
('cosh', (S, S, S), ()),
|
||||
('abs', (S, S, S), ()),
|
||||
('clamp', (S, S, S), (0, 1)),
|
||||
('sqrt', (S, S, S), ()),
|
||||
('sin', (S, S, S), ()),
|
||||
('cos', (S, S, S), ()),
|
||||
('tan', (S, S, S), ()),
|
||||
('asin', (S, S, S), ()),
|
||||
('acos', (S, S, S), ()),
|
||||
('atan', (S, S, S), ()),
|
||||
('reciprocal', (S, S, S), ()),
|
||||
('round', (S, S, S), ()),
|
||||
('sign', (S, S, S), ()),
|
||||
('trunc', (S, S, S), ()),
|
||||
('floor', (S, S, S), ()),
|
||||
('ceil', (S, S, S), ()),
|
||||
('rsqrt', (S, S, S), ()),
|
||||
('fmod', (S, S, S), (1.5,)),
|
||||
('remainder', (S, S, S), (1.5,)),
|
||||
('lerp', (S, S, S), ((S, S, S), 0.4)),
|
||||
('max', (S, S, S), ()),
|
||||
('max', (S, S, S), ((S, S, S),), 'elementwise'),
|
||||
('min', (S, S, S), ()),
|
||||
('min', (S, S, S), ((S, S, S),), 'elementwise'),
|
||||
('mean', (S, S, S), ()),
|
||||
('mean', (S, S, S), (1,), 'dim'),
|
||||
('sum', (S, S, S), ()),
|
||||
('sum', (S, S, S), (1,), 'dim'),
|
||||
('prod', (S, S, S), ()),
|
||||
('prod', (S, S, S), (1,), 'dim'),
|
||||
('var', (S, S, S), ()),
|
||||
('var', (S, S, S), (1,), 'dim'),
|
||||
('std', (S, S, S), ()),
|
||||
('std', (S, S, S), (1,), 'dim'),
|
||||
('renorm', (S, S, S), (2, 1, 0.5)),
|
||||
('renorm', (S, S, S), (1, 2, 3), 'norm_1'),
|
||||
('repeat', (S, S, S, S), (2, 3, 1, 4)),
|
||||
('addmm', (S, M), ((S, S), (S, M)),),
|
||||
('addmm', (S, M), (0.2, 0.6, (S, S), (S, M)), 'coef'),
|
||||
('addbmm', (S, M), ((S, S, S), (S, S, M)),),
|
||||
('addbmm', (S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef'),
|
||||
('baddbmm', (S, S, M), ((S, S, S), (S, S, M)),),
|
||||
('baddbmm', (S, S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef'),
|
||||
('addmv', (S,), ((S, M), (M,)),),
|
||||
('addmv', (S,), (0.2, 0.6, (S, M), (M,)), 'coef'),
|
||||
('addr', (S, M), ((S,), (M,)),),
|
||||
('addr', (S, M), (0.2, 0.6, (S,), (M,)), 'coef'),
|
||||
('dot', (L,), ((L,),),),
|
||||
('addcmul', (S, S), ((S, S), (S, S))),
|
||||
('addcmul', (S, S), (0.5, (S, S), (S, S)), 'scale'),
|
||||
('addcdiv', (S, S), ((S, S), (S, S))),
|
||||
('addcdiv', (S, S), (0.5, (S, S), (S, S)), 'scale'),
|
||||
('norm', (S, S, S), (2,)),
|
||||
('norm', (S, S, S), (2, 1), 'dim'),
|
||||
('dist', (S, S, S), ((S, S, S),)),
|
||||
('dist', (S, S, S), ((S, S, S), 4), '4'),
|
||||
('index_select', (S, S, S), (0, index_variable(2, S))),
|
||||
('diag', (M, M), (), '2d'),
|
||||
('diag', (M,), (), '1d'),
|
||||
('tril', (M, M), ()),
|
||||
('triu', (M, M), ()),
|
||||
('clone', (S, M, S), ()),
|
||||
('eq', (S, S, S), ((S, S, S),)),
|
||||
('ne', (S, S, S), ((S, S, S),)),
|
||||
('gt', (S, S, S), ((S, S, S),)),
|
||||
('ge', (S, S, S), ((S, S, S),)),
|
||||
('lt', (S, S, S), ((S, S, S),)),
|
||||
('le', (S, S, S), ((S, S, S),)),
|
||||
('eq', (S, S, S), (0,), 'scalar'),
|
||||
('ne', (S, S, S), (0,), 'scalar'),
|
||||
('gt', (S, S, S), (0,), 'scalar'),
|
||||
('ge', (S, S, S), (0,), 'scalar'),
|
||||
('lt', (S, S, S), (0,), 'scalar'),
|
||||
('le', (S, S, S), (0,), 'scalar'),
|
||||
('permute', (1, 2, 3, 4), (0, 2, 3, 1)),
|
||||
('select', (S, S, S), (1, 2)),
|
||||
('narrow', (S, S, S), (1, 2, 2)),
|
||||
('squeeze', (S, 1, S, 1), ()),
|
||||
('squeeze', (S, 1, S, 1), (1,), '1_dim'),
|
||||
('squeeze', (S, 1, S, 1), (2,), 'not_1_dim'),
|
||||
('unsqueeze', (S, S, S), (0,), 'first'),
|
||||
('unsqueeze', (S, S, S), (1,), 'middle'),
|
||||
('unsqueeze', (S, S, S), (3,), 'last'),
|
||||
('masked_select', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False),)),
|
||||
('masked_fill_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), 10)),
|
||||
('masked_copy_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), (M, M))),
|
||||
]
|
||||
# TODO: mm, bmm, mv, ger
|
||||
# TODO: max, min with dim (problem with indices)
|
||||
@ -941,6 +976,7 @@ method_tests = [
|
||||
def create_input(call_args):
|
||||
if not isinstance(call_args, tuple):
|
||||
call_args = (call_args,)
|
||||
|
||||
def map_arg(arg):
|
||||
if isinstance(arg, tuple) and not isinstance(arg[0], Variable):
|
||||
return Variable(torch.randn(*arg).double(), requires_grad=True)
|
||||
@ -971,8 +1007,9 @@ ignore_inplace = set((
|
||||
for test in function_tests:
|
||||
cls, constructor_args, call_args = test[:3]
|
||||
test_name = 'test_' + cls.__name__ + ('_' + test[3] if len(test) == 4 else '')
|
||||
|
||||
def do_test(self, cls=cls, constructor_args=constructor_args,
|
||||
call_args=call_args, test_name=test_name):
|
||||
call_args=call_args, test_name=test_name):
|
||||
input = create_input(call_args)
|
||||
output = cls(*constructor_args)(*input)
|
||||
if not isinstance(output, tuple):
|
||||
@ -981,6 +1018,7 @@ for test in function_tests:
|
||||
if not o.requires_grad:
|
||||
continue
|
||||
analytical = get_analytical_jacobian(input, o)
|
||||
|
||||
def fn(input):
|
||||
tmp = cls(*constructor_args)(*input)
|
||||
if not isinstance(tmp, tuple):
|
||||
@ -1027,6 +1065,7 @@ EXCLUDE_FUNCTIONAL = {
|
||||
for test in method_tests:
|
||||
name, self_size, args = test[:3]
|
||||
test_name = 'test_' + name + ('_' + test[3] if len(test) == 4 else '')
|
||||
|
||||
def do_test(self, name=name, self_size=self_size, args=args, test_name=test_name):
|
||||
def check(name):
|
||||
self_variable = create_input((self_size,))[0]
|
||||
@ -1056,13 +1095,12 @@ for test in method_tests:
|
||||
try:
|
||||
check(inplace_name)
|
||||
except Exception as e:
|
||||
if not 'only supports scalar' in e.args[0]:
|
||||
if 'only supports scalar' not in e.args[0]:
|
||||
raise
|
||||
|
||||
|
||||
assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name
|
||||
setattr(TestAutograd, test_name, do_test)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
||||
@ -7,13 +7,14 @@ import torch
|
||||
import torch.cuda
|
||||
import torch.cuda.comm as comm
|
||||
|
||||
from common import TestCase, get_gpu_type, to_gpu, freeze_rng_state
|
||||
from common import TestCase, get_gpu_type, to_gpu, freeze_rng_state, run_tests
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print('CUDA not available, skipping tests')
|
||||
import sys
|
||||
sys.exit()
|
||||
|
||||
|
||||
def is_floating(t):
|
||||
return type(t) in [torch.FloatTensor, torch.DoubleTensor,
|
||||
torch.cuda.FloatTensor, torch.cuda.DoubleTensor]
|
||||
@ -31,7 +32,8 @@ types = [
|
||||
float_types = [
|
||||
torch.FloatTensor,
|
||||
torch.DoubleTensor
|
||||
] # TODO: add half...
|
||||
] # TODO: add half...
|
||||
|
||||
|
||||
def number(floating, integer, t):
|
||||
name = type(t).__name__
|
||||
@ -44,188 +46,204 @@ def number(floating, integer, t):
|
||||
S = 10
|
||||
M = 50
|
||||
|
||||
|
||||
def make_tensor(t, *sizes):
|
||||
return t(*sizes).copy_(torch.randn(*sizes))
|
||||
|
||||
|
||||
def small_2d(t):
|
||||
return make_tensor(t, S, S)
|
||||
|
||||
|
||||
def small_2d_scaled(t, scale=10):
|
||||
return make_tensor(t, S, S).mul(scale)
|
||||
|
||||
|
||||
def small_3d(t):
|
||||
return make_tensor(t, S, S, S)
|
||||
|
||||
|
||||
def medium_1d(t):
|
||||
return make_tensor(t, M)
|
||||
|
||||
|
||||
def medium_2d(t):
|
||||
return make_tensor(t, M, M)
|
||||
|
||||
|
||||
def medium_2d_scaled(t, scale=10):
|
||||
return make_tensor(t, M, M).mul(scale)
|
||||
|
||||
|
||||
def small_3d_ones(t):
|
||||
return t(S, S, S).copy_(torch.ones(S, S, S))
|
||||
|
||||
|
||||
def small_3d_positive(t):
|
||||
min_val = 1e-3 if is_floating(t) else 2
|
||||
return make_tensor(t, S, S, S).clamp_(min_val, 120)
|
||||
|
||||
|
||||
def small_3d_unique(t):
|
||||
return t(S, S, S).copy_(torch.range(1, S*S*S))
|
||||
return t(S, S, S).copy_(torch.range(1, S * S * S))
|
||||
|
||||
|
||||
def small_1d_lapack(t):
|
||||
return t(1, 3).copy_(torch.range(1, 3).view(3))
|
||||
|
||||
|
||||
def small_2d_lapack(t):
|
||||
return t(3, 3).copy_(torch.range(1, 9).view(3, 3))
|
||||
|
||||
|
||||
def small_2d_lapack_skinny(t):
|
||||
return t(3, 4).copy_(torch.range(1, 12).view(3, 4))
|
||||
|
||||
|
||||
def small_2d_lapack_fat(t):
|
||||
return t(4, 3).copy_(torch.range(1, 12).view(4, 3))
|
||||
|
||||
|
||||
def new_t(*sizes):
|
||||
def tmp(t):
|
||||
return t(*sizes).copy_(torch.randn(*sizes))
|
||||
return tmp
|
||||
|
||||
tests = [
|
||||
('add', small_3d, lambda t: [number(3.14, 3, t)] ),
|
||||
('add', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ),
|
||||
('add', small_3d, lambda t: [number(0.2, 2, t), small_3d_positive(t)], 'scalar_tensor' ),
|
||||
('sub', small_3d, lambda t: [number(3.14, 3, t)], ),
|
||||
('sub', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ),
|
||||
('mul', small_3d, lambda t: [number(3.14, 3, t)], ),
|
||||
('mul', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ),
|
||||
('div', small_3d, lambda t: [number(3.14, 3, t)], ),
|
||||
('div', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ),
|
||||
('pow', small_3d, lambda t: [number(3.14, 3, t)], None, float_types),
|
||||
('pow', small_3d, lambda t: [small_3d(t).abs_()], 'tensor', float_types),
|
||||
('addbmm', small_2d, lambda t: [small_3d(t), small_3d(t)], None, float_types),
|
||||
('addbmm', small_2d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar' ),
|
||||
('addbmm', small_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), small_3d(t), small_3d(t)], 'two_scalars' ),
|
||||
('baddbmm', small_3d, lambda t: [small_3d(t), small_3d(t)], ),
|
||||
('baddbmm', small_3d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar' ),
|
||||
('baddbmm', small_3d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), small_3d(t), small_3d(t)], 'two_scalars' ),
|
||||
('addcdiv', small_2d_lapack, lambda t: [small_2d_lapack(t).mul(2), small_2d_lapack(t)], ),
|
||||
('addcdiv', small_2d_lapack, lambda t: [number(2.8, 1, t), small_2d_lapack(t).mul(2), small_2d_lapack(t)], 'scalar' ),
|
||||
('addcmul', small_3d, lambda t: [small_3d(t), small_3d(t)], ),
|
||||
('addcmul', small_3d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar' ),
|
||||
('addmm', medium_2d, lambda t: [medium_2d(t), medium_2d(t)], ),
|
||||
('addmm', medium_2d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_2d(t)], 'scalar' ),
|
||||
('addmm', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_2d(t)], 'two_scalars' ),
|
||||
('addmv', medium_1d, lambda t: [medium_2d(t), medium_1d(t)], ),
|
||||
('addmv', medium_1d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'scalar' ),
|
||||
('addmv', medium_1d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'two_scalars' ),
|
||||
('addr', medium_2d, lambda t: [medium_1d(t), medium_1d(t)], ),
|
||||
('addr', medium_2d, lambda t: [number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'scalar' ),
|
||||
('addr', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'two_scalars' ),
|
||||
('atan2', medium_2d, lambda t: [medium_2d(t)], None, float_types),
|
||||
('fmod', small_3d, lambda t: [3], 'value' ),
|
||||
('fmod', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ),
|
||||
('chunk', medium_2d, lambda t: [4], ),
|
||||
('chunk', medium_2d, lambda t: [4, 1], 'dim' ),
|
||||
('clamp', medium_2d_scaled, lambda t: [-1, 5], ),
|
||||
('clone', medium_2d, lambda t: [], ),
|
||||
('contiguous', medium_2d, lambda t: [], ),
|
||||
('cross', new_t(M, 3, M), lambda t: [new_t(M, 3, M)(t)], ),
|
||||
('cumprod', small_3d, lambda t: [1], ),
|
||||
('cumsum', small_3d, lambda t: [1], ),
|
||||
('dim', small_3d, lambda t: [], ),
|
||||
('dist', small_2d, lambda t: [small_2d(t)], ),
|
||||
('dist', small_2d, lambda t: [small_2d(t), 3], '3_norm' ),
|
||||
('dist', small_2d, lambda t: [small_2d(t), 2.5], '2_5_norm' ),
|
||||
('dot', medium_1d, lambda t: [medium_1d(t)], ),
|
||||
('element_size', medium_1d, lambda t: [], ),
|
||||
('eq', small_3d_ones, lambda t: [small_3d(t)], ),
|
||||
('eq', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ),
|
||||
('ne', small_3d_ones, lambda t: [small_3d(t)], ),
|
||||
('ne', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ),
|
||||
('equal', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ),
|
||||
('equal', small_3d_ones, lambda t: [small_3d(t)], ),
|
||||
('expand', new_t(M, 1, M), lambda t: [M, 4, M], ),
|
||||
('expand_as', new_t(M, 1, M), lambda t: [new_t(M, 4, M)(t)], ),
|
||||
('fill', medium_2d, lambda t: [number(3.14, 3, t)], ),
|
||||
('ge', medium_2d, lambda t: [medium_2d(t)], ),
|
||||
('le', medium_2d, lambda t: [medium_2d(t)], ),
|
||||
('gt', medium_2d, lambda t: [medium_2d(t)], ),
|
||||
('lt', medium_2d, lambda t: [medium_2d(t)], ),
|
||||
('is_contiguous', medium_2d, lambda t: [], ),
|
||||
('add', small_3d, lambda t: [number(3.14, 3, t)]),
|
||||
('add', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
|
||||
('add', small_3d, lambda t: [number(0.2, 2, t), small_3d_positive(t)], 'scalar_tensor'),
|
||||
('sub', small_3d, lambda t: [number(3.14, 3, t)],),
|
||||
('sub', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
|
||||
('mul', small_3d, lambda t: [number(3.14, 3, t)],),
|
||||
('mul', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
|
||||
('div', small_3d, lambda t: [number(3.14, 3, t)],),
|
||||
('div', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
|
||||
('pow', small_3d, lambda t: [number(3.14, 3, t)], None, float_types),
|
||||
('pow', small_3d, lambda t: [small_3d(t).abs_()], 'tensor', float_types),
|
||||
('addbmm', small_2d, lambda t: [small_3d(t), small_3d(t)], None, float_types),
|
||||
('addbmm', small_2d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar'),
|
||||
('addbmm', small_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), small_3d(t), small_3d(t)], 'two_scalars'),
|
||||
('baddbmm', small_3d, lambda t: [small_3d(t), small_3d(t)],),
|
||||
('baddbmm', small_3d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar'),
|
||||
('baddbmm', small_3d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), small_3d(t), small_3d(t)], 'two_scalars'),
|
||||
('addcdiv', small_2d_lapack, lambda t: [small_2d_lapack(t).mul(2), small_2d_lapack(t)],),
|
||||
('addcdiv', small_2d_lapack, lambda t: [number(2.8, 1, t),
|
||||
small_2d_lapack(t).mul(2), small_2d_lapack(t)], 'scalar'),
|
||||
('addcmul', small_3d, lambda t: [small_3d(t), small_3d(t)],),
|
||||
('addcmul', small_3d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar'),
|
||||
('addmm', medium_2d, lambda t: [medium_2d(t), medium_2d(t)],),
|
||||
('addmm', medium_2d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_2d(t)], 'scalar'),
|
||||
('addmm', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_2d(t)], 'two_scalars'),
|
||||
('addmv', medium_1d, lambda t: [medium_2d(t), medium_1d(t)],),
|
||||
('addmv', medium_1d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'scalar'),
|
||||
('addmv', medium_1d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'two_scalars'),
|
||||
('addr', medium_2d, lambda t: [medium_1d(t), medium_1d(t)],),
|
||||
('addr', medium_2d, lambda t: [number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'scalar'),
|
||||
('addr', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'two_scalars'),
|
||||
('atan2', medium_2d, lambda t: [medium_2d(t)], None, float_types),
|
||||
('fmod', small_3d, lambda t: [3], 'value'),
|
||||
('fmod', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
|
||||
('chunk', medium_2d, lambda t: [4],),
|
||||
('chunk', medium_2d, lambda t: [4, 1], 'dim'),
|
||||
('clamp', medium_2d_scaled, lambda t: [-1, 5],),
|
||||
('clone', medium_2d, lambda t: [],),
|
||||
('contiguous', medium_2d, lambda t: [],),
|
||||
('cross', new_t(M, 3, M), lambda t: [new_t(M, 3, M)(t)],),
|
||||
('cumprod', small_3d, lambda t: [1],),
|
||||
('cumsum', small_3d, lambda t: [1],),
|
||||
('dim', small_3d, lambda t: [],),
|
||||
('dist', small_2d, lambda t: [small_2d(t)],),
|
||||
('dist', small_2d, lambda t: [small_2d(t), 3], '3_norm'),
|
||||
('dist', small_2d, lambda t: [small_2d(t), 2.5], '2_5_norm'),
|
||||
('dot', medium_1d, lambda t: [medium_1d(t)],),
|
||||
('element_size', medium_1d, lambda t: [],),
|
||||
('eq', small_3d_ones, lambda t: [small_3d(t)],),
|
||||
('eq', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal'),
|
||||
('ne', small_3d_ones, lambda t: [small_3d(t)],),
|
||||
('ne', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal'),
|
||||
('equal', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal'),
|
||||
('equal', small_3d_ones, lambda t: [small_3d(t)],),
|
||||
('expand', new_t(M, 1, M), lambda t: [M, 4, M],),
|
||||
('expand_as', new_t(M, 1, M), lambda t: [new_t(M, 4, M)(t)],),
|
||||
('fill', medium_2d, lambda t: [number(3.14, 3, t)],),
|
||||
('ge', medium_2d, lambda t: [medium_2d(t)],),
|
||||
('le', medium_2d, lambda t: [medium_2d(t)],),
|
||||
('gt', medium_2d, lambda t: [medium_2d(t)],),
|
||||
('lt', medium_2d, lambda t: [medium_2d(t)],),
|
||||
('is_contiguous', medium_2d, lambda t: [],),
|
||||
# TODO: can't check negative case - GPU copy will be contiguous
|
||||
('is_same_size', medium_2d, lambda t: [small_3d(t)], 'negative' ),
|
||||
('is_same_size', medium_2d, lambda t: [medium_2d(t)], 'positive' ),
|
||||
('is_set_to', medium_2d, lambda t: [medium_2d(t)], ),
|
||||
('is_same_size', medium_2d, lambda t: [small_3d(t)], 'negative'),
|
||||
('is_same_size', medium_2d, lambda t: [medium_2d(t)], 'positive'),
|
||||
('is_set_to', medium_2d, lambda t: [medium_2d(t)],),
|
||||
# TODO: positive case
|
||||
('kthvalue', small_3d_unique, lambda t: [3], ),
|
||||
('kthvalue', small_3d_unique, lambda t: [3, 1], 'dim' ),
|
||||
('lerp', small_3d, lambda t: [small_3d(t), 0.3], ),
|
||||
('max', small_3d_unique, lambda t: [], ),
|
||||
('max', small_3d_unique, lambda t: [1], 'dim' ),
|
||||
('max', medium_2d, lambda t: [medium_2d(t)], 'elementwise' ),
|
||||
('min', small_3d_unique, lambda t: [], ),
|
||||
('min', small_3d_unique, lambda t: [1], 'dim' ),
|
||||
('min', medium_2d, lambda t: [medium_2d(t)], 'elementwise' ),
|
||||
('mean', small_3d, lambda t: [], ),
|
||||
('mean', small_3d, lambda t: [1], 'dim' ),
|
||||
('mode', small_3d, lambda t: [], ),
|
||||
('mode', small_3d, lambda t: [1], 'dim' ),
|
||||
('remainder', small_3d, lambda t: [3], 'value' ),
|
||||
('remainder', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ),
|
||||
('std', small_3d, lambda t: [], ),
|
||||
('std', small_3d, lambda t: [1], 'dim' ),
|
||||
('var', small_3d, lambda t: [], ),
|
||||
('var', small_3d, lambda t: [1], 'dim' ),
|
||||
('ndimension', small_3d, lambda t: [], ),
|
||||
('nelement', small_3d, lambda t: [], ),
|
||||
('numel', small_3d, lambda t: [], ),
|
||||
('narrow', small_3d, lambda t: [1, 3, 2], ),
|
||||
('nonzero', small_3d, lambda t: [], ),
|
||||
('norm', small_3d, lambda t: [], ),
|
||||
('norm', small_3d, lambda t: [3], '3_norm' ),
|
||||
('norm', small_3d, lambda t: [3, 0], '3_norm_dim' ),
|
||||
('ones', small_3d, lambda t: [1, 2, 3, 4, 5], ),
|
||||
('permute', new_t(1, 2, 3, 4), lambda t: [2, 1, 3, 0], ),
|
||||
('prod', small_3d, lambda t: [], ),
|
||||
('prod', small_3d, lambda t: [1], 'dim' ),
|
||||
('sum', small_2d, lambda t: [], ),
|
||||
('sum', small_3d, lambda t: [1], 'dim' ),
|
||||
('renorm', small_3d, lambda t: [2, 1, 1], '2_norm' ),
|
||||
('renorm', small_3d, lambda t: [1.5, 1, 1], '1_5_norm' ),
|
||||
('repeat', small_2d, lambda t: [2, 2, 2], ),
|
||||
('size', new_t(1, 2, 3, 4), lambda t: [], ),
|
||||
('sort', small_3d_unique, lambda t: [], ),
|
||||
('sort', small_3d_unique, lambda t: [1], 'dim' ),
|
||||
('sort', small_3d_unique, lambda t: [1, True], 'dim_descending'),
|
||||
('split', small_3d, lambda t: [2], ),
|
||||
('split', small_3d, lambda t: [2, 1], 'dim' ),
|
||||
('squeeze', new_t(1, 2, 1, 4), lambda t: [], ),
|
||||
('squeeze', new_t(1, 2, 1, 4), lambda t: [2], 'dim' ),
|
||||
('t', new_t(1, 2), lambda t: [], ),
|
||||
('transpose', new_t(1, 2, 3, 4), lambda t: [1, 2], ),
|
||||
('to_list', small_3d, lambda t: [], ),
|
||||
('topk', small_3d, lambda t: [2, 1, False, True], 'dim_sort' ),
|
||||
('topk', small_3d, lambda t: [2, 1, True, True], 'dim_desc_sort' ),
|
||||
('trace', medium_2d, lambda t: [], ),
|
||||
('tril', medium_2d, lambda t: [], ),
|
||||
('tril', medium_2d, lambda t: [2], 'positive' ),
|
||||
('tril', medium_2d, lambda t: [-2], 'negative' ),
|
||||
('triu', medium_2d, lambda t: [], ),
|
||||
('triu', medium_2d, lambda t: [2], 'positive' ),
|
||||
('triu', medium_2d, lambda t: [-2], 'negative' ),
|
||||
('view', small_3d, lambda t: [100, 10], ),
|
||||
('view_as', small_3d, lambda t: [t(100, 10)], ),
|
||||
('zero', small_3d, lambda t: [], ),
|
||||
('zeros', small_3d, lambda t: [1, 2, 3, 4], ),
|
||||
('rsqrt', lambda t: small_3d(t) + 1, lambda t: [], None, float_types),
|
||||
('sinh', lambda t: small_3d(t).clamp(-1, 1), lambda t: [], None, float_types),
|
||||
('tan', lambda t: small_3d(t).clamp(-1, 1), lambda t: [], None, float_types),
|
||||
('kthvalue', small_3d_unique, lambda t: [3],),
|
||||
('kthvalue', small_3d_unique, lambda t: [3, 1], 'dim'),
|
||||
('lerp', small_3d, lambda t: [small_3d(t), 0.3],),
|
||||
('max', small_3d_unique, lambda t: [],),
|
||||
('max', small_3d_unique, lambda t: [1], 'dim'),
|
||||
('max', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
|
||||
('min', small_3d_unique, lambda t: [],),
|
||||
('min', small_3d_unique, lambda t: [1], 'dim'),
|
||||
('min', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
|
||||
('mean', small_3d, lambda t: [],),
|
||||
('mean', small_3d, lambda t: [1], 'dim'),
|
||||
('mode', small_3d, lambda t: [],),
|
||||
('mode', small_3d, lambda t: [1], 'dim'),
|
||||
('remainder', small_3d, lambda t: [3], 'value'),
|
||||
('remainder', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
|
||||
('std', small_3d, lambda t: [],),
|
||||
('std', small_3d, lambda t: [1], 'dim'),
|
||||
('var', small_3d, lambda t: [],),
|
||||
('var', small_3d, lambda t: [1], 'dim'),
|
||||
('ndimension', small_3d, lambda t: [],),
|
||||
('nelement', small_3d, lambda t: [],),
|
||||
('numel', small_3d, lambda t: [],),
|
||||
('narrow', small_3d, lambda t: [1, 3, 2],),
|
||||
('nonzero', small_3d, lambda t: [],),
|
||||
('norm', small_3d, lambda t: [],),
|
||||
('norm', small_3d, lambda t: [3], '3_norm'),
|
||||
('norm', small_3d, lambda t: [3, 0], '3_norm_dim'),
|
||||
('ones', small_3d, lambda t: [1, 2, 3, 4, 5],),
|
||||
('permute', new_t(1, 2, 3, 4), lambda t: [2, 1, 3, 0],),
|
||||
('prod', small_3d, lambda t: [],),
|
||||
('prod', small_3d, lambda t: [1], 'dim'),
|
||||
('sum', small_2d, lambda t: [],),
|
||||
('sum', small_3d, lambda t: [1], 'dim'),
|
||||
('renorm', small_3d, lambda t: [2, 1, 1], '2_norm'),
|
||||
('renorm', small_3d, lambda t: [1.5, 1, 1], '1_5_norm'),
|
||||
('repeat', small_2d, lambda t: [2, 2, 2],),
|
||||
('size', new_t(1, 2, 3, 4), lambda t: [],),
|
||||
('sort', small_3d_unique, lambda t: [],),
|
||||
('sort', small_3d_unique, lambda t: [1], 'dim'),
|
||||
('sort', small_3d_unique, lambda t: [1, True], 'dim_descending'),
|
||||
('split', small_3d, lambda t: [2],),
|
||||
('split', small_3d, lambda t: [2, 1], 'dim'),
|
||||
('squeeze', new_t(1, 2, 1, 4), lambda t: [],),
|
||||
('squeeze', new_t(1, 2, 1, 4), lambda t: [2], 'dim'),
|
||||
('t', new_t(1, 2), lambda t: [],),
|
||||
('transpose', new_t(1, 2, 3, 4), lambda t: [1, 2],),
|
||||
('to_list', small_3d, lambda t: [],),
|
||||
('topk', small_3d, lambda t: [2, 1, False, True], 'dim_sort'),
|
||||
('topk', small_3d, lambda t: [2, 1, True, True], 'dim_desc_sort'),
|
||||
('trace', medium_2d, lambda t: [],),
|
||||
('tril', medium_2d, lambda t: [],),
|
||||
('tril', medium_2d, lambda t: [2], 'positive'),
|
||||
('tril', medium_2d, lambda t: [-2], 'negative'),
|
||||
('triu', medium_2d, lambda t: [],),
|
||||
('triu', medium_2d, lambda t: [2], 'positive'),
|
||||
('triu', medium_2d, lambda t: [-2], 'negative'),
|
||||
('view', small_3d, lambda t: [100, 10],),
|
||||
('view_as', small_3d, lambda t: [t(100, 10)],),
|
||||
('zero', small_3d, lambda t: [],),
|
||||
('zeros', small_3d, lambda t: [1, 2, 3, 4],),
|
||||
('rsqrt', lambda t: small_3d(t) + 1, lambda t: [], None, float_types),
|
||||
('sinh', lambda t: small_3d(t).clamp(-1, 1), lambda t: [], None, float_types),
|
||||
('tan', lambda t: small_3d(t).clamp(-1, 1), lambda t: [], None, float_types),
|
||||
# lapack tests
|
||||
('qr', small_2d_lapack, lambda t: [], 'square', float_types),
|
||||
('qr', small_2d_lapack_skinny, lambda t: [], 'skinny', float_types),
|
||||
('qr', small_2d_lapack_fat, lambda t: [], 'fat', float_types),
|
||||
('qr', small_2d_lapack, lambda t: [], 'square', float_types),
|
||||
('qr', small_2d_lapack_skinny, lambda t: [], 'skinny', float_types),
|
||||
('qr', small_2d_lapack_fat, lambda t: [], 'fat', float_types),
|
||||
|
||||
]
|
||||
|
||||
@ -275,6 +293,8 @@ for fn in simple_pointwise_float:
|
||||
tests.append((fn, small_3d, lambda t: [], None, float_types))
|
||||
|
||||
_cycles_per_ms = None
|
||||
|
||||
|
||||
def get_cycles_per_ms():
|
||||
"""Approximate number of cycles per millisecond for torch.cuda._sleep"""
|
||||
global _cycles_per_ms
|
||||
@ -288,6 +308,7 @@ def get_cycles_per_ms():
|
||||
_cycles_per_ms = 1000000 / start.elapsed_time(end)
|
||||
return _cycles_per_ms
|
||||
|
||||
|
||||
def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5):
|
||||
def tmp(self):
|
||||
cpu_tensor = tensor_constructor(t)
|
||||
@ -314,6 +335,7 @@ def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5):
|
||||
self.assertEqual(cpu_result, gpu_result, precision)
|
||||
return tmp
|
||||
|
||||
|
||||
class TestCuda(TestCase):
|
||||
|
||||
def test_autogpu(self):
|
||||
@ -412,7 +434,7 @@ class TestCuda(TestCase):
|
||||
y_cuda = y.cuda(1)
|
||||
result = comm.reduce_add((x_cuda, y_cuda))
|
||||
self.assertEqual(result.get_device(), 0)
|
||||
self.assertEqual(result.cpu(), x+y)
|
||||
self.assertEqual(result.cpu(), x + y)
|
||||
|
||||
def _test_scatter(self, input, chunk_sizes=None, dim=0):
|
||||
if torch.cuda.device_count() < 2:
|
||||
@ -473,7 +495,7 @@ class TestCuda(TestCase):
|
||||
self._test_gather(1)
|
||||
|
||||
def test_from_sequence(self):
|
||||
seq = [list(range(i*4,i*4+4)) for i in range(5)]
|
||||
seq = [list(range(i * 4, i * 4 + 4)) for i in range(5)]
|
||||
reference = torch.range(0, 19).resize_(5, 4)
|
||||
for t in types:
|
||||
cuda_type = get_gpu_type(t)
|
||||
@ -526,6 +548,7 @@ class TestCuda(TestCase):
|
||||
@unittest.skipIf(torch.cuda.device_count() < 2, "detected only one GPU")
|
||||
def test_multigpu_serialization_remap(self):
|
||||
x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)]
|
||||
|
||||
def gpu_remap(storage, location):
|
||||
if location == 'cuda:1':
|
||||
return storage.cuda(0)
|
||||
@ -666,7 +689,8 @@ for decl in tests:
|
||||
if not hasattr(tensor, name_inner):
|
||||
continue
|
||||
if not hasattr(gpu_tensor, name_inner):
|
||||
print("Ignoring {}, because it's not implemented by torch.cuda.{}".format(name_inner, gpu_tensor.__class__.__name__))
|
||||
print("Ignoring {}, because it's not implemented by torch.cuda.{}".format(
|
||||
name_inner, gpu_tensor.__class__.__name__))
|
||||
continue
|
||||
|
||||
test_name = 'test_' + t.__name__ + '_' + name_inner
|
||||
@ -677,4 +701,4 @@ for decl in tests:
|
||||
setattr(TestCuda, test_name, compare_cpu_gpu(constr, arg_constr, name_inner, t, precision))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
||||
@ -4,7 +4,7 @@ import torch
|
||||
import traceback
|
||||
import unittest
|
||||
from torch.utils.data import Dataset, TensorDataset, DataLoader
|
||||
from common import TestCase
|
||||
from common import TestCase, run_tests
|
||||
from common_nn import TEST_CUDA
|
||||
|
||||
|
||||
@ -27,11 +27,12 @@ class TestTensorDataset(TestCase):
|
||||
l = torch.randn(15)
|
||||
source = TensorDataset(t, l)
|
||||
for i in range(15):
|
||||
self.assertEqual(t[i:i+1], source[i][0])
|
||||
self.assertEqual(l[i:i+1], source[i][1])
|
||||
self.assertEqual(t[i:i + 1], source[i][0])
|
||||
self.assertEqual(l[i:i + 1], source[i][1])
|
||||
|
||||
|
||||
class ErrorDataset(Dataset):
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
@ -50,9 +51,9 @@ class TestDataLoader(TestCase):
|
||||
batch_size = loader.batch_size
|
||||
for i, (sample, target) in enumerate(loader):
|
||||
idx = i * batch_size
|
||||
self.assertEqual(sample, self.data[idx:idx+batch_size])
|
||||
self.assertEqual(target, self.labels[idx:idx+batch_size].view(-1, 1))
|
||||
self.assertEqual(i, math.floor((len(self.dataset)-1) / batch_size))
|
||||
self.assertEqual(sample, self.data[idx:idx + batch_size])
|
||||
self.assertEqual(target, self.labels[idx:idx + batch_size].view(-1, 1))
|
||||
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
|
||||
|
||||
def _test_shuffle(self, loader):
|
||||
found_data = {i: 0 for i in range(self.data.size(0))}
|
||||
@ -67,9 +68,9 @@ class TestDataLoader(TestCase):
|
||||
break
|
||||
self.assertEqual(target, self.labels.narrow(0, data_point_idx, 1))
|
||||
found_labels[data_point_idx] += 1
|
||||
self.assertEqual(sum(found_data.values()), (i+1) * batch_size)
|
||||
self.assertEqual(sum(found_labels.values()), (i+1) * batch_size)
|
||||
self.assertEqual(i, math.floor((len(self.dataset)-1) / batch_size))
|
||||
self.assertEqual(sum(found_data.values()), (i + 1) * batch_size)
|
||||
self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size)
|
||||
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
|
||||
|
||||
def _test_error(self, loader):
|
||||
it = iter(loader)
|
||||
@ -81,10 +82,9 @@ class TestDataLoader(TestCase):
|
||||
errors += 1
|
||||
except StopIteration:
|
||||
self.assertEqual(errors,
|
||||
math.ceil(float(len(loader.dataset))/loader.batch_size))
|
||||
math.ceil(float(len(loader.dataset)) / loader.batch_size))
|
||||
return
|
||||
|
||||
|
||||
def test_sequential(self):
|
||||
self._test_sequential(DataLoader(self.dataset))
|
||||
|
||||
@ -159,4 +159,4 @@ class TestDataLoader(TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
||||
508
test/test_distributed.py
Normal file
508
test/test_distributed.py
Normal file
@ -0,0 +1,508 @@
|
||||
import fcntl
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
from functools import wraps, reduce
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from common import TestCase
|
||||
|
||||
BACKEND = os.environ['BACKEND']
|
||||
TEMP_DIR = os.environ['TEMP_DIR']
|
||||
MASTER_PORT = '29500'
|
||||
MASTER_ADDR = '127.0.0.1:' + MASTER_PORT
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _lock():
|
||||
lockfile = os.path.join(TEMP_DIR, 'lockfile')
|
||||
with open(lockfile, 'w') as lf:
|
||||
try:
|
||||
fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
|
||||
yield
|
||||
finally:
|
||||
fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
|
||||
lf.close()
|
||||
|
||||
|
||||
def _build_tensor(size, value=None):
|
||||
if value is None:
|
||||
value = size
|
||||
return torch.FloatTensor(size, size, size).fill_(value)
|
||||
|
||||
|
||||
class Barrier(object):
|
||||
barrier_id = 0
|
||||
|
||||
@classmethod
|
||||
def init(cls):
|
||||
cls.barrier_id = 0
|
||||
barrier_dir = os.path.join(TEMP_DIR, 'barrier')
|
||||
for f_name in os.listdir(barrier_dir):
|
||||
os.unlink(os.path.join(barrier_dir, f_name))
|
||||
|
||||
@classmethod
|
||||
def sync(cls, timeout=5):
|
||||
cls.barrier_id += 1
|
||||
barrier_dir = os.path.join(TEMP_DIR, 'barrier')
|
||||
pid = str(os.getpid())
|
||||
barrier_file = os.path.join(barrier_dir, pid)
|
||||
with _lock():
|
||||
with open(barrier_file, 'w') as f:
|
||||
f.write(str(cls.barrier_id))
|
||||
|
||||
start_time = time.time()
|
||||
while True:
|
||||
arrived = 0
|
||||
with _lock():
|
||||
for f_name in os.listdir(barrier_dir):
|
||||
with open(os.path.join(barrier_dir, f_name), 'r') as f:
|
||||
data = f.read()
|
||||
if int(data) >= cls.barrier_id:
|
||||
arrived += 1
|
||||
if arrived == dist.get_num_processes():
|
||||
break
|
||||
|
||||
if time.time() - start_time > timeout:
|
||||
raise RuntimeError("barrier timeout")
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
class _DistTestBase(object):
|
||||
|
||||
def _barrier(self, *args, **kwargs):
|
||||
Barrier.sync(*args, **kwargs)
|
||||
|
||||
def _init_group_test(self):
|
||||
group = [1, 2]
|
||||
group_id = dist.new_group(group)
|
||||
rank = dist.get_rank()
|
||||
if rank not in group:
|
||||
return ([], None, rank)
|
||||
|
||||
return (group, group_id, rank)
|
||||
|
||||
def _init_global_test(self):
|
||||
group = [i for i in range(0, dist.get_num_processes())]
|
||||
group_id = dist.group.WORLD
|
||||
rank = dist.get_rank()
|
||||
return (group, group_id, rank)
|
||||
|
||||
# GET RANK
|
||||
def test_get_rank(self):
|
||||
test_dir = os.path.join(TEMP_DIR, 'test_dir')
|
||||
pid = str(os.getpid())
|
||||
num_processes = dist.get_num_processes()
|
||||
with open(os.path.join(test_dir, pid), 'w') as f:
|
||||
f.write(str(dist.get_rank()))
|
||||
|
||||
self._barrier()
|
||||
|
||||
all_ranks = set()
|
||||
for f_name in os.listdir(test_dir):
|
||||
with open(os.path.join(test_dir, f_name), 'r') as f:
|
||||
all_ranks.add(int(f.read()))
|
||||
self.assertEqual(len(all_ranks), num_processes)
|
||||
|
||||
self._barrier()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
for f_name in os.listdir(test_dir):
|
||||
os.unlink(os.path.join(test_dir, f_name))
|
||||
|
||||
self._barrier()
|
||||
|
||||
# SEND RECV
|
||||
def test_send_recv(self):
|
||||
rank = dist.get_rank()
|
||||
tensor = _build_tensor(rank + 1)
|
||||
for dest in range(0, dist.get_num_processes()):
|
||||
if dest == rank:
|
||||
continue
|
||||
dist.send(tensor, dest)
|
||||
|
||||
for src in range(0, dist.get_num_processes()):
|
||||
if src == rank:
|
||||
continue
|
||||
tensor = _build_tensor(src + 1, value=-1)
|
||||
expected_tensor = _build_tensor(src + 1)
|
||||
dist.recv(tensor, src)
|
||||
self.assertEqual(tensor, expected_tensor)
|
||||
|
||||
self._barrier()
|
||||
|
||||
# SEND RECV ANY SOURCE
|
||||
def test_send_recv_any_source(self):
|
||||
rank = dist.get_rank()
|
||||
tensor = _build_tensor(10, rank)
|
||||
for dest in range(0, dist.get_num_processes()):
|
||||
if dest == rank:
|
||||
continue
|
||||
dist.send(tensor, dest)
|
||||
|
||||
recv_ranks = set()
|
||||
for src in range(0, dist.get_num_processes()):
|
||||
if src == rank:
|
||||
continue
|
||||
tensor = _build_tensor(10, value=-1)
|
||||
dist.recv(tensor)
|
||||
recv_ranks.add(tensor.resize_(1)[0])
|
||||
|
||||
self.assertEqual(len(recv_ranks), dist.get_num_processes() - 1)
|
||||
self._barrier()
|
||||
|
||||
# ISEND
|
||||
def test_isend(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_num_processes()
|
||||
|
||||
if rank == 0:
|
||||
requests = [
|
||||
dist.isend(_build_tensor(dest, 10), dest) for dest in range(1, world_size)
|
||||
]
|
||||
for request in requests:
|
||||
request.wait()
|
||||
self.assertTrue(request.is_completed())
|
||||
else:
|
||||
tensor = _build_tensor(rank, -1)
|
||||
dist.recv(tensor, 0)
|
||||
self.assertEqual(tensor, _build_tensor(rank, 10))
|
||||
|
||||
self._barrier()
|
||||
|
||||
# IRECV
|
||||
def test_irecv(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_num_processes()
|
||||
|
||||
if rank == 0:
|
||||
expected_tensors = [_build_tensor(src, -1) for src in range(1, world_size)]
|
||||
requests = [
|
||||
dist.irecv(expected_tensors[src - 1], src) for src in range(1, world_size)
|
||||
]
|
||||
|
||||
for src in range(1, world_size):
|
||||
requests[src - 1].wait()
|
||||
self.assertTrue(requests[src - 1].is_completed())
|
||||
self.assertEqual(expected_tensors[src - 1], _build_tensor(src, 10))
|
||||
else:
|
||||
tensor = _build_tensor(rank, 10)
|
||||
dist.send(tensor, 0)
|
||||
|
||||
self._barrier()
|
||||
|
||||
# BROADCAST
|
||||
def _test_broadcast_helper(self, group, group_id, rank):
|
||||
for src in group:
|
||||
expected_tensor = _build_tensor(src + 1)
|
||||
if rank == src:
|
||||
dist.broadcast(expected_tensor, src, group_id)
|
||||
else:
|
||||
tensor = _build_tensor(src + 1, -1)
|
||||
dist.broadcast(tensor, src, group_id)
|
||||
self.assertEqual(tensor, expected_tensor)
|
||||
|
||||
self._barrier()
|
||||
|
||||
def test_broadcast(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_broadcast_helper(group, group_id, rank)
|
||||
|
||||
def test_broadcast_group(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_broadcast_helper(group, group_id, rank)
|
||||
|
||||
# REDUCE
|
||||
def _test_reduce_helper(self, group, group_id, rank, op, master_value, worker_value, expected_value):
|
||||
for src in group:
|
||||
if rank == src:
|
||||
tensor = _build_tensor(src + 1).fill_(master_value)
|
||||
dist.reduce(tensor, src, op, group_id)
|
||||
self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
|
||||
else:
|
||||
tensor = _build_tensor(src + 1).fill_(worker_value)
|
||||
dist.reduce(tensor, src, op, group_id)
|
||||
|
||||
self._barrier()
|
||||
|
||||
def test_reduce_sum(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.SUM, 2, 10, 2 + (10 * (len(group) - 1))
|
||||
)
|
||||
|
||||
def test_reduce_product(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.PRODUCT,
|
||||
2, 10, reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2)
|
||||
)
|
||||
|
||||
def test_reduce_min(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.MIN, 1010, 1, 1
|
||||
)
|
||||
|
||||
def test_reduce_max(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.MAX, -1, 10, 10
|
||||
)
|
||||
|
||||
def test_reduce_group_sum(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.SUM, 2, 10, 2 + (10 * (len(group) - 1))
|
||||
)
|
||||
|
||||
def test_reduce_group_product(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.PRODUCT,
|
||||
2, 10, reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2)
|
||||
)
|
||||
|
||||
def test_reduce_group_min(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.MIN, 1010, 1, 1
|
||||
)
|
||||
|
||||
def test_reduce_group_max(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.MAX, -1, 10, 10
|
||||
)
|
||||
|
||||
# ALL REDUCE
|
||||
def _test_all_reduce_helper(self, group, group_id, rank, op, master_value, worker_value, expected_value):
|
||||
for src in group:
|
||||
if rank == src:
|
||||
tensor = _build_tensor(src + 1).fill_(master_value)
|
||||
dist.all_reduce(tensor, op, group_id)
|
||||
self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
|
||||
else:
|
||||
tensor = _build_tensor(src + 1).fill_(worker_value)
|
||||
dist.all_reduce(tensor, op, group_id)
|
||||
self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
|
||||
|
||||
self._barrier()
|
||||
|
||||
def test_all_reduce_sum(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_all_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.SUM, 2, 10, 2 + (10 * (len(group) - 1))
|
||||
)
|
||||
|
||||
def test_all_reduce_product(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_all_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.PRODUCT,
|
||||
2, 10, reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2)
|
||||
)
|
||||
|
||||
def test_all_reduce_min(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_all_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.MIN, 1010, 1, 1
|
||||
)
|
||||
|
||||
def test_all_reduce_max(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_all_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.MAX, -1, 10, 10
|
||||
)
|
||||
|
||||
def test_all_reduce_group_sum(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_all_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.SUM, 2, 10, 2 + (10 * (len(group) - 1))
|
||||
)
|
||||
|
||||
def test_all_reduce_group_product(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_all_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.PRODUCT,
|
||||
2, 10, reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2)
|
||||
)
|
||||
|
||||
def test_all_reduce_group_min(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_all_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.MIN, 1010, 1, 1
|
||||
)
|
||||
|
||||
def test_all_reduce_group_max(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_all_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.MAX, -1, 10, 10
|
||||
)
|
||||
|
||||
# SCATTER
|
||||
def _test_scatter_helper(self, group, group_id, rank):
|
||||
for dest in group:
|
||||
tensor = _build_tensor(dest + 1, -1)
|
||||
expected_tensor = _build_tensor(dest + 1, rank)
|
||||
if rank == dest:
|
||||
tensors = [_build_tensor(dest + 1, i) for i in group]
|
||||
dist.scatter_send(tensors, tensor, group_id)
|
||||
self.assertEqual(tensor, expected_tensor)
|
||||
else:
|
||||
dist.scatter_recv(tensor, dest, group_id)
|
||||
self.assertEqual(tensor, expected_tensor)
|
||||
|
||||
self._barrier()
|
||||
|
||||
def test_scatter(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_scatter_helper(group, group_id, rank)
|
||||
|
||||
def test_scatter_group(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_scatter_helper(group, group_id, rank)
|
||||
|
||||
# GATHER
|
||||
def _test_gather_helper(self, group, group_id, rank):
|
||||
for dest in group:
|
||||
tensor = _build_tensor(dest + 1, rank)
|
||||
if rank == dest:
|
||||
tensors = [_build_tensor(dest + 1, -1) for i in group]
|
||||
dist.gather_recv(tensors, tensor, group_id)
|
||||
|
||||
expected_tensors = [_build_tensor(dest + 1, i) for i in group]
|
||||
for t1, t2 in zip(tensors, expected_tensors):
|
||||
self.assertEqual(t1, t2)
|
||||
else:
|
||||
dist.gather_send(tensor, dest, group_id)
|
||||
|
||||
self._barrier()
|
||||
|
||||
def test_gather(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_gather_helper(group, group_id, rank)
|
||||
|
||||
def test_gather_group(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_gather_helper(group, group_id, rank)
|
||||
|
||||
# ALL GATHER
|
||||
def _test_all_gather_helper(self, group, group_id, rank):
|
||||
for dest in group:
|
||||
tensor = _build_tensor(dest + 1, rank)
|
||||
tensors = [_build_tensor(dest + 1, -1) for i in group]
|
||||
dist.all_gather(tensors, tensor, group_id)
|
||||
|
||||
expected_tensors = [_build_tensor(dest + 1, i) for i in group]
|
||||
for t1, t2 in zip(tensors, expected_tensors):
|
||||
self.assertEqual(t1, t2)
|
||||
|
||||
self._barrier()
|
||||
|
||||
def test_all_gather(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_all_gather_helper(group, group_id, rank)
|
||||
|
||||
def test_all_gather_group(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_all_gather_helper(group, group_id, rank)
|
||||
|
||||
# BARRIER
|
||||
def _test_barrier_helper(self, group, group_id, rank):
|
||||
WAIT_TIME = 0.3 # seconds
|
||||
|
||||
for dest in group:
|
||||
expected_time = torch.DoubleTensor(1).fill_(0.0)
|
||||
if dest == rank:
|
||||
expected_time.fill_(time.time() + WAIT_TIME)
|
||||
dist.broadcast(expected_time, dest, group_id)
|
||||
time.sleep(WAIT_TIME + 0.1) # sleep a little bit longer
|
||||
dist.barrier(group_id)
|
||||
else:
|
||||
dist.broadcast(expected_time, dest, group_id)
|
||||
dist.barrier(group_id)
|
||||
self.assertGreaterEqual(time.time(), expected_time[0])
|
||||
|
||||
self._barrier()
|
||||
|
||||
def test_barrier(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_barrier_helper(group, group_id, rank)
|
||||
|
||||
def test_barrier_group(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_barrier_helper(group, group_id, rank)
|
||||
|
||||
if BACKEND == 'tcp':
|
||||
WORLD_SIZE = os.environ['WORLD_SIZE']
|
||||
|
||||
class TestTCP(TestCase, _DistTestBase):
|
||||
|
||||
MANAGER_PROCESS_RANK = -1
|
||||
JOIN_TIMEOUT = 5
|
||||
|
||||
@staticmethod
|
||||
def manager_join(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(self):
|
||||
if self.rank == self.MANAGER_PROCESS_RANK:
|
||||
self._join_and_reduce()
|
||||
else:
|
||||
fn(self)
|
||||
return wrapper
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
os.environ['MASTER_ADDR'] = MASTER_ADDR
|
||||
os.environ['MASTER_PORT'] = MASTER_PORT
|
||||
os.environ['WORLD_SIZE'] = WORLD_SIZE
|
||||
for attr in dir(cls):
|
||||
if attr.startswith('test'):
|
||||
fn = getattr(cls, attr)
|
||||
setattr(cls, attr, cls.manager_join(fn))
|
||||
|
||||
def setUp(self):
|
||||
self.processes = []
|
||||
self.rank = self.MANAGER_PROCESS_RANK
|
||||
Barrier.init()
|
||||
for rank in range(int(WORLD_SIZE)):
|
||||
self.processes.append(self._spawn_process(rank))
|
||||
|
||||
def tearDown(self):
|
||||
for p in self.processes:
|
||||
p.terminate()
|
||||
|
||||
def _spawn_process(self, rank):
|
||||
os.environ['RANK'] = str(rank)
|
||||
name = 'process ' + str(rank)
|
||||
process = multiprocessing.Process(target=self._run, name=name,
|
||||
args=(rank,))
|
||||
process.start()
|
||||
return process
|
||||
|
||||
def _run(self, rank):
|
||||
self.rank = rank
|
||||
dist.init_process_group(backend=BACKEND)
|
||||
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
|
||||
# We're retreiving a corresponding test and executing it.
|
||||
getattr(self, self.id().split(".")[2])()
|
||||
sys.exit(0)
|
||||
|
||||
def _join_and_reduce(self):
|
||||
for p in self.processes:
|
||||
p.join(self.JOIN_TIMEOUT)
|
||||
self.assertEqual(p.exitcode, 0)
|
||||
|
||||
elif BACKEND == 'mpi':
|
||||
dist.init_process_group(backend='mpi')
|
||||
|
||||
class TestMPI(TestCase, _DistTestBase):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
File diff suppressed because it is too large
Load Diff
@ -11,13 +11,14 @@ import torch.cuda
|
||||
import torch.multiprocessing as mp
|
||||
from torch.autograd import Variable
|
||||
from torch.nn import Parameter
|
||||
from common import TestCase
|
||||
from common import TestCase, run_tests
|
||||
|
||||
|
||||
TEST_REPEATS = 30
|
||||
HAS_SHM_FILES = os.path.isdir('/dev/shm')
|
||||
TEST_CUDA_IPC = torch.cuda.is_available() and \
|
||||
sys.version_info[0] == 3 and \
|
||||
sys.platform != 'darwin'
|
||||
sys.version_info[0] == 3 and \
|
||||
sys.platform != 'darwin'
|
||||
|
||||
|
||||
def simple_fill(queue, event):
|
||||
@ -74,7 +75,7 @@ def autograd_sharing(queue, ready, master_modified):
|
||||
master_modified.wait()
|
||||
|
||||
expected_var = torch.range(1, 25).view(5, 5)
|
||||
expected_var[0,0] = 1000
|
||||
expected_var[0, 0] = 1000
|
||||
is_ok = var.data.equal(expected_var)
|
||||
var.data[:] = torch.ones(5, 5)
|
||||
|
||||
@ -113,7 +114,7 @@ class leak_checker(object):
|
||||
# one-off initialization that may use up a file descriptor
|
||||
available_fds = self._get_next_fds(10)
|
||||
self.test_case.assertLessEqual(
|
||||
available_fds[-1] - self.next_fds[-1], 4)
|
||||
available_fds[-1] - self.next_fds[-1], 5)
|
||||
self.test_case.assertFalse(self.has_shm_files())
|
||||
return False
|
||||
|
||||
@ -189,7 +190,7 @@ class TestMultiprocessing(TestCase):
|
||||
def _test_preserve_sharing(self, ctx=mp, repeat=1):
|
||||
def do_test():
|
||||
x = torch.randn(5, 5)
|
||||
data = [x.storage(), x.storage()[1:4], x, x[2], x[:,1]]
|
||||
data = [x.storage(), x.storage()[1:4], x, x[2], x[:, 1]]
|
||||
q = ctx.Queue()
|
||||
q.put(data)
|
||||
new_data = q.get()
|
||||
@ -229,27 +230,27 @@ class TestMultiprocessing(TestCase):
|
||||
|
||||
@unittest.skipIf(platform == 'darwin', "file descriptor strategy is not supported on OS X")
|
||||
def test_fd_sharing(self):
|
||||
self._test_sharing(repeat=20)
|
||||
self._test_sharing(repeat=TEST_REPEATS)
|
||||
|
||||
@unittest.skipIf(platform == 'darwin', "file descriptor strategy is not supported on OS X")
|
||||
def test_fd_preserve_sharing(self):
|
||||
self._test_preserve_sharing(repeat=20)
|
||||
self._test_preserve_sharing(repeat=TEST_REPEATS)
|
||||
|
||||
@unittest.skipIf(platform == 'darwin', "file descriptor strategy is not supported on OS X")
|
||||
def test_fd_pool(self):
|
||||
self._test_pool(repeat=20)
|
||||
self._test_pool(repeat=TEST_REPEATS)
|
||||
|
||||
def test_fs_sharing(self):
|
||||
with fs_sharing():
|
||||
self._test_sharing(repeat=20)
|
||||
self._test_sharing(repeat=TEST_REPEATS)
|
||||
|
||||
def test_fs_preserve_sharing(self):
|
||||
with fs_sharing():
|
||||
self._test_preserve_sharing(repeat=20)
|
||||
self._test_preserve_sharing(repeat=TEST_REPEATS)
|
||||
|
||||
def test_fs_pool(self):
|
||||
with fs_sharing():
|
||||
self._test_pool(repeat=20)
|
||||
self._test_pool(repeat=TEST_REPEATS)
|
||||
|
||||
@unittest.skipIf(not HAS_SHM_FILES, "don't not how to check if shm files exist")
|
||||
def test_fs(self):
|
||||
@ -263,11 +264,12 @@ class TestMultiprocessing(TestCase):
|
||||
q.get()
|
||||
|
||||
with fs_sharing(), leak_checker(self) as lc:
|
||||
for i in range(20):
|
||||
for i in range(TEST_REPEATS):
|
||||
queue_put()
|
||||
|
||||
def test_inherit_tensor(self):
|
||||
class SubProcess(mp.Process):
|
||||
|
||||
def __init__(self, tensor):
|
||||
super(SubProcess, self).__init__()
|
||||
self.tensor = tensor
|
||||
@ -286,7 +288,6 @@ class TestMultiprocessing(TestCase):
|
||||
torch.cuda.FloatTensor([1]) # initialize CUDA outside of leak checker
|
||||
self._test_sharing(mp.get_context('spawn'), torch.cuda.FloatTensor)
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
|
||||
def test_cuda_small_tensors(self):
|
||||
# Check multiple small tensors which will likely use the same
|
||||
@ -359,7 +360,7 @@ class TestMultiprocessing(TestCase):
|
||||
queue.put(var)
|
||||
|
||||
ready.wait()
|
||||
var.data[0,0] = 1000
|
||||
var.data[0, 0] = 1000
|
||||
if var.grad is not None:
|
||||
var.grad.data[:] = torch.ones(5, 5) * 4
|
||||
master_modified.set()
|
||||
@ -380,8 +381,8 @@ class TestMultiprocessing(TestCase):
|
||||
]
|
||||
for requires_grad, volatile in configs:
|
||||
var = Variable(torch.range(1, 25).view(5, 5),
|
||||
requires_grad=requires_grad,
|
||||
volatile=volatile)
|
||||
requires_grad=requires_grad,
|
||||
volatile=volatile)
|
||||
self._test_autograd_sharing(var)
|
||||
|
||||
def test_parameter_sharing(self):
|
||||
@ -409,4 +410,4 @@ class TestMultiprocessing(TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
||||
@ -4,7 +4,7 @@ import torch
|
||||
import torch.cuda.nccl as nccl
|
||||
import torch.cuda
|
||||
|
||||
from common import TestCase
|
||||
from common import TestCase, run_tests
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print('CUDA not available, skipping tests')
|
||||
@ -87,4 +87,4 @@ class TestNCCL(TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
||||
341
test/test_nn.py
341
test/test_nn.py
@ -13,11 +13,14 @@ import torch.nn.parallel as dp
|
||||
from torch.autograd import Variable
|
||||
from torch.nn import Parameter
|
||||
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
|
||||
module_tests, criterion_tests, TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PRECISION
|
||||
from common import freeze_rng_state
|
||||
module_tests, criterion_tests, TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
|
||||
TEST_CUDNN_VERSION, PRECISION
|
||||
from common import freeze_rng_state, run_tests
|
||||
|
||||
|
||||
def default_tensor_type(type):
|
||||
type_str = torch.typename(type)
|
||||
|
||||
def decorator(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
@ -30,9 +33,12 @@ def default_tensor_type(type):
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
class InputVariableMixin(object):
|
||||
|
||||
def _get_input(self):
|
||||
input = TestBase._get_input(self)
|
||||
|
||||
def map_variables(i):
|
||||
if isinstance(i, Variable):
|
||||
return i
|
||||
@ -44,6 +50,7 @@ class InputVariableMixin(object):
|
||||
|
||||
|
||||
class NewModuleTest(InputVariableMixin, ModuleTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(NewModuleTest, self).__init__(*args, **kwargs)
|
||||
self.cudnn = kwargs.get('cudnn', False)
|
||||
@ -63,10 +70,14 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
|
||||
test_case.assertEqual(input._version, input_version)
|
||||
|
||||
input_ip = deepcopy(input)
|
||||
output_ip = module_ip(input_ip)
|
||||
test_case.assertNotEqual(input_ip._version, input_version)
|
||||
|
||||
input_ip_clone = input_ip.clone()
|
||||
output_ip = module_ip(input_ip_clone)
|
||||
test_case.assertNotEqual(input_ip_clone._version, input_version)
|
||||
test_case.assertEqual(output, output_ip)
|
||||
grad = output.data.clone().normal_()
|
||||
output.backward(grad)
|
||||
output_ip.backward(grad)
|
||||
test_case.assertEqual(output.grad, output_ip.grad)
|
||||
|
||||
if type(input.data) == torch.LongTensor and TEST_CUDA:
|
||||
input = input.cuda()
|
||||
@ -352,21 +363,21 @@ class TestNN(NNTestCase):
|
||||
|
||||
def _test_dropout(self, cls, input):
|
||||
p = 0.2
|
||||
input.fill_(1-p)
|
||||
input.fill_(1 - p)
|
||||
|
||||
module = cls(p)
|
||||
input_var = Variable(input, requires_grad=True)
|
||||
output = module(input_var)
|
||||
self.assertLess(abs(output.data.mean() - (1-p)), 0.05)
|
||||
self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
|
||||
output.backward(input)
|
||||
self.assertLess(abs(input_var.grad.data.mean() - (1-p)), 0.05)
|
||||
self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
|
||||
|
||||
module = cls(p, True)
|
||||
input_var = Variable(input.clone(), requires_grad=True)
|
||||
output = module(input_var + 0)
|
||||
self.assertLess(abs(output.data.mean() - (1-p)), 0.05)
|
||||
self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
|
||||
output.backward(input)
|
||||
self.assertLess(abs(input_var.grad.data.mean() - (1-p)), 0.05)
|
||||
self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
|
||||
|
||||
# Check that these don't raise errors
|
||||
module.__repr__()
|
||||
@ -375,7 +386,9 @@ class TestNN(NNTestCase):
|
||||
def test_parameters(self):
|
||||
def num_params(module):
|
||||
return len(list(module.parameters()))
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.l1 = l
|
||||
@ -390,6 +403,7 @@ class TestNN(NNTestCase):
|
||||
|
||||
def test_modules(self):
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.l1 = l
|
||||
@ -411,6 +425,71 @@ class TestNN(NNTestCase):
|
||||
self.assertEqual(n[2], l3)
|
||||
self.assertEqual(n[3], l4)
|
||||
|
||||
def test_ListModule(self):
|
||||
modules = [nn.ReLU(), nn.Linear(5, 5)]
|
||||
module_list = nn.ModuleList(modules)
|
||||
|
||||
def check():
|
||||
self.assertEqual(len(module_list), len(modules))
|
||||
for m1, m2 in zip(modules, module_list):
|
||||
self.assertIs(m1, m2)
|
||||
for m1, m2 in zip(modules, module_list.children()):
|
||||
self.assertIs(m1, m2)
|
||||
for i in range(len(modules)):
|
||||
self.assertIs(module_list[i], modules[i])
|
||||
check()
|
||||
modules += [nn.Conv2d(3, 4, 3)]
|
||||
module_list += [modules[-1]]
|
||||
check()
|
||||
modules.append(nn.Tanh())
|
||||
module_list.append(modules[-1])
|
||||
check()
|
||||
next_modules = [nn.Linear(5, 5), nn.Sigmoid()]
|
||||
modules.extend(next_modules)
|
||||
module_list.extend(next_modules)
|
||||
check()
|
||||
modules[2] = nn.Conv2d(5, 3, 2)
|
||||
module_list[2] = modules[2]
|
||||
check()
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
module_list += nn.ReLU()
|
||||
with self.assertRaises(TypeError):
|
||||
module_list.extend(nn.ReLU())
|
||||
|
||||
def test_ParameterList(self):
|
||||
make_param = lambda: Parameter(torch.randn(10, 10))
|
||||
parameters = [make_param(), make_param()]
|
||||
param_list = nn.ParameterList(parameters)
|
||||
|
||||
def check():
|
||||
self.assertEqual(len(parameters), len(param_list))
|
||||
for p1, p2 in zip(parameters, param_list):
|
||||
self.assertIs(p1, p2)
|
||||
for p1, p2 in zip(parameters, param_list.parameters()):
|
||||
self.assertIs(p1, p2)
|
||||
for i in range(len(parameters)):
|
||||
self.assertIs(parameters[i], param_list[i])
|
||||
check()
|
||||
parameters += [make_param()]
|
||||
param_list += [parameters[-1]]
|
||||
check()
|
||||
parameters.append(make_param())
|
||||
param_list.append(parameters[-1])
|
||||
check()
|
||||
next_params = [make_param(), make_param()]
|
||||
parameters.extend(next_params)
|
||||
param_list.extend(next_params)
|
||||
check()
|
||||
parameters[2] = make_param()
|
||||
param_list[2] = parameters[2]
|
||||
check()
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
param_list += make_param()
|
||||
with self.assertRaises(TypeError):
|
||||
param_list.extend(make_param())
|
||||
|
||||
def test_add_module(self):
|
||||
l = nn.Linear(10, 20)
|
||||
net = nn.Module()
|
||||
@ -451,6 +530,7 @@ class TestNN(NNTestCase):
|
||||
def test_non_leaf_parameters(self):
|
||||
l1 = nn.Linear(10, 10)
|
||||
l2 = nn.Linear(10, 10)
|
||||
|
||||
def assign_weight():
|
||||
l2.weight = l1.weight + 2
|
||||
self.assertRaises(TypeError, assign_weight)
|
||||
@ -458,8 +538,8 @@ class TestNN(NNTestCase):
|
||||
l2.weight = Parameter(torch.randn(10, 10))
|
||||
|
||||
def test_embedding_padding_idx(self):
|
||||
embedding = nn.Embedding(10, 20, padding_idx = 0)
|
||||
input = Variable(torch.LongTensor([[0,2,4,5],[4,3,0,9]]))
|
||||
embedding = nn.Embedding(10, 20, padding_idx=0)
|
||||
input = Variable(torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]]))
|
||||
output = embedding(input)
|
||||
self.assertEqual(output[0][0].sum().data[0], 0)
|
||||
self.assertEqual(output[1][2].sum().data[0], 0)
|
||||
@ -489,14 +569,14 @@ class TestNN(NNTestCase):
|
||||
def expected_indices(dim):
|
||||
if dim == 1:
|
||||
return torch.DoubleTensor([1, 3])
|
||||
lower_dim = expected_indices(dim-1)
|
||||
lower_dim = expected_indices(dim - 1)
|
||||
lower_dim = lower_dim.view(1, *lower_dim.size())
|
||||
return torch.cat((lower_dim+4, lower_dim+12), 0)
|
||||
return torch.cat((lower_dim + 4, lower_dim + 12), 0)
|
||||
|
||||
def expected_grad(dim):
|
||||
if dim == 1:
|
||||
return torch.DoubleTensor([0, 1, 0, 1])
|
||||
lower_dim_grad = expected_grad(dim-1)
|
||||
lower_dim_grad = expected_grad(dim - 1)
|
||||
grad = lower_dim_grad.view(1, *lower_dim_grad.size())
|
||||
zero = torch.zeros(grad.size())
|
||||
return torch.cat((zero, grad, zero, grad), 0)
|
||||
@ -667,7 +747,9 @@ class TestNN(NNTestCase):
|
||||
def test_data_parallel_nested_output(self):
|
||||
def fn(input):
|
||||
return [input, (input.sin(), input.cos(), [input.add(1)]), input]
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def forward(self, input):
|
||||
return fn(input)
|
||||
i = Variable(torch.randn(2, 2).float().cuda(1))
|
||||
@ -686,7 +768,9 @@ class TestNN(NNTestCase):
|
||||
def test_data_parallel_nested_input(self):
|
||||
def fn(input):
|
||||
return input[1][0]
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def forward(self, input):
|
||||
return fn(input)
|
||||
i = Variable(torch.randn(20, 3).float().cuda(1))
|
||||
@ -708,7 +792,7 @@ class TestNN(NNTestCase):
|
||||
def test_state_dict(self):
|
||||
l = nn.Linear(5, 5)
|
||||
block = nn.Module()
|
||||
block.conv=nn.Conv2d(3, 3, 3, bias=False)
|
||||
block.conv = nn.Conv2d(3, 3, 3, bias=False)
|
||||
net = nn.Module()
|
||||
net.linear1 = l
|
||||
net.linear2 = l
|
||||
@ -777,6 +861,7 @@ class TestNN(NNTestCase):
|
||||
|
||||
def test_parameter_assignment(self):
|
||||
l = nn.Linear(5, 5)
|
||||
|
||||
def num_params():
|
||||
return len(list(l.parameters()))
|
||||
self.assertEqual(num_params(), 2)
|
||||
@ -789,7 +874,7 @@ class TestNN(NNTestCase):
|
||||
var = Variable(torch.randn(5, 5))
|
||||
l.var_name = var
|
||||
self.assertEqual(num_params(), 3)
|
||||
self.assertNotIn(var, l.parameters())
|
||||
self.assertNotIn(id(var), map(id, l.parameters()))
|
||||
|
||||
# Make sure Variables are not saved as parameters
|
||||
l.variable_attr = Variable(torch.Tensor(5, 5))
|
||||
@ -805,6 +890,32 @@ class TestNN(NNTestCase):
|
||||
l.param_attr = None
|
||||
self.assertEqual(num_params(), 3)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
|
||||
def test_Conv2d_large_workspace(self):
|
||||
# These sizes require huge cuDNN workspaces. Make sure we choose a
|
||||
# reasonable algorithm that does not run out of memory
|
||||
sizes = [
|
||||
(1, 256, 109, 175),
|
||||
(1, 256, 80, 128),
|
||||
(1, 256, 120, 192),
|
||||
]
|
||||
dtype = torch.cuda.FloatTensor
|
||||
|
||||
def run_test(benchmark):
|
||||
torch.backends.cudnn.benchmark = benchmark
|
||||
conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).type(dtype)
|
||||
for size in sizes:
|
||||
x = torch.randn(size).type(dtype)
|
||||
out = conv(Variable(x, requires_grad=True))
|
||||
out.backward(torch.ones(out.size()).type(dtype))
|
||||
|
||||
b = torch.backends.cudnn.benchmark
|
||||
try:
|
||||
run_test(benchmark=False)
|
||||
run_test(benchmark=True)
|
||||
finally:
|
||||
torch.backends.cudnn.benchmark = b
|
||||
|
||||
def test_ConvTranspose2d_output_size(self):
|
||||
m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2)
|
||||
i = Variable(torch.randn(2, 3, 6, 6))
|
||||
@ -857,7 +968,7 @@ class TestNN(NNTestCase):
|
||||
small_t = torch.rand(1, 1, 5, 5)
|
||||
for i in range(0, 4, 2):
|
||||
for j in range(0, 4, 2):
|
||||
small_t[:,:,i,j] = 100
|
||||
small_t[:, :, i, j] = 100
|
||||
output_small, indices_small = m(Variable(small_t))
|
||||
for h in range(3, 10):
|
||||
for w in range(3, 10):
|
||||
@ -870,10 +981,11 @@ class TestNN(NNTestCase):
|
||||
mu(output_small, indices_small, output_size=size)
|
||||
else:
|
||||
self.assertRaises(ValueError, lambda:
|
||||
mu(output_small, indices_small, (h, w)))
|
||||
mu(output_small, indices_small, (h, w)))
|
||||
|
||||
def test_container_copy(self):
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.linear = nn.Linear(4, 5)
|
||||
@ -925,11 +1037,22 @@ class TestNN(NNTestCase):
|
||||
for i in range(6):
|
||||
hx, cx = lstm(input, (hx, cx))
|
||||
|
||||
(hx+cx).sum().backward()
|
||||
(hx + cx).sum().backward()
|
||||
|
||||
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
|
||||
@default_tensor_type(torch.FloatTensor) # FIXME: just until torch.cuda.DoubleTensor.sum() implemented
|
||||
def test_RNN_cpu_vs_cudnn(self):
|
||||
def test_rnn_initial_hidden_state(self):
|
||||
rnn_modes = ['RNN', 'GRU', 'LSTM']
|
||||
for mode in rnn_modes:
|
||||
rnn = getattr(nn, mode)(30, 20, 2)
|
||||
input = Variable(torch.randn(10, 32, 30))
|
||||
hidden = Variable(torch.Tensor(2, 32, 20).zero_())
|
||||
if mode is 'LSTM':
|
||||
hidden = (hidden, hidden)
|
||||
output1, hidden1 = rnn(input, hidden)
|
||||
output2, hidden2 = rnn(input)
|
||||
self.assertEqual(output1, output2)
|
||||
self.assertEqual(hidden1, hidden2)
|
||||
|
||||
def _test_RNN_cpu_vs_cudnn(self, dropout):
|
||||
|
||||
def forward_backward(cuda, rnn, input_val, hx_val, weights_val):
|
||||
is_lstm = type(rnn) == nn.LSTM
|
||||
@ -957,9 +1080,9 @@ class TestNN(NNTestCase):
|
||||
output, hy = rnn(input, hx)
|
||||
# FIXME this is because of a pytorch bug
|
||||
if is_lstm:
|
||||
fake_loss = 0*(hy[0] + hy[1]).sum()
|
||||
fake_loss = 0 * (hy[0] + hy[1]).sum()
|
||||
else:
|
||||
fake_loss = 0*hy.sum()
|
||||
fake_loss = 0 * hy.sum()
|
||||
|
||||
loss = output.sum() + fake_loss
|
||||
loss.backward()
|
||||
@ -989,42 +1112,40 @@ class TestNN(NNTestCase):
|
||||
for (cpu_weight, gpu_weight) in zip(cpu_layer_weight, gpu_layer_weight):
|
||||
self.assertEqual(cpu_weight.grad.data, gpu_weight.grad.data, prec=5e-5)
|
||||
|
||||
|
||||
for module in (nn.RNN, nn.LSTM, nn.GRU):
|
||||
for bias in (True, False):
|
||||
for bidirectional in (False, True):
|
||||
for dropout in (0, 1): # Because of dropout randomness, can only compare 0 and 1
|
||||
for batch_first in (False, True):
|
||||
num_directions = 2 if bidirectional else 1
|
||||
if batch_first:
|
||||
input_val = torch.randn(batch, seq_length, input_size)
|
||||
else:
|
||||
input_val = torch.randn(seq_length, batch, input_size)
|
||||
hx_val = torch.randn(num_layers * num_directions, batch, hidden_size)
|
||||
for batch_first in (False, True):
|
||||
num_directions = 2 if bidirectional else 1
|
||||
if batch_first:
|
||||
input_val = torch.randn(batch, seq_length, input_size)
|
||||
else:
|
||||
input_val = torch.randn(seq_length, batch, input_size)
|
||||
hx_val = torch.randn(num_layers * num_directions, batch, hidden_size)
|
||||
|
||||
rnn = module(input_size,
|
||||
rnn = module(input_size,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
bias=bias,
|
||||
dropout=dropout,
|
||||
bidirectional=bidirectional,
|
||||
batch_first=batch_first)
|
||||
|
||||
outputs_cpu = forward_backward(
|
||||
False, rnn, input_val, hx_val, rnn.all_weights)
|
||||
|
||||
rnn_gpu = module(input_size,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
bias=bias,
|
||||
dropout=dropout,
|
||||
bidirectional=bidirectional,
|
||||
batch_first = batch_first)
|
||||
batch_first=batch_first)
|
||||
|
||||
outputs_cpu = forward_backward(
|
||||
False, rnn, input_val, hx_val, rnn.all_weights)
|
||||
outputs_gpu = forward_backward(
|
||||
True, rnn_gpu, input_val, hx_val, rnn.all_weights)
|
||||
|
||||
rnn_gpu = module(input_size,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
bias=bias,
|
||||
dropout=dropout,
|
||||
bidirectional=bidirectional,
|
||||
batch_first = batch_first)
|
||||
|
||||
outputs_gpu = forward_backward(
|
||||
True, rnn_gpu, input_val, hx_val, rnn.all_weights)
|
||||
|
||||
compare_cpu_gpu(outputs_cpu, outputs_gpu)
|
||||
compare_cpu_gpu(outputs_cpu, outputs_gpu)
|
||||
|
||||
for nonlinearity in ('tanh', 'relu'):
|
||||
hx_val = torch.randn(num_layers, batch, hidden_size)
|
||||
@ -1039,6 +1160,17 @@ class TestNN(NNTestCase):
|
||||
compare_cpu_gpu(outputs_cpu, outputs_gpu)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
|
||||
@default_tensor_type(torch.FloatTensor) # FIXME: just until torch.cuda.DoubleTensor.sum() implemented
|
||||
def test_RNN_cpu_vs_cudnn_no_dropout(self):
|
||||
self._test_RNN_cpu_vs_cudnn(0)
|
||||
|
||||
@unittest.skipIf(not (TEST_CUDNN and TEST_CUDNN_VERSION >= 5103), "needs cudnn >= 5.1")
|
||||
@default_tensor_type(torch.FloatTensor) # FIXME: just until torch.cuda.DoubleTensor.sum() implemented
|
||||
def test_RNN_cpu_vs_cudnn_with_dropout(self):
|
||||
# Because of dropout randomness, can only compare dropout=0 and dropout=1
|
||||
self._test_RNN_cpu_vs_cudnn(1)
|
||||
|
||||
@unittest.skipIf(not (TEST_CUDNN and TEST_CUDNN_VERSION >= 5103), "needs cudnn >= 5.1")
|
||||
def test_RNN_dropout(self):
|
||||
# checking the assumption that cuDNN sticks dropout in between
|
||||
# RNN layers
|
||||
@ -1057,8 +1189,8 @@ class TestNN(NNTestCase):
|
||||
rnn.weight_hh_l0.data.fill_(1)
|
||||
rnn.weight_ih_l1.data.fill_(1)
|
||||
rnn.weight_hh_l1.data.fill_(1)
|
||||
input = Variable(torch.Tensor(1,1,10).fill_(1))
|
||||
hx = Variable(torch.Tensor(2,1,1000).fill_(0))
|
||||
input = Variable(torch.Tensor(1, 1, 10).fill_(1))
|
||||
hx = Variable(torch.Tensor(2, 1, 1000).fill_(0))
|
||||
if cuda:
|
||||
input = input.cuda()
|
||||
hx = hx.cuda()
|
||||
@ -1081,7 +1213,7 @@ class TestNN(NNTestCase):
|
||||
self.assertEqual(hy.data[0][0][0], 10)
|
||||
self.assertEqual(hy.data[1][0][0], output_val)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
|
||||
@unittest.skipIf(not (TEST_CUDNN and TEST_CUDNN_VERSION >= 5103), "needs cudnn >= 5.1")
|
||||
def test_RNN_dropout_state(self):
|
||||
import sys
|
||||
if sys.version_info[0] == 2:
|
||||
@ -1099,8 +1231,8 @@ class TestNN(NNTestCase):
|
||||
rnn.train()
|
||||
else:
|
||||
rnn.eval()
|
||||
input = Variable(torch.Tensor(1,1,100).uniform_())
|
||||
hx = Variable(torch.Tensor(2,1,100).uniform_())
|
||||
input = Variable(torch.Tensor(1, 1, 100).uniform_())
|
||||
hx = Variable(torch.Tensor(2, 1, 100).uniform_())
|
||||
if cuda:
|
||||
input = input.cuda()
|
||||
hx = hx.cuda()
|
||||
@ -1133,6 +1265,15 @@ class TestNN(NNTestCase):
|
||||
(c * upscale_factor ** 2)
|
||||
self.assertEqual(output[:, c, h, w], input[:, channel_idx, height_idx, weight_idx])
|
||||
|
||||
def test_inplace_thnn(self):
|
||||
r = nn.ReLU(True)
|
||||
input = Variable(torch.randn(5, 5), requires_grad=True)
|
||||
output = r(input + 0)
|
||||
grad_output = torch.randn(5, 5)
|
||||
grad_output_clone = grad_output.clone()
|
||||
output.backward(grad_output)
|
||||
self.assertEqual(grad_output, grad_output_clone)
|
||||
|
||||
def test_pixel_shuffle(self):
|
||||
batch_size = random.randint(1, 3)
|
||||
upscale_factor = random.randint(2, 5)
|
||||
@ -1147,6 +1288,32 @@ class TestNN(NNTestCase):
|
||||
output.backward(output.data)
|
||||
self.assertEqual(input.data, input.grad.data)
|
||||
|
||||
def test_batchnorm_eval(self):
|
||||
types = (torch.FloatTensor,)
|
||||
if TEST_CUDA:
|
||||
types += (torch.cuda.FloatTensor,)
|
||||
for tp in types:
|
||||
module = nn.BatchNorm1d(3).type(tp)
|
||||
module.eval()
|
||||
|
||||
data = Variable(torch.rand(4, 3).type(tp), requires_grad=True)
|
||||
grad = torch.rand(4, 3).type(tp)
|
||||
|
||||
# 1st pass
|
||||
res1 = module(data)
|
||||
res1.backward(grad)
|
||||
grad1 = data.grad.data.clone()
|
||||
|
||||
# 2nd pass
|
||||
data.grad.data.zero_()
|
||||
|
||||
res2 = module(data)
|
||||
res2.backward(grad)
|
||||
grad2 = data.grad.data.clone()
|
||||
self.assertEqual(res1, res2)
|
||||
self.assertEqual(grad1, grad2)
|
||||
|
||||
|
||||
def add_test(test):
|
||||
test_name = test.get_name()
|
||||
cuda_test_name = test_name + '_cuda'
|
||||
@ -1154,8 +1321,8 @@ def add_test(test):
|
||||
raise RuntimeError('Found two tests with the same name: ' + test_name)
|
||||
if hasattr(TestNN, cuda_test_name):
|
||||
raise RuntimeError('Found two tests with the same name: ' + cuda_test_name)
|
||||
setattr(TestNN, test_name, lambda self,test=test: test(self))
|
||||
setattr(TestNN, cuda_test_name, lambda self,test=test: test.test_cuda(self))
|
||||
setattr(TestNN, test_name, lambda self, test=test: test(self))
|
||||
setattr(TestNN, cuda_test_name, lambda self, test=test: test.test_cuda(self))
|
||||
|
||||
|
||||
new_module_tests = [
|
||||
@ -1308,6 +1475,11 @@ new_module_tests = [
|
||||
input_size=(2, 4, 6, 5),
|
||||
cudnn=True,
|
||||
),
|
||||
dict(
|
||||
fullname='Conv2d_groups_thnn',
|
||||
constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
|
||||
input_size=(2, 4, 6, 5),
|
||||
),
|
||||
dict(
|
||||
module_name='ConvTranspose2d',
|
||||
constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)),
|
||||
@ -1460,20 +1632,26 @@ new_module_tests = [
|
||||
dict(
|
||||
module_name='Embedding',
|
||||
constructor_args=(4, 3),
|
||||
input=Variable(
|
||||
torch.randperm(2).repeat(1, 2),
|
||||
requires_grad=False
|
||||
),
|
||||
input=Variable(torch.randperm(2).repeat(1, 2)),
|
||||
jacobian_input=False
|
||||
),
|
||||
dict(
|
||||
constructor=lambda: nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
|
||||
constructor=lambda: nn.Embedding(4, 3, sparse=True),
|
||||
input=Variable(torch.randperm(2).repeat(1, 2)),
|
||||
jacobian_input=False,
|
||||
fullname='Embedding_sparse',
|
||||
test_cuda=False,
|
||||
),
|
||||
dict(
|
||||
constructor=lambda: nn.FractionalMaxPool2d(
|
||||
2, output_ratio=0.5, _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
|
||||
input_size=(1, 3, 5, 5),
|
||||
fullname='FractionalMaxPool2d_ratio',
|
||||
test_cuda=False
|
||||
),
|
||||
dict(
|
||||
constructor=lambda: nn.FractionalMaxPool2d((2, 2), output_size=(4, 4), _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
|
||||
constructor=lambda: nn.FractionalMaxPool2d((2, 2), output_size=(
|
||||
4, 4), _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
|
||||
input_size=(1, 3, 7, 7),
|
||||
fullname='FractionalMaxPool2d_size',
|
||||
test_cuda=False
|
||||
@ -1483,6 +1661,40 @@ new_module_tests = [
|
||||
constructor_args=(3,),
|
||||
input_size=(1, 9, 4, 4),
|
||||
),
|
||||
dict(
|
||||
module_name='UpsamplingNearest2d',
|
||||
constructor_args=(12,),
|
||||
input_size=(1, 2, 4, 4),
|
||||
),
|
||||
dict(
|
||||
module_name='UpsamplingNearest2d',
|
||||
constructor_args=((12, 16)),
|
||||
input_size=(1, 2, 3, 4),
|
||||
desc='tuple'
|
||||
),
|
||||
dict(
|
||||
module_name='UpsamplingNearest2d',
|
||||
constructor_args=(None, 4),
|
||||
input_size=(1, 2, 4, 4),
|
||||
desc='scale'
|
||||
),
|
||||
dict(
|
||||
module_name='UpsamplingBilinear2d',
|
||||
constructor_args=(12,),
|
||||
input_size=(1, 2, 4, 4),
|
||||
),
|
||||
dict(
|
||||
module_name='UpsamplingBilinear2d',
|
||||
constructor_args=((4, 6)),
|
||||
input_size=(1, 2, 2, 3),
|
||||
desc='tuple'
|
||||
),
|
||||
dict(
|
||||
module_name='UpsamplingBilinear2d',
|
||||
constructor_args=(None, 4),
|
||||
input_size=(1, 2, 4, 4),
|
||||
desc='scale'
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@ -1501,6 +1713,7 @@ for test_params in criterion_tests:
|
||||
|
||||
|
||||
class UnpoolingNet(nn.Module):
|
||||
|
||||
def __init__(self, pool, unpool):
|
||||
super(UnpoolingNet, self).__init__()
|
||||
self.pool = pool
|
||||
@ -1531,4 +1744,4 @@ add_test(NewModuleTest(
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import unittest
|
||||
import functools
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torch.legacy.optim as old_optim
|
||||
from torch.autograd import Variable
|
||||
|
||||
from common import TestCase
|
||||
from common import TestCase, run_tests
|
||||
|
||||
|
||||
def rosenbrock(tensor):
|
||||
@ -14,7 +16,7 @@ def rosenbrock(tensor):
|
||||
|
||||
def drosenbrock(tensor):
|
||||
x, y = tensor
|
||||
return torch.DoubleTensor((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
|
||||
return torch.DoubleTensor((-400 * x * (y - x ** 2) - 2 * (1 - x), 200 * (y - x ** 2)))
|
||||
|
||||
|
||||
def wrap_old_fn(old_fn, **config):
|
||||
@ -36,15 +38,22 @@ class TestOptim(TestCase):
|
||||
initial_dist = params.data.dist(solution)
|
||||
|
||||
def eval():
|
||||
optimizer.zero_grad()
|
||||
loss = rosenbrock(params)
|
||||
loss.backward()
|
||||
# loss.backward() will give **slightly** different
|
||||
# gradients, than drosenbtock, because of a different ordering
|
||||
# of floating point operations. In most cases it doesn't matter,
|
||||
# but some optimizers are so sensitive that they can temporarily
|
||||
# diverge up to 1e-4, just to converge again. This makes the
|
||||
# comparison more stable.
|
||||
params.grad.data.copy_(drosenbrock(params.data))
|
||||
return loss
|
||||
|
||||
for i in range(2000):
|
||||
optimizer.zero_grad()
|
||||
optimizer.step(eval)
|
||||
old_fn(lambda _: (rosenbrock(params_t), drosenbrock(params_t)),
|
||||
params_t, state)
|
||||
params_t, state)
|
||||
self.assertEqual(params.data, params_t)
|
||||
|
||||
self.assertLessEqual(params.data.dist(solution), initial_dist)
|
||||
@ -52,25 +61,65 @@ class TestOptim(TestCase):
|
||||
def _test_basic_cases_template(self, weight, bias, input, constructor):
|
||||
weight = Variable(weight, requires_grad=True)
|
||||
bias = Variable(bias, requires_grad=True)
|
||||
input = Variable(input, requires_grad=False)
|
||||
input = Variable(input)
|
||||
optimizer = constructor(weight, bias)
|
||||
|
||||
def fn():
|
||||
optimizer.zero_grad()
|
||||
y = weight.mv(input)
|
||||
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
|
||||
y = y.cuda(bias.get_device())
|
||||
return (y + bias).abs().sum()
|
||||
loss = (y + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
initial_value = fn().data[0]
|
||||
for i in range(200):
|
||||
weight.grad.data.zero_()
|
||||
bias.grad.data.zero_()
|
||||
fn().backward()
|
||||
optimizer.step()
|
||||
optimizer.step(fn)
|
||||
self.assertLess(fn().data[0], initial_value)
|
||||
|
||||
self.assertLessEqual(fn().data[0], initial_value)
|
||||
def _test_state_dict(self, weight, bias, input, constructor):
|
||||
weight = Variable(weight, requires_grad=True)
|
||||
bias = Variable(bias, requires_grad=True)
|
||||
input = Variable(input)
|
||||
|
||||
def _test_basic_cases(self, constructor):
|
||||
def fn_base(optimizer, weight, bias):
|
||||
optimizer.zero_grad()
|
||||
loss = (weight.mv(input) + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
optimizer = constructor(weight, bias)
|
||||
fn = functools.partial(fn_base, optimizer, weight, bias)
|
||||
|
||||
# Prime the optimizer
|
||||
for i in range(20):
|
||||
optimizer.step(fn)
|
||||
# Clone the weights and construct new optimizer for them
|
||||
weight_c = Variable(weight.data.clone(), requires_grad=True)
|
||||
bias_c = Variable(bias.data.clone(), requires_grad=True)
|
||||
optimizer_c = constructor(weight_c, bias_c)
|
||||
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
|
||||
# Load state dict
|
||||
state_dict = deepcopy(optimizer.state_dict())
|
||||
state_dict_c = deepcopy(optimizer.state_dict())
|
||||
optimizer_c.load_state_dict(state_dict_c)
|
||||
# Run both optimizations in parallel
|
||||
for i in range(20):
|
||||
optimizer.step(fn)
|
||||
optimizer_c.step(fn_c)
|
||||
self.assertEqual(weight, weight_c)
|
||||
self.assertEqual(bias, bias_c)
|
||||
# Make sure state dict wasn't modified
|
||||
self.assertEqual(state_dict, state_dict_c)
|
||||
|
||||
def _test_basic_cases(self, constructor, ignore_multidevice=False):
|
||||
self._test_state_dict(
|
||||
torch.randn(10, 5),
|
||||
torch.randn(10),
|
||||
torch.randn(5),
|
||||
constructor
|
||||
)
|
||||
self._test_basic_cases_template(
|
||||
torch.randn(10, 5),
|
||||
torch.randn(10),
|
||||
@ -79,8 +128,8 @@ class TestOptim(TestCase):
|
||||
)
|
||||
# non-contiguous parameters
|
||||
self._test_basic_cases_template(
|
||||
torch.randn(10, 5, 2)[...,0],
|
||||
torch.randn(10, 2)[...,0],
|
||||
torch.randn(10, 5, 2)[..., 0],
|
||||
torch.randn(10, 2)[..., 0],
|
||||
torch.randn(5),
|
||||
constructor
|
||||
)
|
||||
@ -94,12 +143,12 @@ class TestOptim(TestCase):
|
||||
constructor
|
||||
)
|
||||
# Multi-GPU
|
||||
if not torch.cuda.device_count() > 1:
|
||||
if not torch.cuda.device_count() > 1 or ignore_multidevice:
|
||||
return
|
||||
self._test_basic_cases_template(
|
||||
torch.randn(10, 5).cuda(),
|
||||
torch.randn(10).cuda(),
|
||||
torch.randn(5).cuda(),
|
||||
torch.randn(10, 5).cuda(0),
|
||||
torch.randn(10).cuda(1),
|
||||
torch.randn(5).cuda(0),
|
||||
constructor
|
||||
)
|
||||
|
||||
@ -275,10 +324,24 @@ class TestOptim(TestCase):
|
||||
lr=1e-3)
|
||||
)
|
||||
|
||||
def test_lbfgs(self):
|
||||
self._test_rosenbrock(
|
||||
lambda params: optim.LBFGS(params),
|
||||
wrap_old_fn(old_optim.lbfgs)
|
||||
)
|
||||
self._test_rosenbrock(
|
||||
lambda params: optim.LBFGS(params, lr=5e-2, max_iter=5),
|
||||
wrap_old_fn(old_optim.lbfgs, learningRate=5e-2, maxIter=5)
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optim.LBFGS([weight, bias]),
|
||||
ignore_multidevice=True
|
||||
)
|
||||
|
||||
def test_invalid_param_type(self):
|
||||
with self.assertRaises(TypeError):
|
||||
optim.SGD(Variable(torch.randn(5, 5)), lr=3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
||||
@ -4,13 +4,14 @@ from torch import sparse
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
from common import TestCase
|
||||
from common import TestCase, run_tests
|
||||
from numbers import Number
|
||||
|
||||
SparseTensor = sparse.DoubleTensor
|
||||
|
||||
|
||||
class TestSparse(TestCase):
|
||||
|
||||
@staticmethod
|
||||
def _gen_sparse(d, nnz, with_size):
|
||||
v = torch.randn(nnz)
|
||||
@ -19,7 +20,7 @@ class TestSparse(TestCase):
|
||||
x = SparseTensor(i, v)
|
||||
else:
|
||||
i = torch.rand(d, nnz) * \
|
||||
torch.Tensor(with_size).repeat(nnz, 1).transpose(0, 1)
|
||||
torch.Tensor(with_size).repeat(nnz, 1).transpose(0, 1)
|
||||
i = i.type(torch.LongTensor)
|
||||
x = SparseTensor(i, v, torch.Size(with_size))
|
||||
|
||||
@ -74,13 +75,13 @@ class TestSparse(TestCase):
|
||||
|
||||
def test_contig(self):
|
||||
i = torch.LongTensor([
|
||||
[1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
|
||||
[1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
|
||||
[92, 31, 62, 50, 22, 65, 89, 74, 56, 34],
|
||||
])
|
||||
v = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
x = SparseTensor(i, v, torch.Size([100, 100]))
|
||||
exp_i = torch.LongTensor([
|
||||
[0, 1, 6, 14, 27, 35, 39, 40, 66, 71],
|
||||
[0, 1, 6, 14, 27, 35, 39, 40, 66, 71],
|
||||
[31, 92, 65, 50, 34, 62, 22, 56, 74, 89],
|
||||
])
|
||||
exp_v = torch.Tensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7])
|
||||
@ -216,5 +217,4 @@ class TestSparse(TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
run_tests()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -19,7 +19,7 @@ from torch.utils.serialization import load_lua
|
||||
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
|
||||
from common import TestCase
|
||||
from common import TestCase, run_tests
|
||||
|
||||
try:
|
||||
import cffi
|
||||
@ -28,7 +28,9 @@ try:
|
||||
except ImportError:
|
||||
HAS_CFFI = False
|
||||
|
||||
|
||||
class SimplePlugin(Plugin):
|
||||
|
||||
def __init__(self, interval):
|
||||
super(SimplePlugin, self).__init__(interval)
|
||||
self.trainer = None
|
||||
@ -58,6 +60,7 @@ class SimplePlugin(Plugin):
|
||||
|
||||
|
||||
class ModelMock(object):
|
||||
|
||||
def __init__(self):
|
||||
self.num_calls = 0
|
||||
self.output = Variable(torch.ones(1, 1), requires_grad=True)
|
||||
@ -68,6 +71,7 @@ class ModelMock(object):
|
||||
|
||||
|
||||
class CriterionMock(object):
|
||||
|
||||
def __init__(self):
|
||||
self.num_calls = 0
|
||||
|
||||
@ -95,6 +99,7 @@ class OptimizerMock(object):
|
||||
|
||||
|
||||
class DatasetMock(object):
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(10):
|
||||
yield torch.randn(2, 10), torch.randperm(10)[:2]
|
||||
@ -183,6 +188,7 @@ class TestTrainer(TestCase):
|
||||
|
||||
test_dir = os.path.abspath(os.path.dirname(str(__file__)))
|
||||
|
||||
|
||||
class TestFFI(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -196,13 +202,13 @@ class TestFFI(TestCase):
|
||||
@unittest.skipIf(not HAS_CFFI, "ffi tests require cffi package")
|
||||
def test_cpu(self):
|
||||
compile_extension(
|
||||
name='test_extensions.cpulib',
|
||||
header=test_dir + '/ffi/src/cpu/lib.h',
|
||||
sources=[
|
||||
test_dir + '/ffi/src/cpu/lib1.c',
|
||||
test_dir + '/ffi/src/cpu/lib2.c',
|
||||
],
|
||||
verbose=False,
|
||||
name='test_extensions.cpulib',
|
||||
header=test_dir + '/ffi/src/cpu/lib.h',
|
||||
sources=[
|
||||
test_dir + '/ffi/src/cpu/lib1.c',
|
||||
test_dir + '/ffi/src/cpu/lib2.c',
|
||||
],
|
||||
verbose=False,
|
||||
)
|
||||
from test_extensions import cpulib
|
||||
tensor = torch.ones(2, 2).float()
|
||||
@ -217,20 +223,20 @@ class TestFFI(TestCase):
|
||||
self.assertIs(type(f), float)
|
||||
|
||||
self.assertRaises(TypeError,
|
||||
lambda: cpulib.good_func(tensor.double(), 2, 1.5))
|
||||
lambda: cpulib.good_func(tensor.double(), 2, 1.5))
|
||||
self.assertRaises(torch.FatalError,
|
||||
lambda: cpulib.bad_func(tensor, 2, 1.5))
|
||||
lambda: cpulib.bad_func(tensor, 2, 1.5))
|
||||
|
||||
@unittest.skipIf(not HAS_CFFI or not HAS_CUDA, "ffi tests require cffi package")
|
||||
def test_gpu(self):
|
||||
compile_extension(
|
||||
name='gpulib',
|
||||
header=test_dir + '/ffi/src/cuda/cudalib.h',
|
||||
sources=[
|
||||
test_dir + '/ffi/src/cuda/cudalib.c',
|
||||
],
|
||||
with_cuda=True,
|
||||
verbose=False,
|
||||
name='gpulib',
|
||||
header=test_dir + '/ffi/src/cuda/cudalib.h',
|
||||
sources=[
|
||||
test_dir + '/ffi/src/cuda/cudalib.c',
|
||||
],
|
||||
with_cuda=True,
|
||||
verbose=False,
|
||||
)
|
||||
import gpulib
|
||||
tensor = torch.ones(2, 2).float()
|
||||
@ -243,9 +249,9 @@ class TestFFI(TestCase):
|
||||
self.assertEqual(ctensor, torch.ones(2, 2) * 2 + 1.5)
|
||||
|
||||
self.assertRaises(TypeError,
|
||||
lambda: gpulib.cuda_func(tensor, 2, 1.5))
|
||||
lambda: gpulib.cuda_func(tensor, 2, 1.5))
|
||||
self.assertRaises(TypeError,
|
||||
lambda: gpulib.cuda_func(ctensor.storage(), 2, 1.5))
|
||||
lambda: gpulib.cuda_func(ctensor.storage(), 2, 1.5))
|
||||
|
||||
|
||||
class TestLuaReader(TestCase):
|
||||
@ -320,7 +326,7 @@ class TestLuaReader(TestCase):
|
||||
cls._download_data(test_file_path)
|
||||
except urllib.URLError as e:
|
||||
warnings.warn(("Couldn't download the test file for TestLuaReader! "
|
||||
"Tests will be incomplete!"), RuntimeWarning)
|
||||
"Tests will be incomplete!"), RuntimeWarning)
|
||||
return
|
||||
|
||||
tests = load_lua(test_file_path)
|
||||
@ -364,4 +370,4 @@ class TestLuaReader(TestCase):
|
||||
|
||||
TestLuaReader.init()
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
||||
@ -20,13 +20,14 @@ class cwrap(object):
|
||||
""")
|
||||
|
||||
OPTION_CODE_TEMPLATE = [
|
||||
'$call',
|
||||
'$return_result',
|
||||
'$call',
|
||||
'$return_result',
|
||||
]
|
||||
|
||||
FUNCTION_CALL_TEMPLATE = Template("$capture_result$cname($arg_unpack);")
|
||||
|
||||
DEFAULT_PLUGIN_CLASSES = [ArgcountChecker, ConstantArguments, OptionalArguments, ArgumentReferences, BeforeAfterCall, ReturnArguments, GILRelease]
|
||||
DEFAULT_PLUGIN_CLASSES = [ArgcountChecker, ConstantArguments, OptionalArguments,
|
||||
ArgumentReferences, BeforeAfterCall, ReturnArguments, GILRelease]
|
||||
|
||||
def __init__(self, source, destination=None, plugins=[], default_plugins=True):
|
||||
if destination is None:
|
||||
@ -87,7 +88,7 @@ class cwrap(object):
|
||||
with open(fname, 'r') as f:
|
||||
included = f.read().split('\n')
|
||||
# insert it into lines at position i+1
|
||||
lines[i+1:i+1] = included
|
||||
lines[i + 1:i + 1] = included
|
||||
else:
|
||||
output.append(line)
|
||||
i += 1
|
||||
@ -97,10 +98,10 @@ class cwrap(object):
|
||||
def set_declaration_defaults(self, declaration):
|
||||
declaration.setdefault('arguments', [])
|
||||
declaration.setdefault('return', 'void')
|
||||
if not 'cname' in declaration:
|
||||
if 'cname' not in declaration:
|
||||
declaration['cname'] = declaration['name']
|
||||
# Simulate multiple dispatch, even if it's not necessary
|
||||
if not 'options' in declaration:
|
||||
if 'options' not in declaration:
|
||||
declaration['options'] = [{'arguments': declaration['arguments']}]
|
||||
del declaration['arguments']
|
||||
# Parse arguments (some of them can be strings)
|
||||
@ -136,10 +137,10 @@ class cwrap(object):
|
||||
return fallback(*args)
|
||||
|
||||
def get_type_check(self, arg, option):
|
||||
return self.search_plugins('get_type_check', (arg, option), lambda arg,_: None)
|
||||
return self.search_plugins('get_type_check', (arg, option), lambda arg, _: None)
|
||||
|
||||
def get_type_unpack(self, arg, option):
|
||||
return self.search_plugins('get_type_unpack', (arg, option), lambda arg,_: None)
|
||||
return self.search_plugins('get_type_unpack', (arg, option), lambda arg, _: None)
|
||||
|
||||
def get_return_wrapper(self, option):
|
||||
return self.search_plugins('get_return_wrapper', (option,), lambda _: self.RETURN_WRAPPERS[option['return']])
|
||||
@ -182,7 +183,7 @@ class cwrap(object):
|
||||
|
||||
def generate_option(self, option, is_first):
|
||||
checked_args = list(filter(
|
||||
lambda arg: not 'ignore_check' in arg or not arg['ignore_check'],
|
||||
lambda arg: 'ignore_check' not in arg or not arg['ignore_check'],
|
||||
option['arguments']))
|
||||
option['num_checked_args'] = len(checked_args)
|
||||
idx_args = list(filter(
|
||||
@ -193,14 +194,14 @@ class cwrap(object):
|
||||
|
||||
# Generate checks
|
||||
arg_checks = self.map_selected_arguments('get_type_check',
|
||||
'process_single_check', option, checked_args)
|
||||
'process_single_check', option, checked_args)
|
||||
arg_checks = ' &&\n '.join(arg_checks)
|
||||
for plugin in self.plugins:
|
||||
arg_checks = plugin.process_all_checks(arg_checks, option)
|
||||
|
||||
# Generate unpacks
|
||||
arg_unpack = self.map_selected_arguments('get_type_unpack',
|
||||
'process_single_unpack', option, option['arguments'])
|
||||
'process_single_unpack', option, option['arguments'])
|
||||
arg_unpack = ', '.join(arg_unpack)
|
||||
for plugin in self.plugins:
|
||||
arg_unpack = plugin.process_all_unpacks(arg_unpack, option)
|
||||
@ -209,16 +210,16 @@ class cwrap(object):
|
||||
try:
|
||||
return_result = self.get_return_wrapper(option).substitute()
|
||||
call = self.FUNCTION_CALL_TEMPLATE.substitute(capture_result='',
|
||||
cname=option['cname'], arg_unpack=arg_unpack)
|
||||
cname=option['cname'], arg_unpack=arg_unpack)
|
||||
except KeyError:
|
||||
return_result = self.get_return_wrapper(option).substitute(result='__result')
|
||||
call = self.FUNCTION_CALL_TEMPLATE.substitute(capture_result=(option['return'] + ' __result = '),
|
||||
cname=option['cname'], arg_unpack=arg_unpack)
|
||||
cname=option['cname'], arg_unpack=arg_unpack)
|
||||
|
||||
code_template = deepcopy(self.OPTION_CODE_TEMPLATE)
|
||||
for plugin in self.plugins:
|
||||
code_template = plugin.process_option_code_template(code_template,
|
||||
option)
|
||||
option)
|
||||
code_template = Template('\n'.join(code_template))
|
||||
code = code_template.substitute(call=call, return_result=return_result)
|
||||
code_lines = map(lambda s: s.strip(), code.split('\n'))
|
||||
@ -228,6 +229,8 @@ class cwrap(object):
|
||||
depth -= line.count('}') * 2
|
||||
code += ' ' * depth + line + '\n'
|
||||
depth += line.count('{') * 2
|
||||
depth += line.count('(') * 4
|
||||
depth -= line.count(')') * 4
|
||||
|
||||
# Put everything together
|
||||
return self.OPTION_TEMPLATE.substitute(
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from . import CWrapPlugin
|
||||
|
||||
|
||||
class ArgcountChecker(CWrapPlugin):
|
||||
|
||||
def process_all_checks(self, checks, option):
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from . import CWrapPlugin
|
||||
|
||||
|
||||
class ArgcountSortPlugin(CWrapPlugin):
|
||||
|
||||
def __init__(self, descending=True):
|
||||
@ -11,4 +12,3 @@ class ArgcountSortPlugin(CWrapPlugin):
|
||||
for declaration in declarations:
|
||||
declaration['options'].sort(key=num_checked_args, reverse=self.descending)
|
||||
return declarations
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class ArgumentReferences(CWrapPlugin):
|
||||
|
||||
def initialize(self, cwrap):
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from . import CWrapPlugin
|
||||
|
||||
|
||||
class AutoGPU(CWrapPlugin):
|
||||
|
||||
def __init__(self, has_self=True, condition=None):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class BeforeAfterCall(CWrapPlugin):
|
||||
|
||||
def initialize(self, cwrap):
|
||||
@ -13,7 +14,7 @@ class BeforeAfterCall(CWrapPlugin):
|
||||
if '$' in prepend_str:
|
||||
before_call_template = Template(option[name])
|
||||
args = {'arg' + str(i): self.cwrap.get_arg_accessor(arg, option) for i, arg
|
||||
in enumerate(option['arguments'])}
|
||||
in enumerate(option['arguments'])}
|
||||
prepend_str = before_call_template.substitute(args)
|
||||
template.insert(offset, prepend_str)
|
||||
|
||||
@ -23,5 +24,5 @@ class BeforeAfterCall(CWrapPlugin):
|
||||
self.insert_snippet(template, option, call_idx, 'before_call')
|
||||
# call position might have changed
|
||||
call_idx = template.index('$call')
|
||||
self.insert_snippet(template, option, call_idx+1, 'after_call')
|
||||
self.insert_snippet(template, option, call_idx + 1, 'after_call')
|
||||
return template
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class BoolOption(CWrapPlugin):
|
||||
|
||||
UNPACK_TEMPLATE = Template('$arg == Py_True ? $if_true : $if_false')
|
||||
@ -16,4 +17,3 @@ class BoolOption(CWrapPlugin):
|
||||
if self.is_bool_option(arg):
|
||||
return Template(self.UNPACK_TEMPLATE.safe_substitute(
|
||||
if_true=arg['if_true'], if_false=arg['if_false']))
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class ConstantArguments(CWrapPlugin):
|
||||
|
||||
def process_declarations(self, declarations):
|
||||
@ -18,5 +19,3 @@ class ConstantArguments(CWrapPlugin):
|
||||
def get_arg_accessor(self, arg, option):
|
||||
if arg['type'] == 'CONSTANT':
|
||||
return arg['name']
|
||||
|
||||
|
||||
|
||||
@ -3,30 +3,31 @@ from copy import deepcopy
|
||||
from . import CWrapPlugin
|
||||
from itertools import product
|
||||
|
||||
|
||||
class CuDNNPlugin(CWrapPlugin):
|
||||
|
||||
TYPE_UNPACK = {
|
||||
'THTensor*': Template('((THPVoidTensor*)$arg)->cdata'),
|
||||
'int': Template('THPUtils_unpackLong($arg)'),
|
||||
'THTensor*': Template('((THPVoidTensor*)$arg)->cdata'),
|
||||
'int': Template('THPUtils_unpackLong($arg)'),
|
||||
'std::vector<int>': Template('THPUtils_unpackIntTuple($arg)'),
|
||||
'cudnnDataType_t': Template('$arg'),
|
||||
'cudnnHandle_t': Template('$arg'),
|
||||
'Convolution*': Template('(Convolution*)THPWrapper_get($arg)'),
|
||||
'bool': Template('$arg == Py_True'),
|
||||
'double': Template('THPDoubleUtils_unpackReal($arg)'),
|
||||
'cudnnDataType_t': Template('$arg'),
|
||||
'cudnnHandle_t': Template('$arg'),
|
||||
'Convolution*': Template('(Convolution*)THPWrapper_get($arg)'),
|
||||
'bool': Template('$arg == Py_True'),
|
||||
'double': Template('THPDoubleUtils_unpackReal($arg)'),
|
||||
}
|
||||
|
||||
TYPE_CHECK = {
|
||||
'Convolution*': Template('THPWrapper_check($arg)'),
|
||||
'THTensor*': Template('(PyObject*)Py_TYPE($arg) == tensorClass'),
|
||||
'int': Template('THPUtils_checkLong($arg)'),
|
||||
'Convolution*': Template('THPWrapper_check($arg)'),
|
||||
'THTensor*': Template('(PyObject*)Py_TYPE($arg) == tensorClass'),
|
||||
'int': Template('THPUtils_checkLong($arg)'),
|
||||
'std::vector<int>': Template('THPUtils_checkIntTuple($arg)'),
|
||||
'bool': Template('PyBool_Check($arg)'),
|
||||
'double': Template('THPDoubleUtils_checkReal($arg)'),
|
||||
'bool': Template('PyBool_Check($arg)'),
|
||||
'double': Template('THPDoubleUtils_checkReal($arg)'),
|
||||
}
|
||||
|
||||
RETURN_WRAPPER = {
|
||||
'Convolution*': Template('return THPWrapper_New($result, [](void* arg) { delete (Convolution*)arg; });'),
|
||||
'Convolution*': Template('return THPWrapper_New($result, [](void* arg) { delete (Convolution*)arg; });'),
|
||||
}
|
||||
|
||||
METHODS_DECLARATION = Template("""
|
||||
@ -123,7 +124,8 @@ static PyObject * $name(PyObject *self, PyObject *args, PyObject *kwargs)
|
||||
|
||||
def filter_unique_options(self, options):
|
||||
def signature(option):
|
||||
return '#'.join(arg['type'] for arg in option['arguments'] if not 'ignore_check' in arg or not arg['ignore_check'])
|
||||
return '#'.join(arg['type'] for arg in option['arguments']
|
||||
if 'ignore_check' not in arg or not arg['ignore_check'])
|
||||
seen_signatures = set()
|
||||
unique = []
|
||||
for option in options:
|
||||
@ -151,8 +153,8 @@ static PyObject * $name(PyObject *self, PyObject *args, PyObject *kwargs)
|
||||
if not declaration.get('only_register'):
|
||||
extra_flags += ' | METH_KEYWORDS'
|
||||
entry = Template(' {"$python_name", (PyCFunction)$name, METH_VARARGS$extra_flags, NULL},\n').substitute(
|
||||
python_name=declaration['python_name'], name=declaration['name'], extra_flags=extra_flags
|
||||
)
|
||||
python_name=declaration['python_name'], name=declaration['name'], extra_flags=extra_flags
|
||||
)
|
||||
if 'defined_if' in declaration:
|
||||
entry = self.preprocessor_guard(entry, declaration['defined_if'])
|
||||
methods += entry
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class GILRelease(CWrapPlugin):
|
||||
|
||||
OPTION_START = [
|
||||
@ -24,6 +25,5 @@ class GILRelease(CWrapPlugin):
|
||||
def process_option_code_template(self, template, option):
|
||||
call_idx = template.index('$call')
|
||||
template.insert(call_idx, self.BEFORE_CALL)
|
||||
template.insert(call_idx+2, self.AFTER_CALL)
|
||||
template.insert(call_idx + 2, self.AFTER_CALL)
|
||||
return self.OPTION_START + template + self.OPTION_END
|
||||
|
||||
|
||||
211
tools/cwrap/plugins/GenericNN.py
Normal file
211
tools/cwrap/plugins/GenericNN.py
Normal file
@ -0,0 +1,211 @@
|
||||
import copy
|
||||
from string import Template
|
||||
from . import CWrapPlugin
|
||||
|
||||
|
||||
class GenericNN(CWrapPlugin):
|
||||
INPUT_TYPE_CHECK = Template("checkTypes(is_cuda, $type, $tensor_args);")
|
||||
|
||||
HEADER_TEMPLATE = Template("void $name($args);")
|
||||
|
||||
WRAPPER_TEMPLATE = Template("""\
|
||||
void $name($args)
|
||||
{
|
||||
bool is_cuda = $input->isCuda();
|
||||
auto type = $input->type();
|
||||
$type_check
|
||||
$options
|
||||
} else {
|
||||
throw std::runtime_error("invalid arguments");
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
THNN_TEMPLATE = Template("""\
|
||||
if (type == thpp::Type::FLOAT) {
|
||||
THNN_Float$name(
|
||||
NULL,
|
||||
$float_args);
|
||||
} else if (type == thpp::Type::DOUBLE) {
|
||||
THNN_Double$name(
|
||||
NULL,
|
||||
$double_args);
|
||||
} else {
|
||||
throw std::runtime_error("unsupported tensor type");
|
||||
}""")
|
||||
|
||||
THCUNN_TEMPLATE = Template("""\
|
||||
#ifdef WITH_CUDA
|
||||
if (type == thpp::Type::FLOAT) {
|
||||
THNN_Cuda$name(
|
||||
state,
|
||||
$float_args);
|
||||
} else if (type == thpp::Type::DOUBLE) {
|
||||
THNN_CudaDouble$name(
|
||||
state,
|
||||
$double_args);
|
||||
} else if (type == thpp::Type::HALF) {
|
||||
THNN_CudaHalf$name(
|
||||
state,
|
||||
$half_args);
|
||||
} else {
|
||||
throw std::runtime_error("unsupported tensor type");
|
||||
}
|
||||
#endif
|
||||
""")
|
||||
|
||||
INDEX_TENSOR_TYPES = {'THIndexTensor*', 'THCIndexTensor*'}
|
||||
|
||||
REAL_TENSOR_TYPES = {'THTensor*', 'THCTensor*'}
|
||||
|
||||
INPUT_ARGUMENT_MAP = {
|
||||
'THNNState*': 'void*',
|
||||
'THCState*': 'void*',
|
||||
'THTensor*': 'thpp::Tensor*',
|
||||
'THCTensor*': 'thpp::Tensor*',
|
||||
'THIndexTensor*': 'thpp::Tensor*',
|
||||
'THIndex_t': 'long',
|
||||
'real': 'double',
|
||||
}
|
||||
|
||||
def __init__(self, header=False):
|
||||
self.header = header
|
||||
self.declarations = []
|
||||
|
||||
def process_full_file(self, base_wrapper):
|
||||
if self.header:
|
||||
wrapper = '#pragma once\n\n'
|
||||
wrapper += '#include <THPP/Tensor.hpp>\n\n'
|
||||
else:
|
||||
wrapper = '#include "THNN_generic.h"\n'
|
||||
wrapper = '#include "THNN_generic.inc.h"\n\n'
|
||||
wrapper += 'namespace torch { namespace nn {\n\n'
|
||||
wrapper += base_wrapper
|
||||
wrapper += '}} // namespace torch::nn\n'
|
||||
return wrapper
|
||||
|
||||
def process_declarations(self, declarations):
|
||||
for declaration in declarations:
|
||||
base_args = declaration['options'][0]['arguments']
|
||||
for option in declaration['options']:
|
||||
for idx, arg in enumerate(option['arguments']):
|
||||
arg['formal_name'] = base_args[idx]['name']
|
||||
arg['formal_type'] = base_args[idx]['type']
|
||||
if idx != 1:
|
||||
arg['ignore_check'] = True
|
||||
return declarations
|
||||
|
||||
def get_arg_accessor(self, arg, option):
|
||||
return self.get_type_unpack(arg, option)
|
||||
|
||||
def process_option_code_template(self, template, option):
|
||||
code = '// fill me in'
|
||||
|
||||
def base_cast(arg, CReal, real):
|
||||
name = arg['formal_name']
|
||||
type = arg['type']
|
||||
if type in self.REAL_TENSOR_TYPES:
|
||||
return ('(TH{CReal}Tensor*){name}->cdata()'
|
||||
.format(CReal=CReal, name=name))
|
||||
elif type in self.INDEX_TENSOR_TYPES:
|
||||
return '({type}){name}->cdata()'.format(type=type, name=name)
|
||||
elif type == 'THCState*':
|
||||
return '({}){}'.format(type, name)
|
||||
elif type == 'real':
|
||||
if real == 'half':
|
||||
return 'THC_float2half({})'.format(name)
|
||||
return '({real}){name}'.format(real=real, name=name)
|
||||
return name
|
||||
|
||||
def cast(arg, CReal, real):
|
||||
expr = base_cast(arg, CReal, real)
|
||||
if arg.get('optional', False):
|
||||
name = arg['formal_name']
|
||||
return '{name} ? {expr} : NULL'.format(name=name, expr=expr)
|
||||
return expr
|
||||
|
||||
if option['backend'] == 'nn':
|
||||
float_args = []
|
||||
double_args = []
|
||||
for idx, arg in enumerate(option['arguments']):
|
||||
float_args.append(cast(arg, 'Float', 'float'))
|
||||
double_args.append(cast(arg, 'Double', 'double'))
|
||||
|
||||
code = self.THNN_TEMPLATE.substitute(
|
||||
name=option['cname'],
|
||||
float_args=',\n'.join(float_args),
|
||||
double_args=',\n'.join(double_args))
|
||||
|
||||
elif option['backend'] == 'cunn':
|
||||
float_args = []
|
||||
double_args = []
|
||||
half_args = []
|
||||
for idx, arg in enumerate(option['arguments']):
|
||||
float_args.append(cast(arg, 'Cuda', 'float'))
|
||||
double_args.append(cast(arg, 'CudaDouble', 'double'))
|
||||
half_args.append(cast(arg, 'CudaHalf', 'half'))
|
||||
|
||||
code = self.THCUNN_TEMPLATE.substitute(
|
||||
name=option['cname'],
|
||||
float_args=',\n'.join(float_args),
|
||||
double_args=',\n'.join(double_args),
|
||||
half_args=',\n'.join(half_args))
|
||||
|
||||
return [code, '']
|
||||
|
||||
def get_type_unpack(self, arg, option):
|
||||
return Template(arg['name'])
|
||||
|
||||
def get_type_check(self, arg, option):
|
||||
if option['backend'] == 'cunn':
|
||||
return Template('is_cuda')
|
||||
else:
|
||||
return Template('!is_cuda')
|
||||
|
||||
def get_formal_args(self, arguments):
|
||||
formal_args = []
|
||||
for arg in arguments:
|
||||
arg = copy.copy(arg)
|
||||
new_type = self.INPUT_ARGUMENT_MAP.get(arg['type'])
|
||||
if new_type is not None:
|
||||
arg['type'] = new_type
|
||||
formal_args.append(arg)
|
||||
return formal_args
|
||||
|
||||
def get_wrapper_template(self, declaration):
|
||||
# get formal arguments string
|
||||
base_arguments = declaration['options'][0]['arguments']
|
||||
args = self.get_formal_args(base_arguments)
|
||||
arg_str = ', '.join([arg['type'] + ' ' + arg['name'] for arg in args])
|
||||
|
||||
if self.header:
|
||||
return Template(self.HEADER_TEMPLATE.safe_substitute(args=arg_str))
|
||||
|
||||
def get_checked_args(tensor_types):
|
||||
checked_args = []
|
||||
for arg in base_arguments:
|
||||
if arg['type'] in tensor_types:
|
||||
name = arg.get('formal_name', arg['name'])
|
||||
name_str = name
|
||||
if arg.get('optional', False):
|
||||
name_str = '?' + name_str
|
||||
checked_args += ['"' + name_str + '"', name]
|
||||
checked_args += ['NULL']
|
||||
return checked_args
|
||||
|
||||
real_args = get_checked_args(self.REAL_TENSOR_TYPES)
|
||||
long_args = get_checked_args(self.INDEX_TENSOR_TYPES)
|
||||
|
||||
# check input types
|
||||
types_checks = []
|
||||
if len(real_args) > 1:
|
||||
types_checks.append(self.INPUT_TYPE_CHECK.substitute(
|
||||
type='type', tensor_args=', '.join(real_args)))
|
||||
if len(long_args) > 1:
|
||||
types_checks.append(self.INPUT_TYPE_CHECK.substitute(
|
||||
type='thpp::Type::LONG', tensor_args=', '.join(long_args)))
|
||||
|
||||
return Template(self.WRAPPER_TEMPLATE.safe_substitute(
|
||||
input=args[0]['name'],
|
||||
args=arg_str,
|
||||
type_check='\n '.join(types_checks)))
|
||||
@ -1,6 +1,7 @@
|
||||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class KwargsPlugin(CWrapPlugin):
|
||||
|
||||
ACCESSOR_TEMPLATE = Template('(__tuplecount > $idx ? PyTuple_GET_ITEM(args, $idx) : __kw_$name)')
|
||||
@ -53,7 +54,8 @@ class KwargsPlugin(CWrapPlugin):
|
||||
seen_args.add(name)
|
||||
args.append(name)
|
||||
declarations = '\n '.join(['PyObject *__kw_{} = NULL;'.format(name) for name in args])
|
||||
lookups = '\n '.join(['__kw_{name} = PyDict_GetItemString(kwargs, "{name}");'.format(name=name) for name in args])
|
||||
lookups = '\n '.join(
|
||||
['__kw_{name} = PyDict_GetItemString(kwargs, "{name}");'.format(name=name) for name in args])
|
||||
start_idx = code.find('{') + 1
|
||||
new_code = self.WRAPPER_TEMPLATE.substitute(declarations=declarations, lookups=lookups)
|
||||
return code[:start_idx] + new_code + code[start_idx:]
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from . import CWrapPlugin
|
||||
|
||||
|
||||
class NullableArguments(CWrapPlugin):
|
||||
|
||||
def process_single_check(self, code, arg, arg_accessor):
|
||||
if 'nullable' in arg and arg['nullable']:
|
||||
return '({} || {} == Py_None)'.format(code, arg_accessor)
|
||||
@ -10,5 +12,3 @@ class NullableArguments(CWrapPlugin):
|
||||
if 'nullable' in arg and arg['nullable']:
|
||||
return '({} == Py_None ? NULL : {})'.format(arg_accessor, code)
|
||||
return code
|
||||
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ from copy import deepcopy
|
||||
from . import CWrapPlugin
|
||||
from itertools import product
|
||||
|
||||
|
||||
class OptionalArguments(CWrapPlugin):
|
||||
|
||||
def process_declarations(self, declarations):
|
||||
@ -32,20 +33,20 @@ class OptionalArguments(CWrapPlugin):
|
||||
else:
|
||||
kwarg_only_count = -kwarg_only_count
|
||||
arg_signature = '#'.join(
|
||||
arg['type']
|
||||
for arg in option['arguments'][:kwarg_only_count]
|
||||
if not arg.get('ignore_check'))
|
||||
arg['type']
|
||||
for arg in option['arguments'][:kwarg_only_count]
|
||||
if not arg.get('ignore_check'))
|
||||
if kwarg_only_count is None:
|
||||
return arg_signature
|
||||
kwarg_only_signature = '#'.join(
|
||||
arg['name'] + '#' + arg['type']
|
||||
for arg in option['arguments'][kwarg_only_count:]
|
||||
if not arg.get('ignore_check'))
|
||||
arg['name'] + '#' + arg['type']
|
||||
for arg in option['arguments'][kwarg_only_count:]
|
||||
if not arg.get('ignore_check'))
|
||||
return arg_signature + "#-#" + kwarg_only_signature
|
||||
seen_signatures = set()
|
||||
unique = []
|
||||
for option in options:
|
||||
for num_kwarg_only in range(0, len(option['arguments'])+1):
|
||||
for num_kwarg_only in range(0, len(option['arguments']) + 1):
|
||||
sig = signature(option, num_kwarg_only)
|
||||
if sig not in seen_signatures:
|
||||
if num_kwarg_only > 0:
|
||||
@ -55,4 +56,3 @@ class OptionalArguments(CWrapPlugin):
|
||||
seen_signatures.add(sig)
|
||||
break
|
||||
return unique
|
||||
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class ReturnArguments(CWrapPlugin):
|
||||
ARGUMENT_RETURN_TEMPLATE = Template("Py_INCREF($arg);\nreturn (PyObject*)($arg);")
|
||||
TUPLE_RETURN_TEMPLATE = Template("return PyTuple_Pack($num_args, $args);")
|
||||
ARGUMENT_RETURN_TEMPLATE = Template("Py_INCREF($arg);\nreturn (PyObject*)($arg);")
|
||||
TUPLE_RETURN_TEMPLATE = Template("return PyTuple_Pack($num_args, $args);")
|
||||
|
||||
def initialize(self, cwrap):
|
||||
self.cwrap = cwrap
|
||||
@ -16,4 +17,5 @@ class ReturnArguments(CWrapPlugin):
|
||||
if len(args) == 1:
|
||||
return Template(self.ARGUMENT_RETURN_TEMPLATE.safe_substitute(arg=accessors[0]))
|
||||
else:
|
||||
return Template(self.TUPLE_RETURN_TEMPLATE.safe_substitute(num_args=len(args), args=', '.join(accessors)))
|
||||
return Template(self.TUPLE_RETURN_TEMPLATE.safe_substitute(num_args=len(args),
|
||||
args=', '.join(accessors)))
|
||||
|
||||
@ -26,41 +26,41 @@ $METHODS
|
||||
class StandaloneExtension(CWrapPlugin):
|
||||
|
||||
TYPE_UNPACK = {
|
||||
'THFloatTensor*': Template('THPFloatTensor_CData((THPFloatTensor*)$arg)'),
|
||||
'THDoubleTensor*': Template('THPDoubleTensor_CData((THPDoubleTensor*)$arg)'),
|
||||
'THLongTensor*': Template('THPLongTensor_CData((THPLongTensor*)$arg)'),
|
||||
'THIntTensor*': Template('THPIntTensor_CData((THPIntTensor*)$arg)'),
|
||||
'THFloatTensor*': Template('THPFloatTensor_CData((THPFloatTensor*)$arg)'),
|
||||
'THDoubleTensor*': Template('THPDoubleTensor_CData((THPDoubleTensor*)$arg)'),
|
||||
'THLongTensor*': Template('THPLongTensor_CData((THPLongTensor*)$arg)'),
|
||||
'THIntTensor*': Template('THPIntTensor_CData((THPIntTensor*)$arg)'),
|
||||
'THCudaHalfTensor*': Template('THCPHalfTensor_CData((THCPHalfTensor*)$arg)'),
|
||||
'THCudaTensor*': Template('THCPFloatTensor_CData((THCPFloatTensor*)$arg)'),
|
||||
'THCudaTensor*': Template('THCPFloatTensor_CData((THCPFloatTensor*)$arg)'),
|
||||
'THCudaDoubleTensor*': Template('THCPDoubleTensor_CData((THCPDoubleTensor*)$arg)'),
|
||||
'THCudaLongTensor*': Template('THCPLongTensor_CData((THCPLongTensor*)$arg)'),
|
||||
'half': Template('THPHalfUtils_unpackReal($arg)'),
|
||||
'float': Template('THPFloatUtils_unpackReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_unpackReal($arg)'),
|
||||
'bool': Template('($arg == Py_True ? true : false)'),
|
||||
'int': Template('THPUtils_unpackLong($arg)'),
|
||||
'long': Template('THPUtils_unpackLong($arg)'),
|
||||
'void*': Template('(void*)THPUtils_unpackLong($arg)'),
|
||||
'THGenerator*': Template('THPGenerator_CData((THPGenerator*)$arg)'),
|
||||
'half': Template('THPHalfUtils_unpackReal($arg)'),
|
||||
'float': Template('THPFloatUtils_unpackReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_unpackReal($arg)'),
|
||||
'bool': Template('($arg == Py_True ? true : false)'),
|
||||
'int': Template('THPUtils_unpackLong($arg)'),
|
||||
'long': Template('THPUtils_unpackLong($arg)'),
|
||||
'void*': Template('(void*)THPUtils_unpackLong($arg)'),
|
||||
'THGenerator*': Template('THPGenerator_CData((THPGenerator*)$arg)'),
|
||||
}
|
||||
|
||||
TYPE_CHECK = {
|
||||
'THDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THPDoubleTensorClass'),
|
||||
'THFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THPFloatTensorClass'),
|
||||
'THLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THPLongTensorClass'),
|
||||
'THIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIntTensorClass'),
|
||||
'THDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THPDoubleTensorClass'),
|
||||
'THFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THPFloatTensorClass'),
|
||||
'THLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THPLongTensorClass'),
|
||||
'THIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIntTensorClass'),
|
||||
'THCudaHalfTensor*': Template('THCPHalfTensor_Check($arg)'),
|
||||
'THCudaTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPFloatTensorClass'),
|
||||
'THCudaTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPFloatTensorClass'),
|
||||
'THCudaDoubleTensor*': Template('THCPDoubleTensor_Check($arg)'),
|
||||
'THCudaLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPLongTensorClass'),
|
||||
'half': Template('THPHalfUtils_checkReal($arg)'),
|
||||
'float': Template('THPFloatUtils_checkReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_checkReal($arg)'),
|
||||
'bool': Template('PyBool_Check($arg)'),
|
||||
'int': Template('THPUtils_checkLong($arg)'),
|
||||
'long': Template('THPUtils_checkLong($arg)'),
|
||||
'void*': Template('THPUtils_checkLong($arg)'),
|
||||
'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'),
|
||||
'half': Template('THPHalfUtils_checkReal($arg)'),
|
||||
'float': Template('THPFloatUtils_checkReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_checkReal($arg)'),
|
||||
'bool': Template('PyBool_Check($arg)'),
|
||||
'int': Template('THPUtils_checkLong($arg)'),
|
||||
'long': Template('THPUtils_checkLong($arg)'),
|
||||
'void*': Template('THPUtils_checkLong($arg)'),
|
||||
'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'),
|
||||
}
|
||||
|
||||
WRAPPER_TEMPLATE = Template("""
|
||||
@ -131,6 +131,7 @@ PyObject * $name(PyObject *_unused, PyObject *args)
|
||||
|
||||
def get_wrapper_template(self, declaration):
|
||||
arg_desc = []
|
||||
|
||||
def describe_arg(arg):
|
||||
desc = self.TYPE_NAMES[arg['type']] + ' ' + arg['name']
|
||||
if arg.get('nullable'):
|
||||
@ -138,8 +139,8 @@ PyObject * $name(PyObject *_unused, PyObject *args)
|
||||
return desc
|
||||
for option in declaration['options']:
|
||||
option_desc = [describe_arg(arg)
|
||||
for arg in option['arguments']
|
||||
if not arg.get('ignore_check', False)]
|
||||
for arg in option['arguments']
|
||||
if not arg.get('ignore_check', False)]
|
||||
if option_desc:
|
||||
arg_desc.append('({})'.format(', '.join(option_desc)))
|
||||
else:
|
||||
|
||||
@ -4,85 +4,91 @@ from . import CWrapPlugin
|
||||
from itertools import product, chain
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class THPPlugin(CWrapPlugin):
|
||||
|
||||
TYPE_UNPACK = {
|
||||
'THFloatTensor*': Template('((THPFloatTensor*)$arg)->cdata'),
|
||||
'THDoubleTensor*': Template('((THPDoubleTensor*)$arg)->cdata'),
|
||||
'THLongTensor*': Template('((THPLongTensor*)$arg)->cdata'),
|
||||
'THIntTensor*': Template('((THPIntTensor*)$arg)->cdata'),
|
||||
'THTensor*': Template('((THPTensor*)$arg)->cdata'),
|
||||
'THBoolTensor*': Template('((THPBoolTensor*)$arg)->cdata'),
|
||||
'THIndexTensor*': Template('((THPIndexTensor*)$arg)->cdata'),
|
||||
'THFloatTensor*': Template('((THPFloatTensor*)$arg)->cdata'),
|
||||
'THDoubleTensor*': Template('((THPDoubleTensor*)$arg)->cdata'),
|
||||
'THLongTensor*': Template('((THPLongTensor*)$arg)->cdata'),
|
||||
'THIntTensor*': Template('((THPIntTensor*)$arg)->cdata'),
|
||||
'THTensor*': Template('((THPTensor*)$arg)->cdata'),
|
||||
'THBoolTensor*': Template('((THPBoolTensor*)$arg)->cdata'),
|
||||
'THIndexTensor*': Template('((THPIndexTensor*)$arg)->cdata'),
|
||||
|
||||
'THSFloatTensor*': Template('((THSPFloatTensor*)$arg)->cdata'),
|
||||
'THCudaTensor*': Template('((THCPFloatTensor*)$arg)->cdata'),
|
||||
'THCudaDoubleTensor*': Template('((THCPDoubleTensor*)$arg)->cdata'),
|
||||
|
||||
'THSFloatTensor*': Template('((THSPFloatTensor*)$arg)->cdata'),
|
||||
'THSDoubleTensor*': Template('((THSPDoubleTensor*)$arg)->cdata'),
|
||||
'THSLongTensor*': Template('((THSPLongTensor*)$arg)->cdata'),
|
||||
'THSIntTensor*': Template('((THSPIntTensor*)$arg)->cdata'),
|
||||
'THSTensor*': Template('((THSPTensor*)$arg)->cdata'),
|
||||
'THSBoolTensor*': Template('((THSPBoolTensor*)$arg)->cdata'),
|
||||
'THSIndexTensor*': Template('((THSPIndexTensor*)$arg)->cdata'),
|
||||
'THSLongTensor*': Template('((THSPLongTensor*)$arg)->cdata'),
|
||||
'THSIntTensor*': Template('((THSPIntTensor*)$arg)->cdata'),
|
||||
'THSTensor*': Template('((THSPTensor*)$arg)->cdata'),
|
||||
'THSBoolTensor*': Template('((THSPBoolTensor*)$arg)->cdata'),
|
||||
'THSIndexTensor*': Template('((THSPIndexTensor*)$arg)->cdata'),
|
||||
|
||||
'THLongStorage*': Template('((THPLongStorage*)$arg)->cdata'),
|
||||
'THStorage*': Template('((THPStorage*)$arg)->cdata'),
|
||||
'THGenerator*': Template('((THPGenerator*)$arg)->cdata'),
|
||||
'THSize*': Template('__size.get()'),
|
||||
'THStride*': Template('__stride.get()'),
|
||||
'void*': Template('THPUtils_unpackLong($arg)'),
|
||||
'long': Template('THPUtils_unpackLong($arg)'),
|
||||
'int': Template('THPUtils_unpackLong($arg)'),
|
||||
'bool': Template('($arg == Py_True ? true : false)'),
|
||||
'float': Template('THPFloatUtils_unpackReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_unpackReal($arg)'),
|
||||
'real': Template('THPUtils_(unpackReal)($arg)'),
|
||||
'accreal': Template('THPUtils_(unpackAccreal)($arg)'),
|
||||
'THLongStorage*': Template('((THPLongStorage*)$arg)->cdata'),
|
||||
'THStorage*': Template('((THPStorage*)$arg)->cdata'),
|
||||
'THGenerator*': Template('((THPGenerator*)$arg)->cdata'),
|
||||
'THSize*': Template('__size.get()'),
|
||||
'THStride*': Template('__stride.get()'),
|
||||
'void*': Template('THPUtils_unpackLong($arg)'),
|
||||
'long': Template('THPUtils_unpackLong($arg)'),
|
||||
'int': Template('THPUtils_unpackLong($arg)'),
|
||||
'bool': Template('($arg == Py_True ? true : false)'),
|
||||
'float': Template('THPFloatUtils_unpackReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_unpackReal($arg)'),
|
||||
'real': Template('THPUtils_(unpackReal)($arg)'),
|
||||
'accreal': Template('THPUtils_(unpackAccreal)($arg)'),
|
||||
}
|
||||
|
||||
TYPE_CHECK = {
|
||||
'THDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THPDoubleTensorClass'),
|
||||
'THFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THPFloatTensorClass'),
|
||||
'THLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THPLongTensorClass'),
|
||||
'THIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIntTensorClass'),
|
||||
'THCudaTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPFloatTensorClass'),
|
||||
'THTensor*': Template('(PyObject*)Py_TYPE($arg) == THPTensorClass'),
|
||||
'THBoolTensor*': Template('(PyObject*)Py_TYPE($arg) == THPBoolTensorClass'),
|
||||
'THIndexTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIndexTensorClass'),
|
||||
'THDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THPDoubleTensorClass'),
|
||||
'THFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THPFloatTensorClass'),
|
||||
'THLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THPLongTensorClass'),
|
||||
'THIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIntTensorClass'),
|
||||
'THTensor*': Template('(PyObject*)Py_TYPE($arg) == THPTensorClass'),
|
||||
'THBoolTensor*': Template('(PyObject*)Py_TYPE($arg) == THPBoolTensorClass'),
|
||||
'THIndexTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIndexTensorClass'),
|
||||
|
||||
'THCudaTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPFloatTensorClass'),
|
||||
'THCudaDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPDoubleTensorClass'),
|
||||
|
||||
'THSDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPDoubleTensorClass'),
|
||||
'THSFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPFloatTensorClass'),
|
||||
'THSLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPLongTensorClass'),
|
||||
'THSIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPIntTensorClass'),
|
||||
'THSTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPTensorClass'),
|
||||
'THSBoolTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPBoolTensorClass'),
|
||||
'THSIndexTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPIndexTensorClass'),
|
||||
'THSFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPFloatTensorClass'),
|
||||
'THSLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPLongTensorClass'),
|
||||
'THSIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPIntTensorClass'),
|
||||
'THSTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPTensorClass'),
|
||||
'THSBoolTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPBoolTensorClass'),
|
||||
'THSIndexTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPIndexTensorClass'),
|
||||
|
||||
'THLongStorage*': Template('(PyObject*)Py_TYPE($arg) == THPLongStorageClass'),
|
||||
'THStorage*': Template('(PyObject*)Py_TYPE($arg) == THPStorageClass'),
|
||||
'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'),
|
||||
'THSize*': Template('THPUtils_tryUnpackLongs($arg, __size)'),
|
||||
'THStride*': Template('THPUtils_tryUnpackLongs($arg, __stride)'),
|
||||
'void*': Template('THPUtils_checkLong($arg)'),
|
||||
'long': Template('THPUtils_checkLong($arg)'),
|
||||
'int': Template('THPUtils_checkLong($arg)'),
|
||||
'bool': Template('PyBool_Check($arg)'),
|
||||
'float': Template('THPFloatUtils_checkReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_checkReal($arg)'),
|
||||
'real': Template('THPUtils_(checkReal)($arg)'),
|
||||
'accreal': Template('THPUtils_(checkAccreal)($arg)'),
|
||||
'THLongStorage*': Template('(PyObject*)Py_TYPE($arg) == THPLongStorageClass'),
|
||||
'THStorage*': Template('(PyObject*)Py_TYPE($arg) == THPStorageClass'),
|
||||
'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'),
|
||||
'THSize*': Template('THPUtils_tryUnpackLongs($arg, __size)'),
|
||||
'THStride*': Template('THPUtils_tryUnpackLongs($arg, __stride)'),
|
||||
'void*': Template('THPUtils_checkLong($arg)'),
|
||||
'long': Template('THPUtils_checkLong($arg)'),
|
||||
'int': Template('THPUtils_checkLong($arg)'),
|
||||
'bool': Template('PyBool_Check($arg)'),
|
||||
'float': Template('THPFloatUtils_checkReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_checkReal($arg)'),
|
||||
'real': Template('THPUtils_(checkReal)($arg)'),
|
||||
'accreal': Template('THPUtils_(checkAccreal)($arg)'),
|
||||
}
|
||||
|
||||
SIZE_VARARG_CHECK = Template('THPUtils_tryUnpackLongVarArgs(args, $idx, __size)')
|
||||
|
||||
RETURN_WRAPPER = {
|
||||
'THTensor*': Template('return THPTensor_(New)($result);'),
|
||||
'THSTensor*': Template('return THSPTensor_(New)($result);'),
|
||||
'THLongTensor*': Template('return THPLongTensor_New($result);'),
|
||||
'THLongStorage*': Template('return THPLongStorage_New($result);'),
|
||||
'THTensor*': Template('return THPTensor_(New)($result);'),
|
||||
'THSTensor*': Template('return THSPTensor_(New)($result);'),
|
||||
'THLongTensor*': Template('return THPLongTensor_New($result);'),
|
||||
'THLongStorage*': Template('return THPLongStorage_New($result);'),
|
||||
# TODO: make it smarter - it should return python long if result doesn't fit into an int
|
||||
'long': Template('return PyInt_FromLong($result);'),
|
||||
'accreal': Template('return THPUtils_(newAccreal)($result);'),
|
||||
'self': Template('Py_INCREF(self);\nreturn (PyObject*)self;'),
|
||||
'real': Template('return THPUtils_(newReal)($result);'),
|
||||
'long': Template('return PyInt_FromLong($result);'),
|
||||
'accreal': Template('return THPUtils_(newAccreal)($result);'),
|
||||
'self': Template('Py_INCREF(self);\nreturn (PyObject*)self;'),
|
||||
'real': Template('return THPUtils_(newReal)($result);'),
|
||||
}
|
||||
|
||||
TENSOR_METHODS_DECLARATION = Template("""
|
||||
@ -138,13 +144,13 @@ ${cpu}
|
||||
return Template(code)
|
||||
|
||||
ALLOCATE_TYPE = {
|
||||
'THTensor*': _allocate('', ALLOCATE_TMPL),
|
||||
'THLongTensor*': _allocate('Long', ALLOCATE_TMPL),
|
||||
'THIntTensor*': _allocate('Int', ALLOCATE_TMPL),
|
||||
'THBoolTensor*': _allocate('Byte', ALLOCATE_TMPL, ALLOCATE_CUDA),
|
||||
'THIndexTensor*': _allocate('Long', ALLOCATE_TMPL, ALLOCATE_CUDA),
|
||||
'THTensor*': _allocate('', ALLOCATE_TMPL),
|
||||
'THLongTensor*': _allocate('Long', ALLOCATE_TMPL),
|
||||
'THIntTensor*': _allocate('Int', ALLOCATE_TMPL),
|
||||
'THBoolTensor*': _allocate('Byte', ALLOCATE_TMPL, ALLOCATE_CUDA),
|
||||
'THIndexTensor*': _allocate('Long', ALLOCATE_TMPL, ALLOCATE_CUDA),
|
||||
|
||||
'THSTensor*': _allocate('', ALLOCATE_TMPL, sparse=True),
|
||||
'THSTensor*': _allocate('', ALLOCATE_TMPL, sparse=True),
|
||||
}
|
||||
|
||||
TYPE_NAMES = {
|
||||
@ -159,6 +165,8 @@ ${cpu}
|
||||
'THIndexTensor*': '" THPModuleStr "LongTensor',
|
||||
'THFloatTensor*': '" THPModuleStr "FloatTensor',
|
||||
'THDoubleTensor*': '" THPModuleStr "DoubleTensor',
|
||||
'THCudaTensor*': 'torch.cuda.FloatTensor',
|
||||
'THCudaDoubleTensor*': 'torch.cuda.DoubleTensor',
|
||||
'THSize*': 'torch.Size',
|
||||
'THStride*': 'tuple',
|
||||
'long': 'int',
|
||||
@ -198,14 +206,14 @@ ${cpu}
|
||||
def format_args(args, var_args=False):
|
||||
option_desc = [format_arg(arg, var_args)
|
||||
for arg in args
|
||||
if not arg.get('ignore_check', False)
|
||||
and not arg.get('output')]
|
||||
if not arg.get('ignore_check', False) and
|
||||
not arg.get('output')]
|
||||
output_args = list(filter(lambda a: a.get('output'), args))
|
||||
if output_args:
|
||||
if len(output_args) > 1:
|
||||
out_type = 'tuple['
|
||||
out_type += ', '.join(
|
||||
self.TYPE_NAMES[arg['type']] for arg in output_args)
|
||||
self.TYPE_NAMES[arg['type']] for arg in output_args)
|
||||
out_type += ']'
|
||||
option_desc += ['#' + out_type + ' out']
|
||||
else:
|
||||
@ -287,7 +295,7 @@ ${cpu}
|
||||
if not output_provided:
|
||||
arg['ignore_check'] = True
|
||||
else:
|
||||
option_copy['argcount_offset'] = -len(out_idx) + 1
|
||||
option_copy['argcount_offset'] = -len(out_idx) + 1
|
||||
arg['no_kwargs'] = True
|
||||
arg['no_idx'] = True
|
||||
new_options.append(option_copy)
|
||||
@ -345,7 +353,6 @@ ${cpu}
|
||||
if arg['name'] == 'self':
|
||||
arg['ignore_check'] = True
|
||||
|
||||
|
||||
declarations = [d for d in declarations if not d.get('only_stateless', False)]
|
||||
self.declarations.extend(filter(lambda x: not x.get('only_stateless', False), register_only))
|
||||
self.stateless_declarations.extend(filter(lambda x: x.get('only_stateless', False), register_only))
|
||||
@ -377,9 +384,9 @@ ${cpu}
|
||||
if declaration.get('override_method_flags'):
|
||||
flags = declaration['override_method_flags']
|
||||
entry = Template(' {"$python_name", (PyCFunction)$name, $flags, $docstring},\n').substitute(
|
||||
python_name=declaration['python_name'], name=declaration['name'], flags=flags,
|
||||
docstring=declaration.get('docstring_var', 'NULL')
|
||||
)
|
||||
python_name=declaration['python_name'], name=declaration['name'], flags=flags,
|
||||
docstring=declaration.get('docstring_var', 'NULL')
|
||||
)
|
||||
if 'defined_if' in declaration:
|
||||
entry = self.preprocessor_guard(entry, declaration['defined_if'])
|
||||
tensor_methods += entry
|
||||
@ -392,16 +399,16 @@ ${cpu}
|
||||
def process_full_file(self, code):
|
||||
# We have to find a place before all undefs
|
||||
idx = code.find('// PUT DEFINITIONS IN HERE PLEASE')
|
||||
return (code[:idx]
|
||||
+ self.declare_methods(False, False)
|
||||
+ self.declare_methods(True, False)
|
||||
+ self.declare_methods(False, True)
|
||||
+ self.declare_methods(True, True)
|
||||
+ code[idx:]
|
||||
return (code[:idx] +
|
||||
self.declare_methods(False, False) +
|
||||
self.declare_methods(True, False) +
|
||||
self.declare_methods(False, True) +
|
||||
self.declare_methods(True, True) +
|
||||
code[idx:]
|
||||
)
|
||||
|
||||
def preprocessor_guard(self, code, condition):
|
||||
return '#if ' + condition + '\n' + code + '#endif\n'
|
||||
return '#if ' + condition + '\n' + code + '#endif\n'
|
||||
|
||||
def process_wrapper(self, code, declaration):
|
||||
if 'defined_if' in declaration:
|
||||
@ -419,7 +426,7 @@ ${cpu}
|
||||
if option['output_count'] > 1:
|
||||
checks += "PyTuple_Check(__out) &&\n" + indent
|
||||
length_check = "PyTuple_GET_SIZE(__out) == {} &&\n".format(
|
||||
option['output_count'])
|
||||
option['output_count'])
|
||||
checks += length_check + indent
|
||||
code = checks + code
|
||||
else:
|
||||
@ -443,13 +450,13 @@ ${cpu}
|
||||
def generate_docstrings_cpp(self):
|
||||
template = Template('char* $name = "$content";')
|
||||
return '\n\n'.join(
|
||||
template.substitute(name=decl['docstring_var'], content=decl['docstring_content'])
|
||||
for decl in chain(self.declarations, self.stateless_declarations)
|
||||
if 'docstring_var' in decl)
|
||||
template.substitute(name=decl['docstring_var'], content=decl['docstring_content'])
|
||||
for decl in chain(self.declarations, self.stateless_declarations)
|
||||
if 'docstring_var' in decl)
|
||||
|
||||
def generate_docstrings_h(self):
|
||||
template = Template('extern char* $name;')
|
||||
return '\n\n'.join(
|
||||
template.substitute(name=decl['docstring_var'])
|
||||
for decl in chain(self.declarations, self.stateless_declarations)
|
||||
if 'docstring_var' in decl)
|
||||
template.substitute(name=decl['docstring_var'])
|
||||
for decl in chain(self.declarations, self.stateless_declarations)
|
||||
if 'docstring_var' in decl)
|
||||
|
||||
@ -58,3 +58,4 @@ from .ReturnArguments import ReturnArguments
|
||||
from .GILRelease import GILRelease
|
||||
from .AutoGPU import AutoGPU
|
||||
from .CuDNNPlugin import CuDNNPlugin
|
||||
from .GenericNN import GenericNN
|
||||
|
||||
40
tools/docker/Dockerfile-v6
Normal file
40
tools/docker/Dockerfile-v6
Normal file
@ -0,0 +1,40 @@
|
||||
FROM nvidia/cuda:8.0-devel-ubuntu14.04
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
cmake \
|
||||
git \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libjpeg-dev \
|
||||
libpng-dev &&\
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
RUN curl -fsSL http://developer.download.nvidia.com/compute/redist/cudnn/v6.0/cudnn-8.0-linux-x64-v6.0-rc.tgz -O && \
|
||||
tar -xzf cudnn-8.0-linux-x64-v6.0-rc.tgz -C /usr/local && \
|
||||
rm cudnn-8.0-linux-x64-v6.0-rc.tgz && \
|
||||
ldconfig
|
||||
RUN ln -s /usr/local/cuda/lib64/libcudnn.so.6.0.5 /usr/lib/x86_64-linux-gnu/libcudnn.so.6.0.5
|
||||
|
||||
RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-4.2.12-Linux-x86_64.sh && \
|
||||
chmod +x ~/miniconda.sh && \
|
||||
~/miniconda.sh -b -p /opt/conda && \
|
||||
rm ~/miniconda.sh && \
|
||||
/opt/conda/bin/conda install conda-build && \
|
||||
/opt/conda/bin/conda create -y --name pytorch-py35 python=3.5.2 numpy scipy ipython mkl&& \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
ENV PATH /opt/conda/envs/pytorch-py35/bin:$PATH
|
||||
RUN conda install --name pytorch-py35 -c soumith magma-cuda80
|
||||
# This must be done before pip so that requirements.txt is available
|
||||
WORKDIR /opt/pytorch
|
||||
COPY . .
|
||||
|
||||
RUN cat requirements.txt | xargs -n1 pip install --no-cache-dir && \
|
||||
TORCH_CUDA_ARCH_LIST="3.5 5.2 6.0 6.1+PTX" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
|
||||
CMAKE_LIBRARY_PATH=/opt/conda/envs/pytorch-py35/lib \
|
||||
CMAKE_INCLUDE_PATH=/opt/conda/envs/pytorch-py35/include \
|
||||
pip install -v .
|
||||
|
||||
WORKDIR /workspace
|
||||
RUN chmod -R a+w /workspace
|
||||
@ -2,12 +2,13 @@ import os
|
||||
import sys
|
||||
from string import Template, ascii_lowercase
|
||||
from ..cwrap import cwrap
|
||||
from ..cwrap.plugins import StandaloneExtension, NullableArguments, AutoGPU
|
||||
from ..cwrap.plugins import StandaloneExtension, GenericNN, NullableArguments, AutoGPU
|
||||
|
||||
BASE_PATH = os.path.realpath(os.path.join(__file__, '..', '..', '..'))
|
||||
WRAPPER_PATH = os.path.join(BASE_PATH, 'torch', 'csrc', 'nn')
|
||||
THNN_UTILS_PATH = os.path.join(BASE_PATH, 'torch', '_thnn', 'utils.py')
|
||||
|
||||
|
||||
def import_module(name, path):
|
||||
if sys.version_info >= (3, 5):
|
||||
import importlib.util
|
||||
@ -81,7 +82,8 @@ for t in ['CudaHalf', 'Cuda', 'CudaDouble']:
|
||||
def wrap_function(name, type, arguments):
|
||||
cname = 'THNN_' + type + name
|
||||
declaration = ''
|
||||
declaration += 'extern "C" void ' + cname + '(' + ', '.join(TYPE_TRANSFORMS[type].get(arg.type, arg.type) for arg in arguments) + ');\n'
|
||||
declaration += 'extern "C" void ' + cname + \
|
||||
'(' + ', '.join(TYPE_TRANSFORMS[type].get(arg.type, arg.type) for arg in arguments) + ');\n'
|
||||
declaration += FUNCTION_TEMPLATE.substitute(name=type + name, cname=cname)
|
||||
indent = ' ' * 4
|
||||
dict_indent = ' ' * 6
|
||||
@ -91,15 +93,18 @@ def wrap_function(name, type, arguments):
|
||||
declaration += prefix + TYPE_TRANSFORMS[type].get(arg.type, arg.type) + ' ' + arg.name + '\n'
|
||||
else:
|
||||
t = TYPE_TRANSFORMS[type].get(arg.type, arg.type)
|
||||
declaration += prefix + 'type: ' + t + '\n' + \
|
||||
dict_indent + 'name: ' + arg.name + '\n' + \
|
||||
dict_indent + 'nullable: True' + '\n'
|
||||
declaration += prefix + 'type: ' + t + '\n' + \
|
||||
dict_indent + 'name: ' + arg.name + '\n' + \
|
||||
dict_indent + 'nullable: True' + '\n'
|
||||
declaration += ']]\n\n\n'
|
||||
return declaration
|
||||
|
||||
|
||||
def generate_wrappers():
|
||||
wrap_nn()
|
||||
wrap_cunn()
|
||||
wrap_generic()
|
||||
|
||||
|
||||
def wrap_nn():
|
||||
wrapper = '#include <TH/TH.h>\n\n\n'
|
||||
@ -114,6 +119,7 @@ def wrap_nn():
|
||||
NullableArguments(),
|
||||
])
|
||||
|
||||
|
||||
def wrap_cunn():
|
||||
wrapper = '#include <TH/TH.h>\n'
|
||||
wrapper += '#include <THC/THC.h>\n\n\n'
|
||||
@ -128,3 +134,66 @@ def wrap_cunn():
|
||||
NullableArguments(),
|
||||
AutoGPU(has_self=False),
|
||||
])
|
||||
|
||||
GENERIC_FUNCTION_TEMPLATE = Template("""\
|
||||
[[
|
||||
name: $name
|
||||
return: void
|
||||
options:
|
||||
""")
|
||||
|
||||
|
||||
def wrap_generic_function(name, backends):
|
||||
declaration = ''
|
||||
declaration += GENERIC_FUNCTION_TEMPLATE.substitute(name=name)
|
||||
for backend in backends:
|
||||
declaration += ' - cname: ' + name + '\n'
|
||||
declaration += ' backend: ' + backend['name'] + '\n'
|
||||
declaration += ' arguments:\n'
|
||||
for arg in backend['arguments']:
|
||||
declaration += ' - arg: ' + arg.type + ' ' + arg.name + '\n'
|
||||
if arg.is_optional:
|
||||
declaration += ' optional: True\n'
|
||||
declaration += ']]\n\n\n'
|
||||
return declaration
|
||||
|
||||
|
||||
def wrap_generic():
|
||||
from collections import OrderedDict
|
||||
defs = OrderedDict()
|
||||
|
||||
def should_wrap_function(name):
|
||||
if name.startswith('LookupTable'):
|
||||
return False
|
||||
return (name.endswith('updateOutput') or
|
||||
name.endswith('updateGradInput') or
|
||||
name.endswith('accGradParameters') or
|
||||
name.endswith('backward'))
|
||||
|
||||
def add_functions(name, functions):
|
||||
for fn in functions:
|
||||
if not should_wrap_function(fn.name):
|
||||
continue
|
||||
if fn.name not in defs:
|
||||
defs[fn.name] = []
|
||||
defs[fn.name] += [{
|
||||
'name': name,
|
||||
'arguments': fn.arguments[1:],
|
||||
}]
|
||||
|
||||
add_functions('nn', thnn_utils.parse_header(thnn_utils.THNN_H_PATH))
|
||||
add_functions('cunn', thnn_utils.parse_header(thnn_utils.THCUNN_H_PATH))
|
||||
|
||||
wrapper = ''
|
||||
for name, backends in defs.items():
|
||||
wrapper += wrap_generic_function(name, backends)
|
||||
with open('torch/csrc/nn/THNN_generic.cwrap', 'w') as f:
|
||||
f.write(wrapper)
|
||||
|
||||
cwrap('torch/csrc/nn/THNN_generic.cwrap', plugins=[
|
||||
GenericNN(header=True),
|
||||
], default_plugins=False, destination='torch/csrc/nn/THNN_generic.h')
|
||||
|
||||
cwrap('torch/csrc/nn/THNN_generic.cwrap', plugins=[
|
||||
GenericNN(),
|
||||
], default_plugins=False)
|
||||
|
||||
@ -1,8 +1,17 @@
|
||||
import ctypes.util
|
||||
import os
|
||||
|
||||
from .env import check_env_flag
|
||||
|
||||
CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda')
|
||||
WITH_CUDA = not check_env_flag('NO_CUDA') and os.path.exists(CUDA_HOME)
|
||||
if not WITH_CUDA:
|
||||
if check_env_flag('NO_CUDA'):
|
||||
WITH_CUDA = False
|
||||
CUDA_HOME = None
|
||||
else:
|
||||
CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda')
|
||||
if not os.path.exists(CUDA_HOME):
|
||||
cudart_path = ctypes.util.find_library('cudart')
|
||||
if cudart_path is not None:
|
||||
CUDA_HOME = os.path.dirname(cudart_path)
|
||||
else:
|
||||
CUDA_HOME = None
|
||||
WITH_CUDA = CUDA_HOME is not None
|
||||
|
||||
@ -1,9 +1,15 @@
|
||||
import os
|
||||
import glob
|
||||
from itertools import chain
|
||||
|
||||
from .env import check_env_flag
|
||||
from .cuda import WITH_CUDA, CUDA_HOME
|
||||
|
||||
|
||||
def gather_paths(env_vars):
|
||||
return list(chain(*(os.getenv(v, '').split(':') for v in env_vars)))
|
||||
|
||||
|
||||
WITH_CUDNN = False
|
||||
CUDNN_LIB_DIR = None
|
||||
CUDNN_INCLUDE_DIR = None
|
||||
@ -12,13 +18,19 @@ if WITH_CUDA and not check_env_flag('NO_CUDNN'):
|
||||
os.getenv('CUDNN_LIB_DIR'),
|
||||
os.path.join(CUDA_HOME, 'lib'),
|
||||
os.path.join(CUDA_HOME, 'lib64'),
|
||||
'/usr/lib/x86_64-linux-gnu/',
|
||||
]))
|
||||
'/usr/lib/x86_64-linux-gnu/',
|
||||
] + gather_paths([
|
||||
'LIBRARY_PATH',
|
||||
])))
|
||||
include_paths = list(filter(bool, [
|
||||
os.getenv('CUDNN_INCLUDE_DIR'),
|
||||
os.path.join(CUDA_HOME, 'include'),
|
||||
'/usr/include/'
|
||||
]))
|
||||
'/usr/include/',
|
||||
] + gather_paths([
|
||||
'CPATH',
|
||||
'C_INCLUDE_PATH',
|
||||
'CPLUS_INCLUDE_PATH',
|
||||
])))
|
||||
for path in lib_paths:
|
||||
if path is None or not os.path.exists(path):
|
||||
continue
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
|
||||
|
||||
def check_env_flag(name):
|
||||
return os.getenv(name) in ['ON', '1', 'YES', 'TRUE', 'Y']
|
||||
|
||||
@ -56,6 +56,7 @@ del old_flags
|
||||
# Define basic utilities
|
||||
################################################################################
|
||||
|
||||
|
||||
def typename(o):
|
||||
module = ''
|
||||
class_name = ''
|
||||
@ -91,7 +92,7 @@ def set_default_tensor_type(t):
|
||||
|
||||
def set_rng_state(new_state):
|
||||
r"""Sets the random number generator state.
|
||||
|
||||
|
||||
Args:
|
||||
new_state (torch.ByteTensor): The desired state
|
||||
"""
|
||||
@ -104,9 +105,9 @@ def get_rng_state():
|
||||
|
||||
|
||||
def manual_seed(seed):
|
||||
r"""Sets the seed for generating random numbers. And returns a
|
||||
r"""Sets the seed for generating random numbers. And returns a
|
||||
`torch._C.Generator` object.
|
||||
|
||||
|
||||
Args:
|
||||
seed (int or long): The desired seed.
|
||||
"""
|
||||
@ -114,7 +115,7 @@ def manual_seed(seed):
|
||||
|
||||
|
||||
def initial_seed():
|
||||
r"""Returns the initial seed for generating random numbers as a
|
||||
r"""Returns the initial seed for generating random numbers as a
|
||||
python `long`.
|
||||
"""
|
||||
return default_generator.initial_seed()
|
||||
@ -130,61 +131,101 @@ from ._tensor_str import set_printoptions
|
||||
from .storage import _StorageBase
|
||||
from .tensor import _TensorBase
|
||||
|
||||
|
||||
class DoubleStorage(_C.DoubleStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class FloatStorage(_C.FloatStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class LongStorage(_C.LongStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class IntStorage(_C.IntStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class ShortStorage(_C.ShortStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class CharStorage(_C.CharStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class ByteStorage(_C.ByteStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class DoubleTensor(_C.DoubleTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return DoubleStorage
|
||||
|
||||
|
||||
class FloatTensor(_C.FloatTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return FloatStorage
|
||||
|
||||
|
||||
class LongTensor(_C.LongTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return LongStorage
|
||||
|
||||
|
||||
class IntTensor(_C.IntTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return IntStorage
|
||||
|
||||
|
||||
class ShortTensor(_C.ShortTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return ShortStorage
|
||||
|
||||
|
||||
class CharTensor(_C.CharTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
# TODO
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return CharStorage
|
||||
|
||||
|
||||
class ByteTensor(_C.ByteTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return ByteStorage
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -22,7 +22,7 @@ def set_printoptions(
|
||||
edgeitems=None,
|
||||
linewidth=None,
|
||||
profile=None,
|
||||
):
|
||||
):
|
||||
"""Set options for printing. Items shamelessly taken from Numpy
|
||||
|
||||
Args:
|
||||
@ -119,7 +119,7 @@ def _number_format(tensor, min_sz=-1):
|
||||
else:
|
||||
if exp_max > prec + 1 or exp_max < 0:
|
||||
sz = max(min_sz, 7)
|
||||
scale = math.pow(10, exp_max-1)
|
||||
scale = math.pow(10, exp_max - 1)
|
||||
else:
|
||||
if exp_max == 0:
|
||||
sz = 7
|
||||
@ -132,19 +132,19 @@ def _number_format(tensor, min_sz=-1):
|
||||
|
||||
def _tensor_str(self):
|
||||
n = PRINT_OPTS.edgeitems
|
||||
has_hdots = self.size()[-1] > 2*n
|
||||
has_vdots = self.size()[-2] > 2*n
|
||||
has_hdots = self.size()[-1] > 2 * n
|
||||
has_vdots = self.size()[-2] > 2 * n
|
||||
print_full_mat = not has_hdots and not has_vdots
|
||||
formatter = _number_format(self, min_sz=3 if not print_full_mat else 0)
|
||||
print_dots = self.numel() >= PRINT_OPTS.threshold
|
||||
|
||||
dim_sz = max(2, max(len(str(x)) for x in self.size()))
|
||||
dim_fmt = "{:^" + str(dim_sz) + "}"
|
||||
dot_fmt = u"{:^" + str(dim_sz+1) + "}"
|
||||
dot_fmt = u"{:^" + str(dim_sz + 1) + "}"
|
||||
|
||||
counter_dim = self.ndimension() - 2
|
||||
counter = torch.LongStorage(counter_dim).fill_(0)
|
||||
counter[counter.size()-1] = -1
|
||||
counter[counter.size() - 1] = -1
|
||||
finished = False
|
||||
strt = ''
|
||||
while True:
|
||||
@ -152,7 +152,7 @@ def _tensor_str(self):
|
||||
nskipped = [False for i in counter]
|
||||
for i in _range(counter_dim - 1, -1, -1):
|
||||
counter[i] += 1
|
||||
if print_dots and counter[i] == n and self.size(i) > 2*n:
|
||||
if print_dots and counter[i] == n and self.size(i) > 2 * n:
|
||||
counter[i] = self.size(i) - n
|
||||
nskipped[i] = True
|
||||
if counter[i] == self.size(i):
|
||||
@ -188,18 +188,18 @@ def __repr_row(row, indent, fmt, scale, sz, truncate=None):
|
||||
if truncate is not None:
|
||||
dotfmt = " {:^5} "
|
||||
return (indent +
|
||||
' '.join(fmt.format(val/scale) for val in row[:truncate]) +
|
||||
' '.join(fmt.format(val / scale) for val in row[:truncate]) +
|
||||
dotfmt.format('...') +
|
||||
' '.join(fmt.format(val/scale) for val in row[-truncate:]) +
|
||||
' '.join(fmt.format(val / scale) for val in row[-truncate:]) +
|
||||
'\n')
|
||||
else:
|
||||
return indent + ' '.join(fmt.format(val/scale) for val in row) + '\n'
|
||||
return indent + ' '.join(fmt.format(val / scale) for val in row) + '\n'
|
||||
|
||||
|
||||
def _matrix_str(self, indent='', formatter=None, force_truncate=False):
|
||||
n = PRINT_OPTS.edgeitems
|
||||
has_hdots = self.size(1) > 2*n
|
||||
has_vdots = self.size(0) > 2*n
|
||||
has_hdots = self.size(1) > 2 * n
|
||||
has_vdots = self.size(0) > 2 * n
|
||||
print_full_mat = not has_hdots and not has_vdots
|
||||
|
||||
if formatter is None:
|
||||
@ -207,14 +207,14 @@ def _matrix_str(self, indent='', formatter=None, force_truncate=False):
|
||||
min_sz=5 if not print_full_mat else 0)
|
||||
else:
|
||||
fmt, scale, sz = formatter
|
||||
nColumnPerLine = int(math.floor((PRINT_OPTS.linewidth-len(indent))/(sz+1)))
|
||||
nColumnPerLine = int(math.floor((PRINT_OPTS.linewidth - len(indent)) / (sz + 1)))
|
||||
strt = ''
|
||||
firstColumn = 0
|
||||
|
||||
if not force_truncate and \
|
||||
(self.numel() < PRINT_OPTS.threshold or print_full_mat):
|
||||
while firstColumn < self.size(1):
|
||||
lastColumn = min(firstColumn + nColumnPerLine - 1, self.size(1)-1)
|
||||
lastColumn = min(firstColumn + nColumnPerLine - 1, self.size(1) - 1)
|
||||
if nColumnPerLine < self.size(1):
|
||||
strt += '\n' if firstColumn != 1 else ''
|
||||
strt += 'Columns {} to {} \n{}'.format(
|
||||
@ -223,15 +223,15 @@ def _matrix_str(self, indent='', formatter=None, force_truncate=False):
|
||||
strt += SCALE_FORMAT.format(scale)
|
||||
for l in _range(self.size(0)):
|
||||
strt += indent + (' ' if scale != 1 else '')
|
||||
row_slice = self[l, firstColumn:lastColumn+1]
|
||||
strt += ' '.join(fmt.format(val/scale) for val in row_slice)
|
||||
row_slice = self[l, firstColumn:lastColumn + 1]
|
||||
strt += ' '.join(fmt.format(val / scale) for val in row_slice)
|
||||
strt += '\n'
|
||||
firstColumn = lastColumn + 1
|
||||
else:
|
||||
if scale != 1:
|
||||
strt += SCALE_FORMAT.format(scale)
|
||||
if has_vdots and has_hdots:
|
||||
vdotfmt = "{:^" + str((sz+1)*n-1) + "}"
|
||||
vdotfmt = "{:^" + str((sz + 1) * n - 1) + "}"
|
||||
ddotfmt = u"{:^5}"
|
||||
for row in self[:n]:
|
||||
strt += __repr_row(row, indent, fmt, scale, sz, n)
|
||||
@ -245,8 +245,8 @@ def _matrix_str(self, indent='', formatter=None, force_truncate=False):
|
||||
strt += __repr_row(row, indent, fmt, scale, sz, n)
|
||||
elif has_vdots and not has_hdots:
|
||||
vdotfmt = u"{:^" + \
|
||||
str(len(__repr_row(self[0], '', fmt, scale, sz))) + \
|
||||
"}\n"
|
||||
str(len(__repr_row(self[0], '', fmt, scale, sz))) + \
|
||||
"}\n"
|
||||
for row in self[:n]:
|
||||
strt += __repr_row(row, indent, fmt, scale, sz)
|
||||
strt += vdotfmt.format(u'\u22EE')
|
||||
@ -269,13 +269,13 @@ def _vector_str(self):
|
||||
ident = ' '
|
||||
if self.numel() < PRINT_OPTS.threshold:
|
||||
return (strt +
|
||||
'\n'.join(ident + fmt.format(val/scale) for val in self) +
|
||||
'\n'.join(ident + fmt.format(val / scale) for val in self) +
|
||||
'\n')
|
||||
else:
|
||||
return (strt +
|
||||
'\n'.join(ident + fmt.format(val/scale) for val in self[:n]) +
|
||||
'\n'.join(ident + fmt.format(val / scale) for val in self[:n]) +
|
||||
'\n' + (ident + dotfmt.format(u"\u22EE")) +
|
||||
'\n'.join(ident + fmt.format(val/scale) for val in self[-n:]) +
|
||||
'\n'.join(ident + fmt.format(val / scale) for val in self[-n:]) +
|
||||
'\n')
|
||||
|
||||
|
||||
@ -295,4 +295,3 @@ def _str(self):
|
||||
strt += '[{} of size {}{}]\n'.format(torch.typename(self),
|
||||
size_str, device_str)
|
||||
return '\n' + strt
|
||||
|
||||
|
||||
@ -2,7 +2,9 @@ import threading
|
||||
import torch.cuda
|
||||
from .utils import THNN_H_PATH, THCUNN_H_PATH, parse_header, load_backend
|
||||
|
||||
|
||||
class Backends(object):
|
||||
|
||||
def __init__(self):
|
||||
self.backends = {}
|
||||
|
||||
@ -14,6 +16,7 @@ class Backends(object):
|
||||
|
||||
|
||||
class Backend(object):
|
||||
|
||||
def __init__(self, lib_prefix, lib_name, functions, mixins=tuple()):
|
||||
self.lib_prefix = lib_prefix
|
||||
self.lib_name = lib_name
|
||||
@ -32,11 +35,12 @@ class Backend(object):
|
||||
with self.loading_lock:
|
||||
if self.backend is None:
|
||||
self.backend = load_backend(self.lib_prefix, self.lib_name,
|
||||
self.functions, self.mixins)
|
||||
self.functions, self.mixins)
|
||||
return self.backend
|
||||
|
||||
|
||||
class THNNCudaBackendStateMixin(object):
|
||||
|
||||
@property
|
||||
def library_state(self):
|
||||
return torch.cuda._state_cdata
|
||||
|
||||
@ -12,6 +12,7 @@ def _unpickle_backend(backend_name):
|
||||
|
||||
|
||||
class THNNBackendBase(object):
|
||||
|
||||
def __init__(self):
|
||||
self.methods = {}
|
||||
|
||||
@ -33,6 +34,7 @@ class THNNBackendBase(object):
|
||||
|
||||
|
||||
class Function(object):
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.arguments = []
|
||||
@ -46,6 +48,7 @@ class Function(object):
|
||||
|
||||
|
||||
class Argument(object):
|
||||
|
||||
def __init__(self, _type, name, is_optional):
|
||||
self.type = _type
|
||||
self.name = name
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -12,6 +12,7 @@ from .stochastic_function import StochasticFunction
|
||||
|
||||
__all__ = ['Variable', 'Function', 'StochasticFunction', 'backward']
|
||||
|
||||
|
||||
def backward(variables, grad_variables, retain_variables=False):
|
||||
"""Computes the sum of gradients of given variables w.r.t. graph leaves.
|
||||
|
||||
@ -28,7 +29,7 @@ def backward(variables, grad_variables, retain_variables=False):
|
||||
Arguments:
|
||||
variables (sequence of Variable): Variables of which the derivative will be
|
||||
computed.
|
||||
grad_variables (sequence of Variable): Gradients w.r.t. each element of
|
||||
grad_variables (sequence of Tensor): Gradients w.r.t. each element of
|
||||
corresponding variables. Required only for non-scalar variables that
|
||||
require gradient.
|
||||
retain_variables (bool): If ``True``, buffers necessary for computing
|
||||
@ -37,6 +38,6 @@ def backward(variables, grad_variables, retain_variables=False):
|
||||
times.
|
||||
"""
|
||||
Variable._execution_engine.run_backward(
|
||||
tuple(variables), tuple(grad_variables), retain_variables)
|
||||
tuple(variables), tuple(grad_variables), retain_variables)
|
||||
|
||||
assert torch._C._autograd_init()
|
||||
|
||||
@ -5,4 +5,4 @@ from .reduce import *
|
||||
from .linalg import *
|
||||
from .blas import *
|
||||
from .stochastic import *
|
||||
|
||||
from .compare import *
|
||||
|
||||
@ -59,7 +59,7 @@ class Pow(Function):
|
||||
|
||||
def backward(self, grad_output):
|
||||
a, b = self.saved_tensors
|
||||
return grad_output.mul(b).mul_(a.pow(b-1)), grad_output.mul(a.pow(b)).mul_(a.log())
|
||||
return grad_output.mul(b).mul_(a.pow(b - 1)), grad_output.mul(a.pow(b)).mul_(a.log())
|
||||
|
||||
|
||||
class AddConstant(InplaceFunction):
|
||||
@ -174,7 +174,7 @@ class PowConstant(Function):
|
||||
return grad_output.mul(self.fw_result).mul_(math.log(self.constant))
|
||||
else:
|
||||
a = self.saved_tensors[0]
|
||||
return grad_output.mul(self.constant).mul_(a.pow(self.constant-1))
|
||||
return grad_output.mul(self.constant).mul_(a.pow(self.constant - 1))
|
||||
|
||||
|
||||
class Negate(InplaceFunction):
|
||||
|
||||
@ -25,7 +25,7 @@ class Addmm(_BlasBase):
|
||||
self.save_for_backward(matrix1, matrix2)
|
||||
output = self._get_output(add_matrix)
|
||||
return torch.addmm(self.alpha, add_matrix, self.beta,
|
||||
matrix1, matrix2, out=output)
|
||||
matrix1, matrix2, out=output)
|
||||
|
||||
def backward(self, grad_output):
|
||||
matrix1, matrix2 = self.saved_tensors
|
||||
@ -55,7 +55,7 @@ class Addbmm(_BlasBase):
|
||||
self.save_for_backward(batch1, batch2)
|
||||
output = self._get_output(add_matrix)
|
||||
return torch.addbmm(self.alpha, add_matrix, self.beta,
|
||||
batch1, batch2, out=output)
|
||||
batch1, batch2, out=output)
|
||||
|
||||
def backward(self, grad_output):
|
||||
batch1, batch2 = self.saved_tensors
|
||||
@ -68,8 +68,8 @@ class Addbmm(_BlasBase):
|
||||
|
||||
if any(self.needs_input_grad[1:]):
|
||||
batch_grad_output = (grad_output
|
||||
.unsqueeze(0)
|
||||
.expand(batch1.size(0), batch1.size(1), batch2.size(2)))
|
||||
.unsqueeze(0)
|
||||
.expand(batch1.size(0), batch1.size(1), batch2.size(2)))
|
||||
|
||||
if self.needs_input_grad[1]:
|
||||
grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2))
|
||||
@ -90,7 +90,7 @@ class Baddbmm(_BlasBase):
|
||||
self.save_for_backward(batch1, batch2)
|
||||
output = self._get_output(add_batch)
|
||||
return torch.baddbmm(self.alpha, add_batch, self.beta,
|
||||
batch1, batch2, out=output)
|
||||
batch1, batch2, out=output)
|
||||
|
||||
def backward(self, grad_output):
|
||||
batch1, batch2 = self.saved_tensors
|
||||
@ -120,7 +120,7 @@ class Addmv(_BlasBase):
|
||||
self.save_for_backward(matrix, vector)
|
||||
output = self._get_output(add_vector)
|
||||
return torch.addmv(self.alpha, add_vector, self.beta,
|
||||
matrix, vector, out=output)
|
||||
matrix, vector, out=output)
|
||||
|
||||
def backward(self, grad_output):
|
||||
matrix, vector = self.saved_tensors
|
||||
@ -150,7 +150,7 @@ class Addr(_BlasBase):
|
||||
self.save_for_backward(vector1, vector2)
|
||||
output = self._get_output(add_matrix)
|
||||
return torch.addr(self.alpha, add_matrix, self.beta,
|
||||
vector1, vector2, out=output)
|
||||
vector1, vector2, out=output)
|
||||
|
||||
def backward(self, grad_output):
|
||||
vector1, vector2 = self.saved_tensors
|
||||
@ -199,4 +199,3 @@ class Dot(Function):
|
||||
# TODO: trace
|
||||
# TODO: tril
|
||||
# TODO: triu
|
||||
|
||||
|
||||
40
torch/autograd/_functions/compare.py
Normal file
40
torch/autograd/_functions/compare.py
Normal file
@ -0,0 +1,40 @@
|
||||
import torch
|
||||
|
||||
from ..function import Function
|
||||
|
||||
|
||||
class _CompareOp(Function):
|
||||
|
||||
def __init__(self, scalar=None):
|
||||
super(_CompareOp, self).__init__()
|
||||
self.scalar = scalar
|
||||
|
||||
def forward(self, tensor1, tensor2=None):
|
||||
other = tensor2 if tensor2 is not None else self.scalar
|
||||
mask = getattr(tensor1, self.fn_name)(other)
|
||||
self.mark_non_differentiable(mask)
|
||||
return mask
|
||||
|
||||
|
||||
class Eq(_CompareOp):
|
||||
fn_name = 'eq'
|
||||
|
||||
|
||||
class Ne(_CompareOp):
|
||||
fn_name = 'ne'
|
||||
|
||||
|
||||
class Gt(_CompareOp):
|
||||
fn_name = 'gt'
|
||||
|
||||
|
||||
class Ge(_CompareOp):
|
||||
fn_name = 'ge'
|
||||
|
||||
|
||||
class Lt(_CompareOp):
|
||||
fn_name = 'lt'
|
||||
|
||||
|
||||
class Le(_CompareOp):
|
||||
fn_name = 'le'
|
||||
@ -42,4 +42,3 @@ class Triu(Function):
|
||||
return grad_output.triu(self.diagonal_idx)
|
||||
|
||||
# TODO: trace
|
||||
|
||||
|
||||
@ -165,6 +165,7 @@ class Tan(Function):
|
||||
|
||||
|
||||
class Asin(Function):
|
||||
|
||||
def forward(self, i):
|
||||
self.save_for_backward(i)
|
||||
return i.asin()
|
||||
@ -175,6 +176,7 @@ class Asin(Function):
|
||||
|
||||
|
||||
class Acos(Function):
|
||||
|
||||
def forward(self, i):
|
||||
self.save_for_backward(i)
|
||||
return i.acos()
|
||||
@ -185,6 +187,7 @@ class Acos(Function):
|
||||
|
||||
|
||||
class Atan(Function):
|
||||
|
||||
def forward(self, i):
|
||||
self.save_for_backward(i)
|
||||
return i.atan()
|
||||
|
||||
@ -4,6 +4,7 @@ from ..function import Function
|
||||
|
||||
|
||||
class _DimReduceFunction(Function):
|
||||
|
||||
def __init__(self, dim=None):
|
||||
super(_DimReduceFunction, self).__init__()
|
||||
self.dim = dim
|
||||
@ -139,6 +140,7 @@ class Kthvalue(_SelectionFunction):
|
||||
|
||||
|
||||
class Norm(Function):
|
||||
|
||||
def __init__(self, norm_type=2, dim=None):
|
||||
super(Norm, self).__init__()
|
||||
self.norm_type = norm_type
|
||||
|
||||
@ -65,7 +65,7 @@ class Normal(StochasticFunction):
|
||||
output.mul_(stddevs)
|
||||
else:
|
||||
raise RuntimeError("Normal function requires specifying a common "
|
||||
"stddev, or per-sample stddev")
|
||||
"stddev, or per-sample stddev")
|
||||
output.add_(means)
|
||||
self.save_for_backward(output, means, stddevs)
|
||||
self.mark_non_differentiable(output)
|
||||
@ -74,7 +74,7 @@ class Normal(StochasticFunction):
|
||||
def backward(self, reward):
|
||||
output, means, stddevs = self.saved_tensors
|
||||
grad_stddevs = None
|
||||
grad_means = means - output # == -(output - means)
|
||||
grad_means = means - output # == -(output - means)
|
||||
assert self.stddev is not None or stddevs is not None
|
||||
if self.stddev is not None:
|
||||
grad_means /= 1e-6 + self.stddev ** 2
|
||||
@ -88,4 +88,3 @@ class Normal(StochasticFunction):
|
||||
grad_means /= stddevs_sq
|
||||
grad_means *= reward
|
||||
return grad_means, grad_stddevs
|
||||
|
||||
|
||||
@ -35,18 +35,18 @@ class SetItem(InplaceFunction):
|
||||
self.mark_dirty(i)
|
||||
if value is None:
|
||||
value = self.value
|
||||
i.set_index(self.index, value)
|
||||
i._set_index(self.index, value)
|
||||
return i
|
||||
|
||||
def backward(self, grad_output):
|
||||
if self.value is None:
|
||||
grad_input = grad_output.clone()
|
||||
grad_input.set_index(self.index, 0)
|
||||
grad_input._set_index(self.index, 0)
|
||||
grad_value = grad_output.index(self.index).clone()
|
||||
return grad_input, grad_value
|
||||
else:
|
||||
grad_input = grad_output.clone()
|
||||
grad_input.set_index(self.index, 0)
|
||||
grad_input._set_index(self.index, 0)
|
||||
return grad_input
|
||||
|
||||
|
||||
@ -103,6 +103,7 @@ class View(Function):
|
||||
|
||||
|
||||
class Expand(Function):
|
||||
|
||||
def __init__(self, sizes):
|
||||
super(Expand, self).__init__()
|
||||
self.sizes = sizes
|
||||
@ -110,8 +111,8 @@ class Expand(Function):
|
||||
|
||||
def forward(self, i):
|
||||
self.expanded_dims = [dim for dim, (expanded, original)
|
||||
in enumerate(zip(self.sizes, i.size()))
|
||||
if expanded != original]
|
||||
in enumerate(zip(self.sizes, i.size()))
|
||||
if expanded != original]
|
||||
result = i.expand(*self.sizes)
|
||||
self.mark_shared_storage((i, result))
|
||||
return result
|
||||
@ -288,7 +289,7 @@ class IndexSelect(Function):
|
||||
if self.needs_input_grad[0]:
|
||||
index, = self.saved_tensors
|
||||
grad_tensor = grad_output.new(*self.input_size).zero_()
|
||||
grad_tensor.index_copy_(self.dim, index, grad_output)
|
||||
grad_tensor.index_add_(self.dim, index, grad_output)
|
||||
|
||||
return grad_tensor, None
|
||||
|
||||
@ -304,8 +305,8 @@ class Concat(Function):
|
||||
return torch.cat(inputs, self.dim)
|
||||
|
||||
def backward(self, grad_output):
|
||||
return tuple(grad_output.narrow(self.dim, end-size, size) for size, end
|
||||
in zip(self.input_sizes, _accumulate(self.input_sizes)))
|
||||
return tuple(grad_output.narrow(self.dim, end - size, size) for size, end
|
||||
in zip(self.input_sizes, _accumulate(self.input_sizes)))
|
||||
|
||||
|
||||
class Resize(Function):
|
||||
@ -318,11 +319,11 @@ class Resize(Function):
|
||||
def forward(self, tensor):
|
||||
if tensor.numel() != self.numel:
|
||||
raise RuntimeError(("requested resize to {} ({} elements in total), "
|
||||
"but the given tensor has a size of {} ({} elements). "
|
||||
"autograd's resize can only change the shape of a given "
|
||||
"tensor, while preserving the number of elements. ").format(
|
||||
'x'.join(map(str, self.sizes)), self.numel,
|
||||
'x'.join(map(str, tensor.size())), tensor.numel()))
|
||||
"but the given tensor has a size of {} ({} elements). "
|
||||
"autograd's resize can only change the shape of a given "
|
||||
"tensor, while preserving the number of elements. ").format(
|
||||
'x'.join(map(str, self.sizes)), self.numel,
|
||||
'x'.join(map(str, tensor.size())), tensor.numel()))
|
||||
self.input_sizes = tensor.size()
|
||||
result = tensor.new(tensor).resize_(*self.sizes)
|
||||
self.mark_shared_storage((tensor, result))
|
||||
@ -474,7 +475,7 @@ class _MultiSelectionFunction(Function):
|
||||
|
||||
class Sort(_MultiSelectionFunction):
|
||||
|
||||
def __init__(self, dim=None, descending=False, return_indices=False):
|
||||
def __init__(self, dim=None, descending=False, return_indices=True):
|
||||
super(Sort, self).__init__(dim, return_indices)
|
||||
self.descending = descending
|
||||
|
||||
@ -486,14 +487,14 @@ class Sort(_MultiSelectionFunction):
|
||||
|
||||
class Topk(_MultiSelectionFunction):
|
||||
|
||||
def __init__(self, k, dim=None, largest=True, sort=True, return_indices=False):
|
||||
def __init__(self, k, dim=None, largest=True, sort=True, return_indices=True):
|
||||
super(Topk, self).__init__(dim, return_indices)
|
||||
self.k = k
|
||||
self.largest = largest
|
||||
self.sort = sort
|
||||
|
||||
def forward(self, input):
|
||||
dim = self.dim if self.dim is not None else input.dim()-1
|
||||
dim = self.dim if self.dim is not None else input.dim() - 1
|
||||
self.args = (self.k, dim, self.largest, self.sort)
|
||||
return super(Topk, self).forward(input)
|
||||
|
||||
@ -567,9 +568,22 @@ class Scatter(InplaceFunction):
|
||||
return grad_input, None, grad_source
|
||||
|
||||
|
||||
# TODO: kthvalue
|
||||
# TODO: repeat
|
||||
# TODO: sort
|
||||
# TODO: split
|
||||
# TODO: topk
|
||||
class Repeat(Function):
|
||||
|
||||
def __init__(self, repeats):
|
||||
super(Repeat, self).__init__()
|
||||
self.repeats = repeats
|
||||
|
||||
def forward(self, input):
|
||||
return input.repeat(self.repeats)
|
||||
|
||||
def backward(self, grad_output):
|
||||
grad_input = grad_output
|
||||
for dim, repeat in enumerate(self.repeats):
|
||||
if repeat == 1:
|
||||
continue
|
||||
grad_input = sum(grad_input.chunk(repeat, dim))
|
||||
return grad_input
|
||||
|
||||
|
||||
# TODO: unfold
|
||||
|
||||
@ -71,8 +71,8 @@ class BasicEngine(object):
|
||||
else:
|
||||
if prev_fn.num_outputs != 1:
|
||||
raise RuntimeError("one of the function outputs "
|
||||
"wasn't used - this is an error not, but "
|
||||
"it's going to be fixed soon")
|
||||
"wasn't used - this is an error not, but "
|
||||
"it's going to be fixed soon")
|
||||
prev_grad = (d_prev_fn,)
|
||||
ready.appendleft((prev_fn, prev_grad))
|
||||
else:
|
||||
|
||||
@ -154,9 +154,10 @@ def _nested_map(condition, fn):
|
||||
return type(obj)(_map(x) for x in obj)
|
||||
else:
|
||||
raise ValueError("NestedIOFunction doesn't know how to process "
|
||||
"an input object of type " + torch.typename(obj))
|
||||
"an input object of type " + torch.typename(obj))
|
||||
return _map
|
||||
|
||||
|
||||
def _iter_filter(condition):
|
||||
def _iter(obj):
|
||||
if condition(obj):
|
||||
@ -169,17 +170,29 @@ def _iter_filter(condition):
|
||||
yield var
|
||||
else:
|
||||
raise ValueError("NestedIOFunction doesn't know how to process "
|
||||
"an input object of type " + torch.typename(obj))
|
||||
"an input object of type " + torch.typename(obj))
|
||||
return _iter
|
||||
|
||||
|
||||
def _unflatten(input, proto):
|
||||
# unflatten a list or tuple input into a nested list/tuple structure
|
||||
# specified by proto
|
||||
def unflatten_helper(input, proto):
|
||||
res = []
|
||||
if not isinstance(proto, (list, tuple)):
|
||||
return input[0], input[1:]
|
||||
for e in proto:
|
||||
res_e, input = unflatten_helper(input, e)
|
||||
res.append(res_e)
|
||||
return type(proto)(res), input
|
||||
|
||||
return unflatten_helper(input, proto)[0]
|
||||
|
||||
_iter_variables = _iter_filter(lambda o: isinstance(o, torch.autograd.Variable))
|
||||
_iter_tensors = _iter_filter(torch.is_tensor)
|
||||
_iter_None_tensors = _iter_filter(lambda o: o is None or torch.is_tensor(o))
|
||||
_map_variable_tensor = _nested_map(lambda o: isinstance(o, torch.autograd.Variable), lambda o: o.data)
|
||||
|
||||
def _map_tensor_fromiter(itr):
|
||||
return _nested_map(lambda o: torch.is_tensor(o), lambda o: next(itr))
|
||||
|
||||
class NestedIOFunction(Function):
|
||||
|
||||
@ -188,11 +201,11 @@ class NestedIOFunction(Function):
|
||||
flat_input = tuple(_iter_variables(input))
|
||||
flat_output = super(NestedIOFunction, self)._do_forward(*flat_input)
|
||||
nested_output = self._nested_output
|
||||
nested_variables = _map_tensor_fromiter(iter(flat_output))(self._nested_output)
|
||||
nested_variables = _unflatten(flat_output, self._nested_output)
|
||||
return nested_variables
|
||||
|
||||
def backward(self, *gradients):
|
||||
nested_gradients = _map_tensor_fromiter(iter(gradients))(self._nested_output)
|
||||
nested_gradients = _unflatten(gradients, self._nested_output)
|
||||
del self._nested_output
|
||||
result = self.backward_extended(*nested_gradients)
|
||||
del self._to_save_nested
|
||||
@ -214,7 +227,7 @@ class NestedIOFunction(Function):
|
||||
@property
|
||||
def saved_tensors(self):
|
||||
flat_tensors = super(NestedIOFunction, self).saved_tensors
|
||||
return _map_tensor_fromiter(iter(flat_tensors))(self._to_save_nested)
|
||||
return _unflatten(flat_tensors, self._to_save_nested)
|
||||
|
||||
def mark_dirty(self, *args, **kwargs):
|
||||
self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
|
||||
|
||||
@ -2,6 +2,7 @@ from .function import Function
|
||||
|
||||
_NOT_PROVIDED = object()
|
||||
|
||||
|
||||
class StochasticFunction(Function):
|
||||
|
||||
def __init__(self):
|
||||
@ -10,7 +11,7 @@ class StochasticFunction(Function):
|
||||
def _do_backward(self, grad_output, retain_variables):
|
||||
if self.reward is _NOT_PROVIDED:
|
||||
raise RuntimeError("differentiating stochastic functions requires "
|
||||
"providing a reward")
|
||||
"providing a reward")
|
||||
result = super(StochasticFunction, self)._do_backward((self.reward,), retain_variables)
|
||||
if not retain_variables:
|
||||
self.reward = None
|
||||
@ -18,4 +19,3 @@ class StochasticFunction(Function):
|
||||
|
||||
def _reinforce(self, reward):
|
||||
self.reward = reward
|
||||
|
||||
|
||||
@ -72,12 +72,12 @@ class Variable(_C._VariableBase):
|
||||
if self.creator is not None:
|
||||
if value is False:
|
||||
hint = (" If you want to use a computed variable in a subgraph "
|
||||
"that doesn't require differentiation use "
|
||||
"var_no_grad = var.detach().")
|
||||
"that doesn't require differentiation use "
|
||||
"var_no_grad = var.detach().")
|
||||
else:
|
||||
hint = ''
|
||||
raise RuntimeError("you can only change requires_grad flags of "
|
||||
"leaf variables." + hint)
|
||||
"leaf variables." + hint)
|
||||
self._requires_grad = value
|
||||
|
||||
def __getattr__(self, name):
|
||||
@ -87,13 +87,13 @@ class Variable(_C._VariableBase):
|
||||
|
||||
def __getitem__(self, key):
|
||||
if (isinstance(key, Variable) and
|
||||
type(key.data).__name__ == 'ByteTensor'):
|
||||
type(key.data).__name__ == 'ByteTensor'):
|
||||
return MaskedSelect()(self, key)
|
||||
return Index(key)(self)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if (isinstance(key, Variable) and
|
||||
type(key.data).__name__ == 'ByteTensor'):
|
||||
type(key.data).__name__ == 'ByteTensor'):
|
||||
if isinstance(value, Variable):
|
||||
return MaskedCopy(inplace=True)(self, key, value)
|
||||
else:
|
||||
@ -107,9 +107,9 @@ class Variable(_C._VariableBase):
|
||||
def __deepcopy__(self, memo):
|
||||
if self.creator is not None:
|
||||
raise RuntimeError("Only Variables created explicitly by the user "
|
||||
"(graph leaves) support the deepcopy protocol at the moment")
|
||||
"(graph leaves) support the deepcopy protocol at the moment")
|
||||
result = type(self)(self.data.clone(), requires_grad=self.requires_grad,
|
||||
volatile=self.volatile)
|
||||
volatile=self.volatile)
|
||||
memo[id(self)] = result
|
||||
return result
|
||||
|
||||
@ -151,7 +151,9 @@ class Variable(_C._VariableBase):
|
||||
raise RuntimeError('calling backward on a volatile variable')
|
||||
if gradient is None and self.requires_grad:
|
||||
if self.data.numel() != 1:
|
||||
raise RuntimeError('backward should be called only on a scalar (i.e. 1-element tensor) or with gradient w.r.t. the variable')
|
||||
raise RuntimeError(
|
||||
'backward should be called only on a scalar (i.e. 1-element tensor) '
|
||||
'or with gradient w.r.t. the variable')
|
||||
gradient = self.data.new().resize_as_(self.data).fill_(1)
|
||||
self._execution_engine.run_backward((self,), (gradient,), retain_variables)
|
||||
|
||||
@ -219,7 +221,7 @@ class Variable(_C._VariableBase):
|
||||
"""
|
||||
if not isinstance(self.creator, StochasticFunction):
|
||||
raise RuntimeError("reinforce() can be only called on outputs "
|
||||
"of stochastic functions")
|
||||
"of stochastic functions")
|
||||
self.creator._reinforce(reward)
|
||||
|
||||
def detach(self):
|
||||
@ -392,7 +394,7 @@ class Variable(_C._VariableBase):
|
||||
def clamp(self, min=None, max=None):
|
||||
if min is None and max is None:
|
||||
raise ValueError("clamp requires specifying at least one of "
|
||||
"min and max arguments")
|
||||
"min and max arguments")
|
||||
elif min is None and max is not None:
|
||||
return CminConstant(max)(self)
|
||||
elif min is not None and max is None:
|
||||
@ -482,6 +484,40 @@ class Variable(_C._VariableBase):
|
||||
def view_as(self, tensor):
|
||||
return View(*tensor.size())(self)
|
||||
|
||||
def split(self, split_size, dim=0):
|
||||
return torch.split(self, split_size, dim)
|
||||
|
||||
def chunk(self, n_chunks, dim=0):
|
||||
return torch.chunk(self, n_chunks, dim)
|
||||
|
||||
def repeat(self, *repeats):
|
||||
if len(repeats) == 1 and isinstance(repeats[0], torch.Size):
|
||||
repeats = repeats[0]
|
||||
else:
|
||||
repeats = torch.Size(repeats)
|
||||
return Repeat(repeats)(self)
|
||||
|
||||
def var(self, dim=None, unbiased=True):
|
||||
mean = self.mean(dim)
|
||||
if dim is None:
|
||||
mean = mean.view(*(1 for s in self.size()))
|
||||
mean_expanded = mean.expand_as(self)
|
||||
zero_centered = self.sub(mean_expanded)
|
||||
var = zero_centered.mul(zero_centered).sum(dim)
|
||||
numel = self.numel() if dim is None else self.size(dim)
|
||||
return var.div(numel - int(unbiased))
|
||||
|
||||
def std(self, dim=None, unbiased=True):
|
||||
return self.var(dim, unbiased).sqrt()
|
||||
|
||||
def renorm(self, norm_type, dim, maxnorm):
|
||||
t = self.transpose(dim, 0)
|
||||
flat = t.contiguous().view(self.size(0), -1)
|
||||
norms = flat.norm(norm_type, 1)
|
||||
norms = norms.clamp(max=maxnorm).div(norms.add(1e-7))
|
||||
flat_out = flat.mul(norms.expand_as(flat))
|
||||
return flat_out.view(t.size()).transpose(dim, 0)
|
||||
|
||||
@staticmethod
|
||||
def _static_blas(cls, args, inplace):
|
||||
num_args = len(args)
|
||||
@ -503,7 +539,7 @@ class Variable(_C._VariableBase):
|
||||
|
||||
def bmm(self, batch):
|
||||
output = Variable(self.data.new(self.data.size(0), self.data.size(1),
|
||||
batch.data.size(2)))
|
||||
batch.data.size(2)))
|
||||
return self._static_blas(Baddbmm, (output, 0, 1, self, batch), False)
|
||||
|
||||
def mv(self, vector):
|
||||
@ -622,7 +658,7 @@ class Variable(_C._VariableBase):
|
||||
if isinstance(sizes[0], torch.Size):
|
||||
if len(sizes) > 1:
|
||||
raise ValueError("expand expects a several ints or a single "
|
||||
"torch.Size argument")
|
||||
"torch.Size argument")
|
||||
sizes = sizes[0]
|
||||
return Expand(sizes)(self)
|
||||
|
||||
@ -641,7 +677,7 @@ class Variable(_C._VariableBase):
|
||||
|
||||
def narrow(self, dim, start_index, length):
|
||||
index = tuple(slice(None, None) for _ in range(dim)) + \
|
||||
(slice(start_index, start_index+length),)
|
||||
(slice(start_index, start_index + length),)
|
||||
|
||||
return Index(index)(self)
|
||||
|
||||
@ -672,6 +708,42 @@ class Variable(_C._VariableBase):
|
||||
def bernoulli(self):
|
||||
return Bernoulli()(self)
|
||||
|
||||
def eq(self, other):
|
||||
if isinstance(other, Variable):
|
||||
return Eq()(self, other)
|
||||
assert not torch.is_tensor(other), "can't compare Variable and tensor"
|
||||
return Eq(other)(self)
|
||||
|
||||
def ne(self, other):
|
||||
if isinstance(other, Variable):
|
||||
return Ne()(self, other)
|
||||
assert not torch.is_tensor(other), "can't compare Variable and tensor"
|
||||
return Ne(other)(self)
|
||||
|
||||
def gt(self, other):
|
||||
if isinstance(other, Variable):
|
||||
return Gt()(self, other)
|
||||
assert not torch.is_tensor(other), "can't compare Variable and tensor"
|
||||
return Gt(other)(self)
|
||||
|
||||
def ge(self, other):
|
||||
if isinstance(other, Variable):
|
||||
return Ge()(self, other)
|
||||
assert not torch.is_tensor(other), "can't compare Variable and tensor"
|
||||
return Ge(other)(self)
|
||||
|
||||
def lt(self, other):
|
||||
if isinstance(other, Variable):
|
||||
return Lt()(self, other)
|
||||
assert not torch.is_tensor(other), "can't compare Variable and tensor"
|
||||
return Lt(other)(self)
|
||||
|
||||
def le(self, other):
|
||||
if isinstance(other, Variable):
|
||||
return Le()(self, other)
|
||||
assert not torch.is_tensor(other), "can't compare Variable and tensor"
|
||||
return Le(other)(self)
|
||||
|
||||
def __add__(self, other):
|
||||
return self.add(other)
|
||||
__radd__ = __add__
|
||||
@ -710,7 +782,7 @@ class Variable(_C._VariableBase):
|
||||
elif dim_self == 2 and dim_other == 2:
|
||||
return self.mm(other)
|
||||
raise ValueError("both arguments to __matmul__ need to be 1D or 2D, "
|
||||
"but they are {}D and {}D".format(dim_self, dim_other))
|
||||
"but they are {}D and {}D".format(dim_self, dim_other))
|
||||
|
||||
def __div__(self, other):
|
||||
return self.div(other)
|
||||
@ -741,6 +813,30 @@ class Variable(_C._VariableBase):
|
||||
def __iter__(self):
|
||||
return iter(map(lambda i: self[i], range(self.size(0))))
|
||||
|
||||
def __mod__(self, other):
|
||||
return self.remainder(other)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.eq(other)
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.ne(other)
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.lt(other)
|
||||
|
||||
def __le__(self, other):
|
||||
return self.le(other)
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.gt(other)
|
||||
|
||||
def __ge__(self, other):
|
||||
return self.ge(other)
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
class _torch(object):
|
||||
|
||||
@staticmethod
|
||||
@ -748,11 +844,11 @@ class Variable(_C._VariableBase):
|
||||
return Concat(dim)(*iterable)
|
||||
|
||||
@staticmethod
|
||||
def normal(means, stddev=1):
|
||||
if isinstance(stddev, Variable):
|
||||
return Normal()(means, stddev)
|
||||
def normal(means, std=1):
|
||||
if isinstance(std, Variable):
|
||||
return Normal()(means, std)
|
||||
else:
|
||||
return Normal(stddev)(means)
|
||||
return Normal(std)(means)
|
||||
|
||||
@staticmethod
|
||||
def _blas(cls, args, inplace):
|
||||
|
||||
@ -14,12 +14,14 @@ lib = None
|
||||
thisdir = path.dirname(__file__)
|
||||
libpaths = ['', path.join(thisdir, '../../lib')]
|
||||
if sys.platform.startswith('linux'):
|
||||
libnames = ['libcudnn.so.5.1.5', 'libcudnn.so.5.1.3', 'libcudnn.so.5.0.5', 'libcudnn.so.5.1.10']
|
||||
libnames = ['libcudnn.so.6.0.5', 'libcudnn.so.6.0.10', 'libcudnn.so.5.1.5', 'libcudnn.so.5.1.3',
|
||||
'libcudnn.so.5.0.5', 'libcudnn.so.5.1.10']
|
||||
elif sys.platform == 'darwin':
|
||||
libnames = ['libcudnn.5.dylib']
|
||||
libnames = ['libcudnn.6.dylib', 'libcudnn.5.dylib']
|
||||
else:
|
||||
libnames = []
|
||||
|
||||
|
||||
def _loadlib():
|
||||
global lib
|
||||
loaded = False
|
||||
@ -39,6 +41,7 @@ def _loadlib():
|
||||
lib = None
|
||||
raise OSError("Could not load cuDNN")
|
||||
|
||||
|
||||
def is_acceptable(tensor):
|
||||
if not enabled:
|
||||
return False
|
||||
@ -58,13 +61,15 @@ def is_acceptable(tensor):
|
||||
return False
|
||||
if not _C.has_cudnn:
|
||||
warnings.warn("cuDNN library has been detected, but your pytorch "
|
||||
"installation was compiled without support for it. You "
|
||||
"might want to rebuild pytorch, making sure the library "
|
||||
"is visible to the build system.")
|
||||
"installation was compiled without support for it. You "
|
||||
"might want to rebuild pytorch, making sure the library "
|
||||
"is visible to the build system.")
|
||||
return False
|
||||
return True
|
||||
|
||||
__cudnn_version = []
|
||||
|
||||
|
||||
def version():
|
||||
if not lib:
|
||||
raise RuntimeError("cuDNN not initialized")
|
||||
@ -108,7 +113,16 @@ CUDNN_GRU = 3
|
||||
CUDNN_LINEAR_INPUT = 0
|
||||
CUDNN_SKIP_INPUT = 1
|
||||
|
||||
CUDNN_NON_DETERMINISTIC = 0
|
||||
CUDNN_DETERMINISTIC = 1
|
||||
|
||||
CUDNN_RNN_ALGO_STANDARD = 0
|
||||
CUDNN_RNN_ALGO_PERSIST_STATIC = 1
|
||||
CUDNN_RNN_ALGO_PERSIST_DYNAMIC = 2
|
||||
|
||||
|
||||
class CuDNNHandle:
|
||||
|
||||
def __init__(self):
|
||||
ptr = ctypes.c_void_p()
|
||||
check_error(lib.cudnnCreate(ctypes.byref(ptr)))
|
||||
@ -117,7 +131,9 @@ class CuDNNHandle:
|
||||
def __del__(self):
|
||||
check_error(lib.cudnnDestroy(self))
|
||||
|
||||
|
||||
class CuDNNError(RuntimeError):
|
||||
|
||||
def __init__(self, status):
|
||||
self.status = status
|
||||
msg = '{}: {}'.format(status, get_error_string(status))
|
||||
@ -125,6 +141,7 @@ class CuDNNError(RuntimeError):
|
||||
|
||||
|
||||
class TensorDescriptor(object):
|
||||
|
||||
def __init__(self):
|
||||
ptr = ctypes.c_void_p()
|
||||
check_error(lib.cudnnCreateTensorDescriptor(ctypes.byref(ptr)))
|
||||
@ -147,6 +164,7 @@ class TensorDescriptor(object):
|
||||
|
||||
|
||||
class TensorDescriptorArray(object):
|
||||
|
||||
def __init__(self, N):
|
||||
self.ptrs = (ctypes.c_void_p * N)()
|
||||
for i in range(N):
|
||||
@ -175,6 +193,7 @@ class TensorDescriptorArray(object):
|
||||
|
||||
|
||||
class ConvolutionDescriptor(object):
|
||||
|
||||
def __init__(self):
|
||||
ptr = ctypes.c_void_p()
|
||||
check_error(lib.cudnnCreateConvolutionDescriptor(ctypes.byref(ptr)))
|
||||
@ -195,7 +214,9 @@ class ConvolutionDescriptor(object):
|
||||
def as_tuple(self):
|
||||
return (self._pad, self._stride)
|
||||
|
||||
|
||||
class FilterDescriptor(object):
|
||||
|
||||
def __init__(self):
|
||||
ptr = ctypes.c_void_p()
|
||||
check_error(lib.cudnnCreateFilterDescriptor(ctypes.byref(ptr)))
|
||||
@ -216,6 +237,7 @@ class FilterDescriptor(object):
|
||||
|
||||
|
||||
class DropoutDescriptor(object):
|
||||
|
||||
def __init__(self, handle, dropout, seed):
|
||||
ptr = ctypes.c_void_p()
|
||||
check_error(lib.cudnnCreateDropoutDescriptor(ctypes.byref(ptr)))
|
||||
@ -241,30 +263,43 @@ class DropoutDescriptor(object):
|
||||
check_error(lib.cudnnDestroyDropoutDescriptor(self))
|
||||
|
||||
|
||||
|
||||
class RNNDescriptor(object):
|
||||
def __init__(self, hidden_size, num_layers, dropout_desc, input_mode,
|
||||
bidirectional, mode, datatype):
|
||||
|
||||
def __init__(self, handle, hidden_size, num_layers, dropout_desc, input_mode,
|
||||
bidirectional, mode, datatype):
|
||||
ptr = ctypes.c_void_p()
|
||||
check_error(lib.cudnnCreateRNNDescriptor(ctypes.byref(ptr)))
|
||||
self._as_parameter_ = ptr
|
||||
|
||||
check_error(lib.cudnnSetRNNDescriptor(
|
||||
self,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout_desc,
|
||||
input_mode,
|
||||
bidirectional,
|
||||
mode,
|
||||
datatype
|
||||
))
|
||||
if version() >= 6000:
|
||||
check_error(lib.cudnnSetRNNDescriptor_v6(
|
||||
handle,
|
||||
self,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout_desc,
|
||||
input_mode,
|
||||
bidirectional,
|
||||
mode,
|
||||
CUDNN_RNN_ALGO_STANDARD,
|
||||
datatype
|
||||
))
|
||||
else:
|
||||
check_error(lib.cudnnSetRNNDescriptor(
|
||||
self,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout_desc,
|
||||
input_mode,
|
||||
bidirectional,
|
||||
mode,
|
||||
datatype
|
||||
))
|
||||
|
||||
def __del__(self):
|
||||
check_error(lib.cudnnDestroyRNNDescriptor(self))
|
||||
|
||||
|
||||
class ConvolutionAlgoPerf(ctypes.Structure):
|
||||
class ConvolutionAlgoPerf_v5(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("algo", ctypes.c_int),
|
||||
("status", ctypes.c_int),
|
||||
@ -272,13 +307,27 @@ class ConvolutionAlgoPerf(ctypes.Structure):
|
||||
("memory", ctypes.c_size_t),
|
||||
]
|
||||
|
||||
|
||||
class ConvolutionAlgoPerf_v6(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("algo", ctypes.c_int),
|
||||
("status", ctypes.c_int),
|
||||
("time", ctypes.c_float),
|
||||
("memory", ctypes.c_size_t),
|
||||
("determinism", ctypes.c_int),
|
||||
("reserved", ctypes.c_int * 4)
|
||||
]
|
||||
|
||||
|
||||
def check_error(status):
|
||||
if status is not 0:
|
||||
raise CuDNNError(status)
|
||||
|
||||
|
||||
def get_error_string(status):
|
||||
return lib.cudnnGetErrorString(status)
|
||||
|
||||
|
||||
def get_handle():
|
||||
if lib is None:
|
||||
_loadlib()
|
||||
@ -296,11 +345,12 @@ _typemap = {
|
||||
}
|
||||
|
||||
_sizeofmap = {
|
||||
CUDNN_DATA_HALF : 2,
|
||||
CUDNN_DATA_FLOAT : 4,
|
||||
CUDNN_DATA_DOUBLE : 8,
|
||||
CUDNN_DATA_HALF: 2,
|
||||
CUDNN_DATA_FLOAT: 4,
|
||||
CUDNN_DATA_DOUBLE: 8,
|
||||
}
|
||||
|
||||
|
||||
def c_type(tensor):
|
||||
if isinstance(tensor, torch.cuda.HalfTensor):
|
||||
return ctypes.c_float
|
||||
@ -311,10 +361,12 @@ def c_type(tensor):
|
||||
else:
|
||||
raise ValueError("unknown type '{}'".format(type(tensor)))
|
||||
|
||||
|
||||
def int_array(itr):
|
||||
array_type = ctypes.c_int * len(itr)
|
||||
return array_type(*itr)
|
||||
|
||||
|
||||
def descriptor(tensor, N=None):
|
||||
if N is not None:
|
||||
descriptor = TensorDescriptorArray(N)
|
||||
@ -331,16 +383,21 @@ _autotuner_forward = {}
|
||||
_autotuner_backward_data = {}
|
||||
_autotuner_backward_filter = {}
|
||||
|
||||
|
||||
def convolution_autotuner_key(idesc, weight_desc, conv_desc):
|
||||
return (idesc.as_tuple(), weight_desc.as_tuple(), conv_desc.as_tuple())
|
||||
|
||||
|
||||
def convolution_forward_algorithm(idesc, weight_desc, conv_desc, odesc):
|
||||
k = convolution_autotuner_key(idesc, weight_desc, conv_desc)
|
||||
if k in _autotuner_forward:
|
||||
return _autotuner_forward[k]
|
||||
|
||||
if benchmark:
|
||||
perf_results = ConvolutionAlgoPerf()
|
||||
if version() < 6000:
|
||||
perf_results = ConvolutionAlgoPerf_v5()
|
||||
else:
|
||||
perf_results = ConvolutionAlgoPerf_v6()
|
||||
algo_count = ctypes.c_int()
|
||||
check_error(lib.cudnnFindConvolutionForwardAlgorithm(
|
||||
get_handle(), idesc, weight_desc, conv_desc, odesc, 1,
|
||||
@ -360,15 +417,19 @@ def convolution_forward_algorithm(idesc, weight_desc, conv_desc, odesc):
|
||||
wlimit, ctypes.byref(fwd_alg)))
|
||||
return fwd_alg
|
||||
|
||||
|
||||
def convolution_forward_workspace_size(*args):
|
||||
check_error(lib.cudnnGetConvolutionForwardWorkspaceSize(*args))
|
||||
|
||||
|
||||
def convolution_forward(*args):
|
||||
check_error(lib.cudnnConvolutionForward(*args))
|
||||
|
||||
|
||||
def convolution_backward_data(*args):
|
||||
return check_error(lib.cudnnConvolutionBackwardData(*args))
|
||||
|
||||
|
||||
def convolution_backward_data_algorithm(weight_desc, odesc, conv_desc, idesc):
|
||||
k = convolution_autotuner_key(idesc, weight_desc, conv_desc)
|
||||
if k in _autotuner_backward_data:
|
||||
@ -395,12 +456,15 @@ def convolution_backward_data_algorithm(weight_desc, odesc, conv_desc, idesc):
|
||||
wlimit, ctypes.byref(bwd_data_alg)))
|
||||
return bwd_data_alg
|
||||
|
||||
|
||||
def convolution_backward_data_workspace_size(*args):
|
||||
return check_error(lib.cudnnGetConvolutionBackwardDataWorkspaceSize(*args))
|
||||
|
||||
|
||||
def convolution_backward_filter(*args):
|
||||
return check_error(lib.cudnnConvolutionBackwardFilter(*args))
|
||||
|
||||
|
||||
def convolution_backward_filter_algorithm(idesc, odesc, conv_desc, weight_desc):
|
||||
k = convolution_autotuner_key(idesc, weight_desc, conv_desc)
|
||||
if k in _autotuner_backward_filter:
|
||||
@ -427,11 +491,14 @@ def convolution_backward_filter_algorithm(idesc, odesc, conv_desc, weight_desc):
|
||||
wlimit, ctypes.byref(bwd_filter_alg)))
|
||||
return bwd_filter_alg
|
||||
|
||||
|
||||
def convolution_backward_filter_workspace_size(*args):
|
||||
return check_error(lib.cudnnGetConvolutionBackwardFilterWorkspaceSize(*args))
|
||||
|
||||
|
||||
def convolution_backward_bias(*args):
|
||||
check_error(lib.cudnnConvolutionBackwardBias(*args))
|
||||
|
||||
|
||||
def add_tensor(*args):
|
||||
check_error(lib.cudnnAddTensor(*args))
|
||||
|
||||
@ -3,6 +3,7 @@ import torch.backends.cudnn as cudnn
|
||||
from torch.backends.cudnn import check_error
|
||||
import ctypes
|
||||
|
||||
|
||||
def get_cudnn_mode(mode):
|
||||
if mode == 'RNN_RELU':
|
||||
return cudnn.CUDNN_RNN_RELU
|
||||
@ -17,9 +18,10 @@ def get_cudnn_mode(mode):
|
||||
|
||||
|
||||
class Unserializable(object):
|
||||
|
||||
def __init__(self, inner):
|
||||
self.inner = inner
|
||||
|
||||
|
||||
def get(self):
|
||||
return self.inner
|
||||
|
||||
@ -39,8 +41,10 @@ def init_dropout_descriptor(fn, handle):
|
||||
fn.dropout_seed
|
||||
)
|
||||
|
||||
def init_rnn_descriptor(fn):
|
||||
|
||||
def init_rnn_descriptor(fn, handle):
|
||||
return cudnn.RNNDescriptor(
|
||||
handle,
|
||||
fn.hidden_size,
|
||||
fn.num_layers,
|
||||
fn.dropout_state['desc'].get(),
|
||||
@ -80,7 +84,7 @@ def get_num_weights(handle, rnn_desc, x_desc, datatype):
|
||||
datatype
|
||||
))
|
||||
elem_size = cudnn._sizeofmap[datatype]
|
||||
assert(weight_size.value % elem_size == 0)
|
||||
assert weight_size.value % elem_size == 0
|
||||
return weight_size.value // elem_size
|
||||
|
||||
|
||||
@ -139,10 +143,11 @@ def get_parameters(fn, handle, weight_buf):
|
||||
ctypes.byref(nb_dims),
|
||||
ctypes.c_void_p(filter_dim_a.data_ptr())))
|
||||
|
||||
filter_dim_a.resize_(nb_dims.value)
|
||||
assert nb_dims.value <= min_dim
|
||||
filter_dim_a = filter_dim_a[:nb_dims.value]
|
||||
elem_size = cudnn._sizeofmap[fn.datatype]
|
||||
offset_bytes = (matrix_pointer.value - weight_buf.data_ptr())
|
||||
assert(offset_bytes % elem_size == 0)
|
||||
assert offset_bytes % elem_size == 0
|
||||
offset = offset_bytes // elem_size
|
||||
|
||||
# for all the RNN types provided by CUDNN, all the ih weights
|
||||
@ -151,17 +156,16 @@ def get_parameters(fn, handle, weight_buf):
|
||||
# Since we're storing all the weights in a single tensor anyway,
|
||||
# might as well merge the CUDNN ones into a single tensor as well
|
||||
if linear_id == 0 or linear_id == num_linear_layers / 2:
|
||||
assert(filter_dim_a.prod() == filter_dim_a[0])
|
||||
assert filter_dim_a.prod() == filter_dim_a[0]
|
||||
param = fn.weight_buf.new().set_(
|
||||
weight_buf.storage(), offset,
|
||||
filter_dim_a[0] * num_linear_layers // 2, filter_dim_a[2])
|
||||
layer_params.append(param)
|
||||
else:
|
||||
assert(cur_offset == offset)
|
||||
assert cur_offset == offset
|
||||
|
||||
cur_offset = offset + filter_dim_a[0]
|
||||
|
||||
|
||||
params.append(layer_params)
|
||||
|
||||
return params
|
||||
@ -170,7 +174,7 @@ def get_parameters(fn, handle, weight_buf):
|
||||
def _copyParams(params_from, params_to):
|
||||
for layer_params_from, layer_params_to in zip(params_from, params_to):
|
||||
for param_from, param_to in zip(layer_params_from, layer_params_to):
|
||||
assert(param_from.type() == param_to.type())
|
||||
assert param_from.type() == param_to.type()
|
||||
param_to.copy_(param_from)
|
||||
|
||||
|
||||
@ -204,9 +208,9 @@ def forward(fn, input, hx, weight, output, hy):
|
||||
output_size = _output_size(fn)
|
||||
x = input.contiguous()
|
||||
output.resize_(*output_size)
|
||||
hy.resize_(*hidden_size).zero_()
|
||||
hy.resize_(*hidden_size)
|
||||
if cy is not None:
|
||||
cy.resize_(*hidden_size).zero_()
|
||||
cy.resize_(*hidden_size)
|
||||
y = output
|
||||
|
||||
# init descriptors
|
||||
@ -214,7 +218,7 @@ def forward(fn, input, hx, weight, output, hy):
|
||||
fn.dropout_state['desc'] = Unserializable(
|
||||
init_dropout_descriptor(fn, handle)
|
||||
)
|
||||
fn.rnn_desc = init_rnn_descriptor(fn)
|
||||
fn.rnn_desc = init_rnn_descriptor(fn, handle)
|
||||
fn.x_descs = cudnn.descriptor(x[0], fn.seq_length)
|
||||
fn.y_descs = cudnn.descriptor(y[0], fn.seq_length)
|
||||
fn.hx_desc = cudnn.descriptor(hx)
|
||||
@ -237,7 +241,7 @@ def forward(fn, input, hx, weight, output, hy):
|
||||
|
||||
if tuple(hx.size()) != hidden_size:
|
||||
raise RuntimeError('Expected hidden size {}, got {}'.format(
|
||||
hidden_size, tuple(hx.size())))
|
||||
hidden_size, tuple(hx.size())))
|
||||
if cx is not None and tuple(cx.size()) != hidden_size:
|
||||
raise RuntimeError('Expected cell size {}, got {}'.format(
|
||||
hidden_size, tuple(cx.size())))
|
||||
@ -295,7 +299,6 @@ def forward(fn, input, hx, weight, output, hy):
|
||||
output = output.transpose_(0, 1)
|
||||
|
||||
|
||||
|
||||
def backward_grad(fn, input, hx, weight, output, grad_output, grad_hy, grad_input, grad_hx):
|
||||
with torch.cuda.device_of(input):
|
||||
handle = cudnn.get_handle()
|
||||
@ -321,8 +324,8 @@ def backward_grad(fn, input, hx, weight, output, grad_output, grad_hy, grad_inpu
|
||||
y = output
|
||||
w = fn.weight_buf
|
||||
dx = grad_input.resize_as_(input)
|
||||
dhy = grad_hy.resize_(*hidden_size)
|
||||
dcy = grad_cy.resize_(*hidden_size) if grad_cy is not None else None
|
||||
dhy = grad_hy.contiguous().view(*hidden_size)
|
||||
dcy = grad_cy.contiguous().view(*hidden_size) if grad_cy is not None else None
|
||||
dhx = grad_hx.resize_(*hidden_size)
|
||||
dcx = grad_cx.resize_(*hidden_size) if grad_cx is not None else None
|
||||
|
||||
|
||||
@ -697,8 +697,30 @@ bool THCSPShortTensor_init(PyObject *module);
|
||||
bool THCSPCharTensor_init(PyObject *module);
|
||||
bool THCSPByteTensor_init(PyObject *module);
|
||||
|
||||
bool THDPDoubleStorage_init(PyObject *module);
|
||||
bool THDPFloatStorage_init(PyObject *module);
|
||||
//bool THDPHalfStorage_init(PyObject *module);
|
||||
bool THDPLongStorage_init(PyObject *module);
|
||||
bool THDPIntStorage_init(PyObject *module);
|
||||
bool THDPShortStorage_init(PyObject *module);
|
||||
bool THDPCharStorage_init(PyObject *module);
|
||||
bool THDPByteStorage_init(PyObject *module);
|
||||
|
||||
bool THDPDoubleTensor_init(PyObject *module);
|
||||
bool THDPFloatTensor_init(PyObject *module);
|
||||
//bool THDPHalfTensor_init(PyObject *module);
|
||||
bool THDPLongTensor_init(PyObject *module);
|
||||
bool THDPIntTensor_init(PyObject *module);
|
||||
bool THDPShortTensor_init(PyObject *module);
|
||||
bool THDPCharTensor_init(PyObject *module);
|
||||
bool THDPByteTensor_init(PyObject *module);
|
||||
|
||||
static std::vector<PyMethodDef> methods;
|
||||
|
||||
#ifdef WITH_DISTRIBUTED
|
||||
PyMethodDef* THDPModule_methods();
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION == 2
|
||||
PyMODINIT_FUNC init_C()
|
||||
#else
|
||||
@ -716,6 +738,9 @@ PyMODINIT_FUNC PyInit__C()
|
||||
#ifdef WITH_CUDNN
|
||||
THPUtils_addPyMethodDefs(methods, THCUDNN_methods());
|
||||
#endif
|
||||
#ifdef WITH_DISTRIBUTED
|
||||
THPUtils_addPyMethodDefs(methods, THDPModule_methods());
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION == 2
|
||||
ASSERT_TRUE(module = Py_InitModule("torch._C", methods.data()));
|
||||
@ -729,6 +754,7 @@ PyMODINIT_FUNC PyInit__C()
|
||||
};
|
||||
ASSERT_TRUE(module = PyModule_Create(&torchmodule));
|
||||
#endif
|
||||
ASSERT_TRUE(THPWrapper_init(module));
|
||||
ASSERT_TRUE(THPGenerator_init(module));
|
||||
ASSERT_TRUE(THPException_init(module));
|
||||
ASSERT_TRUE(THPSize_init(module));
|
||||
@ -796,7 +822,6 @@ PyMODINIT_FUNC PyInit__C()
|
||||
#endif
|
||||
|
||||
#ifdef WITH_CUDNN
|
||||
ASSERT_TRUE(THCUDNNModule_initModule(module));
|
||||
PyObject *has_cudnn = Py_True;
|
||||
#else
|
||||
PyObject *has_cudnn = Py_False;
|
||||
@ -804,6 +829,28 @@ PyMODINIT_FUNC PyInit__C()
|
||||
Py_INCREF(has_cudnn);
|
||||
ASSERT_TRUE(PyModule_AddObject(module, "has_cudnn", has_cudnn) == 0);
|
||||
|
||||
// TODO THD: enable once master-worker mode is implemented
|
||||
#if 0 && defined(WITH_DISTRIBUTED)
|
||||
// See comment on CUDA objects
|
||||
ASSERT_TRUE(THDPDoubleStorage_init(module));
|
||||
ASSERT_TRUE(THDPFloatStorage_init(module));
|
||||
//ASSERT_TRUE(THDPHalfStorage_init(module));
|
||||
ASSERT_TRUE(THDPLongStorage_init(module));
|
||||
ASSERT_TRUE(THDPIntStorage_init(module));
|
||||
ASSERT_TRUE(THDPShortStorage_init(module));
|
||||
ASSERT_TRUE(THDPCharStorage_init(module));
|
||||
ASSERT_TRUE(THDPByteStorage_init(module));
|
||||
|
||||
ASSERT_TRUE(THDPDoubleTensor_init(module));
|
||||
ASSERT_TRUE(THDPFloatTensor_init(module));
|
||||
//ASSERT_TRUE(THDPHalfTensor_init(module));
|
||||
ASSERT_TRUE(THDPLongTensor_init(module));
|
||||
ASSERT_TRUE(THDPIntTensor_init(module));
|
||||
ASSERT_TRUE(THDPShortTensor_init(module));
|
||||
ASSERT_TRUE(THDPCharTensor_init(module));
|
||||
ASSERT_TRUE(THDPByteTensor_init(module));
|
||||
#endif
|
||||
|
||||
THPDefaultGenerator = (THPGenerator*)THPGenerator_New();
|
||||
ASSERT_TRUE(THPDefaultGenerator != nullptr);
|
||||
ASSERT_TRUE(PyModule_AddObject(module, "default_generator", (PyObject*)THPDefaultGenerator) == 0);
|
||||
|
||||
@ -52,7 +52,7 @@ static void THPWrapper_dealloc(THPWrapper* self)
|
||||
|
||||
PyTypeObject THPWrapperType = {
|
||||
PyVarObject_HEAD_INIT(NULL, 0)
|
||||
"torch._C._CppWrapper", /* tp_name */
|
||||
"torch._C._PtrWrapper", /* tp_name */
|
||||
sizeof(THPWrapper), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
(destructor)THPWrapper_dealloc, /* tp_dealloc */
|
||||
@ -1,5 +1,5 @@
|
||||
#ifndef THP_CUDNN_CPP_WRAPPER_INC
|
||||
#define THP_CUDNN_CPP_WRAPPER_INC
|
||||
#ifndef THP_PTR_WRAPPER_H
|
||||
#define THP_PTR_WRAPPER_H
|
||||
|
||||
#include <functional>
|
||||
|
||||
@ -24,18 +24,17 @@ PyObject * THPSize_New(int dim, long *sizes)
|
||||
|
||||
static PyObject * THPSize_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
|
||||
{
|
||||
PyObject *self = PyTuple_Type.tp_new(type, args, kwargs);
|
||||
THPObjectPtr self = PyTuple_Type.tp_new(type, args, kwargs);
|
||||
if (self) {
|
||||
for (Py_ssize_t i = 0; i < PyTuple_Size(self); ++i) {
|
||||
PyObject *item = PyTuple_GET_ITEM(self, i);
|
||||
PyObject *item = PyTuple_GET_ITEM(self.get(), i);
|
||||
if (!THPUtils_checkLong(item)) {
|
||||
Py_DECREF(self);
|
||||
return PyErr_Format(PyExc_TypeError, "torch.Size() takes an iterable of 'int' (item %zd is '%s')",
|
||||
i, Py_TYPE(item)->tp_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
return self;
|
||||
return self.release();
|
||||
}
|
||||
|
||||
static PyObject * THPSize_repr(THPSize *self)
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
|
||||
#define THP_API extern "C"
|
||||
|
||||
#include "PtrWrapper.h"
|
||||
#include "Exceptions.h"
|
||||
#include "Generator.h"
|
||||
#include "Storage.h"
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#include "BatchNorm.h"
|
||||
|
||||
#include "Descriptors.h"
|
||||
#include "Types.h"
|
||||
|
||||
|
||||
namespace torch { namespace cudnn {
|
||||
@ -78,6 +79,11 @@ void cudnn_batch_norm_forward(
|
||||
Constant one(dataType, 1);
|
||||
Constant zero(dataType, 0);
|
||||
if (training) {
|
||||
THVoidTensor_assertContiguous(bias);
|
||||
THVoidTensor_assertContiguous(running_mean);
|
||||
THVoidTensor_assertContiguous(running_var);
|
||||
THVoidTensor_assertContiguous(save_mean);
|
||||
THVoidTensor_assertContiguous(save_var);
|
||||
CHECK(cudnnBatchNormalizationForwardTraining(
|
||||
handle, mode, &one, &zero,
|
||||
idesc.desc, tensorPointer(dataType, input),
|
||||
@ -91,6 +97,9 @@ void cudnn_batch_norm_forward(
|
||||
tensorPointer(dataType, save_mean),
|
||||
tensorPointer(dataType, save_var)));
|
||||
} else {
|
||||
THVoidTensor_assertContiguous(bias);
|
||||
THVoidTensor_assertContiguous(running_mean);
|
||||
THVoidTensor_assertContiguous(running_var);
|
||||
CHECK(cudnnBatchNormalizationForwardInference(
|
||||
handle, mode, &one, &zero,
|
||||
idesc.desc, tensorPointer(dataType, input),
|
||||
@ -129,6 +138,10 @@ void cudnn_batch_norm_backward(
|
||||
|
||||
Constant one(dataType, 1);
|
||||
Constant zero(dataType, 0);
|
||||
THVoidTensor_assertContiguous(grad_weight);
|
||||
THVoidTensor_assertContiguous(grad_bias);
|
||||
THVoidTensor_assertContiguous(save_mean);
|
||||
THVoidTensor_assertContiguous(save_var);
|
||||
CHECK(cudnnBatchNormalizationBackward(
|
||||
handle, mode, &one, &zero, &one, &one,
|
||||
idesc.desc, tensorPointer(dataType, input),
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
#include "THC/THC.h"
|
||||
#include "Exceptions.h"
|
||||
#include "Types.h"
|
||||
|
||||
#include <cudnn.h>
|
||||
#include <functional>
|
||||
@ -31,6 +32,7 @@ void setWeightDescriptor(FilterDescriptor& desc, cudnnDataType_t dataType, THVoi
|
||||
{
|
||||
CHECK_ARG(weight->nDimension <= 5);
|
||||
int weightSize[5];
|
||||
THVoidTensor_assertContiguous(weight);
|
||||
for (int i = 0; i < weight->nDimension; ++i) {
|
||||
weightSize[i] = (int) weight->size[i];
|
||||
}
|
||||
@ -63,13 +65,13 @@ struct BenchmarkCache {
|
||||
std::mutex mutex;
|
||||
std::unordered_map<ConvolutionParams, T, ParamsHash, ParamsEqual> map;
|
||||
|
||||
bool find(const ConvolutionParams& params, T& results) {
|
||||
bool find(const ConvolutionParams& params, T* results) {
|
||||
std::lock_guard<std::mutex> guard(mutex);
|
||||
auto it = map.find(params);
|
||||
if (it == map.end()) {
|
||||
return false;
|
||||
}
|
||||
results = it->second;
|
||||
*results = it->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -84,77 +86,145 @@ BenchmarkCache<cudnnConvolutionBwdDataAlgo_t> bwd_data_algos;
|
||||
BenchmarkCache<cudnnConvolutionBwdFilterAlgo_t> bwd_filter_algos;
|
||||
|
||||
struct Workspace {
|
||||
void* data;
|
||||
THCState* state;
|
||||
Workspace(THCState* state, size_t size) : data(NULL), state(state) {
|
||||
Workspace(THCState* state, size_t size) : state(state), size(size), data(NULL) {
|
||||
CUDA_CHECK(THCudaMalloc(state, &data, size));
|
||||
}
|
||||
Workspace(const Workspace&) = delete;
|
||||
Workspace(Workspace&&) = default;
|
||||
~Workspace() {
|
||||
THCudaFree(state, data);
|
||||
if (data) {
|
||||
THCudaFree(state, data);
|
||||
}
|
||||
}
|
||||
|
||||
THCState* state;
|
||||
size_t size;
|
||||
void* data;
|
||||
};
|
||||
|
||||
cudnnConvolutionFwdAlgo_t chooseForwardAlgorithm(
|
||||
cudnnHandle_t handle, const Convolution& conv, bool benchmark)
|
||||
{
|
||||
cudnnConvolutionFwdAlgo_t algo;
|
||||
if (benchmark) {
|
||||
if (fwd_algos.find(conv.params, algo)) {
|
||||
return algo;
|
||||
}
|
||||
template<typename algo_t>
|
||||
struct algorithm_search {
|
||||
};
|
||||
|
||||
template<>
|
||||
struct algorithm_search<cudnnConvolutionFwdAlgo_t> {
|
||||
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
|
||||
static BenchmarkCache<cudnnConvolutionFwdAlgo_t>& cache() {
|
||||
return fwd_algos;
|
||||
}
|
||||
|
||||
static cudnnConvolutionFwdAlgoPerf_t findAlgorithm(cudnnHandle_t handle, const Convolution& conv) {
|
||||
int algoCount;
|
||||
cudnnConvolutionFwdAlgoPerf_t perfResults;
|
||||
CHECK(cudnnFindConvolutionForwardAlgorithm(handle, conv.idesc.desc,
|
||||
conv.wdesc.desc, conv.cdesc.desc, conv.odesc.desc, 1, &algoCount, &perfResults));
|
||||
fwd_algos.insert(conv.params, perfResults.algo);
|
||||
return perfResults.algo;
|
||||
return perfResults;
|
||||
}
|
||||
cudnnConvolutionFwdPreference_t pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
|
||||
CHECK(cudnnGetConvolutionForwardAlgorithm(handle, conv.idesc.desc,
|
||||
conv.wdesc.desc, conv.cdesc.desc, conv.odesc.desc, pref, 0, &algo));
|
||||
return algo;
|
||||
}
|
||||
|
||||
cudnnConvolutionBwdDataAlgo_t chooseBackwardDataAlgorithm(
|
||||
cudnnHandle_t handle, const Convolution& conv, bool benchmark)
|
||||
{
|
||||
cudnnConvolutionBwdDataAlgo_t algo;
|
||||
if (benchmark) {
|
||||
if (bwd_data_algos.find(conv.params, algo)) {
|
||||
return algo;
|
||||
}
|
||||
static void getAlgorithm(cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionFwdAlgo_t* algo) {
|
||||
cudnnConvolutionFwdPreference_t pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
|
||||
CHECK(cudnnGetConvolutionForwardAlgorithm(handle, conv.idesc.desc,
|
||||
conv.wdesc.desc, conv.cdesc.desc, conv.odesc.desc, pref, 0, algo));
|
||||
}
|
||||
|
||||
static void getWorkspaceSize(cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionFwdAlgo_t algo, size_t* workspaceSize) {
|
||||
CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle, conv.idesc.desc, conv.wdesc.desc,
|
||||
conv.cdesc.desc, conv.odesc.desc, algo, workspaceSize));
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct algorithm_search<cudnnConvolutionBwdDataAlgo_t> {
|
||||
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
||||
static BenchmarkCache<cudnnConvolutionBwdDataAlgo_t>& cache() {
|
||||
return bwd_data_algos;
|
||||
}
|
||||
|
||||
static cudnnConvolutionBwdDataAlgoPerf_t findAlgorithm(cudnnHandle_t handle, const Convolution& conv) {
|
||||
int algoCount;
|
||||
cudnnConvolutionBwdDataAlgoPerf_t perfResults;
|
||||
CHECK(cudnnFindConvolutionBackwardDataAlgorithm(handle, conv.wdesc.desc,
|
||||
conv.odesc.desc, conv.cdesc.desc, conv.idesc.desc, 1, &algoCount, &perfResults));
|
||||
bwd_data_algos.insert(conv.params, perfResults.algo);
|
||||
return perfResults.algo;
|
||||
return perfResults;
|
||||
}
|
||||
cudnnConvolutionBwdDataPreference_t pref = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST;
|
||||
CHECK(cudnnGetConvolutionBackwardDataAlgorithm(handle, conv.wdesc.desc,
|
||||
conv.odesc.desc, conv.cdesc.desc, conv.idesc.desc, pref, 0, &algo));
|
||||
return algo;
|
||||
}
|
||||
|
||||
cudnnConvolutionBwdFilterAlgo_t chooseBackwardFilterAlgorithm(
|
||||
cudnnHandle_t handle, const Convolution& conv, bool benchmark)
|
||||
{
|
||||
cudnnConvolutionBwdFilterAlgo_t algo;
|
||||
if (benchmark) {
|
||||
if (bwd_filter_algos.find(conv.params, algo)) {
|
||||
return algo;
|
||||
}
|
||||
static void getAlgorithm(cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionBwdDataAlgo_t* algo) {
|
||||
CHECK(cudnnGetConvolutionBackwardDataAlgorithm(handle, conv.wdesc.desc,
|
||||
conv.odesc.desc, conv.cdesc.desc, conv.idesc.desc,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, algo));
|
||||
}
|
||||
|
||||
static void getWorkspaceSize(cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionBwdDataAlgo_t algo, size_t* workspaceSize) {
|
||||
CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(handle, conv.wdesc.desc,
|
||||
conv.odesc.desc, conv.cdesc.desc, conv.idesc.desc, algo,
|
||||
workspaceSize));
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct algorithm_search<cudnnConvolutionBwdFilterAlgo_t> {
|
||||
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
|
||||
static BenchmarkCache<cudnnConvolutionBwdFilterAlgo_t>& cache() {
|
||||
return bwd_filter_algos;
|
||||
}
|
||||
|
||||
static cudnnConvolutionBwdFilterAlgoPerf_t findAlgorithm(cudnnHandle_t handle, const Convolution& conv) {
|
||||
int algoCount;
|
||||
cudnnConvolutionBwdFilterAlgoPerf_t perfResults;
|
||||
CHECK(cudnnFindConvolutionBackwardFilterAlgorithm(handle, conv.idesc.desc,
|
||||
conv.odesc.desc, conv.cdesc.desc, conv.wdesc.desc, 1, &algoCount, &perfResults));
|
||||
bwd_filter_algos.insert(conv.params, perfResults.algo);
|
||||
return perfResults.algo;
|
||||
return perfResults;
|
||||
}
|
||||
|
||||
static void getAlgorithm(cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionBwdFilterAlgo_t* algo) {
|
||||
CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(handle, conv.idesc.desc,
|
||||
conv.odesc.desc, conv.cdesc.desc, conv.wdesc.desc,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, algo));
|
||||
}
|
||||
|
||||
static void getWorkspaceSize(cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionBwdFilterAlgo_t algo, size_t* workspaceSize) {
|
||||
CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(handle, conv.idesc.desc,
|
||||
conv.odesc.desc, conv.cdesc.desc, conv.wdesc.desc, algo, workspaceSize));
|
||||
}
|
||||
};
|
||||
|
||||
template<typename algo_t>
|
||||
Workspace chooseAlgorithm(
|
||||
THCState* state, cudnnHandle_t handle, const Convolution& conv,
|
||||
bool benchmark, algo_t* algo)
|
||||
{
|
||||
using search = algorithm_search<algo_t>;
|
||||
auto& cache = search::cache();
|
||||
|
||||
if (!cache.find(conv.params, algo)) {
|
||||
if (benchmark) {
|
||||
auto perfResults = search::findAlgorithm(handle, conv);
|
||||
if (perfResults.status == CUDNN_STATUS_SUCCESS) {
|
||||
*algo = perfResults.algo;
|
||||
} else {
|
||||
*algo = search::DEFAULT_ALGO;
|
||||
}
|
||||
cache.insert(conv.params, *algo);
|
||||
} else {
|
||||
search::getAlgorithm(handle, conv, algo);
|
||||
}
|
||||
}
|
||||
|
||||
size_t workspace_size;
|
||||
search::getWorkspaceSize(handle, conv, *algo, &workspace_size);
|
||||
try {
|
||||
return Workspace(state, workspace_size);
|
||||
} catch (std::runtime_error& e) {
|
||||
cudaGetLastError(); // clear OOM error
|
||||
|
||||
// switch to default algorithm and record it in the cache to prevent
|
||||
// further OOM errors
|
||||
*algo = search::DEFAULT_ALGO;
|
||||
cache.insert(conv.params, *algo);
|
||||
|
||||
search::getWorkspaceSize(handle, conv, *algo, &workspace_size);
|
||||
return Workspace(state, workspace_size);
|
||||
}
|
||||
cudnnConvolutionBwdFilterPreference_t pref = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST;
|
||||
CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(handle, conv.idesc.desc,
|
||||
conv.odesc.desc, conv.cdesc.desc, conv.wdesc.desc, pref, 0, &algo));
|
||||
return algo;
|
||||
}
|
||||
|
||||
void* tensorPointer(cudnnDataType_t dataType, THVoidTensor* tensor, int groupIdx, int groups, int dim)
|
||||
@ -179,7 +249,7 @@ static_assert(std::is_pod<ConvolutionParams>::value, "ConvolutionParams not POD"
|
||||
Convolution::Convolution(
|
||||
cudnnDataType_t dataType, THVoidTensor* input, THVoidTensor* weight,
|
||||
THVoidTensor* bias, THVoidTensor* output, std::vector<int> pad,
|
||||
std::vector<int> stride, int groups, bool transposed)
|
||||
std::vector<int> stride, std::vector<int> dilation, int groups, bool transposed)
|
||||
: idesc(), odesc(), odesc_bias(), bdesc(), wdesc(), cdesc(), groups(groups)
|
||||
, transposed(transposed)
|
||||
{
|
||||
@ -197,6 +267,7 @@ Convolution::Convolution(
|
||||
for (size_t i = 0; i != pad.size(); ++i) {
|
||||
params.pad[i] = pad[i];
|
||||
params.stride[i] = stride[i];
|
||||
params.dilation[i] = dilation[i];
|
||||
}
|
||||
params.groups = groups;
|
||||
setTensorDescriptor(idesc, dataType, input, groups);
|
||||
@ -206,7 +277,7 @@ Convolution::Convolution(
|
||||
else
|
||||
setTensorDescriptor(odesc_bias, dataType, input, 1);
|
||||
setWeightDescriptor(wdesc, dataType, weight, groups);
|
||||
cdesc.set(dataType, pad.size(), pad.data(), stride.data());
|
||||
cdesc.set(dataType, pad.size(), pad.data(), stride.data(), dilation.data());
|
||||
}
|
||||
|
||||
void cudnn_convolution_forward(
|
||||
@ -215,18 +286,9 @@ void cudnn_convolution_forward(
|
||||
Convolution* info, bool benchmark)
|
||||
{
|
||||
int groups = info->groups;
|
||||
TensorDescriptor& idesc = info->idesc;
|
||||
TensorDescriptor& odesc = info->odesc;
|
||||
FilterDescriptor& wdesc = info->wdesc;
|
||||
ConvolutionDescriptor& cdesc = info->cdesc;
|
||||
|
||||
cudnnConvolutionFwdAlgo_t fwdAlg = chooseForwardAlgorithm(handle, *info, benchmark);
|
||||
|
||||
size_t workspaceSize;
|
||||
CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle, idesc.desc, wdesc.desc,
|
||||
cdesc.desc, odesc.desc, fwdAlg, &workspaceSize));
|
||||
|
||||
Workspace workspace(state, workspaceSize);
|
||||
cudnnConvolutionFwdAlgo_t fwdAlg;
|
||||
Workspace workspace = chooseAlgorithm(state, handle, *info, benchmark, &fwdAlg);
|
||||
|
||||
Constant one(dataType, 1);
|
||||
Constant zero(dataType, 0);
|
||||
@ -236,9 +298,9 @@ void cudnn_convolution_forward(
|
||||
void* weight_ptr = tensorPointer(dataType, weight, i, groups, 0);
|
||||
|
||||
CHECK(cudnnConvolutionForward(
|
||||
handle, &one, idesc.desc, input_ptr, wdesc.desc,
|
||||
weight_ptr, cdesc.desc, fwdAlg, workspace.data,
|
||||
workspaceSize, &zero, odesc.desc, output_ptr));
|
||||
handle, &one, info->idesc.desc, input_ptr, info->wdesc.desc,
|
||||
weight_ptr, info->cdesc.desc, fwdAlg, workspace.data,
|
||||
workspace.size, &zero, info->odesc.desc, output_ptr));
|
||||
}
|
||||
}
|
||||
|
||||
@ -248,7 +310,6 @@ void cudnn_convolution_add_bias(
|
||||
Convolution* info)
|
||||
{
|
||||
CHECK_ARG(output->nDimension <= 5);
|
||||
TensorDescriptor& odesc_bias = info->odesc_bias;
|
||||
TensorDescriptor& bdesc = info->bdesc;
|
||||
|
||||
int size[5] = { 1, (int)bias->size[0], 1, 1, 1 };
|
||||
@ -260,7 +321,7 @@ void cudnn_convolution_add_bias(
|
||||
|
||||
Constant one(dataType, 1);
|
||||
CHECK(cudnnAddTensor(handle, &one, bdesc.desc, bias_ptr, &one,
|
||||
odesc_bias.desc, output_ptr));
|
||||
info->odesc_bias.desc, output_ptr));
|
||||
}
|
||||
|
||||
void cudnn_convolution_backward_data(
|
||||
@ -268,20 +329,11 @@ void cudnn_convolution_backward_data(
|
||||
THVoidTensor* gradOutput, THVoidTensor* gradInput, THVoidTensor* weight,
|
||||
Convolution* info, bool benchmark)
|
||||
{
|
||||
TensorDescriptor& idesc = info->idesc;
|
||||
TensorDescriptor& odesc = info->odesc;
|
||||
FilterDescriptor& wdesc = info->wdesc;
|
||||
ConvolutionDescriptor& cdesc = info->cdesc;
|
||||
int groups = info->params.groups;
|
||||
|
||||
cudnnConvolutionBwdDataAlgo_t bwdDataAlg =
|
||||
chooseBackwardDataAlgorithm(handle, *info, benchmark);
|
||||
cudnnConvolutionBwdDataAlgo_t bwdDataAlg;
|
||||
Workspace workspace = chooseAlgorithm(state, handle, *info, benchmark, &bwdDataAlg);
|
||||
|
||||
size_t workspaceSize;
|
||||
CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(handle, wdesc.desc,
|
||||
odesc.desc, cdesc.desc, idesc.desc, bwdDataAlg, &workspaceSize));
|
||||
|
||||
Workspace workspace(state, workspaceSize);
|
||||
Constant one(dataType, 1);
|
||||
Constant zero(dataType, 0);
|
||||
for (int i = 0; i < groups; ++i) {
|
||||
@ -290,9 +342,9 @@ void cudnn_convolution_backward_data(
|
||||
void* weight_ptr = tensorPointer(dataType, weight, i, groups, 0);
|
||||
|
||||
CHECK(cudnnConvolutionBackwardData(
|
||||
handle, &one, wdesc.desc, weight_ptr, odesc.desc, gradOutput_ptr,
|
||||
cdesc.desc, bwdDataAlg, workspace.data, workspaceSize, &zero,
|
||||
idesc.desc, gradInput_ptr));
|
||||
handle, &one, info->wdesc.desc, weight_ptr, info->odesc.desc, gradOutput_ptr,
|
||||
info->cdesc.desc, bwdDataAlg, workspace.data, workspace.size, &zero,
|
||||
info->idesc.desc, gradInput_ptr));
|
||||
}
|
||||
}
|
||||
|
||||
@ -301,20 +353,11 @@ void cudnn_convolution_backward_filter(
|
||||
THVoidTensor* gradOutput, THVoidTensor* input, THVoidTensor* gradWeight,
|
||||
Convolution* info, bool benchmark)
|
||||
{
|
||||
TensorDescriptor& idesc = info->idesc;
|
||||
TensorDescriptor& odesc = info->odesc;
|
||||
FilterDescriptor& wdesc = info->wdesc;
|
||||
ConvolutionDescriptor& cdesc = info->cdesc;
|
||||
int groups = info->params.groups;
|
||||
|
||||
cudnnConvolutionBwdFilterAlgo_t bwdFilterAlg =
|
||||
chooseBackwardFilterAlgorithm(handle, *info, benchmark);
|
||||
cudnnConvolutionBwdFilterAlgo_t bwdFilterAlg;
|
||||
Workspace workspace = chooseAlgorithm(state, handle, *info, benchmark, &bwdFilterAlg);
|
||||
|
||||
size_t workspaceSize;
|
||||
CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(handle, idesc.desc,
|
||||
odesc.desc, cdesc.desc, wdesc.desc, bwdFilterAlg, &workspaceSize));
|
||||
|
||||
Workspace workspace(state, workspaceSize);
|
||||
Constant one(dataType, 1);
|
||||
Constant zero(dataType, 0);
|
||||
for (int i = 0; i < groups; ++i) {
|
||||
@ -327,9 +370,9 @@ void cudnn_convolution_backward_filter(
|
||||
}
|
||||
|
||||
CHECK(cudnnConvolutionBackwardFilter(
|
||||
handle, &one, idesc.desc, input_ptr, odesc.desc, gradOutput_ptr,
|
||||
cdesc.desc, bwdFilterAlg, workspace.data, workspaceSize, &zero,
|
||||
wdesc.desc, gradWeight_ptr));
|
||||
handle, &one, info->idesc.desc, input_ptr, info->odesc.desc, gradOutput_ptr,
|
||||
info->cdesc.desc, bwdFilterAlg, workspace.data, workspace.size, &zero,
|
||||
info->wdesc.desc, gradWeight_ptr));
|
||||
}
|
||||
}
|
||||
|
||||
@ -337,26 +380,23 @@ void cudnn_convolution_backward_bias(
|
||||
THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType,
|
||||
THVoidTensor* gradOutput, THVoidTensor* gradBias, Convolution* info)
|
||||
{
|
||||
TensorDescriptor& bdesc = info->bdesc;
|
||||
TensorDescriptor& odesc_bias = info->odesc_bias;
|
||||
|
||||
Constant one(dataType, 1);
|
||||
Constant zero(dataType, 0);
|
||||
void* gradOutput_ptr = tensorPointer(dataType, gradOutput, 0, 1, 0);
|
||||
void* gradBias_ptr = tensorPointer(dataType, gradBias, 0, 1, 0);
|
||||
|
||||
CHECK(cudnnConvolutionBackwardBias(
|
||||
handle, &one, odesc_bias.desc, gradOutput_ptr, &zero, bdesc.desc,
|
||||
gradBias_ptr));
|
||||
handle, &one, info->odesc_bias.desc, gradOutput_ptr, &zero,
|
||||
info->bdesc.desc, gradBias_ptr));
|
||||
}
|
||||
|
||||
Convolution* cudnn_convolution_full_forward(
|
||||
THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType,
|
||||
THVoidTensor* input, THVoidTensor* weight, THVoidTensor* bias, THVoidTensor* output,
|
||||
std::vector<int> pad, std::vector<int> stride, int groups, bool benchmark)
|
||||
std::vector<int> pad, std::vector<int> stride, std::vector<int> dilation, int groups, bool benchmark)
|
||||
{
|
||||
std::unique_ptr<Convolution> info(new Convolution(
|
||||
dataType, input, weight, bias, output, pad, stride, groups, false));
|
||||
dataType, input, weight, bias, output, pad, stride, dilation, groups, false));
|
||||
cudnn_convolution_forward(
|
||||
state, handle, dataType, input, weight, output, info.get(), benchmark);
|
||||
if (bias) {
|
||||
@ -369,10 +409,10 @@ Convolution* cudnn_convolution_full_forward(
|
||||
Convolution* cudnn_convolution_transpose_full_forward(
|
||||
THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType,
|
||||
THVoidTensor* input, THVoidTensor* weight, THVoidTensor* bias, THVoidTensor* output,
|
||||
std::vector<int> pad, std::vector<int> stride, int groups, bool benchmark)
|
||||
std::vector<int> pad, std::vector<int> stride, std::vector<int> dilation, int groups, bool benchmark)
|
||||
{
|
||||
std::unique_ptr<Convolution> info(new Convolution(
|
||||
dataType, output, weight, bias, input, pad, stride, groups, true));
|
||||
dataType, output, weight, bias, input, pad, stride, dilation, groups, true));
|
||||
cudnn_convolution_backward_data(
|
||||
state, handle, dataType, input, output, weight, info.get(), benchmark);
|
||||
if (bias) {
|
||||
|
||||
@ -18,6 +18,7 @@ struct ConvolutionParams
|
||||
int weight_size[5];
|
||||
int pad[3];
|
||||
int stride[3];
|
||||
int dilation[3];
|
||||
int groups;
|
||||
};
|
||||
|
||||
@ -41,7 +42,7 @@ struct Convolution
|
||||
Convolution(
|
||||
cudnnDataType_t dataType, THVoidTensor* input, THVoidTensor* weight,
|
||||
THVoidTensor* bias, THVoidTensor* output, std::vector<int> pad,
|
||||
std::vector<int> stride, int groups, bool transposed);
|
||||
std::vector<int> stride, std::vector<int> dilation, int groups, bool transposed);
|
||||
};
|
||||
|
||||
void cudnn_convolution_forward(
|
||||
@ -73,12 +74,12 @@ void cudnn_convolution_backward_bias(
|
||||
Convolution* cudnn_convolution_full_forward(
|
||||
THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType,
|
||||
THVoidTensor* input, THVoidTensor* weight, THVoidTensor* bias, THVoidTensor* output,
|
||||
std::vector<int> pad, std::vector<int> stride, int groups, bool benchmark);
|
||||
std::vector<int> pad, std::vector<int> stride, std::vector<int> dilation, int groups, bool benchmark);
|
||||
|
||||
Convolution* cudnn_convolution_transpose_full_forward(
|
||||
THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType,
|
||||
THVoidTensor* input, THVoidTensor* weight, THVoidTensor* bias, THVoidTensor* output,
|
||||
std::vector<int> pad, std::vector<int> stride, int groups, bool benchmark);
|
||||
std::vector<int> pad, std::vector<int> stride, std::vector<int> dilation, int groups, bool benchmark);
|
||||
|
||||
}} // namespace torch::cudnn
|
||||
|
||||
|
||||
@ -62,10 +62,11 @@ struct ConvolutionDescriptor
|
||||
~ConvolutionDescriptor() {
|
||||
cudnnDestroyConvolutionDescriptor(desc);
|
||||
}
|
||||
void set(cudnnDataType_t dataType, int dim, int* pad, int* stride) {
|
||||
int upscale[3] = {1, 1, 1};
|
||||
void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale) {
|
||||
cudnnDataType_t mathType = dataType;
|
||||
if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT;
|
||||
CHECK(cudnnSetConvolutionNdDescriptor(desc, dim, pad, stride, upscale,
|
||||
CUDNN_CROSS_CORRELATION, dataType));
|
||||
CUDNN_CROSS_CORRELATION, mathType));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#define THP_CUDNN_EXCEPTIONS_INC
|
||||
|
||||
#include <cudnn.h>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <sstream>
|
||||
|
||||
@ -14,13 +15,21 @@ namespace torch { namespace cudnn {
|
||||
class cudnn_exception : public std::runtime_error {
|
||||
public:
|
||||
cudnnStatus_t status;
|
||||
cudnn_exception(cudnnStatus_t status, const char* msg) : std::runtime_error(msg), status(status) {
|
||||
}
|
||||
cudnn_exception(cudnnStatus_t status, const char* msg)
|
||||
: std::runtime_error(msg)
|
||||
, status(status) {}
|
||||
cudnn_exception(cudnnStatus_t status, const std::string& msg)
|
||||
: std::runtime_error(msg)
|
||||
, status(status) {}
|
||||
};
|
||||
|
||||
inline void CHECK(cudnnStatus_t status)
|
||||
{
|
||||
if (status != CUDNN_STATUS_SUCCESS) {
|
||||
if (status == CUDNN_STATUS_NOT_SUPPORTED) {
|
||||
throw cudnn_exception(status, std::string(cudnnGetErrorString(status)) +
|
||||
". This error may appear if you passed in a non-contiguous input.");
|
||||
}
|
||||
throw cudnn_exception(status, cudnnGetErrorString(status));
|
||||
}
|
||||
}
|
||||
@ -28,7 +37,9 @@ inline void CHECK(cudnnStatus_t status)
|
||||
inline void CUDA_CHECK(cudaError_t error)
|
||||
{
|
||||
if (error) {
|
||||
throw std::runtime_error("CUDA error");
|
||||
std::string msg("CUDA error: ");
|
||||
msg += cudaGetErrorString(error);
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,12 +0,0 @@
|
||||
#include "Module.h"
|
||||
|
||||
#include <Python.h>
|
||||
#include "Types.h"
|
||||
#include "Conv.h"
|
||||
#include "CppWrapper.h"
|
||||
#include "torch/csrc/THP.h"
|
||||
|
||||
bool THCUDNNModule_initModule(PyObject *module)
|
||||
{
|
||||
return THPWrapper_init(module);
|
||||
}
|
||||
@ -4,6 +4,5 @@
|
||||
#include <Python.h>
|
||||
|
||||
PyMethodDef* THCUDNN_methods();
|
||||
bool THCUDNNModule_initModule(PyObject *self);
|
||||
|
||||
#endif
|
||||
|
||||
@ -31,4 +31,16 @@ PyObject * getTensorClass(PyObject *args)
|
||||
return NULL;
|
||||
}
|
||||
|
||||
void _THVoidTensor_assertContiguous(THVoidTensor *tensor, const std::string& name)
|
||||
{
|
||||
static const std::string error_str = "cuDNN requires contiguous ";
|
||||
// Contiguity check
|
||||
long long expectedStride = 1;
|
||||
for (int i = tensor->nDimension-1; i >= 0; --i) {
|
||||
if (tensor->stride[i] != expectedStride)
|
||||
throw std::invalid_argument(error_str + name);
|
||||
expectedStride *= tensor->size[i];
|
||||
}
|
||||
}
|
||||
|
||||
}} // namespace torch::cudnn
|
||||
|
||||
@ -3,12 +3,18 @@
|
||||
|
||||
#include <Python.h>
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <cudnn.h>
|
||||
#include "../Types.h"
|
||||
|
||||
namespace torch { namespace cudnn {
|
||||
|
||||
PyObject * getTensorClass(PyObject *args);
|
||||
cudnnDataType_t getCudnnDataType(PyObject *tensorClass);
|
||||
void _THVoidTensor_assertContiguous(THVoidTensor *tensor, const std::string& name);
|
||||
|
||||
#define THVoidTensor_assertContiguous(tensor) \
|
||||
_THVoidTensor_assertContiguous(tensor, #tensor " tensor")
|
||||
|
||||
}} // namespace torch::cudnn
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
#include "BatchNorm.h"
|
||||
#include "Conv.h"
|
||||
#include "torch/csrc/cuda/THCP.h"
|
||||
#include "CppWrapper.h"
|
||||
#include "../PtrWrapper.h"
|
||||
|
||||
|
||||
using namespace torch::cudnn;
|
||||
@ -50,6 +50,7 @@ extern THCState* state;
|
||||
- THTensor* output
|
||||
- std::vector<int> pad
|
||||
- std::vector<int> stride
|
||||
- std::vector<int> dilation
|
||||
- int groups
|
||||
- bool benchmark
|
||||
]]
|
||||
@ -68,6 +69,7 @@ extern THCState* state;
|
||||
- THTensor* output
|
||||
- std::vector<int> pad
|
||||
- std::vector<int> stride
|
||||
- std::vector<int> dilation
|
||||
- int groups
|
||||
- bool benchmark
|
||||
]]
|
||||
|
||||
580
torch/csrc/distributed/Module.cpp
Normal file
580
torch/csrc/distributed/Module.cpp
Normal file
@ -0,0 +1,580 @@
|
||||
#include <Python.h>
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "THDP.h"
|
||||
|
||||
static std::unordered_map<std::string, THDChannelType> name2channel_type = {
|
||||
{"mpi", THDChannelMPI},
|
||||
{"tcp", THDChannelTCP},
|
||||
};
|
||||
|
||||
static bool THDPModule_loadClasses(PyObject *module_dict)
|
||||
{
|
||||
#define ASSERT_NOT_NULL(ptr) if (!(ptr)) { THPUtils_setError("couldn't load classes"); return false; }
|
||||
// TODO THD: enable once master-worker is implemented
|
||||
#if 0
|
||||
ASSERT_NOT_NULL(THDPDoubleStorageClass = PyMapping_GetItemString(module_dict, (char*)"DoubleStorage"));
|
||||
ASSERT_NOT_NULL(THDPFloatStorageClass = PyMapping_GetItemString(module_dict, (char*)"FloatStorage"));
|
||||
//ASSERT_NOT_NULL(THDPHalfStorageClass = PyMapping_GetItemString(module_dict, (char*)"HalfStorage"));
|
||||
ASSERT_NOT_NULL(THDPLongStorageClass = PyMapping_GetItemString(module_dict, (char*)"LongStorage"));
|
||||
ASSERT_NOT_NULL(THDPIntStorageClass = PyMapping_GetItemString(module_dict, (char*)"IntStorage"));
|
||||
ASSERT_NOT_NULL(THDPShortStorageClass = PyMapping_GetItemString(module_dict, (char*)"ShortStorage"));
|
||||
ASSERT_NOT_NULL(THDPCharStorageClass = PyMapping_GetItemString(module_dict, (char*)"CharStorage"));
|
||||
ASSERT_NOT_NULL(THDPByteStorageClass = PyMapping_GetItemString(module_dict, (char*)"ByteStorage"));
|
||||
|
||||
ASSERT_NOT_NULL(THDPDoubleTensorClass = PyMapping_GetItemString(module_dict, (char*)"DoubleTensor"));
|
||||
//ASSERT_NOT_NULL(THDPHalfTensorClass = PyMapping_GetItemString(module_dict, (char*)"HalfTensor"));
|
||||
ASSERT_NOT_NULL(THDPFloatTensorClass = PyMapping_GetItemString(module_dict, (char*)"FloatTensor"));
|
||||
ASSERT_NOT_NULL(THDPLongTensorClass = PyMapping_GetItemString(module_dict, (char*)"LongTensor"));
|
||||
ASSERT_NOT_NULL(THDPIntTensorClass = PyMapping_GetItemString(module_dict, (char*)"IntTensor"));
|
||||
ASSERT_NOT_NULL(THDPShortTensorClass = PyMapping_GetItemString(module_dict, (char*)"ShortTensor"));
|
||||
ASSERT_NOT_NULL(THDPCharTensorClass = PyMapping_GetItemString(module_dict, (char*)"CharTensor"));
|
||||
ASSERT_NOT_NULL(THDPByteTensorClass = PyMapping_GetItemString(module_dict, (char*)"ByteTensor"));
|
||||
#endif
|
||||
|
||||
return true;
|
||||
#undef ASSERT_NOT_NULL
|
||||
}
|
||||
|
||||
static std::unordered_map<PyObject*, THDReduceOp> obj2reduceop;
|
||||
static std::unordered_map<PyObject*, THDGroup> obj2group;
|
||||
|
||||
static THPObjectPtr _ensureBytes(PyObject *obj)
|
||||
{
|
||||
#if PY_MAJOR_VERSION == 2
|
||||
if (PyString_Check(obj)) {
|
||||
#elif PY_MAJOR_VERSION == 3
|
||||
if (PyBytes_Check(obj)) {
|
||||
#endif
|
||||
Py_INCREF(obj);
|
||||
return obj;
|
||||
}
|
||||
if (PyUnicode_Check(obj)) {
|
||||
return PyUnicode_AsASCIIString(obj);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
PyObject* THDPModule_initProcessGroup(PyObject *_unused, PyObject *_backend)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPObjectPtr backend_bytes = _ensureBytes(_backend);
|
||||
THPUtils_assert(backend_bytes, "backend argument has to be a string/bytes "
|
||||
"object, but got %s", THPUtils_typename(_backend));
|
||||
char *backend_name = THPUtils_bytesAsString(backend_bytes.get());
|
||||
THDChannelType channel_type = name2channel_type.at(backend_name);
|
||||
THPUtils_assert(THDProcessGroupInit(channel_type), "failed to initialize "
|
||||
"distributed library (THD)");
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_initMasterWorker(PyObject *_unused, PyObject *_backend)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPObjectPtr backend_bytes = _ensureBytes(_backend);
|
||||
THPUtils_assert(backend_bytes, "backend argument has to be a string/bytes "
|
||||
"object, but got %s", THPUtils_typename(_backend));
|
||||
char *backend_name = THPUtils_bytesAsString(backend_bytes.get());
|
||||
THDChannelType channel_type = name2channel_type.at(backend_name);
|
||||
THPUtils_assert(THDMasterWorkerInit(channel_type), "failed to initialize "
|
||||
"distributed library (THD)");
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_getRank(PyObject *_unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
return PyInt_FromLong(THDGetRank());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_getNumProcesses(PyObject *_unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
return PyInt_FromLong(THDGetNumProcesses());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static THDTensorDescriptor* _makeDescriptor(PyObject *obj)
|
||||
{
|
||||
PyObject *type = (PyObject*)Py_TYPE(obj);
|
||||
#define REGISTER_TH_DESCRIPTOR(TYPE) \
|
||||
if (type == THP##TYPE##Class) \
|
||||
return THDTensorDescriptor_newFromTH##TYPE(((THP##TYPE*)obj)->cdata);
|
||||
REGISTER_TH_DESCRIPTOR(DoubleTensor);
|
||||
REGISTER_TH_DESCRIPTOR(FloatTensor);
|
||||
REGISTER_TH_DESCRIPTOR(LongTensor);
|
||||
REGISTER_TH_DESCRIPTOR(IntTensor);
|
||||
REGISTER_TH_DESCRIPTOR(ShortTensor);
|
||||
REGISTER_TH_DESCRIPTOR(CharTensor);
|
||||
REGISTER_TH_DESCRIPTOR(ByteTensor);
|
||||
#undef REGISTER_TH_DESCRIPTOR
|
||||
throw std::runtime_error(std::string("don't know how to create a THDTensorDesciptor for "
|
||||
"type ") + std::string(THPUtils_typename(obj)));
|
||||
}
|
||||
|
||||
static THDRequest* _unpackRequest(PyObject *obj)
|
||||
{
|
||||
return static_cast<THDRequest*>(THPWrapper_get(obj));
|
||||
}
|
||||
|
||||
static THDReduceOp _getReduceOp(PyObject *obj)
|
||||
{
|
||||
auto it = obj2reduceop.find(obj);
|
||||
if (it == obj2reduceop.end()) {
|
||||
throw std::runtime_error("op should be a constant from "
|
||||
"torch.distributed.reduce_op");
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static THDGroup _getGroup(PyObject *obj)
|
||||
{
|
||||
auto it = obj2group.find(obj);
|
||||
if (it == obj2group.end()) {
|
||||
if (!THPUtils_checkLong(obj))
|
||||
throw std::runtime_error("group should be an int or one of the values "
|
||||
"from torch.distributed.group");
|
||||
return THPUtils_unpackLong(obj);
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
PyObject* THDPModule_isend(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 2 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
THPUtils_invalidArguments(args, NULL, "isend", 1, "(tensor input, int dst_rank)");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
return THPWrapper_New(THDIsend(desc, dst_rank), (void(*)(void*))THDRequest_free);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_irecv(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 2 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
THPUtils_invalidArguments(args, NULL, "irecv", 1, "(tensor output, int src_rank)");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
return THPWrapper_New(THDIrecv(desc, src_rank), (void(*)(void*))THDRequest_free);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_send(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 2 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
THPUtils_invalidArguments(args, NULL, "send", 1, "(tensor input, int dst_rank)");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
THDSend(desc, dst_rank);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_recvAnySource(PyObject *_unused, PyObject *_tensor)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (!THPModule_isTensor(_tensor)) {
|
||||
THPUtils_invalidArguments(_tensor, NULL, "recv", 1, "(tensor output)");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
THDPTensorDesc desc = _makeDescriptor(_tensor);
|
||||
THDRecvAnySource(desc);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_recv(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 2 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
THPUtils_invalidArguments(args, NULL, "recv", 1, "(tensor output, int src_rank)");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
THDRecv(desc, src_rank);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_allReduce(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 3 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0))) {
|
||||
THPUtils_invalidArguments(args, NULL, "all_reduce", 1, "(tensor in_out, reduce_op op, group gr)");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
|
||||
THDReduceOp op = _getReduceOp(PyTuple_GET_ITEM(args, 1));
|
||||
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
THDAllReduce(desc, op, group);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_reduce(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 4 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
THPUtils_invalidArguments(args, NULL, "reduce", 1,
|
||||
"(tensor reduced, int dst_rank, reduce_op op, group gr)");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 3));
|
||||
THDReduceOp op = _getReduceOp(PyTuple_GET_ITEM(args, 2));
|
||||
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
THDReduce(desc, op, dst_rank, group);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_broadcast(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 3 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
THPUtils_invalidArguments(args, NULL, "broadcast", 1,
|
||||
"(tensor src_dst, int src_rank, group gr)");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
|
||||
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
THDBroadcast(desc, src_rank, group);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_allGather(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
PyObject* sequence = PyTuple_GET_ITEM(args, 0);
|
||||
Py_ssize_t tmp_length;
|
||||
std::size_t length;
|
||||
std::vector<THDPTensorDesc> descriptors;
|
||||
std::vector<THDTensorDescriptor*> raw_descriptors;
|
||||
|
||||
if (PyTuple_GET_SIZE(args) != 3 || !PySequence_Check(sequence) ||
|
||||
!THPModule_isTensor(PyTuple_GET_ITEM(args, 1))) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
tmp_length = PySequence_Length(sequence);
|
||||
THPUtils_assert(tmp_length >= 0, "couldn't obtain the length of %s",
|
||||
THPUtils_typename(sequence));
|
||||
|
||||
length = static_cast<std::size_t>(tmp_length);
|
||||
descriptors.reserve(length);
|
||||
for (std::size_t i = 0; i < length; ++i) {
|
||||
if (!THPModule_isTensor(PySequence_ITEM(sequence, i)))
|
||||
goto invalid_arguments;
|
||||
|
||||
descriptors.push_back(
|
||||
THDPTensorDesc(_makeDescriptor(PySequence_ITEM(sequence, i)))
|
||||
);
|
||||
raw_descriptors.push_back(descriptors.back());
|
||||
}
|
||||
|
||||
THDAllGather(
|
||||
raw_descriptors.data(), length,
|
||||
THDPTensorDesc(_makeDescriptor(PyTuple_GET_ITEM(args, 1))),
|
||||
_getGroup(PyTuple_GET_ITEM(args, 2))
|
||||
);
|
||||
Py_RETURN_NONE;
|
||||
|
||||
invalid_arguments:
|
||||
THPUtils_invalidArguments(args, NULL, "allGather", 1,
|
||||
"(list[tensor] output, tensor input, group gr)");
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_gatherSend(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 3 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0))) {
|
||||
THPUtils_invalidArguments(args, NULL, "gatherSend", 1,
|
||||
"(tensor input, int dst_rank, group gr)");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
|
||||
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
THDGatherSend(desc, dst_rank, group);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_gatherRecv(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
PyObject* sequence = PyTuple_GET_ITEM(args, 0);
|
||||
Py_ssize_t tmp_length;
|
||||
std::size_t length;
|
||||
std::vector<THDPTensorDesc> descriptors;
|
||||
std::vector<THDTensorDescriptor*> raw_descriptors;
|
||||
|
||||
if (PyTuple_GET_SIZE(args) != 3 || !PySequence_Check(sequence) ||
|
||||
!THPModule_isTensor(PyTuple_GET_ITEM(args, 1))) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
tmp_length = PySequence_Length(sequence);
|
||||
THPUtils_assert(tmp_length >= 0, "couldn't obtain the length of %s",
|
||||
THPUtils_typename(sequence));
|
||||
|
||||
length = static_cast<std::size_t>(tmp_length);
|
||||
descriptors.reserve(length);
|
||||
for (std::size_t i = 0; i < length; ++i) {
|
||||
if (!THPModule_isTensor(PySequence_ITEM(sequence, i)))
|
||||
goto invalid_arguments;
|
||||
|
||||
descriptors.push_back(
|
||||
THDPTensorDesc(_makeDescriptor(PySequence_ITEM(sequence, i)))
|
||||
);
|
||||
raw_descriptors.push_back(descriptors.back());
|
||||
}
|
||||
|
||||
THDGatherRecv(
|
||||
raw_descriptors.data(), length,
|
||||
THDPTensorDesc(_makeDescriptor(PyTuple_GET_ITEM(args, 1))),
|
||||
_getGroup(PyTuple_GET_ITEM(args, 2))
|
||||
);
|
||||
Py_RETURN_NONE;
|
||||
|
||||
invalid_arguments:
|
||||
THPUtils_invalidArguments(args, NULL, "gatherRecv", 1,
|
||||
"(list[tensor] output, tensor input, group gr)");
|
||||
return NULL;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_scatterSend(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
PyObject* sequence = PyTuple_GET_ITEM(args, 0);
|
||||
Py_ssize_t tmp_length;
|
||||
std::size_t length;
|
||||
std::vector<THDPTensorDesc> descriptors;
|
||||
std::vector<THDTensorDescriptor*> raw_descriptors;
|
||||
|
||||
if (PyTuple_GET_SIZE(args) != 3 || !PySequence_Check(sequence) ||
|
||||
!THPModule_isTensor(PyTuple_GET_ITEM(args, 1))) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
tmp_length = PySequence_Length(sequence);
|
||||
THPUtils_assert(tmp_length >= 0, "couldn't obtain the length of %s",
|
||||
THPUtils_typename(sequence));
|
||||
|
||||
length = static_cast<std::size_t>(tmp_length);
|
||||
descriptors.reserve(length);
|
||||
for (std::size_t i = 0; i < length; ++i) {
|
||||
if (!THPModule_isTensor(PySequence_ITEM(sequence, i)))
|
||||
goto invalid_arguments;
|
||||
|
||||
descriptors.push_back(
|
||||
THDPTensorDesc(_makeDescriptor(PySequence_ITEM(sequence, i)))
|
||||
);
|
||||
raw_descriptors.push_back(descriptors.back());
|
||||
}
|
||||
|
||||
THDScatterSend(
|
||||
raw_descriptors.data(), length,
|
||||
THDPTensorDesc(_makeDescriptor(PyTuple_GET_ITEM(args, 1))),
|
||||
_getGroup(PyTuple_GET_ITEM(args, 2))
|
||||
);
|
||||
Py_RETURN_NONE;
|
||||
|
||||
invalid_arguments:
|
||||
THPUtils_invalidArguments(args, NULL, "scatterSend", 1,
|
||||
"(list[tensor] input, tensor output, group gr)");
|
||||
return NULL;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_scatterRecv(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 3 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
THPUtils_invalidArguments(args, NULL, "scatterRecv", 1,
|
||||
"(tensor output, int src_rank, group gr)");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
|
||||
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
THDScatterRecv(desc, src_rank, group);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_barrier(PyObject *_unused, PyObject *_group)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THDBarrier(_getGroup(_group));
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_newGroup(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
PyObject* sequence = PyTuple_GET_ITEM(args, 0);
|
||||
Py_ssize_t tmp_length;
|
||||
std::size_t length;
|
||||
std::vector<int> ranks;
|
||||
|
||||
if (PyTuple_GET_SIZE(args) != 1 || !PySequence_Check(sequence))
|
||||
goto invalid_arguments;
|
||||
|
||||
tmp_length = PySequence_Length(sequence);
|
||||
THPUtils_assert(tmp_length >= 0, "couldn't obtain the length of %s",
|
||||
THPUtils_typename(sequence));
|
||||
|
||||
length = static_cast<std::size_t>(tmp_length);
|
||||
ranks.reserve(length);
|
||||
for (std::size_t i = 0; i < length; ++i) {
|
||||
if (!THPUtils_checkLong(PySequence_ITEM(sequence, i)))
|
||||
goto invalid_arguments;
|
||||
|
||||
ranks.push_back(THPUtils_unpackLong(PySequence_ITEM(sequence, i)));
|
||||
for (std::size_t j = 0; j < i; ++j)
|
||||
THPUtils_assert(ranks[i] != ranks[j], "ranks should be unique");
|
||||
}
|
||||
|
||||
return PyInt_FromLong(THDNewGroup(ranks.data(), length));
|
||||
|
||||
invalid_arguments:
|
||||
THPUtils_invalidArguments(args, NULL, "newGroup", 1, "(list[int] ranks)");
|
||||
return NULL;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_requestIsCompleted(PyObject *_unused, PyObject *_req)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (!THPWrapper_check(_req)) {
|
||||
THPUtils_invalidArguments(_req, NULL, "requestIsCompleted", 1, "(request req)");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return PyBool_FromLong(THDRequest_isCompleted(_unpackRequest(_req)));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_requestWait(PyObject *_unused, PyObject *_req)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (!THPWrapper_check(_req)) {
|
||||
THPUtils_invalidArguments(_req, NULL, "requestWait", 1, "(request req)");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
THDRequest_wait(_unpackRequest(_req));
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_initExtension(PyObject *_unused, PyObject *args) {
|
||||
if (PyTuple_GET_SIZE(args) != 3) {
|
||||
THPUtils_invalidArguments(args, NULL, "initExtension", 1, "(bool is_master_worker, reduce_op obj, group obj)");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
PyObject* is_master_worker_obj = PyTuple_GET_ITEM(args, 0);
|
||||
PyObject* reduce_op_obj = PyTuple_GET_ITEM(args, 1);
|
||||
PyObject* group_obj = PyTuple_GET_ITEM(args, 2);
|
||||
|
||||
THPUtils_assert(PyBool_Check(is_master_worker_obj), "first argument should be a bool");
|
||||
bool is_master_worker = is_master_worker_obj == Py_True;
|
||||
|
||||
THPObjectPtr reduce_op;
|
||||
#define REGISTER_REDUCE_OP(NAME) \
|
||||
reduce_op = PyObject_GetAttrString(reduce_op_obj, #NAME); \
|
||||
THPUtils_assert(reduce_op, "Missing object for reduce op " #NAME); \
|
||||
obj2reduceop.emplace(reduce_op.get(), THDReduce##NAME);
|
||||
REGISTER_REDUCE_OP(SUM);
|
||||
REGISTER_REDUCE_OP(PRODUCT);
|
||||
REGISTER_REDUCE_OP(MIN);
|
||||
REGISTER_REDUCE_OP(MAX);
|
||||
#undef REGISTER_REDUCE_OP
|
||||
|
||||
THPObjectPtr group;
|
||||
#define REGISTER_GROUP(NAME) \
|
||||
group = PyObject_GetAttrString(group_obj, #NAME); \
|
||||
THPUtils_assert(group, "Missing object for group " #NAME); \
|
||||
obj2group.emplace(group.get(), THDGroup##NAME);
|
||||
REGISTER_GROUP(WORLD);
|
||||
#undef REGISTER_GROUP
|
||||
|
||||
if (is_master_worker) {
|
||||
PyObject *module = PyImport_ImportModule("torch.distributed");
|
||||
THPUtils_assert(module, "class loader couldn't access torch.distributed module");
|
||||
PyObject* module_dict = PyModule_GetDict(module);
|
||||
if (!THDPModule_loadClasses(module_dict)) return NULL;
|
||||
}
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
||||
static struct PyMethodDef _THDPModule_methods[] = {
|
||||
{"_dist_init_extension", (PyCFunction)THDPModule_initExtension, METH_VARARGS, NULL},
|
||||
{"_dist_init_process_group", (PyCFunction)THDPModule_initProcessGroup, METH_O, NULL},
|
||||
{"_dist_init_master_worker", (PyCFunction)THDPModule_initMasterWorker, METH_O, NULL},
|
||||
{"_dist_get_rank", (PyCFunction)THDPModule_getRank, METH_NOARGS, NULL},
|
||||
{"_dist_get_num_processes", (PyCFunction)THDPModule_getNumProcesses, METH_NOARGS, NULL},
|
||||
{"_dist_isend", (PyCFunction)THDPModule_isend, METH_VARARGS, NULL},
|
||||
{"_dist_irecv", (PyCFunction)THDPModule_irecv, METH_VARARGS, NULL},
|
||||
{"_dist_send", (PyCFunction)THDPModule_send, METH_VARARGS, NULL},
|
||||
{"_dist_recv_any_source", (PyCFunction)THDPModule_recvAnySource, METH_O, NULL},
|
||||
{"_dist_recv", (PyCFunction)THDPModule_recv, METH_VARARGS, NULL},
|
||||
{"_dist_all_reduce", (PyCFunction)THDPModule_allReduce, METH_VARARGS, NULL},
|
||||
{"_dist_reduce", (PyCFunction)THDPModule_reduce, METH_VARARGS, NULL},
|
||||
{"_dist_broadcast", (PyCFunction)THDPModule_broadcast, METH_VARARGS, NULL},
|
||||
{"_dist_all_gather", (PyCFunction)THDPModule_allGather, METH_VARARGS, NULL},
|
||||
{"_dist_gather_send", (PyCFunction)THDPModule_gatherSend, METH_VARARGS, NULL},
|
||||
{"_dist_gather_recv", (PyCFunction)THDPModule_gatherRecv, METH_VARARGS, NULL},
|
||||
{"_dist_scatter_send", (PyCFunction)THDPModule_scatterSend, METH_VARARGS, NULL},
|
||||
{"_dist_scatter_recv", (PyCFunction)THDPModule_scatterRecv, METH_VARARGS, NULL},
|
||||
{"_dist_barrier", (PyCFunction)THDPModule_barrier, METH_O, NULL},
|
||||
{"_dist_new_group", (PyCFunction)THDPModule_newGroup, METH_VARARGS, NULL},
|
||||
{"_dist_request_is_completed", (PyCFunction)THDPModule_requestIsCompleted, METH_O, NULL},
|
||||
{"_dist_request_wait", (PyCFunction)THDPModule_requestWait, METH_O, NULL},
|
||||
{NULL}
|
||||
};
|
||||
|
||||
PyMethodDef* THDPModule_methods() {
|
||||
return _THDPModule_methods;
|
||||
}
|
||||
14
torch/csrc/distributed/Storage.cpp
Normal file
14
torch/csrc/distributed/Storage.cpp
Normal file
@ -0,0 +1,14 @@
|
||||
#include <Python.h>
|
||||
#include <structmember.h>
|
||||
|
||||
#include <stdbool.h>
|
||||
#include "THDP.h"
|
||||
|
||||
#include "override_macros.h"
|
||||
|
||||
#define THD_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
|
||||
#include <THD/base/THDGenerateAllTypes.h>
|
||||
|
||||
//#define THD_GENERIC_FILE "torch/csrc/generic/StorageCopy.cpp"
|
||||
//#include <THD/THDGenerateAllTypes.h>
|
||||
|
||||
45
torch/csrc/distributed/Storage.h
Normal file
45
torch/csrc/distributed/Storage.h
Normal file
@ -0,0 +1,45 @@
|
||||
#ifndef THDP_STORAGE_INC
|
||||
#define THDP_STORAGE_INC
|
||||
|
||||
#define THDPStorage TH_CONCAT_3(THDP,Real,Storage)
|
||||
#define THDPStorageStr TH_CONCAT_STRING_3(torch.cuda.,Real,Storage)
|
||||
#define THDPStorageClass TH_CONCAT_3(THDP,Real,StorageClass)
|
||||
#define THDPStorage_(NAME) TH_CONCAT_4(THDP,Real,Storage_,NAME)
|
||||
|
||||
#define THDPDoubleStorage_Check(obj) \
|
||||
PyObject_IsInstance(obj, THDPDoubleStorageClass)
|
||||
#define THDPFloatStorage_Check(obj) \
|
||||
PyObject_IsInstance(obj, THDPFloatStorageClass)
|
||||
#define THDPHalfStorage_Check(obj) \
|
||||
PyObject_IsInstance(obj, THDPHalfStorageClass)
|
||||
#define THDPLongStorage_Check(obj) \
|
||||
PyObject_IsInstance(obj, THDPLongStorageClass)
|
||||
#define THDPIntStorage_Check(obj) \
|
||||
PyObject_IsInstance(obj, THDPIntStorageClass)
|
||||
#define THDPShortStorage_Check(obj) \
|
||||
PyObject_IsInstance(obj, THDPShortStorageClass)
|
||||
#define THDPCharStorage_Check(obj) \
|
||||
PyObject_IsInstance(obj, THDPCharStorageClass)
|
||||
#define THDPByteStorage_Check(obj) \
|
||||
PyObject_IsInstance(obj, THDPByteStorageClass)
|
||||
|
||||
#define THDPDoubleStorage_CData(obj) (obj)->cdata
|
||||
#define THDPFloatStorage_CData(obj) (obj)->cdata
|
||||
#define THDPLongStorage_CData(obj) (obj)->cdata
|
||||
#define THDPIntStorage_CData(obj) (obj)->cdata
|
||||
#define THDPShortStorage_CData(obj) (obj)->cdata
|
||||
#define THDPCharStorage_CData(obj) (obj)->cdata
|
||||
#define THDPByteStorage_CData(obj) (obj)->cdata
|
||||
|
||||
#ifdef _THP_CORE
|
||||
#define THDPStorageType TH_CONCAT_3(THDP,Real,StorageType)
|
||||
#define THDPStorageBaseStr TH_CONCAT_STRING_3(Distributed,Real,StorageBase)
|
||||
#endif
|
||||
|
||||
#include "override_macros.h"
|
||||
|
||||
#define THD_GENERIC_FILE "torch/csrc/generic/Storage.h"
|
||||
#include <THD/base/THDGenerateAllTypes.h>
|
||||
|
||||
#endif
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user