Needed this class for because `parallelize_module` takes a dict, which doesn't allow `PrepareModuleInput` and `PrepareModuleOutput` to be applied at the same time.
The `PrepareModuleInputOutput` in this PR initializes two variables `prepare_module_input` and `prepare_module_output` and uses them to process module / inputs / outputs.
I had another implementation which put all code in `PrepareModuleInputOutput` and let `PrepareModuleInput` and `PrepareModuleOutput` inherit the monolithic `PrepareModuleInputOutput`. But it is
1. less cleaner
2. conceptually abusing inheritance because `PrepareModuleInput` shouldn't be able to access class methods of `PrepareModuleOutput` and vice versa
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150372
Approved by: https://github.com/wanchaol
As titled, this PR introduces a dedicated `ParallelStyle` to shard the
nn.LayerNorm/nn.Dropout/RMSNorm layers. We were mainly using a manual
distribute_module calls before when sharding the RMSNorm layer, but I
think we should have a dedicate TP API to easily shard those layers,
instead of user manually using DTensors.
I call this SequenceParallel, which might bring some confusion that we
technically "deprecated" a SequenceParallel style months ago. But this
time the SeuqenceParallel style is significantly different with the
previous ones (which used to shard two consecutive Linear layers). I
believe making it the right name is the first priority, instead of
worrying about the issue of reusing the old name
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121295
Approved by: https://github.com/awgu, https://github.com/tianyu-l
ghstack dependencies: #121294
Loss parallel is the last piece of sequence parallelism to enable. It enables efficient distributed cross entropy computation when the input is sharded on the class dimension (in a classification problem with many classes). The implementation is via a context manager `loss_parallel`, after enabling which users can directly use `torch.nn.functional.cross_entropy` or `torch.nn.CrossEntropyLoss` without modifying other parts of their code.
Here are the underlying rationales why we are going through these op replacements:
1. `nn.functional.cross_entropy` is the common method that OSS user is using for things like transformer training, to avoid changing user code, we want user to still use this function for loss calculation if they are already using it.
2. `nn.functional.cross_entropy` boils down into `aten.log_softmax` and `aten.nll_loss_foward/backward`, and DTensor now supports those ops already (#117723#119255#118917#119256). They are doing computation with input *replicated* on the class dimension.
3. However when the input of this loss calculation is **sharded on the class dimension**, to run sharded computation efficiently, we need to run both `aten.log_softmax` and `aten.nll_loss_foward` with multiple all-reduce collectives **in the middle of** those aten ops. This is not possible if we are just overriding these two ops, so we need to have some way to **decompose** these two ops into smaller ops to have collectives run in the middle of these two ops.
4. We explored the existing decompositions (#118950). It seems working, except that `log_softmax_backward` and `nll_loss_backward` combined together in aten are implemented in a inefficient way, which would trigger an additional expensive collective. Recently some user also reported similar issues https://github.com/pytorch/pytorch/issues/119261.
5. Therefore, currently we are doing our own decomposition inside a context manager for sequence parallelism specifically. Once we have a better decomposition in core, we can possibly take that instead of reinventing the wheels here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119877
Approved by: https://github.com/wanchaol
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.
In particular this PR:
* Make ParallelStyle to be a real contract API for parallelize_module to
take, each concrete ParallelStyle only needs to implement `apply` to
apply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
* Keep the ColwiseParallel and RowwiseParallel public interface, but
refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
* remove all those private _prepare_input/_prepare_output_fn fields for
both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
* Refactor the recently landed PrepareModuleInput/Output style, change
output_layouts to desired_input/output_layouts, group
the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
`use_local_output` flag.
* Make default arguments be None instead of Placement object, this is
standard python practice to not have custom object instance as default
argument
* Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel
style, all prepare input/output functions) as we throw deprecation
msgs for a while, and in the progress of removing all of them from the tests.
* throw deprecation warning for `tp_mesh_dim` as we recomemnd use device
mesh slice/indexing instead of manually specify mesh dim
* Rewrite all documentations for every ParallelStyle and make the
documentation more clear about what each style is doing
TODOs:
* Rewrite TP tests to adjust for the changes we have in this PR
* add more tests to guard the bug fixes
Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732
Approved by: https://github.com/wz337, https://github.com/fduwjj
As part of TP UX improvements, we want to keep our API simple (not easy) so that users get the flexibility to do what they want and avoid a too generic API which tries to solve everything and get things too complicated. We are updating the doc accordingly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111176
Approved by: https://github.com/wanchaol
ghstack dependencies: #111160, #111166
In some use cases, we found that users might want to annote the input/output DTensor layout for the parent module rather than the submodule whose parameters are to be distributed so that we want to have these two class for users to annote input/output DTensor layouts so that we register pre-FWD/FWD hook for the TP-lized module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111166
Approved by: https://github.com/wanchaol
ghstack dependencies: #111160
To make TP more generic for Attention module, we come up with this new col/rowwise parallel style.
Basically, the idea behind is that:
We only do DTensor op for Col/Rowwise sharded part. For the rest of ATen ops, we will leave it to Tensor ops.
And we set this behavior as default for Colwise and Rowwise parallel style. If people want to customize it, they can always pass in different prepare_input or prepare_output
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100508
Approved by: https://github.com/wanchaol