mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
81 lines
2.5 KiB
Python
Executable File
81 lines
2.5 KiB
Python
Executable File
"""
|
|
Copyright (c) Microsoft Corporation
|
|
Licensed under the MIT license.
|
|
"""
|
|
"""
|
|
Collection of DeepSpeed configuration utilities
|
|
"""
|
|
import json
|
|
import collections
|
|
|
|
|
|
# adapted from https://stackoverflow.com/a/50701137/9201239
|
|
class ScientificNotationEncoder(json.JSONEncoder):
|
|
"""
|
|
This class overrides ``json.dumps`` default formatter.
|
|
|
|
This version keeps everything as normal except formats numbers bigger than 1e3 using scientific notation.
|
|
|
|
Just pass ``cls=ScientificNotationEncoder`` to ``json.dumps`` to activate it
|
|
|
|
"""
|
|
def iterencode(self, o, _one_shot=False, level=0):
|
|
indent = self.indent if self.indent is not None else 4
|
|
prefix_close = " " * level * indent
|
|
level += 1
|
|
prefix = " " * level * indent
|
|
if isinstance(o, bool):
|
|
return "true" if o else "false"
|
|
elif isinstance(o, float) or isinstance(o, int):
|
|
if o > 1e3:
|
|
return f"{o:e}"
|
|
else:
|
|
return f"{o}"
|
|
elif isinstance(o, collections.Mapping):
|
|
x = [
|
|
f'\n{prefix}"{k}": {self.iterencode(v, level=level)}' for k,
|
|
v in o.items()
|
|
]
|
|
return "{" + ', '.join(x) + f"\n{prefix_close}" + "}"
|
|
elif isinstance(o, collections.Sequence) and not isinstance(o, str):
|
|
return f"[{ f', '.join(map(self.iterencode, o)) }]"
|
|
return "\n, ".join(super().iterencode(o, _one_shot))
|
|
|
|
|
|
class DeepSpeedConfigObject(object):
|
|
"""
|
|
For json serialization
|
|
"""
|
|
def repr(self):
|
|
return self.__dict__
|
|
|
|
def __repr__(self):
|
|
return json.dumps(
|
|
self.__dict__,
|
|
sort_keys=True,
|
|
indent=4,
|
|
cls=ScientificNotationEncoder,
|
|
)
|
|
|
|
|
|
def get_scalar_param(param_dict, param_name, param_default_value):
|
|
return param_dict.get(param_name, param_default_value)
|
|
|
|
|
|
def get_list_param(param_dict, param_name, param_default_value):
|
|
return param_dict.get(param_name, param_default_value)
|
|
|
|
|
|
def get_dict_param(param_dict, param_name, param_default_value):
|
|
return param_dict.get(param_name, param_default_value)
|
|
|
|
|
|
def dict_raise_error_on_duplicate_keys(ordered_pairs):
|
|
"""Reject duplicate keys."""
|
|
d = dict((k, v) for k, v in ordered_pairs)
|
|
if len(d) != len(ordered_pairs):
|
|
counter = collections.Counter([pair[0] for pair in ordered_pairs])
|
|
keys = [key for key, value in counter.items() if value > 1]
|
|
raise ValueError("Duplicate keys in DeepSpeed config: {}".format(keys))
|
|
return d
|