[BE][Ez]: Autotype torch/profiler with ruff ANN (#157923)

Apply ruff autotyping fixes to add annotations to torch profiler

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157923
Approved by: https://github.com/albanD, https://github.com/sraikund16
This commit is contained in:
Aaron Gokaslan
2025-07-09 22:07:46 +00:00
committed by PyTorch MergeBot
parent 53ab73090e
commit a1dad2f2d2
5 changed files with 60 additions and 57 deletions

View File

@ -514,7 +514,7 @@ class DataFlowGraph:
def flow_nodes(self) -> tuple[DataFlowNode, ...]:
return tuple(self._flow_nodes)
def validate(self):
def validate(self) -> None:
# Check that each (Tensor, version) pair has a unique creation node
outputs: set[tuple[TensorKey, int]] = set()
for node in self.flow_nodes:
@ -964,7 +964,7 @@ class MemoryProfile:
if key is not None:
self._categories.set_by_id(key, Category.OPTIMIZER_STATE)
def _set_autograd_detail(self):
def _set_autograd_detail(self) -> None:
prior = {None, Category.AUTOGRAD_DETAIL}
for node in self._data_flow_graph.flow_nodes:
if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event):
@ -976,7 +976,7 @@ class MemoryProfile:
class MemoryProfileTimeline:
def __init__(self, memory_profile):
def __init__(self, memory_profile) -> None:
"""The minimum representation of the memory profile timeline
includes the memory timeline and categories. The timeline
consists of [timestamp, action, (TensorKey, version), numbytes]
@ -999,7 +999,7 @@ class MemoryProfileTimeline:
times: list[int] = []
sizes: list[list[int]] = []
def update(key, version, delta):
def update(key, version, delta) -> None:
category = (
self.categories.get(key, version)
if isinstance(key, TensorKey)

View File

@ -26,7 +26,7 @@ class Pattern:
In subclass, define description and skip property.
"""
def __init__(self, prof: profile, should_benchmark: bool = False):
def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
self.prof = prof
self.should_benchmark = should_benchmark
self.name = "Please specify a name for pattern"
@ -39,7 +39,7 @@ class Pattern:
self.tid_root.setdefault(event.start_tid, []).append(event)
@property
def skip(self):
def skip(self) -> bool:
return False
def report(self, event: _ProfilerEvent):
@ -66,8 +66,8 @@ class Pattern:
)
return default_summary
def benchmark_summary(self, events: list[_ProfilerEvent]):
def format_time(time_ns: int):
def benchmark_summary(self, events: list[_ProfilerEvent]) -> str:
def format_time(time_ns: int) -> str:
unit_lst = ["ns", "us", "ms"]
for unit in unit_lst:
if time_ns < 1000:
@ -135,7 +135,9 @@ class Pattern:
class NamePattern(Pattern):
def __init__(self, prof: profile, name: str, should_benchmark: bool = False):
def __init__(
self, prof: profile, name: str, should_benchmark: bool = False
) -> None:
super().__init__(prof, should_benchmark)
self.description = f"Matched Name Event: {name}"
self.name = name
@ -161,7 +163,7 @@ class ExtraCUDACopyPattern(Pattern):
If at any step we failed, it is not a match.
"""
def __init__(self, prof: profile, should_benchmark: bool = False):
def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
super().__init__(prof, should_benchmark)
self.name = "Extra CUDA Copy Pattern"
self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initialize it on GPU."
@ -174,7 +176,7 @@ class ExtraCUDACopyPattern(Pattern):
}
@property
def skip(self):
def skip(self) -> bool:
return not self.prof.with_stack or not self.prof.record_shapes
def match(self, event):
@ -248,7 +250,7 @@ class ForLoopIndexingPattern(Pattern):
We also keep a dictionary to avoid duplicate match in the for loop.
"""
def __init__(self, prof: profile, should_benchmark: bool = False):
def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
super().__init__(prof, should_benchmark)
self.name = "For Loop Indexing Pattern"
self.description = "For loop indexing detected. Vectorization recommended."
@ -271,7 +273,7 @@ class ForLoopIndexingPattern(Pattern):
return False
# Custom event list matching
def same_ops(list1, list2):
def same_ops(list1, list2) -> bool:
if len(list1) != len(list2):
return False
for op1, op2 in zip(list1, list2):
@ -295,7 +297,7 @@ class ForLoopIndexingPattern(Pattern):
class FP32MatMulPattern(Pattern):
def __init__(self, prof: profile, should_benchmark: bool = False):
def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
super().__init__(prof, should_benchmark)
self.name = "FP32 MatMul Pattern"
self.description = (
@ -316,7 +318,7 @@ class FP32MatMulPattern(Pattern):
)
return has_tf32 is False or super().skip or not self.prof.record_shapes
def match(self, event: _ProfilerEvent):
def match(self, event: _ProfilerEvent) -> bool:
# If we saw this pattern once, we don't need to match it again
if event.tag != _EventType.TorchOp:
return False
@ -365,7 +367,7 @@ class OptimizerSingleTensorPattern(Pattern):
String match
"""
def __init__(self, prof: profile, should_benchmark: bool = False):
def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
super().__init__(prof, should_benchmark)
self.name = "Optimizer Single Tensor Pattern"
self.optimizers_with_foreach = ["adam", "sgd", "adamw"]
@ -375,7 +377,7 @@ class OptimizerSingleTensorPattern(Pattern):
)
self.url = ""
def match(self, event: _ProfilerEvent):
def match(self, event: _ProfilerEvent) -> bool:
for optimizer in self.optimizers_with_foreach:
if event.name.endswith(f"_single_tensor_{optimizer}"):
return True
@ -400,7 +402,7 @@ class SynchronizedDataLoaderPattern(Pattern):
"""
def __init__(self, prof: profile, should_benchmark: bool = False):
def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
super().__init__(prof, should_benchmark)
self.name = "Synchronized DataLoader Pattern"
self.description = (
@ -412,7 +414,7 @@ class SynchronizedDataLoaderPattern(Pattern):
"#enable-async-data-loading-and-augmentation"
)
def match(self, event: _ProfilerEvent):
def match(self, event: _ProfilerEvent) -> bool:
def is_dataloader_function(name: str, function_name: str):
return name.startswith(
os.path.join("torch", "utils", "data", "dataloader.py")
@ -459,7 +461,7 @@ class GradNotSetToNonePattern(Pattern):
String match
"""
def __init__(self, prof: profile, should_benchmark: bool = False):
def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
super().__init__(prof, should_benchmark)
self.name = "Gradient Set To Zero Instead of None Pattern"
self.description = (
@ -471,7 +473,7 @@ class GradNotSetToNonePattern(Pattern):
"#disable-gradient-calculation-for-validation-or-inference"
)
def match(self, event: _ProfilerEvent):
def match(self, event: _ProfilerEvent) -> bool:
if not event.name.endswith(": zero_grad"):
return False
if not event.children:
@ -500,7 +502,7 @@ class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern):
String match
"""
def __init__(self, prof: profile, should_benchmark: bool = False):
def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
super().__init__(prof, should_benchmark)
self.name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern"
self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d."
@ -531,17 +533,17 @@ class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern):
class MatMulDimInFP16Pattern(Pattern):
def __init__(self, prof: profile, should_benchmark: bool = False):
def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
super().__init__(prof, should_benchmark)
self.name = "Matrix Multiplication Dimension Not Aligned Pattern"
self.description = "Detected matmul with dimension not aligned. Please use matmul with aligned dimension."
self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-mixed-precision-and-amp"
@property
def skip(self):
def skip(self) -> bool:
return not self.prof.with_stack or not self.prof.record_shapes
def match(self, event: _ProfilerEvent):
def match(self, event: _ProfilerEvent) -> bool:
def mutiple_of(shapes, multiple):
return all(dim % multiple == 0 for shape in shapes for dim in shape[-2:])
@ -584,7 +586,7 @@ class MatMulDimInFP16Pattern(Pattern):
return shapes_factor_map
def source_code_location(event: Optional[_ProfilerEvent]):
def source_code_location(event: Optional[_ProfilerEvent]) -> str:
while event:
if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall:
assert isinstance(
@ -611,7 +613,7 @@ def report_all_anti_patterns(
should_benchmark: bool = False,
print_enable: bool = True,
json_report_dir: Optional[str] = None,
):
) -> None:
report_dict: dict = {}
anti_patterns = [
ExtraCUDACopyPattern(prof, should_benchmark),

View File

@ -52,7 +52,7 @@ class Interval:
class EventKey:
def __init__(self, event):
def __init__(self, event) -> None:
self.event = event
def __hash__(self):
@ -61,7 +61,7 @@ class EventKey:
def __eq__(self, other):
return self.event.id == other.event.id
def __repr__(self):
def __repr__(self) -> str:
return f"{self.event.name}"
def intervals_overlap(self, intervals: list[Interval]):
@ -98,7 +98,7 @@ class EventKey:
class BasicEvaluation:
def __init__(self, prof: profile):
def __init__(self, prof: profile) -> None:
self.profile = prof
self.metrics: dict[EventKey, EventMetrics] = {}
self.compute_self_time()
@ -110,7 +110,7 @@ class BasicEvaluation:
self.queue_depth_list = self.compute_queue_depth()
self.compute_idle_time()
def compute_self_time(self):
def compute_self_time(self) -> None:
"""
Computes event's self time(total time - time in child ops).
"""
@ -234,7 +234,7 @@ class BasicEvaluation:
return queue_depth_list
def compute_idle_time(self):
def compute_idle_time(self) -> None:
"""
Computes idle time of the profile.
"""
@ -386,7 +386,7 @@ def source_code_location(event):
# https://github.com/pytorch/pytorch/issues/75504
# TODO(dberard) - deprecate / remove workaround for CUDA >= 12, when
# we stop supporting older CUDA versions.
def _init_for_cuda_graphs():
def _init_for_cuda_graphs() -> None:
from torch.autograd.profiler import profile
with profile():

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
from contextlib import contextmanager
from typing import NoReturn
try:
@ -8,13 +9,13 @@ except ImportError:
class _ITTStub:
@staticmethod
def _fail(*args, **kwargs):
def _fail(*args, **kwargs) -> NoReturn:
raise RuntimeError(
"ITT functions not installed. Are you sure you have a ITT build?"
)
@staticmethod
def is_available():
def is_available() -> bool:
return False
rangePush = _fail

View File

@ -145,7 +145,7 @@ class _KinetoProfile:
execution_trace_observer: Optional[_ITraceObserver] = None,
acc_events: bool = False,
custom_trace_id_callback: Optional[Callable[[], str]] = None,
):
) -> None:
self.activities = set(activities) if activities else supported_activities()
self.record_shapes = record_shapes
self.with_flops = with_flops
@ -174,14 +174,14 @@ class _KinetoProfile:
# user-defined metadata to be amended to the trace
self.preset_metadata: dict[str, str] = {}
def start(self):
def start(self) -> None:
self.prepare_trace()
self.start_trace()
def stop(self):
def stop(self) -> None:
self.stop_trace()
def prepare_trace(self):
def prepare_trace(self) -> None:
if hasattr(torch, "_inductor"):
import torch._inductor.config as inductor_config
@ -202,7 +202,7 @@ class _KinetoProfile:
)
self.profiler._prepare_trace()
def start_trace(self):
def start_trace(self) -> None:
if self.execution_trace_observer:
self.execution_trace_observer.start()
assert self.profiler is not None
@ -248,7 +248,7 @@ class _KinetoProfile:
for k, v in self.preset_metadata.items():
self.add_metadata_json(k, v)
def stop_trace(self):
def stop_trace(self) -> None:
if self.execution_trace_observer:
self.execution_trace_observer.stop()
assert self.profiler is not None
@ -284,7 +284,7 @@ class _KinetoProfile:
def toggle_collection_dynamic(
self, enable: bool, activities: Iterable[ProfilerActivity]
):
) -> None:
"""Toggle collection of activities on/off at any point of collection. Currently supports toggling Torch Ops
(CPU) and CUDA activity supported in Kineto
@ -341,7 +341,7 @@ class _KinetoProfile:
assert self.profiler
return self.profiler.function_events
def add_metadata(self, key: str, value: str):
def add_metadata(self, key: str, value: str) -> None:
"""
Adds a user defined metadata with a string key and a string value
into the trace file
@ -349,14 +349,14 @@ class _KinetoProfile:
wrapped_value = '"' + value.replace('"', '\\"') + '"'
torch.autograd._add_metadata_json(key, wrapped_value)
def add_metadata_json(self, key: str, value: str):
def add_metadata_json(self, key: str, value: str) -> None:
"""
Adds a user defined metadata with a string key and a valid json value
into the trace file
"""
torch.autograd._add_metadata_json(key, value)
def preset_metadata_json(self, key: str, value: str):
def preset_metadata_json(self, key: str, value: str) -> None:
"""
Preset a user defined metadata when the profiler is not started
and added into the trace file later.
@ -702,7 +702,7 @@ class profile(_KinetoProfile):
# deprecated:
use_cuda: Optional[bool] = None,
custom_trace_id_callback: Optional[Callable[[], str]] = None,
):
) -> None:
activities_set = set(activities) if activities else supported_activities()
if use_cuda is not None:
warn(
@ -818,7 +818,7 @@ class profile(_KinetoProfile):
if self.execution_trace_observer:
self.execution_trace_observer.cleanup()
def start(self):
def start(self) -> None:
self._transit_action(ProfilerAction.NONE, self.current_action)
if self.record_steps:
self.step_rec_fn = prof.record_function(
@ -826,12 +826,12 @@ class profile(_KinetoProfile):
)
self.step_rec_fn.__enter__()
def stop(self):
def stop(self) -> None:
if self.record_steps and self.step_rec_fn:
self.step_rec_fn.__exit__(None, None, None)
self._transit_action(self.current_action, None)
def step(self):
def step(self) -> None:
"""
Signals the profiler that the next profiling step has started.
"""
@ -853,7 +853,7 @@ class profile(_KinetoProfile):
)
self.step_rec_fn.__enter__()
def set_custom_trace_id_callback(self, callback):
def set_custom_trace_id_callback(self, callback) -> None:
"""
Sets a callback to be called when a new trace ID is generated.
"""
@ -867,11 +867,11 @@ class profile(_KinetoProfile):
return None
return self.profiler.trace_id
def _trace_ready(self):
def _trace_ready(self) -> None:
if self.on_trace_ready:
self.on_trace_ready(self)
def _transit_action(self, prev_action, current_action):
def _transit_action(self, prev_action, current_action) -> None:
action_list = self.action_map.get((prev_action, current_action))
if action_list:
for action in action_list:
@ -909,7 +909,7 @@ class ExecutionTraceObserver(_ITraceObserver):
self.output_file_path: str = ""
self.output_file_path_observer: str = ""
def __del__(self):
def __del__(self) -> None:
"""
Calls unregister_callback() to make sure to finalize outputs.
"""
@ -1021,7 +1021,7 @@ class ExecutionTraceObserver(_ITraceObserver):
return None
return resource_dir
def unregister_callback(self):
def unregister_callback(self) -> None:
"""
Removes ET observer from record function callbacks.
"""
@ -1087,7 +1087,7 @@ class ExecutionTraceObserver(_ITraceObserver):
"""
return self._execution_trace_running
def start(self):
def start(self) -> None:
"""
Starts to capture.
"""
@ -1096,7 +1096,7 @@ class ExecutionTraceObserver(_ITraceObserver):
self._execution_trace_running = True
self._record_pg_config()
def stop(self):
def stop(self) -> None:
"""
Stops to capture.
"""
@ -1104,7 +1104,7 @@ class ExecutionTraceObserver(_ITraceObserver):
_disable_execution_trace_observer()
self._execution_trace_running = False
def cleanup(self):
def cleanup(self) -> None:
"""
Calls unregister_callback() to make sure to finalize outputs.
"""