mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TorchTidy] Add option to generate json report (#82261)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82261 Approved by: https://github.com/robieta
This commit is contained in:
committed by
PyTorch MergeBot
parent
1a74fd166d
commit
7922bbef73
@ -1,4 +1,5 @@
|
||||
from collections import deque
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
@ -26,6 +27,7 @@ class Pattern:
|
||||
self.should_benchmark = should_benchmark
|
||||
self.name = "Please specify a name for pattern"
|
||||
self.description = "Please specify a description for pattern"
|
||||
self.url = ""
|
||||
assert prof.profiler is not None and prof.profiler.kineto_results is not None
|
||||
self.event_tree = prof.profiler.kineto_results.experimental_event_tree(
|
||||
)
|
||||
@ -52,11 +54,13 @@ class Pattern:
|
||||
default_summary = f"{self.name}: {len(events)} events matched."
|
||||
if self.should_benchmark:
|
||||
# If benchmark summary is not empty, use it.
|
||||
return self.benchmark_summary(events) if hasattr( # type: ignore[attr-defined]
|
||||
self, 'benchmark') else default_summary
|
||||
return self.benchmark_summary(
|
||||
events) if hasattr( # type: ignore[attr-defined]
|
||||
self, 'benchmark') else default_summary
|
||||
return default_summary
|
||||
|
||||
def benchmark_summary(self, events: List[_ProfilerEvent]):
|
||||
|
||||
def format_time(time_ns: int):
|
||||
unit_lst = ["ns", "us", "ms"]
|
||||
for unit in unit_lst:
|
||||
@ -66,7 +70,8 @@ class Pattern:
|
||||
return f"{time_ns:.2f} s"
|
||||
|
||||
assert hasattr(self, 'benchmark'), 'Please implement benchmark()'
|
||||
shapes_factor_map = self.benchmark(events) # type: ignore[attr-defined]
|
||||
shapes_factor_map = self.benchmark( # type: ignore[attr-defined]
|
||||
events)
|
||||
original_time = sum(event.duration_time_ns for event in events)
|
||||
new_time = sum(shapes_factor_map[input_shapes(event)] *
|
||||
event.duration_time_ns for event in events)
|
||||
@ -158,6 +163,7 @@ class ExtraCUDACopyPattern(Pattern):
|
||||
super().__init__(prof, should_benchmark)
|
||||
self.name = "Extra CUDA Copy Pattern"
|
||||
self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initalize it on GPU."
|
||||
self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device"
|
||||
self.init_ops = {
|
||||
"aten::fill_", "aten::zero_", "aten::normal_", "aten::uniform_"
|
||||
}
|
||||
@ -273,6 +279,7 @@ class FP32MatMulPattern(Pattern):
|
||||
"You are currently using GPU that supports TF32. "
|
||||
"Please enable TF32 by setting 'torch.backends.cuda.matmul.allow_tf32 = True'"
|
||||
)
|
||||
self.url = "https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
|
||||
@property
|
||||
def skip(self):
|
||||
@ -341,6 +348,7 @@ class OptimizerSingleTensorPattern(Pattern):
|
||||
"Deteced optimizer running with single tensor implementation. "
|
||||
"Please enable multi tensor implementation by passing 'foreach=True' into optimizer."
|
||||
)
|
||||
self.url = ""
|
||||
|
||||
def match(self, event: _ProfilerEvent):
|
||||
for optimizer in self.optimizers_with_foreach:
|
||||
@ -374,10 +382,17 @@ class SynchronizedDataLoaderPattern(Pattern):
|
||||
"Detected DataLoader running with synchronized implementation. "
|
||||
"Please enable asynchronous dataloading by setting num_workers > 0 when initializing DataLoader."
|
||||
)
|
||||
self.url = (
|
||||
"https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
|
||||
"#enable-async-data-loading-and-augmentation")
|
||||
|
||||
def match(self, event: _ProfilerEvent):
|
||||
|
||||
def is_dataloader_function(name: str, function_name: str):
|
||||
return name.startswith(os.path.join("torch", "utils", "data", "dataloader.py")) and name.endswith(function_name)
|
||||
return name.startswith(
|
||||
os.path.join("torch", "utils", "data",
|
||||
"dataloader.py")) and name.endswith(function_name)
|
||||
|
||||
if not is_dataloader_function(event.name(), "__iter__"):
|
||||
return False
|
||||
if not event.children:
|
||||
@ -388,7 +403,8 @@ class SynchronizedDataLoaderPattern(Pattern):
|
||||
if not event.children:
|
||||
return False
|
||||
event = event.children[0]
|
||||
return not is_dataloader_function(event.name(), "check_worker_number_rationality")
|
||||
return not is_dataloader_function(event.name(),
|
||||
"check_worker_number_rationality")
|
||||
# TODO: We should also check if the loader is bottleneck.
|
||||
|
||||
|
||||
@ -417,6 +433,9 @@ class GradNotSetToNonePattern(Pattern):
|
||||
self.description = (
|
||||
"Detected gradient set to zero instead of None. "
|
||||
"Please add 'set_to_none=True' when calling zero_grad().")
|
||||
self.url = (
|
||||
"https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
|
||||
"#disable-gradient-calculation-for-validation-or-inference")
|
||||
|
||||
def match(self, event: _ProfilerEvent):
|
||||
if not event.name().endswith(": zero_grad"):
|
||||
@ -449,6 +468,9 @@ class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern):
|
||||
super().__init__(prof, should_benchmark)
|
||||
self.name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern"
|
||||
self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d."
|
||||
self.url = (
|
||||
"https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
|
||||
"#disable-bias-for-convolutions-directly-followed-by-a-batch-norm")
|
||||
|
||||
@property
|
||||
def skip(self):
|
||||
@ -476,6 +498,7 @@ class MatMulDimInFP16Pattern(Pattern):
|
||||
super().__init__(prof, should_benchmark)
|
||||
self.name = "Matrix Multiplication Dimension Not Aligned Pattern"
|
||||
self.description = "Detected matmul with dimension not aligned. Please use matmul with aligned dimension."
|
||||
self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-mixed-precision-and-amp"
|
||||
|
||||
@property
|
||||
def skip(self):
|
||||
@ -538,7 +561,8 @@ def source_code_location(event: _ProfilerEvent):
|
||||
assert isinstance(event.extra_fields,
|
||||
_ExtraFields_PyCall) or isinstance(
|
||||
event.extra_fields, _ExtraFields_PyCCall)
|
||||
if not event.extra_fields.caller.file_name.startswith("torch" + os.sep):
|
||||
if not event.extra_fields.caller.file_name.startswith("torch" +
|
||||
os.sep):
|
||||
return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}"
|
||||
event = event.parent
|
||||
return "No source code location found"
|
||||
@ -578,7 +602,11 @@ def eventTreeBFS(event_tree: List[_ProfilerEvent]):
|
||||
stack.append(child_event)
|
||||
|
||||
|
||||
def report_all_anti_patterns(prof, should_benchmark: bool = False):
|
||||
def report_all_anti_patterns(prof,
|
||||
should_benchmark: bool = False,
|
||||
print_enable: bool = True,
|
||||
json_report_dir: str = None):
|
||||
report_dict: Dict = {}
|
||||
anti_patterns = [
|
||||
ExtraCUDACopyPattern(prof, should_benchmark),
|
||||
ForLoopIndexingPattern(prof, should_benchmark),
|
||||
@ -604,8 +632,27 @@ def report_all_anti_patterns(prof, should_benchmark: bool = False):
|
||||
if report_msg not in reported:
|
||||
message_list.append(report_msg)
|
||||
reported.add(report_msg)
|
||||
src_location, line_no = source_code_location(event).split(":")
|
||||
report_dict.setdefault(src_location, []).append({
|
||||
"line_number": int(line_no),
|
||||
"name": anti_pattern.name,
|
||||
"url": anti_pattern.url,
|
||||
"message": anti_pattern.description,
|
||||
})
|
||||
|
||||
if json_report_dir is not None:
|
||||
json_report_path = os.path.join(json_report_dir,
|
||||
"torchtidy_report.json")
|
||||
if os.path.exists(json_report_path):
|
||||
with open(json_report_path, "r") as f:
|
||||
exisiting_report = json.load(f)
|
||||
exisiting_report.update(report_dict)
|
||||
report_dict = exisiting_report
|
||||
with open(json_report_path, "w") as f:
|
||||
json.dump(report_dict, f, indent=4)
|
||||
|
||||
message_list.append("Summary:")
|
||||
message_list += summaries
|
||||
message_list.append(f"{'-'*40}TorchTidy Report{'-'*40}")
|
||||
print("\n".join(message_list))
|
||||
if print_enable:
|
||||
print("\n".join(message_list))
|
||||
|
Reference in New Issue
Block a user