mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dtensor] to_local backward grad placement passthrough (#121474)
to_local accepts a `grad_placements` if user choose to pass, previously we enforce the grad_out to be the "same" placement as the current DTensor for safety. But I realized that we DO NOT need to enforce this constraint. Why? backward placement does not need to be the same as fwd tensor placement, this is already the case for param vs param.grad (i.e. param can be replicate and grad can be partial), so we should not restrict this to activation vs activation grad too Pull Request resolved: https://github.com/pytorch/pytorch/pull/121474 Approved by: https://github.com/awgu, https://github.com/yoyoyocmu, https://github.com/yifuwang
This commit is contained in:
committed by
PyTorch MergeBot
parent
9373ad0bb8
commit
bc02fca358
@ -12,6 +12,7 @@ from torch.distributed._tensor import (
|
||||
DTensor,
|
||||
init_device_mesh,
|
||||
)
|
||||
from torch.distributed._tensor.debug import CommDebugMode
|
||||
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
@ -30,6 +31,9 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
)
|
||||
|
||||
|
||||
c10d_functional = torch.ops.c10d_functional
|
||||
|
||||
|
||||
class DummyMLP(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
@ -335,10 +339,20 @@ class DTensorTest(DTensorTestBase):
|
||||
global_tensor = torch.ones(8, 3, requires_grad=True)
|
||||
|
||||
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
|
||||
local_out = sharded_dtensor.redistribute(placements=[Replicate()]).to_local(
|
||||
grad_placements=[_Partial()]
|
||||
comm_mode = CommDebugMode()
|
||||
|
||||
with comm_mode:
|
||||
local_out = sharded_dtensor.redistribute(placements=[Replicate()]).to_local(
|
||||
grad_placements=[_Partial()]
|
||||
)
|
||||
local_out.backward(torch.ones_like(local_out))
|
||||
|
||||
self.assertEqual(
|
||||
comm_mode.comm_counts[c10d_functional.all_gather_into_tensor], 1
|
||||
)
|
||||
self.assertEqual(
|
||||
comm_mode.comm_counts[c10d_functional.reduce_scatter_tensor], 1
|
||||
)
|
||||
local_out.sum().backward()
|
||||
|
||||
replica_grad = sharded_dtensor.grad.full_tensor()
|
||||
self.assertEqual(replica_grad, global_tensor * self.world_size)
|
||||
|
Reference in New Issue
Block a user