mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
torch.nn.modules.LazyModuleMixin and torch.nn.LazyLinear (Shape Inference II) (#44538)
Summary: Retake on https://github.com/pytorch/pytorch/issues/40493 after all the feedback from albanD This PR implements the generic Lazy mechanism and a sample `LazyLinear` layer with the `UninitializedParameter`. The main differences with the previous PR are two; Now `torch.nn.Module` remains untouched. We don't require an explicit initialization or a dummy forward pass before starting the training or inference of the actual module. Making this much simpler to use from the user side. As we discussed offline, there was the suggestion of not using a mixin, but changing the `__class__` attribute of `LazyLinear` to become `Linear` once it's completely initialized. While this can be useful, by the time being we need `LazyLinear` to be a `torch.nn.Module` subclass since there are many checks that rely on the modules being instances of `torch.nn.Module`. This can cause problems when we create complex modules such as ``` class MyNetwork(torch.nn.Module): def __init__(self): super(MyNetwork, self).__init__() self.conv = torch.nn.Conv2d(20, 4, 2) self.linear = torch.nn.LazyLinear(10) def forward(self, x): y = self.conv(x).clamp(min=0) return self.linear(y) ``` Here, when the __setattr__ function is called at the time LazyLinear is registered, it won't be added to the child modules of `MyNetwork`, so we have to manually do it later, but currently there is no way to do such thing as we can't access the parent module from LazyLinear once it becomes the Linear module. (We can add a workaround to this if needed). TODO: Add convolutions once the design is OK Fix docstrings Pull Request resolved: https://github.com/pytorch/pytorch/pull/44538 Reviewed By: ngimel Differential Revision: D24162854 Pulled By: albanD fbshipit-source-id: 6d58dfe5d43bfb05b6ee506e266db3cf4b885f0c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7f8b02f5b7
commit
d38a71d579
@ -120,6 +120,10 @@ class SpectralNorm:
|
||||
|
||||
fn = SpectralNorm(name, n_power_iterations, dim, eps)
|
||||
weight = module._parameters[name]
|
||||
if isinstance(weight, torch.nn.parameter.UninitializedParameter):
|
||||
raise ValueError(
|
||||
'The module passed to `SpectralNorm` can\'t have uninitialized parameters. '
|
||||
'Make sure to run the dummy forward before applying spectral normalization')
|
||||
|
||||
with torch.no_grad():
|
||||
weight_mat = fn.reshape_weight_to_matrix(weight)
|
||||
|
||||
Reference in New Issue
Block a user