mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This word appears often in class descriptions and is not consistently spelled. Update comments and some function names to use the correct spelling consistently. Facilitates searching the codebase. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155944 Approved by: https://github.com/Skylion007
864 lines
32 KiB
Python
864 lines
32 KiB
Python
# Owner(s): ["module: nn"]
|
|
import pickle
|
|
import unittest
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import Buffer, Parameter
|
|
from torch.nn.parameter import UninitializedBuffer, UninitializedParameter
|
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
|
from torch.testing._internal.common_utils import (
|
|
run_tests,
|
|
suppress_warnings,
|
|
TEST_PRIVATEUSE1,
|
|
TestCase,
|
|
)
|
|
|
|
|
|
class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
|
|
pass
|
|
|
|
|
|
class TestLazyModules(TestCase):
|
|
@suppress_warnings
|
|
def test_lazy_module_parameter(self):
|
|
module = LazyModule()
|
|
module.register_parameter("test_param", UninitializedParameter())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
state_dict = module.state_dict()
|
|
self.assertIsInstance(state_dict["test_param"], UninitializedParameter)
|
|
new_module = LazyModule()
|
|
# An error is raised when there is an attempt to replace an existing parameter
|
|
# with an uninitialized one
|
|
new_module.register_parameter("test_param", nn.Parameter(torch.ones(5, 5)))
|
|
with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
|
|
new_module.load_state_dict(state_dict)
|
|
# Uninitialized parameters are overridden when the state dict to be loaded contains a valid one
|
|
new_module = LazyModule()
|
|
new_module.register_parameter("test_param", nn.Parameter(torch.ones(5, 5)))
|
|
module.load_state_dict(new_module.state_dict())
|
|
self.assertEqual(module.test_param, torch.ones((5, 5)))
|
|
|
|
# Uninitialized parameters are left unchanged
|
|
module = LazyModule()
|
|
module.register_parameter("test_param", UninitializedParameter())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
|
|
new_module = LazyModule()
|
|
new_module.register_parameter("test_param", UninitializedParameter())
|
|
module.load_state_dict(new_module.state_dict())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
|
|
@suppress_warnings
|
|
def test_lazy_module_buffer(self):
|
|
module = LazyModule()
|
|
module.test_buffer = UninitializedBuffer()
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
state_dict = module.state_dict()
|
|
self.assertIsInstance(state_dict["test_buffer"], UninitializedBuffer)
|
|
new_module = LazyModule()
|
|
# An error is raised when there is an attempt to replace an existing parameter
|
|
# with an uninitialized one
|
|
new_module.test_buffer = Buffer(torch.ones(5, 5))
|
|
with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
|
|
new_module.load_state_dict(state_dict)
|
|
# Uninitialized parameters are overridden when the state dict to be loaded contains a valid one
|
|
new_module = LazyModule()
|
|
new_module.test_buffer = Buffer(torch.ones(5, 5))
|
|
module.load_state_dict(new_module.state_dict())
|
|
self.assertEqual(module.test_buffer, torch.ones((5, 5)))
|
|
|
|
# Uninitialized parameters are left unchanged
|
|
module = LazyModule()
|
|
module.test_buffer = UninitializedBuffer()
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
|
|
new_module = LazyModule()
|
|
new_module.test_buffer = UninitializedBuffer()
|
|
module.load_state_dict(new_module.state_dict())
|
|
module.load_state_dict(new_module.state_dict())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
|
|
@suppress_warnings
|
|
def test_lazy_module_jit_param(self):
|
|
module = LazyModule()
|
|
module.register_parameter("test_param", UninitializedParameter())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
with self.assertRaisesRegex(RuntimeError, "run a forward pass"):
|
|
torch.jit.script(module)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_module_jit_buffer(self):
|
|
module = LazyModule()
|
|
module.test_buffer = UninitializedBuffer()
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
with self.assertRaisesRegex(RuntimeError, "run a forward pass"):
|
|
torch.jit.script(module)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_share_memory_param(self):
|
|
module = LazyModule()
|
|
module.register_parameter("test_param", UninitializedParameter())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
with self.assertRaisesRegex(RuntimeError, "share memory on an uninitialized"):
|
|
module.share_memory()
|
|
|
|
@suppress_warnings
|
|
def test_lazy_share_memory_buffer(self):
|
|
module = LazyModule()
|
|
module.test_buffer = UninitializedBuffer()
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
with self.assertRaisesRegex(RuntimeError, "share memory on an uninitialized"):
|
|
module.share_memory()
|
|
|
|
@suppress_warnings
|
|
def test_linear(self):
|
|
module = nn.LazyLinear(10)
|
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
|
self.assertIsInstance(module.bias, UninitializedParameter)
|
|
input = torch.ones(5, 5)
|
|
output = module(input)
|
|
self.assertIsInstance(module, nn.Linear)
|
|
self.assertNotIsInstance(module, nn.LazyLinear)
|
|
self.assertTrue(module.weight.shape == (10, 5))
|
|
self.assertTrue(module.bias.shape == (10,))
|
|
self.assertTrue((module.weight != 0).any())
|
|
self.assertTrue((module.bias != 0).any())
|
|
self.assertTrue((output != 0).any())
|
|
y = module(input)
|
|
self.assertTrue(
|
|
torch.equal(
|
|
torch.nn.functional.linear(input, module.weight, module.bias), y
|
|
)
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_linear_pickle(self):
|
|
module = nn.LazyLinear(10)
|
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
|
self.assertIsInstance(module.bias, UninitializedParameter)
|
|
module = pickle.loads(pickle.dumps(module))
|
|
self.assertIsInstance(module, nn.LazyLinear)
|
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
|
self.assertIsInstance(module.bias, UninitializedParameter)
|
|
input = torch.ones(5, 5)
|
|
module(input) # fully materialized
|
|
new_module = pickle.loads(pickle.dumps(module))
|
|
self.assertIsInstance(new_module, nn.Linear)
|
|
self.assertNotIsInstance(new_module, nn.LazyLinear)
|
|
self.assertTrue(new_module.weight.shape == (10, 5))
|
|
self.assertNotIsInstance(new_module.weight, UninitializedParameter)
|
|
self.assertTrue(new_module.bias.shape == (10,))
|
|
self.assertNotIsInstance(new_module.bias, UninitializedParameter)
|
|
|
|
@suppress_warnings
|
|
def test_linear_state(self):
|
|
module = nn.Linear(5, 10)
|
|
lazy_module = nn.LazyLinear(10)
|
|
lazy_module.load_state_dict(module.state_dict())
|
|
# Parameters have been initialized but the module won't become a full
|
|
# Linear one until the first iteration. This is due to
|
|
# limitations on the state_dict loading logic
|
|
self.assertFalse(lazy_module.has_uninitialized_params())
|
|
self.assertTrue(lazy_module.weight.shape == (10, 5))
|
|
self.assertTrue(lazy_module.bias.shape == (10,))
|
|
|
|
module = nn.Linear(5, 10)
|
|
lazy_module = nn.LazyLinear(10)
|
|
with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
|
|
module.load_state_dict(lazy_module.state_dict())
|
|
|
|
@suppress_warnings
|
|
def test_lazy_linear_state_and_forward(self):
|
|
module = nn.Linear(5, 10)
|
|
lazy_module = nn.LazyLinear(10)
|
|
lazy_module.load_state_dict(module.state_dict())
|
|
# Parameters have been initialized but the module won't become a full
|
|
# Linear one until the first iteration. This is due to
|
|
# limitations on the state_dict loading logic
|
|
self.assertFalse(lazy_module.has_uninitialized_params())
|
|
self.assertTrue(isinstance(lazy_module, nn.LazyLinear))
|
|
|
|
input = torch.randn(5, 5)
|
|
lazy_module(input)
|
|
self.assertFalse(isinstance(lazy_module, nn.LazyLinear))
|
|
self.assertTrue(lazy_module.in_features == 5)
|
|
|
|
def _check_lazy_conv(
|
|
self,
|
|
cls,
|
|
lazy_cls,
|
|
func,
|
|
init_args,
|
|
input_shape,
|
|
expected_weight_shape,
|
|
expected_bias_shape,
|
|
*forward_args,
|
|
**forward_kwargs,
|
|
):
|
|
module = lazy_cls(*init_args)
|
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
|
if module.bias is not None:
|
|
self.assertIsInstance(module.bias, UninitializedParameter)
|
|
input = torch.ones(*input_shape)
|
|
module(input, *forward_args, **forward_kwargs)
|
|
self.assertIsInstance(module, cls)
|
|
self.assertNotIsInstance(module, lazy_cls)
|
|
self.assertEqual(module.weight.shape, expected_weight_shape)
|
|
if module.bias is not None:
|
|
self.assertEqual(module.bias.shape, expected_bias_shape)
|
|
y = module(input)
|
|
self.assertTrue(torch.equal(func(input, module.weight, module.bias), y))
|
|
|
|
def _check_lazy_conv_pickle(
|
|
self,
|
|
cls,
|
|
lazy_cls,
|
|
init_args,
|
|
input_shape,
|
|
expected_weight_shape,
|
|
expected_bias_shape,
|
|
):
|
|
module = lazy_cls(*init_args)
|
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
|
if module.bias is not None:
|
|
self.assertIsInstance(module.bias, UninitializedParameter)
|
|
module = pickle.loads(pickle.dumps(module))
|
|
self.assertIsInstance(module, lazy_cls)
|
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
|
if module.bias is not None:
|
|
self.assertIsInstance(module.bias, UninitializedParameter)
|
|
input = torch.ones(*input_shape)
|
|
module(input) # fully materialized
|
|
new_module = pickle.loads(pickle.dumps(module))
|
|
self.assertIsInstance(new_module, cls)
|
|
self.assertNotIsInstance(new_module, lazy_cls)
|
|
self.assertEqual(new_module.weight.shape, expected_weight_shape)
|
|
self.assertNotIsInstance(new_module.weight, UninitializedParameter)
|
|
if new_module.bias is not None:
|
|
self.assertEqual(new_module.bias.shape, expected_bias_shape)
|
|
self.assertNotIsInstance(new_module.bias, UninitializedParameter)
|
|
|
|
def _check_lazy_conv_state(
|
|
self, gen_module, gen_lazy_module, expected_weight_shape, expected_bias_shape
|
|
):
|
|
module = gen_module()
|
|
lazy_module = gen_lazy_module()
|
|
lazy_module.load_state_dict(module.state_dict())
|
|
# Parameters have been initialized but the module won't become a full
|
|
# Conv one until the first iteration. This is due to
|
|
# limitations on the state_dict loading logic
|
|
self.assertFalse(lazy_module.has_uninitialized_params())
|
|
self.assertEqual(lazy_module.weight.shape, expected_weight_shape)
|
|
if lazy_module.bias is not None:
|
|
self.assertEqual(lazy_module.bias.shape, expected_bias_shape)
|
|
|
|
module = gen_module()
|
|
lazy_module = gen_lazy_module()
|
|
with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
|
|
module.load_state_dict(lazy_module.state_dict())
|
|
|
|
def test_lazy_pre_forward_hook(self):
|
|
"""
|
|
This test is to test whether lazymodule can register other pre-forward hook
|
|
functions successfully.
|
|
"""
|
|
|
|
class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
|
|
def initialize_parameters(self, input):
|
|
return None
|
|
|
|
def forward(self, input):
|
|
return input
|
|
|
|
def hook_function(module, input):
|
|
return input[0] + 1
|
|
|
|
module = TestModule()
|
|
module.register_forward_pre_hook(hook_function)
|
|
output = module(torch.zeros(2, 2))
|
|
self.assertEqual(output, torch.ones(2, 2))
|
|
|
|
def test_lazy_forward_hook(self):
|
|
"""
|
|
This test is to test whether lazymodule can register other forward hook
|
|
functions successfully.
|
|
"""
|
|
|
|
class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
|
|
def initialize_parameters(self, input):
|
|
return None
|
|
|
|
def forward(self, input):
|
|
return input
|
|
|
|
def hook_function(module, input, output):
|
|
return input[0] + 1
|
|
|
|
module = TestModule()
|
|
module.register_forward_hook(hook_function)
|
|
output = module(torch.zeros(2, 2))
|
|
self.assertEqual(output, torch.ones(2, 2))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv1d(self):
|
|
self._check_lazy_conv(
|
|
nn.Conv1d,
|
|
nn.LazyConv1d,
|
|
torch.nn.functional.conv1d,
|
|
(32, 2),
|
|
(192, 16, 50),
|
|
(32, 16, 2),
|
|
(32,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv1d_pickle(self):
|
|
self._check_lazy_conv_pickle(
|
|
nn.Conv1d, nn.LazyConv1d, (32, 2), (192, 16, 50), (32, 16, 2), (32,)
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv1d_state(self):
|
|
self._check_lazy_conv_state(
|
|
lambda: nn.Conv1d(16, 32, 2),
|
|
lambda: nn.LazyConv1d(32, 2),
|
|
(32, 16, 2),
|
|
(32,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv2d(self):
|
|
self._check_lazy_conv(
|
|
nn.Conv2d,
|
|
nn.LazyConv2d,
|
|
torch.nn.functional.conv2d,
|
|
(32, 2),
|
|
(192, 16, 8, 6),
|
|
(32, 16, 2, 2),
|
|
(32,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv2d_pickle(self):
|
|
self._check_lazy_conv_pickle(
|
|
nn.Conv2d, nn.LazyConv2d, (32, 2), (192, 16, 8, 6), (32, 16, 2, 2), (32,)
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv2d_state(self):
|
|
self._check_lazy_conv_state(
|
|
lambda: nn.Conv2d(16, 32, 2),
|
|
lambda: nn.LazyConv2d(32, 2),
|
|
(32, 16, 2, 2),
|
|
(32,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv3d(self):
|
|
self._check_lazy_conv(
|
|
nn.Conv3d,
|
|
nn.LazyConv3d,
|
|
torch.nn.functional.conv3d,
|
|
(32, 2),
|
|
(192, 16, 8, 7, 6),
|
|
(32, 16, 2, 2, 2),
|
|
(32,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv3d_pickle(self):
|
|
self._check_lazy_conv_pickle(
|
|
nn.Conv3d,
|
|
nn.LazyConv3d,
|
|
(32, 2),
|
|
(192, 16, 8, 7, 6),
|
|
(32, 16, 2, 2, 2),
|
|
(32,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv3d_state(self):
|
|
self._check_lazy_conv_state(
|
|
lambda: nn.Conv3d(16, 32, 2),
|
|
lambda: nn.LazyConv3d(32, 2),
|
|
(32, 16, 2, 2, 2),
|
|
(32,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transposed1d(self):
|
|
self._check_lazy_conv(
|
|
nn.ConvTranspose1d,
|
|
nn.LazyConvTranspose1d,
|
|
torch.nn.functional.conv_transpose1d,
|
|
(32, 2),
|
|
(192, 16, 50),
|
|
(16, 32, 2),
|
|
(32,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose1d_kwargs(self):
|
|
self._check_lazy_conv(
|
|
nn.ConvTranspose1d,
|
|
nn.LazyConvTranspose1d,
|
|
torch.nn.functional.conv_transpose1d,
|
|
(32, 2),
|
|
(192, 16, 50),
|
|
(16, 32, 2),
|
|
(32,),
|
|
output_size=(51,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose1d_pickle(self):
|
|
self._check_lazy_conv_pickle(
|
|
nn.ConvTranspose1d,
|
|
nn.LazyConvTranspose1d,
|
|
(32, 2),
|
|
(192, 16, 50),
|
|
(16, 32, 2),
|
|
(32,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose1d_state(self):
|
|
self._check_lazy_conv_state(
|
|
lambda: nn.ConvTranspose1d(16, 32, 2),
|
|
lambda: nn.LazyConvTranspose1d(32, 2),
|
|
(16, 32, 2),
|
|
(32,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose2d(self):
|
|
self._check_lazy_conv(
|
|
nn.ConvTranspose2d,
|
|
nn.LazyConvTranspose2d,
|
|
torch.nn.functional.conv_transpose2d,
|
|
(32, 2),
|
|
(192, 16, 8, 6),
|
|
(16, 32, 2, 2),
|
|
(32,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose2d_kwargs(self):
|
|
self._check_lazy_conv(
|
|
nn.ConvTranspose2d,
|
|
nn.LazyConvTranspose2d,
|
|
torch.nn.functional.conv_transpose2d,
|
|
(32, 2),
|
|
(192, 16, 8, 6),
|
|
(16, 32, 2, 2),
|
|
(32,),
|
|
output_size=(9, 7),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose2d_pickle(self):
|
|
self._check_lazy_conv_pickle(
|
|
nn.ConvTranspose2d,
|
|
nn.LazyConvTranspose2d,
|
|
(32, 2),
|
|
(192, 16, 8, 6),
|
|
(16, 32, 2, 2),
|
|
(32,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose2d_state(self):
|
|
self._check_lazy_conv_state(
|
|
lambda: nn.ConvTranspose2d(16, 32, 2),
|
|
lambda: nn.LazyConvTranspose2d(32, 2),
|
|
(16, 32, 2, 2),
|
|
(32,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose3d(self):
|
|
self._check_lazy_conv(
|
|
nn.ConvTranspose3d,
|
|
nn.LazyConvTranspose3d,
|
|
torch.nn.functional.conv_transpose3d,
|
|
(32, 2),
|
|
(192, 16, 8, 7, 6),
|
|
(16, 32, 2, 2, 2),
|
|
(32,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose3d_kwargs(self):
|
|
self._check_lazy_conv(
|
|
nn.ConvTranspose3d,
|
|
nn.LazyConvTranspose3d,
|
|
torch.nn.functional.conv_transpose3d,
|
|
(32, 2),
|
|
(192, 16, 8, 7, 6),
|
|
(16, 32, 2, 2, 2),
|
|
(32,),
|
|
output_size=(9, 8, 7),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose3d_pickle(self):
|
|
self._check_lazy_conv_pickle(
|
|
nn.ConvTranspose3d,
|
|
nn.LazyConvTranspose3d,
|
|
(32, 2),
|
|
(192, 16, 8, 7, 6),
|
|
(16, 32, 2, 2, 2),
|
|
(32,),
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose3d_state(self):
|
|
self._check_lazy_conv_state(
|
|
lambda: nn.ConvTranspose3d(16, 32, 2),
|
|
lambda: nn.LazyConvTranspose3d(32, 2),
|
|
(16, 32, 2, 2, 2),
|
|
(32,),
|
|
)
|
|
|
|
def _check_lazy_norm(self, cls, lazy_cls, input_shape):
|
|
for affine in [False, True]:
|
|
for track_running_stats in [False, True]:
|
|
lazy_module = lazy_cls(
|
|
affine=affine, track_running_stats=track_running_stats
|
|
)
|
|
|
|
if affine:
|
|
self.assertIsInstance(lazy_module.weight, UninitializedParameter)
|
|
self.assertIsInstance(lazy_module.bias, UninitializedParameter)
|
|
if track_running_stats:
|
|
self.assertIsInstance(lazy_module.running_mean, UninitializedBuffer)
|
|
self.assertIsInstance(lazy_module.running_var, UninitializedBuffer)
|
|
|
|
input = torch.ones(*input_shape)
|
|
lazy_output = lazy_module(input)
|
|
self.assertIsInstance(lazy_module, cls)
|
|
self.assertNotIsInstance(lazy_module, lazy_cls)
|
|
|
|
num_features = input_shape[1]
|
|
module = cls(
|
|
num_features, affine=affine, track_running_stats=track_running_stats
|
|
)
|
|
expected_output = module(input)
|
|
|
|
self.assertEqual(lazy_output, expected_output)
|
|
if module.weight is not None:
|
|
self.assertEqual(lazy_module.weight.shape, module.weight.shape)
|
|
self.assertEqual(lazy_module.weight, module.weight)
|
|
if module.bias is not None:
|
|
self.assertEqual(lazy_module.bias.shape, module.bias.shape)
|
|
self.assertEqual(lazy_module.bias, module.bias)
|
|
if module.running_mean is not None:
|
|
self.assertEqual(
|
|
lazy_module.running_mean.shape, module.running_mean.shape
|
|
)
|
|
self.assertEqual(lazy_module.running_mean, module.running_mean)
|
|
if module.running_var is not None:
|
|
self.assertEqual(
|
|
lazy_module.running_var.shape, module.running_var.shape
|
|
)
|
|
self.assertEqual(lazy_module.running_var, module.running_var)
|
|
if module.num_batches_tracked is not None:
|
|
self.assertEqual(
|
|
lazy_module.num_batches_tracked.shape,
|
|
module.num_batches_tracked.shape,
|
|
)
|
|
self.assertEqual(
|
|
lazy_module.num_batches_tracked, module.num_batches_tracked
|
|
)
|
|
|
|
def _check_lazy_norm_pickle(self, cls, lazy_cls, input_shape):
|
|
for affine in [False, True]:
|
|
for track_running_stats in [False, True]:
|
|
module = lazy_cls(
|
|
affine=affine, track_running_stats=track_running_stats
|
|
)
|
|
module = pickle.loads(pickle.dumps(module))
|
|
|
|
self.assertIsInstance(module, lazy_cls)
|
|
if affine:
|
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
|
self.assertIsInstance(module.bias, UninitializedParameter)
|
|
if track_running_stats:
|
|
self.assertIsInstance(module.running_mean, UninitializedBuffer)
|
|
self.assertIsInstance(module.running_var, UninitializedBuffer)
|
|
|
|
input = torch.ones(*input_shape)
|
|
module(input) # fully materialized
|
|
module = pickle.loads(pickle.dumps(module))
|
|
|
|
self.assertNotIsInstance(module, lazy_cls)
|
|
self.assertIsInstance(module, cls)
|
|
if affine:
|
|
self.assertNotIsInstance(module.weight, UninitializedParameter)
|
|
self.assertNotIsInstance(module.bias, UninitializedParameter)
|
|
if track_running_stats:
|
|
self.assertNotIsInstance(module.running_mean, UninitializedBuffer)
|
|
self.assertNotIsInstance(module.running_var, UninitializedBuffer)
|
|
|
|
def _check_lazy_batchnorm_state(self, cls, lazy_cls):
|
|
module = cls(10)
|
|
lazy_module = lazy_cls(affine=True, track_running_stats=True)
|
|
lazy_module.load_state_dict(module.state_dict())
|
|
# Parameters have been initialized but the module won't become a full
|
|
# Conv one until the first iteration. This is due to
|
|
# limitations on the state_dict loading logic
|
|
self.assertFalse(lazy_module.has_uninitialized_params())
|
|
self.assertEqual(lazy_module.weight.shape, (10,))
|
|
self.assertEqual(lazy_module.bias.shape, (10,))
|
|
self.assertEqual(lazy_module.running_mean.shape, (10,))
|
|
self.assertEqual(lazy_module.running_var.shape, (10,))
|
|
|
|
module = cls(10)
|
|
lazy_module = lazy_cls()
|
|
with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
|
|
module.load_state_dict(lazy_module.state_dict())
|
|
|
|
def _check_lazy_instancenorm_state(self, cls, lazy_cls):
|
|
for affine in [False, True]:
|
|
for track_running_stats in [False, True]:
|
|
module = cls(10, affine=affine, track_running_stats=track_running_stats)
|
|
lazy_module = lazy_cls(
|
|
affine=affine, track_running_stats=track_running_stats
|
|
)
|
|
lazy_module.load_state_dict(module.state_dict())
|
|
# Parameters have been initialized but the module won't become a full
|
|
# InstanceNorm one until the first iteration. This is due to
|
|
# limitations on the state_dict loading logic
|
|
self.assertFalse(lazy_module.has_uninitialized_params())
|
|
if affine:
|
|
self.assertEqual(lazy_module.weight.shape, (10,))
|
|
self.assertEqual(lazy_module.bias.shape, (10,))
|
|
if track_running_stats:
|
|
self.assertEqual(lazy_module.running_mean.shape, (10,))
|
|
self.assertEqual(lazy_module.running_var.shape, (10,))
|
|
|
|
module = cls(10, affine=True, track_running_stats=True)
|
|
lazy_module = lazy_cls(affine=True, track_running_stats=True)
|
|
with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
|
|
module.load_state_dict(lazy_module.state_dict())
|
|
|
|
def _check_lazy_norm_with_dict_input(self, cls, lazy_cls, input_shape):
|
|
input = {"input": torch.ones(*input_shape)}
|
|
|
|
lazy_module = lazy_cls()
|
|
lazy_output = lazy_module(**input)
|
|
|
|
num_features = input_shape[1]
|
|
module = cls(num_features)
|
|
expected_output = module(**input)
|
|
|
|
self.assertEqual(lazy_output, expected_output)
|
|
|
|
def test_lazy_batchnorm1d(self):
|
|
self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
|
|
self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
|
|
|
|
def test_lazy_batchnorm1d_pickle(self):
|
|
self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
|
|
self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
|
|
|
|
def test_lazy_batchnorm1d_state(self):
|
|
self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
|
|
self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
|
|
|
|
def test_lazy_batchnorm2d(self):
|
|
self._check_lazy_norm(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
|
|
|
|
def test_lazy_batchnorm2d_pickle(self):
|
|
self._check_lazy_norm_pickle(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
|
|
|
|
def test_lazy_batchnorm2d_state(self):
|
|
self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
|
|
self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
|
|
|
|
def test_lazy_batchnorm3d(self):
|
|
self._check_lazy_norm(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
|
|
|
|
def test_lazy_batchnorm3d_pickle(self):
|
|
self._check_lazy_norm_pickle(
|
|
nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8)
|
|
)
|
|
|
|
def test_lazy_batchnorm3d_state(self):
|
|
self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
|
|
self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
|
|
|
|
def test_lazy_instancenorm1d(self):
|
|
self._check_lazy_norm(nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6))
|
|
|
|
def test_lazy_instancenorm1d_pickle(self):
|
|
self._check_lazy_norm_pickle(
|
|
nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6)
|
|
)
|
|
|
|
def test_lazy_instancenorm1d_state(self):
|
|
self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d)
|
|
self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d)
|
|
|
|
def test_lazy_instancenorm2d(self):
|
|
self._check_lazy_norm(nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7))
|
|
|
|
def test_lazy_instancenorm2d_pickle(self):
|
|
self._check_lazy_norm_pickle(
|
|
nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7)
|
|
)
|
|
|
|
def test_lazy_instancenorm2d_state(self):
|
|
self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d)
|
|
self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d)
|
|
|
|
def test_lazy_instancenorm3d(self):
|
|
self._check_lazy_norm(
|
|
nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8)
|
|
)
|
|
|
|
def test_lazy_instancenorm3d_pickle(self):
|
|
self._check_lazy_norm_pickle(
|
|
nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8)
|
|
)
|
|
|
|
def test_lazy_instancenorm3d_state(self):
|
|
self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d)
|
|
self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d)
|
|
|
|
def test_lazy_batchnorm_with_dict_input(self):
|
|
self._check_lazy_norm_with_dict_input(
|
|
nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6)
|
|
)
|
|
self._check_lazy_norm_with_dict_input(
|
|
nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7)
|
|
)
|
|
self._check_lazy_norm_with_dict_input(
|
|
nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8)
|
|
)
|
|
|
|
@suppress_warnings
|
|
def test_materialize_dtype(self):
|
|
module = LazyModule()
|
|
module.register_parameter("test_param", UninitializedParameter())
|
|
module.test_param.materialize(10)
|
|
self.assertTrue(module.test_param.dtype == torch.get_default_dtype())
|
|
module = LazyModule()
|
|
module.register_parameter("test_param", UninitializedParameter())
|
|
module.half()
|
|
module.test_param.materialize(10)
|
|
self.assertTrue(module.test_param.dtype == torch.float16)
|
|
|
|
@unittest.skipIf(
|
|
not (TEST_CUDA or TEST_PRIVATEUSE1), "CUDA and PRIVATEUSE1 not available"
|
|
)
|
|
@suppress_warnings
|
|
def test_materialize_device(self):
|
|
module = LazyModule()
|
|
module.register_parameter("test_param", UninitializedParameter())
|
|
module.test_param.materialize(10)
|
|
self.assertTrue(module.test_param.device.type == "cpu")
|
|
if TEST_CUDA:
|
|
device = "cuda"
|
|
elif TEST_PRIVATEUSE1:
|
|
device = torch._C._get_privateuse1_backend_name()
|
|
module = LazyModule()
|
|
module.register_parameter("test_param", UninitializedParameter())
|
|
module.to(device)
|
|
module.test_param.materialize(10)
|
|
self.assertTrue(module.test_param.device.type == device)
|
|
|
|
@suppress_warnings
|
|
def test_chained_initialization(self):
|
|
class MyNetwork(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear_1 = torch.nn.LazyLinear(15)
|
|
self.linear_2 = torch.nn.LazyLinear(10)
|
|
|
|
def forward(self, x):
|
|
y = self.linear_1(x)
|
|
return self.linear_2(y)
|
|
|
|
net = MyNetwork()
|
|
net(torch.ones(5, 10))
|
|
self.assertTrue(net.linear_1.weight.shape == (15, 10))
|
|
self.assertTrue(net.linear_1.bias.shape == (15,))
|
|
self.assertTrue(net.linear_2.weight.shape == (10, 15))
|
|
self.assertTrue(net.linear_2.bias.shape == (10,))
|
|
|
|
@suppress_warnings
|
|
def test_optimizer_pass(self):
|
|
optimizers = [
|
|
torch.optim.Adadelta,
|
|
torch.optim.Adagrad,
|
|
torch.optim.Adamax,
|
|
torch.optim.Adam,
|
|
torch.optim.AdamW,
|
|
torch.optim.ASGD,
|
|
torch.optim.SGD,
|
|
torch.optim.Rprop,
|
|
torch.optim.RMSprop,
|
|
torch.optim.LBFGS,
|
|
torch.optim.NAdam,
|
|
torch.optim.RAdam,
|
|
]
|
|
|
|
def run_step(module, optim):
|
|
self.assertIsInstance(
|
|
optim.param_groups[0]["params"][0], UninitializedParameter
|
|
)
|
|
module.test_param.materialize(10)
|
|
self.assertIsInstance(optim.param_groups[0]["params"][0], Parameter)
|
|
self.assertNotIsInstance(
|
|
optim.param_groups[0]["params"][0], UninitializedParameter
|
|
)
|
|
for p in module.parameters():
|
|
p.grad = torch.rand_like(p)
|
|
if isinstance(optim, torch.optim.LBFGS):
|
|
optim.step(lambda: 1.0)
|
|
else:
|
|
optim.step()
|
|
|
|
for optim_cls in optimizers:
|
|
module = LazyModule()
|
|
module.register_parameter("test_param", UninitializedParameter())
|
|
if optim_cls is torch.optim.SGD:
|
|
optim = optim_cls(module.parameters(), lr=0.0)
|
|
elif optim_cls is torch.optim.Adagrad:
|
|
with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
|
|
optim = optim_cls(module.parameters())
|
|
continue
|
|
else:
|
|
optim = optim_cls(module.parameters())
|
|
run_step(module, optim)
|
|
|
|
@suppress_warnings
|
|
def test_weight_norm(self):
|
|
m = nn.LazyLinear(7)
|
|
with self.assertRaisesRegex(ValueError, "have uninitialized parameters."):
|
|
m = torch.nn.utils.weight_norm(m)
|
|
|
|
@suppress_warnings
|
|
def test_spectral_norm(self):
|
|
m = nn.LazyLinear(7)
|
|
with self.assertRaisesRegex(ValueError, "have uninitialized parameters."):
|
|
m = torch.nn.utils.spectral_norm(m)
|
|
|
|
@suppress_warnings
|
|
def test_invalid_functions(self):
|
|
param = torch.nn.parameter.UninitializedParameter()
|
|
with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
|
|
torch.empty_like(param)
|
|
|
|
with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
|
|
torch.add(param, param)
|
|
|
|
with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
|
|
param + param
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|