mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Control trace cache warnings (#7039)
Make trace cache warnings configurable, and disabled by default. Fix #6985, #4081, #5033, #5006, #5662 --------- Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
@ -983,6 +983,9 @@ class DeepSpeedEngine(Module):
|
||||
def zeropp_loco_param(self):
|
||||
return self._config.zero_config.zeropp_loco_param
|
||||
|
||||
def zero_log_trace_cache_warnings(self):
|
||||
return self._config.zero_config.log_trace_cache_warnings
|
||||
|
||||
def dump_state(self):
|
||||
return self._config.dump_state
|
||||
|
||||
@ -1692,6 +1695,7 @@ class DeepSpeedEngine(Module):
|
||||
zero_quantized_weights=self.zero_quantized_weights(),
|
||||
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
|
||||
zero_module_granularity_threshold=self.zero_module_granularity_threshold(),
|
||||
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
|
||||
)
|
||||
else:
|
||||
log_dist(
|
||||
@ -1740,6 +1744,7 @@ class DeepSpeedEngine(Module):
|
||||
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
|
||||
zero_module_granularity_threshold=self.zero_module_granularity_threshold(),
|
||||
zeropp_loco_param=self.zeropp_loco_param(),
|
||||
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -45,6 +45,7 @@ ZeRO optimization should be enabled as:
|
||||
"memory_efficient_linear": [true|false],
|
||||
"override_module_apply": [true|false],
|
||||
"zeropp_loco_param": {...},
|
||||
"log_trace_cache_warnings" : [true|false],
|
||||
}
|
||||
}
|
||||
"""
|
||||
@ -340,6 +341,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
|
||||
Override nn.Module apply function, for Stage 3.
|
||||
"""
|
||||
|
||||
log_trace_cache_warnings: bool = False
|
||||
"""
|
||||
Whether to log warnings from trace cache, such as invalidation events.
|
||||
"""
|
||||
|
||||
# Validators
|
||||
@model_validator(mode="after")
|
||||
def overlap_comm_valid(self):
|
||||
|
@ -103,6 +103,7 @@ class DeepSpeedZeRoOffload(object):
|
||||
zero_quantized_weights=False,
|
||||
zero_quantized_nontrainable_weights=False,
|
||||
zero_module_granularity_threshold=0,
|
||||
log_trace_cache_warnings=False,
|
||||
):
|
||||
|
||||
see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True)
|
||||
@ -118,6 +119,7 @@ class DeepSpeedZeRoOffload(object):
|
||||
self.zero_param_parallel_group = zero_param_parallel_group
|
||||
self.zero_quantized_weights = zero_quantized_weights
|
||||
self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights
|
||||
self.log_trace_cache_warnings = log_trace_cache_warnings
|
||||
|
||||
if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
|
||||
self.offload_device = offload_param_config.device
|
||||
@ -165,7 +167,9 @@ class DeepSpeedZeRoOffload(object):
|
||||
timers=self.timers,
|
||||
zero_quantized_weights=self.zero_quantized_weights,
|
||||
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
|
||||
fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module)
|
||||
fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module,
|
||||
log_trace_cache_warnings=self.log_trace_cache_warnings,
|
||||
)
|
||||
|
||||
self.forward_hooks = []
|
||||
self.backward_hooks = []
|
||||
|
@ -76,7 +76,8 @@ class PartitionedParameterCoordinator:
|
||||
param: Parameter
|
||||
step_id_last_used_at: int
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
prefetch_bucket_sz: int,
|
||||
max_reuse_distance_in_numel: int,
|
||||
max_available_parameters_in_numel: int,
|
||||
@ -86,7 +87,9 @@ class PartitionedParameterCoordinator:
|
||||
timers=None,
|
||||
zero_quantized_weights=False,
|
||||
zero_quantized_nontrainable_weights=False,
|
||||
fast_sharding_for_leaf_module=False) -> None:
|
||||
fast_sharding_for_leaf_module=False,
|
||||
log_trace_cache_warnings=False,
|
||||
) -> None:
|
||||
# mapping of param -> handle for each param that is currently in flight
|
||||
self.__inflight_param_registry = inflight_param_registry
|
||||
# keeps track of the number of submodules invoked so far.
|
||||
@ -129,6 +132,9 @@ class PartitionedParameterCoordinator:
|
||||
self.__max_ongoing_fetch_events: int = 2
|
||||
self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None)
|
||||
|
||||
# Whether to log trace cache warnings, e.g. invalidation events
|
||||
self.__log_trace_cache_warnings = log_trace_cache_warnings
|
||||
|
||||
# whether to enable fast fetch for the z3 leaf module.
|
||||
# this will improve fetch speed but will not break down leaf module parameters to alleviate memory pressure.
|
||||
self.fast_sharding_for_leaf_module = fast_sharding_for_leaf_module
|
||||
@ -177,7 +183,7 @@ class PartitionedParameterCoordinator:
|
||||
print_rank_0(
|
||||
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.ds_id}: "
|
||||
f"cache has only {len(self.__submodule_order)} modules",
|
||||
force=True)
|
||||
force=self.__log_trace_cache_warnings)
|
||||
self._invalidate_trace()
|
||||
return
|
||||
|
||||
@ -186,7 +192,7 @@ class PartitionedParameterCoordinator:
|
||||
print_rank_0(
|
||||
f"Invalidate trace cache @ step {self.__step_id}: "
|
||||
f"expected module {expected_module_id}, but got module {sub_module.ds_id}",
|
||||
force=True)
|
||||
force=self.__log_trace_cache_warnings)
|
||||
self._invalidate_trace()
|
||||
|
||||
@compiler.disable
|
||||
|
@ -160,6 +160,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
zero_quantized_nontrainable_weights=False,
|
||||
zero_module_granularity_threshold=0,
|
||||
zeropp_loco_param=None,
|
||||
log_trace_cache_warnings=False,
|
||||
):
|
||||
see_memory_usage("Stage 3 initialize beginning", force=True)
|
||||
|
||||
@ -231,7 +232,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
zero_param_parallel_group=zero_param_parallel_group,
|
||||
zero_quantized_weights=zero_quantized_weights,
|
||||
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
|
||||
zero_module_granularity_threshold=zero_module_granularity_threshold)
|
||||
zero_module_granularity_threshold=zero_module_granularity_threshold,
|
||||
log_trace_cache_warnings=log_trace_cache_warnings,
|
||||
)
|
||||
|
||||
self.persistent_parameters = self.parameter_offload.persistent_parameters
|
||||
self._configure_offloading(offload_optimizer_config, offload_param_config)
|
||||
@ -465,6 +468,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
zero_quantized_weights,
|
||||
zero_quantized_nontrainable_weights,
|
||||
zero_module_granularity_threshold,
|
||||
log_trace_cache_warnings,
|
||||
):
|
||||
return DeepSpeedZeRoOffload(module=module,
|
||||
timers=timers,
|
||||
@ -481,7 +485,8 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
zero_param_parallel_group=zero_param_parallel_group,
|
||||
zero_quantized_weights=zero_quantized_weights,
|
||||
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
|
||||
zero_module_granularity_threshold=zero_module_granularity_threshold)
|
||||
zero_module_granularity_threshold=zero_module_granularity_threshold,
|
||||
log_trace_cache_warnings=log_trace_cache_warnings)
|
||||
|
||||
def _get_trainable_parameter_groups(self):
|
||||
param_groups = []
|
||||
|
@ -371,11 +371,12 @@ Enabling and configuring ZeRO memory optimizations
|
||||
"sub_group_size" : 1e12,
|
||||
"elastic_checkpoint" : [true|false],
|
||||
"stage3_gather_16bit_weights_on_model_save": [true|false],
|
||||
"ignore_unused_parameters": [true|false]
|
||||
"round_robin_gradients": [true|false]
|
||||
"zero_hpz_partition_size": 1
|
||||
"zero_quantized_weights": [true|false]
|
||||
"zero_quantized_gradients": [true|false]
|
||||
"ignore_unused_parameters": [true|false],
|
||||
"round_robin_gradients": [true|false],
|
||||
"zero_hpz_partition_size": 1,
|
||||
"zero_quantized_weights": [true|false],
|
||||
"zero_quantized_gradients": [true|false],
|
||||
"log_trace_cache_warnings": [true|false],
|
||||
}
|
||||
```
|
||||
|
||||
@ -512,6 +513,12 @@ Enabling and configuring ZeRO memory optimizations
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------- | ------- |
|
||||
|Boolean indicating whether to enable communication efficient quantized gradients of ZeRO++. | `False` |
|
||||
|
||||
<i>**log_trace_cache_warnings**</i>: [boolean]
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------------------------------------------------------------- | ------- |
|
||||
| Log warnings from trace cache optimization of parameter sharding, such as cache invalidation events. | `False` |
|
||||
|
||||
***cpu_offload***: [boolean]
|
||||
|
||||
**Deprecated:** **cpu_offload** is deprecated and will be removed in future, please use `offload_optimizer` instead.
|
||||
|
Reference in New Issue
Block a user