Keep the sequence or mapping type in default_collate (#68779)

Summary:
`default_collate`, `default_convert`, and `pin_memory` convert sequences into lists. I believe they should keep the original type when possible (e.g., I have a class that inherits from `list`, which comes from a 3rd party library that I can't change, and provides extra functionality).

Note it's easy to do when the type supports an iterable in its creation but it's not always the case (e.g., `range`).

Even though this can be accomplished if using a custom `default_collate`/`default_convert`, 1) this is behavior they should support out-of-the-box IMHO, and 2) `pin_memory` still does it.

cc VitalyFedyunin ejguan NivekT

Pull Request resolved: https://github.com/pytorch/pytorch/pull/68779

Reviewed By: wenleix

Differential Revision: D32651129

Pulled By: ejguan

fbshipit-source-id: 17c390934bacc0e4ead060469cf15dde815550b4
This commit is contained in:
Santiago Castro
2021-11-29 13:13:05 -08:00
committed by Facebook GitHub Bot
parent d9e7d85390
commit f776f30780
4 changed files with 106 additions and 11 deletions

View File

@ -1,3 +1,4 @@
import collections.abc
import copy
from dataclasses import dataclass
from typing import Callable, Any
@ -1013,10 +1014,22 @@ class DistributedDataParallel(Module, Joinable):
return [type(obj)(*args) for args in zip(*map(to_map, obj))]
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(to_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
return [list(i) for i in zip(*map(to_map, obj))]
if isinstance(obj, dict) and len(obj) > 0:
return [type(obj)(i) for i in zip(*map(to_map, obj.items()))]
if isinstance(obj, str):
# Needs to be checked, otherwise it's taken as a sequence infinitely.
# This is because the elements of a string are also strings, and so on.
return [obj]
if isinstance(obj, collections.abc.Sequence) and len(obj) > 0:
try:
return [type(obj)(i) for i in zip(*map(to_map, obj))]
except TypeError:
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
return [list(i) for i in zip(*map(to_map, obj))]
if isinstance(obj, collections.abc.Mapping) and len(obj) > 0:
try:
return [type(obj)(i) for i in zip(*map(to_map, obj.items()))]
except TypeError:
# The mapping type may not support `__init__(iterable)`.
return [dict(i) for i in zip(*map(to_map, obj.items()))]
return [obj]
# Avoid reference cycle