mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DTensor] Add guide for what to do about mixed torch.Tensor and DTensor operations (#162651)
Also updates the error message to point to the guide. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162651 Approved by: https://github.com/ezyang ghstack dependencies: #162117, #162307
This commit is contained in:
@ -260,3 +260,73 @@ these features.
|
|||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. py:module:: torch.distributed.tensor.device_mesh
|
.. py:module:: torch.distributed.tensor.device_mesh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Mixed Tensor and DTensor operations
|
||||||
|
|
||||||
|
So you got the following error message.
|
||||||
|
```
|
||||||
|
got mixed torch.Tensor and DTensor, need to convert all
|
||||||
|
torch.Tensor to DTensor before calling distributed operators!
|
||||||
|
```
|
||||||
|
|
||||||
|
There are two cases.
|
||||||
|
|
||||||
|
### Case 1: this is user error
|
||||||
|
|
||||||
|
The most common way to run into this error is to create a regular Tensor
|
||||||
|
(using a factory function) and then perform a Tensor-DTensor operation,
|
||||||
|
like the following:
|
||||||
|
|
||||||
|
```
|
||||||
|
tensor = torch.arange(10)
|
||||||
|
return tensor + dtensor
|
||||||
|
```
|
||||||
|
|
||||||
|
We disallow mixed Tensor-DTensor operations: if the input to any operations
|
||||||
|
(e.g. torch.add) is a DTensor, then all Tensor inputs must be DTensors.
|
||||||
|
This is because the semantics are ambiguous. We don't know if `tensor` is
|
||||||
|
the same across ranks or if it is different so we ask that the user
|
||||||
|
figure out how to construct a DTensor with accurate placements from `tensor`.
|
||||||
|
|
||||||
|
If each rank does have the same `tensor`, then please construct a replicated
|
||||||
|
DTensor:
|
||||||
|
|
||||||
|
```
|
||||||
|
tensor = torch.arange(10)
|
||||||
|
tensor = DTensor.from_local(tensor, placements=(Replicate(),))
|
||||||
|
return tensor + dtensor
|
||||||
|
```
|
||||||
|
|
||||||
|
If you wanted to create a DTensor with shards, below is how to do it.
|
||||||
|
Semantically this means that your Tensor data is split between the shards
|
||||||
|
and that operations act on the "full stacked data".
|
||||||
|
|
||||||
|
```
|
||||||
|
tensor = torch.full([], RANK)
|
||||||
|
tensor = DTensor.from_local(tensor, placements=(Shard(0),))
|
||||||
|
return tensor + dtensor
|
||||||
|
```
|
||||||
|
|
||||||
|
There are other things you may wish to do with your tensor beyond
|
||||||
|
these situations (these are not the only two options!).
|
||||||
|
|
||||||
|
## Case 2: the error came from PyTorch framework code
|
||||||
|
|
||||||
|
Sometimes the problem is that PyTorch framework code attempts to perform mixed
|
||||||
|
Tensor-DTensor operations. These are bugs in PyTorch, please file an issue
|
||||||
|
so that we can fix them.
|
||||||
|
|
||||||
|
On the user side, the only thing you can do is to avoid using the operation
|
||||||
|
that caused the issue and file a bug report.
|
||||||
|
|
||||||
|
For PyTorch Developers: one approach of fixing this is to rewrite PyTorch
|
||||||
|
framework code to avoid mixed Tensor-DTensor code (like in the previous section).
|
||||||
|
|
||||||
|
For PyTorch Developers: the second approach is to turn on DTensor implicit
|
||||||
|
replication inside the right places in PyTorch framework code.
|
||||||
|
When on, any mixed Tensor-DTensor operations will assume that the
|
||||||
|
non-DTensors can be replicated. Please be careful when using this as it
|
||||||
|
can lead to silent incorrectness.
|
||||||
|
|
||||||
|
- [Turning on implicit replication in Python](https://github.com/pytorch/pytorch/blob/d8e6b2fddc54c748d976e8f0ebe4b63ebe36d85b/torch/distributed/tensor/experimental/__init__.py#L15)
|
||||||
|
- [Turning on implicit replication in C++](https://github.com/pytorch/pytorch/blob/7a0f93344e2c851b9bcf2b9c3225a323d48fde26/aten/src/ATen/DTensorState.h#L10)
|
||||||
|
@ -521,5 +521,7 @@ class OpDispatcher:
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
|
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
|
||||||
" torch.Tensor to DTensor before calling distributed operators!"
|
" torch.Tensor to DTensor before calling distributed operators!"
|
||||||
|
" Please see https://docs.pytorch.org/docs/main/distributed.tensor.html#mixed-tensor-and-dtensor-operations"
|
||||||
|
" for more details."
|
||||||
)
|
)
|
||||||
return replication_spec
|
return replication_spec
|
||||||
|
Reference in New Issue
Block a user