Foreach gradient clipping (#91846)

Faster gradient clipping using the foreach functions

```
[------------------------ (tensors, scalar) -------------------------]
                                   |  without foreach  |  with foreach |    apex
1 threads: ----------------------------------------------------------------------
      10 tensors of size 4         |         120.5     |       61.1    |     50.3
      100 tensors of size 4        |         946.2     |      239.5    |    136.3
      1000 tensors of size 4       |        9808.5     |     2151.1    |   1006.9
      10000 tensors of size 4      |       96871.2     |    22637.4    |  10119.1
      10 tensors of size 16        |         121.0     |       64.1    |     52.5
      100 tensors of size 16       |         993.4     |      252.6    |    136.7
      1000 tensors of size 16      |        9427.7     |     2151.2    |   1049.5
      10000 tensors of size 16     |       97437.1     |    22203.1    |  10340.0
      10 tensors of size 256       |         118.9     |       62.3    |     51.5
      100 tensors of size 256      |         955.2     |      243.1    |    134.2
      1000 tensors of size 256     |        9374.9     |     2140.7    |   1009.6
      10000 tensors of size 256    |       95302.5     |    21849.4    |  10215.5
      10 tensors of size 65536     |         118.5     |       62.4    |     51.1
      100 tensors of size 65536    |        1740.7     |      243.3    |    225.3
      1000 tensors of size 65536   |       17364.1     |     2228.7    |   2004.5
      10000 tensors of size 65536  |      177510.1     |    25410.4    |  20678.2
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91846
Approved by: https://github.com/janeyx99
This commit is contained in:
milesial
2023-01-20 21:43:29 +00:00
committed by PyTorch MergeBot
parent 44b7a0b7ef
commit e4d83d54a6
5 changed files with 156 additions and 104 deletions

View File

@ -3,11 +3,11 @@ from typing import List, Dict, Tuple, Optional, Union
import torch
from torch import Tensor
from torch.autograd.grad_mode import no_grad
# _group_tensors_by_device_and_dtype is a util function that splits tensors into groups by device and dtype,
# which is useful before sending tensors off to a foreach implementation, which requires tensors to be on
# one device and dtype.
# This util function splits tensors into groups by device and dtype, which is useful before sending
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
# - tensorlists CAN be None
# - all tensors in the first specified list cannot be None
@ -17,16 +17,16 @@ from torch import Tensor
# Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
# may be necessary. Check out torch/optim/sgd.py for an example.
@torch.no_grad()
@no_grad()
def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
with_indices: Optional[bool] = False) -> \
Dict[Tuple[str, torch.dtype], List[List[Union[Tensor, int]]]]:
Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]]:
assert all([not x or len(x) == len(tensorlistlist[0]) for x in tensorlistlist]), (
"all specified tensorlists must match in length")
per_device_and_dtype_tensors: Dict[Tuple[str, torch.dtype], List[List[Union[Tensor, int]]]] = defaultdict(
per_device_and_dtype_tensors: Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]] = defaultdict(
lambda: [[] for _ in range(len(tensorlistlist) + (1 if with_indices else 0))])
for i, t in enumerate(tensorlistlist[0]):
key = (str(t.device), t.dtype)
key = (t.device, t.dtype)
for j in range(len(tensorlistlist)):
# a tensorlist may be empty/None
if tensorlistlist[j]: