[DTensor] Turn on foreach implementation for clip_grad_norm_ for DTensor by default (#126423)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126423
Approved by: https://github.com/awgu
This commit is contained in:
wz337
2024-05-17 06:57:49 +00:00
committed by PyTorch MergeBot
parent f9a7033194
commit 15ca562f86
3 changed files with 17 additions and 7 deletions

View File

@ -15,6 +15,8 @@ def _get_fused_kernels_supported_devices() -> List[str]:
TensorListList: TypeAlias = List[List[Optional[Tensor]]]
Indices: TypeAlias = List[int]
_foreach_supported_types = [torch.Tensor]
# This util function splits tensors into groups by device and dtype, which is useful before sending
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
@ -44,4 +46,4 @@ def _device_has_foreach_support(device: torch.device) -> bool:
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
return _device_has_foreach_support(device) and all(t is None or type(t) == torch.Tensor for t in tensors)
return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors)