mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
remove allow-untyped-defs from context.py (#155622)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155622 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
d9799a2ee7
commit
39c605e8b3
@ -1,8 +1,13 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
import torch._decomp
|
||||
@ -15,8 +20,12 @@ import torch.overrides
|
||||
from torch._prims_common import torch_function_passthrough
|
||||
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
@functools.cache
|
||||
def torch_to_refs_map():
|
||||
def torch_to_refs_map() -> dict[Any, Any]:
|
||||
"""
|
||||
Mapping of torch API functions to torch._refs functions.
|
||||
E.g. torch_to_refs_map()[torch.add] == torch._refs.add
|
||||
@ -71,7 +80,7 @@ def torch_to_refs_map():
|
||||
|
||||
|
||||
@functools.cache
|
||||
def all_prims():
|
||||
def all_prims() -> set[Any]:
|
||||
"""
|
||||
Set of all prim functions, e.g., torch._prims.add in all_prims()
|
||||
"""
|
||||
@ -95,21 +104,21 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strict=False,
|
||||
should_fallback_fn=lambda *_: False,
|
||||
prims_mode_cls=nullcontext,
|
||||
):
|
||||
strict: bool = False,
|
||||
should_fallback_fn: Callable[..., bool] = lambda *_: False,
|
||||
prims_mode_cls: type = nullcontext,
|
||||
) -> None:
|
||||
self.strict = strict
|
||||
self.should_fallback_fn = should_fallback_fn
|
||||
self.prims_mode_cls = prims_mode_cls
|
||||
|
||||
def __torch_function__(
|
||||
self,
|
||||
orig_func: Callable,
|
||||
types: Sequence,
|
||||
orig_func: Callable[_P, _R],
|
||||
types: Sequence[type],
|
||||
args: Sequence[Any] = (),
|
||||
kwargs: Optional[dict] = None,
|
||||
):
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
# For primitive operations, run them as is without interception
|
||||
|
Reference in New Issue
Block a user