mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Makes `shuffle` DataPipe sensitive to DataLoader(2) `shuffle` kwarg. Pull Request resolved: https://github.com/pytorch/pytorch/pull/65756 Reviewed By: albanD Differential Revision: D31344867 Pulled By: VitalyFedyunin fbshipit-source-id: e0084e0ac193ac784d6298328ca1222745681347
36 lines
1.4 KiB
Python
36 lines
1.4 KiB
Python
import torch.utils.data.graph
|
|
|
|
|
|
def get_all_graph_pipes(graph):
|
|
results = set()
|
|
for datapipe, sub_graph in graph.items():
|
|
results.add(datapipe)
|
|
sub_items = get_all_graph_pipes(sub_graph)
|
|
for item in sub_items:
|
|
results.add(item)
|
|
return results
|
|
|
|
|
|
def apply_sharding(datapipe, num_of_instances, instance_id):
|
|
graph = torch.utils.data.graph.traverse(datapipe, exclude_primitive=True)
|
|
all_pipes = get_all_graph_pipes(graph)
|
|
already_applied_to = None
|
|
for pipe in all_pipes:
|
|
if hasattr(pipe, 'is_shardable'):
|
|
if pipe.is_shardable():
|
|
if hasattr(pipe, 'apply_sharding'):
|
|
if already_applied_to is not None:
|
|
raise RuntimeError('This implementation of sharding can be only applied once per instance of DataPipeline.',
|
|
'Already applied to', already_applied_to, 'while trying to apply to', pipe)
|
|
pipe.apply_sharding(num_of_instances, instance_id)
|
|
already_applied_to = pipe
|
|
|
|
|
|
def apply_shuffle_settings(datapipe, shuffle):
|
|
if shuffle is not None:
|
|
graph = torch.utils.data.graph.traverse(datapipe)
|
|
all_pipes = get_all_graph_pipes(graph)
|
|
for pipe in all_pipes:
|
|
if hasattr(pipe, 'set_shuffle_settings'):
|
|
pipe.set_shuffle_settings(shuffle)
|