mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
This is one of a series of PRs to update us to PEP585 (changing Dict -> dict, List -> list, etc). Most of the PRs were completely automated with RUFF as follows: Since RUFF UP006 is considered an "unsafe" fix first we need to enable unsafe fixes: ``` --- a/tools/linter/adapters/ruff_linter.py +++ b/tools/linter/adapters/ruff_linter.py @@ -313,6 +313,7 @@ "ruff", "check", "--fix-only", + "--unsafe-fixes", "--exit-zero", *([f"--config={config}"] if config else []), "--stdin-filename", ``` Then we need to tell RUFF to allow UP006 (as a final PR once all of these have landed this will be made permanent): ``` --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ [tool.ruff] -target-version = "py38" +target-version = "py39" line-length = 88 src = ["caffe2", "torch", "torchgen", "functorch", "test"] @@ -87,7 +87,6 @@ "SIM116", # Disable Use a dictionary instead of consecutive `if` statements "SIM117", "SIM118", - "UP006", # keep-runtime-typing "UP007", # keep-runtime-typing ] select = [ ``` Finally running `lintrunner -a --take RUFF` will fix up the deprecated uses. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145101 Approved by: https://github.com/bobrenjc93
104 lines
3.5 KiB
Python
104 lines
3.5 KiB
Python
# mypy: ignore-errors
|
|
import atexit
|
|
import re
|
|
import shutil
|
|
import textwrap
|
|
from typing import Optional
|
|
|
|
from core.api import GroupedBenchmark, TimerArgs
|
|
from core.types import Definition, FlatIntermediateDefinition, Label
|
|
|
|
from torch.utils.benchmark.utils.common import _make_temp_dir
|
|
|
|
|
|
_TEMPDIR: Optional[str] = None
|
|
|
|
|
|
def get_temp_dir() -> str:
|
|
global _TEMPDIR
|
|
if _TEMPDIR is None:
|
|
_TEMPDIR = _make_temp_dir(
|
|
prefix="instruction_count_microbenchmarks", gc_dev_shm=True
|
|
)
|
|
atexit.register(shutil.rmtree, path=_TEMPDIR)
|
|
return _TEMPDIR
|
|
|
|
|
|
def _flatten(
|
|
key_prefix: Label, sub_schema: Definition, result: FlatIntermediateDefinition
|
|
) -> None:
|
|
for k, value in sub_schema.items():
|
|
if isinstance(k, tuple):
|
|
assert all(isinstance(ki, str) for ki in k)
|
|
key_suffix: Label = k
|
|
elif k is None:
|
|
key_suffix = ()
|
|
else:
|
|
assert isinstance(k, str)
|
|
key_suffix = (k,)
|
|
|
|
key: Label = key_prefix + key_suffix
|
|
if isinstance(value, (TimerArgs, GroupedBenchmark)):
|
|
assert key not in result, f"duplicate key: {key}"
|
|
result[key] = value
|
|
else:
|
|
assert isinstance(value, dict)
|
|
_flatten(key_prefix=key, sub_schema=value, result=result)
|
|
|
|
|
|
def flatten(schema: Definition) -> FlatIntermediateDefinition:
|
|
"""See types.py for an explanation of nested vs. flat definitions."""
|
|
result: FlatIntermediateDefinition = {}
|
|
_flatten(key_prefix=(), sub_schema=schema, result=result)
|
|
|
|
# Ensure that we produced a valid flat definition.
|
|
for k, v in result.items():
|
|
assert isinstance(k, tuple)
|
|
assert all(isinstance(ki, str) for ki in k)
|
|
assert isinstance(v, (TimerArgs, GroupedBenchmark))
|
|
return result
|
|
|
|
|
|
def parse_stmts(stmts: str) -> tuple[str, str]:
|
|
"""Helper function for side-by-side Python and C++ stmts.
|
|
|
|
For more complex statements, it can be useful to see Python and C++ code
|
|
side by side. To this end, we provide an **extremely restricted** way
|
|
to define Python and C++ code side-by-side. The schema should be mostly
|
|
self explanatory, with the following non-obvious caveats:
|
|
- Width for the left (Python) column MUST be 40 characters.
|
|
- The column separator is " | ", not "|". Whitespace matters.
|
|
"""
|
|
stmts = textwrap.dedent(stmts).strip()
|
|
lines: list[str] = stmts.splitlines(keepends=False)
|
|
assert len(lines) >= 3, f"Invalid string:\n{stmts}"
|
|
|
|
column_header_pattern = r"^Python\s{35}\| C\+\+(\s*)$"
|
|
signature_pattern = r"^: f\((.*)\)( -> (.+))?\s*$" # noqa: F841
|
|
separation_pattern = r"^[-]{40} | [-]{40}$"
|
|
code_pattern = r"^(.{40}) \|($| (.*)$)"
|
|
|
|
column_match = re.search(column_header_pattern, lines[0])
|
|
if column_match is None:
|
|
raise ValueError(
|
|
f"Column header `{lines[0]}` "
|
|
f"does not match pattern `{column_header_pattern}`"
|
|
)
|
|
|
|
assert re.search(separation_pattern, lines[1])
|
|
|
|
py_lines: list[str] = []
|
|
cpp_lines: list[str] = []
|
|
for l in lines[2:]:
|
|
l_match = re.search(code_pattern, l)
|
|
if l_match is None:
|
|
raise ValueError(f"Invalid line `{l}`")
|
|
py_lines.append(l_match.groups()[0])
|
|
cpp_lines.append(l_match.groups()[2] or "")
|
|
|
|
# Make sure we can round trip for correctness.
|
|
l_from_stmts = f"{py_lines[-1]:<40} | {cpp_lines[-1]:<40}".rstrip()
|
|
assert l_from_stmts == l.rstrip(), f"Failed to round trip `{l}`"
|
|
|
|
return "\n".join(py_lines), "\n".join(cpp_lines)
|