Options to address the "undocumented python objects":
1. Reference the functions in the .rst via the torch.nn.modules namespace. Note that this changes the generated doc filenames / locations for most of these functions!
2. [Not an option] Monkeypatch `__module__` for these objects (broke several tests in CI due to `inspect.findsource` failing after this change)
3. Update the .rst files to also document the torch.nn.modules forms of these functions, duplicating docs.
#### [this is the docs page added](https://docs-preview.pytorch.org/pytorch/pytorch/158491/nn.aliases.html)
This PR takes option 3 by adding an rst page nn.aliases that documents the aliases in nested namespaces, removing all the torch.nn.modules.* entries from the coverage skiplist except
- NLLLoss2d (deprecated)
- Container (deprecated)
- CrossMapLRN2d (what is this?)
- NonDynamicallyQuantizableLinear
This mostly required adding docstrings to `forward`, `extra_repr` and `reset_parameters`. Since forward arguments are already part of the module docstrings I just added a very basic docstring.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158491
Approved by: https://github.com/janeyx99
I found the same issue as #147490 (@jibril-b-coulibaly).
There's an equivalent in the [doc-string](https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html#rnn) of `torch.nn.RNN`:
```python
# Efficient implementation equivalent to the following with bidirectional=False
def forward(x, hx=None):
if batch_first:
x = x.transpose(0, 1)
seq_len, batch_size, _ = x.size()
if hx is None:
hx = torch.zeros(num_layers, batch_size, hidden_size)
h_t_minus_1 = hx
h_t = hx
output = []
for t in range(seq_len):
for layer in range(num_layers):
h_t[layer] = torch.tanh(
x[t] @ weight_ih[layer].T
+ bias_ih[layer]
+ h_t_minus_1[layer] @ weight_hh[layer].T
+ bias_hh[layer]
)
output.append(h_t[-1])
h_t_minus_1 = h_t
output = torch.stack(output)
if batch_first:
output = output.transpose(0, 1)
return output, h_t
```
However there's something wrong.
1. Like mentioned in #147490, line 499 is wrong
fb55bac3de/torch/nn/modules/rnn.py (L499)
The **input for RNNCell should be different** for different layers.
2. The code contains several hidden **reference-related issues** that may result in unintended modifications to tensors. For example in line 504, this causes all elements in the final output list to point to the same tensor.
fb55bac3de/torch/nn/modules/rnn.py (L504)
3. Some variable is not **defined**. Despite being a relatively minor issue in annotation, it can lead to significant confusion for those who are new to the concept. For example `weight_ih` in line 499
fb55bac3de/torch/nn/modules/rnn.py (L499)
So, i write a runnable version to make it more clear:
```python
# Efficient implementation equivalent to the following with bidirectional=False
rnn = nn.RNN(input_size, hidden_size, num_layers)
params = dict(rnn.named_parameters())
def forward(x, hx=None, batch_first=False):
if batch_first:
x = x.transpose(0, 1)
seq_len, batch_size, _ = x.size()
if hx is None:
hx = torch.zeros(rnn.num_layers, batch_size, rnn.hidden_size)
h_t_minus_1 = hx.clone()
h_t = hx.clone()
output = []
for t in range(seq_len):
for layer in range(rnn.num_layers):
input_t = x[t] if layer == 0 else h_t[layer - 1]
h_t[layer] = torch.tanh(
input_t @ params[f"weight_ih_l{layer}"].T
+ h_t_minus_1[layer] @ params[f"weight_hh_l{layer}"].T
+ params[f"bias_hh_l{layer}"]
+ params[f"bias_ih_l{layer}"]
)
output.append(h_t[-1].clone())
h_t_minus_1 = h_t.clone()
output = torch.stack(output)
if batch_first:
output = output.transpose(0, 1)
return output, h_t
```
This code can reproduce the computation of torch.nn.RNN.
For example:
```python
import torch
import torch.nn as nn
torch.manual_seed(0)
input_size, hidden_size, num_layers = 3, 5, 2
rnn = nn.RNN(input_size, hidden_size, num_layers)
params = dict(rnn.named_parameters())
x = torch.randn(10, 4, 3)
official_imp = rnn(x)
my_imp = forward(x)
assert torch.allclose(official_imp[0], my_imp[0])
assert torch.allclose(official_imp[1], my_imp[1])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153620
Approved by: https://github.com/mikaylagawarecki
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
Resolves#126888
- #126888
This PR is split from PR #126898.
- #126898
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves#126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Fixes#112599
Fixed errors relating to pydocstyle in the following files. The remaining errors are related to docstrings at the module level and at methods within each module, `forward()`, `reset_parameters`, `__init__` ..etc
pydocstyle torch/nn/modules/pooling.py --count
before: 49
after: 29
**remaining errors:**
```
torch/nn/modules/pooling.py:1 at module level:
D100: Missing docstring in public module
torch/nn/modules/pooling.py:90 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:163 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:240 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:315 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/pooling.py:321 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:402 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/pooling.py:408 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:472 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/pooling.py:478 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:541 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/pooling.py:550 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:620 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/pooling.py:630 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:706 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/pooling.py:716 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:720 in public method `__setstate__`:
D105: Missing docstring in magic method
torch/nn/modules/pooling.py:774 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/pooling.py:792 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:845 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/pooling.py:863 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:925 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:979 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:1026 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:1068 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:1111 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:1150 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:1189 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pooling.py:1228 in public method `forward`:
D102: Missing docstring in public method
```
pydocstyle torch/nn/modules/upsampling.py --count
before: 14
after: 7
**remaining:**
```
torch/nn/modules/upsampling.py:1 at module level:
D100: Missing docstring in public module
torch/nn/modules/upsampling.py:142 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/upsampling.py:156 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/upsampling.py:160 in public method `__setstate__`:
D105: Missing docstring in magic method
torch/nn/modules/upsampling.py:166 in public method `extra_repr`:
D102: Missing docstring in public method
torch/nn/modules/upsampling.py:216 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/upsampling.py:263 in public method `__init__`:
D107: Missing docstring in __init__
```
pydocstyle torch/nn/modules/rnn.py --count
before: 47
after: 40
**remaining**
```
torch/nn/modules/rnn.py:1 at module level:
D100: Missing docstring in public module
torch/nn/modules/rnn.py:59 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/rnn.py:160 in public method `__setattr__`:
D105: Missing docstring in magic method
torch/nn/modules/rnn.py:225 in public method `reset_parameters`:
D102: Missing docstring in public method
torch/nn/modules/rnn.py:230 in public method `check_input`:
D102: Missing docstring in public method
torch/nn/modules/rnn.py:242 in public method `get_expected_hidden_size`:
D102: Missing docstring in public method
torch/nn/modules/rnn.py:256 in public method `check_hidden_size`:
D102: Missing docstring in public method
torch/nn/modules/rnn.py:272 in public method `check_forward_args`:
D102: Missing docstring in public method
torch/nn/modules/rnn.py:278 in public method `permute_hidden`:
D102: Missing docstring in public method
torch/nn/modules/rnn.py:284 in public method `extra_repr`:
D102: Missing docstring in public method
torch/nn/modules/rnn.py:305 in public method `__getstate__`:
D105: Missing docstring in magic method
torch/nn/modules/rnn.py:313 in public method `__setstate__`:
D105: Missing docstring in magic method
torch/nn/modules/rnn.py:355 in public method `all_weights`:
D102: Missing docstring in public method
torch/nn/modules/rnn.py:471 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/rnn.py:478 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/rnn.py:481 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/rnn.py:503 in public method `forward` (skipping F811):
D102: Missing docstring in public method
torch/nn/modules/rnn.py:762 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/rnn.py:768 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/rnn.py:771 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/rnn.py:774 in public method `get_expected_cell_size`:
D102: Missing docstring in public method
torch/nn/modules/rnn.py:786 in public method `check_forward_args`:
D102: Missing docstring in public method
torch/nn/modules/rnn.py:798 in public method `permute_hidden`:
D102: Missing docstring in public method
torch/nn/modules/rnn.py:809 in public method `forward` (skipping F811):
D102: Missing docstring in public method
torch/nn/modules/rnn.py:820 in public method `forward` (skipping F811):
D102: Missing docstring in public method
torch/nn/modules/rnn.py:1030 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/rnn.py:1036 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/rnn.py:1039 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/rnn.py:1046 in public method `forward` (skipping F811):
D102: Missing docstring in public method
torch/nn/modules/rnn.py:1054 in public method `forward` (skipping F811):
D102: Missing docstring in public method
torch/nn/modules/rnn.py:1123 in public class `RNNCellBase`:
D101: Missing docstring in public class
torch/nn/modules/rnn.py:1134 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/rnn.py:1152 in public method `extra_repr`:
D102: Missing docstring in public method
torch/nn/modules/rnn.py:1160 in public method `reset_parameters`:
D102: Missing docstring in public method
torch/nn/modules/rnn.py:1224 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/rnn.py:1230 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/rnn.py:1327 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/rnn.py:1332 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/rnn.py:1422 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/rnn.py:1427 in public method `forward`:
D102: Missing docstring in public method
```
pydocstyle torch/nn/modules/pixelshuffle.py --count
before: 13
after: 8
**remaining:**
```
torch/nn/modules/pixelshuffle.py:1 at module level:
D100: Missing docstring in public module
torch/nn/modules/pixelshuffle.py:52 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/pixelshuffle.py:56 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pixelshuffle.py:59 in public method `extra_repr`:
D102: Missing docstring in public method
torch/nn/modules/pixelshuffle.py:105 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/pixelshuffle.py:109 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/pixelshuffle.py:112 in public method `extra_repr`:
D102: Missing docstring in public method
```
pydocstyle torch/nn/modules/sparse.py --count
before: 14
after: 8
**remaining errors:**
```
torch/nn/modules/sparse.py:1 at module level:
D100: Missing docstring in public module
torch/nn/modules/sparse.py:124 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/sparse.py:153 in public method `reset_parameters`:
D102: Missing docstring in public method
torch/nn/modules/sparse.py:162 in public method `forward`:
D102: Missing docstring in public method
torch/nn/modules/sparse.py:167 in public method `extra_repr`:
D102: Missing docstring in public method
torch/nn/modules/sparse.py:320 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/modules/sparse.py:350 in public method `reset_parameters`:
D102: Missing docstring in public method
torch/nn/modules/sparse.py:396 in public method `extra_repr`:
D102: Missing docstring in public method
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113177
Approved by: https://github.com/ezyang
This pull request addresses an inconsistency in the representation of the Hadamard product across PyTorch documentation. Currently, the notation varies among different modules:
- In `torch.nn.LSTM` documentation the Hadamard product is represented with $\odot$
- In `torch.nn.GRU` documentation the Hadamard product is represented with $*$
- In `torch.nn.LSTMCell` documentation the Hadamard product is represented with $*$
- In `torch.nn.GRUCell` documentation the Hadamard product is represented with $*$
- In `torch.ao.nn.quantized.dynamic.GRU` documentation the Hadamard product is represented with $*$
This PR proposes consistently representing the Hadamard product throughout the documentation to enhance clarity and align with established standards.
The notation $\odot$ will be uniformly adopted, following the convention in the [Deep Learning Book](https://www.deeplearningbook.org/contents/linear_algebra.html).
**Changes Made:**
- Modified `torch.nn.GRU` documentation to represent the Hadamard product with $\odot$
- Modified `torch.nn.LSTMCell` documentation to represent the Hadamard product with $\odot$
- Modified `torch.nn.GRUCell` documentation to represent the Hadamard product with $\odot$
- Modified `torch.ao.nn.quantized.dynamic.GRU` documentation to represent the Hadamard product with $\odot$
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111763
Approved by: https://github.com/albanD
Addresses [issue #106085](https://github.com/pytorch/pytorch/issues/106085).
In `torch/nn/modules/rnn.py`:
- Adds documentation string to RNNBase class.
- Adds parameters to __init__ methods for RNN, LSTM, and GRU, classes.
- Adds type annotations to __init__ methods for RNN, LSTM, and GRU.
In `torch/ao/nn/quantized/dynamic/modules/rnn.py`:
- Adds type specifications to `_FLOAT_MODULE` attributes in RNNBase, RNN, LSTM, and GRU classes.
> This resolves a `mypy` assignment error `Incompatible types in assignment (expression has type "Type[LSTM]", base class "RNNBase" defined the type as "Type[RNNBase]")` that seemed to be a result of fully specified type annotations in `torch/nn/modules/rnn.py`).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106222
Approved by: https://github.com/mikaylagawarecki
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.
I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.
I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
- Enabled LSTM weight prepack in inductor.
- Added a mkldnn decomposition for lstm which won't change for different `seq_lens`. With the previous decomposition, for dynamic shapes use case where `seq_lens` changes, the graph will be different.
- Extended several inductor utility functions to support `List(Tensor`) as input. Previously those functions only supported `Tensor` input.
**Update 2023-07-26:**
- https://github.com/pytorch/pytorch/pull/103851 has moved CPU weight packing to be after AOTAutograd. Fixed the support in this PR to follow the same way (mainly in 3b207f7f1c (diff-6dffed1ade0ba3e887f9a4eafa3bfcec267ab2365b8adcb91bd391f49b3fd2e3)).
LSTM is decomposed in `aten.mkldnn_rnn_layer` by layer and by direction. The weight prepack is done at the `mkldnn_rnn_layer` level.
- Add a fix in rnn `__get_state__` function in case we need to recompile an `LSTM` module.
When compiling the module, the weights tensors which are the `named_parameters` of the module are converted to `functional_tensor` here:
76fb72e24a/torch/nn/utils/stateless.py (L125-L128)
The forward function of LSTM will be called:
76fb72e24a/torch/_functorch/aot_autograd.py (L3379-L3381)
In the forward function, the `_flat_weights` are updated to be the same as the weights, thus becoming `functional_tensor`:
76fb72e24a/torch/nn/modules/rnn.py (L775-L778)
The weights tensors are converted back to the original tensors (which are not `functional_tensor` anymore) before exiting the `_reparametrize_module` context here:
76fb72e24a/torch/nn/utils/stateless.py (L130-L142)
But since `_flat_weights` is not in the `named_parameters` of the module, it's still `functional_tensor` ([link of the parameters that will be converted to functional and reverted back](76fb72e24a/torch/_functorch/aot_autograd.py (L3695-L3698))).
At this moment, if we need to recompile the model, `deepcopy` will be called:
76fb72e24a/torch/_dynamo/utils.py (L915-L917)
And it will report `UnImplemented` since we have `functional_tensor` (`_flat_weights`) and will trigger graph break which is not what we expect:
76fb72e24a/torch/_subclasses/meta_utils.py (L514)
Added a fix in the `__get_state__` to update the `_flat_weights` if ever weights have changed to fix this issue. The fix is covered in the `test_lstm_packed` UT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103071
Approved by: https://github.com/jgong5, https://github.com/jansel