mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "Move tensor grouping to ATen (#100007)"
This reverts commit 74b7a6c75e698378882d30958908073407f97fb3. Reverted https://github.com/pytorch/pytorch/pull/100007 on behalf of https://github.com/izaitsevfb due to Breaks internal builds, see D46629727 ([comment](https://github.com/pytorch/pytorch/pull/100007#issuecomment-1587861598))
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from collections import defaultdict
|
||||
from typing import List, Dict, Tuple, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -28,16 +29,23 @@ def _get_fused_kernels_supported_devices() -> List[str]:
|
||||
# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
|
||||
# may be necessary. Check out torch/optim/sgd.py for an example.
|
||||
@no_grad()
|
||||
def _group_tensors_by_device_and_dtype(
|
||||
tensorlistlist: List[List[Optional[Tensor]]],
|
||||
with_indices: bool = False,
|
||||
) -> Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Optional[Tensor]]], List[int]]]:
|
||||
return {
|
||||
(device, getattr(torch, str_dtype)): value
|
||||
for (device, str_dtype), value in
|
||||
torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices).items()
|
||||
}
|
||||
|
||||
def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
|
||||
with_indices: Optional[bool] = False) -> \
|
||||
Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]]:
|
||||
assert all(not x or len(x) == len(tensorlistlist[0]) for x in tensorlistlist), (
|
||||
"all specified tensorlists must match in length")
|
||||
per_device_and_dtype_tensors: Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]] = defaultdict(
|
||||
lambda: [[] for _ in range(len(tensorlistlist) + (1 if with_indices else 0))])
|
||||
for i, t in enumerate(tensorlistlist[0]):
|
||||
key = (t.device, t.dtype)
|
||||
for j in range(len(tensorlistlist)):
|
||||
# a tensorlist may be empty/None
|
||||
if tensorlistlist[j]:
|
||||
per_device_and_dtype_tensors[key][j].append(tensorlistlist[j][i])
|
||||
if with_indices:
|
||||
# tack on previous index
|
||||
per_device_and_dtype_tensors[key][j + 1].append(i)
|
||||
return per_device_and_dtype_tensors
|
||||
|
||||
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
|
||||
if device.type not in set(_get_foreach_kernels_supported_devices() + ["cpu"]) or torch.jit.is_scripting():
|
||||
|
Reference in New Issue
Block a user