Files
pytorch/torch/optim/sparse_adam.py
ErezYosef 197601eeea Add Support for Tracking Parameter Names (named_parameters) in Optimizer State Dict (#134107)
A proposal addressing Issue #1489: **Optimizer should track parameter names and not id.**

(also mentioned in here: [[RFC] Introducing FQNs/clarity eyeglasses to optim state_dict](https://dev-discuss.pytorch.org/t/rfc-introducing-fqns-clarity-to-optim-state-dict/1552)

## Summary
This PR introduces a backward-compatible enhancement where optimizers track parameter names instead of just their id.
Optimizers can be initialized with `named_parameters()` as:
```python
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
```
This allows for greater clarity and ease when handling optimizers, as the parameters' names are preserved within the optimizer’s `state_dict` as:
```
state_dict =
{
    'state': {
    0: {'momentum_buffer': tensor(...), ...},
    1: {'momentum_buffer': tensor(...), ...},
    },
    'param_groups': [
        {
        'lr': 0.01,
        'weight_decay': 0,
        ...
        'params': [0,1]
        'param_names' ['layer.weight', 'layer.bias']  (optional)
        }
    ]
}
```
Loading `state_dict` is not changed (backward-compatible) and the `param_names` key will be ignored.

## Key Features
#### Named Parameters in Optimizer Initialization:
Optimizers can accept the output of `model.named_parameters()` during initialization, allowing them to store parameter names directly.
#### Parameter Names in `state_dict`:
The parameter names are saved as a list in the optimizer’s `state_dict` with key `param_names`, alongside the `params` indices, ensuring seamless tracking of both names and parameters.

## Backward Compatibility
#### No Breaking Changes:
This change is fully backward-compatible. The added `param_names` key in the optimizer's `state_dict` is ignored when loading a state to the optimizer.

#### Customization with Hooks:
For more control, the loaded state_dict can be modified using a custom `register_load_state_dict_pre_hook`, providing flexibility for different design needs.

## Documentation Updates
Please refer to the documentation changes for more details on how this feature is implemented and how it can be used effectively.

## Solution Example:

A suggested solution to the problem mentioned in #1489, for the same parameters but in a different order.
The following `register_load_state_dict_pre_hook` should be added to the optimizer before loading to enable loading the state dict :
```python
def adapt_state_dict_ids(optimizer, state_dict):
    # assuming a single param group.
    current_state_group = optimizer.state_dict()['param_groups'][0]
    loaded_state_group = state_dict['param_groups'][0]

    # same number of params, same names, only different ordering
    current_state_name_to_id_mapping = {}  # mapping --  param_name: id
    for i, name in enumerate(current_state_group['param_names']):
        current_state_name_to_id_mapping[name] = current_state_group['params'][i]

    # changing the ids of the loaded state dict to match the order of the given state dict.
    for i, name in enumerate(current_state_group['param_names']):
        loaded_state_group['params'][i] = current_state_name_to_id_mapping[name]

    return state_dict
```
In this code, the loaded `state_dict` ids are adapted to match the order of the current optimizer `state_dict`.
Both the previous and the current optimizers are required to be initiated with `named_parameters()` to have the 'param_names' key in the dict.

### Note
This is my first contribution to PyTorch, and I wish to receive feedback or suggestions for improvement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134107
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2024-10-14 19:24:44 +00:00

185 lines
7.8 KiB
Python

# mypy: allow-untyped-defs
from typing import List, Tuple, Union
import torch
from torch import Tensor
from . import _functional as F
from .optimizer import _maximize_doc, _params_doc, Optimizer, ParamsT
__all__ = ["SparseAdam"]
class SparseAdam(Optimizer):
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
maximize: bool = False,
):
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 < lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 < eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize)
super().__init__(params, defaults)
sparse_params = []
complex_params = []
for index, param_group in enumerate(self.param_groups):
assert isinstance(
param_group, dict
), f"param_groups must be a list of dicts, but got {type(param_group)}"
# given param group, convert given params to a list first before iterating
for d_index, d_param in enumerate(param_group["params"]):
if d_param.is_sparse:
sparse_params.append([index, d_index])
if d_param.is_complex():
complex_params.append([index, d_index])
if sparse_params:
raise ValueError(
f"Sparse params at indices {sparse_params}: SparseAdam requires dense parameter tensors"
)
if complex_params:
raise ValueError(
f"Complex params at indices {complex_params}: SparseAdam does not support complex parameters"
)
@torch.no_grad()
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
state_steps: List[int] = []
beta1, beta2 = group["betas"]
maximize = group.get("maximize", False)
for p in group["params"]:
if p.grad is not None:
params_with_grad.append(p)
if not p.grad.is_sparse:
raise RuntimeError(
"SparseAdam does not support dense gradients, please consider Adam instead"
)
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
# update the steps for each param group update
state["step"] += 1
# record the step after step update
state_steps.append(state["step"])
F.sparse_adam(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
eps=group["eps"],
beta1=beta1,
beta2=beta2,
lr=group["lr"],
maximize=maximize,
)
return loss
SparseAdam.__doc__ = rf"""SparseAdam implements a masked version of the Adam algorithm
suitable for sparse gradients. Currently, due to implementation constraints (explained
below), SparseAdam is only intended for a narrow subset of use cases, specifically
parameters of a dense layout with gradients of a sparse layout. This occurs in a
special case where the module backwards produces grads already in a sparse layout.
One example NN module that behaves as such is ``nn.Embedding(sparse=True)``.
SparseAdam approximates the Adam algorithm by masking out the parameter and moment
updates corresponding to the zero values in the gradients. Whereas the Adam algorithm
will update the first moment, the second moment, and the parameters based on all values
of the gradients, SparseAdam only updates the moments and parameters corresponding
to the non-zero values of the gradients.
A simplified way of thinking about the `intended` implementation is as such:
1. Create a mask of the non-zero values in the sparse gradients. For example,
if your gradient looks like [0, 5, 0, 0, 9], the mask would be [0, 1, 0, 0, 1].
2. Apply this mask over the running moments and do computation on only the
non-zero values.
3. Apply this mask over the parameters and only apply an update on non-zero values.
In actuality, we use sparse layout Tensors to optimize this approximation, which means the
more gradients that are masked by not being materialized, the more performant the optimization.
Since we rely on using sparse layout tensors, we infer that any materialized value in the
sparse layout is non-zero and we do NOT actually verify that all values are not zero!
It is important to not conflate a semantically sparse tensor (a tensor where many
of its values are zeros) with a sparse layout tensor (a tensor where ``.is_sparse``
returns ``True``). The SparseAdam approximation is intended for `semantically` sparse
tensors and the sparse layout is only a implementation detail. A clearer implementation
would be to use MaskedTensors, but those are experimental.
.. note::
If you suspect your gradients are semantically sparse (but do not have sparse
layout), this variant may not be the best for you. Ideally, you want to avoid
materializing anything that is suspected to be sparse in the first place, since
needing to convert all your grads from dense layout to sparse layout may outweigh
the performance gain. Here, using Adam may be the best alternative, unless you
can easily rig up your module to output sparse grads similar to
``nn.Embedding(sparse=True)``. If you insist on converting your grads, you can do
so by manually overriding your parameters' ``.grad`` fields with their sparse
equivalents before calling ``.step()``.
Args:
{_params_doc}
lr (float, Tensor, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
{_maximize_doc}
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
"""