mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Almost there! Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164913 Approved by: https://github.com/oulgen
369 lines
15 KiB
Python
369 lines
15 KiB
Python
# mypy: allow-untyped-defs
|
|
"""Spectral Normalization from https://arxiv.org/abs/1802.05957."""
|
|
|
|
from typing import Any, Optional, TypeVar
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.nn.modules import Module
|
|
|
|
|
|
__all__ = [
|
|
"SpectralNorm",
|
|
"SpectralNormLoadStateDictPreHook",
|
|
"SpectralNormStateDictHook",
|
|
"spectral_norm",
|
|
"remove_spectral_norm",
|
|
]
|
|
|
|
|
|
class SpectralNorm:
|
|
# Invariant before and after each forward call:
|
|
# u = F.normalize(W @ v)
|
|
# NB: At initialization, this invariant is not enforced
|
|
|
|
_version: int = 1
|
|
# At version 1:
|
|
# made `W` not a buffer,
|
|
# added `v` as a buffer, and
|
|
# made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
|
|
name: str
|
|
dim: int
|
|
n_power_iterations: int
|
|
eps: float
|
|
|
|
def __init__(
|
|
self,
|
|
name: str = "weight",
|
|
n_power_iterations: int = 1,
|
|
dim: int = 0,
|
|
eps: float = 1e-12,
|
|
) -> None:
|
|
self.name = name
|
|
self.dim = dim
|
|
if n_power_iterations <= 0:
|
|
raise ValueError(
|
|
"Expected n_power_iterations to be positive, but "
|
|
f"got n_power_iterations={n_power_iterations}"
|
|
)
|
|
self.n_power_iterations = n_power_iterations
|
|
self.eps = eps
|
|
|
|
def reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
|
|
weight_mat = weight
|
|
if self.dim != 0:
|
|
# permute dim to front
|
|
weight_mat = weight_mat.permute(
|
|
self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim]
|
|
)
|
|
height = weight_mat.size(0)
|
|
return weight_mat.reshape(height, -1)
|
|
|
|
def compute_weight(self, module: Module, do_power_iteration: bool) -> torch.Tensor:
|
|
# NB: If `do_power_iteration` is set, the `u` and `v` vectors are
|
|
# updated in power iteration **in-place**. This is very important
|
|
# because in `DataParallel` forward, the vectors (being buffers) are
|
|
# broadcast from the parallelized module to each module replica,
|
|
# which is a new module object created on the fly. And each replica
|
|
# runs its own spectral norm power iteration. So simply assigning
|
|
# the updated vectors to the module this function runs on will cause
|
|
# the update to be lost forever. And the next time the parallelized
|
|
# module is replicated, the same randomly initialized vectors are
|
|
# broadcast and used!
|
|
#
|
|
# Therefore, to make the change propagate back, we rely on two
|
|
# important behaviors (also enforced via tests):
|
|
# 1. `DataParallel` doesn't clone storage if the broadcast tensor
|
|
# is already on correct device; and it makes sure that the
|
|
# parallelized module is already on `device[0]`.
|
|
# 2. If the out tensor in `out=` kwarg has correct shape, it will
|
|
# just fill in the values.
|
|
# Therefore, since the same power iteration is performed on all
|
|
# devices, simply updating the tensors in-place will make sure that
|
|
# the module replica on `device[0]` will update the _u vector on the
|
|
# parallelized module (by shared storage).
|
|
#
|
|
# However, after we update `u` and `v` in-place, we need to **clone**
|
|
# them before using them to normalize the weight. This is to support
|
|
# backproping through two forward passes, e.g., the common pattern in
|
|
# GAN training: loss = D(real) - D(fake). Otherwise, engine will
|
|
# complain that variables needed to do backward for the first forward
|
|
# (i.e., the `u` and `v` vectors) are changed in the second forward.
|
|
weight = getattr(module, self.name + "_orig")
|
|
u = getattr(module, self.name + "_u")
|
|
v = getattr(module, self.name + "_v")
|
|
weight_mat = self.reshape_weight_to_matrix(weight)
|
|
|
|
if do_power_iteration:
|
|
with torch.no_grad():
|
|
for _ in range(self.n_power_iterations):
|
|
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
|
|
# are the first left and right singular vectors.
|
|
# This power iteration produces approximations of `u` and `v`.
|
|
v = F.normalize(
|
|
torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v
|
|
)
|
|
u = F.normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u)
|
|
if self.n_power_iterations > 0:
|
|
# See above on why we need to clone
|
|
u = u.clone(memory_format=torch.contiguous_format)
|
|
v = v.clone(memory_format=torch.contiguous_format)
|
|
|
|
sigma = torch.dot(u, torch.mv(weight_mat, v))
|
|
weight = weight / sigma
|
|
return weight
|
|
|
|
def remove(self, module: Module) -> None:
|
|
with torch.no_grad():
|
|
weight = self.compute_weight(module, do_power_iteration=False)
|
|
delattr(module, self.name)
|
|
delattr(module, self.name + "_u")
|
|
delattr(module, self.name + "_v")
|
|
delattr(module, self.name + "_orig")
|
|
module.register_parameter(self.name, torch.nn.Parameter(weight.detach()))
|
|
|
|
def __call__(self, module: Module, inputs: Any) -> None:
|
|
setattr(
|
|
module,
|
|
self.name,
|
|
self.compute_weight(module, do_power_iteration=module.training),
|
|
)
|
|
|
|
def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
|
|
# Tries to returns a vector `v` s.t. `u = F.normalize(W @ v)`
|
|
# (the invariant at top of this class) and `u @ W @ v = sigma`.
|
|
# This uses pinverse in case W^T W is not invertible.
|
|
v = torch.linalg.multi_dot(
|
|
[weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)]
|
|
).squeeze(1)
|
|
return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
|
|
|
|
@staticmethod
|
|
def apply(
|
|
module: Module, name: str, n_power_iterations: int, dim: int, eps: float
|
|
) -> "SpectralNorm":
|
|
for hook in module._forward_pre_hooks.values():
|
|
if isinstance(hook, SpectralNorm) and hook.name == name:
|
|
raise RuntimeError(
|
|
f"Cannot register two spectral_norm hooks on the same parameter {name}"
|
|
)
|
|
|
|
fn = SpectralNorm(name, n_power_iterations, dim, eps)
|
|
weight = module._parameters[name]
|
|
if weight is None:
|
|
raise ValueError(
|
|
f"`SpectralNorm` cannot be applied as parameter `{name}` is None"
|
|
)
|
|
if isinstance(weight, torch.nn.parameter.UninitializedParameter):
|
|
raise ValueError(
|
|
"The module passed to `SpectralNorm` can't have uninitialized parameters. "
|
|
"Make sure to run the dummy forward before applying spectral normalization"
|
|
)
|
|
|
|
with torch.no_grad():
|
|
weight_mat = fn.reshape_weight_to_matrix(weight)
|
|
|
|
h, w = weight_mat.size()
|
|
# randomly initialize `u` and `v`
|
|
u = F.normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
|
|
v = F.normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
|
|
|
|
delattr(module, fn.name)
|
|
module.register_parameter(fn.name + "_orig", weight)
|
|
# We still need to assign weight back as fn.name because all sorts of
|
|
# things may assume that it exists, e.g., when initializing weights.
|
|
# However, we can't directly assign as it could be an nn.Parameter and
|
|
# gets added as a parameter. Instead, we register weight.data as a plain
|
|
# attribute.
|
|
setattr(module, fn.name, weight.data)
|
|
module.register_buffer(fn.name + "_u", u)
|
|
module.register_buffer(fn.name + "_v", v)
|
|
|
|
module.register_forward_pre_hook(fn)
|
|
module._register_state_dict_hook(SpectralNormStateDictHook(fn))
|
|
module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
|
|
return fn
|
|
|
|
|
|
# This is a top level class because Py2 pickle doesn't like inner class nor an
|
|
# instancemethod.
|
|
class SpectralNormLoadStateDictPreHook:
|
|
# See docstring of SpectralNorm._version on the changes to spectral_norm.
|
|
def __init__(self, fn) -> None:
|
|
self.fn = fn
|
|
|
|
# For state_dict with version None, (assuming that it has gone through at
|
|
# least one training forward), we have
|
|
#
|
|
# u = F.normalize(W_orig @ v)
|
|
# W = W_orig / sigma, where sigma = u @ W_orig @ v
|
|
#
|
|
# To compute `v`, we solve `W_orig @ x = u`, and let
|
|
# v = x / (u @ W_orig @ x) * (W / W_orig).
|
|
def __call__(
|
|
self,
|
|
state_dict,
|
|
prefix,
|
|
local_metadata,
|
|
strict,
|
|
missing_keys,
|
|
unexpected_keys,
|
|
error_msgs,
|
|
) -> None:
|
|
fn = self.fn
|
|
version = local_metadata.get("spectral_norm", {}).get(
|
|
fn.name + ".version", None
|
|
)
|
|
if version is None or version < 1:
|
|
weight_key = prefix + fn.name
|
|
if (
|
|
version is None
|
|
and all(weight_key + s in state_dict for s in ("_orig", "_u", "_v"))
|
|
and weight_key not in state_dict
|
|
):
|
|
# Detect if it is the updated state dict and just missing metadata.
|
|
# This could happen if the users are crafting a state dict themselves,
|
|
# so we just pretend that this is the newest.
|
|
return
|
|
has_missing_keys = False
|
|
for suffix in ("_orig", "", "_u"):
|
|
key = weight_key + suffix
|
|
if key not in state_dict:
|
|
has_missing_keys = True
|
|
if strict:
|
|
missing_keys.append(key)
|
|
if has_missing_keys:
|
|
return
|
|
with torch.no_grad():
|
|
weight_orig = state_dict[weight_key + "_orig"]
|
|
weight = state_dict.pop(weight_key)
|
|
sigma = (weight_orig / weight).mean()
|
|
weight_mat = fn.reshape_weight_to_matrix(weight_orig)
|
|
u = state_dict[weight_key + "_u"]
|
|
v = fn._solve_v_and_rescale(weight_mat, u, sigma)
|
|
state_dict[weight_key + "_v"] = v
|
|
|
|
|
|
# This is a top level class because Py2 pickle doesn't like inner class nor an
|
|
# instancemethod.
|
|
class SpectralNormStateDictHook:
|
|
# See docstring of SpectralNorm._version on the changes to spectral_norm.
|
|
def __init__(self, fn) -> None:
|
|
self.fn = fn
|
|
|
|
def __call__(self, module, state_dict, prefix, local_metadata) -> None:
|
|
if "spectral_norm" not in local_metadata:
|
|
local_metadata["spectral_norm"] = {}
|
|
key = self.fn.name + ".version"
|
|
if key in local_metadata["spectral_norm"]:
|
|
raise RuntimeError(f"Unexpected key in metadata['spectral_norm']: {key}")
|
|
local_metadata["spectral_norm"][key] = self.fn._version
|
|
|
|
|
|
T_module = TypeVar("T_module", bound=Module)
|
|
|
|
|
|
def spectral_norm(
|
|
module: T_module,
|
|
name: str = "weight",
|
|
n_power_iterations: int = 1,
|
|
eps: float = 1e-12,
|
|
dim: Optional[int] = None,
|
|
) -> T_module:
|
|
r"""Apply spectral normalization to a parameter in the given module.
|
|
|
|
.. math::
|
|
\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
|
|
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
|
|
|
|
Spectral normalization stabilizes the training of discriminators (critics)
|
|
in Generative Adversarial Networks (GANs) by rescaling the weight tensor
|
|
with spectral norm :math:`\sigma` of the weight matrix calculated using
|
|
power iteration method. If the dimension of the weight tensor is greater
|
|
than 2, it is reshaped to 2D in power iteration method to get spectral
|
|
norm. This is implemented via a hook that calculates spectral norm and
|
|
rescales weight before every :meth:`~Module.forward` call.
|
|
|
|
See `Spectral Normalization for Generative Adversarial Networks`_ .
|
|
|
|
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
|
|
|
|
Args:
|
|
module (nn.Module): containing module
|
|
name (str, optional): name of weight parameter
|
|
n_power_iterations (int, optional): number of power iterations to
|
|
calculate spectral norm
|
|
eps (float, optional): epsilon for numerical stability in
|
|
calculating norms
|
|
dim (int, optional): dimension corresponding to number of outputs,
|
|
the default is ``0``, except for modules that are instances of
|
|
ConvTranspose{1,2,3}d, when it is ``1``
|
|
|
|
Returns:
|
|
The original module with the spectral norm hook
|
|
|
|
.. note::
|
|
This function has been reimplemented as
|
|
:func:`torch.nn.utils.parametrizations.spectral_norm` using the new
|
|
parametrization functionality in
|
|
:func:`torch.nn.utils.parametrize.register_parametrization`. Please use
|
|
the newer version. This function will be deprecated in a future version
|
|
of PyTorch.
|
|
|
|
Example::
|
|
|
|
>>> m = spectral_norm(nn.Linear(20, 40))
|
|
>>> m
|
|
Linear(in_features=20, out_features=40, bias=True)
|
|
>>> m.weight_u.size()
|
|
torch.Size([40])
|
|
|
|
"""
|
|
if dim is None:
|
|
if isinstance(
|
|
module,
|
|
(
|
|
torch.nn.ConvTranspose1d,
|
|
torch.nn.ConvTranspose2d,
|
|
torch.nn.ConvTranspose3d,
|
|
),
|
|
):
|
|
dim = 1
|
|
else:
|
|
dim = 0
|
|
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
|
# pyrefly: ignore # bad-return
|
|
return module
|
|
|
|
|
|
def remove_spectral_norm(module: T_module, name: str = "weight") -> T_module:
|
|
r"""Remove the spectral normalization reparameterization from a module.
|
|
|
|
Args:
|
|
module (Module): containing module
|
|
name (str, optional): name of weight parameter
|
|
|
|
Example:
|
|
>>> m = spectral_norm(nn.Linear(40, 10))
|
|
>>> remove_spectral_norm(m)
|
|
"""
|
|
for k, hook in module._forward_pre_hooks.items():
|
|
if isinstance(hook, SpectralNorm) and hook.name == name:
|
|
hook.remove(module)
|
|
del module._forward_pre_hooks[k]
|
|
break
|
|
else:
|
|
raise ValueError(f"spectral_norm of '{name}' not found in {module}")
|
|
|
|
for k, hook in module._state_dict_hooks.items():
|
|
if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name:
|
|
del module._state_dict_hooks[k]
|
|
break
|
|
|
|
for k, hook in module._load_state_dict_pre_hooks.items():
|
|
if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name:
|
|
del module._load_state_dict_pre_hooks[k]
|
|
break
|
|
|
|
return module
|