mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[autograd] match 0-dim gradients device type regardless of subclassness (#160165)
Not sure if there some subclasses where the outer.dim() == 0 but you wouldn't want to move it? FIXES https://github.com/pytorch/pytorch/issues/160084 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160165 Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
d25c4f954d
commit
c8205cb354
@ -12396,6 +12396,29 @@ class TestAutogradDeviceType(TestCase):
|
||||
x.resize_as_(y)
|
||||
self.assertEqual(x._version, 2)
|
||||
|
||||
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
|
||||
def test_zero_dim_param_mixed_device_grad(self, device):
|
||||
# cpu 0-dim params with an accelerator device grad
|
||||
# https://github.com/pytorch/pytorch/issues/160084
|
||||
class RegressionModel(torch.nn.Module):
|
||||
def __init__(self, a=0, b=0):
|
||||
super().__init__()
|
||||
self.a = torch.nn.Parameter(torch.tensor(a).float())
|
||||
self.b = torch.nn.Parameter(torch.tensor(b).float())
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.a + self.b
|
||||
|
||||
# Keep the model on cpu as we do want to test the mixed cpu/accelerator behavior here
|
||||
model = RegressionModel()
|
||||
inputs = torch.randn(4, 10, device=device)
|
||||
out = model(inputs)
|
||||
out.sum().backward()
|
||||
self.assertIsNotNone(model.a.grad)
|
||||
self.assertIsNotNone(model.b.grad)
|
||||
self.assertEqual(model.a.grad.device, torch.device("cpu"))
|
||||
self.assertEqual(model.b.grad.device, torch.device("cpu"))
|
||||
|
||||
|
||||
class TestAllowMutationOnSaved(TestCase):
|
||||
def assertClonedLenEqual(self, ctx, n):
|
||||
|
Reference in New Issue
Block a user