mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add swap_tensors path to nn parametrizations (#124130)
Fixes #123859 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124130 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
87f651c7e7
commit
5ba6bb7b2f
@ -9,6 +9,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
import torch.nn.utils.parametrize as parametrize
|
||||
from torch import Tensor
|
||||
from torch.__future__ import get_swap_module_params_on_conversion
|
||||
from torch.nn import Parameter
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
@ -20,8 +22,10 @@ from torch.testing._internal.common_utils import (
|
||||
set_default_dtype,
|
||||
skipIfNoLapack,
|
||||
skipIfTorchDynamo,
|
||||
swap,
|
||||
TemporaryFileName,
|
||||
)
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
|
||||
|
||||
class TestNNParametrization(NNTestCase):
|
||||
@ -32,6 +36,7 @@ class TestNNParametrization(NNTestCase):
|
||||
# and remove the `@skipIfNoLapack` (see #70995)
|
||||
# torch/nn/utils/parametrize
|
||||
@skipIfNoLapack
|
||||
@swap([True, False])
|
||||
def test_register_and_remove_parametrization(self):
|
||||
r"""Test that it is possible to add a few parametrizations
|
||||
on a parameter or a buffer and that removing them restores the initial state
|
||||
@ -94,8 +99,7 @@ class TestNNParametrization(NNTestCase):
|
||||
self.assertTrue(parametrize.is_parametrized(model, "weight"))
|
||||
self.assertFalse(parametrize.is_parametrized(model, "bias"))
|
||||
self.assertNotIn("weight", model._parameters)
|
||||
A = model.weight
|
||||
self.assertTrue(A.shape[0] == 1)
|
||||
self.assertTrue(model.weight.shape[0] == 1)
|
||||
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
|
||||
self.assertFalse(hasattr(model, "parametrizations"))
|
||||
self.assertEqual(model.weight, initial_model.weight)
|
||||
@ -110,8 +114,7 @@ class TestNNParametrization(NNTestCase):
|
||||
self.assertTrue(parametrize.is_parametrized(model, "weight"))
|
||||
self.assertFalse(parametrize.is_parametrized(model, "bias"))
|
||||
self.assertNotIn("weight", model._parameters)
|
||||
A = model.weight
|
||||
self.assertTrue(A.shape[0] == 1)
|
||||
self.assertTrue(model.weight.shape[0] == 1)
|
||||
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
|
||||
self.assertFalse(hasattr(model, "parametrizations"))
|
||||
self.assertEqual(model.weight, initial_model.weight)
|
||||
@ -128,6 +131,10 @@ class TestNNParametrization(NNTestCase):
|
||||
# Result should be skew-symmetric
|
||||
A = model.weight
|
||||
self.assertEqual(A, -A.T)
|
||||
if get_swap_module_params_on_conversion():
|
||||
# When using the swap_tensors path, this is needed so that the autograd
|
||||
# graph is not alive anymore.
|
||||
del A
|
||||
# Remove and check consistency
|
||||
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
|
||||
self.assertFalse(hasattr(model, "parametrizations"))
|
||||
@ -145,6 +152,10 @@ class TestNNParametrization(NNTestCase):
|
||||
# Result should be skew-symmetric
|
||||
A = model.weight
|
||||
self.assertEqual(A, -A.T)
|
||||
if get_swap_module_params_on_conversion():
|
||||
# When using the swap_tensors path, this is needed so that the autograd
|
||||
# graph is not alive anymore.
|
||||
del A
|
||||
# Remove and check consistency
|
||||
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
|
||||
self.assertFalse(hasattr(model, "parametrizations"))
|
||||
@ -159,6 +170,10 @@ class TestNNParametrization(NNTestCase):
|
||||
X = model.weight
|
||||
Id = torch.eye(X.size(0), device=X.device)
|
||||
self.assertEqual(X.T @ X, Id)
|
||||
if get_swap_module_params_on_conversion():
|
||||
# When using the swap_tensors path, this is needed so that the autograd
|
||||
# graph is not alive anymore.
|
||||
del X
|
||||
# Structure tests
|
||||
self.assertTrue(hasattr(model, "parametrizations"))
|
||||
self.assertTrue(parametrize.is_parametrized(model))
|
||||
@ -246,6 +261,10 @@ class TestNNParametrization(NNTestCase):
|
||||
sgd.step()
|
||||
self.assertNotEqual(model.weight, weight_copy)
|
||||
self.assertNotEqual(model.bias, bias_copy)
|
||||
if get_swap_module_params_on_conversion():
|
||||
# When using the swap_tensors path, this is needed so that the autograd
|
||||
# graph is not alive anymore.
|
||||
del weight_copy, bias_copy
|
||||
|
||||
# Test leave_parametrized=True
|
||||
for _ in range(2):
|
||||
@ -266,7 +285,12 @@ class TestNNParametrization(NNTestCase):
|
||||
sgd.step()
|
||||
self.assertNotEqual(model.weight, weight_copy)
|
||||
self.assertNotEqual(model.bias, bias_copy)
|
||||
if get_swap_module_params_on_conversion():
|
||||
# When using the swap_tensors path, this is needed so that the autograd
|
||||
# graph is not alive anymore.
|
||||
del weight_copy, bias_copy
|
||||
|
||||
@swap([True, False])
|
||||
def test_register_and_remove_nested_parametrization(self):
|
||||
r"""Test that it is possible to nest the parametrizations
|
||||
meaning that the original param is parametrized again
|
||||
@ -288,6 +312,10 @@ class TestNNParametrization(NNTestCase):
|
||||
# Result should be skew-symmetric
|
||||
A = model.weight
|
||||
self.assertEqual(A, -A.T)
|
||||
if get_swap_module_params_on_conversion():
|
||||
# When using the swap_tensors path, this is needed so that the autograd
|
||||
# graph is not alive anymore.
|
||||
del A
|
||||
|
||||
# Add nested parametrization
|
||||
param_mod = model.parametrizations.weight
|
||||
@ -316,6 +344,7 @@ class TestNNParametrization(NNTestCase):
|
||||
self.assertFalse(hasattr(model, "parametrizations"))
|
||||
self.assertEqual(model.__class__, nn.Linear)
|
||||
|
||||
@swap([True, False])
|
||||
def test_register_and_remove_buffer_parametrization(self):
|
||||
r"""Test that it is possible to add and remove parametrizations on buffers"""
|
||||
|
||||
@ -354,6 +383,7 @@ class TestNNParametrization(NNTestCase):
|
||||
# FIXME: Rewrite this test using functions not depending on LAPACK
|
||||
# and remove the `@skipIfNoLapack` (see #70995)
|
||||
@skipIfNoLapack
|
||||
@swap([True, False])
|
||||
def test_serialization_parametrization(self):
|
||||
r"""Test that it is possible to serialize a parametrized model via state_dict"""
|
||||
|
||||
@ -403,6 +433,7 @@ class TestNNParametrization(NNTestCase):
|
||||
# FIXME: Rewrite this test using functions not depending on LAPACK
|
||||
# and remove the `@skipIfNoLapack` (see #70995)
|
||||
@skipIfNoLapack
|
||||
@swap([True, False])
|
||||
def test_initialization_parametrization(self):
|
||||
r"""Test that it is possible to initialize a parametrization when it
|
||||
implements a `right_inverse` method
|
||||
@ -472,6 +503,7 @@ class TestNNParametrization(NNTestCase):
|
||||
self.assertEqual(model.weight, X)
|
||||
self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X))
|
||||
|
||||
@swap([True, False])
|
||||
def test_errors_unparametrized_tensor_parametrization(self):
|
||||
# Test errors when registering a parametrization on an unparametrized tensor
|
||||
module = nn.Linear(3, 4)
|
||||
@ -621,6 +653,7 @@ class TestNNParametrization(NNTestCase):
|
||||
self.assertFalse(parametrize.is_parametrized(module))
|
||||
self.assertEqual(module.weight, weight_init)
|
||||
|
||||
@swap([True, False])
|
||||
def test_errors_parametrized_tensor_parametrization(self):
|
||||
# Test errors when registering a parametrization on a parametrized tensor
|
||||
|
||||
@ -702,6 +735,7 @@ class TestNNParametrization(NNTestCase):
|
||||
# FIXME: Rewrite this test using functions not depending on LAPACK
|
||||
# and remove the `@skipIfNoLapack` (see #70995)
|
||||
@skipIfNoLapack
|
||||
@swap([True, False])
|
||||
def test_multiple_inputs_parametrization(self):
|
||||
# A parametrization with several outputs
|
||||
class RankOne(nn.Module):
|
||||
@ -803,6 +837,7 @@ class TestNNParametrization(NNTestCase):
|
||||
# FIXME: Rewrite this test using functions not depending on LAPACK
|
||||
# and remove the `@skipIfNoLapack` (see #70995)
|
||||
@skipIfNoLapack
|
||||
@swap([True, False])
|
||||
def test_caching_parametrization(self):
|
||||
r"""Test the caching system of a parametrization"""
|
||||
|
||||
@ -830,6 +865,7 @@ class TestNNParametrization(NNTestCase):
|
||||
# FIXME: Rewrite this test using functions not depending on LAPACK
|
||||
# and remove the `@skipIfNoLapack` (see #70995)
|
||||
@skipIfNoLapack
|
||||
@swap([True, False])
|
||||
def test_caching_parametrization_with_transfer_parametrizations_and_params(self):
|
||||
r"""Test that transferring parametrizations doesn't cause issues with caching"""
|
||||
|
||||
@ -862,6 +898,7 @@ class TestNNParametrization(NNTestCase):
|
||||
# test that the results are distinct objects for each module
|
||||
self.assertNotEqual(id(A), id(X))
|
||||
|
||||
@swap([True, False])
|
||||
def test_parametrization_same_training_mode(self):
|
||||
r"""Test training mode updated on parametrization registration"""
|
||||
|
||||
@ -878,6 +915,7 @@ class TestNNParametrization(NNTestCase):
|
||||
self.assertTrue(module.parametrizations.weight[0].training)
|
||||
self.assertTrue(module.parametrizations.weight[1].training)
|
||||
|
||||
@swap([True, False])
|
||||
def test_type_before_parametrizations(self):
|
||||
r"""Test that type_before_parametrizations always retrieves original type"""
|
||||
|
||||
@ -895,6 +933,7 @@ class TestNNParametrization(NNTestCase):
|
||||
parametrize.type_before_parametrizations(model) == original_type
|
||||
)
|
||||
|
||||
@swap([True, False])
|
||||
def test_deepcopy_after_parametrization(self):
|
||||
r"""Test that we are able to create a deepcopy of the module when it's parametrized."""
|
||||
|
||||
@ -955,6 +994,7 @@ class TestNNParametrization(NNTestCase):
|
||||
parametrize.register_parametrization(model, "weight", AddOne())
|
||||
check_deepcopy(model, deepcopy(model))
|
||||
|
||||
@swap([True, False])
|
||||
def test_transfer_parametrizations_and_params(self):
|
||||
r"""Test that all parametrizations and their associated parameters are transferred."""
|
||||
|
||||
@ -994,6 +1034,10 @@ class TestNNParametrization(NNTestCase):
|
||||
|
||||
# check that the transfer didn't affect the original value
|
||||
self.assertEqual(hold_weight, model.weight)
|
||||
if get_swap_module_params_on_conversion():
|
||||
# When using the swap_tensors path, this is needed so that the autograd
|
||||
# graph is not alive anymore.
|
||||
del hold_weight
|
||||
|
||||
# testing that changes to one set of parametrizations do not affect the other
|
||||
parametrize.remove_parametrizations(to_model, "weight")
|
||||
@ -1018,6 +1062,7 @@ class TestNNParametrization(NNTestCase):
|
||||
# check that the new transfer didn't change the value for the from_module
|
||||
self.assertEqual(hold_test_param, model.test_param)
|
||||
|
||||
@swap([True, False])
|
||||
def test_transfer_parametrizations_and_params_right_inverse(self):
|
||||
r"""Test that all parametrizations and their associated parameters are transferred."""
|
||||
|
||||
@ -1047,6 +1092,7 @@ class TestNNParametrization(NNTestCase):
|
||||
# check that transfer doesn't affect the from_model weight
|
||||
self.assertEqual(hold_weight, model.weight)
|
||||
|
||||
@swap([True, False])
|
||||
def test_transfer_parametrizations_and_params_single_param(self):
|
||||
r"""Test that all parametrizations and their associated parameters are transferred."""
|
||||
|
||||
@ -1086,6 +1132,7 @@ class TestNNParametrization(NNTestCase):
|
||||
# FIXME: Rewrite this test using functions not depending on LAPACK
|
||||
# and remove the `@skipIfNoLapack` (see #70995)
|
||||
@skipIfNoLapack
|
||||
@swap([True, False])
|
||||
def test_transfer_parametrizations_and_params_many_to_one(self):
|
||||
# A parametrization with several outputs
|
||||
class RankOne(nn.Module):
|
||||
@ -1152,6 +1199,7 @@ class TestNNParametrization(NNTestCase):
|
||||
# check that the new transfer didn't change the value for the from_module
|
||||
self.assertEqual(hold_test_param, model.test_param)
|
||||
|
||||
@swap([True, False])
|
||||
def test_new_spectral_norm(self):
|
||||
with set_default_dtype(torch.double):
|
||||
input = torch.randn(3, 5)
|
||||
@ -1289,16 +1337,30 @@ class TestNNParametrization(NNTestCase):
|
||||
# avoid doing another power iteration
|
||||
m, wrapped_m, _ = get_modules()
|
||||
pre_remove_out = wrapped_m(input)
|
||||
if get_swap_module_params_on_conversion():
|
||||
# When using the swap_tensors path, this is needed so that the autograd
|
||||
# graph is not alive anymore.
|
||||
pre_remove_out_ref = pre_remove_out.detach()
|
||||
del pre_remove_out
|
||||
else:
|
||||
pre_remove_out_ref = pre_remove_out
|
||||
m.eval()
|
||||
m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
|
||||
self.assertEqual(wrapped_m(input), pre_remove_out)
|
||||
self.assertEqual(wrapped_m(input), pre_remove_out_ref)
|
||||
|
||||
torch.nn.utils.parametrizations.spectral_norm(m)
|
||||
for _ in range(3):
|
||||
pre_remove_out = wrapped_m(input)
|
||||
if get_swap_module_params_on_conversion():
|
||||
# When using the swap_tensors path, this is needed so that the autograd
|
||||
# graph is not alive anymore.
|
||||
pre_remove_out_ref = pre_remove_out.detach()
|
||||
del pre_remove_out
|
||||
else:
|
||||
pre_remove_out_ref = pre_remove_out
|
||||
m.eval()
|
||||
m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
|
||||
self.assertEqual(wrapped_m(input), pre_remove_out)
|
||||
self.assertEqual(wrapped_m(input), pre_remove_out_ref)
|
||||
|
||||
# TEST EVAL BEHAVIOR
|
||||
m, wrapped_m, spectral_norm_m = get_modules()
|
||||
@ -1352,6 +1414,7 @@ class TestNNParametrization(NNTestCase):
|
||||
|
||||
gradcheck(fn, (m.parametrizations.weight.original,))
|
||||
|
||||
@swap([True, False])
|
||||
def test_new_spectral_norm_load_state_dict(self):
|
||||
for activate_times in (0, 3):
|
||||
inp = torch.randn(2, 3)
|
||||
@ -1431,6 +1494,7 @@ class TestNNParametrization(NNTestCase):
|
||||
snm.eval()
|
||||
self.assertEqual(out3_eval, snm(inp))
|
||||
|
||||
@swap([True, False])
|
||||
def test_new_spectral_norm_dim(self):
|
||||
inp = torch.randn(2, 3, 10, 12)
|
||||
m = nn.ConvTranspose2d(3, 4, (5, 6))
|
||||
@ -1443,6 +1507,7 @@ class TestNNParametrization(NNTestCase):
|
||||
snm._u.shape, m.parametrizations.weight.original[0, :, 0, 0].shape
|
||||
)
|
||||
|
||||
@swap([True, False])
|
||||
def test_new_spectral_norm_forward(self):
|
||||
input = torch.randn(3, 5)
|
||||
m = nn.Linear(5, 7)
|
||||
@ -1461,6 +1526,7 @@ class TestNNParametrization(NNTestCase):
|
||||
expect_out = m(input)
|
||||
self.assertEqual(expect_out, out_hat)
|
||||
|
||||
@swap([True, False])
|
||||
@skipIfTorchDynamo("Test does not work with TorchDynamo")
|
||||
def test_new_spectral_norm_value(self):
|
||||
# a test that the spectral norm (= top singular value)
|
||||
@ -1477,6 +1543,7 @@ class TestNNParametrization(NNTestCase):
|
||||
self.assertEqual(m.weight.data, expected)
|
||||
|
||||
@skipIfNoLapack
|
||||
@swap([True, False])
|
||||
def test_orthogonal_parametrization(self):
|
||||
# Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization)
|
||||
|
||||
@ -1532,7 +1599,13 @@ class TestNNParametrization(NNTestCase):
|
||||
|
||||
# We do not support householder for complex inputs
|
||||
# See Note [Householder complex]
|
||||
w_init = m.weight.clone()
|
||||
|
||||
# When using the swap_tensors path, this is needed so that the autograd
|
||||
# graph is not alive anymore.
|
||||
if get_swap_module_params_on_conversion():
|
||||
w_init = m.weight.clone().detach()
|
||||
else:
|
||||
w_init = m.weight.clone()
|
||||
if parametrization == "householder" and m.weight.is_complex():
|
||||
msg = "householder parametrization does not support complex tensors"
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
@ -1605,6 +1678,7 @@ class TestNNParametrization(NNTestCase):
|
||||
assert_is_orthogonal(m.weight)
|
||||
|
||||
@skipIfNoLapack
|
||||
@swap([True, False])
|
||||
def test_orthogonal_errors(self):
|
||||
m = nn.Linear(3, 4)
|
||||
with self.assertRaisesRegex(ValueError, "has to be one of"):
|
||||
@ -1618,6 +1692,7 @@ class TestNNParametrization(NNTestCase):
|
||||
m.weight = torch.randn(5, 5)
|
||||
torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
|
||||
|
||||
@swap([True, False])
|
||||
def test_weight_norm_state_dict_compat(self):
|
||||
m = nn.Linear(4, 5)
|
||||
m = torch.nn.utils.weight_norm(m)
|
||||
@ -1630,12 +1705,14 @@ class TestNNParametrization(NNTestCase):
|
||||
input = torch.randn(3, 4)
|
||||
self.assertEqual(m(input), m2(input))
|
||||
|
||||
@swap([True, False])
|
||||
def test_weight_norm_pickle(self):
|
||||
m = nn.Linear(4, 5)
|
||||
m = torch.nn.utils.parametrizations.weight_norm(m)
|
||||
with self.assertRaisesRegex(RuntimeError, "state_dict"):
|
||||
pickle.dumps(m)
|
||||
|
||||
@swap([True, False])
|
||||
def test_weight_norm_deepcopy(self):
|
||||
m = nn.Linear(4, 5)
|
||||
m = torch.nn.utils.parametrizations.weight_norm(m)
|
||||
@ -1643,8 +1720,130 @@ class TestNNParametrization(NNTestCase):
|
||||
input = torch.randn(3, 4)
|
||||
self.assertEqual(m(input), m2(input))
|
||||
|
||||
@swap([True])
|
||||
def test_wrapper_subclass_parametrization(self):
|
||||
class Subclassify(nn.Module):
|
||||
def forward(self, X):
|
||||
return TwoTensor(X, X)
|
||||
|
||||
class UnSubclassify(nn.Module):
|
||||
def forward(self, X):
|
||||
return X.a
|
||||
|
||||
class IdentityWithRightInverse(nn.Module):
|
||||
def forward(self, X):
|
||||
return X
|
||||
|
||||
def right_inverse(self, X):
|
||||
return TwoTensor(X, X)
|
||||
|
||||
def _check_parametrization(
|
||||
parametrization,
|
||||
type_before_registration,
|
||||
type_after_registration,
|
||||
leave_parametrized=False,
|
||||
type_after_right_inverse=None,
|
||||
):
|
||||
model = nn.Linear(2, 2)
|
||||
buf = torch.randn(2, 2)
|
||||
model.register_buffer("buf", buf)
|
||||
if (
|
||||
type_before_registration == TwoTensor
|
||||
and type_after_registration == Tensor
|
||||
):
|
||||
model._apply(lambda t: TwoTensor(t, t))
|
||||
initial_weight = model.weight.clone().detach()
|
||||
initial_weight_id = id(model.weight)
|
||||
initial_buf = model.buf.clone().detach()
|
||||
initial_buf_id = id(model.buf)
|
||||
type_original_weight = (
|
||||
type_before_registration
|
||||
if type_after_right_inverse is None
|
||||
else type_after_right_inverse
|
||||
)
|
||||
type_original_buf = (
|
||||
Tensor if type_original_weight is nn.Parameter else type_original_weight
|
||||
)
|
||||
type_after_removal_buf = (
|
||||
type_after_registration if leave_parametrized else type_original_buf
|
||||
)
|
||||
if leave_parametrized:
|
||||
if type_after_registration is Tensor:
|
||||
type_after_removal_weight = nn.Parameter
|
||||
else:
|
||||
type_after_removal_weight = type_after_registration
|
||||
else:
|
||||
type_after_removal_weight = type_original_weight
|
||||
|
||||
parametrize.register_parametrization(model, "weight", parametrization())
|
||||
parametrize.register_parametrization(model, "buf", parametrization())
|
||||
self.assertTrue(hasattr(model, "parametrizations"))
|
||||
self.assertTrue(parametrize.is_parametrized(model))
|
||||
self.assertFalse(parametrize.is_parametrized(model, "bias"))
|
||||
# checks for weight
|
||||
self.assertTrue(parametrize.is_parametrized(model, "weight"))
|
||||
self.assertTrue(
|
||||
isinstance(model.parametrizations.weight.original, nn.Parameter)
|
||||
)
|
||||
self.assertTrue(
|
||||
type(model.parametrizations.weight.original) is type_original_weight
|
||||
)
|
||||
self.assertNotIn("weight", model._parameters)
|
||||
self.assertTrue(type(model.weight) is type_after_registration)
|
||||
# checks for buf
|
||||
self.assertTrue(parametrize.is_parametrized(model, "buf"))
|
||||
self.assertFalse(
|
||||
isinstance(model.parametrizations.buf.original, nn.Parameter)
|
||||
)
|
||||
self.assertTrue(
|
||||
type(model.parametrizations.buf.original) is type_original_buf
|
||||
)
|
||||
self.assertTrue(type(model.buf) is type_after_registration)
|
||||
parametrize.remove_parametrizations(
|
||||
model, "weight", leave_parametrized=leave_parametrized
|
||||
)
|
||||
parametrize.remove_parametrizations(
|
||||
model, "buf", leave_parametrized=leave_parametrized
|
||||
)
|
||||
self.assertFalse(hasattr(model, "parametrizations"))
|
||||
self.assertEqual(model.__class__, nn.Linear)
|
||||
# checks for weight
|
||||
self.assertTrue(type(model.weight) is type_after_removal_weight)
|
||||
self.assertTrue(isinstance(model.weight, nn.Parameter))
|
||||
self.assertEqual(id(model.weight), initial_weight_id)
|
||||
# checks for buf
|
||||
self.assertTrue(type(model.buf) is type_after_removal_buf)
|
||||
self.assertFalse(isinstance(model.buf, nn.Parameter))
|
||||
self.assertEqual(id(model.buf), initial_buf_id)
|
||||
if not leave_parametrized and type_after_right_inverse is None:
|
||||
self.assertEqual(model.weight, initial_weight)
|
||||
self.assertEqual(model.buf, initial_buf)
|
||||
|
||||
_check_parametrization(Subclassify, nn.Parameter, TwoTensor)
|
||||
_check_parametrization(UnSubclassify, TwoTensor, Tensor)
|
||||
_check_parametrization(
|
||||
IdentityWithRightInverse,
|
||||
nn.Parameter,
|
||||
TwoTensor,
|
||||
type_after_right_inverse=TwoTensor,
|
||||
)
|
||||
_check_parametrization(
|
||||
Subclassify, nn.Parameter, TwoTensor, leave_parametrized=True
|
||||
)
|
||||
_check_parametrization(
|
||||
UnSubclassify, TwoTensor, Tensor, leave_parametrized=True
|
||||
)
|
||||
_check_parametrization(
|
||||
IdentityWithRightInverse,
|
||||
nn.Parameter,
|
||||
TwoTensor,
|
||||
leave_parametrized=True,
|
||||
type_after_right_inverse=TwoTensor,
|
||||
)
|
||||
|
||||
|
||||
class TestNNParametrizationDevice(NNTestCase):
|
||||
@swap([True, False])
|
||||
def test_weight_norm_parametrization(self, device):
|
||||
for dtype in [torch.float, torch.bfloat16]:
|
||||
input = torch.randn(3, 4, dtype=dtype, device=device)
|
||||
|
@ -1,6 +1,8 @@
|
||||
import torch
|
||||
from torch.__future__ import get_swap_module_params_on_conversion
|
||||
from torch.nn.modules.container import ModuleList, ModuleDict, Module
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
from torch import Tensor
|
||||
|
||||
import collections
|
||||
@ -64,6 +66,14 @@ def _register_parameter_or_buffer(module, name, X):
|
||||
else:
|
||||
module.register_buffer(name, X)
|
||||
|
||||
def _maybe_set(dest: Tensor, src: Tensor) -> None:
|
||||
should_swap = get_swap_module_params_on_conversion() or is_traceable_wrapper_subclass(dest)
|
||||
if should_swap:
|
||||
if isinstance(dest, Parameter) and not isinstance(src, Parameter):
|
||||
src = Parameter(src, requires_grad=dest.requires_grad)
|
||||
torch.utils.swap_tensors(dest, src)
|
||||
else:
|
||||
dest.set_(src) # type: ignore[call-overload]
|
||||
|
||||
class ParametrizationList(ModuleList):
|
||||
r"""A sequential container that holds and manages the original parameters or buffers of a parametrized :class:`torch.nn.Module`.
|
||||
@ -157,7 +167,7 @@ class ParametrizationList(ModuleList):
|
||||
# Set the original to original so that the user does not need to re-register the parameter
|
||||
# manually in the optimiser
|
||||
with torch.no_grad():
|
||||
original.set_(new) # type: ignore[call-overload]
|
||||
_maybe_set(original, new)
|
||||
_register_parameter_or_buffer(self, "original", original)
|
||||
else:
|
||||
for i, originali in enumerate(new):
|
||||
@ -231,7 +241,7 @@ class ParametrizationList(ModuleList):
|
||||
f"while `original` has dtype {self.original.dtype}"
|
||||
)
|
||||
# We know that the result is going to have the same dtype
|
||||
self.original.set_(value) # type: ignore[call-overload]
|
||||
_maybe_set(self.original, value)
|
||||
else:
|
||||
if not isinstance(value, collections.abc.Sequence):
|
||||
raise ValueError(
|
||||
@ -255,7 +265,7 @@ class ParametrizationList(ModuleList):
|
||||
f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} "
|
||||
f"while `original{i}` has dtype {original_i.dtype}"
|
||||
)
|
||||
original_i.set_(tensor)
|
||||
_maybe_set(original_i, tensor)
|
||||
|
||||
def forward(self) -> Tensor:
|
||||
if torch.jit.is_scripting():
|
||||
@ -645,18 +655,20 @@ def remove_parametrizations(
|
||||
# This way the user does not need to update the optimizer
|
||||
with torch.no_grad():
|
||||
if type(original) is torch.Tensor:
|
||||
original.set_(t)
|
||||
_maybe_set(original, t)
|
||||
else:
|
||||
try:
|
||||
original.set_(t)
|
||||
_maybe_set(original, t)
|
||||
except RuntimeError as e:
|
||||
# TODO: Fix this for tensor subclasses that are parameters:
|
||||
# RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach().
|
||||
raise RuntimeError("Calling remove_parametrizations() with leave_parametrized=True "
|
||||
"for a parameter that is an instance of a tensor subclass requires "
|
||||
"set_() to be implemented correctly for the tensor subclass. Either "
|
||||
"set leave_parametrized=False or provide a working implementation for "
|
||||
"set_() in the tensor subclass.") from e
|
||||
"set_() to be implemented correctly for the tensor subclass."
|
||||
"Alternatively, one can opt into the swap_tensors path"
|
||||
"Either set leave_parametrized=False or provide a working implementation"
|
||||
"for set_() in the tensor subclass or set "
|
||||
"torch.__future__.set_swap_module_params_on_conversion(True).") from e
|
||||
else:
|
||||
if leave_parametrized:
|
||||
# We cannot use no_grad because we need to know whether one or more
|
||||
|
Reference in New Issue
Block a user