diff --git a/docs/source/benchmark_utils.rst b/docs/source/benchmark_utils.rst index c93fbfd66c3d..7546179c503f 100644 --- a/docs/source/benchmark_utils.rst +++ b/docs/source/benchmark_utils.rst @@ -19,6 +19,9 @@ Benchmark Utils - torch.utils.benchmark .. autoclass:: FunctionCounts :members: +.. autoclass:: Compare + :members: + .. These are missing documentation. Adding them here until a better place .. is made in this file. .. py:module:: torch.utils.benchmark.examples diff --git a/torch/utils/benchmark/utils/compare.py b/torch/utils/benchmark/utils/compare.py index 337b742ca069..20122df66718 100644 --- a/torch/utils/benchmark/utils/compare.py +++ b/torch/utils/benchmark/utils/compare.py @@ -267,6 +267,21 @@ Times are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}). class Compare: + """Helper class for displaying the results of many measurements in a + formatted table. + + The table format is based on the information fields provided in + :class:`torch.utils.benchmark.Timer` (`description`, `label`, `sub_label`, + `num_threads`, etc). + + The table can be directly printed using :meth:`print` or casted as a `str`. + + For a full tutorial on how to use this class, see: + https://pytorch.org/tutorials/recipes/recipes/benchmark.html + + Args: + results: List of Measurment to display. + """ def __init__(self, results: List[common.Measurement]): self._results: List[common.Measurement] = [] self.extend_results(results) @@ -278,6 +293,10 @@ class Compare: return "\n".join(self._render()) def extend_results(self, results): + """Append results to already stored ones. + + All added results must be instances of ``Measurement``. + """ for r in results: if not isinstance(r, common.Measurement): raise ValueError( @@ -286,15 +305,22 @@ class Compare: self._results.extend(results) def trim_significant_figures(self): + """Enables trimming of significant figures when building the formatted table.""" self._trim_significant_figures = True def colorize(self, rowwise=False): + """Colorize formatted table. + + Colorize columnwise by default. + """ self._colorize = Colorize.ROWWISE if rowwise else Colorize.COLUMNWISE def highlight_warnings(self): + """Enables warning highlighting when building formatted table.""" self._highlight_warnings = True def print(self): + """Print formatted table""" print(str(self)) def _render(self):