Files
pytorch/torch/utils/data/graph_settings.py
Vitaly Fedyunin d90012689f [DataPipe] Control shuffle settings from DataLoader2 (#65756)
Summary:
Makes `shuffle` DataPipe sensitive to DataLoader(2) `shuffle` kwarg.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65756

Reviewed By: albanD

Differential Revision: D31344867

Pulled By: VitalyFedyunin

fbshipit-source-id: e0084e0ac193ac784d6298328ca1222745681347
2021-12-14 07:35:26 -08:00

36 lines
1.4 KiB
Python

import torch.utils.data.graph
def get_all_graph_pipes(graph):
results = set()
for datapipe, sub_graph in graph.items():
results.add(datapipe)
sub_items = get_all_graph_pipes(sub_graph)
for item in sub_items:
results.add(item)
return results
def apply_sharding(datapipe, num_of_instances, instance_id):
graph = torch.utils.data.graph.traverse(datapipe, exclude_primitive=True)
all_pipes = get_all_graph_pipes(graph)
already_applied_to = None
for pipe in all_pipes:
if hasattr(pipe, 'is_shardable'):
if pipe.is_shardable():
if hasattr(pipe, 'apply_sharding'):
if already_applied_to is not None:
raise RuntimeError('This implementation of sharding can be only applied once per instance of DataPipeline.',
'Already applied to', already_applied_to, 'while trying to apply to', pipe)
pipe.apply_sharding(num_of_instances, instance_id)
already_applied_to = pipe
def apply_shuffle_settings(datapipe, shuffle):
if shuffle is not None:
graph = torch.utils.data.graph.traverse(datapipe)
all_pipes = get_all_graph_pipes(graph)
for pipe in all_pipes:
if hasattr(pipe, 'set_shuffle_settings'):
pipe.set_shuffle_settings(shuffle)