mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Using scalarType instead string in function _group_tensors_by_device_and_dtype. (#127869)
Now torch.dtype can pass through pybind11, so modify function _group_tensors_by_device_and_dtype to using scalar type. And without convert torch.dtype and string in python and c++ side. @ezyang @bdhirsh Pull Request resolved: https://github.com/pytorch/pytorch/pull/127869 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
0ff60236ab
commit
3bcc3cddb5
@ -34,12 +34,7 @@ def _group_tensors_by_device_and_dtype(
|
||||
tensorlistlist: TensorListList,
|
||||
with_indices: bool = False,
|
||||
) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]:
|
||||
return {
|
||||
(device, getattr(torch, str_dtype)): value
|
||||
for (device, str_dtype), value in
|
||||
torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices).items()
|
||||
}
|
||||
|
||||
return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
|
||||
|
||||
def _device_has_foreach_support(device: torch.device) -> bool:
|
||||
return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()
|
||||
|
Reference in New Issue
Block a user