Add swap_tensors path to nn parametrizations (#124130)

Fixes #123859

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124130
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-04-18 09:53:32 -07:00
committed by PyTorch MergeBot
parent 87f651c7e7
commit 5ba6bb7b2f
15 changed files with 226 additions and 15 deletions

View File

@ -9,6 +9,8 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.nn.utils.parametrize as parametrize
from torch import Tensor
from torch.__future__ import get_swap_module_params_on_conversion
from torch.nn import Parameter
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_device_type import instantiate_device_type_tests
@ -20,8 +22,10 @@ from torch.testing._internal.common_utils import (
set_default_dtype,
skipIfNoLapack,
skipIfTorchDynamo,
swap,
TemporaryFileName,
)
from torch.testing._internal.two_tensor import TwoTensor
class TestNNParametrization(NNTestCase):
@ -32,6 +36,7 @@ class TestNNParametrization(NNTestCase):
# and remove the `@skipIfNoLapack` (see #70995)
# torch/nn/utils/parametrize
@skipIfNoLapack
@swap([True, False])
def test_register_and_remove_parametrization(self):
r"""Test that it is possible to add a few parametrizations
on a parameter or a buffer and that removing them restores the initial state
@ -94,8 +99,7 @@ class TestNNParametrization(NNTestCase):
self.assertTrue(parametrize.is_parametrized(model, "weight"))
self.assertFalse(parametrize.is_parametrized(model, "bias"))
self.assertNotIn("weight", model._parameters)
A = model.weight
self.assertTrue(A.shape[0] == 1)
self.assertTrue(model.weight.shape[0] == 1)
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.weight, initial_model.weight)
@ -110,8 +114,7 @@ class TestNNParametrization(NNTestCase):
self.assertTrue(parametrize.is_parametrized(model, "weight"))
self.assertFalse(parametrize.is_parametrized(model, "bias"))
self.assertNotIn("weight", model._parameters)
A = model.weight
self.assertTrue(A.shape[0] == 1)
self.assertTrue(model.weight.shape[0] == 1)
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.weight, initial_model.weight)
@ -128,6 +131,10 @@ class TestNNParametrization(NNTestCase):
# Result should be skew-symmetric
A = model.weight
self.assertEqual(A, -A.T)
if get_swap_module_params_on_conversion():
# When using the swap_tensors path, this is needed so that the autograd
# graph is not alive anymore.
del A
# Remove and check consistency
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
self.assertFalse(hasattr(model, "parametrizations"))
@ -145,6 +152,10 @@ class TestNNParametrization(NNTestCase):
# Result should be skew-symmetric
A = model.weight
self.assertEqual(A, -A.T)
if get_swap_module_params_on_conversion():
# When using the swap_tensors path, this is needed so that the autograd
# graph is not alive anymore.
del A
# Remove and check consistency
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
self.assertFalse(hasattr(model, "parametrizations"))
@ -159,6 +170,10 @@ class TestNNParametrization(NNTestCase):
X = model.weight
Id = torch.eye(X.size(0), device=X.device)
self.assertEqual(X.T @ X, Id)
if get_swap_module_params_on_conversion():
# When using the swap_tensors path, this is needed so that the autograd
# graph is not alive anymore.
del X
# Structure tests
self.assertTrue(hasattr(model, "parametrizations"))
self.assertTrue(parametrize.is_parametrized(model))
@ -246,6 +261,10 @@ class TestNNParametrization(NNTestCase):
sgd.step()
self.assertNotEqual(model.weight, weight_copy)
self.assertNotEqual(model.bias, bias_copy)
if get_swap_module_params_on_conversion():
# When using the swap_tensors path, this is needed so that the autograd
# graph is not alive anymore.
del weight_copy, bias_copy
# Test leave_parametrized=True
for _ in range(2):
@ -266,7 +285,12 @@ class TestNNParametrization(NNTestCase):
sgd.step()
self.assertNotEqual(model.weight, weight_copy)
self.assertNotEqual(model.bias, bias_copy)
if get_swap_module_params_on_conversion():
# When using the swap_tensors path, this is needed so that the autograd
# graph is not alive anymore.
del weight_copy, bias_copy
@swap([True, False])
def test_register_and_remove_nested_parametrization(self):
r"""Test that it is possible to nest the parametrizations
meaning that the original param is parametrized again
@ -288,6 +312,10 @@ class TestNNParametrization(NNTestCase):
# Result should be skew-symmetric
A = model.weight
self.assertEqual(A, -A.T)
if get_swap_module_params_on_conversion():
# When using the swap_tensors path, this is needed so that the autograd
# graph is not alive anymore.
del A
# Add nested parametrization
param_mod = model.parametrizations.weight
@ -316,6 +344,7 @@ class TestNNParametrization(NNTestCase):
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.__class__, nn.Linear)
@swap([True, False])
def test_register_and_remove_buffer_parametrization(self):
r"""Test that it is possible to add and remove parametrizations on buffers"""
@ -354,6 +383,7 @@ class TestNNParametrization(NNTestCase):
# FIXME: Rewrite this test using functions not depending on LAPACK
# and remove the `@skipIfNoLapack` (see #70995)
@skipIfNoLapack
@swap([True, False])
def test_serialization_parametrization(self):
r"""Test that it is possible to serialize a parametrized model via state_dict"""
@ -403,6 +433,7 @@ class TestNNParametrization(NNTestCase):
# FIXME: Rewrite this test using functions not depending on LAPACK
# and remove the `@skipIfNoLapack` (see #70995)
@skipIfNoLapack
@swap([True, False])
def test_initialization_parametrization(self):
r"""Test that it is possible to initialize a parametrization when it
implements a `right_inverse` method
@ -472,6 +503,7 @@ class TestNNParametrization(NNTestCase):
self.assertEqual(model.weight, X)
self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X))
@swap([True, False])
def test_errors_unparametrized_tensor_parametrization(self):
# Test errors when registering a parametrization on an unparametrized tensor
module = nn.Linear(3, 4)
@ -621,6 +653,7 @@ class TestNNParametrization(NNTestCase):
self.assertFalse(parametrize.is_parametrized(module))
self.assertEqual(module.weight, weight_init)
@swap([True, False])
def test_errors_parametrized_tensor_parametrization(self):
# Test errors when registering a parametrization on a parametrized tensor
@ -702,6 +735,7 @@ class TestNNParametrization(NNTestCase):
# FIXME: Rewrite this test using functions not depending on LAPACK
# and remove the `@skipIfNoLapack` (see #70995)
@skipIfNoLapack
@swap([True, False])
def test_multiple_inputs_parametrization(self):
# A parametrization with several outputs
class RankOne(nn.Module):
@ -803,6 +837,7 @@ class TestNNParametrization(NNTestCase):
# FIXME: Rewrite this test using functions not depending on LAPACK
# and remove the `@skipIfNoLapack` (see #70995)
@skipIfNoLapack
@swap([True, False])
def test_caching_parametrization(self):
r"""Test the caching system of a parametrization"""
@ -830,6 +865,7 @@ class TestNNParametrization(NNTestCase):
# FIXME: Rewrite this test using functions not depending on LAPACK
# and remove the `@skipIfNoLapack` (see #70995)
@skipIfNoLapack
@swap([True, False])
def test_caching_parametrization_with_transfer_parametrizations_and_params(self):
r"""Test that transferring parametrizations doesn't cause issues with caching"""
@ -862,6 +898,7 @@ class TestNNParametrization(NNTestCase):
# test that the results are distinct objects for each module
self.assertNotEqual(id(A), id(X))
@swap([True, False])
def test_parametrization_same_training_mode(self):
r"""Test training mode updated on parametrization registration"""
@ -878,6 +915,7 @@ class TestNNParametrization(NNTestCase):
self.assertTrue(module.parametrizations.weight[0].training)
self.assertTrue(module.parametrizations.weight[1].training)
@swap([True, False])
def test_type_before_parametrizations(self):
r"""Test that type_before_parametrizations always retrieves original type"""
@ -895,6 +933,7 @@ class TestNNParametrization(NNTestCase):
parametrize.type_before_parametrizations(model) == original_type
)
@swap([True, False])
def test_deepcopy_after_parametrization(self):
r"""Test that we are able to create a deepcopy of the module when it's parametrized."""
@ -955,6 +994,7 @@ class TestNNParametrization(NNTestCase):
parametrize.register_parametrization(model, "weight", AddOne())
check_deepcopy(model, deepcopy(model))
@swap([True, False])
def test_transfer_parametrizations_and_params(self):
r"""Test that all parametrizations and their associated parameters are transferred."""
@ -994,6 +1034,10 @@ class TestNNParametrization(NNTestCase):
# check that the transfer didn't affect the original value
self.assertEqual(hold_weight, model.weight)
if get_swap_module_params_on_conversion():
# When using the swap_tensors path, this is needed so that the autograd
# graph is not alive anymore.
del hold_weight
# testing that changes to one set of parametrizations do not affect the other
parametrize.remove_parametrizations(to_model, "weight")
@ -1018,6 +1062,7 @@ class TestNNParametrization(NNTestCase):
# check that the new transfer didn't change the value for the from_module
self.assertEqual(hold_test_param, model.test_param)
@swap([True, False])
def test_transfer_parametrizations_and_params_right_inverse(self):
r"""Test that all parametrizations and their associated parameters are transferred."""
@ -1047,6 +1092,7 @@ class TestNNParametrization(NNTestCase):
# check that transfer doesn't affect the from_model weight
self.assertEqual(hold_weight, model.weight)
@swap([True, False])
def test_transfer_parametrizations_and_params_single_param(self):
r"""Test that all parametrizations and their associated parameters are transferred."""
@ -1086,6 +1132,7 @@ class TestNNParametrization(NNTestCase):
# FIXME: Rewrite this test using functions not depending on LAPACK
# and remove the `@skipIfNoLapack` (see #70995)
@skipIfNoLapack
@swap([True, False])
def test_transfer_parametrizations_and_params_many_to_one(self):
# A parametrization with several outputs
class RankOne(nn.Module):
@ -1152,6 +1199,7 @@ class TestNNParametrization(NNTestCase):
# check that the new transfer didn't change the value for the from_module
self.assertEqual(hold_test_param, model.test_param)
@swap([True, False])
def test_new_spectral_norm(self):
with set_default_dtype(torch.double):
input = torch.randn(3, 5)
@ -1289,16 +1337,30 @@ class TestNNParametrization(NNTestCase):
# avoid doing another power iteration
m, wrapped_m, _ = get_modules()
pre_remove_out = wrapped_m(input)
if get_swap_module_params_on_conversion():
# When using the swap_tensors path, this is needed so that the autograd
# graph is not alive anymore.
pre_remove_out_ref = pre_remove_out.detach()
del pre_remove_out
else:
pre_remove_out_ref = pre_remove_out
m.eval()
m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
self.assertEqual(wrapped_m(input), pre_remove_out)
self.assertEqual(wrapped_m(input), pre_remove_out_ref)
torch.nn.utils.parametrizations.spectral_norm(m)
for _ in range(3):
pre_remove_out = wrapped_m(input)
if get_swap_module_params_on_conversion():
# When using the swap_tensors path, this is needed so that the autograd
# graph is not alive anymore.
pre_remove_out_ref = pre_remove_out.detach()
del pre_remove_out
else:
pre_remove_out_ref = pre_remove_out
m.eval()
m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
self.assertEqual(wrapped_m(input), pre_remove_out)
self.assertEqual(wrapped_m(input), pre_remove_out_ref)
# TEST EVAL BEHAVIOR
m, wrapped_m, spectral_norm_m = get_modules()
@ -1352,6 +1414,7 @@ class TestNNParametrization(NNTestCase):
gradcheck(fn, (m.parametrizations.weight.original,))
@swap([True, False])
def test_new_spectral_norm_load_state_dict(self):
for activate_times in (0, 3):
inp = torch.randn(2, 3)
@ -1431,6 +1494,7 @@ class TestNNParametrization(NNTestCase):
snm.eval()
self.assertEqual(out3_eval, snm(inp))
@swap([True, False])
def test_new_spectral_norm_dim(self):
inp = torch.randn(2, 3, 10, 12)
m = nn.ConvTranspose2d(3, 4, (5, 6))
@ -1443,6 +1507,7 @@ class TestNNParametrization(NNTestCase):
snm._u.shape, m.parametrizations.weight.original[0, :, 0, 0].shape
)
@swap([True, False])
def test_new_spectral_norm_forward(self):
input = torch.randn(3, 5)
m = nn.Linear(5, 7)
@ -1461,6 +1526,7 @@ class TestNNParametrization(NNTestCase):
expect_out = m(input)
self.assertEqual(expect_out, out_hat)
@swap([True, False])
@skipIfTorchDynamo("Test does not work with TorchDynamo")
def test_new_spectral_norm_value(self):
# a test that the spectral norm (= top singular value)
@ -1477,6 +1543,7 @@ class TestNNParametrization(NNTestCase):
self.assertEqual(m.weight.data, expected)
@skipIfNoLapack
@swap([True, False])
def test_orthogonal_parametrization(self):
# Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization)
@ -1532,7 +1599,13 @@ class TestNNParametrization(NNTestCase):
# We do not support householder for complex inputs
# See Note [Householder complex]
w_init = m.weight.clone()
# When using the swap_tensors path, this is needed so that the autograd
# graph is not alive anymore.
if get_swap_module_params_on_conversion():
w_init = m.weight.clone().detach()
else:
w_init = m.weight.clone()
if parametrization == "householder" and m.weight.is_complex():
msg = "householder parametrization does not support complex tensors"
with self.assertRaisesRegex(ValueError, msg):
@ -1605,6 +1678,7 @@ class TestNNParametrization(NNTestCase):
assert_is_orthogonal(m.weight)
@skipIfNoLapack
@swap([True, False])
def test_orthogonal_errors(self):
m = nn.Linear(3, 4)
with self.assertRaisesRegex(ValueError, "has to be one of"):
@ -1618,6 +1692,7 @@ class TestNNParametrization(NNTestCase):
m.weight = torch.randn(5, 5)
torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
@swap([True, False])
def test_weight_norm_state_dict_compat(self):
m = nn.Linear(4, 5)
m = torch.nn.utils.weight_norm(m)
@ -1630,12 +1705,14 @@ class TestNNParametrization(NNTestCase):
input = torch.randn(3, 4)
self.assertEqual(m(input), m2(input))
@swap([True, False])
def test_weight_norm_pickle(self):
m = nn.Linear(4, 5)
m = torch.nn.utils.parametrizations.weight_norm(m)
with self.assertRaisesRegex(RuntimeError, "state_dict"):
pickle.dumps(m)
@swap([True, False])
def test_weight_norm_deepcopy(self):
m = nn.Linear(4, 5)
m = torch.nn.utils.parametrizations.weight_norm(m)
@ -1643,8 +1720,130 @@ class TestNNParametrization(NNTestCase):
input = torch.randn(3, 4)
self.assertEqual(m(input), m2(input))
@swap([True])
def test_wrapper_subclass_parametrization(self):
class Subclassify(nn.Module):
def forward(self, X):
return TwoTensor(X, X)
class UnSubclassify(nn.Module):
def forward(self, X):
return X.a
class IdentityWithRightInverse(nn.Module):
def forward(self, X):
return X
def right_inverse(self, X):
return TwoTensor(X, X)
def _check_parametrization(
parametrization,
type_before_registration,
type_after_registration,
leave_parametrized=False,
type_after_right_inverse=None,
):
model = nn.Linear(2, 2)
buf = torch.randn(2, 2)
model.register_buffer("buf", buf)
if (
type_before_registration == TwoTensor
and type_after_registration == Tensor
):
model._apply(lambda t: TwoTensor(t, t))
initial_weight = model.weight.clone().detach()
initial_weight_id = id(model.weight)
initial_buf = model.buf.clone().detach()
initial_buf_id = id(model.buf)
type_original_weight = (
type_before_registration
if type_after_right_inverse is None
else type_after_right_inverse
)
type_original_buf = (
Tensor if type_original_weight is nn.Parameter else type_original_weight
)
type_after_removal_buf = (
type_after_registration if leave_parametrized else type_original_buf
)
if leave_parametrized:
if type_after_registration is Tensor:
type_after_removal_weight = nn.Parameter
else:
type_after_removal_weight = type_after_registration
else:
type_after_removal_weight = type_original_weight
parametrize.register_parametrization(model, "weight", parametrization())
parametrize.register_parametrization(model, "buf", parametrization())
self.assertTrue(hasattr(model, "parametrizations"))
self.assertTrue(parametrize.is_parametrized(model))
self.assertFalse(parametrize.is_parametrized(model, "bias"))
# checks for weight
self.assertTrue(parametrize.is_parametrized(model, "weight"))
self.assertTrue(
isinstance(model.parametrizations.weight.original, nn.Parameter)
)
self.assertTrue(
type(model.parametrizations.weight.original) is type_original_weight
)
self.assertNotIn("weight", model._parameters)
self.assertTrue(type(model.weight) is type_after_registration)
# checks for buf
self.assertTrue(parametrize.is_parametrized(model, "buf"))
self.assertFalse(
isinstance(model.parametrizations.buf.original, nn.Parameter)
)
self.assertTrue(
type(model.parametrizations.buf.original) is type_original_buf
)
self.assertTrue(type(model.buf) is type_after_registration)
parametrize.remove_parametrizations(
model, "weight", leave_parametrized=leave_parametrized
)
parametrize.remove_parametrizations(
model, "buf", leave_parametrized=leave_parametrized
)
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.__class__, nn.Linear)
# checks for weight
self.assertTrue(type(model.weight) is type_after_removal_weight)
self.assertTrue(isinstance(model.weight, nn.Parameter))
self.assertEqual(id(model.weight), initial_weight_id)
# checks for buf
self.assertTrue(type(model.buf) is type_after_removal_buf)
self.assertFalse(isinstance(model.buf, nn.Parameter))
self.assertEqual(id(model.buf), initial_buf_id)
if not leave_parametrized and type_after_right_inverse is None:
self.assertEqual(model.weight, initial_weight)
self.assertEqual(model.buf, initial_buf)
_check_parametrization(Subclassify, nn.Parameter, TwoTensor)
_check_parametrization(UnSubclassify, TwoTensor, Tensor)
_check_parametrization(
IdentityWithRightInverse,
nn.Parameter,
TwoTensor,
type_after_right_inverse=TwoTensor,
)
_check_parametrization(
Subclassify, nn.Parameter, TwoTensor, leave_parametrized=True
)
_check_parametrization(
UnSubclassify, TwoTensor, Tensor, leave_parametrized=True
)
_check_parametrization(
IdentityWithRightInverse,
nn.Parameter,
TwoTensor,
leave_parametrized=True,
type_after_right_inverse=TwoTensor,
)
class TestNNParametrizationDevice(NNTestCase):
@swap([True, False])
def test_weight_norm_parametrization(self, device):
for dtype in [torch.float, torch.bfloat16]:
input = torch.randn(3, 4, dtype=dtype, device=device)

