Files
pytorch/benchmarks/instruction_counts/core/utils.py
Aaron Orenstein 07669ed960 PEP585 update - benchmarks tools torchgen (#145101)
This is one of a series of PRs to update us to PEP585 (changing Dict -> dict, List -> list, etc).  Most of the PRs were completely automated with RUFF as follows:

Since RUFF UP006 is considered an "unsafe" fix first we need to enable unsafe fixes:

```
--- a/tools/linter/adapters/ruff_linter.py
+++ b/tools/linter/adapters/ruff_linter.py
@@ -313,6 +313,7 @@
                     "ruff",
                     "check",
                     "--fix-only",
+                    "--unsafe-fixes",
                     "--exit-zero",
                     *([f"--config={config}"] if config else []),
                     "--stdin-filename",
```

Then we need to tell RUFF to allow UP006 (as a final PR once all of these have landed this will be made permanent):

```
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,7 +40,7 @@

 [tool.ruff]
-target-version = "py38"
+target-version = "py39"
 line-length = 88
 src = ["caffe2", "torch", "torchgen", "functorch", "test"]

@@ -87,7 +87,6 @@
     "SIM116", # Disable Use a dictionary instead of consecutive `if` statements
     "SIM117",
     "SIM118",
-    "UP006", # keep-runtime-typing
     "UP007", # keep-runtime-typing
 ]
 select = [
```

Finally running `lintrunner -a --take RUFF` will fix up the deprecated uses.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145101
Approved by: https://github.com/bobrenjc93
2025-01-18 05:05:07 +00:00

104 lines
3.5 KiB
Python

# mypy: ignore-errors
import atexit
import re
import shutil
import textwrap
from typing import Optional
from core.api import GroupedBenchmark, TimerArgs
from core.types import Definition, FlatIntermediateDefinition, Label
from torch.utils.benchmark.utils.common import _make_temp_dir
_TEMPDIR: Optional[str] = None
def get_temp_dir() -> str:
global _TEMPDIR
if _TEMPDIR is None:
_TEMPDIR = _make_temp_dir(
prefix="instruction_count_microbenchmarks", gc_dev_shm=True
)
atexit.register(shutil.rmtree, path=_TEMPDIR)
return _TEMPDIR
def _flatten(
key_prefix: Label, sub_schema: Definition, result: FlatIntermediateDefinition
) -> None:
for k, value in sub_schema.items():
if isinstance(k, tuple):
assert all(isinstance(ki, str) for ki in k)
key_suffix: Label = k
elif k is None:
key_suffix = ()
else:
assert isinstance(k, str)
key_suffix = (k,)
key: Label = key_prefix + key_suffix
if isinstance(value, (TimerArgs, GroupedBenchmark)):
assert key not in result, f"duplicate key: {key}"
result[key] = value
else:
assert isinstance(value, dict)
_flatten(key_prefix=key, sub_schema=value, result=result)
def flatten(schema: Definition) -> FlatIntermediateDefinition:
"""See types.py for an explanation of nested vs. flat definitions."""
result: FlatIntermediateDefinition = {}
_flatten(key_prefix=(), sub_schema=schema, result=result)
# Ensure that we produced a valid flat definition.
for k, v in result.items():
assert isinstance(k, tuple)
assert all(isinstance(ki, str) for ki in k)
assert isinstance(v, (TimerArgs, GroupedBenchmark))
return result
def parse_stmts(stmts: str) -> tuple[str, str]:
"""Helper function for side-by-side Python and C++ stmts.
For more complex statements, it can be useful to see Python and C++ code
side by side. To this end, we provide an **extremely restricted** way
to define Python and C++ code side-by-side. The schema should be mostly
self explanatory, with the following non-obvious caveats:
- Width for the left (Python) column MUST be 40 characters.
- The column separator is " | ", not "|". Whitespace matters.
"""
stmts = textwrap.dedent(stmts).strip()
lines: list[str] = stmts.splitlines(keepends=False)
assert len(lines) >= 3, f"Invalid string:\n{stmts}"
column_header_pattern = r"^Python\s{35}\| C\+\+(\s*)$"
signature_pattern = r"^: f\((.*)\)( -> (.+))?\s*$" # noqa: F841
separation_pattern = r"^[-]{40} | [-]{40}$"
code_pattern = r"^(.{40}) \|($| (.*)$)"
column_match = re.search(column_header_pattern, lines[0])
if column_match is None:
raise ValueError(
f"Column header `{lines[0]}` "
f"does not match pattern `{column_header_pattern}`"
)
assert re.search(separation_pattern, lines[1])
py_lines: list[str] = []
cpp_lines: list[str] = []
for l in lines[2:]:
l_match = re.search(code_pattern, l)
if l_match is None:
raise ValueError(f"Invalid line `{l}`")
py_lines.append(l_match.groups()[0])
cpp_lines.append(l_match.groups()[2] or "")
# Make sure we can round trip for correctness.
l_from_stmts = f"{py_lines[-1]:<40} | {cpp_lines[-1]:<40}".rstrip()
assert l_from_stmts == l.rstrip(), f"Failed to round trip `{l}`"
return "\n".join(py_lines), "\n".join(cpp_lines)