## Summary
Adds missing type annotations to `torch.nn.init` and removes `# mypy: allow-untyped-defs` since all functions are now properly typed.
## Changes
- Added missing type annotations to initialization functions in the module.
- Added missing typing imports: `Any`, `Callable`, `Union`
- Removed `# mypy: allow-untyped-defs` comment
- Create Literal types for kaiming initialization mode and nonlinearity.
- Created `__all__`
## Why
Better IDE support, catches type errors earlier, and brings the module up to PyTorch's typing standards. No runtime changes - purely additive typing improvements.
Tested with existing test suite and lintrunner.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154504
Approved by: https://github.com/Skylion007
Try to unblock https://github.com/pytorch/pytorch/issues/131991
- `nn.init.orthogonal_` uses `tensor.new`, which is the legacy factory function. We change this to `tensor.new_empty` (empty is okay since it will be immediately followed by `.normal_()` to fill the tensor) so that it preserves `DTensor`-ness.
- `nn.init.orthogonal_` uses QR decomposition (`aten.linalg_qr.default`) and `torch.diag` (calling into `aten.diagonal_copy.default`). For simplicity, we use naive replicate strategies for now. `aten.diagonal_copy.default` could do something more sophisticated for sharded inputs, but I would rather defer that to later due to the complexity. For `orthogonal_` support specifically, since the result of the QR decomp will be replicated, the input to `aten.diagonal_copy.default` will be replicated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132104
Approved by: https://github.com/albanD, https://github.com/wanchaol
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
Resolves#126888
- #126888
This PR is split from PR #126898.
- #126898
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves#126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
A follow-up of PR #112617 on issue #112596
Added suggested changes from the review.
- More specific on the type of uniform and normal distribution used.
```py
def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:
r"""Fill the input `Tensor` with values using a Xavier uniform distribution.
The method is described in `Understanding the difficulty of training...
"""
```
```py
def kaiming_normal_(
tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
):
r"""Fill the input `Tensor` with values using a Kaiming normal distribution.
The method is described in `Delving deep into rectifiers: Surpassing...
"""
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112864
Approved by: https://github.com/kit1980
Fixes#112596
Fix docstring errors in init.py
### Before the change -> 38 errors
```
╭─user@pc ~/Path/to/pytorch ‹fix/docstring_init›
╰─➤ pydocstyle torch/nn/init.py --count 127 ↵
torch/nn/init.py:1 at module level:
D100: Missing docstring in public module
torch/nn/init.py:68 in public function `calculate_gain`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:123 in public function `uniform_`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:123 in public function `uniform_`:
D400: First line should end with a period (not 'm')
torch/nn/init.py:123 in public function `uniform_`:
D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:141 in public function `normal_`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:141 in public function `normal_`:
D400: First line should end with a period (not 'l')
torch/nn/init.py:141 in public function `normal_`:
D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:165 in public function `trunc_normal_`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:165 in public function `trunc_normal_`:
D400: First line should end with a period (not 'd')
torch/nn/init.py:165 in public function `trunc_normal_`:
D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:187 in public function `constant_`:
D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:203 in public function `ones_`:
D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:216 in public function `zeros_`:
D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:229 in public function `eye_`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:229 in public function `eye_`:
D400: First line should end with a period (not 'y')
torch/nn/init.py:229 in public function `eye_`:
D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:249 in public function `dirac_`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:249 in public function `dirac_`:
D400: First line should end with a period (not 'c')
torch/nn/init.py:249 in public function `dirac_`:
D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:311 in public function `xavier_uniform_`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:311 in public function `xavier_uniform_`:
D400: First line should end with a period (not 'd')
torch/nn/init.py:311 in public function `xavier_uniform_`:
D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:338 in public function `xavier_normal_`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:338 in public function `xavier_normal_`:
D400: First line should end with a period (not 'd')
torch/nn/init.py:338 in public function `xavier_normal_`:
D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:376 in public function `kaiming_uniform_`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:376 in public function `kaiming_uniform_`:
D400: First line should end with a period (not 'd')
torch/nn/init.py:376 in public function `kaiming_uniform_`:
D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:425 in public function `kaiming_normal_`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:425 in public function `kaiming_normal_`:
D400: First line should end with a period (not 'd')
torch/nn/init.py:425 in public function `kaiming_normal_`:
D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:462 in public function `orthogonal_`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:462 in public function `orthogonal_`:
D400: First line should end with a period (not 's')
torch/nn/init.py:462 in public function `orthogonal_`:
D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:507 in public function `sparse_`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:507 in public function `sparse_`:
D400: First line should end with a period (not 'e')
torch/nn/init.py:507 in public function `sparse_`:
D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
38
```
### After the change -> 0 errors
```
╭─user@pc ~/Path/to/pytorch ‹fix/docstring_init*›
╰─➤ pydocstyle torch/nn/init.py --count
0
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112617
Approved by: https://github.com/mikaylagawarecki
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.
I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.
I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
This will solve @albertz's issue as described in #98200 , threading the generator argument through the trunc_normal_ function. I'm still working on #99796 (and won't let it stall out), but this fix doesn't trigger any JIT issues, so I think it might be helpful to get it merged now.
Would be happy to iterate on this if there are any issues.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100810
Approved by: https://github.com/Skylion007, https://github.com/albanD
This is a new version of #15648 based on the latest master branch.
Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.
In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)
Fixes https://github.com/pytorch/pytorch/issues/71105
@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72164
torch.Tensor ctor creates an empty tensor and this PR makes
ShardedTensor on par with that.
In particular we remove TensorInitParams and instead always a create an empty
tensor and then fill it in for things like ones, zeros, full etc. This is
inline with torch.ones etc. as well since even for those APIs we first create
an empty tensor and then fill it out.
ghstack-source-id: 148318045
Test Plan: waitforbuildbot
Reviewed By: wanchaol
Differential Revision: D33934603
fbshipit-source-id: 5655bbd726f29e74600ebe9f33f9dc5952b528f4
(cherry picked from commit 78b301c78c9d5046e2f0a9818dcbc2cc45e7cdd0)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69874
We have a handful of ops supported for ShardedTensor via
``__torch_function__`` dispatch. However, we currently can't cover all torch
operators and having a way for users to extend this functionality will make
this functionality much more general.
In this PR, I've introduced a custom_sharded_op decorator which can be used to
register a custom sharded op implementation.
ghstack-source-id: 145841141
Test Plan: waitforbuildbot
Reviewed By: wanchaol
Differential Revision: D33078587
fbshipit-source-id: 5936b7ac25582e613653c19afa559219719ee54b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63997
Use torch_function to extend torch.nn.init.uniform_
The Init is done in SPMD fashion. Note that ideally we want to aggregate sharded tensors into a global tensor, init it and reshard. It's fine to run it SPMD since uniform is I.I.D indepenent and identifically distributed.
Also enable unit test for test_linear.py for OSS test
Test Plan:
a) Unit Test
(pytorch) ... $ python test/distributed/_sharded_tensor/ops/test_init.py TestShardedTensorNNInit --v
(pytorch) ... $ python test/distributed/_sharded_tensor/ops/test_linear.py --v (before runs this command is no-op)
or b) Manual run: Instruction here: https://docs.google.com/document/d/1_m1Hdo5w51-hhPlZ_F8Y6PIWrN7UgJZqiSpARYvhsaE/edit#
Imported from OSS
Reviewed By: pritamdamania87, anjali411
Differential Revision: D30563017
fbshipit-source-id: d1859f7682235bcb44515efc69ca92bc5e34fce1
Summary:
This uses the shape of the tensor instead of directly indexing it. This is useful when extending PyTorch's tensor class, e.g. for lazy access. Since the `init` sub-module doesn't check for `torch_function`, it is not possibly to override its functions. Explicitly indexing the tensor will force a call to tensor() and reconstruct the full tensor/explicitly access the elements. Simply using the shape allows to avoid that.
Fixes https://github.com/pytorch/pytorch/issues/53540
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53522
Reviewed By: anjali411
Differential Revision: D26947794
Pulled By: jbschlosser
fbshipit-source-id: 80cd65efed16383f21363cee2eb404c9bc05971c
Summary:
Fixes #{[24991](https://github.com/pytorch/pytorch/issues/24991)}
I used a value of 0.75 as suggested in the forums by Thomas: https://discuss.pytorch.org/t/calculate-gain-tanh/20854/6
I verified that the value keeps the gradient stable for a 100-layer network.
Code to reproduce (from [jpeg729](https://discuss.pytorch.org/t/calculate-gain-tanh/20854/4)):
```python
import torch
import torch.nn.functional as F
import sys
a = torch.randn(1000,1000, requires_grad=True)
b = a
print (f"in: {a.std().item():.4f}")
for i in range(100):
l = torch.nn.Linear(1000,1000, bias=False)
torch.nn.init.xavier_normal_(l.weight, torch.nn.init.calculate_gain("selu"))
b = getattr(F, 'selu')(l(b))
if i % 10 == 0:
print (f"out: {b.std().item():.4f}", end=" ")
a.grad = None
b.sum().backward(retain_graph=True)
print (f"grad: {a.grad.abs().mean().item():.4f}")
```
Output:
```
in: 1.0008
out: 0.7968 grad: 0.6509
out: 0.3127 grad: 0.2760
out: 0.2404 grad: 0.2337
out: 0.2062 grad: 0.2039
out: 0.2056 grad: 0.1795
out: 0.2044 grad: 0.1977
out: 0.2005 grad: 0.2045
out: 0.2042 grad: 0.2273
out: 0.1944 grad: 0.2034
out: 0.2085 grad: 0.2464
```
I included the necessary documentation change, and it passes the _test_calculate_gain_nonlinear_ unittest.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50664
Reviewed By: mruberry
Differential Revision: D25942217
Pulled By: ngimel
fbshipit-source-id: 29ff1be25713484fa7c516df71b12fdaecfb9af8
Summary:
Hello there,
I was going through the default initialization of some layers, and ended up on the `torch.nn.init` documentation. As shown below, there was a slight issue with the docstrings of both `kaiming_normal_` and `kaiming_uniform_` that yielded a wrong list of function parameters:

This PR fixes the indentation in the corresponding docstrings.
Any feedback is welcome!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37739
Differential Revision: D21393728
Pulled By: ngimel
fbshipit-source-id: 64523cb328e72d2e51c2c42b20a4545c1ec5f478