View File

@ -1,6 +1,8 @@
import torch
from torch.__future__ import get_swap_module_params_on_conversion
from torch.nn.modules.container import ModuleList, ModuleDict, Module
from torch.nn.parameter import Parameter
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch import Tensor
import collections
@ -64,6 +66,14 @@ def _register_parameter_or_buffer(module, name, X):
else:
module.register_buffer(name, X)
def _maybe_set(dest: Tensor, src: Tensor) -> None:
should_swap = get_swap_module_params_on_conversion() or is_traceable_wrapper_subclass(dest)
if should_swap:
if isinstance(dest, Parameter) and not isinstance(src, Parameter):
src = Parameter(src, requires_grad=dest.requires_grad)
torch.utils.swap_tensors(dest, src)
else:
dest.set_(src) # type: ignore[call-overload]
class ParametrizationList(ModuleList):
r"""A sequential container that holds and manages the original parameters or buffers of a parametrized :class:`torch.nn.Module`.
@ -157,7 +167,7 @@ class ParametrizationList(ModuleList):
# Set the original to original so that the user does not need to re-register the parameter
# manually in the optimiser
with torch.no_grad():
original.set_(new) # type: ignore[call-overload]
_maybe_set(original, new)
_register_parameter_or_buffer(self, "original", original)
else:
for i, originali in enumerate(new):
@ -231,7 +241,7 @@ class ParametrizationList(ModuleList):
f"while `original` has dtype {self.original.dtype}"
)
# We know that the result is going to have the same dtype
self.original.set_(value) # type: ignore[call-overload]
_maybe_set(self.original, value)
else:
if not isinstance(value, collections.abc.Sequence):
raise ValueError(
@ -255,7 +265,7 @@ class ParametrizationList(ModuleList):
f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} "
f"while `original{i}` has dtype {original_i.dtype}"
)
original_i.set_(tensor)
_maybe_set(original_i, tensor)
def forward(self) -> Tensor:
if torch.jit.is_scripting():
@ -645,18 +655,20 @@ def remove_parametrizations(
# This way the user does not need to update the optimizer
with torch.no_grad():
if type(original) is torch.Tensor:
original.set_(t)
_maybe_set(original, t)
else:
try:
original.set_(t)
_maybe_set(original, t)
except RuntimeError as e:
# TODO: Fix this for tensor subclasses that are parameters:
# RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach().
raise RuntimeError("Calling remove_parametrizations() with leave_parametrized=True "
"for a parameter that is an instance of a tensor subclass requires "
"set_() to be implemented correctly for the tensor subclass. Either "
"set leave_parametrized=False or provide a working implementation for "
"set_() in the tensor subclass.") from e
"set_() to be implemented correctly for the tensor subclass."
"Alternatively, one can opt into the swap_tensors path"
"Either set leave_parametrized=False or provide a working implementation"
"for set_() in the tensor subclass or set "
"torch.__future__.set_swap_module_params_on_conversion(True).") from e
else:
if leave_parametrized:
# We cannot use no_grad because we need to know whether one or more