From 09ae69a3648dc00a701b2d940871efec8e19f246 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sun, 26 Jan 2025 03:41:00 +0000 Subject: [PATCH] Revert "Fix type annotation of `Linear.bias` (#142326)" This reverts commit 81e370fc6b90f9cb98c88f3173e738aba0dc650a. Reverted https://github.com/pytorch/pytorch/pull/142326 on behalf of https://github.com/malfet due to This introduced a graph break and regressed inductor tests, see https://hud.pytorch.org/hud/pytorch/pytorch/73622fc5fa9713f46a5cef9704772e645591bce6/1?per_page=50&name_filter=inductor_torchbench&mergeLF=true ([comment](https://github.com/pytorch/pytorch/pull/142326#issuecomment-2614196349)) --- .../pruner/base_structured_sparsifier.py | 2 +- torch/nn/modules/activation.py | 4 +--- torch/nn/modules/linear.py | 6 +---- torch/nn/modules/transformer.py | 13 +++++------ torch/utils/_typing_utils.py | 22 +++++-------------- 5 files changed, 14 insertions(+), 33 deletions(-) diff --git a/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py b/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py index fcbdb3593979..aa1129440fe5 100644 --- a/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py +++ b/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py @@ -265,7 +265,7 @@ class BaseStructuredSparsifier(BaseSparsifier): module.prune_bias = prune_bias module.register_forward_hook( - BiasHook(module.parametrizations.weight[0], prune_bias) # type: ignore[union-attr, index] + BiasHook(module.parametrizations.weight[0], prune_bias) ) def prune(self) -> None: diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 6ea07eb68e89..564a516a2477 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -7,7 +7,6 @@ import torch.nn.functional as F from torch import Tensor from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ from torch.nn.parameter import Parameter -from torch.utils._typing_utils import not_none from .linear import NonDynamicallyQuantizableLinear from .module import Module @@ -1123,7 +1122,6 @@ class MultiheadAttention(Module): xavier_uniform_(self.v_proj_weight) if self.in_proj_bias is not None: - assert self.out_proj.bias is not None constant_(self.in_proj_bias, 0.0) constant_(self.out_proj.bias, 0.0) if self.bias_k is not None: @@ -1321,7 +1319,7 @@ class MultiheadAttention(Module): self.in_proj_weight, self.in_proj_bias, self.out_proj.weight, - not_none(self.out_proj.bias), + self.out_proj.bias, merged_mask, need_weights, average_attn_weights, diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index 23a5f98664a4..4e53df95acf5 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import math -from typing import Any, Optional, TYPE_CHECKING +from typing import Any import torch from torch import Tensor @@ -89,8 +89,6 @@ class Linear(Module): in_features: int out_features: int weight: Tensor - if TYPE_CHECKING: - bias: Optional[Tensor] def __init__( self, @@ -192,8 +190,6 @@ class Bilinear(Module): in2_features: int out_features: int weight: Tensor - if TYPE_CHECKING: - bias: Optional[Tensor] def __init__( self, diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index 7ab0352b4308..4218bddc71e3 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -7,7 +7,6 @@ import torch import torch.nn.functional as F from torch import Tensor from torch.nn.init import xavier_uniform_ -from torch.utils._typing_utils import not_none from .activation import MultiheadAttention from .container import ModuleList @@ -842,15 +841,15 @@ class TransformerEncoderLayer(Module): self.self_attn.in_proj_weight, self.self_attn.in_proj_bias, self.self_attn.out_proj.weight, - not_none(self.self_attn.out_proj.bias), + self.self_attn.out_proj.bias, self.norm1.weight, self.norm1.bias, self.norm2.weight, self.norm2.bias, self.linear1.weight, - not_none(self.linear1.bias), + self.linear1.bias, self.linear2.weight, - not_none(self.linear2.bias), + self.linear2.bias, ) # We have to use list comprehensions below because TorchScript does not support @@ -886,7 +885,7 @@ class TransformerEncoderLayer(Module): self.self_attn.in_proj_weight, self.self_attn.in_proj_bias, self.self_attn.out_proj.weight, - not_none(self.self_attn.out_proj.bias), + self.self_attn.out_proj.bias, self.activation_relu_or_gelu == 2, self.norm_first, self.norm1.eps, @@ -895,9 +894,9 @@ class TransformerEncoderLayer(Module): self.norm2.weight, self.norm2.bias, self.linear1.weight, - not_none(self.linear1.bias), + self.linear1.bias, self.linear2.weight, - not_none(self.linear2.bias), + self.linear2.bias, merged_mask, mask_type, ) diff --git a/torch/utils/_typing_utils.py b/torch/utils/_typing_utils.py index f59e50266dad..ffb6b383e4e6 100644 --- a/torch/utils/_typing_utils.py +++ b/torch/utils/_typing_utils.py @@ -1,6 +1,6 @@ """Miscellaneous utilities to aid with typing.""" -from typing import Optional, TYPE_CHECKING, TypeVar +from typing import Optional, TypeVar # Helper to turn Optional[T] into T when we know None either isn't @@ -8,19 +8,7 @@ from typing import Optional, TYPE_CHECKING, TypeVar T = TypeVar("T") -# TorchScript cannot handle the type signature of `not_none` at runtime, because it trips -# over the `Optional[T]`. To allow using `not_none` from inside a TorchScript method/module, -# we split the implementation, and hide the runtime type information from TorchScript. -if TYPE_CHECKING: - - def not_none(obj: Optional[T]) -> T: - ... - -else: - - def not_none(obj): - if obj is None: - raise TypeError( - "Invariant encountered: value was None when it should not be" - ) - return obj +def not_none(obj: Optional[T]) -> T: + if obj is None: + raise TypeError("Invariant encountered: value was None when it should not be") + return obj