Handle the case of DeepCompile's enabled but not activated (#7603)

This PR improves state management for DeepCompile in the engine.

Previously, the system relied only on the config flag indicating whether
DeepCompile was enabled. However, DeepCompile is actually activated only
when `compile()` is called. This meant that if DeepCompile was enabled
in the config but `compile()` was never called, it could lead to invalid
internal states (as shown in #7598).

Since `enabled == True` should be interpreted as an option that modifies
the behavior of `compile()`, this PR introduces clearer state
management:
- If .compile() is not called, the DeepCompile config has no effect on
behavior. A one-time message is shown instead.
- A new state, DeepCompile activated, is introduced. This represents the
condition where DeepCompile is both enabled in the config and .compile()
has been called.

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
Masahiro Tanaka
2025-09-30 17:21:55 -07:00
committed by GitHub
parent 177c25c9d7
commit e32e817306
3 changed files with 67 additions and 12 deletions

View File

@ -74,7 +74,7 @@ from deepspeed.runtime.sparse_tensor import SparseTensor
from deepspeed.runtime import lr_schedules
from deepspeed.utils import groups
from deepspeed.utils import logger, log_dist, instrument_w_nvtx
from deepspeed.utils import logger, log_dist, log_dist_once, instrument_w_nvtx
from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \
STEP_MICRO_TIMER, \
@ -287,10 +287,12 @@ class DeepSpeedEngine(Module):
self.pipeline_parallelism = isinstance(model, PipelineModule)
self._deepcompile_active = False
# Configure distributed model
self._configure_distributed_model(model)
if not self.is_deepcompile_enabled():
# These hooks should be disabled later if DeepCompile is not active.
self.module_forward_pre_hook = self._create_module_forward_pre_hook()
self.module_forward_post_hook = self._create_module_forward_post_hook()
@ -521,7 +523,7 @@ class DeepSpeedEngine(Module):
def destroy(self):
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
self.optimizer.destroy()
if self.is_deepcompile_enabled():
if self.is_deepcompile_active():
get_deepcompile_handle().cleanup()
debug_clear_module_and_param_names()
@ -2155,7 +2157,12 @@ class DeepSpeedEngine(Module):
if self.autotuning_profile_model_info():
ma = get_ma_status()
if self.is_deepcompile_enabled() and hasattr(self, "launch_compile_passes"):
if self.is_deepcompile_enabled() and not self.is_deepcompile_active() and not self.is_compiled:
log_dist_once(
"DeepCompile is enabled but engine.compile() has not been called; executing without DeepCompile until compile() runs.",
ranks=[0])
if self.is_deepcompile_active() and hasattr(self, "launch_compile_passes"):
# We can't have this in forward prologue as the compiler compiles hooks including the forward prologue.
self.launch_compile_passes(self.global_steps)
@ -2213,7 +2220,7 @@ class DeepSpeedEngine(Module):
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
# Skip gradient reduction when DeepCompile is enabled
# DeepCompile handles its own gradient reduction through compiled graph operations
if self.is_deepcompile_enabled():
if self.is_deepcompile_active():
return
# Pass (PP) gas boundary flag to optimizer (required for zero)
@ -2239,7 +2246,7 @@ class DeepSpeedEngine(Module):
scale_wrt_gas = self.scale_wrt_gas
# scale loss w.r.t. gradient accumulation if reduction is not disabled
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt and not self.is_deepcompile_enabled(
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt and not self.is_deepcompile_active(
)
if do_gradient_reduction and self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
loss = self._scale_loss_by_gas(loss.float())
@ -2257,7 +2264,7 @@ class DeepSpeedEngine(Module):
)]
self.monitor.write_events(self.summary_events)
if self.is_deepcompile_enabled():
if self.is_deepcompile_active():
deepcompile_backward_prologue(self.is_gradient_accumulation_boundary())
if self.zenflow and self.auto_update:
@ -4083,11 +4090,40 @@ class DeepSpeedEngine(Module):
elif self.zero_optimization_stage() == ZeroStageEnum.weights:
backend = init_z3(self, backend, compile_config, compile_kwargs, schedule)
# Hook state must align with whether DeepCompile is active.
self._set_deepcompile_active(enable_deepcompile)
# create new dict to avoid modifying original dict
try:
self.module.compile(**{**compile_kwargs, 'backend': backend})
except Exception:
if enable_deepcompile:
# Restore default hooks if compilation fails before completing.
self._set_deepcompile_active(False)
raise
self._is_compiled = True
def _set_deepcompile_active(self, active: bool) -> None:
"""Toggle DeepCompile runtime state and manage forward hooks accordingly."""
if self._deepcompile_active == active:
return
if active:
if self.module_forward_pre_hook is not None:
self.module_forward_pre_hook.remove()
self.module_forward_pre_hook = None
if self.module_forward_post_hook is not None:
self.module_forward_post_hook.remove()
self.module_forward_post_hook = None
else:
if self.module_forward_pre_hook is None:
self.module_forward_pre_hook = self._create_module_forward_pre_hook()
if self.module_forward_post_hook is None:
self.module_forward_post_hook = self._create_module_forward_post_hook()
self._deepcompile_active = active
def get_compile_time(self):
from deepspeed.compile.backend import opt_pass_times
return opt_pass_times
@ -4095,9 +4131,12 @@ class DeepSpeedEngine(Module):
def register_compile_pass(self, pass_name: str, pass_fn: Callable) -> None:
register_compile_pass(pass_name, pass_fn)
def is_deepcompile_enabled(self):
def is_deepcompile_enabled(self) -> bool:
return self._config.compile_config.deepcompile
def is_deepcompile_active(self) -> bool:
return self._deepcompile_active
@property
def is_compiled(self) -> bool:
return self._is_compiled

View File

@ -3,7 +3,7 @@
# DeepSpeed Team
from .logging import logger, log_dist, set_log_level_from_string
from .logging import logger, log_dist, log_dist_once, set_log_level_from_string
from .comms_logging import get_caller_func
#from .distributed import init_distributed
from .init_on_device import OnDevice

View File

@ -107,6 +107,22 @@ def log_dist(message, ranks=None, level=logging.INFO):
logger.log(level, final_message)
@functools.lru_cache(None)
def _log_dist_once_cached(message, ranks_key, level):
ranks_arg = list(ranks_key) if ranks_key is not None else None
log_dist(message, ranks=ranks_arg, level=level)
def log_dist_once(message, ranks=None, level=logging.INFO):
# Identical to `log_dist`, but will emit each unique message only once per process.
# ranks is a list which is unhashable, so convert to tuple for caching
ranks_key = tuple(ranks) if ranks is not None else None
_log_dist_once_cached(message, ranks_key, level)
logger.log_dist_once = log_dist_once
def print_json_dist(message, ranks=None, path=None):
from deepspeed import comm as dist
"""Print message when one of following condition meets