mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Bilinear] move check to reset_parameters (#160952)
Fixes #160407 ### Summary: Moved the check to reset_parameters to make `Bilinear` module lazy. Lazy modules have in_features initialized to 0 and a pre forward hook that initializes these to the appropriate shape, then calls reset parameters, ### Impact: module: nn, linear.py ### Test: <img width="903" height="182" alt="Screenshot From 2025-08-19 13-27-12" src="https://github.com/user-attachments/assets/bc04b0d6-5174-4dc9-8b21-9e019b3822a5" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160952 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
595e13feb7
commit
a749c40342
@ -214,8 +214,6 @@ class Bilinear(Module):
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
if in1_features <= 0:
|
||||
raise ValueError(f"in1_features must be > 0, but got {in1_features}")
|
||||
self.in1_features = in1_features
|
||||
self.in2_features = in2_features
|
||||
self.out_features = out_features
|
||||
@ -233,6 +231,10 @@ class Bilinear(Module):
|
||||
"""
|
||||
Resets parameters based on their initialization used in ``__init__``.
|
||||
"""
|
||||
if self.in1_features <= 0:
|
||||
raise ValueError(
|
||||
f"in1_features must be > 0, but got (in1_features={self.in1_features})"
|
||||
)
|
||||
bound = 1 / math.sqrt(self.weight.size(1))
|
||||
init.uniform_(self.weight, -bound, bound)
|
||||
if self.bias is not None:
|
||||
|
Reference in New Issue
Block a user