mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 23:53:48 +08:00
Support training multiple models, such as in [HF](https://huggingface.co/docs/accelerate/en/usage_guides/deepspeed_multiple_model) Here is some update on supporting multiple DS engines with single loss.backward(). The main message is that I think we can support this. First, some context. Backward pass in ZeRO is complicated because the optimizations/features require special handling of gradients, such as: 1. Gradient partitioning 2. Overlapping backward and reduction 3. Upcasting for fp32 grad accumulation So, we created engine.backward(loss) as a wrapper function to provide us fine-grained control over backward as below ```python def backward(loss): backward_prologue() # setup logic for special gradient handling loss.backward() backward_epilogue() # cleanup/teardown logic ``` As demonstrated by @muellerzr, this approach breaks down when loss originates from multiple DS engines. Our proposed solution is to use backward hooks on the module to launch backward_prologue() and backward_epilogue() . Specifically, 1. backward pre hook on engine.module to launch backward_prologue() before any module gradient is created. 2. backward post hook on engine.module to launch backward_epilogue() after all module gradients are created. We plan for this solution to preserve BC, i.e., engine.backward() will remain correct for single engine scenarios. The current status is that (1) is completed, while (2) is in progress. To unblock e2e testing for multi-engine scenarios, since there are probably other issues, we have a temporarily added engine._backward_prologue() . You can try this out via the following artifacts. 1. Simple multi-engine test code: https://gist.github.com/tjruwase/f1adccf087b8fa269ffce2ab91c4f1c6#file-multi_engine-py 2. DS branch: https://github.com/microsoft/DeepSpeed/tree/olruwase/zero_multi_models --------- Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>