remove allow-untyped-defs from context.py (#155622)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155622
Approved by: https://github.com/Skylion007
This commit is contained in:
Bob Ren
2025-06-15 21:26:44 -07:00
committed by PyTorch MergeBot
parent d9799a2ee7
commit 39c605e8b3

View File

@ -1,8 +1,13 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import functools
from collections.abc import Sequence
from contextlib import nullcontext
from typing import Any, Callable, Optional
from typing import Any, Callable, TYPE_CHECKING, TypeVar
from typing_extensions import ParamSpec
if TYPE_CHECKING:
from collections.abc import Sequence
import torch
import torch._decomp
@ -15,8 +20,12 @@ import torch.overrides
from torch._prims_common import torch_function_passthrough
_P = ParamSpec("_P")
_R = TypeVar("_R")
@functools.cache
def torch_to_refs_map():
def torch_to_refs_map() -> dict[Any, Any]:
"""
Mapping of torch API functions to torch._refs functions.
E.g. torch_to_refs_map()[torch.add] == torch._refs.add
@ -71,7 +80,7 @@ def torch_to_refs_map():
@functools.cache
def all_prims():
def all_prims() -> set[Any]:
"""
Set of all prim functions, e.g., torch._prims.add in all_prims()
"""
@ -95,21 +104,21 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
def __init__(
self,
strict=False,
should_fallback_fn=lambda *_: False,
prims_mode_cls=nullcontext,
):
strict: bool = False,
should_fallback_fn: Callable[..., bool] = lambda *_: False,
prims_mode_cls: type = nullcontext,
) -> None:
self.strict = strict
self.should_fallback_fn = should_fallback_fn
self.prims_mode_cls = prims_mode_cls
def __torch_function__(
self,
orig_func: Callable,
types: Sequence,
orig_func: Callable[_P, _R],
types: Sequence[type],
args: Sequence[Any] = (),
kwargs: Optional[dict] = None,
):
kwargs: dict[str, Any] | None = None,
) -> Any:
if kwargs is None:
kwargs = {}
# For primitive operations, run them as is without interception