mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[autograd] match 0-dim gradients device type regardless of subclassness (#160165)
Not sure if there some subclasses where the outer.dim() == 0 but you wouldn't want to move it? FIXES https://github.com/pytorch/pytorch/issues/160084 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160165 Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
d25c4f954d
commit
c8205cb354
@ -7673,6 +7673,31 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
out2 = torch.compile(model, backend="eager")(input.clone())
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
@requires_cuda
|
||||
def test_zero_dim_param_mixed_device_grad(self):
|
||||
# cpu 0-dim params with cuda grads
|
||||
# https://github.com/pytorch/pytorch/issues/160084
|
||||
class RegressionModel(torch.nn.Module):
|
||||
def __init__(self, a=0, b=0):
|
||||
super().__init__()
|
||||
self.a = torch.nn.Parameter(torch.tensor(a).float())
|
||||
self.b = torch.nn.Parameter(torch.tensor(b).float())
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.a + self.b
|
||||
|
||||
model = RegressionModel()
|
||||
model.forward = torch.compile(
|
||||
model.forward, backend="aot_eager", fullgraph=True
|
||||
)
|
||||
inputs = torch.randn(4, 10).to("cuda")
|
||||
out = model(inputs)
|
||||
out.sum().backward()
|
||||
self.assertIsNotNone(model.a.grad)
|
||||
self.assertIsNotNone(model.b.grad)
|
||||
self.assertEqual(model.a.grad.device, torch.device("cpu"))
|
||||
self.assertEqual(model.b.grad.device, torch.device("cpu"))
|
||||
|
||||
def test_filter_warnings(self):
|
||||
x = torch.ones(2, 2, requires_grad=True)
|
||||
|
||||
|
||||
@ -12396,6 +12396,29 @@ class TestAutogradDeviceType(TestCase):
|
||||
x.resize_as_(y)
|
||||
self.assertEqual(x._version, 2)
|
||||
|
||||
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
|
||||
def test_zero_dim_param_mixed_device_grad(self, device):
|
||||
# cpu 0-dim params with an accelerator device grad
|
||||
# https://github.com/pytorch/pytorch/issues/160084
|
||||
class RegressionModel(torch.nn.Module):
|
||||
def __init__(self, a=0, b=0):
|
||||
super().__init__()
|
||||
self.a = torch.nn.Parameter(torch.tensor(a).float())
|
||||
self.b = torch.nn.Parameter(torch.tensor(b).float())
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.a + self.b
|
||||
|
||||
# Keep the model on cpu as we do want to test the mixed cpu/accelerator behavior here
|
||||
model = RegressionModel()
|
||||
inputs = torch.randn(4, 10, device=device)
|
||||
out = model(inputs)
|
||||
out.sum().backward()
|
||||
self.assertIsNotNone(model.a.grad)
|
||||
self.assertIsNotNone(model.b.grad)
|
||||
self.assertEqual(model.a.grad.device, torch.device("cpu"))
|
||||
self.assertEqual(model.b.grad.device, torch.device("cpu"))
|
||||
|
||||
|
||||
class TestAllowMutationOnSaved(TestCase):
|
||||
def assertClonedLenEqual(self, ctx, n):
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# Owner(s): ["module: __torch_dispatch__"]
|
||||
# ruff: noqa: F841
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
import sys
|
||||
import tempfile
|
||||
@ -1718,49 +1717,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
||||
self.assertEqual(s.device_index, 2)
|
||||
self.assertEqual(s.device_type, 3)
|
||||
|
||||
def test_subclass_autograd_device_check(self) -> None:
|
||||
class NonWrapperSubclass(torch.Tensor):
|
||||
elem: torch.Tensor
|
||||
|
||||
__slots__ = ["elem"]
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, *args, **kwargs):
|
||||
# Wrong device here!
|
||||
r = torch.Tensor._make_subclass(
|
||||
cls, elem.to("meta"), elem.requires_grad
|
||||
)
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
r.elem = elem
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
def unwrap(e):
|
||||
return e.elem if isinstance(e, NonWrapperSubclass) else e
|
||||
|
||||
def wrap(e):
|
||||
return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e
|
||||
|
||||
rs = tree_map(
|
||||
wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
|
||||
)
|
||||
logging.getLogger("NonWrapperSubclass").info(
|
||||
f"{func.__module__}.{func.__name__}", # noqa: G004
|
||||
args,
|
||||
kwargs,
|
||||
rs,
|
||||
)
|
||||
return rs
|
||||
|
||||
x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True))
|
||||
y = torch.randn(2, requires_grad=True)
|
||||
z = x * y
|
||||
self.assertIsInstance(z, NonWrapperSubclass)
|
||||
z.sum().backward(torch.tensor(1))
|
||||
self.assertEqual(x.grad, y)
|
||||
self.assertEqual(y.grad, x)
|
||||
|
||||
def test_none_wrapping(self):
|
||||
# A Tensor subclass that returns None when doing add
|
||||
# See LoggingTensor above for more details on the subclass
|
||||
|
||||
@ -979,13 +979,13 @@ static void validate_outputs_impl(
|
||||
}
|
||||
|
||||
if (grad.device() != metadata.device()) {
|
||||
// quick hack for: https://github.com/pytorch/pytorch/issues/65016 but
|
||||
// should be eventually removed
|
||||
if (!(metadata.is_tensor_subclass() ||
|
||||
grad.unsafeGetTensorImpl()->is_python_dispatch())) {
|
||||
if (grad.dim() == 0) {
|
||||
grad = grad.to(metadata.device());
|
||||
} else {
|
||||
if (grad.dim() == 0) {
|
||||
grad = grad.to(metadata.device());
|
||||
} else {
|
||||
// quick hack for: https://github.com/pytorch/pytorch/issues/65016 but
|
||||
// should be eventually removed
|
||||
if (!(metadata.is_tensor_subclass() ||
|
||||
grad.unsafeGetTensorImpl()->is_python_dispatch())) {
|
||||
std::stringstream ss;
|
||||
ss << "invalid gradient at index " << i << " - expected device ";
|
||||
ss << metadata.device() << " but got " << grad.device();
|
||||
|
||||
Reference in New Issue
Block a user