mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Foreach gradient clipping (#91846)
Faster gradient clipping using the foreach functions ``` [------------------------ (tensors, scalar) -------------------------] | without foreach | with foreach | apex 1 threads: ---------------------------------------------------------------------- 10 tensors of size 4 | 120.5 | 61.1 | 50.3 100 tensors of size 4 | 946.2 | 239.5 | 136.3 1000 tensors of size 4 | 9808.5 | 2151.1 | 1006.9 10000 tensors of size 4 | 96871.2 | 22637.4 | 10119.1 10 tensors of size 16 | 121.0 | 64.1 | 52.5 100 tensors of size 16 | 993.4 | 252.6 | 136.7 1000 tensors of size 16 | 9427.7 | 2151.2 | 1049.5 10000 tensors of size 16 | 97437.1 | 22203.1 | 10340.0 10 tensors of size 256 | 118.9 | 62.3 | 51.5 100 tensors of size 256 | 955.2 | 243.1 | 134.2 1000 tensors of size 256 | 9374.9 | 2140.7 | 1009.6 10000 tensors of size 256 | 95302.5 | 21849.4 | 10215.5 10 tensors of size 65536 | 118.5 | 62.4 | 51.1 100 tensors of size 65536 | 1740.7 | 243.3 | 225.3 1000 tensors of size 65536 | 17364.1 | 2228.7 | 2004.5 10000 tensors of size 65536 | 177510.1 | 25410.4 | 20678.2 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/91846 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
44b7a0b7ef
commit
e4d83d54a6
@ -3,11 +3,11 @@ from typing import List, Dict, Tuple, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.autograd.grad_mode import no_grad
|
||||
|
||||
|
||||
# _group_tensors_by_device_and_dtype is a util function that splits tensors into groups by device and dtype,
|
||||
# which is useful before sending tensors off to a foreach implementation, which requires tensors to be on
|
||||
# one device and dtype.
|
||||
# This util function splits tensors into groups by device and dtype, which is useful before sending
|
||||
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
|
||||
# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
|
||||
# - tensorlists CAN be None
|
||||
# - all tensors in the first specified list cannot be None
|
||||
@ -17,16 +17,16 @@ from torch import Tensor
|
||||
# Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
|
||||
# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
|
||||
# may be necessary. Check out torch/optim/sgd.py for an example.
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
|
||||
with_indices: Optional[bool] = False) -> \
|
||||
Dict[Tuple[str, torch.dtype], List[List[Union[Tensor, int]]]]:
|
||||
Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]]:
|
||||
assert all([not x or len(x) == len(tensorlistlist[0]) for x in tensorlistlist]), (
|
||||
"all specified tensorlists must match in length")
|
||||
per_device_and_dtype_tensors: Dict[Tuple[str, torch.dtype], List[List[Union[Tensor, int]]]] = defaultdict(
|
||||
per_device_and_dtype_tensors: Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]] = defaultdict(
|
||||
lambda: [[] for _ in range(len(tensorlistlist) + (1 if with_indices else 0))])
|
||||
for i, t in enumerate(tensorlistlist[0]):
|
||||
key = (str(t.device), t.dtype)
|
||||
key = (t.device, t.dtype)
|
||||
for j in range(len(tensorlistlist)):
|
||||
# a tensorlist may be empty/None
|
||||
if tensorlistlist[j]:
|
||||
|
Reference in New Issue
Block a user