mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964 Approved by: https://github.com/justinchuby, https://github.com/albanD
461 lines
16 KiB
Python
461 lines
16 KiB
Python
# Owner(s): ["module: autograd"]
|
|
|
|
import contextlib
|
|
import warnings
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch.library import _scoped_library, Library
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
TestCase,
|
|
)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def autograd_fallback_mode(mode):
|
|
prev = torch._C._get_autograd_fallback_mode()
|
|
try:
|
|
torch._C._set_autograd_fallback_mode(mode)
|
|
yield
|
|
finally:
|
|
torch._C._set_autograd_fallback_mode(prev)
|
|
|
|
|
|
class TestAutogradFallback(TestCase):
|
|
test_ns = "_test_autograd_fallback"
|
|
|
|
def tearDown(self):
|
|
if hasattr(torch.ops, self.test_ns):
|
|
delattr(torch.ops, self.test_ns)
|
|
if hasattr(self, "lib"):
|
|
del self.lib.m
|
|
del self.lib
|
|
|
|
def get_op(self, name):
|
|
return getattr(getattr(torch.ops, self.test_ns), name).default
|
|
|
|
def get_lib(self):
|
|
lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901
|
|
self.lib = lib
|
|
return lib
|
|
|
|
@parametrize("mode", ("nothing", "warn"))
|
|
def test_no_grad(self, mode):
|
|
with autograd_fallback_mode(mode):
|
|
lib = self.get_lib()
|
|
lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
|
|
lib.impl("foo", lambda a, b, c: a + b + c, "CPU")
|
|
op = self.get_op("foo")
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("error")
|
|
with torch.no_grad():
|
|
a = torch.randn([], requires_grad=True)
|
|
b = torch.randn([], requires_grad=True)
|
|
out = op(a, b, 1)
|
|
self.assertFalse(out.requires_grad)
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("error")
|
|
a = torch.randn([])
|
|
b = torch.randn([])
|
|
out = op(a, b, 1)
|
|
self.assertFalse(out.requires_grad)
|
|
|
|
@parametrize("mode", ("nothing", "warn"))
|
|
def test_no_autograd_kernel(self, mode):
|
|
with autograd_fallback_mode(mode):
|
|
lib = self.get_lib()
|
|
lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
|
|
op = self.get_op("foo")
|
|
|
|
def foo_impl(a, b, c):
|
|
result = a.detach().numpy() + b.detach().numpy() + c
|
|
return torch.tensor(result)
|
|
|
|
lib.impl("foo", foo_impl, "CPU")
|
|
|
|
# Some inputs requiring grad
|
|
a = torch.randn([], requires_grad=False)
|
|
b = torch.randn([], requires_grad=True)
|
|
out = op(a, b, 1).sum()
|
|
with self._check_ctx(mode, mode_nothing_raises=True):
|
|
out.backward()
|
|
self.assertIsNone(b.grad)
|
|
|
|
def _check_ctx(self, mode, *, mode_nothing_raises=False):
|
|
if mode == "warn":
|
|
return self.assertWarnsRegex(
|
|
UserWarning, "an autograd kernel was not registered"
|
|
)
|
|
assert mode == "nothing"
|
|
if mode_nothing_raises:
|
|
return self.assertRaisesRegex(RuntimeError, "does not require grad")
|
|
return contextlib.nullcontext()
|
|
|
|
@parametrize("mode", ("nothing", "warn"))
|
|
def test_no_autograd_kernel_inplace(self, mode):
|
|
with autograd_fallback_mode(mode):
|
|
# input modified in-place gets returned as output
|
|
lib = self.get_lib()
|
|
lib.define("foo(Tensor(a!) self, Tensor(b!) y) -> (Tensor(a!), Tensor(b!))")
|
|
op = self.get_op("foo")
|
|
|
|
def foo_impl(x, y):
|
|
with torch.no_grad():
|
|
x.sin_()
|
|
y.cos_()
|
|
return x, y
|
|
|
|
lib.impl("foo", foo_impl, "CPU")
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
w = x.clone()
|
|
v = x.clone()
|
|
y0 = w[0]
|
|
y1 = v[1]
|
|
z0, z1 = op(y0, y1)
|
|
for tensor in [w, v, z0, z1, y0, y1]:
|
|
with self._check_ctx(mode):
|
|
tensor.sum().backward(retain_graph=True)
|
|
|
|
# no outputs: we don't do anything. Maybe we should in the future.
|
|
# This is not a common failure mode.
|
|
lib.define("bar(Tensor(a!) self) -> ()")
|
|
op = self.get_op("bar")
|
|
|
|
def bar_impl(x):
|
|
with torch.no_grad():
|
|
x.sin_()
|
|
|
|
lib.impl("bar", bar_impl, "CPU")
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("error")
|
|
x = torch.randn([], requires_grad=True)
|
|
y = x.clone()
|
|
op(y)
|
|
y.backward()
|
|
self.assertEqual(x.grad, torch.ones_like(x))
|
|
|
|
@parametrize("mode", ("nothing", "warn"))
|
|
def test_cpu_return_self(self, mode):
|
|
with autograd_fallback_mode(mode):
|
|
# To be clear, none of these situations are OK and will lead
|
|
# to other problems down the line. We're testing them because
|
|
# it is fairly common to actually do these things.
|
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
|
lib.define("foo(Tensor self) -> Tensor")
|
|
lib.impl("foo", lambda x: x, "CPU")
|
|
op = self.get_op("foo")
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
y = op(x).sum()
|
|
with self._check_ctx(mode):
|
|
y.backward()
|
|
self.assertEqual(x.grad, torch.ones_like(x))
|
|
|
|
lib.define("bar(Tensor(a!) self) -> Tensor(a!)")
|
|
lib.impl("bar", lambda x: x, "CPU")
|
|
op = self.get_op("bar")
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
y = op(x).sum()
|
|
with self._check_ctx(mode):
|
|
y.backward()
|
|
self.assertEqual(x.grad, torch.ones_like(x))
|
|
|
|
@parametrize("mode", ("nothing", "warn"))
|
|
def test_composite_registered_to_cpu(self, mode):
|
|
with autograd_fallback_mode(mode):
|
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
|
lib.define("foo(Tensor self) -> Tensor")
|
|
lib.impl("foo", lambda x: x.sin().sum(), "CPU")
|
|
op = self.get_op("foo")
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
y = op(x)
|
|
with self._check_ctx(mode):
|
|
y.backward()
|
|
self.assertEqual(x.grad, x.cos())
|
|
|
|
@parametrize("mode", ("nothing", "warn"))
|
|
def test_autograd_function_registered_to_cpu(self, mode):
|
|
with autograd_fallback_mode(mode):
|
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
|
lib.define("foo(Tensor self) -> Tensor")
|
|
|
|
class NumpySin(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
ctx.save_for_backward(x)
|
|
return torch.tensor(np.sin(x.cpu().numpy()))
|
|
|
|
@staticmethod
|
|
def backward(ctx, gx):
|
|
(x,) = ctx.saved_tensors
|
|
return gx * x.cos()
|
|
|
|
lib.impl("foo", NumpySin.apply, "CPU")
|
|
op = self.get_op("foo")
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
y = op(x).sum()
|
|
with self._check_ctx(mode):
|
|
y.backward()
|
|
self.assertEqual(x.grad, x.cos())
|
|
|
|
@parametrize("mode", ("nothing", "warn"))
|
|
def test_inplace_autograd_function_registered_to_cpu(self, mode):
|
|
with autograd_fallback_mode(mode):
|
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
|
lib.define("foo(Tensor(a!) self) -> Tensor(a!)")
|
|
|
|
class NumpySin_(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
ctx.save_for_backward(x.clone())
|
|
x_np = x.detach().numpy()
|
|
np.sin(x_np, out=x_np)
|
|
ctx.mark_dirty(x)
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(ctx, gx):
|
|
(x,) = ctx.saved_tensors
|
|
return gx * x.cos()
|
|
|
|
lib.impl("foo", NumpySin_.apply, "CPU")
|
|
op = self.get_op("foo")
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
z = x.clone()
|
|
w = z[0]
|
|
y = op(w)
|
|
|
|
expected = torch.zeros_like(x)
|
|
expected[0] = x[0].cos()
|
|
with self._check_ctx(mode):
|
|
(gx,) = torch.autograd.grad(
|
|
y, x, torch.ones_like(y), retain_graph=True
|
|
)
|
|
self.assertEqual(gx, expected)
|
|
|
|
expected = torch.ones_like(x)
|
|
expected[0] = x[0].cos()
|
|
with self._check_ctx(mode):
|
|
(gx,) = torch.autograd.grad(z, x, torch.ones_like(z))
|
|
self.assertEqual(gx, expected)
|
|
|
|
@parametrize("mode", ("nothing", "warn"))
|
|
def test_inplace_on_tensor_that_does_not_require_grad(self, mode):
|
|
# We don't do anything special (that is, we don't rebase history).
|
|
# See NOTE [autograd fallback and in-place operations] for why
|
|
with autograd_fallback_mode(mode):
|
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
|
# Correct usage of (a!)
|
|
lib.define("foo(Tensor(a!) self, Tensor other) -> Tensor(a!)")
|
|
|
|
def foo_impl(x, y):
|
|
x_d = x.detach()
|
|
y = y.detach()
|
|
x_d.add_(y)
|
|
return x
|
|
|
|
lib.impl("foo", foo_impl, "CPU")
|
|
foo = self.get_op("foo")
|
|
|
|
# Incorrect usage of (a!): user doesn't return tensor as-is
|
|
lib.define("bar(Tensor(a!) self, Tensor other) -> Tensor(a!)")
|
|
|
|
def bar_impl(x, y):
|
|
x_d = x.detach()
|
|
y = y.detach()
|
|
x_d.add_(y)
|
|
return x_d.clone()
|
|
|
|
lib.impl("bar", bar_impl, "CPU")
|
|
bar = self.get_op("bar")
|
|
|
|
# User mutated input tensor but didn't return it.
|
|
lib.define("baz(Tensor(a!) self, Tensor other) -> ()")
|
|
|
|
def baz_impl(x, y):
|
|
x_d = x.detach()
|
|
y = y.detach()
|
|
x_d.add_(y)
|
|
|
|
lib.impl("baz", baz_impl, "CPU")
|
|
baz = self.get_op("baz")
|
|
|
|
# Test in-place on non-view
|
|
for op in (foo, bar, baz):
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, requires_grad=True)
|
|
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
|
|
z = x.clone()
|
|
op(z, y)
|
|
torch.autograd.grad(z, y, torch.ones_like(z), allow_unused=True)
|
|
|
|
# Test in-place on view
|
|
for op in (foo, bar, baz):
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, requires_grad=True)
|
|
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
|
|
z = x[:]
|
|
op(z, y)
|
|
torch.autograd.grad(z, x, torch.ones_like(z), allow_unused=True)
|
|
|
|
@parametrize("mode", ("nothing", "warn"))
|
|
def test_post_autograd_returns_leaf(self, mode):
|
|
with autograd_fallback_mode(mode):
|
|
lib = self.get_lib()
|
|
lib.define("foo(Tensor a) -> (Tensor, Tensor)")
|
|
op = self.get_op("foo")
|
|
|
|
lib.impl(
|
|
"foo", lambda a: (a.clone(), a.detach().clone().requires_grad_()), "CPU"
|
|
)
|
|
x = torch.randn(3, requires_grad=True)
|
|
_, z = op(x)
|
|
with self._check_ctx(mode):
|
|
z.sum().backward()
|
|
|
|
@parametrize("mode", ("nothing", "warn"))
|
|
def test_undefined_inputs_outputs(self, mode):
|
|
with autograd_fallback_mode(mode):
|
|
lib = self.get_lib()
|
|
lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)")
|
|
op = self.get_op("foo")
|
|
|
|
def foo_impl(a, b):
|
|
return None, b.clone()
|
|
|
|
lib.impl("foo", foo_impl, "CPU")
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
# NB: PyTorch dispatcher treats "None" as undefined Tensor.
|
|
_, z = op(None, x)
|
|
with self._check_ctx(mode):
|
|
z.sum().backward()
|
|
|
|
@parametrize("mode", ("nothing", "warn"))
|
|
def test_undefined_grads(self, mode):
|
|
with autograd_fallback_mode(mode):
|
|
lib = self.get_lib()
|
|
lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)")
|
|
op = self.get_op("foo")
|
|
|
|
def foo_impl(a, b):
|
|
return a.sin(), b.cos()
|
|
|
|
lib.impl("foo", foo_impl, "CPU")
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
y = torch.randn(3)
|
|
w, z = op(x, y)
|
|
w = torch._C._functions.UndefinedGrad()(w)
|
|
z = torch._C._functions.UndefinedGrad()(z)
|
|
with self._check_ctx(mode):
|
|
(z + w).sum().backward()
|
|
|
|
@parametrize("mode", ("nothing", "warn"))
|
|
def test_base_does_not_require_grad(self, mode):
|
|
with autograd_fallback_mode(mode):
|
|
lib = self.get_lib()
|
|
lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
|
|
op = self.get_op("foo")
|
|
|
|
def foo_impl(a):
|
|
with torch.no_grad():
|
|
return a.zero_()
|
|
|
|
lib.impl("foo", foo_impl, "CPU")
|
|
x = torch.randn(3)
|
|
y = x[:]
|
|
y.requires_grad_()
|
|
w = y[:]
|
|
self.assertTrue(w._base is x)
|
|
|
|
# Hook should be registered on w, but not w._base
|
|
op(w)
|
|
with self._check_ctx(mode):
|
|
w.sum().backward()
|
|
|
|
@parametrize("mode", ("nothing", "warn"))
|
|
def test_post_autograd_returns_mix_of_requires_grad_tensors(self, mode):
|
|
with autograd_fallback_mode(mode):
|
|
lib = self.get_lib()
|
|
lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor, Tensor)")
|
|
op = self.get_op("foo")
|
|
|
|
def foo_impl(a, b):
|
|
with torch.no_grad():
|
|
x = a.clone()
|
|
z = b.clone()
|
|
y = a * b
|
|
return x, y, z
|
|
|
|
lib.impl("foo", foo_impl, "CPU")
|
|
a = torch.randn(3, requires_grad=True)
|
|
b = torch.randn(3, requires_grad=True)
|
|
x, y, z = op(a, b)
|
|
|
|
with self._check_ctx(mode, mode_nothing_raises=True):
|
|
torch.autograd.grad(
|
|
x, (a, b), torch.ones_like(x), allow_unused=True, retain_graph=True
|
|
)
|
|
|
|
with self._check_ctx(mode, mode_nothing_raises=False):
|
|
torch.autograd.grad(
|
|
y, (a, b), torch.ones_like(y), allow_unused=True, retain_graph=True
|
|
)
|
|
|
|
with self._check_ctx(mode, mode_nothing_raises=True):
|
|
torch.autograd.grad(
|
|
z, (a, b), torch.ones_like(z), allow_unused=True, retain_graph=True
|
|
)
|
|
|
|
@parametrize("mode", ("nothing", "warn"))
|
|
def test_supports_tensor_lists(self, mode):
|
|
with autograd_fallback_mode(mode):
|
|
lib = self.get_lib()
|
|
lib.define("foo(Tensor[] a) -> Tensor[]")
|
|
op = self.get_op("foo")
|
|
|
|
def foo_impl(a):
|
|
x, y, z = a
|
|
with torch.no_grad():
|
|
return x + y + z, x * y * z
|
|
|
|
lib.impl("foo", foo_impl, "CPU")
|
|
x = torch.randn(3, requires_grad=True)
|
|
y = torch.randn(1, requires_grad=True)
|
|
z = torch.randn(2, 1, requires_grad=True)
|
|
a, b = op([x, y, z])
|
|
with self._check_ctx(mode, mode_nothing_raises=True):
|
|
torch.autograd.grad(
|
|
a,
|
|
(x, y, z),
|
|
torch.ones_like(a),
|
|
allow_unused=True,
|
|
retain_graph=True,
|
|
)
|
|
with self._check_ctx(mode, mode_nothing_raises=True):
|
|
torch.autograd.grad(
|
|
b,
|
|
(x, y, z),
|
|
torch.ones_like(b),
|
|
allow_unused=True,
|
|
retain_graph=True,
|
|
)
|
|
|
|
|
|
instantiate_parametrized_tests(TestAutogradFallback)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|