Options to address the "undocumented python objects":
1. Reference the functions in the .rst via the torch.nn.modules namespace. Note that this changes the generated doc filenames / locations for most of these functions!
2. [Not an option] Monkeypatch `__module__` for these objects (broke several tests in CI due to `inspect.findsource` failing after this change)
3. Update the .rst files to also document the torch.nn.modules forms of these functions, duplicating docs.
#### [this is the docs page added](https://docs-preview.pytorch.org/pytorch/pytorch/158491/nn.aliases.html)
This PR takes option 3 by adding an rst page nn.aliases that documents the aliases in nested namespaces, removing all the torch.nn.modules.* entries from the coverage skiplist except
- NLLLoss2d (deprecated)
- Container (deprecated)
- CrossMapLRN2d (what is this?)
- NonDynamicallyQuantizableLinear
This mostly required adding docstrings to `forward`, `extra_repr` and `reset_parameters`. Since forward arguments are already part of the module docstrings I just added a very basic docstring.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158491
Approved by: https://github.com/janeyx99
For a GroupNorm module, if num_channels is not divisible by num_groups, we need to report an error when defining a module other than at the running step.
example:
```
import torch
m = torch.nn.GroupNorm(5, 6)
x = torch.randn(1, 6, 4, 4)
y = m(x)
```
before:
```
Traceback (most recent call last):
File "group_norm_test.py", line 8, in <module>
y = m(x)
File "/home/xiaobinz/miniconda3/envs/pytorch_mater/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
return forward_call(*input, **kwargs)
File "/home/xiaobinz/miniconda3/envs/pytorch_mater/lib/python3.7/site-packages/torch/nn/modules/normalization.py", line 271, in forward
input, self.num_groups, self.weight, self.bias, self.eps)
File "/home/xiaobinz/miniconda3/envs/pytorch_mater/lib/python3.7/site-packages/torch/nn/functional.py", line 2500, in group_norm
return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: Expected number of channels in input to be divisible by num_groups, but got input of shape [1, 6, 4, 4] and num_groups=5
```
after:
```
Traceback (most recent call last):
File "group_norm_test.py", line 6, in <module>
m = torch.nn.GroupNorm(5, 6)
File "/home/xiaobinz/miniconda3/envs/pytorch_test/lib/python3.7/site-packages/torch/nn/modules/normalization.py", line 251, in __init__
raise ValueError('num_channels must be divisible by num_groups')
```
This PR also update the doc of num_groups.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74293
Approved by: https://github.com/jbschlosser
Summary:
Closes https://github.com/pytorch/pytorch/issues/51455
I think the current implementation is aggregating over the correct dimensions. The shape of `normalized_shape` is only used to determine the dimensions to aggregate over. The actual values of `normalized_shape` are used when `elementwise_affine=True` to initialize the weights and biases.
This PR updates the docstring to clarify how `normalized_shape` is used. Here is a short script comparing the implementations for tensorflow and pytorch:
```python
import torch
import torch.nn as nn
import tensorflow as tf
from tensorflow.keras.layers import LayerNormalization
rng = np.random.RandomState()
x = rng.randn(10, 20, 64, 64).astype(np.float32)
# slightly non-trival
x[:, :10, ...] = x[:, :10, ...] * 10 + 20
x[:, 10:, ...] = x[:, 10:, ...] * 30 - 100
# Tensorflow Layer norm
x_tf = tf.convert_to_tensor(x)
layer_norm_tf = LayerNormalization(axis=[-3, -2, -1], epsilon=1e-5)
output_tf = layer_norm_tf(x_tf)
output_tf_np = output_tf.numpy()
# PyTorch Layer norm
x_torch = torch.as_tensor(x)
layer_norm_torch = nn.LayerNorm([20, 64, 64], elementwise_affine=False)
output_torch = layer_norm_torch(x_torch)
output_torch_np = output_torch.detach().numpy()
# check tensorflow and pytorch
torch.testing.assert_allclose(output_tf_np, output_torch_np)
# manual comutation
manual_output = ((x_torch - x_torch.mean(dim=(-3, -2, -1), keepdims=True)) /
(x_torch.var(dim=(-3, -2, -1), keepdims=True, unbiased=False) + 1e-5).sqrt())
torch.testing.assert_allclose(output_torch, manual_output)
```
To get to the layer normalization as shown here:
<img width="157" alt="Screen Shot 2021-05-29 at 2 13 52 PM" src="https://user-images.githubusercontent.com/5402633/120080691-1e37f100-c088-11eb-9060-4f263e4cd093.png">
One needs to pass in `normalized_shape` with shape `x.dim() - 1` with the size of the channels and all spatial dimensions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59178
Reviewed By: ejguan
Differential Revision: D28931877
Pulled By: jbschlosser
fbshipit-source-id: 193e05205b9085bb190c221428c96d2ca29f2a70
Summary:
Context: https://github.com/pytorch/pytorch/pull/53299#discussion_r587882857
These are the only hand-written parts of this diff:
- the addition to `.github/workflows/lint.yml`
- the file endings changed in these four files (to appease FB-internal land-blocking lints):
- `GLOSSARY.md`
- `aten/src/ATen/core/op_registration/README.md`
- `scripts/README.md`
- `torch/csrc/jit/codegen/fuser/README.md`
The rest was generated by running this command (on macOS):
```
git grep -I -l ' $' -- . ':(exclude)**/contrib/**' ':(exclude)third_party' | xargs gsed -i 's/ *$//'
```
I looked over the auto-generated changes and didn't see anything that looked problematic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53406
Test Plan:
This run (after adding the lint but before removing existing trailing spaces) failed:
- https://github.com/pytorch/pytorch/runs/2043032377
This run (on the tip of this PR) succeeded:
- https://github.com/pytorch/pytorch/runs/2043296348
Reviewed By: walterddr, seemethere
Differential Revision: D26856620
Pulled By: samestep
fbshipit-source-id: 3f0de7f7c2e4b0f1c089eac9b5085a58dd7e0d97
Summary:
Fixes https://github.com/pytorch/pytorch/issues/49034
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49035
Test Plan:
Imported from GitHub, without a `Test Plan:` line.
Force rebased to deal with merge conflicts
Reviewed By: zhangguanheng66
Differential Revision: D25767065
Pulled By: walterddr
fbshipit-source-id: ffb904e449f137825824e3f43f3775a55e9b011b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38211
Just because the annotations are inline doesn't mean the files type
check; most of the newly annotated files have type errors and I
added exclusions for them in mypy.ini. The payoff of moving
all of these modules inline is I can delete the relevant code
generation logic for the pyi files (which was added ignore
annotations that weren't actually relevant anymore.)
For the most part the translation was completely mechanical, but there
were two hairy issues. First, I needed to work around a Python 3.6 and
earlier bug where Generic has a nontrivial metaclass. This fix is in
torch/jit/__init__.py. Second, module.py, we need to apply the same
fix for avoiding contravariance checks that the pyi file used to have;
this is done by declaring forward as a variable (rather than a
function), which appears to be sufficient enough to get mypy to not
contravariantly check input arguments.
Because we aren't actually typechecking these modules in most
cases, it is inevitable that some of these type annotations are wrong.
I slavishly copied the old annotations from the pyi files unless there
was an obvious correction I could make. These annotations will probably
need fixing up later.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Differential Revision: D21497397
Pulled By: ezyang
fbshipit-source-id: 2b08bacc152c48f074e7edc4ee5dce1b77d83702
Summary:
xref gh-32838, gh-34032
This is a major refactor of parts of the documentation to split it up using sphinx's `autosummary` feature which will build out `autofuction` and `autoclass` stub files and link to them. The end result is that the top module pages like torch.nn.rst and torch.rst are now more like table-of-contents to the actual single-class or single-function documentations pages.
Along the way, I modified many of the docstrings to eliminate sphinx warnings when building. I think the only thing I changed from a non-documentation perspective is to add names to `__all__` when adding them to `globals()` in `torch.__init__.py`
I do not know the CI system: are the documentation build artifacts available after the build, so reviewers can preview before merging?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37419
Differential Revision: D21337640
Pulled By: ezyang
fbshipit-source-id: d4ad198780c3ae7a96a9f22651e00ff2d31a0c0f
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32745
Some parameters (like `bias` in conv) are optional. To achieve this
previously, you had to add `bias` as a constant, which would invoke some
pretty weird behavior in the frontend, summarized as:
```
if bias is not None:
add it as a parameter normally
else: # bias is None
add it as a constant with the value None
```
There are several things bad about this:
1. Bias is not a constant. Marking it `__constants__` is confusing.
2. It basically relies on an implementation detail (the frontend
processes parameters before constants) to work.
Okay, whatever. I don't even know why we did this originally, but
getting rid of it doesn't break anything, so I assume improved NoneType
refinement has made this a non-issue.
Note on perf: this will make no difference; if bias was `None` it's still
folded out today, if bias is a Tensor it would be added as a parameter
both before and after this change
Test Plan: Imported from OSS
Differential Revision: D19628634
Pulled By: suo
fbshipit-source-id: d9128a09c5d096b938fcf567b8c23b09ac9ab37f
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25339
This is to get rid of backend-specific dispatch in modules; this autograd function is no longer backend specific so
doesn't need to be in a backend specific location.
Test Plan: Imported from OSS
Differential Revision: D17101576
Pulled By: gchanan
fbshipit-source-id: f4f0bd3ecc2d4dbd8cdfedbaabcadb8c603d2507
Summary:
We are planning to put up a deprecation warning for legacy autograd function in 1.2: https://github.com/pytorch/pytorch/pull/22922. This PR removes all usage of legacy function in PyTorch core and test suite, to prepare for the eventual removal of legacy function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22925
Differential Revision: D16344834
Pulled By: yf225
fbshipit-source-id: 8bf4cca740398835a08b7a290f3058c3e46781ba
Summary:
* Deletes all weak script decorators / associated data structures / methods
* In order to keep supporting the standard library in script, this enables recursive script on any function defined in `torch.nn`
* Most changes in `torch/nn` are the result of `ag -Q "weak" torch/nn/ -l | xargs sed -i '/weak/d'`, only `rnn.py` needed manual editing to use the `ignore` and `export` to continue supporting the overloaded `forward` methods
* `Sequential`/`ModuleList` no longer need to be added to constants since they are compiled on demand
This should also fix https://github.com/pytorch/pytorch/issues/22212
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22212
Differential Revision: D15988346
Pulled By: driazati
fbshipit-source-id: af223e3ad0580be895377312949997a70e988e4f
Summary:
A bunch of modules were missing entries for `__constants__` which was making their `__repr__`s not work. Others had `__constants__` that were not necessary since it was provided by some parent class instead.
Fixes#20978
](https://our.intern.facebook.com/intern/diff/15539518/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21071
Pulled By: driazati
Differential Revision: D15539518
fbshipit-source-id: 24bdd1ef41ef636eefd5d2bad4ab2d79646ed4f0
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598
ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a
Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18598 Turn on F401: Unused import warning.**
This was requested by someone at Facebook; this lint is turned
on for Facebook by default. "Sure, why not."
I had to noqa a number of imports in __init__. Hypothetically
we're supposed to use __all__ in this case, but I was too lazy
to fix it. Left for future work.
Be careful! flake8-2 and flake8-3 behave differently with
respect to import resolution for # type: comments. flake8-3 will
report an import unused; flake8-2 will not. For now, I just
noqa'd all these sites.
All the changes were done by hand.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: D14687478
fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
Summary:
PR to update the shape notation for all of the torch.nn modules to take a unified form. The goal is to make these definitions machine-readable and those checkable by unifying the style across all of the different modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15741
Differential Revision: D13709601
Pulled By: ezyang
fbshipit-source-id: fb89a03903fdf0cd0dcf76f3e469b8582b2f3634