mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
The codebase has a few locations where callable parameter type information is lost when the unpackings *args and **kwargs are typed as Any. Refactor these instances to retain type information using typing_extensions.ParamSpec. Also, in these functions, enforce return type with TypeVar. Addresses #142306 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143797 Approved by: https://github.com/Skylion007 Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com> Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
79b354ee37
commit
a7915c56f6
@ -20,6 +20,7 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import ParamSpec
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -51,6 +52,8 @@ three = 3
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
def clone_me(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
||||
if x is None:
|
||||
@ -407,9 +410,9 @@ def check_dynamic_shape_capture() -> bool:
|
||||
return not config.assume_static_by_default
|
||||
|
||||
|
||||
def _make_fn_with_patches(fn: Callable[..., _T], *patches: Any) -> Callable[..., _T]:
|
||||
def _make_fn_with_patches(fn: Callable[_P, _T], *patches: Any) -> Callable[_P, _T]:
|
||||
@functools.wraps(fn)
|
||||
def _fn(*args: Any, **kwargs: Any) -> _T:
|
||||
def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
with contextlib.ExitStack() as stack:
|
||||
for module, attr, val in patches:
|
||||
stack.enter_context(patch.object(module, attr, val))
|
||||
|
Reference in New Issue
Block a user