Adding Compare in torch.utils.benchmark documentation (#125009)

`torch.utils.benchmark.Compare` is not directly exposed in torch.utils.benchmark documentation.

I think this is a valuable resource to add since it can help people embracing the torch benchmark way of doing things, and help people building documentation towards it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125009
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Alexandre Ghelfi, PhD
2024-05-03 00:50:49 +00:00
committed by PyTorch MergeBot
parent 4440d0755a
commit d18a6f46d0
2 changed files with 29 additions and 0 deletions

View File

@ -19,6 +19,9 @@ Benchmark Utils - torch.utils.benchmark
.. autoclass:: FunctionCounts
:members:
.. autoclass:: Compare
:members:
.. These are missing documentation. Adding them here until a better place
.. is made in this file.
.. py:module:: torch.utils.benchmark.examples

View File

@ -267,6 +267,21 @@ Times are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}).
class Compare:
"""Helper class for displaying the results of many measurements in a
formatted table.
The table format is based on the information fields provided in
:class:`torch.utils.benchmark.Timer` (`description`, `label`, `sub_label`,
`num_threads`, etc).
The table can be directly printed using :meth:`print` or casted as a `str`.
For a full tutorial on how to use this class, see:
https://pytorch.org/tutorials/recipes/recipes/benchmark.html
Args:
results: List of Measurment to display.
"""
def __init__(self, results: List[common.Measurement]):
self._results: List[common.Measurement] = []
self.extend_results(results)
@ -278,6 +293,10 @@ class Compare:
return "\n".join(self._render())
def extend_results(self, results):
"""Append results to already stored ones.
All added results must be instances of ``Measurement``.
"""
for r in results:
if not isinstance(r, common.Measurement):
raise ValueError(
@ -286,15 +305,22 @@ class Compare:
self._results.extend(results)
def trim_significant_figures(self):
"""Enables trimming of significant figures when building the formatted table."""
self._trim_significant_figures = True
def colorize(self, rowwise=False):
"""Colorize formatted table.
Colorize columnwise by default.
"""
self._colorize = Colorize.ROWWISE if rowwise else Colorize.COLUMNWISE
def highlight_warnings(self):
"""Enables warning highlighting when building formatted table."""
self._highlight_warnings = True
def print(self):
"""Print formatted table"""
print(str(self))
def _render(self):