mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Profiler] Improve the docstring for export_memory_timeline (#110949)
Summary: Add more details about the export_memory_timeline API, as we've landed new representations of the memory timeline data. Test Plan: CI, should be no functional change, as we only changed comments. Differential Revision: D50123450 Pull Request resolved: https://github.com/pytorch/pytorch/pull/110949 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
31611b40b9
commit
52b1470935
@ -54,7 +54,8 @@ class _KinetoProfile:
|
||||
``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``.
|
||||
Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA.
|
||||
record_shapes (bool): save information about operator's input shapes.
|
||||
profile_memory (bool): track tensor memory allocation/deallocation.
|
||||
profile_memory (bool): track tensor memory allocation/deallocation (see ``export_memory_timeline``
|
||||
for more details).
|
||||
with_stack (bool): record source information (file and line number) for the ops.
|
||||
with_flops (bool): use formula to estimate the FLOPS of specific operators
|
||||
(matrix multiplication and 2D convolution).
|
||||
@ -251,14 +252,26 @@ class _KinetoProfile:
|
||||
return MemoryProfile(self.profiler.kineto_results)
|
||||
|
||||
def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
|
||||
"""Extract the memory information from the memory profile collected
|
||||
tree for a given device, and export a timeline plot consisting of
|
||||
[times, [sizes by category]], where times are timestamps and sizes
|
||||
are memory usage for each category. The memory timeline plot will
|
||||
be saved a JSON (by default) or gzipped JSON.
|
||||
"""Export memory event information from the profiler collected
|
||||
tree for a given device, and export a timeline plot. There are 3
|
||||
exportable files using ``export_memory_timeline``, each controlled by the
|
||||
``path``'s suffix.
|
||||
|
||||
Input: (path of file, device)
|
||||
Output: File written as JSON or gzipped JSON
|
||||
- For an HTML compatible plot, use the suffix ``.html``, and a memory timeline
|
||||
plot will be embedded as a PNG file in the HTML file.
|
||||
|
||||
- For plot points consisting of ``[times, [sizes by category]]``, where
|
||||
``times`` are timestamps and ``sizes`` are memory usage for each category.
|
||||
The memory timeline plot will be saved a JSON (``.json``) or gzipped JSON
|
||||
(``.json.gz``) depending on the suffix.
|
||||
|
||||
- For raw memory points, use the suffix ``.raw.json.gz``. Each raw memory
|
||||
event will consist of ``(timestamp, action, numbytes, category)``, where
|
||||
``action`` is one of ``[PREEXISTING, CREATE, INCREMENT_VERSION, DESTROY]``,
|
||||
and ``category`` is one of the enums from
|
||||
``torch.profiler._memory_profiler.Category``.
|
||||
|
||||
Output: Memory timeline written as gzipped JSON, JSON, or HTML.
|
||||
"""
|
||||
# Default to device 0, if unset. Fallback on cpu.
|
||||
if device is None and self.use_device and self.use_device != "cuda":
|
||||
|
Reference in New Issue
Block a user