mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 14:34:54 +08:00
Keep the sequence or mapping type in default_collate (#68779)
Summary: `default_collate`, `default_convert`, and `pin_memory` convert sequences into lists. I believe they should keep the original type when possible (e.g., I have a class that inherits from `list`, which comes from a 3rd party library that I can't change, and provides extra functionality). Note it's easy to do when the type supports an iterable in its creation but it's not always the case (e.g., `range`). Even though this can be accomplished if using a custom `default_collate`/`default_convert`, 1) this is behavior they should support out-of-the-box IMHO, and 2) `pin_memory` still does it. cc VitalyFedyunin ejguan NivekT Pull Request resolved: https://github.com/pytorch/pytorch/pull/68779 Reviewed By: wenleix Differential Revision: D32651129 Pulled By: ejguan fbshipit-source-id: 17c390934bacc0e4ead060469cf15dde815550b4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
d9e7d85390
commit
f776f30780
@ -1,3 +1,4 @@
|
||||
import collections.abc
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Any
|
||||
@ -1013,10 +1014,22 @@ class DistributedDataParallel(Module, Joinable):
|
||||
return [type(obj)(*args) for args in zip(*map(to_map, obj))]
|
||||
if isinstance(obj, tuple) and len(obj) > 0:
|
||||
return list(zip(*map(to_map, obj)))
|
||||
if isinstance(obj, list) and len(obj) > 0:
|
||||
return [list(i) for i in zip(*map(to_map, obj))]
|
||||
if isinstance(obj, dict) and len(obj) > 0:
|
||||
return [type(obj)(i) for i in zip(*map(to_map, obj.items()))]
|
||||
if isinstance(obj, str):
|
||||
# Needs to be checked, otherwise it's taken as a sequence infinitely.
|
||||
# This is because the elements of a string are also strings, and so on.
|
||||
return [obj]
|
||||
if isinstance(obj, collections.abc.Sequence) and len(obj) > 0:
|
||||
try:
|
||||
return [type(obj)(i) for i in zip(*map(to_map, obj))]
|
||||
except TypeError:
|
||||
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
|
||||
return [list(i) for i in zip(*map(to_map, obj))]
|
||||
if isinstance(obj, collections.abc.Mapping) and len(obj) > 0:
|
||||
try:
|
||||
return [type(obj)(i) for i in zip(*map(to_map, obj.items()))]
|
||||
except TypeError:
|
||||
# The mapping type may not support `__init__(iterable)`.
|
||||
return [dict(i) for i in zip(*map(to_map, obj.items()))]
|
||||
return [obj]
|
||||
|
||||
# Avoid reference cycle
|
||||
|
||||
Reference in New Issue
Block a user