mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This PR adds fused Adam and AdamW implementations. Benchmark on Macbook Pro with M1 Max chip and 64GB unified memory: **Fast math enabled:** ``` [---------------------------------------------- Fused Adam ----------------------------------------------] | Fused: True | Fused: False 1 threads: ----------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 89 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 90 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 83 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 12 | 94 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 88 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 12 | 90 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 100 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 27 | 100 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 23 | 100 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 27 | 100 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 23 | 98 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 480 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 72 | 450 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 82 | 450 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 73 | 420 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 91 | 500 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 83 | 400 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 94 | 500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 78 | 400 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 170 | 500 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 140 | 600 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 170 | 600 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 140 | 500 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 250 | 890 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 220 | 850 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 250 | 830 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 220 | 770 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 270 | 870 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 230 | 840 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 270 | 810 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 240 | 800 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 400 | 1000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 360 | 2000 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 430 | 2000 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 360 | 1300 Times are in milliseconds (ms). ``` **Fast math disabled:** ``` [---------------------------------------------- Fused Adam ----------------------------------------------] | Fused: True | Fused: False 1 threads: ----------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 84 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 84 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 79 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 93 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 10 | 90 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 91 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 81 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 34 | 100 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 31 | 100 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 34 | 95 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 31 | 100 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 94 | 500 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 430 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 92 | 430 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 81 | 390 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 98 | 500 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 88 | 430 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 100 | 500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 88 | 400 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 210 | 500 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 190 | 610 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 210 | 510 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 190 | 500 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 300 | 900 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 260 | 850 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 295 | 900 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 260 | 800 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 320 | 910 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 280 | 900 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 320 | 900 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 300 | 900 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 500 | 2000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 480 | 2000 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 540 | 1500 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 480 | 1200 Times are in milliseconds (ms). ``` ```python def profile_fused_adam(): from torch.optim import adam, adamw import torch.utils.benchmark as benchmark import itertools def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused): fn( params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach=False, capturable=False, fused=fused, amsgrad=amsgrad, beta1=0.9, beta2=0.99, lr=1e-3, weight_decay=.0, eps=1e-5, maximize=False, grad_scale=None, found_inf=None, ) torch.mps.synchronize() device = "mps" results = [] for num_tensors, numel, adamWflag, amsgrad in itertools.product([100, 500, 1000], [1024, 65536, 1048576], [True, False], [True, False]): print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}") params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=torch.float32, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)] max_exp_avg_sqs = [torch.arange(numel, dtype=torch.float32, device=device) for _ in range(num_tensors)] if amsgrad else [] state_steps = [torch.tensor([5], dtype=torch.float32, device=device) for _ in range(num_tensors)] if adamWflag: fn = adamw.adamw else: fn = adam.adam for fused in [True, False]: t = benchmark.Timer( stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)', label='Fused Adam', sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}", globals=locals(), description= f"Fused: {fused}", ).blocked_autorange(min_run_time=5) results.append(t) compare = benchmark.Compare(results) compare.trim_significant_figures() compare.colorize(rowwise=True) compare.print() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127242 Approved by: https://github.com/kulinseth, https://github.com/janeyx99
45 lines
2.3 KiB
Python
45 lines
2.3 KiB
Python
from typing import List, Dict, Tuple, Optional
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.autograd.grad_mode import no_grad
|
|
from typing_extensions import TypeAlias
|
|
|
|
def _get_foreach_kernels_supported_devices() -> List[str]:
|
|
r"""Return the device type list that supports foreach kernels."""
|
|
return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
|
|
|
|
def _get_fused_kernels_supported_devices() -> List[str]:
|
|
r"""Return the device type list that supports fused kernels in optimizer."""
|
|
return ["mps", "cuda", "xpu", "cpu", torch._C._get_privateuse1_backend_name()]
|
|
|
|
TensorListList: TypeAlias = List[List[Optional[Tensor]]]
|
|
Indices: TypeAlias = List[int]
|
|
_foreach_supported_types = [torch.Tensor]
|
|
|
|
|
|
# This util function splits tensors into groups by device and dtype, which is useful before sending
|
|
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
|
|
# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
|
|
# - tensorlists CAN be None
|
|
# - all tensors in the first specified list cannot be None
|
|
# - given an index i, all specified tensorlist[i]s match in dtype and device
|
|
# with_indices (bool, optional): whether to track previous indices as the last list per dictionary entry.
|
|
# It comes in handy if there are Nones or literals in the tensorlists that are getting scattered out.
|
|
# Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
|
|
# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
|
|
# may be necessary. Check out torch/optim/sgd.py for an example.
|
|
@no_grad()
|
|
def _group_tensors_by_device_and_dtype(
|
|
tensorlistlist: TensorListList,
|
|
with_indices: bool = False,
|
|
) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]:
|
|
return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
|
|
|
|
def _device_has_foreach_support(device: torch.device) -> bool:
|
|
return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()
|
|
|
|
|
|
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
|
|
return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors)
|