Add Sphinx docs
Before Width: | Height: | Size: 5.8 KiB After Width: | Height: | Size: 5.8 KiB |
Before Width: | Height: | Size: 32 KiB After Width: | Height: | Size: 32 KiB |
Before Width: | Height: | Size: 6.0 KiB After Width: | Height: | Size: 6.0 KiB |
Before Width: | Height: | Size: 5.4 KiB After Width: | Height: | Size: 5.4 KiB |
Before Width: | Height: | Size: 5.8 KiB After Width: | Height: | Size: 5.8 KiB |
Before Width: | Height: | Size: 8.9 KiB After Width: | Height: | Size: 8.9 KiB |
Before Width: | Height: | Size: 8.5 KiB After Width: | Height: | Size: 8.5 KiB |
Before Width: | Height: | Size: 6.4 KiB After Width: | Height: | Size: 6.4 KiB |
Before Width: | Height: | Size: 19 KiB After Width: | Height: | Size: 19 KiB |
Before Width: | Height: | Size: 19 KiB After Width: | Height: | Size: 19 KiB |
Before Width: | Height: | Size: 20 KiB After Width: | Height: | Size: 20 KiB |
Before Width: | Height: | Size: 12 KiB After Width: | Height: | Size: 12 KiB |
Before Width: | Height: | Size: 6.4 KiB After Width: | Height: | Size: 6.4 KiB |
Before Width: | Height: | Size: 6.4 KiB After Width: | Height: | Size: 6.4 KiB |
Before Width: | Height: | Size: 6.1 KiB After Width: | Height: | Size: 6.1 KiB |
Before Width: | Height: | Size: 6.3 KiB After Width: | Height: | Size: 6.3 KiB |
Before Width: | Height: | Size: 19 KiB After Width: | Height: | Size: 19 KiB |
Before Width: | Height: | Size: 6.7 KiB After Width: | Height: | Size: 6.7 KiB |
Before Width: | Height: | Size: 5.9 KiB After Width: | Height: | Size: 5.9 KiB |
Before Width: | Height: | Size: 6.8 KiB After Width: | Height: | Size: 6.8 KiB |
Before Width: | Height: | Size: 5.4 KiB After Width: | Height: | Size: 5.4 KiB |
Before Width: | Height: | Size: 7.2 KiB After Width: | Height: | Size: 7.2 KiB |
Before Width: | Height: | Size: 18 KiB After Width: | Height: | Size: 18 KiB |
20
docs/Makefile
Normal file
@ -0,0 +1,20 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line.
|
||||
SPHINXOPTS =
|
||||
SPHINXBUILD = sphinx-build
|
||||
SPHINXPROJ = PyTorch
|
||||
SOURCEDIR = source
|
||||
BUILDDIR = build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
36
docs/make.bat
Normal file
@ -0,0 +1,36 @@
|
||||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=source
|
||||
set BUILDDIR=build
|
||||
set SPHINXPROJ=PyTorch
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.http://sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
||||
|
||||
:end
|
||||
popd
|
2
docs/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
||||
sphinx
|
||||
sphinx_rtd_theme
|
2
docs/source/autograd.rst
Normal file
@ -0,0 +1,2 @@
|
||||
torch.autograd
|
||||
===================================
|
174
docs/source/conf.py
Normal file
@ -0,0 +1,174 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# PyTorch documentation build configuration file, created by
|
||||
# sphinx-quickstart on Fri Dec 23 13:31:47 2016.
|
||||
#
|
||||
# This file is execfile()d with the current directory set to its
|
||||
# containing dir.
|
||||
#
|
||||
# Note that not all possible configuration values are present in this
|
||||
# autogenerated file.
|
||||
#
|
||||
# All configuration values have a default; values that are commented out
|
||||
# serve to show the default.
|
||||
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
# import os
|
||||
# import sys
|
||||
# sys.path.insert(0, os.path.abspath('.'))
|
||||
import torch
|
||||
import sphinx_rtd_theme
|
||||
|
||||
|
||||
# -- General configuration ------------------------------------------------
|
||||
|
||||
# If your documentation needs a minimal Sphinx version, state it here.
|
||||
#
|
||||
# needs_sphinx = '1.0'
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.doctest',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.todo',
|
||||
'sphinx.ext.coverage',
|
||||
'sphinx.ext.mathjax',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.viewcode',
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
#
|
||||
# source_suffix = ['.rst', '.md']
|
||||
source_suffix = '.rst'
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = 'index'
|
||||
|
||||
# General information about the project.
|
||||
project = 'PyTorch'
|
||||
copyright = '2016, Torch Contributors'
|
||||
author = 'Torch Contributors'
|
||||
|
||||
# The version info for the project you're documenting, acts as replacement for
|
||||
# |version| and |release|, also used in various other places throughout the
|
||||
# built documents.
|
||||
#
|
||||
# The short X.Y version.
|
||||
version = '0.1.6'
|
||||
# The full version, including alpha/beta/rc tags.
|
||||
release = '0.1.6'
|
||||
|
||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||
# for a list of supported languages.
|
||||
#
|
||||
# This is also used if you do content translation via gettext catalogs.
|
||||
# Usually you set "language" from the command line for these cases.
|
||||
language = None
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This patterns also effect to html_static_path and html_extra_path
|
||||
exclude_patterns = []
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = 'sphinx'
|
||||
|
||||
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
||||
todo_include_todos = True
|
||||
|
||||
|
||||
# -- Options for HTML output ----------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
# further. For a list of options available for each theme, see the
|
||||
# documentation.
|
||||
#
|
||||
html_theme_options = {
|
||||
'collapse_navigation': False,
|
||||
'display_version': False,
|
||||
}
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
|
||||
|
||||
# -- Options for HTMLHelp output ------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = 'PyTorchdoc'
|
||||
|
||||
|
||||
# -- Options for LaTeX output ---------------------------------------------
|
||||
|
||||
latex_elements = {
|
||||
# The paper size ('letterpaper' or 'a4paper').
|
||||
#
|
||||
# 'papersize': 'letterpaper',
|
||||
|
||||
# The font size ('10pt', '11pt' or '12pt').
|
||||
#
|
||||
# 'pointsize': '10pt',
|
||||
|
||||
# Additional stuff for the LaTeX preamble.
|
||||
#
|
||||
# 'preamble': '',
|
||||
|
||||
# Latex figure (float) alignment
|
||||
#
|
||||
# 'figure_align': 'htbp',
|
||||
}
|
||||
|
||||
# Grouping the document tree into LaTeX files. List of tuples
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, 'pytorch.tex', 'PyTorch Documentation',
|
||||
'Torch Contributors', 'manual'),
|
||||
]
|
||||
|
||||
|
||||
# -- Options for manual page output ---------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [
|
||||
(master_doc, 'PyTorch', 'PyTorch Documentation',
|
||||
[author], 1)
|
||||
]
|
||||
|
||||
|
||||
# -- Options for Texinfo output -------------------------------------------
|
||||
|
||||
# Grouping the document tree into Texinfo files. List of tuples
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(master_doc, 'PyTorch', 'PyTorch Documentation',
|
||||
author, 'PyTorch', 'One line description of project.',
|
||||
'Miscellaneous'),
|
||||
]
|
||||
|
||||
|
||||
# Example configuration for intersphinx: refer to the Python standard library.
|
||||
intersphinx_mapping = {'https://docs.python.org/': None}
|
5
docs/source/cuda.rst
Normal file
@ -0,0 +1,5 @@
|
||||
torch.cuda
|
||||
===================================
|
||||
|
||||
.. automodule:: torch.cuda
|
||||
:members:
|
7
docs/source/data.rst
Normal file
@ -0,0 +1,7 @@
|
||||
torch.utils.data
|
||||
===================================
|
||||
|
||||
.. automodule:: torch.utils.data
|
||||
.. autoclass:: Dataset
|
||||
.. autoclass:: TensorDataset
|
||||
.. autoclass:: DataLoader
|
33
docs/source/index.rst
Normal file
@ -0,0 +1,33 @@
|
||||
.. PyTorch documentation master file, created by
|
||||
sphinx-quickstart on Fri Dec 23 13:31:47 2016.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
:github_url: https://github.com/pytorch/pytorch
|
||||
|
||||
PyTorch documentation
|
||||
===================================
|
||||
|
||||
PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Package Reference
|
||||
|
||||
torch
|
||||
tensors
|
||||
nn
|
||||
optim
|
||||
autograd
|
||||
multiprocessing
|
||||
legacy
|
||||
cuda
|
||||
data
|
||||
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
||||
* :ref:`genindex`
|
||||
* :ref:`modindex`
|
||||
* :ref:`search`
|
2
docs/source/legacy.rst
Normal file
@ -0,0 +1,2 @@
|
||||
torch.legacy
|
||||
===================================
|
4
docs/source/multiprocessing.rst
Normal file
@ -0,0 +1,4 @@
|
||||
torch.multiprocessing
|
||||
===================================
|
||||
|
||||
.. automodule:: torch.multiprocessing
|
5
docs/source/nn.rst
Normal file
@ -0,0 +1,5 @@
|
||||
torch.nn
|
||||
===================================
|
||||
|
||||
.. automodule:: torch.nn
|
||||
.. autoclass:: Container
|
6
docs/source/optim.rst
Normal file
@ -0,0 +1,6 @@
|
||||
torch.optim
|
||||
===================================
|
||||
|
||||
.. automodule:: torch.optim
|
||||
.. autoclass:: Optimizer
|
||||
.. autoclass:: SGD
|
5
docs/source/tensors.rst
Normal file
@ -0,0 +1,5 @@
|
||||
torch.Tensor
|
||||
===================================
|
||||
|
||||
.. autoclass:: torch.FloatTensor
|
||||
:members:
|
149
docs/source/torch.rst
Normal file
@ -0,0 +1,149 @@
|
||||
torch
|
||||
===================================
|
||||
.. automodule:: torch
|
||||
|
||||
Math operations
|
||||
----------------------------------
|
||||
|
||||
.. autofunction:: abs
|
||||
.. autofunction:: acos
|
||||
.. autofunction:: add
|
||||
.. autofunction:: addbmm
|
||||
.. autofunction:: addcdiv
|
||||
.. autofunction:: addcmul
|
||||
.. autofunction:: addmm
|
||||
.. autofunction:: addmv
|
||||
.. autofunction:: addr
|
||||
.. autofunction:: all
|
||||
.. autofunction:: any
|
||||
.. autofunction:: asin
|
||||
.. autofunction:: atan
|
||||
.. autofunction:: atan2
|
||||
.. autofunction:: baddbmm
|
||||
.. autofunction:: bernoulli
|
||||
.. autofunction:: bmm
|
||||
.. autofunction:: cat
|
||||
.. autofunction:: cauchy
|
||||
.. autofunction:: cdiv
|
||||
.. autofunction:: ceil
|
||||
.. autofunction:: cfmod
|
||||
.. autofunction:: cinv
|
||||
.. autofunction:: clamp
|
||||
.. autofunction:: cmax
|
||||
.. autofunction:: cmin
|
||||
.. autofunction:: cmod
|
||||
.. autofunction:: cmul
|
||||
.. autofunction:: cos
|
||||
.. autofunction:: cosh
|
||||
.. autofunction:: cpow
|
||||
.. autofunction:: cremainder
|
||||
.. autofunction:: cross
|
||||
.. autofunction:: csub
|
||||
.. autofunction:: cumprod
|
||||
.. autofunction:: cumsum
|
||||
.. autofunction:: diag
|
||||
.. autofunction:: dist
|
||||
.. autofunction:: div
|
||||
.. autofunction:: dot
|
||||
.. autofunction:: eig
|
||||
.. autofunction:: eq
|
||||
.. autofunction:: equal
|
||||
.. autofunction:: exp
|
||||
.. autofunction:: exponential
|
||||
.. autofunction:: eye
|
||||
.. autofunction:: fill
|
||||
.. autofunction:: floor
|
||||
.. autofunction:: fmod
|
||||
.. autofunction:: frac
|
||||
.. autofunction:: from_numpy
|
||||
.. autofunction:: gather
|
||||
.. autofunction:: ge
|
||||
.. autofunction:: gels
|
||||
.. autofunction:: geometric
|
||||
.. autofunction:: geqrf
|
||||
.. autofunction:: ger
|
||||
.. autofunction:: gesv
|
||||
.. autofunction:: gt
|
||||
.. autofunction:: histc
|
||||
.. autofunction:: index_select
|
||||
.. autofunction:: inverse
|
||||
.. autofunction:: kthvalue
|
||||
.. autofunction:: le
|
||||
.. autofunction:: lerp
|
||||
.. autofunction:: linspace
|
||||
.. autofunction:: log
|
||||
.. autofunction:: log1p
|
||||
.. autofunction:: log_normal
|
||||
.. autofunction:: logspace
|
||||
.. autofunction:: lt
|
||||
.. autofunction:: masked_select
|
||||
.. autofunction:: max
|
||||
.. autofunction:: mean
|
||||
.. autofunction:: median
|
||||
.. autofunction:: min
|
||||
.. autofunction:: mm
|
||||
.. autofunction:: mod
|
||||
.. autofunction:: mode
|
||||
.. autofunction:: mul
|
||||
.. autofunction:: multinomial
|
||||
.. autofunction:: mv
|
||||
.. autofunction:: ne
|
||||
.. autofunction:: neg
|
||||
.. autofunction:: nonzero
|
||||
.. autofunction:: norm
|
||||
.. autofunction:: normal
|
||||
.. autofunction:: numel
|
||||
.. autofunction:: ones
|
||||
.. autofunction:: orgqr
|
||||
.. autofunction:: ormqr
|
||||
.. autofunction:: potrf
|
||||
.. autofunction:: potri
|
||||
.. autofunction:: potrs
|
||||
.. autofunction:: pow
|
||||
.. autofunction:: prod
|
||||
.. autofunction:: pstrf
|
||||
.. autofunction:: qr
|
||||
.. autofunction:: rand
|
||||
.. autofunction:: randn
|
||||
.. autofunction:: random
|
||||
.. autofunction:: randperm
|
||||
.. autofunction:: range
|
||||
.. autofunction:: remainder
|
||||
.. autofunction:: renorm
|
||||
.. autofunction:: reshape
|
||||
.. autofunction:: round
|
||||
.. autofunction:: rsqrt
|
||||
.. autofunction:: scatter
|
||||
.. autofunction:: sigmoid
|
||||
.. autofunction:: sign
|
||||
.. autofunction:: sin
|
||||
.. autofunction:: sinh
|
||||
.. autofunction:: sort
|
||||
.. autofunction:: sqrt
|
||||
.. autofunction:: squeeze
|
||||
.. autofunction:: std
|
||||
.. autofunction:: sum
|
||||
.. autofunction:: svd
|
||||
.. autofunction:: symeig
|
||||
.. autofunction:: t
|
||||
.. autofunction:: tan
|
||||
.. autofunction:: tanh
|
||||
.. autofunction:: topk
|
||||
.. autofunction:: trace
|
||||
.. autofunction:: transpose
|
||||
.. autofunction:: tril
|
||||
.. autofunction:: triu
|
||||
.. autofunction:: trtrs
|
||||
.. autofunction:: trunc
|
||||
.. autofunction:: unfold
|
||||
.. autofunction:: uniform
|
||||
.. autofunction:: var
|
||||
.. autofunction:: zero
|
||||
.. autofunction:: zeros
|
||||
|
||||
|
||||
|
||||
Parallelism
|
||||
----------------------------------
|
||||
.. autofunction:: get_num_threads
|
||||
.. autofunction:: set_num_threads
|
@ -1,7 +1,26 @@
|
||||
"""
|
||||
The torch package contains data structures for multi-dimensional
|
||||
tensors and mathematical operations over these are defined.
|
||||
Additionally, it provides many utilities for efficient serializing of
|
||||
Tensors and arbitrary types, and other useful utilities.
|
||||
|
||||
It has a CUDA counterpart, that enables you to run your tensor computations
|
||||
on an NVIDIA GPU with compute capability >= 2.0.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import math
|
||||
from ._utils import _import_dotted_name
|
||||
|
||||
__all__ = [
|
||||
'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type',
|
||||
'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed',
|
||||
'save', 'load', 'set_printoptions',
|
||||
'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
|
||||
'ShortStorage', 'CharStorage', 'ByteStorage',
|
||||
'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
|
||||
'ShortTensor', 'CharTensor', 'ByteTensor',
|
||||
]
|
||||
|
||||
################################################################################
|
||||
# Load the extension module
|
||||
################################################################################
|
||||
@ -22,7 +41,13 @@ if not hasattr(_dl_flags, 'RTLD_GLOBAL') or not hasattr(_dl_flags, 'RTLD_NOW'):
|
||||
|
||||
old_flags = sys.getdlopenflags()
|
||||
sys.setdlopenflags(_dl_flags.RTLD_GLOBAL | _dl_flags.RTLD_NOW)
|
||||
|
||||
from torch._C import *
|
||||
|
||||
__all__ += [name for name in dir(_C)
|
||||
if name[0] != '_' and
|
||||
not name.endswith('Base')]
|
||||
|
||||
sys.setdlopenflags(old_flags)
|
||||
del _dl_flags
|
||||
del old_flags
|
||||
@ -150,25 +175,15 @@ class ByteTensor(_C.ByteTensorBase, _TensorBase):
|
||||
return ByteStorage
|
||||
|
||||
|
||||
_tensor_classes = set()
|
||||
_storage_classes = set()
|
||||
_storage_classes = {
|
||||
DoubleStorage, FloatStorage, LongStorage, IntStorage, ShortStorage,
|
||||
CharStorage, ByteStorage,
|
||||
}
|
||||
|
||||
|
||||
_storage_classes.add(DoubleStorage)
|
||||
_storage_classes.add(FloatStorage)
|
||||
_storage_classes.add(LongStorage)
|
||||
_storage_classes.add(IntStorage)
|
||||
_storage_classes.add(ShortStorage)
|
||||
_storage_classes.add(CharStorage)
|
||||
_storage_classes.add(ByteStorage)
|
||||
|
||||
_tensor_classes.add(DoubleTensor)
|
||||
_tensor_classes.add(FloatTensor)
|
||||
_tensor_classes.add(LongTensor)
|
||||
_tensor_classes.add(IntTensor)
|
||||
_tensor_classes.add(ShortTensor)
|
||||
_tensor_classes.add(CharTensor)
|
||||
_tensor_classes.add(ByteTensor)
|
||||
_tensor_classes = {
|
||||
DoubleTensor, FloatTensor, LongTensor, IntTensor, ShortTensor,
|
||||
CharTensor, ByteTensor,
|
||||
}
|
||||
|
||||
|
||||
set_default_tensor_type('torch.FloatTensor')
|
||||
@ -215,3 +230,5 @@ import torch.cuda
|
||||
import torch.autograd
|
||||
import torch.nn
|
||||
import torch.optim
|
||||
from . import docs # attaches docstrings to torch functions
|
||||
del docs
|
||||
|
@ -524,6 +524,31 @@ PyObject *THPModule_safeCall(PyObject *_unused, PyObject *args, PyObject *kwargs
|
||||
return result;
|
||||
}
|
||||
|
||||
PyObject *THPModule_addDocStr(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
// adds a __doc__ string to a function, similar to numpy's arr_add_docstring
|
||||
PyObject *obj;
|
||||
PyObject *doc;
|
||||
if (!PyArg_ParseTuple(args, "OO!", &obj, &THPUtils_stringType, &doc)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (Py_TYPE(obj) == &PyCFunction_Type) {
|
||||
PyCFunctionObject* f = (PyCFunctionObject *)obj;
|
||||
if (f->m_ml->ml_doc) {
|
||||
return PyErr_Format(PyExc_RuntimeError,
|
||||
"function '%s' already has a docstring", f->m_ml->ml_name);
|
||||
}
|
||||
f->m_ml->ml_doc = THPUtils_stringAsString(doc);
|
||||
Py_INCREF(doc);
|
||||
} else {
|
||||
return PyErr_Format(PyExc_TypeError,
|
||||
"don't know how to add docstring to type '%s'", Py_TYPE(obj)->tp_name);
|
||||
}
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
extern PyObject * THCPModule_initExtension(PyObject *self);
|
||||
extern PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg);
|
||||
@ -547,28 +572,29 @@ extern PyObject * THCPModule_cudaSleep(PyObject *_unused, PyObject *cycles);
|
||||
#endif
|
||||
|
||||
static PyMethodDef TorchMethods[] = {
|
||||
{"_initExtension", (PyCFunction)THPModule_initExtension, METH_O, NULL},
|
||||
{"_autograd_init", (PyCFunction)THPAutograd_initExtension, METH_NOARGS, NULL},
|
||||
{"_initExtension", (PyCFunction)THPModule_initExtension, METH_O, NULL},
|
||||
{"_autograd_init", (PyCFunction)THPAutograd_initExtension, METH_NOARGS, NULL},
|
||||
{"_add_docstr", (PyCFunction)THPModule_addDocStr, METH_VARARGS, NULL},
|
||||
#ifdef WITH_CUDA
|
||||
{"_cuda_init", (PyCFunction)THCPModule_initExtension, METH_NOARGS, NULL},
|
||||
{"_cuda_setDevice", (PyCFunction)THCPModule_setDevice_wrap, METH_O, NULL},
|
||||
{"_cuda_getDevice", (PyCFunction)THCPModule_getDevice_wrap, METH_NOARGS, NULL},
|
||||
{"_cuda_init", (PyCFunction)THCPModule_initExtension, METH_NOARGS, NULL},
|
||||
{"_cuda_setDevice", (PyCFunction)THCPModule_setDevice_wrap, METH_O, NULL},
|
||||
{"_cuda_getDevice", (PyCFunction)THCPModule_getDevice_wrap, METH_NOARGS, NULL},
|
||||
{"_cuda_getDeviceCount", (PyCFunction)THCPModule_getDeviceCount_wrap, METH_NOARGS, NULL},
|
||||
{"_cuda_getCurrentStream", (PyCFunction)THCPModule_getCurrentStream_wrap, METH_NOARGS, NULL},
|
||||
{"_cuda_setStream", (PyCFunction)THCPModule_setStream_wrap, METH_O, NULL},
|
||||
{"_cuda_setStream", (PyCFunction)THCPModule_setStream_wrap, METH_O, NULL},
|
||||
{"_cuda_isDriverSufficient", (PyCFunction)THCPModule_isDriverSufficient, METH_NOARGS, NULL},
|
||||
{"_cuda_getDriverVersion", (PyCFunction)THCPModule_getDriverVersion, METH_NOARGS, NULL},
|
||||
{"_cuda_getRNGState", (PyCFunction)THCPModule_getRNGState, METH_NOARGS, NULL},
|
||||
{"_cuda_setRNGState", (PyCFunction)THCPModule_setRNGState, METH_O, NULL},
|
||||
{"_cuda_manualSeed", (PyCFunction)THCPModule_manualSeed, METH_O, NULL},
|
||||
{"_cuda_manualSeedAll", (PyCFunction)THCPModule_manualSeedAll, METH_O, NULL},
|
||||
{"_cuda_seed", (PyCFunction)THCPModule_seed, METH_NOARGS, NULL},
|
||||
{"_cuda_seedAll", (PyCFunction)THCPModule_seedAll, METH_NOARGS, NULL},
|
||||
{"_cuda_initialSeed", (PyCFunction)THCPModule_initialSeed, METH_NOARGS, NULL},
|
||||
{"_cuda_getRNGState", (PyCFunction)THCPModule_getRNGState, METH_NOARGS, NULL},
|
||||
{"_cuda_setRNGState", (PyCFunction)THCPModule_setRNGState, METH_O, NULL},
|
||||
{"_cuda_manualSeed", (PyCFunction)THCPModule_manualSeed, METH_O, NULL},
|
||||
{"_cuda_manualSeedAll", (PyCFunction)THCPModule_manualSeedAll, METH_O, NULL},
|
||||
{"_cuda_seed", (PyCFunction)THCPModule_seed, METH_NOARGS, NULL},
|
||||
{"_cuda_seedAll", (PyCFunction)THCPModule_seedAll, METH_NOARGS, NULL},
|
||||
{"_cuda_initialSeed", (PyCFunction)THCPModule_initialSeed, METH_NOARGS, NULL},
|
||||
{"_cuda_cudaHostAllocator", (PyCFunction)THCPModule_cudaHostAllocator, METH_NOARGS, NULL},
|
||||
{"_cuda_synchronize", (PyCFunction)THCPModule_cudaSynchronize, METH_NOARGS, NULL},
|
||||
{"_cuda_getLibPath", (PyCFunction)THCPModule_getLibPath, METH_NOARGS, NULL},
|
||||
{"_cuda_sleep", (PyCFunction)THCPModule_cudaSleep, METH_O, NULL},
|
||||
{"_cuda_synchronize", (PyCFunction)THCPModule_cudaSynchronize, METH_NOARGS, NULL},
|
||||
{"_cuda_getLibPath", (PyCFunction)THCPModule_getLibPath, METH_NOARGS, NULL},
|
||||
{"_cuda_sleep", (PyCFunction)THCPModule_cudaSleep, METH_O, NULL},
|
||||
#endif
|
||||
{"_safe_call", (PyCFunction)THPModule_safeCall, METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"_sendfd", (PyCFunction)THPModule_sendfd, METH_VARARGS, NULL},
|
||||
|
@ -27,10 +27,14 @@
|
||||
#define THPUtils_bytesFromString(c_string) PyString_FromString(c_string)
|
||||
#define THPUtils_checkBytes(obj) PyString_Check(obj)
|
||||
#define THPUtils_bytesAsString(obj) PyString_AS_STRING(obj)
|
||||
#define THPUtils_stringType PyString_Type
|
||||
#define THPUtils_stringAsString(obj) PyString_AS_STRING(obj)
|
||||
#else
|
||||
#define THPUtils_bytesFromString(c_string) PyBytes_FromString(c_string)
|
||||
#define THPUtils_checkBytes(obj) PyBytes_Check(obj)
|
||||
#define THPUtils_bytesAsString(obj) PyBytes_AS_STRING(obj)
|
||||
#define THPUtils_stringType PyUnicode_Type
|
||||
#define THPUtils_stringAsString(obj) PyBytes_AS_STRING(PyUnicode_AsUTF8String(obj))
|
||||
#endif
|
||||
|
||||
|
||||
|
592
torch/docs.py
Normal file
@ -0,0 +1,592 @@
|
||||
"""Adds docstrings to functions defined in the torch._C"""
|
||||
|
||||
import torch._C
|
||||
from torch._C import _add_docstr as add_docstr
|
||||
|
||||
add_docstr(torch._C.abs,
|
||||
"""abs([result], tensor) -> tensor
|
||||
|
||||
Computes the element-wise absolute value of a tensor.
|
||||
|
||||
Example:
|
||||
>>> torch.abs(torch.FloatTensor([-1, -2, 3]))
|
||||
FloatTensor([1, 2, 3])
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.acos,
|
||||
"""
|
||||
acos([result], tensor) -> tensor
|
||||
|
||||
Computes the element-wise inverse cosine of a tensor.
|
||||
|
||||
Example:
|
||||
>>> torch.acos(torch.FloatTensor([1, -1]))
|
||||
FloatTensor([0.0000, 3.1416])
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.add,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.addbmm,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.addcdiv,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.addcmul,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.addmm,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.addmv,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.addr,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.all,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.any,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.asin,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.atan,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.atan2,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.baddbmm,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.bernoulli,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.bmm,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cat,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cauchy,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cdiv,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.ceil,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cfmod,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cinv,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.clamp,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cmax,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cmin,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cmod,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cmul,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cos,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cosh,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cpow,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cremainder,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cross,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.csub,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cumprod,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cumsum,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.diag,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.dist,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.div,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.dot,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.eig,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.eq,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.equal,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.exp,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.exponential,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.eye,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.fill,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.floor,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.fmod,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.frac,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.from_numpy,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.gather,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.ge,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.gels,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.geometric,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.geqrf,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.ger,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.gesv,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.get_num_threads,
|
||||
"""
|
||||
get_num_threads() -> int
|
||||
|
||||
Gets the number of OpenMP threads used for parallelizing CPU operations
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.gt,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.histc,
|
||||
"""
|
||||
histc([result], tensor, bins=100, min=0, max=0) -> tensor
|
||||
|
||||
Computes the histogram of a tensor.
|
||||
|
||||
The elements are sorted into equal width bins between `min` and `max`. If `min`
|
||||
and `max` are both zero, the minimum and maximum values of the data are used.
|
||||
|
||||
Args:
|
||||
result: (tensor) optional result tensor
|
||||
tensor: (tensor) input data
|
||||
bins: (int) number of histogram bins
|
||||
min: (int) lower end of the range (inclusive)
|
||||
max: (int) upper end of the range (inclusive)
|
||||
|
||||
Returns:
|
||||
tensor: the histogram
|
||||
|
||||
Example:
|
||||
>>> torch.histc(torch.FloatTensor([1, 2, 1]), bins=4, min=0, max=3)
|
||||
FloatTensor([0, 2, 1, 0])
|
||||
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.index_select,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.inverse,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.kthvalue,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.le,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.lerp,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.linspace,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.log,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.log1p,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.log_normal,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.logspace,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.lt,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.masked_select,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.max,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.mean,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.median,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.min,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.mm,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.mod,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.mode,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.mul,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.multinomial,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.mv,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.ne,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.neg,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.nonzero,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.norm,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.normal,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.numel,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.ones,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.orgqr,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.ormqr,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.potrf,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.potri,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.potrs,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.pow,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.prod,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.pstrf,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.qr,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.rand,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.randn,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.random,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.randperm,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.range,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.remainder,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.renorm,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.reshape,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.round,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.rsqrt,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.scatter,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.set_num_threads,
|
||||
"""
|
||||
set_num_threads(int)
|
||||
|
||||
Sets the number of OpenMP threads used for parallelizing CPU operations
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.sigmoid,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.sign,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.sin,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.sinh,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.sort,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.sqrt,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.squeeze,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.std,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.sum,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.svd,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.symeig,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.t,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.tan,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.tanh,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.topk,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.trace,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.transpose,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.tril,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.triu,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.trtrs,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.trunc,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.unfold,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.uniform,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.var,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.zero,
|
||||
"""
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.zeros,
|
||||
"""
|
||||
""")
|
@ -1,3 +1,116 @@
|
||||
"""
|
||||
:mod:`torch.optim` is a package for optimizing neural networks.
|
||||
It provides a wide variety of optimization methods such as SGD, Adam etc.
|
||||
|
||||
Currently, the following optimization methods are supported, typically with
|
||||
options such as weight decay and other bells and whistles.
|
||||
|
||||
- SGD `(params, lr=required, momentum=0, dampening=0)`
|
||||
- AdaDelta `(params, rho=0.9, eps=1e-6, weight_decay=0)`
|
||||
- Adagrad `(params, lr=1e-2, lr_decay=0, weight_decay=0)`
|
||||
- Adam `(params, lr=1e-2, betas=(0.9, 0.999), epsilon=1e-8, weight_decay=0)`
|
||||
- AdaMax `(params, lr=1e-2, betas=(0.9, 0.999), eps=1e-38, weight_decay=0)`
|
||||
- Averaged SGD `(params, lr=1e-2, lambd=1e-4, alpha=0.75, t0=1e6, weight_decay=0)`
|
||||
- RProp `(params, lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50))`
|
||||
- RMSProp `(params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0)`
|
||||
|
||||
|
||||
The usage of the Optim package itself is as follows.
|
||||
|
||||
1. Construct an optimizer
|
||||
2. Use `optimizer.step(...)` to optimize.
|
||||
- Call `optimizer.zero_grad()` to zero out the gradient buffers when appropriate
|
||||
|
||||
## 1. Constructing the optimizer
|
||||
|
||||
One first constructs an `Optimizer` object by giving it a list of parameters
|
||||
to optimize, as well as the optimizer options,such as learning rate, weight decay, etc.
|
||||
|
||||
Examples:
|
||||
|
||||
`optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)`
|
||||
|
||||
`optimizer = optim.Adam([var1, var2], lr = 0.0001)`
|
||||
|
||||
### Per-parameter options
|
||||
|
||||
In a more advanced usage, one can specify per-layer options by passing each parameter group along with it's custom options.
|
||||
|
||||
**__Any parameter group that does not have an attribute defined will use the default attributes.__**
|
||||
|
||||
This is very useful when one wants to specify per-layer learning rates for example.
|
||||
|
||||
Example:
|
||||
|
||||
`optim.SGD([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 1e-3}, lr=1e-2, momentum=0.9)`
|
||||
|
||||
`model1`'s parameters will use the default learning rate of `1e-2` and momentum of `0.9`
|
||||
`model2`'s parameters will use a learning rate of `1e-3`, and the default momentum of `0.9`
|
||||
|
||||
Then, you can use the optimizer by calling `optimizer.zero_grad()` and `optimizer.step(...)`. Read the next sections.
|
||||
|
||||
## 2. Taking an optimization step using `Optimizer.step(...)`
|
||||
|
||||
The step function has the following two signatures:
|
||||
|
||||
### a. `Optimizer.step(closure)`
|
||||
|
||||
The `step` function takes a user-defined closure that computes f(x) and returns the loss.
|
||||
|
||||
The closure needs to do the following:
|
||||
- Optimizer.zero_grad()
|
||||
- Compute the loss
|
||||
- Call loss.backward()
|
||||
- return the loss
|
||||
|
||||
Example 1: training a neural network
|
||||
|
||||
```python
|
||||
# Example 1: training a neural network with optimizer.step(closure)
|
||||
net = MNISTNet()
|
||||
criterion = ClassNLLLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=0.001)
|
||||
|
||||
for data in data_batches:
|
||||
input, target = data
|
||||
def closure():
|
||||
optimizer.zero_grad()
|
||||
output = net(input)
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
return loss
|
||||
optimizer.step(closure)
|
||||
```
|
||||
|
||||
Notes: Why is this required? Why cant we simply have the optimizer take the parameters and grads?
|
||||
Some optimization algorithms such as Conjugate Gradient and LBFGS need to evaluate their function
|
||||
multiple times. For such optimization methods, the function (i.e. the closure) has to be defined.
|
||||
|
||||
|
||||
### b. `Optimizer.step()`
|
||||
|
||||
This is a simplified usage that supports most, but not all optimization algorithms. For example, it does not support LBFGS or Conjugate Gradient.
|
||||
|
||||
The usage for this is to simply call the function after the backward() is called on your model.
|
||||
|
||||
Example 2: training a neural network
|
||||
|
||||
```python
|
||||
# Example 2: training a neural network with optimizer.step()
|
||||
net = MNISTNet()
|
||||
criterion = ClassNLLLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=0.001)
|
||||
|
||||
for data in data_batches:
|
||||
input, target = data
|
||||
optimizer.zero_grad()
|
||||
output = net(input)
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
"""
|
||||
|
||||
from .adadelta import Adadelta
|
||||
from .adagrad import Adagrad
|
||||
from .adam import Adam
|
||||
@ -6,6 +119,7 @@ from .asgd import ASGD
|
||||
from .sgd import SGD
|
||||
from .rprop import Rprop
|
||||
from .rmsprop import RMSprop
|
||||
from .optimizer import Optimizer
|
||||
|
||||
del adadelta
|
||||
del adagrad
|
||||
|
@ -5,10 +5,10 @@ class SGD(Optimizer):
|
||||
"""Implements stochastic gradient descent with optional momentum.
|
||||
|
||||
Args:
|
||||
params: parameters to optimize
|
||||
lr: learning rate
|
||||
momentum: momentum factory (default: 0)
|
||||
weight_decay: weight decay (L2 penalty) (default: 0)
|
||||
params: (sequence) parameters to optimize
|
||||
lr: (float) learning rate
|
||||
momentum: (float) momentum factor (default: 0)
|
||||
weight_decay: (float) weight decay (L2 penalty) (default: 0)
|
||||
Example:
|
||||
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
||||
>>> def closure():
|
||||
|