mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			19 Commits
		
	
	
		
			ciflow/pul
			...
			v1.1.0
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 0b868b1906 | |||
| fb6347d937 | |||
| 1bb8cfcc5a | |||
| 142c973f41 | |||
| e39ab6632f | |||
| 20607a99a3 | |||
| f0bc8d1dc5 | |||
| cca6aca5d2 | |||
| 5a5ff34ff1 | |||
| 63b2ecd934 | |||
| 82f6886f73 | |||
| fbe8a37832 | |||
| 89748dd0dd | |||
| 092bcc9c69 | |||
| 54f9440479 | |||
| 4adc14da61 | |||
| ea1d0eeb92 | |||
| c7ad499b33 | |||
| 16f2b22120 | 
@ -143,7 +143,7 @@ install_doc_push_script: &install_doc_push_script
 | 
			
		||||
    if [ "\$is_master_doc" = true ]; then
 | 
			
		||||
      find "\$install_path" -name "*.html" -print0 | xargs -0 perl -pi -w -e "s@master\s+\((\d\.\d\.[A-Fa-f0-9]+\+[A-Fa-f0-9]+)\s+\)@<a href='http://pytorch.org/docs/versions.html'>\1 \▼</a>@g"
 | 
			
		||||
    else
 | 
			
		||||
      find "\$install_path" -name "*.html" -print0 | xargs -0 perl -pi -w -e "s@master\s+\((\d\.\d\.[A-Fa-f0-9]+\+[A-Fa-f0-9]+)\s+\)@<a href='http://pytorch.org/docs/versions.html'>\$version \▼</a>@g"
 | 
			
		||||
      find "\$install_path" -name "*.html" -print0 | xargs -0 perl -pi -w -e "s@master\s+\((\d\.\d\.\S+)\s+\)@<a href='http://pytorch.org/docs/versions.html'>\$version \▼</a>@g"
 | 
			
		||||
    fi
 | 
			
		||||
 | 
			
		||||
    git add "\$install_path" || true
 | 
			
		||||
@ -1169,11 +1169,11 @@ jobs:
 | 
			
		||||
            export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && ./doc_push_script.sh docs/master master") | docker exec -u jenkins -i "$id" bash) 2>&1'
 | 
			
		||||
 | 
			
		||||
          # stable release docs push. Due to some circleci limitations, we keep
 | 
			
		||||
          # an eternal PR open (#16502) for merging v1.0.1 -> master for this job.
 | 
			
		||||
          # XXX: The following code is only run on the v1.0.1 branch, which might
 | 
			
		||||
          # an eternal PR open for merging v1.1.0 -> master for this job.
 | 
			
		||||
          # XXX: The following code is only run on the v1.1.0 branch, which might
 | 
			
		||||
          # not be exactly the same as what you see here.
 | 
			
		||||
          elif [[ "${CIRCLE_BRANCH}" == "v1.0.1" ]]; then
 | 
			
		||||
            export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && ./doc_push_script.sh docs/stable 1.0.1") | docker exec -u jenkins -i "$id" bash) 2>&1'
 | 
			
		||||
          elif [[ "${CIRCLE_BRANCH}" == "v1.1.0" ]]; then
 | 
			
		||||
            export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && ./doc_push_script.sh docs/stable 1.1.0") | docker exec -u jenkins -i "$id" bash) 2>&1'
 | 
			
		||||
 | 
			
		||||
          # For open PRs: Do a dry_run of the docs build, don't push build
 | 
			
		||||
          else
 | 
			
		||||
 | 
			
		||||
@ -7,8 +7,8 @@ cat >/home/circleci/project/login_to_anaconda.sh <<EOL
 | 
			
		||||
set +x
 | 
			
		||||
echo "Trying to login to Anaconda"
 | 
			
		||||
yes | anaconda login \
 | 
			
		||||
    --username "$PYTORCH_BINARY_PJH5_CONDA_USERNAME" \
 | 
			
		||||
    --password "$PYTORCH_BINARY_PJH5_CONDA_PASSWORD"
 | 
			
		||||
    --username "$PYTORCH_BINARY_SOUMITH_CONDA_USERNAME" \
 | 
			
		||||
    --password "$PYTORCH_BINARY_SOUMITH_CONDA_PASSWORD"
 | 
			
		||||
set -x
 | 
			
		||||
EOL
 | 
			
		||||
chmod +x /home/circleci/project/login_to_anaconda.sh
 | 
			
		||||
@ -24,7 +24,7 @@ pushd /home/circleci/project/final_pkgs
 | 
			
		||||
if [[ "$PACKAGE_TYPE" == conda ]]; then
 | 
			
		||||
  retry conda install -yq anaconda-client
 | 
			
		||||
  retry timeout 30 /home/circleci/project/login_to_anaconda.sh
 | 
			
		||||
  anaconda upload "$(ls)" -u pytorch --label main --no-progress --force
 | 
			
		||||
  anaconda upload "$(ls)" -u pytorch-testing --label main --no-progress --force
 | 
			
		||||
elif [[ "$PACKAGE_TYPE" == libtorch ]]; then
 | 
			
		||||
  retry pip install -q awscli
 | 
			
		||||
  s3_dir="s3://pytorch/libtorch/${PIP_UPLOAD_FOLDER}${DESIRED_CUDA}/"
 | 
			
		||||
 | 
			
		||||
@ -6,8 +6,8 @@ cat >/Users/distiller/project/login_to_anaconda.sh <<EOL
 | 
			
		||||
set +x
 | 
			
		||||
echo "Trying to login to Anaconda"
 | 
			
		||||
yes | anaconda login \
 | 
			
		||||
    --username "$PYTORCH_BINARY_PJH5_CONDA_USERNAME" \
 | 
			
		||||
    --password "$PYTORCH_BINARY_PJH5_CONDA_PASSWORD"
 | 
			
		||||
    --username "$PYTORCH_BINARY_SOUMITH_CONDA_USERNAME" \
 | 
			
		||||
    --password "$PYTORCH_BINARY_SOUMITH_CONDA_PASSWORD"
 | 
			
		||||
set -x
 | 
			
		||||
EOL
 | 
			
		||||
chmod +x /Users/distiller/project/login_to_anaconda.sh
 | 
			
		||||
@ -24,7 +24,7 @@ pushd "$workdir/final_pkgs"
 | 
			
		||||
if [[ "$PACKAGE_TYPE" == conda ]]; then
 | 
			
		||||
  retry conda install -yq anaconda-client
 | 
			
		||||
  retry /Users/distiller/project/login_to_anaconda.sh
 | 
			
		||||
  retry anaconda upload "$(ls)" -u pytorch --label main --no-progress --force
 | 
			
		||||
  retry anaconda upload "$(ls)" -u pytorch-testing --label main --no-progress --force
 | 
			
		||||
elif [[ "$PACKAGE_TYPE" == libtorch ]]; then
 | 
			
		||||
  retry pip install -q awscli
 | 
			
		||||
  s3_dir="s3://pytorch/libtorch/${PIP_UPLOAD_FOLDER}${DESIRED_CUDA}/"
 | 
			
		||||
 | 
			
		||||
@ -40,19 +40,19 @@ fi
 | 
			
		||||
 | 
			
		||||
# Upload to parallel folder for gcc abis
 | 
			
		||||
if [[ "$DESIRED_DEVTOOLSET" == 'devtoolset7' ]]; then
 | 
			
		||||
  export PIP_UPLOAD_FOLDER='nightly/devtoolset7/'
 | 
			
		||||
  export PIP_UPLOAD_FOLDER='devtoolset7/'
 | 
			
		||||
  if [[ "$PACKAGE_TYPE" == 'conda' ]]; then
 | 
			
		||||
    echo "We don't handle conda builds with gcc ABI of 1, since we don't"
 | 
			
		||||
    echo "want to add a new package name to the conda builds"
 | 
			
		||||
    exit 1
 | 
			
		||||
  fi
 | 
			
		||||
else
 | 
			
		||||
  export PIP_UPLOAD_FOLDER='nightly/'
 | 
			
		||||
  export PIP_UPLOAD_FOLDER=''
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
# We put this here so that OVERRIDE_PACKAGE_VERSION below can read from it
 | 
			
		||||
export DATE="$(date -u +%Y%m%d)"
 | 
			
		||||
export PYTORCH_BUILD_VERSION="1.1.0.dev$DATE"
 | 
			
		||||
export PYTORCH_BUILD_VERSION="1.1.0"
 | 
			
		||||
export PYTORCH_BUILD_NUMBER=1
 | 
			
		||||
 | 
			
		||||
cat >>"$envfile" <<EOL
 | 
			
		||||
@ -73,8 +73,8 @@ export PYTORCH_BUILD_VERSION="$PYTORCH_BUILD_VERSION"
 | 
			
		||||
export PYTORCH_BUILD_NUMBER="$PYTORCH_BUILD_NUMBER"
 | 
			
		||||
export OVERRIDE_PACKAGE_VERSION="$PYTORCH_BUILD_VERSION"
 | 
			
		||||
 | 
			
		||||
export TORCH_PACKAGE_NAME='torch-nightly'
 | 
			
		||||
export TORCH_CONDA_BUILD_FOLDER='pytorch-nightly'
 | 
			
		||||
export TORCH_PACKAGE_NAME='torch'
 | 
			
		||||
export TORCH_CONDA_BUILD_FOLDER='pytorch-1.1.0'
 | 
			
		||||
 | 
			
		||||
export NO_FBGEMM=1
 | 
			
		||||
export PIP_UPLOAD_FOLDER="$PIP_UPLOAD_FOLDER"
 | 
			
		||||
 | 
			
		||||
@ -143,7 +143,7 @@ install_doc_push_script: &install_doc_push_script
 | 
			
		||||
    if [ "\$is_master_doc" = true ]; then
 | 
			
		||||
      find "\$install_path" -name "*.html" -print0 | xargs -0 perl -pi -w -e "s@master\s+\((\d\.\d\.[A-Fa-f0-9]+\+[A-Fa-f0-9]+)\s+\)@<a href='http://pytorch.org/docs/versions.html'>\1 \▼</a>@g"
 | 
			
		||||
    else
 | 
			
		||||
      find "\$install_path" -name "*.html" -print0 | xargs -0 perl -pi -w -e "s@master\s+\((\d\.\d\.[A-Fa-f0-9]+\+[A-Fa-f0-9]+)\s+\)@<a href='http://pytorch.org/docs/versions.html'>\$version \▼</a>@g"
 | 
			
		||||
      find "\$install_path" -name "*.html" -print0 | xargs -0 perl -pi -w -e "s@master\s+\((\d\.\d\.\S+)\s+\)@<a href='http://pytorch.org/docs/versions.html'>\$version \▼</a>@g"
 | 
			
		||||
    fi
 | 
			
		||||
 | 
			
		||||
    git add "\$install_path" || true
 | 
			
		||||
 | 
			
		||||
@ -62,11 +62,11 @@
 | 
			
		||||
            export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && ./doc_push_script.sh docs/master master") | docker exec -u jenkins -i "$id" bash) 2>&1'
 | 
			
		||||
 | 
			
		||||
          # stable release docs push. Due to some circleci limitations, we keep
 | 
			
		||||
          # an eternal PR open (#16502) for merging v1.0.1 -> master for this job.
 | 
			
		||||
          # XXX: The following code is only run on the v1.0.1 branch, which might
 | 
			
		||||
          # an eternal PR open for merging v1.1.0 -> master for this job.
 | 
			
		||||
          # XXX: The following code is only run on the v1.1.0 branch, which might
 | 
			
		||||
          # not be exactly the same as what you see here.
 | 
			
		||||
          elif [[ "${CIRCLE_BRANCH}" == "v1.0.1" ]]; then
 | 
			
		||||
            export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && ./doc_push_script.sh docs/stable 1.0.1") | docker exec -u jenkins -i "$id" bash) 2>&1'
 | 
			
		||||
          elif [[ "${CIRCLE_BRANCH}" == "v1.1.0" ]]; then
 | 
			
		||||
            export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && ./doc_push_script.sh docs/stable 1.1.0") | docker exec -u jenkins -i "$id" bash) 2>&1'
 | 
			
		||||
 | 
			
		||||
          # For open PRs: Do a dry_run of the docs build, don't push build
 | 
			
		||||
          else
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										11
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								README.md
									
									
									
									
									
								
							@ -37,11 +37,12 @@ At a granular level, PyTorch is a library that consists of the following compone
 | 
			
		||||
 | 
			
		||||
| Component | Description |
 | 
			
		||||
| ---- | --- |
 | 
			
		||||
| **torch** | a Tensor library like NumPy, with strong GPU support |
 | 
			
		||||
| **torch.autograd** | a tape-based automatic differentiation library that supports all differentiable Tensor operations in torch |
 | 
			
		||||
| **torch.nn** | a neural networks library deeply integrated with autograd designed for maximum flexibility |
 | 
			
		||||
| **torch.multiprocessing** | Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training |
 | 
			
		||||
| **torch.utils** | DataLoader and other utility functions for convenience |
 | 
			
		||||
