Files
pytorch/torch/_C/_autograd.pyi
Shivam Raikundalia 3ebbeb75fd [Profiler] Make Kineto traces export ns granularity for finer timestamps (#122425) (#123650)
Summary:

Kineto traces use microsecond level granularity because of chrome tracing defaults to that precision. Fix by adding preprocessor flag to TARGETS and BUCK files. Also remove any unnecessary ns to us conversions made in the profiler itself.

This diff contains profiler changes only. Libkineto changes found in D54964435.

Test Plan:
Check JSON and chrome tracing to make sure values are as expected. Tracing with flags enabled should have ns precision. Tracings without flags should be same as master.
Zoomer: https://www.internalfb.com/intern/zoomer/?profiling_run_fbid=796886748550189
Ran key_averages() to make sure FunctionEvent code working as expected:
--  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls

                                          ProfilerStep*         0.74%       3.976ms        64.40%     346.613ms      69.323ms       0.000us         0.00%      61.710ms      12.342ms             5
                      Optimizer.zero_grad#SGD.zero_grad         0.76%       4.109ms         0.76%       4.109ms     821.743us       0.000us         0.00%       0.000us       0.000us             5
                                          ## forward ##         6.89%      37.057ms        27.19%     146.320ms      29.264ms       0.000us         0.00%      58.708ms      11.742ms             5
                                           aten::conv2d         0.22%       1.176ms         7.74%      41.658ms     157.199us       0.000us         0.00%      27.550ms     103.962us           265
                                      aten::convolution         0.79%       4.273ms         7.52%      40.482ms     152.762us       0.000us         0.00%      27.550ms     103.962us           265
                                     aten::_convolution         0.69%       3.688ms         6.73%      36.209ms     136.637us       0.000us         0.00%      27.550ms     103.962us           265
                                aten::cudnn_convolution         6.04%      32.520ms         6.04%      32.520ms     122.719us      27.550ms         8.44%      27.550ms     103.962us           265
                                             aten::add_         2.42%      13.045ms         2.42%      13.045ms      30.694us      12.700ms         3.89%      12.700ms      29.882us           425
                                       aten::batch_norm         0.19%       1.027ms         8.12%      43.717ms     164.971us       0.000us         0.00%      16.744ms      63.185us           265
                           aten::_batch_norm_impl_index         0.31%       1.646ms         7.93%      42.691ms     161.096us       0.000us         0.00%      16.744ms      63.185us           265
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------

Differential Revision: D55925068

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123650
Approved by: https://github.com/aaronenyeshi
2024-04-11 04:29:20 +00:00

125 lines
4.1 KiB
Python

from enum import Enum
from typing import Any, Callable, List, Optional, Set
import torch
from ._profiler import (
_ProfilerEvent,
ActiveProfilerType,
ProfilerActivity,
ProfilerConfig,
)
# Defined in tools/autograd/init.cpp
class DeviceType(Enum):
CPU = ...
CUDA = ...
MKLDNN = ...
OPENGL = ...
OPENCL = ...
IDEEP = ...
HIP = ...
FPGA = ...
ORT = ...
XLA = ...
MPS = ...
HPU = ...
Meta = ...
Vulkan = ...
Metal = ...
PrivateUse1 = ...
class ProfilerEvent:
def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
def cpu_memory_usage(self) -> int: ...
def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ...
def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ...
def cuda_memory_usage(self) -> int: ...
def device(self) -> int: ...
def handle(self) -> int: ...
def has_cuda(self) -> bool: ...
def is_remote(self) -> bool: ...
def kind(self) -> int: ...
def name(self) -> str: ...
def node_id(self) -> int: ...
def sequence_nr(self) -> int: ...
def shapes(self) -> List[List[int]]: ...
def thread_id(self) -> int: ...
def flops(self) -> float: ...
def is_async(self) -> bool: ...
class _KinetoEvent:
def name(self) -> str: ...
def device_index(self) -> int: ...
def device_resource_id(self) -> int: ...
def start_ns(self) -> int: ...
def duration_ns(self) -> int: ...
def is_async(self) -> bool: ...
def linked_correlation_id(self) -> int: ...
def shapes(self) -> List[List[int]]: ...
def dtypes(self) -> List[str]: ...
def concrete_inputs(self) -> List[Any]: ...
def device_type(self) -> DeviceType: ...
def start_thread_id(self) -> int: ...
def end_thread_id(self) -> int: ...
def correlation_id(self) -> int: ...
def fwd_thread_id(self) -> int: ...
def stack(self) -> List[str]: ...
def scope(self) -> int: ...
def sequence_nr(self) -> int: ...
def flops(self) -> int: ...
def cuda_elapsed_us(self) -> int: ...
def privateuse1_elapsed_us(self) -> int: ...
class _ProfilerResult:
def events(self) -> List[_KinetoEvent]: ...
def legacy_events(self) -> List[List[ProfilerEvent]]: ...
def save(self, path: str) -> None: ...
def experimental_event_tree(self) -> List[_ProfilerEvent]: ...
def trace_start_ns(self) -> int: ...
class SavedTensor: ...
def _enable_profiler(
config: ProfilerConfig,
activities: Set[ProfilerActivity],
) -> None: ...
def _prepare_profiler(
config: ProfilerConfig,
activities: Set[ProfilerActivity],
) -> None: ...
def _disable_profiler() -> _ProfilerResult: ...
def _profiler_enabled() -> bool: ...
def _add_metadata_json(key: str, value: str) -> None: ...
def _kineto_step() -> None: ...
def _get_sequence_nr() -> int: ...
def kineto_available() -> bool: ...
def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ...
def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
def _supported_activities() -> Set[ProfilerActivity]: ...
def _enable_record_function(enable: bool) -> None: ...
def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
def _push_saved_tensors_default_hooks(
pack_hook: Callable[[torch.Tensor], Any],
unpack_hook: Callable[[Any], torch.Tensor],
) -> None: ...
def _pop_saved_tensors_default_hooks() -> None: ...
def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
def _profiler_type() -> ActiveProfilerType: ...
def _saved_tensors_hooks_enable() -> None: ...
def _saved_tensors_hooks_disable(message: str) -> None: ...
def _saved_tensors_hooks_get_disabled_error_message() -> Optional[str]: ...
class CreationMeta(Enum):
DEFAULT = ...
IN_CUSTOM_FUNCTION = ...
MULTI_OUTPUT_NODE = ...
NO_GRAD_MODE = ...
INFERENCE_MODE = ...
def _set_creation_meta(t: torch.Tensor, creation_meta: CreationMeta) -> None: ...
def _get_creation_meta(t: torch.Tensor) -> CreationMeta: ...