Propagate callable parameter types using ParamSpec (#142306) (#143797)

The codebase has a few locations where callable parameter type information is lost when the unpackings *args and **kwargs are typed as Any. Refactor these instances to retain type information using typing_extensions.ParamSpec.

Also, in these functions, enforce return type with TypeVar.

Addresses #142306

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143797
Approved by: https://github.com/Skylion007

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
This commit is contained in:
Kasperi Apell
2024-12-29 23:03:14 +00:00
committed by PyTorch MergeBot
parent 79b354ee37
commit a7915c56f6
8 changed files with 90 additions and 39 deletions

View File

@ -20,6 +20,7 @@ from typing import (
TypeVar,
Union,
)
from typing_extensions import ParamSpec
from unittest.mock import patch
import torch
@ -51,6 +52,8 @@ three = 3
log = logging.getLogger(__name__)
_P = ParamSpec("_P")
def clone_me(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
if x is None:
@ -407,9 +410,9 @@ def check_dynamic_shape_capture() -> bool:
return not config.assume_static_by_default
def _make_fn_with_patches(fn: Callable[..., _T], *patches: Any) -> Callable[..., _T]:
def _make_fn_with_patches(fn: Callable[_P, _T], *patches: Any) -> Callable[_P, _T]:
@functools.wraps(fn)
def _fn(*args: Any, **kwargs: Any) -> _T:
def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
with contextlib.ExitStack() as stack:
for module, attr, val in patches:
stack.enter_context(patch.object(module, attr, val))