Files
pytorch/torch/utils/_foreach_utils.py
haozhe.zhu 3c964ad1ca add fused_sgd_kernel support for CPU device (#123629)
Support fused_sgd_kernel support for CPU.

## Bench result:
32 core/sockets ICX
Test Scripts:
https://gist.github.com/zhuhaozhe/688763e17e93e4c5e12f25f676ec90d9
https://gist.github.com/zhuhaozhe/ad9938694bc7fae8b66d376f4dffc6c9
```
Tensor Size: 262144, Num Tensor 4, Num Threads: 1
_single_tensor_sgd time: 0.2301 seconds
_fused_sgd time: 0.0925 seconds
Tensor Size: 4194304, Num Tensor 32, Num Threads: 32
_single_tensor_sgd time: 2.6195 seconds
_fused_sgd time: 1.7543 seconds
```
## Test Plan:
```
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_optim.py -k test_can_load_older_state_dict
python test_optim.py -k test_grad_scaling_autocast_fused_optimizers
python test_torch.py -k test_grad_scaling_autocast_fused
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
```
Looks like we already have some PRs under this issue https://github.com/pytorch/pytorch/issues/123451 to unified the UTs, I did not modified UT in this PR.

Co-authored-by: Jane Xu <janeyx@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123629
Approved by: https://github.com/jgong5, https://github.com/janeyx99
2024-04-23 08:28:19 +00:00

48 lines
2.3 KiB
Python

from typing import List, Dict, Tuple, Optional
import torch
from torch import Tensor
from torch.autograd.grad_mode import no_grad
from typing_extensions import TypeAlias
def _get_foreach_kernels_supported_devices() -> List[str]:
r"""Return the device type list that supports foreach kernels."""
return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
def _get_fused_kernels_supported_devices() -> List[str]:
r"""Return the device type list that supports fused kernels in optimizer."""
return ["cuda", "xpu", "cpu", torch._C._get_privateuse1_backend_name()]
TensorListList: TypeAlias = List[List[Optional[Tensor]]]
Indices: TypeAlias = List[int]
# This util function splits tensors into groups by device and dtype, which is useful before sending
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
# - tensorlists CAN be None
# - all tensors in the first specified list cannot be None
# - given an index i, all specified tensorlist[i]s match in dtype and device
# with_indices (bool, optional): whether to track previous indices as the last list per dictionary entry.
# It comes in handy if there are Nones or literals in the tensorlists that are getting scattered out.
# Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
# may be necessary. Check out torch/optim/sgd.py for an example.
@no_grad()
def _group_tensors_by_device_and_dtype(
tensorlistlist: TensorListList,
with_indices: bool = False,
) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]:
return {
(device, getattr(torch, str_dtype)): value
for (device, str_dtype), value in
torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices).items()
}
def _device_has_foreach_support(device: torch.device) -> bool:
return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
return _device_has_foreach_support(device) and all(t is None or type(t) == torch.Tensor for t in tensors)