[dtensor] to_local backward grad placement passthrough (#121474)

to_local accepts a `grad_placements` if user choose to pass, previously
we enforce the grad_out to be the "same" placement as the current
DTensor for safety.

But I realized that we DO NOT need to enforce this constraint. Why?
backward placement does not need to be the same as fwd tensor placement, this
is already the case for param vs param.grad (i.e. param can be replicate
and grad can be partial), so we should not restrict this to activation
vs activation grad too

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121474
Approved by: https://github.com/awgu, https://github.com/yoyoyocmu, https://github.com/yifuwang
This commit is contained in:
Wanchao Liang
2024-03-07 20:13:38 -08:00
committed by PyTorch MergeBot
parent 9373ad0bb8
commit bc02fca358
2 changed files with 19 additions and 17 deletions

View File

@ -12,6 +12,7 @@ from torch.distributed._tensor import (
DTensor,
init_device_mesh,
)
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
@ -30,6 +31,9 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
)
c10d_functional = torch.ops.c10d_functional
class DummyMLP(torch.nn.Module):
def __init__(self, device):
super().__init__()
@ -335,10 +339,20 @@ class DTensorTest(DTensorTestBase):
global_tensor = torch.ones(8, 3, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
local_out = sharded_dtensor.redistribute(placements=[Replicate()]).to_local(
grad_placements=[_Partial()]
comm_mode = CommDebugMode()
with comm_mode:
local_out = sharded_dtensor.redistribute(placements=[Replicate()]).to_local(
grad_placements=[_Partial()]
)
local_out.backward(torch.ones_like(local_out))
self.assertEqual(
comm_mode.comm_counts[c10d_functional.all_gather_into_tensor], 1
)
self.assertEqual(
comm_mode.comm_counts[c10d_functional.reduce_scatter_tensor], 1
)
local_out.sum().backward()
replica_grad = sharded_dtensor.grad.full_tensor()
self.assertEqual(replica_grad, global_tensor * self.world_size)

View File

@ -89,25 +89,13 @@ class _ToTorchTensor(torch.autograd.Function):
grad_output, mesh, dtensor_spec.placements
)
tensor_stride = tuple(tensor_stride)
if grad_placements is not None:
grad_spec = DTensorSpec(
mesh,
grad_placements,
tensor_meta=TensorMeta(
shape=dtensor_meta.shape,
stride=tensor_stride,
dtype=dtensor_meta.dtype,
),
)
grad_output = redistribute_local_tensor(
grad_output, grad_spec, dtensor_spec
)
grad_placements = grad_placements or dtensor_spec.placements
return (
DTensor(
grad_output,
mesh,
dtensor_spec.placements,
grad_placements,
shape=dtensor_meta.shape,
dtype=dtensor_meta.dtype,
requires_grad=grad_output.requires_grad,