torch.nn.modules.LazyModuleMixin and torch.nn.LazyLinear (Shape Inference II) (#44538)

Summary:
Retake on https://github.com/pytorch/pytorch/issues/40493 after all the feedback from albanD

This PR implements the generic Lazy mechanism and a sample `LazyLinear` layer with the `UninitializedParameter`.

The main differences with the previous PR are two;
Now `torch.nn.Module` remains untouched.
We don't require an explicit initialization or a dummy forward pass before starting the training or inference of the actual module. Making this much simpler to use from the user side.

As we discussed offline, there was the suggestion of not using a mixin, but changing the `__class__` attribute of `LazyLinear` to become `Linear` once it's completely initialized. While this can be useful, by the time being we need `LazyLinear` to be a `torch.nn.Module` subclass since there are many checks that rely on the modules being instances of `torch.nn.Module`.
This can cause problems when we create complex modules such as
```
class MyNetwork(torch.nn.Module):
    def __init__(self):
        super(MyNetwork, self).__init__()
        self.conv = torch.nn.Conv2d(20, 4, 2)
        self.linear = torch.nn.LazyLinear(10)
    def forward(self, x):
        y = self.conv(x).clamp(min=0)
        return self.linear(y)
```
Here, when the __setattr__ function is called at the time LazyLinear is registered, it won't be added to the child modules of `MyNetwork`, so we have to manually do it later, but currently there is no way to do such thing as we can't access the parent module from LazyLinear once it becomes the Linear module. (We can add a workaround to this if needed).

TODO:

Add convolutions once the design is OK
Fix docstrings

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44538

Reviewed By: ngimel

Differential Revision: D24162854

Pulled By: albanD

fbshipit-source-id: 6d58dfe5d43bfb05b6ee506e266db3cf4b885f0c
This commit is contained in:
Emilio Castillo
2020-10-19 13:09:16 -07:00
committed by Facebook GitHub Bot
parent 7f8b02f5b7
commit d38a71d579
15 changed files with 662 additions and 12 deletions

View File

@ -21,6 +21,7 @@ These are the basic building block for graphs
:template: classtemplate.rst :template: classtemplate.rst
~parameter.Parameter ~parameter.Parameter
~parameter.UninitializedParameter
Containers Containers
---------------------------------- ----------------------------------
@ -207,6 +208,7 @@ Linear Layers
nn.Identity nn.Identity
nn.Linear nn.Linear
nn.Bilinear nn.Bilinear
nn.LazyLinear
Dropout Layers Dropout Layers
-------------- --------------
@ -373,3 +375,14 @@ Quantized Functions
Quantization refers to techniques for performing computations and storing tensors at lower bitwidths than Quantization refers to techniques for performing computations and storing tensors at lower bitwidths than
floating point precision. PyTorch supports both per tensor and per channel asymmetric linear quantization. To learn more how to use quantized functions in PyTorch, please refer to the :ref:`quantization-doc` documentation. floating point precision. PyTorch supports both per tensor and per channel asymmetric linear quantization. To learn more how to use quantized functions in PyTorch, please refer to the :ref:`quantization-doc` documentation.
Lazy Modules Initialization
---------------------------
.. currentmodule:: torch
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
nn.modules.lazy.LazyModuleMixin

View File

@ -2403,6 +2403,15 @@ class DistributedDataParallelTest(MultiProcessTestCase):
def test_arbitrary_forward_return_value_grad_is_view(self): def test_arbitrary_forward_return_value_grad_is_view(self):
self._test_arbitrary_forward_return_value(gradient_as_bucket_view=True) self._test_arbitrary_forward_return_value(gradient_as_bucket_view=True)
@requires_nccl()
@skip_if_not_multigpu
@skip_if_rocm
def test_ddp_with_lazy_parameters(self):
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
with self.assertRaisesRegex(RuntimeError, 'Modules with uninitialized parameters'):
DistributedDataParallel(torch.nn.LazyLinear(10), process_group=process_group)
def _test_find_unused_parameters_kwarg(self, gradient_as_bucket_view=False): def _test_find_unused_parameters_kwarg(self, gradient_as_bucket_view=False):
""" """
Note: this test can be sped up by only running it on a CPU module Note: this test can be sped up by only running it on a CPU module

View File

@ -78,6 +78,13 @@ class TestDataParallel(TestCase):
for p1, p2 in zip(model.parameters(), model_dp.parameters()): for p1, p2 in zip(model.parameters(), model_dp.parameters()):
self.assertTrue(p1.allclose(p2)) self.assertTrue(p1.allclose(p2))
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_lazy_linear(self):
with self.assertRaisesRegex(RuntimeError, 'Modules with uninitialized parameters'):
model_dp = torch.nn.DataParallel(torch.nn.LazyLinear(10).to(0))
model_dp(torch.rand(10, 10).to(0))
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_parallel_apply(self): def test_parallel_apply(self):
l1 = nn.Linear(10, 5).to("cuda:0", torch.float) l1 = nn.Linear(10, 5).to("cuda:0", torch.float)

View File

@ -32,12 +32,13 @@ from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch.autograd.gradcheck import gradgradcheck from torch.autograd.gradcheck import gradgradcheck
from torch.nn import Parameter from torch.nn import Parameter
from torch.nn.parameter import UninitializedParameter
from torch.nn.parallel._functions import Broadcast from torch.nn.parallel._functions import Broadcast
from torch.testing import get_all_fp_dtypes from torch.testing import get_all_fp_dtypes
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, download_file, \ TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, download_file, \
get_function_arglist, load_tests, repeat_test_for_types, ALL_TENSORTYPES, \ get_function_arglist, load_tests, repeat_test_for_types, ALL_TENSORTYPES, \
ALL_TENSORTYPES2, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC ALL_TENSORTYPES2, suppress_warnings, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
module_tests, criterion_tests, loss_reference_fns, \ module_tests, criterion_tests, loss_reference_fns, \
@ -13101,6 +13102,199 @@ class TestModuleGlobalHooks(TestCase):
output.backward(torch.ones(5, 5), retain_graph=True) output.backward(torch.ones(5, 5), retain_graph=True)
self.assertTrue(local_backward_called and global_backward_called) self.assertTrue(local_backward_called and global_backward_called)
class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
pass
class TestLazyModules(TestCase):
@suppress_warnings
def test_lazy_module_parameter(self):
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
self.assertTrue(module.has_uninitialized_params())
state_dict = module.state_dict()
self.assertIsInstance(state_dict['test_param'], UninitializedParameter)
new_module = LazyModule()
# An error is raised when there is an attempt to replace an existing parameter
# with an uninitialized one
new_module.register_parameter('test_param', nn.Parameter(torch.ones(5, 5)))
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
new_module.load_state_dict(state_dict)
# Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
new_module = LazyModule()
new_module.register_parameter('test_param', nn.Parameter(torch.ones(5, 5)))
module.load_state_dict(new_module.state_dict())
self.assertEqual(module.test_param, torch.ones((5, 5)))
# Uninitialized parameters are left unchanged
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
self.assertTrue(module.has_uninitialized_params())
new_module = LazyModule()
new_module.register_parameter('test_param', UninitializedParameter())
module.load_state_dict(new_module.state_dict())
self.assertTrue(module.has_uninitialized_params())
@suppress_warnings
def test_lazy_module_jit(self):
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
self.assertTrue(module.has_uninitialized_params())
with self.assertRaisesRegex(RuntimeError, 'run a forward pass'):
torch.jit.script(module)
@suppress_warnings
def test_lazy_share_memory(self):
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
self.assertTrue(module.has_uninitialized_params())
with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'):
module.share_memory()
@suppress_warnings
def test_linear(self):
module = nn.LazyLinear(10)
self.assertIsInstance(module.weight, UninitializedParameter)
input = torch.ones(5, 5)
module(input)
self.assertIsInstance(module, nn.Linear)
self.assertNotIsInstance(module, nn.LazyLinear)
self.assertTrue(module.weight.shape == (10, 5))
y = module(input)
self.assertTrue(torch.equal(torch.nn.functional.linear(input, module.weight, module.bias), y))
@suppress_warnings
def test_lazy_linear_pickle(self):
module = nn.LazyLinear(10)
self.assertIsInstance(module.weight, UninitializedParameter)
module = pickle.loads(pickle.dumps(module))
self.assertIsInstance(module, nn.LazyLinear)
self.assertIsInstance(module.weight, UninitializedParameter)
input = torch.ones(5, 5)
module(input) # fully materialized
new_module = pickle.loads(pickle.dumps(module))
self.assertIsInstance(new_module, nn.Linear)
self.assertNotIsInstance(new_module, nn.LazyLinear)
self.assertTrue(new_module.weight.shape == (10, 5))
self.assertNotIsInstance(new_module.weight, UninitializedParameter)
@suppress_warnings
def test_linear_state(self):
module = nn.Linear(5, 10)
lazy_module = nn.LazyLinear(10)
lazy_module.load_state_dict(module.state_dict())
# Parameters have been initialized but the module won't become a full
# Linear one until the first iteration. This is due to
# limitations on the state_dict loading logic
self.assertFalse(lazy_module.has_uninitialized_params())
self.assertTrue(lazy_module.weight.shape == (10, 5))
module = nn.Linear(5, 10)
lazy_module = nn.LazyLinear(10)
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
module.load_state_dict(lazy_module.state_dict())
@suppress_warnings
def test_materialize_dtype(self):
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
module.test_param.materialize(10)
self.assertTrue(module.test_param.dtype == torch.float64)
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
module.half()
module.test_param.materialize(10)
self.assertTrue(module.test_param.dtype == torch.float16)
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
@suppress_warnings
def test_materialize_device(self):
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
module.test_param.materialize(10)
self.assertTrue(module.test_param.device.type == 'cpu')
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
module.cuda()
module.test_param.materialize(10)
self.assertTrue(module.test_param.device.type == 'cuda')
@suppress_warnings
def test_chained_initialization(self):
class MyNetwork(torch.nn.Module):
def __init__(self):
super(MyNetwork, self).__init__()
self.linear_1 = torch.nn.LazyLinear(15)
self.linear_2 = torch.nn.LazyLinear(10)
def forward(self, x):
y = self.linear_1(x)
return self.linear_2(y)
net = MyNetwork()
net(torch.ones(5, 10))
self.assertTrue(net.linear_1.weight.shape == (15, 10))
self.assertTrue(net.linear_2.weight.shape == (10, 15))
@suppress_warnings
def test_optimizer_pass(self):
optimizers = [torch.optim.Adadelta, torch.optim.Adagrad, torch.optim.Adam,
torch.optim.AdamW, torch.optim.Adamax,
torch.optim.ASGD, torch.optim.SGD, torch.optim.Rprop,
torch.optim.RMSprop, torch.optim.LBFGS]
def run_step(module, optim):
self.assertIsInstance(optim.param_groups[0]['params'][0], UninitializedParameter)
module.test_param.materialize(10)
self.assertIsInstance(optim.param_groups[0]['params'][0], Parameter)
self.assertNotIsInstance(optim.param_groups[0]['params'][0], UninitializedParameter)
for p in module.parameters():
p.grad = torch.rand_like(p)
if isinstance(optim, torch.optim.LBFGS):
optim.step(lambda: 1.0)
else:
optim.step()
for optim_cls in optimizers:
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
if optim_cls is torch.optim.SGD:
optim = optim_cls(module.parameters(), lr=0.0)
elif optim_cls is torch.optim.Adagrad:
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
optim = optim_cls(module.parameters())
continue
else:
optim = optim_cls(module.parameters())
run_step(module, optim)
@suppress_warnings
def test_weight_norm(self):
m = nn.LazyLinear(7)
with self.assertRaisesRegex(ValueError, 'have uninitialized parameters.'):
m = torch.nn.utils.weight_norm(m)
@suppress_warnings
def test_spectral_norm(self):
m = nn.LazyLinear(7)
with self.assertRaisesRegex(ValueError, 'have uninitialized parameters.'):
m = torch.nn.utils.spectral_norm(m)
@suppress_warnings
def test_invalid_functions(self):
param = torch.nn.parameter.UninitializedParameter()
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
torch.empty_like(param)
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
torch.add(param, param)
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
param + param
instantiate_device_type_tests(TestNNDeviceType, globals()) instantiate_device_type_tests(TestNNDeviceType, globals())
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -561,6 +561,13 @@ def check_module_initialized(mod):
raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?" raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?"
.format(torch.typename(type(mod)))) .format(torch.typename(type(mod))))
# This is to avoid importing torch.distributed.nn
if not hasattr(mod, 'remote_parameters'):
for name, param in mod._parameters.items():
if isinstance(param, torch.nn.parameter.UninitializedParameter):
raise RuntimeError("'{}' has uninitialized parameters {}. Did you forget to run a forward pass?"
.format(torch.typename(type(mod)), name))
def infer_methods_to_compile(nn_module): def infer_methods_to_compile(nn_module):
""" """
Implements the default rules for which methods should act as starting Implements the default rules for which methods should act as starting

