mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: uncomment lines in the pyrefly.toml file step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/356645cf8cfe33123d9a27f23b30f7b1 after: 0 errors (2,753 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164615 Approved by: https://github.com/oulgen
176 lines
4.9 KiB
Python
176 lines
4.9 KiB
Python
# mypy: allow-untyped-defs
|
|
import inspect
|
|
import re
|
|
import string
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Any, Optional
|
|
from types import ModuleType
|
|
|
|
import torch
|
|
|
|
_TAGS: dict[str, dict[str, Any]] = {
|
|
"torch": {
|
|
"cond": {},
|
|
"dynamic-shape": {},
|
|
"escape-hatch": {},
|
|
"map": {},
|
|
"dynamic-value": {},
|
|
"operator": {},
|
|
"mutation": {},
|
|
},
|
|
"python": {
|
|
"assert": {},
|
|
"builtin": {},
|
|
"closure": {},
|
|
"context-manager": {},
|
|
"control-flow": {},
|
|
"data-structure": {},
|
|
"standard-library": {},
|
|
"object-model": {},
|
|
},
|
|
}
|
|
|
|
|
|
class SupportLevel(Enum):
|
|
"""
|
|
Indicates at what stage the feature
|
|
used in the example is handled in export.
|
|
"""
|
|
|
|
SUPPORTED = 1
|
|
NOT_SUPPORTED_YET = 0
|
|
|
|
|
|
ArgsType = tuple[Any, ...]
|
|
|
|
|
|
def check_inputs_type(args, kwargs):
|
|
if not isinstance(args, tuple):
|
|
raise ValueError(
|
|
f"Expecting args type to be a tuple, got: {type(args)}"
|
|
)
|
|
if not isinstance(kwargs, dict):
|
|
raise ValueError(
|
|
f"Expecting kwargs type to be a dict, got: {type(kwargs)}"
|
|
)
|
|
for key in kwargs:
|
|
if not isinstance(key, str):
|
|
raise ValueError(
|
|
f"Expecting kwargs keys to be a string, got: {type(key)}"
|
|
)
|
|
|
|
def _validate_tag(tag: str):
|
|
parts = tag.split(".")
|
|
t = _TAGS
|
|
for part in parts:
|
|
assert set(part) <= set(
|
|
string.ascii_lowercase + "-"
|
|
), f"Tag contains invalid characters: {part}"
|
|
if part in t:
|
|
t = t[part]
|
|
else:
|
|
raise ValueError(f"Tag {tag} is not found in registered tags.")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExportCase:
|
|
example_args: ArgsType
|
|
description: str # A description of the use case.
|
|
model: torch.nn.Module
|
|
name: str
|
|
example_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
extra_args: Optional[ArgsType] = None # For testing graph generalization.
|
|
# Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
|
|
tags: set[str] = field(default_factory=set)
|
|
support_level: SupportLevel = SupportLevel.SUPPORTED
|
|
dynamic_shapes: Optional[dict[str, Any]] = None
|
|
|
|
def __post_init__(self):
|
|
check_inputs_type(self.example_args, self.example_kwargs)
|
|
if self.extra_args is not None:
|
|
check_inputs_type(self.extra_args, {})
|
|
|
|
for tag in self.tags:
|
|
_validate_tag(tag)
|
|
|
|
if not isinstance(self.description, str) or len(self.description) == 0:
|
|
raise ValueError(f'Invalid description: "{self.description}"')
|
|
|
|
|
|
_EXAMPLE_CASES: dict[str, ExportCase] = {}
|
|
_MODULES: set[ModuleType] = set()
|
|
_EXAMPLE_CONFLICT_CASES: dict[str, list[ExportCase]] = {}
|
|
_EXAMPLE_REWRITE_CASES: dict[str, list[ExportCase]] = {}
|
|
|
|
|
|
def register_db_case(case: ExportCase) -> None:
|
|
"""
|
|
Registers a user provided ExportCase into example bank.
|
|
"""
|
|
if case.name in _EXAMPLE_CASES:
|
|
if case.name not in _EXAMPLE_CONFLICT_CASES:
|
|
_EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]]
|
|
_EXAMPLE_CONFLICT_CASES[case.name].append(case)
|
|
return
|
|
|
|
_EXAMPLE_CASES[case.name] = case
|
|
|
|
|
|
def to_snake_case(name):
|
|
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
|
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
|
|
|
|
|
def _make_export_case(m, name, configs):
|
|
if not isinstance(m, torch.nn.Module):
|
|
raise TypeError("Export case class should be a torch.nn.Module.")
|
|
|
|
if "description" not in configs:
|
|
# Fallback to docstring if description is missing.
|
|
assert (
|
|
m.__doc__ is not None
|
|
), f"Could not find description or docstring for export case: {m}"
|
|
configs = {**configs, "description": m.__doc__}
|
|
# pyrefly: ignore # bad-argument-type
|
|
return ExportCase(**{**configs, "model": m, "name": name})
|
|
|
|
|
|
def export_case(**kwargs):
|
|
"""
|
|
Decorator for registering a user provided case into example bank.
|
|
"""
|
|
|
|
def wrapper(m):
|
|
configs = kwargs
|
|
module = inspect.getmodule(m)
|
|
if module in _MODULES:
|
|
raise RuntimeError("export_case should only be used once per example file.")
|
|
|
|
assert module is not None
|
|
_MODULES.add(module)
|
|
module_name = module.__name__.split(".")[-1]
|
|
case = _make_export_case(m, module_name, configs)
|
|
register_db_case(case)
|
|
return case
|
|
|
|
return wrapper
|
|
|
|
|
|
def export_rewrite_case(**kwargs):
|
|
def wrapper(m):
|
|
configs = kwargs
|
|
|
|
parent = configs.pop("parent")
|
|
assert isinstance(parent, ExportCase)
|
|
key = parent.name
|
|
if key not in _EXAMPLE_REWRITE_CASES:
|
|
_EXAMPLE_REWRITE_CASES[key] = []
|
|
|
|
configs["example_args"] = parent.example_args
|
|
case = _make_export_case(m, to_snake_case(m.__name__), configs)
|
|
_EXAMPLE_REWRITE_CASES[key].append(case)
|
|
return case
|
|
|
|
return wrapper
|