mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
Feat: add to_json (#3743)
This commit is contained in:
@ -80,6 +80,19 @@ class ParallelismConfig:
|
||||
f"\tcp_handler={self.cp_handler})\n"
|
||||
)
|
||||
|
||||
def to_json(self):
|
||||
import copy
|
||||
|
||||
_non_serializable_fields = ["device_mesh"]
|
||||
|
||||
copy.deepcopy(
|
||||
{
|
||||
k: copy.deepcopy(v.__dict__) if hasattr(v, "__dict__") else v
|
||||
for k, v in self.__dict__.items()
|
||||
if k not in _non_serializable_fields
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def dp_dim_names(self):
|
||||
"""Names of enabled dimensions across which data parallelism is applied."""
|
||||
|
Reference in New Issue
Block a user