Previously, DTensor kept its own copy of the generator state after the first time a random operator was called on a DTensor. This copy would evolve independently from the generator outside of DTensor. After adding support for users to pass a specific generator into random operators (e.g. `uniform_(..., generator=)`), it was determined (in discussion on #159991) to change the semantics so that any random operations performed on DTensor would evolve the state of the publicly visible generators (either the default one or user-passed one). The upsides are (1) it is now possible to call torch.manual_seed() at any point in the program and have a consistent effect on DTensor, (2) DTensor ops have an observable effect on the generator. The downside is that users are now responsible for seeding their generator before using DTensor, ensuring all ranks use the same seed. Fixes #159991 confirmed docs rendered OK <img width="897" height="414" alt="image" src="https://github.com/user-attachments/assets/c082f0f0-5447-47aa-834f-65342eb237cd" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160482 Approved by: https://github.com/wanchaol
9.1 KiB
:::{currentmodule} torch.distributed.tensor :::
torch.distributed.tensor
:::{note}
torch.distributed.tensor
is currently in alpha state and under
development, we are committing backward compatibility for the most APIs listed
in the doc, but there might be API changes if necessary.
:::
PyTorch DTensor (Distributed Tensor)
PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributed
logic, including sharded storage, operator computation and collective communications across devices/hosts.
DTensor
could be used to build different parallelism solutions and support sharded state_dict representation
when working with multi-dimensional sharding.
Please see examples from the PyTorch native parallelism solutions that are built on top of DTensor
:
.. automodule:: torch.distributed.tensor
{class}DTensor
follows the SPMD (single program, multiple data) programming model to empower users to
write distributed program as if it's a single-device program with the same convergence property. It
provides a uniform tensor sharding layout (DTensor Layout) through specifying the {class}DeviceMesh
and {class}Placement
:
- {class}
DeviceMesh
represents the device topology and the communicators of the cluster using an n-dimensional array. - {class}
Placement
describes the sharding layout of the logical tensor on the {class}DeviceMesh
. DTensor supports three types of placements: {class}Shard
, {class}Replicate
and {class}Partial
.
DTensor Class APIs
.. currentmodule:: torch.distributed.tensor
{class}DTensor
is a torch.Tensor
subclass. This means once a {class}DTensor
is created, it could be
used in very similar way to torch.Tensor
, including running different types of PyTorch operators as if
running them in a single device, allowing proper distributed computation for PyTorch operators.
In addition to existing torch.Tensor
methods, it also offers a set of additional methods to interact with
torch.Tensor
, redistribute
the DTensor Layout to a new DTensor, get the full tensor content
on all devices, etc.
.. autoclass:: DTensor
:members: from_local, to_local, full_tensor, redistribute, device_mesh, placements
:member-order: groupwise
:special-members: __create_chunk_list__
DeviceMesh as the distributed communicator
.. currentmodule:: torch.distributed.device_mesh
{class}DeviceMesh
was built from DTensor as the abstraction to describe cluster's device topology and represent
multi-dimensional communicators (on top of ProcessGroup
). To see the details of how to create/use a DeviceMesh,
please refer to the DeviceMesh recipe.
DTensor Placement Types
.. automodule:: torch.distributed.tensor.placement_types
.. currentmodule:: torch.distributed.tensor.placement_types
DTensor supports the following types of {class}Placement
on each {class}DeviceMesh
dimension:
.. autoclass:: Shard
:members:
:undoc-members:
.. autoclass:: Replicate
:members:
:undoc-members:
.. autoclass:: Partial
:members:
:undoc-members:
.. autoclass:: Placement
:members:
:undoc-members:
(create_dtensor)=
Different ways to create a DTensor
.. currentmodule:: torch.distributed.tensor
- There're three ways to construct a {class}
DTensor
: - {meth}
distribute_tensor
creates a {class}DTensor
from a logical or "global"torch.Tensor
on each rank. This could be used to shard the leaftorch.Tensor
s (i.e. model parameters/buffers and inputs). - {meth}
DTensor.from_local
creates a {class}DTensor
from a localtorch.Tensor
on each rank, which can be used to create {class}DTensor
from a non-leaftorch.Tensor
s (i.e. intermediate activation tensors during forward/backward). - DTensor provides dedicated tensor factory functions (e.g. {meth}
empty
, {meth}ones
, {meth}randn
, etc.) to allow different {class}DTensor
creations by directly specifying the {class}DeviceMesh
and {class}Placement
. Compare to {meth}distribute_tensor
, this could directly materializing the sharded memory on device, instead of performing sharding after initializing the logical Tensor memory.
- {meth}
Create DTensor from a logical torch.Tensor
The SPMD (single program, multiple data) programming model in torch.distributed
launches multiple processes
(i.e. via torchrun
) to execute the same program, this means that the model inside the program would be
initialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directly
on GPU if enough memory).
DTensor
offers a {meth}distribute_tensor
API that could shard the model weights or Tensors to DTensor
s,
where it would create a DTensor from the "logical" Tensor on each process. This would empower the created
DTensor
s to comply with the single device semantic, which is critical for numerical correctness.
.. autofunction:: distribute_tensor
Along with {meth}distribute_tensor
, DTensor also offers a {meth}distribute_module
API to allow easier
sharding on the {class}nn.Module
level
.. autofunction:: distribute_module
DTensor Factory Functions
DTensor also provides dedicated tensor factory functions to allow creating {class}DTensor
directly
using torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionally
specifying the {class}DeviceMesh
and {class}Placement
for the {class}DTensor
created:
.. autofunction:: zeros
.. autofunction:: ones
.. autofunction:: empty
.. autofunction:: full
.. autofunction:: rand
.. autofunction:: randn
Random Operations
DTensor provides distributed RNG functionality to ensure that random operations on sharded tensors get unique values, and random operations on replicated tensors get the same values. This system requires that all participating ranks (e.g. SPMD ranks) start out using the same generator state before each dtensor random operation is performed, and if this is true, it ensures they all end up at the same state after each dtensor random operation completes. There is no communication performed during random operations to synchronize RNG states.
Operators that accept a generator
kwarg will utilize the user-passed generator, if passed, or the default generator for the device otherwise. Whichever generator is used, it will be advanced after the DTensor operation. It is valid to use the same generator for both DTensor and non-DTensor operations, but care must be taken to ensure the non-DTensor operations advance the generator state equally on all ranks if so.
When using DTensor together with Pipeline Parallelism, ranks for each pipeline stage should use a distinct seed, and ranks within a pipeline stage should use the same seed.
DTensor's RNG infra is based on the philox based RNG algorithm, and supports any philox based backend (cuda, and other cuda-like devices), but unfortunately does not yet support the CPU backend.
Debugging
.. automodule:: torch.distributed.tensor.debug
.. currentmodule:: torch.distributed.tensor.debug
Logging
When launching the program, you can turn on additional logging using the TORCH_LOGS
environment variable from
torch._logging :
TORCH_LOGS=+dtensor
will displaylogging.DEBUG
messages and all levels above it.TORCH_LOGS=dtensor
will displaylogging.INFO
messages and above.TORCH_LOGS=-dtensor
will displaylogging.WARNING
messages and above.
Debugging Tools
To debug the program that applied DTensor, and understand more details about what collectives happened under the
hood, DTensor provides a {class}CommDebugMode
:
.. autoclass:: CommDebugMode
:members:
:undoc-members:
To visualize the sharding of a DTensor that have less than 3 dimensions, DTensor provides {meth}visualize_sharding
:
.. autofunction:: visualize_sharding
Experimental Features
DTensor
also provides a set of experimental features. These features are either in prototyping stage, or the basic
functionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks to
these features.
.. automodule:: torch.distributed.tensor.experimental
.. currentmodule:: torch.distributed.tensor.experimental
.. autofunction:: context_parallel
.. autofunction:: local_map
.. autofunction:: register_sharding
% modules that are missing docs, add the doc later when necessary
.. py:module:: torch.distributed.tensor.device_mesh