View File

@ -1,5 +1,5 @@
from .modules import * from .modules import *
from .parameter import Parameter from .parameter import Parameter, UninitializedParameter
from .parallel import DataParallel from .parallel import DataParallel
from . import init from . import init
from . import utils from . import utils

View File

@ -1,5 +1,5 @@
from .module import Module from .module import Module
from .linear import Identity, Linear, Bilinear from .linear import Identity, Linear, Bilinear, LazyLinear
from .conv import Conv1d, Conv2d, Conv3d, \ from .conv import Conv1d, Conv2d, Conv3d, \
ConvTranspose1d, ConvTranspose2d, ConvTranspose3d ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \ from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \
@ -54,6 +54,6 @@ __all__ = [
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d',
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold', 'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder', 'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer', 'LazyLinear',
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle' 'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle'
] ]

257
torch/nn/modules/lazy.py Normal file
View File

@ -0,0 +1,257 @@
import itertools
from typing_extensions import Protocol
import warnings
import torch
from ..parameter import UninitializedParameter
class _LazyProtocol(Protocol):
"""This is to avoid errors with mypy checks for
The attributes in a mixin:
https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes
"""
def _register_load_state_dict_pre_hook(self, hook):
...
def register_forward_pre_hook(self, hook):
...
def _lazy_load_hook(
self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
...
def _get_name(self):
...
def _infer_parameters(self, module, input):
...
@property
def _parameters(self):
...
@property
def _buffers(self):
...
@property
def _non_persistent_buffers_set(self):
...
@property
def _load_hook(self):
...
@property
def _initialize_hook(self):
...
class LazyModuleMixin:
r"""A mixin for modules that lazily initialize parameters, also known as "lazy modules."
.. warning:
Lazy modules are an experimental new feature under active development,
and their API is likely to change.
Modules that lazily initialize parameters, or "lazy modules",
derive the shapes of their parameters from the first input(s)
to their forward method. Until that first forward they contain
:class:`torch.nn.UninitializedParameter`s that should not be accessed
or used, and afterward they contain regular :class:`torch.nn.Parameter`s.
Lazy modules are convenient since they don't require computing some
module arguments, like the `in_features` argument of a
typical :class:`torch.nn.Linear`.
After construction, networks with lazy modules should first
be converted to the desired dtype and placed on the desired device.
The lazy modules should then be initialized with one or more "dry runs".
These "dry runs" send inputs of the correct size, dtype, and device through
the network and to each one of its lazy modules. After this the network can be used as usual.
>>> class LazyMLP(torch.nn.Module):
... def __init__(self):
... super().__init__()
... self.fc1 = torch.nn.LazyLinear(10)
... self.relu1 = torch.nn.ReLU()
... self.fc2 = torch.nn.LazyLinear(1)
... self.relu2 = torch.nn.ReLU()
...
... def forward(self, input):
... x = self.relu1(self.fc1(input))
... y = self.relu2(self.fc2(x))
... return y
>>> # constructs a network with lazy modules
>>> lazy_mlp = LazyMLP()
>>> # transforms the network's device and dtype
>>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs'
>>> lazy_mlp = mlp.cuda().double()
>>> lazy_mlp
LazyMLP(
(fc1): LazyLinear(in_features=0, out_features=10, bias=True)
(relu1): ReLU()
(fc2): LazyLinear(in_features=0, out_features=1, bias=True)
(relu2): ReLU()
)
>>> # performs a dry run to initialize the network's lazy modules
>>> lazy_mlp(torch.ones(10,10).cuda())
>>> # after initialization, LazyLinear modules become regular Linear modules
>>> lazy_mlp
LazyMLP(
(fc1): Linear(in_features=10, out_features=10, bias=True)
(relu1): ReLU()
(fc2): Linear(in_features=10, out_features=1, bias=True)
(relu2): ReLU()
)
>>> # attaches an optimizer, since parameters can now be used as usual
>>> optim = torch.optim.SGD(mlp.parameters(), lr=0.01)
A final caveat when using lazy modules is that the order of initialization of a network's
parameters may change, since the lazy modules are always initialized after other modules.
This can cause the parameters of a network using lazy modules to be initialized differently
than the parameters of a network without lazy modules.
For example, if the LazyMLP class defined above had a :class:`torch.nn.LazyLinear` module
first and then a regular :class:`torch.nn.Linear` second, the second module would be
initialized on construction and the first module would be initialized during the first dry run.
Lazy modules can be serialized with a state dict like other modules. For example:
>>> lazy_mlp = LazyMLP()
>>> # The state dict shows the uninitialized parameters
>>> lazy_mlp.state_dict()
OrderedDict([('fc1.weight', Uninitialized parameter),
('fc1.bias',
tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30,
4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])),
('fc2.weight', Uninitialized parameter),
('fc2.bias', tensor([0.0019]))])
Lazy modules can also load regular :class:`torch.nn.Parameter` s,
which replace their :class:`torch.nn.UninitializedParameter` s:
>>> full_mlp = LazyMLP()
>>> # Dry run to initialize another module
>>> full_mlp.forward(torch.ones(10, 1))
>>> # Load an initialized state into a lazy module
>>> lazy_mlp.load_state_dict(full_mlp.state_dict())
>>> # The state dict now holds valid values
>>> lazy_mlp.state_dict()
OrderedDict([('fc1.weight',
tensor([[-0.3837],
[ 0.0907],
[ 0.6708],
[-0.5223],
[-0.9028],
[ 0.2851],
[-0.4537],
[ 0.6813],
[ 0.5766],
[-0.8678]])),
('fc1.bias',
tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30,
4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])),
('fc2.weight',
tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.1088, -0.1795, -0.2301, 0.2807,
0.2479, 0.1091]])),
('fc2.bias', tensor([0.0019]))])
Note, however, that lazy modules cannot validate that the shape of parameters they load is correct.
"""
# modules inheriting from this will change their __class__ to the specified
# one after they are fully initialized
cls_to_become = None
def __init__(self: _LazyProtocol, *args, **kwargs):
# Mypy doesnt like this super call in a mixin
super().__init__(*args, **kwargs) # type: ignore
self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook)
self._initialize_hook = self.register_forward_pre_hook(self._infer_parameters)
warnings.warn('Lazy modules are a new feature under heavy development '
'so changes to the API or functionality can happen at any moment.')
def _save_to_state_dict(self: _LazyProtocol, destination, prefix, keep_vars):
# This should be ideally implemented as a hook,
# but we should override `detach` in the UninitializedParameter to return itself
# which is not clean
for name, param in self._parameters.items():
if param is not None:
if isinstance(param, UninitializedParameter):
destination[prefix + name] = param
else:
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
def _lazy_load_hook(
self: _LazyProtocol, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""load_state_dict pre-hook function for lazy buffers and parameters.
The purpose of this hook is to adjust the current state and/or
``state_dict`` being loaded so that a module instance serialized in
both un/initialized state can be deserialized onto both un/initialized
module instance.
See comment in ``torch.nn.Module._register_load_state_dict_pre_hook``
for the details of the hook specification.
"""
local_state = {k: v for k, v in self._parameters.items() if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if isinstance(param, UninitializedParameter):
# The current parameter is not initialized but the one being loaded one is
# create a new parameter based on the uninitialized one
if not isinstance(input_param, UninitializedParameter):
with torch.no_grad():
param.materialize(input_param.shape)
def initialize_parameters(self: _LazyProtocol, *args, **kwargs):
r"""Initialize parameters according to the input batch properties.
This adds an interface to isolate parameter initialization from the
forward pass when doing parameter shape inference.
"""
raise NotImplementedError('initialize_parameters is not implemented for {}'.format(self.__class__.__name__))
def has_uninitialized_params(self: _LazyProtocol):
r"""Check if a module has parameters that are not initialized
"""
# This is to avoid the JIT to track this parameter and force
# custom modules __setstate__ to add it
params = self._parameters.values()
for param in itertools.chain(params):
if isinstance(param, (UninitializedParameter)):
return True
return False
def _infer_parameters(self: _LazyProtocol, module, input):
r"""Infers the size and initializes the parameters according to the
provided input batch.
Given a module that contains parameters that were declared inferrable
using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass
in the complete module using the provided input to initialize all the parameters
as needed.
The module is set into evaluation mode before running the forward pass in order
to avoid saving statistics or calculating gradients
"""
module.initialize_parameters(*input)
if module.has_uninitialized_params():
raise RuntimeError('module {} has not been fully initialized'.format(self._get_name()))
module._initialize_hook.remove()
module._load_hook.remove()
delattr(module, '_initialize_hook')
delattr(module, '_load_hook')
if module.cls_to_become is not None:
module.__class__ = module.cls_to_become
def _replicate_for_data_parallel(self: _LazyProtocol):
raise RuntimeError('Modules with uninitialized parameters can\'t be used with `DataParallel`. '
'Run a dummy forward pass to correctly initialize the modules')

View File

@ -2,10 +2,11 @@ import math
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter, UninitializedParameter
from .. import functional as F from .. import functional as F
from .. import init from .. import init
from .module import Module from .module import Module
from .lazy import LazyModuleMixin
class Identity(Module): class Identity(Module):
@ -178,4 +179,49 @@ class Bilinear(Module):
self.in1_features, self.in2_features, self.out_features, self.bias is not None self.in1_features, self.in2_features, self.out_features, self.bias is not None
) )
class LazyLinear(LazyModuleMixin, Linear):
r"""A :class:`torch.nn.Linear` module with lazy initialization.
In this module, the `weight` and `bias` are of :class:`torch.nn.UninitializedParameter`
class. They will be initialized after the first call to ``forward`` is done and the
module will become a regular :class:`torch.nn.Linear` module.
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
on lazy modules and their limitations.
Args:
out_features: size of each output sample
bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True``
Attributes:
weight: the learnable weights of the module of shape
:math:`(\text{out\_features}, \text{in\_features})`. The values are
initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
:math:`k = \frac{1}{\text{in\_features}}`
bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in\_features}}`
"""
cls_to_become = Linear
def __init__(self, out_features: int, bias: bool = True) -> None:
super().__init__(0, out_features, bias)
self.weight = UninitializedParameter()
def reset_parameters(self) -> None:
if not self.has_uninitialized_params() and self.in_features != 0:
super().reset_parameters()
def initialize_parameters(self, input) -> None:
if self.has_uninitialized_params():
with torch.no_grad():
self.in_features = input.shape[-1]
self.weight.materialize((self.out_features, self.in_features))
self.reset_parameters()
# TODO: PartialLinear - maybe in sparse? # TODO: PartialLinear - maybe in sparse?

View File

@ -965,18 +965,20 @@ class Module:
key = prefix + name key = prefix + name
if key in state_dict: if key in state_dict:
input_param = state_dict[key] input_param = state_dict[key]
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = isinstance(param, torch.nn.parameter.UninitializedParameter)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(param.shape) == 0 and len(input_param.shape) == 1: if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0] input_param = input_param[0]
if input_param.shape != param.shape: if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint # local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.' 'the shape in current model is {}.'
.format(key, input_param.shape, param.shape)) .format(key, input_param.shape, param.shape))
continue continue
try: try:
with torch.no_grad(): with torch.no_grad():
param.copy_(input_param) param.copy_(input_param)

View File

@ -404,7 +404,13 @@ class DistributedDataParallel(Module):
) )
pass pass
# used for intra-node param sync and inter-node sync as well # Check that a module does not have Uninitialized parameters
for param in module.parameters():
if isinstance(param, torch.nn.parameter.UninitializedParameter):
raise RuntimeError(
'Modules with uninitialized parameters can\'t be used with `DistributedDataParallel`. '
'Run a dummy forward pass to correctly initialize the modules')
# used for intra-node param sync and inter-node sync as wel
self.broadcast_bucket_size = int(250 * 1024 * 1024) self.broadcast_bucket_size = int(250 * 1024 * 1024)
# reduction bucket size # reduction bucket size

View File

@ -44,3 +44,97 @@ class Parameter(torch.Tensor):
) )
__torch_function__ = _disabled_torch_function_impl __torch_function__ = _disabled_torch_function_impl
class UninitializedParameter(Parameter):
r"""A parameter that is not initialized.
Unitialized Parameters are a a special case of :class:`torch.nn.Parameter`
where the shape of the data is still unknown.
Unlikely a :class:`torch.nn.Parameter`, uninitialized parameters
hold no data and attempting to access some properties, like their shape,
will throw a runtime error. The only operations that can be performed on a uninitialized
parameter are changing its datatype, moving it to a different device and
converting it to a regular :class:`torch.nn.Parameter`.
"""
_allowed_methods = [
torch.Tensor.__hash__,
torch.Tensor.size,
torch.Tensor.copy_,
torch.Tensor.is_floating_point,
torch.Tensor.half,
torch.Tensor.float,
torch.Tensor.double,
torch.Tensor.char,
torch.Tensor.short,
torch.Tensor.int,
torch.Tensor.long,
torch.Tensor.cuda,
torch.Tensor.cpu,
torch.Tensor.to,
torch.Tensor.get_device,
torch._has_compatible_shallow_copy_type]
def __new__(cls, requires_grad=True):
data = torch.Tensor()
return torch.Tensor._make_subclass(cls, data, requires_grad)
def materialize(self, shape, device=None, dtype=None):
r"""Create a Parameter with the same properties of the uninitialized one.
Given a shape, it materializes a parameter in the same device
and with the same `dtype` as the current one or the specified ones in the
arguments.
Args:
shape : (tuple): the shape for the materialized tensor.
device (:class:`torch.device`): the desired device of the parameters
and buffers in this module. Optional.
dtype (:class:`torch.dtype`): the desired floating point type of
the floating point parameters and buffers in this module. Optional.
"""
if device is None:
device = self.data.device
if dtype is None:
dtype = self.data.dtype
self.data = torch.empty(shape, device=device, dtype=dtype)
self.__class__ = Parameter
@property
def shape(self):
raise RuntimeError(
'Can\'t access the shape of an uninitialized parameter. '
'This error usually happens in `load_state_dict` when trying to load '
'an uninitialized parameter into an initialized one. '
'Call `forward` to initialize the parameters before accessing their attributes.')
def share_memory_(self):
raise RuntimeError(
'Can\'t share memory on an uninitialized parameter. '
'Call `forward` to initialize the parameters before calling '
'`module.share_memory()`.')
def __repr__(self):
return 'Uninitialized parameter'
def __reduce_ex__(self, proto):
# See Note [Don't serialize hooks]
return (
UninitializedParameter,
(self.requires_grad,)
)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# method-wrapper is to detect access to Tensor properties that are
# wrapped in descriptors
if func in cls._allowed_methods or func.__class__.__name__ == 'method-wrapper':
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
raise ValueError(
'Attempted to use an uninitialized parameter in {}. '
'This error happens when you are using a `LazyModule` or '
'explicitly manipulating `torch.nn.parameter.UninitializedParameter` '
'objects. When using LazyModules Call `forward` with a dummy batch '
'to initialize the parameters before calling torch functions'.format(func))

View File

@ -1,7 +1,15 @@
import torch
from .. import Tensor from .. import Tensor
from typing import Tuple, Optional
import builtins import builtins
class Parameter(Tensor): class Parameter(Tensor):
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ... def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...
... ...
class UninitializedParameter(Tensor):
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...
def materialize(self, shape: Tuple[int], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
...

View File

@ -120,6 +120,10 @@ class SpectralNorm:
fn = SpectralNorm(name, n_power_iterations, dim, eps) fn = SpectralNorm(name, n_power_iterations, dim, eps)
weight = module._parameters[name] weight = module._parameters[name]
if isinstance(weight, torch.nn.parameter.UninitializedParameter):
raise ValueError(
'The module passed to `SpectralNorm` can\'t have uninitialized parameters. '
'Make sure to run the dummy forward before applying spectral normalization')
with torch.no_grad(): with torch.no_grad():
weight_mat = fn.reshape_weight_to_matrix(weight) weight_mat = fn.reshape_weight_to_matrix(weight)

View File

@ -1,7 +1,7 @@
r""" r"""
Weight Normalization from https://arxiv.org/abs/1602.07868 Weight Normalization from https://arxiv.org/abs/1602.07868
""" """
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter, UninitializedParameter
from torch import _weight_norm, norm_except_dim from torch import _weight_norm, norm_except_dim
from typing import Any, TypeVar from typing import Any, TypeVar
from ..modules import Module from ..modules import Module
@ -36,7 +36,10 @@ class WeightNorm(object):
fn = WeightNorm(name, dim) fn = WeightNorm(name, dim)
weight = getattr(module, name) weight = getattr(module, name)
if isinstance(weight, UninitializedParameter):
raise ValueError(
'The module passed to `WeightNorm` can\'t have uninitialized parameters. '
'Make sure to run the dummy forward before applying weight normalization')
# remove w from parameter list # remove w from parameter list
del module._parameters[name] del module._parameters[name]