Compare commits

...

156 Commits

Author SHA1 Message Date
c54597e0b2 std::move fixes 2017-02-03 21:31:03 +01:00
833b8cbc7a Remove unused code from module 2017-02-02 17:20:11 +01:00
75aeb16e05 Merge commit '72089c9c36c6b880c695baf732cd04329d72c098' 2017-02-01 22:00:42 -08:00
10bb6bb9b8 Fix function names in error messages 2017-02-01 15:21:57 -08:00
3c9ef69c37 Fix THCTensor::isSparse 2017-02-01 14:51:06 -08:00
dee987d6ee use pseudo-fp16 2017-02-01 23:48:09 +01:00
138f254ec1 Support sparse tensors in THPP (#667) 2017-02-01 17:34:50 -05:00
c7c8aaa7f0 Add ModuleList and ParameterList to nn 2017-02-01 23:26:31 +01:00
d0db624e02 Add W503 to PEP8 ignore list (#646) 2017-02-01 15:57:09 -05:00
e3e7b76310 Rename all normal and log_normal args to std 2017-02-01 21:48:11 +01:00
dad02bceb9 Remove duplicated line in cwrap 2017-02-01 21:48:11 +01:00
b195285879 Improve CUDA detection in THPP 2017-02-01 21:48:11 +01:00
8f3da5b51d set_index -> _set_index 2017-02-01 21:48:11 +01:00
825e919eb8 Add torch.unbind 2017-02-01 21:48:11 +01:00
acb0ce8885 Add LongTensor indexing support 2017-02-01 21:48:11 +01:00
72089c9c36 Update THHalf.c 2017-02-01 11:53:29 -08:00
cf2f158fec Remove erroneous proprietary license header
This change was approved by NVIDIA Legal, and I am authorized to make the change on behalf of the company.
2017-02-01 11:43:44 -08:00
6470b5bd21 Add test for Embedding with sparse=True (#663) 2017-02-01 09:54:42 +05:30
tvn
44196955e2 ByteTensor should be unsigned (#664)
ByteTensor should be unsigned
2017-01-31 21:43:39 -05:00
f08ec1394d Fix bug with inplace TH(CU)NN
Also, remove unnecessary zero_() calls
2017-01-31 21:00:49 +01:00
f8fb25e0a2 Add generic bindings to THNN and THCUNN (#645)
Adds bindings using thpp::Tensor to THNN and THCUNN. This allows calling
into those APIs without knowing the concrete types of the tensor
arguments.
2017-01-31 13:23:02 -05:00
6a0c66752f Fix documentation and argument name for Tensor.normal_(mean, stddev) (#652) 2017-01-31 11:55:39 -05:00
a1bd4efb08 readme: add guidance on disabling CUDA (#655) 2017-01-31 14:05:51 +05:30
b43ce05268 Refactor parts of utils.h (#648)
Moves THPObjectPtr into a separate header, so that it can be included
independently. Currently, utils.h requries all of THP.h. Also adds RAII
structs for acquiring and releasing the GIL.
2017-01-30 21:16:28 -05:00
80e56cfda9 Merge commit 'dc9a5b7d2fbcf21268b524b9da5ae38a74214a59' 2017-01-30 17:58:05 -08:00
24701fc5a7 Merge commit '03dcf8a83bb009ecfdd8f27c4d9a6db40829b690' 2017-01-30 17:57:20 -08:00
f78a266d99 Merge commit '368cbe615d0a7bdaadddcb3bd390abcd4cc17b91' 2017-01-30 17:56:37 -08:00
f096fb6859 adding cudnn V6 support (#515) 2017-01-31 02:01:37 +01:00
a3e11d606b Fix linter errors 2017-01-31 01:58:09 +01:00
79232c24e2 Fixes after rebase 2017-01-31 01:58:09 +01:00
15d9d499ab Remove ZMQ dependency from compilation files 2017-01-31 01:58:09 +01:00
962084c8e8 Add Data Channel receive from any source (#52) 2017-01-31 01:58:09 +01:00
7518b1eefb Introduce Scalar for easier send/receive types through DataChannel 2017-01-31 01:58:09 +01:00
8215d7a4ba Implement TH_API functions from the set 2 (#49) 2017-01-31 01:58:09 +01:00
5aaa220d84 Thd functions v3 (#46) 2017-01-31 01:58:09 +01:00
12c16ab9bc Remaining storage functions implemented 2017-01-31 01:58:09 +01:00
76520512e7 DataChannel tests rewrite (#42); DataChannel isend and irecv implementation (#44) 2017-01-31 01:58:09 +01:00
66de965882 Replace ZeroMQ (#41) 2017-01-31 01:58:09 +01:00
10d32fb0b7 Fix DataChannel tests failure (#43)
Tests failed due to accessing reference which could be invalid.
2017-01-31 01:58:09 +01:00
e72c9b6e4a Storage constructors implemented (#40) 2017-01-31 01:58:09 +01:00
ac1f68127a Add barrier, scatter, gather and allGather implementations + groups (#34) 2017-01-31 01:58:09 +01:00
60d1852c7b Major improvements to master-worker mode
* Fixed all undefined symbol errors
* Implemented storage interface and THStorage class
* RPC improvements
* Code refactor
2017-01-31 01:58:09 +01:00
d53eb521fc Add missing headers. 2017-01-31 01:58:09 +01:00
9808932f10 Refactor RPC and change TensorType to Type 2017-01-31 01:58:09 +01:00
ea876eb6d5 Add initial bindings for master-worker mode 2017-01-31 01:58:09 +01:00
0a45864866 Add THDStorage and improve master-worker mode implementation 2017-01-31 01:58:09 +01:00
2560b39796 Merge TensorTypeTraits.hpp with TensorTraits.hpp 2017-01-31 01:58:09 +01:00
21afa4c88b Worker handling for constructors + destructor 2017-01-31 01:58:09 +01:00
9fc3c5e4d2 THDTensor constructors implemented + some minor fixes 2017-01-31 01:58:09 +01:00
3e3501c98d Integration tests of the THD Python interface (#28) 2017-01-31 01:58:09 +01:00
5e6fcd02b5 Implement data channel groups (#25) 2017-01-31 01:58:09 +01:00
d46ebcfadf Fix broadcast and reduce implementations
Due to bad rank mapping broadcast and reduce were connecting
wrong processes what resulted in errors or not received/sent tensors.

 * Introduced new mapping method to solve this problem.
 * Added and improved tests for this cases.
2017-01-31 01:58:09 +01:00
41480c8cf2 Data channel maintenance 2017-01-31 01:58:09 +01:00
236890d902 Fix transitive library dependencies in CMake 2017-01-31 01:58:09 +01:00
55632d81d2 Add Python wrappers for process group mode 2017-01-31 01:58:09 +01:00
0b276d622e Add reduce and allReduce implementations (#15) 2017-01-31 01:58:09 +01:00
c81491b37d Preserve directory structure when installing headers 2017-01-31 01:58:09 +01:00
42e189425f Detect ZMQ libs and headers in CMake 2017-01-31 01:58:09 +01:00
3cfa0d7199 Expose C API for process group mode 2017-01-31 01:58:09 +01:00
7c9e088661 Reorganize THD directory structure 2017-01-31 01:58:09 +01:00
e78aa4bb84 Implement CommandChannel with ZMQ. 2017-01-31 01:58:09 +01:00
f8e94d0d8b Implement DataChannel (MPI and TCP) (#8) 2017-01-31 01:58:09 +01:00
ebe6f40fce RPC message packing and unpacking implemented 2017-01-31 01:58:09 +01:00
5fb37efb46 Use #pragma once instead of defines 2017-01-31 01:58:09 +01:00
4f47855873 Style improvements 2017-01-31 01:58:09 +01:00
52ae6f682f Add initial version of tensor wrappers 2017-01-31 01:58:09 +01:00
c35f58f97b Template for THD implementation 2017-01-31 01:58:09 +01:00
659b2f3154 Add more autograd functions 2017-01-31 00:39:34 +01:00
5ea05cfb96 Return indices from Variable sort and topk 2017-01-31 00:39:34 +01:00
dc9a5b7d2f Fix memory leak in SpatialMaxUnpooling 2017-01-30 23:23:07 +01:00
f7ab5a128a Delete extra bracket in RNNCellBase.__repr__. (#637)
This extra bracket causes a ValueError when trying to print a Module that uses RNNCellBase or any of its subclasses.
2017-01-29 23:21:24 -05:00
368cbe615d Add Ubuntu 16.04 lib paths in CMake 2017-01-30 01:16:02 +01:00
d4c9a3782b billinear -> bilinear, docs for upsampling, improved docs for Unpooling, pep8 tests fix (#617)
* billinear -> bilinear, docs for upsampling, improved docs for Unpooling, pep8 tests fix
2017-01-30 05:08:48 +05:30
172dca5e8b Fix bug in cat (non-contiguous first input) 2017-01-29 21:25:53 +01:00
818bf0c408 Compile with asserts by default 2017-01-29 21:21:59 +01:00
03dcf8a83b Compile with asserts on by default 2017-01-29 21:18:54 +01:00
604f607fd1 Add asserts in index* functions 2017-01-29 21:18:43 +01:00
956d946c25 Default initial hidden states for recurrent layers (#605)
Fixes #434
2017-01-29 12:38:56 +01:00
970caaa621 Exclude sphinx_rtd_theme from pep8 2017-01-28 23:37:39 -05:00
00a5980cdf Improve RNN doc formatting 2017-01-28 23:37:39 -05:00
e24eee04f0 Link THC to THPP 2017-01-28 23:37:39 -05:00
f1b3af4ee2 Add more bernoulli options in cwrap 2017-01-28 23:37:39 -05:00
fb2d28f477 remove circular references in NestedIOFunction 2017-01-28 23:30:06 +01:00
3a704ff725 Fix legacy load_lua for SpatialConvolution (#608)
* fix legacy load_lua for conv2d

* fix pep8
2017-01-28 20:19:18 +01:00
0180e638e5 Remove unnecessary zero_() calls in cuDNN RNN 2017-01-28 14:36:57 +01:00
95c6ae04fb Fix non-contiguous grad handling in cuDNN RNN 2017-01-28 14:36:57 +01:00
27c4c6e0af Merge commit '6ee77b4edd1552d3a9a2e5389ffc351e513a8089' 2017-01-27 17:29:07 -08:00
da17414b3f Merge commit '343d65db91c2419843d36aed5467c2d1374108bc' 2017-01-27 17:16:08 -08:00
be2b27a747 Merge commit '4461ae809043390d5223905cb82b17035c7f9f31' 2017-01-27 17:15:21 -08:00
aec2c8f752 Merge commit 'c45ff2efe64d0face3889194ba6f885fe9cc4d48' 2017-01-27 17:12:13 -08:00
13e34b4679 Fix multiprocessing tests 2017-01-28 01:18:42 +01:00
57373c7c29 Fix docs 2017-01-28 01:16:04 +01:00
79f5bf84e5 [pep8] Potentially breaking docstring changes 2017-01-28 01:15:51 +01:00
3ed720079e [pep8] Fix most remaining lint manually 2017-01-28 01:15:51 +01:00
e7c1e6a8e3 [pep8] Fix most lint automatically with autopep8
Here's the command I used to invoke autopep8 (in parallel!):

    git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i

Several rules are ignored in setup.cfg. The goal is to let autopep8
handle everything which it can handle safely, and to disable any rules
which are tricky or controversial to address. We may want to come back
and re-enable some of these rules later, but I'm trying to make this
patch as safe as possible.

Also configures flake8 to match pep8's behavior.

Also configures TravisCI to check the whole project for lint.
2017-01-28 01:15:51 +01:00
f1d0d73ed7 Fix flaky Sqrt test 2017-01-28 00:45:49 +01:00
9c411513bf Patch distutils crash when linking with ccache 2017-01-28 00:28:33 +01:00
ce78bc898b Fix travis builds and add ccache 2017-01-28 00:28:33 +01:00
887002e932 Add bindings to CUDA tensors and storages in THPP (#615) 2017-01-27 18:15:56 -05:00
31dea5ff23 Small typo in README (#613) 2017-01-27 20:18:36 +01:00
ec4602a973 Fix bad code alignment (#612)
forward *is* a method of the Linear class
2017-01-27 20:16:49 +01:00
a38749d15f Fix cuda notes
Target GPU *is* consisten with source GPU
2017-01-27 19:30:49 +01:00
6ee77b4edd Added cunn support for TemporalRowConvolutionMM (#415)
* Added cunn TemporalRowConvolutionMM support
2017-01-27 13:30:25 -05:00
343d65db91 Rowconv repull (#1120)
* Added TemporalRowConvolutionMM layer, tests, and documentation
2017-01-27 13:29:05 -05:00
a90913105c add make-contiguous in batchnorm backward (#602) 2017-01-26 16:17:39 -05:00
9368596059 legacy.nn Attributes: Add '_gradOutput' to SpatialConvolution. (#600) 2017-01-26 15:00:41 -05:00
80ed795ff1 Minor ffi utils fix 2017-01-26 11:55:49 +01:00
a2938e3d11 add cc 3.0 to nccl (#594) 2017-01-25 22:47:23 -05:00
2ad967dbe4 Fix pep8 in setup.py with "autopep8 -i setup.py" 2017-01-25 22:23:22 -05:00
7415c090ac Check setup.py for pep8 lint on TravisCI 2017-01-25 22:23:22 -05:00
a1fa995044 Fixes and improvements (#593)
* Fix error in ELU backward

* Add --seed flag for testst st

* Add test for BatchNorm eval

* Fix autograd.backward docs

* Support cc flags in cuDNN search

* Fix IndexSelect backward formula
2017-01-25 22:21:49 -05:00
3c2ecc6b15 add dockerfiles (#583)
* add dockerfiles
2017-01-25 17:30:29 -05:00
fa1516d319 Install THCUNN.h and generic/THCUNN.h
The THCApply.cuh is moved to the .cu files so that THCUNN.h can be
compiled by a standard C compiler.
2017-01-25 14:13:17 -08:00
5e26f49db4 Install THNN.h and generic/THNN.h 2017-01-25 14:09:09 -08:00
7694f65120 Revert "Using accreal instead of real in the API" 2017-01-25 16:26:42 -05:00
b5ebf68df1 Revert "Convert real to accreal in libTHCUNN" 2017-01-25 16:13:20 -05:00
aa46055274 Update CI links in README (#579) 2017-01-25 13:58:05 -05:00
2cad802b68 Revert "cuda implementation of Gated Linear Unit" 2017-01-25 13:15:22 -05:00
2d01f384f1 fallback to nn batchnorm on backward-evaluate (#589) 2017-01-25 12:38:57 -05:00
f8d4f980b3 Add upsampling modules and functions 2017-01-24 17:30:50 -05:00
4f5a6c366e Make Variables non-comparable 2017-01-24 17:30:50 -05:00
ecfcf39f30 Improve optimizer serialization
Also, add optimizer.load_state_dict
2017-01-24 17:30:50 -05:00
3975a2676e Fix invalid DECREF in torch.Size constructor 2017-01-24 17:30:50 -05:00
138ee75a3b Fix for target_link_libraries on CMake 2.8 (#581) 2017-01-24 17:26:24 -05:00
0048f228cb Add spatial test for LogSoftmax 2017-01-24 23:24:25 +01:00
2748b920ab make adam have the same lr as lua torch (#576) 2017-01-24 16:35:28 -05:00
a92a2312d4 Add missing fields to read_lua_file for BatchNorm and Linear layers. 2017-01-24 22:09:47 +01:00
945ce5cdb0 Fix math block of GRUCell in docs (#572)
Added a blank space between the beginning of the `.. math::` block, otherwise it is displayed as a code block.
2017-01-24 14:28:56 -05:00
b39de2cbbe Merge pull request #416 from pavanky/half-fixes
Convert real to accreal in libTHCUNN
2017-01-24 12:17:49 -05:00
49a555e0f5 Merge pull request #1109 from pavanky/api
Using accreal instead of real in the API
2017-01-24 12:17:17 -05:00
ce13900148 update From Source instructions 2017-01-24 10:48:25 -05:00
4c77ad6ee4 step_rate -> lr in adadelta (#569) 2017-01-24 10:05:59 -05:00
0bc4246425 adding NLLLoss2d to docs 2017-01-24 09:22:51 -05:00
c45ff2efe6 Merge pull request #915 from pavanky/convert
Macros to convert between real and accreal
2017-01-24 09:14:33 -05:00
99b520cc5d Merge pull request #421 from huihuifan/cudaGLU
cuda implementation of Gated Linear Unit
2017-01-24 09:13:34 -05:00
e05607aee1 Add fall back to implicit GEMM and friends. (#558)
If we can't allocate the workspace for the desired algorithm, we fall
back to a default algorithm which does not require a workspace.
2017-01-24 09:10:39 -05:00
a360ba1734 Add a hint about CUDNN_STATUS_NOT_SUPPORTED 2017-01-24 09:09:30 -05:00
c661b963b9 Add more contiguity checks to cuDNN 2017-01-24 09:09:30 -05:00
e374dc1696 add step rate to adadelta (#568)
Scales `delta` before it is applied to the parameters in order to control the learning rate of the optimizer (inspired from climin optim lib for theano).
Also changed the link to the Adadelta paper to point to the right location.
2017-01-24 08:48:19 -05:00
116e0c7f38 Merge commit '45596d52897fb187701943cb77456ff1e7249989' 2017-01-23 14:37:44 -08:00
45596d5289 Add contiguity checks to THCUNN 2017-01-23 14:17:51 -08:00
342e7b873d fixing THPP cmake for cmake < 3.1 (#559) 2017-01-23 14:47:06 -05:00
00410c4496 Fix broken THNN groups in conv functions 2017-01-22 18:32:51 -05:00
8b9276bbee Fix view bug in Conv1d 2017-01-22 18:32:51 -05:00
3238786ea1 Improve optimizer error messages 2017-01-22 18:32:51 -05:00
07ebbcbcb3 Add Parameter docs 2017-01-22 18:32:51 -05:00
ca555abcf9 fix comments 2017-01-22 18:02:40 -05:00
63893c3fa2 Fix auto-gpu semantics for indexing 2017-01-22 18:02:40 -05:00
f8ae34706e Port L-BFGS from Lua optim 2017-01-22 18:02:40 -05:00
7179002bfb cuda implementation of Gated Linear Unit 2017-01-19 23:01:30 -08:00
43b5be1d78 added c implementation of GatedLinearUnit 2017-01-19 22:18:08 -08:00
b5f6fdb814 Using accreal instead of real in the API
This is done to be consistent with the changes made to cunn
2017-01-17 16:58:19 -08:00
a69d819901 Converting all instances of real to accreal in libTHCUNN
This is because the current version of luaffifb fails to pass
custom structs (i.e. half) as arguments or accept them as return
values.

The accreal parameters are immediately converted to real internally.
This is done to ensure none of the internal code needs to be changed.

This change also removes transform_reals_to_half which is no longer
necessary.

Change-Id: I978151d001de5492576fb0eddfa0608cd4e99149
2017-01-17 16:06:42 -08:00
fef2b1526d Adding macros to convert between real and accreal 2017-01-17 15:14:45 -08:00
3719994c96 Remove redundant code in THGenerateAllTypes.h 2017-01-17 15:12:43 -08:00
4461ae8090 include cstddef for msvc 2017-01-15 23:45:48 +08:00
511 changed files with 20397 additions and 9101 deletions

3
.gitignore vendored
View File

@ -15,6 +15,9 @@ torch/csrc/nn/THNN.cwrap
torch/csrc/nn/THNN.cpp torch/csrc/nn/THNN.cpp
torch/csrc/nn/THCUNN.cwrap torch/csrc/nn/THCUNN.cwrap
torch/csrc/nn/THCUNN.cpp torch/csrc/nn/THCUNN.cpp
torch/csrc/nn/THNN_generic.cwrap
torch/csrc/nn/THNN_generic.cpp
torch/csrc/nn/THNN_generic.h
docs/src/**/* docs/src/**/*
test/data/legacy_modules.t7 test/data/legacy_modules.t7
test/htmlcov test/htmlcov

View File

@ -4,16 +4,25 @@ python:
- 2.7.8 - 2.7.8
- 2.7 - 2.7
- 3.5 - 3.5
- 3.6
- nightly - nightly
cache:
- ccache
- directories:
- $HOME/.ccache
install: install:
- export CC="gcc-4.8" - unset CCACHE_DISABLE
- export CXX="g++-4.8" - export CCACHE_DIR=$HOME/.ccache
- export CC="ccache gcc-4.8"
- export CXX="ccache g++-4.8"
- ccache --show-stats
- travis_retry pip install -r requirements.txt - travis_retry pip install -r requirements.txt
- travis_retry pip install . - python setup.py install
script: script:
- ./test/run_test.sh - OMP_NUM_THREADS=2 ./test/run_test.sh
addons: addons:
apt: apt:
@ -30,3 +39,9 @@ sudo: false
matrix: matrix:
fast_finish: true fast_finish: true
include:
env: LINT_CHECK
python: "2.7"
addons: true
install: pip install pep8
script: pep8

33
Dockerfile Normal file
View File

@ -0,0 +1,33 @@
FROM nvidia/cuda:8.0-cudnn5-devel-ubuntu14.04
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
curl \
ca-certificates \
libjpeg-dev \
libpng-dev &&\
rm -rf /var/lib/apt/lists/*
RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-4.2.12-Linux-x86_64.sh && \
chmod +x ~/miniconda.sh && \
~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh && \
/opt/conda/bin/conda install conda-build && \
/opt/conda/bin/conda create -y --name pytorch-py35 python=3.5.2 numpy scipy ipython mkl&& \
/opt/conda/bin/conda clean -ya
ENV PATH /opt/conda/envs/pytorch-py35/bin:$PATH
RUN conda install --name pytorch-py35 -c soumith magma-cuda80
# This must be done before pip so that requirements.txt is available
WORKDIR /opt/pytorch
COPY . .
RUN cat requirements.txt | xargs -n1 pip install --no-cache-dir && \
TORCH_CUDA_ARCH_LIST="3.5 5.2 6.0 6.1+PTX" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
CMAKE_LIBRARY_PATH=/opt/conda/envs/pytorch-py35/lib \
CMAKE_INCLUDE_PATH=/opt/conda/envs/pytorch-py35/include \
pip install -v .
WORKDIR /workspace
RUN chmod -R a+w /workspace

View File

@ -14,17 +14,17 @@ We are in an early-release Beta. Expect some adventures and rough edges.
- [Installation](#installation) - [Installation](#installation)
- [Binaries](#binaries) - [Binaries](#binaries)
- [From source](#from-source) - [From source](#from-source)
- [Docker image](#docker-image)
- [Getting Started](#getting-started) - [Getting Started](#getting-started)
- [Communication](#communication) - [Communication](#communication)
- [Releases and Contributing](#releases-and-contributing) - [Releases and Contributing](#releases-and-contributing)
- [The Team](#the-team) - [The Team](#the-team)
| Python | **`Linux CPU`** | **`Linux GPU`** | | System | Python | Status |
|--------|--------------------|------------------| | --- | --- | --- |
| 2.7.8 | [![Build Status](https://travis-ci.com/apaszke/pytorch.svg?token=shqHbUq29zKDxuqzGcjC&branch=master)](https://travis-ci.com/apaszke/pytorch) | | | Linux CPU | 2.7.8, 2.7, 3.5, nightly | [![Build Status](https://travis-ci.org/pytorch/pytorch.svg?branch=master)](https://travis-ci.org/pytorch/pytorch) |
| 2.7 | [![Build Status](https://travis-ci.com/apaszke/pytorch.svg?token=shqHbUq29zKDxuqzGcjC&branch=master)](https://travis-ci.com/apaszke/pytorch) | [![Build Status](http://build.pytorch.org:8080/buildStatus/icon?job=pytorch-master-py2)](https://build.pytorch.org/job/pytorch-master-py2) | | Linux GPU | 2.7 | [![Build Status](http://build.pytorch.org:8080/buildStatus/icon?job=pytorch-master-py2)](https://build.pytorch.org/job/pytorch-master-py2) |
| 3.5 | [![Build Status](https://travis-ci.com/apaszke/pytorch.svg?token=shqHbUq29zKDxuqzGcjC&branch=master)](https://travis-ci.com/apaszke/pytorch) | [![Build Status](http://build.pytorch.org:8080/buildStatus/icon?job=pytorch-master-py3)](https://build.pytorch.org/job/pytorch-master-py3) | | Linux GPU | 3.5 | [![Build Status](http://build.pytorch.org:8080/buildStatus/icon?job=pytorch-master-py3)](https://build.pytorch.org/job/pytorch-master-py3) |
| Nightly| [![Build Status](https://travis-ci.com/apaszke/pytorch.svg?token=shqHbUq29zKDxuqzGcjC&branch=master)](https://travis-ci.com/apaszke/pytorch) | |
## More about PyTorch ## More about PyTorch
@ -101,7 +101,7 @@ We hope you never spend hours debugging your code because of bad stack traces or
PyTorch has minimal framework overhead. We integrate acceleration libraries PyTorch has minimal framework overhead. We integrate acceleration libraries
such as Intel MKL and NVIDIA (CuDNN, NCCL) to maximize speed. such as Intel MKL and NVIDIA (CuDNN, NCCL) to maximize speed.
At the core, it's CPU and GPU Tensor and Neural Network backends At the core, its CPU and GPU Tensor and Neural Network backends
(TH, THC, THNN, THCUNN) are written as independent libraries with a C99 API. (TH, THC, THNN, THCUNN) are written as independent libraries with a C99 API.
They are mature and have been tested for years. They are mature and have been tested for years.
@ -135,24 +135,36 @@ conda install pytorch torchvision -c soumith
### From source ### From source
Instructions for an Anaconda environment. If you are installing from source, we highly recommend installing an [Anaconda](https://www.continuum.io/downloads) environment.
You will get a high-quality BLAS library (MKL) and you get a controlled compiler version regardless of your Linux distro.
Once you have [anaconda](https://www.continuum.io/downloads) installed, here are the instructions.
If you want to compile with CUDA support, install If you want to compile with CUDA support, install
- [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 7.5 or above - [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 7.5 or above
- [NVIDIA CuDNN](https://developer.nvidia.com/cudnn) v5.x - [NVIDIA CuDNN](https://developer.nvidia.com/cudnn) v5.x
If you want to disable CUDA support, export environment variable `NO_CUDA=1`.
#### Install optional dependencies #### Install optional dependencies
On Linux
```bash ```bash
export CMAKE_PREFIX_PATH=[anaconda root directory] export CMAKE_PREFIX_PATH=[anaconda root directory]
# Install basic dependencies # Install basic dependencies
conda install numpy mkl setuptools cmake gcc cffi conda install numpy mkl setuptools cmake gcc cffi
# On Linux, add LAPACK support for the GPU # Add LAPACK support for the GPU
conda install -c soumith magma-cuda75 # or magma-cuda80 if CUDA 8.0 conda install -c soumith magma-cuda75 # or magma-cuda80 if CUDA 8.0
``` ```
On OSX
```bash
export CMAKE_PREFIX_PATH=[anaconda root directory]
conda install numpy setuptools cmake cffi
```
#### Install PyTorch #### Install PyTorch
```bash ```bash
export MACOSX_DEPLOYMENT_TARGET=10.9 # if OSX export MACOSX_DEPLOYMENT_TARGET=10.9 # if OSX
@ -160,6 +172,25 @@ pip install -r requirements.txt
python setup.py install python setup.py install
``` ```
### Docker image
Dockerfiles are supplied to build images with cuda support and cudnn v5 and cudnn v6 RC. Build them as usual
```
docker build . -t pytorch-cudnnv5
```
or
```
docker build . -t pytorch-cudnnv6 -f tools/docker/Dockerfile-v6
```
and run them with nvidia-docker:
```
nvidia-docker run --rm -ti --ipc=host pytorch-cudnnv5
```
Please note that pytorch uses shared memory to share data between processes, so if torch multiprocessing is used (e.g.
for multithreaded data loaders) the default shared memory segment size that container runs with is not enough, and you
should increase shared memory size either with --ipc=host or --shm-size command line options to nvidia-docker run.
## Getting Started ## Getting Started
Three pointers to get you started: Three pointers to get you started:

View File

@ -201,12 +201,13 @@ from docutils import nodes
from sphinx.util.docfields import TypedField from sphinx.util.docfields import TypedField
from sphinx import addnodes from sphinx import addnodes
def patched_make_field(self, types, domain, items): def patched_make_field(self, types, domain, items):
# type: (List, unicode, Tuple) -> nodes.field # type: (List, unicode, Tuple) -> nodes.field
def handle_item(fieldarg, content): def handle_item(fieldarg, content):
par = nodes.paragraph() par = nodes.paragraph()
par += addnodes.literal_strong('', fieldarg) # Patch: this line added par += addnodes.literal_strong('', fieldarg) # Patch: this line added
#par.extend(self.make_xrefs(self.rolename, domain, fieldarg, # par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
# addnodes.literal_strong)) # addnodes.literal_strong))
if fieldarg in types: if fieldarg in types:
par += nodes.Text(' (') par += nodes.Text(' (')

View File

@ -7,6 +7,12 @@ torch.nn
.. automodule:: torch.nn .. automodule:: torch.nn
.. currentmodule:: torch.nn .. currentmodule:: torch.nn
Parameters
----------
.. autoclass:: Parameter
:members:
Containers Containers
---------------------------------- ----------------------------------
@ -362,6 +368,12 @@ Loss functions
.. autoclass:: NLLLoss .. autoclass:: NLLLoss
:members: :members:
:hidden:`NLLLoss2d`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: NLLLoss2d
:members:
:hidden:`KLDivLoss` :hidden:`KLDivLoss`
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
@ -432,6 +444,19 @@ Vision layers
.. autoclass:: PixelShuffle .. autoclass:: PixelShuffle
:members: :members:
:hidden:`UpsamplingNearest2d`
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: UpsamplingNearest2d
:members:
:hidden:`UpsamplingBilinear2d`
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: UpsamplingBilinear2d
:members:
Multi-GPU layers Multi-GPU layers
---------------- ----------------

View File

@ -29,12 +29,15 @@ Below you can find a small example showcasing this::
b = torch.FloatTensor(1).cuda() b = torch.FloatTensor(1).cuda()
# a.get_device() == b.get_device() == 1 # a.get_device() == b.get_device() == 1
c = a + b
# c.get_device() == 1
z = x + y z = x + y
# z.get_device() == 1 # z.get_device() == 0
# even within a context, you can give a GPU id to the .cuda call # even within a context, you can give a GPU id to the .cuda call
c = torch.randn(2).cuda(2) d = torch.randn(2).cuda(2)
# c.get_device() == 2 # d.get_device() == 2
Best practices Best practices
-------------- --------------

View File

@ -144,9 +144,9 @@ This is how a ``Linear`` module can be implemented::
if bias is not None: if bias is not None:
self.bias.data.uniform_(-0.1, 0.1) self.bias.data.uniform_(-0.1, 0.1)
def forward(self, input): def forward(self, input):
# See the autograd section for explanation of what happens here. # See the autograd section for explanation of what happens here.
return Linear()(input, self.weight, self.bias) return Linear()(input, self.weight, self.bias)
Writing custom C extensions Writing custom C extensions

View File

@ -106,6 +106,8 @@ Algorithms
:members: :members:
.. autoclass:: ASGD .. autoclass:: ASGD
:members: :members:
.. autoclass:: LBFGS
:members:
.. autoclass:: RMSprop .. autoclass:: RMSprop
:members: :members:
.. autoclass:: Rprop .. autoclass:: Rprop

View File

@ -14,8 +14,8 @@ Data type CPU tensor GPU tensor
32-bit floating point :class:`torch.FloatTensor` :class:`torch.cuda.FloatTensor` 32-bit floating point :class:`torch.FloatTensor` :class:`torch.cuda.FloatTensor`
64-bit floating point :class:`torch.DoubleTensor` :class:`torch.cuda.DoubleTensor` 64-bit floating point :class:`torch.DoubleTensor` :class:`torch.cuda.DoubleTensor`
16-bit floating point N/A :class:`torch.cuda.HalfTensor` 16-bit floating point N/A :class:`torch.cuda.HalfTensor`
8-bit integer (signed) :class:`torch.ByteTensor` :class:`torch.cuda.ByteTensor` 8-bit integer (unsigned) :class:`torch.ByteTensor` :class:`torch.cuda.ByteTensor`
8-bit integer (unsigned) :class:`torch.CharTensor` :class:`torch.cuda.CharTensor` 8-bit integer (signed) :class:`torch.CharTensor` :class:`torch.cuda.CharTensor`
16-bit integer (signed) :class:`torch.ShortTensor` :class:`torch.cuda.ShortTensor` 16-bit integer (signed) :class:`torch.ShortTensor` :class:`torch.cuda.ShortTensor`
32-bit integer (signed) :class:`torch.IntTensor` :class:`torch.cuda.IntTensor` 32-bit integer (signed) :class:`torch.IntTensor` :class:`torch.cuda.IntTensor`
64-bit integer (signed) :class:`torch.LongTensor` :class:`torch.cuda.LongTensor` 64-bit integer (signed) :class:`torch.LongTensor` :class:`torch.cuda.LongTensor`
@ -251,7 +251,6 @@ view of a storage and defines numeric operations on it.
.. automethod:: scatter_ .. automethod:: scatter_
.. automethod:: select .. automethod:: select
.. automethod:: set_ .. automethod:: set_
.. automethod:: set_index
.. automethod:: share_memory_ .. automethod:: share_memory_
.. automethod:: short .. automethod:: short
.. automethod:: sigmoid .. automethod:: sigmoid

View File

@ -37,6 +37,7 @@ Indexing, Slicing, Joining, Mutating Ops
.. autofunction:: stack .. autofunction:: stack
.. autofunction:: t .. autofunction:: t
.. autofunction:: transpose .. autofunction:: transpose
.. autofunction:: unbind
Random sampling Random sampling

8
setup.cfg Normal file
View File

@ -0,0 +1,8 @@
[pep8]
max-line-length = 120
ignore = E402,E721,E731,W503
exclude = docs/src
[flake8]
max-line-length = 120
ignore = E305,E402,E721,E731,F401,F403,F405,F811,F812,F821,F841

182
setup.py
View File

@ -1,6 +1,7 @@
from setuptools import setup, Extension, distutils, Command, find_packages from setuptools import setup, Extension, distutils, Command, find_packages
import setuptools.command.build_ext import setuptools.command.build_ext
import setuptools.command.install import setuptools.command.install
import distutils.unixccompiler
import distutils.command.build import distutils.command.build
import distutils.command.clean import distutils.command.clean
import platform import platform
@ -13,18 +14,26 @@ from tools.setup_helpers.env import check_env_flag
from tools.setup_helpers.cuda import WITH_CUDA, CUDA_HOME from tools.setup_helpers.cuda import WITH_CUDA, CUDA_HOME
from tools.setup_helpers.cudnn import WITH_CUDNN, CUDNN_LIB_DIR, CUDNN_INCLUDE_DIR from tools.setup_helpers.cudnn import WITH_CUDNN, CUDNN_LIB_DIR, CUDNN_INCLUDE_DIR
DEBUG = check_env_flag('DEBUG') DEBUG = check_env_flag('DEBUG')
WITH_DISTRIBUTED = check_env_flag('WITH_DISTRIBUTED')
WITH_DISTRIBUTED_MW = WITH_DISTRIBUTED and check_env_flag('WITH_DISTRIBUTED_MW')
################################################################################ ################################################################################
# Monkey-patch setuptools to compile in parallel # Monkey-patch setuptools to compile in parallel
################################################################################ ################################################################################
original_link = distutils.unixccompiler.UnixCCompiler.link
def parallelCCompile(self, sources, output_dir=None, macros=None, include_dirs=None, debug=0, extra_preargs=None, extra_postargs=None, depends=None):
def parallelCCompile(self, sources, output_dir=None, macros=None,
include_dirs=None, debug=0, extra_preargs=None,
extra_postargs=None, depends=None):
# those lines are copied from distutils.ccompiler.CCompiler directly # those lines are copied from distutils.ccompiler.CCompiler directly
macros, objects, extra_postargs, pp_opts, build = self._setup_compile(output_dir, macros, include_dirs, sources, depends, extra_postargs) macros, objects, extra_postargs, pp_opts, build = self._setup_compile(
output_dir, macros, include_dirs, sources, depends, extra_postargs)
cc_args = self._get_cc_args(pp_opts, debug, extra_preargs) cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
# compile using a thread pool # compile using a thread pool
import multiprocessing.pool import multiprocessing.pool
def _single_compile(obj): def _single_compile(obj):
src, ext = build[obj] src, ext = build[obj]
self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts) self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
@ -33,12 +42,23 @@ def parallelCCompile(self, sources, output_dir=None, macros=None, include_dirs=N
return objects return objects
def patched_link(self, *args, **kwargs):
_cxx = self.compiler_cxx
self.compiler_cxx = None
result = original_link(self, *args, **kwargs)
self.compiler_cxx = _cxx
return result
distutils.ccompiler.CCompiler.compile = parallelCCompile distutils.ccompiler.CCompiler.compile = parallelCCompile
distutils.unixccompiler.UnixCCompiler.link = patched_link
################################################################################ ################################################################################
# Custom build commands # Custom build commands
################################################################################ ################################################################################
class build_deps(Command): class build_deps(Command):
user_options = [] user_options = []
@ -53,6 +73,8 @@ class build_deps(Command):
build_all_cmd = ['bash', 'torch/lib/build_all.sh'] build_all_cmd = ['bash', 'torch/lib/build_all.sh']
if WITH_CUDA: if WITH_CUDA:
build_all_cmd += ['--with-cuda'] build_all_cmd += ['--with-cuda']
if WITH_DISTRIBUTED:
build_all_cmd += ['--with-distributed']
if subprocess.call(build_all_cmd) != 0: if subprocess.call(build_all_cmd) != 0:
sys.exit(1) sys.exit(1)
generate_nn_wrappers() generate_nn_wrappers()
@ -73,6 +95,7 @@ class build_module(Command):
class build_ext(setuptools.command.build_ext.build_ext): class build_ext(setuptools.command.build_ext.build_ext):
def run(self): def run(self):
# Print build options # Print build options
if WITH_NUMPY: if WITH_NUMPY:
@ -116,6 +139,7 @@ class build(distutils.command.build.build):
class install(setuptools.command.install.install): class install(setuptools.command.install.install):
def run(self): def run(self):
if not self.skip_build: if not self.skip_build:
self.run_command('build_deps') self.run_command('build_deps')
@ -123,6 +147,7 @@ class install(setuptools.command.install.install):
class clean(distutils.command.clean.clean): class clean(distutils.command.clean.clean):
def run(self): def run(self):
import glob import glob
with open('.gitignore', 'r') as f: with open('.gitignore', 'r') as f:
@ -138,7 +163,6 @@ class clean(distutils.command.clean.clean):
distutils.command.clean.clean.run(self) distutils.command.clean.clean.run(self)
################################################################################ ################################################################################
# Configure compile flags # Configure compile flags
################################################################################ ################################################################################
@ -161,31 +185,35 @@ include_dirs += [
tmp_install_path + "/include", tmp_install_path + "/include",
tmp_install_path + "/include/TH", tmp_install_path + "/include/TH",
tmp_install_path + "/include/THPP", tmp_install_path + "/include/THPP",
tmp_install_path + "/include/THNN",
] ]
extra_link_args.append('-L' + lib_path) extra_link_args.append('-L' + lib_path)
# we specify exact lib names to avoid conflict with lua-torch installs # we specify exact lib names to avoid conflict with lua-torch installs
TH_LIB = os.path.join(lib_path, 'libTH.so.1') TH_LIB = os.path.join(lib_path, 'libTH.so.1')
THS_LIB = os.path.join(lib_path, 'libTHS.so.1') THS_LIB = os.path.join(lib_path, 'libTHS.so.1')
THC_LIB = os.path.join(lib_path, 'libTHC.so.1') THC_LIB = os.path.join(lib_path, 'libTHC.so.1')
THCS_LIB = os.path.join(lib_path, 'libTHCS.so.1') THCS_LIB = os.path.join(lib_path, 'libTHCS.so.1')
THNN_LIB = os.path.join(lib_path, 'libTHNN.so.1') THNN_LIB = os.path.join(lib_path, 'libTHNN.so.1')
THCUNN_LIB = os.path.join(lib_path, 'libTHCUNN.so.1') THCUNN_LIB = os.path.join(lib_path, 'libTHCUNN.so.1')
THPP_LIB = os.path.join(lib_path, 'libTHPP.so.1') THPP_LIB = os.path.join(lib_path, 'libTHPP.so.1')
THD_LIB = os.path.join(lib_path, 'libTHD.so.1')
if platform.system() == 'Darwin': if platform.system() == 'Darwin':
TH_LIB = os.path.join(lib_path, 'libTH.1.dylib') TH_LIB = os.path.join(lib_path, 'libTH.1.dylib')
THS_LIB = os.path.join(lib_path, 'libTHS.1.dylib') THS_LIB = os.path.join(lib_path, 'libTHS.1.dylib')
THC_LIB = os.path.join(lib_path, 'libTHC.1.dylib') THC_LIB = os.path.join(lib_path, 'libTHC.1.dylib')
THCS_LIB = os.path.join(lib_path, 'libTHCS.1.dylib') THCS_LIB = os.path.join(lib_path, 'libTHCS.1.dylib')
THNN_LIB = os.path.join(lib_path, 'libTHNN.1.dylib') THNN_LIB = os.path.join(lib_path, 'libTHNN.1.dylib')
THCUNN_LIB = os.path.join(lib_path, 'libTHCUNN.1.dylib') THCUNN_LIB = os.path.join(lib_path, 'libTHCUNN.1.dylib')
THPP_LIB = os.path.join(lib_path, 'libTHPP.1.dylib') THPP_LIB = os.path.join(lib_path, 'libTHPP.1.dylib')
THD_LIB = os.path.join(lib_path, 'libTHD.1.dylib')
main_compile_args = ['-D_THP_CORE'] main_compile_args = ['-D_THP_CORE']
main_libraries = ['shm'] main_libraries = ['shm']
main_link_args = [TH_LIB, THS_LIB, THPP_LIB] main_link_args = [TH_LIB, THS_LIB, THPP_LIB, THNN_LIB]
main_sources = [ main_sources = [
"torch/csrc/PtrWrapper.cpp",
"torch/csrc/Module.cpp", "torch/csrc/Module.cpp",
"torch/csrc/Generator.cpp", "torch/csrc/Generator.cpp",
"torch/csrc/Size.cpp", "torch/csrc/Size.cpp",
@ -200,6 +228,7 @@ main_sources = [
"torch/csrc/autograd/variable.cpp", "torch/csrc/autograd/variable.cpp",
"torch/csrc/autograd/function.cpp", "torch/csrc/autograd/function.cpp",
"torch/csrc/autograd/engine.cpp", "torch/csrc/autograd/engine.cpp",
"torch/csrc/nn/THNN_generic.cpp",
] ]
try: try:
@ -210,6 +239,20 @@ try:
except ImportError: except ImportError:
WITH_NUMPY = False WITH_NUMPY = False
if WITH_DISTRIBUTED:
extra_compile_args += ['-DWITH_DISTRIBUTED']
main_sources += [
"torch/csrc/distributed/Module.cpp",
"torch/csrc/distributed/utils.cpp",
]
if WITH_DISTRIBUTED_MW:
main_sources += [
"torch/csrc/distributed/Tensor.cpp",
"torch/csrc/distributed/Storage.cpp",
]
include_dirs += [tmp_install_path + "/include/THD"]
main_link_args += [THD_LIB]
if WITH_CUDA: if WITH_CUDA:
cuda_lib_dirs = ['lib64', 'lib'] cuda_lib_dirs = ['lib64', 'lib']
cuda_include_path = os.path.join(CUDA_HOME, 'include') cuda_include_path = os.path.join(CUDA_HOME, 'include')
@ -218,11 +261,12 @@ if WITH_CUDA:
if os.path.exists(cuda_lib_path): if os.path.exists(cuda_lib_path):
break break
include_dirs.append(cuda_include_path) include_dirs.append(cuda_include_path)
include_dirs.append(tmp_install_path + "/include/THCUNN")
extra_link_args.append('-L' + cuda_lib_path) extra_link_args.append('-L' + cuda_lib_path)
extra_link_args.append('-Wl,-rpath,' + cuda_lib_path) extra_link_args.append('-Wl,-rpath,' + cuda_lib_path)
extra_compile_args += ['-DWITH_CUDA'] extra_compile_args += ['-DWITH_CUDA']
extra_compile_args += ['-DCUDA_LIB_PATH=' + cuda_lib_path] extra_compile_args += ['-DCUDA_LIB_PATH=' + cuda_lib_path]
main_link_args += [THC_LIB, THCS_LIB] main_link_args += [THC_LIB, THCS_LIB, THCUNN_LIB]
main_sources += [ main_sources += [
"torch/csrc/cuda/Module.cpp", "torch/csrc/cuda/Module.cpp",
"torch/csrc/cuda/Storage.cpp", "torch/csrc/cuda/Storage.cpp",
@ -238,13 +282,11 @@ if WITH_CUDNN:
include_dirs.append(CUDNN_INCLUDE_DIR) include_dirs.append(CUDNN_INCLUDE_DIR)
extra_link_args.append('-L' + CUDNN_LIB_DIR) extra_link_args.append('-L' + CUDNN_LIB_DIR)
main_sources += [ main_sources += [
"torch/csrc/cudnn/Module.cpp",
"torch/csrc/cudnn/BatchNorm.cpp", "torch/csrc/cudnn/BatchNorm.cpp",
"torch/csrc/cudnn/Conv.cpp", "torch/csrc/cudnn/Conv.cpp",
"torch/csrc/cudnn/cuDNN.cpp", "torch/csrc/cudnn/cuDNN.cpp",
"torch/csrc/cudnn/Types.cpp", "torch/csrc/cudnn/Types.cpp",
"torch/csrc/cudnn/Handles.cpp", "torch/csrc/cudnn/Handles.cpp",
"torch/csrc/cudnn/CppWrapper.cpp",
] ]
extra_compile_args += ['-DWITH_CUDNN'] extra_compile_args += ['-DWITH_CUDNN']
@ -267,70 +309,70 @@ extensions = []
packages = find_packages(exclude=('tools.*',)) packages = find_packages(exclude=('tools.*',))
C = Extension("torch._C", C = Extension("torch._C",
libraries=main_libraries, libraries=main_libraries,
sources=main_sources, sources=main_sources,
language='c++', language='c++',
extra_compile_args=main_compile_args + extra_compile_args, extra_compile_args=main_compile_args + extra_compile_args,
include_dirs=include_dirs, include_dirs=include_dirs,
extra_link_args=extra_link_args + main_link_args + [make_relative_rpath('lib')], extra_link_args=extra_link_args + main_link_args + [make_relative_rpath('lib')],
) )
extensions.append(C) extensions.append(C)
DL = Extension("torch._dl", DL = Extension("torch._dl",
sources=["torch/csrc/dl.c"], sources=["torch/csrc/dl.c"],
language='c', language='c',
) )
extensions.append(DL) extensions.append(DL)
THNN = Extension("torch._thnn._THNN", THNN = Extension("torch._thnn._THNN",
sources=['torch/csrc/nn/THNN.cpp'], sources=['torch/csrc/nn/THNN.cpp'],
language='c++', language='c++',
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
include_dirs=include_dirs, include_dirs=include_dirs,
extra_link_args=extra_link_args + [ extra_link_args=extra_link_args + [
TH_LIB, TH_LIB,
THNN_LIB, THNN_LIB,
make_relative_rpath('../lib'), make_relative_rpath('../lib'),
] ]
) )
extensions.append(THNN) extensions.append(THNN)
if WITH_CUDA: if WITH_CUDA:
THCUNN = Extension("torch._thnn._THCUNN", THCUNN = Extension("torch._thnn._THCUNN",
sources=['torch/csrc/nn/THCUNN.cpp'], sources=['torch/csrc/nn/THCUNN.cpp'],
language='c++', language='c++',
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
include_dirs=include_dirs, include_dirs=include_dirs,
extra_link_args=extra_link_args + [ extra_link_args=extra_link_args + [
TH_LIB, TH_LIB,
THC_LIB, THC_LIB,
THCUNN_LIB, THCUNN_LIB,
make_relative_rpath('../lib'), make_relative_rpath('../lib'),
] ]
) )
extensions.append(THCUNN) extensions.append(THCUNN)
version="0.1" version = "0.1"
if os.getenv('PYTORCH_BUILD_VERSION'): if os.getenv('PYTORCH_BUILD_VERSION'):
version = os.getenv('PYTORCH_BUILD_VERSION') \ version = os.getenv('PYTORCH_BUILD_VERSION') \
+ '_' + os.getenv('PYTORCH_BUILD_NUMBER') + '_' + os.getenv('PYTORCH_BUILD_NUMBER')
setup(name="torch", version=version, setup(name="torch", version=version,
ext_modules=extensions, ext_modules=extensions,
cmdclass = { cmdclass={
'build': build, 'build': build,
'build_ext': build_ext, 'build_ext': build_ext,
'build_deps': build_deps, 'build_deps': build_deps,
'build_module': build_module, 'build_module': build_module,
'install': install, 'install': install,
'clean': clean, 'clean': clean,
}, },
packages=packages, packages=packages,
package_data={'torch': [ package_data={'torch': [
'lib/*.so*', 'lib/*.dylib*', 'lib/*.so*', 'lib/*.dylib*',
'lib/torch_shm_manager', 'lib/torch_shm_manager',
'lib/*.h', 'lib/*.h',
'lib/include/TH/*.h', 'lib/include/TH/generic/*.h', 'lib/include/TH/*.h', 'lib/include/TH/generic/*.h',
'lib/include/THC/*.h', 'lib/include/THC/generic/*.h']}, 'lib/include/THC/*.h', 'lib/include/THC/generic/*.h']},
install_requires=['pyyaml'], install_requires=['pyyaml'],
) )

View File

@ -1,3 +1,5 @@
import sys
import argparse
import unittest import unittest
import contextlib import contextlib
from itertools import product from itertools import product
@ -9,9 +11,17 @@ from torch.autograd import Variable, Function
torch.set_default_tensor_type('torch.DoubleTensor') torch.set_default_tensor_type('torch.DoubleTensor')
torch.manual_seed(123)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(123) def run_tests():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--seed', type=int, default=123)
args, remaining = parser.parse_known_args()
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
remaining = [sys.argv[0]] + remaining
unittest.main(argv=remaining)
TEST_NUMPY = True TEST_NUMPY = True
@ -20,6 +30,7 @@ try:
except ImportError: except ImportError:
TEST_NUMPY = False TEST_NUMPY = False
def get_cpu_type(t): def get_cpu_type(t):
assert t.__module__ == 'torch.cuda' assert t.__module__ == 'torch.cuda'
return getattr(torch, t.__class__.__name__) return getattr(torch, t.__class__.__name__)
@ -146,7 +157,7 @@ def make_jacobian(input, num_out):
return torch.zeros(input.nelement(), num_out) return torch.zeros(input.nelement(), num_out)
else: else:
return type(input)(filter(lambda x: x is not None, return type(input)(filter(lambda x: x is not None,
(make_jacobian(elem, num_out) for elem in input))) (make_jacobian(elem, num_out) for elem in input)))
def iter_tensors(x, only_requiring_grad=False): def iter_tensors(x, only_requiring_grad=False):
@ -197,7 +208,7 @@ def get_numerical_jacobian(fn, input, target):
outb.copy_(fn(input)) outb.copy_(fn(input))
flat_tensor[i] = orig flat_tensor[i] = orig
outb.add_(-1,outa).div_(2*perturbation) outb.add_(-1, outa).div_(2 * perturbation)
d_tensor[i] = outb d_tensor[i] = outb
return jacobian return jacobian

View File

@ -18,6 +18,7 @@ else:
TEST_CUDA = torch.cuda.is_available() TEST_CUDA = torch.cuda.is_available()
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2 TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.cuda.FloatTensor(1)) TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.cuda.FloatTensor(1))
TEST_CUDNN_VERSION = TEST_CUDNN and torch.backends.cudnn.version()
PRECISION = 1e-5 PRECISION = 1e-5
module_tests = [ module_tests = [
@ -25,14 +26,14 @@ module_tests = [
module_name='Linear', module_name='Linear',
constructor_args=(10, 8), constructor_args=(10, 8),
input_size=(4, 10), input_size=(4, 10),
reference_fn=lambda i,p: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8) reference_fn=lambda i, p: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)
), ),
dict( dict(
module_name='Linear', module_name='Linear',
constructor_args=(10, 8, False), constructor_args=(10, 8, False),
input_size=(4, 10), input_size=(4, 10),
desc='no_bias', desc='no_bias',
reference_fn=lambda i,p: torch.mm(i, p[0].t()) reference_fn=lambda i, p: torch.mm(i, p[0].t())
), ),
dict( dict(
module_name='Threshold', module_name='Threshold',
@ -72,7 +73,7 @@ module_tests = [
dict( dict(
module_name='Hardtanh', module_name='Hardtanh',
input_size=(3, 2, 5), input_size=(3, 2, 5),
reference_fn=lambda i,_: i.clamp(-1, 1) reference_fn=lambda i, _: i.clamp(-1, 1)
), ),
dict( dict(
module_name='Sigmoid', module_name='Sigmoid',
@ -85,17 +86,23 @@ module_tests = [
dict( dict(
module_name='Softmax', module_name='Softmax',
input_size=(10, 20), input_size=(10, 20),
reference_fn=lambda i,_: torch.exp(i).div(torch.exp(i).sum(1).expand(10, 20)) reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1).expand(10, 20))
), ),
dict( dict(
module_name='Softmax2d', module_name='Softmax2d',
input_size=(1, 3, 10, 20), input_size=(1, 3, 10, 20),
reference_fn=lambda i,_: torch.exp(i).div(torch.exp(i).sum(1).expand_as(i)) reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1).expand_as(i))
), ),
dict( dict(
module_name='LogSoftmax', module_name='LogSoftmax',
input_size=(10, 20), input_size=(10, 20),
reference_fn=lambda i,_: torch.exp(i).div_(torch.exp(i).sum(1).expand(10, 20)).log_() reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1).expand(10, 20)).log_()
),
dict(
module_name='LogSoftmax',
input_size=(1, 3, 10, 20),
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1).expand_as(i)).log_(),
desc='multiparam'
), ),
dict( dict(
module_name='ELU', module_name='ELU',
@ -124,18 +131,18 @@ module_tests = [
dict( dict(
module_name='LogSigmoid', module_name='LogSigmoid',
input_size=(2, 3, 4), input_size=(2, 3, 4),
reference_fn=lambda i,_: i.sigmoid().log() reference_fn=lambda i, _: i.sigmoid().log()
), ),
dict( dict(
module_name='Softplus', module_name='Softplus',
input_size=(10, 20), input_size=(10, 20),
reference_fn=lambda i,_: torch.log(1 + torch.exp(i)) reference_fn=lambda i, _: torch.log(1 + torch.exp(i))
), ),
dict( dict(
module_name='Softplus', module_name='Softplus',
constructor_args=(2,), constructor_args=(2,),
input_size=(10, 20), input_size=(10, 20),
reference_fn=lambda i,_: 1. / 2. * torch.log(1 + torch.exp(2 * i)), reference_fn=lambda i, _: 1. / 2. * torch.log(1 + torch.exp(2 * i)),
desc='beta' desc='beta'
), ),
dict( dict(
@ -166,7 +173,7 @@ module_tests = [
dict( dict(
module_name='Softsign', module_name='Softsign',
input_size=(3, 2, 5), input_size=(3, 2, 5),
reference_fn=lambda i,_: i.div(1 + torch.abs(i)) reference_fn=lambda i, _: i.div(1 + torch.abs(i))
), ),
dict( dict(
module_name='Softmin', module_name='Softmin',
@ -181,11 +188,11 @@ module_tests = [
criterion_tests = [ criterion_tests = [
dict(module_name='L1Loss', dict(module_name='L1Loss',
input_size=(2, 3, 4), input_size=(2, 3, 4),
target=torch.randn(2, 3, 4), target=torch.randn(2, 3, 4),
reference_fn=lambda i,t,_: 1./i.numel() * \ reference_fn=lambda i, t, _: 1. / i.numel() *
sum((a-b).abs().sum() for a,b in zip(i, t)) sum((a - b).abs().sum() for a, b in zip(i, t))
), ),
dict( dict(
module_name='NLLLoss', module_name='NLLLoss',
input=torch.rand(15, 10).log(), input=torch.rand(15, 10).log(),
@ -207,7 +214,7 @@ criterion_tests = [
module_name='MSELoss', module_name='MSELoss',
input=torch.randn(2, 3, 4, 5), input=torch.randn(2, 3, 4, 5),
target=torch.randn(2, 3, 4, 5), target=torch.randn(2, 3, 4, 5),
reference_fn=lambda i,t,_: (i-t).abs().pow(2).sum() / i.numel() reference_fn=lambda i, t, _: (i - t).abs().pow(2).sum() / i.numel()
), ),
dict( dict(
module_name='BCELoss', module_name='BCELoss',
@ -364,9 +371,9 @@ class NNTestCase(TestCase):
if jacobian_input: if jacobian_input:
for jacobian_x, d_x in zip(flat_jacobian_input, iter_tensors(d_input)): for jacobian_x, d_x in zip(flat_jacobian_input, iter_tensors(d_input)):
jacobian_x[:,i] = d_x jacobian_x[:, i] = d_x
if jacobian_parameters: if jacobian_parameters:
jacobian_param[:,i] = torch.cat(self._flatten_tensors(d_param), 0) jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)
res = tuple() res = tuple()
if jacobian_input: if jacobian_input:
@ -427,7 +434,7 @@ class NNTestCase(TestCase):
fx1 = self._forward_criterion(criterion, input, target) fx1 = self._forward_criterion(criterion, input, target)
x[i] = original - eps x[i] = original - eps
fx2 = self._forward_criterion(criterion, input, target) fx2 = self._forward_criterion(criterion, input, target)
deriv = (fx1 - fx2) / (2.*eps) deriv = (fx1 - fx2) / (2. * eps)
d_x[i] = deriv d_x[i] = deriv
x[i] = original x[i] = original
@ -441,8 +448,9 @@ class NNTestCase(TestCase):
class TestBase(object): class TestBase(object):
def __init__(self, constructor, constructor_args=tuple(), input_size=None, def __init__(self, constructor, constructor_args=tuple(), input_size=None,
input=None, desc='', reference_fn=None, fullname=None, **kwargs): input=None, desc='', reference_fn=None, fullname=None, **kwargs):
if input_size is None and input is None: if input_size is None and input is None:
raise RuntimeError("Specify either an input tensor, or it's size!") raise RuntimeError("Specify either an input tensor, or it's size!")
self.constructor = constructor self.constructor = constructor
@ -490,6 +498,7 @@ class TestBase(object):
class ModuleTest(TestBase): class ModuleTest(TestBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ModuleTest, self).__init__(*args, **kwargs) super(ModuleTest, self).__init__(*args, **kwargs)
self.jacobian_input = kwargs.get('jacobian_input', True) self.jacobian_input = kwargs.get('jacobian_input', True)
@ -562,6 +571,7 @@ class ModuleTest(TestBase):
class CriterionTest(TestBase): class CriterionTest(TestBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(CriterionTest, self).__init__(*args, **kwargs) super(CriterionTest, self).__init__(*args, **kwargs)
self.target = self._get_target(kwargs['target']) self.target = self._get_target(kwargs['target'])
@ -584,7 +594,7 @@ class CriterionTest(TestBase):
if isinstance(target, Variable): if isinstance(target, Variable):
target = target.data target = target.data
expected_out = self.reference_fn(deepcopy(self._unpack_input(input)), expected_out = self.reference_fn(deepcopy(self._unpack_input(input)),
deepcopy(target), module) deepcopy(target), module)
test_case.assertEqual(out, expected_out) test_case.assertEqual(out, expected_out)
test_case.check_criterion_jacobian(module, input, self.target) test_case.check_criterion_jacobian(module, input, self.target)
@ -607,10 +617,10 @@ class CriterionTest(TestBase):
cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target) cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target)
gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target) gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target)
test_case.assertEqual(cpu_output, gpu_output, 2e-4) test_case.assertEqual(cpu_output, gpu_output, 4e-4)
cpu_gradInput = test_case._backward_criterion(cpu_module, cpu_input, cpu_target) cpu_gradInput = test_case._backward_criterion(cpu_module, cpu_input, cpu_target)
gpu_gradInput = test_case._backward_criterion(gpu_module, gpu_input, gpu_target) gpu_gradInput = test_case._backward_criterion(gpu_module, gpu_input, gpu_target)
test_case.assertEqual(cpu_gradInput, gpu_gradInput, 2e-4) test_case.assertEqual(cpu_gradInput, gpu_gradInput, 4e-4)
except NotImplementedError: except NotImplementedError:
pass pass

View File

@ -2,6 +2,7 @@ import torch.nn as nn
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.linear = nn.Linear(10, 20) self.linear = nn.Linear(10, 20)

View File

@ -2,6 +2,7 @@ import torch.nn as nn
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.linear = nn.Linear(10, 20) self.linear = nn.Linear(10, 20)

View File

@ -1,5 +1,6 @@
import torch import torch
def check_error(desc, fn, *required_substrings): def check_error(desc, fn, *required_substrings):
try: try:
fn() fn()
@ -16,54 +17,55 @@ def check_error(desc, fn, *required_substrings):
assert False, "given function ({}) didn't raise an error".format(desc) assert False, "given function ({}) didn't raise an error".format(desc)
check_error( check_error(
'Wrong argument types', 'Wrong argument types',
lambda: torch.FloatStorage(object()), lambda: torch.FloatStorage(object()),
'object') 'object')
check_error('Unknown keyword argument', check_error('Unknown keyword argument',
lambda: torch.FloatStorage(content=1234.), lambda: torch.FloatStorage(content=1234.),
'keyword') 'keyword')
check_error('Invalid types inside a sequence', check_error('Invalid types inside a sequence',
lambda: torch.FloatStorage(['a', 'b']), lambda: torch.FloatStorage(['a', 'b']),
'list', 'str') 'list', 'str')
check_error('Invalid size type', check_error('Invalid size type',
lambda: torch.FloatStorage(1.5), lambda: torch.FloatStorage(1.5),
'float') 'float')
check_error('Invalid offset', check_error('Invalid offset',
lambda: torch.FloatStorage(torch.FloatStorage(2), 4), lambda: torch.FloatStorage(torch.FloatStorage(2), 4),
'2', '4') '2', '4')
check_error('Negative offset', check_error('Negative offset',
lambda: torch.FloatStorage(torch.FloatStorage(2), -1), lambda: torch.FloatStorage(torch.FloatStorage(2), -1),
'2', '-1') '2', '-1')
check_error('Invalid size', check_error('Invalid size',
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, 5), lambda: torch.FloatStorage(torch.FloatStorage(3), 1, 5),
'2', '1', '5') '2', '1', '5')
check_error('Negative size', check_error('Negative size',
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, -5), lambda: torch.FloatStorage(torch.FloatStorage(3), 1, -5),
'2', '1', '-5') '2', '1', '-5')
check_error('Invalid index type', check_error('Invalid index type',
lambda: torch.FloatStorage(10)['first item'], lambda: torch.FloatStorage(10)['first item'],
'str') 'str')
def assign(): def assign():
torch.FloatStorage(10)[1:-1] = '1' torch.FloatStorage(10)[1:-1] = '1'
check_error('Invalid value type', check_error('Invalid value type',
assign, assign,
'str') 'str')
check_error('resize_ with invalid type', check_error('resize_ with invalid type',
lambda: torch.FloatStorage(10).resize_(1.5), lambda: torch.FloatStorage(10).resize_(1.5),
'float') 'float')
check_error('fill_ with invalid type', check_error('fill_ with invalid type',
lambda: torch.IntStorage(10).fill_('asdf'), lambda: torch.IntStorage(10).fill_('asdf'),
'str') 'str')
# TODO: frombuffer # TODO: frombuffer

View File

@ -1,5 +1,5 @@
# th test.lua > lua.out th test.lua > lua.out
python3 test.py > python.out python3 test.py > python.out
diff lua.out python.out >/dev/null 2>&1 diff lua.out python.out >/dev/null 2>&1

File diff suppressed because it is too large Load Diff

View File

@ -1,39 +0,0 @@
assert(arg[1])
funcs = {
'resizeAs', 'add', 'zero', 'mul', 'div', 'abs',
'addcmul', 'addcdiv', 'copy', 'sqrt', 'fill',
{'cmul', 'mul'},
{'cdiv', 'div'},
}
for _, val in pairs(funcs) do
local name, newname
if type(val) == 'table' then
name = val[1]
newname = val[2]
else
name = val
newname = val .. '_'
end
command = "sed -i -r "
.. "'/torch\\." .. name .. "\\(/b; " -- short-circuits
.. "s/([a-zA-Z]*)\\." .. name .. "\\(" -- substitution
.. "/"
.. "\\1\\." .. newname .. "\\(/g' " .. arg[1]
print(command)
os.execute(command)
command = "sed -i 's/math\\." .. newname
.. "/math\\." .. name .. "/' " .. arg[1]
print(command)
os.execute(command)
end
funcs = {
{'torch\.cmul', 'torch\.mul'},
{'torch\.cdiv', 'torch\.div'},
}
for _, val in pairs(funcs) do
command = "sed -i 's/" .. val[1] .. "/" .. val[2] .. "/' " .. arg[1]
print(command)
os.execute(command)
end

33
test/optim/test.lua Normal file
View File

@ -0,0 +1,33 @@
local cjson = require 'cjson'
require 'optim'
function rosenbrock(t)
x, y = t[1], t[2]
return (1 - x) ^ 2 + 100 * (y - x^2)^2
end
function drosenbrock(t)
x, y = t[1], t[2]
return torch.DoubleTensor({-400 * x * (y - x^2) - 2 * (1 - x), 200 * x * (y - x^2)})
end
local fd = io.open('tests.json', 'r')
local tests = cjson.decode(fd:read('*a'))
fd:close()
for i, test in ipairs(tests) do
print(test.algorithm)
algorithm = optim[test.algorithm]
for i, config in ipairs(test.config) do
print('================================================================================')
params = torch.DoubleTensor({1.5, 1.5})
for i = 1, 100 do
function closure(x)
return rosenbrock(x), drosenbrock(x)
end
algorithm(closure, params, config)
print(string.format('%.8f\t%.8f', params[1], params[2]))
end
end
end

View File

@ -3,13 +3,15 @@ import torch
import torch.legacy.optim as optim import torch.legacy.optim as optim
from pprint import pprint from pprint import pprint
def rosenbrock(tensor): def rosenbrock(tensor):
x, y = tensor x, y = tensor
return (1 - x)**2 + 100 * (y - x**2)**2 return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2
def drosenbrock(tensor): def drosenbrock(tensor):
x, y = tensor x, y = tensor
return torch.DoubleTensor((-400 * x * (y - x**2) - 2 * (1 - x), 200 * x * (y - x**2))) return torch.DoubleTensor((-400 * x * (y - x ** 2) - 2 * (1 - x), 200 * x * (y - x ** 2)))
algorithms = { algorithms = {
'adadelta': optim.adadelta, 'adadelta': optim.adadelta,
@ -22,6 +24,7 @@ algorithms = {
'rmsprop': optim.rmsprop, 'rmsprop': optim.rmsprop,
'rprop': optim.rprop, 'rprop': optim.rprop,
'sgd': optim.sgd, 'sgd': optim.sgd,
'lbfgs': optim.lbfgs,
} }
with open('tests.json', 'r') as f: with open('tests.json', 'r') as f:
@ -35,4 +38,4 @@ for test in tests:
params = torch.DoubleTensor((1.5, 1.5)) params = torch.DoubleTensor((1.5, 1.5))
for i in range(100): for i in range(100):
algorithm(lambda x: (rosenbrock(x), drosenbrock(x)), params, config) algorithm(lambda x: (rosenbrock(x), drosenbrock(x)), params, config)
print('{:.12f}\t{:.12f}\t'.format(params[0], params[1])) print('{:.8f}\t{:.8f}\t'.format(params[0], params[1]))

View File

@ -98,5 +98,12 @@
{"learningRate": 1e-4, "nesterov": true, "momentum": 0.95, "dampening": 0}, {"learningRate": 1e-4, "nesterov": true, "momentum": 0.95, "dampening": 0},
{"weightDecay": 0.2} {"weightDecay": 0.2}
] ]
},
{
"algorithm": "lbfgs",
"config": [
{},
{"learningRate": 1e-1}
]
} }
] ]

View File

@ -2,8 +2,17 @@
set -e set -e
PYCMD=${PYCMD:="python"} PYCMD=${PYCMD:="python"}
if [ "$1" == "coverage" ]; COVERAGE=0
then while [[ "$#" -gt 0 ]]; do
case "$1" in
-p|--python) PYCMD=$2; shift 2 ;;
-c|--coverage) COVERAGE=1; shift 2 ;;
--) shift; break ;;
*) echo "Invalid argument: $1!" ; exit 1 ;;
esac
done
if [[ $COVERAGE -eq 1 ]]; then
coverage erase coverage erase
PYCMD="coverage run --parallel-mode --source torch " PYCMD="coverage run --parallel-mode --source torch "
echo "coverage flag found. Setting python command to: \"$PYCMD\"" echo "coverage flag found. Setting python command to: \"$PYCMD\""
@ -12,39 +21,66 @@ fi
pushd "$(dirname "$0")" pushd "$(dirname "$0")"
echo "Running torch tests" echo "Running torch tests"
$PYCMD test_torch.py $PYCMD test_torch.py $@
echo "Running autograd tests" echo "Running autograd tests"
$PYCMD test_autograd.py $PYCMD test_autograd.py $@
echo "Running sparse tests" echo "Running sparse tests"
$PYCMD test_sparse.py $PYCMD test_sparse.py $@
echo "Running nn tests" echo "Running nn tests"
$PYCMD test_nn.py $PYCMD test_nn.py $@
echo "Running legacy nn tests" echo "Running legacy nn tests"
$PYCMD test_legacy_nn.py $PYCMD test_legacy_nn.py $@
echo "Running optim tests" echo "Running optim tests"
$PYCMD test_optim.py $PYCMD test_optim.py $@
echo "Running multiprocessing tests" echo "Running multiprocessing tests"
$PYCMD test_multiprocessing.py $PYCMD test_multiprocessing.py $@
MULTIPROCESSING_METHOD=spawn $PYCMD test_multiprocessing.py MULTIPROCESSING_METHOD=spawn $PYCMD test_multiprocessing.py $@
MULTIPROCESSING_METHOD=forkserver $PYCMD test_multiprocessing.py MULTIPROCESSING_METHOD=forkserver $PYCMD test_multiprocessing.py $@
echo "Running util tests" echo "Running util tests"
$PYCMD test_utils.py $PYCMD test_utils.py $@
echo "Running dataloader tests" echo "Running dataloader tests"
$PYCMD test_dataloader.py $PYCMD test_dataloader.py $@
echo "Running cuda tests" echo "Running cuda tests"
$PYCMD test_cuda.py $PYCMD test_cuda.py $@
echo "Running NCCL tests" echo "Running NCCL tests"
$PYCMD test_nccl.py $PYCMD test_nccl.py $@
################################################################################
if [[ "$TEST_DISTRIBUTED" -eq 1 ]]; then
distributed_set_up() {
export TEMP_DIR="$(mktemp -d)"
rm -rf "$TEMP_DIR/"*
mkdir "$TEMP_DIR/barrier"
mkdir "$TEMP_DIR/test_dir"
}
distributed_tear_down() {
rm -rf "$TEMP_DIR"
}
trap distributed_tear_down EXIT SIGHUP SIGINT SIGTERM
echo "Running distributed tests for the TCP backend"
distributed_set_up
BACKEND=tcp WORLD_SIZE=3 $PYCMD ./test_distributed.py
distributed_tear_down
echo "Running distributed tests for the MPI backend"
distributed_set_up
BACKEND=mpi mpiexec -n 3 $PYCMD ./test_distributed.py
distributed_tear_down
fi
################################################################################
if [ "$1" == "coverage" ]; if [ "$1" == "coverage" ];
then then

View File

@ -7,7 +7,8 @@ import unittest
from copy import deepcopy from copy import deepcopy
from collections import OrderedDict from collections import OrderedDict
from common import make_jacobian, TestCase, iter_tensors, get_numerical_jacobian from common import make_jacobian, TestCase, iter_tensors, \
get_numerical_jacobian, run_tests
from torch.autograd._functions import * from torch.autograd._functions import *
from torch.autograd import Variable, Function from torch.autograd import Variable, Function
@ -45,7 +46,7 @@ def get_analytical_jacobian(input, output):
zero_gradients(input) zero_gradients(input)
output.backward(grad_output, retain_variables=True) output.backward(grad_output, retain_variables=True)
for jacobian_x, d_x in zip(jacobian, iter_gradients(input)): for jacobian_x, d_x in zip(jacobian, iter_gradients(input)):
jacobian_x[:,i] = d_x jacobian_x[:, i] = d_x
return jacobian return jacobian
@ -67,6 +68,7 @@ class TestAutograd(TestCase):
y = Variable(torch.ones(5, 5) * 4, requires_grad=True) y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
counter = [0] counter = [0]
def bw_hook(inc, grad): def bw_hook(inc, grad):
self.assertIsInstance(grad, Variable) self.assertIsInstance(grad, Variable)
counter[0] += inc counter[0] += inc
@ -102,6 +104,7 @@ class TestAutograd(TestCase):
# WARNING: this is a test for autograd internals. # WARNING: this is a test for autograd internals.
# You should never have to use such things in your code. # You should never have to use such things in your code.
class NoneGradientFunction(Function): class NoneGradientFunction(Function):
def forward(self, x, y): def forward(self, x, y):
assert self.needs_input_grad[0] assert self.needs_input_grad[0]
assert not self.needs_input_grad[1] assert not self.needs_input_grad[1]
@ -113,6 +116,7 @@ class TestAutograd(TestCase):
fn = NoneGradientFunction() fn = NoneGradientFunction()
fn._backward_hooks = OrderedDict() fn._backward_hooks = OrderedDict()
was_called = [False] was_called = [False]
def hook(grad_input, grad_output): def hook(grad_input, grad_output):
self.assertIsInstance(grad_input, tuple) self.assertIsInstance(grad_input, tuple)
self.assertIsInstance(grad_output, tuple) self.assertIsInstance(grad_output, tuple)
@ -142,7 +146,7 @@ class TestAutograd(TestCase):
v.backward(grad_output) v.backward(grad_output)
self.assertEqual(v.grad.data, grad_output) self.assertEqual(v.grad.data, grad_output)
a = x + (y * z) + 4 * z**2 * x / y a = x + (y * z) + 4 * z ** 2 * x / y
a.backward(grad_output) a.backward(grad_output)
x_grad = 4 * z_t.pow(2) / y_t + 1 x_grad = 4 * z_t.pow(2) / y_t + 1
y_grad = z_t - 4 * x_t * z_t.pow(2) / y_t.pow(2) y_grad = z_t - 4 * x_t * z_t.pow(2) / y_t.pow(2)
@ -237,6 +241,7 @@ class TestAutograd(TestCase):
self.assertFalse(a.requires_grad) self.assertFalse(a.requires_grad)
b = a + z b = a + z
self.assertTrue(b.requires_grad) self.assertTrue(b.requires_grad)
def error(): def error():
raise RuntimeError raise RuntimeError
# Make sure backward isn't called on these # Make sure backward isn't called on these
@ -374,6 +379,7 @@ class TestAutograd(TestCase):
segfault. segfault.
""" """
class CollectOnDelete(Function): class CollectOnDelete(Function):
def __del__(self): def __del__(self):
gc.collect() gc.collect()
@ -381,7 +387,7 @@ class TestAutograd(TestCase):
Variable(torch.randn(10, 10), creator=CollectOnDelete()) Variable(torch.randn(10, 10), creator=CollectOnDelete())
@unittest.skipIf(not torch.cuda.is_available() or torch.cuda.device_count() < 2, @unittest.skipIf(not torch.cuda.is_available() or torch.cuda.device_count() < 2,
"CUDA not available or <2 GPUs detected") "CUDA not available or <2 GPUs detected")
def test_unused_output_gpu(self): def test_unused_output_gpu(self):
from torch.nn.parallel._functions import Broadcast from torch.nn.parallel._functions import Broadcast
x = Variable(torch.randn(5, 5).float().cuda(), requires_grad=True) x = Variable(torch.randn(5, 5).float().cuda(), requires_grad=True)
@ -431,6 +437,7 @@ class TestAutograd(TestCase):
def test_return_leaf(self): def test_return_leaf(self):
class Identity(Function): class Identity(Function):
def forward(self, a, b): def forward(self, a, b):
return a, a + b return a, a + b
@ -438,6 +445,7 @@ class TestAutograd(TestCase):
return grad_a + grad_b, grad_b return grad_a + grad_b, grad_b
class Inplace(InplaceFunction): class Inplace(InplaceFunction):
def forward(self, a, b): def forward(self, a, b):
self.mark_dirty(a) self.mark_dirty(a)
return a.add_(b), b + 2 return a.add_(b), b + 2
@ -459,6 +467,7 @@ class TestAutograd(TestCase):
def test_return_leaf_inplace(self): def test_return_leaf_inplace(self):
class Inplace(InplaceFunction): class Inplace(InplaceFunction):
def forward(self, a, b): def forward(self, a, b):
self.mark_dirty(a) self.mark_dirty(a)
return a.add_(b), b + 2 return a.add_(b), b + 2
@ -491,51 +500,51 @@ class TestAutograd(TestCase):
self.assertEqual(z.grad.data, torch.ones(5) * 2) self.assertEqual(z.grad.data, torch.ones(5) * 2)
def test_backward_copy(self): def test_backward_copy(self):
# This tests checks backward engine for a very subtle bug that appreared # This tests checks backward engine for a very subtle bug that appreared
# in one of the initial versions of autograd. Gradients tensors were # in one of the initial versions of autograd. Gradients tensors were
# simply stored in lists while the function waited for all its gradients # simply stored in lists while the function waited for all its gradients
# to be computed. However, sometimes an output was used multiple times, # to be computed. However, sometimes an output was used multiple times,
# so the gradients needed to be summed. Engine used to keep a need_copy # so the gradients needed to be summed. Engine used to keep a need_copy
# set of tensors that will need a clone upon next addition and removed # set of tensors that will need a clone upon next addition and removed
# them from the set as soon as the clone was performed. However, this # them from the set as soon as the clone was performed. However, this
# could lead to incorrect results if the same gradient tensor was # could lead to incorrect results if the same gradient tensor was
# buffered in three places in the graph: # buffered in three places in the graph:
# 1. When accumulating gradients in one of these places it was cloned # 1. When accumulating gradients in one of these places it was cloned
# and removed from need_copy set. # and removed from need_copy set.
# 2. When accumulating in second place, it wasn't in the need_copy set, # 2. When accumulating in second place, it wasn't in the need_copy set,
# so the gradients were simply accumulated in-place (which already # so the gradients were simply accumulated in-place (which already
# modified the grad in 3rd place) # modified the grad in 3rd place)
# 3. When accumulating in the third place, it wasn't in the need_copy set # 3. When accumulating in the third place, it wasn't in the need_copy set
# as well, so the incoming gradient was summed in-place, yielding # as well, so the incoming gradient was summed in-place, yielding
# incorrect results in all functions, except the first one. # incorrect results in all functions, except the first one.
x = Variable(torch.ones(5, 5), requires_grad=True) x = Variable(torch.ones(5, 5), requires_grad=True)
y = Variable(torch.ones(5, 5), requires_grad=True) y = Variable(torch.ones(5, 5), requires_grad=True)
# Simulate that we're in the middle of the graph # Simulate that we're in the middle of the graph
a = x + 2 a = x + 2
b = y + 2 b = y + 2
c = x + 2 c = x + 2
# This op will just return grad_output two times in backward # This op will just return grad_output two times in backward
add1 = a + b add1 = a + b
add2 = add1 + c add2 = add1 + c
# Simulate a long branch, so grad_output will get buffered. # Simulate a long branch, so grad_output will get buffered.
for i in range(4): for i in range(4):
a = a * 2 a = a * 2
b = b * 2 b = b * 2
c = c * 2 c = c * 2
branch = a + b + c branch = a + b + c
out = add2 + branch out = add2 + branch
# expected gradients are: # expected gradients are:
# for x: 34 (16 from final a, 16 from final c, 2 from add2) # for x: 34 (16 from final a, 16 from final c, 2 from add2)
# for y: 17 (16 from final b, 1 from add2) # for y: 17 (16 from final b, 1 from add2)
grad_output = torch.ones(5, 5) grad_output = torch.ones(5, 5)
out.backward(grad_output) out.backward(grad_output)
self.assertEqual(x.grad.data, torch.ones(5, 5) * 34) self.assertEqual(x.grad.data, torch.ones(5, 5) * 34)
self.assertEqual(y.grad.data, torch.ones(5, 5) * 17) self.assertEqual(y.grad.data, torch.ones(5, 5) * 17)
def test_functional_blas(self): def test_functional_blas(self):
def compare(fn, *args): def compare(fn, *args):
unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg
for arg in args) for arg in args)
self.assertEqual(fn(*args).data, fn(*unpacked_args)) self.assertEqual(fn(*args).data, fn(*unpacked_args))
def test_blas_add(fn, x, y, z): def test_blas_add(fn, x, y, z):
@ -548,27 +557,29 @@ class TestAutograd(TestCase):
compare(fn, x, y) compare(fn, x, y)
test_blas(torch.mm, Variable(torch.randn(2, 10)), test_blas(torch.mm, Variable(torch.randn(2, 10)),
Variable(torch.randn(10, 4))) Variable(torch.randn(10, 4)))
test_blas_add(torch.addmm, Variable(torch.randn(2, 4)), test_blas_add(torch.addmm, Variable(torch.randn(2, 4)),
Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4))) Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
test_blas(torch.bmm, Variable(torch.randn(4, 2, 10)), test_blas(torch.bmm, Variable(torch.randn(4, 2, 10)),
Variable(torch.randn(4, 10, 4))) Variable(torch.randn(4, 10, 4)))
test_blas_add(torch.addbmm, Variable(torch.randn(2, 4)), test_blas_add(torch.addbmm, Variable(torch.randn(2, 4)),
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4))) Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
test_blas_add(torch.baddbmm, Variable(torch.randn(4, 2, 4)), test_blas_add(torch.baddbmm, Variable(torch.randn(4, 2, 4)),
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4))) Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
test_blas(torch.mv, Variable(torch.randn(2, 10)), test_blas(torch.mv, Variable(torch.randn(2, 10)),
Variable(torch.randn(10))) Variable(torch.randn(10)))
test_blas_add(torch.addmv, Variable(torch.randn(2)), test_blas_add(torch.addmv, Variable(torch.randn(2)),
Variable(torch.randn(2, 10)), Variable(torch.randn(10))) Variable(torch.randn(2, 10)), Variable(torch.randn(10)))
test_blas(torch.ger, Variable(torch.randn(5)), test_blas(torch.ger, Variable(torch.randn(5)),
Variable(torch.randn(6))) Variable(torch.randn(6)))
test_blas_add(torch.addr, Variable(torch.randn(5, 6)), test_blas_add(torch.addr, Variable(torch.randn(5, 6)),
Variable(torch.randn(5)), Variable(torch.randn(6))) Variable(torch.randn(5)), Variable(torch.randn(6)))
def test_save_none_for_backward(self): def test_save_none_for_backward(self):
test_case = self test_case = self
class MyFn(Function): class MyFn(Function):
def forward(self, input): def forward(self, input):
self.save_for_backward(None, input, None) self.save_for_backward(None, input, None)
return input * input return input * input
@ -586,6 +597,7 @@ class TestAutograd(TestCase):
def test_too_many_grads(self): def test_too_many_grads(self):
class MyFn(Function): class MyFn(Function):
def forward(self, input): def forward(self, input):
return input return input
@ -674,6 +686,7 @@ class TestAutograd(TestCase):
def test_dep_nograd(self): def test_dep_nograd(self):
class F1(Function): class F1(Function):
def forward(self, input): def forward(self, input):
out = torch.randn(input.size()) out = torch.randn(input.size())
self.mark_non_differentiable(out) self.mark_non_differentiable(out)
@ -683,6 +696,7 @@ class TestAutograd(TestCase):
return grad_output return grad_output
class F2(Function): class F2(Function):
def forward(self, input, ignored): def forward(self, input, ignored):
return input return input
@ -705,6 +719,7 @@ def index_variable(shape, max_indices):
index = torch.rand(*shape).mul_(max_indices).floor_().long() index = torch.rand(*shape).mul_(max_indices).floor_().long()
return Variable(index, requires_grad=False) return Variable(index, requires_grad=False)
def gather_variable(shape, index_dim, max_indices): def gather_variable(shape, index_dim, max_indices):
assert len(shape) == 2 assert len(shape) == 2
assert index_dim < 2 assert index_dim < 2
@ -712,7 +727,7 @@ def gather_variable(shape, index_dim, max_indices):
index = torch.LongTensor(*shape) index = torch.LongTensor(*shape)
for i in range(shape[index_dim]): for i in range(shape[index_dim]):
index.select(index_dim, i).copy_( index.select(index_dim, i).copy_(
torch.randperm(max_indices)[:shape[batch_dim]]) torch.randperm(max_indices)[:shape[batch_dim]])
return Variable(index, requires_grad=False) return Variable(index, requires_grad=False)
@ -720,215 +735,235 @@ L = 20
M = 10 M = 10
S = 5 S = 5
function_tests = [ function_tests = [
(Add, (), ((M, M), (M, M)) ), (Add, (), ((M, M), (M, M))),
(Sub, (), ((M, M), (M, M)) ), (Sub, (), ((M, M), (M, M))),
(Mul, (), ((M, M), (M, M)) ), (Mul, (), ((M, M), (M, M))),
(Div, (), ((M, M), torch.rand(M, M) + 5e-2) ), (Div, (), ((M, M), torch.rand(M, M) + 5e-2)),
(Pow, (), (torch.rand(M, M) + 1e-3, torch.rand(M, M) + 0.1)), (Pow, (), (torch.rand(M, M) + 1e-3, torch.rand(M, M) + 0.1)),
(AddConstant, (3.14,), ((L, L),) ), (AddConstant, (3.14,), ((L, L),)),
(SubConstant, (3.14,), ((L, L),) ), (SubConstant, (3.14,), ((L, L),)),
(SubConstant, (3.14, True), ((L, L),), 'from_tensor' ), (SubConstant, (3.14, True), ((L, L),), 'from_tensor'),
(MulConstant, (3.14,), ((L, L),) ), (MulConstant, (3.14,), ((L, L),)),
(DivConstant, (3.14, True), (torch.rand(L, L) + 1e-1,), 'by_tensor' ), (DivConstant, (3.14, True), (torch.rand(L, L) + 1e-1,), 'by_tensor'),
(PowConstant, (3.14,), (torch.rand(L, L),) ), (PowConstant, (3.14,), (torch.rand(L, L),)),
(PowConstant, (3.14, True), (torch.rand(L, L),), 'tensor_power' ), (PowConstant, (3.14, True), (torch.rand(L, L),), 'tensor_power'),
(Transpose, (0, 1), (torch.rand(L, L),) ), (Transpose, (0, 1), (torch.rand(L, L),)),
(Transpose, (2, 0), (torch.rand(S, S, S),), '3d' ), (Transpose, (2, 0), (torch.rand(S, S, S),), '3d'),
(Permute, ((0, 4, 3, 5, 1, 2),), ((1, 2, 3, 4, 5, 6),) ), (Permute, ((0, 4, 3, 5, 1, 2),), ((1, 2, 3, 4, 5, 6),)),
(Index, ((1, 2),), (torch.rand(S, S, S),) ), (Index, ((1, 2),), (torch.rand(S, S, S),)),
(Index, (slice(0, 3),), (torch.rand(S, S, S),), 'slice' ), (Index, (slice(0, 3),), (torch.rand(S, S, S),), 'slice'),
(Index, ((slice(0, 3), 1),),(torch.rand(S, S, S),), 'slice_index' ), (Index, ((slice(0, 3), 1),), (torch.rand(S, S, S),), 'slice_index'),
(View, (S*S, S), (torch.rand(S, S, S),) ), (View, (S * S, S), (torch.rand(S, S, S),)),
(Expand, ((S, 5, S, 5),), ((S, 1, S, 1),) ), (Expand, ((S, 5, S, 5),), ((S, 1, S, 1),)),
(Exp, (), (torch.rand(S, S, S),) ), (Exp, (), (torch.rand(S, S, S),)),
(Log, (), (torch.rand(S, S, S) + 1e-2,) ), (Log, (), (torch.rand(S, S, S) + 1e-2,)),
(Log1p, (), (torch.rand(S, S, S),) ), (Log1p, (), (torch.rand(S, S, S),)),
(Tanh, (), ((S, S, S),) ), (Tanh, (), ((S, S, S),)),
(Sigmoid, (), ((S, S, S),) ), (Sigmoid, (), ((S, S, S),)),
(Sinh, (), ((S, S, S),) ), (Sinh, (), ((S, S, S),)),
(Cosh, (), ((S, S, S),) ), (Cosh, (), ((S, S, S),)),
(Abs, (), ((S, S, S),) ), (Abs, (), ((S, S, S),)),
(Clamp, (0, 1), ((S, S, S),) ), (Clamp, (0, 1), ((S, S, S),)),
(Sqrt, (), (torch.rand(S, S, S) + 1e-4,) ), (Sqrt, (), (torch.rand(S, S, S) + 5e-4,)),
(Sin, (), ((S, S, S),) ), (Sin, (), ((S, S, S),)),
(Cos, (), ((S, S, S),) ), (Cos, (), ((S, S, S),)),
(Tan, (), (torch.randn(S, S, S).clamp(-1, 1),) ), (Tan, (), (torch.randn(S, S, S).clamp(-1, 1),)),
(Asin, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),) ), (Asin, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),)),
(Acos, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),) ), (Acos, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),)),
(Atan, (), ((S, S, S),) ), (Atan, (), ((S, S, S),)),
(Reciprocal, (), (torch.rand(S, S, S) + 0.1,) ), (Reciprocal, (), (torch.rand(S, S, S) + 0.1,)),
(Cmax, (), ((S, S, S), (S, S, S)) ), (Cmax, (), ((S, S, S), (S, S, S))),
(Cmin, (), ((S, S, S), (S, S, S)) ), (Cmin, (), ((S, S, S), (S, S, S))),
(Round, (), ((S, S, S),) ), (Round, (), ((S, S, S),)),
(Sign, (), ((S, S, S),) ), (Sign, (), ((S, S, S),)),
(Trunc, (), ((S, S, S),) ), (Trunc, (), ((S, S, S),)),
(Floor, (), ((S, S, S),) ), (Floor, (), ((S, S, S),)),
(Ceil, (), ((S, S, S),) ), (Ceil, (), ((S, S, S),)),
(Frac, (), ((S, S, S),) ), (Frac, (), ((S, S, S),)),
(Fmod, (1.5,), ((S, S, S),) ), (Fmod, (1.5,), ((S, S, S),)),
(Lerp, (0.2,), ((S, S, S), (S, S, S)) ), (Lerp, (0.2,), ((S, S, S), (S, S, S))),
(Rsqrt, (), (torch.rand(S, S, S) + 1e-2,) ), (Rsqrt, (), (torch.rand(S, S, S) + 1e-2,)),
(Remainder, (1.5,), ((S, S, S),) ), (Remainder, (1.5,), ((S, S, S),)),
(CmaxConstant, (0.5,), ((S, S, S),) ), (CmaxConstant, (0.5,), ((S, S, S),)),
(CminConstant, (0.5,), ((S, S, S),) ), (CminConstant, (0.5,), ((S, S, S),)),
(Mean, (), ((S, S, S),) ), (Mean, (), ((S, S, S),)),
(Mean, (1,), ((S, S, S),), 'dim' ), (Mean, (1,), ((S, S, S),), 'dim'),
(Sum, (), ((S, S, S),) ), (Sum, (), ((S, S, S),)),
(Sum, (1,), ((S, S, S),), 'dim' ), (Sum, (1,), ((S, S, S),), 'dim'),
(Prod, (), ((S, S, S),) ), (Prod, (), ((S, S, S),)),
(Prod, (1,), ((S, S, S),), 'dim' ), (Prod, (1,), ((S, S, S),), 'dim'),
(Addmm, (), ((S, M), (S, S), (S, M)), ), (Addmm, (), ((S, M), (S, S), (S, M)),),
(Addmm, (0.1, 1), ((S, M), (S, S), (S, M)), 'coef' ), (Addmm, (0.1, 1), ((S, M), (S, S), (S, M)), 'coef'),
(Addbmm, (), ((S, M), (S, S, S), (S, S, M)), ), (Addbmm, (), ((S, M), (S, S, S), (S, S, M)),),
(Addbmm, (0.1, 0.4), ((S, M), (S, S, S), (S, S, M)), 'coef' ), (Addbmm, (0.1, 0.4), ((S, M), (S, S, S), (S, S, M)), 'coef'),
(Baddbmm, (), ((S, S, M), (S, S, S), (S, S, M)), ), (Baddbmm, (), ((S, S, M), (S, S, S), (S, S, M)),),
(Baddbmm, (0.1, 0.4), ((S, S, M), (S, S, S), (S, S, M)), 'coef' ), (Baddbmm, (0.1, 0.4), ((S, S, M), (S, S, S), (S, S, M)), 'coef'),
(Addmv, (), ((S,), (S, M), (M,)), ), (Addmv, (), ((S,), (S, M), (M,)),),
(Addmv, (0.1, 0.4), ((S,), (S, M), (M,)), 'coef' ), (Addmv, (0.1, 0.4), ((S,), (S, M), (M,)), 'coef'),
(Addr, (), ((S, M), (S,), (M,)), ), (Addr, (), ((S, M), (S,), (M,)),),
(Addr, (0.1, 0.4), ((S, M), (S,), (M,)), 'coef' ), (Addr, (0.1, 0.4), ((S, M), (S,), (M,)), 'coef'),
(Dot, (), ((L,), (L,)), ), (Dot, (), ((L,), (L,)),),
(Max, (), ((S, S, S),), ), (Max, (), ((S, S, S),),),
(Min, (), ((S, S, S),), ), (Repeat, (torch.Size([2, 3, 1, 4]),), ((S, S, S, S),)),
(Max, (0,), ((S, S, S),), 'dim' ), (Min, (), ((S, S, S),),),
(Min, (0,), ((S, S, S),), 'dim' ), (Max, (0,), ((S, S, S),), 'dim'),
(Mode, (0,), ((S, S, S),), ), (Min, (0,), ((S, S, S),), 'dim'),
(Kthvalue, (2, 0), ((S, S, S),), ), (Mode, (0,), ((S, S, S),),),
(Median, (0,), ((S, S, S),), ), (Kthvalue, (2, 0), ((S, S, S),),),
(Norm, (1.5,), (torch.rand(S, S, S),), '1_5' ), (Median, (0,), ((S, S, S),),),
(Norm, (), ((S, S, S),), '2' ), (Norm, (1.5,), (torch.rand(S, S, S),), '1_5'),
(Norm, (3,), ((S, S, S),), '3' ), (Norm, (), ((S, S, S),), '2'),
(Norm, (1.5, 0), (torch.rand(S, S, S),), '1_5_dim' ), (Norm, (3,), ((S, S, S),), '3'),
(Norm, (2, 0), ((S, S, S),), '2_dim' ), (Norm, (1.5, 0), (torch.rand(S, S, S),), '1_5_dim'),
(Norm, (3, 0), ((S, S, S),), '3_dim' ), (Norm, (2, 0), ((S, S, S),), '2_dim'),
(Addcmul, (), ((S, S), (S, S), (S, S)) ), (Norm, (3, 0), ((S, S, S),), '3_dim'),
(Addcmul, (0.6,), ((S, S), (S, S), (S, S)), 'scale' ), (Addcmul, (), ((S, S), (S, S), (S, S))),
(Addcdiv, (), ((S, S), (S, S), torch.rand(S, S) + 1e-2) ), (Addcmul, (0.6,), ((S, S), (S, S), (S, S)), 'scale'),
(Addcdiv, (0.6,), ((S, S), (S, S), torch.rand(S, S) + 1e-2), 'scale'), (Addcdiv, (), ((S, S), (S, S), torch.rand(S, S) + 1e-2)),
(IndexAdd, (0,), ((S, S), index_variable(2, S), (2, S)) ), (Addcdiv, (0.6,), ((S, S), (S, S), torch.rand(S, S) + 1e-2), 'scale'),
(IndexAdd, (0,), ((S, S), index_variable(2, S), (2, S))),
# (IndexCopy, (0,), ((S, S), index_variable(2, S), (2, S)) ), # (IndexCopy, (0,), ((S, S), index_variable(2, S), (2, S)) ),
(IndexFill, (0, 2), ((S, S), index_variable(2, S)) ), (IndexFill, (0, 2), ((S, S), index_variable(2, S))),
(IndexSelect, (0,), ((S, S), index_variable(2, S)) ), (IndexSelect, (0,), ((S, S), index_variable(2, S))),
(Gather, (0,), ((M, S), gather_variable((S, S), 1, M)) ), (Gather, (0,), ((M, S), gather_variable((S, S), 1, M))),
(Gather, (1,), ((M, S), gather_variable((M, S//2), 0, S)), 'dim1'), (Gather, (1,), ((M, S), gather_variable((M, S // 2), 0, S)), 'dim1'),
(Scatter, (0,), ((M, S), gather_variable((S, S), 1, M), (S, S))), (Scatter, (0,), ((M, S), gather_variable((S, S), 1, M), (S, S))),
(Scatter, (1,), ((M, S), gather_variable((M, S//2), 0, S), (M, S//2)), 'dim1'), (Scatter, (1,), ((M, S), gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1'),
(Concat, (0,), ((1, S, S), (2, S, S), (3, S, S)) ), (Concat, (0,), ((1, S, S), (2, S, S), (3, S, S))),
(Resize, (S*S, S), ((S, S, S),) ), (Resize, (S * S, S), ((S, S, S),)),
(Diag, (), ((S, S),), '2d' ), (Diag, (), ((S, S),), '2d'),
(Diag, (), ((S,),), '1d' ), (Diag, (), ((S,),), '1d'),
(Tril, (), ((S, S),) ), (Tril, (), ((S, S),)),
(Tril, (2,), ((S, S),), 'idx' ), (Tril, (2,), ((S, S),), 'idx'),
(Triu, (), ((S, S),) ), (Triu, (), ((S, S),)),
(Triu, (2,), ((S, S),), 'idx' ), (Triu, (2,), ((S, S),), 'idx'),
(Clone, (), ((S, M, S),) ), (Clone, (), ((S, M, S),)),
(Squeeze, (), ((S, 1, M, 1),) ), (Squeeze, (), ((S, 1, M, 1),)),
(Squeeze, (1,), ((S, 1, M, 1),), 'dim' ), (Squeeze, (1,), ((S, 1, M, 1),), 'dim'),
(Unsqueeze, (0,), ((S, M, S),), '0' ), (Unsqueeze, (0,), ((S, M, S),), '0'),
(Unsqueeze, (1,), ((S, M, S),), '1' ), (Unsqueeze, (1,), ((S, M, S),), '1'),
# (MaskedCopy, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False), (S, S),)), # (MaskedCopy, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False), (S, S),)),
(MaskedFill, (10,), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))), (MaskedFill, (10,), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
(MaskedSelect, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))), (MaskedSelect, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
(Sort, (), ((S, M, S),) ), (Sort, (), ((S, M, S),)),
(Sort, (1,), ((S, M, S),), 'dim' ), (Sort, (1,), ((S, M, S),), 'dim'),
(Sort, (1, True), ((S, M, S),), 'dim_desc' ), (Sort, (1, True), ((S, M, S),), 'dim_desc'),
(Topk, (3,), ((S, M, S),) ), (Topk, (3,), ((S, M, S),)),
(Topk, (3, 1), ((S, M, S),), 'dim' ), (Topk, (3, 1), ((S, M, S),), 'dim'),
(Topk, (3, 1, True), ((S, M, S),), 'dim_desc' ), (Topk, (3, 1, True), ((S, M, S),), 'dim_desc'),
(Topk, (3, 1, True, True), ((S, M, S),), 'dim_desc_sort' ), (Topk, (3, 1, True, True), ((S, M, S),), 'dim_desc_sort'),
] ]
method_tests = [ method_tests = [
('add', (S, S, S), ((S, S, S),) ), ('add', (S, S, S), ((S, S, S),)),
('add', (S, S, S), (3.14,), 'constant' ), ('add', (S, S, S), (3.14,), 'constant'),
('sub', (S, S, S), ((S, S, S),) ), ('sub', (S, S, S), ((S, S, S),)),
('sub', (S, S, S), (3.14,), 'constant' ), ('sub', (S, S, S), (3.14,), 'constant'),
('mul', (S, S, S), ((S, S, S),) ), ('mul', (S, S, S), ((S, S, S),)),
('mul', (S, S, S), (3.14,), 'constant' ), ('mul', (S, S, S), (3.14,), 'constant'),
('div', (S, S, S), ((S, S, S),) ), ('div', (S, S, S), ((S, S, S),)),
('div', (S, S, S), (3.14,), 'constant' ), ('div', (S, S, S), (3.14,), 'constant'),
('pow', (S, S, S), ((S, S, S),) ), ('pow', (S, S, S), ((S, S, S),)),
('pow', (S, S, S), (3.14,), 'constant' ), ('pow', (S, S, S), (3.14,), 'constant'),
('transpose', (1, 2, 3), (1, 2) ), ('transpose', (1, 2, 3), (1, 2)),
('t', (1, 2), () ), ('t', (1, 2), ()),
('view', (S, S, S), (S*S, S), ), ('view', (S, S, S), (S * S, S),),
('view_as', (S, S, S), ((S*S, S),) ), ('view_as', (S, S, S), ((S * S, S),)),
('expand', (S, 1, S), (S, S, S) ), ('expand', (S, 1, S), (S, S, S)),
('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size' ), ('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size'),
('exp', (S, S, S), () ), ('exp', (S, S, S), ()),
('log', (S, S, S), () ), ('log', (S, S, S), ()),
('log1p', (S, S, S), () ), ('log1p', (S, S, S), ()),
('tanh', (S, S, S), () ), ('tanh', (S, S, S), ()),
('sigmoid', (S, S, S), () ), ('sigmoid', (S, S, S), ()),
('sinh', (S, S, S), () ), ('sinh', (S, S, S), ()),
('cosh', (S, S, S), () ), ('cosh', (S, S, S), ()),
('abs', (S, S, S), () ), ('abs', (S, S, S), ()),
('clamp', (S, S, S), (0, 1) ), ('clamp', (S, S, S), (0, 1)),
('sqrt', (S, S, S), () ), ('sqrt', (S, S, S), ()),
('sin', (S, S, S), () ), ('sin', (S, S, S), ()),
('cos', (S, S, S), () ), ('cos', (S, S, S), ()),
('tan', (S, S, S), () ), ('tan', (S, S, S), ()),
('asin', (S, S, S), () ), ('asin', (S, S, S), ()),
('acos', (S, S, S), () ), ('acos', (S, S, S), ()),
('atan', (S, S, S), () ), ('atan', (S, S, S), ()),
('reciprocal', (S, S, S), () ), ('reciprocal', (S, S, S), ()),
('round', (S, S, S), () ), ('round', (S, S, S), ()),
('sign', (S, S, S), () ), ('sign', (S, S, S), ()),
('trunc', (S, S, S), () ), ('trunc', (S, S, S), ()),
('floor', (S, S, S), () ), ('floor', (S, S, S), ()),
('ceil', (S, S, S), () ), ('ceil', (S, S, S), ()),
('rsqrt', (S, S, S), () ), ('rsqrt', (S, S, S), ()),
('fmod', (S, S, S), (1.5,) ), ('fmod', (S, S, S), (1.5,)),
('remainder', (S, S, S), (1.5,) ), ('remainder', (S, S, S), (1.5,)),
('lerp', (S, S, S), ((S, S, S), 0.4) ), ('lerp', (S, S, S), ((S, S, S), 0.4)),
('max', (S, S, S), () ), ('max', (S, S, S), ()),
('max', (S, S, S), ((S, S, S),), 'elementwise' ), ('max', (S, S, S), ((S, S, S),), 'elementwise'),
('min', (S, S, S), () ), ('min', (S, S, S), ()),
('min', (S, S, S), ((S, S, S),), 'elementwise' ), ('min', (S, S, S), ((S, S, S),), 'elementwise'),
('mean', (S, S, S), () ), ('mean', (S, S, S), ()),
('mean', (S, S, S), (1,), 'dim' ), ('mean', (S, S, S), (1,), 'dim'),
('sum', (S, S, S), () ), ('sum', (S, S, S), ()),
('sum', (S, S, S), (1,), 'dim' ), ('sum', (S, S, S), (1,), 'dim'),
('prod', (S, S, S), () ), ('prod', (S, S, S), ()),
('prod', (S, S, S), (1,), 'dim' ), ('prod', (S, S, S), (1,), 'dim'),
('addmm', (S, M), ((S, S), (S, M)), ), ('var', (S, S, S), ()),
('addmm', (S, M), (0.2, 0.6, (S, S), (S, M)), 'coef' ), ('var', (S, S, S), (1,), 'dim'),
('addbmm', (S, M), ((S, S, S), (S, S, M)), ), ('std', (S, S, S), ()),
('addbmm', (S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef' ), ('std', (S, S, S), (1,), 'dim'),
('baddbmm', (S, S, M), ((S, S, S), (S, S, M)), ), ('renorm', (S, S, S), (2, 1, 0.5)),
('baddbmm', (S, S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef' ), ('renorm', (S, S, S), (1, 2, 3), 'norm_1'),
('addmv', (S,), ((S, M), (M,)), ), ('repeat', (S, S, S, S), (2, 3, 1, 4)),
('addmv', (S,), (0.2, 0.6, (S, M), (M,)), 'coef' ), ('addmm', (S, M), ((S, S), (S, M)),),
('addr', (S, M), ((S,), (M,)), ), ('addmm', (S, M), (0.2, 0.6, (S, S), (S, M)), 'coef'),
('addr', (S, M), (0.2, 0.6, (S,), (M,)), 'coef' ), ('addbmm', (S, M), ((S, S, S), (S, S, M)),),
('dot', (L,), ((L,),), ), ('addbmm', (S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef'),
('addcmul', (S, S), ((S, S), (S, S)) ), ('baddbmm', (S, S, M), ((S, S, S), (S, S, M)),),
('addcmul', (S, S), (0.5, (S, S), (S, S)), 'scale' ), ('baddbmm', (S, S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef'),
('addcdiv', (S, S), ((S, S), (S, S)) ), ('addmv', (S,), ((S, M), (M,)),),
('addcdiv', (S, S), (0.5, (S, S), (S, S)), 'scale' ), ('addmv', (S,), (0.2, 0.6, (S, M), (M,)), 'coef'),
('norm', (S, S, S), (2,) ), ('addr', (S, M), ((S,), (M,)),),
('norm', (S, S, S), (2, 1), 'dim' ), ('addr', (S, M), (0.2, 0.6, (S,), (M,)), 'coef'),
('dist', (S, S, S), ((S, S, S),) ), ('dot', (L,), ((L,),),),
('dist', (S, S, S), ((S, S, S), 4), '4' ), ('addcmul', (S, S), ((S, S), (S, S))),
('index_select', (S, S, S), (0, index_variable(2, S)) ), ('addcmul', (S, S), (0.5, (S, S), (S, S)), 'scale'),
('diag', (M, M), (), '2d' ), ('addcdiv', (S, S), ((S, S), (S, S))),
('diag', (M,), (), '1d' ), ('addcdiv', (S, S), (0.5, (S, S), (S, S)), 'scale'),
('tril', (M, M), () ), ('norm', (S, S, S), (2,)),
('triu', (M, M), () ), ('norm', (S, S, S), (2, 1), 'dim'),
('clone', (S, M, S), () ), ('dist', (S, S, S), ((S, S, S),)),
('permute', (1, 2, 3, 4), (0, 2, 3, 1) ), ('dist', (S, S, S), ((S, S, S), 4), '4'),
('select', (S, S, S), (1, 2) ), ('index_select', (S, S, S), (0, index_variable(2, S))),
('narrow', (S, S, S), (1, 2, 2) ), ('diag', (M, M), (), '2d'),
('squeeze', (S, 1, S, 1), () ), ('diag', (M,), (), '1d'),
('squeeze', (S, 1, S, 1), (1,), '1_dim' ), ('tril', (M, M), ()),
('squeeze', (S, 1, S, 1), (2,), 'not_1_dim' ), ('triu', (M, M), ()),
('unsqueeze', (S, S, S), (0,), 'first' ), ('clone', (S, M, S), ()),
('unsqueeze', (S, S, S), (1,), 'middle' ), ('eq', (S, S, S), ((S, S, S),)),
('unsqueeze', (S, S, S), (3,), 'last' ), ('ne', (S, S, S), ((S, S, S),)),
('masked_select', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False),) ), ('gt', (S, S, S), ((S, S, S),)),
('masked_fill_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), 10) ), ('ge', (S, S, S), ((S, S, S),)),
('masked_copy_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), (M, M)) ), ('lt', (S, S, S), ((S, S, S),)),
('le', (S, S, S), ((S, S, S),)),
('eq', (S, S, S), (0,), 'scalar'),
('ne', (S, S, S), (0,), 'scalar'),
('gt', (S, S, S), (0,), 'scalar'),
('ge', (S, S, S), (0,), 'scalar'),
('lt', (S, S, S), (0,), 'scalar'),
('le', (S, S, S), (0,), 'scalar'),
('permute', (1, 2, 3, 4), (0, 2, 3, 1)),
('select', (S, S, S), (1, 2)),
('narrow', (S, S, S), (1, 2, 2)),
('squeeze', (S, 1, S, 1), ()),
('squeeze', (S, 1, S, 1), (1,), '1_dim'),
('squeeze', (S, 1, S, 1), (2,), 'not_1_dim'),
('unsqueeze', (S, S, S), (0,), 'first'),
('unsqueeze', (S, S, S), (1,), 'middle'),
('unsqueeze', (S, S, S), (3,), 'last'),
('masked_select', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False),)),
('masked_fill_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), 10)),
('masked_copy_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), (M, M))),
] ]
# TODO: mm, bmm, mv, ger # TODO: mm, bmm, mv, ger
# TODO: max, min with dim (problem with indices) # TODO: max, min with dim (problem with indices)
@ -941,6 +976,7 @@ method_tests = [
def create_input(call_args): def create_input(call_args):
if not isinstance(call_args, tuple): if not isinstance(call_args, tuple):
call_args = (call_args,) call_args = (call_args,)
def map_arg(arg): def map_arg(arg):
if isinstance(arg, tuple) and not isinstance(arg[0], Variable): if isinstance(arg, tuple) and not isinstance(arg[0], Variable):
return Variable(torch.randn(*arg).double(), requires_grad=True) return Variable(torch.randn(*arg).double(), requires_grad=True)
@ -971,8 +1007,9 @@ ignore_inplace = set((
for test in function_tests: for test in function_tests:
cls, constructor_args, call_args = test[:3] cls, constructor_args, call_args = test[:3]
test_name = 'test_' + cls.__name__ + ('_' + test[3] if len(test) == 4 else '') test_name = 'test_' + cls.__name__ + ('_' + test[3] if len(test) == 4 else '')
def do_test(self, cls=cls, constructor_args=constructor_args, def do_test(self, cls=cls, constructor_args=constructor_args,
call_args=call_args, test_name=test_name): call_args=call_args, test_name=test_name):
input = create_input(call_args) input = create_input(call_args)
output = cls(*constructor_args)(*input) output = cls(*constructor_args)(*input)
if not isinstance(output, tuple): if not isinstance(output, tuple):
@ -981,6 +1018,7 @@ for test in function_tests:
if not o.requires_grad: if not o.requires_grad:
continue continue
analytical = get_analytical_jacobian(input, o) analytical = get_analytical_jacobian(input, o)
def fn(input): def fn(input):
tmp = cls(*constructor_args)(*input) tmp = cls(*constructor_args)(*input)
if not isinstance(tmp, tuple): if not isinstance(tmp, tuple):
@ -1027,6 +1065,7 @@ EXCLUDE_FUNCTIONAL = {
for test in method_tests: for test in method_tests:
name, self_size, args = test[:3] name, self_size, args = test[:3]
test_name = 'test_' + name + ('_' + test[3] if len(test) == 4 else '') test_name = 'test_' + name + ('_' + test[3] if len(test) == 4 else '')
def do_test(self, name=name, self_size=self_size, args=args, test_name=test_name): def do_test(self, name=name, self_size=self_size, args=args, test_name=test_name):
def check(name): def check(name):
self_variable = create_input((self_size,))[0] self_variable = create_input((self_size,))[0]
@ -1056,13 +1095,12 @@ for test in method_tests:
try: try:
check(inplace_name) check(inplace_name)
except Exception as e: except Exception as e:
if not 'only supports scalar' in e.args[0]: if 'only supports scalar' not in e.args[0]:
raise raise
assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name
setattr(TestAutograd, test_name, do_test) setattr(TestAutograd, test_name, do_test)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() run_tests()

View File

@ -7,13 +7,14 @@ import torch
import torch.cuda import torch.cuda
import torch.cuda.comm as comm import torch.cuda.comm as comm
from common import TestCase, get_gpu_type, to_gpu, freeze_rng_state from common import TestCase, get_gpu_type, to_gpu, freeze_rng_state, run_tests
if not torch.cuda.is_available(): if not torch.cuda.is_available():
print('CUDA not available, skipping tests') print('CUDA not available, skipping tests')
import sys import sys
sys.exit() sys.exit()
def is_floating(t): def is_floating(t):
return type(t) in [torch.FloatTensor, torch.DoubleTensor, return type(t) in [torch.FloatTensor, torch.DoubleTensor,
torch.cuda.FloatTensor, torch.cuda.DoubleTensor] torch.cuda.FloatTensor, torch.cuda.DoubleTensor]
@ -31,7 +32,8 @@ types = [
float_types = [ float_types = [
torch.FloatTensor, torch.FloatTensor,
torch.DoubleTensor torch.DoubleTensor
] # TODO: add half... ] # TODO: add half...
def number(floating, integer, t): def number(floating, integer, t):
name = type(t).__name__ name = type(t).__name__
@ -44,188 +46,204 @@ def number(floating, integer, t):
S = 10 S = 10
M = 50 M = 50
def make_tensor(t, *sizes): def make_tensor(t, *sizes):
return t(*sizes).copy_(torch.randn(*sizes)) return t(*sizes).copy_(torch.randn(*sizes))
def small_2d(t): def small_2d(t):
return make_tensor(t, S, S) return make_tensor(t, S, S)
def small_2d_scaled(t, scale=10): def small_2d_scaled(t, scale=10):
return make_tensor(t, S, S).mul(scale) return make_tensor(t, S, S).mul(scale)
def small_3d(t): def small_3d(t):
return make_tensor(t, S, S, S) return make_tensor(t, S, S, S)
def medium_1d(t): def medium_1d(t):
return make_tensor(t, M) return make_tensor(t, M)
def medium_2d(t): def medium_2d(t):
return make_tensor(t, M, M) return make_tensor(t, M, M)
def medium_2d_scaled(t, scale=10): def medium_2d_scaled(t, scale=10):
return make_tensor(t, M, M).mul(scale) return make_tensor(t, M, M).mul(scale)
def small_3d_ones(t): def small_3d_ones(t):
return t(S, S, S).copy_(torch.ones(S, S, S)) return t(S, S, S).copy_(torch.ones(S, S, S))
def small_3d_positive(t): def small_3d_positive(t):
min_val = 1e-3 if is_floating(t) else 2 min_val = 1e-3 if is_floating(t) else 2
return make_tensor(t, S, S, S).clamp_(min_val, 120) return make_tensor(t, S, S, S).clamp_(min_val, 120)
def small_3d_unique(t): def small_3d_unique(t):
return t(S, S, S).copy_(torch.range(1, S*S*S)) return t(S, S, S).copy_(torch.range(1, S * S * S))
def small_1d_lapack(t): def small_1d_lapack(t):
return t(1, 3).copy_(torch.range(1, 3).view(3)) return t(1, 3).copy_(torch.range(1, 3).view(3))
def small_2d_lapack(t): def small_2d_lapack(t):
return t(3, 3).copy_(torch.range(1, 9).view(3, 3)) return t(3, 3).copy_(torch.range(1, 9).view(3, 3))
def small_2d_lapack_skinny(t): def small_2d_lapack_skinny(t):
return t(3, 4).copy_(torch.range(1, 12).view(3, 4)) return t(3, 4).copy_(torch.range(1, 12).view(3, 4))
def small_2d_lapack_fat(t): def small_2d_lapack_fat(t):
return t(4, 3).copy_(torch.range(1, 12).view(4, 3)) return t(4, 3).copy_(torch.range(1, 12).view(4, 3))
def new_t(*sizes): def new_t(*sizes):
def tmp(t): def tmp(t):
return t(*sizes).copy_(torch.randn(*sizes)) return t(*sizes).copy_(torch.randn(*sizes))
return tmp return tmp
tests = [ tests = [
('add', small_3d, lambda t: [number(3.14, 3, t)] ), ('add', small_3d, lambda t: [number(3.14, 3, t)]),
('add', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ), ('add', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
('add', small_3d, lambda t: [number(0.2, 2, t), small_3d_positive(t)], 'scalar_tensor' ), ('add', small_3d, lambda t: [number(0.2, 2, t), small_3d_positive(t)], 'scalar_tensor'),
('sub', small_3d, lambda t: [number(3.14, 3, t)], ), ('sub', small_3d, lambda t: [number(3.14, 3, t)],),
('sub', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ), ('sub', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
('mul', small_3d, lambda t: [number(3.14, 3, t)], ), ('mul', small_3d, lambda t: [number(3.14, 3, t)],),
('mul', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ), ('mul', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
('div', small_3d, lambda t: [number(3.14, 3, t)], ), ('div', small_3d, lambda t: [number(3.14, 3, t)],),
('div', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ), ('div', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
('pow', small_3d, lambda t: [number(3.14, 3, t)], None, float_types), ('pow', small_3d, lambda t: [number(3.14, 3, t)], None, float_types),
('pow', small_3d, lambda t: [small_3d(t).abs_()], 'tensor', float_types), ('pow', small_3d, lambda t: [small_3d(t).abs_()], 'tensor', float_types),
('addbmm', small_2d, lambda t: [small_3d(t), small_3d(t)], None, float_types), ('addbmm', small_2d, lambda t: [small_3d(t), small_3d(t)], None, float_types),
('addbmm', small_2d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar' ), ('addbmm', small_2d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar'),
('addbmm', small_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), small_3d(t), small_3d(t)], 'two_scalars' ), ('addbmm', small_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), small_3d(t), small_3d(t)], 'two_scalars'),
('baddbmm', small_3d, lambda t: [small_3d(t), small_3d(t)], ), ('baddbmm', small_3d, lambda t: [small_3d(t), small_3d(t)],),
('baddbmm', small_3d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar' ), ('baddbmm', small_3d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar'),
('baddbmm', small_3d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), small_3d(t), small_3d(t)], 'two_scalars' ), ('baddbmm', small_3d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), small_3d(t), small_3d(t)], 'two_scalars'),
('addcdiv', small_2d_lapack, lambda t: [small_2d_lapack(t).mul(2), small_2d_lapack(t)], ), ('addcdiv', small_2d_lapack, lambda t: [small_2d_lapack(t).mul(2), small_2d_lapack(t)],),
('addcdiv', small_2d_lapack, lambda t: [number(2.8, 1, t), small_2d_lapack(t).mul(2), small_2d_lapack(t)], 'scalar' ), ('addcdiv', small_2d_lapack, lambda t: [number(2.8, 1, t),
('addcmul', small_3d, lambda t: [small_3d(t), small_3d(t)], ), small_2d_lapack(t).mul(2), small_2d_lapack(t)], 'scalar'),
('addcmul', small_3d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar' ), ('addcmul', small_3d, lambda t: [small_3d(t), small_3d(t)],),
('addmm', medium_2d, lambda t: [medium_2d(t), medium_2d(t)], ), ('addcmul', small_3d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar'),
('addmm', medium_2d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_2d(t)], 'scalar' ), ('addmm', medium_2d, lambda t: [medium_2d(t), medium_2d(t)],),
('addmm', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_2d(t)], 'two_scalars' ), ('addmm', medium_2d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_2d(t)], 'scalar'),
('addmv', medium_1d, lambda t: [medium_2d(t), medium_1d(t)], ), ('addmm', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_2d(t)], 'two_scalars'),
('addmv', medium_1d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'scalar' ), ('addmv', medium_1d, lambda t: [medium_2d(t), medium_1d(t)],),
('addmv', medium_1d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'two_scalars' ), ('addmv', medium_1d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'scalar'),
('addr', medium_2d, lambda t: [medium_1d(t), medium_1d(t)], ), ('addmv', medium_1d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'two_scalars'),
('addr', medium_2d, lambda t: [number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'scalar' ), ('addr', medium_2d, lambda t: [medium_1d(t), medium_1d(t)],),
('addr', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'two_scalars' ), ('addr', medium_2d, lambda t: [number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'scalar'),
('atan2', medium_2d, lambda t: [medium_2d(t)], None, float_types), ('addr', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'two_scalars'),
('fmod', small_3d, lambda t: [3], 'value' ), ('atan2', medium_2d, lambda t: [medium_2d(t)], None, float_types),
('fmod', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ), ('fmod', small_3d, lambda t: [3], 'value'),
('chunk', medium_2d, lambda t: [4], ), ('fmod', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
('chunk', medium_2d, lambda t: [4, 1], 'dim' ), ('chunk', medium_2d, lambda t: [4],),
('clamp', medium_2d_scaled, lambda t: [-1, 5], ), ('chunk', medium_2d, lambda t: [4, 1], 'dim'),
('clone', medium_2d, lambda t: [], ), ('clamp', medium_2d_scaled, lambda t: [-1, 5],),
('contiguous', medium_2d, lambda t: [], ), ('clone', medium_2d, lambda t: [],),
('cross', new_t(M, 3, M), lambda t: [new_t(M, 3, M)(t)], ), ('contiguous', medium_2d, lambda t: [],),
('cumprod', small_3d, lambda t: [1], ), ('cross', new_t(M, 3, M), lambda t: [new_t(M, 3, M)(t)],),
('cumsum', small_3d, lambda t: [1], ), ('cumprod', small_3d, lambda t: [1],),
('dim', small_3d, lambda t: [], ), ('cumsum', small_3d, lambda t: [1],),
('dist', small_2d, lambda t: [small_2d(t)], ), ('dim', small_3d, lambda t: [],),
('dist', small_2d, lambda t: [small_2d(t), 3], '3_norm' ), ('dist', small_2d, lambda t: [small_2d(t)],),
('dist', small_2d, lambda t: [small_2d(t), 2.5], '2_5_norm' ), ('dist', small_2d, lambda t: [small_2d(t), 3], '3_norm'),
('dot', medium_1d, lambda t: [medium_1d(t)], ), ('dist', small_2d, lambda t: [small_2d(t), 2.5], '2_5_norm'),
('element_size', medium_1d, lambda t: [], ), ('dot', medium_1d, lambda t: [medium_1d(t)],),
('eq', small_3d_ones, lambda t: [small_3d(t)], ), ('element_size', medium_1d, lambda t: [],),
('eq', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ), ('eq', small_3d_ones, lambda t: [small_3d(t)],),
('ne', small_3d_ones, lambda t: [small_3d(t)], ), ('eq', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal'),
('ne', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ), ('ne', small_3d_ones, lambda t: [small_3d(t)],),
('equal', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ), ('ne', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal'),
('equal', small_3d_ones, lambda t: [small_3d(t)], ), ('equal', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal'),
('expand', new_t(M, 1, M), lambda t: [M, 4, M], ), ('equal', small_3d_ones, lambda t: [small_3d(t)],),
('expand_as', new_t(M, 1, M), lambda t: [new_t(M, 4, M)(t)], ), ('expand', new_t(M, 1, M), lambda t: [M, 4, M],),
('fill', medium_2d, lambda t: [number(3.14, 3, t)], ), ('expand_as', new_t(M, 1, M), lambda t: [new_t(M, 4, M)(t)],),
('ge', medium_2d, lambda t: [medium_2d(t)], ), ('fill', medium_2d, lambda t: [number(3.14, 3, t)],),
('le', medium_2d, lambda t: [medium_2d(t)], ), ('ge', medium_2d, lambda t: [medium_2d(t)],),
('gt', medium_2d, lambda t: [medium_2d(t)], ), ('le', medium_2d, lambda t: [medium_2d(t)],),
('lt', medium_2d, lambda t: [medium_2d(t)], ), ('gt', medium_2d, lambda t: [medium_2d(t)],),
('is_contiguous', medium_2d, lambda t: [], ), ('lt', medium_2d, lambda t: [medium_2d(t)],),
('is_contiguous', medium_2d, lambda t: [],),
# TODO: can't check negative case - GPU copy will be contiguous # TODO: can't check negative case - GPU copy will be contiguous
('is_same_size', medium_2d, lambda t: [small_3d(t)], 'negative' ), ('is_same_size', medium_2d, lambda t: [small_3d(t)], 'negative'),
('is_same_size', medium_2d, lambda t: [medium_2d(t)], 'positive' ), ('is_same_size', medium_2d, lambda t: [medium_2d(t)], 'positive'),
('is_set_to', medium_2d, lambda t: [medium_2d(t)], ), ('is_set_to', medium_2d, lambda t: [medium_2d(t)],),
# TODO: positive case # TODO: positive case
('kthvalue', small_3d_unique, lambda t: [3], ), ('kthvalue', small_3d_unique, lambda t: [3],),
('kthvalue', small_3d_unique, lambda t: [3, 1], 'dim' ), ('kthvalue', small_3d_unique, lambda t: [3, 1], 'dim'),
('lerp', small_3d, lambda t: [small_3d(t), 0.3], ), ('lerp', small_3d, lambda t: [small_3d(t), 0.3],),
('max', small_3d_unique, lambda t: [], ), ('max', small_3d_unique, lambda t: [],),
('max', small_3d_unique, lambda t: [1], 'dim' ), ('max', small_3d_unique, lambda t: [1], 'dim'),
('max', medium_2d, lambda t: [medium_2d(t)], 'elementwise' ), ('max', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
('min', small_3d_unique, lambda t: [], ), ('min', small_3d_unique, lambda t: [],),
('min', small_3d_unique, lambda t: [1], 'dim' ), ('min', small_3d_unique, lambda t: [1], 'dim'),
('min', medium_2d, lambda t: [medium_2d(t)], 'elementwise' ), ('min', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
('mean', small_3d, lambda t: [], ), ('mean', small_3d, lambda t: [],),
('mean', small_3d, lambda t: [1], 'dim' ), ('mean', small_3d, lambda t: [1], 'dim'),
('mode', small_3d, lambda t: [], ), ('mode', small_3d, lambda t: [],),
('mode', small_3d, lambda t: [1], 'dim' ), ('mode', small_3d, lambda t: [1], 'dim'),
('remainder', small_3d, lambda t: [3], 'value' ), ('remainder', small_3d, lambda t: [3], 'value'),
('remainder', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ), ('remainder', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
('std', small_3d, lambda t: [], ), ('std', small_3d, lambda t: [],),
('std', small_3d, lambda t: [1], 'dim' ), ('std', small_3d, lambda t: [1], 'dim'),
('var', small_3d, lambda t: [], ), ('var', small_3d, lambda t: [],),
('var', small_3d, lambda t: [1], 'dim' ), ('var', small_3d, lambda t: [1], 'dim'),
('ndimension', small_3d, lambda t: [], ), ('ndimension', small_3d, lambda t: [],),
('nelement', small_3d, lambda t: [], ), ('nelement', small_3d, lambda t: [],),
('numel', small_3d, lambda t: [], ), ('numel', small_3d, lambda t: [],),
('narrow', small_3d, lambda t: [1, 3, 2], ), ('narrow', small_3d, lambda t: [1, 3, 2],),
('nonzero', small_3d, lambda t: [], ), ('nonzero', small_3d, lambda t: [],),
('norm', small_3d, lambda t: [], ), ('norm', small_3d, lambda t: [],),
('norm', small_3d, lambda t: [3], '3_norm' ), ('norm', small_3d, lambda t: [3], '3_norm'),
('norm', small_3d, lambda t: [3, 0], '3_norm_dim' ), ('norm', small_3d, lambda t: [3, 0], '3_norm_dim'),
('ones', small_3d, lambda t: [1, 2, 3, 4, 5], ), ('ones', small_3d, lambda t: [1, 2, 3, 4, 5],),
('permute', new_t(1, 2, 3, 4), lambda t: [2, 1, 3, 0], ), ('permute', new_t(1, 2, 3, 4), lambda t: [2, 1, 3, 0],),
('prod', small_3d, lambda t: [], ), ('prod', small_3d, lambda t: [],),
('prod', small_3d, lambda t: [1], 'dim' ), ('prod', small_3d, lambda t: [1], 'dim'),
('sum', small_2d, lambda t: [], ), ('sum', small_2d, lambda t: [],),
('sum', small_3d, lambda t: [1], 'dim' ), ('sum', small_3d, lambda t: [1], 'dim'),
('renorm', small_3d, lambda t: [2, 1, 1], '2_norm' ), ('renorm', small_3d, lambda t: [2, 1, 1], '2_norm'),
('renorm', small_3d, lambda t: [1.5, 1, 1], '1_5_norm' ), ('renorm', small_3d, lambda t: [1.5, 1, 1], '1_5_norm'),
('repeat', small_2d, lambda t: [2, 2, 2], ), ('repeat', small_2d, lambda t: [2, 2, 2],),
('size', new_t(1, 2, 3, 4), lambda t: [], ), ('size', new_t(1, 2, 3, 4), lambda t: [],),
('sort', small_3d_unique, lambda t: [], ), ('sort', small_3d_unique, lambda t: [],),
('sort', small_3d_unique, lambda t: [1], 'dim' ), ('sort', small_3d_unique, lambda t: [1], 'dim'),
('sort', small_3d_unique, lambda t: [1, True], 'dim_descending'), ('sort', small_3d_unique, lambda t: [1, True], 'dim_descending'),
('split', small_3d, lambda t: [2], ), ('split', small_3d, lambda t: [2],),
('split', small_3d, lambda t: [2, 1], 'dim' ), ('split', small_3d, lambda t: [2, 1], 'dim'),
('squeeze', new_t(1, 2, 1, 4), lambda t: [], ), ('squeeze', new_t(1, 2, 1, 4), lambda t: [],),
('squeeze', new_t(1, 2, 1, 4), lambda t: [2], 'dim' ), ('squeeze', new_t(1, 2, 1, 4), lambda t: [2], 'dim'),
('t', new_t(1, 2), lambda t: [], ), ('t', new_t(1, 2), lambda t: [],),
('transpose', new_t(1, 2, 3, 4), lambda t: [1, 2], ), ('transpose', new_t(1, 2, 3, 4), lambda t: [1, 2],),
('to_list', small_3d, lambda t: [], ), ('to_list', small_3d, lambda t: [],),
('topk', small_3d, lambda t: [2, 1, False, True], 'dim_sort' ), ('topk', small_3d, lambda t: [2, 1, False, True], 'dim_sort'),
('topk', small_3d, lambda t: [2, 1, True, True], 'dim_desc_sort' ), ('topk', small_3d, lambda t: [2, 1, True, True], 'dim_desc_sort'),
('trace', medium_2d, lambda t: [], ), ('trace', medium_2d, lambda t: [],),
('tril', medium_2d, lambda t: [], ), ('tril', medium_2d, lambda t: [],),
('tril', medium_2d, lambda t: [2], 'positive' ), ('tril', medium_2d, lambda t: [2], 'positive'),
('tril', medium_2d, lambda t: [-2], 'negative' ), ('tril', medium_2d, lambda t: [-2], 'negative'),
('triu', medium_2d, lambda t: [], ), ('triu', medium_2d, lambda t: [],),
('triu', medium_2d, lambda t: [2], 'positive' ), ('triu', medium_2d, lambda t: [2], 'positive'),
('triu', medium_2d, lambda t: [-2], 'negative' ), ('triu', medium_2d, lambda t: [-2], 'negative'),
('view', small_3d, lambda t: [100, 10], ), ('view', small_3d, lambda t: [100, 10],),
('view_as', small_3d, lambda t: [t(100, 10)], ), ('view_as', small_3d, lambda t: [t(100, 10)],),
('zero', small_3d, lambda t: [], ), ('zero', small_3d, lambda t: [],),
('zeros', small_3d, lambda t: [1, 2, 3, 4], ), ('zeros', small_3d, lambda t: [1, 2, 3, 4],),
('rsqrt', lambda t: small_3d(t) + 1, lambda t: [], None, float_types), ('rsqrt', lambda t: small_3d(t) + 1, lambda t: [], None, float_types),
('sinh', lambda t: small_3d(t).clamp(-1, 1), lambda t: [], None, float_types), ('sinh', lambda t: small_3d(t).clamp(-1, 1), lambda t: [], None, float_types),
('tan', lambda t: small_3d(t).clamp(-1, 1), lambda t: [], None, float_types), ('tan', lambda t: small_3d(t).clamp(-1, 1), lambda t: [], None, float_types),
# lapack tests # lapack tests
('qr', small_2d_lapack, lambda t: [], 'square', float_types), ('qr', small_2d_lapack, lambda t: [], 'square', float_types),
('qr', small_2d_lapack_skinny, lambda t: [], 'skinny', float_types), ('qr', small_2d_lapack_skinny, lambda t: [], 'skinny', float_types),
('qr', small_2d_lapack_fat, lambda t: [], 'fat', float_types), ('qr', small_2d_lapack_fat, lambda t: [], 'fat', float_types),
] ]
@ -275,6 +293,8 @@ for fn in simple_pointwise_float:
tests.append((fn, small_3d, lambda t: [], None, float_types)) tests.append((fn, small_3d, lambda t: [], None, float_types))
_cycles_per_ms = None _cycles_per_ms = None
def get_cycles_per_ms(): def get_cycles_per_ms():
"""Approximate number of cycles per millisecond for torch.cuda._sleep""" """Approximate number of cycles per millisecond for torch.cuda._sleep"""
global _cycles_per_ms global _cycles_per_ms
@ -288,6 +308,7 @@ def get_cycles_per_ms():
_cycles_per_ms = 1000000 / start.elapsed_time(end) _cycles_per_ms = 1000000 / start.elapsed_time(end)
return _cycles_per_ms return _cycles_per_ms
def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5): def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5):
def tmp(self): def tmp(self):
cpu_tensor = tensor_constructor(t) cpu_tensor = tensor_constructor(t)
@ -314,6 +335,7 @@ def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5):
self.assertEqual(cpu_result, gpu_result, precision) self.assertEqual(cpu_result, gpu_result, precision)
return tmp return tmp
class TestCuda(TestCase): class TestCuda(TestCase):
def test_autogpu(self): def test_autogpu(self):
@ -412,7 +434,7 @@ class TestCuda(TestCase):
y_cuda = y.cuda(1) y_cuda = y.cuda(1)
result = comm.reduce_add((x_cuda, y_cuda)) result = comm.reduce_add((x_cuda, y_cuda))
self.assertEqual(result.get_device(), 0) self.assertEqual(result.get_device(), 0)
self.assertEqual(result.cpu(), x+y) self.assertEqual(result.cpu(), x + y)
def _test_scatter(self, input, chunk_sizes=None, dim=0): def _test_scatter(self, input, chunk_sizes=None, dim=0):
if torch.cuda.device_count() < 2: if torch.cuda.device_count() < 2:
@ -473,7 +495,7 @@ class TestCuda(TestCase):
self._test_gather(1) self._test_gather(1)
def test_from_sequence(self): def test_from_sequence(self):
seq = [list(range(i*4,i*4+4)) for i in range(5)] seq = [list(range(i * 4, i * 4 + 4)) for i in range(5)]
reference = torch.range(0, 19).resize_(5, 4) reference = torch.range(0, 19).resize_(5, 4)
for t in types: for t in types:
cuda_type = get_gpu_type(t) cuda_type = get_gpu_type(t)
@ -526,6 +548,7 @@ class TestCuda(TestCase):
@unittest.skipIf(torch.cuda.device_count() < 2, "detected only one GPU") @unittest.skipIf(torch.cuda.device_count() < 2, "detected only one GPU")
def test_multigpu_serialization_remap(self): def test_multigpu_serialization_remap(self):
x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)] x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)]
def gpu_remap(storage, location): def gpu_remap(storage, location):
if location == 'cuda:1': if location == 'cuda:1':
return storage.cuda(0) return storage.cuda(0)
@ -666,7 +689,8 @@ for decl in tests:
if not hasattr(tensor, name_inner): if not hasattr(tensor, name_inner):
continue continue
if not hasattr(gpu_tensor, name_inner): if not hasattr(gpu_tensor, name_inner):
print("Ignoring {}, because it's not implemented by torch.cuda.{}".format(name_inner, gpu_tensor.__class__.__name__)) print("Ignoring {}, because it's not implemented by torch.cuda.{}".format(
name_inner, gpu_tensor.__class__.__name__))
continue continue
test_name = 'test_' + t.__name__ + '_' + name_inner test_name = 'test_' + t.__name__ + '_' + name_inner
@ -677,4 +701,4 @@ for decl in tests:
setattr(TestCuda, test_name, compare_cpu_gpu(constr, arg_constr, name_inner, t, precision)) setattr(TestCuda, test_name, compare_cpu_gpu(constr, arg_constr, name_inner, t, precision))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() run_tests()

View File

@ -4,7 +4,7 @@ import torch
import traceback import traceback
import unittest import unittest
from torch.utils.data import Dataset, TensorDataset, DataLoader from torch.utils.data import Dataset, TensorDataset, DataLoader
from common import TestCase from common import TestCase, run_tests
from common_nn import TEST_CUDA from common_nn import TEST_CUDA
@ -27,11 +27,12 @@ class TestTensorDataset(TestCase):
l = torch.randn(15) l = torch.randn(15)
source = TensorDataset(t, l) source = TensorDataset(t, l)
for i in range(15): for i in range(15):
self.assertEqual(t[i:i+1], source[i][0]) self.assertEqual(t[i:i + 1], source[i][0])
self.assertEqual(l[i:i+1], source[i][1]) self.assertEqual(l[i:i + 1], source[i][1])
class ErrorDataset(Dataset): class ErrorDataset(Dataset):
def __init__(self, size): def __init__(self, size):
self.size = size self.size = size
@ -50,9 +51,9 @@ class TestDataLoader(TestCase):
batch_size = loader.batch_size batch_size = loader.batch_size
for i, (sample, target) in enumerate(loader): for i, (sample, target) in enumerate(loader):
idx = i * batch_size idx = i * batch_size
self.assertEqual(sample, self.data[idx:idx+batch_size]) self.assertEqual(sample, self.data[idx:idx + batch_size])
self.assertEqual(target, self.labels[idx:idx+batch_size].view(-1, 1)) self.assertEqual(target, self.labels[idx:idx + batch_size].view(-1, 1))
self.assertEqual(i, math.floor((len(self.dataset)-1) / batch_size)) self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
def _test_shuffle(self, loader): def _test_shuffle(self, loader):
found_data = {i: 0 for i in range(self.data.size(0))} found_data = {i: 0 for i in range(self.data.size(0))}
@ -67,9 +68,9 @@ class TestDataLoader(TestCase):
break break
self.assertEqual(target, self.labels.narrow(0, data_point_idx, 1)) self.assertEqual(target, self.labels.narrow(0, data_point_idx, 1))
found_labels[data_point_idx] += 1 found_labels[data_point_idx] += 1
self.assertEqual(sum(found_data.values()), (i+1) * batch_size) self.assertEqual(sum(found_data.values()), (i + 1) * batch_size)
self.assertEqual(sum(found_labels.values()), (i+1) * batch_size) self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size)
self.assertEqual(i, math.floor((len(self.dataset)-1) / batch_size)) self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
def _test_error(self, loader): def _test_error(self, loader):
it = iter(loader) it = iter(loader)
@ -81,10 +82,9 @@ class TestDataLoader(TestCase):
errors += 1 errors += 1
except StopIteration: except StopIteration:
self.assertEqual(errors, self.assertEqual(errors,
math.ceil(float(len(loader.dataset))/loader.batch_size)) math.ceil(float(len(loader.dataset)) / loader.batch_size))
return return
def test_sequential(self): def test_sequential(self):
self._test_sequential(DataLoader(self.dataset)) self._test_sequential(DataLoader(self.dataset))
@ -159,4 +159,4 @@ class TestDataLoader(TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() run_tests()

508
test/test_distributed.py Normal file
View File

@ -0,0 +1,508 @@
import fcntl
import multiprocessing
import os
import sys
import time
import unittest
from functools import wraps, reduce
from contextlib import contextmanager
import torch
import torch.distributed as dist
from common import TestCase
BACKEND = os.environ['BACKEND']
TEMP_DIR = os.environ['TEMP_DIR']
MASTER_PORT = '29500'
MASTER_ADDR = '127.0.0.1:' + MASTER_PORT
@contextmanager
def _lock():
lockfile = os.path.join(TEMP_DIR, 'lockfile')
with open(lockfile, 'w') as lf:
try:
fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
yield
finally:
fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
lf.close()
def _build_tensor(size, value=None):
if value is None:
value = size
return torch.FloatTensor(size, size, size).fill_(value)
class Barrier(object):
barrier_id = 0
@classmethod
def init(cls):
cls.barrier_id = 0
barrier_dir = os.path.join(TEMP_DIR, 'barrier')
for f_name in os.listdir(barrier_dir):
os.unlink(os.path.join(barrier_dir, f_name))
@classmethod
def sync(cls, timeout=5):
cls.barrier_id += 1
barrier_dir = os.path.join(TEMP_DIR, 'barrier')
pid = str(os.getpid())
barrier_file = os.path.join(barrier_dir, pid)
with _lock():
with open(barrier_file, 'w') as f:
f.write(str(cls.barrier_id))
start_time = time.time()
while True:
arrived = 0
with _lock():
for f_name in os.listdir(barrier_dir):
with open(os.path.join(barrier_dir, f_name), 'r') as f:
data = f.read()
if int(data) >= cls.barrier_id:
arrived += 1
if arrived == dist.get_num_processes():
break
if time.time() - start_time > timeout:
raise RuntimeError("barrier timeout")
time.sleep(0.1)
class _DistTestBase(object):
def _barrier(self, *args, **kwargs):
Barrier.sync(*args, **kwargs)
def _init_group_test(self):
group = [1, 2]
group_id = dist.new_group(group)
rank = dist.get_rank()
if rank not in group:
return ([], None, rank)
return (group, group_id, rank)
def _init_global_test(self):
group = [i for i in range(0, dist.get_num_processes())]
group_id = dist.group.WORLD
rank = dist.get_rank()
return (group, group_id, rank)
# GET RANK
def test_get_rank(self):
test_dir = os.path.join(TEMP_DIR, 'test_dir')
pid = str(os.getpid())
num_processes = dist.get_num_processes()
with open(os.path.join(test_dir, pid), 'w') as f:
f.write(str(dist.get_rank()))
self._barrier()
all_ranks = set()
for f_name in os.listdir(test_dir):
with open(os.path.join(test_dir, f_name), 'r') as f:
all_ranks.add(int(f.read()))
self.assertEqual(len(all_ranks), num_processes)
self._barrier()
if dist.get_rank() == 0:
for f_name in os.listdir(test_dir):
os.unlink(os.path.join(test_dir, f_name))
self._barrier()
# SEND RECV
def test_send_recv(self):
rank = dist.get_rank()
tensor = _build_tensor(rank + 1)
for dest in range(0, dist.get_num_processes()):
if dest == rank:
continue
dist.send(tensor, dest)
for src in range(0, dist.get_num_processes()):
if src == rank:
continue
tensor = _build_tensor(src + 1, value=-1)
expected_tensor = _build_tensor(src + 1)
dist.recv(tensor, src)
self.assertEqual(tensor, expected_tensor)
self._barrier()
# SEND RECV ANY SOURCE
def test_send_recv_any_source(self):
rank = dist.get_rank()
tensor = _build_tensor(10, rank)
for dest in range(0, dist.get_num_processes()):
if dest == rank:
continue
dist.send(tensor, dest)
recv_ranks = set()
for src in range(0, dist.get_num_processes()):
if src == rank:
continue
tensor = _build_tensor(10, value=-1)
dist.recv(tensor)
recv_ranks.add(tensor.resize_(1)[0])
self.assertEqual(len(recv_ranks), dist.get_num_processes() - 1)
self._barrier()
# ISEND
def test_isend(self):
rank = dist.get_rank()
world_size = dist.get_num_processes()
if rank == 0:
requests = [
dist.isend(_build_tensor(dest, 10), dest) for dest in range(1, world_size)
]
for request in requests:
request.wait()
self.assertTrue(request.is_completed())
else:
tensor = _build_tensor(rank, -1)
dist.recv(tensor, 0)
self.assertEqual(tensor, _build_tensor(rank, 10))
self._barrier()
# IRECV
def test_irecv(self):
rank = dist.get_rank()
world_size = dist.get_num_processes()
if rank == 0:
expected_tensors = [_build_tensor(src, -1) for src in range(1, world_size)]
requests = [
dist.irecv(expected_tensors[src - 1], src) for src in range(1, world_size)
]
for src in range(1, world_size):
requests[src - 1].wait()
self.assertTrue(requests[src - 1].is_completed())
self.assertEqual(expected_tensors[src - 1], _build_tensor(src, 10))
else:
tensor = _build_tensor(rank, 10)
dist.send(tensor, 0)
self._barrier()
# BROADCAST
def _test_broadcast_helper(self, group, group_id, rank):
for src in group:
expected_tensor = _build_tensor(src + 1)
if rank == src:
dist.broadcast(expected_tensor, src, group_id)
else:
tensor = _build_tensor(src + 1, -1)
dist.broadcast(tensor, src, group_id)
self.assertEqual(tensor, expected_tensor)
self._barrier()
def test_broadcast(self):
group, group_id, rank = self._init_global_test()
self._test_broadcast_helper(group, group_id, rank)
def test_broadcast_group(self):
group, group_id, rank = self._init_group_test()
self._test_broadcast_helper(group, group_id, rank)
# REDUCE
def _test_reduce_helper(self, group, group_id, rank, op, master_value, worker_value, expected_value):
for src in group:
if rank == src:
tensor = _build_tensor(src + 1).fill_(master_value)
dist.reduce(tensor, src, op, group_id)
self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
else:
tensor = _build_tensor(src + 1).fill_(worker_value)
dist.reduce(tensor, src, op, group_id)
self._barrier()
def test_reduce_sum(self):
group, group_id, rank = self._init_global_test()
self._test_reduce_helper(
group, group_id, rank, dist.reduce_op.SUM, 2, 10, 2 + (10 * (len(group) - 1))
)
def test_reduce_product(self):
group, group_id, rank = self._init_global_test()
self._test_reduce_helper(
group, group_id, rank, dist.reduce_op.PRODUCT,
2, 10, reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2)
)
def test_reduce_min(self):
group, group_id, rank = self._init_global_test()
self._test_reduce_helper(
group, group_id, rank, dist.reduce_op.MIN, 1010, 1, 1
)
def test_reduce_max(self):
group, group_id, rank = self._init_global_test()
self._test_reduce_helper(
group, group_id, rank, dist.reduce_op.MAX, -1, 10, 10
)
def test_reduce_group_sum(self):
group, group_id, rank = self._init_group_test()
self._test_reduce_helper(
group, group_id, rank, dist.reduce_op.SUM, 2, 10, 2 + (10 * (len(group) - 1))
)
def test_reduce_group_product(self):
group, group_id, rank = self._init_group_test()
self._test_reduce_helper(
group, group_id, rank, dist.reduce_op.PRODUCT,
2, 10, reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2)
)
def test_reduce_group_min(self):
group, group_id, rank = self._init_group_test()
self._test_reduce_helper(
group, group_id, rank, dist.reduce_op.MIN, 1010, 1, 1
)
def test_reduce_group_max(self):
group, group_id, rank = self._init_group_test()
self._test_reduce_helper(
group, group_id, rank, dist.reduce_op.MAX, -1, 10, 10
)
# ALL REDUCE
def _test_all_reduce_helper(self, group, group_id, rank, op, master_value, worker_value, expected_value):
for src in group:
if rank == src:
tensor = _build_tensor(src + 1).fill_(master_value)
dist.all_reduce(tensor, op, group_id)
self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
else:
tensor = _build_tensor(src + 1).fill_(worker_value)
dist.all_reduce(tensor, op, group_id)
self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
self._barrier()
def test_all_reduce_sum(self):
group, group_id, rank = self._init_global_test()
self._test_all_reduce_helper(
group, group_id, rank, dist.reduce_op.SUM, 2, 10, 2 + (10 * (len(group) - 1))
)
def test_all_reduce_product(self):
group, group_id, rank = self._init_global_test()
self._test_all_reduce_helper(
group, group_id, rank, dist.reduce_op.PRODUCT,
2, 10, reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2)
)
def test_all_reduce_min(self):
group, group_id, rank = self._init_global_test()
self._test_all_reduce_helper(
group, group_id, rank, dist.reduce_op.MIN, 1010, 1, 1
)
def test_all_reduce_max(self):
group, group_id, rank = self._init_global_test()
self._test_all_reduce_helper(
group, group_id, rank, dist.reduce_op.MAX, -1, 10, 10
)
def test_all_reduce_group_sum(self):
group, group_id, rank = self._init_group_test()
self._test_all_reduce_helper(
group, group_id, rank, dist.reduce_op.SUM, 2, 10, 2 + (10 * (len(group) - 1))
)
def test_all_reduce_group_product(self):
group, group_id, rank = self._init_group_test()
self._test_all_reduce_helper(
group, group_id, rank, dist.reduce_op.PRODUCT,
2, 10, reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2)
)
def test_all_reduce_group_min(self):
group, group_id, rank = self._init_group_test()
self._test_all_reduce_helper(
group, group_id, rank, dist.reduce_op.MIN, 1010, 1, 1
)
def test_all_reduce_group_max(self):
group, group_id, rank = self._init_group_test()
self._test_all_reduce_helper(
group, group_id, rank, dist.reduce_op.MAX, -1, 10, 10
)
# SCATTER
def _test_scatter_helper(self, group, group_id, rank):
for dest in group:
tensor = _build_tensor(dest + 1, -1)
expected_tensor = _build_tensor(dest + 1, rank)
if rank == dest:
tensors = [_build_tensor(dest + 1, i) for i in group]
dist.scatter_send(tensors, tensor, group_id)
self.assertEqual(tensor, expected_tensor)
else:
dist.scatter_recv(tensor, dest, group_id)
self.assertEqual(tensor, expected_tensor)
self._barrier()
def test_scatter(self):
group, group_id, rank = self._init_global_test()
self._test_scatter_helper(group, group_id, rank)
def test_scatter_group(self):
group, group_id, rank = self._init_group_test()
self._test_scatter_helper(group, group_id, rank)
# GATHER
def _test_gather_helper(self, group, group_id, rank):
for dest in group:
tensor = _build_tensor(dest + 1, rank)
if rank == dest:
tensors = [_build_tensor(dest + 1, -1) for i in group]
dist.gather_recv(tensors, tensor, group_id)
expected_tensors = [_build_tensor(dest + 1, i) for i in group]
for t1, t2 in zip(tensors, expected_tensors):
self.assertEqual(t1, t2)
else:
dist.gather_send(tensor, dest, group_id)
self._barrier()
def test_gather(self):
group, group_id, rank = self._init_global_test()
self._test_gather_helper(group, group_id, rank)
def test_gather_group(self):
group, group_id, rank = self._init_group_test()
self._test_gather_helper(group, group_id, rank)
# ALL GATHER
def _test_all_gather_helper(self, group, group_id, rank):
for dest in group:
tensor = _build_tensor(dest + 1, rank)
tensors = [_build_tensor(dest + 1, -1) for i in group]
dist.all_gather(tensors, tensor, group_id)
expected_tensors = [_build_tensor(dest + 1, i) for i in group]
for t1, t2 in zip(tensors, expected_tensors):
self.assertEqual(t1, t2)
self._barrier()
def test_all_gather(self):
group, group_id, rank = self._init_global_test()
self._test_all_gather_helper(group, group_id, rank)
def test_all_gather_group(self):
group, group_id, rank = self._init_group_test()
self._test_all_gather_helper(group, group_id, rank)
# BARRIER
def _test_barrier_helper(self, group, group_id, rank):
WAIT_TIME = 0.3 # seconds
for dest in group:
expected_time = torch.DoubleTensor(1).fill_(0.0)
if dest == rank:
expected_time.fill_(time.time() + WAIT_TIME)
dist.broadcast(expected_time, dest, group_id)
time.sleep(WAIT_TIME + 0.1) # sleep a little bit longer
dist.barrier(group_id)
else:
dist.broadcast(expected_time, dest, group_id)
dist.barrier(group_id)
self.assertGreaterEqual(time.time(), expected_time[0])
self._barrier()
def test_barrier(self):
group, group_id, rank = self._init_global_test()
self._test_barrier_helper(group, group_id, rank)
def test_barrier_group(self):
group, group_id, rank = self._init_group_test()
self._test_barrier_helper(group, group_id, rank)
if BACKEND == 'tcp':
WORLD_SIZE = os.environ['WORLD_SIZE']
class TestTCP(TestCase, _DistTestBase):
MANAGER_PROCESS_RANK = -1
JOIN_TIMEOUT = 5
@staticmethod
def manager_join(fn):
@wraps(fn)
def wrapper(self):
if self.rank == self.MANAGER_PROCESS_RANK:
self._join_and_reduce()
else:
fn(self)
return wrapper
@classmethod
def setUpClass(cls):
os.environ['MASTER_ADDR'] = MASTER_ADDR
os.environ['MASTER_PORT'] = MASTER_PORT
os.environ['WORLD_SIZE'] = WORLD_SIZE
for attr in dir(cls):
if attr.startswith('test'):
fn = getattr(cls, attr)
setattr(cls, attr, cls.manager_join(fn))
def setUp(self):
self.processes = []
self.rank = self.MANAGER_PROCESS_RANK
Barrier.init()
for rank in range(int(WORLD_SIZE)):
self.processes.append(self._spawn_process(rank))
def tearDown(self):
for p in self.processes:
p.terminate()
def _spawn_process(self, rank):
os.environ['RANK'] = str(rank)
name = 'process ' + str(rank)
process = multiprocessing.Process(target=self._run, name=name,
args=(rank,))
process.start()
return process
def _run(self, rank):
self.rank = rank
dist.init_process_group(backend=BACKEND)
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
# We're retreiving a corresponding test and executing it.
getattr(self, self.id().split(".")[2])()
sys.exit(0)
def _join_and_reduce(self):
for p in self.processes:
p.join(self.JOIN_TIMEOUT)
self.assertEqual(p.exitcode, 0)
elif BACKEND == 'mpi':
dist.init_process_group(backend='mpi')
class TestMPI(TestCase, _DistTestBase):
pass
if __name__ == '__main__':
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@ -11,13 +11,14 @@ import torch.cuda
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.autograd import Variable from torch.autograd import Variable
from torch.nn import Parameter from torch.nn import Parameter
from common import TestCase from common import TestCase, run_tests
TEST_REPEATS = 30
HAS_SHM_FILES = os.path.isdir('/dev/shm') HAS_SHM_FILES = os.path.isdir('/dev/shm')
TEST_CUDA_IPC = torch.cuda.is_available() and \ TEST_CUDA_IPC = torch.cuda.is_available() and \
sys.version_info[0] == 3 and \ sys.version_info[0] == 3 and \
sys.platform != 'darwin' sys.platform != 'darwin'
def simple_fill(queue, event): def simple_fill(queue, event):
@ -74,7 +75,7 @@ def autograd_sharing(queue, ready, master_modified):
master_modified.wait() master_modified.wait()
expected_var = torch.range(1, 25).view(5, 5) expected_var = torch.range(1, 25).view(5, 5)
expected_var[0,0] = 1000 expected_var[0, 0] = 1000
is_ok = var.data.equal(expected_var) is_ok = var.data.equal(expected_var)
var.data[:] = torch.ones(5, 5) var.data[:] = torch.ones(5, 5)
@ -113,7 +114,7 @@ class leak_checker(object):
# one-off initialization that may use up a file descriptor # one-off initialization that may use up a file descriptor
available_fds = self._get_next_fds(10) available_fds = self._get_next_fds(10)
self.test_case.assertLessEqual( self.test_case.assertLessEqual(
available_fds[-1] - self.next_fds[-1], 4) available_fds[-1] - self.next_fds[-1], 5)
self.test_case.assertFalse(self.has_shm_files()) self.test_case.assertFalse(self.has_shm_files())
return False return False
@ -189,7 +190,7 @@ class TestMultiprocessing(TestCase):
def _test_preserve_sharing(self, ctx=mp, repeat=1): def _test_preserve_sharing(self, ctx=mp, repeat=1):
def do_test(): def do_test():
x = torch.randn(5, 5) x = torch.randn(5, 5)
data = [x.storage(), x.storage()[1:4], x, x[2], x[:,1]] data = [x.storage(), x.storage()[1:4], x, x[2], x[:, 1]]
q = ctx.Queue() q = ctx.Queue()
q.put(data) q.put(data)
new_data = q.get() new_data = q.get()
@ -229,27 +230,27 @@ class TestMultiprocessing(TestCase):
@unittest.skipIf(platform == 'darwin', "file descriptor strategy is not supported on OS X") @unittest.skipIf(platform == 'darwin', "file descriptor strategy is not supported on OS X")
def test_fd_sharing(self): def test_fd_sharing(self):
self._test_sharing(repeat=20) self._test_sharing(repeat=TEST_REPEATS)
@unittest.skipIf(platform == 'darwin', "file descriptor strategy is not supported on OS X") @unittest.skipIf(platform == 'darwin', "file descriptor strategy is not supported on OS X")
def test_fd_preserve_sharing(self): def test_fd_preserve_sharing(self):
self._test_preserve_sharing(repeat=20) self._test_preserve_sharing(repeat=TEST_REPEATS)
@unittest.skipIf(platform == 'darwin', "file descriptor strategy is not supported on OS X") @unittest.skipIf(platform == 'darwin', "file descriptor strategy is not supported on OS X")
def test_fd_pool(self): def test_fd_pool(self):
self._test_pool(repeat=20) self._test_pool(repeat=TEST_REPEATS)
def test_fs_sharing(self): def test_fs_sharing(self):
with fs_sharing(): with fs_sharing():
self._test_sharing(repeat=20) self._test_sharing(repeat=TEST_REPEATS)
def test_fs_preserve_sharing(self): def test_fs_preserve_sharing(self):
with fs_sharing(): with fs_sharing():
self._test_preserve_sharing(repeat=20) self._test_preserve_sharing(repeat=TEST_REPEATS)
def test_fs_pool(self): def test_fs_pool(self):
with fs_sharing(): with fs_sharing():
self._test_pool(repeat=20) self._test_pool(repeat=TEST_REPEATS)
@unittest.skipIf(not HAS_SHM_FILES, "don't not how to check if shm files exist") @unittest.skipIf(not HAS_SHM_FILES, "don't not how to check if shm files exist")
def test_fs(self): def test_fs(self):
@ -263,11 +264,12 @@ class TestMultiprocessing(TestCase):
q.get() q.get()
with fs_sharing(), leak_checker(self) as lc: with fs_sharing(), leak_checker(self) as lc:
for i in range(20): for i in range(TEST_REPEATS):
queue_put() queue_put()
def test_inherit_tensor(self): def test_inherit_tensor(self):
class SubProcess(mp.Process): class SubProcess(mp.Process):
def __init__(self, tensor): def __init__(self, tensor):
super(SubProcess, self).__init__() super(SubProcess, self).__init__()
self.tensor = tensor self.tensor = tensor
@ -286,7 +288,6 @@ class TestMultiprocessing(TestCase):
torch.cuda.FloatTensor([1]) # initialize CUDA outside of leak checker torch.cuda.FloatTensor([1]) # initialize CUDA outside of leak checker
self._test_sharing(mp.get_context('spawn'), torch.cuda.FloatTensor) self._test_sharing(mp.get_context('spawn'), torch.cuda.FloatTensor)
@unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available') @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
def test_cuda_small_tensors(self): def test_cuda_small_tensors(self):
# Check multiple small tensors which will likely use the same # Check multiple small tensors which will likely use the same
@ -359,7 +360,7 @@ class TestMultiprocessing(TestCase):
queue.put(var) queue.put(var)
ready.wait() ready.wait()
var.data[0,0] = 1000 var.data[0, 0] = 1000
if var.grad is not None: if var.grad is not None:
var.grad.data[:] = torch.ones(5, 5) * 4 var.grad.data[:] = torch.ones(5, 5) * 4
master_modified.set() master_modified.set()
@ -380,8 +381,8 @@ class TestMultiprocessing(TestCase):
] ]
for requires_grad, volatile in configs: for requires_grad, volatile in configs:
var = Variable(torch.range(1, 25).view(5, 5), var = Variable(torch.range(1, 25).view(5, 5),
requires_grad=requires_grad, requires_grad=requires_grad,
volatile=volatile) volatile=volatile)
self._test_autograd_sharing(var) self._test_autograd_sharing(var)
def test_parameter_sharing(self): def test_parameter_sharing(self):
@ -409,4 +410,4 @@ class TestMultiprocessing(TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() run_tests()

View File

@ -4,7 +4,7 @@ import torch
import torch.cuda.nccl as nccl import torch.cuda.nccl as nccl
import torch.cuda import torch.cuda
from common import TestCase from common import TestCase, run_tests
if not torch.cuda.is_available(): if not torch.cuda.is_available():
print('CUDA not available, skipping tests') print('CUDA not available, skipping tests')
@ -87,4 +87,4 @@ class TestNCCL(TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() run_tests()

View File

@ -13,11 +13,14 @@ import torch.nn.parallel as dp
from torch.autograd import Variable from torch.autograd import Variable
from torch.nn import Parameter from torch.nn import Parameter
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \ from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
module_tests, criterion_tests, TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PRECISION module_tests, criterion_tests, TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
from common import freeze_rng_state TEST_CUDNN_VERSION, PRECISION
from common import freeze_rng_state, run_tests
def default_tensor_type(type): def default_tensor_type(type):
type_str = torch.typename(type) type_str = torch.typename(type)
def decorator(fn): def decorator(fn):
@wraps(fn) @wraps(fn)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -30,9 +33,12 @@ def default_tensor_type(type):
return wrapper return wrapper
return decorator return decorator
class InputVariableMixin(object): class InputVariableMixin(object):
def _get_input(self): def _get_input(self):
input = TestBase._get_input(self) input = TestBase._get_input(self)
def map_variables(i): def map_variables(i):
if isinstance(i, Variable): if isinstance(i, Variable):
return i return i
@ -44,6 +50,7 @@ class InputVariableMixin(object):
class NewModuleTest(InputVariableMixin, ModuleTest): class NewModuleTest(InputVariableMixin, ModuleTest):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(NewModuleTest, self).__init__(*args, **kwargs) super(NewModuleTest, self).__init__(*args, **kwargs)
self.cudnn = kwargs.get('cudnn', False) self.cudnn = kwargs.get('cudnn', False)
@ -63,10 +70,14 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
test_case.assertEqual(input._version, input_version) test_case.assertEqual(input._version, input_version)
input_ip = deepcopy(input) input_ip = deepcopy(input)
output_ip = module_ip(input_ip) input_ip_clone = input_ip.clone()
test_case.assertNotEqual(input_ip._version, input_version) output_ip = module_ip(input_ip_clone)
test_case.assertNotEqual(input_ip_clone._version, input_version)
test_case.assertEqual(output, output_ip) test_case.assertEqual(output, output_ip)
grad = output.data.clone().normal_()
output.backward(grad)
output_ip.backward(grad)
test_case.assertEqual(output.grad, output_ip.grad)
if type(input.data) == torch.LongTensor and TEST_CUDA: if type(input.data) == torch.LongTensor and TEST_CUDA:
input = input.cuda() input = input.cuda()
@ -352,21 +363,21 @@ class TestNN(NNTestCase):
def _test_dropout(self, cls, input): def _test_dropout(self, cls, input):
p = 0.2 p = 0.2
input.fill_(1-p) input.fill_(1 - p)
module = cls(p) module = cls(p)
input_var = Variable(input, requires_grad=True) input_var = Variable(input, requires_grad=True)
output = module(input_var) output = module(input_var)
self.assertLess(abs(output.data.mean() - (1-p)), 0.05) self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
output.backward(input) output.backward(input)
self.assertLess(abs(input_var.grad.data.mean() - (1-p)), 0.05) self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
module = cls(p, True) module = cls(p, True)
input_var = Variable(input.clone(), requires_grad=True) input_var = Variable(input.clone(), requires_grad=True)
output = module(input_var + 0) output = module(input_var + 0)
self.assertLess(abs(output.data.mean() - (1-p)), 0.05) self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
output.backward(input) output.backward(input)
self.assertLess(abs(input_var.grad.data.mean() - (1-p)), 0.05) self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
# Check that these don't raise errors # Check that these don't raise errors
module.__repr__() module.__repr__()
@ -375,7 +386,9 @@ class TestNN(NNTestCase):
def test_parameters(self): def test_parameters(self):
def num_params(module): def num_params(module):
return len(list(module.parameters())) return len(list(module.parameters()))
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.l1 = l self.l1 = l
@ -390,6 +403,7 @@ class TestNN(NNTestCase):
def test_modules(self): def test_modules(self):
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.l1 = l self.l1 = l
@ -411,6 +425,71 @@ class TestNN(NNTestCase):
self.assertEqual(n[2], l3) self.assertEqual(n[2], l3)
self.assertEqual(n[3], l4) self.assertEqual(n[3], l4)
def test_ListModule(self):
modules = [nn.ReLU(), nn.Linear(5, 5)]
module_list = nn.ModuleList(modules)
def check():
self.assertEqual(len(module_list), len(modules))
for m1, m2 in zip(modules, module_list):
self.assertIs(m1, m2)
for m1, m2 in zip(modules, module_list.children()):
self.assertIs(m1, m2)
for i in range(len(modules)):
self.assertIs(module_list[i], modules[i])
check()
modules += [nn.Conv2d(3, 4, 3)]
module_list += [modules[-1]]
check()
modules.append(nn.Tanh())
module_list.append(modules[-1])
check()
next_modules = [nn.Linear(5, 5), nn.Sigmoid()]
modules.extend(next_modules)
module_list.extend(next_modules)
check()
modules[2] = nn.Conv2d(5, 3, 2)
module_list[2] = modules[2]
check()
with self.assertRaises(TypeError):
module_list += nn.ReLU()
with self.assertRaises(TypeError):
module_list.extend(nn.ReLU())
def test_ParameterList(self):
make_param = lambda: Parameter(torch.randn(10, 10))
parameters = [make_param(), make_param()]
param_list = nn.ParameterList(parameters)
def check():
self.assertEqual(len(parameters), len(param_list))
for p1, p2 in zip(parameters, param_list):
self.assertIs(p1, p2)
for p1, p2 in zip(parameters, param_list.parameters()):
self.assertIs(p1, p2)
for i in range(len(parameters)):
self.assertIs(parameters[i], param_list[i])
check()
parameters += [make_param()]
param_list += [parameters[-1]]
check()
parameters.append(make_param())
param_list.append(parameters[-1])
check()
next_params = [make_param(), make_param()]
parameters.extend(next_params)
param_list.extend(next_params)
check()
parameters[2] = make_param()
param_list[2] = parameters[2]
check()
with self.assertRaises(TypeError):
param_list += make_param()
with self.assertRaises(TypeError):
param_list.extend(make_param())
def test_add_module(self): def test_add_module(self):
l = nn.Linear(10, 20) l = nn.Linear(10, 20)
net = nn.Module() net = nn.Module()
@ -451,6 +530,7 @@ class TestNN(NNTestCase):
def test_non_leaf_parameters(self): def test_non_leaf_parameters(self):
l1 = nn.Linear(10, 10) l1 = nn.Linear(10, 10)
l2 = nn.Linear(10, 10) l2 = nn.Linear(10, 10)
def assign_weight(): def assign_weight():
l2.weight = l1.weight + 2 l2.weight = l1.weight + 2
self.assertRaises(TypeError, assign_weight) self.assertRaises(TypeError, assign_weight)
@ -458,8 +538,8 @@ class TestNN(NNTestCase):
l2.weight = Parameter(torch.randn(10, 10)) l2.weight = Parameter(torch.randn(10, 10))
def test_embedding_padding_idx(self): def test_embedding_padding_idx(self):
embedding = nn.Embedding(10, 20, padding_idx = 0) embedding = nn.Embedding(10, 20, padding_idx=0)
input = Variable(torch.LongTensor([[0,2,4,5],[4,3,0,9]])) input = Variable(torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]]))
output = embedding(input) output = embedding(input)
self.assertEqual(output[0][0].sum().data[0], 0) self.assertEqual(output[0][0].sum().data[0], 0)
self.assertEqual(output[1][2].sum().data[0], 0) self.assertEqual(output[1][2].sum().data[0], 0)
@ -489,14 +569,14 @@ class TestNN(NNTestCase):
def expected_indices(dim): def expected_indices(dim):
if dim == 1: if dim == 1:
return torch.DoubleTensor([1, 3]) return torch.DoubleTensor([1, 3])
lower_dim = expected_indices(dim-1) lower_dim = expected_indices(dim - 1)
lower_dim = lower_dim.view(1, *lower_dim.size()) lower_dim = lower_dim.view(1, *lower_dim.size())
return torch.cat((lower_dim+4, lower_dim+12), 0) return torch.cat((lower_dim + 4, lower_dim + 12), 0)
def expected_grad(dim): def expected_grad(dim):
if dim == 1: if dim == 1:
return torch.DoubleTensor([0, 1, 0, 1]) return torch.DoubleTensor([0, 1, 0, 1])
lower_dim_grad = expected_grad(dim-1) lower_dim_grad = expected_grad(dim - 1)
grad = lower_dim_grad.view(1, *lower_dim_grad.size()) grad = lower_dim_grad.view(1, *lower_dim_grad.size())
zero = torch.zeros(grad.size()) zero = torch.zeros(grad.size())
return torch.cat((zero, grad, zero, grad), 0) return torch.cat((zero, grad, zero, grad), 0)
@ -667,7 +747,9 @@ class TestNN(NNTestCase):
def test_data_parallel_nested_output(self): def test_data_parallel_nested_output(self):
def fn(input): def fn(input):
return [input, (input.sin(), input.cos(), [input.add(1)]), input] return [input, (input.sin(), input.cos(), [input.add(1)]), input]
class Net(nn.Module): class Net(nn.Module):
def forward(self, input): def forward(self, input):
return fn(input) return fn(input)
i = Variable(torch.randn(2, 2).float().cuda(1)) i = Variable(torch.randn(2, 2).float().cuda(1))
@ -686,7 +768,9 @@ class TestNN(NNTestCase):
def test_data_parallel_nested_input(self): def test_data_parallel_nested_input(self):
def fn(input): def fn(input):
return input[1][0] return input[1][0]
class Net(nn.Module): class Net(nn.Module):
def forward(self, input): def forward(self, input):
return fn(input) return fn(input)
i = Variable(torch.randn(20, 3).float().cuda(1)) i = Variable(torch.randn(20, 3).float().cuda(1))
@ -708,7 +792,7 @@ class TestNN(NNTestCase):
def test_state_dict(self): def test_state_dict(self):
l = nn.Linear(5, 5) l = nn.Linear(5, 5)
block = nn.Module() block = nn.Module()
block.conv=nn.Conv2d(3, 3, 3, bias=False) block.conv = nn.Conv2d(3, 3, 3, bias=False)
net = nn.Module() net = nn.Module()
net.linear1 = l net.linear1 = l
net.linear2 = l net.linear2 = l
@ -777,6 +861,7 @@ class TestNN(NNTestCase):
def test_parameter_assignment(self): def test_parameter_assignment(self):
l = nn.Linear(5, 5) l = nn.Linear(5, 5)
def num_params(): def num_params():
return len(list(l.parameters())) return len(list(l.parameters()))
self.assertEqual(num_params(), 2) self.assertEqual(num_params(), 2)
@ -789,7 +874,7 @@ class TestNN(NNTestCase):
var = Variable(torch.randn(5, 5)) var = Variable(torch.randn(5, 5))
l.var_name = var l.var_name = var
self.assertEqual(num_params(), 3) self.assertEqual(num_params(), 3)
self.assertNotIn(var, l.parameters()) self.assertNotIn(id(var), map(id, l.parameters()))
# Make sure Variables are not saved as parameters # Make sure Variables are not saved as parameters
l.variable_attr = Variable(torch.Tensor(5, 5)) l.variable_attr = Variable(torch.Tensor(5, 5))
@ -805,6 +890,32 @@ class TestNN(NNTestCase):
l.param_attr = None l.param_attr = None
self.assertEqual(num_params(), 3) self.assertEqual(num_params(), 3)
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
def test_Conv2d_large_workspace(self):
# These sizes require huge cuDNN workspaces. Make sure we choose a
# reasonable algorithm that does not run out of memory
sizes = [
(1, 256, 109, 175),
(1, 256, 80, 128),
(1, 256, 120, 192),
]
dtype = torch.cuda.FloatTensor
def run_test(benchmark):
torch.backends.cudnn.benchmark = benchmark
conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).type(dtype)
for size in sizes:
x = torch.randn(size).type(dtype)
out = conv(Variable(x, requires_grad=True))
out.backward(torch.ones(out.size()).type(dtype))
b = torch.backends.cudnn.benchmark
try:
run_test(benchmark=False)
run_test(benchmark=True)
finally:
torch.backends.cudnn.benchmark = b
def test_ConvTranspose2d_output_size(self): def test_ConvTranspose2d_output_size(self):
m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2) m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2)
i = Variable(torch.randn(2, 3, 6, 6)) i = Variable(torch.randn(2, 3, 6, 6))
@ -857,7 +968,7 @@ class TestNN(NNTestCase):
small_t = torch.rand(1, 1, 5, 5) small_t = torch.rand(1, 1, 5, 5)
for i in range(0, 4, 2): for i in range(0, 4, 2):
for j in range(0, 4, 2): for j in range(0, 4, 2):
small_t[:,:,i,j] = 100 small_t[:, :, i, j] = 100
output_small, indices_small = m(Variable(small_t)) output_small, indices_small = m(Variable(small_t))
for h in range(3, 10): for h in range(3, 10):
for w in range(3, 10): for w in range(3, 10):
@ -870,10 +981,11 @@ class TestNN(NNTestCase):
mu(output_small, indices_small, output_size=size) mu(output_small, indices_small, output_size=size)
else: else:
self.assertRaises(ValueError, lambda: self.assertRaises(ValueError, lambda:
mu(output_small, indices_small, (h, w))) mu(output_small, indices_small, (h, w)))
def test_container_copy(self): def test_container_copy(self):
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super(Model, self).__init__() super(Model, self).__init__()
self.linear = nn.Linear(4, 5) self.linear = nn.Linear(4, 5)
@ -925,11 +1037,22 @@ class TestNN(NNTestCase):
for i in range(6): for i in range(6):
hx, cx = lstm(input, (hx, cx)) hx, cx = lstm(input, (hx, cx))
(hx+cx).sum().backward() (hx + cx).sum().backward()
@unittest.skipIf(not TEST_CUDNN, "needs cudnn") def test_rnn_initial_hidden_state(self):
@default_tensor_type(torch.FloatTensor) # FIXME: just until torch.cuda.DoubleTensor.sum() implemented rnn_modes = ['RNN', 'GRU', 'LSTM']
def test_RNN_cpu_vs_cudnn(self): for mode in rnn_modes:
rnn = getattr(nn, mode)(30, 20, 2)
input = Variable(torch.randn(10, 32, 30))
hidden = Variable(torch.Tensor(2, 32, 20).zero_())
if mode is 'LSTM':
hidden = (hidden, hidden)
output1, hidden1 = rnn(input, hidden)
output2, hidden2 = rnn(input)
self.assertEqual(output1, output2)
self.assertEqual(hidden1, hidden2)
def _test_RNN_cpu_vs_cudnn(self, dropout):
def forward_backward(cuda, rnn, input_val, hx_val, weights_val): def forward_backward(cuda, rnn, input_val, hx_val, weights_val):
is_lstm = type(rnn) == nn.LSTM is_lstm = type(rnn) == nn.LSTM
@ -957,9 +1080,9 @@ class TestNN(NNTestCase):
output, hy = rnn(input, hx) output, hy = rnn(input, hx)
# FIXME this is because of a pytorch bug # FIXME this is because of a pytorch bug
if is_lstm: if is_lstm:
fake_loss = 0*(hy[0] + hy[1]).sum() fake_loss = 0 * (hy[0] + hy[1]).sum()
else: else:
fake_loss = 0*hy.sum() fake_loss = 0 * hy.sum()
loss = output.sum() + fake_loss loss = output.sum() + fake_loss
loss.backward() loss.backward()
@ -989,42 +1112,40 @@ class TestNN(NNTestCase):
for (cpu_weight, gpu_weight) in zip(cpu_layer_weight, gpu_layer_weight): for (cpu_weight, gpu_weight) in zip(cpu_layer_weight, gpu_layer_weight):
self.assertEqual(cpu_weight.grad.data, gpu_weight.grad.data, prec=5e-5) self.assertEqual(cpu_weight.grad.data, gpu_weight.grad.data, prec=5e-5)
for module in (nn.RNN, nn.LSTM, nn.GRU): for module in (nn.RNN, nn.LSTM, nn.GRU):
for bias in (True, False): for bias in (True, False):
for bidirectional in (False, True): for bidirectional in (False, True):
for dropout in (0, 1): # Because of dropout randomness, can only compare 0 and 1 for batch_first in (False, True):
for batch_first in (False, True): num_directions = 2 if bidirectional else 1
num_directions = 2 if bidirectional else 1 if batch_first:
if batch_first: input_val = torch.randn(batch, seq_length, input_size)
input_val = torch.randn(batch, seq_length, input_size) else:
else: input_val = torch.randn(seq_length, batch, input_size)
input_val = torch.randn(seq_length, batch, input_size) hx_val = torch.randn(num_layers * num_directions, batch, hidden_size)
hx_val = torch.randn(num_layers * num_directions, batch, hidden_size)
rnn = module(input_size, rnn = module(input_size,
hidden_size,
num_layers,
bias=bias,
dropout=dropout,
bidirectional=bidirectional,
batch_first=batch_first)
outputs_cpu = forward_backward(
False, rnn, input_val, hx_val, rnn.all_weights)
rnn_gpu = module(input_size,
hidden_size, hidden_size,
num_layers, num_layers,
bias=bias, bias=bias,
dropout=dropout, dropout=dropout,
bidirectional=bidirectional, bidirectional=bidirectional,
batch_first = batch_first) batch_first=batch_first)
outputs_cpu = forward_backward( outputs_gpu = forward_backward(
False, rnn, input_val, hx_val, rnn.all_weights) True, rnn_gpu, input_val, hx_val, rnn.all_weights)
rnn_gpu = module(input_size, compare_cpu_gpu(outputs_cpu, outputs_gpu)
hidden_size,
num_layers,
bias=bias,
dropout=dropout,
bidirectional=bidirectional,
batch_first = batch_first)
outputs_gpu = forward_backward(
True, rnn_gpu, input_val, hx_val, rnn.all_weights)
compare_cpu_gpu(outputs_cpu, outputs_gpu)
for nonlinearity in ('tanh', 'relu'): for nonlinearity in ('tanh', 'relu'):
hx_val = torch.randn(num_layers, batch, hidden_size) hx_val = torch.randn(num_layers, batch, hidden_size)
@ -1039,6 +1160,17 @@ class TestNN(NNTestCase):
compare_cpu_gpu(outputs_cpu, outputs_gpu) compare_cpu_gpu(outputs_cpu, outputs_gpu)
@unittest.skipIf(not TEST_CUDNN, "needs cudnn") @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
@default_tensor_type(torch.FloatTensor) # FIXME: just until torch.cuda.DoubleTensor.sum() implemented
def test_RNN_cpu_vs_cudnn_no_dropout(self):
self._test_RNN_cpu_vs_cudnn(0)
@unittest.skipIf(not (TEST_CUDNN and TEST_CUDNN_VERSION >= 5103), "needs cudnn >= 5.1")
@default_tensor_type(torch.FloatTensor) # FIXME: just until torch.cuda.DoubleTensor.sum() implemented
def test_RNN_cpu_vs_cudnn_with_dropout(self):
# Because of dropout randomness, can only compare dropout=0 and dropout=1
self._test_RNN_cpu_vs_cudnn(1)
@unittest.skipIf(not (TEST_CUDNN and TEST_CUDNN_VERSION >= 5103), "needs cudnn >= 5.1")
def test_RNN_dropout(self): def test_RNN_dropout(self):
# checking the assumption that cuDNN sticks dropout in between # checking the assumption that cuDNN sticks dropout in between
# RNN layers # RNN layers
@ -1057,8 +1189,8 @@ class TestNN(NNTestCase):
rnn.weight_hh_l0.data.fill_(1) rnn.weight_hh_l0.data.fill_(1)
rnn.weight_ih_l1.data.fill_(1) rnn.weight_ih_l1.data.fill_(1)
rnn.weight_hh_l1.data.fill_(1) rnn.weight_hh_l1.data.fill_(1)
input = Variable(torch.Tensor(1,1,10).fill_(1)) input = Variable(torch.Tensor(1, 1, 10).fill_(1))
hx = Variable(torch.Tensor(2,1,1000).fill_(0)) hx = Variable(torch.Tensor(2, 1, 1000).fill_(0))
if cuda: if cuda:
input = input.cuda() input = input.cuda()
hx = hx.cuda() hx = hx.cuda()
@ -1081,7 +1213,7 @@ class TestNN(NNTestCase):
self.assertEqual(hy.data[0][0][0], 10) self.assertEqual(hy.data[0][0][0], 10)
self.assertEqual(hy.data[1][0][0], output_val) self.assertEqual(hy.data[1][0][0], output_val)
@unittest.skipIf(not TEST_CUDNN, "needs cudnn") @unittest.skipIf(not (TEST_CUDNN and TEST_CUDNN_VERSION >= 5103), "needs cudnn >= 5.1")
def test_RNN_dropout_state(self): def test_RNN_dropout_state(self):
import sys import sys
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
@ -1099,8 +1231,8 @@ class TestNN(NNTestCase):
rnn.train() rnn.train()
else: else:
rnn.eval() rnn.eval()
input = Variable(torch.Tensor(1,1,100).uniform_()) input = Variable(torch.Tensor(1, 1, 100).uniform_())
hx = Variable(torch.Tensor(2,1,100).uniform_()) hx = Variable(torch.Tensor(2, 1, 100).uniform_())
if cuda: if cuda:
input = input.cuda() input = input.cuda()
hx = hx.cuda() hx = hx.cuda()
@ -1133,6 +1265,15 @@ class TestNN(NNTestCase):
(c * upscale_factor ** 2) (c * upscale_factor ** 2)
self.assertEqual(output[:, c, h, w], input[:, channel_idx, height_idx, weight_idx]) self.assertEqual(output[:, c, h, w], input[:, channel_idx, height_idx, weight_idx])
def test_inplace_thnn(self):
r = nn.ReLU(True)
input = Variable(torch.randn(5, 5), requires_grad=True)
output = r(input + 0)
grad_output = torch.randn(5, 5)
grad_output_clone = grad_output.clone()
output.backward(grad_output)
self.assertEqual(grad_output, grad_output_clone)
def test_pixel_shuffle(self): def test_pixel_shuffle(self):
batch_size = random.randint(1, 3) batch_size = random.randint(1, 3)
upscale_factor = random.randint(2, 5) upscale_factor = random.randint(2, 5)
@ -1147,6 +1288,32 @@ class TestNN(NNTestCase):
output.backward(output.data) output.backward(output.data)
self.assertEqual(input.data, input.grad.data) self.assertEqual(input.data, input.grad.data)
def test_batchnorm_eval(self):
types = (torch.FloatTensor,)
if TEST_CUDA:
types += (torch.cuda.FloatTensor,)
for tp in types:
module = nn.BatchNorm1d(3).type(tp)
module.eval()
data = Variable(torch.rand(4, 3).type(tp), requires_grad=True)
grad = torch.rand(4, 3).type(tp)
# 1st pass
res1 = module(data)
res1.backward(grad)
grad1 = data.grad.data.clone()
# 2nd pass
data.grad.data.zero_()
res2 = module(data)
res2.backward(grad)
grad2 = data.grad.data.clone()
self.assertEqual(res1, res2)
self.assertEqual(grad1, grad2)
def add_test(test): def add_test(test):
test_name = test.get_name() test_name = test.get_name()
cuda_test_name = test_name + '_cuda' cuda_test_name = test_name + '_cuda'
@ -1154,8 +1321,8 @@ def add_test(test):
raise RuntimeError('Found two tests with the same name: ' + test_name) raise RuntimeError('Found two tests with the same name: ' + test_name)
if hasattr(TestNN, cuda_test_name): if hasattr(TestNN, cuda_test_name):
raise RuntimeError('Found two tests with the same name: ' + cuda_test_name) raise RuntimeError('Found two tests with the same name: ' + cuda_test_name)
setattr(TestNN, test_name, lambda self,test=test: test(self)) setattr(TestNN, test_name, lambda self, test=test: test(self))
setattr(TestNN, cuda_test_name, lambda self,test=test: test.test_cuda(self)) setattr(TestNN, cuda_test_name, lambda self, test=test: test.test_cuda(self))
new_module_tests = [ new_module_tests = [
@ -1308,6 +1475,11 @@ new_module_tests = [
input_size=(2, 4, 6, 5), input_size=(2, 4, 6, 5),
cudnn=True, cudnn=True,
), ),
dict(
fullname='Conv2d_groups_thnn',
constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
input_size=(2, 4, 6, 5),
),
dict( dict(
module_name='ConvTranspose2d', module_name='ConvTranspose2d',
constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)), constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)),
@ -1460,20 +1632,26 @@ new_module_tests = [
dict( dict(
module_name='Embedding', module_name='Embedding',
constructor_args=(4, 3), constructor_args=(4, 3),
input=Variable( input=Variable(torch.randperm(2).repeat(1, 2)),
torch.randperm(2).repeat(1, 2),
requires_grad=False
),
jacobian_input=False jacobian_input=False
), ),
dict( dict(
constructor=lambda: nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()), constructor=lambda: nn.Embedding(4, 3, sparse=True),
input=Variable(torch.randperm(2).repeat(1, 2)),
jacobian_input=False,
fullname='Embedding_sparse',
test_cuda=False,
),
dict(
constructor=lambda: nn.FractionalMaxPool2d(
2, output_ratio=0.5, _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
input_size=(1, 3, 5, 5), input_size=(1, 3, 5, 5),
fullname='FractionalMaxPool2d_ratio', fullname='FractionalMaxPool2d_ratio',
test_cuda=False test_cuda=False
), ),
dict( dict(
constructor=lambda: nn.FractionalMaxPool2d((2, 2), output_size=(4, 4), _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()), constructor=lambda: nn.FractionalMaxPool2d((2, 2), output_size=(
4, 4), _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
input_size=(1, 3, 7, 7), input_size=(1, 3, 7, 7),
fullname='FractionalMaxPool2d_size', fullname='FractionalMaxPool2d_size',
test_cuda=False test_cuda=False
@ -1483,6 +1661,40 @@ new_module_tests = [
constructor_args=(3,), constructor_args=(3,),
input_size=(1, 9, 4, 4), input_size=(1, 9, 4, 4),
), ),
dict(
module_name='UpsamplingNearest2d',
constructor_args=(12,),
input_size=(1, 2, 4, 4),
),
dict(
module_name='UpsamplingNearest2d',
constructor_args=((12, 16)),
input_size=(1, 2, 3, 4),
desc='tuple'
),
dict(
module_name='UpsamplingNearest2d',
constructor_args=(None, 4),
input_size=(1, 2, 4, 4),
desc='scale'
),
dict(
module_name='UpsamplingBilinear2d',
constructor_args=(12,),
input_size=(1, 2, 4, 4),
),
dict(
module_name='UpsamplingBilinear2d',
constructor_args=((4, 6)),
input_size=(1, 2, 2, 3),
desc='tuple'
),
dict(
module_name='UpsamplingBilinear2d',
constructor_args=(None, 4),
input_size=(1, 2, 4, 4),
desc='scale'
),
] ]
@ -1501,6 +1713,7 @@ for test_params in criterion_tests:
class UnpoolingNet(nn.Module): class UnpoolingNet(nn.Module):
def __init__(self, pool, unpool): def __init__(self, pool, unpool):
super(UnpoolingNet, self).__init__() super(UnpoolingNet, self).__init__()
self.pool = pool self.pool = pool
@ -1531,4 +1744,4 @@ add_test(NewModuleTest(
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() run_tests()

View File

@ -1,10 +1,12 @@
import unittest import unittest
import functools
from copy import deepcopy
import torch import torch
import torch.optim as optim import torch.optim as optim
import torch.legacy.optim as old_optim import torch.legacy.optim as old_optim
from torch.autograd import Variable from torch.autograd import Variable
from common import TestCase from common import TestCase, run_tests
def rosenbrock(tensor): def rosenbrock(tensor):
@ -14,7 +16,7 @@ def rosenbrock(tensor):
def drosenbrock(tensor): def drosenbrock(tensor):
x, y = tensor x, y = tensor
return torch.DoubleTensor((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2))) return torch.DoubleTensor((-400 * x * (y - x ** 2) - 2 * (1 - x), 200 * (y - x ** 2)))
def wrap_old_fn(old_fn, **config): def wrap_old_fn(old_fn, **config):
@ -36,15 +38,22 @@ class TestOptim(TestCase):
initial_dist = params.data.dist(solution) initial_dist = params.data.dist(solution)
def eval(): def eval():
optimizer.zero_grad()
loss = rosenbrock(params) loss = rosenbrock(params)
loss.backward() loss.backward()
# loss.backward() will give **slightly** different
# gradients, than drosenbtock, because of a different ordering
# of floating point operations. In most cases it doesn't matter,
# but some optimizers are so sensitive that they can temporarily
# diverge up to 1e-4, just to converge again. This makes the
# comparison more stable.
params.grad.data.copy_(drosenbrock(params.data))
return loss return loss
for i in range(2000): for i in range(2000):
optimizer.zero_grad()
optimizer.step(eval) optimizer.step(eval)
old_fn(lambda _: (rosenbrock(params_t), drosenbrock(params_t)), old_fn(lambda _: (rosenbrock(params_t), drosenbrock(params_t)),
params_t, state) params_t, state)
self.assertEqual(params.data, params_t) self.assertEqual(params.data, params_t)
self.assertLessEqual(params.data.dist(solution), initial_dist) self.assertLessEqual(params.data.dist(solution), initial_dist)
@ -52,25 +61,65 @@ class TestOptim(TestCase):
def _test_basic_cases_template(self, weight, bias, input, constructor): def _test_basic_cases_template(self, weight, bias, input, constructor):
weight = Variable(weight, requires_grad=True) weight = Variable(weight, requires_grad=True)
bias = Variable(bias, requires_grad=True) bias = Variable(bias, requires_grad=True)
input = Variable(input, requires_grad=False) input = Variable(input)
optimizer = constructor(weight, bias) optimizer = constructor(weight, bias)
def fn(): def fn():
optimizer.zero_grad()
y = weight.mv(input) y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device(): if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device()) y = y.cuda(bias.get_device())
return (y + bias).abs().sum() loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().data[0] initial_value = fn().data[0]
for i in range(200): for i in range(200):
weight.grad.data.zero_() optimizer.step(fn)
bias.grad.data.zero_() self.assertLess(fn().data[0], initial_value)
fn().backward()
optimizer.step()
self.assertLessEqual(fn().data[0], initial_value) def _test_state_dict(self, weight, bias, input, constructor):
weight = Variable(weight, requires_grad=True)
bias = Variable(bias, requires_grad=True)
input = Variable(input)
def _test_basic_cases(self, constructor): def fn_base(optimizer, weight, bias):
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
return loss
optimizer = constructor(weight, bias)
fn = functools.partial(fn_base, optimizer, weight, bias)
# Prime the optimizer
for i in range(20):
optimizer.step(fn)
# Clone the weights and construct new optimizer for them
weight_c = Variable(weight.data.clone(), requires_grad=True)
bias_c = Variable(bias.data.clone(), requires_grad=True)
optimizer_c = constructor(weight_c, bias_c)
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
# Load state dict
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_c.load_state_dict(state_dict_c)
# Run both optimizations in parallel
for i in range(20):
optimizer.step(fn)
optimizer_c.step(fn_c)
self.assertEqual(weight, weight_c)
self.assertEqual(bias, bias_c)
# Make sure state dict wasn't modified
self.assertEqual(state_dict, state_dict_c)
def _test_basic_cases(self, constructor, ignore_multidevice=False):
self._test_state_dict(
torch.randn(10, 5),
torch.randn(10),
torch.randn(5),
constructor
)
self._test_basic_cases_template( self._test_basic_cases_template(
torch.randn(10, 5), torch.randn(10, 5),
torch.randn(10), torch.randn(10),
@ -79,8 +128,8 @@ class TestOptim(TestCase):
) )
# non-contiguous parameters # non-contiguous parameters
self._test_basic_cases_template( self._test_basic_cases_template(
torch.randn(10, 5, 2)[...,0], torch.randn(10, 5, 2)[..., 0],
torch.randn(10, 2)[...,0], torch.randn(10, 2)[..., 0],
torch.randn(5), torch.randn(5),
constructor constructor
) )
@ -94,12 +143,12 @@ class TestOptim(TestCase):
constructor constructor
) )
# Multi-GPU # Multi-GPU
if not torch.cuda.device_count() > 1: if not torch.cuda.device_count() > 1 or ignore_multidevice:
return return
self._test_basic_cases_template( self._test_basic_cases_template(
torch.randn(10, 5).cuda(), torch.randn(10, 5).cuda(0),
torch.randn(10).cuda(), torch.randn(10).cuda(1),
torch.randn(5).cuda(), torch.randn(5).cuda(0),
constructor constructor
) )
@ -275,10 +324,24 @@ class TestOptim(TestCase):
lr=1e-3) lr=1e-3)
) )
def test_lbfgs(self):
self._test_rosenbrock(
lambda params: optim.LBFGS(params),
wrap_old_fn(old_optim.lbfgs)
)
self._test_rosenbrock(
lambda params: optim.LBFGS(params, lr=5e-2, max_iter=5),
wrap_old_fn(old_optim.lbfgs, learningRate=5e-2, maxIter=5)
)
self._test_basic_cases(
lambda weight, bias: optim.LBFGS([weight, bias]),
ignore_multidevice=True
)
def test_invalid_param_type(self): def test_invalid_param_type(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
optim.SGD(Variable(torch.randn(5, 5)), lr=3) optim.SGD(Variable(torch.randn(5, 5)), lr=3)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() run_tests()

View File

@ -4,13 +4,14 @@ from torch import sparse
import itertools import itertools
import random import random
import unittest import unittest
from common import TestCase from common import TestCase, run_tests
from numbers import Number from numbers import Number
SparseTensor = sparse.DoubleTensor SparseTensor = sparse.DoubleTensor
class TestSparse(TestCase): class TestSparse(TestCase):
@staticmethod @staticmethod
def _gen_sparse(d, nnz, with_size): def _gen_sparse(d, nnz, with_size):
v = torch.randn(nnz) v = torch.randn(nnz)
@ -19,7 +20,7 @@ class TestSparse(TestCase):
x = SparseTensor(i, v) x = SparseTensor(i, v)
else: else:
i = torch.rand(d, nnz) * \ i = torch.rand(d, nnz) * \
torch.Tensor(with_size).repeat(nnz, 1).transpose(0, 1) torch.Tensor(with_size).repeat(nnz, 1).transpose(0, 1)
i = i.type(torch.LongTensor) i = i.type(torch.LongTensor)
x = SparseTensor(i, v, torch.Size(with_size)) x = SparseTensor(i, v, torch.Size(with_size))
@ -74,13 +75,13 @@ class TestSparse(TestCase):
def test_contig(self): def test_contig(self):
i = torch.LongTensor([ i = torch.LongTensor([
[1, 0, 35, 14, 39, 6, 71, 66, 40, 27], [1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
[92, 31, 62, 50, 22, 65, 89, 74, 56, 34], [92, 31, 62, 50, 22, 65, 89, 74, 56, 34],
]) ])
v = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) v = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
x = SparseTensor(i, v, torch.Size([100, 100])) x = SparseTensor(i, v, torch.Size([100, 100]))
exp_i = torch.LongTensor([ exp_i = torch.LongTensor([
[0, 1, 6, 14, 27, 35, 39, 40, 66, 71], [0, 1, 6, 14, 27, 35, 39, 40, 66, 71],
[31, 92, 65, 50, 34, 62, 22, 56, 74, 89], [31, 92, 65, 50, 34, 62, 22, 56, 74, 89],
]) ])
exp_v = torch.Tensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7]) exp_v = torch.Tensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7])
@ -216,5 +217,4 @@ class TestSparse(TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() run_tests()

File diff suppressed because it is too large Load Diff

View File

@ -19,7 +19,7 @@ from torch.utils.serialization import load_lua
HAS_CUDA = torch.cuda.is_available() HAS_CUDA = torch.cuda.is_available()
from common import TestCase from common import TestCase, run_tests
try: try:
import cffi import cffi
@ -28,7 +28,9 @@ try:
except ImportError: except ImportError:
HAS_CFFI = False HAS_CFFI = False
class SimplePlugin(Plugin): class SimplePlugin(Plugin):
def __init__(self, interval): def __init__(self, interval):
super(SimplePlugin, self).__init__(interval) super(SimplePlugin, self).__init__(interval)
self.trainer = None self.trainer = None
@ -58,6 +60,7 @@ class SimplePlugin(Plugin):
class ModelMock(object): class ModelMock(object):
def __init__(self): def __init__(self):
self.num_calls = 0 self.num_calls = 0
self.output = Variable(torch.ones(1, 1), requires_grad=True) self.output = Variable(torch.ones(1, 1), requires_grad=True)
@ -68,6 +71,7 @@ class ModelMock(object):
class CriterionMock(object): class CriterionMock(object):
def __init__(self): def __init__(self):
self.num_calls = 0 self.num_calls = 0
@ -95,6 +99,7 @@ class OptimizerMock(object):
class DatasetMock(object): class DatasetMock(object):
def __iter__(self): def __iter__(self):
for i in range(10): for i in range(10):
yield torch.randn(2, 10), torch.randperm(10)[:2] yield torch.randn(2, 10), torch.randperm(10)[:2]
@ -183,6 +188,7 @@ class TestTrainer(TestCase):
test_dir = os.path.abspath(os.path.dirname(str(__file__))) test_dir = os.path.abspath(os.path.dirname(str(__file__)))
class TestFFI(TestCase): class TestFFI(TestCase):
def setUp(self): def setUp(self):
@ -196,13 +202,13 @@ class TestFFI(TestCase):
@unittest.skipIf(not HAS_CFFI, "ffi tests require cffi package") @unittest.skipIf(not HAS_CFFI, "ffi tests require cffi package")
def test_cpu(self): def test_cpu(self):
compile_extension( compile_extension(
name='test_extensions.cpulib', name='test_extensions.cpulib',
header=test_dir + '/ffi/src/cpu/lib.h', header=test_dir + '/ffi/src/cpu/lib.h',
sources=[ sources=[
test_dir + '/ffi/src/cpu/lib1.c', test_dir + '/ffi/src/cpu/lib1.c',
test_dir + '/ffi/src/cpu/lib2.c', test_dir + '/ffi/src/cpu/lib2.c',
], ],
verbose=False, verbose=False,
) )
from test_extensions import cpulib from test_extensions import cpulib
tensor = torch.ones(2, 2).float() tensor = torch.ones(2, 2).float()
@ -217,20 +223,20 @@ class TestFFI(TestCase):
self.assertIs(type(f), float) self.assertIs(type(f), float)
self.assertRaises(TypeError, self.assertRaises(TypeError,
lambda: cpulib.good_func(tensor.double(), 2, 1.5)) lambda: cpulib.good_func(tensor.double(), 2, 1.5))
self.assertRaises(torch.FatalError, self.assertRaises(torch.FatalError,
lambda: cpulib.bad_func(tensor, 2, 1.5)) lambda: cpulib.bad_func(tensor, 2, 1.5))
@unittest.skipIf(not HAS_CFFI or not HAS_CUDA, "ffi tests require cffi package") @unittest.skipIf(not HAS_CFFI or not HAS_CUDA, "ffi tests require cffi package")
def test_gpu(self): def test_gpu(self):
compile_extension( compile_extension(
name='gpulib', name='gpulib',
header=test_dir + '/ffi/src/cuda/cudalib.h', header=test_dir + '/ffi/src/cuda/cudalib.h',
sources=[ sources=[
test_dir + '/ffi/src/cuda/cudalib.c', test_dir + '/ffi/src/cuda/cudalib.c',
], ],
with_cuda=True, with_cuda=True,
verbose=False, verbose=False,
) )
import gpulib import gpulib
tensor = torch.ones(2, 2).float() tensor = torch.ones(2, 2).float()
@ -243,9 +249,9 @@ class TestFFI(TestCase):
self.assertEqual(ctensor, torch.ones(2, 2) * 2 + 1.5) self.assertEqual(ctensor, torch.ones(2, 2) * 2 + 1.5)
self.assertRaises(TypeError, self.assertRaises(TypeError,
lambda: gpulib.cuda_func(tensor, 2, 1.5)) lambda: gpulib.cuda_func(tensor, 2, 1.5))
self.assertRaises(TypeError, self.assertRaises(TypeError,
lambda: gpulib.cuda_func(ctensor.storage(), 2, 1.5)) lambda: gpulib.cuda_func(ctensor.storage(), 2, 1.5))
class TestLuaReader(TestCase): class TestLuaReader(TestCase):
@ -320,7 +326,7 @@ class TestLuaReader(TestCase):
cls._download_data(test_file_path) cls._download_data(test_file_path)
except urllib.URLError as e: except urllib.URLError as e:
warnings.warn(("Couldn't download the test file for TestLuaReader! " warnings.warn(("Couldn't download the test file for TestLuaReader! "
"Tests will be incomplete!"), RuntimeWarning) "Tests will be incomplete!"), RuntimeWarning)
return return
tests = load_lua(test_file_path) tests = load_lua(test_file_path)
@ -364,4 +370,4 @@ class TestLuaReader(TestCase):
TestLuaReader.init() TestLuaReader.init()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() run_tests()

View File

@ -20,13 +20,14 @@ class cwrap(object):
""") """)
OPTION_CODE_TEMPLATE = [ OPTION_CODE_TEMPLATE = [
'$call', '$call',
'$return_result', '$return_result',
] ]
FUNCTION_CALL_TEMPLATE = Template("$capture_result$cname($arg_unpack);") FUNCTION_CALL_TEMPLATE = Template("$capture_result$cname($arg_unpack);")
DEFAULT_PLUGIN_CLASSES = [ArgcountChecker, ConstantArguments, OptionalArguments, ArgumentReferences, BeforeAfterCall, ReturnArguments, GILRelease] DEFAULT_PLUGIN_CLASSES = [ArgcountChecker, ConstantArguments, OptionalArguments,
ArgumentReferences, BeforeAfterCall, ReturnArguments, GILRelease]
def __init__(self, source, destination=None, plugins=[], default_plugins=True): def __init__(self, source, destination=None, plugins=[], default_plugins=True):
if destination is None: if destination is None:
@ -87,7 +88,7 @@ class cwrap(object):
with open(fname, 'r') as f: with open(fname, 'r') as f:
included = f.read().split('\n') included = f.read().split('\n')
# insert it into lines at position i+1 # insert it into lines at position i+1
lines[i+1:i+1] = included lines[i + 1:i + 1] = included
else: else:
output.append(line) output.append(line)
i += 1 i += 1
@ -97,10 +98,10 @@ class cwrap(object):
def set_declaration_defaults(self, declaration): def set_declaration_defaults(self, declaration):
declaration.setdefault('arguments', []) declaration.setdefault('arguments', [])
declaration.setdefault('return', 'void') declaration.setdefault('return', 'void')
if not 'cname' in declaration: if 'cname' not in declaration:
declaration['cname'] = declaration['name'] declaration['cname'] = declaration['name']
# Simulate multiple dispatch, even if it's not necessary # Simulate multiple dispatch, even if it's not necessary
if not 'options' in declaration: if 'options' not in declaration:
declaration['options'] = [{'arguments': declaration['arguments']}] declaration['options'] = [{'arguments': declaration['arguments']}]
del declaration['arguments'] del declaration['arguments']
# Parse arguments (some of them can be strings) # Parse arguments (some of them can be strings)
@ -136,10 +137,10 @@ class cwrap(object):
return fallback(*args) return fallback(*args)
def get_type_check(self, arg, option): def get_type_check(self, arg, option):
return self.search_plugins('get_type_check', (arg, option), lambda arg,_: None) return self.search_plugins('get_type_check', (arg, option), lambda arg, _: None)
def get_type_unpack(self, arg, option): def get_type_unpack(self, arg, option):
return self.search_plugins('get_type_unpack', (arg, option), lambda arg,_: None) return self.search_plugins('get_type_unpack', (arg, option), lambda arg, _: None)
def get_return_wrapper(self, option): def get_return_wrapper(self, option):
return self.search_plugins('get_return_wrapper', (option,), lambda _: self.RETURN_WRAPPERS[option['return']]) return self.search_plugins('get_return_wrapper', (option,), lambda _: self.RETURN_WRAPPERS[option['return']])
@ -182,7 +183,7 @@ class cwrap(object):
def generate_option(self, option, is_first): def generate_option(self, option, is_first):
checked_args = list(filter( checked_args = list(filter(
lambda arg: not 'ignore_check' in arg or not arg['ignore_check'], lambda arg: 'ignore_check' not in arg or not arg['ignore_check'],
option['arguments'])) option['arguments']))
option['num_checked_args'] = len(checked_args) option['num_checked_args'] = len(checked_args)
idx_args = list(filter( idx_args = list(filter(
@ -193,14 +194,14 @@ class cwrap(object):
# Generate checks # Generate checks
arg_checks = self.map_selected_arguments('get_type_check', arg_checks = self.map_selected_arguments('get_type_check',
'process_single_check', option, checked_args) 'process_single_check', option, checked_args)
arg_checks = ' &&\n '.join(arg_checks) arg_checks = ' &&\n '.join(arg_checks)
for plugin in self.plugins: for plugin in self.plugins:
arg_checks = plugin.process_all_checks(arg_checks, option) arg_checks = plugin.process_all_checks(arg_checks, option)
# Generate unpacks # Generate unpacks
arg_unpack = self.map_selected_arguments('get_type_unpack', arg_unpack = self.map_selected_arguments('get_type_unpack',
'process_single_unpack', option, option['arguments']) 'process_single_unpack', option, option['arguments'])
arg_unpack = ', '.join(arg_unpack) arg_unpack = ', '.join(arg_unpack)
for plugin in self.plugins: for plugin in self.plugins:
arg_unpack = plugin.process_all_unpacks(arg_unpack, option) arg_unpack = plugin.process_all_unpacks(arg_unpack, option)
@ -209,16 +210,16 @@ class cwrap(object):
try: try:
return_result = self.get_return_wrapper(option).substitute() return_result = self.get_return_wrapper(option).substitute()
call = self.FUNCTION_CALL_TEMPLATE.substitute(capture_result='', call = self.FUNCTION_CALL_TEMPLATE.substitute(capture_result='',
cname=option['cname'], arg_unpack=arg_unpack) cname=option['cname'], arg_unpack=arg_unpack)
except KeyError: except KeyError:
return_result = self.get_return_wrapper(option).substitute(result='__result') return_result = self.get_return_wrapper(option).substitute(result='__result')
call = self.FUNCTION_CALL_TEMPLATE.substitute(capture_result=(option['return'] + ' __result = '), call = self.FUNCTION_CALL_TEMPLATE.substitute(capture_result=(option['return'] + ' __result = '),
cname=option['cname'], arg_unpack=arg_unpack) cname=option['cname'], arg_unpack=arg_unpack)
code_template = deepcopy(self.OPTION_CODE_TEMPLATE) code_template = deepcopy(self.OPTION_CODE_TEMPLATE)
for plugin in self.plugins: for plugin in self.plugins:
code_template = plugin.process_option_code_template(code_template, code_template = plugin.process_option_code_template(code_template,
option) option)
code_template = Template('\n'.join(code_template)) code_template = Template('\n'.join(code_template))
code = code_template.substitute(call=call, return_result=return_result) code = code_template.substitute(call=call, return_result=return_result)
code_lines = map(lambda s: s.strip(), code.split('\n')) code_lines = map(lambda s: s.strip(), code.split('\n'))
@ -228,6 +229,8 @@ class cwrap(object):
depth -= line.count('}') * 2 depth -= line.count('}') * 2
code += ' ' * depth + line + '\n' code += ' ' * depth + line + '\n'
depth += line.count('{') * 2 depth += line.count('{') * 2
depth += line.count('(') * 4
depth -= line.count(')') * 4
# Put everything together # Put everything together
return self.OPTION_TEMPLATE.substitute( return self.OPTION_TEMPLATE.substitute(

View File

@ -1,5 +1,6 @@
from . import CWrapPlugin from . import CWrapPlugin
class ArgcountChecker(CWrapPlugin): class ArgcountChecker(CWrapPlugin):
def process_all_checks(self, checks, option): def process_all_checks(self, checks, option):

View File

@ -1,5 +1,6 @@
from . import CWrapPlugin from . import CWrapPlugin
class ArgcountSortPlugin(CWrapPlugin): class ArgcountSortPlugin(CWrapPlugin):
def __init__(self, descending=True): def __init__(self, descending=True):
@ -11,4 +12,3 @@ class ArgcountSortPlugin(CWrapPlugin):
for declaration in declarations: for declaration in declarations:
declaration['options'].sort(key=num_checked_args, reverse=self.descending) declaration['options'].sort(key=num_checked_args, reverse=self.descending)
return declarations return declarations

View File

@ -1,6 +1,7 @@
from . import CWrapPlugin from . import CWrapPlugin
from string import Template from string import Template
class ArgumentReferences(CWrapPlugin): class ArgumentReferences(CWrapPlugin):
def initialize(self, cwrap): def initialize(self, cwrap):

View File

@ -1,5 +1,6 @@
from . import CWrapPlugin from . import CWrapPlugin
class AutoGPU(CWrapPlugin): class AutoGPU(CWrapPlugin):
def __init__(self, has_self=True, condition=None): def __init__(self, has_self=True, condition=None):

View File

@ -1,6 +1,7 @@
from . import CWrapPlugin from . import CWrapPlugin
from string import Template from string import Template
class BeforeAfterCall(CWrapPlugin): class BeforeAfterCall(CWrapPlugin):
def initialize(self, cwrap): def initialize(self, cwrap):
@ -13,7 +14,7 @@ class BeforeAfterCall(CWrapPlugin):
if '$' in prepend_str: if '$' in prepend_str:
before_call_template = Template(option[name]) before_call_template = Template(option[name])
args = {'arg' + str(i): self.cwrap.get_arg_accessor(arg, option) for i, arg args = {'arg' + str(i): self.cwrap.get_arg_accessor(arg, option) for i, arg
in enumerate(option['arguments'])} in enumerate(option['arguments'])}
prepend_str = before_call_template.substitute(args) prepend_str = before_call_template.substitute(args)
template.insert(offset, prepend_str) template.insert(offset, prepend_str)
@ -23,5 +24,5 @@ class BeforeAfterCall(CWrapPlugin):
self.insert_snippet(template, option, call_idx, 'before_call') self.insert_snippet(template, option, call_idx, 'before_call')
# call position might have changed # call position might have changed
call_idx = template.index('$call') call_idx = template.index('$call')
self.insert_snippet(template, option, call_idx+1, 'after_call') self.insert_snippet(template, option, call_idx + 1, 'after_call')
return template return template

View File

@ -1,6 +1,7 @@
from . import CWrapPlugin from . import CWrapPlugin
from string import Template from string import Template
class BoolOption(CWrapPlugin): class BoolOption(CWrapPlugin):
UNPACK_TEMPLATE = Template('$arg == Py_True ? $if_true : $if_false') UNPACK_TEMPLATE = Template('$arg == Py_True ? $if_true : $if_false')
@ -16,4 +17,3 @@ class BoolOption(CWrapPlugin):
if self.is_bool_option(arg): if self.is_bool_option(arg):
return Template(self.UNPACK_TEMPLATE.safe_substitute( return Template(self.UNPACK_TEMPLATE.safe_substitute(
if_true=arg['if_true'], if_false=arg['if_false'])) if_true=arg['if_true'], if_false=arg['if_false']))

View File

@ -1,6 +1,7 @@
from . import CWrapPlugin from . import CWrapPlugin
from string import Template from string import Template
class ConstantArguments(CWrapPlugin): class ConstantArguments(CWrapPlugin):
def process_declarations(self, declarations): def process_declarations(self, declarations):
@ -18,5 +19,3 @@ class ConstantArguments(CWrapPlugin):
def get_arg_accessor(self, arg, option): def get_arg_accessor(self, arg, option):
if arg['type'] == 'CONSTANT': if arg['type'] == 'CONSTANT':
return arg['name'] return arg['name']

View File

@ -3,30 +3,31 @@ from copy import deepcopy
from . import CWrapPlugin from . import CWrapPlugin
from itertools import product from itertools import product
class CuDNNPlugin(CWrapPlugin): class CuDNNPlugin(CWrapPlugin):
TYPE_UNPACK = { TYPE_UNPACK = {
'THTensor*': Template('((THPVoidTensor*)$arg)->cdata'), 'THTensor*': Template('((THPVoidTensor*)$arg)->cdata'),
'int': Template('THPUtils_unpackLong($arg)'), 'int': Template('THPUtils_unpackLong($arg)'),
'std::vector<int>': Template('THPUtils_unpackIntTuple($arg)'), 'std::vector<int>': Template('THPUtils_unpackIntTuple($arg)'),
'cudnnDataType_t': Template('$arg'), 'cudnnDataType_t': Template('$arg'),
'cudnnHandle_t': Template('$arg'), 'cudnnHandle_t': Template('$arg'),
'Convolution*': Template('(Convolution*)THPWrapper_get($arg)'), 'Convolution*': Template('(Convolution*)THPWrapper_get($arg)'),
'bool': Template('$arg == Py_True'), 'bool': Template('$arg == Py_True'),
'double': Template('THPDoubleUtils_unpackReal($arg)'), 'double': Template('THPDoubleUtils_unpackReal($arg)'),
} }
TYPE_CHECK = { TYPE_CHECK = {
'Convolution*': Template('THPWrapper_check($arg)'), 'Convolution*': Template('THPWrapper_check($arg)'),
'THTensor*': Template('(PyObject*)Py_TYPE($arg) == tensorClass'), 'THTensor*': Template('(PyObject*)Py_TYPE($arg) == tensorClass'),
'int': Template('THPUtils_checkLong($arg)'), 'int': Template('THPUtils_checkLong($arg)'),
'std::vector<int>': Template('THPUtils_checkIntTuple($arg)'), 'std::vector<int>': Template('THPUtils_checkIntTuple($arg)'),
'bool': Template('PyBool_Check($arg)'), 'bool': Template('PyBool_Check($arg)'),
'double': Template('THPDoubleUtils_checkReal($arg)'), 'double': Template('THPDoubleUtils_checkReal($arg)'),
} }
RETURN_WRAPPER = { RETURN_WRAPPER = {
'Convolution*': Template('return THPWrapper_New($result, [](void* arg) { delete (Convolution*)arg; });'), 'Convolution*': Template('return THPWrapper_New($result, [](void* arg) { delete (Convolution*)arg; });'),
} }
METHODS_DECLARATION = Template(""" METHODS_DECLARATION = Template("""
@ -123,7 +124,8 @@ static PyObject * $name(PyObject *self, PyObject *args, PyObject *kwargs)
def filter_unique_options(self, options): def filter_unique_options(self, options):
def signature(option): def signature(option):
return '#'.join(arg['type'] for arg in option['arguments'] if not 'ignore_check' in arg or not arg['ignore_check']) return '#'.join(arg['type'] for arg in option['arguments']
if 'ignore_check' not in arg or not arg['ignore_check'])
seen_signatures = set() seen_signatures = set()
unique = [] unique = []
for option in options: for option in options:
@ -151,8 +153,8 @@ static PyObject * $name(PyObject *self, PyObject *args, PyObject *kwargs)
if not declaration.get('only_register'): if not declaration.get('only_register'):
extra_flags += ' | METH_KEYWORDS' extra_flags += ' | METH_KEYWORDS'
entry = Template(' {"$python_name", (PyCFunction)$name, METH_VARARGS$extra_flags, NULL},\n').substitute( entry = Template(' {"$python_name", (PyCFunction)$name, METH_VARARGS$extra_flags, NULL},\n').substitute(
python_name=declaration['python_name'], name=declaration['name'], extra_flags=extra_flags python_name=declaration['python_name'], name=declaration['name'], extra_flags=extra_flags
) )
if 'defined_if' in declaration: if 'defined_if' in declaration:
entry = self.preprocessor_guard(entry, declaration['defined_if']) entry = self.preprocessor_guard(entry, declaration['defined_if'])
methods += entry methods += entry

View File

@ -1,6 +1,7 @@
from . import CWrapPlugin from . import CWrapPlugin
from string import Template from string import Template
class GILRelease(CWrapPlugin): class GILRelease(CWrapPlugin):
OPTION_START = [ OPTION_START = [
@ -24,6 +25,5 @@ class GILRelease(CWrapPlugin):
def process_option_code_template(self, template, option): def process_option_code_template(self, template, option):
call_idx = template.index('$call') call_idx = template.index('$call')
template.insert(call_idx, self.BEFORE_CALL) template.insert(call_idx, self.BEFORE_CALL)
template.insert(call_idx+2, self.AFTER_CALL) template.insert(call_idx + 2, self.AFTER_CALL)
return self.OPTION_START + template + self.OPTION_END return self.OPTION_START + template + self.OPTION_END

View File

@ -0,0 +1,211 @@
import copy
from string import Template
from . import CWrapPlugin
class GenericNN(CWrapPlugin):
INPUT_TYPE_CHECK = Template("checkTypes(is_cuda, $type, $tensor_args);")
HEADER_TEMPLATE = Template("void $name($args);")
WRAPPER_TEMPLATE = Template("""\
void $name($args)
{
bool is_cuda = $input->isCuda();
auto type = $input->type();
$type_check
$options
} else {
throw std::runtime_error("invalid arguments");
}
}
""")
THNN_TEMPLATE = Template("""\
if (type == thpp::Type::FLOAT) {
THNN_Float$name(
NULL,
$float_args);
} else if (type == thpp::Type::DOUBLE) {
THNN_Double$name(
NULL,
$double_args);
} else {
throw std::runtime_error("unsupported tensor type");
}""")
THCUNN_TEMPLATE = Template("""\
#ifdef WITH_CUDA
if (type == thpp::Type::FLOAT) {
THNN_Cuda$name(
state,
$float_args);
} else if (type == thpp::Type::DOUBLE) {
THNN_CudaDouble$name(
state,
$double_args);
} else if (type == thpp::Type::HALF) {
THNN_CudaHalf$name(
state,
$half_args);
} else {
throw std::runtime_error("unsupported tensor type");
}
#endif
""")
INDEX_TENSOR_TYPES = {'THIndexTensor*', 'THCIndexTensor*'}
REAL_TENSOR_TYPES = {'THTensor*', 'THCTensor*'}
INPUT_ARGUMENT_MAP = {
'THNNState*': 'void*',
'THCState*': 'void*',
'THTensor*': 'thpp::Tensor*',
'THCTensor*': 'thpp::Tensor*',
'THIndexTensor*': 'thpp::Tensor*',
'THIndex_t': 'long',
'real': 'double',
}
def __init__(self, header=False):
self.header = header
self.declarations = []
def process_full_file(self, base_wrapper):
if self.header:
wrapper = '#pragma once\n\n'
wrapper += '#include <THPP/Tensor.hpp>\n\n'
else:
wrapper = '#include "THNN_generic.h"\n'
wrapper = '#include "THNN_generic.inc.h"\n\n'
wrapper += 'namespace torch { namespace nn {\n\n'
wrapper += base_wrapper
wrapper += '}} // namespace torch::nn\n'
return wrapper
def process_declarations(self, declarations):
for declaration in declarations:
base_args = declaration['options'][0]['arguments']
for option in declaration['options']:
for idx, arg in enumerate(option['arguments']):
arg['formal_name'] = base_args[idx]['name']
arg['formal_type'] = base_args[idx]['type']
if idx != 1:
arg['ignore_check'] = True
return declarations
def get_arg_accessor(self, arg, option):
return self.get_type_unpack(arg, option)
def process_option_code_template(self, template, option):
code = '// fill me in'
def base_cast(arg, CReal, real):
name = arg['formal_name']
type = arg['type']
if type in self.REAL_TENSOR_TYPES:
return ('(TH{CReal}Tensor*){name}->cdata()'
.format(CReal=CReal, name=name))
elif type in self.INDEX_TENSOR_TYPES:
return '({type}){name}->cdata()'.format(type=type, name=name)
elif type == 'THCState*':
return '({}){}'.format(type, name)
elif type == 'real':
if real == 'half':
return 'THC_float2half({})'.format(name)
return '({real}){name}'.format(real=real, name=name)
return name
def cast(arg, CReal, real):
expr = base_cast(arg, CReal, real)
if arg.get('optional', False):
name = arg['formal_name']
return '{name} ? {expr} : NULL'.format(name=name, expr=expr)
return expr
if option['backend'] == 'nn':
float_args = []
double_args = []
for idx, arg in enumerate(option['arguments']):
float_args.append(cast(arg, 'Float', 'float'))
double_args.append(cast(arg, 'Double', 'double'))
code = self.THNN_TEMPLATE.substitute(
name=option['cname'],
float_args=',\n'.join(float_args),
double_args=',\n'.join(double_args))
elif option['backend'] == 'cunn':
float_args = []
double_args = []
half_args = []
for idx, arg in enumerate(option['arguments']):
float_args.append(cast(arg, 'Cuda', 'float'))
double_args.append(cast(arg, 'CudaDouble', 'double'))
half_args.append(cast(arg, 'CudaHalf', 'half'))
code = self.THCUNN_TEMPLATE.substitute(
name=option['cname'],
float_args=',\n'.join(float_args),
double_args=',\n'.join(double_args),
half_args=',\n'.join(half_args))
return [code, '']
def get_type_unpack(self, arg, option):
return Template(arg['name'])
def get_type_check(self, arg, option):
if option['backend'] == 'cunn':
return Template('is_cuda')
else:
return Template('!is_cuda')
def get_formal_args(self, arguments):
formal_args = []
for arg in arguments:
arg = copy.copy(arg)
new_type = self.INPUT_ARGUMENT_MAP.get(arg['type'])
if new_type is not None:
arg['type'] = new_type
formal_args.append(arg)
return formal_args
def get_wrapper_template(self, declaration):
# get formal arguments string
base_arguments = declaration['options'][0]['arguments']
args = self.get_formal_args(base_arguments)
arg_str = ', '.join([arg['type'] + ' ' + arg['name'] for arg in args])
if self.header:
return Template(self.HEADER_TEMPLATE.safe_substitute(args=arg_str))
def get_checked_args(tensor_types):
checked_args = []
for arg in base_arguments:
if arg['type'] in tensor_types:
name = arg.get('formal_name', arg['name'])
name_str = name
if arg.get('optional', False):
name_str = '?' + name_str
checked_args += ['"' + name_str + '"', name]
checked_args += ['NULL']
return checked_args
real_args = get_checked_args(self.REAL_TENSOR_TYPES)
long_args = get_checked_args(self.INDEX_TENSOR_TYPES)
# check input types
types_checks = []
if len(real_args) > 1:
types_checks.append(self.INPUT_TYPE_CHECK.substitute(
type='type', tensor_args=', '.join(real_args)))
if len(long_args) > 1:
types_checks.append(self.INPUT_TYPE_CHECK.substitute(
type='thpp::Type::LONG', tensor_args=', '.join(long_args)))
return Template(self.WRAPPER_TEMPLATE.safe_substitute(
input=args[0]['name'],
args=arg_str,
type_check='\n '.join(types_checks)))

View File

@ -1,6 +1,7 @@
from . import CWrapPlugin from . import CWrapPlugin
from string import Template from string import Template
class KwargsPlugin(CWrapPlugin): class KwargsPlugin(CWrapPlugin):
ACCESSOR_TEMPLATE = Template('(__tuplecount > $idx ? PyTuple_GET_ITEM(args, $idx) : __kw_$name)') ACCESSOR_TEMPLATE = Template('(__tuplecount > $idx ? PyTuple_GET_ITEM(args, $idx) : __kw_$name)')
@ -53,7 +54,8 @@ class KwargsPlugin(CWrapPlugin):
seen_args.add(name) seen_args.add(name)
args.append(name) args.append(name)
declarations = '\n '.join(['PyObject *__kw_{} = NULL;'.format(name) for name in args]) declarations = '\n '.join(['PyObject *__kw_{} = NULL;'.format(name) for name in args])
lookups = '\n '.join(['__kw_{name} = PyDict_GetItemString(kwargs, "{name}");'.format(name=name) for name in args]) lookups = '\n '.join(
['__kw_{name} = PyDict_GetItemString(kwargs, "{name}");'.format(name=name) for name in args])
start_idx = code.find('{') + 1 start_idx = code.find('{') + 1
new_code = self.WRAPPER_TEMPLATE.substitute(declarations=declarations, lookups=lookups) new_code = self.WRAPPER_TEMPLATE.substitute(declarations=declarations, lookups=lookups)
return code[:start_idx] + new_code + code[start_idx:] return code[:start_idx] + new_code + code[start_idx:]

View File

@ -1,6 +1,8 @@
from . import CWrapPlugin from . import CWrapPlugin
class NullableArguments(CWrapPlugin): class NullableArguments(CWrapPlugin):
def process_single_check(self, code, arg, arg_accessor): def process_single_check(self, code, arg, arg_accessor):
if 'nullable' in arg and arg['nullable']: if 'nullable' in arg and arg['nullable']:
return '({} || {} == Py_None)'.format(code, arg_accessor) return '({} || {} == Py_None)'.format(code, arg_accessor)
@ -10,5 +12,3 @@ class NullableArguments(CWrapPlugin):
if 'nullable' in arg and arg['nullable']: if 'nullable' in arg and arg['nullable']:
return '({} == Py_None ? NULL : {})'.format(arg_accessor, code) return '({} == Py_None ? NULL : {})'.format(arg_accessor, code)
return code return code

View File

@ -2,6 +2,7 @@ from copy import deepcopy
from . import CWrapPlugin from . import CWrapPlugin
from itertools import product from itertools import product
class OptionalArguments(CWrapPlugin): class OptionalArguments(CWrapPlugin):
def process_declarations(self, declarations): def process_declarations(self, declarations):
@ -32,20 +33,20 @@ class OptionalArguments(CWrapPlugin):
else: else:
kwarg_only_count = -kwarg_only_count kwarg_only_count = -kwarg_only_count
arg_signature = '#'.join( arg_signature = '#'.join(
arg['type'] arg['type']
for arg in option['arguments'][:kwarg_only_count] for arg in option['arguments'][:kwarg_only_count]
if not arg.get('ignore_check')) if not arg.get('ignore_check'))
if kwarg_only_count is None: if kwarg_only_count is None:
return arg_signature return arg_signature
kwarg_only_signature = '#'.join( kwarg_only_signature = '#'.join(
arg['name'] + '#' + arg['type'] arg['name'] + '#' + arg['type']
for arg in option['arguments'][kwarg_only_count:] for arg in option['arguments'][kwarg_only_count:]
if not arg.get('ignore_check')) if not arg.get('ignore_check'))
return arg_signature + "#-#" + kwarg_only_signature return arg_signature + "#-#" + kwarg_only_signature
seen_signatures = set() seen_signatures = set()
unique = [] unique = []
for option in options: for option in options:
for num_kwarg_only in range(0, len(option['arguments'])+1): for num_kwarg_only in range(0, len(option['arguments']) + 1):
sig = signature(option, num_kwarg_only) sig = signature(option, num_kwarg_only)
if sig not in seen_signatures: if sig not in seen_signatures:
if num_kwarg_only > 0: if num_kwarg_only > 0:
@ -55,4 +56,3 @@ class OptionalArguments(CWrapPlugin):
seen_signatures.add(sig) seen_signatures.add(sig)
break break
return unique return unique

View File

@ -1,9 +1,10 @@
from . import CWrapPlugin from . import CWrapPlugin
from string import Template from string import Template
class ReturnArguments(CWrapPlugin): class ReturnArguments(CWrapPlugin):
ARGUMENT_RETURN_TEMPLATE = Template("Py_INCREF($arg);\nreturn (PyObject*)($arg);") ARGUMENT_RETURN_TEMPLATE = Template("Py_INCREF($arg);\nreturn (PyObject*)($arg);")
TUPLE_RETURN_TEMPLATE = Template("return PyTuple_Pack($num_args, $args);") TUPLE_RETURN_TEMPLATE = Template("return PyTuple_Pack($num_args, $args);")
def initialize(self, cwrap): def initialize(self, cwrap):
self.cwrap = cwrap self.cwrap = cwrap
@ -16,4 +17,5 @@ class ReturnArguments(CWrapPlugin):
if len(args) == 1: if len(args) == 1:
return Template(self.ARGUMENT_RETURN_TEMPLATE.safe_substitute(arg=accessors[0])) return Template(self.ARGUMENT_RETURN_TEMPLATE.safe_substitute(arg=accessors[0]))
else: else:
return Template(self.TUPLE_RETURN_TEMPLATE.safe_substitute(num_args=len(args), args=', '.join(accessors))) return Template(self.TUPLE_RETURN_TEMPLATE.safe_substitute(num_args=len(args),
args=', '.join(accessors)))

View File

@ -26,41 +26,41 @@ $METHODS
class StandaloneExtension(CWrapPlugin): class StandaloneExtension(CWrapPlugin):
TYPE_UNPACK = { TYPE_UNPACK = {
'THFloatTensor*': Template('THPFloatTensor_CData((THPFloatTensor*)$arg)'), 'THFloatTensor*': Template('THPFloatTensor_CData((THPFloatTensor*)$arg)'),
'THDoubleTensor*': Template('THPDoubleTensor_CData((THPDoubleTensor*)$arg)'), 'THDoubleTensor*': Template('THPDoubleTensor_CData((THPDoubleTensor*)$arg)'),
'THLongTensor*': Template('THPLongTensor_CData((THPLongTensor*)$arg)'), 'THLongTensor*': Template('THPLongTensor_CData((THPLongTensor*)$arg)'),
'THIntTensor*': Template('THPIntTensor_CData((THPIntTensor*)$arg)'), 'THIntTensor*': Template('THPIntTensor_CData((THPIntTensor*)$arg)'),
'THCudaHalfTensor*': Template('THCPHalfTensor_CData((THCPHalfTensor*)$arg)'), 'THCudaHalfTensor*': Template('THCPHalfTensor_CData((THCPHalfTensor*)$arg)'),
'THCudaTensor*': Template('THCPFloatTensor_CData((THCPFloatTensor*)$arg)'), 'THCudaTensor*': Template('THCPFloatTensor_CData((THCPFloatTensor*)$arg)'),
'THCudaDoubleTensor*': Template('THCPDoubleTensor_CData((THCPDoubleTensor*)$arg)'), 'THCudaDoubleTensor*': Template('THCPDoubleTensor_CData((THCPDoubleTensor*)$arg)'),
'THCudaLongTensor*': Template('THCPLongTensor_CData((THCPLongTensor*)$arg)'), 'THCudaLongTensor*': Template('THCPLongTensor_CData((THCPLongTensor*)$arg)'),
'half': Template('THPHalfUtils_unpackReal($arg)'), 'half': Template('THPHalfUtils_unpackReal($arg)'),
'float': Template('THPFloatUtils_unpackReal($arg)'), 'float': Template('THPFloatUtils_unpackReal($arg)'),
'double': Template('THPDoubleUtils_unpackReal($arg)'), 'double': Template('THPDoubleUtils_unpackReal($arg)'),
'bool': Template('($arg == Py_True ? true : false)'), 'bool': Template('($arg == Py_True ? true : false)'),
'int': Template('THPUtils_unpackLong($arg)'), 'int': Template('THPUtils_unpackLong($arg)'),
'long': Template('THPUtils_unpackLong($arg)'), 'long': Template('THPUtils_unpackLong($arg)'),
'void*': Template('(void*)THPUtils_unpackLong($arg)'), 'void*': Template('(void*)THPUtils_unpackLong($arg)'),
'THGenerator*': Template('THPGenerator_CData((THPGenerator*)$arg)'), 'THGenerator*': Template('THPGenerator_CData((THPGenerator*)$arg)'),
} }
TYPE_CHECK = { TYPE_CHECK = {
'THDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THPDoubleTensorClass'), 'THDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THPDoubleTensorClass'),
'THFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THPFloatTensorClass'), 'THFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THPFloatTensorClass'),
'THLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THPLongTensorClass'), 'THLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THPLongTensorClass'),
'THIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIntTensorClass'), 'THIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIntTensorClass'),
'THCudaHalfTensor*': Template('THCPHalfTensor_Check($arg)'), 'THCudaHalfTensor*': Template('THCPHalfTensor_Check($arg)'),
'THCudaTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPFloatTensorClass'), 'THCudaTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPFloatTensorClass'),
'THCudaDoubleTensor*': Template('THCPDoubleTensor_Check($arg)'), 'THCudaDoubleTensor*': Template('THCPDoubleTensor_Check($arg)'),
'THCudaLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPLongTensorClass'), 'THCudaLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPLongTensorClass'),
'half': Template('THPHalfUtils_checkReal($arg)'), 'half': Template('THPHalfUtils_checkReal($arg)'),
'float': Template('THPFloatUtils_checkReal($arg)'), 'float': Template('THPFloatUtils_checkReal($arg)'),
'double': Template('THPDoubleUtils_checkReal($arg)'), 'double': Template('THPDoubleUtils_checkReal($arg)'),
'bool': Template('PyBool_Check($arg)'), 'bool': Template('PyBool_Check($arg)'),
'int': Template('THPUtils_checkLong($arg)'), 'int': Template('THPUtils_checkLong($arg)'),
'long': Template('THPUtils_checkLong($arg)'), 'long': Template('THPUtils_checkLong($arg)'),
'void*': Template('THPUtils_checkLong($arg)'), 'void*': Template('THPUtils_checkLong($arg)'),
'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'), 'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'),
} }
WRAPPER_TEMPLATE = Template(""" WRAPPER_TEMPLATE = Template("""
@ -131,6 +131,7 @@ PyObject * $name(PyObject *_unused, PyObject *args)
def get_wrapper_template(self, declaration): def get_wrapper_template(self, declaration):
arg_desc = [] arg_desc = []
def describe_arg(arg): def describe_arg(arg):
desc = self.TYPE_NAMES[arg['type']] + ' ' + arg['name'] desc = self.TYPE_NAMES[arg['type']] + ' ' + arg['name']
if arg.get('nullable'): if arg.get('nullable'):
@ -138,8 +139,8 @@ PyObject * $name(PyObject *_unused, PyObject *args)
return desc return desc
for option in declaration['options']: for option in declaration['options']:
option_desc = [describe_arg(arg) option_desc = [describe_arg(arg)
for arg in option['arguments'] for arg in option['arguments']
if not arg.get('ignore_check', False)] if not arg.get('ignore_check', False)]
if option_desc: if option_desc:
arg_desc.append('({})'.format(', '.join(option_desc))) arg_desc.append('({})'.format(', '.join(option_desc)))
else: else:

View File

@ -4,85 +4,91 @@ from . import CWrapPlugin
from itertools import product, chain from itertools import product, chain
from collections import OrderedDict from collections import OrderedDict
class THPPlugin(CWrapPlugin): class THPPlugin(CWrapPlugin):
TYPE_UNPACK = { TYPE_UNPACK = {
'THFloatTensor*': Template('((THPFloatTensor*)$arg)->cdata'), 'THFloatTensor*': Template('((THPFloatTensor*)$arg)->cdata'),
'THDoubleTensor*': Template('((THPDoubleTensor*)$arg)->cdata'), 'THDoubleTensor*': Template('((THPDoubleTensor*)$arg)->cdata'),
'THLongTensor*': Template('((THPLongTensor*)$arg)->cdata'), 'THLongTensor*': Template('((THPLongTensor*)$arg)->cdata'),
'THIntTensor*': Template('((THPIntTensor*)$arg)->cdata'), 'THIntTensor*': Template('((THPIntTensor*)$arg)->cdata'),
'THTensor*': Template('((THPTensor*)$arg)->cdata'), 'THTensor*': Template('((THPTensor*)$arg)->cdata'),
'THBoolTensor*': Template('((THPBoolTensor*)$arg)->cdata'), 'THBoolTensor*': Template('((THPBoolTensor*)$arg)->cdata'),
'THIndexTensor*': Template('((THPIndexTensor*)$arg)->cdata'), 'THIndexTensor*': Template('((THPIndexTensor*)$arg)->cdata'),
'THSFloatTensor*': Template('((THSPFloatTensor*)$arg)->cdata'), 'THCudaTensor*': Template('((THCPFloatTensor*)$arg)->cdata'),
'THCudaDoubleTensor*': Template('((THCPDoubleTensor*)$arg)->cdata'),
'THSFloatTensor*': Template('((THSPFloatTensor*)$arg)->cdata'),
'THSDoubleTensor*': Template('((THSPDoubleTensor*)$arg)->cdata'), 'THSDoubleTensor*': Template('((THSPDoubleTensor*)$arg)->cdata'),
'THSLongTensor*': Template('((THSPLongTensor*)$arg)->cdata'), 'THSLongTensor*': Template('((THSPLongTensor*)$arg)->cdata'),
'THSIntTensor*': Template('((THSPIntTensor*)$arg)->cdata'), 'THSIntTensor*': Template('((THSPIntTensor*)$arg)->cdata'),
'THSTensor*': Template('((THSPTensor*)$arg)->cdata'), 'THSTensor*': Template('((THSPTensor*)$arg)->cdata'),
'THSBoolTensor*': Template('((THSPBoolTensor*)$arg)->cdata'), 'THSBoolTensor*': Template('((THSPBoolTensor*)$arg)->cdata'),
'THSIndexTensor*': Template('((THSPIndexTensor*)$arg)->cdata'), 'THSIndexTensor*': Template('((THSPIndexTensor*)$arg)->cdata'),
'THLongStorage*': Template('((THPLongStorage*)$arg)->cdata'), 'THLongStorage*': Template('((THPLongStorage*)$arg)->cdata'),
'THStorage*': Template('((THPStorage*)$arg)->cdata'), 'THStorage*': Template('((THPStorage*)$arg)->cdata'),
'THGenerator*': Template('((THPGenerator*)$arg)->cdata'), 'THGenerator*': Template('((THPGenerator*)$arg)->cdata'),
'THSize*': Template('__size.get()'), 'THSize*': Template('__size.get()'),
'THStride*': Template('__stride.get()'), 'THStride*': Template('__stride.get()'),
'void*': Template('THPUtils_unpackLong($arg)'), 'void*': Template('THPUtils_unpackLong($arg)'),
'long': Template('THPUtils_unpackLong($arg)'), 'long': Template('THPUtils_unpackLong($arg)'),
'int': Template('THPUtils_unpackLong($arg)'), 'int': Template('THPUtils_unpackLong($arg)'),
'bool': Template('($arg == Py_True ? true : false)'), 'bool': Template('($arg == Py_True ? true : false)'),
'float': Template('THPFloatUtils_unpackReal($arg)'), 'float': Template('THPFloatUtils_unpackReal($arg)'),
'double': Template('THPDoubleUtils_unpackReal($arg)'), 'double': Template('THPDoubleUtils_unpackReal($arg)'),
'real': Template('THPUtils_(unpackReal)($arg)'), 'real': Template('THPUtils_(unpackReal)($arg)'),
'accreal': Template('THPUtils_(unpackAccreal)($arg)'), 'accreal': Template('THPUtils_(unpackAccreal)($arg)'),
} }
TYPE_CHECK = { TYPE_CHECK = {
'THDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THPDoubleTensorClass'), 'THDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THPDoubleTensorClass'),
'THFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THPFloatTensorClass'), 'THFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THPFloatTensorClass'),
'THLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THPLongTensorClass'), 'THLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THPLongTensorClass'),
'THIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIntTensorClass'), 'THIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIntTensorClass'),
'THCudaTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPFloatTensorClass'), 'THTensor*': Template('(PyObject*)Py_TYPE($arg) == THPTensorClass'),
'THTensor*': Template('(PyObject*)Py_TYPE($arg) == THPTensorClass'), 'THBoolTensor*': Template('(PyObject*)Py_TYPE($arg) == THPBoolTensorClass'),
'THBoolTensor*': Template('(PyObject*)Py_TYPE($arg) == THPBoolTensorClass'), 'THIndexTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIndexTensorClass'),
'THIndexTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIndexTensorClass'),
'THCudaTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPFloatTensorClass'),
'THCudaDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPDoubleTensorClass'),
'THSDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPDoubleTensorClass'), 'THSDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPDoubleTensorClass'),
'THSFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPFloatTensorClass'), 'THSFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPFloatTensorClass'),
'THSLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPLongTensorClass'), 'THSLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPLongTensorClass'),
'THSIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPIntTensorClass'), 'THSIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPIntTensorClass'),
'THSTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPTensorClass'), 'THSTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPTensorClass'),
'THSBoolTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPBoolTensorClass'), 'THSBoolTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPBoolTensorClass'),
'THSIndexTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPIndexTensorClass'), 'THSIndexTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPIndexTensorClass'),
'THLongStorage*': Template('(PyObject*)Py_TYPE($arg) == THPLongStorageClass'), 'THLongStorage*': Template('(PyObject*)Py_TYPE($arg) == THPLongStorageClass'),
'THStorage*': Template('(PyObject*)Py_TYPE($arg) == THPStorageClass'), 'THStorage*': Template('(PyObject*)Py_TYPE($arg) == THPStorageClass'),
'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'), 'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'),
'THSize*': Template('THPUtils_tryUnpackLongs($arg, __size)'), 'THSize*': Template('THPUtils_tryUnpackLongs($arg, __size)'),
'THStride*': Template('THPUtils_tryUnpackLongs($arg, __stride)'), 'THStride*': Template('THPUtils_tryUnpackLongs($arg, __stride)'),
'void*': Template('THPUtils_checkLong($arg)'), 'void*': Template('THPUtils_checkLong($arg)'),
'long': Template('THPUtils_checkLong($arg)'), 'long': Template('THPUtils_checkLong($arg)'),
'int': Template('THPUtils_checkLong($arg)'), 'int': Template('THPUtils_checkLong($arg)'),
'bool': Template('PyBool_Check($arg)'), 'bool': Template('PyBool_Check($arg)'),
'float': Template('THPFloatUtils_checkReal($arg)'), 'float': Template('THPFloatUtils_checkReal($arg)'),
'double': Template('THPDoubleUtils_checkReal($arg)'), 'double': Template('THPDoubleUtils_checkReal($arg)'),
'real': Template('THPUtils_(checkReal)($arg)'), 'real': Template('THPUtils_(checkReal)($arg)'),
'accreal': Template('THPUtils_(checkAccreal)($arg)'), 'accreal': Template('THPUtils_(checkAccreal)($arg)'),
} }
SIZE_VARARG_CHECK = Template('THPUtils_tryUnpackLongVarArgs(args, $idx, __size)') SIZE_VARARG_CHECK = Template('THPUtils_tryUnpackLongVarArgs(args, $idx, __size)')
RETURN_WRAPPER = { RETURN_WRAPPER = {
'THTensor*': Template('return THPTensor_(New)($result);'), 'THTensor*': Template('return THPTensor_(New)($result);'),
'THSTensor*': Template('return THSPTensor_(New)($result);'), 'THSTensor*': Template('return THSPTensor_(New)($result);'),
'THLongTensor*': Template('return THPLongTensor_New($result);'), 'THLongTensor*': Template('return THPLongTensor_New($result);'),
'THLongStorage*': Template('return THPLongStorage_New($result);'), 'THLongStorage*': Template('return THPLongStorage_New($result);'),
# TODO: make it smarter - it should return python long if result doesn't fit into an int # TODO: make it smarter - it should return python long if result doesn't fit into an int
'long': Template('return PyInt_FromLong($result);'), 'long': Template('return PyInt_FromLong($result);'),
'accreal': Template('return THPUtils_(newAccreal)($result);'), 'accreal': Template('return THPUtils_(newAccreal)($result);'),
'self': Template('Py_INCREF(self);\nreturn (PyObject*)self;'), 'self': Template('Py_INCREF(self);\nreturn (PyObject*)self;'),
'real': Template('return THPUtils_(newReal)($result);'), 'real': Template('return THPUtils_(newReal)($result);'),
} }
TENSOR_METHODS_DECLARATION = Template(""" TENSOR_METHODS_DECLARATION = Template("""
@ -138,13 +144,13 @@ ${cpu}
return Template(code) return Template(code)
ALLOCATE_TYPE = { ALLOCATE_TYPE = {
'THTensor*': _allocate('', ALLOCATE_TMPL), 'THTensor*': _allocate('', ALLOCATE_TMPL),
'THLongTensor*': _allocate('Long', ALLOCATE_TMPL), 'THLongTensor*': _allocate('Long', ALLOCATE_TMPL),
'THIntTensor*': _allocate('Int', ALLOCATE_TMPL), 'THIntTensor*': _allocate('Int', ALLOCATE_TMPL),
'THBoolTensor*': _allocate('Byte', ALLOCATE_TMPL, ALLOCATE_CUDA), 'THBoolTensor*': _allocate('Byte', ALLOCATE_TMPL, ALLOCATE_CUDA),
'THIndexTensor*': _allocate('Long', ALLOCATE_TMPL, ALLOCATE_CUDA), 'THIndexTensor*': _allocate('Long', ALLOCATE_TMPL, ALLOCATE_CUDA),
'THSTensor*': _allocate('', ALLOCATE_TMPL, sparse=True), 'THSTensor*': _allocate('', ALLOCATE_TMPL, sparse=True),
} }
TYPE_NAMES = { TYPE_NAMES = {
@ -159,6 +165,8 @@ ${cpu}
'THIndexTensor*': '" THPModuleStr "LongTensor', 'THIndexTensor*': '" THPModuleStr "LongTensor',
'THFloatTensor*': '" THPModuleStr "FloatTensor', 'THFloatTensor*': '" THPModuleStr "FloatTensor',
'THDoubleTensor*': '" THPModuleStr "DoubleTensor', 'THDoubleTensor*': '" THPModuleStr "DoubleTensor',
'THCudaTensor*': 'torch.cuda.FloatTensor',
'THCudaDoubleTensor*': 'torch.cuda.DoubleTensor',
'THSize*': 'torch.Size', 'THSize*': 'torch.Size',
'THStride*': 'tuple', 'THStride*': 'tuple',
'long': 'int', 'long': 'int',
@ -198,14 +206,14 @@ ${cpu}
def format_args(args, var_args=False): def format_args(args, var_args=False):
option_desc = [format_arg(arg, var_args) option_desc = [format_arg(arg, var_args)
for arg in args for arg in args
if not arg.get('ignore_check', False) if not arg.get('ignore_check', False) and
and not arg.get('output')] not arg.get('output')]
output_args = list(filter(lambda a: a.get('output'), args)) output_args = list(filter(lambda a: a.get('output'), args))
if output_args: if output_args:
if len(output_args) > 1: if len(output_args) > 1:
out_type = 'tuple[' out_type = 'tuple['
out_type += ', '.join( out_type += ', '.join(
self.TYPE_NAMES[arg['type']] for arg in output_args) self.TYPE_NAMES[arg['type']] for arg in output_args)
out_type += ']' out_type += ']'
option_desc += ['#' + out_type + ' out'] option_desc += ['#' + out_type + ' out']
else: else:
@ -287,7 +295,7 @@ ${cpu}
if not output_provided: if not output_provided:
arg['ignore_check'] = True arg['ignore_check'] = True
else: else:
option_copy['argcount_offset'] = -len(out_idx) + 1 option_copy['argcount_offset'] = -len(out_idx) + 1
arg['no_kwargs'] = True arg['no_kwargs'] = True
arg['no_idx'] = True arg['no_idx'] = True
new_options.append(option_copy) new_options.append(option_copy)
@ -345,7 +353,6 @@ ${cpu}
if arg['name'] == 'self': if arg['name'] == 'self':
arg['ignore_check'] = True arg['ignore_check'] = True
declarations = [d for d in declarations if not d.get('only_stateless', False)] declarations = [d for d in declarations if not d.get('only_stateless', False)]
self.declarations.extend(filter(lambda x: not x.get('only_stateless', False), register_only)) self.declarations.extend(filter(lambda x: not x.get('only_stateless', False), register_only))
self.stateless_declarations.extend(filter(lambda x: x.get('only_stateless', False), register_only)) self.stateless_declarations.extend(filter(lambda x: x.get('only_stateless', False), register_only))
@ -377,9 +384,9 @@ ${cpu}
if declaration.get('override_method_flags'): if declaration.get('override_method_flags'):
flags = declaration['override_method_flags'] flags = declaration['override_method_flags']
entry = Template(' {"$python_name", (PyCFunction)$name, $flags, $docstring},\n').substitute( entry = Template(' {"$python_name", (PyCFunction)$name, $flags, $docstring},\n').substitute(
python_name=declaration['python_name'], name=declaration['name'], flags=flags, python_name=declaration['python_name'], name=declaration['name'], flags=flags,
docstring=declaration.get('docstring_var', 'NULL') docstring=declaration.get('docstring_var', 'NULL')
) )
if 'defined_if' in declaration: if 'defined_if' in declaration:
entry = self.preprocessor_guard(entry, declaration['defined_if']) entry = self.preprocessor_guard(entry, declaration['defined_if'])
tensor_methods += entry tensor_methods += entry
@ -392,16 +399,16 @@ ${cpu}
def process_full_file(self, code): def process_full_file(self, code):
# We have to find a place before all undefs # We have to find a place before all undefs
idx = code.find('// PUT DEFINITIONS IN HERE PLEASE') idx = code.find('// PUT DEFINITIONS IN HERE PLEASE')
return (code[:idx] return (code[:idx] +
+ self.declare_methods(False, False) self.declare_methods(False, False) +
+ self.declare_methods(True, False) self.declare_methods(True, False) +
+ self.declare_methods(False, True) self.declare_methods(False, True) +
+ self.declare_methods(True, True) self.declare_methods(True, True) +
+ code[idx:] code[idx:]
) )
def preprocessor_guard(self, code, condition): def preprocessor_guard(self, code, condition):
return '#if ' + condition + '\n' + code + '#endif\n' return '#if ' + condition + '\n' + code + '#endif\n'
def process_wrapper(self, code, declaration): def process_wrapper(self, code, declaration):
if 'defined_if' in declaration: if 'defined_if' in declaration:
@ -419,7 +426,7 @@ ${cpu}
if option['output_count'] > 1: if option['output_count'] > 1:
checks += "PyTuple_Check(__out) &&\n" + indent checks += "PyTuple_Check(__out) &&\n" + indent
length_check = "PyTuple_GET_SIZE(__out) == {} &&\n".format( length_check = "PyTuple_GET_SIZE(__out) == {} &&\n".format(
option['output_count']) option['output_count'])
checks += length_check + indent checks += length_check + indent
code = checks + code code = checks + code
else: else:
@ -443,13 +450,13 @@ ${cpu}
def generate_docstrings_cpp(self): def generate_docstrings_cpp(self):
template = Template('char* $name = "$content";') template = Template('char* $name = "$content";')
return '\n\n'.join( return '\n\n'.join(
template.substitute(name=decl['docstring_var'], content=decl['docstring_content']) template.substitute(name=decl['docstring_var'], content=decl['docstring_content'])
for decl in chain(self.declarations, self.stateless_declarations) for decl in chain(self.declarations, self.stateless_declarations)
if 'docstring_var' in decl) if 'docstring_var' in decl)
def generate_docstrings_h(self): def generate_docstrings_h(self):
template = Template('extern char* $name;') template = Template('extern char* $name;')
return '\n\n'.join( return '\n\n'.join(
template.substitute(name=decl['docstring_var']) template.substitute(name=decl['docstring_var'])
for decl in chain(self.declarations, self.stateless_declarations) for decl in chain(self.declarations, self.stateless_declarations)
if 'docstring_var' in decl) if 'docstring_var' in decl)

View File

@ -58,3 +58,4 @@ from .ReturnArguments import ReturnArguments
from .GILRelease import GILRelease from .GILRelease import GILRelease
from .AutoGPU import AutoGPU from .AutoGPU import AutoGPU
from .CuDNNPlugin import CuDNNPlugin from .CuDNNPlugin import CuDNNPlugin
from .GenericNN import GenericNN

View File

@ -0,0 +1,40 @@
FROM nvidia/cuda:8.0-devel-ubuntu14.04
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
curl \
ca-certificates \
libjpeg-dev \
libpng-dev &&\
rm -rf /var/lib/apt/lists/*
RUN curl -fsSL http://developer.download.nvidia.com/compute/redist/cudnn/v6.0/cudnn-8.0-linux-x64-v6.0-rc.tgz -O && \
tar -xzf cudnn-8.0-linux-x64-v6.0-rc.tgz -C /usr/local && \
rm cudnn-8.0-linux-x64-v6.0-rc.tgz && \
ldconfig
RUN ln -s /usr/local/cuda/lib64/libcudnn.so.6.0.5 /usr/lib/x86_64-linux-gnu/libcudnn.so.6.0.5
RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-4.2.12-Linux-x86_64.sh && \
chmod +x ~/miniconda.sh && \
~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh && \
/opt/conda/bin/conda install conda-build && \
/opt/conda/bin/conda create -y --name pytorch-py35 python=3.5.2 numpy scipy ipython mkl&& \
/opt/conda/bin/conda clean -ya
ENV PATH /opt/conda/envs/pytorch-py35/bin:$PATH
RUN conda install --name pytorch-py35 -c soumith magma-cuda80
# This must be done before pip so that requirements.txt is available
WORKDIR /opt/pytorch
COPY . .
RUN cat requirements.txt | xargs -n1 pip install --no-cache-dir && \
TORCH_CUDA_ARCH_LIST="3.5 5.2 6.0 6.1+PTX" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
CMAKE_LIBRARY_PATH=/opt/conda/envs/pytorch-py35/lib \
CMAKE_INCLUDE_PATH=/opt/conda/envs/pytorch-py35/include \
pip install -v .
WORKDIR /workspace
RUN chmod -R a+w /workspace

View File

@ -2,12 +2,13 @@ import os
import sys import sys
from string import Template, ascii_lowercase from string import Template, ascii_lowercase
from ..cwrap import cwrap from ..cwrap import cwrap
from ..cwrap.plugins import StandaloneExtension, NullableArguments, AutoGPU from ..cwrap.plugins import StandaloneExtension, GenericNN, NullableArguments, AutoGPU
BASE_PATH = os.path.realpath(os.path.join(__file__, '..', '..', '..')) BASE_PATH = os.path.realpath(os.path.join(__file__, '..', '..', '..'))
WRAPPER_PATH = os.path.join(BASE_PATH, 'torch', 'csrc', 'nn') WRAPPER_PATH = os.path.join(BASE_PATH, 'torch', 'csrc', 'nn')
THNN_UTILS_PATH = os.path.join(BASE_PATH, 'torch', '_thnn', 'utils.py') THNN_UTILS_PATH = os.path.join(BASE_PATH, 'torch', '_thnn', 'utils.py')
def import_module(name, path): def import_module(name, path):
if sys.version_info >= (3, 5): if sys.version_info >= (3, 5):
import importlib.util import importlib.util
@ -81,7 +82,8 @@ for t in ['CudaHalf', 'Cuda', 'CudaDouble']:
def wrap_function(name, type, arguments): def wrap_function(name, type, arguments):
cname = 'THNN_' + type + name cname = 'THNN_' + type + name
declaration = '' declaration = ''
declaration += 'extern "C" void ' + cname + '(' + ', '.join(TYPE_TRANSFORMS[type].get(arg.type, arg.type) for arg in arguments) + ');\n' declaration += 'extern "C" void ' + cname + \
'(' + ', '.join(TYPE_TRANSFORMS[type].get(arg.type, arg.type) for arg in arguments) + ');\n'
declaration += FUNCTION_TEMPLATE.substitute(name=type + name, cname=cname) declaration += FUNCTION_TEMPLATE.substitute(name=type + name, cname=cname)
indent = ' ' * 4 indent = ' ' * 4
dict_indent = ' ' * 6 dict_indent = ' ' * 6
@ -91,15 +93,18 @@ def wrap_function(name, type, arguments):
declaration += prefix + TYPE_TRANSFORMS[type].get(arg.type, arg.type) + ' ' + arg.name + '\n' declaration += prefix + TYPE_TRANSFORMS[type].get(arg.type, arg.type) + ' ' + arg.name + '\n'
else: else:
t = TYPE_TRANSFORMS[type].get(arg.type, arg.type) t = TYPE_TRANSFORMS[type].get(arg.type, arg.type)
declaration += prefix + 'type: ' + t + '\n' + \ declaration += prefix + 'type: ' + t + '\n' + \
dict_indent + 'name: ' + arg.name + '\n' + \ dict_indent + 'name: ' + arg.name + '\n' + \
dict_indent + 'nullable: True' + '\n' dict_indent + 'nullable: True' + '\n'
declaration += ']]\n\n\n' declaration += ']]\n\n\n'
return declaration return declaration
def generate_wrappers(): def generate_wrappers():
wrap_nn() wrap_nn()
wrap_cunn() wrap_cunn()
wrap_generic()
def wrap_nn(): def wrap_nn():
wrapper = '#include <TH/TH.h>\n\n\n' wrapper = '#include <TH/TH.h>\n\n\n'
@ -114,6 +119,7 @@ def wrap_nn():
NullableArguments(), NullableArguments(),
]) ])
def wrap_cunn(): def wrap_cunn():
wrapper = '#include <TH/TH.h>\n' wrapper = '#include <TH/TH.h>\n'
wrapper += '#include <THC/THC.h>\n\n\n' wrapper += '#include <THC/THC.h>\n\n\n'
@ -128,3 +134,66 @@ def wrap_cunn():
NullableArguments(), NullableArguments(),
AutoGPU(has_self=False), AutoGPU(has_self=False),
]) ])
GENERIC_FUNCTION_TEMPLATE = Template("""\
[[
name: $name
return: void
options:
""")
def wrap_generic_function(name, backends):
declaration = ''
declaration += GENERIC_FUNCTION_TEMPLATE.substitute(name=name)
for backend in backends:
declaration += ' - cname: ' + name + '\n'
declaration += ' backend: ' + backend['name'] + '\n'
declaration += ' arguments:\n'
for arg in backend['arguments']:
declaration += ' - arg: ' + arg.type + ' ' + arg.name + '\n'
if arg.is_optional:
declaration += ' optional: True\n'
declaration += ']]\n\n\n'
return declaration
def wrap_generic():
from collections import OrderedDict
defs = OrderedDict()
def should_wrap_function(name):
if name.startswith('LookupTable'):
return False
return (name.endswith('updateOutput') or
name.endswith('updateGradInput') or
name.endswith('accGradParameters') or
name.endswith('backward'))
def add_functions(name, functions):
for fn in functions:
if not should_wrap_function(fn.name):
continue
if fn.name not in defs:
defs[fn.name] = []
defs[fn.name] += [{
'name': name,
'arguments': fn.arguments[1:],
}]
add_functions('nn', thnn_utils.parse_header(thnn_utils.THNN_H_PATH))
add_functions('cunn', thnn_utils.parse_header(thnn_utils.THCUNN_H_PATH))
wrapper = ''
for name, backends in defs.items():
wrapper += wrap_generic_function(name, backends)
with open('torch/csrc/nn/THNN_generic.cwrap', 'w') as f:
f.write(wrapper)
cwrap('torch/csrc/nn/THNN_generic.cwrap', plugins=[
GenericNN(header=True),
], default_plugins=False, destination='torch/csrc/nn/THNN_generic.h')
cwrap('torch/csrc/nn/THNN_generic.cwrap', plugins=[
GenericNN(),
], default_plugins=False)

View File

@ -1,8 +1,17 @@
import ctypes.util
import os import os
from .env import check_env_flag from .env import check_env_flag
CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda') if check_env_flag('NO_CUDA'):
WITH_CUDA = not check_env_flag('NO_CUDA') and os.path.exists(CUDA_HOME) WITH_CUDA = False
if not WITH_CUDA:
CUDA_HOME = None CUDA_HOME = None
else:
CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda')
if not os.path.exists(CUDA_HOME):
cudart_path = ctypes.util.find_library('cudart')
if cudart_path is not None:
CUDA_HOME = os.path.dirname(cudart_path)
else:
CUDA_HOME = None
WITH_CUDA = CUDA_HOME is not None

View File

@ -1,9 +1,15 @@
import os import os
import glob import glob
from itertools import chain
from .env import check_env_flag from .env import check_env_flag
from .cuda import WITH_CUDA, CUDA_HOME from .cuda import WITH_CUDA, CUDA_HOME
def gather_paths(env_vars):
return list(chain(*(os.getenv(v, '').split(':') for v in env_vars)))
WITH_CUDNN = False WITH_CUDNN = False
CUDNN_LIB_DIR = None CUDNN_LIB_DIR = None
CUDNN_INCLUDE_DIR = None CUDNN_INCLUDE_DIR = None
@ -12,13 +18,19 @@ if WITH_CUDA and not check_env_flag('NO_CUDNN'):
os.getenv('CUDNN_LIB_DIR'), os.getenv('CUDNN_LIB_DIR'),
os.path.join(CUDA_HOME, 'lib'), os.path.join(CUDA_HOME, 'lib'),
os.path.join(CUDA_HOME, 'lib64'), os.path.join(CUDA_HOME, 'lib64'),
'/usr/lib/x86_64-linux-gnu/', '/usr/lib/x86_64-linux-gnu/',
])) ] + gather_paths([
'LIBRARY_PATH',
])))
include_paths = list(filter(bool, [ include_paths = list(filter(bool, [
os.getenv('CUDNN_INCLUDE_DIR'), os.getenv('CUDNN_INCLUDE_DIR'),
os.path.join(CUDA_HOME, 'include'), os.path.join(CUDA_HOME, 'include'),
'/usr/include/' '/usr/include/',
])) ] + gather_paths([
'CPATH',
'C_INCLUDE_PATH',
'CPLUS_INCLUDE_PATH',
])))
for path in lib_paths: for path in lib_paths:
if path is None or not os.path.exists(path): if path is None or not os.path.exists(path):
continue continue

View File

@ -1,4 +1,5 @@
import os import os
def check_env_flag(name): def check_env_flag(name):
return os.getenv(name) in ['ON', '1', 'YES', 'TRUE', 'Y'] return os.getenv(name) in ['ON', '1', 'YES', 'TRUE', 'Y']

View File

@ -56,6 +56,7 @@ del old_flags
# Define basic utilities # Define basic utilities
################################################################################ ################################################################################
def typename(o): def typename(o):
module = '' module = ''
class_name = '' class_name = ''
@ -91,7 +92,7 @@ def set_default_tensor_type(t):
def set_rng_state(new_state): def set_rng_state(new_state):
r"""Sets the random number generator state. r"""Sets the random number generator state.
Args: Args:
new_state (torch.ByteTensor): The desired state new_state (torch.ByteTensor): The desired state
""" """
@ -104,9 +105,9 @@ def get_rng_state():
def manual_seed(seed): def manual_seed(seed):
r"""Sets the seed for generating random numbers. And returns a r"""Sets the seed for generating random numbers. And returns a
`torch._C.Generator` object. `torch._C.Generator` object.
Args: Args:
seed (int or long): The desired seed. seed (int or long): The desired seed.
""" """
@ -114,7 +115,7 @@ def manual_seed(seed):
def initial_seed(): def initial_seed():
r"""Returns the initial seed for generating random numbers as a r"""Returns the initial seed for generating random numbers as a
python `long`. python `long`.
""" """
return default_generator.initial_seed() return default_generator.initial_seed()
@ -130,61 +131,101 @@ from ._tensor_str import set_printoptions
from .storage import _StorageBase from .storage import _StorageBase
from .tensor import _TensorBase from .tensor import _TensorBase
class DoubleStorage(_C.DoubleStorageBase, _StorageBase): class DoubleStorage(_C.DoubleStorageBase, _StorageBase):
pass pass
class FloatStorage(_C.FloatStorageBase, _StorageBase): class FloatStorage(_C.FloatStorageBase, _StorageBase):
pass pass
class LongStorage(_C.LongStorageBase, _StorageBase): class LongStorage(_C.LongStorageBase, _StorageBase):
pass pass
class IntStorage(_C.IntStorageBase, _StorageBase): class IntStorage(_C.IntStorageBase, _StorageBase):
pass pass
class ShortStorage(_C.ShortStorageBase, _StorageBase): class ShortStorage(_C.ShortStorageBase, _StorageBase):
pass pass
class CharStorage(_C.CharStorageBase, _StorageBase): class CharStorage(_C.CharStorageBase, _StorageBase):
pass pass
class ByteStorage(_C.ByteStorageBase, _StorageBase): class ByteStorage(_C.ByteStorageBase, _StorageBase):
pass pass
class DoubleTensor(_C.DoubleTensorBase, _TensorBase): class DoubleTensor(_C.DoubleTensorBase, _TensorBase):
def is_signed(self): def is_signed(self):
return True return True
@classmethod @classmethod
def storage_type(cls): def storage_type(cls):
return DoubleStorage return DoubleStorage
class FloatTensor(_C.FloatTensorBase, _TensorBase): class FloatTensor(_C.FloatTensorBase, _TensorBase):
def is_signed(self): def is_signed(self):
return True return True
@classmethod @classmethod
def storage_type(cls): def storage_type(cls):
return FloatStorage return FloatStorage
class LongTensor(_C.LongTensorBase, _TensorBase): class LongTensor(_C.LongTensorBase, _TensorBase):
def is_signed(self): def is_signed(self):
return True return True
@classmethod @classmethod
def storage_type(cls): def storage_type(cls):
return LongStorage return LongStorage
class IntTensor(_C.IntTensorBase, _TensorBase): class IntTensor(_C.IntTensorBase, _TensorBase):
def is_signed(self): def is_signed(self):
return True return True
@classmethod @classmethod
def storage_type(cls): def storage_type(cls):
return IntStorage return IntStorage
class ShortTensor(_C.ShortTensorBase, _TensorBase): class ShortTensor(_C.ShortTensorBase, _TensorBase):
def is_signed(self): def is_signed(self):
return True return True
@classmethod @classmethod
def storage_type(cls): def storage_type(cls):
return ShortStorage return ShortStorage
class CharTensor(_C.CharTensorBase, _TensorBase): class CharTensor(_C.CharTensorBase, _TensorBase):
def is_signed(self): def is_signed(self):
# TODO # TODO
return False return False
@classmethod @classmethod
def storage_type(cls): def storage_type(cls):
return CharStorage return CharStorage
class ByteTensor(_C.ByteTensorBase, _TensorBase): class ByteTensor(_C.ByteTensorBase, _TensorBase):
def is_signed(self): def is_signed(self):
return False return False
@classmethod @classmethod
def storage_type(cls): def storage_type(cls):
return ByteStorage return ByteStorage

File diff suppressed because it is too large Load Diff

View File

@ -22,7 +22,7 @@ def set_printoptions(
edgeitems=None, edgeitems=None,
linewidth=None, linewidth=None,
profile=None, profile=None,
): ):
"""Set options for printing. Items shamelessly taken from Numpy """Set options for printing. Items shamelessly taken from Numpy
Args: Args:
@ -119,7 +119,7 @@ def _number_format(tensor, min_sz=-1):
else: else:
if exp_max > prec + 1 or exp_max < 0: if exp_max > prec + 1 or exp_max < 0:
sz = max(min_sz, 7) sz = max(min_sz, 7)
scale = math.pow(10, exp_max-1) scale = math.pow(10, exp_max - 1)
else: else:
if exp_max == 0: if exp_max == 0:
sz = 7 sz = 7
@ -132,19 +132,19 @@ def _number_format(tensor, min_sz=-1):
def _tensor_str(self): def _tensor_str(self):
n = PRINT_OPTS.edgeitems n = PRINT_OPTS.edgeitems
has_hdots = self.size()[-1] > 2*n has_hdots = self.size()[-1] > 2 * n
has_vdots = self.size()[-2] > 2*n has_vdots = self.size()[-2] > 2 * n
print_full_mat = not has_hdots and not has_vdots print_full_mat = not has_hdots and not has_vdots
formatter = _number_format(self, min_sz=3 if not print_full_mat else 0) formatter = _number_format(self, min_sz=3 if not print_full_mat else 0)
print_dots = self.numel() >= PRINT_OPTS.threshold print_dots = self.numel() >= PRINT_OPTS.threshold
dim_sz = max(2, max(len(str(x)) for x in self.size())) dim_sz = max(2, max(len(str(x)) for x in self.size()))
dim_fmt = "{:^" + str(dim_sz) + "}" dim_fmt = "{:^" + str(dim_sz) + "}"
dot_fmt = u"{:^" + str(dim_sz+1) + "}" dot_fmt = u"{:^" + str(dim_sz + 1) + "}"
counter_dim = self.ndimension() - 2 counter_dim = self.ndimension() - 2
counter = torch.LongStorage(counter_dim).fill_(0) counter = torch.LongStorage(counter_dim).fill_(0)
counter[counter.size()-1] = -1 counter[counter.size() - 1] = -1
finished = False finished = False
strt = '' strt = ''
while True: while True:
@ -152,7 +152,7 @@ def _tensor_str(self):
nskipped = [False for i in counter] nskipped = [False for i in counter]
for i in _range(counter_dim - 1, -1, -1): for i in _range(counter_dim - 1, -1, -1):
counter[i] += 1 counter[i] += 1
if print_dots and counter[i] == n and self.size(i) > 2*n: if print_dots and counter[i] == n and self.size(i) > 2 * n:
counter[i] = self.size(i) - n counter[i] = self.size(i) - n
nskipped[i] = True nskipped[i] = True
if counter[i] == self.size(i): if counter[i] == self.size(i):
@ -188,18 +188,18 @@ def __repr_row(row, indent, fmt, scale, sz, truncate=None):
if truncate is not None: if truncate is not None:
dotfmt = " {:^5} " dotfmt = " {:^5} "
return (indent + return (indent +
' '.join(fmt.format(val/scale) for val in row[:truncate]) + ' '.join(fmt.format(val / scale) for val in row[:truncate]) +
dotfmt.format('...') + dotfmt.format('...') +
' '.join(fmt.format(val/scale) for val in row[-truncate:]) + ' '.join(fmt.format(val / scale) for val in row[-truncate:]) +
'\n') '\n')
else: else:
return indent + ' '.join(fmt.format(val/scale) for val in row) + '\n' return indent + ' '.join(fmt.format(val / scale) for val in row) + '\n'
def _matrix_str(self, indent='', formatter=None, force_truncate=False): def _matrix_str(self, indent='', formatter=None, force_truncate=False):
n = PRINT_OPTS.edgeitems n = PRINT_OPTS.edgeitems
has_hdots = self.size(1) > 2*n has_hdots = self.size(1) > 2 * n
has_vdots = self.size(0) > 2*n has_vdots = self.size(0) > 2 * n
print_full_mat = not has_hdots and not has_vdots print_full_mat = not has_hdots and not has_vdots
if formatter is None: if formatter is None:
@ -207,14 +207,14 @@ def _matrix_str(self, indent='', formatter=None, force_truncate=False):
min_sz=5 if not print_full_mat else 0) min_sz=5 if not print_full_mat else 0)
else: else:
fmt, scale, sz = formatter fmt, scale, sz = formatter
nColumnPerLine = int(math.floor((PRINT_OPTS.linewidth-len(indent))/(sz+1))) nColumnPerLine = int(math.floor((PRINT_OPTS.linewidth - len(indent)) / (sz + 1)))
strt = '' strt = ''
firstColumn = 0 firstColumn = 0
if not force_truncate and \ if not force_truncate and \
(self.numel() < PRINT_OPTS.threshold or print_full_mat): (self.numel() < PRINT_OPTS.threshold or print_full_mat):
while firstColumn < self.size(1): while firstColumn < self.size(1):
lastColumn = min(firstColumn + nColumnPerLine - 1, self.size(1)-1) lastColumn = min(firstColumn + nColumnPerLine - 1, self.size(1) - 1)
if nColumnPerLine < self.size(1): if nColumnPerLine < self.size(1):
strt += '\n' if firstColumn != 1 else '' strt += '\n' if firstColumn != 1 else ''
strt += 'Columns {} to {} \n{}'.format( strt += 'Columns {} to {} \n{}'.format(
@ -223,15 +223,15 @@ def _matrix_str(self, indent='', formatter=None, force_truncate=False):
strt += SCALE_FORMAT.format(scale) strt += SCALE_FORMAT.format(scale)
for l in _range(self.size(0)): for l in _range(self.size(0)):
strt += indent + (' ' if scale != 1 else '') strt += indent + (' ' if scale != 1 else '')
row_slice = self[l, firstColumn:lastColumn+1] row_slice = self[l, firstColumn:lastColumn + 1]
strt += ' '.join(fmt.format(val/scale) for val in row_slice) strt += ' '.join(fmt.format(val / scale) for val in row_slice)
strt += '\n' strt += '\n'
firstColumn = lastColumn + 1 firstColumn = lastColumn + 1
else: else:
if scale != 1: if scale != 1:
strt += SCALE_FORMAT.format(scale) strt += SCALE_FORMAT.format(scale)
if has_vdots and has_hdots: if has_vdots and has_hdots:
vdotfmt = "{:^" + str((sz+1)*n-1) + "}" vdotfmt = "{:^" + str((sz + 1) * n - 1) + "}"
ddotfmt = u"{:^5}" ddotfmt = u"{:^5}"
for row in self[:n]: for row in self[:n]:
strt += __repr_row(row, indent, fmt, scale, sz, n) strt += __repr_row(row, indent, fmt, scale, sz, n)
@ -245,8 +245,8 @@ def _matrix_str(self, indent='', formatter=None, force_truncate=False):
strt += __repr_row(row, indent, fmt, scale, sz, n) strt += __repr_row(row, indent, fmt, scale, sz, n)
elif has_vdots and not has_hdots: elif has_vdots and not has_hdots:
vdotfmt = u"{:^" + \ vdotfmt = u"{:^" + \
str(len(__repr_row(self[0], '', fmt, scale, sz))) + \ str(len(__repr_row(self[0], '', fmt, scale, sz))) + \
"}\n" "}\n"
for row in self[:n]: for row in self[:n]:
strt += __repr_row(row, indent, fmt, scale, sz) strt += __repr_row(row, indent, fmt, scale, sz)
strt += vdotfmt.format(u'\u22EE') strt += vdotfmt.format(u'\u22EE')
@ -269,13 +269,13 @@ def _vector_str(self):
ident = ' ' ident = ' '
if self.numel() < PRINT_OPTS.threshold: if self.numel() < PRINT_OPTS.threshold:
return (strt + return (strt +
'\n'.join(ident + fmt.format(val/scale) for val in self) + '\n'.join(ident + fmt.format(val / scale) for val in self) +
'\n') '\n')
else: else:
return (strt + return (strt +
'\n'.join(ident + fmt.format(val/scale) for val in self[:n]) + '\n'.join(ident + fmt.format(val / scale) for val in self[:n]) +
'\n' + (ident + dotfmt.format(u"\u22EE")) + '\n' + (ident + dotfmt.format(u"\u22EE")) +
'\n'.join(ident + fmt.format(val/scale) for val in self[-n:]) + '\n'.join(ident + fmt.format(val / scale) for val in self[-n:]) +
'\n') '\n')
@ -295,4 +295,3 @@ def _str(self):
strt += '[{} of size {}{}]\n'.format(torch.typename(self), strt += '[{} of size {}{}]\n'.format(torch.typename(self),
size_str, device_str) size_str, device_str)
return '\n' + strt return '\n' + strt

View File

@ -2,7 +2,9 @@ import threading
import torch.cuda import torch.cuda
from .utils import THNN_H_PATH, THCUNN_H_PATH, parse_header, load_backend from .utils import THNN_H_PATH, THCUNN_H_PATH, parse_header, load_backend
class Backends(object): class Backends(object):
def __init__(self): def __init__(self):
self.backends = {} self.backends = {}
@ -14,6 +16,7 @@ class Backends(object):
class Backend(object): class Backend(object):
def __init__(self, lib_prefix, lib_name, functions, mixins=tuple()): def __init__(self, lib_prefix, lib_name, functions, mixins=tuple()):
self.lib_prefix = lib_prefix self.lib_prefix = lib_prefix
self.lib_name = lib_name self.lib_name = lib_name
@ -32,11 +35,12 @@ class Backend(object):
with self.loading_lock: with self.loading_lock:
if self.backend is None: if self.backend is None:
self.backend = load_backend(self.lib_prefix, self.lib_name, self.backend = load_backend(self.lib_prefix, self.lib_name,
self.functions, self.mixins) self.functions, self.mixins)
return self.backend return self.backend
class THNNCudaBackendStateMixin(object): class THNNCudaBackendStateMixin(object):
@property @property
def library_state(self): def library_state(self):
return torch.cuda._state_cdata return torch.cuda._state_cdata

View File

@ -12,6 +12,7 @@ def _unpickle_backend(backend_name):
class THNNBackendBase(object): class THNNBackendBase(object):
def __init__(self): def __init__(self):
self.methods = {} self.methods = {}
@ -33,6 +34,7 @@ class THNNBackendBase(object):
class Function(object): class Function(object):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.arguments = [] self.arguments = []
@ -46,6 +48,7 @@ class Function(object):
class Argument(object): class Argument(object):
def __init__(self, _type, name, is_optional): def __init__(self, _type, name, is_optional):
self.type = _type self.type = _type
self.name = name self.name = name

File diff suppressed because it is too large Load Diff

View File

@ -12,6 +12,7 @@ from .stochastic_function import StochasticFunction
__all__ = ['Variable', 'Function', 'StochasticFunction', 'backward'] __all__ = ['Variable', 'Function', 'StochasticFunction', 'backward']
def backward(variables, grad_variables, retain_variables=False): def backward(variables, grad_variables, retain_variables=False):
"""Computes the sum of gradients of given variables w.r.t. graph leaves. """Computes the sum of gradients of given variables w.r.t. graph leaves.
@ -28,7 +29,7 @@ def backward(variables, grad_variables, retain_variables=False):
Arguments: Arguments:
variables (sequence of Variable): Variables of which the derivative will be variables (sequence of Variable): Variables of which the derivative will be
computed. computed.
grad_variables (sequence of Variable): Gradients w.r.t. each element of grad_variables (sequence of Tensor): Gradients w.r.t. each element of
corresponding variables. Required only for non-scalar variables that corresponding variables. Required only for non-scalar variables that
require gradient. require gradient.
retain_variables (bool): If ``True``, buffers necessary for computing retain_variables (bool): If ``True``, buffers necessary for computing
@ -37,6 +38,6 @@ def backward(variables, grad_variables, retain_variables=False):
times. times.
""" """
Variable._execution_engine.run_backward( Variable._execution_engine.run_backward(
tuple(variables), tuple(grad_variables), retain_variables) tuple(variables), tuple(grad_variables), retain_variables)
assert torch._C._autograd_init() assert torch._C._autograd_init()

View File

@ -5,4 +5,4 @@ from .reduce import *
from .linalg import * from .linalg import *
from .blas import * from .blas import *
from .stochastic import * from .stochastic import *
from .compare import *

View File

@ -59,7 +59,7 @@ class Pow(Function):
def backward(self, grad_output): def backward(self, grad_output):
a, b = self.saved_tensors a, b = self.saved_tensors
return grad_output.mul(b).mul_(a.pow(b-1)), grad_output.mul(a.pow(b)).mul_(a.log()) return grad_output.mul(b).mul_(a.pow(b - 1)), grad_output.mul(a.pow(b)).mul_(a.log())
class AddConstant(InplaceFunction): class AddConstant(InplaceFunction):
@ -174,7 +174,7 @@ class PowConstant(Function):
return grad_output.mul(self.fw_result).mul_(math.log(self.constant)) return grad_output.mul(self.fw_result).mul_(math.log(self.constant))
else: else:
a = self.saved_tensors[0] a = self.saved_tensors[0]
return grad_output.mul(self.constant).mul_(a.pow(self.constant-1)) return grad_output.mul(self.constant).mul_(a.pow(self.constant - 1))
class Negate(InplaceFunction): class Negate(InplaceFunction):

View File

@ -25,7 +25,7 @@ class Addmm(_BlasBase):
self.save_for_backward(matrix1, matrix2) self.save_for_backward(matrix1, matrix2)
output = self._get_output(add_matrix) output = self._get_output(add_matrix)
return torch.addmm(self.alpha, add_matrix, self.beta, return torch.addmm(self.alpha, add_matrix, self.beta,
matrix1, matrix2, out=output) matrix1, matrix2, out=output)
def backward(self, grad_output): def backward(self, grad_output):
matrix1, matrix2 = self.saved_tensors matrix1, matrix2 = self.saved_tensors
@ -55,7 +55,7 @@ class Addbmm(_BlasBase):
self.save_for_backward(batch1, batch2) self.save_for_backward(batch1, batch2)
output = self._get_output(add_matrix) output = self._get_output(add_matrix)
return torch.addbmm(self.alpha, add_matrix, self.beta, return torch.addbmm(self.alpha, add_matrix, self.beta,
batch1, batch2, out=output) batch1, batch2, out=output)
def backward(self, grad_output): def backward(self, grad_output):
batch1, batch2 = self.saved_tensors batch1, batch2 = self.saved_tensors
@ -68,8 +68,8 @@ class Addbmm(_BlasBase):
if any(self.needs_input_grad[1:]): if any(self.needs_input_grad[1:]):
batch_grad_output = (grad_output batch_grad_output = (grad_output
.unsqueeze(0) .unsqueeze(0)
.expand(batch1.size(0), batch1.size(1), batch2.size(2))) .expand(batch1.size(0), batch1.size(1), batch2.size(2)))
if self.needs_input_grad[1]: if self.needs_input_grad[1]:
grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2)) grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2))
@ -90,7 +90,7 @@ class Baddbmm(_BlasBase):
self.save_for_backward(batch1, batch2) self.save_for_backward(batch1, batch2)
output = self._get_output(add_batch) output = self._get_output(add_batch)
return torch.baddbmm(self.alpha, add_batch, self.beta, return torch.baddbmm(self.alpha, add_batch, self.beta,
batch1, batch2, out=output) batch1, batch2, out=output)
def backward(self, grad_output): def backward(self, grad_output):
batch1, batch2 = self.saved_tensors batch1, batch2 = self.saved_tensors
@ -120,7 +120,7 @@ class Addmv(_BlasBase):
self.save_for_backward(matrix, vector) self.save_for_backward(matrix, vector)
output = self._get_output(add_vector) output = self._get_output(add_vector)
return torch.addmv(self.alpha, add_vector, self.beta, return torch.addmv(self.alpha, add_vector, self.beta,
matrix, vector, out=output) matrix, vector, out=output)
def backward(self, grad_output): def backward(self, grad_output):
matrix, vector = self.saved_tensors matrix, vector = self.saved_tensors
@ -150,7 +150,7 @@ class Addr(_BlasBase):
self.save_for_backward(vector1, vector2) self.save_for_backward(vector1, vector2)
output = self._get_output(add_matrix) output = self._get_output(add_matrix)
return torch.addr(self.alpha, add_matrix, self.beta, return torch.addr(self.alpha, add_matrix, self.beta,
vector1, vector2, out=output) vector1, vector2, out=output)
def backward(self, grad_output): def backward(self, grad_output):
vector1, vector2 = self.saved_tensors vector1, vector2 = self.saved_tensors
@ -199,4 +199,3 @@ class Dot(Function):
# TODO: trace # TODO: trace
# TODO: tril # TODO: tril
# TODO: triu # TODO: triu

View File

@ -0,0 +1,40 @@
import torch
from ..function import Function
class _CompareOp(Function):
def __init__(self, scalar=None):
super(_CompareOp, self).__init__()
self.scalar = scalar
def forward(self, tensor1, tensor2=None):
other = tensor2 if tensor2 is not None else self.scalar
mask = getattr(tensor1, self.fn_name)(other)
self.mark_non_differentiable(mask)
return mask
class Eq(_CompareOp):
fn_name = 'eq'
class Ne(_CompareOp):
fn_name = 'ne'
class Gt(_CompareOp):
fn_name = 'gt'
class Ge(_CompareOp):
fn_name = 'ge'
class Lt(_CompareOp):
fn_name = 'lt'
class Le(_CompareOp):
fn_name = 'le'

View File

@ -42,4 +42,3 @@ class Triu(Function):
return grad_output.triu(self.diagonal_idx) return grad_output.triu(self.diagonal_idx)
# TODO: trace # TODO: trace

View File

@ -165,6 +165,7 @@ class Tan(Function):
class Asin(Function): class Asin(Function):
def forward(self, i): def forward(self, i):
self.save_for_backward(i) self.save_for_backward(i)
return i.asin() return i.asin()
@ -175,6 +176,7 @@ class Asin(Function):
class Acos(Function): class Acos(Function):
def forward(self, i): def forward(self, i):
self.save_for_backward(i) self.save_for_backward(i)
return i.acos() return i.acos()
@ -185,6 +187,7 @@ class Acos(Function):
class Atan(Function): class Atan(Function):
def forward(self, i): def forward(self, i):
self.save_for_backward(i) self.save_for_backward(i)
return i.atan() return i.atan()

View File

@ -4,6 +4,7 @@ from ..function import Function
class _DimReduceFunction(Function): class _DimReduceFunction(Function):
def __init__(self, dim=None): def __init__(self, dim=None):
super(_DimReduceFunction, self).__init__() super(_DimReduceFunction, self).__init__()
self.dim = dim self.dim = dim
@ -139,6 +140,7 @@ class Kthvalue(_SelectionFunction):
class Norm(Function): class Norm(Function):
def __init__(self, norm_type=2, dim=None): def __init__(self, norm_type=2, dim=None):
super(Norm, self).__init__() super(Norm, self).__init__()
self.norm_type = norm_type self.norm_type = norm_type

View File

@ -65,7 +65,7 @@ class Normal(StochasticFunction):
output.mul_(stddevs) output.mul_(stddevs)
else: else:
raise RuntimeError("Normal function requires specifying a common " raise RuntimeError("Normal function requires specifying a common "
"stddev, or per-sample stddev") "stddev, or per-sample stddev")
output.add_(means) output.add_(means)
self.save_for_backward(output, means, stddevs) self.save_for_backward(output, means, stddevs)
self.mark_non_differentiable(output) self.mark_non_differentiable(output)
@ -74,7 +74,7 @@ class Normal(StochasticFunction):
def backward(self, reward): def backward(self, reward):
output, means, stddevs = self.saved_tensors output, means, stddevs = self.saved_tensors
grad_stddevs = None grad_stddevs = None
grad_means = means - output # == -(output - means) grad_means = means - output # == -(output - means)
assert self.stddev is not None or stddevs is not None assert self.stddev is not None or stddevs is not None
if self.stddev is not None: if self.stddev is not None:
grad_means /= 1e-6 + self.stddev ** 2 grad_means /= 1e-6 + self.stddev ** 2
@ -88,4 +88,3 @@ class Normal(StochasticFunction):
grad_means /= stddevs_sq grad_means /= stddevs_sq
grad_means *= reward grad_means *= reward
return grad_means, grad_stddevs return grad_means, grad_stddevs

View File

@ -35,18 +35,18 @@ class SetItem(InplaceFunction):
self.mark_dirty(i) self.mark_dirty(i)
if value is None: if value is None:
value = self.value value = self.value
i.set_index(self.index, value) i._set_index(self.index, value)
return i return i
def backward(self, grad_output): def backward(self, grad_output):
if self.value is None: if self.value is None:
grad_input = grad_output.clone() grad_input = grad_output.clone()
grad_input.set_index(self.index, 0) grad_input._set_index(self.index, 0)
grad_value = grad_output.index(self.index).clone() grad_value = grad_output.index(self.index).clone()
return grad_input, grad_value return grad_input, grad_value
else: else:
grad_input = grad_output.clone() grad_input = grad_output.clone()
grad_input.set_index(self.index, 0) grad_input._set_index(self.index, 0)
return grad_input return grad_input
@ -103,6 +103,7 @@ class View(Function):
class Expand(Function): class Expand(Function):
def __init__(self, sizes): def __init__(self, sizes):
super(Expand, self).__init__() super(Expand, self).__init__()
self.sizes = sizes self.sizes = sizes
@ -110,8 +111,8 @@ class Expand(Function):
def forward(self, i): def forward(self, i):
self.expanded_dims = [dim for dim, (expanded, original) self.expanded_dims = [dim for dim, (expanded, original)
in enumerate(zip(self.sizes, i.size())) in enumerate(zip(self.sizes, i.size()))
if expanded != original] if expanded != original]
result = i.expand(*self.sizes) result = i.expand(*self.sizes)
self.mark_shared_storage((i, result)) self.mark_shared_storage((i, result))
return result return result
@ -288,7 +289,7 @@ class IndexSelect(Function):
if self.needs_input_grad[0]: if self.needs_input_grad[0]:
index, = self.saved_tensors index, = self.saved_tensors
grad_tensor = grad_output.new(*self.input_size).zero_() grad_tensor = grad_output.new(*self.input_size).zero_()
grad_tensor.index_copy_(self.dim, index, grad_output) grad_tensor.index_add_(self.dim, index, grad_output)
return grad_tensor, None return grad_tensor, None
@ -304,8 +305,8 @@ class Concat(Function):
return torch.cat(inputs, self.dim) return torch.cat(inputs, self.dim)
def backward(self, grad_output): def backward(self, grad_output):
return tuple(grad_output.narrow(self.dim, end-size, size) for size, end return tuple(grad_output.narrow(self.dim, end - size, size) for size, end
in zip(self.input_sizes, _accumulate(self.input_sizes))) in zip(self.input_sizes, _accumulate(self.input_sizes)))
class Resize(Function): class Resize(Function):
@ -318,11 +319,11 @@ class Resize(Function):
def forward(self, tensor): def forward(self, tensor):
if tensor.numel() != self.numel: if tensor.numel() != self.numel:
raise RuntimeError(("requested resize to {} ({} elements in total), " raise RuntimeError(("requested resize to {} ({} elements in total), "
"but the given tensor has a size of {} ({} elements). " "but the given tensor has a size of {} ({} elements). "
"autograd's resize can only change the shape of a given " "autograd's resize can only change the shape of a given "
"tensor, while preserving the number of elements. ").format( "tensor, while preserving the number of elements. ").format(
'x'.join(map(str, self.sizes)), self.numel, 'x'.join(map(str, self.sizes)), self.numel,
'x'.join(map(str, tensor.size())), tensor.numel())) 'x'.join(map(str, tensor.size())), tensor.numel()))
self.input_sizes = tensor.size() self.input_sizes = tensor.size()
result = tensor.new(tensor).resize_(*self.sizes) result = tensor.new(tensor).resize_(*self.sizes)
self.mark_shared_storage((tensor, result)) self.mark_shared_storage((tensor, result))
@ -474,7 +475,7 @@ class _MultiSelectionFunction(Function):
class Sort(_MultiSelectionFunction): class Sort(_MultiSelectionFunction):
def __init__(self, dim=None, descending=False, return_indices=False): def __init__(self, dim=None, descending=False, return_indices=True):
super(Sort, self).__init__(dim, return_indices) super(Sort, self).__init__(dim, return_indices)
self.descending = descending self.descending = descending
@ -486,14 +487,14 @@ class Sort(_MultiSelectionFunction):
class Topk(_MultiSelectionFunction): class Topk(_MultiSelectionFunction):
def __init__(self, k, dim=None, largest=True, sort=True, return_indices=False): def __init__(self, k, dim=None, largest=True, sort=True, return_indices=True):
super(Topk, self).__init__(dim, return_indices) super(Topk, self).__init__(dim, return_indices)
self.k = k self.k = k
self.largest = largest self.largest = largest
self.sort = sort self.sort = sort
def forward(self, input): def forward(self, input):
dim = self.dim if self.dim is not None else input.dim()-1 dim = self.dim if self.dim is not None else input.dim() - 1
self.args = (self.k, dim, self.largest, self.sort) self.args = (self.k, dim, self.largest, self.sort)
return super(Topk, self).forward(input) return super(Topk, self).forward(input)
@ -567,9 +568,22 @@ class Scatter(InplaceFunction):
return grad_input, None, grad_source return grad_input, None, grad_source
# TODO: kthvalue class Repeat(Function):
# TODO: repeat
# TODO: sort def __init__(self, repeats):
# TODO: split super(Repeat, self).__init__()
# TODO: topk self.repeats = repeats
def forward(self, input):
return input.repeat(self.repeats)
def backward(self, grad_output):
grad_input = grad_output
for dim, repeat in enumerate(self.repeats):
if repeat == 1:
continue
grad_input = sum(grad_input.chunk(repeat, dim))
return grad_input
# TODO: unfold # TODO: unfold

View File

@ -71,8 +71,8 @@ class BasicEngine(object):
else: else:
if prev_fn.num_outputs != 1: if prev_fn.num_outputs != 1:
raise RuntimeError("one of the function outputs " raise RuntimeError("one of the function outputs "
"wasn't used - this is an error not, but " "wasn't used - this is an error not, but "
"it's going to be fixed soon") "it's going to be fixed soon")
prev_grad = (d_prev_fn,) prev_grad = (d_prev_fn,)
ready.appendleft((prev_fn, prev_grad)) ready.appendleft((prev_fn, prev_grad))
else: else:

View File

@ -154,9 +154,10 @@ def _nested_map(condition, fn):
return type(obj)(_map(x) for x in obj) return type(obj)(_map(x) for x in obj)
else: else:
raise ValueError("NestedIOFunction doesn't know how to process " raise ValueError("NestedIOFunction doesn't know how to process "
"an input object of type " + torch.typename(obj)) "an input object of type " + torch.typename(obj))
return _map return _map
def _iter_filter(condition): def _iter_filter(condition):
def _iter(obj): def _iter(obj):
if condition(obj): if condition(obj):
@ -169,17 +170,29 @@ def _iter_filter(condition):
yield var yield var
else: else:
raise ValueError("NestedIOFunction doesn't know how to process " raise ValueError("NestedIOFunction doesn't know how to process "
"an input object of type " + torch.typename(obj)) "an input object of type " + torch.typename(obj))
return _iter return _iter
def _unflatten(input, proto):
# unflatten a list or tuple input into a nested list/tuple structure
# specified by proto
def unflatten_helper(input, proto):
res = []
if not isinstance(proto, (list, tuple)):
return input[0], input[1:]
for e in proto:
res_e, input = unflatten_helper(input, e)
res.append(res_e)
return type(proto)(res), input
return unflatten_helper(input, proto)[0]
_iter_variables = _iter_filter(lambda o: isinstance(o, torch.autograd.Variable)) _iter_variables = _iter_filter(lambda o: isinstance(o, torch.autograd.Variable))
_iter_tensors = _iter_filter(torch.is_tensor) _iter_tensors = _iter_filter(torch.is_tensor)
_iter_None_tensors = _iter_filter(lambda o: o is None or torch.is_tensor(o)) _iter_None_tensors = _iter_filter(lambda o: o is None or torch.is_tensor(o))
_map_variable_tensor = _nested_map(lambda o: isinstance(o, torch.autograd.Variable), lambda o: o.data) _map_variable_tensor = _nested_map(lambda o: isinstance(o, torch.autograd.Variable), lambda o: o.data)
def _map_tensor_fromiter(itr):
return _nested_map(lambda o: torch.is_tensor(o), lambda o: next(itr))
class NestedIOFunction(Function): class NestedIOFunction(Function):
@ -188,11 +201,11 @@ class NestedIOFunction(Function):
flat_input = tuple(_iter_variables(input)) flat_input = tuple(_iter_variables(input))
flat_output = super(NestedIOFunction, self)._do_forward(*flat_input) flat_output = super(NestedIOFunction, self)._do_forward(*flat_input)
nested_output = self._nested_output nested_output = self._nested_output
nested_variables = _map_tensor_fromiter(iter(flat_output))(self._nested_output) nested_variables = _unflatten(flat_output, self._nested_output)
return nested_variables return nested_variables
def backward(self, *gradients): def backward(self, *gradients):
nested_gradients = _map_tensor_fromiter(iter(gradients))(self._nested_output) nested_gradients = _unflatten(gradients, self._nested_output)
del self._nested_output del self._nested_output
result = self.backward_extended(*nested_gradients) result = self.backward_extended(*nested_gradients)
del self._to_save_nested del self._to_save_nested
@ -214,7 +227,7 @@ class NestedIOFunction(Function):
@property @property
def saved_tensors(self): def saved_tensors(self):
flat_tensors = super(NestedIOFunction, self).saved_tensors flat_tensors = super(NestedIOFunction, self).saved_tensors
return _map_tensor_fromiter(iter(flat_tensors))(self._to_save_nested) return _unflatten(flat_tensors, self._to_save_nested)
def mark_dirty(self, *args, **kwargs): def mark_dirty(self, *args, **kwargs):
self.dirty_tensors = tuple(_iter_tensors((args, kwargs))) self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))

View File

@ -2,6 +2,7 @@ from .function import Function
_NOT_PROVIDED = object() _NOT_PROVIDED = object()
class StochasticFunction(Function): class StochasticFunction(Function):
def __init__(self): def __init__(self):
@ -10,7 +11,7 @@ class StochasticFunction(Function):
def _do_backward(self, grad_output, retain_variables): def _do_backward(self, grad_output, retain_variables):
if self.reward is _NOT_PROVIDED: if self.reward is _NOT_PROVIDED:
raise RuntimeError("differentiating stochastic functions requires " raise RuntimeError("differentiating stochastic functions requires "
"providing a reward") "providing a reward")
result = super(StochasticFunction, self)._do_backward((self.reward,), retain_variables) result = super(StochasticFunction, self)._do_backward((self.reward,), retain_variables)
if not retain_variables: if not retain_variables:
self.reward = None self.reward = None
@ -18,4 +19,3 @@ class StochasticFunction(Function):
def _reinforce(self, reward): def _reinforce(self, reward):
self.reward = reward self.reward = reward

View File

@ -72,12 +72,12 @@ class Variable(_C._VariableBase):
if self.creator is not None: if self.creator is not None:
if value is False: if value is False:
hint = (" If you want to use a computed variable in a subgraph " hint = (" If you want to use a computed variable in a subgraph "
"that doesn't require differentiation use " "that doesn't require differentiation use "
"var_no_grad = var.detach().") "var_no_grad = var.detach().")
else: else:
hint = '' hint = ''
raise RuntimeError("you can only change requires_grad flags of " raise RuntimeError("you can only change requires_grad flags of "
"leaf variables." + hint) "leaf variables." + hint)
self._requires_grad = value self._requires_grad = value
def __getattr__(self, name): def __getattr__(self, name):
@ -87,13 +87,13 @@ class Variable(_C._VariableBase):
def __getitem__(self, key): def __getitem__(self, key):
if (isinstance(key, Variable) and if (isinstance(key, Variable) and
type(key.data).__name__ == 'ByteTensor'): type(key.data).__name__ == 'ByteTensor'):
return MaskedSelect()(self, key) return MaskedSelect()(self, key)
return Index(key)(self) return Index(key)(self)
def __setitem__(self, key, value): def __setitem__(self, key, value):
if (isinstance(key, Variable) and if (isinstance(key, Variable) and
type(key.data).__name__ == 'ByteTensor'): type(key.data).__name__ == 'ByteTensor'):
if isinstance(value, Variable): if isinstance(value, Variable):
return MaskedCopy(inplace=True)(self, key, value) return MaskedCopy(inplace=True)(self, key, value)
else: else:
@ -107,9 +107,9 @@ class Variable(_C._VariableBase):
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
if self.creator is not None: if self.creator is not None:
raise RuntimeError("Only Variables created explicitly by the user " raise RuntimeError("Only Variables created explicitly by the user "
"(graph leaves) support the deepcopy protocol at the moment") "(graph leaves) support the deepcopy protocol at the moment")
result = type(self)(self.data.clone(), requires_grad=self.requires_grad, result = type(self)(self.data.clone(), requires_grad=self.requires_grad,
volatile=self.volatile) volatile=self.volatile)
memo[id(self)] = result memo[id(self)] = result
return result return result
@ -151,7 +151,9 @@ class Variable(_C._VariableBase):
raise RuntimeError('calling backward on a volatile variable') raise RuntimeError('calling backward on a volatile variable')
if gradient is None and self.requires_grad: if gradient is None and self.requires_grad:
if self.data.numel() != 1: if self.data.numel() != 1:
raise RuntimeError('backward should be called only on a scalar (i.e. 1-element tensor) or with gradient w.r.t. the variable') raise RuntimeError(
'backward should be called only on a scalar (i.e. 1-element tensor) '
'or with gradient w.r.t. the variable')
gradient = self.data.new().resize_as_(self.data).fill_(1) gradient = self.data.new().resize_as_(self.data).fill_(1)
self._execution_engine.run_backward((self,), (gradient,), retain_variables) self._execution_engine.run_backward((self,), (gradient,), retain_variables)
@ -219,7 +221,7 @@ class Variable(_C._VariableBase):
""" """
if not isinstance(self.creator, StochasticFunction): if not isinstance(self.creator, StochasticFunction):
raise RuntimeError("reinforce() can be only called on outputs " raise RuntimeError("reinforce() can be only called on outputs "
"of stochastic functions") "of stochastic functions")
self.creator._reinforce(reward) self.creator._reinforce(reward)
def detach(self): def detach(self):
@ -392,7 +394,7 @@ class Variable(_C._VariableBase):
def clamp(self, min=None, max=None): def clamp(self, min=None, max=None):
if min is None and max is None: if min is None and max is None:
raise ValueError("clamp requires specifying at least one of " raise ValueError("clamp requires specifying at least one of "
"min and max arguments") "min and max arguments")
elif min is None and max is not None: elif min is None and max is not None:
return CminConstant(max)(self) return CminConstant(max)(self)
elif min is not None and max is None: elif min is not None and max is None:
@ -482,6 +484,40 @@ class Variable(_C._VariableBase):
def view_as(self, tensor): def view_as(self, tensor):
return View(*tensor.size())(self) return View(*tensor.size())(self)
def split(self, split_size, dim=0):
return torch.split(self, split_size, dim)
def chunk(self, n_chunks, dim=0):
return torch.chunk(self, n_chunks, dim)
def repeat(self, *repeats):
if len(repeats) == 1 and isinstance(repeats[0], torch.Size):
repeats = repeats[0]
else:
repeats = torch.Size(repeats)
return Repeat(repeats)(self)
def var(self, dim=None, unbiased=True):
mean = self.mean(dim)
if dim is None:
mean = mean.view(*(1 for s in self.size()))
mean_expanded = mean.expand_as(self)
zero_centered = self.sub(mean_expanded)
var = zero_centered.mul(zero_centered).sum(dim)
numel = self.numel() if dim is None else self.size(dim)
return var.div(numel - int(unbiased))
def std(self, dim=None, unbiased=True):
return self.var(dim, unbiased).sqrt()
def renorm(self, norm_type, dim, maxnorm):
t = self.transpose(dim, 0)
flat = t.contiguous().view(self.size(0), -1)
norms = flat.norm(norm_type, 1)
norms = norms.clamp(max=maxnorm).div(norms.add(1e-7))
flat_out = flat.mul(norms.expand_as(flat))
return flat_out.view(t.size()).transpose(dim, 0)
@staticmethod @staticmethod
def _static_blas(cls, args, inplace): def _static_blas(cls, args, inplace):
num_args = len(args) num_args = len(args)
@ -503,7 +539,7 @@ class Variable(_C._VariableBase):
def bmm(self, batch): def bmm(self, batch):
output = Variable(self.data.new(self.data.size(0), self.data.size(1), output = Variable(self.data.new(self.data.size(0), self.data.size(1),
batch.data.size(2))) batch.data.size(2)))
return self._static_blas(Baddbmm, (output, 0, 1, self, batch), False) return self._static_blas(Baddbmm, (output, 0, 1, self, batch), False)
def mv(self, vector): def mv(self, vector):
@ -622,7 +658,7 @@ class Variable(_C._VariableBase):
if isinstance(sizes[0], torch.Size): if isinstance(sizes[0], torch.Size):
if len(sizes) > 1: if len(sizes) > 1:
raise ValueError("expand expects a several ints or a single " raise ValueError("expand expects a several ints or a single "
"torch.Size argument") "torch.Size argument")
sizes = sizes[0] sizes = sizes[0]
return Expand(sizes)(self) return Expand(sizes)(self)
@ -641,7 +677,7 @@ class Variable(_C._VariableBase):
def narrow(self, dim, start_index, length): def narrow(self, dim, start_index, length):
index = tuple(slice(None, None) for _ in range(dim)) + \ index = tuple(slice(None, None) for _ in range(dim)) + \
(slice(start_index, start_index+length),) (slice(start_index, start_index + length),)
return Index(index)(self) return Index(index)(self)
@ -672,6 +708,42 @@ class Variable(_C._VariableBase):
def bernoulli(self): def bernoulli(self):
return Bernoulli()(self) return Bernoulli()(self)
def eq(self, other):
if isinstance(other, Variable):
return Eq()(self, other)
assert not torch.is_tensor(other), "can't compare Variable and tensor"
return Eq(other)(self)
def ne(self, other):
if isinstance(other, Variable):
return Ne()(self, other)
assert not torch.is_tensor(other), "can't compare Variable and tensor"
return Ne(other)(self)
def gt(self, other):
if isinstance(other, Variable):
return Gt()(self, other)
assert not torch.is_tensor(other), "can't compare Variable and tensor"
return Gt(other)(self)
def ge(self, other):
if isinstance(other, Variable):
return Ge()(self, other)
assert not torch.is_tensor(other), "can't compare Variable and tensor"
return Ge(other)(self)
def lt(self, other):
if isinstance(other, Variable):
return Lt()(self, other)
assert not torch.is_tensor(other), "can't compare Variable and tensor"
return Lt(other)(self)
def le(self, other):
if isinstance(other, Variable):
return Le()(self, other)
assert not torch.is_tensor(other), "can't compare Variable and tensor"
return Le(other)(self)
def __add__(self, other): def __add__(self, other):
return self.add(other) return self.add(other)
__radd__ = __add__ __radd__ = __add__
@ -710,7 +782,7 @@ class Variable(_C._VariableBase):
elif dim_self == 2 and dim_other == 2: elif dim_self == 2 and dim_other == 2:
return self.mm(other) return self.mm(other)
raise ValueError("both arguments to __matmul__ need to be 1D or 2D, " raise ValueError("both arguments to __matmul__ need to be 1D or 2D, "
"but they are {}D and {}D".format(dim_self, dim_other)) "but they are {}D and {}D".format(dim_self, dim_other))
def __div__(self, other): def __div__(self, other):
return self.div(other) return self.div(other)
@ -741,6 +813,30 @@ class Variable(_C._VariableBase):
def __iter__(self): def __iter__(self):
return iter(map(lambda i: self[i], range(self.size(0)))) return iter(map(lambda i: self[i], range(self.size(0))))
def __mod__(self, other):
return self.remainder(other)
def __eq__(self, other):
return self.eq(other)
def __ne__(self, other):
return self.ne(other)
def __lt__(self, other):
return self.lt(other)
def __le__(self, other):
return self.le(other)
def __gt__(self, other):
return self.gt(other)
def __ge__(self, other):
return self.ge(other)
def __hash__(self):
return id(self)
class _torch(object): class _torch(object):
@staticmethod @staticmethod
@ -748,11 +844,11 @@ class Variable(_C._VariableBase):
return Concat(dim)(*iterable) return Concat(dim)(*iterable)
@staticmethod @staticmethod
def normal(means, stddev=1): def normal(means, std=1):
if isinstance(stddev, Variable): if isinstance(std, Variable):
return Normal()(means, stddev) return Normal()(means, std)
else: else:
return Normal(stddev)(means) return Normal(std)(means)
@staticmethod @staticmethod
def _blas(cls, args, inplace): def _blas(cls, args, inplace):

View File

@ -14,12 +14,14 @@ lib = None
thisdir = path.dirname(__file__) thisdir = path.dirname(__file__)
libpaths = ['', path.join(thisdir, '../../lib')] libpaths = ['', path.join(thisdir, '../../lib')]
if sys.platform.startswith('linux'): if sys.platform.startswith('linux'):
libnames = ['libcudnn.so.5.1.5', 'libcudnn.so.5.1.3', 'libcudnn.so.5.0.5', 'libcudnn.so.5.1.10'] libnames = ['libcudnn.so.6.0.5', 'libcudnn.so.6.0.10', 'libcudnn.so.5.1.5', 'libcudnn.so.5.1.3',
'libcudnn.so.5.0.5', 'libcudnn.so.5.1.10']
elif sys.platform == 'darwin': elif sys.platform == 'darwin':
libnames = ['libcudnn.5.dylib'] libnames = ['libcudnn.6.dylib', 'libcudnn.5.dylib']
else: else:
libnames = [] libnames = []
def _loadlib(): def _loadlib():
global lib global lib
loaded = False loaded = False
@ -39,6 +41,7 @@ def _loadlib():
lib = None lib = None
raise OSError("Could not load cuDNN") raise OSError("Could not load cuDNN")
def is_acceptable(tensor): def is_acceptable(tensor):
if not enabled: if not enabled:
return False return False
@ -58,13 +61,15 @@ def is_acceptable(tensor):
return False return False
if not _C.has_cudnn: if not _C.has_cudnn:
warnings.warn("cuDNN library has been detected, but your pytorch " warnings.warn("cuDNN library has been detected, but your pytorch "
"installation was compiled without support for it. You " "installation was compiled without support for it. You "
"might want to rebuild pytorch, making sure the library " "might want to rebuild pytorch, making sure the library "
"is visible to the build system.") "is visible to the build system.")
return False return False
return True return True
__cudnn_version = [] __cudnn_version = []
def version(): def version():
if not lib: if not lib:
raise RuntimeError("cuDNN not initialized") raise RuntimeError("cuDNN not initialized")
@ -108,7 +113,16 @@ CUDNN_GRU = 3
CUDNN_LINEAR_INPUT = 0 CUDNN_LINEAR_INPUT = 0
CUDNN_SKIP_INPUT = 1 CUDNN_SKIP_INPUT = 1
CUDNN_NON_DETERMINISTIC = 0
CUDNN_DETERMINISTIC = 1
CUDNN_RNN_ALGO_STANDARD = 0
CUDNN_RNN_ALGO_PERSIST_STATIC = 1
CUDNN_RNN_ALGO_PERSIST_DYNAMIC = 2
class CuDNNHandle: class CuDNNHandle:
def __init__(self): def __init__(self):
ptr = ctypes.c_void_p() ptr = ctypes.c_void_p()
check_error(lib.cudnnCreate(ctypes.byref(ptr))) check_error(lib.cudnnCreate(ctypes.byref(ptr)))
@ -117,7 +131,9 @@ class CuDNNHandle:
def __del__(self): def __del__(self):
check_error(lib.cudnnDestroy(self)) check_error(lib.cudnnDestroy(self))
class CuDNNError(RuntimeError): class CuDNNError(RuntimeError):
def __init__(self, status): def __init__(self, status):
self.status = status self.status = status
msg = '{}: {}'.format(status, get_error_string(status)) msg = '{}: {}'.format(status, get_error_string(status))
@ -125,6 +141,7 @@ class CuDNNError(RuntimeError):
class TensorDescriptor(object): class TensorDescriptor(object):
def __init__(self): def __init__(self):
ptr = ctypes.c_void_p() ptr = ctypes.c_void_p()
check_error(lib.cudnnCreateTensorDescriptor(ctypes.byref(ptr))) check_error(lib.cudnnCreateTensorDescriptor(ctypes.byref(ptr)))
@ -147,6 +164,7 @@ class TensorDescriptor(object):
class TensorDescriptorArray(object): class TensorDescriptorArray(object):
def __init__(self, N): def __init__(self, N):
self.ptrs = (ctypes.c_void_p * N)() self.ptrs = (ctypes.c_void_p * N)()
for i in range(N): for i in range(N):
@ -175,6 +193,7 @@ class TensorDescriptorArray(object):
class ConvolutionDescriptor(object): class ConvolutionDescriptor(object):
def __init__(self): def __init__(self):
ptr = ctypes.c_void_p() ptr = ctypes.c_void_p()
check_error(lib.cudnnCreateConvolutionDescriptor(ctypes.byref(ptr))) check_error(lib.cudnnCreateConvolutionDescriptor(ctypes.byref(ptr)))
@ -195,7 +214,9 @@ class ConvolutionDescriptor(object):
def as_tuple(self): def as_tuple(self):
return (self._pad, self._stride) return (self._pad, self._stride)
class FilterDescriptor(object): class FilterDescriptor(object):
def __init__(self): def __init__(self):
ptr = ctypes.c_void_p() ptr = ctypes.c_void_p()
check_error(lib.cudnnCreateFilterDescriptor(ctypes.byref(ptr))) check_error(lib.cudnnCreateFilterDescriptor(ctypes.byref(ptr)))
@ -216,6 +237,7 @@ class FilterDescriptor(object):
class DropoutDescriptor(object): class DropoutDescriptor(object):
def __init__(self, handle, dropout, seed): def __init__(self, handle, dropout, seed):
ptr = ctypes.c_void_p() ptr = ctypes.c_void_p()
check_error(lib.cudnnCreateDropoutDescriptor(ctypes.byref(ptr))) check_error(lib.cudnnCreateDropoutDescriptor(ctypes.byref(ptr)))
@ -241,30 +263,43 @@ class DropoutDescriptor(object):
check_error(lib.cudnnDestroyDropoutDescriptor(self)) check_error(lib.cudnnDestroyDropoutDescriptor(self))
class RNNDescriptor(object): class RNNDescriptor(object):
def __init__(self, hidden_size, num_layers, dropout_desc, input_mode,
bidirectional, mode, datatype): def __init__(self, handle, hidden_size, num_layers, dropout_desc, input_mode,
bidirectional, mode, datatype):
ptr = ctypes.c_void_p() ptr = ctypes.c_void_p()
check_error(lib.cudnnCreateRNNDescriptor(ctypes.byref(ptr))) check_error(lib.cudnnCreateRNNDescriptor(ctypes.byref(ptr)))
self._as_parameter_ = ptr self._as_parameter_ = ptr
if version() >= 6000:
check_error(lib.cudnnSetRNNDescriptor( check_error(lib.cudnnSetRNNDescriptor_v6(
self, handle,
hidden_size, self,
num_layers, hidden_size,
dropout_desc, num_layers,
input_mode, dropout_desc,
bidirectional, input_mode,
mode, bidirectional,
datatype mode,
)) CUDNN_RNN_ALGO_STANDARD,
datatype
))
else:
check_error(lib.cudnnSetRNNDescriptor(
self,
hidden_size,
num_layers,
dropout_desc,
input_mode,
bidirectional,
mode,
datatype
))
def __del__(self): def __del__(self):
check_error(lib.cudnnDestroyRNNDescriptor(self)) check_error(lib.cudnnDestroyRNNDescriptor(self))
class ConvolutionAlgoPerf(ctypes.Structure): class ConvolutionAlgoPerf_v5(ctypes.Structure):
_fields_ = [ _fields_ = [
("algo", ctypes.c_int), ("algo", ctypes.c_int),
("status", ctypes.c_int), ("status", ctypes.c_int),
@ -272,13 +307,27 @@ class ConvolutionAlgoPerf(ctypes.Structure):
("memory", ctypes.c_size_t), ("memory", ctypes.c_size_t),
] ]
class ConvolutionAlgoPerf_v6(ctypes.Structure):
_fields_ = [
("algo", ctypes.c_int),
("status", ctypes.c_int),
("time", ctypes.c_float),
("memory", ctypes.c_size_t),
("determinism", ctypes.c_int),
("reserved", ctypes.c_int * 4)
]
def check_error(status): def check_error(status):
if status is not 0: if status is not 0:
raise CuDNNError(status) raise CuDNNError(status)
def get_error_string(status): def get_error_string(status):
return lib.cudnnGetErrorString(status) return lib.cudnnGetErrorString(status)
def get_handle(): def get_handle():
if lib is None: if lib is None:
_loadlib() _loadlib()
@ -296,11 +345,12 @@ _typemap = {
} }
_sizeofmap = { _sizeofmap = {
CUDNN_DATA_HALF : 2, CUDNN_DATA_HALF: 2,
CUDNN_DATA_FLOAT : 4, CUDNN_DATA_FLOAT: 4,
CUDNN_DATA_DOUBLE : 8, CUDNN_DATA_DOUBLE: 8,
} }
def c_type(tensor): def c_type(tensor):
if isinstance(tensor, torch.cuda.HalfTensor): if isinstance(tensor, torch.cuda.HalfTensor):
return ctypes.c_float return ctypes.c_float
@ -311,10 +361,12 @@ def c_type(tensor):
else: else:
raise ValueError("unknown type '{}'".format(type(tensor))) raise ValueError("unknown type '{}'".format(type(tensor)))
def int_array(itr): def int_array(itr):
array_type = ctypes.c_int * len(itr) array_type = ctypes.c_int * len(itr)
return array_type(*itr) return array_type(*itr)
def descriptor(tensor, N=None): def descriptor(tensor, N=None):
if N is not None: if N is not None:
descriptor = TensorDescriptorArray(N) descriptor = TensorDescriptorArray(N)
@ -331,16 +383,21 @@ _autotuner_forward = {}
_autotuner_backward_data = {} _autotuner_backward_data = {}
_autotuner_backward_filter = {} _autotuner_backward_filter = {}
def convolution_autotuner_key(idesc, weight_desc, conv_desc): def convolution_autotuner_key(idesc, weight_desc, conv_desc):
return (idesc.as_tuple(), weight_desc.as_tuple(), conv_desc.as_tuple()) return (idesc.as_tuple(), weight_desc.as_tuple(), conv_desc.as_tuple())
def convolution_forward_algorithm(idesc, weight_desc, conv_desc, odesc): def convolution_forward_algorithm(idesc, weight_desc, conv_desc, odesc):
k = convolution_autotuner_key(idesc, weight_desc, conv_desc) k = convolution_autotuner_key(idesc, weight_desc, conv_desc)
if k in _autotuner_forward: if k in _autotuner_forward:
return _autotuner_forward[k] return _autotuner_forward[k]
if benchmark: if benchmark:
perf_results = ConvolutionAlgoPerf() if version() < 6000:
perf_results = ConvolutionAlgoPerf_v5()
else:
perf_results = ConvolutionAlgoPerf_v6()
algo_count = ctypes.c_int() algo_count = ctypes.c_int()
check_error(lib.cudnnFindConvolutionForwardAlgorithm( check_error(lib.cudnnFindConvolutionForwardAlgorithm(
get_handle(), idesc, weight_desc, conv_desc, odesc, 1, get_handle(), idesc, weight_desc, conv_desc, odesc, 1,
@ -360,15 +417,19 @@ def convolution_forward_algorithm(idesc, weight_desc, conv_desc, odesc):
wlimit, ctypes.byref(fwd_alg))) wlimit, ctypes.byref(fwd_alg)))
return fwd_alg return fwd_alg
def convolution_forward_workspace_size(*args): def convolution_forward_workspace_size(*args):
check_error(lib.cudnnGetConvolutionForwardWorkspaceSize(*args)) check_error(lib.cudnnGetConvolutionForwardWorkspaceSize(*args))
def convolution_forward(*args): def convolution_forward(*args):
check_error(lib.cudnnConvolutionForward(*args)) check_error(lib.cudnnConvolutionForward(*args))
def convolution_backward_data(*args): def convolution_backward_data(*args):
return check_error(lib.cudnnConvolutionBackwardData(*args)) return check_error(lib.cudnnConvolutionBackwardData(*args))
def convolution_backward_data_algorithm(weight_desc, odesc, conv_desc, idesc): def convolution_backward_data_algorithm(weight_desc, odesc, conv_desc, idesc):
k = convolution_autotuner_key(idesc, weight_desc, conv_desc) k = convolution_autotuner_key(idesc, weight_desc, conv_desc)
if k in _autotuner_backward_data: if k in _autotuner_backward_data:
@ -395,12 +456,15 @@ def convolution_backward_data_algorithm(weight_desc, odesc, conv_desc, idesc):
wlimit, ctypes.byref(bwd_data_alg))) wlimit, ctypes.byref(bwd_data_alg)))
return bwd_data_alg return bwd_data_alg
def convolution_backward_data_workspace_size(*args): def convolution_backward_data_workspace_size(*args):
return check_error(lib.cudnnGetConvolutionBackwardDataWorkspaceSize(*args)) return check_error(lib.cudnnGetConvolutionBackwardDataWorkspaceSize(*args))
def convolution_backward_filter(*args): def convolution_backward_filter(*args):
return check_error(lib.cudnnConvolutionBackwardFilter(*args)) return check_error(lib.cudnnConvolutionBackwardFilter(*args))
def convolution_backward_filter_algorithm(idesc, odesc, conv_desc, weight_desc): def convolution_backward_filter_algorithm(idesc, odesc, conv_desc, weight_desc):
k = convolution_autotuner_key(idesc, weight_desc, conv_desc) k = convolution_autotuner_key(idesc, weight_desc, conv_desc)
if k in _autotuner_backward_filter: if k in _autotuner_backward_filter:
@ -427,11 +491,14 @@ def convolution_backward_filter_algorithm(idesc, odesc, conv_desc, weight_desc):
wlimit, ctypes.byref(bwd_filter_alg))) wlimit, ctypes.byref(bwd_filter_alg)))
return bwd_filter_alg return bwd_filter_alg
def convolution_backward_filter_workspace_size(*args): def convolution_backward_filter_workspace_size(*args):
return check_error(lib.cudnnGetConvolutionBackwardFilterWorkspaceSize(*args)) return check_error(lib.cudnnGetConvolutionBackwardFilterWorkspaceSize(*args))
def convolution_backward_bias(*args): def convolution_backward_bias(*args):
check_error(lib.cudnnConvolutionBackwardBias(*args)) check_error(lib.cudnnConvolutionBackwardBias(*args))
def add_tensor(*args): def add_tensor(*args):
check_error(lib.cudnnAddTensor(*args)) check_error(lib.cudnnAddTensor(*args))

View File

@ -3,6 +3,7 @@ import torch.backends.cudnn as cudnn
from torch.backends.cudnn import check_error from torch.backends.cudnn import check_error
import ctypes import ctypes
def get_cudnn_mode(mode): def get_cudnn_mode(mode):
if mode == 'RNN_RELU': if mode == 'RNN_RELU':
return cudnn.CUDNN_RNN_RELU return cudnn.CUDNN_RNN_RELU
@ -17,9 +18,10 @@ def get_cudnn_mode(mode):
class Unserializable(object): class Unserializable(object):
def __init__(self, inner): def __init__(self, inner):
self.inner = inner self.inner = inner
def get(self): def get(self):
return self.inner return self.inner
@ -39,8 +41,10 @@ def init_dropout_descriptor(fn, handle):
fn.dropout_seed fn.dropout_seed
) )
def init_rnn_descriptor(fn):
def init_rnn_descriptor(fn, handle):
return cudnn.RNNDescriptor( return cudnn.RNNDescriptor(
handle,
fn.hidden_size, fn.hidden_size,
fn.num_layers, fn.num_layers,
fn.dropout_state['desc'].get(), fn.dropout_state['desc'].get(),
@ -80,7 +84,7 @@ def get_num_weights(handle, rnn_desc, x_desc, datatype):
datatype datatype
)) ))
elem_size = cudnn._sizeofmap[datatype] elem_size = cudnn._sizeofmap[datatype]
assert(weight_size.value % elem_size == 0) assert weight_size.value % elem_size == 0
return weight_size.value // elem_size return weight_size.value // elem_size
@ -139,10 +143,11 @@ def get_parameters(fn, handle, weight_buf):
ctypes.byref(nb_dims), ctypes.byref(nb_dims),
ctypes.c_void_p(filter_dim_a.data_ptr()))) ctypes.c_void_p(filter_dim_a.data_ptr())))
filter_dim_a.resize_(nb_dims.value) assert nb_dims.value <= min_dim
filter_dim_a = filter_dim_a[:nb_dims.value]
elem_size = cudnn._sizeofmap[fn.datatype] elem_size = cudnn._sizeofmap[fn.datatype]
offset_bytes = (matrix_pointer.value - weight_buf.data_ptr()) offset_bytes = (matrix_pointer.value - weight_buf.data_ptr())
assert(offset_bytes % elem_size == 0) assert offset_bytes % elem_size == 0
offset = offset_bytes // elem_size offset = offset_bytes // elem_size
# for all the RNN types provided by CUDNN, all the ih weights # for all the RNN types provided by CUDNN, all the ih weights
@ -151,17 +156,16 @@ def get_parameters(fn, handle, weight_buf):
# Since we're storing all the weights in a single tensor anyway, # Since we're storing all the weights in a single tensor anyway,
# might as well merge the CUDNN ones into a single tensor as well # might as well merge the CUDNN ones into a single tensor as well
if linear_id == 0 or linear_id == num_linear_layers / 2: if linear_id == 0 or linear_id == num_linear_layers / 2:
assert(filter_dim_a.prod() == filter_dim_a[0]) assert filter_dim_a.prod() == filter_dim_a[0]
param = fn.weight_buf.new().set_( param = fn.weight_buf.new().set_(
weight_buf.storage(), offset, weight_buf.storage(), offset,
filter_dim_a[0] * num_linear_layers // 2, filter_dim_a[2]) filter_dim_a[0] * num_linear_layers // 2, filter_dim_a[2])
layer_params.append(param) layer_params.append(param)
else: else:
assert(cur_offset == offset) assert cur_offset == offset
cur_offset = offset + filter_dim_a[0] cur_offset = offset + filter_dim_a[0]
params.append(layer_params) params.append(layer_params)
return params return params
@ -170,7 +174,7 @@ def get_parameters(fn, handle, weight_buf):
def _copyParams(params_from, params_to): def _copyParams(params_from, params_to):
for layer_params_from, layer_params_to in zip(params_from, params_to): for layer_params_from, layer_params_to in zip(params_from, params_to):
for param_from, param_to in zip(layer_params_from, layer_params_to): for param_from, param_to in zip(layer_params_from, layer_params_to):
assert(param_from.type() == param_to.type()) assert param_from.type() == param_to.type()
param_to.copy_(param_from) param_to.copy_(param_from)
@ -204,9 +208,9 @@ def forward(fn, input, hx, weight, output, hy):
output_size = _output_size(fn) output_size = _output_size(fn)
x = input.contiguous() x = input.contiguous()
output.resize_(*output_size) output.resize_(*output_size)
hy.resize_(*hidden_size).zero_() hy.resize_(*hidden_size)
if cy is not None: if cy is not None:
cy.resize_(*hidden_size).zero_() cy.resize_(*hidden_size)
y = output y = output
# init descriptors # init descriptors
@ -214,7 +218,7 @@ def forward(fn, input, hx, weight, output, hy):
fn.dropout_state['desc'] = Unserializable( fn.dropout_state['desc'] = Unserializable(
init_dropout_descriptor(fn, handle) init_dropout_descriptor(fn, handle)
) )
fn.rnn_desc = init_rnn_descriptor(fn) fn.rnn_desc = init_rnn_descriptor(fn, handle)
fn.x_descs = cudnn.descriptor(x[0], fn.seq_length) fn.x_descs = cudnn.descriptor(x[0], fn.seq_length)
fn.y_descs = cudnn.descriptor(y[0], fn.seq_length) fn.y_descs = cudnn.descriptor(y[0], fn.seq_length)
fn.hx_desc = cudnn.descriptor(hx) fn.hx_desc = cudnn.descriptor(hx)
@ -237,7 +241,7 @@ def forward(fn, input, hx, weight, output, hy):
if tuple(hx.size()) != hidden_size: if tuple(hx.size()) != hidden_size:
raise RuntimeError('Expected hidden size {}, got {}'.format( raise RuntimeError('Expected hidden size {}, got {}'.format(
hidden_size, tuple(hx.size()))) hidden_size, tuple(hx.size())))
if cx is not None and tuple(cx.size()) != hidden_size: if cx is not None and tuple(cx.size()) != hidden_size:
raise RuntimeError('Expected cell size {}, got {}'.format( raise RuntimeError('Expected cell size {}, got {}'.format(
hidden_size, tuple(cx.size()))) hidden_size, tuple(cx.size())))
@ -295,7 +299,6 @@ def forward(fn, input, hx, weight, output, hy):
output = output.transpose_(0, 1) output = output.transpose_(0, 1)
def backward_grad(fn, input, hx, weight, output, grad_output, grad_hy, grad_input, grad_hx): def backward_grad(fn, input, hx, weight, output, grad_output, grad_hy, grad_input, grad_hx):
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
handle = cudnn.get_handle() handle = cudnn.get_handle()
@ -321,8 +324,8 @@ def backward_grad(fn, input, hx, weight, output, grad_output, grad_hy, grad_inpu
y = output y = output
w = fn.weight_buf w = fn.weight_buf
dx = grad_input.resize_as_(input) dx = grad_input.resize_as_(input)
dhy = grad_hy.resize_(*hidden_size) dhy = grad_hy.contiguous().view(*hidden_size)
dcy = grad_cy.resize_(*hidden_size) if grad_cy is not None else None dcy = grad_cy.contiguous().view(*hidden_size) if grad_cy is not None else None
dhx = grad_hx.resize_(*hidden_size) dhx = grad_hx.resize_(*hidden_size)
dcx = grad_cx.resize_(*hidden_size) if grad_cx is not None else None dcx = grad_cx.resize_(*hidden_size) if grad_cx is not None else None

View File

@ -697,8 +697,30 @@ bool THCSPShortTensor_init(PyObject *module);
bool THCSPCharTensor_init(PyObject *module); bool THCSPCharTensor_init(PyObject *module);
bool THCSPByteTensor_init(PyObject *module); bool THCSPByteTensor_init(PyObject *module);
bool THDPDoubleStorage_init(PyObject *module);
bool THDPFloatStorage_init(PyObject *module);
//bool THDPHalfStorage_init(PyObject *module);
bool THDPLongStorage_init(PyObject *module);
bool THDPIntStorage_init(PyObject *module);
bool THDPShortStorage_init(PyObject *module);
bool THDPCharStorage_init(PyObject *module);
bool THDPByteStorage_init(PyObject *module);
bool THDPDoubleTensor_init(PyObject *module);
bool THDPFloatTensor_init(PyObject *module);
//bool THDPHalfTensor_init(PyObject *module);
bool THDPLongTensor_init(PyObject *module);
bool THDPIntTensor_init(PyObject *module);
bool THDPShortTensor_init(PyObject *module);
bool THDPCharTensor_init(PyObject *module);
bool THDPByteTensor_init(PyObject *module);
static std::vector<PyMethodDef> methods; static std::vector<PyMethodDef> methods;
#ifdef WITH_DISTRIBUTED
PyMethodDef* THDPModule_methods();
#endif
#if PY_MAJOR_VERSION == 2 #if PY_MAJOR_VERSION == 2
PyMODINIT_FUNC init_C() PyMODINIT_FUNC init_C()
#else #else
@ -716,6 +738,9 @@ PyMODINIT_FUNC PyInit__C()
#ifdef WITH_CUDNN #ifdef WITH_CUDNN
THPUtils_addPyMethodDefs(methods, THCUDNN_methods()); THPUtils_addPyMethodDefs(methods, THCUDNN_methods());
#endif #endif
#ifdef WITH_DISTRIBUTED
THPUtils_addPyMethodDefs(methods, THDPModule_methods());
#endif
#if PY_MAJOR_VERSION == 2 #if PY_MAJOR_VERSION == 2
ASSERT_TRUE(module = Py_InitModule("torch._C", methods.data())); ASSERT_TRUE(module = Py_InitModule("torch._C", methods.data()));
@ -729,6 +754,7 @@ PyMODINIT_FUNC PyInit__C()
}; };
ASSERT_TRUE(module = PyModule_Create(&torchmodule)); ASSERT_TRUE(module = PyModule_Create(&torchmodule));
#endif #endif
ASSERT_TRUE(THPWrapper_init(module));
ASSERT_TRUE(THPGenerator_init(module)); ASSERT_TRUE(THPGenerator_init(module));
ASSERT_TRUE(THPException_init(module)); ASSERT_TRUE(THPException_init(module));
ASSERT_TRUE(THPSize_init(module)); ASSERT_TRUE(THPSize_init(module));
@ -796,7 +822,6 @@ PyMODINIT_FUNC PyInit__C()
#endif #endif
#ifdef WITH_CUDNN #ifdef WITH_CUDNN
ASSERT_TRUE(THCUDNNModule_initModule(module));
PyObject *has_cudnn = Py_True; PyObject *has_cudnn = Py_True;
#else #else
PyObject *has_cudnn = Py_False; PyObject *has_cudnn = Py_False;
@ -804,6 +829,28 @@ PyMODINIT_FUNC PyInit__C()
Py_INCREF(has_cudnn); Py_INCREF(has_cudnn);
ASSERT_TRUE(PyModule_AddObject(module, "has_cudnn", has_cudnn) == 0); ASSERT_TRUE(PyModule_AddObject(module, "has_cudnn", has_cudnn) == 0);
// TODO THD: enable once master-worker mode is implemented
#if 0 && defined(WITH_DISTRIBUTED)
// See comment on CUDA objects
ASSERT_TRUE(THDPDoubleStorage_init(module));
ASSERT_TRUE(THDPFloatStorage_init(module));
//ASSERT_TRUE(THDPHalfStorage_init(module));
ASSERT_TRUE(THDPLongStorage_init(module));
ASSERT_TRUE(THDPIntStorage_init(module));
ASSERT_TRUE(THDPShortStorage_init(module));
ASSERT_TRUE(THDPCharStorage_init(module));
ASSERT_TRUE(THDPByteStorage_init(module));
ASSERT_TRUE(THDPDoubleTensor_init(module));
ASSERT_TRUE(THDPFloatTensor_init(module));
//ASSERT_TRUE(THDPHalfTensor_init(module));
ASSERT_TRUE(THDPLongTensor_init(module));
ASSERT_TRUE(THDPIntTensor_init(module));
ASSERT_TRUE(THDPShortTensor_init(module));
ASSERT_TRUE(THDPCharTensor_init(module));
ASSERT_TRUE(THDPByteTensor_init(module));
#endif
THPDefaultGenerator = (THPGenerator*)THPGenerator_New(); THPDefaultGenerator = (THPGenerator*)THPGenerator_New();
ASSERT_TRUE(THPDefaultGenerator != nullptr); ASSERT_TRUE(THPDefaultGenerator != nullptr);
ASSERT_TRUE(PyModule_AddObject(module, "default_generator", (PyObject*)THPDefaultGenerator) == 0); ASSERT_TRUE(PyModule_AddObject(module, "default_generator", (PyObject*)THPDefaultGenerator) == 0);

View File

@ -52,7 +52,7 @@ static void THPWrapper_dealloc(THPWrapper* self)
PyTypeObject THPWrapperType = { PyTypeObject THPWrapperType = {
PyVarObject_HEAD_INIT(NULL, 0) PyVarObject_HEAD_INIT(NULL, 0)
"torch._C._CppWrapper", /* tp_name */ "torch._C._PtrWrapper", /* tp_name */
sizeof(THPWrapper), /* tp_basicsize */ sizeof(THPWrapper), /* tp_basicsize */
0, /* tp_itemsize */ 0, /* tp_itemsize */
(destructor)THPWrapper_dealloc, /* tp_dealloc */ (destructor)THPWrapper_dealloc, /* tp_dealloc */

View File

@ -1,5 +1,5 @@
#ifndef THP_CUDNN_CPP_WRAPPER_INC #ifndef THP_PTR_WRAPPER_H
#define THP_CUDNN_CPP_WRAPPER_INC #define THP_PTR_WRAPPER_H
#include <functional> #include <functional>

View File

@ -24,18 +24,17 @@ PyObject * THPSize_New(int dim, long *sizes)
static PyObject * THPSize_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs) static PyObject * THPSize_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{ {
PyObject *self = PyTuple_Type.tp_new(type, args, kwargs); THPObjectPtr self = PyTuple_Type.tp_new(type, args, kwargs);
if (self) { if (self) {
for (Py_ssize_t i = 0; i < PyTuple_Size(self); ++i) { for (Py_ssize_t i = 0; i < PyTuple_Size(self); ++i) {
PyObject *item = PyTuple_GET_ITEM(self, i); PyObject *item = PyTuple_GET_ITEM(self.get(), i);
if (!THPUtils_checkLong(item)) { if (!THPUtils_checkLong(item)) {
Py_DECREF(self);
return PyErr_Format(PyExc_TypeError, "torch.Size() takes an iterable of 'int' (item %zd is '%s')", return PyErr_Format(PyExc_TypeError, "torch.Size() takes an iterable of 'int' (item %zd is '%s')",
i, Py_TYPE(item)->tp_name); i, Py_TYPE(item)->tp_name);
} }
} }
} }
return self; return self.release();
} }
static PyObject * THPSize_repr(THPSize *self) static PyObject * THPSize_repr(THPSize *self)

View File

@ -21,6 +21,7 @@
#define THP_API extern "C" #define THP_API extern "C"
#include "PtrWrapper.h"
#include "Exceptions.h" #include "Exceptions.h"
#include "Generator.h" #include "Generator.h"
#include "Storage.h" #include "Storage.h"

View File

@ -1,6 +1,7 @@
#include "BatchNorm.h" #include "BatchNorm.h"
#include "Descriptors.h" #include "Descriptors.h"
#include "Types.h"
namespace torch { namespace cudnn { namespace torch { namespace cudnn {
@ -78,6 +79,11 @@ void cudnn_batch_norm_forward(
Constant one(dataType, 1); Constant one(dataType, 1);
Constant zero(dataType, 0); Constant zero(dataType, 0);
if (training) { if (training) {
THVoidTensor_assertContiguous(bias);
THVoidTensor_assertContiguous(running_mean);
THVoidTensor_assertContiguous(running_var);
THVoidTensor_assertContiguous(save_mean);
THVoidTensor_assertContiguous(save_var);
CHECK(cudnnBatchNormalizationForwardTraining( CHECK(cudnnBatchNormalizationForwardTraining(
handle, mode, &one, &zero, handle, mode, &one, &zero,
idesc.desc, tensorPointer(dataType, input), idesc.desc, tensorPointer(dataType, input),
@ -91,6 +97,9 @@ void cudnn_batch_norm_forward(
tensorPointer(dataType, save_mean), tensorPointer(dataType, save_mean),
tensorPointer(dataType, save_var))); tensorPointer(dataType, save_var)));
} else { } else {
THVoidTensor_assertContiguous(bias);
THVoidTensor_assertContiguous(running_mean);
THVoidTensor_assertContiguous(running_var);
CHECK(cudnnBatchNormalizationForwardInference( CHECK(cudnnBatchNormalizationForwardInference(
handle, mode, &one, &zero, handle, mode, &one, &zero,
idesc.desc, tensorPointer(dataType, input), idesc.desc, tensorPointer(dataType, input),
@ -129,6 +138,10 @@ void cudnn_batch_norm_backward(
Constant one(dataType, 1); Constant one(dataType, 1);
Constant zero(dataType, 0); Constant zero(dataType, 0);
THVoidTensor_assertContiguous(grad_weight);
THVoidTensor_assertContiguous(grad_bias);
THVoidTensor_assertContiguous(save_mean);
THVoidTensor_assertContiguous(save_var);
CHECK(cudnnBatchNormalizationBackward( CHECK(cudnnBatchNormalizationBackward(
handle, mode, &one, &zero, &one, &one, handle, mode, &one, &zero, &one, &one,
idesc.desc, tensorPointer(dataType, input), idesc.desc, tensorPointer(dataType, input),

View File

@ -2,6 +2,7 @@
#include "THC/THC.h" #include "THC/THC.h"
#include "Exceptions.h" #include "Exceptions.h"
#include "Types.h"
#include <cudnn.h> #include <cudnn.h>
#include <functional> #include <functional>
@ -31,6 +32,7 @@ void setWeightDescriptor(FilterDescriptor& desc, cudnnDataType_t dataType, THVoi
{ {
CHECK_ARG(weight->nDimension <= 5); CHECK_ARG(weight->nDimension <= 5);
int weightSize[5]; int weightSize[5];
THVoidTensor_assertContiguous(weight);
for (int i = 0; i < weight->nDimension; ++i) { for (int i = 0; i < weight->nDimension; ++i) {
weightSize[i] = (int) weight->size[i]; weightSize[i] = (int) weight->size[i];
} }
@ -63,13 +65,13 @@ struct BenchmarkCache {
std::mutex mutex; std::mutex mutex;
std::unordered_map<ConvolutionParams, T, ParamsHash, ParamsEqual> map; std::unordered_map<ConvolutionParams, T, ParamsHash, ParamsEqual> map;
bool find(const ConvolutionParams& params, T& results) { bool find(const ConvolutionParams& params, T* results) {
std::lock_guard<std::mutex> guard(mutex); std::lock_guard<std::mutex> guard(mutex);
auto it = map.find(params); auto it = map.find(params);
if (it == map.end()) { if (it == map.end()) {
return false; return false;
} }
results = it->second; *results = it->second;
return true; return true;
} }
@ -84,77 +86,145 @@ BenchmarkCache<cudnnConvolutionBwdDataAlgo_t> bwd_data_algos;
BenchmarkCache<cudnnConvolutionBwdFilterAlgo_t> bwd_filter_algos; BenchmarkCache<cudnnConvolutionBwdFilterAlgo_t> bwd_filter_algos;
struct Workspace { struct Workspace {
void* data; Workspace(THCState* state, size_t size) : state(state), size(size), data(NULL) {
THCState* state;
Workspace(THCState* state, size_t size) : data(NULL), state(state) {
CUDA_CHECK(THCudaMalloc(state, &data, size)); CUDA_CHECK(THCudaMalloc(state, &data, size));
} }
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
~Workspace() { ~Workspace() {
THCudaFree(state, data); if (data) {
THCudaFree(state, data);
}
} }
THCState* state;
size_t size;
void* data;
}; };
cudnnConvolutionFwdAlgo_t chooseForwardAlgorithm( template<typename algo_t>
cudnnHandle_t handle, const Convolution& conv, bool benchmark) struct algorithm_search {
{ };
cudnnConvolutionFwdAlgo_t algo;
if (benchmark) { template<>
if (fwd_algos.find(conv.params, algo)) { struct algorithm_search<cudnnConvolutionFwdAlgo_t> {
return algo; static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
} static BenchmarkCache<cudnnConvolutionFwdAlgo_t>& cache() {
return fwd_algos;
}
static cudnnConvolutionFwdAlgoPerf_t findAlgorithm(cudnnHandle_t handle, const Convolution& conv) {
int algoCount; int algoCount;
cudnnConvolutionFwdAlgoPerf_t perfResults; cudnnConvolutionFwdAlgoPerf_t perfResults;
CHECK(cudnnFindConvolutionForwardAlgorithm(handle, conv.idesc.desc, CHECK(cudnnFindConvolutionForwardAlgorithm(handle, conv.idesc.desc,
conv.wdesc.desc, conv.cdesc.desc, conv.odesc.desc, 1, &algoCount, &perfResults)); conv.wdesc.desc, conv.cdesc.desc, conv.odesc.desc, 1, &algoCount, &perfResults));
fwd_algos.insert(conv.params, perfResults.algo); return perfResults;
return perfResults.algo;
} }
cudnnConvolutionFwdPreference_t pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
CHECK(cudnnGetConvolutionForwardAlgorithm(handle, conv.idesc.desc,
conv.wdesc.desc, conv.cdesc.desc, conv.odesc.desc, pref, 0, &algo));
return algo;
}
cudnnConvolutionBwdDataAlgo_t chooseBackwardDataAlgorithm( static void getAlgorithm(cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionFwdAlgo_t* algo) {
cudnnHandle_t handle, const Convolution& conv, bool benchmark) cudnnConvolutionFwdPreference_t pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
{ CHECK(cudnnGetConvolutionForwardAlgorithm(handle, conv.idesc.desc,
cudnnConvolutionBwdDataAlgo_t algo; conv.wdesc.desc, conv.cdesc.desc, conv.odesc.desc, pref, 0, algo));
if (benchmark) { }
if (bwd_data_algos.find(conv.params, algo)) {
return algo; static void getWorkspaceSize(cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionFwdAlgo_t algo, size_t* workspaceSize) {
} CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle, conv.idesc.desc, conv.wdesc.desc,
conv.cdesc.desc, conv.odesc.desc, algo, workspaceSize));
}
};
template<>
struct algorithm_search<cudnnConvolutionBwdDataAlgo_t> {
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
static BenchmarkCache<cudnnConvolutionBwdDataAlgo_t>& cache() {
return bwd_data_algos;
}
static cudnnConvolutionBwdDataAlgoPerf_t findAlgorithm(cudnnHandle_t handle, const Convolution& conv) {
int algoCount; int algoCount;
cudnnConvolutionBwdDataAlgoPerf_t perfResults; cudnnConvolutionBwdDataAlgoPerf_t perfResults;
CHECK(cudnnFindConvolutionBackwardDataAlgorithm(handle, conv.wdesc.desc, CHECK(cudnnFindConvolutionBackwardDataAlgorithm(handle, conv.wdesc.desc,
conv.odesc.desc, conv.cdesc.desc, conv.idesc.desc, 1, &algoCount, &perfResults)); conv.odesc.desc, conv.cdesc.desc, conv.idesc.desc, 1, &algoCount, &perfResults));
bwd_data_algos.insert(conv.params, perfResults.algo); return perfResults;
return perfResults.algo;
} }
cudnnConvolutionBwdDataPreference_t pref = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST;
CHECK(cudnnGetConvolutionBackwardDataAlgorithm(handle, conv.wdesc.desc,
conv.odesc.desc, conv.cdesc.desc, conv.idesc.desc, pref, 0, &algo));
return algo;
}
cudnnConvolutionBwdFilterAlgo_t chooseBackwardFilterAlgorithm( static void getAlgorithm(cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionBwdDataAlgo_t* algo) {
cudnnHandle_t handle, const Convolution& conv, bool benchmark) CHECK(cudnnGetConvolutionBackwardDataAlgorithm(handle, conv.wdesc.desc,
{ conv.odesc.desc, conv.cdesc.desc, conv.idesc.desc,
cudnnConvolutionBwdFilterAlgo_t algo; CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, algo));
if (benchmark) { }
if (bwd_filter_algos.find(conv.params, algo)) {
return algo; static void getWorkspaceSize(cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionBwdDataAlgo_t algo, size_t* workspaceSize) {
} CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(handle, conv.wdesc.desc,
conv.odesc.desc, conv.cdesc.desc, conv.idesc.desc, algo,
workspaceSize));
}
};
template<>
struct algorithm_search<cudnnConvolutionBwdFilterAlgo_t> {
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
static BenchmarkCache<cudnnConvolutionBwdFilterAlgo_t>& cache() {
return bwd_filter_algos;
}
static cudnnConvolutionBwdFilterAlgoPerf_t findAlgorithm(cudnnHandle_t handle, const Convolution& conv) {
int algoCount; int algoCount;
cudnnConvolutionBwdFilterAlgoPerf_t perfResults; cudnnConvolutionBwdFilterAlgoPerf_t perfResults;
CHECK(cudnnFindConvolutionBackwardFilterAlgorithm(handle, conv.idesc.desc, CHECK(cudnnFindConvolutionBackwardFilterAlgorithm(handle, conv.idesc.desc,
conv.odesc.desc, conv.cdesc.desc, conv.wdesc.desc, 1, &algoCount, &perfResults)); conv.odesc.desc, conv.cdesc.desc, conv.wdesc.desc, 1, &algoCount, &perfResults));
bwd_filter_algos.insert(conv.params, perfResults.algo); return perfResults;
return perfResults.algo; }
static void getAlgorithm(cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionBwdFilterAlgo_t* algo) {
CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(handle, conv.idesc.desc,
conv.odesc.desc, conv.cdesc.desc, conv.wdesc.desc,
CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, algo));
}
static void getWorkspaceSize(cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionBwdFilterAlgo_t algo, size_t* workspaceSize) {
CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(handle, conv.idesc.desc,
conv.odesc.desc, conv.cdesc.desc, conv.wdesc.desc, algo, workspaceSize));
}
};
template<typename algo_t>
Workspace chooseAlgorithm(
THCState* state, cudnnHandle_t handle, const Convolution& conv,
bool benchmark, algo_t* algo)
{
using search = algorithm_search<algo_t>;
auto& cache = search::cache();
if (!cache.find(conv.params, algo)) {
if (benchmark) {
auto perfResults = search::findAlgorithm(handle, conv);
if (perfResults.status == CUDNN_STATUS_SUCCESS) {
*algo = perfResults.algo;
} else {
*algo = search::DEFAULT_ALGO;
}
cache.insert(conv.params, *algo);
} else {
search::getAlgorithm(handle, conv, algo);
}
}
size_t workspace_size;
search::getWorkspaceSize(handle, conv, *algo, &workspace_size);
try {
return Workspace(state, workspace_size);
} catch (std::runtime_error& e) {
cudaGetLastError(); // clear OOM error
// switch to default algorithm and record it in the cache to prevent
// further OOM errors
*algo = search::DEFAULT_ALGO;
cache.insert(conv.params, *algo);
search::getWorkspaceSize(handle, conv, *algo, &workspace_size);
return Workspace(state, workspace_size);
} }
cudnnConvolutionBwdFilterPreference_t pref = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST;
CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(handle, conv.idesc.desc,
conv.odesc.desc, conv.cdesc.desc, conv.wdesc.desc, pref, 0, &algo));
return algo;
} }
void* tensorPointer(cudnnDataType_t dataType, THVoidTensor* tensor, int groupIdx, int groups, int dim) void* tensorPointer(cudnnDataType_t dataType, THVoidTensor* tensor, int groupIdx, int groups, int dim)
@ -179,7 +249,7 @@ static_assert(std::is_pod<ConvolutionParams>::value, "ConvolutionParams not POD"
Convolution::Convolution( Convolution::Convolution(
cudnnDataType_t dataType, THVoidTensor* input, THVoidTensor* weight, cudnnDataType_t dataType, THVoidTensor* input, THVoidTensor* weight,
THVoidTensor* bias, THVoidTensor* output, std::vector<int> pad, THVoidTensor* bias, THVoidTensor* output, std::vector<int> pad,
std::vector<int> stride, int groups, bool transposed) std::vector<int> stride, std::vector<int> dilation, int groups, bool transposed)
: idesc(), odesc(), odesc_bias(), bdesc(), wdesc(), cdesc(), groups(groups) : idesc(), odesc(), odesc_bias(), bdesc(), wdesc(), cdesc(), groups(groups)
, transposed(transposed) , transposed(transposed)
{ {
@ -197,6 +267,7 @@ Convolution::Convolution(
for (size_t i = 0; i != pad.size(); ++i) { for (size_t i = 0; i != pad.size(); ++i) {
params.pad[i] = pad[i]; params.pad[i] = pad[i];
params.stride[i] = stride[i]; params.stride[i] = stride[i];
params.dilation[i] = dilation[i];
} }
params.groups = groups; params.groups = groups;
setTensorDescriptor(idesc, dataType, input, groups); setTensorDescriptor(idesc, dataType, input, groups);
@ -206,7 +277,7 @@ Convolution::Convolution(
else else
setTensorDescriptor(odesc_bias, dataType, input, 1); setTensorDescriptor(odesc_bias, dataType, input, 1);
setWeightDescriptor(wdesc, dataType, weight, groups); setWeightDescriptor(wdesc, dataType, weight, groups);
cdesc.set(dataType, pad.size(), pad.data(), stride.data()); cdesc.set(dataType, pad.size(), pad.data(), stride.data(), dilation.data());
} }
void cudnn_convolution_forward( void cudnn_convolution_forward(
@ -215,18 +286,9 @@ void cudnn_convolution_forward(
Convolution* info, bool benchmark) Convolution* info, bool benchmark)
{ {
int groups = info->groups; int groups = info->groups;
TensorDescriptor& idesc = info->idesc;
TensorDescriptor& odesc = info->odesc;
FilterDescriptor& wdesc = info->wdesc;
ConvolutionDescriptor& cdesc = info->cdesc;
cudnnConvolutionFwdAlgo_t fwdAlg = chooseForwardAlgorithm(handle, *info, benchmark); cudnnConvolutionFwdAlgo_t fwdAlg;
Workspace workspace = chooseAlgorithm(state, handle, *info, benchmark, &fwdAlg);
size_t workspaceSize;
CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle, idesc.desc, wdesc.desc,
cdesc.desc, odesc.desc, fwdAlg, &workspaceSize));
Workspace workspace(state, workspaceSize);
Constant one(dataType, 1); Constant one(dataType, 1);
Constant zero(dataType, 0); Constant zero(dataType, 0);
@ -236,9 +298,9 @@ void cudnn_convolution_forward(
void* weight_ptr = tensorPointer(dataType, weight, i, groups, 0); void* weight_ptr = tensorPointer(dataType, weight, i, groups, 0);
CHECK(cudnnConvolutionForward( CHECK(cudnnConvolutionForward(
handle, &one, idesc.desc, input_ptr, wdesc.desc, handle, &one, info->idesc.desc, input_ptr, info->wdesc.desc,
weight_ptr, cdesc.desc, fwdAlg, workspace.data, weight_ptr, info->cdesc.desc, fwdAlg, workspace.data,
workspaceSize, &zero, odesc.desc, output_ptr)); workspace.size, &zero, info->odesc.desc, output_ptr));
} }
} }
@ -248,7 +310,6 @@ void cudnn_convolution_add_bias(
Convolution* info) Convolution* info)
{ {
CHECK_ARG(output->nDimension <= 5); CHECK_ARG(output->nDimension <= 5);
TensorDescriptor& odesc_bias = info->odesc_bias;
TensorDescriptor& bdesc = info->bdesc; TensorDescriptor& bdesc = info->bdesc;
int size[5] = { 1, (int)bias->size[0], 1, 1, 1 }; int size[5] = { 1, (int)bias->size[0], 1, 1, 1 };
@ -260,7 +321,7 @@ void cudnn_convolution_add_bias(
Constant one(dataType, 1); Constant one(dataType, 1);
CHECK(cudnnAddTensor(handle, &one, bdesc.desc, bias_ptr, &one, CHECK(cudnnAddTensor(handle, &one, bdesc.desc, bias_ptr, &one,
odesc_bias.desc, output_ptr)); info->odesc_bias.desc, output_ptr));
} }
void cudnn_convolution_backward_data( void cudnn_convolution_backward_data(
@ -268,20 +329,11 @@ void cudnn_convolution_backward_data(
THVoidTensor* gradOutput, THVoidTensor* gradInput, THVoidTensor* weight, THVoidTensor* gradOutput, THVoidTensor* gradInput, THVoidTensor* weight,
Convolution* info, bool benchmark) Convolution* info, bool benchmark)
{ {
TensorDescriptor& idesc = info->idesc;
TensorDescriptor& odesc = info->odesc;
FilterDescriptor& wdesc = info->wdesc;
ConvolutionDescriptor& cdesc = info->cdesc;
int groups = info->params.groups; int groups = info->params.groups;
cudnnConvolutionBwdDataAlgo_t bwdDataAlg = cudnnConvolutionBwdDataAlgo_t bwdDataAlg;
chooseBackwardDataAlgorithm(handle, *info, benchmark); Workspace workspace = chooseAlgorithm(state, handle, *info, benchmark, &bwdDataAlg);
size_t workspaceSize;
CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(handle, wdesc.desc,
odesc.desc, cdesc.desc, idesc.desc, bwdDataAlg, &workspaceSize));
Workspace workspace(state, workspaceSize);
Constant one(dataType, 1); Constant one(dataType, 1);
Constant zero(dataType, 0); Constant zero(dataType, 0);
for (int i = 0; i < groups; ++i) { for (int i = 0; i < groups; ++i) {
@ -290,9 +342,9 @@ void cudnn_convolution_backward_data(
void* weight_ptr = tensorPointer(dataType, weight, i, groups, 0); void* weight_ptr = tensorPointer(dataType, weight, i, groups, 0);
CHECK(cudnnConvolutionBackwardData( CHECK(cudnnConvolutionBackwardData(
handle, &one, wdesc.desc, weight_ptr, odesc.desc, gradOutput_ptr, handle, &one, info->wdesc.desc, weight_ptr, info->odesc.desc, gradOutput_ptr,
cdesc.desc, bwdDataAlg, workspace.data, workspaceSize, &zero, info->cdesc.desc, bwdDataAlg, workspace.data, workspace.size, &zero,
idesc.desc, gradInput_ptr)); info->idesc.desc, gradInput_ptr));
} }
} }
@ -301,20 +353,11 @@ void cudnn_convolution_backward_filter(
THVoidTensor* gradOutput, THVoidTensor* input, THVoidTensor* gradWeight, THVoidTensor* gradOutput, THVoidTensor* input, THVoidTensor* gradWeight,
Convolution* info, bool benchmark) Convolution* info, bool benchmark)
{ {
TensorDescriptor& idesc = info->idesc;
TensorDescriptor& odesc = info->odesc;
FilterDescriptor& wdesc = info->wdesc;
ConvolutionDescriptor& cdesc = info->cdesc;
int groups = info->params.groups; int groups = info->params.groups;
cudnnConvolutionBwdFilterAlgo_t bwdFilterAlg = cudnnConvolutionBwdFilterAlgo_t bwdFilterAlg;
chooseBackwardFilterAlgorithm(handle, *info, benchmark); Workspace workspace = chooseAlgorithm(state, handle, *info, benchmark, &bwdFilterAlg);
size_t workspaceSize;
CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(handle, idesc.desc,
odesc.desc, cdesc.desc, wdesc.desc, bwdFilterAlg, &workspaceSize));
Workspace workspace(state, workspaceSize);
Constant one(dataType, 1); Constant one(dataType, 1);
Constant zero(dataType, 0); Constant zero(dataType, 0);
for (int i = 0; i < groups; ++i) { for (int i = 0; i < groups; ++i) {
@ -327,9 +370,9 @@ void cudnn_convolution_backward_filter(
} }
CHECK(cudnnConvolutionBackwardFilter( CHECK(cudnnConvolutionBackwardFilter(
handle, &one, idesc.desc, input_ptr, odesc.desc, gradOutput_ptr, handle, &one, info->idesc.desc, input_ptr, info->odesc.desc, gradOutput_ptr,
cdesc.desc, bwdFilterAlg, workspace.data, workspaceSize, &zero, info->cdesc.desc, bwdFilterAlg, workspace.data, workspace.size, &zero,
wdesc.desc, gradWeight_ptr)); info->wdesc.desc, gradWeight_ptr));
} }
} }
@ -337,26 +380,23 @@ void cudnn_convolution_backward_bias(
THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType, THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType,
THVoidTensor* gradOutput, THVoidTensor* gradBias, Convolution* info) THVoidTensor* gradOutput, THVoidTensor* gradBias, Convolution* info)
{ {
TensorDescriptor& bdesc = info->bdesc;
TensorDescriptor& odesc_bias = info->odesc_bias;
Constant one(dataType, 1); Constant one(dataType, 1);
Constant zero(dataType, 0); Constant zero(dataType, 0);
void* gradOutput_ptr = tensorPointer(dataType, gradOutput, 0, 1, 0); void* gradOutput_ptr = tensorPointer(dataType, gradOutput, 0, 1, 0);
void* gradBias_ptr = tensorPointer(dataType, gradBias, 0, 1, 0); void* gradBias_ptr = tensorPointer(dataType, gradBias, 0, 1, 0);
CHECK(cudnnConvolutionBackwardBias( CHECK(cudnnConvolutionBackwardBias(
handle, &one, odesc_bias.desc, gradOutput_ptr, &zero, bdesc.desc, handle, &one, info->odesc_bias.desc, gradOutput_ptr, &zero,
gradBias_ptr)); info->bdesc.desc, gradBias_ptr));
} }
Convolution* cudnn_convolution_full_forward( Convolution* cudnn_convolution_full_forward(
THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType, THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType,
THVoidTensor* input, THVoidTensor* weight, THVoidTensor* bias, THVoidTensor* output, THVoidTensor* input, THVoidTensor* weight, THVoidTensor* bias, THVoidTensor* output,
std::vector<int> pad, std::vector<int> stride, int groups, bool benchmark) std::vector<int> pad, std::vector<int> stride, std::vector<int> dilation, int groups, bool benchmark)
{ {
std::unique_ptr<Convolution> info(new Convolution( std::unique_ptr<Convolution> info(new Convolution(
dataType, input, weight, bias, output, pad, stride, groups, false)); dataType, input, weight, bias, output, pad, stride, dilation, groups, false));
cudnn_convolution_forward( cudnn_convolution_forward(
state, handle, dataType, input, weight, output, info.get(), benchmark); state, handle, dataType, input, weight, output, info.get(), benchmark);
if (bias) { if (bias) {
@ -369,10 +409,10 @@ Convolution* cudnn_convolution_full_forward(
Convolution* cudnn_convolution_transpose_full_forward( Convolution* cudnn_convolution_transpose_full_forward(
THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType, THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType,
THVoidTensor* input, THVoidTensor* weight, THVoidTensor* bias, THVoidTensor* output, THVoidTensor* input, THVoidTensor* weight, THVoidTensor* bias, THVoidTensor* output,
std::vector<int> pad, std::vector<int> stride, int groups, bool benchmark) std::vector<int> pad, std::vector<int> stride, std::vector<int> dilation, int groups, bool benchmark)
{ {
std::unique_ptr<Convolution> info(new Convolution( std::unique_ptr<Convolution> info(new Convolution(
dataType, output, weight, bias, input, pad, stride, groups, true)); dataType, output, weight, bias, input, pad, stride, dilation, groups, true));
cudnn_convolution_backward_data( cudnn_convolution_backward_data(
state, handle, dataType, input, output, weight, info.get(), benchmark); state, handle, dataType, input, output, weight, info.get(), benchmark);
if (bias) { if (bias) {

View File

@ -18,6 +18,7 @@ struct ConvolutionParams
int weight_size[5]; int weight_size[5];
int pad[3]; int pad[3];
int stride[3]; int stride[3];
int dilation[3];
int groups; int groups;
}; };
@ -41,7 +42,7 @@ struct Convolution
Convolution( Convolution(
cudnnDataType_t dataType, THVoidTensor* input, THVoidTensor* weight, cudnnDataType_t dataType, THVoidTensor* input, THVoidTensor* weight,
THVoidTensor* bias, THVoidTensor* output, std::vector<int> pad, THVoidTensor* bias, THVoidTensor* output, std::vector<int> pad,
std::vector<int> stride, int groups, bool transposed); std::vector<int> stride, std::vector<int> dilation, int groups, bool transposed);
}; };
void cudnn_convolution_forward( void cudnn_convolution_forward(
@ -73,12 +74,12 @@ void cudnn_convolution_backward_bias(
Convolution* cudnn_convolution_full_forward( Convolution* cudnn_convolution_full_forward(
THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType, THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType,
THVoidTensor* input, THVoidTensor* weight, THVoidTensor* bias, THVoidTensor* output, THVoidTensor* input, THVoidTensor* weight, THVoidTensor* bias, THVoidTensor* output,
std::vector<int> pad, std::vector<int> stride, int groups, bool benchmark); std::vector<int> pad, std::vector<int> stride, std::vector<int> dilation, int groups, bool benchmark);
Convolution* cudnn_convolution_transpose_full_forward( Convolution* cudnn_convolution_transpose_full_forward(
THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType, THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType,
THVoidTensor* input, THVoidTensor* weight, THVoidTensor* bias, THVoidTensor* output, THVoidTensor* input, THVoidTensor* weight, THVoidTensor* bias, THVoidTensor* output,
std::vector<int> pad, std::vector<int> stride, int groups, bool benchmark); std::vector<int> pad, std::vector<int> stride, std::vector<int> dilation, int groups, bool benchmark);
}} // namespace torch::cudnn }} // namespace torch::cudnn

View File

@ -62,10 +62,11 @@ struct ConvolutionDescriptor
~ConvolutionDescriptor() { ~ConvolutionDescriptor() {
cudnnDestroyConvolutionDescriptor(desc); cudnnDestroyConvolutionDescriptor(desc);
} }
void set(cudnnDataType_t dataType, int dim, int* pad, int* stride) { void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale) {
int upscale[3] = {1, 1, 1}; cudnnDataType_t mathType = dataType;
if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT;
CHECK(cudnnSetConvolutionNdDescriptor(desc, dim, pad, stride, upscale, CHECK(cudnnSetConvolutionNdDescriptor(desc, dim, pad, stride, upscale,
CUDNN_CROSS_CORRELATION, dataType)); CUDNN_CROSS_CORRELATION, mathType));
} }
}; };

View File

@ -2,6 +2,7 @@
#define THP_CUDNN_EXCEPTIONS_INC #define THP_CUDNN_EXCEPTIONS_INC
#include <cudnn.h> #include <cudnn.h>
#include <string>
#include <stdexcept> #include <stdexcept>
#include <sstream> #include <sstream>
@ -14,13 +15,21 @@ namespace torch { namespace cudnn {
class cudnn_exception : public std::runtime_error { class cudnn_exception : public std::runtime_error {
public: public:
cudnnStatus_t status; cudnnStatus_t status;
cudnn_exception(cudnnStatus_t status, const char* msg) : std::runtime_error(msg), status(status) { cudnn_exception(cudnnStatus_t status, const char* msg)
} : std::runtime_error(msg)
, status(status) {}
cudnn_exception(cudnnStatus_t status, const std::string& msg)
: std::runtime_error(msg)
, status(status) {}
}; };
inline void CHECK(cudnnStatus_t status) inline void CHECK(cudnnStatus_t status)
{ {
if (status != CUDNN_STATUS_SUCCESS) { if (status != CUDNN_STATUS_SUCCESS) {
if (status == CUDNN_STATUS_NOT_SUPPORTED) {
throw cudnn_exception(status, std::string(cudnnGetErrorString(status)) +
". This error may appear if you passed in a non-contiguous input.");
}
throw cudnn_exception(status, cudnnGetErrorString(status)); throw cudnn_exception(status, cudnnGetErrorString(status));
} }
} }
@ -28,7 +37,9 @@ inline void CHECK(cudnnStatus_t status)
inline void CUDA_CHECK(cudaError_t error) inline void CUDA_CHECK(cudaError_t error)
{ {
if (error) { if (error) {
throw std::runtime_error("CUDA error"); std::string msg("CUDA error: ");
msg += cudaGetErrorString(error);
throw std::runtime_error(msg);
} }
} }

View File

@ -1,12 +0,0 @@
#include "Module.h"
#include <Python.h>
#include "Types.h"
#include "Conv.h"
#include "CppWrapper.h"
#include "torch/csrc/THP.h"
bool THCUDNNModule_initModule(PyObject *module)
{
return THPWrapper_init(module);
}

View File

@ -4,6 +4,5 @@
#include <Python.h> #include <Python.h>
PyMethodDef* THCUDNN_methods(); PyMethodDef* THCUDNN_methods();
bool THCUDNNModule_initModule(PyObject *self);
#endif #endif

View File

@ -31,4 +31,16 @@ PyObject * getTensorClass(PyObject *args)
return NULL; return NULL;
} }
void _THVoidTensor_assertContiguous(THVoidTensor *tensor, const std::string& name)
{
static const std::string error_str = "cuDNN requires contiguous ";
// Contiguity check
long long expectedStride = 1;
for (int i = tensor->nDimension-1; i >= 0; --i) {
if (tensor->stride[i] != expectedStride)
throw std::invalid_argument(error_str + name);
expectedStride *= tensor->size[i];
}
}
}} // namespace torch::cudnn }} // namespace torch::cudnn

View File

@ -3,12 +3,18 @@
#include <Python.h> #include <Python.h>
#include <cstddef> #include <cstddef>
#include <string>
#include <cudnn.h> #include <cudnn.h>
#include "../Types.h"
namespace torch { namespace cudnn { namespace torch { namespace cudnn {
PyObject * getTensorClass(PyObject *args); PyObject * getTensorClass(PyObject *args);
cudnnDataType_t getCudnnDataType(PyObject *tensorClass); cudnnDataType_t getCudnnDataType(PyObject *tensorClass);
void _THVoidTensor_assertContiguous(THVoidTensor *tensor, const std::string& name);
#define THVoidTensor_assertContiguous(tensor) \
_THVoidTensor_assertContiguous(tensor, #tensor " tensor")
}} // namespace torch::cudnn }} // namespace torch::cudnn

View File

@ -4,7 +4,7 @@
#include "BatchNorm.h" #include "BatchNorm.h"
#include "Conv.h" #include "Conv.h"
#include "torch/csrc/cuda/THCP.h" #include "torch/csrc/cuda/THCP.h"
#include "CppWrapper.h" #include "../PtrWrapper.h"
using namespace torch::cudnn; using namespace torch::cudnn;
@ -50,6 +50,7 @@ extern THCState* state;
- THTensor* output - THTensor* output
- std::vector<int> pad - std::vector<int> pad
- std::vector<int> stride - std::vector<int> stride
- std::vector<int> dilation
- int groups - int groups
- bool benchmark - bool benchmark
]] ]]
@ -68,6 +69,7 @@ extern THCState* state;
- THTensor* output - THTensor* output
- std::vector<int> pad - std::vector<int> pad
- std::vector<int> stride - std::vector<int> stride
- std::vector<int> dilation
- int groups - int groups
- bool benchmark - bool benchmark
]] ]]

View File

@ -0,0 +1,580 @@
#include <Python.h>
#include <memory>
#include <unordered_map>
#include <vector>
#include "THDP.h"
static std::unordered_map<std::string, THDChannelType> name2channel_type = {
{"mpi", THDChannelMPI},
{"tcp", THDChannelTCP},
};
static bool THDPModule_loadClasses(PyObject *module_dict)
{
#define ASSERT_NOT_NULL(ptr) if (!(ptr)) { THPUtils_setError("couldn't load classes"); return false; }
// TODO THD: enable once master-worker is implemented
#if 0
ASSERT_NOT_NULL(THDPDoubleStorageClass = PyMapping_GetItemString(module_dict, (char*)"DoubleStorage"));
ASSERT_NOT_NULL(THDPFloatStorageClass = PyMapping_GetItemString(module_dict, (char*)"FloatStorage"));
//ASSERT_NOT_NULL(THDPHalfStorageClass = PyMapping_GetItemString(module_dict, (char*)"HalfStorage"));
ASSERT_NOT_NULL(THDPLongStorageClass = PyMapping_GetItemString(module_dict, (char*)"LongStorage"));
ASSERT_NOT_NULL(THDPIntStorageClass = PyMapping_GetItemString(module_dict, (char*)"IntStorage"));
ASSERT_NOT_NULL(THDPShortStorageClass = PyMapping_GetItemString(module_dict, (char*)"ShortStorage"));
ASSERT_NOT_NULL(THDPCharStorageClass = PyMapping_GetItemString(module_dict, (char*)"CharStorage"));
ASSERT_NOT_NULL(THDPByteStorageClass = PyMapping_GetItemString(module_dict, (char*)"ByteStorage"));
ASSERT_NOT_NULL(THDPDoubleTensorClass = PyMapping_GetItemString(module_dict, (char*)"DoubleTensor"));
//ASSERT_NOT_NULL(THDPHalfTensorClass = PyMapping_GetItemString(module_dict, (char*)"HalfTensor"));
ASSERT_NOT_NULL(THDPFloatTensorClass = PyMapping_GetItemString(module_dict, (char*)"FloatTensor"));
ASSERT_NOT_NULL(THDPLongTensorClass = PyMapping_GetItemString(module_dict, (char*)"LongTensor"));
ASSERT_NOT_NULL(THDPIntTensorClass = PyMapping_GetItemString(module_dict, (char*)"IntTensor"));
ASSERT_NOT_NULL(THDPShortTensorClass = PyMapping_GetItemString(module_dict, (char*)"ShortTensor"));
ASSERT_NOT_NULL(THDPCharTensorClass = PyMapping_GetItemString(module_dict, (char*)"CharTensor"));
ASSERT_NOT_NULL(THDPByteTensorClass = PyMapping_GetItemString(module_dict, (char*)"ByteTensor"));
#endif
return true;
#undef ASSERT_NOT_NULL
}
static std::unordered_map<PyObject*, THDReduceOp> obj2reduceop;
static std::unordered_map<PyObject*, THDGroup> obj2group;
static THPObjectPtr _ensureBytes(PyObject *obj)
{
#if PY_MAJOR_VERSION == 2
if (PyString_Check(obj)) {
#elif PY_MAJOR_VERSION == 3
if (PyBytes_Check(obj)) {
#endif
Py_INCREF(obj);
return obj;
}
if (PyUnicode_Check(obj)) {
return PyUnicode_AsASCIIString(obj);
}
return NULL;
}
PyObject* THDPModule_initProcessGroup(PyObject *_unused, PyObject *_backend)
{
HANDLE_TH_ERRORS
THPObjectPtr backend_bytes = _ensureBytes(_backend);
THPUtils_assert(backend_bytes, "backend argument has to be a string/bytes "
"object, but got %s", THPUtils_typename(_backend));
char *backend_name = THPUtils_bytesAsString(backend_bytes.get());
THDChannelType channel_type = name2channel_type.at(backend_name);
THPUtils_assert(THDProcessGroupInit(channel_type), "failed to initialize "
"distributed library (THD)");
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_initMasterWorker(PyObject *_unused, PyObject *_backend)
{
HANDLE_TH_ERRORS
THPObjectPtr backend_bytes = _ensureBytes(_backend);
THPUtils_assert(backend_bytes, "backend argument has to be a string/bytes "
"object, but got %s", THPUtils_typename(_backend));
char *backend_name = THPUtils_bytesAsString(backend_bytes.get());
THDChannelType channel_type = name2channel_type.at(backend_name);
THPUtils_assert(THDMasterWorkerInit(channel_type), "failed to initialize "
"distributed library (THD)");
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_getRank(PyObject *_unused)
{
HANDLE_TH_ERRORS
return PyInt_FromLong(THDGetRank());
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_getNumProcesses(PyObject *_unused)
{
HANDLE_TH_ERRORS
return PyInt_FromLong(THDGetNumProcesses());
END_HANDLE_TH_ERRORS
}
static THDTensorDescriptor* _makeDescriptor(PyObject *obj)
{
PyObject *type = (PyObject*)Py_TYPE(obj);
#define REGISTER_TH_DESCRIPTOR(TYPE) \
if (type == THP##TYPE##Class) \
return THDTensorDescriptor_newFromTH##TYPE(((THP##TYPE*)obj)->cdata);
REGISTER_TH_DESCRIPTOR(DoubleTensor);
REGISTER_TH_DESCRIPTOR(FloatTensor);
REGISTER_TH_DESCRIPTOR(LongTensor);
REGISTER_TH_DESCRIPTOR(IntTensor);
REGISTER_TH_DESCRIPTOR(ShortTensor);
REGISTER_TH_DESCRIPTOR(CharTensor);
REGISTER_TH_DESCRIPTOR(ByteTensor);
#undef REGISTER_TH_DESCRIPTOR
throw std::runtime_error(std::string("don't know how to create a THDTensorDesciptor for "
"type ") + std::string(THPUtils_typename(obj)));
}
static THDRequest* _unpackRequest(PyObject *obj)
{
return static_cast<THDRequest*>(THPWrapper_get(obj));
}
static THDReduceOp _getReduceOp(PyObject *obj)
{
auto it = obj2reduceop.find(obj);
if (it == obj2reduceop.end()) {
throw std::runtime_error("op should be a constant from "
"torch.distributed.reduce_op");
}
return it->second;
}
static THDGroup _getGroup(PyObject *obj)
{
auto it = obj2group.find(obj);
if (it == obj2group.end()) {
if (!THPUtils_checkLong(obj))
throw std::runtime_error("group should be an int or one of the values "
"from torch.distributed.group");
return THPUtils_unpackLong(obj);
}
return it->second;
}
PyObject* THDPModule_isend(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
if (PyTuple_GET_SIZE(args) != 2 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
THPUtils_invalidArguments(args, NULL, "isend", 1, "(tensor input, int dst_rank)");
return NULL;
}
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
return THPWrapper_New(THDIsend(desc, dst_rank), (void(*)(void*))THDRequest_free);
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_irecv(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
if (PyTuple_GET_SIZE(args) != 2 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
THPUtils_invalidArguments(args, NULL, "irecv", 1, "(tensor output, int src_rank)");
return NULL;
}
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
return THPWrapper_New(THDIrecv(desc, src_rank), (void(*)(void*))THDRequest_free);
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_send(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
if (PyTuple_GET_SIZE(args) != 2 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
THPUtils_invalidArguments(args, NULL, "send", 1, "(tensor input, int dst_rank)");
return NULL;
}
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
THDSend(desc, dst_rank);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_recvAnySource(PyObject *_unused, PyObject *_tensor)
{
HANDLE_TH_ERRORS
if (!THPModule_isTensor(_tensor)) {
THPUtils_invalidArguments(_tensor, NULL, "recv", 1, "(tensor output)");
return NULL;
}
THDPTensorDesc desc = _makeDescriptor(_tensor);
THDRecvAnySource(desc);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_recv(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
if (PyTuple_GET_SIZE(args) != 2 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
THPUtils_invalidArguments(args, NULL, "recv", 1, "(tensor output, int src_rank)");
return NULL;
}
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
THDRecv(desc, src_rank);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_allReduce(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
if (PyTuple_GET_SIZE(args) != 3 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0))) {
THPUtils_invalidArguments(args, NULL, "all_reduce", 1, "(tensor in_out, reduce_op op, group gr)");
return NULL;
}
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
THDReduceOp op = _getReduceOp(PyTuple_GET_ITEM(args, 1));
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
THDAllReduce(desc, op, group);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_reduce(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
if (PyTuple_GET_SIZE(args) != 4 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
THPUtils_invalidArguments(args, NULL, "reduce", 1,
"(tensor reduced, int dst_rank, reduce_op op, group gr)");
return NULL;
}
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 3));
THDReduceOp op = _getReduceOp(PyTuple_GET_ITEM(args, 2));
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
THDReduce(desc, op, dst_rank, group);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_broadcast(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
if (PyTuple_GET_SIZE(args) != 3 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
THPUtils_invalidArguments(args, NULL, "broadcast", 1,
"(tensor src_dst, int src_rank, group gr)");
return NULL;
}
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
THDBroadcast(desc, src_rank, group);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_allGather(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
PyObject* sequence = PyTuple_GET_ITEM(args, 0);
Py_ssize_t tmp_length;
std::size_t length;
std::vector<THDPTensorDesc> descriptors;
std::vector<THDTensorDescriptor*> raw_descriptors;
if (PyTuple_GET_SIZE(args) != 3 || !PySequence_Check(sequence) ||
!THPModule_isTensor(PyTuple_GET_ITEM(args, 1))) {
goto invalid_arguments;
}
tmp_length = PySequence_Length(sequence);
THPUtils_assert(tmp_length >= 0, "couldn't obtain the length of %s",
THPUtils_typename(sequence));
length = static_cast<std::size_t>(tmp_length);
descriptors.reserve(length);
for (std::size_t i = 0; i < length; ++i) {
if (!THPModule_isTensor(PySequence_ITEM(sequence, i)))
goto invalid_arguments;
descriptors.push_back(
THDPTensorDesc(_makeDescriptor(PySequence_ITEM(sequence, i)))
);
raw_descriptors.push_back(descriptors.back());
}
THDAllGather(
raw_descriptors.data(), length,
THDPTensorDesc(_makeDescriptor(PyTuple_GET_ITEM(args, 1))),
_getGroup(PyTuple_GET_ITEM(args, 2))
);
Py_RETURN_NONE;
invalid_arguments:
THPUtils_invalidArguments(args, NULL, "allGather", 1,
"(list[tensor] output, tensor input, group gr)");
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_gatherSend(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
if (PyTuple_GET_SIZE(args) != 3 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0))) {
THPUtils_invalidArguments(args, NULL, "gatherSend", 1,
"(tensor input, int dst_rank, group gr)");
return NULL;
}
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
THDGatherSend(desc, dst_rank, group);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_gatherRecv(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
PyObject* sequence = PyTuple_GET_ITEM(args, 0);
Py_ssize_t tmp_length;
std::size_t length;
std::vector<THDPTensorDesc> descriptors;
std::vector<THDTensorDescriptor*> raw_descriptors;
if (PyTuple_GET_SIZE(args) != 3 || !PySequence_Check(sequence) ||
!THPModule_isTensor(PyTuple_GET_ITEM(args, 1))) {
goto invalid_arguments;
}
tmp_length = PySequence_Length(sequence);
THPUtils_assert(tmp_length >= 0, "couldn't obtain the length of %s",
THPUtils_typename(sequence));
length = static_cast<std::size_t>(tmp_length);
descriptors.reserve(length);
for (std::size_t i = 0; i < length; ++i) {
if (!THPModule_isTensor(PySequence_ITEM(sequence, i)))
goto invalid_arguments;
descriptors.push_back(
THDPTensorDesc(_makeDescriptor(PySequence_ITEM(sequence, i)))
);
raw_descriptors.push_back(descriptors.back());
}
THDGatherRecv(
raw_descriptors.data(), length,
THDPTensorDesc(_makeDescriptor(PyTuple_GET_ITEM(args, 1))),
_getGroup(PyTuple_GET_ITEM(args, 2))
);
Py_RETURN_NONE;
invalid_arguments:
THPUtils_invalidArguments(args, NULL, "gatherRecv", 1,
"(list[tensor] output, tensor input, group gr)");
return NULL;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_scatterSend(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
PyObject* sequence = PyTuple_GET_ITEM(args, 0);
Py_ssize_t tmp_length;
std::size_t length;
std::vector<THDPTensorDesc> descriptors;
std::vector<THDTensorDescriptor*> raw_descriptors;
if (PyTuple_GET_SIZE(args) != 3 || !PySequence_Check(sequence) ||
!THPModule_isTensor(PyTuple_GET_ITEM(args, 1))) {
goto invalid_arguments;
}
tmp_length = PySequence_Length(sequence);
THPUtils_assert(tmp_length >= 0, "couldn't obtain the length of %s",
THPUtils_typename(sequence));
length = static_cast<std::size_t>(tmp_length);
descriptors.reserve(length);
for (std::size_t i = 0; i < length; ++i) {
if (!THPModule_isTensor(PySequence_ITEM(sequence, i)))
goto invalid_arguments;
descriptors.push_back(
THDPTensorDesc(_makeDescriptor(PySequence_ITEM(sequence, i)))
);
raw_descriptors.push_back(descriptors.back());
}
THDScatterSend(
raw_descriptors.data(), length,
THDPTensorDesc(_makeDescriptor(PyTuple_GET_ITEM(args, 1))),
_getGroup(PyTuple_GET_ITEM(args, 2))
);
Py_RETURN_NONE;
invalid_arguments:
THPUtils_invalidArguments(args, NULL, "scatterSend", 1,
"(list[tensor] input, tensor output, group gr)");
return NULL;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_scatterRecv(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
if (PyTuple_GET_SIZE(args) != 3 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
THPUtils_invalidArguments(args, NULL, "scatterRecv", 1,
"(tensor output, int src_rank, group gr)");
return NULL;
}
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
THDScatterRecv(desc, src_rank, group);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_barrier(PyObject *_unused, PyObject *_group)
{
HANDLE_TH_ERRORS
THDBarrier(_getGroup(_group));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_newGroup(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
PyObject* sequence = PyTuple_GET_ITEM(args, 0);
Py_ssize_t tmp_length;
std::size_t length;
std::vector<int> ranks;
if (PyTuple_GET_SIZE(args) != 1 || !PySequence_Check(sequence))
goto invalid_arguments;
tmp_length = PySequence_Length(sequence);
THPUtils_assert(tmp_length >= 0, "couldn't obtain the length of %s",
THPUtils_typename(sequence));
length = static_cast<std::size_t>(tmp_length);
ranks.reserve(length);
for (std::size_t i = 0; i < length; ++i) {
if (!THPUtils_checkLong(PySequence_ITEM(sequence, i)))
goto invalid_arguments;
ranks.push_back(THPUtils_unpackLong(PySequence_ITEM(sequence, i)));
for (std::size_t j = 0; j < i; ++j)
THPUtils_assert(ranks[i] != ranks[j], "ranks should be unique");
}
return PyInt_FromLong(THDNewGroup(ranks.data(), length));
invalid_arguments:
THPUtils_invalidArguments(args, NULL, "newGroup", 1, "(list[int] ranks)");
return NULL;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_requestIsCompleted(PyObject *_unused, PyObject *_req)
{
HANDLE_TH_ERRORS
if (!THPWrapper_check(_req)) {
THPUtils_invalidArguments(_req, NULL, "requestIsCompleted", 1, "(request req)");
return NULL;
}
return PyBool_FromLong(THDRequest_isCompleted(_unpackRequest(_req)));
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_requestWait(PyObject *_unused, PyObject *_req)
{
HANDLE_TH_ERRORS
if (!THPWrapper_check(_req)) {
THPUtils_invalidArguments(_req, NULL, "requestWait", 1, "(request req)");
return NULL;
}
THDRequest_wait(_unpackRequest(_req));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THDPModule_initExtension(PyObject *_unused, PyObject *args) {
if (PyTuple_GET_SIZE(args) != 3) {
THPUtils_invalidArguments(args, NULL, "initExtension", 1, "(bool is_master_worker, reduce_op obj, group obj)");
return NULL;
}
PyObject* is_master_worker_obj = PyTuple_GET_ITEM(args, 0);
PyObject* reduce_op_obj = PyTuple_GET_ITEM(args, 1);
PyObject* group_obj = PyTuple_GET_ITEM(args, 2);
THPUtils_assert(PyBool_Check(is_master_worker_obj), "first argument should be a bool");
bool is_master_worker = is_master_worker_obj == Py_True;
THPObjectPtr reduce_op;
#define REGISTER_REDUCE_OP(NAME) \
reduce_op = PyObject_GetAttrString(reduce_op_obj, #NAME); \
THPUtils_assert(reduce_op, "Missing object for reduce op " #NAME); \
obj2reduceop.emplace(reduce_op.get(), THDReduce##NAME);
REGISTER_REDUCE_OP(SUM);
REGISTER_REDUCE_OP(PRODUCT);
REGISTER_REDUCE_OP(MIN);
REGISTER_REDUCE_OP(MAX);
#undef REGISTER_REDUCE_OP
THPObjectPtr group;
#define REGISTER_GROUP(NAME) \
group = PyObject_GetAttrString(group_obj, #NAME); \
THPUtils_assert(group, "Missing object for group " #NAME); \
obj2group.emplace(group.get(), THDGroup##NAME);
REGISTER_GROUP(WORLD);
#undef REGISTER_GROUP
if (is_master_worker) {
PyObject *module = PyImport_ImportModule("torch.distributed");
THPUtils_assert(module, "class loader couldn't access torch.distributed module");
PyObject* module_dict = PyModule_GetDict(module);
if (!THDPModule_loadClasses(module_dict)) return NULL;
}
Py_RETURN_TRUE;
}
static struct PyMethodDef _THDPModule_methods[] = {
{"_dist_init_extension", (PyCFunction)THDPModule_initExtension, METH_VARARGS, NULL},
{"_dist_init_process_group", (PyCFunction)THDPModule_initProcessGroup, METH_O, NULL},
{"_dist_init_master_worker", (PyCFunction)THDPModule_initMasterWorker, METH_O, NULL},
{"_dist_get_rank", (PyCFunction)THDPModule_getRank, METH_NOARGS, NULL},
{"_dist_get_num_processes", (PyCFunction)THDPModule_getNumProcesses, METH_NOARGS, NULL},
{"_dist_isend", (PyCFunction)THDPModule_isend, METH_VARARGS, NULL},
{"_dist_irecv", (PyCFunction)THDPModule_irecv, METH_VARARGS, NULL},
{"_dist_send", (PyCFunction)THDPModule_send, METH_VARARGS, NULL},
{"_dist_recv_any_source", (PyCFunction)THDPModule_recvAnySource, METH_O, NULL},
{"_dist_recv", (PyCFunction)THDPModule_recv, METH_VARARGS, NULL},
{"_dist_all_reduce", (PyCFunction)THDPModule_allReduce, METH_VARARGS, NULL},
{"_dist_reduce", (PyCFunction)THDPModule_reduce, METH_VARARGS, NULL},
{"_dist_broadcast", (PyCFunction)THDPModule_broadcast, METH_VARARGS, NULL},
{"_dist_all_gather", (PyCFunction)THDPModule_allGather, METH_VARARGS, NULL},
{"_dist_gather_send", (PyCFunction)THDPModule_gatherSend, METH_VARARGS, NULL},
{"_dist_gather_recv", (PyCFunction)THDPModule_gatherRecv, METH_VARARGS, NULL},
{"_dist_scatter_send", (PyCFunction)THDPModule_scatterSend, METH_VARARGS, NULL},
{"_dist_scatter_recv", (PyCFunction)THDPModule_scatterRecv, METH_VARARGS, NULL},
{"_dist_barrier", (PyCFunction)THDPModule_barrier, METH_O, NULL},
{"_dist_new_group", (PyCFunction)THDPModule_newGroup, METH_VARARGS, NULL},
{"_dist_request_is_completed", (PyCFunction)THDPModule_requestIsCompleted, METH_O, NULL},
{"_dist_request_wait", (PyCFunction)THDPModule_requestWait, METH_O, NULL},
{NULL}
};
PyMethodDef* THDPModule_methods() {
return _THDPModule_methods;
}

View File

@ -0,0 +1,14 @@
#include <Python.h>
#include <structmember.h>
#include <stdbool.h>
#include "THDP.h"
#include "override_macros.h"
#define THD_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
#include <THD/base/THDGenerateAllTypes.h>
//#define THD_GENERIC_FILE "torch/csrc/generic/StorageCopy.cpp"
//#include <THD/THDGenerateAllTypes.h>

View File

@ -0,0 +1,45 @@
#ifndef THDP_STORAGE_INC
#define THDP_STORAGE_INC
#define THDPStorage TH_CONCAT_3(THDP,Real,Storage)
#define THDPStorageStr TH_CONCAT_STRING_3(torch.cuda.,Real,Storage)
#define THDPStorageClass TH_CONCAT_3(THDP,Real,StorageClass)
#define THDPStorage_(NAME) TH_CONCAT_4(THDP,Real,Storage_,NAME)
#define THDPDoubleStorage_Check(obj) \
PyObject_IsInstance(obj, THDPDoubleStorageClass)
#define THDPFloatStorage_Check(obj) \
PyObject_IsInstance(obj, THDPFloatStorageClass)
#define THDPHalfStorage_Check(obj) \
PyObject_IsInstance(obj, THDPHalfStorageClass)
#define THDPLongStorage_Check(obj) \
PyObject_IsInstance(obj, THDPLongStorageClass)
#define THDPIntStorage_Check(obj) \
PyObject_IsInstance(obj, THDPIntStorageClass)
#define THDPShortStorage_Check(obj) \
PyObject_IsInstance(obj, THDPShortStorageClass)
#define THDPCharStorage_Check(obj) \
PyObject_IsInstance(obj, THDPCharStorageClass)
#define THDPByteStorage_Check(obj) \
PyObject_IsInstance(obj, THDPByteStorageClass)
#define THDPDoubleStorage_CData(obj) (obj)->cdata
#define THDPFloatStorage_CData(obj) (obj)->cdata
#define THDPLongStorage_CData(obj) (obj)->cdata
#define THDPIntStorage_CData(obj) (obj)->cdata
#define THDPShortStorage_CData(obj) (obj)->cdata
#define THDPCharStorage_CData(obj) (obj)->cdata
#define THDPByteStorage_CData(obj) (obj)->cdata
#ifdef _THP_CORE
#define THDPStorageType TH_CONCAT_3(THDP,Real,StorageType)
#define THDPStorageBaseStr TH_CONCAT_STRING_3(Distributed,Real,StorageBase)
#endif
#include "override_macros.h"
#define THD_GENERIC_FILE "torch/csrc/generic/Storage.h"
#include <THD/base/THDGenerateAllTypes.h>
#endif

Some files were not shown because too many files have changed in this diff Show More