Files
pytorch/torch/optim
Filip 167ad09be5 [optim] override SWALR.state_dict and load_state_dict (#163122)
Fixes #163105

Note that the new `SWALR.load_state_dict` is **not backwards compatible**:
```python
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  """Load the scheduler's state.

  Args:
      state_dict (dict): scheduler state. Should be an object returned
          from a call to :meth:`state_dict`.
  """
  self.__dict__.update(state_dict)
  self._set_anneal_func(self._anneal_strategy)
```

If we'd like to maintain compatibility with old state_dicts (loaded with `weights_only=False`), we could use something along these lines:
```python
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
    """Load the scheduler's state.

    Args:
        state_dict (dict): scheduler state. Should be an object returned
            from a call to :meth:`state_dict`.
    """
    anneal_func = state_dict.pop("anneal_func", None)
    strategy = state_dict.get("_anneal_strategy")
    self.__dict__.update(state_dict)

    if anneal_func is not None:
        state_dict["anneal_func"] = anneal_func
        if strategy is None:
            if anneal_func == self._linear_anneal:
                strategy = "linear"
            elif anneal_func == self._cosine_anneal:
                strategy = "cos"

    if strategy is None:
        strategy = getattr(self, "_anneal_strategy", "cos")

    self._set_anneal_func(strategy)
```

But given the fact that loading an `SWALR` state_dict before this PR would have caused an error, this seems okay. A GitHub/Google search for `SWALR.load_state_dict` had no results. Happy to change if not, or add a warning just in case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163122
Approved by: https://github.com/janeyx99
2025-09-17 18:17:26 +00:00
..