mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FSDP][optim_state_dict] Fix a memory leakage in optim_state_dict (#96263)
Summary: The original code uses a class variable to store flat_parameter result. This could cause memory leakage. Test Plan: CI and a E2E run Reviewed By: awgu Differential Revision: D43893577 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96263 Approved by: https://github.com/zhaojuanmao
This commit is contained in:
committed by
PyTorch MergeBot
parent
2bb022e902
commit
92edac72aa
@ -1,7 +1,7 @@
|
||||
import copy
|
||||
import functools
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Any,
|
||||
cast,
|
||||
@ -50,6 +50,7 @@ def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
|
||||
yield k, dictionary[k]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ConsolidatedOptimState:
|
||||
"""
|
||||
This holds the consolidated optimizer state on the target rank. Positive-
|
||||
@ -70,9 +71,9 @@ class _ConsolidatedOptimState:
|
||||
name to its value.
|
||||
"""
|
||||
|
||||
tensor_state: Dict[str, torch.Tensor] = {}
|
||||
zero_dim_tensor_state: Dict[str, torch.Tensor] = {}
|
||||
non_tensor_state: Dict[str, Any] = {}
|
||||
tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict)
|
||||
zero_dim_tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict)
|
||||
non_tensor_state: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class _PosDimTensorInfo(NamedTuple):
|
||||
|
Reference in New Issue
Block a user