mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
Summary: There were two problems with SN + DP: 1. In SN, the updated _u vector is saved back to module via a `setattr`. However, in DP, everything is run on a replica, so those updates are lost. 2. In DP, the buffers are broadcast via a `broadcast_coalesced`, so on replicas they are all views. Therefore, the `detach_` call won't work. Fixes are: 1. Update _u vector in-place so, by the shared storage between 1st replica and the parallelized module, the update is retained 2. Do not call `detach_`. 3. Added comments in SN about the subtlety. 4. Added a note to the DP doc on this particular behavior of DP. cc crcrpar taesung89 The controller you requested could not be found. yaoshengfu Fixes https://github.com/pytorch/pytorch/issues/11476 Pull Request resolved: https://github.com/pytorch/pytorch/pull/12671 Differential Revision: D10410232 Pulled By: SsnL fbshipit-source-id: c447951844a30366d8c196bf9436340e88f3b6d9
175 lines
7.6 KiB
Python
175 lines
7.6 KiB
Python
"""
|
|
Spectral Normalization from https://arxiv.org/abs/1802.05957
|
|
"""
|
|
import torch
|
|
from torch.nn.functional import normalize
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
class SpectralNorm(object):
|
|
|
|
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
|
|
self.name = name
|
|
self.dim = dim
|
|
if n_power_iterations <= 0:
|
|
raise ValueError('Expected n_power_iterations to be positive, but '
|
|
'got n_power_iterations={}'.format(n_power_iterations))
|
|
self.n_power_iterations = n_power_iterations
|
|
self.eps = eps
|
|
|
|
def compute_weight_and_update_u(self, module):
|
|
# NB: This updates the _u vector **in-place**. This is very important
|
|
# because in DataParallel forward, the _u vector (being a buffer) is
|
|
# broadcast from the parallelized module to each module replica,
|
|
# which is a new module object on the fly. And each replica runs its
|
|
# own spectral norm power iteration. So simply assigning the updated
|
|
# _u vector to module this runs on will cause the update to be lost
|
|
# forever. And the next time the parallelized module is replicated,
|
|
# the same randomly initialized _u vector is broadcast!
|
|
#
|
|
# Therefore, to make the change propagate back, we rely on two
|
|
# important bahaviors (also enforced via tests):
|
|
# 1. DataParallel doesn't clone storage if the broadcast tensor is
|
|
# alreay on correct device; and it makes sure that the
|
|
# parallelized module is already on device[0].
|
|
# 2. If the out tensor in out= kwarg has correct shape, it will
|
|
# just fill in the values.
|
|
# Therefore, since the same power iteration is performed on all
|
|
# devices, simply updating the _u tensor in-place will make sure
|
|
# that the module replica on device[0] will update the _u vector on
|
|
# the parallized module (by shared storage).
|
|
weight = getattr(module, self.name + '_orig')
|
|
u = getattr(module, self.name + '_u')
|
|
weight_mat = weight
|
|
if self.dim != 0:
|
|
# permute dim to front
|
|
weight_mat = weight_mat.permute(self.dim,
|
|
*[d for d in range(weight_mat.dim()) if d != self.dim])
|
|
height = weight_mat.size(0)
|
|
weight_mat = weight_mat.reshape(height, -1)
|
|
with torch.no_grad():
|
|
for _ in range(self.n_power_iterations):
|
|
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
|
|
# are the first left and right singular vectors.
|
|
# This power iteration produces approximations of `u` and `v`.
|
|
v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
|
|
u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps, out=u)
|
|
|
|
sigma = torch.dot(u, torch.matmul(weight_mat, v))
|
|
weight = weight / sigma
|
|
return weight
|
|
|
|
def remove(self, module):
|
|
weight = getattr(module, self.name)
|
|
delattr(module, self.name)
|
|
delattr(module, self.name + '_u')
|
|
delattr(module, self.name + '_orig')
|
|
module.register_parameter(self.name, torch.nn.Parameter(weight))
|
|
|
|
def __call__(self, module, inputs):
|
|
if module.training:
|
|
weight = self.compute_weight_and_update_u(module)
|
|
setattr(module, self.name, weight)
|
|
else:
|
|
r_g = getattr(module, self.name + '_orig').requires_grad
|
|
weight = getattr(module, self.name).detach()
|
|
# NB: Cannot detach weight in-place here because if this is used
|
|
# DataParallel, the buffers are broadcast using
|
|
# `broadacast_coalesced` and `weight` here is actually a view,
|
|
# and you can't detach views in-place.
|
|
setattr(module, self.name, weight.requires_grad_(r_g))
|
|
|
|
@staticmethod
|
|
def apply(module, name, n_power_iterations, dim, eps):
|
|
fn = SpectralNorm(name, n_power_iterations, dim, eps)
|
|
weight = module._parameters[name]
|
|
height = weight.size(dim)
|
|
|
|
u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
|
|
delattr(module, fn.name)
|
|
module.register_parameter(fn.name + "_orig", weight)
|
|
# We still need to assign weight back as fn.name because all sorts of
|
|
# things may assume that it exists, e.g., when initializing weights.
|
|
# However, we can't directly assign as it could be an nn.Parameter and
|
|
# gets added as a parameter. Instead, we register weight.data as a
|
|
# buffer, which will cause weight to be included in the state dict
|
|
# and also supports nn.init due to shared storage.
|
|
module.register_buffer(fn.name, weight.data)
|
|
module.register_buffer(fn.name + "_u", u)
|
|
|
|
module.register_forward_pre_hook(fn)
|
|
return fn
|
|
|
|
|
|
def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
|
|
r"""Applies spectral normalization to a parameter in the given module.
|
|
|
|
.. math::
|
|
\mathbf{W} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
|
|
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
|
|
|
|
Spectral normalization stabilizes the training of discriminators (critics)
|
|
in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
|
|
with spectral norm :math:`\sigma` of the weight matrix calculated using
|
|
power iteration method. If the dimension of the weight tensor is greater
|
|
than 2, it is reshaped to 2D in power iteration method to get spectral
|
|
norm. This is implemented via a hook that calculates spectral norm and
|
|
rescales weight before every :meth:`~Module.forward` call.
|
|
|
|
See `Spectral Normalization for Generative Adversarial Networks`_ .
|
|
|
|
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
|
|
|
|
Args:
|
|
module (nn.Module): containing module
|
|
name (str, optional): name of weight parameter
|
|
n_power_iterations (int, optional): number of power iterations to
|
|
calculate spectal norm
|
|
eps (float, optional): epsilon for numerical stability in
|
|
calculating norms
|
|
dim (int, optional): dimension corresponding to number of outputs,
|
|
the default is 0, except for modules that are instances of
|
|
ConvTranspose1/2/3d, when it is 1
|
|
|
|
Returns:
|
|
The original module with the spectal norm hook
|
|
|
|
Example::
|
|
|
|
>>> m = spectral_norm(nn.Linear(20, 40))
|
|
Linear (20 -> 40)
|
|
>>> m.weight_u.size()
|
|
torch.Size([20])
|
|
|
|
"""
|
|
if dim is None:
|
|
if isinstance(module, (torch.nn.ConvTranspose1d,
|
|
torch.nn.ConvTranspose2d,
|
|
torch.nn.ConvTranspose3d)):
|
|
dim = 1
|
|
else:
|
|
dim = 0
|
|
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
|
return module
|
|
|
|
|
|
def remove_spectral_norm(module, name='weight'):
|
|
r"""Removes the spectral normalization reparameterization from a module.
|
|
|
|
Args:
|
|
module (nn.Module): containing module
|
|
name (str, optional): name of weight parameter
|
|
|
|
Example:
|
|
>>> m = spectral_norm(nn.Linear(40, 10))
|
|
>>> remove_spectral_norm(m)
|
|
"""
|
|
for k, hook in module._forward_pre_hooks.items():
|
|
if isinstance(hook, SpectralNorm) and hook.name == name:
|
|
hook.remove(module)
|
|
del module._forward_pre_hooks[k]
|
|
return module
|
|
|
|
raise ValueError("spectral_norm of '{}' not found in {}".format(
|
|
name, module))
|