This is a BC breaking change to distribute_module. The underlying rationle
for this change is that sometimes in the input_fn/output_fn, user would want
to access to the current module for some attributes. This might not be
common enough, but in some cases it's worth to access to the module.
An outstanding use case we want to support is float8, if we want to make
float8 works with the TP API, the input_fn/output_fn of TP parallel
styles would need to get access to the module, where the module might
encapsulates `dynamic_linear.emulate` attribute, that is useful for
input/output casting
Since this is needed for fp8 and DTensor still under prototype release,
I feel it's worth the change and it's better we make the change as
early.
Right now making it a soft BC breaking, which means we maintain BC still
but throw deprecation messages.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120895
Approved by: https://github.com/tianyu-l
The existing warning in `DTensor.__new__()` checks `if requires_grad != local_tensor.requires_grad:` and warns with:
> To construct DTensor from `torch.Tensor`, it's recommended to use `local_tensor.detach()` and make `requires_grad` consistent.
Calling `local_tensor.detach()` will have the returned `Tensor` have `requires_grad=False`, so the error message refers to the case where `local_tensor.requires_grad is True` but the user passed `requires_grad=False` to `to_local()`.
However, there is the converse case, where `local_tensor.requires_grad is False` but the user passed `requires_grad=True`. In this case, the original `if requires_grad != local_tensor.requires_grad:` check succeeds, and the warning is emitted. However, the warning message does not apply in that case.
This can happen via `_prepare_output_fn` -> `redistribute` -> `Redistribute.forward()`, where `output.requires_grad is False` but it passes `requires_grad=input.requires_grad` which can be `True`.
We should not warn in this case since `Redistribute.forward()` is our own framework code, so I was proposing to relax the warning.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118186
Approved by: https://github.com/XilunWu, https://github.com/wanchaol
ghstack dependencies: #117994
This PR:
- refactors the redistribute implementation logic to make it more
sound, by figuring out the transform informations first and then apply
transformation step by step, we also cache the decisions so that it
could be reuse again
- for uneven sharding, refactor uneven sharding logic, and use a logical
shape concept for each transform information to fix the uneven sharding
multi-mesh redistribute bug
fixes https://github.com/pytorch/pytorch/issues/115310
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115525
Approved by: https://github.com/XilunWu
Summary:
This change makes the `DTensor.from_local()` placements in backward pass from `Partial()` to `Replicate()` as pass through for following reasons:
1. When we run backward pass of DTensor.from_local, if the target placement is partial() (i.e. from user manual overwrite code instead of torch_dispatch) we keep the grad as replicate. This is because converting the gradients back to `Partial()` is meaningless.
2. The current div logic will lead to wrong numerical value in the above case.
Test Plan:
**CI**:
CI Tests
**Unit test**:
`buck2 test mode/dev-nosan //caffe2/test/distributed/_tensor:redistribute`
- Passed
**With model training**:
```
# We tested the case where input tensor is manually overwrite as Partial() and
# output tensor manually overwrite to Shard() then to local.
# Before the change: numerical value not correct
Forward pass:
collective: ReduceScatter
backward pass:
collective: AllGather + div by process group size
# After the change: div is removed as expected.
Forward pass:
collective: ReduceScatter
Backward pas:
collective: AllGather
```
Differential Revision: D52175709
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115967
Approved by: https://github.com/wanchaol
Summary:
Rename _device_mesh.py to device_mesh.py, update all callsites, add documentation.
We created stubs for public class and methods in torch.distributed.device_mesh so that torch.distributed.device_mesh can be imported with or without distributed is available().
Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/115099
Prior to landing, CI signals are all passed. Shipit added the "ci/trunk" label to the PR and DID NOT wait for it and went ahead committing. More context can be found in the reverted PR above.
Test Plan: CI.
Differential Revision: D51861018
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115193
Approved by: https://github.com/fegin
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).
Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
* Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
* Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
* Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
* Signatures now:
```python
# attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
# ctx is anything useful for rebuilding the class we want to guard on
attrs, ctx = x.__tensor_flatten__()
...
# inner_tensors is a dict of {attr -> tensor}
# ctx is taken unmodified from flattening and (eventually) guarded on
# outer_size is the expected size of the output; possibly symbolic
# outer_stride is the expected strides of the output; possibly symbolic
y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)
# at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
# the assert simplifies symbols when there are relationships between outer and inner symbols
```
* Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
* Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
* Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
Summary:
Rename _device_mesh.py to device_mesh.py, update all callsites, adds documentation.
Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/114991
It was failing because failing a public module binding tests in MacOS, and this is due to the change in import order for torch/distributed/fsdp/_common_utils.py. Since this original import would still work, we remove the changes in this file.
Test Plan: CI.
Differential Revision: D51825114
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115099
Approved by: https://github.com/wanchaol, https://github.com/fegin
torch.equal/is_same_size currently skips sharding prop and directly do
local tensor compute, this is wrong. for these two ops:
- torch.equal: should not skip sharding prop, need to have two DTensor
have the SAME sharding before compare local shard values
- torch.is_same_size: need to completely skip both sharding prop and
local compute
This PR refactors the existing op_dispatch to make it a class instance
so that we can do custom op handling, then fixes both torch.equal and
torch.is_same_size
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112927
Approved by: https://github.com/fduwjj, https://github.com/XilunWu
full_tensor API should return synchronously instead of
AsyncCollectiveTensor and if the return is that, we do the wait
directly, this makes the full_tensor API be more percise
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113322
Approved by: https://github.com/wz337
Fixes#110762
This PR:
fixes issue described in #110762 by adding kwarg for shape and stride when creating DTensor using `DTensor.from_local()`. When `shape` and `stride` are provided, we skip calcualtion for `tensor_shape` and `tensor_stride` using `compute_global_tensor_info()`, as `compute_global_tensor_info()` always assume even sharding.
Test plan:
```
python3 test/distributed/_tensor/test_dtensor.py -k test_from_local_uneven_sharding
python3 test/distributed/_tensor/test_dtensor.py -k test_from_local_uneven_sharding_raise_error
```
cc. @wanchaol
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110781
Approved by: https://github.com/wanchaol
This PR introduces a `full_tensor` API to DTensor, there were so many
callsites that exercises the `redistribute(replicate)` path and I feel
it deserves a separate API, mostly just a syntactic sugar
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112224
Approved by: https://github.com/wz337
TP style still have some regression due to negative dim specifications,
fix it by allow DTensor API to handle negative dims and normalize them.
i.e. TP uses `Shard(-1)`, and then try to redistribute `Shard(1) -> Shard(-1)`, this should ideally be no-op but current it runs a decompose sharding phrase and it would turn this transformation to `Shard(1) -> Replicate -> Shard(-1)`, which is wrong and triggers unnecessary allgathers
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111750
Approved by: https://github.com/rohan-varma
We are refactoring parallel style to solve the following things:
1. To further simplifying code logic to make more readable for users.
2. To remove tuple check so that we can work with dynamo for now. Ideally dynamo needs to support this case and we will fix it in parallel.
3. Add tests for newly added parallel style in UT and torch compile test so that we can capture regression due to code change.
4. Move placements early return check into DTensor since it is by passed by dynamo.
5. Remove PairwiseParallelStyle from unit tests to use the new Col/Rowwise parallel style.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111625
Approved by: https://github.com/wanchaol
This PR updates DTensor to support torch.compile
Cool stuff: there are some new tests in `test_dtensor.py` that show both the forward and backward graphs that we can send to inductor, when running a matmul with DTensor's. In particular, for this user code:
```
def fn(x, y):
dt = DTensor.from_local(x.reshape(2, 4), mesh, [Shard(0)], run_check=False)
dt2 = DTensor.from_local(y.reshape(4, 2), mesh, [Shard(1)], run_check=False)
dt_out = torch.matmul(dt, dt2)
dt_out_redistribute = dt_out.redistribute(mesh, [Replicate()])
return dt_out.to_local()
```
We generate the following fw and backward graphs.
Forward graph:
```
def forward(self, primals_1, primals_2):
view = torch.ops.aten.view.default(primals_1, [2, 4]); primals_1 = None
_to_copy = torch.ops.aten._to_copy.default(view, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0)); view = None
detach = torch.ops.aten.detach.default(_to_copy); _to_copy = None
detach_1 = torch.ops.aten.detach.default(detach); detach = None
view_1 = torch.ops.aten.view.default(primals_2, [4, 2]); primals_2 = None
_to_copy_1 = torch.ops.aten._to_copy.default(view_1, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0)); view_1 = None
detach_2 = torch.ops.aten.detach.default(_to_copy_1); _to_copy_1 = None
detach_3 = torch.ops.aten.detach.default(detach_2); detach_2 = None
detach_4 = torch.ops.aten.detach.default(detach_1)
all_gather_into_tensor = torch.ops.c10d_functional.all_gather_into_tensor.default(detach_3, 'ptd:0', [0, 1], 2)
wait_tensor = torch.ops.c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None
split = torch.ops.aten.split.Tensor(wait_tensor, 4); wait_tensor = None
getitem = split[0]
getitem_1 = split[1]; split = None
cat = torch.ops.aten.cat.default([getitem, getitem_1], 1); getitem = getitem_1 = None
detach_5 = torch.ops.aten.detach.default(cat); cat = None
mm = torch.ops.aten.mm.default(detach_4, detach_5); detach_4 = detach_5 = None
detach_6 = torch.ops.aten.detach.default(mm); mm = None
detach_9 = torch.ops.aten.detach.default(detach_6); detach_6 = None
detach_10 = torch.ops.aten.detach.default(detach_9); detach_9 = None
t = torch.ops.aten.t.default(detach_1); detach_1 = None
detach_13 = torch.ops.aten.detach.default(t); t = None
t_1 = torch.ops.aten.t.default(detach_3); detach_3 = None
detach_15 = torch.ops.aten.detach.default(t_1); t_1 = None
clone = torch.ops.aten.clone.default(detach_15, memory_format = torch.contiguous_format); detach_15 = None
return [detach_10, detach_13, clone]
```
Backward graph:
```
def forward(self, detach_13, clone, tangents_1):
detach_11 = torch.ops.aten.detach.default(tangents_1); tangents_1 = None
detach_12 = torch.ops.aten.detach.default(detach_11); detach_11 = None
mm_1 = torch.ops.aten.mm.default(detach_13, detach_12); detach_13 = None
detach_14 = torch.ops.aten.detach.default(mm_1); mm_1 = None
detach_16 = torch.ops.aten.detach.default(detach_12); detach_12 = None
all_gather_into_tensor_2 = torch.ops.c10d_functional.all_gather_into_tensor.default(clone, 'ptd:0', [0, 1], 2); clone = None
wait_tensor_2 = torch.ops.c10d_functional.wait_tensor.default(all_gather_into_tensor_2);
detach_17 = torch.ops.aten.detach.default(wait_tensor_2); wait_tensor_2 = None
mm_2 = torch.ops.aten.mm.default(detach_16, detach_17); detach_16 = detach_17 = None
detach_18 = torch.ops.aten.detach.default(mm_2); mm_2 = None
split_1 = torch.ops.aten.split.Tensor(detach_14, 2, 1); detach_14 = None
getitem_2 = split_1[0]
getitem_3 = split_1[1]; split_1 = None
cat_1 = torch.ops.aten.cat.default([getitem_2, getitem_3]); getitem_2 = getitem_3 = None
reduce_scatter_tensor = torch.ops.c10d_functional.reduce_scatter_tensor.default(cat_1, 'SUM', 'ptd:0', [0, 1], 2); cat_1 = None
wait_tensor_3 = torch.ops.c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None
detach_19 = torch.ops.aten.detach.default(wait_tensor_3); wait_tensor_3 = None
detach_20 = torch.ops.aten.detach.default(detach_19); detach_19 = None
detach_21 = torch.ops.aten.detach.default(detach_20); detach_20 = None
detach_22 = torch.ops.aten.detach.default(detach_21); detach_21 = None
_to_copy_2 = torch.ops.aten._to_copy.default(detach_22, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); detach_22 = None
view_2 = torch.ops.aten.view.default(_to_copy_2, [8]); _to_copy_2 = None
detach_23 = torch.ops.aten.detach.default(detach_18); detach_18 = None
detach_24 = torch.ops.aten.detach.default(detach_23); detach_23 = None
_to_copy_3 = torch.ops.aten._to_copy.default(detach_24, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); detach_24 = None
view_3 = torch.ops.aten.view.default(_to_copy_3, [8]); _to_copy_3 = None
return [view_3, view_2]
```
Some of the stuff in this graph looks kinda of silly though (e.g. an unnecessary split() + cat(), and all the extra detach() calls).
Stuff that's broken:
- functionalization is pretty horribly broken. In particular, the original strategy I used in this stack was to have functionalization run **above** subclass desugaring. But that doesn't play well with with the way we want to compile DTensor. DTensor has a few API's like `.redistribute()`, `.to_local()`, and the `DTensor()` constructor, that we want to put directly into the graph so that we can compile them (e.g. redistribute() will desugar into collective ops). Doing this requires functionalization to run **underneath** the subclass though. I hacked around this for now, by forcing these functions to run functionalization first if they need to.
- the backward test that I have is... wrong. The backward graph that we trace out looks kind of reasonable, but it gives incorrect gradients on one of the two inputs. This needs further debugging (presumably we should be able to stare at the graph and identify which part of it is wrong?).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105236
Approved by: https://github.com/wanchaol
skip tensor.to in from_local and distribute_tensor when device_type of
device mesh matches tensor.device type, since from_local on the critial
path of TP, this might also reduce some overhead
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110774
Approved by: https://github.com/fduwjj
When we convert to local tensor, dtensor can't track autograd or
gradient layout of the local tensor anymore, if user do sth not expected, there
needs to be a way for user to hint about the gradient layout of the
local tensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110629
Approved by: https://github.com/zdevito
resolves https://github.com/pytorch/pytorch/issues/109101
The problem is essentially because we were hashing all the arguments, including
the scalar too (i.e. aten.div(tensor, scalar)), in the optimizer, the scalar might
change everytime we call the op, thus cache miss everytime we call the op
This PR improves the sharding cache behavior by introducing a
RuntimeSchemaInfo, used to record some runtime necessary hashing
information during op registration time. This enable us to:
* only hash arguments that are tensor or have static_argnum, this is to
enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache.
as we currently hashing all args which exclude those cases
* with the correct cache behavior, optimizers will hit the cache again
and resolve the high cpu overhead issue.
simple MLP shows all cache hit and for a single addmm -> 0.319ms (from 0.341ms), shows some hashing improvements:
<img width="1172" alt="Screenshot 2023-09-14 at 11 06 07 AM" src="https://github.com/pytorch/pytorch/assets/9443650/3406d673-dd8d-4ad9-9b80-9d4721c430e3">
Adam optimizer shows aten.div hit sharding cache again
<img width="1016" alt="Screenshot 2023-09-14 at 11 02 10 AM" src="https://github.com/pytorch/pytorch/assets/9443650/4280e8e3-af44-4fc2-8360-ea80b768f1d9">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109306
Approved by: https://github.com/fduwjj
This PR switches the usage of fx's shape prop TensorMetadata to
dtensor's own dedicated defined TensorMeta, this is because DTensor
only cares three fields: shape/stride/dtype, all other fields are not
necessary and can be inferred from local_tensor directly. This would
help significantly simplify how we deal with the tensor metadata by not
caring other fields.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108261
Approved by: https://github.com/fduwjj
ghstack dependencies: #107306
There is already some support for plumbing `__torch_dispatch__` tensor subclasses through dynamo, but this PR beefs it up a bit and adds a test. In particular:
(1) Fakeifying tensor subclasses didn't properly set autograd metadata (requires_grad, is_leaf) on the newly fakeified wrapper subclass. I don't actually have a test for this in this PR, but it's tested pretty heavily later in my aot autograd tests
(2) Fakeifying tensor subclasses didn't properly track source information for dynamic shapes on the inner tensors. I added a new `WrapperSubclassFieldSource` subclass, that represents a source coming from a tensor field on a wrapper subclass, which I use in the fakeifying logic, and again in symbolic_shapes.py to generate proper guards.
(3) `_make_wrapper_subclass()` marginally updated this code to work better with dynamic shapes. One thing that's a bit weird about `_make_wrapper_subclass`: it has two overloads, and the first explicitly does not support dynamic shapes (and the second.. does not support kwargs). I think that later we probably want to consolidate / at least make the first overload work with dynamic shapes, but I didn't want to handle that in this PR (so these smaller changes seemed like a strict improvement).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107415
Approved by: https://github.com/ezyang
This PR fixes the requires_grad set when calling distribute_tensor, we
should set the requires_grad of the local tensor after the detach call
to make sure we create the leaf correctly, otherwise it would raise
warnings
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107606
Approved by: https://github.com/fduwjj
This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.
This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms

