[Bilinear] move check to reset_parameters (#160952)

Fixes #160407

### Summary:
Moved the check to reset_parameters to make `Bilinear` module lazy. Lazy modules have in_features initialized to 0 and a pre forward hook that initializes these to the appropriate shape, then calls reset parameters,

### Impact:
module: nn, linear.py

### Test:

<img width="903" height="182" alt="Screenshot From 2025-08-19 13-27-12" src="https://github.com/user-attachments/assets/bc04b0d6-5174-4dc9-8b21-9e019b3822a5" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160952
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Parshant Sharma
2025-09-13 01:17:06 +00:00
committed by PyTorch MergeBot
parent 595e13feb7
commit a749c40342

View File

@ -214,8 +214,6 @@ class Bilinear(Module):
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
if in1_features <= 0:
raise ValueError(f"in1_features must be > 0, but got {in1_features}")
self.in1_features = in1_features self.in1_features = in1_features
self.in2_features = in2_features self.in2_features = in2_features
self.out_features = out_features self.out_features = out_features
@ -233,6 +231,10 @@ class Bilinear(Module):
""" """
Resets parameters based on their initialization used in ``__init__``. Resets parameters based on their initialization used in ``__init__``.
""" """
if self.in1_features <= 0:
raise ValueError(
f"in1_features must be > 0, but got (in1_features={self.in1_features})"
)
bound = 1 / math.sqrt(self.weight.size(1)) bound = 1 / math.sqrt(self.weight.size(1))
init.uniform_(self.weight, -bound, bound) init.uniform_(self.weight, -bound, bound)
if self.bias is not None: if self.bias is not None: