mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Options to address the "undocumented python objects": 1. Reference the functions in the .rst via the torch.nn.modules namespace. Note that this changes the generated doc filenames / locations for most of these functions! 2. [Not an option] Monkeypatch `__module__` for these objects (broke several tests in CI due to `inspect.findsource` failing after this change) 3. Update the .rst files to also document the torch.nn.modules forms of these functions, duplicating docs. #### [this is the docs page added](https://docs-preview.pytorch.org/pytorch/pytorch/158491/nn.aliases.html) This PR takes option 3 by adding an rst page nn.aliases that documents the aliases in nested namespaces, removing all the torch.nn.modules.* entries from the coverage skiplist except - NLLLoss2d (deprecated) - Container (deprecated) - CrossMapLRN2d (what is this?) - NonDynamicallyQuantizableLinear This mostly required adding docstrings to `forward`, `extra_repr` and `reset_parameters`. Since forward arguments are already part of the module docstrings I just added a very basic docstring. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158491 Approved by: https://github.com/janeyx99
171 lines
5.7 KiB
Python
171 lines
5.7 KiB
Python
# mypy: allow-untyped-defs
|
|
from typing import Union
|
|
|
|
from torch import Tensor
|
|
from torch.types import _size
|
|
|
|
from .module import Module
|
|
|
|
|
|
__all__ = ["Flatten", "Unflatten"]
|
|
|
|
|
|
class Flatten(Module):
|
|
r"""
|
|
Flattens a contiguous range of dims into a tensor.
|
|
|
|
For use with :class:`~nn.Sequential`, see :meth:`torch.flatten` for details.
|
|
|
|
Shape:
|
|
- Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
|
|
where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
|
|
number of dimensions including none.
|
|
- Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
|
|
|
|
Args:
|
|
start_dim: first dim to flatten (default = 1).
|
|
end_dim: last dim to flatten (default = -1).
|
|
|
|
Examples::
|
|
>>> input = torch.randn(32, 1, 5, 5)
|
|
>>> # With default parameters
|
|
>>> m = nn.Flatten()
|
|
>>> output = m(input)
|
|
>>> output.size()
|
|
torch.Size([32, 25])
|
|
>>> # With non-default parameters
|
|
>>> m = nn.Flatten(0, 2)
|
|
>>> output = m(input)
|
|
>>> output.size()
|
|
torch.Size([160, 5])
|
|
"""
|
|
|
|
__constants__ = ["start_dim", "end_dim"]
|
|
start_dim: int
|
|
end_dim: int
|
|
|
|
def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
|
|
super().__init__()
|
|
self.start_dim = start_dim
|
|
self.end_dim = end_dim
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
"""
|
|
Runs the forward pass.
|
|
"""
|
|
return input.flatten(self.start_dim, self.end_dim)
|
|
|
|
def extra_repr(self) -> str:
|
|
"""
|
|
Returns the extra representation of the module.
|
|
"""
|
|
return f"start_dim={self.start_dim}, end_dim={self.end_dim}"
|
|
|
|
|
|
class Unflatten(Module):
|
|
r"""
|
|
Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
|
|
|
|
* :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
|
|
be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
|
|
|
|
* :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
|
|
a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
|
|
(tuple of `(name, size)` tuples) for `NamedTensor` input.
|
|
|
|
Shape:
|
|
- Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
|
|
dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
|
|
- Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
|
|
:math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
|
|
|
|
Args:
|
|
dim (Union[int, str]): Dimension to be unflattened
|
|
unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension
|
|
|
|
Examples:
|
|
>>> input = torch.randn(2, 50)
|
|
>>> # With tuple of ints
|
|
>>> m = nn.Sequential(
|
|
>>> nn.Linear(50, 50),
|
|
>>> nn.Unflatten(1, (2, 5, 5))
|
|
>>> )
|
|
>>> output = m(input)
|
|
>>> output.size()
|
|
torch.Size([2, 2, 5, 5])
|
|
>>> # With torch.Size
|
|
>>> m = nn.Sequential(
|
|
>>> nn.Linear(50, 50),
|
|
>>> nn.Unflatten(1, torch.Size([2, 5, 5]))
|
|
>>> )
|
|
>>> output = m(input)
|
|
>>> output.size()
|
|
torch.Size([2, 2, 5, 5])
|
|
>>> # With namedshape (tuple of tuples)
|
|
>>> input = torch.randn(2, 50, names=("N", "features"))
|
|
>>> unflatten = nn.Unflatten("features", (("C", 2), ("H", 5), ("W", 5)))
|
|
>>> output = unflatten(input)
|
|
>>> output.size()
|
|
torch.Size([2, 2, 5, 5])
|
|
"""
|
|
|
|
NamedShape = tuple[tuple[str, int]]
|
|
|
|
__constants__ = ["dim", "unflattened_size"]
|
|
dim: Union[int, str]
|
|
unflattened_size: Union[_size, NamedShape]
|
|
|
|
def __init__(
|
|
self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
if isinstance(dim, int):
|
|
self._require_tuple_int(unflattened_size)
|
|
elif isinstance(dim, str):
|
|
self._require_tuple_tuple(unflattened_size)
|
|
else:
|
|
raise TypeError("invalid argument type for dim parameter")
|
|
|
|
self.dim = dim
|
|
self.unflattened_size = unflattened_size
|
|
|
|
def _require_tuple_tuple(self, input) -> None:
|
|
if isinstance(input, tuple):
|
|
for idx, elem in enumerate(input):
|
|
if not isinstance(elem, tuple):
|
|
raise TypeError(
|
|
"unflattened_size must be tuple of tuples, "
|
|
+ f"but found element of type {type(elem).__name__} at pos {idx}"
|
|
)
|
|
return
|
|
raise TypeError(
|
|
"unflattened_size must be a tuple of tuples, "
|
|
+ f"but found type {type(input).__name__}"
|
|
)
|
|
|
|
def _require_tuple_int(self, input) -> None:
|
|
if isinstance(input, (tuple, list)):
|
|
for idx, elem in enumerate(input):
|
|
if not isinstance(elem, int):
|
|
raise TypeError(
|
|
"unflattened_size must be tuple of ints, "
|
|
+ f"but found element of type {type(elem).__name__} at pos {idx}"
|
|
)
|
|
return
|
|
raise TypeError(
|
|
f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}"
|
|
)
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
"""
|
|
Runs the forward pass.
|
|
"""
|
|
return input.unflatten(self.dim, self.unflattened_size)
|
|
|
|
def extra_repr(self) -> str:
|
|
"""
|
|
Returns the extra representation of the module.
|
|
"""
|
|
return f"dim={self.dim}, unflattened_size={self.unflattened_size}"
|