mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
[logging] less startup noise (#7526)
This PR removes some and enables removing other startup noise - especially when it's replicated rank-times and doesn't carry any informative payload. 1. add `--log_level` flag which sets the launcher's logger to a desired setting - defaulting to `logging.INFO` for now for BC, but will change to `logging.WARNING` in v1 2. add `--quiet/-q` flag which sets the launcher's logger to `logging.ERROR` which essentially disables startup info messages 3. change the logging defaults elsewhere to `logging.WARNING` (main impact is the accelerator.py), once deepspeed started the frameworks control its loglevel for each rank, so the tricky part is this pre-start stage info logs. this part is breaking BC as there is no machinery to set the logger level for `real_accelerator.py`) 4. builder is changed to non-verbose (BC breaking) --------- Signed-off-by: Stas Bekman <stas@stason.org> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
@ -25,7 +25,7 @@ from argparse import ArgumentParser, REMAINDER
|
|||||||
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, CROSS_RANK, CROSS_SIZE
|
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, CROSS_RANK, CROSS_SIZE
|
||||||
from deepspeed.accelerator import get_accelerator
|
from deepspeed.accelerator import get_accelerator
|
||||||
from ..nebula.constants import DLTS_POD_ENV_PATH
|
from ..nebula.constants import DLTS_POD_ENV_PATH
|
||||||
from ..utils import logger, get_numactl_cmd
|
from ..utils import logger, get_numactl_cmd, set_log_level_from_string
|
||||||
from ..elasticity import is_torch_elastic_compatible
|
from ..elasticity import is_torch_elastic_compatible
|
||||||
from .constants import ELASTIC_TRAINING_ID_DEFAULT
|
from .constants import ELASTIC_TRAINING_ID_DEFAULT
|
||||||
|
|
||||||
@ -102,6 +102,18 @@ def parse_args():
|
|||||||
"numbers and range. i.e. 1,3-5,7 => [1,3,4,5,7]. When not "
|
"numbers and range. i.e. 1,3-5,7 => [1,3,4,5,7]. When not "
|
||||||
"specified, all cores on system would be used rank binding")
|
"specified, all cores on system would be used rank binding")
|
||||||
|
|
||||||
|
# TODOV1: change the default to 'warning'
|
||||||
|
parser.add_argument("--log_level",
|
||||||
|
type=str,
|
||||||
|
default="info",
|
||||||
|
choices=['debug', 'info', 'warning', 'error', 'critical'],
|
||||||
|
help="Set launcher loglevel. The default is 'info'")
|
||||||
|
|
||||||
|
parser.add_argument("-q",
|
||||||
|
"--quiet",
|
||||||
|
action="store_true",
|
||||||
|
help="Try to be as quiet as possible. Aliases to `--log_level error`")
|
||||||
|
|
||||||
# positional
|
# positional
|
||||||
parser.add_argument("training_script",
|
parser.add_argument("training_script",
|
||||||
type=str,
|
type=str,
|
||||||
@ -134,6 +146,10 @@ def main():
|
|||||||
args = parse_args()
|
args = parse_args()
|
||||||
current_env = os.environ.copy()
|
current_env = os.environ.copy()
|
||||||
|
|
||||||
|
if args.quiet:
|
||||||
|
args.log_level = "error"
|
||||||
|
set_log_level_from_string(args.log_level)
|
||||||
|
|
||||||
for k in current_env.keys():
|
for k in current_env.keys():
|
||||||
if "NCCL" in k:
|
if "NCCL" in k:
|
||||||
logger.info(f"{args.node_rank} {k}={current_env[k]}")
|
logger.info(f"{args.node_rank} {k}={current_env[k]}")
|
||||||
|
@ -28,7 +28,7 @@ from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRun
|
|||||||
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER, IMPI_LAUNCHER
|
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER, IMPI_LAUNCHER
|
||||||
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
|
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
|
||||||
from ..nebula.constants import NEBULA_EXPORT_ENVS
|
from ..nebula.constants import NEBULA_EXPORT_ENVS
|
||||||
from ..utils import logger
|
from ..utils import logger, set_log_level_from_string
|
||||||
|
|
||||||
from ..autotuning import Autotuner
|
from ..autotuning import Autotuner
|
||||||
from deepspeed.accelerator import get_accelerator
|
from deepspeed.accelerator import get_accelerator
|
||||||
@ -212,6 +212,18 @@ def parse_args(args=None):
|
|||||||
default=None,
|
default=None,
|
||||||
help="Python virtual environment activation script for job.")
|
help="Python virtual environment activation script for job.")
|
||||||
|
|
||||||
|
# TODOV1: change the default to 'warning'
|
||||||
|
parser.add_argument("--log_level",
|
||||||
|
type=str,
|
||||||
|
default="info",
|
||||||
|
choices=['debug', 'info', 'warning', 'error', 'critical'],
|
||||||
|
help="Set runner loglevel. The default is 'info'")
|
||||||
|
|
||||||
|
parser.add_argument("-q",
|
||||||
|
"--quiet",
|
||||||
|
action="store_true",
|
||||||
|
help="Try to be as quiet as possible. Aliases to `--log_level error`")
|
||||||
|
|
||||||
return parser.parse_args(args=args)
|
return parser.parse_args(args=args)
|
||||||
|
|
||||||
|
|
||||||
@ -424,6 +436,10 @@ def parse_num_nodes(str_num_nodes: str, elastic_training: bool):
|
|||||||
def main(args=None):
|
def main(args=None):
|
||||||
args = parse_args(args)
|
args = parse_args(args)
|
||||||
|
|
||||||
|
if args.quiet:
|
||||||
|
args.log_level = "error"
|
||||||
|
set_log_level_from_string(args.log_level)
|
||||||
|
|
||||||
if args.elastic_training:
|
if args.elastic_training:
|
||||||
assert args.master_addr != "", "Master Addr is required when elastic training is enabled"
|
assert args.master_addr != "", "Master Addr is required when elastic training is enabled"
|
||||||
|
|
||||||
@ -553,6 +569,10 @@ def main(args=None):
|
|||||||
deepspeed_launch.append("--bind_cores_to_rank")
|
deepspeed_launch.append("--bind_cores_to_rank")
|
||||||
if args.bind_core_list is not None:
|
if args.bind_core_list is not None:
|
||||||
deepspeed_launch.append(f"--bind_core_list={args.bind_core_list}")
|
deepspeed_launch.append(f"--bind_core_list={args.bind_core_list}")
|
||||||
|
if args.quiet:
|
||||||
|
deepspeed_launch.append("--quiet")
|
||||||
|
deepspeed_launch.append(f"--log_level={args.log_level}")
|
||||||
|
|
||||||
cmd = deepspeed_launch + [args.user_script] + args.user_args
|
cmd = deepspeed_launch + [args.user_script] + args.user_args
|
||||||
else:
|
else:
|
||||||
args.launcher = args.launcher.lower()
|
args.launcher = args.launcher.lower()
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
# DeepSpeed Team
|
# DeepSpeed Team
|
||||||
|
|
||||||
from .logging import logger, log_dist
|
from .logging import logger, log_dist, set_log_level_from_string
|
||||||
from .comms_logging import get_caller_func
|
from .comms_logging import get_caller_func
|
||||||
#from .distributed import init_distributed
|
#from .distributed import init_distributed
|
||||||
from .init_on_device import OnDevice
|
from .init_on_device import OnDevice
|
||||||
|
@ -22,7 +22,7 @@ log_levels = {
|
|||||||
class LoggerFactory:
|
class LoggerFactory:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_logger(name=None, level=logging.INFO):
|
def create_logger(name=None, level=logging.WARNING):
|
||||||
"""create a logger
|
"""create a logger
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -58,7 +58,7 @@ class LoggerFactory:
|
|||||||
return logger_
|
return logger_
|
||||||
|
|
||||||
|
|
||||||
logger = LoggerFactory.create_logger(name="DeepSpeed", level=logging.INFO)
|
logger = LoggerFactory.create_logger(name="DeepSpeed", level=logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
@ -134,6 +134,29 @@ def print_json_dist(message, ranks=None, path=None):
|
|||||||
os.fsync(outfile)
|
os.fsync(outfile)
|
||||||
|
|
||||||
|
|
||||||
|
def get_log_level_from_string(log_level_str):
|
||||||
|
"""converts a log level string into its numerical equivalent. e.g. "info" => `logging.INFO`
|
||||||
|
"""
|
||||||
|
log_level_str = log_level_str.lower()
|
||||||
|
if log_level_str not in log_levels:
|
||||||
|
raise ValueError(
|
||||||
|
f"{log_level_str} is not one of the valid logging levels. Valid log levels are {log_levels.keys()}.")
|
||||||
|
return log_levels[log_level_str]
|
||||||
|
|
||||||
|
|
||||||
|
def set_log_level_from_string(log_level_str, custom_logger=None):
|
||||||
|
"""Sets a log level in the passed `logger` from string. e.g. "info" => `logging.INFO`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_level_str: one of 'debug', 'info', 'warning', 'error', 'critical'
|
||||||
|
custom_logger: if `None` will use the default `logger` object
|
||||||
|
"""
|
||||||
|
log_level = get_log_level_from_string(log_level_str)
|
||||||
|
if custom_logger is None:
|
||||||
|
custom_logger = logger
|
||||||
|
custom_logger.setLevel(log_level)
|
||||||
|
|
||||||
|
|
||||||
def get_current_level():
|
def get_current_level():
|
||||||
"""
|
"""
|
||||||
Return logger's current log level
|
Return logger's current log level
|
||||||
@ -156,8 +179,5 @@ def should_log_le(max_log_level_str):
|
|||||||
if not isinstance(max_log_level_str, str):
|
if not isinstance(max_log_level_str, str):
|
||||||
raise ValueError(f"{max_log_level_str} is not a string")
|
raise ValueError(f"{max_log_level_str} is not a string")
|
||||||
|
|
||||||
max_log_level_str = max_log_level_str.lower()
|
max_log_level = get_log_level_from_string(max_log_level_str)
|
||||||
if max_log_level_str not in log_levels:
|
return get_current_level() <= max_log_level
|
||||||
raise ValueError(f"{max_log_level_str} is not one of the logging levels")
|
|
||||||
|
|
||||||
return get_current_level() <= log_levels[max_log_level_str]
|
|
||||||
|
@ -522,7 +522,7 @@ class OpBuilder(ABC):
|
|||||||
extra_compile_args={'cxx': self.strip_empty_entries(self.cxx_args())},
|
extra_compile_args={'cxx': self.strip_empty_entries(self.cxx_args())},
|
||||||
extra_link_args=self.strip_empty_entries(self.extra_ldflags()))
|
extra_link_args=self.strip_empty_entries(self.extra_ldflags()))
|
||||||
|
|
||||||
def load(self, verbose=True):
|
def load(self, verbose=False):
|
||||||
if self.name in __class__._loaded_ops:
|
if self.name in __class__._loaded_ops:
|
||||||
return __class__._loaded_ops[self.name]
|
return __class__._loaded_ops[self.name]
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user