Files
pytorch/torch/_export/db/case.py
Maggie Moss 4ab847bbc7 Pyrefly suppressions 4/n (#164615)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: uncomment lines in the pyrefly.toml file
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/356645cf8cfe33123d9a27f23b30f7b1

after:

0 errors (2,753 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164615
Approved by: https://github.com/oulgen
2025-10-06 16:14:36 +00:00

176 lines
4.9 KiB
Python

# mypy: allow-untyped-defs
import inspect
import re
import string
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Optional
from types import ModuleType
import torch
_TAGS: dict[str, dict[str, Any]] = {
"torch": {
"cond": {},
"dynamic-shape": {},
"escape-hatch": {},
"map": {},
"dynamic-value": {},
"operator": {},
"mutation": {},
},
"python": {
"assert": {},
"builtin": {},
"closure": {},
"context-manager": {},
"control-flow": {},
"data-structure": {},
"standard-library": {},
"object-model": {},
},
}
class SupportLevel(Enum):
"""
Indicates at what stage the feature
used in the example is handled in export.
"""
SUPPORTED = 1
NOT_SUPPORTED_YET = 0
ArgsType = tuple[Any, ...]
def check_inputs_type(args, kwargs):
if not isinstance(args, tuple):
raise ValueError(
f"Expecting args type to be a tuple, got: {type(args)}"
)
if not isinstance(kwargs, dict):
raise ValueError(
f"Expecting kwargs type to be a dict, got: {type(kwargs)}"
)
for key in kwargs:
if not isinstance(key, str):
raise ValueError(
f"Expecting kwargs keys to be a string, got: {type(key)}"
)
def _validate_tag(tag: str):
parts = tag.split(".")
t = _TAGS
for part in parts:
assert set(part) <= set(
string.ascii_lowercase + "-"
), f"Tag contains invalid characters: {part}"
if part in t:
t = t[part]
else:
raise ValueError(f"Tag {tag} is not found in registered tags.")
@dataclass(frozen=True)
class ExportCase:
example_args: ArgsType
description: str # A description of the use case.
model: torch.nn.Module
name: str
example_kwargs: dict[str, Any] = field(default_factory=dict)
extra_args: Optional[ArgsType] = None # For testing graph generalization.
# Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
tags: set[str] = field(default_factory=set)
support_level: SupportLevel = SupportLevel.SUPPORTED
dynamic_shapes: Optional[dict[str, Any]] = None
def __post_init__(self):
check_inputs_type(self.example_args, self.example_kwargs)
if self.extra_args is not None:
check_inputs_type(self.extra_args, {})
for tag in self.tags:
_validate_tag(tag)
if not isinstance(self.description, str) or len(self.description) == 0:
raise ValueError(f'Invalid description: "{self.description}"')
_EXAMPLE_CASES: dict[str, ExportCase] = {}
_MODULES: set[ModuleType] = set()
_EXAMPLE_CONFLICT_CASES: dict[str, list[ExportCase]] = {}
_EXAMPLE_REWRITE_CASES: dict[str, list[ExportCase]] = {}
def register_db_case(case: ExportCase) -> None:
"""
Registers a user provided ExportCase into example bank.
"""
if case.name in _EXAMPLE_CASES:
if case.name not in _EXAMPLE_CONFLICT_CASES:
_EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]]
_EXAMPLE_CONFLICT_CASES[case.name].append(case)
return
_EXAMPLE_CASES[case.name] = case
def to_snake_case(name):
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
def _make_export_case(m, name, configs):
if not isinstance(m, torch.nn.Module):
raise TypeError("Export case class should be a torch.nn.Module.")
if "description" not in configs:
# Fallback to docstring if description is missing.
assert (
m.__doc__ is not None
), f"Could not find description or docstring for export case: {m}"
configs = {**configs, "description": m.__doc__}
# pyrefly: ignore # bad-argument-type
return ExportCase(**{**configs, "model": m, "name": name})
def export_case(**kwargs):
"""
Decorator for registering a user provided case into example bank.
"""
def wrapper(m):
configs = kwargs
module = inspect.getmodule(m)
if module in _MODULES:
raise RuntimeError("export_case should only be used once per example file.")
assert module is not None
_MODULES.add(module)
module_name = module.__name__.split(".")[-1]
case = _make_export_case(m, module_name, configs)
register_db_case(case)
return case
return wrapper
def export_rewrite_case(**kwargs):
def wrapper(m):
configs = kwargs
parent = configs.pop("parent")
assert isinstance(parent, ExportCase)
key = parent.name
if key not in _EXAMPLE_REWRITE_CASES:
_EXAMPLE_REWRITE_CASES[key] = []
configs["example_args"] = parent.example_args
case = _make_export_case(m, to_snake_case(m.__name__), configs)
_EXAMPLE_REWRITE_CASES[key].append(case)
return case
return wrapper