mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix copy_ forward AD to handle broadcasting (#69592)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69592 Currently, forward AD function for`copy_` (in `VariableTypeManual`) does not handle the broadcasting case. ~EDIT: but that is not a design decision, not a bug. In this PR, we make that clear as a comment.~ Note: `broadcast_to` does not have a batching rule in core, so the ops that rely on `copy_` to broadcast will still fail batched forward grad computation. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D33020603 Pulled By: soulitzer fbshipit-source-id: 09cb702bffc74061964a9c05cfef5121f8164814
This commit is contained in:
committed by
Facebook GitHub Bot
parent
db32daf4b2
commit
baf92f9d5a
@ -8395,6 +8395,17 @@ class TestAutogradDeviceType(TestCase):
|
||||
z = x.to(torch.bfloat16)
|
||||
self.assertTrue(z.requires_grad)
|
||||
|
||||
def test_copy_forward_ad_broadcasting(self, device):
|
||||
# copy_ allows the src to have a different shape from self as long as src is
|
||||
# broadcastable to self. Make sure forward AD handles this case.
|
||||
primal = torch.rand(3, 3, device=device)
|
||||
tangent = torch.rand(3, 3, device=device)
|
||||
non_dual = torch.rand(1, 3, 3, device=device)
|
||||
|
||||
with fwAD.dual_level():
|
||||
dual = fwAD.make_dual(primal, tangent)
|
||||
non_dual.copy_(dual)
|
||||
|
||||
@onlyCUDA
|
||||
def test_simple_reentrant_cross_device(self, device):
|
||||
class ReentrantFunc(Function):
|
||||
|
@ -177,7 +177,11 @@ Tensor & copy_(c10::DispatchKeySet ks, Tensor & self, const Tensor & src, bool n
|
||||
new_fw_grad = self_fw_grad.fill_(0);
|
||||
}
|
||||
} else {
|
||||
new_fw_grad = src_fw_grad;
|
||||
if (!self.is_same_size(src_fw_grad)) {
|
||||
new_fw_grad = src_fw_grad.broadcast_to(self.sizes());
|
||||
} else {
|
||||
new_fw_grad = src_fw_grad;
|
||||
}
|
||||
}
|
||||
self._set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ true);
|
||||
}
|
||||
|
Reference in New Issue
Block a user