mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fixes #163105 Note that the new `SWALR.load_state_dict` is **not backwards compatible**: ```python @override def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load the scheduler's state. Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ self.__dict__.update(state_dict) self._set_anneal_func(self._anneal_strategy) ``` If we'd like to maintain compatibility with old state_dicts (loaded with `weights_only=False`), we could use something along these lines: ```python @override def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load the scheduler's state. Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ anneal_func = state_dict.pop("anneal_func", None) strategy = state_dict.get("_anneal_strategy") self.__dict__.update(state_dict) if anneal_func is not None: state_dict["anneal_func"] = anneal_func if strategy is None: if anneal_func == self._linear_anneal: strategy = "linear" elif anneal_func == self._cosine_anneal: strategy = "cos" if strategy is None: strategy = getattr(self, "_anneal_strategy", "cos") self._set_anneal_func(strategy) ``` But given the fact that loading an `SWALR` state_dict before this PR would have caused an error, this seems okay. A GitHub/Google search for `SWALR.load_state_dict` had no results. Happy to change if not, or add a warning just in case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163122 Approved by: https://github.com/janeyx99