[TorchTidy] Add Pattern to detect Synchronous Data Loader (#81740)

Summary: By setting num_workers > 0 in DataLoader, we can achieve async data loading, which is non blocking to the computation. This helps speed up the training process. By matching the call structure, we can detect if we are using Synchronous Data Loader.

Test Plan:
Added test in test.profiler.py

Differential Revision: [D38082644](https://our.internmc.facebook.com/intern/diff/D38082644)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81740
Approved by: https://github.com/robieta
This commit is contained in:
David Chen
2022-07-27 17:10:54 +00:00
committed by PyTorch MergeBot
parent df1b7c2978
commit d537f868f3
2 changed files with 62 additions and 7 deletions

View File

@ -32,7 +32,8 @@ from torch.profiler._pattern_matcher import (Pattern, NamePattern,
ExtraCUDACopyPattern,
ForLoopIndexingPattern,
FP32MatMulPattern,
OptimizerSingleTensorPattern)
OptimizerSingleTensorPattern,
SynchronizedDataLoaderPattern)
from torch.testing._internal.common_device_type import skipCUDAVersionIn
try:
@ -1655,6 +1656,16 @@ aten::mm""")
num_matched.append(len(pattern.matched_events()))
self.assertEqual(num_matched, [i for i, _ in cases])
def test_profiler_synchronized_dataloader_pattern(self):
dataset = torch.rand((100, 100))
sync_dataloader = torch.utils.data.DataLoader(dataset, batch_size=10)
async_dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, num_workers=4)
with profile(with_stack=True) as prof:
next(iter(sync_dataloader))
next(iter(async_dataloader))
pattern = SynchronizedDataLoaderPattern(prof)
num_matched = len(pattern.matched_events())
self.assertEqual(num_matched, 1)
if __name__ == '__main__':
run_tests()

View File

@ -1,4 +1,5 @@
from collections import deque
import os
import re
from typing import Dict, List, Set
@ -340,12 +341,11 @@ class OptimizerSingleTensorPattern(Pattern):
def __init__(self, prof: profile, should_benchmark: bool = False):
super().__init__(prof, should_benchmark)
self.name = "Optimizer Single Tensor Pattern"
self.optimizers_with_foreach = [
"adam", "sgd", "adamw"
]
self.optimizers_with_foreach = ["adam", "sgd", "adamw"]
self.description = (
"Deteced optimizer running with single tensor implementation. "
"Please enable multi tensor implementation by passing 'foreach=True' into optimizer.")
"Please enable multi tensor implementation by passing 'foreach=True' into optimizer."
)
def match(self, event: _ProfilerEvent):
for optimizer in self.optimizers_with_foreach:
@ -354,6 +354,49 @@ class OptimizerSingleTensorPattern(Pattern):
return False
class SynchronizedDataLoaderPattern(Pattern):
'''
This pattern identifies if we are using num_workers=0 in DataLoader.
example:
torch.utils.data.DataLoader(dataset, batch_size=batch_size)
Add num_workers=N to the arguments. N depends on system configuration.
Pattern:
dataloader.py(...): __iter__
dataloader.py(...): _get_iterator
NOT dataloader.py(...): check_worker_number_rationality
Algorithm:
If we don't see check_worker_number_rationality call in the dataloader __iter__,
It is not an asynchronous dataloader.
'''
def __init__(self, prof: profile, should_benchmark: bool = False):
super().__init__(prof, should_benchmark)
self.name = "Synchronized DataLoader Pattern"
self.description = (
"Detected DataLoader running with synchronized implementation. "
"Please enable asynchronous dataloading by setting num_workers > 0 when initializing DataLoader."
)
def match(self, event: _ProfilerEvent):
def is_dataloader_function(name: str, function_name: str):
return name.startswith(os.path.join("torch", "utils", "data", "dataloader.py")) and name.endswith(function_name)
if not is_dataloader_function(event.name(), "__iter__"):
return False
if not event.children:
return False
event = event.children[0]
if not is_dataloader_function(event.name(), "_get_iterator"):
return False
if not event.children:
return False
event = event.children[0]
return not is_dataloader_function(event.name(), "check_worker_number_rationality")
# TODO: We should also check if the loader is bottleneck.
def source_code_location(event: _ProfilerEvent):
while event:
if event_type(event) == _EventType.PyCall or event_type(
@ -361,7 +404,7 @@ def source_code_location(event: _ProfilerEvent):
assert isinstance(event.extra_fields,
_ExtraFields_PyCall) or isinstance(
event.extra_fields, _ExtraFields_PyCCall)
if not event.extra_fields.caller.file_name.startswith("torch/"):
if not event.extra_fields.caller.file_name.startswith("torch" + os.sep):
return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}"
event = event.parent
return "No source code location found"
@ -377,7 +420,8 @@ def report_all_anti_patterns(prof, should_benchmark: bool = False):
ExtraCUDACopyPattern(prof, should_benchmark),
ForLoopIndexingPattern(prof, should_benchmark),
FP32MatMulPattern(prof, should_benchmark),
OptimizerSingleTensorPattern(prof, should_benchmark)
OptimizerSingleTensorPattern(prof, should_benchmark),
SynchronizedDataLoaderPattern(prof, should_benchmark)
]
reported = set()
summaries = []