mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 06:34:55 +08:00
Currently the `bias` attribute of `torch.nn.Linear` (and `Bilinear`) is typed incorrectly, because it relies on the implicit `Module.__getattr__` which types it as `Tensor | Module`. This has two issues: - It hides the fact that `bias` is optional, and can be `None`, which in turn can hide actual bugs on user side. - It blurs the type due to having `Module` in the union, which can require unnecessary `isistance(linear.bias, Tensor)` on user side. This PR types the `bias` attribute explicitly to fix these issues. CC @ezyang @Skylion007 Pull Request resolved: https://github.com/pytorch/pytorch/pull/142326 Approved by: https://github.com/ezyang
27 lines
771 B
Python
27 lines
771 B
Python
"""Miscellaneous utilities to aid with typing."""
|
|
|
|
from typing import Optional, TYPE_CHECKING, TypeVar
|
|
|
|
|
|
# Helper to turn Optional[T] into T when we know None either isn't
|
|
# possible or should trigger an exception.
|
|
T = TypeVar("T")
|
|
|
|
|
|
# TorchScript cannot handle the type signature of `not_none` at runtime, because it trips
|
|
# over the `Optional[T]`. To allow using `not_none` from inside a TorchScript method/module,
|
|
# we split the implementation, and hide the runtime type information from TorchScript.
|
|
if TYPE_CHECKING:
|
|
|
|
def not_none(obj: Optional[T]) -> T:
|
|
...
|
|
|
|
else:
|
|
|
|
def not_none(obj):
|
|
if obj is None:
|
|
raise TypeError(
|
|
"Invariant encountered: value was None when it should not be"
|
|
)
|
|
return obj
|