Files
pytorch/benchmarks/instruction_counts/core/utils.py
Xuehai Pan 26f4f10ac8 [5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
2024-05-27 14:49:57 +00:00

103 lines
3.5 KiB
Python

import atexit
import re
import shutil
import textwrap
from typing import List, Optional, Tuple
from core.api import GroupedBenchmark, TimerArgs
from core.types import Definition, FlatIntermediateDefinition, Label
from torch.utils.benchmark.utils.common import _make_temp_dir
_TEMPDIR: Optional[str] = None
def get_temp_dir() -> str:
global _TEMPDIR
if _TEMPDIR is None:
_TEMPDIR = _make_temp_dir(
prefix="instruction_count_microbenchmarks", gc_dev_shm=True
)
atexit.register(shutil.rmtree, path=_TEMPDIR)
return _TEMPDIR
def _flatten(
key_prefix: Label, sub_schema: Definition, result: FlatIntermediateDefinition
) -> None:
for k, value in sub_schema.items():
if isinstance(k, tuple):
assert all(isinstance(ki, str) for ki in k)
key_suffix: Label = k
elif k is None:
key_suffix = ()
else:
assert isinstance(k, str)
key_suffix = (k,)
key: Label = key_prefix + key_suffix
if isinstance(value, (TimerArgs, GroupedBenchmark)):
assert key not in result, f"duplicate key: {key}"
result[key] = value
else:
assert isinstance(value, dict)
_flatten(key_prefix=key, sub_schema=value, result=result)
def flatten(schema: Definition) -> FlatIntermediateDefinition:
"""See types.py for an explanation of nested vs. flat definitions."""
result: FlatIntermediateDefinition = {}
_flatten(key_prefix=(), sub_schema=schema, result=result)
# Ensure that we produced a valid flat definition.
for k, v in result.items():
assert isinstance(k, tuple)
assert all(isinstance(ki, str) for ki in k)
assert isinstance(v, (TimerArgs, GroupedBenchmark))
return result
def parse_stmts(stmts: str) -> Tuple[str, str]:
"""Helper function for side-by-side Python and C++ stmts.
For more complex statements, it can be useful to see Python and C++ code
side by side. To this end, we provide an **extremely restricted** way
to define Python and C++ code side-by-side. The schema should be mostly
self explanatory, with the following non-obvious caveats:
- Width for the left (Python) column MUST be 40 characters.
- The column separator is " | ", not "|". Whitespace matters.
"""
stmts = textwrap.dedent(stmts).strip()
lines: List[str] = stmts.splitlines(keepends=False)
assert len(lines) >= 3, f"Invalid string:\n{stmts}"
column_header_pattern = r"^Python\s{35}\| C\+\+(\s*)$"
signature_pattern = r"^: f\((.*)\)( -> (.+))?\s*$"
separation_pattern = r"^[-]{40} | [-]{40}$"
code_pattern = r"^(.{40}) \|($| (.*)$)"
column_match = re.search(column_header_pattern, lines[0])
if column_match is None:
raise ValueError(
f"Column header `{lines[0]}` "
f"does not match pattern `{column_header_pattern}`"
)
assert re.search(separation_pattern, lines[1])
py_lines: List[str] = []
cpp_lines: List[str] = []
for l in lines[2:]:
l_match = re.search(code_pattern, l)
if l_match is None:
raise ValueError(f"Invalid line `{l}`")
py_lines.append(l_match.groups()[0])
cpp_lines.append(l_match.groups()[2] or "")
# Make sure we can round trip for correctness.
l_from_stmts = f"{py_lines[-1]:<40} | {cpp_lines[-1]:<40}".rstrip()
assert l_from_stmts == l.rstrip(), f"Failed to round trip `{l}`"
return "\n".join(py_lines), "\n".join(cpp_lines)