mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126 Approved by: https://github.com/kit1980
103 lines
3.5 KiB
Python
103 lines
3.5 KiB
Python
import atexit
|
|
import re
|
|
import shutil
|
|
import textwrap
|
|
from typing import List, Optional, Tuple
|
|
|
|
from core.api import GroupedBenchmark, TimerArgs
|
|
from core.types import Definition, FlatIntermediateDefinition, Label
|
|
|
|
from torch.utils.benchmark.utils.common import _make_temp_dir
|
|
|
|
|
|
_TEMPDIR: Optional[str] = None
|
|
|
|
|
|
def get_temp_dir() -> str:
|
|
global _TEMPDIR
|
|
if _TEMPDIR is None:
|
|
_TEMPDIR = _make_temp_dir(
|
|
prefix="instruction_count_microbenchmarks", gc_dev_shm=True
|
|
)
|
|
atexit.register(shutil.rmtree, path=_TEMPDIR)
|
|
return _TEMPDIR
|
|
|
|
|
|
def _flatten(
|
|
key_prefix: Label, sub_schema: Definition, result: FlatIntermediateDefinition
|
|
) -> None:
|
|
for k, value in sub_schema.items():
|
|
if isinstance(k, tuple):
|
|
assert all(isinstance(ki, str) for ki in k)
|
|
key_suffix: Label = k
|
|
elif k is None:
|
|
key_suffix = ()
|
|
else:
|
|
assert isinstance(k, str)
|
|
key_suffix = (k,)
|
|
|
|
key: Label = key_prefix + key_suffix
|
|
if isinstance(value, (TimerArgs, GroupedBenchmark)):
|
|
assert key not in result, f"duplicate key: {key}"
|
|
result[key] = value
|
|
else:
|
|
assert isinstance(value, dict)
|
|
_flatten(key_prefix=key, sub_schema=value, result=result)
|
|
|
|
|
|
def flatten(schema: Definition) -> FlatIntermediateDefinition:
|
|
"""See types.py for an explanation of nested vs. flat definitions."""
|
|
result: FlatIntermediateDefinition = {}
|
|
_flatten(key_prefix=(), sub_schema=schema, result=result)
|
|
|
|
# Ensure that we produced a valid flat definition.
|
|
for k, v in result.items():
|
|
assert isinstance(k, tuple)
|
|
assert all(isinstance(ki, str) for ki in k)
|
|
assert isinstance(v, (TimerArgs, GroupedBenchmark))
|
|
return result
|
|
|
|
|
|
def parse_stmts(stmts: str) -> Tuple[str, str]:
|
|
"""Helper function for side-by-side Python and C++ stmts.
|
|
|
|
For more complex statements, it can be useful to see Python and C++ code
|
|
side by side. To this end, we provide an **extremely restricted** way
|
|
to define Python and C++ code side-by-side. The schema should be mostly
|
|
self explanatory, with the following non-obvious caveats:
|
|
- Width for the left (Python) column MUST be 40 characters.
|
|
- The column separator is " | ", not "|". Whitespace matters.
|
|
"""
|
|
stmts = textwrap.dedent(stmts).strip()
|
|
lines: List[str] = stmts.splitlines(keepends=False)
|
|
assert len(lines) >= 3, f"Invalid string:\n{stmts}"
|
|
|
|
column_header_pattern = r"^Python\s{35}\| C\+\+(\s*)$"
|
|
signature_pattern = r"^: f\((.*)\)( -> (.+))?\s*$"
|
|
separation_pattern = r"^[-]{40} | [-]{40}$"
|
|
code_pattern = r"^(.{40}) \|($| (.*)$)"
|
|
|
|
column_match = re.search(column_header_pattern, lines[0])
|
|
if column_match is None:
|
|
raise ValueError(
|
|
f"Column header `{lines[0]}` "
|
|
f"does not match pattern `{column_header_pattern}`"
|
|
)
|
|
|
|
assert re.search(separation_pattern, lines[1])
|
|
|
|
py_lines: List[str] = []
|
|
cpp_lines: List[str] = []
|
|
for l in lines[2:]:
|
|
l_match = re.search(code_pattern, l)
|
|
if l_match is None:
|
|
raise ValueError(f"Invalid line `{l}`")
|
|
py_lines.append(l_match.groups()[0])
|
|
cpp_lines.append(l_match.groups()[2] or "")
|
|
|
|
# Make sure we can round trip for correctness.
|
|
l_from_stmts = f"{py_lines[-1]:<40} | {cpp_lines[-1]:<40}".rstrip()
|
|
assert l_from_stmts == l.rstrip(), f"Failed to round trip `{l}`"
|
|
|
|
return "\n".join(py_lines), "\n".join(cpp_lines)
|