Fix copy_ forward AD to handle broadcasting (#69592)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69592

Currently, forward AD function for`copy_` (in `VariableTypeManual`) does not handle the broadcasting case. ~EDIT: but that is not a design decision, not a bug. In this PR, we make that clear as a comment.~

Note: `broadcast_to` does not have a batching rule in core, so the ops that rely on `copy_` to broadcast will still fail batched forward grad computation.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33020603

Pulled By: soulitzer

fbshipit-source-id: 09cb702bffc74061964a9c05cfef5121f8164814
This commit is contained in:
soulitzer
2021-12-12 00:07:50 -08:00
committed by Facebook GitHub Bot
parent db32daf4b2
commit baf92f9d5a
2 changed files with 16 additions and 1 deletions

View File

@ -8395,6 +8395,17 @@ class TestAutogradDeviceType(TestCase):
z = x.to(torch.bfloat16)
self.assertTrue(z.requires_grad)
def test_copy_forward_ad_broadcasting(self, device):
# copy_ allows the src to have a different shape from self as long as src is
# broadcastable to self. Make sure forward AD handles this case.
primal = torch.rand(3, 3, device=device)
tangent = torch.rand(3, 3, device=device)
non_dual = torch.rand(1, 3, 3, device=device)
with fwAD.dual_level():
dual = fwAD.make_dual(primal, tangent)
non_dual.copy_(dual)
@onlyCUDA
def test_simple_reentrant_cross_device(self, device):
class ReentrantFunc(Function):

View File

@ -177,7 +177,11 @@ Tensor & copy_(c10::DispatchKeySet ks, Tensor & self, const Tensor & src, bool n
new_fw_grad = self_fw_grad.fill_(0);
}
} else {
new_fw_grad = src_fw_grad;
if (!self.is_same_size(src_fw_grad)) {
new_fw_grad = src_fw_grad.broadcast_to(self.sizes());
} else {
new_fw_grad = src_fw_grad;
}
}
self._set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ true);
}