mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "Fix type annotation of Linear.bias
(#142326)"
This reverts commit 81e370fc6b90f9cb98c88f3173e738aba0dc650a.
Reverted https://github.com/pytorch/pytorch/pull/142326 on behalf of https://github.com/malfet due to This introduced a graph break and regressed inductor tests, see 73622fc5fa/1
([comment](https://github.com/pytorch/pytorch/pull/142326#issuecomment-2614196349))
This commit is contained in:
@ -265,7 +265,7 @@ class BaseStructuredSparsifier(BaseSparsifier):
|
||||
module.prune_bias = prune_bias
|
||||
|
||||
module.register_forward_hook(
|
||||
BiasHook(module.parametrizations.weight[0], prune_bias) # type: ignore[union-attr, index]
|
||||
BiasHook(module.parametrizations.weight[0], prune_bias)
|
||||
)
|
||||
|
||||
def prune(self) -> None:
|
||||
|
@ -7,7 +7,6 @@ import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.utils._typing_utils import not_none
|
||||
|
||||
from .linear import NonDynamicallyQuantizableLinear
|
||||
from .module import Module
|
||||
@ -1123,7 +1122,6 @@ class MultiheadAttention(Module):
|
||||
xavier_uniform_(self.v_proj_weight)
|
||||
|
||||
if self.in_proj_bias is not None:
|
||||
assert self.out_proj.bias is not None
|
||||
constant_(self.in_proj_bias, 0.0)
|
||||
constant_(self.out_proj.bias, 0.0)
|
||||
if self.bias_k is not None:
|
||||
@ -1321,7 +1319,7 @@ class MultiheadAttention(Module):
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.out_proj.weight,
|
||||
not_none(self.out_proj.bias),
|
||||
self.out_proj.bias,
|
||||
merged_mask,
|
||||
need_weights,
|
||||
average_attn_weights,
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -89,8 +89,6 @@ class Linear(Module):
|
||||
in_features: int
|
||||
out_features: int
|
||||
weight: Tensor
|
||||
if TYPE_CHECKING:
|
||||
bias: Optional[Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -192,8 +190,6 @@ class Bilinear(Module):
|
||||
in2_features: int
|
||||
out_features: int
|
||||
weight: Tensor
|
||||
if TYPE_CHECKING:
|
||||
bias: Optional[Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -7,7 +7,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn.init import xavier_uniform_
|
||||
from torch.utils._typing_utils import not_none
|
||||
|
||||
from .activation import MultiheadAttention
|
||||
from .container import ModuleList
|
||||
@ -842,15 +841,15 @@ class TransformerEncoderLayer(Module):
|
||||
self.self_attn.in_proj_weight,
|
||||
self.self_attn.in_proj_bias,
|
||||
self.self_attn.out_proj.weight,
|
||||
not_none(self.self_attn.out_proj.bias),
|
||||
self.self_attn.out_proj.bias,
|
||||
self.norm1.weight,
|
||||
self.norm1.bias,
|
||||
self.norm2.weight,
|
||||
self.norm2.bias,
|
||||
self.linear1.weight,
|
||||
not_none(self.linear1.bias),
|
||||
self.linear1.bias,
|
||||
self.linear2.weight,
|
||||
not_none(self.linear2.bias),
|
||||
self.linear2.bias,
|
||||
)
|
||||
|
||||
# We have to use list comprehensions below because TorchScript does not support
|
||||
@ -886,7 +885,7 @@ class TransformerEncoderLayer(Module):
|
||||
self.self_attn.in_proj_weight,
|
||||
self.self_attn.in_proj_bias,
|
||||
self.self_attn.out_proj.weight,
|
||||
not_none(self.self_attn.out_proj.bias),
|
||||
self.self_attn.out_proj.bias,
|
||||
self.activation_relu_or_gelu == 2,
|
||||
self.norm_first,
|
||||
self.norm1.eps,
|
||||
@ -895,9 +894,9 @@ class TransformerEncoderLayer(Module):
|
||||
self.norm2.weight,
|
||||
self.norm2.bias,
|
||||
self.linear1.weight,
|
||||
not_none(self.linear1.bias),
|
||||
self.linear1.bias,
|
||||
self.linear2.weight,
|
||||
not_none(self.linear2.bias),
|
||||
self.linear2.bias,
|
||||
merged_mask,
|
||||
mask_type,
|
||||
)
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Miscellaneous utilities to aid with typing."""
|
||||
|
||||
from typing import Optional, TYPE_CHECKING, TypeVar
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
|
||||
# Helper to turn Optional[T] into T when we know None either isn't
|
||||
@ -8,19 +8,7 @@ from typing import Optional, TYPE_CHECKING, TypeVar
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# TorchScript cannot handle the type signature of `not_none` at runtime, because it trips
|
||||
# over the `Optional[T]`. To allow using `not_none` from inside a TorchScript method/module,
|
||||
# we split the implementation, and hide the runtime type information from TorchScript.
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def not_none(obj: Optional[T]) -> T:
|
||||
...
|
||||
|
||||
else:
|
||||
|
||||
def not_none(obj):
|
||||
if obj is None:
|
||||
raise TypeError(
|
||||
"Invariant encountered: value was None when it should not be"
|
||||
)
|
||||
return obj
|
||||
def not_none(obj: Optional[T]) -> T:
|
||||
if obj is None:
|
||||
raise TypeError("Invariant encountered: value was None when it should not be")
|
||||
return obj
|
||||
|
Reference in New Issue
Block a user