Commit Graph

55 Commits

Author SHA1 Message Date
3cda34ebde [2/N] Apply ruff UP035 check in torch files (#164054)
This is the result of applying the ruff `UP035` check.
`Callable` is imported from `collections.abc` instead of `typing`.
`TypeAlias` is also imported from `typing`.
This PR is the follow-up of #163947.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164054
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2025-09-29 03:35:32 +00:00
1e79872f2e [BE] More torch.nn docs coverage test (except for torch.nn.parallel) (#158654)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158654
Approved by: https://github.com/janeyx99
ghstack dependencies: #158491
2025-07-25 22:03:55 +00:00
4c0d5ad4be Fix docstring for clip_grads_with_norm_ to reflect clamping behavior (#158200)
Fix docstring for clip_grads_with_norm_ to reflect clamping behavior
This PR updates the docstring for torch.nn.utils.clip_grads_with_norm_ to accurately reflect the implementation behavior. The current documentation suggests that gradients are always scaled by:

grad = grad * (max_norm / (total_norm + eps))

However, the actual implementation clamps the scale coefficient to a maximum of 1.0, ensuring gradients are only scaled down, not up. This PR corrects the formula and adds a clarifying note to avoid confusion for users.

Updated the formula in the docstring to:

grad = grad * min(max_norm / (total_norm + eps), 1.0)

Added a note explaining the rationale for clamping (to prevent gradient amplification).
Ensured consistency with the behavior of clip_grad_norm_.

Fixes #151554

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158200
Approved by: https://github.com/mikaylagawarecki
2025-07-25 18:07:41 +00:00
fcc682be4b [BE][Ez]: Fully type nn.utils.clip_grad (#154801)
Full types clip_grad and exposed typing annotations that were hidden by a bad decorator

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154801
Approved by: https://github.com/jansel
2025-07-09 14:27:51 +00:00
596b418391 [BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/** to ruff format (#144548)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548
Approved by: https://github.com/ezyang
2025-06-14 11:27:04 +00:00
5e93abe3c0 Address docs for clip_grad functions (#155125)
This PR takes the opinionated stance that `torch.nn.utils.<func>` should be the preferred API over `torch.nn.utils.clip_grad.<func>`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155125
Approved by: https://github.com/albanD, https://github.com/mikaylagawarecki, https://github.com/janeyx99
2025-06-05 19:22:09 +00:00
50de6ae253 Revert "[BE][Ez]: Fully type nn.utils.clip_grad (#154801)"
This reverts commit 9ce2732b685da527308dc2dc4b2eeb4e252f57d1.

Reverted https://github.com/pytorch/pytorch/pull/154801 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/154801#issuecomment-2937886337))
2025-06-04 00:41:27 +00:00
9ce2732b68 [BE][Ez]: Fully type nn.utils.clip_grad (#154801)
Full types clip_grad and exposed typing annotations that were hidden by a bad decorator

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154801
Approved by: https://github.com/jansel
2025-05-31 23:06:45 +00:00
f12d8d60b1 Add hint message when parameters is empty in clip_grad_norm_ (#151529)
Fixes #148259

## Changes

- Add print warning message when `parameters` generator exhausted

## Test Result
### print warning
```python

import torch
import torch.nn as nn
import torch.optim as optim

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

inputs = torch.randn(16, 10)
targets = torch.randn(16, 1)

outputs = model(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()

params_to_clip = model.parameters()

for p in params_to_clip:
    print(p.shape)

max_norm = 1.0
norm_type = 2.0
total_norm = nn.utils.clip_grad_norm_(params_to_clip, max_norm, norm_type)
print(f"total_norm: {total_norm}")
```

```bash
/home/zong/code/pytorch/torch/nn/utils/clip_grad.py:222: UserWarning: `parameters` is an empty generator, no gradient clipping will occur.
  warnings.warn(
total_norm: 0.0
```

### UT

```bash
pytest test/test_nn.py -k test_clip_grad_norm
```

![image](https://github.com/user-attachments/assets/0aa0f06c-e0a5-43cf-9a97-d7c2747c9180)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151529
Approved by: https://github.com/jbschlosser
2025-05-22 11:23:39 +00:00
fe90a5c140 [Easy] Optimize clip_grad param description (#151532)
Fix missing optional description in `clip_grad_norm_` and `clip_grad_value_`

## Test Result

### Before

![image](https://github.com/user-attachments/assets/3393dd4b-a730-4dd4-8304-9b895ac669d4)

![image](https://github.com/user-attachments/assets/220c4738-a728-474b-b06d-b5be7660d150)

### After

![image](https://github.com/user-attachments/assets/5637bb68-3b6d-49a3-8ee1-3af636950aa0)

![image](https://github.com/user-attachments/assets/c0f1d966-a9ba-4fac-a874-9d4955f6e0d6)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151532
Approved by: https://github.com/Skylion007, https://github.com/albanD
2025-04-17 16:47:38 +00:00
7178b827d7 PEP585: Missed conversions (#145342)
Differential Revision: [D68785969](https://our.internmc.facebook.com/intern/diff/D68785969)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145342
Approved by: https://github.com/bobrenjc93
2025-01-29 05:24:36 +00:00
5e8e1d725a Remove some unused type ignores (round 1) (#142325)
Over time, a large number of the existing type ignores have become irrelevant/unused/dead as a result of improvements in annotations and type checking.

Having these `# type: ignore` linger around is not ideal for two reasons:

- They are verbose/ugly syntatically.
- They could hide genuine bugs in the future, if a refactoring would actually introduce a bug but it gets hidden by the ignore.

I'm counting over 1500 unused ignores already. This is a first PR that removes some of them. Note that I haven't touched type ignores that looked "conditional" like the import challenge mentioned in https://github.com/pytorch/pytorch/pull/60006#issuecomment-2480604728. I will address these at a later point, and eventually would enable `warn_unused_ignores = True` in the mypy configuration as discussed in that comment to prevent accumulating more dead ignores going forward.

This PR should have no effect on runtime at all.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142325
Approved by: https://github.com/Skylion007, https://github.com/janeyx99
2024-12-09 18:23:46 +00:00
2ee91db03d Add APIs to separate norm calculation and gradient scaling in nn.utils.clip_grad_norm_ (#139662)
Fixes https://github.com/pytorch/pytorch/issues/139467

Refactor `nn.utils.clip_grad_norm_` into `nn.utils.get_total_norm` and then `nn.utils.clip_grads_with_norm_` . `clip_grad_norm_` now calls into these two new ops,

`get_total_norm` is generalized (rather than `get_grad_norm` due to the discussion on the issue from @awgu)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139662
Approved by: https://github.com/H-Huang
2024-11-07 23:13:23 +00:00
fc5aa24a6e Rewording doc string for clip_grad_norm_ (#133406)
The doc string for `torch.nn.utils.clip_grad_norm_` needed some clarity, it was earlier unclear that the norm was being computed over the norms of individual gradient parameters.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133406
Approved by: https://github.com/mikaylagawarecki
2024-08-15 16:21:27 +00:00
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
5a80d2df84 [BE] enable UFMT for torch/nn/utils (#128595)
Part of #123062

- #123062
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128595
Approved by: https://github.com/Skylion007
2024-06-13 18:34:57 +00:00
27f9d3b0a1 Flip default value for mypy disallow_untyped_defs [8/11] (#127845)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127845
Approved by: https://github.com/oulgen
ghstack dependencies: #127842, #127843, #127844
2024-06-08 18:49:56 +00:00
67ef2683d9 [BE] wrap deprecated function/class with typing_extensions.deprecated (#127689)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves #126888

- #126888

This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0a8325cbad4734a563aa459ca611991.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
7c71d7f32b [DTensor] Supported foreach=True for clip_grad_norm_ (#120910)
This PR adds support for `clip_grad_norm_(foreach=True)` by implementing `aten._foreach_norm.Scalar` and `aten._foreach_mul_.Tensor`. `foreach=True` is required to get competitive performance with `DTensor`.

`foreach=True` reduces CPU overhead for Llama-7B from 388 ms to 63 ms. Existing flat-parameter FSDP's `clip_grad_norm_` takes 3 ms on CPU 😢 .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120910
Approved by: https://github.com/wanchaol, https://github.com/janeyx99
ghstack dependencies: #120238
2024-03-02 00:28:09 +00:00
ef9b6d6816 Replace individual detaches with overall torch.no_grad decorator (#120638)
Fixes https://github.com/pytorch/pytorch/issues/120611.

At first, I thought there were too many detaches, but @awgu and I made the conclusion that both `clip_grad_norm_` and `clip_grad_value_` should be run under torch.no_grad similar to optimizer step. One option is to continue calling `detach`, but doing that on many tensors is slower than setting the context to be no_grad (I think?) and Andrew had noticed: "the 1st round of detaches takes 10 ms for FSDP2, whereas existing FSDP's clip_grad_norm_ only takes 3 ms total" since there are more tensors in FSDP2.

This change also disables grad mode for the foreach path of `clip_grad_value_`, which the first attempt that didn't do this was an oversight. Not sure how to add a test case for this since grad mode will be turned back on after the call.

New profile is not much different from the one in the bottom of this stack, but the number of detaches is 0 :D:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (c71bcceb)]$ python playground2.py
STAGE:2024-02-26 13:07:15 211224:211224 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-02-26 13:07:16 211224:211224 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-02-26 13:07:16 211224:211224 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                       cudaLaunchKernel        70.63%     110.415ms        70.63%     110.415ms       5.811ms       0.000us         0.00%       0.000us       0.000us            19
                               aten::linalg_vector_norm         0.18%     284.000us        26.00%      40.636ms      40.636ms       3.000us         0.99%       3.000us       3.000us             1
                                            aten::clamp         0.09%     148.000us        14.88%      23.261ms      23.261ms       1.000us         0.33%       1.000us       1.000us             1
                                               aten::to         0.75%       1.170ms        14.05%      21.970ms      84.826us       0.000us         0.00%     258.000us       0.996us           259
                                         aten::_to_copy         2.28%       3.562ms        13.31%      20.800ms     161.240us       0.000us         0.00%     258.000us       2.000us           129
                                    aten::_foreach_norm         4.44%       6.935ms        12.72%      19.878ms       9.939ms      19.000us         6.29%      21.000us      10.500us             2
                                              aten::add         0.11%     173.000us        10.97%      17.153ms      17.153ms       1.000us         0.33%       1.000us       1.000us             1
                                            aten::stack         2.99%       4.673ms         9.15%      14.300ms      14.300ms       0.000us         0.00%       6.000us       6.000us             1
                                            aten::copy_         5.49%       8.586ms         8.96%      14.001ms     108.535us     258.000us        85.43%     258.000us       2.000us           129
                                       aten::reciprocal         0.11%     179.000us         8.35%      13.051ms      13.051ms       1.000us         0.33%       1.000us       1.000us             1
                                              aten::cat         0.64%     993.000us         4.42%       6.902ms       6.902ms       6.000us         1.99%       6.000us       6.000us             1
                                            aten::zeros         0.04%      69.000us         4.28%       6.698ms       3.349ms       0.000us         0.00%       2.000us       1.000us             2
                                            aten::zero_         0.04%      66.000us         4.13%       6.462ms       3.231ms       0.000us         0.00%       2.000us       1.000us             2
                                            aten::fill_         0.06%      98.000us         4.09%       6.396ms       3.198ms       2.000us         0.66%       2.000us       1.000us             2
                                    aten::_foreach_mul_         1.50%       2.342ms         3.79%       5.924ms       2.962ms      10.000us         3.31%      10.000us       5.000us             2
                                            aten::empty         3.27%       5.115ms         3.27%       5.115ms      19.826us       0.000us         0.00%       0.000us       0.000us           258
                                    aten::empty_strided         2.07%       3.237ms         2.07%       3.237ms      25.093us       0.000us         0.00%       0.000us       0.000us           129
                             cudaDeviceEnablePeerAccess         1.93%       3.023ms         1.93%       3.023ms       1.512ms       0.000us         0.00%       0.000us       0.000us             2
                                        aten::unsqueeze         1.21%       1.896ms         1.74%       2.725ms      10.645us       0.000us         0.00%       0.000us       0.000us           256
                                        cudaMemcpyAsync         1.01%       1.572ms         1.01%       1.572ms      12.186us       0.000us         0.00%       0.000us       0.000us           129
                                       aten::as_strided         0.54%     839.000us         0.54%     839.000us       3.265us       0.000us         0.00%       0.000us       0.000us           257
                                    cudaStreamWaitEvent         0.34%     539.000us         0.34%     539.000us       2.089us       0.000us         0.00%       0.000us       0.000us           258
                                        cudaEventRecord         0.18%     274.000us         0.18%     274.000us       1.062us       0.000us         0.00%       0.000us       0.000us           258
                                              aten::mul         0.07%     107.000us         0.08%     132.000us     132.000us       1.000us         0.33%       1.000us       1.000us             1
                                  cudaDeviceSynchronize         0.01%      17.000us         0.01%      17.000us       8.500us       0.000us         0.00%       0.000us       0.000us             2
                                cudaDeviceCanAccessPeer         0.00%       7.000us         0.00%       7.000us       3.500us       0.000us         0.00%       0.000us       0.000us             2
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us         0.66%       2.000us       1.000us             2
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      13.000us         4.30%      13.000us       3.250us             4
void at::native::lpnorm_cleanup<float, (at::native::...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         1.99%       6.000us       3.000us             2
                         Memcpy PtoP (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us     258.000us        85.43%     258.000us       2.000us           129
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         1.99%       6.000us       3.000us             2
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us       3.000us         0.99%       3.000us       3.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      10.000us         3.31%      10.000us       2.500us             4
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 156.319ms
Self CUDA time total: 302.000us
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120638
Approved by: https://github.com/Skylion007, https://github.com/albanD
ghstack dependencies: #120623
2024-02-27 01:27:05 +00:00
df72819f91 clip_grad_norm can use fast foreach path for inf norm (#120623)
Now that foreach_norm supports inf, we should not special case it.

For a mere 256 parameters, we get a win of 30ms in CPU time and ~800us -> 300us decrease in CUDA time. This win is only bigger for more parameters.

New profile:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (bf1c0490|REBASE-i|detached HEAD)]$ python playground2.py
STAGE:2024-02-26 13:14:10 395517:395517 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-02-26 13:14:11 395517:395517 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-02-26 13:14:11 395517:395517 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                       cudaLaunchKernel        67.01%     102.262ms        67.01%     102.262ms       5.382ms       2.000us         0.66%       2.000us       0.105us            19
                               aten::linalg_vector_norm         0.20%     311.000us        23.44%      35.776ms      35.776ms       3.000us         0.99%       3.000us       3.000us             1
                                               aten::to         0.79%       1.208ms        14.62%      22.311ms      86.143us       0.000us         0.00%     263.000us       1.015us           259
                                            aten::clamp         0.12%     182.000us        13.96%      21.303ms      21.303ms       1.000us         0.33%       1.000us       1.000us             1
                                         aten::_to_copy         2.38%       3.628ms        13.83%      21.103ms     163.589us       0.000us         0.00%     263.000us       2.039us           129
                                    aten::_foreach_norm         4.71%       7.185ms        13.54%      20.659ms      10.329ms      19.000us         6.29%      23.000us      11.500us             2
                                              aten::add         0.14%     211.000us        10.86%      16.580ms      16.580ms       1.000us         0.33%       1.000us       1.000us             1
                                            aten::stack         3.11%       4.744ms         9.59%      14.642ms      14.642ms       0.000us         0.00%       6.000us       6.000us             1
                                            aten::copy_         5.71%       8.721ms         9.27%      14.152ms     109.705us     258.000us        85.43%     263.000us       2.039us           129
                                       aten::reciprocal         0.13%     193.000us         7.93%      12.100ms      12.100ms       1.000us         0.33%       1.000us       1.000us             1
                                              aten::cat         0.67%       1.017ms         4.67%       7.129ms       7.129ms       6.000us         1.99%       6.000us       6.000us             1
                                            aten::zeros         0.05%      79.000us         4.46%       6.800ms       3.400ms       0.000us         0.00%       2.000us       1.000us             2
                                            aten::zero_         0.05%      79.000us         4.28%       6.537ms       3.268ms       0.000us         0.00%       2.000us       1.000us             2
                                            aten::fill_         0.09%     131.000us         4.23%       6.458ms       3.229ms       2.000us         0.66%       2.000us       1.000us             2
                                    aten::_foreach_mul_         1.56%       2.377ms         3.86%       5.896ms       2.948ms      10.000us         3.31%      10.000us       5.000us             2
                                            aten::empty         3.55%       5.414ms         3.55%       5.414ms      20.984us       0.000us         0.00%       0.000us       0.000us           258
                                    aten::empty_strided         2.18%       3.323ms         2.18%       3.323ms      25.760us       0.000us         0.00%       0.000us       0.000us           129
                                           aten::detach         0.85%       1.302ms         2.10%       3.199ms      12.496us       0.000us         0.00%       0.000us       0.000us           256
                             cudaDeviceEnablePeerAccess         2.01%       3.069ms         2.01%       3.069ms       1.534ms       0.000us         0.00%       0.000us       0.000us             2
                                        aten::unsqueeze         1.24%       1.899ms         1.81%       2.769ms      10.816us       0.000us         0.00%       0.000us       0.000us           256
                                                 detach         1.24%       1.897ms         1.24%       1.897ms       7.410us       0.000us         0.00%       0.000us       0.000us           256
                                        cudaMemcpyAsync         1.01%       1.539ms         1.01%       1.539ms      11.930us       0.000us         0.00%       0.000us       0.000us           129
                                       aten::as_strided         0.58%     881.000us         0.58%     881.000us       3.428us       0.000us         0.00%       0.000us       0.000us           257
                                    cudaStreamWaitEvent         0.35%     540.000us         0.35%     540.000us       2.093us       0.000us         0.00%       0.000us       0.000us           258
                                        cudaEventRecord         0.18%     278.000us         0.18%     278.000us       1.078us       5.000us         1.66%       5.000us       0.019us           258
                                              aten::mul         0.08%     125.000us         0.09%     138.000us     138.000us       1.000us         0.33%       1.000us       1.000us             1
                                  cudaDeviceSynchronize         0.01%      13.000us         0.01%      13.000us       6.500us       0.000us         0.00%       0.000us       0.000us             2
                                cudaDeviceCanAccessPeer         0.00%       5.000us         0.00%       5.000us       2.500us       0.000us         0.00%       0.000us       0.000us             2
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us         0.66%       2.000us       1.000us             2
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      13.000us         4.30%      13.000us       3.250us             4
void at::native::lpnorm_cleanup<float, (at::native::...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         1.99%       6.000us       3.000us             2
                         Memcpy PtoP (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us     258.000us        85.43%     258.000us       2.000us           129
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         1.99%       6.000us       3.000us             2
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us       3.000us         0.99%       3.000us       3.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      10.000us         3.31%      10.000us       2.500us             4
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 152.613ms
Self CUDA time total: 302.000us
```

Compared to on main:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5a0a9644)]$ python playground2.py
STAGE:2024-02-26 13:09:56 285045:285045 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-02-26 13:09:57 285045:285045 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-02-26 13:09:57 285045:285045 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                       cudaLaunchKernel        61.42%     113.375ms        61.42%     113.375ms     424.625us      45.000us         5.66%      45.000us       0.169us           267
                               aten::linalg_vector_norm        14.04%      25.909ms        37.67%      69.534ms     271.617us     514.000us        64.65%     559.000us       2.184us           256
                                               aten::to         0.78%       1.433ms        12.87%      23.751ms      91.703us       0.000us         0.00%     278.000us       1.073us           259
                                         aten::_to_copy         2.02%       3.730ms        12.09%      22.318ms     173.008us       0.000us         0.00%     278.000us       2.155us           129
                                            aten::clamp         0.09%     174.000us        11.43%      21.103ms      21.103ms       1.000us         0.13%       1.000us       1.000us             1
                                              aten::add         0.11%     205.000us         9.08%      16.768ms      16.768ms       1.000us         0.13%       1.000us       1.000us             1
                                            aten::copy_         4.94%       9.112ms         8.15%      15.043ms     116.612us     258.000us        32.45%     278.000us       2.155us           129
                                            aten::stack         2.76%       5.091ms         7.97%      14.719ms      14.719ms       0.000us         0.00%       6.000us       6.000us             1
                                       aten::reciprocal         0.11%     194.000us         7.01%      12.933ms      12.933ms       1.000us         0.13%       1.000us       1.000us             1
                                              aten::max         0.09%     165.000us         6.43%      11.868ms      11.868ms       3.000us         0.38%       3.000us       3.000us             1
                                           aten::detach         1.58%       2.911ms         4.12%       7.596ms      14.836us       0.000us         0.00%       0.000us       0.000us           512
                                              aten::cat         0.56%       1.042ms         3.73%       6.882ms       6.882ms       6.000us         0.75%       6.000us       6.000us             1
                                    aten::_foreach_mul_         1.36%       2.503ms         3.33%       6.145ms       3.072ms      10.000us         1.26%      10.000us       5.000us             2
                                                 detach         2.54%       4.685ms         2.54%       4.685ms       9.150us       0.000us         0.00%       0.000us       0.000us           512
                                    aten::empty_strided         1.92%       3.545ms         1.92%       3.545ms      27.481us       0.000us         0.00%       0.000us       0.000us           129
                             cudaDeviceEnablePeerAccess         1.64%       3.022ms         1.64%       3.022ms       1.511ms       0.000us         0.00%       0.000us       0.000us             2
                                        aten::unsqueeze         1.03%       1.892ms         1.49%       2.746ms      10.727us       0.000us         0.00%       0.000us       0.000us           256
                                       aten::as_strided         1.35%       2.494ms         1.35%       2.494ms       4.862us       0.000us         0.00%       0.000us       0.000us           513
                                        cudaMemcpyAsync         1.01%       1.868ms         1.01%       1.868ms      14.481us       4.000us         0.50%       4.000us       0.031us           129
                                    cudaStreamWaitEvent         0.41%     760.000us         0.41%     760.000us       2.946us       8.000us         1.01%       8.000us       0.031us           258
                                        cudaEventRecord         0.15%     276.000us         0.15%     276.000us       1.070us       8.000us         1.01%       8.000us       0.031us           258
                                              aten::mul         0.08%     139.000us         0.08%     153.000us     153.000us       1.000us         0.13%       1.000us       1.000us             1
                                            aten::empty         0.02%      35.000us         0.02%      35.000us      35.000us       0.000us         0.00%       0.000us       0.000us             1
                                  cudaDeviceSynchronize         0.01%      14.000us         0.01%      14.000us       7.000us       0.000us         0.00%       0.000us       0.000us             2
                                cudaDeviceCanAccessPeer         0.00%       5.000us         0.00%       5.000us       2.500us       0.000us         0.00%       0.000us       0.000us             2
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us     514.000us        64.65%     514.000us       2.008us           256
                         Memcpy PtoP (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us     258.000us        32.45%     258.000us       2.000us           129
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         0.75%       6.000us       3.000us             2
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us       3.000us         0.38%       3.000us       3.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.13%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.13%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.13%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.13%       1.000us       1.000us             1
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      10.000us         1.26%      10.000us       2.500us             4
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 184.579ms
Self CUDA time total: 795.000us
```

For script:
```
import torch
from math import inf
from torch.nn.utils import clip_grad_norm_

params = [torch.rand(32, 16, device="cuda:3")*5 for _ in range(128)] + [torch.rand(32, 16, device="cuda:4")*-7 for _ in range(128)]
for p in params:
    p.grad = torch.rand_like(p)

with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ]
) as p:
    total_norm = clip_grad_norm_(params, 10.0, norm_type=inf)
    torch.cuda.synchronize()

print(p.key_averages().table(sort_by="cpu_time_total"))
print(total_norm)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120623
Approved by: https://github.com/Skylion007, https://github.com/mikaylagawarecki
2024-02-27 01:27:05 +00:00
ad4472833c define public API for torch.nn.utils (#111026)
Adding modules imported here and the following functions to the `__all__`:
* [clip_grad_norm_](https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html)
* [clip_grad_value_](https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_value_.html)
* [remove_weight_norm](https://pytorch.org/docs/stable/generated/torch.nn.utils.remove_weight_norm.html)
* [parameters_to_vector](https://pytorch.org/docs/stable/generated/torch.nn.utils.parameters_to_vector.html)
* [vector_to_parameters](https://pytorch.org/docs/stable/generated/torch.nn.utils.vector_to_parameters.html)
* [remove_spectral_norm](https://pytorch.org/docs/stable/generated/torch.nn.utils.remove_spectral_norm.html)
* [skip_init](https://pytorch.org/docs/stable/generated/torch.nn.utils.skip_init.html)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111026
Approved by: https://github.com/mikaylagawarecki
2023-10-12 23:05:23 +00:00
6d2887cc06 Reland "Move tensor grouping to ATen" (#103912)
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors,  see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.

Also, using `std::hash` for string should result in a faster hash function.

(cherry picked from commit 74b7a6c75e698378882d30958908073407f97fb3)

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5914771</samp>

This pull request introduces a new function `_group_tensors_by_device_and_dtype` that can group tensors by their device and dtype, and updates the `foreach` utilities and several optimizers to use this function. The goal is to improve the performance, readability, and compatibility of the code that handles tensors with different properties. The pull request also adds a test case and type annotations for the new function, and some error checks for the `fused` argument in Adam and AdamW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103912
Approved by: https://github.com/janeyx99
2023-06-21 09:26:33 +00:00
0cb5bc3b04 Revert "Move tensor grouping to ATen (#100007)"
This reverts commit 74b7a6c75e698378882d30958908073407f97fb3.

Reverted https://github.com/pytorch/pytorch/pull/100007 on behalf of https://github.com/izaitsevfb due to Breaks internal builds, see D46629727 ([comment](https://github.com/pytorch/pytorch/pull/100007#issuecomment-1587861598))
2023-06-12 18:30:33 +00:00
74b7a6c75e Move tensor grouping to ATen (#100007)
rel: #94344
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100007
Approved by: https://github.com/janeyx99
2023-06-09 15:44:46 +00:00
704283d61f Improve clip_grad_norm to use torch.linalg.vector_norm (#102429)
Done in this PR:
 - Use `torch.linalg.vector_norm` instead of `torch.norm`
 - Reduce bandwidth boundary of clip_grad_norm when used with `inf`, ie no need to get the returned tensor after `abs`

What I'm slightly unsure:
 - I don't know if `inf` support `torch._foreach` API

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102429
Approved by: https://github.com/lezcano
2023-05-30 18:35:18 +00:00
b005ec62b9 [BE] Remove dependency on six and future (#94709)
Remove the Python 2 and 3 compatibility library [six](https://pypi.org/project/six) and [future](https://pypi.org/project/future) and `torch._six`. We only support Python 3.8+ now. It's time to retire them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94709
Approved by: https://github.com/malfet, https://github.com/Skylion007
2023-02-14 09:14:14 +00:00
8c9f745af1 [foreach] guard default support on native tensors only (#92923)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92923
Approved by: https://github.com/ngimel, https://github.com/crcrpar
2023-01-26 04:52:58 +00:00
e4d83d54a6 Foreach gradient clipping (#91846)
Faster gradient clipping using the foreach functions

```
[------------------------ (tensors, scalar) -------------------------]
                                   |  without foreach  |  with foreach |    apex
1 threads: ----------------------------------------------------------------------
      10 tensors of size 4         |         120.5     |       61.1    |     50.3
      100 tensors of size 4        |         946.2     |      239.5    |    136.3
      1000 tensors of size 4       |        9808.5     |     2151.1    |   1006.9
      10000 tensors of size 4      |       96871.2     |    22637.4    |  10119.1
      10 tensors of size 16        |         121.0     |       64.1    |     52.5
      100 tensors of size 16       |         993.4     |      252.6    |    136.7
      1000 tensors of size 16      |        9427.7     |     2151.2    |   1049.5
      10000 tensors of size 16     |       97437.1     |    22203.1    |  10340.0
      10 tensors of size 256       |         118.9     |       62.3    |     51.5
      100 tensors of size 256      |         955.2     |      243.1    |    134.2
      1000 tensors of size 256     |        9374.9     |     2140.7    |   1009.6
      10000 tensors of size 256    |       95302.5     |    21849.4    |  10215.5
      10 tensors of size 65536     |         118.5     |       62.4    |     51.1
      100 tensors of size 65536    |        1740.7     |      243.3    |    225.3
      1000 tensors of size 65536   |       17364.1     |     2228.7    |   2004.5
      10000 tensors of size 65536  |      177510.1     |    25410.4    |  20678.2
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91846
Approved by: https://github.com/janeyx99
2023-01-20 21:43:29 +00:00
94a6d72032 Update doc of clip grad (#91312)
Replaces #85772 that has a broken internal state.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91312
Approved by: https://github.com/soulitzer
2022-12-22 22:34:32 +00:00
9db3c517de Add __all__ for torch.nn.modules, torch.distributed.elastic, torch.nn.utils submodules (#80240)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80240
Approved by: https://github.com/rohan-varma
2022-06-27 17:11:12 +00:00
7843a5e882 Move Tensor.grad back into C++
`Tensor.grad` was moved to python in #30531 to add a warning. However,
that warning has since been lowered into C++ so this wrapper is no
longer necessary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76675

Approved by: https://github.com/albanD
2022-06-10 13:44:45 +00:00
da764f9224 Update clip_grad_norm_ documentation
It actually returns the gradient norm, not the parameter norms.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76230
Approved by: https://github.com/jbschlosser
2022-04-22 14:07:34 +00:00
7f7966a888 [Docs] Fix the syntax of documentation (#69958)
Summary:
Fixes the syntax of documentation in the file torch/nn/utils/clip_grad.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69958

Reviewed By: mruberry

Differential Revision: D33160612

Pulled By: albanD

fbshipit-source-id: 2dc199fee345bb4c75632900bc6f73a1ab8192a6
2021-12-16 10:38:39 -08:00
e858f6eed9 torch.nn.utils.clip_grad_norm_: remove device syncs (#61042)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/60691

### Changes

Per the discussion in the above issue, this PR makes 2 changes:
1. When `error_if_nonfinite=False`, the NaN/Inf checks are truly skipped, and no device synchronization occurs.
    - Additionally, when performing the checks, the 2 results are combined with `torch.logical_or` to incur only a single sync (instead of 2 in the happy/finite path).
2. The `clip_coef` conditional is removed, in favor of a call to `clamp(..., max=1.0)` and an unconditional multiplication.

### Testing

- The existing unit tests for `clip_grad_norm_` pass.
- I have manually profiled the example program from https://github.com/pytorch/pytorch/issues/60691, and verified that:
    - No synchronizations occur when using `error_if_nonfinite=False`.
    - A single synchronization occurs when using `error_if_nonfinite=True`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61042

Reviewed By: mrshenli

Differential Revision: D29764096

Pulled By: jbschlosser

fbshipit-source-id: db594b24608d16374b91bcbb9469046dfeeb152d
2021-07-22 08:53:40 -07:00
38a08a49ea Flip clip_grad_norm default for error_if_nonfinite to false (#55169)
Summary:
Non-backwards-compatible change introduced in https://github.com/pytorch/pytorch/pull/53843 is tripping up a lot of code. Better to set it to False initially and then potentially flip to True in the later version to give people time to adapt.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55169

Reviewed By: mruberry

Differential Revision: D27511150

Pulled By: jbschlosser

fbshipit-source-id: 1ac018557c0900b31995c29f04aea060a27bc525
2021-04-02 12:25:32 -07:00
3ddc6174da Raise error in clip_grad_norm_ if norm is non-finite (#53843)
Summary:
**BC-breaking note**: This change throws errors for cases that used to silently pass. The old behavior can be obtained by setting `error_if_nonfinite=False`

Fixes https://github.com/pytorch/pytorch/issues/46849

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53843

Reviewed By: malfet

Differential Revision: D27291838

Pulled By: jbschlosser

fbshipit-source-id: 216d191b26e1b5919a44a3af5cde6f35baf825c4
2021-03-29 08:41:21 -07:00
e6779d4357 [*.py] Rename "Arguments:" to "Args:" (#49736)
Summary:
I've written custom parsers and emitters for everything from docstrings to classes and functions. However, I recently came across an issue when I was parsing/generating from the TensorFlow codebase: inconsistent use of `Args:` and `Arguments:` in its docstrings.

```sh
(pytorch#c348fae)$ for name in 'Args:' 'Arguments:'; do
    printf '%-10s %04d\n' "$name" "$(rg -IFtpy --count-matches "$name" | paste -s -d+ -- | bc)"; done
Args:      1095
Arguments: 0336
```

It is easy enough to extend my parsers to support both variants, however it looks like `Arguments:` is wrong anyway, as per:

  - https://google.github.io/styleguide/pyguide.html#doc-function-args @ [`ddccc0f`](https://github.com/google/styleguide/blob/ddccc0f/pyguide.md)

  - https://chromium.googlesource.com/chromiumos/docs/+/master/styleguide/python.md#describing-arguments-in-docstrings @ [`9fc0fc0`](https://chromium.googlesource.com/chromiumos/docs/+/9fc0fc0/styleguide/python.md)

  - https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html @ [`c0ae8e3`](https://github.com/sphinx-contrib/napoleon/blob/c0ae8e3/docs/source/example_google.rst)

Therefore, only `Args:` is valid. This PR replaces them throughout the codebase.

PS: For related PRs, see tensorflow/tensorflow/pull/45420

PPS: The trackbacks automatically appearing below are sending the same changes to other repositories in the [PyTorch](https://github.com/pytorch) organisation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/49736

Reviewed By: albanD

Differential Revision: D25710534

Pulled By: soumith

fbshipit-source-id: 61e8ff01abb433e9f78185c2d1d0cbd7c22c1619
2020-12-28 09:34:47 -08:00
1c6ace87d1 Embed torch.nn typing annotations (#43044)
Summary:
Delete several .pyi files and embed annotations from those files in respective .py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43044

Reviewed By: ezyang

Differential Revision: D23123234

Pulled By: malfet

fbshipit-source-id: 4ba361cc84402352090523924b0035e100ba48b1
2020-08-14 13:24:58 -07:00
54d4b419db fix clip_grad_norm to work with parameters on the different devices (#38615)
Summary:
Per title.
We move all the individual gradient norms to a single device before stacking (no-op if all the gradients are already on a single device), `clip_coef` is copied to the device of gradient, which may be suboptimal as there could be multiple copies, but no worse than when we were synchronizing for each parameter. In a simple case of all gradients on a single device, there should be no synchronization.
Also, we no longer error out if parameter list is empty or none of the parameters have gradients, and return 0 total_norm instead.
Fixes https://github.com/pytorch/pytorch/issues/38605
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38615

Reviewed By: ailzhang

Differential Revision: D21634588

Pulled By: ngimel

fbshipit-source-id: ea4d08d4f3445438260052820c7ca285231a156b
2020-05-19 10:33:40 -07:00
e74a215ade Changed clip_grad_norm_ total_norm calculation (#32020)
Summary:
Redefines the computation of the total_norm to increase performance as shown in https://github.com/pytorch/pytorch/issues/31474.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32020

Differential Revision: D19353309

Pulled By: ngimel

fbshipit-source-id: bf7530dcd39f56614a211b5f21445864d4f2e875
2020-01-13 08:13:46 -08:00
19a6de328f Correct docstring of vision/init functions
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17351

Differential Revision: D14276355

Pulled By: soumith

fbshipit-source-id: 9b572b6a04eeb1e44cd93961edac76ed10f7b24e
2019-03-01 11:40:23 -08:00
27455e9c78 Use _six for inf and nan (#9500)
Summary:
Things like `float('inf')` are actually quite expensive.
```py
In [1]: import math

In [2]: %timeit -n 200 math.inf
49.3 ns ± 1.42 ns per loop (mean ± std. dev. of 7 runs, 200 loops each)

In [3]: %timeit -n 200 float('inf')
194 ns ± 39.1 ns per loop (mean ± std. dev. of 7 runs, 200 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9500

Reviewed By: soumith

Differential Revision: D8876229

Pulled By: SsnL

fbshipit-source-id: 78602b76bb53d5588910b58270930c0bd413d2d7
2018-07-18 10:40:29 -07:00
89c2b50a15 Grad clip for parameters on different devices (#9302)
Summary:
I'm trying to write a multi-gpu network by pipelining some layers onto different GPUs. However, the current gradient clip requires all the parameters to locate in the same device.

The overhead of CUDA launch is reduced since the scalar calculation is performed on CPU, but it introduces extra data transfers.

No performance regression is observed by running the following snippet:
```python
import time

import torch

module = torch.nn.Sequential(
    torch.nn.LSTM(1024, 1024),
    torch.nn.LSTM(256, 256),
    torch.nn.Linear(100, 10000),
).cuda()

torch.nn.utils.clip_grad_norm_(module.parameters(), 1)
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
    torch.nn.utils.clip_grad_norm_(module.parameters(), 1)
torch.cuda.synchronize()
time_elapse = time.time() - start
print('{} ms per clip'.format(time_elapse))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9302

Differential Revision: D8781551

Pulled By: soumith

fbshipit-source-id: 9d76d01fe0531927f770a16b9523872a7e08e927
2018-07-10 07:56:55 -07:00
1078491502 Change is_tensor to isinstance(*, torch.Tensor) (#7814)
Thanks!
2018-05-24 15:08:16 -04:00
2222fc7666 Add support for accepting Tensor as input in clip_grad_* functions. (#7769) 2018-05-23 12:12:03 +02:00
8325206c6f A clip grad fix for sparse tensors. (#7257) 2018-05-04 00:35:32 +02:00
7fcaf3b49e Update torch.nn.init and torch.nn.utils.clip_grad (#6173)
Introducing two updates.

1. Add param to He initialization scheme in torch.nn.init
Problem solved:
The function calculate_gain can take an argument to specify the type of non-linearity used. However, it wasn't possible to pass this argument directly to the He / Kaiming weight initialization function.

2. Add util to clip gradient value in torch.nn.utils.clip_grad
Problem solved:
DL libraries typically provide users with easy access to functions for clipping the gradients both using the norm and a fixed value. However, the utils clip_grad.py only had a function to clip the gradient norm.

* add param to He initialization scheme in torch.nn.init

* add util to clip gradient value in torch/nn/utils/clip_grad.py

* update doc in torch.nn.utils.clip_grad

* update and add test for torch.nn.utils.clip_grad

* update function signature in torch.nn.utils.clip_grad to match suffix_ convention

* ensure backward compatibility in torch.nn.utils.clip_grad

* remove DeprecationWarning in torch.nn.utils.clip_grad

* extend test and implementation of torch.nn.utils.clip_grad

* update test and implementation torch.nn.utils.clip_grad
2018-04-17 11:32:32 -04:00