mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[DTensor] Turn on foreach implementation for clip_grad_norm_ for DTensor by default (#126423)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/126423 Approved by: https://github.com/awgu
This commit is contained in:
@ -15,6 +15,8 @@ def _get_fused_kernels_supported_devices() -> List[str]:
|
||||
|
||||
TensorListList: TypeAlias = List[List[Optional[Tensor]]]
|
||||
Indices: TypeAlias = List[int]
|
||||
_foreach_supported_types = [torch.Tensor]
|
||||
|
||||
|
||||
# This util function splits tensors into groups by device and dtype, which is useful before sending
|
||||
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
|
||||
@ -44,4 +46,4 @@ def _device_has_foreach_support(device: torch.device) -> bool:
|
||||
|
||||
|
||||
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
|
||||
return _device_has_foreach_support(device) and all(t is None or type(t) == torch.Tensor for t in tensors)
|
||||
return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors)
|
||||
|
Reference in New Issue
Block a user