mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 04:34:43 +08:00
torch.nn.modules.LazyModuleMixin and torch.nn.LazyLinear (Shape Inference II) (#44538)
Summary: Retake on https://github.com/pytorch/pytorch/issues/40493 after all the feedback from albanD This PR implements the generic Lazy mechanism and a sample `LazyLinear` layer with the `UninitializedParameter`. The main differences with the previous PR are two; Now `torch.nn.Module` remains untouched. We don't require an explicit initialization or a dummy forward pass before starting the training or inference of the actual module. Making this much simpler to use from the user side. As we discussed offline, there was the suggestion of not using a mixin, but changing the `__class__` attribute of `LazyLinear` to become `Linear` once it's completely initialized. While this can be useful, by the time being we need `LazyLinear` to be a `torch.nn.Module` subclass since there are many checks that rely on the modules being instances of `torch.nn.Module`. This can cause problems when we create complex modules such as ``` class MyNetwork(torch.nn.Module): def __init__(self): super(MyNetwork, self).__init__() self.conv = torch.nn.Conv2d(20, 4, 2) self.linear = torch.nn.LazyLinear(10) def forward(self, x): y = self.conv(x).clamp(min=0) return self.linear(y) ``` Here, when the __setattr__ function is called at the time LazyLinear is registered, it won't be added to the child modules of `MyNetwork`, so we have to manually do it later, but currently there is no way to do such thing as we can't access the parent module from LazyLinear once it becomes the Linear module. (We can add a workaround to this if needed). TODO: Add convolutions once the design is OK Fix docstrings Pull Request resolved: https://github.com/pytorch/pytorch/pull/44538 Reviewed By: ngimel Differential Revision: D24162854 Pulled By: albanD fbshipit-source-id: 6d58dfe5d43bfb05b6ee506e266db3cf4b885f0c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7f8b02f5b7
commit
d38a71d579
@ -21,6 +21,7 @@ These are the basic building block for graphs
|
|||||||
:template: classtemplate.rst
|
:template: classtemplate.rst
|
||||||
|
|
||||||
~parameter.Parameter
|
~parameter.Parameter
|
||||||
|
~parameter.UninitializedParameter
|
||||||
|
|
||||||
Containers
|
Containers
|
||||||
----------------------------------
|
----------------------------------
|
||||||
@ -207,6 +208,7 @@ Linear Layers
|
|||||||
nn.Identity
|
nn.Identity
|
||||||
nn.Linear
|
nn.Linear
|
||||||
nn.Bilinear
|
nn.Bilinear
|
||||||
|
nn.LazyLinear
|
||||||
|
|
||||||
Dropout Layers
|
Dropout Layers
|
||||||
--------------
|
--------------
|
||||||
@ -373,3 +375,14 @@ Quantized Functions
|
|||||||
|
|
||||||
Quantization refers to techniques for performing computations and storing tensors at lower bitwidths than
|
Quantization refers to techniques for performing computations and storing tensors at lower bitwidths than
|
||||||
floating point precision. PyTorch supports both per tensor and per channel asymmetric linear quantization. To learn more how to use quantized functions in PyTorch, please refer to the :ref:`quantization-doc` documentation.
|
floating point precision. PyTorch supports both per tensor and per channel asymmetric linear quantization. To learn more how to use quantized functions in PyTorch, please refer to the :ref:`quantization-doc` documentation.
|
||||||
|
|
||||||
|
Lazy Modules Initialization
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
.. currentmodule:: torch
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
nn.modules.lazy.LazyModuleMixin
|
||||||
|
|||||||
@ -2403,6 +2403,15 @@ class DistributedDataParallelTest(MultiProcessTestCase):
|
|||||||
def test_arbitrary_forward_return_value_grad_is_view(self):
|
def test_arbitrary_forward_return_value_grad_is_view(self):
|
||||||
self._test_arbitrary_forward_return_value(gradient_as_bucket_view=True)
|
self._test_arbitrary_forward_return_value(gradient_as_bucket_view=True)
|
||||||
|
|
||||||
|
@requires_nccl()
|
||||||
|
@skip_if_not_multigpu
|
||||||
|
@skip_if_rocm
|
||||||
|
def test_ddp_with_lazy_parameters(self):
|
||||||
|
store = c10d.FileStore(self.file_name, self.world_size)
|
||||||
|
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'Modules with uninitialized parameters'):
|
||||||
|
DistributedDataParallel(torch.nn.LazyLinear(10), process_group=process_group)
|
||||||
|
|
||||||
def _test_find_unused_parameters_kwarg(self, gradient_as_bucket_view=False):
|
def _test_find_unused_parameters_kwarg(self, gradient_as_bucket_view=False):
|
||||||
"""
|
"""
|
||||||
Note: this test can be sped up by only running it on a CPU module
|
Note: this test can be sped up by only running it on a CPU module
|
||||||
|
|||||||
@ -78,6 +78,13 @@ class TestDataParallel(TestCase):
|
|||||||
for p1, p2 in zip(model.parameters(), model_dp.parameters()):
|
for p1, p2 in zip(model.parameters(), model_dp.parameters()):
|
||||||
self.assertTrue(p1.allclose(p2))
|
self.assertTrue(p1.allclose(p2))
|
||||||
|
|
||||||
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
||||||
|
def test_data_parallel_lazy_linear(self):
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'Modules with uninitialized parameters'):
|
||||||
|
model_dp = torch.nn.DataParallel(torch.nn.LazyLinear(10).to(0))
|
||||||
|
model_dp(torch.rand(10, 10).to(0))
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
||||||
def test_parallel_apply(self):
|
def test_parallel_apply(self):
|
||||||
l1 = nn.Linear(10, 5).to("cuda:0", torch.float)
|
l1 = nn.Linear(10, 5).to("cuda:0", torch.float)
|
||||||
|
|||||||
196
test/test_nn.py
196
test/test_nn.py
@ -32,12 +32,13 @@ from torch.nn.utils import parameters_to_vector, vector_to_parameters
|
|||||||
from torch.autograd import gradcheck
|
from torch.autograd import gradcheck
|
||||||
from torch.autograd.gradcheck import gradgradcheck
|
from torch.autograd.gradcheck import gradgradcheck
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
from torch.nn.parameter import UninitializedParameter
|
||||||
from torch.nn.parallel._functions import Broadcast
|
from torch.nn.parallel._functions import Broadcast
|
||||||
from torch.testing import get_all_fp_dtypes
|
from torch.testing import get_all_fp_dtypes
|
||||||
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
|
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
|
||||||
TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, download_file, \
|
TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, download_file, \
|
||||||
get_function_arglist, load_tests, repeat_test_for_types, ALL_TENSORTYPES, \
|
get_function_arglist, load_tests, repeat_test_for_types, ALL_TENSORTYPES, \
|
||||||
ALL_TENSORTYPES2, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC
|
ALL_TENSORTYPES2, suppress_warnings, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
|
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
|
||||||
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
|
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
|
||||||
module_tests, criterion_tests, loss_reference_fns, \
|
module_tests, criterion_tests, loss_reference_fns, \
|
||||||
@ -13101,6 +13102,199 @@ class TestModuleGlobalHooks(TestCase):
|
|||||||
output.backward(torch.ones(5, 5), retain_graph=True)
|
output.backward(torch.ones(5, 5), retain_graph=True)
|
||||||
self.assertTrue(local_backward_called and global_backward_called)
|
self.assertTrue(local_backward_called and global_backward_called)
|
||||||
|
|
||||||
|
|
||||||
|
class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestLazyModules(TestCase):
|
||||||
|
|
||||||
|
@suppress_warnings
|
||||||
|
def test_lazy_module_parameter(self):
|
||||||
|
module = LazyModule()
|
||||||
|
module.register_parameter('test_param', UninitializedParameter())
|
||||||
|
self.assertTrue(module.has_uninitialized_params())
|
||||||
|
state_dict = module.state_dict()
|
||||||
|
self.assertIsInstance(state_dict['test_param'], UninitializedParameter)
|
||||||
|
new_module = LazyModule()
|
||||||
|
# An error is raised when there is an attempt to replace an existing parameter
|
||||||
|
# with an uninitialized one
|
||||||
|
new_module.register_parameter('test_param', nn.Parameter(torch.ones(5, 5)))
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
|
||||||
|
new_module.load_state_dict(state_dict)
|
||||||
|
# Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
|
||||||
|
new_module = LazyModule()
|
||||||
|
new_module.register_parameter('test_param', nn.Parameter(torch.ones(5, 5)))
|
||||||
|
module.load_state_dict(new_module.state_dict())
|
||||||
|
self.assertEqual(module.test_param, torch.ones((5, 5)))
|
||||||
|
|
||||||
|
# Uninitialized parameters are left unchanged
|
||||||
|
module = LazyModule()
|
||||||
|
module.register_parameter('test_param', UninitializedParameter())
|
||||||
|
self.assertTrue(module.has_uninitialized_params())
|
||||||
|
|
||||||
|
new_module = LazyModule()
|
||||||
|
new_module.register_parameter('test_param', UninitializedParameter())
|
||||||
|
module.load_state_dict(new_module.state_dict())
|
||||||
|
self.assertTrue(module.has_uninitialized_params())
|
||||||
|
|
||||||
|
@suppress_warnings
|
||||||
|
def test_lazy_module_jit(self):
|
||||||
|
module = LazyModule()
|
||||||
|
module.register_parameter('test_param', UninitializedParameter())
|
||||||
|
self.assertTrue(module.has_uninitialized_params())
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'run a forward pass'):
|
||||||
|
torch.jit.script(module)
|
||||||
|
|
||||||
|
@suppress_warnings
|
||||||
|
def test_lazy_share_memory(self):
|
||||||
|
module = LazyModule()
|
||||||
|
module.register_parameter('test_param', UninitializedParameter())
|
||||||
|
self.assertTrue(module.has_uninitialized_params())
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'):
|
||||||
|
module.share_memory()
|
||||||
|
|
||||||
|
@suppress_warnings
|
||||||
|
def test_linear(self):
|
||||||
|
module = nn.LazyLinear(10)
|
||||||
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
||||||
|
input = torch.ones(5, 5)
|
||||||
|
module(input)
|
||||||
|
self.assertIsInstance(module, nn.Linear)
|
||||||
|
self.assertNotIsInstance(module, nn.LazyLinear)
|
||||||
|
self.assertTrue(module.weight.shape == (10, 5))
|
||||||
|
y = module(input)
|
||||||
|
self.assertTrue(torch.equal(torch.nn.functional.linear(input, module.weight, module.bias), y))
|
||||||
|
|
||||||
|
@suppress_warnings
|
||||||
|
def test_lazy_linear_pickle(self):
|
||||||
|
module = nn.LazyLinear(10)
|
||||||
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
||||||
|
module = pickle.loads(pickle.dumps(module))
|
||||||
|
self.assertIsInstance(module, nn.LazyLinear)
|
||||||
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
||||||
|
input = torch.ones(5, 5)
|
||||||
|
module(input) # fully materialized
|
||||||
|
new_module = pickle.loads(pickle.dumps(module))
|
||||||
|
self.assertIsInstance(new_module, nn.Linear)
|
||||||
|
self.assertNotIsInstance(new_module, nn.LazyLinear)
|
||||||
|
self.assertTrue(new_module.weight.shape == (10, 5))
|
||||||
|
self.assertNotIsInstance(new_module.weight, UninitializedParameter)
|
||||||
|
|
||||||
|
@suppress_warnings
|
||||||
|
def test_linear_state(self):
|
||||||
|
module = nn.Linear(5, 10)
|
||||||
|
lazy_module = nn.LazyLinear(10)
|
||||||
|
lazy_module.load_state_dict(module.state_dict())
|
||||||
|
# Parameters have been initialized but the module won't become a full
|
||||||
|
# Linear one until the first iteration. This is due to
|
||||||
|
# limitations on the state_dict loading logic
|
||||||
|
self.assertFalse(lazy_module.has_uninitialized_params())
|
||||||
|
self.assertTrue(lazy_module.weight.shape == (10, 5))
|
||||||
|
|
||||||
|
module = nn.Linear(5, 10)
|
||||||
|
lazy_module = nn.LazyLinear(10)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
|
||||||
|
module.load_state_dict(lazy_module.state_dict())
|
||||||
|
|
||||||
|
@suppress_warnings
|
||||||
|
def test_materialize_dtype(self):
|
||||||
|
module = LazyModule()
|
||||||
|
module.register_parameter('test_param', UninitializedParameter())
|
||||||
|
module.test_param.materialize(10)
|
||||||
|
self.assertTrue(module.test_param.dtype == torch.float64)
|
||||||
|
module = LazyModule()
|
||||||
|
module.register_parameter('test_param', UninitializedParameter())
|
||||||
|
module.half()
|
||||||
|
module.test_param.materialize(10)
|
||||||
|
self.assertTrue(module.test_param.dtype == torch.float16)
|
||||||
|
|
||||||
|
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
|
||||||
|
@suppress_warnings
|
||||||
|
def test_materialize_device(self):
|
||||||
|
module = LazyModule()
|
||||||
|
module.register_parameter('test_param', UninitializedParameter())
|
||||||
|
module.test_param.materialize(10)
|
||||||
|
self.assertTrue(module.test_param.device.type == 'cpu')
|
||||||
|
module = LazyModule()
|
||||||
|
module.register_parameter('test_param', UninitializedParameter())
|
||||||
|
module.cuda()
|
||||||
|
module.test_param.materialize(10)
|
||||||
|
self.assertTrue(module.test_param.device.type == 'cuda')
|
||||||
|
|
||||||
|
@suppress_warnings
|
||||||
|
def test_chained_initialization(self):
|
||||||
|
class MyNetwork(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyNetwork, self).__init__()
|
||||||
|
self.linear_1 = torch.nn.LazyLinear(15)
|
||||||
|
self.linear_2 = torch.nn.LazyLinear(10)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.linear_1(x)
|
||||||
|
return self.linear_2(y)
|
||||||
|
|
||||||
|
net = MyNetwork()
|
||||||
|
net(torch.ones(5, 10))
|
||||||
|
self.assertTrue(net.linear_1.weight.shape == (15, 10))
|
||||||
|
self.assertTrue(net.linear_2.weight.shape == (10, 15))
|
||||||
|
|
||||||
|
@suppress_warnings
|
||||||
|
def test_optimizer_pass(self):
|
||||||
|
optimizers = [torch.optim.Adadelta, torch.optim.Adagrad, torch.optim.Adam,
|
||||||
|
torch.optim.AdamW, torch.optim.Adamax,
|
||||||
|
torch.optim.ASGD, torch.optim.SGD, torch.optim.Rprop,
|
||||||
|
torch.optim.RMSprop, torch.optim.LBFGS]
|
||||||
|
|
||||||
|
def run_step(module, optim):
|
||||||
|
self.assertIsInstance(optim.param_groups[0]['params'][0], UninitializedParameter)
|
||||||
|
module.test_param.materialize(10)
|
||||||
|
self.assertIsInstance(optim.param_groups[0]['params'][0], Parameter)
|
||||||
|
self.assertNotIsInstance(optim.param_groups[0]['params'][0], UninitializedParameter)
|
||||||
|
for p in module.parameters():
|
||||||
|
p.grad = torch.rand_like(p)
|
||||||
|
if isinstance(optim, torch.optim.LBFGS):
|
||||||
|
optim.step(lambda: 1.0)
|
||||||
|
else:
|
||||||
|
optim.step()
|
||||||
|
|
||||||
|
for optim_cls in optimizers:
|
||||||
|
module = LazyModule()
|
||||||
|
module.register_parameter('test_param', UninitializedParameter())
|
||||||
|
if optim_cls is torch.optim.SGD:
|
||||||
|
optim = optim_cls(module.parameters(), lr=0.0)
|
||||||
|
elif optim_cls is torch.optim.Adagrad:
|
||||||
|
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
|
||||||
|
optim = optim_cls(module.parameters())
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
optim = optim_cls(module.parameters())
|
||||||
|
run_step(module, optim)
|
||||||
|
|
||||||
|
@suppress_warnings
|
||||||
|
def test_weight_norm(self):
|
||||||
|
m = nn.LazyLinear(7)
|
||||||
|
with self.assertRaisesRegex(ValueError, 'have uninitialized parameters.'):
|
||||||
|
m = torch.nn.utils.weight_norm(m)
|
||||||
|
|
||||||
|
@suppress_warnings
|
||||||
|
def test_spectral_norm(self):
|
||||||
|
m = nn.LazyLinear(7)
|
||||||
|
with self.assertRaisesRegex(ValueError, 'have uninitialized parameters.'):
|
||||||
|
m = torch.nn.utils.spectral_norm(m)
|
||||||
|
|
||||||
|
@suppress_warnings
|
||||||
|
def test_invalid_functions(self):
|
||||||
|
param = torch.nn.parameter.UninitializedParameter()
|
||||||
|
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
|
||||||
|
torch.empty_like(param)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
|
||||||
|
torch.add(param, param)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
|
||||||
|
param + param
|
||||||
|
|
||||||
instantiate_device_type_tests(TestNNDeviceType, globals())
|
instantiate_device_type_tests(TestNNDeviceType, globals())
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -561,6 +561,13 @@ def check_module_initialized(mod):
|
|||||||
raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?"
|
raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?"
|
||||||
.format(torch.typename(type(mod))))
|
.format(torch.typename(type(mod))))
|
||||||
|
|
||||||
|
# This is to avoid importing torch.distributed.nn
|
||||||
|
if not hasattr(mod, 'remote_parameters'):
|
||||||
|
for name, param in mod._parameters.items():
|
||||||
|
if isinstance(param, torch.nn.parameter.UninitializedParameter):
|
||||||
|
raise RuntimeError("'{}' has uninitialized parameters {}. Did you forget to run a forward pass?"
|
||||||
|
.format(torch.typename(type(mod)), name))
|
||||||
|
|
||||||
def infer_methods_to_compile(nn_module):
|
def infer_methods_to_compile(nn_module):
|
||||||
"""
|
"""
|
||||||
Implements the default rules for which methods should act as starting
|
Implements the default rules for which methods should act as starting
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from .modules import *
|
from .modules import *
|
||||||
from .parameter import Parameter
|
from .parameter import Parameter, UninitializedParameter
|
||||||
from .parallel import DataParallel
|
from .parallel import DataParallel
|
||||||
from . import init
|
from . import init
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from .module import Module
|
from .module import Module
|
||||||
from .linear import Identity, Linear, Bilinear
|
from .linear import Identity, Linear, Bilinear, LazyLinear
|
||||||
from .conv import Conv1d, Conv2d, Conv3d, \
|
from .conv import Conv1d, Conv2d, Conv3d, \
|
||||||
ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
|
ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
|
||||||
from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \
|
from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \
|
||||||
@ -54,6 +54,6 @@ __all__ = [
|
|||||||
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d',
|
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d',
|
||||||
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
|
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
|
||||||
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
|
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
|
||||||
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
|
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer', 'LazyLinear',
|
||||||
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle'
|
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle'
|
||||||
]
|
]
|
||||||
|
|||||||
257
torch/nn/modules/lazy.py
Normal file
257
torch/nn/modules/lazy.py
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
import itertools
|
||||||
|
from typing_extensions import Protocol
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from ..parameter import UninitializedParameter
|
||||||
|
|
||||||
|
|
||||||
|
class _LazyProtocol(Protocol):
|
||||||
|
"""This is to avoid errors with mypy checks for
|
||||||
|
The attributes in a mixin:
|
||||||
|
https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes
|
||||||
|
"""
|
||||||
|
def _register_load_state_dict_pre_hook(self, hook):
|
||||||
|
...
|
||||||
|
|
||||||
|
def register_forward_pre_hook(self, hook):
|
||||||
|
...
|
||||||
|
|
||||||
|
def _lazy_load_hook(
|
||||||
|
self, state_dict, prefix, local_metadata, strict,
|
||||||
|
missing_keys, unexpected_keys, error_msgs):
|
||||||
|
...
|
||||||
|
|
||||||
|
def _get_name(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
def _infer_parameters(self, module, input):
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _parameters(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _buffers(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _non_persistent_buffers_set(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _load_hook(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _initialize_hook(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class LazyModuleMixin:
|
||||||
|
r"""A mixin for modules that lazily initialize parameters, also known as "lazy modules."
|
||||||
|
|
||||||
|
.. warning:
|
||||||
|
Lazy modules are an experimental new feature under active development,
|
||||||
|
and their API is likely to change.
|
||||||
|
|
||||||
|
Modules that lazily initialize parameters, or "lazy modules",
|
||||||
|
derive the shapes of their parameters from the first input(s)
|
||||||
|
to their forward method. Until that first forward they contain
|
||||||
|
:class:`torch.nn.UninitializedParameter`s that should not be accessed
|
||||||
|
or used, and afterward they contain regular :class:`torch.nn.Parameter`s.
|
||||||
|
Lazy modules are convenient since they don't require computing some
|
||||||
|
module arguments, like the `in_features` argument of a
|
||||||
|
typical :class:`torch.nn.Linear`.
|
||||||
|
|
||||||
|
After construction, networks with lazy modules should first
|
||||||
|
be converted to the desired dtype and placed on the desired device.
|
||||||
|
The lazy modules should then be initialized with one or more "dry runs".
|
||||||
|
These "dry runs" send inputs of the correct size, dtype, and device through
|
||||||
|
the network and to each one of its lazy modules. After this the network can be used as usual.
|
||||||
|
|
||||||
|
>>> class LazyMLP(torch.nn.Module):
|
||||||
|
... def __init__(self):
|
||||||
|
... super().__init__()
|
||||||
|
... self.fc1 = torch.nn.LazyLinear(10)
|
||||||
|
... self.relu1 = torch.nn.ReLU()
|
||||||
|
... self.fc2 = torch.nn.LazyLinear(1)
|
||||||
|
... self.relu2 = torch.nn.ReLU()
|
||||||
|
...
|
||||||
|
... def forward(self, input):
|
||||||
|
... x = self.relu1(self.fc1(input))
|
||||||
|
... y = self.relu2(self.fc2(x))
|
||||||
|
... return y
|
||||||
|
>>> # constructs a network with lazy modules
|
||||||
|
>>> lazy_mlp = LazyMLP()
|
||||||
|
>>> # transforms the network's device and dtype
|
||||||
|
>>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs'
|
||||||
|
>>> lazy_mlp = mlp.cuda().double()
|
||||||
|
>>> lazy_mlp
|
||||||
|
LazyMLP(
|
||||||
|
(fc1): LazyLinear(in_features=0, out_features=10, bias=True)
|
||||||
|
(relu1): ReLU()
|
||||||
|
(fc2): LazyLinear(in_features=0, out_features=1, bias=True)
|
||||||
|
(relu2): ReLU()
|
||||||
|
)
|
||||||
|
>>> # performs a dry run to initialize the network's lazy modules
|
||||||
|
>>> lazy_mlp(torch.ones(10,10).cuda())
|
||||||
|
>>> # after initialization, LazyLinear modules become regular Linear modules
|
||||||
|
>>> lazy_mlp
|
||||||
|
LazyMLP(
|
||||||
|
(fc1): Linear(in_features=10, out_features=10, bias=True)
|
||||||
|
(relu1): ReLU()
|
||||||
|
(fc2): Linear(in_features=10, out_features=1, bias=True)
|
||||||
|
(relu2): ReLU()
|
||||||
|
)
|
||||||
|
>>> # attaches an optimizer, since parameters can now be used as usual
|
||||||
|
>>> optim = torch.optim.SGD(mlp.parameters(), lr=0.01)
|
||||||
|
|
||||||
|
A final caveat when using lazy modules is that the order of initialization of a network's
|
||||||
|
parameters may change, since the lazy modules are always initialized after other modules.
|
||||||
|
This can cause the parameters of a network using lazy modules to be initialized differently
|
||||||
|
than the parameters of a network without lazy modules.
|
||||||
|
For example, if the LazyMLP class defined above had a :class:`torch.nn.LazyLinear` module
|
||||||
|
first and then a regular :class:`torch.nn.Linear` second, the second module would be
|
||||||
|
initialized on construction and the first module would be initialized during the first dry run.
|
||||||
|
|
||||||
|
Lazy modules can be serialized with a state dict like other modules. For example:
|
||||||
|
|
||||||
|
>>> lazy_mlp = LazyMLP()
|
||||||
|
>>> # The state dict shows the uninitialized parameters
|
||||||
|
>>> lazy_mlp.state_dict()
|
||||||
|
OrderedDict([('fc1.weight', Uninitialized parameter),
|
||||||
|
('fc1.bias',
|
||||||
|
tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30,
|
||||||
|
4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])),
|
||||||
|
('fc2.weight', Uninitialized parameter),
|
||||||
|
('fc2.bias', tensor([0.0019]))])
|
||||||
|
|
||||||
|
|
||||||
|
Lazy modules can also load regular :class:`torch.nn.Parameter` s,
|
||||||
|
which replace their :class:`torch.nn.UninitializedParameter` s:
|
||||||
|
|
||||||
|
|
||||||
|
>>> full_mlp = LazyMLP()
|
||||||
|
>>> # Dry run to initialize another module
|
||||||
|
>>> full_mlp.forward(torch.ones(10, 1))
|
||||||
|
>>> # Load an initialized state into a lazy module
|
||||||
|
>>> lazy_mlp.load_state_dict(full_mlp.state_dict())
|
||||||
|
>>> # The state dict now holds valid values
|
||||||
|
>>> lazy_mlp.state_dict()
|
||||||
|
OrderedDict([('fc1.weight',
|
||||||
|
tensor([[-0.3837],
|
||||||
|
[ 0.0907],
|
||||||
|
[ 0.6708],
|
||||||
|
[-0.5223],
|
||||||
|
[-0.9028],
|
||||||
|
[ 0.2851],
|
||||||
|
[-0.4537],
|
||||||
|
[ 0.6813],
|
||||||
|
[ 0.5766],
|
||||||
|
[-0.8678]])),
|
||||||
|
('fc1.bias',
|
||||||
|
tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30,
|
||||||
|
4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])),
|
||||||
|
('fc2.weight',
|
||||||
|
tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.1088, -0.1795, -0.2301, 0.2807,
|
||||||
|
0.2479, 0.1091]])),
|
||||||
|
('fc2.bias', tensor([0.0019]))])
|
||||||
|
|
||||||
|
Note, however, that lazy modules cannot validate that the shape of parameters they load is correct.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# modules inheriting from this will change their __class__ to the specified
|
||||||
|
# one after they are fully initialized
|
||||||
|
cls_to_become = None
|
||||||
|
|
||||||
|
def __init__(self: _LazyProtocol, *args, **kwargs):
|
||||||
|
# Mypy doesnt like this super call in a mixin
|
||||||
|
super().__init__(*args, **kwargs) # type: ignore
|
||||||
|
self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook)
|
||||||
|
self._initialize_hook = self.register_forward_pre_hook(self._infer_parameters)
|
||||||
|
warnings.warn('Lazy modules are a new feature under heavy development '
|
||||||
|
'so changes to the API or functionality can happen at any moment.')
|
||||||
|
|
||||||
|
def _save_to_state_dict(self: _LazyProtocol, destination, prefix, keep_vars):
|
||||||
|
# This should be ideally implemented as a hook,
|
||||||
|
# but we should override `detach` in the UninitializedParameter to return itself
|
||||||
|
# which is not clean
|
||||||
|
for name, param in self._parameters.items():
|
||||||
|
if param is not None:
|
||||||
|
if isinstance(param, UninitializedParameter):
|
||||||
|
destination[prefix + name] = param
|
||||||
|
else:
|
||||||
|
destination[prefix + name] = param if keep_vars else param.detach()
|
||||||
|
for name, buf in self._buffers.items():
|
||||||
|
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||||
|
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||||
|
|
||||||
|
def _lazy_load_hook(
|
||||||
|
self: _LazyProtocol, state_dict, prefix, local_metadata, strict,
|
||||||
|
missing_keys, unexpected_keys, error_msgs):
|
||||||
|
"""load_state_dict pre-hook function for lazy buffers and parameters.
|
||||||
|
|
||||||
|
The purpose of this hook is to adjust the current state and/or
|
||||||
|
``state_dict`` being loaded so that a module instance serialized in
|
||||||
|
both un/initialized state can be deserialized onto both un/initialized
|
||||||
|
module instance.
|
||||||
|
See comment in ``torch.nn.Module._register_load_state_dict_pre_hook``
|
||||||
|
for the details of the hook specification.
|
||||||
|
"""
|
||||||
|
local_state = {k: v for k, v in self._parameters.items() if v is not None}
|
||||||
|
for name, param in local_state.items():
|
||||||
|
key = prefix + name
|
||||||
|
if key in state_dict:
|
||||||
|
input_param = state_dict[key]
|
||||||
|
if isinstance(param, UninitializedParameter):
|
||||||
|
# The current parameter is not initialized but the one being loaded one is
|
||||||
|
# create a new parameter based on the uninitialized one
|
||||||
|
if not isinstance(input_param, UninitializedParameter):
|
||||||
|
with torch.no_grad():
|
||||||
|
param.materialize(input_param.shape)
|
||||||
|
|
||||||
|
def initialize_parameters(self: _LazyProtocol, *args, **kwargs):
|
||||||
|
r"""Initialize parameters according to the input batch properties.
|
||||||
|
This adds an interface to isolate parameter initialization from the
|
||||||
|
forward pass when doing parameter shape inference.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError('initialize_parameters is not implemented for {}'.format(self.__class__.__name__))
|
||||||
|
|
||||||
|
def has_uninitialized_params(self: _LazyProtocol):
|
||||||
|
r"""Check if a module has parameters that are not initialized
|
||||||
|
"""
|
||||||
|
# This is to avoid the JIT to track this parameter and force
|
||||||
|
# custom modules __setstate__ to add it
|
||||||
|
params = self._parameters.values()
|
||||||
|
for param in itertools.chain(params):
|
||||||
|
if isinstance(param, (UninitializedParameter)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _infer_parameters(self: _LazyProtocol, module, input):
|
||||||
|
r"""Infers the size and initializes the parameters according to the
|
||||||
|
provided input batch.
|
||||||
|
Given a module that contains parameters that were declared inferrable
|
||||||
|
using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass
|
||||||
|
in the complete module using the provided input to initialize all the parameters
|
||||||
|
as needed.
|
||||||
|
The module is set into evaluation mode before running the forward pass in order
|
||||||
|
to avoid saving statistics or calculating gradients
|
||||||
|
"""
|
||||||
|
module.initialize_parameters(*input)
|
||||||
|
if module.has_uninitialized_params():
|
||||||
|
raise RuntimeError('module {} has not been fully initialized'.format(self._get_name()))
|
||||||
|
module._initialize_hook.remove()
|
||||||
|
module._load_hook.remove()
|
||||||
|
delattr(module, '_initialize_hook')
|
||||||
|
delattr(module, '_load_hook')
|
||||||
|
if module.cls_to_become is not None:
|
||||||
|
module.__class__ = module.cls_to_become
|
||||||
|
|
||||||
|
|
||||||
|
def _replicate_for_data_parallel(self: _LazyProtocol):
|
||||||
|
raise RuntimeError('Modules with uninitialized parameters can\'t be used with `DataParallel`. '
|
||||||
|
'Run a dummy forward pass to correctly initialize the modules')
|
||||||
@ -2,10 +2,11 @@ import math
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import init
|
from .. import init
|
||||||
from .module import Module
|
from .module import Module
|
||||||
|
from .lazy import LazyModuleMixin
|
||||||
|
|
||||||
|
|
||||||
class Identity(Module):
|
class Identity(Module):
|
||||||
@ -178,4 +179,49 @@ class Bilinear(Module):
|
|||||||
self.in1_features, self.in2_features, self.out_features, self.bias is not None
|
self.in1_features, self.in2_features, self.out_features, self.bias is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LazyLinear(LazyModuleMixin, Linear):
|
||||||
|
r"""A :class:`torch.nn.Linear` module with lazy initialization.
|
||||||
|
|
||||||
|
In this module, the `weight` and `bias` are of :class:`torch.nn.UninitializedParameter`
|
||||||
|
class. They will be initialized after the first call to ``forward`` is done and the
|
||||||
|
module will become a regular :class:`torch.nn.Linear` module.
|
||||||
|
|
||||||
|
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
||||||
|
on lazy modules and their limitations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_features: size of each output sample
|
||||||
|
bias: If set to ``False``, the layer will not learn an additive bias.
|
||||||
|
Default: ``True``
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
weight: the learnable weights of the module of shape
|
||||||
|
:math:`(\text{out\_features}, \text{in\_features})`. The values are
|
||||||
|
initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
|
||||||
|
:math:`k = \frac{1}{\text{in\_features}}`
|
||||||
|
bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
|
||||||
|
If :attr:`bias` is ``True``, the values are initialized from
|
||||||
|
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
|
||||||
|
:math:`k = \frac{1}{\text{in\_features}}`
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
cls_to_become = Linear
|
||||||
|
|
||||||
|
def __init__(self, out_features: int, bias: bool = True) -> None:
|
||||||
|
super().__init__(0, out_features, bias)
|
||||||
|
self.weight = UninitializedParameter()
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
if not self.has_uninitialized_params() and self.in_features != 0:
|
||||||
|
super().reset_parameters()
|
||||||
|
|
||||||
|
def initialize_parameters(self, input) -> None:
|
||||||
|
if self.has_uninitialized_params():
|
||||||
|
with torch.no_grad():
|
||||||
|
self.in_features = input.shape[-1]
|
||||||
|
self.weight.materialize((self.out_features, self.in_features))
|
||||||
|
self.reset_parameters()
|
||||||
# TODO: PartialLinear - maybe in sparse?
|
# TODO: PartialLinear - maybe in sparse?
|
||||||
|
|||||||
@ -965,18 +965,20 @@ class Module:
|
|||||||
key = prefix + name
|
key = prefix + name
|
||||||
if key in state_dict:
|
if key in state_dict:
|
||||||
input_param = state_dict[key]
|
input_param = state_dict[key]
|
||||||
|
# This is used to avoid copying uninitialized parameters into
|
||||||
|
# non-lazy modules, since they dont have the hook to do the checks
|
||||||
|
# in such case, it will error when accessing the .shape attribute.
|
||||||
|
is_param_lazy = isinstance(param, torch.nn.parameter.UninitializedParameter)
|
||||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||||
if len(param.shape) == 0 and len(input_param.shape) == 1:
|
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
||||||
input_param = input_param[0]
|
input_param = input_param[0]
|
||||||
|
|
||||||
if input_param.shape != param.shape:
|
if not is_param_lazy and input_param.shape != param.shape:
|
||||||
# local shape should match the one in checkpoint
|
# local shape should match the one in checkpoint
|
||||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||||
'the shape in current model is {}.'
|
'the shape in current model is {}.'
|
||||||
.format(key, input_param.shape, param.shape))
|
.format(key, input_param.shape, param.shape))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
param.copy_(input_param)
|
param.copy_(input_param)
|
||||||
|
|||||||
@ -404,7 +404,13 @@ class DistributedDataParallel(Module):
|
|||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# used for intra-node param sync and inter-node sync as well
|
# Check that a module does not have Uninitialized parameters
|
||||||
|
for param in module.parameters():
|
||||||
|
if isinstance(param, torch.nn.parameter.UninitializedParameter):
|
||||||
|
raise RuntimeError(
|
||||||
|
'Modules with uninitialized parameters can\'t be used with `DistributedDataParallel`. '
|
||||||
|
'Run a dummy forward pass to correctly initialize the modules')
|
||||||
|
# used for intra-node param sync and inter-node sync as wel
|
||||||
self.broadcast_bucket_size = int(250 * 1024 * 1024)
|
self.broadcast_bucket_size = int(250 * 1024 * 1024)
|
||||||
|
|
||||||
# reduction bucket size
|
# reduction bucket size
|
||||||
|
|||||||
@ -44,3 +44,97 @@ class Parameter(torch.Tensor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
__torch_function__ = _disabled_torch_function_impl
|
__torch_function__ = _disabled_torch_function_impl
|
||||||
|
|
||||||
|
|
||||||
|
class UninitializedParameter(Parameter):
|
||||||
|
r"""A parameter that is not initialized.
|
||||||
|
|
||||||
|
Unitialized Parameters are a a special case of :class:`torch.nn.Parameter`
|
||||||
|
where the shape of the data is still unknown.
|
||||||
|
|
||||||
|
Unlikely a :class:`torch.nn.Parameter`, uninitialized parameters
|
||||||
|
hold no data and attempting to access some properties, like their shape,
|
||||||
|
will throw a runtime error. The only operations that can be performed on a uninitialized
|
||||||
|
parameter are changing its datatype, moving it to a different device and
|
||||||
|
converting it to a regular :class:`torch.nn.Parameter`.
|
||||||
|
"""
|
||||||
|
_allowed_methods = [
|
||||||
|
torch.Tensor.__hash__,
|
||||||
|
torch.Tensor.size,
|
||||||
|
torch.Tensor.copy_,
|
||||||
|
torch.Tensor.is_floating_point,
|
||||||
|
torch.Tensor.half,
|
||||||
|
torch.Tensor.float,
|
||||||
|
torch.Tensor.double,
|
||||||
|
torch.Tensor.char,
|
||||||
|
torch.Tensor.short,
|
||||||
|
torch.Tensor.int,
|
||||||
|
torch.Tensor.long,
|
||||||
|
torch.Tensor.cuda,
|
||||||
|
torch.Tensor.cpu,
|
||||||
|
torch.Tensor.to,
|
||||||
|
torch.Tensor.get_device,
|
||||||
|
torch._has_compatible_shallow_copy_type]
|
||||||
|
|
||||||
|
def __new__(cls, requires_grad=True):
|
||||||
|
data = torch.Tensor()
|
||||||
|
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||||
|
|
||||||
|
def materialize(self, shape, device=None, dtype=None):
|
||||||
|
r"""Create a Parameter with the same properties of the uninitialized one.
|
||||||
|
Given a shape, it materializes a parameter in the same device
|
||||||
|
and with the same `dtype` as the current one or the specified ones in the
|
||||||
|
arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape : (tuple): the shape for the materialized tensor.
|
||||||
|
device (:class:`torch.device`): the desired device of the parameters
|
||||||
|
and buffers in this module. Optional.
|
||||||
|
dtype (:class:`torch.dtype`): the desired floating point type of
|
||||||
|
the floating point parameters and buffers in this module. Optional.
|
||||||
|
"""
|
||||||
|
if device is None:
|
||||||
|
device = self.data.device
|
||||||
|
if dtype is None:
|
||||||
|
dtype = self.data.dtype
|
||||||
|
self.data = torch.empty(shape, device=device, dtype=dtype)
|
||||||
|
self.__class__ = Parameter
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
raise RuntimeError(
|
||||||
|
'Can\'t access the shape of an uninitialized parameter. '
|
||||||
|
'This error usually happens in `load_state_dict` when trying to load '
|
||||||
|
'an uninitialized parameter into an initialized one. '
|
||||||
|
'Call `forward` to initialize the parameters before accessing their attributes.')
|
||||||
|
|
||||||
|
def share_memory_(self):
|
||||||
|
raise RuntimeError(
|
||||||
|
'Can\'t share memory on an uninitialized parameter. '
|
||||||
|
'Call `forward` to initialize the parameters before calling '
|
||||||
|
'`module.share_memory()`.')
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'Uninitialized parameter'
|
||||||
|
|
||||||
|
def __reduce_ex__(self, proto):
|
||||||
|
# See Note [Don't serialize hooks]
|
||||||
|
return (
|
||||||
|
UninitializedParameter,
|
||||||
|
(self.requires_grad,)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||||
|
# method-wrapper is to detect access to Tensor properties that are
|
||||||
|
# wrapped in descriptors
|
||||||
|
if func in cls._allowed_methods or func.__class__.__name__ == 'method-wrapper':
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
return super().__torch_function__(func, types, args, kwargs)
|
||||||
|
raise ValueError(
|
||||||
|
'Attempted to use an uninitialized parameter in {}. '
|
||||||
|
'This error happens when you are using a `LazyModule` or '
|
||||||
|
'explicitly manipulating `torch.nn.parameter.UninitializedParameter` '
|
||||||
|
'objects. When using LazyModules Call `forward` with a dummy batch '
|
||||||
|
'to initialize the parameters before calling torch functions'.format(func))
|
||||||
|
|||||||
@ -1,7 +1,15 @@
|
|||||||
|
import torch
|
||||||
from .. import Tensor
|
from .. import Tensor
|
||||||
|
from typing import Tuple, Optional
|
||||||
import builtins
|
import builtins
|
||||||
|
|
||||||
class Parameter(Tensor):
|
class Parameter(Tensor):
|
||||||
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...
|
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
|
class UninitializedParameter(Tensor):
|
||||||
|
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...
|
||||||
|
|
||||||
|
def materialize(self, shape: Tuple[int], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
|
||||||
|
...
|
||||||
|
|||||||
@ -120,6 +120,10 @@ class SpectralNorm:
|
|||||||
|
|
||||||
fn = SpectralNorm(name, n_power_iterations, dim, eps)
|
fn = SpectralNorm(name, n_power_iterations, dim, eps)
|
||||||
weight = module._parameters[name]
|
weight = module._parameters[name]
|
||||||
|
if isinstance(weight, torch.nn.parameter.UninitializedParameter):
|
||||||
|
raise ValueError(
|
||||||
|
'The module passed to `SpectralNorm` can\'t have uninitialized parameters. '
|
||||||
|
'Make sure to run the dummy forward before applying spectral normalization')
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
weight_mat = fn.reshape_weight_to_matrix(weight)
|
weight_mat = fn.reshape_weight_to_matrix(weight)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
r"""
|
r"""
|
||||||
Weight Normalization from https://arxiv.org/abs/1602.07868
|
Weight Normalization from https://arxiv.org/abs/1602.07868
|
||||||
"""
|
"""
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||||
from torch import _weight_norm, norm_except_dim
|
from torch import _weight_norm, norm_except_dim
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar
|
||||||
from ..modules import Module
|
from ..modules import Module
|
||||||
@ -36,7 +36,10 @@ class WeightNorm(object):
|
|||||||
fn = WeightNorm(name, dim)
|
fn = WeightNorm(name, dim)
|
||||||
|
|
||||||
weight = getattr(module, name)
|
weight = getattr(module, name)
|
||||||
|
if isinstance(weight, UninitializedParameter):
|
||||||
|
raise ValueError(
|
||||||
|
'The module passed to `WeightNorm` can\'t have uninitialized parameters. '
|
||||||
|
'Make sure to run the dummy forward before applying weight normalization')
|
||||||
# remove w from parameter list
|
# remove w from parameter list
|
||||||
del module._parameters[name]
|
del module._parameters[name]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user