[autograd] match 0-dim gradients device type regardless of subclassness (#160165)

Not sure if there some subclasses where the outer.dim() == 0 but you wouldn't want to move it?

FIXES https://github.com/pytorch/pytorch/issues/160084

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160165
Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
Simon Fan
2025-08-09 12:02:47 -07:00
committed by PyTorch MergeBot
parent d25c4f954d
commit c8205cb354
4 changed files with 55 additions and 51 deletions

View File

@ -7673,6 +7673,31 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
out2 = torch.compile(model, backend="eager")(input.clone())
self.assertEqual(out1, out2)
@requires_cuda
def test_zero_dim_param_mixed_device_grad(self):
# cpu 0-dim params with cuda grads
# https://github.com/pytorch/pytorch/issues/160084
class RegressionModel(torch.nn.Module):
def __init__(self, a=0, b=0):
super().__init__()
self.a = torch.nn.Parameter(torch.tensor(a).float())
self.b = torch.nn.Parameter(torch.tensor(b).float())
def forward(self, x):
return x * self.a + self.b
model = RegressionModel()
model.forward = torch.compile(
model.forward, backend="aot_eager", fullgraph=True
)
inputs = torch.randn(4, 10).to("cuda")
out = model(inputs)
out.sum().backward()
self.assertIsNotNone(model.a.grad)
self.assertIsNotNone(model.b.grad)
self.assertEqual(model.a.grad.device, torch.device("cpu"))
self.assertEqual(model.b.grad.device, torch.device("cpu"))
def test_filter_warnings(self):
x = torch.ones(2, 2, requires_grad=True)

View File

@ -12396,6 +12396,29 @@ class TestAutogradDeviceType(TestCase):
x.resize_as_(y)
self.assertEqual(x._version, 2)
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
def test_zero_dim_param_mixed_device_grad(self, device):
# cpu 0-dim params with an accelerator device grad
# https://github.com/pytorch/pytorch/issues/160084
class RegressionModel(torch.nn.Module):
def __init__(self, a=0, b=0):
super().__init__()
self.a = torch.nn.Parameter(torch.tensor(a).float())
self.b = torch.nn.Parameter(torch.tensor(b).float())
def forward(self, x):
return x * self.a + self.b
# Keep the model on cpu as we do want to test the mixed cpu/accelerator behavior here
model = RegressionModel()
inputs = torch.randn(4, 10, device=device)
out = model(inputs)
out.sum().backward()
self.assertIsNotNone(model.a.grad)
self.assertIsNotNone(model.b.grad)
self.assertEqual(model.a.grad.device, torch.device("cpu"))
self.assertEqual(model.b.grad.device, torch.device("cpu"))
class TestAllowMutationOnSaved(TestCase):
def assertClonedLenEqual(self, ctx, n):

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: __torch_dispatch__"]
# ruff: noqa: F841
import logging
import pickle
import sys
import tempfile
@ -1718,49 +1717,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
self.assertEqual(s.device_index, 2)
self.assertEqual(s.device_type, 3)
def test_subclass_autograd_device_check(self) -> None:
class NonWrapperSubclass(torch.Tensor):
elem: torch.Tensor
__slots__ = ["elem"]
@staticmethod
def __new__(cls, elem, *args, **kwargs):
# Wrong device here!
r = torch.Tensor._make_subclass(
cls, elem.to("meta"), elem.requires_grad
)
# ...the real tensor is held as an element on the tensor.
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
return e.elem if isinstance(e, NonWrapperSubclass) else e
def wrap(e):
return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e
rs = tree_map(
wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
)
logging.getLogger("NonWrapperSubclass").info(
f"{func.__module__}.{func.__name__}", # noqa: G004
args,
kwargs,
rs,
)
return rs
x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True))
y = torch.randn(2, requires_grad=True)
z = x * y
self.assertIsInstance(z, NonWrapperSubclass)
z.sum().backward(torch.tensor(1))
self.assertEqual(x.grad, y)
self.assertEqual(y.grad, x)
def test_none_wrapping(self):
# A Tensor subclass that returns None when doing add
# See LoggingTensor above for more details on the subclass

View File

@ -979,13 +979,13 @@ static void validate_outputs_impl(
}
if (grad.device() != metadata.device()) {
// quick hack for: https://github.com/pytorch/pytorch/issues/65016 but
// should be eventually removed
if (!(metadata.is_tensor_subclass() ||
grad.unsafeGetTensorImpl()->is_python_dispatch())) {
if (grad.dim() == 0) {
grad = grad.to(metadata.device());
} else {
if (grad.dim() == 0) {
grad = grad.to(metadata.device());
} else {
// quick hack for: https://github.com/pytorch/pytorch/issues/65016 but
// should be eventually removed
if (!(metadata.is_tensor_subclass() ||
grad.unsafeGetTensorImpl()->is_python_dispatch())) {
std::stringstream ss;
ss << "invalid gradient at index " << i << " - expected device ";
ss << metadata.device() << " but got " << grad.device();