Control trace cache warnings (#7039)

Make trace cache warnings configurable, and disabled by default. 

Fix #6985, #4081, #5033, #5006, #5662

---------

Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Olatunji Ruwase
2025-02-18 14:16:17 -05:00
committed by GitHub
parent 14b3cce4aa
commit ee3f19bf66
6 changed files with 54 additions and 21 deletions

View File

@ -983,6 +983,9 @@ class DeepSpeedEngine(Module):
def zeropp_loco_param(self):
return self._config.zero_config.zeropp_loco_param
def zero_log_trace_cache_warnings(self):
return self._config.zero_config.log_trace_cache_warnings
def dump_state(self):
return self._config.dump_state
@ -1692,6 +1695,7 @@ class DeepSpeedEngine(Module):
zero_quantized_weights=self.zero_quantized_weights(),
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
zero_module_granularity_threshold=self.zero_module_granularity_threshold(),
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
)
else:
log_dist(
@ -1740,6 +1744,7 @@ class DeepSpeedEngine(Module):
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
zero_module_granularity_threshold=self.zero_module_granularity_threshold(),
zeropp_loco_param=self.zeropp_loco_param(),
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
)
else:

View File

@ -45,6 +45,7 @@ ZeRO optimization should be enabled as:
"memory_efficient_linear": [true|false],
"override_module_apply": [true|false],
"zeropp_loco_param": {...},
"log_trace_cache_warnings" : [true|false],
}
}
"""
@ -340,6 +341,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
Override nn.Module apply function, for Stage 3.
"""
log_trace_cache_warnings: bool = False
"""
Whether to log warnings from trace cache, such as invalidation events.
"""
# Validators
@model_validator(mode="after")
def overlap_comm_valid(self):

View File

@ -103,6 +103,7 @@ class DeepSpeedZeRoOffload(object):
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
zero_module_granularity_threshold=0,
log_trace_cache_warnings=False,
):
see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True)
@ -118,6 +119,7 @@ class DeepSpeedZeRoOffload(object):
self.zero_param_parallel_group = zero_param_parallel_group
self.zero_quantized_weights = zero_quantized_weights
self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights
self.log_trace_cache_warnings = log_trace_cache_warnings
if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
self.offload_device = offload_param_config.device
@ -165,7 +167,9 @@ class DeepSpeedZeRoOffload(object):
timers=self.timers,
zero_quantized_weights=self.zero_quantized_weights,
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module)
fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module,
log_trace_cache_warnings=self.log_trace_cache_warnings,
)
self.forward_hooks = []
self.backward_hooks = []

View File

@ -76,7 +76,8 @@ class PartitionedParameterCoordinator:
param: Parameter
step_id_last_used_at: int
def __init__(self,
def __init__(
self,
prefetch_bucket_sz: int,
max_reuse_distance_in_numel: int,
max_available_parameters_in_numel: int,
@ -86,7 +87,9 @@ class PartitionedParameterCoordinator:
timers=None,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
fast_sharding_for_leaf_module=False) -> None:
fast_sharding_for_leaf_module=False,
log_trace_cache_warnings=False,
) -> None:
# mapping of param -> handle for each param that is currently in flight
self.__inflight_param_registry = inflight_param_registry
# keeps track of the number of submodules invoked so far.
@ -129,6 +132,9 @@ class PartitionedParameterCoordinator:
self.__max_ongoing_fetch_events: int = 2
self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None)
# Whether to log trace cache warnings, e.g. invalidation events
self.__log_trace_cache_warnings = log_trace_cache_warnings
# whether to enable fast fetch for the z3 leaf module.
# this will improve fetch speed but will not break down leaf module parameters to alleviate memory pressure.
self.fast_sharding_for_leaf_module = fast_sharding_for_leaf_module
@ -177,7 +183,7 @@ class PartitionedParameterCoordinator:
print_rank_0(
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.ds_id}: "
f"cache has only {len(self.__submodule_order)} modules",
force=True)
force=self.__log_trace_cache_warnings)
self._invalidate_trace()
return
@ -186,7 +192,7 @@ class PartitionedParameterCoordinator:
print_rank_0(
f"Invalidate trace cache @ step {self.__step_id}: "
f"expected module {expected_module_id}, but got module {sub_module.ds_id}",
force=True)
force=self.__log_trace_cache_warnings)
self._invalidate_trace()
@compiler.disable

View File

@ -160,6 +160,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
zero_quantized_nontrainable_weights=False,
zero_module_granularity_threshold=0,
zeropp_loco_param=None,
log_trace_cache_warnings=False,
):
see_memory_usage("Stage 3 initialize beginning", force=True)
@ -231,7 +232,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
zero_param_parallel_group=zero_param_parallel_group,
zero_quantized_weights=zero_quantized_weights,
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
zero_module_granularity_threshold=zero_module_granularity_threshold)
zero_module_granularity_threshold=zero_module_granularity_threshold,
log_trace_cache_warnings=log_trace_cache_warnings,
)
self.persistent_parameters = self.parameter_offload.persistent_parameters
self._configure_offloading(offload_optimizer_config, offload_param_config)
@ -465,6 +468,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
zero_quantized_weights,
zero_quantized_nontrainable_weights,
zero_module_granularity_threshold,
log_trace_cache_warnings,
):
return DeepSpeedZeRoOffload(module=module,
timers=timers,
@ -481,7 +485,8 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
zero_param_parallel_group=zero_param_parallel_group,
zero_quantized_weights=zero_quantized_weights,
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
zero_module_granularity_threshold=zero_module_granularity_threshold)
zero_module_granularity_threshold=zero_module_granularity_threshold,
log_trace_cache_warnings=log_trace_cache_warnings)
def _get_trainable_parameter_groups(self):
param_groups = []

View File

@ -371,11 +371,12 @@ Enabling and configuring ZeRO memory optimizations
"sub_group_size" : 1e12,
"elastic_checkpoint" : [true|false],
"stage3_gather_16bit_weights_on_model_save": [true|false],
"ignore_unused_parameters": [true|false]
"round_robin_gradients": [true|false]
"zero_hpz_partition_size": 1
"zero_quantized_weights": [true|false]
"zero_quantized_gradients": [true|false]
"ignore_unused_parameters": [true|false],
"round_robin_gradients": [true|false],
"zero_hpz_partition_size": 1,
"zero_quantized_weights": [true|false],
"zero_quantized_gradients": [true|false],
"log_trace_cache_warnings": [true|false],
}
```
@ -512,6 +513,12 @@ Enabling and configuring ZeRO memory optimizations
| ----------------------------------------------------------------------------------------------------------------------------------- | ------- |
|Boolean indicating whether to enable communication efficient quantized gradients of ZeRO++. | `False` |
<i>**log_trace_cache_warnings**</i>: [boolean]
| Description | Default |
| ------------------------------------------------------------------------------------------------------------------- | ------- |
| Log warnings from trace cache optimization of parameter sharding, such as cache invalidation events. | `False` |
***cpu_offload***: [boolean]
**Deprecated:** **cpu_offload** is deprecated and will be removed in future, please use `offload_optimizer` instead.