3cda34ebde
[2/N] Apply ruff UP035 check in torch files ( #164054 )
...
This is the result of applying the ruff `UP035` check.
`Callable` is imported from `collections.abc` instead of `typing`.
`TypeAlias` is also imported from `typing`.
This PR is the follow-up of #163947 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164054
Approved by: https://github.com/ezyang , https://github.com/Skylion007
2025-09-29 03:35:32 +00:00
1e79872f2e
[BE] More torch.nn docs coverage test (except for torch.nn.parallel) ( #158654 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158654
Approved by: https://github.com/janeyx99
ghstack dependencies: #158491
2025-07-25 22:03:55 +00:00
4c0d5ad4be
Fix docstring for clip_grads_with_norm_ to reflect clamping behavior ( #158200 )
...
Fix docstring for clip_grads_with_norm_ to reflect clamping behavior
This PR updates the docstring for torch.nn.utils.clip_grads_with_norm_ to accurately reflect the implementation behavior. The current documentation suggests that gradients are always scaled by:
grad = grad * (max_norm / (total_norm + eps))
However, the actual implementation clamps the scale coefficient to a maximum of 1.0, ensuring gradients are only scaled down, not up. This PR corrects the formula and adds a clarifying note to avoid confusion for users.
Updated the formula in the docstring to:
grad = grad * min(max_norm / (total_norm + eps), 1.0)
Added a note explaining the rationale for clamping (to prevent gradient amplification).
Ensured consistency with the behavior of clip_grad_norm_.
Fixes #151554
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158200
Approved by: https://github.com/mikaylagawarecki
2025-07-25 18:07:41 +00:00
fcc682be4b
[BE][Ez]: Fully type nn.utils.clip_grad ( #154801 )
...
Full types clip_grad and exposed typing annotations that were hidden by a bad decorator
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154801
Approved by: https://github.com/jansel
2025-07-09 14:27:51 +00:00
596b418391
[BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/**
to ruff format
( #144548 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548
Approved by: https://github.com/ezyang
2025-06-14 11:27:04 +00:00
5e93abe3c0
Address docs for clip_grad functions ( #155125 )
...
This PR takes the opinionated stance that `torch.nn.utils.<func>` should be the preferred API over `torch.nn.utils.clip_grad.<func>`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155125
Approved by: https://github.com/albanD , https://github.com/mikaylagawarecki , https://github.com/janeyx99
2025-06-05 19:22:09 +00:00
50de6ae253
Revert "[BE][Ez]: Fully type nn.utils.clip_grad ( #154801 )"
...
This reverts commit 9ce2732b685da527308dc2dc4b2eeb4e252f57d1.
Reverted https://github.com/pytorch/pytorch/pull/154801 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/154801#issuecomment-2937886337 ))
2025-06-04 00:41:27 +00:00
9ce2732b68
[BE][Ez]: Fully type nn.utils.clip_grad ( #154801 )
...
Full types clip_grad and exposed typing annotations that were hidden by a bad decorator
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154801
Approved by: https://github.com/jansel
2025-05-31 23:06:45 +00:00
f12d8d60b1
Add hint message when parameters is empty in clip_grad_norm_ ( #151529 )
...
Fixes #148259
## Changes
- Add print warning message when `parameters` generator exhausted
## Test Result
### print warning
```python
import torch
import torch.nn as nn
import torch.optim as optim
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
inputs = torch.randn(16, 10)
targets = torch.randn(16, 1)
outputs = model(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
params_to_clip = model.parameters()
for p in params_to_clip:
print(p.shape)
max_norm = 1.0
norm_type = 2.0
total_norm = nn.utils.clip_grad_norm_(params_to_clip, max_norm, norm_type)
print(f"total_norm: {total_norm}")
```
```bash
/home/zong/code/pytorch/torch/nn/utils/clip_grad.py:222: UserWarning: `parameters` is an empty generator, no gradient clipping will occur.
warnings.warn(
total_norm: 0.0
```
### UT
```bash
pytest test/test_nn.py -k test_clip_grad_norm
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151529
Approved by: https://github.com/jbschlosser
2025-05-22 11:23:39 +00:00
fe90a5c140
[Easy] Optimize clip_grad
param description ( #151532 )
...
Fix missing optional description in `clip_grad_norm_` and `clip_grad_value_`
## Test Result
### Before


### After


Pull Request resolved: https://github.com/pytorch/pytorch/pull/151532
Approved by: https://github.com/Skylion007 , https://github.com/albanD
2025-04-17 16:47:38 +00:00
7178b827d7
PEP585: Missed conversions ( #145342 )
...
Differential Revision: [D68785969](https://our.internmc.facebook.com/intern/diff/D68785969 )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145342
Approved by: https://github.com/bobrenjc93
2025-01-29 05:24:36 +00:00
5e8e1d725a
Remove some unused type ignores (round 1) ( #142325 )
...
Over time, a large number of the existing type ignores have become irrelevant/unused/dead as a result of improvements in annotations and type checking.
Having these `# type: ignore` linger around is not ideal for two reasons:
- They are verbose/ugly syntatically.
- They could hide genuine bugs in the future, if a refactoring would actually introduce a bug but it gets hidden by the ignore.
I'm counting over 1500 unused ignores already. This is a first PR that removes some of them. Note that I haven't touched type ignores that looked "conditional" like the import challenge mentioned in https://github.com/pytorch/pytorch/pull/60006#issuecomment-2480604728 . I will address these at a later point, and eventually would enable `warn_unused_ignores = True` in the mypy configuration as discussed in that comment to prevent accumulating more dead ignores going forward.
This PR should have no effect on runtime at all.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142325
Approved by: https://github.com/Skylion007 , https://github.com/janeyx99
2024-12-09 18:23:46 +00:00
2ee91db03d
Add APIs to separate norm calculation and gradient scaling in nn.utils.clip_grad_norm_
( #139662 )
...
Fixes https://github.com/pytorch/pytorch/issues/139467
Refactor `nn.utils.clip_grad_norm_` into `nn.utils.get_total_norm` and then `nn.utils.clip_grads_with_norm_` . `clip_grad_norm_` now calls into these two new ops,
`get_total_norm` is generalized (rather than `get_grad_norm` due to the discussion on the issue from @awgu)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139662
Approved by: https://github.com/H-Huang
2024-11-07 23:13:23 +00:00
fc5aa24a6e
Rewording doc string for clip_grad_norm_ ( #133406 )
...
The doc string for `torch.nn.utils.clip_grad_norm_` needed some clarity, it was earlier unclear that the norm was being computed over the norms of individual gradient parameters.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133406
Approved by: https://github.com/mikaylagawarecki
2024-08-15 16:21:27 +00:00
5a0068cc69
[BE] mypy: disallow untyped decorators ( #131428 )
...
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.
Step 1 - Enable the error and override in all the offending files.
#131429
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby , https://github.com/oulgen
2024-07-23 21:50:55 +00:00
5a80d2df84
[BE] enable UFMT for torch/nn/utils
( #128595 )
...
Part of #123062
- #123062
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128595
Approved by: https://github.com/Skylion007
2024-06-13 18:34:57 +00:00
27f9d3b0a1
Flip default value for mypy disallow_untyped_defs [8/11] ( #127845 )
...
See #127836 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127845
Approved by: https://github.com/oulgen
ghstack dependencies: #127842 , #127843 , #127844
2024-06-08 18:49:56 +00:00
67ef2683d9
[BE] wrap deprecated function/class with typing_extensions.deprecated
( #127689 )
...
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
Resolves #126888
- #126888
This PR is split from PR #126898 .
- #126898
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00
033e733021
Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated
( #126898 )"
...
This reverts commit 749a132fb0a8325cbad4734a563aa459ca611991.
Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456 ))
2024-05-31 19:47:24 +00:00
749a132fb0
[BE] wrap deprecated function/class with typing_extensions.deprecated
( #126898 )
...
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves #126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
7c71d7f32b
[DTensor] Supported foreach=True
for clip_grad_norm_
( #120910 )
...
This PR adds support for `clip_grad_norm_(foreach=True)` by implementing `aten._foreach_norm.Scalar` and `aten._foreach_mul_.Tensor`. `foreach=True` is required to get competitive performance with `DTensor`.
`foreach=True` reduces CPU overhead for Llama-7B from 388 ms to 63 ms. Existing flat-parameter FSDP's `clip_grad_norm_` takes 3 ms on CPU 😢 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120910
Approved by: https://github.com/wanchaol , https://github.com/janeyx99
ghstack dependencies: #120238
2024-03-02 00:28:09 +00:00
ef9b6d6816
Replace individual detaches with overall torch.no_grad decorator ( #120638 )
...
Fixes https://github.com/pytorch/pytorch/issues/120611 .
At first, I thought there were too many detaches, but @awgu and I made the conclusion that both `clip_grad_norm_` and `clip_grad_value_` should be run under torch.no_grad similar to optimizer step. One option is to continue calling `detach`, but doing that on many tensors is slower than setting the context to be no_grad (I think?) and Andrew had noticed: "the 1st round of detaches takes 10 ms for FSDP2, whereas existing FSDP's clip_grad_norm_ only takes 3 ms total" since there are more tensors in FSDP2.
This change also disables grad mode for the foreach path of `clip_grad_value_`, which the first attempt that didn't do this was an oversight. Not sure how to add a test case for this since grad mode will be turned back on after the call.
New profile is not much different from the one in the bottom of this stack, but the number of detaches is 0 :D:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (c71bcceb)]$ python playground2.py
STAGE:2024-02-26 13:07:15 211224:211224 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-02-26 13:07:16 211224:211224 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-02-26 13:07:16 211224:211224 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
cudaLaunchKernel 70.63% 110.415ms 70.63% 110.415ms 5.811ms 0.000us 0.00% 0.000us 0.000us 19
aten::linalg_vector_norm 0.18% 284.000us 26.00% 40.636ms 40.636ms 3.000us 0.99% 3.000us 3.000us 1
aten::clamp 0.09% 148.000us 14.88% 23.261ms 23.261ms 1.000us 0.33% 1.000us 1.000us 1
aten::to 0.75% 1.170ms 14.05% 21.970ms 84.826us 0.000us 0.00% 258.000us 0.996us 259
aten::_to_copy 2.28% 3.562ms 13.31% 20.800ms 161.240us 0.000us 0.00% 258.000us 2.000us 129
aten::_foreach_norm 4.44% 6.935ms 12.72% 19.878ms 9.939ms 19.000us 6.29% 21.000us 10.500us 2
aten::add 0.11% 173.000us 10.97% 17.153ms 17.153ms 1.000us 0.33% 1.000us 1.000us 1
aten::stack 2.99% 4.673ms 9.15% 14.300ms 14.300ms 0.000us 0.00% 6.000us 6.000us 1
aten::copy_ 5.49% 8.586ms 8.96% 14.001ms 108.535us 258.000us 85.43% 258.000us 2.000us 129
aten::reciprocal 0.11% 179.000us 8.35% 13.051ms 13.051ms 1.000us 0.33% 1.000us 1.000us 1
aten::cat 0.64% 993.000us 4.42% 6.902ms 6.902ms 6.000us 1.99% 6.000us 6.000us 1
aten::zeros 0.04% 69.000us 4.28% 6.698ms 3.349ms 0.000us 0.00% 2.000us 1.000us 2
aten::zero_ 0.04% 66.000us 4.13% 6.462ms 3.231ms 0.000us 0.00% 2.000us 1.000us 2
aten::fill_ 0.06% 98.000us 4.09% 6.396ms 3.198ms 2.000us 0.66% 2.000us 1.000us 2
aten::_foreach_mul_ 1.50% 2.342ms 3.79% 5.924ms 2.962ms 10.000us 3.31% 10.000us 5.000us 2
aten::empty 3.27% 5.115ms 3.27% 5.115ms 19.826us 0.000us 0.00% 0.000us 0.000us 258
aten::empty_strided 2.07% 3.237ms 2.07% 3.237ms 25.093us 0.000us 0.00% 0.000us 0.000us 129
cudaDeviceEnablePeerAccess 1.93% 3.023ms 1.93% 3.023ms 1.512ms 0.000us 0.00% 0.000us 0.000us 2
aten::unsqueeze 1.21% 1.896ms 1.74% 2.725ms 10.645us 0.000us 0.00% 0.000us 0.000us 256
cudaMemcpyAsync 1.01% 1.572ms 1.01% 1.572ms 12.186us 0.000us 0.00% 0.000us 0.000us 129
aten::as_strided 0.54% 839.000us 0.54% 839.000us 3.265us 0.000us 0.00% 0.000us 0.000us 257
cudaStreamWaitEvent 0.34% 539.000us 0.34% 539.000us 2.089us 0.000us 0.00% 0.000us 0.000us 258
cudaEventRecord 0.18% 274.000us 0.18% 274.000us 1.062us 0.000us 0.00% 0.000us 0.000us 258
aten::mul 0.07% 107.000us 0.08% 132.000us 132.000us 1.000us 0.33% 1.000us 1.000us 1
cudaDeviceSynchronize 0.01% 17.000us 0.01% 17.000us 8.500us 0.000us 0.00% 0.000us 0.000us 2
cudaDeviceCanAccessPeer 0.00% 7.000us 0.00% 7.000us 3.500us 0.000us 0.00% 0.000us 0.000us 2
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 2.000us 0.66% 2.000us 1.000us 2
void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 13.000us 4.30% 13.000us 3.250us 4
void at::native::lpnorm_cleanup<float, (at::native::... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 1.99% 6.000us 3.000us 2
Memcpy PtoP (Device -> Device) 0.00% 0.000us 0.00% 0.000us 0.000us 258.000us 85.43% 258.000us 2.000us 129
void at::native::(anonymous namespace)::CatArrayBatc... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 1.99% 6.000us 3.000us 2
void at::native::reduce_kernel<512, 1, at::native::R... 0.00% 0.000us 0.00% 0.000us 0.000us 3.000us 0.99% 3.000us 3.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 10.000us 3.31% 10.000us 2.500us 4
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 156.319ms
Self CUDA time total: 302.000us
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120638
Approved by: https://github.com/Skylion007 , https://github.com/albanD
ghstack dependencies: #120623
2024-02-27 01:27:05 +00:00
df72819f91
clip_grad_norm can use fast foreach path for inf norm ( #120623 )
...
Now that foreach_norm supports inf, we should not special case it.
For a mere 256 parameters, we get a win of 30ms in CPU time and ~800us -> 300us decrease in CUDA time. This win is only bigger for more parameters.
New profile:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (bf1c0490|REBASE-i|detached HEAD)]$ python playground2.py
STAGE:2024-02-26 13:14:10 395517:395517 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-02-26 13:14:11 395517:395517 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-02-26 13:14:11 395517:395517 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
cudaLaunchKernel 67.01% 102.262ms 67.01% 102.262ms 5.382ms 2.000us 0.66% 2.000us 0.105us 19
aten::linalg_vector_norm 0.20% 311.000us 23.44% 35.776ms 35.776ms 3.000us 0.99% 3.000us 3.000us 1
aten::to 0.79% 1.208ms 14.62% 22.311ms 86.143us 0.000us 0.00% 263.000us 1.015us 259
aten::clamp 0.12% 182.000us 13.96% 21.303ms 21.303ms 1.000us 0.33% 1.000us 1.000us 1
aten::_to_copy 2.38% 3.628ms 13.83% 21.103ms 163.589us 0.000us 0.00% 263.000us 2.039us 129
aten::_foreach_norm 4.71% 7.185ms 13.54% 20.659ms 10.329ms 19.000us 6.29% 23.000us 11.500us 2
aten::add 0.14% 211.000us 10.86% 16.580ms 16.580ms 1.000us 0.33% 1.000us 1.000us 1
aten::stack 3.11% 4.744ms 9.59% 14.642ms 14.642ms 0.000us 0.00% 6.000us 6.000us 1
aten::copy_ 5.71% 8.721ms 9.27% 14.152ms 109.705us 258.000us 85.43% 263.000us 2.039us 129
aten::reciprocal 0.13% 193.000us 7.93% 12.100ms 12.100ms 1.000us 0.33% 1.000us 1.000us 1
aten::cat 0.67% 1.017ms 4.67% 7.129ms 7.129ms 6.000us 1.99% 6.000us 6.000us 1
aten::zeros 0.05% 79.000us 4.46% 6.800ms 3.400ms 0.000us 0.00% 2.000us 1.000us 2
aten::zero_ 0.05% 79.000us 4.28% 6.537ms 3.268ms 0.000us 0.00% 2.000us 1.000us 2
aten::fill_ 0.09% 131.000us 4.23% 6.458ms 3.229ms 2.000us 0.66% 2.000us 1.000us 2
aten::_foreach_mul_ 1.56% 2.377ms 3.86% 5.896ms 2.948ms 10.000us 3.31% 10.000us 5.000us 2
aten::empty 3.55% 5.414ms 3.55% 5.414ms 20.984us 0.000us 0.00% 0.000us 0.000us 258
aten::empty_strided 2.18% 3.323ms 2.18% 3.323ms 25.760us 0.000us 0.00% 0.000us 0.000us 129
aten::detach 0.85% 1.302ms 2.10% 3.199ms 12.496us 0.000us 0.00% 0.000us 0.000us 256
cudaDeviceEnablePeerAccess 2.01% 3.069ms 2.01% 3.069ms 1.534ms 0.000us 0.00% 0.000us 0.000us 2
aten::unsqueeze 1.24% 1.899ms 1.81% 2.769ms 10.816us 0.000us 0.00% 0.000us 0.000us 256
detach 1.24% 1.897ms 1.24% 1.897ms 7.410us 0.000us 0.00% 0.000us 0.000us 256
cudaMemcpyAsync 1.01% 1.539ms 1.01% 1.539ms 11.930us 0.000us 0.00% 0.000us 0.000us 129
aten::as_strided 0.58% 881.000us 0.58% 881.000us 3.428us 0.000us 0.00% 0.000us 0.000us 257
cudaStreamWaitEvent 0.35% 540.000us 0.35% 540.000us 2.093us 0.000us 0.00% 0.000us 0.000us 258
cudaEventRecord 0.18% 278.000us 0.18% 278.000us 1.078us 5.000us 1.66% 5.000us 0.019us 258
aten::mul 0.08% 125.000us 0.09% 138.000us 138.000us 1.000us 0.33% 1.000us 1.000us 1
cudaDeviceSynchronize 0.01% 13.000us 0.01% 13.000us 6.500us 0.000us 0.00% 0.000us 0.000us 2
cudaDeviceCanAccessPeer 0.00% 5.000us 0.00% 5.000us 2.500us 0.000us 0.00% 0.000us 0.000us 2
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 2.000us 0.66% 2.000us 1.000us 2
void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 13.000us 4.30% 13.000us 3.250us 4
void at::native::lpnorm_cleanup<float, (at::native::... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 1.99% 6.000us 3.000us 2
Memcpy PtoP (Device -> Device) 0.00% 0.000us 0.00% 0.000us 0.000us 258.000us 85.43% 258.000us 2.000us 129
void at::native::(anonymous namespace)::CatArrayBatc... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 1.99% 6.000us 3.000us 2
void at::native::reduce_kernel<512, 1, at::native::R... 0.00% 0.000us 0.00% 0.000us 0.000us 3.000us 0.99% 3.000us 3.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1
void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 10.000us 3.31% 10.000us 2.500us 4
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 152.613ms
Self CUDA time total: 302.000us
```
Compared to on main:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5a0a9644)]$ python playground2.py
STAGE:2024-02-26 13:09:56 285045:285045 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-02-26 13:09:57 285045:285045 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-02-26 13:09:57 285045:285045 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
cudaLaunchKernel 61.42% 113.375ms 61.42% 113.375ms 424.625us 45.000us 5.66% 45.000us 0.169us 267
aten::linalg_vector_norm 14.04% 25.909ms 37.67% 69.534ms 271.617us 514.000us 64.65% 559.000us 2.184us 256
aten::to 0.78% 1.433ms 12.87% 23.751ms 91.703us 0.000us 0.00% 278.000us 1.073us 259
aten::_to_copy 2.02% 3.730ms 12.09% 22.318ms 173.008us 0.000us 0.00% 278.000us 2.155us 129
aten::clamp 0.09% 174.000us 11.43% 21.103ms 21.103ms 1.000us 0.13% 1.000us 1.000us 1
aten::add 0.11% 205.000us 9.08% 16.768ms 16.768ms 1.000us 0.13% 1.000us 1.000us 1
aten::copy_ 4.94% 9.112ms 8.15% 15.043ms 116.612us 258.000us 32.45% 278.000us 2.155us 129
aten::stack 2.76% 5.091ms 7.97% 14.719ms 14.719ms 0.000us 0.00% 6.000us 6.000us 1
aten::reciprocal 0.11% 194.000us 7.01% 12.933ms 12.933ms 1.000us 0.13% 1.000us 1.000us 1
aten::max 0.09% 165.000us 6.43% 11.868ms 11.868ms 3.000us 0.38% 3.000us 3.000us 1
aten::detach 1.58% 2.911ms 4.12% 7.596ms 14.836us 0.000us 0.00% 0.000us 0.000us 512
aten::cat 0.56% 1.042ms 3.73% 6.882ms 6.882ms 6.000us 0.75% 6.000us 6.000us 1
aten::_foreach_mul_ 1.36% 2.503ms 3.33% 6.145ms 3.072ms 10.000us 1.26% 10.000us 5.000us 2
detach 2.54% 4.685ms 2.54% 4.685ms 9.150us 0.000us 0.00% 0.000us 0.000us 512
aten::empty_strided 1.92% 3.545ms 1.92% 3.545ms 27.481us 0.000us 0.00% 0.000us 0.000us 129
cudaDeviceEnablePeerAccess 1.64% 3.022ms 1.64% 3.022ms 1.511ms 0.000us 0.00% 0.000us 0.000us 2
aten::unsqueeze 1.03% 1.892ms 1.49% 2.746ms 10.727us 0.000us 0.00% 0.000us 0.000us 256
aten::as_strided 1.35% 2.494ms 1.35% 2.494ms 4.862us 0.000us 0.00% 0.000us 0.000us 513
cudaMemcpyAsync 1.01% 1.868ms 1.01% 1.868ms 14.481us 4.000us 0.50% 4.000us 0.031us 129
cudaStreamWaitEvent 0.41% 760.000us 0.41% 760.000us 2.946us 8.000us 1.01% 8.000us 0.031us 258
cudaEventRecord 0.15% 276.000us 0.15% 276.000us 1.070us 8.000us 1.01% 8.000us 0.031us 258
aten::mul 0.08% 139.000us 0.08% 153.000us 153.000us 1.000us 0.13% 1.000us 1.000us 1
aten::empty 0.02% 35.000us 0.02% 35.000us 35.000us 0.000us 0.00% 0.000us 0.000us 1
cudaDeviceSynchronize 0.01% 14.000us 0.01% 14.000us 7.000us 0.000us 0.00% 0.000us 0.000us 2
cudaDeviceCanAccessPeer 0.00% 5.000us 0.00% 5.000us 2.500us 0.000us 0.00% 0.000us 0.000us 2
void at::native::reduce_kernel<512, 1, at::native::R... 0.00% 0.000us 0.00% 0.000us 0.000us 514.000us 64.65% 514.000us 2.008us 256
Memcpy PtoP (Device -> Device) 0.00% 0.000us 0.00% 0.000us 0.000us 258.000us 32.45% 258.000us 2.000us 129
void at::native::(anonymous namespace)::CatArrayBatc... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 0.75% 6.000us 3.000us 2
void at::native::reduce_kernel<512, 1, at::native::R... 0.00% 0.000us 0.00% 0.000us 0.000us 3.000us 0.38% 3.000us 3.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.13% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.13% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.13% 1.000us 1.000us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.13% 1.000us 1.000us 1
void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 10.000us 1.26% 10.000us 2.500us 4
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 184.579ms
Self CUDA time total: 795.000us
```
For script:
```
import torch
from math import inf
from torch.nn.utils import clip_grad_norm_
params = [torch.rand(32, 16, device="cuda:3")*5 for _ in range(128)] + [torch.rand(32, 16, device="cuda:4")*-7 for _ in range(128)]
for p in params:
p.grad = torch.rand_like(p)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
) as p:
total_norm = clip_grad_norm_(params, 10.0, norm_type=inf)
torch.cuda.synchronize()
print(p.key_averages().table(sort_by="cpu_time_total"))
print(total_norm)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120623
Approved by: https://github.com/Skylion007 , https://github.com/mikaylagawarecki
2024-02-27 01:27:05 +00:00
ad4472833c
define public API for torch.nn.utils ( #111026 )
...
Adding modules imported here and the following functions to the `__all__`:
* [clip_grad_norm_](https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html )
* [clip_grad_value_](https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_value_.html )
* [remove_weight_norm](https://pytorch.org/docs/stable/generated/torch.nn.utils.remove_weight_norm.html )
* [parameters_to_vector](https://pytorch.org/docs/stable/generated/torch.nn.utils.parameters_to_vector.html )
* [vector_to_parameters](https://pytorch.org/docs/stable/generated/torch.nn.utils.vector_to_parameters.html )
* [remove_spectral_norm](https://pytorch.org/docs/stable/generated/torch.nn.utils.remove_spectral_norm.html )
* [skip_init](https://pytorch.org/docs/stable/generated/torch.nn.utils.skip_init.html )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111026
Approved by: https://github.com/mikaylagawarecki
2023-10-12 23:05:23 +00:00
6d2887cc06
Reland "Move tensor grouping to ATen" ( #103912 )
...
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors, see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.
Also, using `std::hash` for string should result in a faster hash function.
(cherry picked from commit 74b7a6c75e698378882d30958908073407f97fb3)
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5914771</samp>
This pull request introduces a new function `_group_tensors_by_device_and_dtype` that can group tensors by their device and dtype, and updates the `foreach` utilities and several optimizers to use this function. The goal is to improve the performance, readability, and compatibility of the code that handles tensors with different properties. The pull request also adds a test case and type annotations for the new function, and some error checks for the `fused` argument in Adam and AdamW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103912
Approved by: https://github.com/janeyx99
2023-06-21 09:26:33 +00:00
0cb5bc3b04
Revert "Move tensor grouping to ATen ( #100007 )"
...
This reverts commit 74b7a6c75e698378882d30958908073407f97fb3.
Reverted https://github.com/pytorch/pytorch/pull/100007 on behalf of https://github.com/izaitsevfb due to Breaks internal builds, see D46629727 ([comment](https://github.com/pytorch/pytorch/pull/100007#issuecomment-1587861598 ))
2023-06-12 18:30:33 +00:00
74b7a6c75e
Move tensor grouping to ATen ( #100007 )
...
rel: #94344
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100007
Approved by: https://github.com/janeyx99
2023-06-09 15:44:46 +00:00
704283d61f
Improve clip_grad_norm
to use torch.linalg.vector_norm ( #102429 )
...
Done in this PR:
- Use `torch.linalg.vector_norm` instead of `torch.norm`
- Reduce bandwidth boundary of clip_grad_norm when used with `inf`, ie no need to get the returned tensor after `abs`
What I'm slightly unsure:
- I don't know if `inf` support `torch._foreach` API
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102429
Approved by: https://github.com/lezcano
2023-05-30 18:35:18 +00:00
b005ec62b9
[BE] Remove dependency on six
and future
( #94709 )
...
Remove the Python 2 and 3 compatibility library [six](https://pypi.org/project/six ) and [future](https://pypi.org/project/future ) and `torch._six`. We only support Python 3.8+ now. It's time to retire them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94709
Approved by: https://github.com/malfet , https://github.com/Skylion007
2023-02-14 09:14:14 +00:00
8c9f745af1
[foreach] guard default support on native tensors only ( #92923 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92923
Approved by: https://github.com/ngimel , https://github.com/crcrpar
2023-01-26 04:52:58 +00:00
e4d83d54a6
Foreach gradient clipping ( #91846 )
...
Faster gradient clipping using the foreach functions
```
[------------------------ (tensors, scalar) -------------------------]
| without foreach | with foreach | apex
1 threads: ----------------------------------------------------------------------
10 tensors of size 4 | 120.5 | 61.1 | 50.3
100 tensors of size 4 | 946.2 | 239.5 | 136.3
1000 tensors of size 4 | 9808.5 | 2151.1 | 1006.9
10000 tensors of size 4 | 96871.2 | 22637.4 | 10119.1
10 tensors of size 16 | 121.0 | 64.1 | 52.5
100 tensors of size 16 | 993.4 | 252.6 | 136.7
1000 tensors of size 16 | 9427.7 | 2151.2 | 1049.5
10000 tensors of size 16 | 97437.1 | 22203.1 | 10340.0
10 tensors of size 256 | 118.9 | 62.3 | 51.5
100 tensors of size 256 | 955.2 | 243.1 | 134.2
1000 tensors of size 256 | 9374.9 | 2140.7 | 1009.6
10000 tensors of size 256 | 95302.5 | 21849.4 | 10215.5
10 tensors of size 65536 | 118.5 | 62.4 | 51.1
100 tensors of size 65536 | 1740.7 | 243.3 | 225.3
1000 tensors of size 65536 | 17364.1 | 2228.7 | 2004.5
10000 tensors of size 65536 | 177510.1 | 25410.4 | 20678.2
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91846
Approved by: https://github.com/janeyx99
2023-01-20 21:43:29 +00:00
94a6d72032
Update doc of clip grad ( #91312 )
...
Replaces #85772 that has a broken internal state.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91312
Approved by: https://github.com/soulitzer
2022-12-22 22:34:32 +00:00
9db3c517de
Add __all__ for torch.nn.modules, torch.distributed.elastic, torch.nn.utils submodules ( #80240 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80240
Approved by: https://github.com/rohan-varma
2022-06-27 17:11:12 +00:00
7843a5e882
Move Tensor.grad back into C++
...
`Tensor.grad` was moved to python in #30531 to add a warning. However,
that warning has since been lowered into C++ so this wrapper is no
longer necessary.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76675
Approved by: https://github.com/albanD
2022-06-10 13:44:45 +00:00
da764f9224
Update clip_grad_norm_ documentation
...
It actually returns the gradient norm, not the parameter norms.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76230
Approved by: https://github.com/jbschlosser
2022-04-22 14:07:34 +00:00
7f7966a888
[Docs] Fix the syntax of documentation ( #69958 )
...
Summary:
Fixes the syntax of documentation in the file torch/nn/utils/clip_grad.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69958
Reviewed By: mruberry
Differential Revision: D33160612
Pulled By: albanD
fbshipit-source-id: 2dc199fee345bb4c75632900bc6f73a1ab8192a6
2021-12-16 10:38:39 -08:00
e858f6eed9
torch.nn.utils.clip_grad_norm_: remove device syncs ( #61042 )
...
Summary:
Fixes https://github.com/pytorch/pytorch/issues/60691
### Changes
Per the discussion in the above issue, this PR makes 2 changes:
1. When `error_if_nonfinite=False`, the NaN/Inf checks are truly skipped, and no device synchronization occurs.
- Additionally, when performing the checks, the 2 results are combined with `torch.logical_or` to incur only a single sync (instead of 2 in the happy/finite path).
2. The `clip_coef` conditional is removed, in favor of a call to `clamp(..., max=1.0)` and an unconditional multiplication.
### Testing
- The existing unit tests for `clip_grad_norm_` pass.
- I have manually profiled the example program from https://github.com/pytorch/pytorch/issues/60691 , and verified that:
- No synchronizations occur when using `error_if_nonfinite=False`.
- A single synchronization occurs when using `error_if_nonfinite=True`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61042
Reviewed By: mrshenli
Differential Revision: D29764096
Pulled By: jbschlosser
fbshipit-source-id: db594b24608d16374b91bcbb9469046dfeeb152d
2021-07-22 08:53:40 -07:00
38a08a49ea
Flip clip_grad_norm default for error_if_nonfinite to false ( #55169 )
...
Summary:
Non-backwards-compatible change introduced in https://github.com/pytorch/pytorch/pull/53843 is tripping up a lot of code. Better to set it to False initially and then potentially flip to True in the later version to give people time to adapt.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55169
Reviewed By: mruberry
Differential Revision: D27511150
Pulled By: jbschlosser
fbshipit-source-id: 1ac018557c0900b31995c29f04aea060a27bc525
2021-04-02 12:25:32 -07:00
3ddc6174da
Raise error in clip_grad_norm_ if norm is non-finite ( #53843 )
...
Summary:
**BC-breaking note**: This change throws errors for cases that used to silently pass. The old behavior can be obtained by setting `error_if_nonfinite=False`
Fixes https://github.com/pytorch/pytorch/issues/46849
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53843
Reviewed By: malfet
Differential Revision: D27291838
Pulled By: jbschlosser
fbshipit-source-id: 216d191b26e1b5919a44a3af5cde6f35baf825c4
2021-03-29 08:41:21 -07:00
e6779d4357
[*.py] Rename "Arguments:" to "Args:" ( #49736 )
...
Summary:
I've written custom parsers and emitters for everything from docstrings to classes and functions. However, I recently came across an issue when I was parsing/generating from the TensorFlow codebase: inconsistent use of `Args:` and `Arguments:` in its docstrings.
```sh
(pytorch#c348fae)$ for name in 'Args:' 'Arguments:'; do
printf '%-10s %04d\n' "$name" "$(rg -IFtpy --count-matches "$name" | paste -s -d+ -- | bc)"; done
Args: 1095
Arguments: 0336
```
It is easy enough to extend my parsers to support both variants, however it looks like `Arguments:` is wrong anyway, as per:
- https://google.github.io/styleguide/pyguide.html#doc-function-args @ [`ddccc0f`](https://github.com/google/styleguide/blob/ddccc0f/pyguide.md )
- https://chromium.googlesource.com/chromiumos/docs/+/master/styleguide/python.md#describing-arguments-in-docstrings @ [`9fc0fc0`](https://chromium.googlesource.com/chromiumos/docs/+/9fc0fc0/styleguide/python.md )
- https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html @ [`c0ae8e3`](https://github.com/sphinx-contrib/napoleon/blob/c0ae8e3/docs/source/example_google.rst )
Therefore, only `Args:` is valid. This PR replaces them throughout the codebase.
PS: For related PRs, see tensorflow/tensorflow/pull/45420
PPS: The trackbacks automatically appearing below are sending the same changes to other repositories in the [PyTorch](https://github.com/pytorch ) organisation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49736
Reviewed By: albanD
Differential Revision: D25710534
Pulled By: soumith
fbshipit-source-id: 61e8ff01abb433e9f78185c2d1d0cbd7c22c1619
2020-12-28 09:34:47 -08:00
1c6ace87d1
Embed torch.nn typing annotations ( #43044 )
...
Summary:
Delete several .pyi files and embed annotations from those files in respective .py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43044
Reviewed By: ezyang
Differential Revision: D23123234
Pulled By: malfet
fbshipit-source-id: 4ba361cc84402352090523924b0035e100ba48b1
2020-08-14 13:24:58 -07:00
54d4b419db
fix clip_grad_norm to work with parameters on the different devices ( #38615 )
...
Summary:
Per title.
We move all the individual gradient norms to a single device before stacking (no-op if all the gradients are already on a single device), `clip_coef` is copied to the device of gradient, which may be suboptimal as there could be multiple copies, but no worse than when we were synchronizing for each parameter. In a simple case of all gradients on a single device, there should be no synchronization.
Also, we no longer error out if parameter list is empty or none of the parameters have gradients, and return 0 total_norm instead.
Fixes https://github.com/pytorch/pytorch/issues/38605
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38615
Reviewed By: ailzhang
Differential Revision: D21634588
Pulled By: ngimel
fbshipit-source-id: ea4d08d4f3445438260052820c7ca285231a156b
2020-05-19 10:33:40 -07:00
e74a215ade
Changed clip_grad_norm_ total_norm calculation ( #32020 )
...
Summary:
Redefines the computation of the total_norm to increase performance as shown in https://github.com/pytorch/pytorch/issues/31474 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32020
Differential Revision: D19353309
Pulled By: ngimel
fbshipit-source-id: bf7530dcd39f56614a211b5f21445864d4f2e875
2020-01-13 08:13:46 -08:00
19a6de328f
Correct docstring of vision/init functions
...
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17351
Differential Revision: D14276355
Pulled By: soumith
fbshipit-source-id: 9b572b6a04eeb1e44cd93961edac76ed10f7b24e
2019-03-01 11:40:23 -08:00
27455e9c78
Use _six for inf and nan ( #9500 )
...
Summary:
Things like `float('inf')` are actually quite expensive.
```py
In [1]: import math
In [2]: %timeit -n 200 math.inf
49.3 ns ± 1.42 ns per loop (mean ± std. dev. of 7 runs, 200 loops each)
In [3]: %timeit -n 200 float('inf')
194 ns ± 39.1 ns per loop (mean ± std. dev. of 7 runs, 200 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9500
Reviewed By: soumith
Differential Revision: D8876229
Pulled By: SsnL
fbshipit-source-id: 78602b76bb53d5588910b58270930c0bd413d2d7
2018-07-18 10:40:29 -07:00
89c2b50a15
Grad clip for parameters on different devices ( #9302 )
...
Summary:
I'm trying to write a multi-gpu network by pipelining some layers onto different GPUs. However, the current gradient clip requires all the parameters to locate in the same device.
The overhead of CUDA launch is reduced since the scalar calculation is performed on CPU, but it introduces extra data transfers.
No performance regression is observed by running the following snippet:
```python
import time
import torch
module = torch.nn.Sequential(
torch.nn.LSTM(1024, 1024),
torch.nn.LSTM(256, 256),
torch.nn.Linear(100, 10000),
).cuda()
torch.nn.utils.clip_grad_norm_(module.parameters(), 1)
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
torch.nn.utils.clip_grad_norm_(module.parameters(), 1)
torch.cuda.synchronize()
time_elapse = time.time() - start
print('{} ms per clip'.format(time_elapse))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9302
Differential Revision: D8781551
Pulled By: soumith
fbshipit-source-id: 9d76d01fe0531927f770a16b9523872a7e08e927
2018-07-10 07:56:55 -07:00
1078491502
Change is_tensor to isinstance(*, torch.Tensor) ( #7814 )
...
Thanks!
2018-05-24 15:08:16 -04:00
2222fc7666
Add support for accepting Tensor as input in clip_grad_* functions. ( #7769 )
2018-05-23 12:12:03 +02:00
8325206c6f
A clip grad fix for sparse tensors. ( #7257 )
2018-05-04 00:35:32 +02:00
7fcaf3b49e
Update torch.nn.init and torch.nn.utils.clip_grad ( #6173 )
...
Introducing two updates.
1. Add param to He initialization scheme in torch.nn.init
Problem solved:
The function calculate_gain can take an argument to specify the type of non-linearity used. However, it wasn't possible to pass this argument directly to the He / Kaiming weight initialization function.
2. Add util to clip gradient value in torch.nn.utils.clip_grad
Problem solved:
DL libraries typically provide users with easy access to functions for clipping the gradients both using the norm and a fixed value. However, the utils clip_grad.py only had a function to clip the gradient norm.
* add param to He initialization scheme in torch.nn.init
* add util to clip gradient value in torch/nn/utils/clip_grad.py
* update doc in torch.nn.utils.clip_grad
* update and add test for torch.nn.utils.clip_grad
* update function signature in torch.nn.utils.clip_grad to match suffix_ convention
* ensure backward compatibility in torch.nn.utils.clip_grad
* remove DeprecationWarning in torch.nn.utils.clip_grad
* extend test and implementation of torch.nn.utils.clip_grad
* update test and implementation torch.nn.utils.clip_grad
2018-04-17 11:32:32 -04:00