| [**torch**](https://pytorch.org/docs/stable/torch.html) | a Tensor library like NumPy, with strong GPU support |
 | 
			
		||||
| [**torch.autograd**](https://pytorch.org/docs/stable/autograd.html) | a tape-based automatic differentiation library that supports all differentiable Tensor operations in torch |
 | 
			
		||||
| [**torch.jit**](https://pytorch.org/docs/stable/jit.html) | a compilation stack (TorchScript) to create serializable and optimizable models from PyTorch code  |
 | 
			
		||||
| [**torch.nn**](https://pytorch.org/docs/stable/nn.html) | a neural networks library deeply integrated with autograd designed for maximum flexibility |
 | 
			
		||||
| [**torch.multiprocessing**](https://pytorch.org/docs/stable/multiprocessing.html) | Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training |
 | 
			
		||||
| [**torch.utils**](https://pytorch.org/docs/stable/data.html) | DataLoader and other utility functions for convenience |
 | 
			
		||||
 | 
			
		||||
Usually one uses PyTorch either as:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2688,6 +2688,7 @@
 | 
			
		||||
  types:
 | 
			
		||||
    - floating_point
 | 
			
		||||
  backends:
 | 
			
		||||
    - CPU
 | 
			
		||||
    - CUDA
 | 
			
		||||
  return: self
 | 
			
		||||
  arguments:
 | 
			
		||||
@ -2721,6 +2722,7 @@
 | 
			
		||||
[[
 | 
			
		||||
  name: _th_geometric_
 | 
			
		||||
  backends:
 | 
			
		||||
    - CPU
 | 
			
		||||
    - CUDA
 | 
			
		||||
  cname: geometric
 | 
			
		||||
  variants: function
 | 
			
		||||
 | 
			
		||||
@ -55,6 +55,12 @@ void THTensor_(cappedRandom)(THTensor *self, THGenerator *_generator, int64_t ma
 | 
			
		||||
  THTensor_(clampedRandom)(self, _generator, 0, max);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void THTensor_(geometric)(THTensor *self, THGenerator *_generator, double p)
 | 
			
		||||
{
 | 
			
		||||
  std::lock_guard<std::mutex> lock(_generator->mutex);
 | 
			
		||||
  TH_TENSOR_APPLY(scalar_t, self, *self_data = (scalar_t)THRandom_geometric(_generator, p););
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
 | 
			
		||||
 | 
			
		||||
#if defined(TH_REAL_IS_FLOAT)
 | 
			
		||||
 | 
			
		||||
@ -5,6 +5,7 @@
 | 
			
		||||
TH_API void THTensor_(random)(THTensor *self, THGenerator *_generator);
 | 
			
		||||
TH_API void THTensor_(clampedRandom)(THTensor *self, THGenerator *_generator, int64_t min, int64_t max);
 | 
			
		||||
TH_API void THTensor_(cappedRandom)(THTensor *self, THGenerator *_generator, int64_t max);
 | 
			
		||||
TH_API void THTensor_(geometric)(THTensor *self, THGenerator *_generator, double p);
 | 
			
		||||
 | 
			
		||||
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
 | 
			
		||||
TH_API void THTensor_(bernoulli_Tensor)(THTensor *self, THGenerator *_generator, THTensor *p);
 | 
			
		||||
 | 
			
		||||
@ -338,6 +338,22 @@ if(BUILD_TEST)
 | 
			
		||||
  if (NOT CAFFE2_USE_MSVC_STATIC_RUNTIME)
 | 
			
		||||
      set(gtest_force_shared_crt ON CACHE BOOL "force shared crt on gtest" FORCE)
 | 
			
		||||
  endif()
 | 
			
		||||
  # We need to replace googletest cmake scripts too.
 | 
			
		||||
  # Otherwise, it will sometimes break the build.
 | 
			
		||||
  # To make the git clean after the build, we make a backup first.
 | 
			
		||||
  if (MSVC AND MSVC_Z7_OVERRIDE)
 | 
			
		||||
    execute_process(
 | 
			
		||||
      COMMAND ${CMAKE_COMMAND}
 | 
			
		||||
              "-DFILENAME=${CMAKE_CURRENT_LIST_DIR}/../third_party/googletest/googletest/cmake/internal_utils.cmake"
 | 
			
		||||
              "-DBACKUP=${CMAKE_CURRENT_LIST_DIR}/../third_party/googletest/googletest/cmake/internal_utils.cmake.bak"
 | 
			
		||||
              "-DREVERT=0"
 | 
			
		||||
              "-P"
 | 
			
		||||
              "${CMAKE_CURRENT_LIST_DIR}/GoogleTestPatch.cmake"
 | 
			
		||||
      RESULT_VARIABLE _exitcode)
 | 
			
		||||
    if(NOT ${_exitcode} EQUAL 0)
 | 
			
		||||
      message(WARNING "Patching failed for Google Test. The build may fail.")
 | 
			
		||||
    endif()
 | 
			
		||||
  endif()
 | 
			
		||||
 | 
			
		||||
  # Add googletest subdirectory but make sure our INCLUDE_DIRECTORIES
 | 
			
		||||
  # don't bleed into it. This is because libraries installed into the root conda
 | 
			
		||||
@ -363,6 +379,21 @@ if(BUILD_TEST)
 | 
			
		||||
 | 
			
		||||
  # Recover build options.
 | 
			
		||||
  set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE)
 | 
			
		||||
 | 
			
		||||
  # To make the git clean after the build, we revert the changes here.
 | 
			
		||||
  if (MSVC AND MSVC_Z7_OVERRIDE)
 | 
			
		||||
    execute_process(
 | 
			
		||||
      COMMAND ${CMAKE_COMMAND}
 | 
			
		||||
              "-DFILENAME=${CMAKE_CURRENT_LIST_DIR}/../third_party/googletest/googletest/cmake/internal_utils.cmake"
 | 
			
		||||
              "-DBACKUP=${CMAKE_CURRENT_LIST_DIR}/../third_party/googletest/googletest/cmake/internal_utils.cmake.bak"
 | 
			
		||||
              "-DREVERT=1"
 | 
			
		||||
              "-P"
 | 
			
		||||
              "${CMAKE_CURRENT_LIST_DIR}/GoogleTestPatch.cmake"
 | 
			
		||||
      RESULT_VARIABLE _exitcode)
 | 
			
		||||
    if(NOT ${_exitcode} EQUAL 0)
 | 
			
		||||
      message(WARNING "Reverting changes failed for Google Test. The build may fail.")
 | 
			
		||||
    endif()
 | 
			
		||||
  endif()
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
# ---[ FBGEMM
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										24
									
								
								cmake/GoogleTestPatch.cmake
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								cmake/GoogleTestPatch.cmake
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,24 @@
 | 
			
		||||
# CMake file to replace the string contents in Google Test and Google Mock
 | 
			
		||||
# Usage example:
 | 
			
		||||
# Patch the cmake file
 | 
			
		||||
#   cmake -DFILENAME=internal_utils.cmake
 | 
			
		||||
#         -DBACKUP=internal_utils.cmake.bak
 | 
			
		||||
#         -DREVERT=0 
 | 
			
		||||
#         -P GoogleTestPatch.cmake 
 | 
			
		||||
# Revert the changes
 | 
			
		||||
#   cmake -DFILENAME=internal_utils.cmake
 | 
			
		||||
#         -DBACKUP=internal_utils.cmake.bak
 | 
			
		||||
#         -DREVERT=1 
 | 
			
		||||
#         -P GoogleTestPatch.cmake 
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if(REVERT)
 | 
			
		||||
  file(READ ${BACKUP} content)
 | 
			
		||||
  file(WRITE ${FILENAME} "${content}")
 | 
			
		||||
  file(REMOVE ${BACKUP})
 | 
			
		||||
else(REVERT)
 | 
			
		||||
  file(READ ${FILENAME} content)
 | 
			
		||||
  file(WRITE ${BACKUP} "${content}")
 | 
			
		||||
  string(REGEX REPLACE "[-/]Z[iI]" "/Z7" content "${content}")
 | 
			
		||||
  file(WRITE ${FILENAME} "${content}")
 | 
			
		||||
endif(REVERT)
 | 
			
		||||
@ -63,7 +63,7 @@ Build + CI
 | 
			
		||||
-  Jesse Hellemn (`pjh5 <https://github.com/pjh5>`__)
 | 
			
		||||
-  Soumith Chintala (`soumith <https://github.com/soumith>`__)
 | 
			
		||||
-  (sunsetting) Orion Reblitz-Richardson
 | 
			
		||||
(`orionr <https://github.com/orionr>`__)
 | 
			
		||||
   (`orionr <https://github.com/orionr>`__)
 | 
			
		||||
 | 
			
		||||
Distributions & RNG
 | 
			
		||||
~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
 | 
			
		||||
@ -258,7 +258,7 @@ Probability distributions - torch.distributions
 | 
			
		||||
    :show-inheritance:
 | 
			
		||||
 | 
			
		||||
:hidden:`LogitRelaxedBernoulli`
 | 
			
		||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
 | 
			
		||||
.. currentmodule:: torch.distributions.relaxed_bernoulli
 | 
			
		||||
.. autoclass:: LogitRelaxedBernoulli
 | 
			
		||||
@ -301,7 +301,7 @@ Probability distributions - torch.distributions
 | 
			
		||||
    :members:
 | 
			
		||||
    :undoc-members:
 | 
			
		||||
    :show-inheritance:
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
:hidden:`Weibull`
 | 
			
		||||
~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -151,7 +151,7 @@ net models. In particular, TorchScript supports:
 | 
			
		||||
Unlike Python, each variable in TorchScript function must have a single static type.
 | 
			
		||||
This makes it easier to optimize TorchScript functions.
 | 
			
		||||
 | 
			
		||||
Example::
 | 
			
		||||
Example (a type mismatch)::
 | 
			
		||||
 | 
			
		||||
    @torch.jit.script
 | 
			
		||||
    def an_error(x):
 | 
			
		||||
@ -201,35 +201,34 @@ Example::
 | 
			
		||||
 | 
			
		||||
        @torch.jit.script_method
 | 
			
		||||
        def forward(self, x):
 | 
			
		||||
            # type: (Tensor) -> Tuple[List[Tuple[Tensor, Tensor]], Dict[int, Tensor]]
 | 
			
		||||
            # type: (Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]
 | 
			
		||||
 | 
			
		||||
            # This annotates the list to be a `List[Tuple[Tensor, Tensor]]`
 | 
			
		||||
            list_of_tuple = torch.jit.annotate(List[Tuple[Tensor, Tensor]], [])
 | 
			
		||||
            # This annotates the list to be a `List[Tuple[int, float]]`
 | 
			
		||||
            my_list = torch.jit.annotate(List[Tuple[int, float]], [])
 | 
			
		||||
            for i in range(10):
 | 
			
		||||
                list_of_tuple.append((x, x))
 | 
			
		||||
                my_list.append((x, x))
 | 
			
		||||
 | 
			
		||||
            # This annotates the list to be a `Dict[int, Tensor]`
 | 
			
		||||
            int_tensor_dict = torch.jit.annotate(Dict[int, Tensor], {})
 | 
			
		||||
            return list_of_tuple, int_tensor_dict
 | 
			
		||||
            my_dict = torch.jit.annotate(Dict[str, int], {})
 | 
			
		||||
            return my_list, my_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Optional Type Refinement
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
TorchScript will refine the type of a variable of type Optional[T] when
 | 
			
		||||
a comparison to None is made inside the conditional of an if statement.
 | 
			
		||||
The compiler can reason about multiple None checks that are combined with
 | 
			
		||||
AND, OR, or NOT. Refinement will also occur for else blocks of if statements
 | 
			
		||||
TorchScript will refine the type of a variable of type ``Optional[T]`` when
 | 
			
		||||
a comparison to ``None`` is made inside the conditional of an if-statement.
 | 
			
		||||
The compiler can reason about multiple ``None`` checks that are combined with
 | 
			
		||||
``and``, ``or``, and ``not``. Refinement will also occur for else blocks of if-statements
 | 
			
		||||
that are not explicitly written.
 | 
			
		||||
 | 
			
		||||
The expression must be emitted within the conditional; assigning
 | 
			
		||||
a None check to a variable and using it in the conditional will not refine types.
 | 
			
		||||
a ``None`` check to a variable and using it in the conditional will not refine types.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Example::
 | 
			
		||||
 | 
			
		||||
  @torch.jit.script
 | 
			
		||||
  def opt_unwrap(x, y, z):
 | 
			
		||||
  def optional_unwrap(x, y, z):
 | 
			
		||||
    # type: (Optional[int], Optional[int], Optional[int]) -> int
 | 
			
		||||
    if x is None:
 | 
			
		||||
      x = 1
 | 
			
		||||
@ -240,6 +239,66 @@ Example::
 | 
			
		||||
    return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Classes
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
Python classes can be used in TorchScript if they are annotated with ``@torch.jit.script``,
 | 
			
		||||
similar to how you would declare a TorchScript function: ::
 | 
			
		||||
 | 
			
		||||
    @torch.jit.script
 | 
			
		||||
    class Foo:
 | 
			
		||||
      def __init__(self, x, y)
 | 
			
		||||
        self.x = x
 | 
			
		||||
 | 
			
		||||
      def aug_add_x(self, inc):
 | 
			
		||||
        self.x += inc
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
This subset is restricted:
 | 
			
		||||
 | 
			
		||||
* All functions must be valid TorchScript functions (including ``__init__()``)
 | 
			
		||||
* Classes must be new-style classes, as we use ``__new__()`` to construct them with pybind11
 | 
			
		||||
* TorchScript classes are statically typed. Members are declared by assigning to
 | 
			
		||||
  self in the ``__init__()`` method
 | 
			
		||||
 | 
			
		||||
    For example, assigning outside of the ``__init__()`` method: ::
 | 
			
		||||
 | 
			
		||||
        @torch.jit.script
 | 
			
		||||
        class Foo:
 | 
			
		||||
          def assign_x(self):
 | 
			
		||||
            self.x = torch.rand(2, 3)
 | 
			
		||||
 | 
			
		||||
    Will result in: ::
 | 
			
		||||
 | 
			
		||||
        RuntimeError:
 | 
			
		||||
        Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
 | 
			
		||||
        def assign_x(self):
 | 
			
		||||
          self.x = torch.rand(2, 3)
 | 
			
		||||
          ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
 | 
			
		||||
 | 
			
		||||
* No expressions except method definitions are allowed in the body of the class
 | 
			
		||||
* No support for inheritance or any other polymorphism strategy, except for inheriting
 | 
			
		||||
  from object to specify a new-style class
 | 
			
		||||
 | 
			
		||||
After a class is defined, it can be used in both TorchScript and Python interchangeably
 | 
			
		||||
like any other TorchScript type:
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    @torch.jit.script
 | 
			
		||||
    class Pair:
 | 
			
		||||
      def __init__(self, first, second)
 | 
			
		||||
        self.first = first
 | 
			
		||||
        self.second = second
 | 
			
		||||
 | 
			
		||||
    @torch.jit.script
 | 
			
		||||
    def sum_pair(p):
 | 
			
		||||
      # type : (Pair) -> Tensor
 | 
			
		||||
      return p.first + p.second
 | 
			
		||||
 | 
			
		||||
    p = Pair(torch.rand(2, 3), torch.rand(2, 3)
 | 
			
		||||
    print(sum_pair(p))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Expressions
 | 
			
		||||
~~~~~~~~~~~
 | 
			
		||||
 | 
			
		||||
@ -255,8 +314,9 @@ List Construction
 | 
			
		||||
    ``[3, 4]``, ``[]``, ``[torch.rand(3), torch.rand(4)]``
 | 
			
		||||
 | 
			
		||||
    .. note::
 | 
			
		||||
        an empty list is assumed have type ``List[Tensor]``.
 | 
			
		||||
        An empty list is assumed have type ``List[Tensor]``.
 | 
			
		||||
        The types of other list literals are derived from the type of the members.
 | 
			
		||||
        To denote an empty list of another type, use ``torch.jit.annotate``.
 | 
			
		||||
 | 
			
		||||
Tuple Construction
 | 
			
		||||
""""""""""""""""""
 | 
			
		||||
@ -268,8 +328,9 @@ Dict Construction
 | 
			
		||||
    ``{'hello': 3}``, ``{}``, ``{'a': torch.rand(3), 'b': torch.rand(4)}``
 | 
			
		||||
 | 
			
		||||
    .. note::
 | 
			
		||||
        an empty dict is assumed have type ``Dict[str, Tensor]``.
 | 
			
		||||
        An empty dict is assumed have type ``Dict[str, Tensor]``.
 | 
			
		||||
        The types of other dict literals are derived from the type of the members.
 | 
			
		||||
        To denote an empty dict of another type, use ``torch.jit.annotate``.
 | 
			
		||||
 | 
			
		||||
Variables
 | 
			
		||||
^^^^^^^^^
 | 
			
		||||
@ -341,10 +402,6 @@ Subscripts
 | 
			
		||||
 | 
			
		||||
  ``t[i:j, i]``
 | 
			
		||||
 | 
			
		||||
  .. note::
 | 
			
		||||
    TorchScript currently does not support mutating tensors in place, so any
 | 
			
		||||
    tensor indexing can only appear on the right-hand size of an expression.
 | 
			
		||||
 | 
			
		||||
Function Calls
 | 
			
		||||
^^^^^^^^^^^^^^
 | 
			
		||||
   Calls to built-in functions: ``torch.rand(3, dtype=torch.int)``
 | 
			
		||||
@ -468,11 +525,6 @@ For loops with ``range``
 | 
			
		||||
        for i in range(10):
 | 
			
		||||
            x *= i
 | 
			
		||||
 | 
			
		||||
    .. note::
 | 
			
		||||
      Script currently does not support iterating over generic iterable
 | 
			
		||||
      objects like lists or tensors. Script currently does not support start or
 | 
			
		||||
      increment parameters to range. These will be added in a future version.
 | 
			
		||||
 | 
			
		||||
For loops over tuples:
 | 
			
		||||
 | 
			
		||||
    ::
 | 
			
		||||
@ -512,9 +564,9 @@ For loops over constant ``torch.nn.ModuleList``
 | 
			
		||||
                  return v
 | 
			
		||||
 | 
			
		||||
      .. note::
 | 
			
		||||
          To use a module list inside a ``@script_method`` it must be marked
 | 
			
		||||
          To use a ``nn.ModuleList`` inside a ``@script_method`` it must be marked
 | 
			
		||||
          constant by adding the name of the attribute to the ``__constants__``
 | 
			
		||||
          list for the type. For loops over a ModuleList will unroll the body of the
 | 
			
		||||
          list for the type. For loops over a ``nn.ModuleList`` will unroll the body of the
 | 
			
		||||
          loop at compile time, with each member of the constant module list.
 | 
			
		||||
 | 
			
		||||
Return
 | 
			
		||||
@ -557,17 +609,17 @@ To make writing TorchScript more convenient, we allow script code to refer
 | 
			
		||||
to Python values in the surrounding scope. For instance, any time there is a
 | 
			
		||||
reference to ``torch``, the TorchScript compiler is actually resolving it to the
 | 
			
		||||
``torch`` Python module when the function is declared.  These Python values are
 | 
			
		||||
not a first class part of TorchScript. Instead they are desugared at compile-time
 | 
			
		||||
into the primitive types that TorchScript supports. This section describes the
 | 
			
		||||
rules that are used when accessing Python values in TorchScript. They depend
 | 
			
		||||
on the dynamic type of the python valued referenced.
 | 
			
		||||
not a first class part of TorchScript. Instead they are de-sugared at compile-time
 | 
			
		||||
into the primitive types that TorchScript supports. This depends
 | 
			
		||||
on the dynamic type of the Python valued referenced when compilation occurs.
 | 
			
		||||
This section describes the rules that are used when accessing Python values in TorchScript.
 | 
			
		||||
 | 
			
		||||
Functions
 | 
			
		||||
^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
  TorchScript can call Python functions. This functionality is very useful when
 | 
			
		||||
  incrementally converting a model into script. The model can be moved function-by-function
 | 
			
		||||
  to script, leaving calls to Python functions in place. This way you can incrementally
 | 
			
		||||
  incrementally converting a model to TorchScript. The model can be moved function-by-function
 | 
			
		||||
  to TorchScript, leaving calls to Python functions in place. This way you can incrementally
 | 
			
		||||
  check the correctness of the model as you go.
 | 
			
		||||
 | 
			
		||||
  Example::
 | 
			
		||||
@ -581,10 +633,37 @@ Functions
 | 
			
		||||
      def bar(x)
 | 
			
		||||
        return foo(x + 1)
 | 
			
		||||
 | 
			
		||||
  .. note::
 | 
			
		||||
    Attempting to call ``save`` on a ScriptModule that contains calls to Python
 | 
			
		||||
    functions will fail. The intention is that this pathway is used for debugging
 | 
			
		||||
    and the calls removed or turned into script functions before saving.
 | 
			
		||||
  Attempting to call ``save`` on a ScriptModule that contains calls to Python
 | 
			
		||||
  functions will fail. The intention is that this pathway is used for debugging
 | 
			
		||||
  and the calls removed or turned into script functions before saving. If you
 | 
			
		||||
  want to export a module with a Python function, add the ``@torch.jit.ignore``
 | 
			
		||||
  decorator to the function which will replace these function calls with an
 | 
			
		||||
  exception when the model is saved: ::
 | 
			
		||||
 | 
			
		||||
      class M(torch.jit.ScriptModule):
 | 
			
		||||
        def __init__(self):
 | 
			
		||||
          super(M, self).__init__()
 | 
			
		||||
 | 
			
		||||
        @torch.jit.script_method
 | 
			
		||||
        def forward(self, x):
 | 
			
		||||
          self.ignored_code(x)
 | 
			
		||||
          return x + 2
 | 
			
		||||
 | 
			
		||||
        @torch.jit.ignore
 | 
			
		||||
        def ignored_code(self, x):
 | 
			
		||||
          # non-TorchScript code
 | 
			
		||||
          import pdb; pdb.set_trace()
 | 
			
		||||
 | 
			
		||||
      m = M()
 | 
			
		||||
      # Runs, makes upcall to Python to run `ignored_code`
 | 
			
		||||
      m(torch.ones(2, 2))
 | 
			
		||||
 | 
			
		||||
      # Replaces all calls to `ignored_code` with a `raise`
 | 
			
		||||
      m.save("m.pt")
 | 
			
		||||
      loaded = torch.jit.load("m.pt")
 | 
			
		||||
 | 
			
		||||
      # This runs `ignored_code` after saving which will raise an Exception!
 | 
			
		||||
      loaded(torch.ones(2, 2))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Attribute Lookup On Python Modules
 | 
			
		||||
@ -621,6 +700,7 @@ Python-defined Constants
 | 
			
		||||
    Supported constant Python Values are
 | 
			
		||||
 | 
			
		||||
    * ``int``
 | 
			
		||||
    * ``float``
 | 
			
		||||
    * ``bool``
 | 
			
		||||
    * ``torch.device``
 | 
			
		||||
    * ``torch.layout``
 | 
			
		||||
@ -629,6 +709,31 @@ Python-defined Constants
 | 
			
		||||
    * ``torch.nn.ModuleList`` which can be used in a TorchScript for loop
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Module Attributes
 | 
			
		||||
^^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
The ``torch.nn.Parameter`` wrapper and ``register_buffer`` can be used to assign
 | 
			
		||||
tensors to a ``ScriptModule``. In a similar vein, attributes of any type can be
 | 
			
		||||
assign on a ``ScriptModule`` by wrapping them with ``torch.jit.Attribute`` and
 | 
			
		||||
specifying the type. All types available in TorchScript are supported. These
 | 
			
		||||
attributes are mutable and are saved in a separate archive in the serialized
 | 
			
		||||
model binary. Tensor attributes are semantically the same as buffers.
 | 
			
		||||
 | 
			
		||||
Example::
 | 
			
		||||
 | 
			
		||||
    class Foo(torch.jit.ScriptModule):
 | 
			
		||||
      def __init__(self, a_dict):
 | 
			
		||||
        super(Foo, self).__init__(False)
 | 
			
		||||
        self.words = torch.jit.Attribute([], List[str])
 | 
			
		||||
        self.some_dict = torch.jit.Attribute(a_dict, Dict[str, int])
 | 
			
		||||
 | 
			
		||||
      @torch.jit.script_method
 | 
			
		||||
      def forward(self, input):
 | 
			
		||||
        # type: (str) -> int
 | 
			
		||||
        self.words.append(input)
 | 
			
		||||
        return self.some_dict[input]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Debugging
 | 
			
		||||
~~~~~~~~~
 | 
			
		||||
 | 
			
		||||
@ -655,21 +760,21 @@ Disable JIT for Debugging
 | 
			
		||||
 | 
			
		||||
        traced_fn(torch.rand(3, 4))
 | 
			
		||||
 | 
			
		||||
    Debugging this script with PDB works except for when we invoke the @script
 | 
			
		||||
    function. We can globally disable JIT, so that we can call the @script
 | 
			
		||||
    Debugging this script with PDB works except for when we invoke the ``@torch.jit.script``
 | 
			
		||||
    function. We can globally disable JIT, so that we can call the ``@torch.jit.script``
 | 
			
		||||
    function as a normal python function and not compile it. If the above script
 | 
			
		||||
    is called ``disable_jit_example.py``, we can invoke it like so::
 | 
			
		||||
 | 
			
		||||
        $ PYTORCH_JIT=0 python disable_jit_example.py
 | 
			
		||||
 | 
			
		||||
    and we will be able to step into the @script function as a normal Python
 | 
			
		||||
    and we will be able to step into the ``@torch.jit.script`` function as a normal Python
 | 
			
		||||
    function.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Inspecting Code
 | 
			
		||||
^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
    TorchScript provides a code pretty-printer for all ScriptModule instances. This
 | 
			
		||||
    TorchScript provides a code pretty-printer for all ``ScriptModule`` instances. This
 | 
			
		||||
    pretty-printer gives an interpretation of the script method's code as valid
 | 
			
		||||
    Python syntax. For example::
 | 
			
		||||
 | 
			
		||||
@ -688,11 +793,11 @@ Inspecting Code
 | 
			
		||||
 | 
			
		||||
    A ``ScriptModule`` with a single ``forward`` method will have an attribute
 | 
			
		||||
    ``code``, which you can use to inspect the ``ScriptModule``'s code.
 | 
			
		||||
    If the ScriptModule has more than one method, you will need to access
 | 
			
		||||
    If the ``ScriptModule`` has more than one method, you will need to access
 | 
			
		||||
    ``.code`` on the method itself and not the module. We can inspect the
 | 
			
		||||
    code of a method named ``bar`` on a ScriptModule by accessing ``.bar.code``.
 | 
			
		||||
 | 
			
		||||
    The example script abouve produces the code::
 | 
			
		||||
    The example script above produces the code::
 | 
			
		||||
 | 
			
		||||
        def forward(self,
 | 
			
		||||
                    len: int) -> Tensor:
 | 
			
		||||
@ -706,7 +811,7 @@ Inspecting Code
 | 
			
		||||
                rv0 = rv1
 | 
			
		||||
            return rv0
 | 
			
		||||
 | 
			
		||||
    This is TorchScript's interpretation of the code for the ``forward`` method.
 | 
			
		||||
    This is TorchScript's compilation of the code for the ``forward`` method.
 | 
			
		||||
    You can use this to ensure TorchScript (tracing or scripting) has captured
 | 
			
		||||
    your model code correctly.
 | 
			
		||||
 | 
			
		||||
@ -734,7 +839,7 @@ Interpreting Graphs
 | 
			
		||||
 | 
			
		||||
        print(foo.graph)
 | 
			
		||||
 | 
			
		||||
    ``.graph`` follows the same rules described in the Inspecting Code section
 | 
			
		||||
    ``.graph`` follows the same rules described in the `Inspecting Code`_ section
 | 
			
		||||
    with regard to ``forward`` method lookup.
 | 
			
		||||
 | 
			
		||||
    The example script above produces the graph::
 | 
			
		||||
@ -949,9 +1054,9 @@ best practices?
 | 
			
		||||
      # ... later, when using the model:
 | 
			
		||||
 | 
			
		||||
      if use_gpu:
 | 
			
		||||
         model = torch.jit.load("gpu.pth")
 | 
			
		||||
        model = torch.jit.load("gpu.pth")
 | 
			
		||||
      else:
 | 
			
		||||
         model = torch.jit.load("cpu.pth")
 | 
			
		||||
        model = torch.jit.load("cpu.pth")
 | 
			
		||||
 | 
			
		||||
      model(input)
 | 
			
		||||
 | 
			
		||||
@ -961,6 +1066,40 @@ best practices?
 | 
			
		||||
   the correct device information.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Q: How do I store attributes on a ``ScriptModule``?
 | 
			
		||||
 | 
			
		||||
    Say we have a model like: ::
 | 
			
		||||
 | 
			
		||||
      class Model(torch.jit.ScriptModule):
 | 
			
		||||
        def __init__(self):
 | 
			
		||||
          super(Model, self).__init__()
 | 
			
		||||
          self.x = 2
 | 
			
		||||
 | 
			
		||||
        @torch.jit.script_method
 | 
			
		||||
        def forward(self):
 | 
			
		||||
          return self.x
 | 
			
		||||
 | 
			
		||||
    If ``Model`` is instantiated it will result in a compilation error
 | 
			
		||||
    since the compiler doesn't know about ``x``. There are 4 ways to inform the
 | 
			
		||||
    compiler of attributes on ``ScriptModule``:
 | 
			
		||||
 | 
			
		||||
    1. ``nn.Parameter`` - values wrapped in ``nn.Parameter`` will work as they
 | 
			
		||||
    do on ``nn.Module``\s
 | 
			
		||||
 | 
			
		||||
    2. ``register_buffer`` - values wrapped in ``register_buffer`` will work as
 | 
			
		||||
    they do on ``nn.Module``\s
 | 
			
		||||
 | 
			
		||||
    3. ``__constants__`` - adding a list called ``__constants__`` at the
 | 
			
		||||
    class definition level will mark the contained names as constants. Constants
 | 
			
		||||
    are saved directly in the code of the model. See
 | 
			
		||||
    `Python-defined Constants`_.
 | 
			
		||||
 | 
			
		||||
    4. ``torch.jit.Attribute`` - values wrapped in ``torch.jit.Attribute`` can
 | 
			
		||||
    be any ``TorchScript`` type, be mutated and are saved outside of the code of
 | 
			
		||||
    the model. See `Module Attributes`_.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Builtin Functions
 | 
			
		||||
~~~~~~~~~~~~~~~~~
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -530,7 +530,7 @@ Linear layers
 | 
			
		||||
----------------------------------
 | 
			
		||||
 | 
			
		||||
:hidden:`Identity`
 | 
			
		||||
~~~~~~~~~~~~~~~~
 | 
			
		||||
~~~~~~~~~~~~~~~~~~
 | 
			
		||||
 | 
			
		||||
.. autoclass:: Identity
 | 
			
		||||
    :members:
 | 
			
		||||
 | 
			
		||||
@ -20,6 +20,7 @@ and visualization by TensorBoard. For example:
 | 
			
		||||
 | 
			
		||||
.. code:: python
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    import torch
 | 
			
		||||
    import torchvision
 | 
			
		||||
    from torch.utils.tensorboard import SummaryWriter
 | 
			
		||||
@ -31,7 +32,9 @@ and visualization by TensorBoard. For example:
 | 
			
		||||
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
 | 
			
		||||
    trainset = datasets.MNIST('mnist_train', train=True, download=True, transform=transform)
 | 
			
		||||
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
 | 
			
		||||
    model = torchvision.models.vgg16(False)
 | 
			
		||||
    model = torchvision.models.resnet50(False)
 | 
			
		||||
    # Have ResNet model take in grayscale rather than RGB
 | 
			
		||||
    model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
 | 
			
		||||
    images, labels = next(iter(trainloader))
 | 
			
		||||
 | 
			
		||||
    grid = torchvision.utils.make_grid(images)
 | 
			
		||||
@ -39,8 +42,24 @@ and visualization by TensorBoard. For example:
 | 
			
		||||
    writer.add_graph(model, images)
 | 
			
		||||
    writer.close()
 | 
			
		||||
 | 
			
		||||
This can then be visualized with TensorBoard, which should be installed
 | 
			
		||||
with ``pip install tensorboard`` or equivalent.
 | 
			
		||||
This can then be visualized with TensorBoard, which should be installable
 | 
			
		||||
and runnable with::
 | 
			
		||||
 | 
			
		||||
    pip install tb-nightly  # Until 1.14 moves to the release channel
 | 
			
		||||
    tensorboard --logdir=runs
 | 
			
		||||
 | 
			
		||||
.. currentmodule:: torch.utils.tensorboard.writer
 | 
			
		||||
 | 
			
		||||
.. autoclass:: SummaryWriter
 | 
			
		||||
 | 
			
		||||
   .. automethod:: add_scalar
 | 
			
		||||
   .. automethod:: add_histogram
 | 
			
		||||
   .. automethod:: add_image
 | 
			
		||||
   .. automethod:: add_figure
 | 
			
		||||
   .. automethod:: add_video
 | 
			
		||||
   .. automethod:: add_audio
 | 
			
		||||
   .. automethod:: add_text
 | 
			
		||||
   .. automethod:: add_graph
 | 
			
		||||
   .. automethod:: add_embedding
 | 
			
		||||
   .. automethod:: add_pr_curve
 | 
			
		||||
   .. automethod:: add_custom_scalars
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										40
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										40
									
								
								setup.py
									
									
									
									
									
								
							@ -229,7 +229,7 @@ cmake_python_include_dir = distutils.sysconfig.get_python_inc()
 | 
			
		||||
# Version, create_version_file, and package_name
 | 
			
		||||
################################################################################
 | 
			
		||||
package_name = os.getenv('TORCH_PACKAGE_NAME', 'torch')
 | 
			
		||||
version = '1.1.0a0'
 | 
			
		||||
version = '1.1.0'
 | 
			
		||||
sha = 'Unknown'
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
@ -243,8 +243,8 @@ if os.getenv('PYTORCH_BUILD_VERSION'):
 | 
			
		||||
    version = os.getenv('PYTORCH_BUILD_VERSION')
 | 
			
		||||
    if build_number > 1:
 | 
			
		||||
        version += '.post' + str(build_number)
 | 
			
		||||
elif sha != 'Unknown':
 | 
			
		||||
    version += '+' + sha[:7]
 | 
			
		||||
# elif sha != 'Unknown':
 | 
			
		||||
#     version += '+' + sha[:7]
 | 
			
		||||
report("Building wheel {}-{}".format(package_name, version))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -583,10 +583,16 @@ main_sources = ["torch/csrc/stub.cpp"]
 | 
			
		||||
# before libcaffe2.so in the linker command.
 | 
			
		||||
main_link_args.extend(CAFFE2_LIBS)
 | 
			
		||||
 | 
			
		||||
install_requires=[]
 | 
			
		||||
 | 
			
		||||
if sys.version_info[0] == 2:
 | 
			
		||||
    install_requires.append('future')
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    import numpy as np
 | 
			
		||||
    NUMPY_INCLUDE_DIR = np.get_include()
 | 
			
		||||
    USE_NUMPY = True
 | 
			
		||||
    install_requires.append('numpy')
 | 
			
		||||
except ImportError:
 | 
			
		||||
    USE_NUMPY = False
 | 
			
		||||
 | 
			
		||||
@ -726,6 +732,7 @@ if __name__ == '__main__':
 | 
			
		||||
        cmdclass=cmdclass,
 | 
			
		||||
        packages=packages,
 | 
			
		||||
        entry_points=entry_points,
 | 
			
		||||
        install_requires=install_requires,
 | 
			
		||||
        package_data={
 | 
			
		||||
            'torch': [
 | 
			
		||||
                'py.typed',
 | 
			
		||||
@ -818,6 +825,33 @@ if __name__ == '__main__':
 | 
			
		||||
                'python/serialized_test/data/operator_test/*.zip',
 | 
			
		||||
            ]
 | 
			
		||||
        },
 | 
			
		||||
        url='https://pytorch.org/',
 | 
			
		||||
        download_url='https://github.com/pytorch/pytorch/tags',
 | 
			
		||||
        author='PyTorch Team',
 | 
			
		||||
        author_email='packages@pytorch.org',
 | 
			
		||||
        # PyPI package information.
 | 
			
		||||
        classifiers=[
 | 
			
		||||
            'Development Status :: 5 - Production/Stable',
 | 
			
		||||
            'Intended Audience :: Developers',
 | 
			
		||||
            'Intended Audience :: Education',
 | 
			
		||||
            'Intended Audience :: Science/Research',
 | 
			
		||||
            'License :: OSI Approved :: BSD License',
 | 
			
		||||
            'Programming Language :: C++',
 | 
			
		||||
            'Programming Language :: Python :: 2',
 | 
			
		||||
            'Programming Language :: Python :: 2.7',
 | 
			
		||||
            'Programming Language :: Python :: 3',
 | 
			
		||||
            'Programming Language :: Python :: 3.5',
 | 
			
		||||
            'Programming Language :: Python :: 3.6',
 | 
			
		||||
            'Programming Language :: Python :: 3.7',
 | 
			
		||||
            'Topic :: Scientific/Engineering',
 | 
			
		||||
            'Topic :: Scientific/Engineering :: Mathematics',
 | 
			
		||||
            'Topic :: Scientific/Engineering :: Artificial Intelligence',
 | 
			
		||||
            'Topic :: Software Development',
 | 
			
		||||
            'Topic :: Software Development :: Libraries',
 | 
			
		||||
            'Topic :: Software Development :: Libraries :: Python Modules',
 | 
			
		||||
        ],
 | 
			
		||||
        license='BSD',
 | 
			
		||||
        keywords='pytorch machine learning',
 | 
			
		||||
    )
 | 
			
		||||
    if EMIT_BUILD_WARNING:
 | 
			
		||||
        print_box(build_update_message)
 | 
			
		||||
 | 
			
		||||
@ -2176,20 +2176,27 @@ class DistributedDataParallelTest(MultiProcessTestCase):
 | 
			
		||||
        input = torch.rand([batch_size, 2], dtype=torch.float)
 | 
			
		||||
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id)
 | 
			
		||||
 | 
			
		||||
        def test_find_unused_parameters(find_unused_parameters):
 | 
			
		||||
            model = DistributedDataParallel(
 | 
			
		||||
                FindUnusedParametersModule().float().to(device_id),
 | 
			
		||||
                device_ids=[device_id],
 | 
			
		||||
                process_group=process_group,
 | 
			
		||||
                find_unused_parameters=find_unused_parameters,
 | 
			
		||||
            )
 | 
			
		||||
        def test_find_unused_parameters(find_unused_parameters, test_default=False):
 | 
			
		||||
            if test_default:
 | 
			
		||||
                model = DistributedDataParallel(
 | 
			
		||||
                    FindUnusedParametersModule().float().to(device_id),
 | 
			
		||||
                    device_ids=[device_id],
 | 
			
		||||
                    process_group=process_group,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                model = DistributedDataParallel(
 | 
			
		||||
                    FindUnusedParametersModule().float().to(device_id),
 | 
			
		||||
                    device_ids=[device_id],
 | 
			
		||||
                    process_group=process_group,
 | 
			
		||||
                    find_unused_parameters=find_unused_parameters,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            output, fc3 = model(input)
 | 
			
		||||
            output = fc3(output)
 | 
			
		||||
            loss = criterion(output, target)
 | 
			
		||||
            loss.backward()
 | 
			
		||||
 | 
			
		||||
        # First test that the default behavior under these conditions is to
 | 
			
		||||
        # First test that finding unused params under these conditions is to
 | 
			
		||||
        # trigger an error when `backward` is called (because fc3 is an unused
 | 
			
		||||
        # parameter and will therefore be marked ready twice).
 | 
			
		||||
        try:
 | 
			
		||||
@ -2207,6 +2214,12 @@ class DistributedDataParallelTest(MultiProcessTestCase):
 | 
			
		||||
        except Exception as ex:
 | 
			
		||||
            self.fail("Unexpected exception: %s" % ex)
 | 
			
		||||
 | 
			
		||||
        # Test find_unused_parameters defaults to False
 | 
			
		||||
        try:
 | 
			
		||||
            test_find_unused_parameters(True, test_default=True)
 | 
			
		||||
        except Exception as ex:
 | 
			
		||||
            self.fail("Unexpected exception: %s" % ex)
 | 
			
		||||
 | 
			
		||||
    @skip_if_not_nccl
 | 
			
		||||
    @skip_if_not_multigpu
 | 
			
		||||
    def test_multiple_outputs_multiple_backward(self):
 | 
			
		||||
@ -2257,6 +2270,151 @@ class DistributedDataParallelTest(MultiProcessTestCase):
 | 
			
		||||
        loss2 = criterion(output2, target)
 | 
			
		||||
        loss2.backward()
 | 
			
		||||
 | 
			
		||||
    @skip_if_not_nccl
 | 
			
		||||
    @skip_if_not_multigpu
 | 
			
		||||
    def test_no_used_parameters(self):
 | 
			
		||||
        """
 | 
			
		||||
        Note: this test can be sped up by only running it on a CPU module
 | 
			
		||||
        once DistributedDataParallel supports them.
 | 
			
		||||
        """
 | 
			
		||||
        store = c10d.FileStore(self.file.name, self.world_size)
 | 
			
		||||
        process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
 | 
			
		||||
 | 
			
		||||
        class NoUsedParameters(nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super(NoUsedParameters, self).__init__()
 | 
			
		||||
 | 
			
		||||
                # Make sure this module has some parameters, only to then decide
 | 
			
		||||
                # to never use them from the `forward` function.
 | 
			
		||||
                self.fc1 = nn.Linear(2, 10, bias=False)
 | 
			
		||||
                self.fc2 = nn.Linear(10, 4, bias=False)
 | 
			
		||||
                self.fc3 = nn.Linear(4, 4, bias=False)
 | 
			
		||||
                self.relu = nn.ReLU()
 | 
			
		||||
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                return x * 0.0
 | 
			
		||||
 | 
			
		||||
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
 | 
			
		||||
        model = DistributedDataParallel(
 | 
			
		||||
            NoUsedParameters().float().to(device_id),
 | 
			
		||||
            device_ids=[device_id],
 | 
			
		||||
            process_group=process_group,
 | 
			
		||||
            find_unused_parameters=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        batch_size = 4
 | 
			
		||||
        input = torch.rand([batch_size, 2], dtype=torch.float)
 | 
			
		||||
 | 
			
		||||
        # After initialization, no parameter has their gradient set.
 | 
			
		||||
        for p in model.parameters():
 | 
			
		||||
            self.assertTrue(p.requires_grad)
 | 
			
		||||
            self.assertIsNone(p.grad)
 | 
			
		||||
 | 
			
		||||
        # Run `forward` function.
 | 
			
		||||
        model(input)
 | 
			
		||||
 | 
			
		||||
        # Because none of the parameters were used, we expect reduction for
 | 
			
		||||
        # all parameters will be executed right when initializing the reducer.
 | 
			
		||||
        # Once `forward` returns, all the parameter's gradients must be set.
 | 
			
		||||
        for p in model.parameters():
 | 
			
		||||
            self.assertTrue(p.requires_grad)
 | 
			
		||||
            self.assertIsNotNone(p.grad)
 | 
			
		||||
            self.assertTrue(torch.is_tensor(p.grad))
 | 
			
		||||
            self.assertEqual(p.size(), p.grad.size())
 | 
			
		||||
 | 
			
		||||
    @skip_if_not_nccl
 | 
			
		||||
    @skip_if_not_multigpu
 | 
			
		||||
    def test_no_grad(self):
 | 
			
		||||
        """
 | 
			
		||||
        Note: this test can be sped up by only running it on a CPU module
 | 
			
		||||
        once DistributedDataParallel supports them.
 | 
			
		||||
        """
 | 
			
		||||
        store = c10d.FileStore(self.file.name, self.world_size)
 | 
			
		||||
        process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
 | 
			
		||||
 | 
			
		||||
        class NoGradModule(nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super(NoGradModule, self).__init__()
 | 
			
		||||
                self.fc1 = nn.Linear(2, 10, bias=False)
 | 
			
		||||
                self.fc2 = nn.Linear(10, 4, bias=False)
 | 
			
		||||
                self.relu = nn.ReLU()
 | 
			
		||||
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                x = self.relu(self.fc1(x))
 | 
			
		||||
                x = self.relu(self.fc2(x))
 | 
			
		||||
                return F.softmax(x, dim=1)
 | 
			
		||||
 | 
			
		||||
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
 | 
			
		||||
        model = DistributedDataParallel(
 | 
			
		||||
            NoGradModule().float().to(device_id),
 | 
			
		||||
            device_ids=[device_id],
 | 
			
		||||
            process_group=process_group,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        batch_size = 4
 | 
			
		||||
        input = torch.rand([batch_size, 2], dtype=torch.float)
 | 
			
		||||
 | 
			
		||||
        def check_no_grads():
 | 
			
		||||
            for p in model.parameters():
 | 
			
		||||
                self.assertTrue(p.requires_grad)
 | 
			
		||||
                self.assertIsNone(p.grad)
 | 
			
		||||
 | 
			
		||||
        # After initialization, no parameter has their gradient set.
 | 
			
		||||
        check_no_grads()
 | 
			
		||||
 | 
			
		||||
        # Run `forward` function with torch.no_grad()
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            output = model(input)
 | 
			
		||||
            self.assertTrue(torch.is_tensor(output))
 | 
			
		||||
 | 
			
		||||
        # No parameter should have their gradient set.
 | 
			
		||||
        check_no_grads()
 | 
			
		||||
 | 
			
		||||
    @skip_if_not_nccl
 | 
			
		||||
    @skip_if_not_multigpu
 | 
			
		||||
    def test_ignored_output(self):
 | 
			
		||||
        """
 | 
			
		||||
        Note: this test can be sped up by only running it on a CPU module
 | 
			
		||||
        once DistributedDataParallel supports them.
 | 
			
		||||
        """
 | 
			
		||||
        store = c10d.FileStore(self.file.name, self.world_size)
 | 
			
		||||
        process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
 | 
			
		||||
 | 
			
		||||
        class IgnoredOutput(nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super(IgnoredOutput, self).__init__()
 | 
			
		||||
                self.fc1 = nn.Linear(2, 10, bias=False)
 | 
			
		||||
                self.fc2 = nn.Linear(10, 4, bias=False)
 | 
			
		||||
                self.relu = nn.ReLU()
 | 
			
		||||
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                x = self.relu(self.fc1(x))
 | 
			
		||||
                x = self.relu(self.fc2(x))
 | 
			
		||||
                return F.softmax(x, dim=1)
 | 
			
		||||
 | 
			
		||||
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
 | 
			
		||||
        model = DistributedDataParallel(
 | 
			
		||||
            IgnoredOutput().float().to(device_id),
 | 
			
		||||
            device_ids=[device_id],
 | 
			
		||||
            process_group=process_group,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        batch_size = 4
 | 
			
		||||
        criterion = nn.CrossEntropyLoss()
 | 
			
		||||
        input = torch.rand([batch_size, 2], dtype=torch.float)
 | 
			
		||||
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id)
 | 
			
		||||
 | 
			
		||||
        # Run a few iterations where we ignore the output.
 | 
			
		||||
        for _ in range(4):
 | 
			
		||||
            output = model(input)
 | 
			
		||||
            del output
 | 
			
		||||
 | 
			
		||||
        # Run a few iterations where we use the output.
 | 
			
		||||
        for _ in range(4):
 | 
			
		||||
            output = model(input)
 | 
			
		||||
            loss = criterion(output, target)
 | 
			
		||||
            loss.backward()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReducerModule(nn.Module):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
 | 
			
		||||
@ -3648,20 +3648,14 @@ a")
 | 
			
		||||
        check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)
 | 
			
		||||
 | 
			
		||||
    def test_tensor_item(self):
 | 
			
		||||
        def test_scalar_to_float_coercion(x):
 | 
			
		||||
            return x.item() == 1
 | 
			
		||||
 | 
			
		||||
        self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1.0),))
 | 
			
		||||
        self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1),))
 | 
			
		||||
 | 
			
		||||
        def test_scalar_cast(x):
 | 
			
		||||
            scalar = x.item()
 | 
			
		||||
            return int(scalar), float(scalar)
 | 
			
		||||
 | 
			
		||||
        graph = torch.jit.script(test_scalar_cast).graph
 | 
			
		||||
        FileCheck().check("(int, float) = prim::TupleConstruct").run(graph)
 | 
			
		||||
        self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1.0),))
 | 
			
		||||
        self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1),))
 | 
			
		||||
        self.checkScript(test_scalar_cast, (torch.tensor(1.0),))
 | 
			
		||||
        self.checkScript(test_scalar_cast, (torch.tensor(1),))
 | 
			
		||||
 | 
			
		||||
        expected_str = r"Use int\(tensor\) or float\(tensor\) to retrieve"
 | 
			
		||||
        with self.assertRaisesRegex(RuntimeError, expected_str):
 | 
			
		||||
@ -11730,6 +11724,8 @@ a")
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
 | 
			
		||||
    def test_attribute_unpickling(self):
 | 
			
		||||
        tensor = torch.randn(2, 2)
 | 
			
		||||
 | 
			
		||||
        class M(torch.jit.ScriptModule):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super(M, self).__init__()
 | 
			
		||||
@ -11738,37 +11734,24 @@ a")
 | 
			
		||||
                self.int = torch.jit.Attribute(99, int)
 | 
			
		||||
                self.tuple = torch.jit.Attribute((1, 2, 3, 4), Tuple[int, int, int, int])
 | 
			
		||||
                self.list = torch.jit.Attribute([(1, 2), (3, 4)], List[Tuple[int, int]])
 | 
			
		||||
                self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor)
 | 
			
		||||
                self.tensor = torch.jit.Attribute(tensor, torch.Tensor)
 | 
			
		||||
                self.int_list = torch.jit.Attribute([1, 2, 3, 4], List[int])
 | 
			
		||||
 | 
			
		||||
            @torch.jit.script_method
 | 
			
		||||
            def forward(self):
 | 
			
		||||
                return (self.table, self.float, self.int, self.tuple, self.list, self.int_list)
 | 
			
		||||
 | 
			
		||||
        class TensorID(object):
 | 
			
		||||
            def __setstate__(self, id):
 | 
			
		||||
                self.id = id
 | 
			
		||||
 | 
			
		||||
        class IntList(object):
 | 
			
		||||
            def __setstate__(self, data):
 | 
			
		||||
                self.data = data
 | 
			
		||||
 | 
			
		||||
        class JitUnpickler(pickle.Unpickler):
 | 
			
		||||
            def find_class(self, module, name):
 | 
			
		||||
                if not module == '__main__':
 | 
			
		||||
                    return None
 | 
			
		||||
 | 
			
		||||
                if name == 'TensorID':
 | 
			
		||||
                    return TensorID
 | 
			
		||||
                elif name == 'IntList':
 | 
			
		||||
                    return IntList
 | 
			
		||||
 | 
			
		||||
        with TemporaryFileName() as fname:
 | 
			
		||||
            M().save(fname)
 | 
			
		||||
            archive_name = os.path.basename(os.path.normpath(fname))
 | 
			
		||||
            archive = zipfile.ZipFile(fname, 'r')
 | 
			
		||||
            pickled_data = archive.read(os.path.join(archive_name, 'attributes.pkl'))
 | 
			
		||||
            JitUnpickler(io.BytesIO(pickled_data)).load()
 | 
			
		||||
            out = pickle.load(io.BytesIO(pickled_data))
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(out[0], {"I": "am", "a test": "test"})
 | 
			
		||||
            self.assertEqual(out[1], 2.3)
 | 
			
		||||
            self.assertEqual(out[2], 99)
 | 
			
		||||
            self.assertEqual(out[6], [1, 2, 3, 4])
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
 | 
			
		||||
    def test_old_models_bc(self):
 | 
			
		||||
 | 
			
		||||
@ -891,6 +891,18 @@ class _TestTorchMixin(object):
 | 
			
		||||
    def test_max(self):
 | 
			
		||||
        self._testSelection(torch.max, max)
 | 
			
		||||
 | 
			
		||||
    def test_log_normal(self):
 | 
			
		||||
        for device in torch.testing.get_all_device_types():
 | 
			
		||||
            a = torch.tensor([10], dtype=torch.float, device=device).log_normal_()
 | 
			
		||||
            self.assertEqual(a.dtype, torch.float)
 | 
			
		||||
            self.assertEqual(a.size(), torch.Size([1]))
 | 
			
		||||
 | 
			
		||||
    def test_geometric(self):
 | 
			
		||||
        for device in torch.testing.get_all_device_types():
 | 
			
		||||
            a = torch.tensor([10], dtype=torch.float, device=device).geometric_(0.5)
 | 
			
		||||
            self.assertEqual(a.dtype, torch.float)
 | 
			
		||||
            self.assertEqual(a.size(), torch.Size([1]))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _test_max_with_inf(self, dtypes=(torch.float, torch.double), device='cpu'):
 | 
			
		||||
        for dtype in dtypes:
 | 
			
		||||
 | 
			
		||||
@ -45,6 +45,7 @@ Reducer::Reducer(
 | 
			
		||||
    : replicas_(std::move(replicas)),
 | 
			
		||||
      process_group_(std::move(process_group)),
 | 
			
		||||
      expect_autograd_hooks_(false),
 | 
			
		||||
      require_finalize_(false),
 | 
			
		||||
      has_marked_unused_parameters_(false),
 | 
			
		||||
      next_bucket_(0),
 | 
			
		||||
      backward_stats_base_(0) {
 | 
			
		||||
@ -160,6 +161,12 @@ void Reducer::mark_variable_ready(
 | 
			
		||||
  backward_stats_[replica_index][variable_index] =
 | 
			
		||||
      current_time_in_nanos() - backward_stats_base_;
 | 
			
		||||
 | 
			
		||||
  // Any time we mark a variable ready (be it in line due to unused parameters,
 | 
			
		||||
  // or via an autograd hook), we require a call to the finalize function. If
 | 
			
		||||
  // this doesn't happen before the next iteration (or call to
 | 
			
		||||
  // `prepare_for_backwards`), we know something is wrong.
 | 
			
		||||
  require_finalize_ = true;
 | 
			
		||||
 | 
			
		||||
  const auto& bucket_index = variable_locators_[variable_index];
 | 
			
		||||
  auto& bucket = buckets_[bucket_index.bucket_index];
 | 
			
		||||
  auto& replica = bucket.replicas[replica_index];
 | 
			
		||||
@ -228,12 +235,16 @@ void Reducer::mark_variable_ready(
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Queue function to finalize once the final bucket was marked ready.
 | 
			
		||||
  // Run finalizer function once the final bucket was marked ready.
 | 
			
		||||
  if (next_bucket_ == buckets_.size()) {
 | 
			
		||||
    // Autograd callbacks can only be registered while the engine is running.
 | 
			
		||||
    AT_ASSERT(called_from_autograd);
 | 
			
		||||
    torch::autograd::Engine::get_default_engine().queue_callback(
 | 
			
		||||
        [=] { this->finalize_backward(); });
 | 
			
		||||
    if (called_from_autograd) {
 | 
			
		||||
      torch::autograd::Engine::get_default_engine().queue_callback([=] {
 | 
			
		||||
        std::lock_guard<std::mutex> lock(this->mutex_);
 | 
			
		||||
        this->finalize_backward();
 | 
			
		||||
      });
 | 
			
		||||
    } else {
 | 
			
		||||
      finalize_backward();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -375,6 +386,28 @@ void Reducer::prepare_for_backward(
 | 
			
		||||
  std::unordered_set<torch::autograd::Function*> seen;
 | 
			
		||||
  std::vector<torch::autograd::Function*> queue;
 | 
			
		||||
 | 
			
		||||
  // Check that any prior reduction has finished.
 | 
			
		||||
  // The variable `expect_autograd_hooks` is true until gradients for all
 | 
			
		||||
  // parameters have been received and all buckets are ready.
 | 
			
		||||
  if (require_finalize_) {
 | 
			
		||||
    AT_ERROR(
 | 
			
		||||
        "Expected to have finished reduction in the prior iteration before ",
 | 
			
		||||
        "starting a new one. ",
 | 
			
		||||
        "",
 | 
			
		||||
        "This error indicates that your module has parameters that were ",
 | 
			
		||||
        "not used in producing its output (the return value of `forward`). ",
 | 
			
		||||
        "",
 | 
			
		||||
        "You can enable unused parameter detection by passing the keyword "
 | 
			
		||||
        "argument `find_unused_parameters=True` to ",
 | 
			
		||||
        "`torch.nn.parallel.DistributedDataParallel`. ",
 | 
			
		||||
        "",
 | 
			
		||||
        "If you already have this argument set, then the distributed data ",
 | 
			
		||||
        "parallel module wasn't able to locate the output tensors in the ",
 | 
			
		||||
        "return value of your module's `forward` function. ",
 | 
			
		||||
        "Please include the structure of the return value of `forward` of ",
 | 
			
		||||
        "your module when reporting this issue (e.g. list, dict, iterable).");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Reset accounting.
 | 
			
		||||
  has_marked_unused_parameters_ = true;
 | 
			
		||||
  expect_autograd_hooks_ = true;
 | 
			
		||||
@ -433,34 +466,16 @@ void Reducer::prepare_for_backward(
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Reducer::finalize_backward() {
 | 
			
		||||
  std::lock_guard<std::mutex> lock(mutex_);
 | 
			
		||||
 | 
			
		||||
  // No longer expect autograd hooks to fire after this function returns.
 | 
			
		||||
  AT_ASSERT(expect_autograd_hooks_);
 | 
			
		||||
  expect_autograd_hooks_ = false;
 | 
			
		||||
 | 
			
		||||
  // No longer require call to finalize after this function returns.
 | 
			
		||||
  AT_ASSERT(require_finalize_);
 | 
			
		||||
  require_finalize_ = false;
 | 
			
		||||
 | 
			
		||||
  // Check that all buckets were completed and had their work kicked off.
 | 
			
		||||
  if (next_bucket_ < buckets_.size()) {
 | 
			
		||||
    // If the reducer marked unused parameters and we STILL didn't get
 | 
			
		||||
    // gradients for all module parameters, something is seriously wrong.
 | 
			
		||||
    AT_ASSERT(!has_marked_unused_parameters_);
 | 
			
		||||
    AT_ERROR(
 | 
			
		||||
        "Expected to have gradients for all module parameters upon returning ",
 | 
			
		||||
        "from the call to `torch.autograd.backward`. ",
 | 
			
		||||
        "",
 | 
			
		||||
        "This error indicates that your module has parameters that were ",
 | 
			
		||||
        "not used in producing its output (the return value of `forward`). ",
 | 
			
		||||
        "",
 | 
			
		||||
        "You can enable unused parameter detection by passing the keyword "
 | 
			
		||||
        "argument `find_unused_parameters=True` to ",
 | 
			
		||||
        "`torch.nn.parallel.DistributedDataParallel`. ",
 | 
			
		||||
        "",
 | 
			
		||||
        "If you already have this argument set, then the distributed data ",
 | 
			
		||||
        "parallel module wasn't able to locate the output tensors in the ",
 | 
			
		||||
        "return value of your module's `forward` function. ",
 | 
			
		||||
        "Please include the structure of the return value of `forward` of ",
 | 
			
		||||
        "your module when reporting this issue (e.g. list, dict, iterable).");
 | 
			
		||||
  }
 | 
			
		||||
  AT_ASSERT(next_bucket_ == buckets_.size());
 | 
			
		||||
 | 
			
		||||
  // Wait for asynchronous reduction to complete and unflatten contents.
 | 
			
		||||
  for (auto& bucket : buckets_) {
 | 
			
		||||
 | 
			
		||||
@ -54,6 +54,7 @@ class Reducer {
 | 
			
		||||
  std::unordered_map<torch::autograd::Function*, std::tuple<int, int>> func_;
 | 
			
		||||
 | 
			
		||||
  bool expect_autograd_hooks_;
 | 
			
		||||
  bool require_finalize_;
 | 
			
		||||
  bool has_marked_unused_parameters_;
 | 
			
		||||
  size_t next_bucket_;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1290,20 +1290,17 @@ During export a list of all the tensors in a model is created. Tensors can come
 | 
			
		||||
 | 
			
		||||
### `attributes.pkl`
 | 
			
		||||
 | 
			
		||||
Attributes are all module properties that are not parameters or constants. Attributes are saved in a list in the order they were defined on the module. The list is stored as a Python `pickle` archive. `pickle`'s format was chosen due to:
 | 
			
		||||
* **user friendliness** - the attributes file can be loaded in Python with `pickle` without having PyTorch installed
 | 
			
		||||
* **size limits** - formats such as Protobuf empose size limits on total message size, whereas pickle limits are on individual values (e.g. strings cannot be longer than 4 GB)
 | 
			
		||||
* **standard format** - `pickle` is a standard Python module with a reasonably simple format. The format is a program to be consumed by a stack machine that is detailed in Python's [`pickletools.py`](https://svn.python.org/projects/python/trunk/Lib/pickletools.py)
 | 
			
		||||
* **built-in memoization** - for shared reference types (e.g. Tensor, string, lists, dicts)
 | 
			
		||||
* **self describing** - a separate definition file is not needed to understand the pickled data
 | 
			
		||||
* **eager mode save** - `torch.save()` already produces a `pickle` archive, so doing the same with attributes may ease unification of these formats in the future
 | 
			
		||||
[pickler.h](pickler.h),
 | 
			
		||||
[pickler.cpp](pickler.cpp),
 | 
			
		||||
[torch/jit/_pickle.py](../../../torch/jit/_pickle.py)
 | 
			
		||||
[caffe2/proto/torch.proto](../../../caffe2/proto/torch.proto)
 | 
			
		||||
 | 
			
		||||
A given module may have many attributes of different types and many submodules, each with their own attributes. Attributes are recorded in `model.json`:
 | 
			
		||||
Attributes are all module properties that are not parameters or constants. Attributes are saved in a list in the order they were defined on the module. A given module may have many attributes of different types and many submodules, each with their own attributes. Attribute metadata is recorded in `model.json`:
 | 
			
		||||
* `type` - the full type of the attribute (in [Mypy syntax](https://mypy.readthedocs.io/en/latest/cheat_sheet_py3.html))
 | 
			
		||||
* `name` - the attribute's name
 | 
			
		||||
* `id` - the offset into the saved list of all model attributes
 | 
			
		||||
 | 
			
		||||
`model.json`
 | 
			
		||||
In `model.json`:
 | 
			
		||||
```json
 | 
			
		||||
{
 | 
			
		||||
  "mainModule": {
 | 
			
		||||
@ -1344,41 +1341,61 @@ A given module may have many attributes of different types and many submodules,
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Attributes of the main module and its submodules are saved to a single file in the `zip` archive of a `.pt` file named `attributes.pkl`. A single file is used so that attributes can reference each other and shared values. Unpickling this will return a list of values corresponding to the attributes.
 | 
			
		||||
Attributes of the main module and its submodules are saved to a single file in the `zip` archive of a `.pt` file named `attributes.pkl`. Attributes are stored as a Python `pickle` archive. `pickle`'s format was chosen due to:
 | 
			
		||||
* **user friendliness** - the attributes file can be loaded in Python with `pickle`
 | 
			
		||||
* **size limits** - formats such as Protobuf empose size limits on total message size, whereas pickle limits are on individual values (e.g. strings cannot be longer than 4 GB)
 | 
			
		||||
* **standard format** - `pickle` is a standard Python module with a reasonably simple format. The format is a program to be consumed by a stack machine that is detailed in Python's [`pickletools.py`](https://svn.python.org/projects/python/trunk/Lib/pickletools.py)
 | 
			
		||||
* **built-in memoization** - for shared reference types (e.g. Tensor, string, lists, dicts)
 | 
			
		||||
* **self describing** - a separate definition file is not needed to understand the pickled data
 | 
			
		||||
* **eager mode save** - `torch.save()` already produces a `pickle` archive, so doing the same with attributes avoids introducing yet another format
 | 
			
		||||
 | 
			
		||||
All attributes are written into the `attributes.pkl` file with the exception of tensors, which store only a tensor table index (see "tensors" above). Classes are used to mark special data types, such as this tensor table index or specialized lists. To load the `attributes.pkl` file without PyTorch for inspection or manual editing, these classes must be defined, so a custom [`Unpickler`](https://docs.python.org/3/library/pickle.html#pickle.Unpickler) is necessary:
 | 
			
		||||
[pickler.cpp](pickler.cpp) implements a subset of the Pickle format necessary for TorchScript models.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
A single file is used for the top level module and all submodules so that attributes can reference each other and share values. Unpickling `attributes.pkl`  will return a tuple of values corresponding to the attributes.
 | 
			
		||||
 | 
			
		||||
All attributes are written into the `attributes.pkl` file with the exception of tensors, which store only a tensor table index (see "tensors" above). PyTorch functions defined in [torch/jit/_pickle.py](../../../torch/jit/_pickle.py) are used to mark special data types, such as this tensor table index or specialized lists. To load the `attributes.pkl` file, use the `pickle` module in Python:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
import pickle
 | 
			
		||||
# attributes.pkl include references to functions in torch.jit._pickle
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
pickle.load(open("attributes.pkl", "rb"))
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
If for some reason you don't have PyTorch installed, you can still load `attributes.pkl` with a custom [`Unpickler`](https://docs.python.org/3/library/pickle.html#pickle.Unpickler):
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
import pickle
 | 
			
		||||
 | 
			
		||||
# Tensor objects are stored as instances of this class
 | 
			
		||||
class TensorID(object):
 | 
			
		||||
    def __setstate__(self, id):
 | 
			
		||||
        self.id = id
 | 
			
		||||
 | 
			
		||||
# List[int] has internal specializations, and these are indicated with this class
 | 
			
		||||
class IntList(object):
 | 
			
		||||
    def __setstate__(self, data):
 | 
			
		||||
        self.data = data
 | 
			
		||||
 | 
			
		||||
class JitUnpickler(pickle.Unpickler):
 | 
			
		||||
    def find_class(self, module, name):
 | 
			
		||||
        if not module == '__main__':
 | 
			
		||||
            return None
 | 
			
		||||
        if module != 'torch.jit._pickle':
 | 
			
		||||
            raise RuntimeError("Unknown module")
 | 
			
		||||
 | 
			
		||||
        if name == 'TensorID':
 | 
			
		||||
            return TensorID
 | 
			
		||||
        elif name == 'IntList':
 | 
			
		||||
            return IntList
 | 
			
		||||
        identity = lambda x: x
 | 
			
		||||
        if name == 'build_tensor_from_id':
 | 
			
		||||
            # Without the tensor table we can't do anything other than
 | 
			
		||||
            # return the tensor ID
 | 
			
		||||
            return identity
 | 
			
		||||
        elif name == 'build_intlist':
 | 
			
		||||
            return identity
 | 
			
		||||
 | 
			
		||||
JitUnpickler(open("my_model/attributes.pkl", "rb")).load()
 | 
			
		||||
print(JitUnpickler(open("out_dir/out/attributes.pkl", "rb")).load())
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Binary Format
 | 
			
		||||
 | 
			
		||||
Running the following snippet produces a `ScriptModule` with several attributes.
 | 
			
		||||
Python's `pickletools` module can be used to decode the binary blob of `attributes.pkl` into a human readable format.
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
import pickletools
 | 
			
		||||
import zipfile
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Tuple, List
 | 
			
		||||
 | 
			
		||||
class M(torch.jit.ScriptModule):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super(M, self).__init__()
 | 
			
		||||
@ -1391,50 +1408,46 @@ class M(torch.jit.ScriptModule):
 | 
			
		||||
    def forward(self):
 | 
			
		||||
        return (self.float, self.tuple, self.tensor, self.int_list)
 | 
			
		||||
 | 
			
		||||
M().save("out.pt")
 | 
			
		||||
M().save("out.zip")
 | 
			
		||||
model_zip = zipfile.ZipFile("out.zip", 'r')
 | 
			
		||||
model_zip.extractall("out_dir")
 | 
			
		||||
pickletools.dis(open("out_dir/out/attributes.pkl", "rb"))
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
In a terminal, Python's `pickletools` module can be used to decode the binary blob of `attributes.pkl` into a human readable format.
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
unzip -o out.pt
 | 
			
		||||
python -m pickletools out/attributes.pkl
 | 
			
		||||
The output of the above commands demonstrates the concepts described earlier. Attributes are wrapped in with `2: EMPTY_LIST` and appear in the order they are defined on the module. Functions for certain special types (e.g. `List[int]`, `Tensor`) can be seen at `37: GLOBAL` and `66: GLOBAL`, followed by data specific to that type, then finally by an instruction to build the object at `65: BUILD` and `113: BUILD` respectively.
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
The output of the above commands demonstrates the concepts described earlier. Attributes are wrapped in with `2: EMPTY_LIST` and appear in the order they are defined on the module. Classes for certain special types (`List[int]`, `Tensor`) can be seen at `37: GLOBAL` and `66: GLOBAL`, followed by data specific to that type, then finally by an instruction to build the object at `65: BUILD` and `113: BUILD` respectively.
 | 
			
		||||
```
 | 
			
		||||
    0: \x80 PROTO      2
 | 
			
		||||
    2: ]    EMPTY_LIST
 | 
			
		||||
    3: (    MARK
 | 
			
		||||
    4: G        BINFLOAT   2.3
 | 
			
		||||
   13: (        MARK
 | 
			
		||||
   14: J            BININT     1
 | 
			
		||||
   19: J            BININT     2
 | 
			
		||||
   24: J            BININT     3
 | 
			
		||||
   29: J            BININT     4
 | 
			
		||||
   34: t            TUPLE      (MARK at 13)
 | 
			
		||||
   35: q        BINPUT     0
 | 
			
		||||
   37: c        GLOBAL     '__main__ TensorID'
 | 
			
		||||
   56: q        BINPUT     1
 | 
			
		||||
   58: )        EMPTY_TUPLE
 | 
			
		||||
   59: \x81     NEWOBJ
 | 
			
		||||
   60: J        BININT     0
 | 
			
		||||
   65: b        BUILD
 | 
			
		||||
   66: c        GLOBAL     '__main__ IntList'
 | 
			
		||||
   84: q        BINPUT     2
 | 
			
		||||
   86: )        EMPTY_TUPLE
 | 
			
		||||
   87: \x81     NEWOBJ
 | 
			
		||||
   88: ]        EMPTY_LIST
 | 
			
		||||
   89: q        BINPUT     3
 | 
			
		||||
   91: (        MARK
 | 
			
		||||
   92: J            BININT     1
 | 
			
		||||
   97: J            BININT     2
 | 
			
		||||
  102: J            BININT     3
 | 
			
		||||
  107: J            BININT     4
 | 
			
		||||
  112: e            APPENDS    (MARK at 91)
 | 
			
		||||
  113: b        BUILD
 | 
			
		||||
  114: e        APPENDS    (MARK at 3)
 | 
			
		||||
  115: .    STOP
 | 
			
		||||
  0: \x80 PROTO      2
 | 
			
		||||
  2: (    MARK
 | 
			
		||||
  3: G        BINFLOAT   2.3
 | 
			
		||||
 12: (        MARK
 | 
			
		||||
 13: K            BININT1    1
 | 
			
		||||
 15: K            BININT1    2
 | 
			
		||||
 17: K            BININT1    3
 | 
			
		||||
 19: K            BININT1    4
 | 
			
		||||
 21: t            TUPLE      (MARK at 12)
 | 
			
		||||
 22: q        BINPUT     0
 | 
			
		||||
 24: c        GLOBAL     'torch.jit._pickle build_tensor_from_id'
 | 
			
		||||
 64: q        BINPUT     1
 | 
			
		||||
 66: (        MARK
 | 
			
		||||
 67: K            BININT1    0
 | 
			
		||||
 69: t            TUPLE      (MARK at 66)
 | 
			
		||||
 70: R        REDUCE
 | 
			
		||||
 71: c        GLOBAL     'torch.jit._pickle build_intlist'
 | 
			
		||||
104: q        BINPUT     2
 | 
			
		||||
106: (        MARK
 | 
			
		||||
107: ]            EMPTY_LIST
 | 
			
		||||
108: (            MARK
 | 
			
		||||
109: K                BININT1    1
 | 
			
		||||
111: K                BININT1    2
 | 
			
		||||
113: K                BININT1    3
 | 
			
		||||
115: K                BININT1    4
 | 
			
		||||
117: e                APPENDS    (MARK at 108)
 | 
			
		||||
118: t            TUPLE      (MARK at 106)
 | 
			
		||||
119: R        REDUCE
 | 
			
		||||
120: q        BINPUT     3
 | 
			
		||||
122: t        TUPLE      (MARK at 2)
 | 
			
		||||
123: .    STOP
 | 
			
		||||
highest protocol among opcodes = 2
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,13 @@ namespace jit {
 | 
			
		||||
using ::c10::IValue;
 | 
			
		||||
 | 
			
		||||
PicklerClass getClass(const std::string& str) {
 | 
			
		||||
  if (str == "build_tensor_from_id") {
 | 
			
		||||
    return PicklerClass::TENSOR;
 | 
			
		||||
  } else if (str == "build_intlist") {
 | 
			
		||||
    return PicklerClass::INTLIST;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // TODO [unpickler refactor]
 | 
			
		||||
  if (str == "TensorID") {
 | 
			
		||||
    return PicklerClass::TENSOR;
 | 
			
		||||
  } else if (str == "IntList") {
 | 
			
		||||
@ -15,8 +22,8 @@ PicklerClass getClass(const std::string& str) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const std::string& getClassName(PicklerClass cls) {
 | 
			
		||||
  static const std::string tensor_class("TensorID\n");
 | 
			
		||||
  static const std::string intlist_class("IntList\n");
 | 
			
		||||
  static const std::string tensor_class("build_tensor_from_id\n");
 | 
			
		||||
  static const std::string intlist_class("build_intlist\n");
 | 
			
		||||
  switch (cls) {
 | 
			
		||||
    case PicklerClass::TENSOR:
 | 
			
		||||
      return tensor_class;
 | 
			
		||||
@ -28,7 +35,7 @@ const std::string& getClassName(PicklerClass cls) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const std::string& getModuleName() {
 | 
			
		||||
  static const std::string module_name("__main__\n");
 | 
			
		||||
  static const std::string module_name("torch.jit._pickle\n");
 | 
			
		||||
  return module_name;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -42,12 +49,11 @@ void Pickler::start() {
 | 
			
		||||
 | 
			
		||||
  // All attributes get pushed into a list and their indices saved in the
 | 
			
		||||
  // module def
 | 
			
		||||
  push<OpCode>(OpCode::EMPTY_LIST);
 | 
			
		||||
  push<OpCode>(OpCode::MARK);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Pickler::finish() {
 | 
			
		||||
  push<OpCode>(OpCode::APPENDS);
 | 
			
		||||
  push<OpCode>(OpCode::TUPLE);
 | 
			
		||||
  push<OpCode>(OpCode::STOP);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -92,17 +98,19 @@ void Pickler::addIValue(const IValue& ivalue) {
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Returns a void* uniquely identifying this IValue's data. For non-containers,
 | 
			
		||||
/// returns nullptr.
 | 
			
		||||
const void* Pickler::getPointer(const IValue& ivalue) {
 | 
			
		||||
  if (ivalue.isGenericDict()) {
 | 
			
		||||
    return &(ivalue.toGenericDictRef());
 | 
			
		||||
    return ivalue.toGenericDict().get();
 | 
			
		||||
  } else if (ivalue.isGenericList()) {
 | 
			
		||||
    return &(ivalue.toGenericListRef());
 | 
			
		||||
    return ivalue.toGenericList().get();
 | 
			
		||||
  } else if (ivalue.isTuple()) {
 | 
			
		||||
    return &(ivalue.toTuple()->elements());
 | 
			
		||||
    return ivalue.toTuple().get();
 | 
			
		||||
  } else if (ivalue.isString()) {
 | 
			
		||||
    return &(ivalue.toStringRef());
 | 
			
		||||
    return ivalue.toString().get();
 | 
			
		||||
  } else if (ivalue.isIntList()) {
 | 
			
		||||
    return &(ivalue.toIntListRef());
 | 
			
		||||
    return ivalue.toIntList().get();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return nullptr;
 | 
			
		||||
@ -165,35 +173,48 @@ void Pickler::pushClass(PicklerClass cls) {
 | 
			
		||||
  } else {
 | 
			
		||||
    pushBinGet(memo_entry->second);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  push<OpCode>(OpCode::EMPTY_TUPLE);
 | 
			
		||||
  push<OpCode>(OpCode::NEWOBJ);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Pickler::pushTensor(const IValue& ivalue) {
 | 
			
		||||
  pushClass(PicklerClass::TENSOR);
 | 
			
		||||
 | 
			
		||||
  tensor_table_->push_back(ivalue.toTensor());
 | 
			
		||||
  auto tensor_id = tensor_table_->size() - 1;
 | 
			
		||||
  push<OpCode>(OpCode::BININT);
 | 
			
		||||
  push<uint32_t>(tensor_id);
 | 
			
		||||
  int64_t tensor_id = tensor_table_->size() - 1;
 | 
			
		||||
  // Reduce arguments are spread (e.g. `*args`) before calling the global,
 | 
			
		||||
  // so wrap in a tuple
 | 
			
		||||
  push<OpCode>(OpCode::MARK);
 | 
			
		||||
  addIValue(tensor_id);
 | 
			
		||||
  push<OpCode>(OpCode::TUPLE);
 | 
			
		||||
 | 
			
		||||
  push<OpCode>(OpCode::BUILD);
 | 
			
		||||
  push<OpCode>(OpCode::REDUCE);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Pickler::pushIntList(const IValue& ivalue) {
 | 
			
		||||
  pushClass(PicklerClass::INTLIST);
 | 
			
		||||
 | 
			
		||||
  push<OpCode>(OpCode::EMPTY_LIST);
 | 
			
		||||
  pushMemoization(ivalue);
 | 
			
		||||
 | 
			
		||||
  // Reduce arguments are spread (e.g. `*args`) before calling the global,
 | 
			
		||||
  // so wrap in a tuple
 | 
			
		||||
  push<OpCode>(OpCode::MARK);
 | 
			
		||||
 | 
			
		||||
  push<OpCode>(OpCode::EMPTY_LIST);
 | 
			
		||||
  // Mark list
 | 
			
		||||
  push<OpCode>(OpCode::MARK);
 | 
			
		||||
 | 
			
		||||
  // Add items
 | 
			
		||||
  for (const auto& item : ivalue.toIntListRef()) {
 | 
			
		||||
    addIValue(item);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Finish list
 | 
			
		||||
  push<OpCode>(OpCode::APPENDS);
 | 
			
		||||
  push<OpCode>(OpCode::BUILD);
 | 
			
		||||
 | 
			
		||||
  // Finish tuple
 | 
			
		||||
  push<OpCode>(OpCode::TUPLE);
 | 
			
		||||
 | 
			
		||||
  // Call reduce
 | 
			
		||||
  push<OpCode>(OpCode::REDUCE);
 | 
			
		||||
  pushMemoization(ivalue);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Pickler::pushDouble(const IValue& ivalue) {
 | 
			
		||||
@ -208,8 +229,6 @@ void Pickler::pushDouble(const IValue& ivalue) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Pickler::pushDict(const IValue& ivalue) {
 | 
			
		||||
  auto dict = ivalue.toGenericDictRef();
 | 
			
		||||
 | 
			
		||||
  push<OpCode>(OpCode::EMPTY_DICT);
 | 
			
		||||
  pushMemoization(ivalue);
 | 
			
		||||
 | 
			
		||||
@ -226,7 +245,7 @@ void Pickler::pushDict(const IValue& ivalue) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Pickler::pushMemoization(const void* item) {
 | 
			
		||||
  AT_ASSERT(item != nullptr);
 | 
			
		||||
  AT_CHECK(item != nullptr, "Pickler cannot memoize a nullptr");
 | 
			
		||||
  if (memo_id <= std::numeric_limits<uint8_t>::max()) {
 | 
			
		||||
    push<OpCode>(OpCode::BINPUT);
 | 
			
		||||
    push<uint8_t>(memo_id);
 | 
			
		||||
@ -241,7 +260,14 @@ void Pickler::pushMemoization(const void* item) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Pickler::pushMemoization(const IValue& ivalue) {
 | 
			
		||||
  pushMemoization(getPointer(ivalue));
 | 
			
		||||
  auto ptr = getPointer(ivalue);
 | 
			
		||||
  AT_CHECK(
 | 
			
		||||
      ptr != nullptr,
 | 
			
		||||
      "Pickler cannot memoize ",
 | 
			
		||||
      ivalue.tagKind(),
 | 
			
		||||
      " IValue ",
 | 
			
		||||
      ivalue)
 | 
			
		||||
  pushMemoization(ptr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Pickler::pushList(const IValue& ivalue) {
 | 
			
		||||
@ -273,8 +299,17 @@ void Pickler::pushTuple(const IValue& ivalue) {
 | 
			
		||||
 | 
			
		||||
std::vector<IValue> Unpickler::parse_ivalue_list() {
 | 
			
		||||
  run();
 | 
			
		||||
  AT_ASSERT(stack_.size() == 1);
 | 
			
		||||
  return stack_[0].toGenericListRef();
 | 
			
		||||
  AT_CHECK(
 | 
			
		||||
      stack_.size() == 1,
 | 
			
		||||
      "Expected stack to end with a size of 1 but got ",
 | 
			
		||||
      stack_.size());
 | 
			
		||||
 | 
			
		||||
  auto value = stack_[0].ivalue();
 | 
			
		||||
  if (value.isGenericList()) {
 | 
			
		||||
    // TODO [unpickler refactor]
 | 
			
		||||
    return value.toGenericListRef();
 | 
			
		||||
  }
 | 
			
		||||
  return value.toTuple()->elements();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
double Unpickler::readFloat() {
 | 
			
		||||
@ -294,7 +329,10 @@ double Unpickler::readFloat() {
 | 
			
		||||
 | 
			
		||||
void Unpickler::run() {
 | 
			
		||||
  // Expect a PROTO opcode and protocol number at the start of blob
 | 
			
		||||
  AT_ASSERT(readOpCode() == OpCode::PROTO);
 | 
			
		||||
  AT_CHECK(
 | 
			
		||||
      readOpCode() == OpCode::PROTO,
 | 
			
		||||
      "Expected PROTO opcode at the start"
 | 
			
		||||
      " of pickle archive");
 | 
			
		||||
  uint8_t protocol = read<uint8_t>();
 | 
			
		||||
  AT_CHECK(
 | 
			
		||||
      protocol == 2,
 | 
			
		||||
@ -312,17 +350,18 @@ void Unpickler::run() {
 | 
			
		||||
  AT_ERROR("Overran buffer while unpickling data, didn't find STOP opcode");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
OpCode Unpickler::readInstruction() {
 | 
			
		||||
  auto opcode = readOpCode();
 | 
			
		||||
  switch (opcode) {
 | 
			
		||||
    case OpCode::EMPTY_LIST: {
 | 
			
		||||
      // Look back to see if the last opcode was an IntList class
 | 
			
		||||
      if (last_opcode_ == OpCode::NEWOBJ) {
 | 
			
		||||
        // TODO [unpickler refactor] remove this case
 | 
			
		||||
        // It's a list specialization, the enum ID of which is on the stack
 | 
			
		||||
        AT_CHECK(
 | 
			
		||||
            stack_.size() > 0,
 | 
			
		||||
            "Unpickler found an empty stack when it expected a value");
 | 
			
		||||
        auto value = stack_.back().toInt();
 | 
			
		||||
        auto value = stack_.back().ivalue().toInt();
 | 
			
		||||
        AT_CHECK(
 | 
			
		||||
            value >= 0 && value <= std::numeric_limits<uint8_t>::max(),
 | 
			
		||||
            "Unpickler could not decode PicklerClass for ",
 | 
			
		||||
@ -331,6 +370,14 @@ OpCode Unpickler::readInstruction() {
 | 
			
		||||
        if (cls == PicklerClass::INTLIST) {
 | 
			
		||||
          stack_.emplace_back(std::vector<int64_t>());
 | 
			
		||||
        }
 | 
			
		||||
      } else if (stack_.size() > 0 && stack_.back().pickler_class_opt()) {
 | 
			
		||||
        // Check if we're in a GLOBAL opcode and if so, if it's a list
 | 
			
		||||
        // specialization
 | 
			
		||||
        if (stack_.back().pickler_class() == PicklerClass::INTLIST) {
 | 
			
		||||
          stack_.emplace_back(std::vector<int64_t>());
 | 
			
		||||
        } else {
 | 
			
		||||
          AT_ERROR("Unknown list specialization");
 | 
			
		||||
        }
 | 
			
		||||
      } else {
 | 
			
		||||
        stack_.emplace_back(std::vector<IValue>());
 | 
			
		||||
      }
 | 
			
		||||
@ -394,10 +441,14 @@ OpCode Unpickler::readInstruction() {
 | 
			
		||||
    case OpCode::TUPLE: {
 | 
			
		||||
      size_t start = marks_.back();
 | 
			
		||||
      marks_.pop_back();
 | 
			
		||||
      IValue tup = c10::ivalue::Tuple::create(
 | 
			
		||||
          std::vector<IValue>(stack_.begin() + start, stack_.end()));
 | 
			
		||||
      stack_.resize(start);
 | 
			
		||||
      stack_.push_back(tup);
 | 
			
		||||
      auto tuple = c10::ivalue::Tuple::create({});
 | 
			
		||||
      tuple->elements().reserve(stack_.size() - start);
 | 
			
		||||
      auto start_it = stack_.begin() + start;
 | 
			
		||||
      for (auto it = start_it; it != stack_.end(); ++it) {
 | 
			
		||||
        tuple->elements().emplace_back(it->ivalue());
 | 
			
		||||
      }
 | 
			
		||||
      stack_.erase(start_it, stack_.end());
 | 
			
		||||
      stack_.emplace_back(IValue(tuple));
 | 
			
		||||
    } break;
 | 
			
		||||
    case OpCode::EMPTY_DICT:
 | 
			
		||||
      stack_.emplace_back(c10::ivalue::UnorderedMap());
 | 
			
		||||
@ -408,11 +459,11 @@ OpCode Unpickler::readInstruction() {
 | 
			
		||||
    case OpCode::SETITEMS: {
 | 
			
		||||
      size_t start = marks_.back();
 | 
			
		||||
      marks_.pop_back();
 | 
			
		||||
      auto dict = stack_.at(start - 1).toGenericDict();
 | 
			
		||||
      auto dict = stack_.at(start - 1).ivalue().toGenericDict();
 | 
			
		||||
      for (size_t i = start; i < stack_.size(); i += 2) {
 | 
			
		||||
        dict->elements()[stack_[i]] = stack_[i + 1];
 | 
			
		||||
        dict->elements()[stack_[i].ivalue()] = stack_[i + 1].ivalue();
 | 
			
		||||
      }
 | 
			
		||||
      stack_.resize(start);
 | 
			
		||||
      stack_.erase(stack_.begin() + start, stack_.end());
 | 
			
		||||
    } break;
 | 
			
		||||
    case OpCode::BINGET: {
 | 
			
		||||
      stack_.push_back(memo_table_.at(read<uint8_t>()));
 | 
			
		||||
@ -423,35 +474,64 @@ OpCode Unpickler::readInstruction() {
 | 
			
		||||
    case OpCode::STOP:
 | 
			
		||||
      break;
 | 
			
		||||
    case OpCode::GLOBAL: {
 | 
			
		||||
      AT_ASSERT(readString() == "__main__");
 | 
			
		||||
      // Push class name to stack
 | 
			
		||||
      stack_.emplace_back(static_cast<uint8_t>(getClass(readString())));
 | 
			
		||||
      // Module name, it's not needed for anything
 | 
			
		||||
      auto module_name = readString();
 | 
			
		||||
      // TODO [unpickler refactor] __main__ isn't used by the pickler anymore
 | 
			
		||||
      if (module_name == "__main__") {
 | 
			
		||||
        stack_.emplace_back(static_cast<uint8_t>(getClass(readString())));
 | 
			
		||||
      } else {
 | 
			
		||||
        // Push class name to stack
 | 
			
		||||
        stack_.emplace_back(getClass(readString()));
 | 
			
		||||
      }
 | 
			
		||||
    } break;
 | 
			
		||||
    case OpCode::NEWOBJ: {
 | 
			
		||||
      // pop empty tuple
 | 
			
		||||
      stack_.pop_back();
 | 
			
		||||
    } break;
 | 
			
		||||
    case OpCode::BUILD: {
 | 
			
		||||
      auto setitem_data = stack_.back();
 | 
			
		||||
      // TODO: [unpickler refactor]
 | 
			
		||||
      auto setitem_data = stack_.back().ivalue();
 | 
			
		||||
      stack_.pop_back();
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
      auto class_name =
 | 
			
		||||
          static_cast<PicklerClass>(uint8_t(stack_.back().toInt()));
 | 
			
		||||
        static_cast<PicklerClass>(uint8_t(stack_.back().ivalue().toInt()));
 | 
			
		||||
      stack_.pop_back();
 | 
			
		||||
 | 
			
		||||
      switch (class_name) {
 | 
			
		||||
      case PicklerClass::TENSOR:
 | 
			
		||||
        stack_.emplace_back(tensor_table_->at(setitem_data.toInt()));
 | 
			
		||||
        break;
 | 
			
		||||
      case PicklerClass::INTLIST:
 | 
			
		||||
        stack_.emplace_back(setitem_data);
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        AT_ERROR("Unknown pickler class id");
 | 
			
		||||
      }
 | 
			
		||||
    } break;
 | 
			
		||||
    case OpCode::REDUCE: {
 | 
			
		||||
      // Pop reduce arg off the stack
 | 
			
		||||
      auto data = stack_.back().ivalue().toTuple();
 | 
			
		||||
      stack_.pop_back();
 | 
			
		||||
 | 
			
		||||
      // Remove GLOBAL from stack
 | 
			
		||||
      auto class_name = stack_.back().pickler_class();
 | 
			
		||||
      stack_.pop_back();
 | 
			
		||||
 | 
			
		||||
      switch (class_name) {
 | 
			
		||||
        case PicklerClass::TENSOR:
 | 
			
		||||
          stack_.emplace_back(tensor_table_->at(setitem_data.toInt()));
 | 
			
		||||
          stack_.emplace_back(
 | 
			
		||||
              tensor_table_->at(data->elements().at(0).toInt()));
 | 
			
		||||
          break;
 | 
			
		||||
        case PicklerClass::INTLIST:
 | 
			
		||||
          stack_.push_back(setitem_data);
 | 
			
		||||
          stack_.emplace_back(data->elements().at(0).toIntListRef());
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          AT_ERROR("Unknown pickler class id");
 | 
			
		||||
      }
 | 
			
		||||
    } break;
 | 
			
		||||
    default:
 | 
			
		||||
      AT_ERROR("Unknown opcode for unpickling: ", static_cast<uint8_t>(opcode));
 | 
			
		||||
      AT_ERROR("Unknown opcode for unpickling at ", reinterpret_cast<void*>(opcode),": ", static_cast<uint8_t>(opcode));
 | 
			
		||||
  }
 | 
			
		||||
  return opcode;
 | 
			
		||||
}
 | 
			
		||||
@ -460,19 +540,27 @@ void Unpickler::readList() {
 | 
			
		||||
  size_t start = marks_.back();
 | 
			
		||||
  marks_.pop_back();
 | 
			
		||||
  auto list_ivalue = stack_.at(start - 1);
 | 
			
		||||
  if (list_ivalue.isIntList()) {
 | 
			
		||||
    auto list = stack_.at(start - 1).toIntList();
 | 
			
		||||
    auto num_elements = stack_.size() - start;
 | 
			
		||||
  auto num_elements = stack_.size() - start;
 | 
			
		||||
  if (list_ivalue.ivalue().isIntList()) {
 | 
			
		||||
    auto list = stack_.at(start - 1).ivalue().toIntList();
 | 
			
		||||
    list->elements().reserve(num_elements);
 | 
			
		||||
    for (auto it = stack_.begin() + start; it != stack_.end(); ++it) {
 | 
			
		||||
      list->elements().emplace_back(it->toInt());
 | 
			
		||||
      list->elements().emplace_back(it->ivalue().toInt());
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    auto list = stack_.at(start - 1).toGenericList();
 | 
			
		||||
    list->elements().insert(
 | 
			
		||||
        list->elements().end(), stack_.begin() + start, stack_.end());
 | 
			
		||||
    auto list = stack_.at(start - 1).ivalue().toGenericList();
 | 
			
		||||
    list->elements().reserve(num_elements);
 | 
			
		||||
    for (auto it = stack_.begin() + start; it != stack_.end(); ++it) {
 | 
			
		||||
      list->elements().emplace_back(it->ivalue());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  stack_.resize(start);
 | 
			
		||||
 | 
			
		||||
  stack_.erase(stack_.begin() + start, stack_.end());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline bool is_valid_python_id_char(char c) {
 | 
			
		||||
  return c == '_' || c == '.' || (c >= '0' && c <= '9') ||
 | 
			
		||||
      (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Read a newline terminated string
 | 
			
		||||
@ -487,7 +575,12 @@ std::string Unpickler::readString() {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Simple check just in case there is no terminating '\n'
 | 
			
		||||
    AT_ASSERT(c >= '0' && c <= 'z');
 | 
			
		||||
    AT_CHECK(
 | 
			
		||||
        is_valid_python_id_char(c),
 | 
			
		||||
        "Found character '",
 | 
			
		||||
        uint8_t(c),
 | 
			
		||||
        "' in string, "
 | 
			
		||||
        "strings must be qualified Python identifiers");
 | 
			
		||||
 | 
			
		||||
    // Increment after to exclude newline from string
 | 
			
		||||
    ++n;
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,5 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
@ -119,7 +121,7 @@ class Pickler {
 | 
			
		||||
  // the left of a '::', its type cannot be deduced by the compiler so one must
 | 
			
		||||
  // explicitly instantiate the template, i.e. push<int>(int) works, push(int)
 | 
			
		||||
  // does not)
 | 
			
		||||
  template<typename T>
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  void push(typename std::common_type<T>::type value) {
 | 
			
		||||
    const char* begin = reinterpret_cast<const char*>(&value);
 | 
			
		||||
    stack_.insert(stack_.end(), begin, begin + sizeof(T));
 | 
			
		||||
@ -140,6 +142,39 @@ class Pickler {
 | 
			
		||||
  uint32_t memo_id = 0;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// An item in the unpickler stack. There needs to be a way to differentiate
 | 
			
		||||
// between a GLOBAL item (PicklerClass) and a normal value item (IValue)
 | 
			
		||||
struct StackItem {
 | 
			
		||||
  StackItem(IValue ivalue)
 | 
			
		||||
      : pickler_class_(c10::nullopt), ivalue_(std::move(ivalue)) {}
 | 
			
		||||
  StackItem(PicklerClass pickler_class)
 | 
			
		||||
      : pickler_class_(pickler_class), ivalue_(c10::nullopt) {}
 | 
			
		||||
 | 
			
		||||
  IValue ivalue() {
 | 
			
		||||
    return *ivalue_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  PicklerClass pickler_class() {
 | 
			
		||||
    return *pickler_class_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  c10::optional<IValue> ivalue_opt() {
 | 
			
		||||
    return ivalue_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  c10::optional<PicklerClass> pickler_class_opt() {
 | 
			
		||||
    return pickler_class_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  c10::optional<PicklerClass> pickler_class_;
 | 
			
		||||
  c10::optional<IValue> ivalue_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// [unpickler refactor] there is some cruft around OpCode::BUILD,
 | 
			
		||||
// OpCode::NEWOBJ, and the last_opcode_ member below that should be deleted at
 | 
			
		||||
// some point, the Pickler doesn't produce it and it's only around to support
 | 
			
		||||
// models saved before 1.1
 | 
			
		||||
class Unpickler {
 | 
			
		||||
  TH_DISALLOW_COPY_AND_ASSIGN(Unpickler);
 | 
			
		||||
 | 
			
		||||
@ -176,12 +211,14 @@ class Unpickler {
 | 
			
		||||
  OpCode readOpCode();
 | 
			
		||||
  void readList();
 | 
			
		||||
 | 
			
		||||
  std::vector<IValue> stack_;
 | 
			
		||||
  std::vector<IValue> memo_table_;
 | 
			
		||||
  std::vector<StackItem> stack_;
 | 
			
		||||
  std::vector<StackItem> memo_table_;
 | 
			
		||||
  std::vector<size_t> marks_;
 | 
			
		||||
  const uint8_t* bytes_;
 | 
			
		||||
  const uint8_t* end_ptr_;
 | 
			
		||||
  const std::vector<at::Tensor>* tensor_table_;
 | 
			
		||||
 | 
			
		||||
  // [unpickler refactor]
 | 
			
		||||
  OpCode last_opcode_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -107,10 +107,6 @@ Value* tryConvertToType(
 | 
			
		||||
        DeviceObjType::get()->isSubtypeOf(concrete_type)) {
 | 
			
		||||
      return graph.insert(aten::device, {value}, {}, loc);
 | 
			
		||||
    }
 | 
			
		||||
    if (concrete_type == FloatType::get() &&
 | 
			
		||||
        value->type() == NumberType::get()) {
 | 
			
		||||
      return graph.insert(prim::Float, {value}, {}, loc);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return value;
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,6 @@ import torch._jit_internal as _jit_internal
 | 
			
		||||
from torch._six import with_metaclass, get_function_from_type, \
 | 
			
		||||
    string_classes
 | 
			
		||||
from torch._jit_internal import ignore  # noqa: F401
 | 
			
		||||
from torch.jit._pickle import Unpickler  # noqa: F401
 | 
			
		||||
from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
 | 
			
		||||
    _list_with_default
 | 
			
		||||
import torch.testing
 | 
			
		||||
@ -99,7 +98,8 @@ def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP):
 | 
			
		||||
        Returns:
 | 
			
		||||
            A ``ScriptModule`` object.
 | 
			
		||||
 | 
			
		||||
        Example:
 | 
			
		||||
        Example: ::
 | 
			
		||||
 | 
			
		||||
            torch.jit.load('scriptmodule.pt')
 | 
			
		||||
 | 
			
		||||
            # Load ScriptModule from io.BytesIO object
 | 
			
		||||
@ -178,7 +178,8 @@ def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP):
 | 
			
		||||
 | 
			
		||||
            Please use something like ``io.BytesIO`` instead.
 | 
			
		||||
 | 
			
		||||
        Example:
 | 
			
		||||
        Example: ::
 | 
			
		||||
 | 
			
		||||
            m = torch.jit.ScriptModule()
 | 
			
		||||
 | 
			
		||||
            # Save to file
 | 
			
		||||
@ -1069,13 +1070,13 @@ if _enabled:
 | 
			
		||||
        The core data structure in TorchScript is the ``ScriptModule``. It is an
 | 
			
		||||
        analogue of torch's ``nn.Module`` and represents an entire model as a tree of
 | 
			
		||||
        submodules. Like normal modules, each individual module in a ``ScriptModule`` can
 | 
			
		||||
        have submodules, parameters, and methods. In ``nn.Module``s methods are implemented
 | 
			
		||||
        as Python functions, but in ``ScriptModule``s methods are implemented as
 | 
			
		||||
        have submodules, parameters, and methods. In ``nn.Module``\s methods are implemented
 | 
			
		||||
        as Python functions, but in ``ScriptModule``\s methods are implemented as
 | 
			
		||||
        TorchScript functions,  a statically-typed subset of Python that contains all
 | 
			
		||||
        of PyTorch's built-in Tensor operations. This difference allows your
 | 
			
		||||
        ScriptModules code to run without the need for a Python interpreter.
 | 
			
		||||
 | 
			
		||||
        ``ScriptModule``s be created in two ways:
 | 
			
		||||
        ``ScriptModule``\s be created in two ways:
 | 
			
		||||
 | 
			
		||||
        **Tracing:**
 | 
			
		||||
 | 
			
		||||
@ -1132,9 +1133,9 @@ if _enabled:
 | 
			
		||||
            You can write TorchScript code directly using Python syntax. You do this
 | 
			
		||||
            using the ``@torch.jit.script`` decorator (for functions) or
 | 
			
		||||
            ``@torch.jit.script_method`` decorator (for methods) on subclasses of
 | 
			
		||||
            ScriptModule. With this decorator the body of the annotated function is
 | 
			
		||||
            ``ScriptModule``. With this decorator the body of the annotated function is
 | 
			
		||||
            directly translated into TorchScript. TorchScript itself is a subset of
 | 
			
		||||
            the Python language, so not all features in python work, but we provide
 | 
			
		||||
            the Python language, so not all features in Python work, but we provide
 | 
			
		||||
            enough functionality to compute on tensors and do control-dependent
 | 
			
		||||
            operations.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,24 +1,8 @@
 | 
			
		||||
import pickle
 | 
			
		||||
def build_intlist(data):
 | 
			
		||||
    return data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TensorID(object):
 | 
			
		||||
    def __setstate__(self, id):
 | 
			
		||||
        self.id = id
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IntList(object):
 | 
			
		||||
    def __setstate__(self, data):
 | 
			
		||||
        self.data = data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Unpickler(pickle.Unpickler):
 | 
			
		||||
    def find_class(self, module, name):
 | 
			
		||||
        if not module == '__main__':
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        if name == 'TensorID':
 | 
			
		||||
            return TensorID
 | 
			
		||||
        elif name == 'IntList':
 | 
			
		||||
            return IntList
 | 
			
		||||
        elif name == 'LiteralTensor':
 | 
			
		||||
            return LiteralTensor
 | 
			
		||||
def build_tensor_from_id(data):
 | 
			
		||||
    if isinstance(data, int):
 | 
			
		||||
        # just the id, can't really do anything
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
@ -198,7 +198,7 @@ class DistributedDataParallel(Module):
 | 
			
		||||
                                       Parameters that don't receive gradients as
 | 
			
		||||
                                       part of this graph are preemptively marked
 | 
			
		||||
                                       as being ready to be reduced.
 | 
			
		||||
                                       (default: ``True``)
 | 
			
		||||
                                       (default: ``False``)
 | 
			
		||||
        check_reduction: when setting to ``True``, it enables DistributedDataParallel
 | 
			
		||||
                         to automatically check if the previous iteration's
 | 
			
		||||
                         backward reductions were successfully issued at the
 | 
			
		||||
@ -220,7 +220,7 @@ class DistributedDataParallel(Module):
 | 
			
		||||
    def __init__(self, module, device_ids=None,
 | 
			
		||||
                 output_device=None, dim=0, broadcast_buffers=True,
 | 
			
		||||
                 process_group=None, bucket_cap_mb=25,
 | 
			
		||||
                 find_unused_parameters=True,
 | 
			
		||||
                 find_unused_parameters=False,
 | 
			
		||||
                 check_reduction=False):
 | 
			
		||||
 | 
			
		||||
        super(DistributedDataParallel, self).__init__()
 | 
			
		||||
@ -380,14 +380,16 @@ class DistributedDataParallel(Module):
 | 
			
		||||
        else:
 | 
			
		||||
            output = self.module(*inputs, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # We'll return the output object verbatim since it is a freeform object.
 | 
			
		||||
        # We need to find any tensors in this object, though, because we need to
 | 
			
		||||
        # figure out which parameters were used during this forward pass,
 | 
			
		||||
        # to ensure we short circuit reduction for any unused parameters.
 | 
			
		||||
        if self.find_unused_parameters:
 | 
			
		||||
            self.reducer.prepare_for_backward(list(_find_tensors(output)))
 | 
			
		||||
        else:
 | 
			
		||||
            self.reducer.prepare_for_backward([])
 | 
			
		||||
        if torch.is_grad_enabled():
 | 
			
		||||
            # We'll return the output object verbatim since it is a freeform
 | 
			
		||||
            # object. We need to find any tensors in this object, though,
 | 
			
		||||
            # because we need to figure out which parameters were used during
 | 
			
		||||
            # this forward pass, to ensure we short circuit reduction for any
 | 
			
		||||
            # unused parameters. Only if `find_unused_parameters` is set.
 | 
			
		||||
            if self.find_unused_parameters:
 | 
			
		||||
                self.reducer.prepare_for_backward(list(_find_tensors(output)))
 | 
			
		||||
            else:
 | 
			
		||||
                self.reducer.prepare_for_backward([])
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    def scatter(self, inputs, kwargs, device_ids):
 | 
			
		||||
 | 
			
		||||
@ -459,7 +459,7 @@ class SummaryWriter(object):
 | 
			
		||||
            walltime (float): Optional override default walltime (time.time())
 | 
			
		||||
              seconds after epoch of event
 | 
			
		||||
        Shape:
 | 
			
		||||
            vid_tensor: :math:`(N, T, C, H, W)`.
 | 
			
		||||
            vid_tensor: :math:`(N, T, C, H, W)`. The values should lie in [0, 255] for type `uint8` or [0, 1] for type `float`.
 | 
			
		||||
        """
 | 
			
		||||
        self._get_file_writer().add_summary(
 | 
			
		||||
            video(tag, vid_tensor, fps), global_step, walltime)
 | 
			
		||||
@ -714,7 +714,7 @@ class SummaryWriter(object):
 | 
			
		||||
    def add_custom_scalars(self, layout):
 | 
			
		||||
        """Create special chart by collecting charts tags in 'scalars'. Note that this function can only be called once
 | 
			
		||||
        for each SummaryWriter() object. Because it only provides metadata to tensorboard, the function can be called
 | 
			
		||||
        before or after the training loop. See ``examples/demo_custom_scalars.py`` for more.
 | 
			
		||||
        before or after the training loop.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            layout (dict): {categoryName: *charts*}, where *charts* is also a dictionary
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user