mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548 Approved by: https://github.com/ezyang
166 lines
5.7 KiB
Python
166 lines
5.7 KiB
Python
# mypy: allow-untyped-defs
|
|
r"""Weight Normalization from https://arxiv.org/abs/1602.07868."""
|
|
|
|
from typing import Any, TypeVar
|
|
from typing_extensions import deprecated
|
|
|
|
from torch import _weight_norm, norm_except_dim
|
|
from torch.nn.modules import Module
|
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
|
|
|
|
|
__all__ = ["WeightNorm", "weight_norm", "remove_weight_norm"]
|
|
|
|
|
|
class WeightNorm:
|
|
name: str
|
|
dim: int
|
|
|
|
def __init__(self, name: str, dim: int) -> None:
|
|
if dim is None:
|
|
dim = -1
|
|
self.name = name
|
|
self.dim = dim
|
|
|
|
# TODO Make return type more specific
|
|
def compute_weight(self, module: Module) -> Any:
|
|
g = getattr(module, self.name + "_g")
|
|
v = getattr(module, self.name + "_v")
|
|
return _weight_norm(v, g, self.dim)
|
|
|
|
@staticmethod
|
|
@deprecated(
|
|
"`torch.nn.utils.weight_norm` is deprecated "
|
|
"in favor of `torch.nn.utils.parametrizations.weight_norm`.",
|
|
category=FutureWarning,
|
|
)
|
|
def apply(module, name: str, dim: int) -> "WeightNorm":
|
|
for hook in module._forward_pre_hooks.values():
|
|
if isinstance(hook, WeightNorm) and hook.name == name:
|
|
raise RuntimeError(
|
|
f"Cannot register two weight_norm hooks on the same parameter {name}"
|
|
)
|
|
|
|
if dim is None:
|
|
dim = -1
|
|
|
|
fn = WeightNorm(name, dim)
|
|
|
|
weight = getattr(module, name)
|
|
if isinstance(weight, UninitializedParameter):
|
|
raise ValueError(
|
|
"The module passed to `WeightNorm` can't have uninitialized parameters. "
|
|
"Make sure to run the dummy forward before applying weight normalization"
|
|
)
|
|
# remove w from parameter list
|
|
del module._parameters[name]
|
|
|
|
# add g and v as new parameters and express w as g/||v|| * v
|
|
module.register_parameter(
|
|
name + "_g", Parameter(norm_except_dim(weight, 2, dim).data)
|
|
)
|
|
module.register_parameter(name + "_v", Parameter(weight.data))
|
|
setattr(module, name, fn.compute_weight(module))
|
|
|
|
# recompute weight before every forward()
|
|
module.register_forward_pre_hook(fn)
|
|
|
|
return fn
|
|
|
|
def remove(self, module: Module) -> None:
|
|
weight = self.compute_weight(module)
|
|
delattr(module, self.name)
|
|
del module._parameters[self.name + "_g"]
|
|
del module._parameters[self.name + "_v"]
|
|
setattr(module, self.name, Parameter(weight.data))
|
|
|
|
def __call__(self, module: Module, inputs: Any) -> None:
|
|
setattr(module, self.name, self.compute_weight(module))
|
|
|
|
|
|
T_module = TypeVar("T_module", bound=Module)
|
|
|
|
|
|
def weight_norm(module: T_module, name: str = "weight", dim: int = 0) -> T_module:
|
|
r"""Apply weight normalization to a parameter in the given module.
|
|
|
|
.. math::
|
|
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
|
|
|
|
Weight normalization is a reparameterization that decouples the magnitude
|
|
of a weight tensor from its direction. This replaces the parameter specified
|
|
by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude
|
|
(e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).
|
|
Weight normalization is implemented via a hook that recomputes the weight
|
|
tensor from the magnitude and direction before every :meth:`~Module.forward`
|
|
call.
|
|
|
|
By default, with ``dim=0``, the norm is computed independently per output
|
|
channel/plane. To compute a norm over the entire weight tensor, use
|
|
``dim=None``.
|
|
|
|
See https://arxiv.org/abs/1602.07868
|
|
|
|
.. warning::
|
|
|
|
This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm`
|
|
which uses the modern parametrization API. The new ``weight_norm`` is compatible
|
|
with ``state_dict`` generated from old ``weight_norm``.
|
|
|
|
Migration guide:
|
|
|
|
* The magnitude (``weight_g``) and direction (``weight_v``) are now expressed
|
|
as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1``
|
|
respectively. If this is bothering you, please comment on
|
|
https://github.com/pytorch/pytorch/issues/102999
|
|
|
|
* To remove the weight normalization reparametrization, use
|
|
:func:`torch.nn.utils.parametrize.remove_parametrizations`.
|
|
|
|
* The weight is no longer recomputed once at module forward; instead, it will
|
|
be recomputed on every access. To restore the old behavior, use
|
|
:func:`torch.nn.utils.parametrize.cached` before invoking the module
|
|
in question.
|
|
|
|
Args:
|
|
module (Module): containing module
|
|
name (str, optional): name of weight parameter
|
|
dim (int, optional): dimension over which to compute the norm
|
|
|
|
Returns:
|
|
The original module with the weight norm hook
|
|
|
|
Example::
|
|
|
|
>>> m = weight_norm(nn.Linear(20, 40), name='weight')
|
|
>>> m
|
|
Linear(in_features=20, out_features=40, bias=True)
|
|
>>> m.weight_g.size()
|
|
torch.Size([40, 1])
|
|
>>> m.weight_v.size()
|
|
torch.Size([40, 20])
|
|
|
|
"""
|
|
WeightNorm.apply(module, name, dim)
|
|
return module
|
|
|
|
|
|
def remove_weight_norm(module: T_module, name: str = "weight") -> T_module:
|
|
r"""Remove the weight normalization reparameterization from a module.
|
|
|
|
Args:
|
|
module (Module): containing module
|
|
name (str, optional): name of weight parameter
|
|
|
|
Example:
|
|
>>> m = weight_norm(nn.Linear(20, 40))
|
|
>>> remove_weight_norm(m)
|
|
"""
|
|
for k, hook in module._forward_pre_hooks.items():
|
|
if isinstance(hook, WeightNorm) and hook.name == name:
|
|
hook.remove(module)
|
|
del module._forward_pre_hooks[k]
|
|
return module
|
|
|
|
raise ValueError(f"weight_norm of '{name}' not found in {module}")
|