[FSDP][optim_state_dict] Fix a memory leakage in optim_state_dict (#96263)

Summary: The original code uses a class variable to store flat_parameter result. This could cause memory leakage.

Test Plan: CI and a E2E run

Reviewed By: awgu

Differential Revision: D43893577

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96263
Approved by: https://github.com/zhaojuanmao
This commit is contained in:
Chien-Chin Huang
2023-03-08 08:43:40 +00:00
committed by PyTorch MergeBot
parent 2bb022e902
commit 92edac72aa

View File

@ -1,7 +1,7 @@
import copy
import functools
import warnings
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import (
Any,
cast,
@ -50,6 +50,7 @@ def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
yield k, dictionary[k]
@dataclass
class _ConsolidatedOptimState:
"""
This holds the consolidated optimizer state on the target rank. Positive-
@ -70,9 +71,9 @@ class _ConsolidatedOptimState:
name to its value.
"""
tensor_state: Dict[str, torch.Tensor] = {}
zero_dim_tensor_state: Dict[str, torch.Tensor] = {}
non_tensor_state: Dict[str, Any] = {}
tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict)
zero_dim_tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict)
non_tensor_state: Dict[str, Any] = field(default_factory=dict)
class _PosDimTensorInfo(NamedTuple):