mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adding Compare in torch.utils.benchmark documentation (#125009)
`torch.utils.benchmark.Compare` is not directly exposed in torch.utils.benchmark documentation. I think this is a valuable resource to add since it can help people embracing the torch benchmark way of doing things, and help people building documentation towards it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125009 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
4440d0755a
commit
d18a6f46d0
@ -19,6 +19,9 @@ Benchmark Utils - torch.utils.benchmark
|
||||
.. autoclass:: FunctionCounts
|
||||
:members:
|
||||
|
||||
.. autoclass:: Compare
|
||||
:members:
|
||||
|
||||
.. These are missing documentation. Adding them here until a better place
|
||||
.. is made in this file.
|
||||
.. py:module:: torch.utils.benchmark.examples
|
||||
|
@ -267,6 +267,21 @@ Times are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}).
|
||||
|
||||
|
||||
class Compare:
|
||||
"""Helper class for displaying the results of many measurements in a
|
||||
formatted table.
|
||||
|
||||
The table format is based on the information fields provided in
|
||||
:class:`torch.utils.benchmark.Timer` (`description`, `label`, `sub_label`,
|
||||
`num_threads`, etc).
|
||||
|
||||
The table can be directly printed using :meth:`print` or casted as a `str`.
|
||||
|
||||
For a full tutorial on how to use this class, see:
|
||||
https://pytorch.org/tutorials/recipes/recipes/benchmark.html
|
||||
|
||||
Args:
|
||||
results: List of Measurment to display.
|
||||
"""
|
||||
def __init__(self, results: List[common.Measurement]):
|
||||
self._results: List[common.Measurement] = []
|
||||
self.extend_results(results)
|
||||
@ -278,6 +293,10 @@ class Compare:
|
||||
return "\n".join(self._render())
|
||||
|
||||
def extend_results(self, results):
|
||||
"""Append results to already stored ones.
|
||||
|
||||
All added results must be instances of ``Measurement``.
|
||||
"""
|
||||
for r in results:
|
||||
if not isinstance(r, common.Measurement):
|
||||
raise ValueError(
|
||||
@ -286,15 +305,22 @@ class Compare:
|
||||
self._results.extend(results)
|
||||
|
||||
def trim_significant_figures(self):
|
||||
"""Enables trimming of significant figures when building the formatted table."""
|
||||
self._trim_significant_figures = True
|
||||
|
||||
def colorize(self, rowwise=False):
|
||||
"""Colorize formatted table.
|
||||
|
||||
Colorize columnwise by default.
|
||||
"""
|
||||
self._colorize = Colorize.ROWWISE if rowwise else Colorize.COLUMNWISE
|
||||
|
||||
def highlight_warnings(self):
|
||||
"""Enables warning highlighting when building formatted table."""
|
||||
self._highlight_warnings = True
|
||||
|
||||
def print(self):
|
||||
"""Print formatted table"""
|
||||
print(str(self))
|
||||
|
||||
def _render(self):
|
||||
|
Reference in New Issue
Block a user