mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Handle the case of DeepCompile's enabled but not activated (#7603)
This PR improves state management for DeepCompile in the engine. Previously, the system relied only on the config flag indicating whether DeepCompile was enabled. However, DeepCompile is actually activated only when `compile()` is called. This meant that if DeepCompile was enabled in the config but `compile()` was never called, it could lead to invalid internal states (as shown in #7598). Since `enabled == True` should be interpreted as an option that modifies the behavior of `compile()`, this PR introduces clearer state management: - If .compile() is not called, the DeepCompile config has no effect on behavior. A one-time message is shown instead. - A new state, DeepCompile activated, is introduced. This represents the condition where DeepCompile is both enabled in the config and .compile() has been called. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
@ -74,7 +74,7 @@ from deepspeed.runtime.sparse_tensor import SparseTensor
|
||||
|
||||
from deepspeed.runtime import lr_schedules
|
||||
from deepspeed.utils import groups
|
||||
from deepspeed.utils import logger, log_dist, instrument_w_nvtx
|
||||
from deepspeed.utils import logger, log_dist, log_dist_once, instrument_w_nvtx
|
||||
from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \
|
||||
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \
|
||||
STEP_MICRO_TIMER, \
|
||||
@ -287,10 +287,12 @@ class DeepSpeedEngine(Module):
|
||||
|
||||
self.pipeline_parallelism = isinstance(model, PipelineModule)
|
||||
|
||||
self._deepcompile_active = False
|
||||
|
||||
# Configure distributed model
|
||||
self._configure_distributed_model(model)
|
||||
|
||||
if not self.is_deepcompile_enabled():
|
||||
# These hooks should be disabled later if DeepCompile is not active.
|
||||
self.module_forward_pre_hook = self._create_module_forward_pre_hook()
|
||||
self.module_forward_post_hook = self._create_module_forward_post_hook()
|
||||
|
||||
@ -521,7 +523,7 @@ class DeepSpeedEngine(Module):
|
||||
def destroy(self):
|
||||
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
|
||||
self.optimizer.destroy()
|
||||
if self.is_deepcompile_enabled():
|
||||
if self.is_deepcompile_active():
|
||||
get_deepcompile_handle().cleanup()
|
||||
debug_clear_module_and_param_names()
|
||||
|
||||
@ -2155,7 +2157,12 @@ class DeepSpeedEngine(Module):
|
||||
if self.autotuning_profile_model_info():
|
||||
ma = get_ma_status()
|
||||
|
||||
if self.is_deepcompile_enabled() and hasattr(self, "launch_compile_passes"):
|
||||
if self.is_deepcompile_enabled() and not self.is_deepcompile_active() and not self.is_compiled:
|
||||
log_dist_once(
|
||||
"DeepCompile is enabled but engine.compile() has not been called; executing without DeepCompile until compile() runs.",
|
||||
ranks=[0])
|
||||
|
||||
if self.is_deepcompile_active() and hasattr(self, "launch_compile_passes"):
|
||||
# We can't have this in forward prologue as the compiler compiles hooks including the forward prologue.
|
||||
self.launch_compile_passes(self.global_steps)
|
||||
|
||||
@ -2213,7 +2220,7 @@ class DeepSpeedEngine(Module):
|
||||
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
|
||||
# Skip gradient reduction when DeepCompile is enabled
|
||||
# DeepCompile handles its own gradient reduction through compiled graph operations
|
||||
if self.is_deepcompile_enabled():
|
||||
if self.is_deepcompile_active():
|
||||
return
|
||||
|
||||
# Pass (PP) gas boundary flag to optimizer (required for zero)
|
||||
@ -2239,7 +2246,7 @@ class DeepSpeedEngine(Module):
|
||||
scale_wrt_gas = self.scale_wrt_gas
|
||||
|
||||
# scale loss w.r.t. gradient accumulation if reduction is not disabled
|
||||
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt and not self.is_deepcompile_enabled(
|
||||
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt and not self.is_deepcompile_active(
|
||||
)
|
||||
if do_gradient_reduction and self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
|
||||
loss = self._scale_loss_by_gas(loss.float())
|
||||
@ -2257,7 +2264,7 @@ class DeepSpeedEngine(Module):
|
||||
)]
|
||||
self.monitor.write_events(self.summary_events)
|
||||
|
||||
if self.is_deepcompile_enabled():
|
||||
if self.is_deepcompile_active():
|
||||
deepcompile_backward_prologue(self.is_gradient_accumulation_boundary())
|
||||
|
||||
if self.zenflow and self.auto_update:
|
||||
@ -4083,11 +4090,40 @@ class DeepSpeedEngine(Module):
|
||||
elif self.zero_optimization_stage() == ZeroStageEnum.weights:
|
||||
backend = init_z3(self, backend, compile_config, compile_kwargs, schedule)
|
||||
|
||||
# Hook state must align with whether DeepCompile is active.
|
||||
self._set_deepcompile_active(enable_deepcompile)
|
||||
|
||||
# create new dict to avoid modifying original dict
|
||||
try:
|
||||
self.module.compile(**{**compile_kwargs, 'backend': backend})
|
||||
except Exception:
|
||||
if enable_deepcompile:
|
||||
# Restore default hooks if compilation fails before completing.
|
||||
self._set_deepcompile_active(False)
|
||||
raise
|
||||
|
||||
self._is_compiled = True
|
||||
|
||||
def _set_deepcompile_active(self, active: bool) -> None:
|
||||
"""Toggle DeepCompile runtime state and manage forward hooks accordingly."""
|
||||
if self._deepcompile_active == active:
|
||||
return
|
||||
|
||||
if active:
|
||||
if self.module_forward_pre_hook is not None:
|
||||
self.module_forward_pre_hook.remove()
|
||||
self.module_forward_pre_hook = None
|
||||
if self.module_forward_post_hook is not None:
|
||||
self.module_forward_post_hook.remove()
|
||||
self.module_forward_post_hook = None
|
||||
else:
|
||||
if self.module_forward_pre_hook is None:
|
||||
self.module_forward_pre_hook = self._create_module_forward_pre_hook()
|
||||
if self.module_forward_post_hook is None:
|
||||
self.module_forward_post_hook = self._create_module_forward_post_hook()
|
||||
|
||||
self._deepcompile_active = active
|
||||
|
||||
def get_compile_time(self):
|
||||
from deepspeed.compile.backend import opt_pass_times
|
||||
return opt_pass_times
|
||||
@ -4095,9 +4131,12 @@ class DeepSpeedEngine(Module):
|
||||
def register_compile_pass(self, pass_name: str, pass_fn: Callable) -> None:
|
||||
register_compile_pass(pass_name, pass_fn)
|
||||
|
||||
def is_deepcompile_enabled(self):
|
||||
def is_deepcompile_enabled(self) -> bool:
|
||||
return self._config.compile_config.deepcompile
|
||||
|
||||
def is_deepcompile_active(self) -> bool:
|
||||
return self._deepcompile_active
|
||||
|
||||
@property
|
||||
def is_compiled(self) -> bool:
|
||||
return self._is_compiled
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .logging import logger, log_dist, set_log_level_from_string
|
||||
from .logging import logger, log_dist, log_dist_once, set_log_level_from_string
|
||||
from .comms_logging import get_caller_func
|
||||
#from .distributed import init_distributed
|
||||
from .init_on_device import OnDevice
|
||||
|
@ -107,6 +107,22 @@ def log_dist(message, ranks=None, level=logging.INFO):
|
||||
logger.log(level, final_message)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _log_dist_once_cached(message, ranks_key, level):
|
||||
ranks_arg = list(ranks_key) if ranks_key is not None else None
|
||||
log_dist(message, ranks=ranks_arg, level=level)
|
||||
|
||||
|
||||
def log_dist_once(message, ranks=None, level=logging.INFO):
|
||||
# Identical to `log_dist`, but will emit each unique message only once per process.
|
||||
# ranks is a list which is unhashable, so convert to tuple for caching
|
||||
ranks_key = tuple(ranks) if ranks is not None else None
|
||||
_log_dist_once_cached(message, ranks_key, level)
|
||||
|
||||
|
||||
logger.log_dist_once = log_dist_once
|
||||
|
||||
|
||||
def print_json_dist(message, ranks=None, path=None):
|
||||
from deepspeed import comm as dist
|
||||
"""Print message when one of following condition meets
|
||||
|
Reference in New Issue
Block a user