mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
[BE] Revert distributed change in https://github.com/pytorch/pytorch/pull/68779 (#83181)
https://github.com/pytorch/pytorch/issues/82641 points out a regression in how inputs / outputs are processed by DDP, blocking their HF use case. It was narrowed down to https://github.com/pytorch/pytorch/pull/68779 and reverting the distributed change there fixes the issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/83181 Approved by: https://github.com/kumpera
This commit is contained in:
committed by
PyTorch MergeBot
parent
4e90526a4f
commit
b29a074882
@ -1,5 +1,3 @@
|
||||
import collections
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel._functions import _get_stream
|
||||
@ -40,22 +38,10 @@ def _recursive_to(inputs, target_gpu, use_side_stream_for_tensor_copies):
|
||||
return [type(obj)(*args) for args in zip(*map(to_map, obj))]
|
||||
if isinstance(obj, tuple) and len(obj) > 0:
|
||||
return list(zip(*map(to_map, obj)))
|
||||
if isinstance(obj, str):
|
||||
# Needs to be checked, otherwise it's taken as a sequence infinitely.
|
||||
# This is because the elements of a string are also strings, and so on.
|
||||
return [obj]
|
||||
if isinstance(obj, collections.abc.Sequence) and len(obj) > 0:
|
||||
try:
|
||||
return [type(obj)(i) for i in zip(*map(to_map, obj))] # type: ignore[call-arg]
|
||||
except TypeError:
|
||||
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
|
||||
return [list(i) for i in zip(*map(to_map, obj))]
|
||||
if isinstance(obj, collections.abc.Mapping) and len(obj) > 0:
|
||||
try:
|
||||
return [type(obj)(i) for i in zip(*map(to_map, obj.items()))] # type: ignore[call-arg]
|
||||
except TypeError:
|
||||
# The mapping type may not support `__init__(iterable)`.
|
||||
return [dict(i) for i in zip(*map(to_map, obj.items()))]
|
||||
if isinstance(obj, list) and len(obj) > 0:
|
||||
return [list(i) for i in zip(*map(to_map, obj))]
|
||||
if isinstance(obj, dict) and len(obj) > 0:
|
||||
return [type(obj)(i) for i in zip(*map(to_map, obj.items()))]
|
||||
return [obj]
|
||||
|
||||
# Avoid reference cycle
|
||||
|
||||
Reference in New Issue
Block a user