Files
DeepSpeed/docs/code-docs
Olatunji Ruwase b418cf6c1b Training multiple models (#7018)
Support training multiple models, such as in
[HF](https://huggingface.co/docs/accelerate/en/usage_guides/deepspeed_multiple_model)

Here is some update on supporting multiple DS engines with single
loss.backward(). The main message is that I think we can support this.
First, some context. Backward pass in ZeRO is complicated because the
optimizations/features require special handling of gradients, such as:

1. Gradient partitioning
2. Overlapping backward and reduction
3. Upcasting for fp32 grad accumulation

So, we created engine.backward(loss) as a wrapper function to provide us
fine-grained control over backward as below

```python
def backward(loss):
 backward_prologue() # setup logic for special gradient handling
 loss.backward()
 backward_epilogue() # cleanup/teardown logic
```

As demonstrated by @muellerzr, this approach breaks down when loss
originates from multiple DS engines. Our proposed solution is to use
backward hooks on the module to launch backward_prologue() and
backward_epilogue() . Specifically,

1. backward pre hook on engine.module to launch backward_prologue()
before any module gradient is created.
2. backward post hook on engine.module to launch backward_epilogue()
after all module gradients are created.

We plan for this solution to preserve BC, i.e., engine.backward() will
remain correct for single engine scenarios.
The current status is that (1) is completed, while (2) is in progress.
To unblock e2e testing for multi-engine scenarios, since there are
probably other issues, we have a temporarily added
engine._backward_prologue() . You can try this out via the following
artifacts.

1. Simple multi-engine test code:
https://gist.github.com/tjruwase/f1adccf087b8fa269ffce2ab91c4f1c6#file-multi_engine-py
2. DS branch:
https://github.com/microsoft/DeepSpeed/tree/olruwase/zero_multi_models

---------

Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
2025-03-11 20:59:23 +00:00
..
2025-03-11 20:59:23 +00:00
2020-03-17 13:49:48 -07:00