Revert "Fix type annotation of Linear.bias (#142326)"

This reverts commit 81e370fc6b90f9cb98c88f3173e738aba0dc650a.

Reverted https://github.com/pytorch/pytorch/pull/142326 on behalf of https://github.com/malfet due to This introduced a graph break and regressed inductor tests, see 73622fc5fa/1 ([comment](https://github.com/pytorch/pytorch/pull/142326#issuecomment-2614196349))
This commit is contained in:
PyTorch MergeBot
2025-01-26 03:41:00 +00:00
parent 73622fc5fa
commit 09ae69a364
5 changed files with 14 additions and 33 deletions

View File

@ -265,7 +265,7 @@ class BaseStructuredSparsifier(BaseSparsifier):
module.prune_bias = prune_bias
module.register_forward_hook(
BiasHook(module.parametrizations.weight[0], prune_bias) # type: ignore[union-attr, index]
BiasHook(module.parametrizations.weight[0], prune_bias)
)
def prune(self) -> None:

View File

@ -7,7 +7,6 @@ import torch.nn.functional as F
from torch import Tensor
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.parameter import Parameter
from torch.utils._typing_utils import not_none
from .linear import NonDynamicallyQuantizableLinear
from .module import Module
@ -1123,7 +1122,6 @@ class MultiheadAttention(Module):
xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
assert self.out_proj.bias is not None
constant_(self.in_proj_bias, 0.0)
constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
@ -1321,7 +1319,7 @@ class MultiheadAttention(Module):
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
not_none(self.out_proj.bias),
self.out_proj.bias,
merged_mask,
need_weights,
average_attn_weights,

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import math
from typing import Any, Optional, TYPE_CHECKING
from typing import Any
import torch
from torch import Tensor
@ -89,8 +89,6 @@ class Linear(Module):
in_features: int
out_features: int
weight: Tensor
if TYPE_CHECKING:
bias: Optional[Tensor]
def __init__(
self,
@ -192,8 +190,6 @@ class Bilinear(Module):
in2_features: int
out_features: int
weight: Tensor
if TYPE_CHECKING:
bias: Optional[Tensor]
def __init__(
self,

View File

@ -7,7 +7,6 @@ import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.init import xavier_uniform_
from torch.utils._typing_utils import not_none
from .activation import MultiheadAttention
from .container import ModuleList
@ -842,15 +841,15 @@ class TransformerEncoderLayer(Module):
self.self_attn.in_proj_weight,
self.self_attn.in_proj_bias,
self.self_attn.out_proj.weight,
not_none(self.self_attn.out_proj.bias),
self.self_attn.out_proj.bias,
self.norm1.weight,
self.norm1.bias,
self.norm2.weight,
self.norm2.bias,
self.linear1.weight,
not_none(self.linear1.bias),
self.linear1.bias,
self.linear2.weight,
not_none(self.linear2.bias),
self.linear2.bias,
)
# We have to use list comprehensions below because TorchScript does not support
@ -886,7 +885,7 @@ class TransformerEncoderLayer(Module):
self.self_attn.in_proj_weight,
self.self_attn.in_proj_bias,
self.self_attn.out_proj.weight,
not_none(self.self_attn.out_proj.bias),
self.self_attn.out_proj.bias,
self.activation_relu_or_gelu == 2,
self.norm_first,
self.norm1.eps,
@ -895,9 +894,9 @@ class TransformerEncoderLayer(Module):
self.norm2.weight,
self.norm2.bias,
self.linear1.weight,
not_none(self.linear1.bias),
self.linear1.bias,
self.linear2.weight,
not_none(self.linear2.bias),
self.linear2.bias,
merged_mask,
mask_type,
)

View File

@ -1,6 +1,6 @@
"""Miscellaneous utilities to aid with typing."""
from typing import Optional, TYPE_CHECKING, TypeVar
from typing import Optional, TypeVar
# Helper to turn Optional[T] into T when we know None either isn't
@ -8,19 +8,7 @@ from typing import Optional, TYPE_CHECKING, TypeVar
T = TypeVar("T")
# TorchScript cannot handle the type signature of `not_none` at runtime, because it trips
# over the `Optional[T]`. To allow using `not_none` from inside a TorchScript method/module,
# we split the implementation, and hide the runtime type information from TorchScript.
if TYPE_CHECKING:
def not_none(obj: Optional[T]) -> T:
...
else:
def not_none(obj):
if obj is None:
raise TypeError(
"Invariant encountered: value was None when it should not be"
)
raise TypeError("Invariant encountered: value was None when it should not be")
return obj