torch.nn.modules.LazyModuleMixin and torch.nn.LazyLinear (Shape Inference II) (#44538)

Summary:
Retake on https://github.com/pytorch/pytorch/issues/40493 after all the feedback from albanD

This PR implements the generic Lazy mechanism and a sample `LazyLinear` layer with the `UninitializedParameter`.

The main differences with the previous PR are two;
Now `torch.nn.Module` remains untouched.
We don't require an explicit initialization or a dummy forward pass before starting the training or inference of the actual module. Making this much simpler to use from the user side.

As we discussed offline, there was the suggestion of not using a mixin, but changing the `__class__` attribute of `LazyLinear` to become `Linear` once it's completely initialized. While this can be useful, by the time being we need `LazyLinear` to be a `torch.nn.Module` subclass since there are many checks that rely on the modules being instances of `torch.nn.Module`.
This can cause problems when we create complex modules such as
```
class MyNetwork(torch.nn.Module):
    def __init__(self):
        super(MyNetwork, self).__init__()
        self.conv = torch.nn.Conv2d(20, 4, 2)
        self.linear = torch.nn.LazyLinear(10)
    def forward(self, x):
        y = self.conv(x).clamp(min=0)
        return self.linear(y)
```
Here, when the __setattr__ function is called at the time LazyLinear is registered, it won't be added to the child modules of `MyNetwork`, so we have to manually do it later, but currently there is no way to do such thing as we can't access the parent module from LazyLinear once it becomes the Linear module. (We can add a workaround to this if needed).

TODO:

Add convolutions once the design is OK
Fix docstrings

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44538

Reviewed By: ngimel

Differential Revision: D24162854

Pulled By: albanD

fbshipit-source-id: 6d58dfe5d43bfb05b6ee506e266db3cf4b885f0c
This commit is contained in:
Emilio Castillo
2020-10-19 13:09:16 -07:00
committed by Facebook GitHub Bot
parent 7f8b02f5b7
commit d38a71d579
15 changed files with 662 additions and 12 deletions

View File

@ -120,6 +120,10 @@ class SpectralNorm:
fn = SpectralNorm(name, n_power_iterations, dim, eps)
weight = module._parameters[name]
if isinstance(weight, torch.nn.parameter.UninitializedParameter):
raise ValueError(
'The module passed to `SpectralNorm` can\'t have uninitialized parameters. '
'Make sure to run the dummy forward before applying spectral normalization')
with torch.no_grad():
weight_mat = fn.reshape_weight_to_matrix(weight)