This is going to be used in https://github.com/pytorch/torchtitan/issues/1682
Add a `register_custom_function` to the `_PipelineScheduleRuntime` which allows users to implement any custom function to replace the runtime operation dynamically.
The signature of the callback should look like:
```python
class _CustomFunctionProtocol(Protocol):
def __call__(self, action: _Action, ctx: _PipelineContext) -> None: ...
```
`_PipelineContext` contains a reference to the schedule which is executing the operations.
### Testing
Added a test which adds custom methods for `FORWARD` and `OVERLAP_F_B` which are just the same implementations as those used in the default schedule runtime. Check that the schedule can still run, numerics are correct, and the callbacks are executed the correct number of times.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162016
Approved by: https://github.com/fegin
First fix for https://github.com/pytorch/pytorch/issues/164756
In the pipeline IR we call `UNSHARD` and `RESHARD`, but there is a bug because when we call `module.unshard()` these do not recursively call the FSDP modules, hence leading to sometime call allgather before the module forward.
Since we want the pipeline IR to explicitly handle this, we can call `group.unshard` instead which ensures that all the modules are unsharded.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164775
Approved by: https://github.com/weifengpy
Previously, an eval() call before a training step() would not correctly initialize the backward pass of the pipeline stages, leading to errors during the subsequent training step. This PR ensures that the backward stages can still be initialized after an eval() call.
Fixes#162822
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162823
Approved by: https://github.com/dcci, https://github.com/H-Huang
These change add an `eval()` API to PP schedules
## Context
Currently, you can run "Forward only" for a schedule in two ways:
1. Use a custom schedule `_ScheduleForwardOnly`
2. Do not pass in `loss_fn` in schedule constructor, and no backward computations will be executed.
However, this is still limiting because we may want to run forward through the pipeline / calculate the loss, but without backward, e.g. during validation. These changes allow for this.
```python
if self.rank == 0:
schedule.eval(x)
elif self.rank == self.world_size - 1:
losses = []
schedule.eval(target=target, losses=losses)
else:
schedule.eval()
```
TODO:
- in later PRs, we will deprecate the `_ScheduleForwardOnly`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157795
Approved by: https://github.com/wconstab
Change logging.error to logging.exception to log additional information when relevant. A few places have slipped in logging.errors in try except since I last did a clean up here and the rule is stabilized so I am enabling it codebase wide. I have NOQA'd much of our custom exception stack trace handling for RPC calls and distributed and tried to a fix a few errors based on whether we immediately reraised it or if we didn't print any exception handling where it could be useful.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153473
Approved by: https://github.com/albanD, https://github.com/cyyever
Fixes#132644
`_batch_p2p` incorrectly assumes that `dist.batch_isend_irecv` returns a single-element list of `dist.Work`, likely due to NCCL's coalescing behaviour.
For none NCCL backends like Gloo, multiple `dist.Work` objects are returned, causing the code to discard some operations via `.pop()`. This leads to deadlocks during pipeline parallelism.
## Changes:
* Modified `_batch_p2p` to return `list[dist.Work]` instead of popping a single element.
* Added `_wait_batch_p2p` to call `wait()` on multiple `dist.Work` objects, consuming the result of `_batch_p2p`.
* Updated references from `dist.Work` to `list[dist.Work]`.
## Testing:
* `pippy_bert.py` from #132644 now works with gloo.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152938
Approved by: https://github.com/kwen2501, https://github.com/H-Huang
This PR allows schedules loaded via CSV to automatically set their `stage_index_to_group_rank ` and removes the `stage_index_to_group_rank ` argument from the `PipelineScheduleMulti` constructor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146217
Approved by: https://github.com/wconstab
ghstack dependencies: #146193
We use `stage_index_to_group_rank` in the stage to determine what send/recv ops and in the schedule for IR generation. However, we don't need to expose this as an argument in our schedule class, so this stack of PRs is to remove it.
This PR creates a `stage_index_to_group_rank` utility function and removes the arg for the ZBVschedule. In a following PR I will add code to infer the `stage_index_to_group_rank` for the CSV schedule path and we will be able to remove this argument from our classes entirely.
Related comment from @wconstab https://github.com/pytorch/torchtitan/issues/774#issuecomment-2619793741
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146193
Approved by: https://github.com/wconstab
Adds a grad-scaling method `perform_pp_grad_scaling()` which divides grads by num_microbatches.
Enables grad scaling by default, unless disabled due to using a loss function that sums instead of averaging losses.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144352
Approved by: https://github.com/H-Huang
**Overview**
This PR moves `torch/distributed/_composable/fsdp` to `torch/distributed/fsdp/_fully_shard` and makes public APIs available from `torch.distributed.fsdp`, e.g.:
```
from torch.distributed.fsdp import fully_shard
```
This is targeting 2.6 release. I rewrote some of the documentation with (hopefully) improved phrasing.
**Follow-Ups**
- [x] Add some explanation in the docs about FSDP1 vs. FSDP2
- [ ] Move unit tests from `test/distributed/_composable/fsdp` to `test/distributed/fsdp/fully_shard/`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141868
Approved by: https://github.com/kwen2501, https://github.com/wconstab, https://github.com/weifengpy
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>