mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 23:45:05 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74984 Approved by: https://github.com/ejguan
91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
import io
|
|
import pickle
|
|
|
|
from torch.utils.data import IterDataPipe, MapDataPipe
|
|
from torch.utils.data._utils.serialization import DILL_AVAILABLE
|
|
|
|
from typing import Any, Dict, Set, Tuple, Type, Union
|
|
|
|
__all__ = ["traverse", ]
|
|
|
|
DataPipe = Union[IterDataPipe, MapDataPipe]
|
|
reduce_ex_hook = None
|
|
|
|
|
|
def _stub_unpickler():
|
|
return "STUB"
|
|
|
|
|
|
# TODO(VitalyFedyunin): Make sure it works without dill module installed
|
|
def _list_connected_datapipes(scan_obj, only_datapipe, cache):
|
|
f = io.BytesIO()
|
|
p = pickle.Pickler(f) # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is
|
|
if DILL_AVAILABLE:
|
|
from dill import Pickler as dill_Pickler
|
|
d = dill_Pickler(f)
|
|
else:
|
|
d = None
|
|
|
|
def stub_pickler(obj):
|
|
return _stub_unpickler, ()
|
|
|
|
captured_connections = []
|
|
|
|
def getstate_hook(obj):
|
|
state = {}
|
|
for k, v in obj.__dict__.items():
|
|
if isinstance(v, (IterDataPipe, MapDataPipe, tuple)):
|
|
state[k] = v
|
|
return state
|
|
|
|
def reduce_hook(obj):
|
|
if obj == scan_obj or obj in cache:
|
|
raise NotImplementedError
|
|
else:
|
|
captured_connections.append(obj)
|
|
return _stub_unpickler, ()
|
|
|
|
datapipe_classes: Tuple[Type[DataPipe]] = (IterDataPipe, MapDataPipe) # type: ignore[assignment]
|
|
|
|
try:
|
|
for cls in datapipe_classes:
|
|
cls.set_reduce_ex_hook(reduce_hook)
|
|
if only_datapipe:
|
|
cls.set_getstate_hook(getstate_hook)
|
|
try:
|
|
p.dump(scan_obj)
|
|
except (pickle.PickleError, AttributeError, TypeError):
|
|
if DILL_AVAILABLE:
|
|
d.dump(scan_obj)
|
|
else:
|
|
raise
|
|
finally:
|
|
for cls in datapipe_classes:
|
|
cls.set_reduce_ex_hook(None)
|
|
if only_datapipe:
|
|
cls.set_getstate_hook(None)
|
|
if DILL_AVAILABLE:
|
|
from dill import extend as dill_extend
|
|
dill_extend(False) # Undo change to dispatch table
|
|
return captured_connections
|
|
|
|
|
|
def traverse(datapipe, only_datapipe=False):
|
|
cache: Set[DataPipe] = set()
|
|
return _traverse_helper(datapipe, only_datapipe, cache)
|
|
|
|
|
|
# Add cache here to prevent infinite recursion on DataPipe
|
|
def _traverse_helper(datapipe, only_datapipe, cache):
|
|
if not isinstance(datapipe, (IterDataPipe, MapDataPipe)):
|
|
raise RuntimeError("Expected `IterDataPipe` or `MapDataPipe`, but {} is found".format(type(datapipe)))
|
|
|
|
cache.add(datapipe)
|
|
items = _list_connected_datapipes(datapipe, only_datapipe, cache)
|
|
d: Dict[DataPipe, Any] = {datapipe: {}}
|
|
for item in items:
|
|
# Using cache.copy() here is to prevent recursion on a single path rather than global graph
|
|
# Single DataPipe can present multiple times in different paths in graph
|
|
d[datapipe].update(_traverse_helper(item, only_datapipe, cache.copy()))
|
|
return d
|