after (with this change), aten.addmm latency: 0.341ms

overall one layer of mlp time reduced from 13.535 -> 9.665ms
Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107305
Approved by: https://github.com/fduwjj
Move the remaining collectives to a separate file to prepare device mesh
to become a public distributed API
For those remaining utils, we need to upstream them to functional
collectives with proper implementation, added TODO there for a follow up
PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107012
Approved by: https://github.com/fduwjj
This PR adds initial dynamo support for DTensor, in particular, it:
- allows DTensor be passed into a compiled function, and allow fakify
DTensor during dynamo tracing by turning the inner local tensor to meta
tensor.
- We use `allow_in_graph` to include `DTensor` and `DTensor.from_local` to be represented as `TorchVariable`
- The dtensor created becomes a normal `TensorVariable` and it would insert any tensor operations to the output graph just like torch.Tensor
- note that dtensor have a new instance method `redistribute` compare to plain tensor, and we currently special handle it in `TensorVariable`
`from_local` and `redistribute` both accepts some non-trival metadata as arguments (i.e. DeviceMesh, Placement) which fx.Graph does not support. In order to let these two APIs appear in the dynamo captured graph, we encoded the metadata into a new_function (like `functools.partial`) and the new function only accepts prim args (i.e. tensor), then we put `call_function` with this new_function to the graph. This is suggested by @ezyang. The underlying rationale here is that the metadata will not change across the graph invocations so it's safe to encode them.
Captured graph:
```
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:685, code: dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
prim_from_local = torch__dynamo_variables_torch_prim_from_local(l_x_, run_check = False); l_x_ = None
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:686, code: return dt.redistribute(mesh, [Replicate()]).to_local() + 2
prim_redistribute = torch__dynamo_variables_tensor_prim_redistribute(prim_from_local); prim_from_local = None
to_local = prim_redistribute.to_local(); prim_redistribute = None
add = to_local + 2; to_local = None
return (add,)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103146
Approved by: https://github.com/voznesenskym
This PR adds necessary plumbing through torchdynamo to allow tensor
subclasses with certain contract (i.e. with `__tensor_flatten__` and
`__tensor_unflatten__`) to goes through the dynamo fakification pass by
fakifying the tensor subclass internal components.
Some of the tensor subclass contract logic mostly borrowed from
https://github.com/pytorch/pytorch/pull/97540
Added some tests to verify simply passing through a tensor subclass
(i.e. DTensor) through dynamo eager works as expected.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105308
Approved by: https://github.com/ezyang
This PR canonicalize the detach callsite to only call the detach
from `distribute_tensor`. Change other callsite to view_as and remove the
tensor constructor detach call
This is so that we don't detach local tensor for every op run when
rewrapping the DTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105239
Approved by: https://github.com/albanD
# Change
This PR adds two classes to DTensor:
1. `CudaRNGStateTracker`: `CudaRNGStateTracker` stores Random Number Generator (RNG) state (a `ByteTensor` object) in a `dict`, mapping from a corresponding tag to each state tensor. It also provides a set of convenient utility methods to help access/modify the state tensors. The most important interface is `_distribute_region` which will be used when DTensor executes a random op (an operator that calls RNG).
2. `OffsetBasedRNGTracker`: This subclass of `CudaRNGStateTracker` defines the default policy of how RNG states should be shared and synchronized among all ranks to respect the semantics of DTensor random operators.
# Warning
- With `Multi-threaded ProcessGroup`, the global variable `_rng_tracker` will be shared among threads(ranks) and cause issue. We need to figure out a compatible solution for that.
- The RNG state may be asynchronous outside of participating ranks. It is harmless in our current use case of submesh though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103235
Approved by: https://github.com/wanchaol
This PR changes the context manager behavior of device mesh, now we use
a mesh env to track the current mesh and save the mesh to a stack so
that we can allow nested context manager
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101202
Approved by: https://github.com/wz337