mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	This reverts commit a02c573a8996d5d47585410ceaf81c87104cfd43. Reverted https://github.com/pytorch/pytorch/pull/103185 on behalf of https://github.com/izaitsevfb due to Breaks internal builds, see D46629734 ([comment](https://github.com/pytorch/pytorch/pull/103185#issuecomment-1588258206))
		
			
				
	
	
		
			11196 lines
		
	
	
		
			424 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			11196 lines
		
	
	
		
			424 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["module: autograd"]
 | |
| 
 | |
| import contextlib
 | |
| import gc
 | |
| import io
 | |
| import math
 | |
| import os
 | |
| import random
 | |
| import sys
 | |
| import tempfile
 | |
| import threading
 | |
| import time
 | |
| import unittest
 | |
| import uuid
 | |
| import warnings
 | |
| import operator
 | |
| import subprocess
 | |
| from copy import deepcopy
 | |
| from collections import OrderedDict
 | |
| from itertools import product
 | |
| from operator import mul
 | |
| from functools import reduce, partial
 | |
| import torch
 | |
| 
 | |
| from torch import nn
 | |
| from torch import inf, nan
 | |
| from torch.autograd.function import once_differentiable
 | |
| from torch.autograd.profiler import (profile, record_function, emit_nvtx, emit_itt)
 | |
| from torch.autograd.profiler_util import (_format_time, EventList, FunctionEvent, FunctionEventAvg)
 | |
| from torch.utils.checkpoint import checkpoint
 | |
| from torch.testing import make_tensor
 | |
| from torch.testing._internal.common_cuda import TEST_CUDA
 | |
| from torch.testing._internal.common_utils import (
 | |
|     TestCase, run_tests, skipIfNoLapack, slowTest, IS_WINDOWS, IS_MACOS,
 | |
|     disable_gc, gradcheck, gradgradcheck, parametrize,
 | |
|     instantiate_parametrized_tests, skipIfMps, set_warn_always_context)
 | |
| from torch.autograd import Variable, Function, detect_anomaly, kineto_available, _calculate_shape
 | |
| from torch.autograd.function import InplaceFunction
 | |
| import torch.autograd.forward_ad as fwAD
 | |
| from torch.testing._internal.common_methods_invocations import mask_not_all_zeros
 | |
| from torch.testing._internal.common_device_type import (instantiate_device_type_tests,
 | |
|                                                         onlyCPU, onlyCUDA, dtypes, dtypesIfCUDA,
 | |
|                                                         deviceCountAtLeast, skipMeta, dtypesIfMPS)
 | |
| from torch.testing._internal.common_dtype import floating_types_and
 | |
| from torch.utils._mode_utils import no_dispatch
 | |
| from torch.utils._python_dispatch import TorchDispatchMode
 | |
| import weakref
 | |
| import collections
 | |
| import pickle
 | |
| 
 | |
| 
 | |
| def graph_desc(fn):
 | |
|     if fn is None:
 | |
|         return 'None'
 | |
|     result = type(fn).__name__ + '('
 | |
|     next_functions = fn.next_functions
 | |
|     for next_fn, _ in next_functions:
 | |
|         result += graph_desc(next_fn)
 | |
|         result += ', '
 | |
|     if next_functions:
 | |
|         result = result[:-2]
 | |
|     return result + ')'
 | |
| 
 | |
| 
 | |
| class TestAutograd(TestCase):
 | |
|     def test_copy_slices_graph_task_updates(self):
 | |
|         def f1(x, y):
 | |
|             out = x.clone().view(-1)
 | |
|             out += y
 | |
|             return out
 | |
| 
 | |
|         def f2(x, y):
 | |
|             out = x.clone().view(-1)
 | |
|             b = out * 2
 | |
|             out += y
 | |
|             return out + b
 | |
| 
 | |
|         x = torch.rand(2, requires_grad=True)
 | |
|         y = torch.rand(2, requires_grad=True)
 | |
| 
 | |
|         y_safe = torch._C._functions.DelayedError("Boom!", 1)(y)
 | |
| 
 | |
|         for f in [f1, f2]:
 | |
|             # Ensure that the error Node works
 | |
|             out = f(x, y_safe)
 | |
|             with self.assertRaisesRegex(RuntimeError, "Boom!"):
 | |
|                 out.sum().backward()
 | |
| 
 | |
|             out = f(x, y_safe)
 | |
|             with self.assertRaisesRegex(RuntimeError, "Boom!"):
 | |
|                 torch.autograd.grad(out.sum(), y)
 | |
| 
 | |
|             # Ensure that if we don't ask for y, it doesn't crash
 | |
|             out = f(x, y_safe)
 | |
|             torch.autograd.grad(out.sum(), x)
 | |
| 
 | |
|             out = f(x, y_safe)
 | |
|             torch.autograd.grad(out.sum(), y_safe)
 | |
| 
 | |
|             out = f(x, y_safe)
 | |
|             torch.autograd.grad(out.sum(), (x, y_safe))
 | |
| 
 | |
|         # Ensure that we don't run extra view Node
 | |
|         def f3(x, y):
 | |
|             out = x.clone().view(-1)
 | |
| 
 | |
|             def hook(*args):
 | |
|                 # This should never be called!
 | |
|                 self.assertTrue(False)
 | |
|             out.register_hook(hook)
 | |
| 
 | |
|             b = out + y
 | |
|             out += y
 | |
|             return out + b, b
 | |
| 
 | |
|         out, b = f3(x, y_safe)
 | |
|         torch.autograd.grad(out.sum(), (b, y_safe))
 | |
| 
 | |
| 
 | |
|     def test_grad_mode_class_decoration(self):
 | |
|         # Decorating class is deprecated and should not be used
 | |
|         with self.assertWarnsRegex(UserWarning, "Decorating classes is deprecated"):
 | |
|             @torch.no_grad()
 | |
|             class Foo():
 | |
|                 def __init__(self):
 | |
|                     assert not torch.is_grad_enabled()
 | |
| 
 | |
|                 def foo(self):
 | |
|                     # Not applied to methods
 | |
|                     assert torch.is_grad_enabled()
 | |
| 
 | |
|             # Show that we can actually construct the class
 | |
|             foo = Foo()
 | |
|             foo.foo()
 | |
| 
 | |
|         # Decorating functions or methods is fine though
 | |
|         with warnings.catch_warnings(record=True) as w:
 | |
|             @torch.no_grad()
 | |
|             def foo():
 | |
|                 assert not torch.is_grad_enabled()
 | |
| 
 | |
|             foo()
 | |
| 
 | |
|             class Foo2():
 | |
|                 @torch.no_grad()
 | |
|                 def __init__(self):
 | |
|                     assert not torch.is_grad_enabled()
 | |
| 
 | |
|                 @torch.no_grad()
 | |
|                 def foo(self):
 | |
|                     assert not torch.is_grad_enabled()
 | |
| 
 | |
|             foo2 = Foo2()
 | |
|             foo2.foo()
 | |
| 
 | |
|         self.assertEqual(len(w), 0)
 | |
| 
 | |
|     def test_tensor_grad_warnings(self):
 | |
|         dummy = torch.empty(1)
 | |
| 
 | |
|         with warnings.catch_warnings(record=True) as w:
 | |
|             # Accessing .grad on leaf
 | |
|             dummy.requires_grad_()
 | |
|             foo = dummy.grad
 | |
|             self.assertEqual(len(w), 0)
 | |
| 
 | |
|             # Accessing .grad on non-leaf
 | |
|             dummy = dummy.clone()
 | |
|             foo = dummy.grad
 | |
|             self.assertEqual(len(w), 1)
 | |
| 
 | |
|             # Accessing .grad on non-leaf that retains gradients
 | |
|             dummy.retain_grad()
 | |
|             foo = dummy.grad
 | |
|             self.assertEqual(len(w), 1)
 | |
| 
 | |
|     def _function_test(self, cls):
 | |
|         x = torch.randn(5, 5, requires_grad=True)
 | |
|         y = torch.randn(5, 5, requires_grad=True)
 | |
|         result = cls.apply(x, 2, y)
 | |
|         go = torch.ones((), requires_grad=True)
 | |
|         result.sum().backward(go, create_graph=True)
 | |
| 
 | |
|         self.assertEqual(x.grad, y + torch.ones(5, 5))
 | |
|         self.assertEqual(y.grad, x + torch.ones(5, 5) * 2)
 | |
|         self.assertIsNotNone(x.grad.grad_fn)
 | |
|         self.assertIsNotNone(y.grad.grad_fn)
 | |
| 
 | |
|         return x, y
 | |
| 
 | |
|     def test_function(self):
 | |
|         class MyFunction(Function):
 | |
| 
 | |
|             @staticmethod
 | |
|             def forward(ctx, tensor1, pyscalar, tensor2):
 | |
|                 ctx.pyscalar = pyscalar
 | |
|                 ctx.save_for_backward(tensor1, tensor2)
 | |
|                 return tensor1 + pyscalar * tensor2 + tensor1 * tensor2
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output):
 | |
|                 var1, var2 = ctx.saved_tensors
 | |
|                 # NOTE: self is the test case here
 | |
|                 self.assertIsInstance(var1, torch.Tensor)
 | |
|                 self.assertIsInstance(var2, torch.Tensor)
 | |
|                 self.assertIsInstance(grad_output, torch.Tensor)
 | |
|                 return (grad_output + grad_output * var2, None,
 | |
|                         grad_output * ctx.pyscalar + grad_output * var1)
 | |
| 
 | |
|         x, y = self._function_test(MyFunction)
 | |
| 
 | |
|         x_grad_desc = graph_desc(x.grad.grad_fn)
 | |
|         y_grad_desc = graph_desc(y.grad.grad_fn)
 | |
|         self.assertExpected(x_grad_desc, "x_grad_desc")
 | |
|         self.assertExpected(y_grad_desc, "y_grad_desc")
 | |
| 
 | |
|     def test_once_differentiable(self):
 | |
|         class MyFunction(Function):
 | |
| 
 | |
|             @staticmethod
 | |
|             def forward(ctx, tensor1, pyscalar, tensor2):
 | |
|                 ctx.pyscalar = pyscalar
 | |
|                 ctx.save_for_backward(tensor1, tensor2)
 | |
|                 return tensor1 + pyscalar * tensor2 + tensor1 * tensor2
 | |
| 
 | |
|             @staticmethod
 | |
|             @once_differentiable
 | |
|             def backward(ctx, grad_output):
 | |
|                 self.assertFalse(torch.is_grad_enabled())
 | |
|                 t1, t2 = ctx.saved_tensors
 | |
|                 return (grad_output + grad_output * t2, None,
 | |
|                         grad_output * ctx.pyscalar + grad_output * t1)
 | |
| 
 | |
|         x, y = self._function_test(MyFunction)
 | |
|         self.assertEqual(graph_desc(x.grad.grad_fn),
 | |
|                          'CopyBackwards(None, Error(AccumulateGrad(), None, AccumulateGrad()))')
 | |
|         self.assertEqual(graph_desc(y.grad.grad_fn),
 | |
|                          'CopyBackwards(None, Error(AccumulateGrad(), None, AccumulateGrad()))')
 | |
| 
 | |
|     def test_function_returns_input(self):
 | |
|         class MyFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 return grad * 2
 | |
| 
 | |
|         for shape in [(1,), ()]:
 | |
|             v = torch.ones(shape, requires_grad=True)
 | |
|             MyFunction.apply(v).backward()
 | |
|             self.assertEqual(v.grad, torch.full(shape, 2.))
 | |
| 
 | |
|             with torch.no_grad():
 | |
|                 v.grad.zero_()
 | |
|             MyFunction.apply(v.clone()).backward()
 | |
|             self.assertEqual(v.grad, torch.full(shape, 2.))
 | |
| 
 | |
|     def test_function_returns_undefined_tensor(self):
 | |
|         class MyFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x * 2
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 return None
 | |
| 
 | |
|         # Test that undefined tensors returned from custom backward function
 | |
|         # are propagated as undefined and not tensor full of zeroes
 | |
|         x = torch.ones(1, requires_grad=True)
 | |
| 
 | |
|         MyFunction.apply(x).backward()
 | |
|         self.assertIsNone(x.grad)
 | |
| 
 | |
|         MyFunction.apply(x ** 2).backward()
 | |
|         self.assertIsNone(x.grad)
 | |
| 
 | |
|         MyFunction.apply(x).sum().backward()
 | |
|         self.assertIsNone(x.grad)
 | |
| 
 | |
|         self.assertIsNone(torch.autograd.grad(MyFunction.apply(x), x, allow_unused=True)[0])
 | |
| 
 | |
|     def test_materialize_grads(self):
 | |
|         class MyFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 self.assertEqual(grad, torch.zeros(1))
 | |
|                 return grad
 | |
| 
 | |
|         x = torch.ones(1, requires_grad=True)
 | |
|         torch._C._functions.UndefinedGrad()(MyFunction.apply(x)).backward()
 | |
| 
 | |
|     def test_dont_materialize_grads(self):
 | |
|         class MyFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 ctx.set_materialize_grads(False)
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 self.assertIsNone(grad)
 | |
|                 return grad
 | |
| 
 | |
|         x = torch.ones(1, requires_grad=True)
 | |
|         torch._C._functions.UndefinedGrad()(MyFunction.apply(x)).backward()
 | |
| 
 | |
|     def test_legacy_function_deprecation_exception(self):
 | |
|         # Trigger exception
 | |
|         class MyFunction(Function):
 | |
|             def forward(self, x):
 | |
|                 return x
 | |
| 
 | |
|             def backward(self, grad_output):
 | |
|                 return grad_output
 | |
| 
 | |
|         # Check exception occurs
 | |
|         with self.assertRaisesRegex(
 | |
|                 RuntimeError,
 | |
|                 'Legacy autograd function with non-static forward method is deprecated'):
 | |
|             MyFunction()(torch.randn(3, 4))
 | |
| 
 | |
|     class SimulateBackwardError(Function):
 | |
| 
 | |
|         @staticmethod
 | |
|         def forward(ctx, input):
 | |
|             return input.clone()
 | |
| 
 | |
|         @staticmethod
 | |
|         @once_differentiable
 | |
|         def backward(ctx, input):
 | |
|             raise Exception("Simulate error on backward pass")
 | |
| 
 | |
|     def test_custom_function_exception(self):
 | |
| 
 | |
|         t1 = torch.rand((3, 3), requires_grad=True)
 | |
|         t2 = torch.rand((3, 3), requires_grad=True)
 | |
| 
 | |
|         tmp = (t1 + t2) * (t1 + t2)
 | |
|         t3 = TestAutograd.SimulateBackwardError.apply(tmp)
 | |
|         with self.assertRaisesRegex(Exception, "Simulate error on backward pass"):
 | |
|             t3.sum().backward()
 | |
| 
 | |
|     def test_custom_function_non_tensor_inputs_outputs(self):
 | |
|         class MyFunction(Function):
 | |
| 
 | |
|             @staticmethod
 | |
|             def forward(ctx, t1, t2, scale, t3):
 | |
|                 t4 = t1 + t2 * t3
 | |
|                 t5 = t1 * t2 + t3
 | |
|                 t4 *= scale
 | |
|                 t5 *= scale
 | |
| 
 | |
|                 # Save scale
 | |
|                 ctx.scale = scale
 | |
|                 ctx.save_for_backward(t1, t2, t3)
 | |
|                 return scale, t4, None, True, t5, "bar", t1
 | |
| 
 | |
|             @staticmethod
 | |
|             @once_differentiable
 | |
|             def backward(ctx, *grads):
 | |
|                 # Verify grads
 | |
|                 self.assertEqual(7, len(grads))
 | |
|                 self.assertIsNone(grads[0])
 | |
|                 self.assertIsNone(grads[2])
 | |
|                 self.assertIsNone(grads[3])
 | |
|                 self.assertIsNone(grads[5])
 | |
| 
 | |
|                 scale = ctx.scale
 | |
|                 var1, var2, var3 = ctx.saved_tensors
 | |
|                 return (
 | |
|                     grads[1] * scale + grads[4] * var2 * scale + grads[6],
 | |
|                     grads[1] * var3 * scale + grads[4] * var1 * scale,
 | |
|                     None,
 | |
|                     grads[1] * var2 * scale + grads[4] * scale,
 | |
|                 )
 | |
| 
 | |
|         t1 = torch.rand(10, dtype=torch.double, requires_grad=True)
 | |
|         t2 = torch.rand(10, dtype=torch.double, requires_grad=True)
 | |
|         t3 = torch.rand(10, dtype=torch.double)
 | |
|         scale = random.randint(0, 10)
 | |
|         res = MyFunction.apply(t1, t2, scale, t3)
 | |
|         self.assertEqual(scale, res[0])
 | |
|         self.assertEqual((t1 + t2 * t3) * scale, res[1])
 | |
|         self.assertEqual(None, res[2])
 | |
|         self.assertEqual(True, res[3])
 | |
|         self.assertEqual((t1 * t2 + t3) * scale, res[4])
 | |
|         self.assertEqual("bar", res[5])
 | |
|         self.assertEqual(t1, res[6])
 | |
| 
 | |
|         # Validate running backward.
 | |
|         torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()])
 | |
|         self.assertIsNotNone(t1.grad)
 | |
|         self.assertIsNotNone(t2.grad)
 | |
|         self.assertIsNone(t3.grad)
 | |
| 
 | |
|         # Test gradcheck
 | |
|         def foo(t1, t2, t3):
 | |
|             res = MyFunction.apply(t1, t2, scale, t3)
 | |
|             return res[1], res[4], res[6]
 | |
| 
 | |
|         gradcheck(foo, (t1, t2, t3))
 | |
| 
 | |
|     def test_custom_function_no_tensors(self):
 | |
|         class MyFunction(Function):
 | |
| 
 | |
|             @staticmethod
 | |
|             def forward(ctx, t1, t2, scale, t3):
 | |
|                 t4 = t1 + t2 * t3
 | |
|                 t5 = t1 * t2 + t3
 | |
|                 t4 *= scale
 | |
|                 t5 *= scale
 | |
|                 return scale, t4, None, True, t5, "bar", t1
 | |
| 
 | |
|             @staticmethod
 | |
|             @once_differentiable
 | |
|             def backward(ctx, *args):
 | |
|                 return (args[0], args[1], None, args[2])
 | |
| 
 | |
|         t1 = random.random()
 | |
|         t2 = random.random()
 | |
|         t3 = random.random()
 | |
|         scale = random.randint(0, 10)
 | |
|         res = MyFunction.apply(t1, t2, scale, t3)
 | |
|         self.assertEqual(scale, res[0])
 | |
|         self.assertEqual((t1 + t2 * t3) * scale, res[1])
 | |
|         self.assertEqual(None, res[2])
 | |
|         self.assertEqual(True, res[3])
 | |
|         self.assertEqual((t1 * t2 + t3) * scale, res[4])
 | |
|         self.assertEqual("bar", res[5])
 | |
|         self.assertEqual(t1, res[6])
 | |
| 
 | |
|     def test_invalid_gradients(self):
 | |
|         class MyFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x * 2
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output):
 | |
|                 return torch.randn(10, dtype=torch.float)
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, 'expected shape'):
 | |
|             input = torch.randn(5, 5, dtype=torch.float, requires_grad=True)
 | |
|             MyFunction.apply(input).sum().backward()
 | |
| 
 | |
|     def test_unrelated_inputs(self):
 | |
|         # test to ensure grad(grad)check runs successfully even if there is an
 | |
|         # unrelated (but differentiable) inputs
 | |
| 
 | |
|         def my_function(x, y):
 | |
|             return x * x
 | |
| 
 | |
|         x = torch.rand(10, dtype=torch.double, requires_grad=True)
 | |
|         y = torch.rand(10, dtype=torch.double, requires_grad=True)
 | |
| 
 | |
|         gradcheck(my_function, (x, y))
 | |
|         gradgradcheck(my_function, (x, y))
 | |
| 
 | |
|     def test_not_implemented_grad(self):
 | |
|         a = torch.rand(2, requires_grad=True)
 | |
|         # if grad for nextafter ends up being implemented, this should be changed
 | |
|         y = torch.nextafter(a, a).sum()
 | |
|         with self.assertRaisesRegex(
 | |
|                 NotImplementedError,
 | |
|                 'the derivative for .* is not implemented'):
 | |
|             y.backward()
 | |
| 
 | |
|     def test_not_implemented_fwad(self):
 | |
|         x = torch.randn(3)
 | |
|         v = torch.rand(3)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             dual_x = fwAD.make_dual(x, v)
 | |
| 
 | |
|             err_msg = r"Trying to use forward AD with .* that does not support it"
 | |
|             hint_msg = "Running forward AD for an OP that does not implement it should raise a NotImplementedError"
 | |
| 
 | |
|             with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
 | |
|                 # if forward AD ends up being implemented for torch.igamma, choose a different op
 | |
|                 torch.igamma(dual_x, dual_x)
 | |
| 
 | |
|     def test_will_engine_execute_node(self):
 | |
|         counter = [0]
 | |
| 
 | |
|         class MyFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x * 2
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, gO):
 | |
|                 return gO * 2
 | |
| 
 | |
|         def get_grad_fn(t):
 | |
|             if t.requires_grad and t.grad_fn is None:
 | |
|                 return t.clone().grad_fn.next_functions[0][0]
 | |
|             else:
 | |
|                 return t.grad_fn
 | |
| 
 | |
|         a = torch.randn(2, 3, 4, requires_grad=True)
 | |
|         a2 = torch.randn(2, 3, 4, requires_grad=True)
 | |
|         b = a * a2
 | |
|         b2 = b.cos()
 | |
|         c = MyFunction.apply(b)
 | |
| 
 | |
|         should_execute = list(map(get_grad_fn, (a, b, c)))
 | |
|         should_not_execute = list(map(get_grad_fn, (a2, b2)))
 | |
| 
 | |
|         def fn(x):
 | |
|             counter[0] += 1
 | |
| 
 | |
|             for g in should_execute:
 | |
|                 self.assertTrue(torch._C._will_engine_execute_node(g))
 | |
| 
 | |
|             for g in should_not_execute:
 | |
|                 self.assertFalse(torch._C._will_engine_execute_node(g))
 | |
| 
 | |
|         b.register_hook(fn)
 | |
|         c.register_hook(fn)
 | |
| 
 | |
|         # .backward(inputs=) is OK
 | |
|         out = c.sum()
 | |
|         torch.autograd.backward(out, inputs=(a, b), retain_graph=True)
 | |
|         self.assertEqual(counter[0], 2)
 | |
| 
 | |
|         # .backward() is OK
 | |
|         should_execute = list(map(get_grad_fn, (a, a2, b, c)))
 | |
|         should_not_execute = list(map(get_grad_fn, (b2,)))
 | |
|         torch.autograd.backward(out, retain_graph=True)
 | |
| 
 | |
|         # .grad is NOT OK when leaf is passed (this is the current state, subject to change)
 | |
|         with self.assertRaisesRegex(RuntimeError, "are currently running autograd.grad()"):
 | |
|             torch.autograd.grad(out, (a,))
 | |
| 
 | |
|         # .grad is OK when non-leaf is passed
 | |
|         a = torch.randn(1, 2, 3, requires_grad=True) * 2
 | |
|         b = a * 2
 | |
| 
 | |
|         def fn(x):
 | |
|             # Check a non-leaf
 | |
|             counter[0] += 1
 | |
|             self.assertTrue(torch._C._will_engine_execute_node(b.grad_fn))
 | |
|         b.register_hook(fn)
 | |
|         counter[0] = 0
 | |
|         torch.autograd.grad(b.sum(), (a,))
 | |
|         self.assertEqual(counter[0], 1)
 | |
| 
 | |
|         # Verify other errors are raised
 | |
|         with self.assertRaisesRegex(RuntimeError, "during the backward pass"):
 | |
|             torch._C._will_engine_execute_node(out.grad_fn)
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "expects an grad_fn"):
 | |
|             torch._C._will_engine_execute_node(out)
 | |
| 
 | |
|     def test_custom_function_vmap_defaults(self):
 | |
|         class MySquare(Function):
 | |
|             @staticmethod
 | |
|             def forward(x):
 | |
|                 return x ** 2
 | |
| 
 | |
|             @staticmethod
 | |
|             def setup_context(ctx, inputs, output):
 | |
|                 x, = inputs
 | |
|                 ctx.save_for_backward(x)
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, gO):
 | |
|                 x, = ctx.saved_tensors
 | |
|                 return gO * 2 * x
 | |
| 
 | |
|         self.assertFalse(MySquare.generate_vmap_rule)
 | |
|         self.assertTrue(hasattr(MySquare, 'vmap'))
 | |
| 
 | |
|     def test_custom_function_setup_context_simple(self):
 | |
|         class MySquare(Function):
 | |
|             @staticmethod
 | |
|             def forward(x):
 | |
|                 return x ** 2
 | |
| 
 | |
|             @staticmethod
 | |
|             def setup_context(ctx, inputs, output):
 | |
|                 x, = inputs
 | |
|                 ctx.save_for_backward(x)
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, gO):
 | |
|                 x, = ctx.saved_tensors
 | |
|                 return gO * 2 * x
 | |
| 
 | |
|         x = torch.randn([], requires_grad=True)
 | |
|         y = MySquare.apply(x)
 | |
|         gx, = torch.autograd.grad(y, x)
 | |
|         self.assertEqual(gx, 2 * x)
 | |
| 
 | |
|     def test_custom_function_setup_context_multi_output(self):
 | |
|         # Multiple outputs with some non-Tensor outputs.
 | |
|         class MySquare(Function):
 | |
|             @staticmethod
 | |
|             def forward(x):
 | |
|                 two_x = x.item() * 2
 | |
|                 return x ** 2, two_x
 | |
| 
 | |
|             @staticmethod
 | |
|             def setup_context(ctx, inputs, output):
 | |
|                 x, = inputs
 | |
|                 _, two_x = output
 | |
|                 ctx.two_x = two_x
 | |
| 
 | |
|             @staticmethod
 | |
|             @once_differentiable
 | |
|             def backward(ctx, gO, _):
 | |
|                 return gO * ctx.two_x
 | |
| 
 | |
|         x = torch.randn([], requires_grad=True)
 | |
|         y, _ = MySquare.apply(x)
 | |
|         gx, = torch.autograd.grad(y, x)
 | |
|         self.assertEqual(gx, 2 * x)
 | |
| 
 | |
|     def test_custom_function_setup_context_multi_input(self):
 | |
|         class MyReshape(Function):
 | |
|             @staticmethod
 | |
|             def forward(x, shape, scale_forward, scale_backward):
 | |
|                 return x.reshape(shape) * scale_forward
 | |
| 
 | |
|             @staticmethod
 | |
|             def setup_context(ctx, inputs, output):
 | |
|                 x, shape, scale_forward, scale_backward = inputs
 | |
|                 ctx.scale_backward = scale_backward
 | |
|                 ctx.x_shape = x.shape
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, gO):
 | |
|                 return gO.reshape(ctx.x_shape) * ctx.scale_backward, None, None, None
 | |
| 
 | |
|         class MyReshapeRef(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, shape, scale_forward, scale_backward):
 | |
|                 ctx.scale_backward = scale_backward
 | |
|                 ctx.x_shape = x.shape
 | |
|                 return x.reshape(shape) * scale_forward
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, gO):
 | |
|                 return gO.reshape(ctx.x_shape) * ctx.scale_backward, None, None, None
 | |
| 
 | |
|         def test(x, shape, scale_forward, scale_backward):
 | |
|             y = MyReshape.apply(x, shape, scale_forward, scale_backward).sum()
 | |
|             gx, = torch.autograd.grad(y, x)
 | |
| 
 | |
|             y_expected = MyReshapeRef.apply(x, shape, scale_forward, scale_backward).sum()
 | |
|             gx_expected, = torch.autograd.grad(y_expected, x)
 | |
| 
 | |
|             self.assertEqual(y_expected, y)
 | |
|             self.assertEqual(gx_expected, gx)
 | |
| 
 | |
|         test(torch.randn(24, requires_grad=True), (3, 8), 7, 11)
 | |
|         test(torch.randn(2, 3, 4, requires_grad=True), (6, 4), -1, 2)
 | |
| 
 | |
|     def test_accumulate_grad(self):
 | |
|         grad_output = torch.ones(5, 5)
 | |
| 
 | |
|         def compute_grad(create_graph):
 | |
|             x = torch.randn(5, 5, requires_grad=True)
 | |
|             y = x + 2
 | |
|             y.backward(grad_output, retain_graph=True)
 | |
|             x_grad = x.grad
 | |
|             x_grad_clone = x.grad.clone()
 | |
|             y.backward(grad_output, create_graph=create_graph)
 | |
|             return x_grad, x_grad_clone
 | |
| 
 | |
|         # Accumulate in-place when create_graph is False
 | |
|         x_grad, x_grad_clone = compute_grad(create_graph=False)
 | |
|         self.assertEqual(x_grad, x_grad_clone * 2)
 | |
| 
 | |
|         # Accumulate out-of-place when create_graph is False
 | |
|         x_grad, x_grad_clone = compute_grad(create_graph=True)
 | |
|         self.assertEqual(x_grad, x_grad_clone)
 | |
| 
 | |
|     def test_accumulate_grad_tensor_reference(self):
 | |
|         def _test_grad_tensor(params_grad_tensor, backward_grad_tensor, should_preserve_reference, create_graph):
 | |
|             params = torch.tensor([1.5, 1.5]).requires_grad_()
 | |
|             params.grad = params_grad_tensor
 | |
|             grad_saved = params.grad
 | |
|             params.backward(backward_grad_tensor, create_graph=create_graph)
 | |
|             self.assertEqual(id(grad_saved) == id(params.grad), should_preserve_reference)
 | |
| 
 | |
|         for create_graph in (False, True):
 | |
|             # Accumulate dense gradient to sparse gradient will change the `params.grad` reference
 | |
|             _test_grad_tensor(
 | |
|                 torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.])),
 | |
|                 torch.tensor([1.5, 1.5]),
 | |
|                 False,  # never accumulates in-place
 | |
|                 create_graph)
 | |
| 
 | |
|             # Accumulate dense gradient to dense gradient will preserve the `params.grad` reference,
 | |
|             # but only if create_graph=False.
 | |
|             _test_grad_tensor(
 | |
|                 torch.tensor([1.5, 1.5]),
 | |
|                 torch.tensor([1.5, 1.5]),
 | |
|                 not create_graph,
 | |
|                 create_graph)
 | |
| 
 | |
|             # Accumulate sparse gradient to sparse gradient will preserve the `params.grad` reference,
 | |
|             # but only if create_graph=False.
 | |
|             _test_grad_tensor(
 | |
|                 torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.])),
 | |
|                 torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.])),
 | |
|                 not create_graph,
 | |
|                 create_graph)
 | |
| 
 | |
|     def test_accumulate_grad_with_zero_numel_grad(self):
 | |
|         a = torch.rand(4, 0, requires_grad=True)
 | |
|         b = torch.rand(4, 1, requires_grad=True)
 | |
|         c = a + b
 | |
|         assert c.shape == (4, 0)
 | |
|         c.sum().backward()
 | |
| 
 | |
|         self.assertEqual(b.grad, torch.zeros(4, 1))
 | |
|         self.assertEqual(a.grad, torch.zeros(4, 0))
 | |
| 
 | |
|     def test_hessian_vector(self):
 | |
|         x = torch.randn(2, 2, requires_grad=True)
 | |
|         y = torch.randn(2, 2, requires_grad=True)
 | |
| 
 | |
|         z = x ** 2 + y * x + y ** 2
 | |
|         z.backward(torch.ones(2, 2), create_graph=True)
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             x_grad = 2 * x + y
 | |
|             y_grad = x + 2 * y
 | |
|         self.assertEqual(x.grad, x_grad)
 | |
|         self.assertEqual(y.grad, y_grad)
 | |
| 
 | |
|         grad_sum = 2 * x.grad + y.grad
 | |
|         grad_sum.backward(torch.ones(2, 2))
 | |
|         x_hv = torch.ones(2, 2) * 5
 | |
|         y_hv = torch.ones(2, 2) * 4
 | |
|         self.assertEqual(x.grad, x_grad + x_hv)
 | |
|         self.assertEqual(y.grad, y_grad + y_hv)
 | |
| 
 | |
|     def test_grad(self):
 | |
|         x = torch.randn(2, 2, requires_grad=True)
 | |
|         y = torch.randn(2, 2, requires_grad=True)
 | |
|         z = x ** 2 + y * x + y ** 2
 | |
|         z.backward(torch.ones(2, 2), create_graph=True)
 | |
| 
 | |
|         x_grad = 2 * x + y
 | |
|         y_grad = x + 2 * y
 | |
|         self.assertEqual(x.grad, x_grad)
 | |
|         self.assertEqual(y.grad, y_grad)
 | |
| 
 | |
|         grad_sum = 2 * x.grad + y.grad
 | |
|         x_hv = torch.autograd.grad(
 | |
|             outputs=[grad_sum], grad_outputs=[torch.ones(2, 2)],
 | |
|             inputs=[x], create_graph=True)
 | |
|         expected_x_hv = torch.ones(2, 2) * 5
 | |
|         expected_y_hv = torch.ones(2, 2) * 4
 | |
| 
 | |
|         self.assertEqual(x_hv[0], expected_x_hv)
 | |
|         self.assertEqual(x.grad, x_grad)
 | |
|         self.assertEqual(y.grad, y_grad)
 | |
| 
 | |
|         # Test that grad_outputs and outputs have the same shape
 | |
|         grad_out = torch.ones(2)
 | |
|         try:
 | |
|             torch.autograd.grad(
 | |
|                 outputs=[grad_sum], grad_outputs=[grad_out],
 | |
|                 inputs=[x], create_graph=True)
 | |
|             self.assertFail()
 | |
|         except RuntimeError as error:
 | |
|             self.assertEqual(str(error), "Mismatch in shape: grad_output[0] has a shape of "
 | |
|                              + str(grad_out.shape) + " and output[0] has a shape of "
 | |
|                              + str(grad_sum.shape) + ".")
 | |
| 
 | |
|     def test_grad_nonleaf(self):
 | |
|         x_init = torch.randn(2, 2, requires_grad=True)
 | |
|         x = x_init
 | |
|         y = torch.randn(2, 2, requires_grad=True)
 | |
|         grad_output = torch.ones(2, 2)
 | |
| 
 | |
|         def fn(x):
 | |
|             return x ** 2 + y * x + y ** 2
 | |
| 
 | |
|         for _ in range(5):
 | |
|             grad_x, = torch.autograd.grad(
 | |
|                 fn(x), x, grad_outputs=grad_output, create_graph=True)
 | |
| 
 | |
|             grad_x_expected = 2 * x + y
 | |
|             self.assertIsNone(y.grad)
 | |
|             self.assertIsNone(x.grad)
 | |
|             self.assertEqual(grad_x, grad_x_expected)
 | |
| 
 | |
|             x = x + 0.05 * grad_x
 | |
| 
 | |
|         val_init = fn(x_init).sum()
 | |
|         val_final = fn(x).sum()
 | |
|         self.assertGreater(val_final, val_init)
 | |
| 
 | |
|         x.backward(grad_output)
 | |
|         self.assertIsNotNone(y.grad)
 | |
|         self.assertIsNotNone(x_init.grad)
 | |
| 
 | |
|     def test_grad_nonleaf_many_outputs(self):
 | |
|         # This checks an edge case for function callbacks
 | |
|         # We want to capture two grads of a function, but can only
 | |
|         # register a single callback.
 | |
|         x = torch.randn(4, 2, requires_grad=True)
 | |
|         a, b = x.chunk(2)
 | |
| 
 | |
|         def hook(*grads):
 | |
|             hook_called[0] = True
 | |
|         hook_called = [False]
 | |
|         x.register_hook(hook)
 | |
| 
 | |
|         go = torch.randn(2, 2)
 | |
|         grad_a, grad_b = torch.autograd.grad(
 | |
|             (a + 2 * b), [a, b], grad_outputs=go, create_graph=True)
 | |
| 
 | |
|         self.assertEqual(grad_a, go)
 | |
|         self.assertEqual(grad_b, go * 2)
 | |
|         self.assertFalse(hook_called[0])
 | |
|         self.assertIsNone(x.grad)
 | |
| 
 | |
|     def test_grad_nonleaf_register_hook(self):
 | |
|         # This checks an edge case for register_hook.
 | |
|         # We want to capture grad of a nonleaf tensor,
 | |
|         # but avoid segfault during backward of other nonleaf tensors
 | |
|         x = torch.randn(5, requires_grad=True)
 | |
|         x_list = x.unbind()
 | |
| 
 | |
|         x0 = x_list[0]
 | |
|         hook_results = [None]
 | |
| 
 | |
|         def hook(grad):
 | |
|             hook_results[0] = grad
 | |
|         x0.register_hook(hook)
 | |
| 
 | |
|         x_list[0].backward()
 | |
|         self.assertEqual(hook_results[0], torch.tensor(1.))
 | |
|         expected_grad = torch.tensor([1., 0, 0, 0, 0])
 | |
|         self.assertEqual(x.grad, expected_grad)
 | |
|         self.assertIsNone(x_list[0].grad)
 | |
| 
 | |
|         for i in range(1, 5, 1):
 | |
|             x_list[i].backward()
 | |
|             self.assertEqual(hook_results[0], None)
 | |
|             expected_grad[i] = 1.0
 | |
|             self.assertEqual(x.grad, expected_grad)
 | |
|             self.assertIsNone(x_list[i].grad)
 | |
| 
 | |
|     def test_grad_materialize_grads(self):
 | |
|         x = torch.tensor(0.5, requires_grad=True)
 | |
|         a = torch.tensor(1.0, requires_grad=True)
 | |
|         y = x * a
 | |
|         dydx = torch.autograd.grad(y, x, create_graph=True)
 | |
|         d2ydx2_none = torch.autograd.grad(dydx, x, create_graph=True, allow_unused=True)
 | |
|         d2ydx2 = torch.autograd.grad(dydx, x, create_graph=True, allow_unused=True, materialize_grads=True)
 | |
|         # `allow_unused` set to True implicitly
 | |
|         d3ydx3 = torch.autograd.grad(d2ydx2, x, materialize_grads=True)
 | |
|         self.assertIsNone(d2ydx2_none[0])
 | |
|         self.assertEqual(d2ydx2[0].item(), 0)
 | |
|         self.assertEqual(d3ydx3[0].item(), 0)
 | |
|         with self.assertRaisesRegex(ValueError, "Expected allow_unused to be True or not passed when"):
 | |
|             torch.autograd.grad(y, x, allow_unused=False, materialize_grads=True)
 | |
| 
 | |
|     def test_hook_with_no_name(self):
 | |
|         # Create a hook that do not have a __name__ attribute
 | |
|         class MyHookClass:
 | |
|             def __call__(self, grad):
 | |
|                 return grad.clone()
 | |
| 
 | |
|         x = torch.randn(5, requires_grad=True).clone()
 | |
|         x.register_hook(MyHookClass())
 | |
|         x.sum().backward()
 | |
|         # Should run fine
 | |
| 
 | |
|     def test_prehook_ordering(self):
 | |
|         # Hooks registered to tensor are ordered before those
 | |
|         # that are registered to grad_fn
 | |
|         log = []
 | |
| 
 | |
|         def hook1(g):
 | |
|             log.append(1)
 | |
|             return g * 3
 | |
| 
 | |
|         def hook2(gs):
 | |
|             log.append(2)
 | |
|             return tuple(g * 2 for g in gs)
 | |
| 
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         b = a.clone()
 | |
| 
 | |
|         b.grad_fn.register_prehook(hook2)
 | |
|         b.register_hook(hook1)
 | |
|         b.grad_fn.register_prehook(hook2)
 | |
| 
 | |
|         acc = b.grad_fn.next_functions[0][0]
 | |
|         a.register_hook(hook1)
 | |
|         acc.register_prehook(hook2)
 | |
|         a.register_hook(hook1)
 | |
| 
 | |
|         b.sum().backward(retain_graph=True)
 | |
|         self.assertEqual(log, [1, 2, 2, 1, 1, 2])
 | |
| 
 | |
|         # grad also runs hooks on accumulate grad nodes, even though
 | |
|         # the accumulate grad nodes are not actually executed
 | |
|         log = []
 | |
|         torch.autograd.grad(b.sum(), inputs=(a,), retain_graph=True)
 | |
|         self.assertEqual(log, [1, 2, 2, 1, 1])
 | |
| 
 | |
|         log = []
 | |
|         b.sum().backward(inputs=(b,))
 | |
|         self.assertEqual(log, [1, 2, 2])
 | |
|         # retains_grad hooks would not observe modifications by all pre hooks
 | |
|         # because they are executed after
 | |
|         self.assertEqual(b.grad.item(), 3)
 | |
| 
 | |
|     def test_retains_grad_can_always_observe_tensor_prehook(self):
 | |
|         def tensor_prehook(g):
 | |
|             return g * 2
 | |
| 
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         b = a.clone()
 | |
|         b.register_hook(tensor_prehook)
 | |
|         b.retain_grad()
 | |
|         b.register_hook(tensor_prehook)
 | |
| 
 | |
|         b.clone().backward()
 | |
|         self.assertEqual(b.grad.item(), 4)
 | |
| 
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         b = a.clone()
 | |
|         b.retain_grad()
 | |
|         b.register_hook(tensor_prehook)
 | |
| 
 | |
|         b.clone().backward()
 | |
|         self.assertEqual(b.grad.item(), 2)
 | |
| 
 | |
|     def test_accumulate_grad_posthooks_can_observe_tensor_prehook(self):
 | |
|         # Post hooks on accumulate should be able to observe changes to
 | |
|         # grad made by tensor prehooks
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
| 
 | |
|         def tensor_prehook(g):
 | |
|             return g * 2
 | |
| 
 | |
|         def posthook(gO, gI):
 | |
|             self.assertTrue(torch.allclose(gI[0], a * 2))
 | |
|             self.assertEqual(len(gO), 0)
 | |
| 
 | |
|         def prehook(gI):
 | |
|             self.assertTrue(torch.allclose(gI[0], a * 2))
 | |
|             self.assertEqual(len(gI), 1)
 | |
| 
 | |
|         b = a.clone()
 | |
|         acc = b.grad_fn.next_functions[0][0]
 | |
|         acc.register_hook(posthook)
 | |
|         acc.register_prehook(prehook)
 | |
|         a.register_hook(tensor_prehook)
 | |
| 
 | |
|         b.backward()
 | |
| 
 | |
|     def test_hook_edge_case_when_called_with_grad(self):
 | |
|         # grad executes the tensor hooks of the next node but not
 | |
|         # grad_fn pre hooks or the post hooks
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         b = a * 2
 | |
|         c = b * 2
 | |
| 
 | |
|         tensor_hook_count = [0]
 | |
|         prehook_count = [0]
 | |
|         posthook_count = [0]
 | |
| 
 | |
|         def reset_counts():
 | |
|             nonlocal tensor_hook_count, prehook_count, posthook_count
 | |
|             tensor_hook_count = [0]
 | |
|             prehook_count = [0]
 | |
|             posthook_count = [0]
 | |
| 
 | |
|         def tensor_prehook(g):
 | |
|             tensor_hook_count[0] += 1
 | |
| 
 | |
|         def prehook(g):
 | |
|             prehook_count[0] += 1
 | |
| 
 | |
|         def posthook(gI, gO):
 | |
|             posthook_count[0] += 1
 | |
| 
 | |
|         a.register_hook(tensor_prehook)
 | |
|         b.register_hook(tensor_prehook)
 | |
|         acc = b.grad_fn.next_functions[0][0]
 | |
|         acc.register_hook(posthook)
 | |
|         acc.register_prehook(prehook)
 | |
|         b.grad_fn.register_hook(posthook)
 | |
|         b.grad_fn.register_prehook(prehook)
 | |
| 
 | |
|         torch.autograd.grad(c, inputs=(b), retain_graph=True)
 | |
|         self.assertEqual(tensor_hook_count[0], 1)
 | |
|         self.assertEqual(posthook_count[0], 0)
 | |
|         self.assertEqual(prehook_count[0], 0)
 | |
|         reset_counts()
 | |
| 
 | |
|         torch.autograd.grad(c, inputs=(a, b), retain_graph=True)
 | |
|         self.assertEqual(tensor_hook_count[0], 2)
 | |
|         self.assertEqual(posthook_count[0], 1)
 | |
|         self.assertEqual(prehook_count[0], 1)
 | |
|         reset_counts()
 | |
| 
 | |
|         c.backward(retain_graph=True)
 | |
|         self.assertEqual(tensor_hook_count[0], 2)
 | |
|         self.assertEqual(posthook_count[0], 2)
 | |
|         self.assertEqual(prehook_count[0], 2)
 | |
|         reset_counts()
 | |
| 
 | |
|         c.backward(inputs=(a, b), retain_graph=True)
 | |
|         self.assertEqual(tensor_hook_count[0], 2)
 | |
|         self.assertEqual(posthook_count[0], 2)
 | |
|         self.assertEqual(prehook_count[0], 2)
 | |
| 
 | |
|     def test_sharded_grad(self):
 | |
|         leaves = [torch.zeros(5, 5, requires_grad=True) for _ in range(10)]
 | |
|         intermediates = [l * i + l * l for i, l in enumerate(leaves)]
 | |
|         loss = sum(v * i for i, v in enumerate(intermediates)).sum()
 | |
| 
 | |
|         # define a helper for dividing intermediates into groups
 | |
|         def group(l, group_size):
 | |
|             return (l[i:i + group_size] for i in range(0, len(l), group_size))
 | |
| 
 | |
|         # Compute the d loss / d intermediates in chunks of shard_size
 | |
|         shard_size = 2
 | |
|         d_intermediates = [d_i for intermediates_batch in group(intermediates, shard_size)
 | |
|                            for d_i in torch.autograd.grad(loss, intermediates_batch)]
 | |
|         # Compute rest of backward pass
 | |
|         torch.autograd.backward(intermediates, d_intermediates)
 | |
| 
 | |
|         for i, l in enumerate(leaves):
 | |
|             self.assertEqual(l.grad, i * i * (1 + l))
 | |
| 
 | |
|     def test_backward_badcalls(self):
 | |
|         x = torch.ones(1)
 | |
|         with self.assertRaisesRegex(RuntimeError, 'does not require grad'):
 | |
|             x.backward()
 | |
| 
 | |
|     def test_grad_badcalls(self):
 | |
|         x = torch.ones(1)
 | |
|         y = x ** 2
 | |
|         with self.assertRaisesRegex(RuntimeError, 'does not require grad'):
 | |
|             torch.autograd.grad(x, y)
 | |
|         with self.assertRaisesRegex(RuntimeError, 'does not require grad'):
 | |
|             torch.autograd.grad(y, x)
 | |
| 
 | |
|         x = torch.ones(1, requires_grad=True)
 | |
|         y = x ** 2
 | |
|         torch.autograd.grad(y, x)  # this should succeed now
 | |
| 
 | |
|     def test_grad_empty_inputs(self):
 | |
|         x = torch.tensor([1.0], requires_grad=True)
 | |
|         with self.assertRaisesRegex(ValueError, "grad requires non-empty inputs."):
 | |
|             torch.autograd.grad(2 * x, [], grad_outputs=torch.tensor([1.0]))
 | |
| 
 | |
|     def test_grad_fn_badcalls(self):
 | |
|         error_regex = 'expected .* arguments, got .* instead'
 | |
|         x = torch.ones(1, requires_grad=True)
 | |
|         y = x ** 2
 | |
|         with self.assertRaisesRegex(TypeError, error_regex):
 | |
|             y.grad_fn(x.detach(), x.detach())  # too many
 | |
|         with self.assertRaisesRegex(TypeError, error_regex):
 | |
|             y.grad_fn()  # too few
 | |
| 
 | |
|         y.grad_fn(x.detach())  # this should succeed
 | |
| 
 | |
|     def test_grad_unreachable(self):
 | |
|         x = torch.ones(1, requires_grad=True)
 | |
|         y = torch.ones(1, requires_grad=True)
 | |
|         # Make sure x and y have grad accumulators allocated
 | |
|         z = x * 2
 | |
|         w = y * 2
 | |
| 
 | |
|         grad_x, grad_y = torch.autograd.grad(x * 2, [x, y], allow_unused=True)
 | |
|         self.assertEqual(grad_x, x * 2)
 | |
|         self.assertIsNone(grad_y)
 | |
| 
 | |
|         # This is slightly different than the case above, because z doesn't even
 | |
|         # have a grad accumulator allocated.
 | |
|         z = torch.ones(1, requires_grad=True)
 | |
|         grad_x, grad_z = torch.autograd.grad(x * 2, [x, z], allow_unused=True)
 | |
|         self.assertEqual(grad_x, x * 2)
 | |
|         self.assertIsNone(grad_z)
 | |
| 
 | |
|         # allow_unused=False, but grads contains None inside, should throw
 | |
|         with self.assertRaisesRegex(RuntimeError,
 | |
|                                     "Set allow_unused=True"):
 | |
|             grad_x, grad_y = torch.autograd.grad(x * 2, [x, y], allow_unused=False)
 | |
| 
 | |
|     def test_grad_unreachable_discovery(self):
 | |
|         # Test that certain nodes are not erroneously executed when an input
 | |
|         # is unreachable. See #39784
 | |
|         class MyFunc(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, x):
 | |
|                 self.fail("This node should not be executed!")
 | |
| 
 | |
|         x = MyFunc.apply(torch.randn(1, requires_grad=True) * 2)
 | |
|         y = torch.randn(1, requires_grad=True)
 | |
|         (gY,) = torch.autograd.grad(x, (y, ), allow_unused=True)
 | |
|         self.assertIsNone(gY)
 | |
| 
 | |
|         x = MyFunc.apply(torch.randn(1, requires_grad=True) * 2)
 | |
|         y = torch.randn(1, requires_grad=True)
 | |
|         z = torch.randn(1, requires_grad=True)
 | |
|         (gY, gZ) = torch.autograd.grad(x + z, (y, z), allow_unused=True)
 | |
|         self.assertIsNone(gY)
 | |
|         self.assertIsNotNone(gZ)
 | |
| 
 | |
|         x = MyFunc.apply(torch.randn(1, requires_grad=True) * 2)
 | |
|         y = torch.randn(1, requires_grad=True)
 | |
|         torch.autograd.backward(x, inputs=(y, ))  # allow_unused is implicitly True!
 | |
|         self.assertIsNone(y.grad)
 | |
| 
 | |
|     def test_grad_batched_grad(self):
 | |
|         x = torch.randn(2, 2, requires_grad=True)
 | |
| 
 | |
|         out = x.clone()  # Size([2, 2])
 | |
|         batched_grad = torch.arange(3).expand(2, 2, 3).transpose(0, 2)  # Size([3, 2, 2])
 | |
|         grad, = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)
 | |
|         self.assertEqual(grad, torch.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype))
 | |
| 
 | |
|         # Detect shape mismatch
 | |
|         grad_out = torch.ones(2, 2)
 | |
|         with self.assertRaisesRegex(RuntimeError, "If `is_grads_batched=True`, we interpret the first"):
 | |
|             torch.autograd.grad(outputs=out, grad_outputs=(grad_out,), inputs=(x,), is_grads_batched=True)
 | |
| 
 | |
|         # Scalar outputs
 | |
|         out = x.sum()  # Size([])
 | |
|         batched_grad = torch.arange(3)  # Size([3])
 | |
|         grad, = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)
 | |
|         self.assertEqual(grad, torch.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype))
 | |
| 
 | |
|         # We consider scalar and sized-1 to be a mismatch. This is consistent with current non-batched behavior.
 | |
|         grad_out = torch.ones(2).unsqueeze(1)
 | |
|         with self.assertRaisesRegex(RuntimeError, "If `is_grads_batched=True`, we interpret the first"):
 | |
|             torch.autograd.grad(outputs=out, grad_outputs=(grad_out,), inputs=(x,), is_grads_batched=True)
 | |
| 
 | |
|     def test_hooks(self):
 | |
|         x = torch.ones(5, 5, requires_grad=True)
 | |
|         y = torch.ones(5, 5) * 4
 | |
|         y.requires_grad_(True)
 | |
| 
 | |
|         counter = [0]
 | |
| 
 | |
|         def bw_hook(inc, grad):
 | |
|             self.assertIsInstance(grad, torch.Tensor)
 | |
|             counter[0] += inc
 | |
| 
 | |
|         z = x ** 2 + x * 2 + x * y + y
 | |
|         x.register_hook(lambda *args: bw_hook(0, *args))
 | |
|         test = z.register_hook(lambda *args: bw_hook(1, *args))
 | |
|         z.backward(torch.ones(5, 5), retain_graph=True)
 | |
|         self.assertEqual(counter[0], 1)
 | |
| 
 | |
|         test2 = z.register_hook(lambda *args: bw_hook(2, *args))
 | |
|         z.backward(torch.ones(5, 5), retain_graph=True)
 | |
|         self.assertEqual(counter[0], 4)
 | |
| 
 | |
|         test2.remove()
 | |
|         z.backward(torch.ones(5, 5), retain_graph=True)
 | |
|         self.assertEqual(counter[0], 5)
 | |
| 
 | |
|         def bw_hook_modify(grad):
 | |
|             return grad.mul(2)
 | |
| 
 | |
|         test.remove()
 | |
|         z.register_hook(bw_hook_modify)
 | |
|         with torch.no_grad():
 | |
|             y.grad.zero_()
 | |
|         z.backward(torch.ones(5, 5), retain_graph=True)
 | |
|         self.assertEqual(y.grad, (x + 1) * 2)
 | |
| 
 | |
|         y.register_hook(bw_hook_modify)
 | |
|         with torch.no_grad():
 | |
|             y.grad.zero_()
 | |
|         z.backward(torch.ones(5, 5))
 | |
|         self.assertEqual(y.grad, (x + 1) * 4)
 | |
| 
 | |
|     def _get_mul2(self, use_custom_function):
 | |
|         if use_custom_function:
 | |
|             class Mul2(Function):
 | |
|                 @staticmethod
 | |
|                 def forward(ctx, x):
 | |
|                     return x * 2
 | |
| 
 | |
|                 @staticmethod
 | |
|                 def backward(ctx, gO):
 | |
|                     return gO * 2
 | |
| 
 | |
|             return Mul2.apply
 | |
|         else:
 | |
|             return lambda x: x * 2
 | |
| 
 | |
|     def test_grad_fn_prehooks(self):
 | |
|         for use_custom_function in (True, False):
 | |
|             mul2 = self._get_mul2(use_custom_function)
 | |
| 
 | |
|             a = torch.tensor([1.], requires_grad=True)
 | |
|             b = mul2(a)
 | |
| 
 | |
|             post_counter = [0]
 | |
|             pre_counter = [0]
 | |
| 
 | |
|             def posthook(grad_input, grad_output):
 | |
|                 self.assertEqual(pre_counter[0], 3)
 | |
|                 self.assertTrue(torch.allclose(grad_output[0], torch.ones(1) * 8))
 | |
|                 self.assertTrue(torch.allclose(grad_input[0], torch.ones(1) * 16))
 | |
|                 post_counter[0] += 1
 | |
|                 return grad_input
 | |
| 
 | |
|             def prehook(grad_output):
 | |
|                 pre_counter[0] += 1
 | |
|                 return (grad_output[0] * 2,)
 | |
| 
 | |
|             # register posthook x 2
 | |
|             b.grad_fn.register_hook(posthook)
 | |
|             b.grad_fn.register_hook(posthook)
 | |
|             # register prehook x 3
 | |
|             b.grad_fn.register_prehook(prehook)
 | |
|             b.grad_fn.register_prehook(lambda x: None)
 | |
|             b.grad_fn.register_prehook(prehook)
 | |
|             b.grad_fn.register_prehook(prehook)
 | |
|             b.grad_fn.register_prehook(lambda x: x)
 | |
|             b.grad_fn.register_prehook(lambda x: None)
 | |
| 
 | |
|             b.sum().backward()
 | |
| 
 | |
|             self.assertEqual(post_counter[0], 2)
 | |
|             self.assertEqual(pre_counter[0], 3)
 | |
| 
 | |
|             # Return None
 | |
|             a = torch.rand(3, 3, requires_grad=True)
 | |
|             b = mul2(a)
 | |
| 
 | |
|             def prehook(grad_output):
 | |
|                 pre_counter[0] += 1
 | |
|                 return None
 | |
| 
 | |
|             b.grad_fn.register_prehook(prehook)
 | |
|             b.sum().backward()
 | |
|             self.assertEqual(pre_counter[0], 4)
 | |
|             self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2))
 | |
| 
 | |
|     def test_grad_fn_prehooks_multiple_outputs(self):
 | |
|         # Compute gradients without hooks
 | |
|         b = torch.rand(3, 3, requires_grad=True)
 | |
|         var, mean = torch.var_mean(b, dim=0)
 | |
|         (var + mean).sum().backward()
 | |
| 
 | |
|         # Compute gradients with hooks
 | |
|         a = b.detach().requires_grad_()
 | |
|         counter = [0]
 | |
| 
 | |
|         def prehook(grad_output):
 | |
|             gvar, gmean = grad_output
 | |
|             counter[0] += 1
 | |
|             return (gvar * 2, gmean * 2)
 | |
| 
 | |
|         var, mean = torch.var_mean(a, dim=0)
 | |
|         mean.grad_fn.register_prehook(prehook)
 | |
|         (var + mean).sum().backward()
 | |
| 
 | |
|         self.assertEqual(counter[0], 1)
 | |
|         # Compare
 | |
|         self.assertTrue(torch.allclose(a.grad, b.grad * 2))
 | |
| 
 | |
|         # Test with custom Function
 | |
|         class DoubleMul2(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, a, y):
 | |
|                 ctx.a = a
 | |
|                 return a * x * 2, a, a * y * 2
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, g1, _a, g2):
 | |
|                 return ctx.a * g1 * 2, None, ctx.a * g2 * 2
 | |
| 
 | |
|         counter = [0]
 | |
| 
 | |
|         def prehook(grad_output):
 | |
|             g1, ga, g2 = grad_output
 | |
|             self.assertIsNone(ga)
 | |
|             counter[0] += 1
 | |
|             return (g1 * 2, None, g2 * 2)
 | |
| 
 | |
|         a = torch.randn(3, 3, requires_grad=True)
 | |
|         b = torch.randn(3, 3, requires_grad=True)
 | |
|         k = 3
 | |
|         c, _, d = DoubleMul2.apply(a, k, b)
 | |
|         c.grad_fn.register_prehook(prehook)
 | |
|         (c + d).sum().backward()
 | |
| 
 | |
|         self.assertEqual(counter[0], 1)
 | |
|         self.assertTrue(torch.allclose(a.grad, torch.ones(1) * 4 * k))
 | |
|         self.assertTrue(torch.allclose(b.grad, torch.ones(1) * 4 * k))
 | |
| 
 | |
|     def test_grad_fn_prehooks_remove_hooks(self):
 | |
|         for use_custom_function in (True, False):
 | |
|             mul2 = self._get_mul2(use_custom_function)
 | |
| 
 | |
|             # Simply remove hooks
 | |
| 
 | |
|             a = torch.rand(3, 3, requires_grad=True)
 | |
|             b = mul2(a)
 | |
|             counter = [0]
 | |
| 
 | |
|             def prehook(grad_output):
 | |
|                 counter[0] += 1
 | |
|                 return None
 | |
| 
 | |
|             handle = b.grad_fn.register_prehook(prehook)
 | |
|             b.grad_fn.register_prehook(prehook)
 | |
|             handle.remove()
 | |
|             b.sum().backward()
 | |
|             self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2))
 | |
|             self.assertEqual(counter[0], 1)
 | |
| 
 | |
|             # Remove hooks during backward
 | |
|             a = torch.rand(3, 3, requires_grad=True)
 | |
|             b = mul2(a)
 | |
|             counter = [0]
 | |
| 
 | |
|             def prehook1(grad_output):
 | |
|                 handle2.remove()
 | |
|                 # Remove hook that is already removed is OK
 | |
|                 handle3.remove()
 | |
|                 return None
 | |
| 
 | |
|             def prehook2(grad_output):
 | |
|                 counter[0] += 1
 | |
|                 return None
 | |
| 
 | |
|             # Hooks that registered first run first
 | |
|             b.grad_fn.register_prehook(prehook1)
 | |
|             handle2 = b.grad_fn.register_prehook(prehook2)
 | |
|             handle3 = b.grad_fn.register_prehook(prehook2)
 | |
|             handle3.remove()
 | |
|             b.sum().backward()
 | |
|             self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2))
 | |
|             self.assertEqual(counter[0], 1)
 | |
| 
 | |
|     def test_hooks_cpp(self):
 | |
|         # Tests hooks for autograd function implemented in C++
 | |
|         bn = torch.nn.BatchNorm1d(5, affine=False)
 | |
|         bn.double()
 | |
|         bn.eval()
 | |
| 
 | |
|         counter = [0]
 | |
| 
 | |
|         def bw_hook(grad):
 | |
|             counter[0] += 1
 | |
|             return grad * 2
 | |
| 
 | |
|         x = torch.ones(5, 5, dtype=torch.double, requires_grad=True)
 | |
|         z = bn(x)
 | |
|         z.register_hook(bw_hook)
 | |
|         z.sum().backward()
 | |
| 
 | |
|         self.assertEqual(counter[0], 1, msg='bw_hook not called')
 | |
|         self.assertEqual(x.grad, torch.ones(5, 5, dtype=torch.double) * 2, atol=1e-5, rtol=0)
 | |
| 
 | |
|     def test_hook_none(self):
 | |
|         # WARNING: this is a test for autograd internals.
 | |
|         # You should never have to use such things in your code.
 | |
|         class NoneGradientFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, y):
 | |
|                 assert ctx.needs_input_grad[0]
 | |
|                 assert not ctx.needs_input_grad[1]
 | |
|                 return x, y
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_x, grad_y):
 | |
|                 return grad_x, None
 | |
| 
 | |
|         was_called = [False]
 | |
| 
 | |
|         def hook(grad):
 | |
|             self.assertIsNotNone(grad)
 | |
|             was_called[0] = True
 | |
| 
 | |
|         x = torch.randn(5, 5, requires_grad=True)
 | |
|         y = torch.randn(5, 5)
 | |
|         rx, ry = NoneGradientFunction.apply(x, y)
 | |
|         rx.register_hook(hook)
 | |
|         ry.register_hook(hook)
 | |
|         sum(rx, ry).sum().backward()
 | |
|         self.assertTrue(was_called[0])
 | |
| 
 | |
|     def test_retain_grad(self):
 | |
|         input = torch.rand(1, 3, requires_grad=True)
 | |
|         h1 = input * 3
 | |
|         out = (h1 * h1).sum()
 | |
| 
 | |
|         # It should be possible to call retain_grad() multiple times
 | |
|         h1.retain_grad()
 | |
|         h1.retain_grad()
 | |
| 
 | |
|         # Gradient should be accumulated
 | |
|         out.backward(retain_graph=True)
 | |
|         self.assertEqual(h1 * 2, h1.grad)
 | |
|         out.backward(retain_graph=True)
 | |
|         self.assertEqual(h1 * 4, h1.grad)
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             input.grad.zero_()
 | |
|         # It should be a no-op for leaves
 | |
|         input.retain_grad()
 | |
|         input.retain_grad()
 | |
|         out.backward()
 | |
|         self.assertEqual(input * 18, input.grad)
 | |
| 
 | |
|     # NB: See test/cpp/api/autograd.cpp for more tests on the interaction between
 | |
|     #     retains_grad and hooks in cpp
 | |
|     def test_retain_grad_inplace(self):
 | |
|         a = torch.tensor([1.], requires_grad=True).clone()
 | |
|         a.retain_grad()
 | |
|         a.mul_(2)
 | |
|         a.sum().backward()
 | |
|         self.assertEqual(a.grad, torch.tensor([1.]))
 | |
| 
 | |
|         a = torch.tensor([1.], requires_grad=True).clone()
 | |
|         a.retain_grad()
 | |
|         # Inplace multiple times is OK
 | |
|         a.mul_(2)
 | |
|         a.mul_(2)
 | |
|         a.sum().backward()
 | |
|         self.assertEqual(a.grad, torch.tensor([1.]))
 | |
| 
 | |
|     def test_retains_grad_inplace_multiple_outputs(self):
 | |
|         class DoubleMul(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x * 2, x * 3
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, g1, g2):
 | |
|                 return g1 * 2 + g2 * 3
 | |
| 
 | |
|         var_mean = partial(torch.var_mean, dim=0)
 | |
| 
 | |
|         for fn in (DoubleMul.apply, var_mean):
 | |
|             b = torch.rand(3, 3, requires_grad=True)
 | |
|             var, mean = fn(b)
 | |
|             var.retain_grad()
 | |
|             mean.retain_grad()
 | |
|             # node has two retains_grad hooks
 | |
|             var.mul_(2)
 | |
|             # the retain_grad hook multi-output node refers should now be a nullptr
 | |
|             (var + mean).sum().backward()
 | |
|             gvar = var.grad
 | |
|             gmean = mean.grad
 | |
| 
 | |
|             a = b.detach().requires_grad_(True)
 | |
|             var, mean = fn(a)
 | |
|             var.mul_(2)
 | |
|             out = (var + mean).sum()
 | |
|             gvar_expected, gmean_expected = torch.autograd.grad(out, inputs=(var, mean))
 | |
|             self.assertTrue(torch.allclose(gvar, gvar_expected))
 | |
|             self.assertTrue(torch.allclose(gmean, gmean_expected))
 | |
| 
 | |
|     def test_retain_grad_inplace_over_view(self):
 | |
|         base = torch.tensor([1.], requires_grad=True).clone()
 | |
|         view = base[:]
 | |
|         view2 = base[:]
 | |
|         view.retain_grad()
 | |
|         view2.retain_grad()
 | |
|         view.mul_(2)
 | |
|         (view + view2).sum().backward()
 | |
| 
 | |
|         # The old grad_fn, slice, wouldn't be part of the graph during backward
 | |
|         # so if the retains grad were not properly updated to the new grad_fn,
 | |
|         # the grad would still be None
 | |
|         self.assertEqual(view.grad, view2.grad)
 | |
|         self.assertEqual(view.grad, torch.tensor([1.]))
 | |
| 
 | |
|     def test_tensor_hooks_inplace(self):
 | |
|         # Check that the second hook gets registered to the new version of tensor
 | |
|         count1 = [0]
 | |
|         count2 = [0]
 | |
| 
 | |
|         def fn1(grad):
 | |
|             count1[0] += 1
 | |
|             # x2 from mul, x2 from fn2
 | |
|             self.assertEqual(grad, torch.tensor([4.]))
 | |
|             return grad * 2
 | |
| 
 | |
|         def fn2(grad):
 | |
|             count2[0] += 1
 | |
|             self.assertEqual(grad, torch.tensor([1.]))
 | |
|             return grad * 2
 | |
| 
 | |
|         a = torch.tensor([1.], requires_grad=True)
 | |
|         b = a.clone()
 | |
|         b.register_hook(fn1)
 | |
|         b.mul_(2)
 | |
|         b.register_hook(fn2)
 | |
|         b.sum().backward()
 | |
|         self.assertEqual(count1[0], 1)
 | |
|         self.assertEqual(count2[0], 1)
 | |
|         self.assertEqual(a.grad, torch.tensor([8.]))
 | |
| 
 | |
|         count3 = [0]
 | |
| 
 | |
|         def fn3(grad):
 | |
|             count3[0] += 1
 | |
|             self.assertEqual(grad, torch.tensor([4.]))
 | |
|             return grad * 2
 | |
| 
 | |
|         a = torch.tensor([1.], requires_grad=True)
 | |
|         b = a.clone()
 | |
|         b.register_hook(fn3)
 | |
|         # Inplace multiple times is OK
 | |
|         b.mul_(2)
 | |
|         b.mul_(2)
 | |
|         b.sum().backward()
 | |
|         self.assertEqual(count1[0], 1)
 | |
|         self.assertEqual(a.grad, torch.tensor([8.]))
 | |
| 
 | |
|     def test_tensor_hooks_inplace_multiple_outputs(self):
 | |
|         class DoubleMul(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x * 2, x * 3
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, g1, g2):
 | |
|                 return g1 * 2 + g2 * 3
 | |
| 
 | |
|         var_mean = partial(torch.var_mean, dim=0)
 | |
| 
 | |
|         for fn in (DoubleMul.apply, var_mean):
 | |
|             counts = [0, 0, 0]
 | |
| 
 | |
|             def fn0(grad):
 | |
|                 counts[0] += 1
 | |
|                 self.assertEqual(grad, torch.ones_like(out1) * 2)
 | |
| 
 | |
|             def fn1(grad):
 | |
|                 counts[1] += 1
 | |
|                 self.assertEqual(grad, torch.ones_like(out1) * 3)
 | |
| 
 | |
|             def fn2(grad):
 | |
|                 counts[2] += 1
 | |
|                 self.assertEqual(grad, torch.ones_like(out1))
 | |
| 
 | |
|             b = torch.rand(3, 3, requires_grad=True)
 | |
|             out1, out2 = fn(b)
 | |
|             out1.register_hook(fn0)
 | |
|             out2.register_hook(fn1)
 | |
|             # node refers to two hook dicts
 | |
|             # out1 no longer no longer points to its old hook dict
 | |
|             out1.mul_(2)
 | |
|             # fn2 is registered to out1's new hook dict
 | |
|             out1.register_hook(fn2)
 | |
|             (out1 + out2 * 3).sum().backward()
 | |
|             self.assertEqual(counts, [1, 1, 1])
 | |
| 
 | |
|     def test_tensor_hooks_inplace_over_view(self):
 | |
|         # There might be a better UX here, but this is the way it is now
 | |
|         count = [0]
 | |
| 
 | |
|         def fn0(grad):
 | |
|             self.fail()
 | |
| 
 | |
|         def fn1(grad):
 | |
|             self.fail()
 | |
| 
 | |
|         def fn2(grad):
 | |
|             count[0] += 1
 | |
|             self.assertEqual(grad, torch.tensor([1.]))
 | |
| 
 | |
|         base = torch.tensor([1.], requires_grad=True).clone()
 | |
|         view = base[:]
 | |
|         view2 = base[:]
 | |
|         view.register_hook(fn0)
 | |
|         view2.register_hook(fn1)
 | |
|         view.mul_(2)
 | |
|         # We need to explicitly trigger an update to view to update its grad_fn
 | |
|         view2.grad_fn
 | |
|         view2.register_hook(fn2)
 | |
|         (view + view2).sum().backward()
 | |
|         # The hooks originally registered to view are not fired, one must explicitly
 | |
|         # trigger an update to the view's grad_fn, and then register a new hook
 | |
|         self.assertEqual(count[0], 1)
 | |
| 
 | |
|     def test_retain_grad_cycle(self):
 | |
|         x = torch.ones(5, 5, requires_grad=True)
 | |
| 
 | |
|         def run_test():
 | |
|             y = x * 2
 | |
|             y.retain_grad()
 | |
| 
 | |
|             return y / 2, torch._C._WeakTensorRef(y)
 | |
| 
 | |
|         z, ref = run_test()
 | |
|         self.assertTrue(ref.expired())
 | |
|         z.sum().backward()
 | |
| 
 | |
|     def test_backward(self):
 | |
|         v = torch.randn(5, 5, requires_grad=True)
 | |
|         x = torch.randn(5, 5, requires_grad=True)
 | |
|         y = (torch.rand(5, 5) + 0.1).requires_grad_(True)
 | |
|         z = torch.randn(5, 5, requires_grad=True)
 | |
|         grad_output = torch.randn(5, 5)
 | |
| 
 | |
|         v.backward(grad_output)
 | |
|         self.assertEqual(v.grad, grad_output)
 | |
| 
 | |
|         a = x + (y * z) + 4 * z ** 2 * x / y
 | |
|         a.backward(grad_output)
 | |
|         x_grad = 4 * z.pow(2) / y + 1
 | |
|         y_grad = z - 4 * x * z.pow(2) / y.pow(2)
 | |
|         z_grad = 8 * x * z / y + y
 | |
|         self.assertEqual(x.grad, x_grad * grad_output)
 | |
|         self.assertEqual(y.grad, y_grad * grad_output)
 | |
|         self.assertEqual(z.grad, z_grad * grad_output)
 | |
| 
 | |
|     def test_to_sparse_backward(self):
 | |
|         to_attr_names = (
 | |
|             'to_dense',
 | |
|             'to_sparse',
 | |
|             'to_sparse_csr',
 | |
|             'to_sparse_csc',
 | |
|             'to_sparse_bsr',
 | |
|             'to_sparse_bsc',
 | |
|         )
 | |
|         to_params = ((), (), (), (), (2,), (2,))
 | |
|         to_attr_names_params = dict(zip(to_attr_names, to_params))
 | |
| 
 | |
|         def check_inversion_possible(t, layout1, layout1_params, layout2, layout2_params):
 | |
|             l = (layout1, layout2)
 | |
|             p = (layout1_params, layout2_params)
 | |
|             for l1, l2, p1, p2 in ((*l, *p), (*l[::-1], *p[::-1])):
 | |
|                 try:
 | |
|                     to_l1 = getattr(t, l1)(*p1)
 | |
|                     to_l2 = getattr(to_l1, l2)(*p2)
 | |
|                 except RuntimeError:
 | |
|                     return False
 | |
| 
 | |
|             return True
 | |
| 
 | |
|         self_strided = torch.rand(4, 4, dtype=torch.double) + 1
 | |
|         grad_strided = torch.rand(4, 4, dtype=torch.double) + 1
 | |
| 
 | |
|         for from_to_attr in to_attr_names:
 | |
|             from_params = to_attr_names_params[from_to_attr]
 | |
|             self_from = getattr(self_strided, from_to_attr)(*from_params).requires_grad_(True)
 | |
| 
 | |
|             for to_to_attr in to_attr_names[1:]:
 | |
|                 to_params = to_attr_names_params[to_to_attr]
 | |
| 
 | |
|                 if check_inversion_possible(self_strided, from_to_attr, from_params, to_to_attr, to_params):
 | |
|                     self_to = getattr(self_from, to_to_attr)(*to_params)
 | |
|                     grad_to = getattr(grad_strided, to_to_attr)(*to_params)
 | |
| 
 | |
|                     # No gradcheck support for BSR/BSC, so the grads are checked explicitly
 | |
|                     grad_res = torch.autograd.grad(self_to, self_from, grad_to)[0]
 | |
| 
 | |
|                     self.assertEqual(grad_res.layout, self_from.layout)
 | |
|                     self.assertEqual(grad_res.to_dense(), grad_strided)
 | |
| 
 | |
|     def test_sparse_mm_backward(self):
 | |
|         size = (3, 3)
 | |
| 
 | |
|         mm_test_cases = product(*(([False, True],) * 4))
 | |
| 
 | |
|         for a_req_grad, a_is_sparse, b_req_grad, b_is_sparse in mm_test_cases:
 | |
|             # We should only be testing cases with sparse inputs, and at least one
 | |
|             # input needs to require grad so we can call a backward pass
 | |
|             if not ((a_is_sparse or b_is_sparse) and (a_req_grad or b_req_grad)):
 | |
|                 continue
 | |
|             a = torch.randn(size)
 | |
|             if a_is_sparse:
 | |
|                 # detaching as `a` needs to be a leaf
 | |
|                 a = a.to_sparse().detach()
 | |
|             b = torch.randn(size)
 | |
|             if b_is_sparse:
 | |
|                 # detaching as `b` needs to be a leaf
 | |
|                 b = b.to_sparse().detach()
 | |
| 
 | |
|             a = a.requires_grad_(a_req_grad)
 | |
|             b = b.requires_grad_(b_req_grad)
 | |
| 
 | |
|             r = a.mm(b)
 | |
|             s = r.sum().backward()
 | |
|             a_grad = None if a.grad is None else a.grad.clone().detach()
 | |
|             b_grad = None if b.grad is None else b.grad.clone().detach()
 | |
| 
 | |
|             # Redo with only dense tensors
 | |
|             a = (a.to_dense() if a.is_sparse else a).clone().detach().requires_grad_(a_req_grad)
 | |
|             b = (b.to_dense() if b.is_sparse else b).clone().detach().requires_grad_(b_req_grad)
 | |
| 
 | |
|             r = a.mm(b)
 | |
|             r.sum().backward()
 | |
| 
 | |
|             self.assertEqual(a_grad, a.grad)
 | |
|             self.assertEqual(b_grad, b.grad)
 | |
| 
 | |
|     def test_multi_backward(self):
 | |
|         x = torch.randn(5, 5, requires_grad=True)
 | |
|         y = torch.randn(5, 5, requires_grad=True)
 | |
| 
 | |
|         q = torch.randn(5, 5, requires_grad=True)
 | |
| 
 | |
|         a = torch.randn(5, 5, requires_grad=True)
 | |
|         b = torch.randn(5, 5, requires_grad=True)
 | |
| 
 | |
|         q2 = q * 2
 | |
|         z = x + y + q2
 | |
|         c = a * b + q2
 | |
|         grad_z = torch.randn(5, 5)
 | |
|         grad_c = torch.randn(5, 5)
 | |
|         torch.autograd.backward([z, c], [grad_z, grad_c])
 | |
| 
 | |
|         self.assertEqual(x.grad, grad_z)
 | |
|         self.assertEqual(y.grad, grad_z)
 | |
|         self.assertEqual(a.grad, grad_c * b)
 | |
|         self.assertEqual(b.grad, grad_c * a)
 | |
|         self.assertEqual(q.grad, (grad_c + grad_z) * 2)
 | |
| 
 | |
|     def test_multi_backward_no_grad(self):
 | |
|         x = torch.randn(5, 5, requires_grad=True)
 | |
|         y = torch.randn(5, 5, requires_grad=False)
 | |
| 
 | |
|         z = x + y
 | |
|         q = y * 2
 | |
| 
 | |
|         # NB: we currently raise an exception if any arguments to backwards
 | |
|         # have requires_grad=False and don't have a grad_fn. We may want to
 | |
|         # relax that check to a warning.
 | |
|         def call_backwards():
 | |
|             torch.autograd.backward([z, q], [torch.ones(5, 5), torch.ones(5, 5)])
 | |
|         self.assertRaises(RuntimeError, call_backwards)
 | |
| 
 | |
|     def test_backward_with_inputs(self):
 | |
|         x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
 | |
|         y = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
 | |
| 
 | |
|         def fn():
 | |
|             return x ** 2 + y * x + y ** 2
 | |
| 
 | |
|         gradient = torch.ones(2, 2)
 | |
|         x_grad_expected = 2 * x + y
 | |
|         y_grad_expected = x + 2 * y
 | |
| 
 | |
|         @torch.no_grad()
 | |
|         def reset_grad():
 | |
|             x.grad.zero_()
 | |
|             y.grad.zero_()
 | |
| 
 | |
|         torch.autograd.backward(fn(), gradient, inputs=[x, y])
 | |
|         self.assertEqual(x.grad, x_grad_expected)
 | |
|         self.assertEqual(y.grad, y_grad_expected)
 | |
| 
 | |
|         reset_grad()
 | |
|         torch.autograd.backward(fn(), gradient, inputs=[x])
 | |
|         self.assertEqual(x.grad, x_grad_expected)
 | |
|         self.assertEqual(y.grad, torch.zeros(2, 2), exact_dtype=False)
 | |
| 
 | |
|         reset_grad()
 | |
|         torch.autograd.backward(fn(), gradient, inputs=[y])
 | |
|         self.assertEqual(y.grad, y_grad_expected)
 | |
|         self.assertEqual(x.grad, torch.zeros(2, 2), exact_dtype=False)
 | |
| 
 | |
|         reset_grad()
 | |
|         torch.autograd.backward(fn(), gradient, inputs=y)
 | |
|         self.assertEqual(y.grad, y_grad_expected)
 | |
|         self.assertEqual(x.grad, torch.zeros(2, 2), exact_dtype=False)
 | |
| 
 | |
|         reset_grad()
 | |
|         self.assertRaisesRegex(RuntimeError, 'cannot be empty',
 | |
|                                lambda: torch.autograd.backward(fn(), gradient, inputs=[]))
 | |
| 
 | |
|     def test_backward_with_nonleaf_inputs(self):
 | |
|         x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
 | |
|         x_nonleaf = x * 1
 | |
|         y = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
 | |
|         z = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
 | |
| 
 | |
|         out = x_nonleaf ** 2 + y * x_nonleaf + y ** 2
 | |
| 
 | |
|         out.backward(torch.ones(2, 2, dtype=torch.double), create_graph=True, inputs=[x, y, x_nonleaf])
 | |
|         x_grad_expected = 2 * x + y
 | |
|         y_grad_expected = x + 2 * y
 | |
|         x_non_leaf_expected = 2 * x_nonleaf + y
 | |
| 
 | |
|         self.assertEqual(y.grad, y_grad_expected)
 | |
|         self.assertEqual(x.grad, x_grad_expected)
 | |
|         self.assertEqual(x_nonleaf.grad, x_non_leaf_expected)
 | |
| 
 | |
|         # backward doesn't have an allow_unused flag, so the behavior of backward
 | |
|         # when variable is not part of the graph is as if allow_used were true
 | |
|         # x.grad will simply be None.
 | |
|         out.backward(torch.ones(2, 2, dtype=torch.double), create_graph=True, inputs=[z])
 | |
|         self.assertIsNone(z.grad)
 | |
| 
 | |
|     def test_dependent_backward(self):
 | |
|         x = torch.randn(10, requires_grad=True)
 | |
|         y = x ** 2
 | |
|         z = y ** 3
 | |
| 
 | |
|         go_y = torch.randn(10)
 | |
|         go_z = torch.randn(10)
 | |
|         torch.autograd.backward([y, z], [go_y, go_z])
 | |
| 
 | |
|         xd = x
 | |
|         self.assertEqual(x.grad, 2 * xd * go_y + 6 * xd.pow(5) * go_z)
 | |
| 
 | |
|     def test_save_output_nr(self):
 | |
|         x = torch.randn(10, requires_grad=True)
 | |
| 
 | |
|         class MultiOutputFn(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x[:5], x[5:]
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, *grad):
 | |
|                 return torch.cat(grad)
 | |
| 
 | |
|         a, b = MultiOutputFn.apply(x)
 | |
|         self.assertEqual(b.output_nr, 1)
 | |
| 
 | |
|         class TestFn(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, b):
 | |
|                 ctx.save_for_backward(b)
 | |
|                 return b * 2
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_b):
 | |
|                 b, = ctx.saved_tensors
 | |
|                 self.assertEqual(b.output_nr, 1)
 | |
| 
 | |
|         TestFn.apply(b).sum().backward()
 | |
| 
 | |
|     def test_first_grad_fn_access_in_no_grad_mode(self):
 | |
|         a = torch.tensor([1 + 1j], requires_grad=True).clone()
 | |
|         v = a.real
 | |
|         a.add_(1)
 | |
|         with torch.autograd.grad_mode.no_grad():
 | |
|             v.grad_fn
 | |
| 
 | |
|     def test_free_deep_graph(self):
 | |
|         def scope():
 | |
|             depth = 150000
 | |
|             x = torch.randn(1, requires_grad=True)
 | |
|             y = x.clone()
 | |
| 
 | |
|             # build a "chain" computation graph
 | |
|             for _ in range(depth):
 | |
|                 y = y + y * 0.000001
 | |
| 
 | |
|             # graph deletion occurs when the above locals go out of scope.
 | |
|             # In this case `del y` will trigger it but it's easier to leave
 | |
|             # it to Python to delete the locals.
 | |
| 
 | |
|         # Should not stack overflow
 | |
|         scope()
 | |
| 
 | |
|     def test_free_deep_graph_complicated(self):
 | |
|         def scope():
 | |
|             depth = 100000
 | |
|             randchoice = torch.randint(2, [depth, 2])
 | |
|             x = torch.randn(1, requires_grad=True)
 | |
|             y = x.clone()
 | |
| 
 | |
|             # Hold the two previous values
 | |
|             prev_values = [None, None]
 | |
| 
 | |
|             # Build a "chain with skip connections" graph
 | |
|             for _ in range(depth):
 | |
|                 prev_tensors = [tensor for tensor in prev_values[:-1]
 | |
|                                 if tensor is not None]
 | |
|                 prev_values.append(y)
 | |
|                 prev_values.pop(0)
 | |
| 
 | |
|                 # Definitely pick one tensor to add
 | |
|                 y += y * 0.000001
 | |
| 
 | |
|                 # Possibly add other tensors
 | |
|                 nprev = len(prev_tensors)
 | |
|                 if nprev == 2:
 | |
|                     y += randchoice[depth].mul(torch.cat(prev_tensors)).sum()
 | |
| 
 | |
|             # graph deletion occurs when the above locals go out of scope.
 | |
| 
 | |
|         # Should not stack overflow
 | |
|         scope()
 | |
| 
 | |
|     def test_free_deep_graph_pyfunction(self):
 | |
|         class MyOp(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, tensor1, tensor2):
 | |
|                 return tensor1 + tensor2
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output):
 | |
|                 return grad_output, grad_output
 | |
| 
 | |
|         def scope():
 | |
|             depth = 150000
 | |
|             x = torch.randn(1, requires_grad=True)
 | |
|             y = x.clone()
 | |
| 
 | |
|             # build deeply nested computation graph
 | |
|             for _ in range(depth):
 | |
|                 y = MyOp.apply(y, y)
 | |
| 
 | |
|             # graph deletion occurs when the above locals go out of scope.
 | |
| 
 | |
|         # Should not stack overflow
 | |
|         scope()
 | |
| 
 | |
|     def test_no_unnecessary_save(self):
 | |
|         # If we kept x in the derivative Function of x * 2 we would
 | |
|         # get an error in the backward that would complain that we've
 | |
|         # modified x, which was needed for gradient computation.
 | |
|         # Since we should elide unnecessary saves, this test should pass.
 | |
|         mu = torch.ones(1, requires_grad=True)
 | |
|         x = torch.empty(1)
 | |
|         loss = 0
 | |
|         for i in range(3):
 | |
|             x.detach_()
 | |
|             x.copy_(mu + i)
 | |
|             ft = torch.tensor([float(i)])
 | |
|             multiplied = x * ft
 | |
|             s = multiplied.sum()
 | |
|             loss += s
 | |
|         loss.backward()
 | |
| 
 | |
|     def test_no_grad(self):
 | |
|         x = torch.ones(5, 5, requires_grad=True)
 | |
|         y = torch.ones(5, 5) * 4
 | |
|         with torch.no_grad():
 | |
|             w = x + y
 | |
| 
 | |
|         @torch.no_grad()
 | |
|         def adder(x, y):
 | |
|             return x + y
 | |
| 
 | |
|         z = adder(x, y)
 | |
| 
 | |
|         self.assertFalse(w.requires_grad)
 | |
|         self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
 | |
|         self.assertIsNone(w.grad_fn)
 | |
|         self.assertFalse(z.requires_grad)
 | |
|         self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5)))
 | |
|         self.assertIsNone(z.grad_fn)
 | |
| 
 | |
|         # test nested decorator and with-statement on no_grad
 | |
|         with torch.no_grad():
 | |
|             self.assertFalse(torch.is_grad_enabled())
 | |
|             w = adder(x, y)
 | |
|             self.assertFalse(torch.is_grad_enabled())
 | |
| 
 | |
|     def test_set_grad_generator_functions(self):
 | |
|         @torch.no_grad()
 | |
|         def gen_no_grad():
 | |
|             for i in range(10):
 | |
|                 self.assertEqual(torch.is_grad_enabled(), False)
 | |
|                 yield i
 | |
| 
 | |
|         with torch.enable_grad():
 | |
|             for _ in gen_no_grad():
 | |
|                 self.assertEqual(torch.is_grad_enabled(), True)
 | |
| 
 | |
|         @torch.enable_grad()
 | |
|         def gen_enable_grad():
 | |
|             for i in range(10):
 | |
|                 self.assertEqual(torch.is_grad_enabled(), True)
 | |
|                 yield i
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             for _ in gen_enable_grad():
 | |
|                 self.assertEqual(torch.is_grad_enabled(), False)
 | |
| 
 | |
|     def test_set_grad_generator_functions_recursive(self):
 | |
|         # enable_grad_decorator_recursive and no_grad_decorator_recursive call each other
 | |
|         # recursively, to ensure that the decorators preserve the caller's setting
 | |
|         @torch.enable_grad()
 | |
|         def enable_grad_decorator_recursive(depth):
 | |
|             self.assertTrue(torch.is_grad_enabled())
 | |
|             if depth > 0:
 | |
|                 no_grad_decorator_recursive(depth - 1)
 | |
|                 self.assertTrue(torch.is_grad_enabled())
 | |
| 
 | |
|         @torch.no_grad()
 | |
|         def no_grad_decorator_recursive(depth):
 | |
|             self.assertFalse(torch.is_grad_enabled())
 | |
|             if depth > 0:
 | |
|                 enable_grad_decorator_recursive(depth - 1)
 | |
|                 self.assertFalse(torch.is_grad_enabled())
 | |
| 
 | |
|         # enable_grad_context_manager_recursive and no_grad_context_manager_recursive call
 | |
|         # each other recursively, to ensure that the decorators preserve the caller's setting
 | |
|         def enable_grad_context_manager_recursive(depth):
 | |
|             with torch.enable_grad():
 | |
|                 self.assertTrue(torch.is_grad_enabled())
 | |
|                 if depth > 0:
 | |
|                     no_grad_context_manager_recursive(depth - 1)
 | |
|                     self.assertTrue(torch.is_grad_enabled())
 | |
| 
 | |
|         def no_grad_context_manager_recursive(depth):
 | |
|             with torch.no_grad():
 | |
|                 self.assertFalse(torch.is_grad_enabled())
 | |
|                 if depth > 0:
 | |
|                     enable_grad_context_manager_recursive(depth - 1)
 | |
|                     self.assertFalse(torch.is_grad_enabled())
 | |
| 
 | |
|         with torch.enable_grad():
 | |
|             self.assertTrue(torch.is_grad_enabled())
 | |
|             enable_grad_decorator_recursive(10)
 | |
|             self.assertTrue(torch.is_grad_enabled())
 | |
|             enable_grad_context_manager_recursive(10)
 | |
|             self.assertTrue(torch.is_grad_enabled())
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             self.assertFalse(torch.is_grad_enabled())
 | |
|             enable_grad_decorator_recursive(10)
 | |
|             self.assertFalse(torch.is_grad_enabled())
 | |
|             enable_grad_context_manager_recursive(10)
 | |
|             self.assertFalse(torch.is_grad_enabled())
 | |
| 
 | |
|     def test_set_grad_coroutines(self):
 | |
|         @torch.no_grad()
 | |
|         def coro_no_grad(n=10):
 | |
|             self.assertFalse(torch.is_grad_enabled())
 | |
|             for i in range(n):
 | |
|                 self.assertFalse(torch.is_grad_enabled())
 | |
|                 r = yield i
 | |
|                 self.assertFalse(torch.is_grad_enabled())
 | |
|                 self.assertEqual(i, r)
 | |
|             self.assertFalse(torch.is_grad_enabled())
 | |
| 
 | |
|         @torch.enable_grad()
 | |
|         def coro_enable_grad(n=10):
 | |
|             self.assertTrue(torch.is_grad_enabled())
 | |
|             for i in range(n):
 | |
|                 self.assertTrue(torch.is_grad_enabled())
 | |
|                 r = yield i
 | |
|                 self.assertTrue(torch.is_grad_enabled())
 | |
|                 self.assertEqual(i, r)
 | |
|             self.assertTrue(torch.is_grad_enabled())
 | |
| 
 | |
|         with torch.enable_grad():
 | |
|             self.assertTrue(torch.is_grad_enabled())
 | |
|             coro, r = coro_no_grad(), None
 | |
|             try:
 | |
|                 while True:
 | |
|                     self.assertTrue(torch.is_grad_enabled())
 | |
|                     r = coro.send(r)
 | |
|                     self.assertTrue(torch.is_grad_enabled())
 | |
| 
 | |
|             except StopIteration:
 | |
|                 pass
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             self.assertFalse(torch.is_grad_enabled())
 | |
|             coro, r = coro_enable_grad(), None
 | |
|             try:
 | |
|                 while True:
 | |
|                     self.assertFalse(torch.is_grad_enabled())
 | |
|                     r = coro.send(r)
 | |
|                     self.assertFalse(torch.is_grad_enabled())
 | |
| 
 | |
|             except StopIteration:
 | |
|                 pass
 | |
| 
 | |
|     def test_set_grad_coroutines_benign_exceptions(self):
 | |
|         class RecoverableException(Exception):
 | |
|             pass
 | |
| 
 | |
|         @torch.no_grad()
 | |
|         def coro_no_grad(n=10):
 | |
|             has_raised = False
 | |
|             for i in range(n):
 | |
|                 try:
 | |
|                     self.assertFalse(torch.is_grad_enabled())
 | |
|                     yield (-i if has_raised else i)
 | |
| 
 | |
|                 except RecoverableException:
 | |
|                     self.assertFalse(torch.is_grad_enabled())
 | |
|                     has_raised = True
 | |
| 
 | |
|         @torch.enable_grad()
 | |
|         def coro_enable_grad(n=10):
 | |
|             has_raised = False
 | |
|             for i in range(n):
 | |
|                 try:
 | |
|                     self.assertTrue(torch.is_grad_enabled())
 | |
|                     yield (-i if has_raised else i)
 | |
| 
 | |
|                 except RecoverableException:
 | |
|                     self.assertTrue(torch.is_grad_enabled())
 | |
|                     has_raised = True
 | |
| 
 | |
|         with torch.enable_grad():
 | |
|             coro = coro_no_grad()
 | |
|             assert 0 == next(coro)
 | |
|             try:
 | |
|                 while True:
 | |
|                     r = coro.throw(RecoverableException)
 | |
|                     self.assertLess(r, 0)
 | |
| 
 | |
|             except StopIteration:
 | |
|                 pass
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             coro = coro_enable_grad()
 | |
|             assert 0 == next(coro)
 | |
|             try:
 | |
|                 while True:
 | |
|                     r = coro.throw(RecoverableException)
 | |
|                     self.assertLess(r, 0)
 | |
| 
 | |
|             except StopIteration:
 | |
|                 pass
 | |
| 
 | |
|     def test_set_grad_coroutines_critical_exceptions(self):
 | |
|         class UnrecoverableException(Exception):
 | |
|             pass
 | |
| 
 | |
|         class SecondaryException(Exception):
 | |
|             pass
 | |
| 
 | |
|         @torch.no_grad()
 | |
|         def coro_no_grad(n=10):
 | |
|             has_raised = False
 | |
|             for i in range(n):
 | |
|                 try:
 | |
|                     self.assertFalse(torch.is_grad_enabled())
 | |
|                     yield (-i if has_raised else i)
 | |
| 
 | |
|                 except UnrecoverableException:
 | |
|                     self.assertFalse(torch.is_grad_enabled())
 | |
|                     raise SecondaryException
 | |
| 
 | |
|         @torch.enable_grad()
 | |
|         def coro_enable_grad(n=10):
 | |
|             has_raised = False
 | |
|             for i in range(n):
 | |
|                 try:
 | |
|                     self.assertTrue(torch.is_grad_enabled())
 | |
|                     yield (-i if has_raised else i)
 | |
| 
 | |
|                 except UnrecoverableException :
 | |
|                     self.assertTrue(torch.is_grad_enabled())
 | |
|                     raise SecondaryException
 | |
| 
 | |
|         with torch.enable_grad():
 | |
|             coro = coro_no_grad()
 | |
|             assert 0 == next(coro)
 | |
|             with self.assertRaises(SecondaryException):
 | |
|                 coro.throw(UnrecoverableException)
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             coro = coro_enable_grad()
 | |
|             assert 0 == next(coro)
 | |
|             with self.assertRaises(SecondaryException):
 | |
|                 coro.throw(UnrecoverableException)
 | |
| 
 | |
|     def test_set_grad_coroutines_exit(self):
 | |
|         @torch.no_grad()
 | |
|         def coro_no_grad(state):
 | |
|             for i in range(10):
 | |
|                 try:
 | |
|                     self.assertFalse(torch.is_grad_enabled())
 | |
|                     yield i
 | |
| 
 | |
|                 except GeneratorExit:
 | |
|                     self.assertFalse(torch.is_grad_enabled())
 | |
|                     state.add('GeneratorExit')
 | |
|                     raise
 | |
| 
 | |
|         @torch.enable_grad()
 | |
|         def coro_enable_grad(state):
 | |
|             for i in range(10):
 | |
|                 try:
 | |
|                     self.assertTrue(torch.is_grad_enabled())
 | |
|                     yield i
 | |
| 
 | |
|                 except GeneratorExit:
 | |
|                     self.assertTrue(torch.is_grad_enabled())
 | |
|                     state.add('GeneratorExit')
 | |
|                     raise
 | |
| 
 | |
|         state = set()
 | |
|         with torch.enable_grad():
 | |
|             coro = coro_no_grad(state)
 | |
|             for i in range(5):
 | |
|                 next(coro)
 | |
| 
 | |
|             coro.close()
 | |
|         self.assertTrue('GeneratorExit' in state)
 | |
| 
 | |
|         state = set()
 | |
|         with torch.no_grad():
 | |
|             coro = coro_enable_grad(state)
 | |
|             for i in range(5):
 | |
|                 next(coro)
 | |
| 
 | |
|             coro.close()
 | |
|         self.assertTrue('GeneratorExit' in state)
 | |
| 
 | |
|     def test_no_grad_python_function(self):
 | |
|         """Python Functions should respect grad mode."""
 | |
|         x = torch.ones(5, 5, requires_grad=True)
 | |
| 
 | |
|         class MyOp(Function):
 | |
|             @staticmethod
 | |
|             def forward(self, x):
 | |
|                 return x + 1
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(self, dy):
 | |
|                 return dy
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             y = MyOp.apply(x)
 | |
|         self.assertFalse(y.requires_grad)
 | |
| 
 | |
|     def test_indexing(self):
 | |
|         x = torch.arange(1., 17).view(4, 4)
 | |
|         y = Variable(x, requires_grad=True)
 | |
| 
 | |
|         def compare(x, y, idx, indexed_tensor, indexed_var):
 | |
|             indexed_var_t = indexed_var.data
 | |
|             if not isinstance(indexed_tensor, torch.Tensor):
 | |
|                 indexed_var_t = indexed_var_t[0]
 | |
|             self.assertEqual(indexed_tensor, indexed_var_t)
 | |
| 
 | |
|             indexed_var.sum().backward()
 | |
|             expected_grad = torch.empty(x.size()).fill_(0)
 | |
|             expected_grad[idx] = 1
 | |
|             self.assertEqual(y.grad, expected_grad)
 | |
| 
 | |
|         def check_index(x, y, idx):
 | |
|             if y.grad is not None:
 | |
|                 with torch.no_grad():
 | |
|                     y.grad.zero_()
 | |
|             indexed_tensor = x[idx]
 | |
|             indexed_var = y[idx]
 | |
|             compare(x, y, idx, indexed_tensor, indexed_var)
 | |
| 
 | |
|         check_index(x, y, 1)
 | |
|         check_index(x, y, (1, 1))
 | |
|         check_index(x, y, slice(1, None))
 | |
|         check_index(x, y, slice(None, 2))
 | |
|         check_index(x, y, (slice(None, 2), 2))
 | |
|         check_index(x, y, (slice(1, 2), 2))
 | |
|         check_index(x, y, (1, slice(2, None)))
 | |
|         check_index(x, y, (slice(None, None), slice(2, None)))
 | |
|         check_index(x, y, torch.LongTensor([0, 2]))
 | |
|         check_index(x, y, torch.rand(4, 4).bernoulli().bool())
 | |
|         check_index(x, y, (Ellipsis, slice(2, None)))
 | |
|         check_index(x, y, ([0], [0]))
 | |
|         check_index(x, y, ([1, 2, 3], [0]))
 | |
|         check_index(x, y, ([1, 2], [2, 1]))
 | |
|         check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 3]]))
 | |
|         check_index(x, y, ([slice(None), [2, 3]]))
 | |
|         check_index(x, y, ([[2, 3], slice(None)]))
 | |
| 
 | |
|         # advanced indexing, with less dim, or ellipsis
 | |
|         check_index(x, y, ([0]))
 | |
|         check_index(x, y, ([0], ))
 | |
| 
 | |
|         x = torch.arange(1., 49).view(4, 3, 4)
 | |
|         y = Variable(x, requires_grad=True)
 | |
| 
 | |
|         check_index(x, y, (slice(None), [0], [0]))
 | |
|         check_index(x, y, ([0], [0], slice(None)))
 | |
|         check_index(x, y, (slice(None), [0, 1, 2], [0]))
 | |
|         check_index(x, y, ([0, 1, 2], [0], slice(None)))
 | |
|         check_index(x, y, (slice(None), [1, 2], [2, 1]))
 | |
|         check_index(x, y, ([1, 2], [2, 1], slice(None)))
 | |
|         check_index(x, y, (slice(None), [[1, 2], [2, 0]], [[0, 1], [2, 3]]))
 | |
|         check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 2]], slice(None)))
 | |
|         check_index(x, y, (slice(None), slice(None), [2, 1]))
 | |
|         check_index(x, y, (slice(None), [2, 1], slice(None)))
 | |
|         check_index(x, y, ([2, 1], slice(None), slice(None)))
 | |
| 
 | |
|         # advanced indexing, with less dim, or ellipsis
 | |
|         check_index(x, y, ([0], ))
 | |
|         check_index(x, y, ([0], slice(None)))
 | |
|         check_index(x, y, ([0], Ellipsis))
 | |
|         check_index(x, y, ([1, 2], [0, 1]))
 | |
|         check_index(x, y, ([1, 2], [0, 1], Ellipsis))
 | |
|         check_index(x, y, (Ellipsis, [1, 2], [0, 1]))
 | |
| 
 | |
|         # advanced indexing, with a tensor wrapped in a variable
 | |
|         z = torch.LongTensor([0, 1])
 | |
|         zv = Variable(z, requires_grad=False)
 | |
|         seq = [z, Ellipsis]
 | |
|         seqv = [zv, Ellipsis]
 | |
| 
 | |
|         if y.grad is not None:
 | |
|             with torch.no_grad():
 | |
|                 y.grad.zero_()
 | |
|         indexed_tensor = x[seq]
 | |
|         indexed_var = y[seqv]
 | |
|         compare(x, y, seq, indexed_tensor, indexed_var)
 | |
| 
 | |
|     def test_indexing_duplicates(self):
 | |
|         x = torch.arange(1., 17).view(4, 4)
 | |
|         y = Variable(x, requires_grad=True)
 | |
| 
 | |
|         idx = torch.LongTensor([1, 1, 3, 2, 1, 2])
 | |
|         y[idx].sum().backward()
 | |
|         expected_grad = torch.zeros(4, 4)
 | |
|         for i in idx:
 | |
|             expected_grad[i] += 1
 | |
|         self.assertEqual(y.grad, expected_grad)
 | |
| 
 | |
|         # with advanced indexing
 | |
|         x = torch.arange(1., 17).view(4, 4)
 | |
|         y = Variable(x, requires_grad=True)
 | |
| 
 | |
|         idx = [[1, 1, 3, 2, 1, 2], [0]]
 | |
|         y[idx].sum().backward()
 | |
|         expected_grad = torch.zeros(4, 4)
 | |
|         for i in idx[0]:
 | |
|             for j in idx[1]:
 | |
|                 expected_grad[i][j] += 1
 | |
| 
 | |
|         self.assertEqual(y.grad, expected_grad)
 | |
| 
 | |
|         x = torch.arange(1., 17).view(4, 4)
 | |
|         y = Variable(x, requires_grad=True)
 | |
|         idx = [[[1, 2], [0, 0]], [[0, 1], [1, 1]]]
 | |
|         y[idx].sum().backward()
 | |
|         expected_grad = torch.tensor([[0., 2., 0., 0.],
 | |
|                                       [1., 0., 0., 0.],
 | |
|                                       [0., 1., 0., 0.],
 | |
|                                       [0., 0., 0., 0.]])
 | |
|         self.assertEqual(y.grad, expected_grad)
 | |
| 
 | |
|         x = torch.arange(1., 65).view(4, 4, 4)
 | |
|         y = Variable(x, requires_grad=True)
 | |
| 
 | |
|         idx = [[1, 1, 1], slice(None), slice(None)]
 | |
|         y[idx].sum().backward()
 | |
|         expected_grad = torch.empty(4, 4, 4).zero_()
 | |
|         expected_grad[1].fill_(3)
 | |
|         self.assertEqual(y.grad, expected_grad)
 | |
| 
 | |
|     def test_index_backward_does_not_save_tensor(self):
 | |
|         # Example from https://github.com/pytorch/pytorch/issues/24853.
 | |
|         # if `index(tensor, indices)` saves `tensor` for backwards, then it will
 | |
|         # trigger a version check on `tensor` during the backward pass, which
 | |
|         # will cause the following code to error because `tensor` gets modified
 | |
|         # by the indexing line.
 | |
|         a = torch.tensor([1., 0, 0])
 | |
|         b = torch.zeros(3, requires_grad=True)
 | |
|         tensor = b + 0
 | |
|         tensor[a != 0] = tensor[a != 0]
 | |
|         tensor.backward(torch.zeros_like(tensor))
 | |
| 
 | |
|     def test_volatile_deprecated(self):
 | |
|         v = torch.autograd.torch.randn(3, 3)
 | |
|         with warnings.catch_warnings(record=True) as w:
 | |
|             self.assertFalse(v.volatile)
 | |
|         self.assertIn('volatile', str(w[0].message))
 | |
| 
 | |
|     def test_saved_variables_deprecated(self):
 | |
|         class MyFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, tensor1, tensor2):
 | |
|                 ctx.save_for_backward(tensor1, tensor2)
 | |
|                 return tensor1 + tensor2
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output):
 | |
|                 var1, var2 = ctx.saved_variables
 | |
|                 return (grad_output, grad_output)
 | |
| 
 | |
|         with warnings.catch_warnings(record=True) as warns:
 | |
|             warnings.simplefilter("always")
 | |
|             x = torch.randn((3, 3), requires_grad=True)
 | |
|             y = torch.randn((3, 3), requires_grad=True)
 | |
|             MyFunction.apply(x, y).sum().backward()
 | |
| 
 | |
|             has_deprecated = ('deprecated' in str(warn) and
 | |
|                               'saved_variables' in str(warn) for warn in warns)
 | |
|             has_deprecated = reduce(lambda x, y: x or y, has_deprecated)
 | |
|             self.assertTrue(has_deprecated)
 | |
| 
 | |
|     def test_requires_grad(self):
 | |
|         x = torch.randn(5, 5)
 | |
|         y = torch.randn(5, 5)
 | |
|         z = torch.randn(5, 5, requires_grad=True)
 | |
|         a = x + y
 | |
|         self.assertFalse(a.requires_grad)
 | |
|         b = a + z
 | |
|         self.assertTrue(b.requires_grad)
 | |
| 
 | |
|         def error():
 | |
|             raise RuntimeError
 | |
|         # Make sure backward isn't called on these
 | |
|         a._backward_hooks = OrderedDict()
 | |
|         x._backward_hooks = OrderedDict()
 | |
|         y._backward_hooks = OrderedDict()
 | |
|         a._backward_hooks['test'] = error
 | |
|         x._backward_hooks['test'] = error
 | |
|         y._backward_hooks['test'] = error
 | |
|         b.backward(torch.ones(5, 5))
 | |
| 
 | |
|     def test_requires_grad_(self):
 | |
|         x = torch.randn(5, 5)
 | |
|         y = torch.randn(5, 5, requires_grad=True)
 | |
|         self.assertIs(x, x.requires_grad_())
 | |
|         self.assertTrue(x.requires_grad)
 | |
|         self.assertIs(y, y.requires_grad_())
 | |
|         self.assertTrue(y.requires_grad)
 | |
|         self.assertIs(x, x.requires_grad_(True))
 | |
|         self.assertTrue(x.requires_grad)
 | |
|         self.assertIs(y, y.requires_grad_(True))
 | |
|         self.assertTrue(y.requires_grad)
 | |
|         z = x * y
 | |
|         self.assertRaises(RuntimeError, lambda: z.requires_grad_(False))
 | |
|         self.assertIs(z, z.requires_grad_())
 | |
|         self.assertTrue(z.requires_grad)
 | |
|         self.assertIs(z, z.requires_grad_(True))
 | |
|         self.assertTrue(z.requires_grad)
 | |
| 
 | |
|         self.assertIs(x, x.requires_grad_(False))
 | |
|         self.assertFalse(x.requires_grad)
 | |
|         self.assertIs(y, y.requires_grad_(False))
 | |
|         self.assertFalse(y.requires_grad)
 | |
| 
 | |
|     def test_requires_grad_inplace(self):
 | |
|         a = torch.randn(5, 5)
 | |
|         b = torch.randn(5, 5, requires_grad=True)
 | |
|         a += b
 | |
|         self.assertTrue(a.requires_grad)
 | |
| 
 | |
|         # non-leaf
 | |
|         a = torch.randn(5, 5) + 0
 | |
|         b = torch.randn(5, 5, requires_grad=True)
 | |
|         a += b
 | |
|         self.assertTrue(a.requires_grad)
 | |
| 
 | |
|     def test_no_requires_grad_inplace(self):
 | |
|         # basic case, should be able to modify inplace while requires_grad is False
 | |
|         a = torch.randn(2, 3)
 | |
|         a.add_(5)
 | |
|         a.requires_grad = True
 | |
|         a.sum().backward()
 | |
|         self.assertEqual(a.grad, torch.ones(2, 3))
 | |
| 
 | |
|         # same but with a view
 | |
|         a = torch.randn(2, 3)
 | |
|         b = a[:]
 | |
|         b.add_(5)
 | |
|         a.requires_grad = True
 | |
|         a.sum().backward()
 | |
|         self.assertEqual(a.grad, torch.ones(2, 3))
 | |
| 
 | |
|         # should fail if requires_grad = True when we modify inplace
 | |
|         a = torch.randn(2, 3)
 | |
|         b = a[:]
 | |
|         a.requires_grad = True
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             a.add_(5)
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             b.add_(5)
 | |
| 
 | |
|     def test_attribute_deletion(self):
 | |
|         x = torch.randn((5, 5), requires_grad=True)
 | |
|         del x.grad
 | |
|         self.assertIsNone(x.grad)
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             del x.data
 | |
|         with self.assertRaises(TypeError):
 | |
|             x.data = None
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             del x.requires_grad
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             del x._grad_fn
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             del x._backward_hooks
 | |
| 
 | |
|     def test_duplicate_backward_root(self):
 | |
|         a = torch.randn(5, 5, requires_grad=True)
 | |
|         b = torch.randn(5, 5, requires_grad=True)
 | |
| 
 | |
|         x = a * b
 | |
|         grad_output = torch.randn_like(x)
 | |
|         torch.autograd.backward([x, x], [grad_output, grad_output])
 | |
| 
 | |
|         self.assertEqual(a.grad, b * grad_output * 2)
 | |
|         self.assertEqual(b.grad, a * grad_output * 2)
 | |
| 
 | |
|     def test_backward_no_grad(self):
 | |
|         a = torch.randn(5, 5, requires_grad=True)
 | |
|         b = a + 2
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             torch.autograd.backward([b], [None])
 | |
| 
 | |
|     def test_backward_twice_with_saved_values(self):
 | |
|         b = torch.randn(3, requires_grad=True, dtype=torch.double)
 | |
|         c = torch.zeros(3, dtype=torch.double)
 | |
|         c[[1, 2]] = b[[1, 1]]
 | |
|         c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
 | |
|         self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True',
 | |
|                                lambda: c.backward(torch.tensor([1, 1, 1], dtype=torch.double)))
 | |
| 
 | |
|     def test_backward_twice_retained_graph_with_saved_values(self):
 | |
|         b = torch.randn(3, requires_grad=True, dtype=torch.double)
 | |
|         c = torch.zeros(3, dtype=torch.double)
 | |
|         c[[1, 2]] = b[[1, 1]]
 | |
|         c.backward(torch.tensor([1, 1, 1], dtype=torch.double), retain_graph=True)
 | |
|         c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
 | |
| 
 | |
|     def test_backward_twice_without_saved_values(self):
 | |
|         b = torch.randn(3, requires_grad=True, dtype=torch.double)
 | |
|         c = b + 1
 | |
|         c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
 | |
|         c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
 | |
| 
 | |
|     def test_backward_twice_retained_graph_without_saved_values(self):
 | |
|         b = torch.randn(3, requires_grad=True, dtype=torch.double)
 | |
|         c = torch.zeros(3, dtype=torch.double)
 | |
|         c[[1, 2]] = b[[1, 1]]
 | |
|         c.backward(torch.tensor([1, 1, 1], dtype=torch.double), retain_graph=True)
 | |
|         c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
 | |
| 
 | |
|     def test_backward_create_graph_warns(self):
 | |
|         with set_warn_always_context(True):
 | |
|             b = torch.randn(3, requires_grad=True, dtype=torch.double)
 | |
|             c = b * b
 | |
|             with warnings.catch_warnings(record=True) as ws:
 | |
|                 c.backward(torch.ones_like(c), create_graph=True)
 | |
|             b.grad = None
 | |
|             self.assertTrue(any('Using backward() with create_graph=True' in str(w.message) for w in ws))
 | |
| 
 | |
|             # Should not warn for grad
 | |
|             with warnings.catch_warnings(record=True) as ws:
 | |
|                 torch.autograd.grad(c, b, torch.ones_like(c), create_graph=True)
 | |
|             self.assertFalse(any('Using backward() with create_graph=True' in str(w.message) for w in ws))
 | |
| 
 | |
|     def test_next_functions(self):
 | |
|         x = torch.randn(5, 5, requires_grad=True)
 | |
|         y = torch.randn(5, 5, requires_grad=True)
 | |
| 
 | |
|         a = x + y
 | |
|         self.assertIsNotNone(a.grad_fn)
 | |
|         next_functions = a.grad_fn.next_functions
 | |
|         self.assertEqual(len(next_functions), 2)
 | |
|         self.assertIsInstance(next_functions[0][0], torch._C._functions.AccumulateGrad)
 | |
|         self.assertEqual(next_functions[0][1], 0)
 | |
|         self.assertIsInstance(next_functions[1][0], torch._C._functions.AccumulateGrad)
 | |
|         self.assertEqual(next_functions[1][1], 0)
 | |
| 
 | |
|         b = a + 5
 | |
|         next_functions = b.grad_fn.next_functions
 | |
|         self.assertEqual(len(next_functions), 2)
 | |
|         self.assertIs(next_functions[0][0], a.grad_fn)
 | |
|         self.assertIs(next_functions[1][0], None)
 | |
| 
 | |
|     def test_inplace(self):
 | |
|         x = torch.ones(5, 5, requires_grad=True)
 | |
|         y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
 | |
| 
 | |
|         z = x * y
 | |
|         q = z + y
 | |
|         w = z * y
 | |
|         z.add_(2)
 | |
|         # Add doesn't need it's inputs to do backward, so it shouldn't raise
 | |
|         q.backward(torch.ones(5, 5), retain_graph=True)
 | |
|         # Mul saves both inputs in forward, so it should raise
 | |
|         self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
 | |
| 
 | |
|         z = x * y
 | |
|         q = z * y
 | |
|         r = z + y
 | |
|         w = z.add_(y)
 | |
|         # w is a the last expression, so this should succeed
 | |
|         w.backward(torch.ones(5, 5), retain_graph=True)
 | |
|         # r doesn't use the modified value in backward, so it should succeed
 | |
|         r.backward(torch.ones(5, 5), retain_graph=True)
 | |
|         # q uses dirty z, so it should raise
 | |
|         self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5)))
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             x.grad.zero_()
 | |
|         m = x / 2
 | |
|         z = m + y / 8
 | |
|         q = z * y
 | |
|         r = z + y
 | |
|         prev_version = z._version
 | |
|         w = z.exp_()
 | |
|         self.assertNotEqual(z._version, prev_version)
 | |
|         r.backward(torch.ones(5, 5), retain_graph=True)
 | |
|         self.assertEqual(x.grad, torch.ones(5, 5) / 2)
 | |
|         w.backward(torch.ones(5, 5), retain_graph=True)
 | |
|         self.assertEqual(x.grad, torch.empty(5, 5).fill_((1 + math.e) / 2))
 | |
|         self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5)))
 | |
| 
 | |
|         leaf = torch.ones(5, 5, requires_grad=True)
 | |
|         x = leaf.clone()
 | |
|         x.add_(10)
 | |
|         self.assertEqual(x, torch.ones(5, 5) * 11)
 | |
|         # x should be still usable
 | |
|         y = x + 2
 | |
|         y.backward(torch.ones(5, 5))
 | |
|         self.assertEqual(leaf.grad, torch.ones(5, 5))
 | |
|         z = x * y
 | |
|         x.add_(2)
 | |
|         self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5)))
 | |
| 
 | |
|     def test_mark_non_differentiable(self):
 | |
|         class MyFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, input):
 | |
|                 output = input > 0
 | |
|                 ctx.mark_non_differentiable(output)
 | |
|                 return output
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output):
 | |
|                 return (grad_output * 0).to(torch.double)
 | |
| 
 | |
|         x = torch.randn(5, 5, requires_grad=True)
 | |
|         mask = MyFunction.apply(x)
 | |
|         self.assertFalse(mask.requires_grad)
 | |
|         y = x.masked_fill(mask, 0)
 | |
|         y.sum().backward()
 | |
| 
 | |
|     def test_mark_non_differentiable_mixed(self):
 | |
|         class MyFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, input):
 | |
|                 a = input + 1
 | |
|                 b = input + 2
 | |
|                 ctx.mark_non_differentiable(a)
 | |
|                 return a, b
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_a, grad_b):
 | |
|                 self.assertTrue((grad_a == 0).all())
 | |
|                 self.assertTrue((grad_b == 1).all())
 | |
|                 return grad_b
 | |
| 
 | |
|         x = torch.randn(5, 5, requires_grad=True)
 | |
|         a, b = MyFunction.apply(x)
 | |
|         self.assertFalse(a.requires_grad)
 | |
|         self.assertTrue(b.requires_grad)
 | |
|         b.sum().backward()
 | |
|         self.assertEqual(x.grad, torch.ones(5, 5))
 | |
| 
 | |
|     def test_mark_non_differentiable_none(self):
 | |
|         # This used to segfault because MyFunction would send back null
 | |
|         # gradients to MulBackward, which is implemented in C++. C++
 | |
|         # implemented functions expect incoming grad_outputs to be non-null.
 | |
|         class MyFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, input):
 | |
|                 output = input.clone()
 | |
|                 ctx.mark_non_differentiable(output)
 | |
|                 return output
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output):
 | |
|                 return None
 | |
| 
 | |
|         x = torch.randn(5, 5, requires_grad=True)
 | |
|         r = MyFunction.apply(x * x)
 | |
|         (r * x).sum().backward()
 | |
| 
 | |
|     def test_return_duplicate(self):
 | |
|         class DoubleDuplicate(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 output = x * 2
 | |
|                 return output, output
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad1, grad2):
 | |
|                 return grad1 * 2 + grad2 * 2
 | |
| 
 | |
|         def fn(x):
 | |
|             a, b = DoubleDuplicate.apply(x)
 | |
|             self.assertIs(a, b)
 | |
|             return a + b
 | |
| 
 | |
|         x = torch.randn(5, 5, dtype=torch.double, requires_grad=True)
 | |
|         gradcheck(fn, [x])
 | |
|         gradgradcheck(fn, [x])
 | |
| 
 | |
|     def test_return_duplicate_inplace(self):
 | |
|         class DoubleInplace(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 x.mul_(2)
 | |
|                 ctx.mark_dirty(x)
 | |
|                 return x, x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad1, grad2):
 | |
|                 return grad1 * 2 + grad2 * 2
 | |
| 
 | |
|         def inplace_fn(x):
 | |
|             a, b = DoubleInplace.apply(x.clone())
 | |
|             self.assertIs(a, b)
 | |
|             return a + b
 | |
| 
 | |
|         x = torch.randn(5, 5, dtype=torch.double, requires_grad=True)
 | |
|         gradcheck(inplace_fn, [x])
 | |
|         gradgradcheck(inplace_fn, [x])
 | |
| 
 | |
|         # Can't modify leaf variables in-place
 | |
|         self.assertRaises(RuntimeError, lambda: InplaceFunction.apply(x))
 | |
|         # Functions which modify views in-place must return only one output
 | |
|         self.assertRaises(RuntimeError, lambda: InplaceFunction.apply(x.clone()[0]))
 | |
| 
 | |
|     def _test_setitem(self, size, index):
 | |
|         x = torch.ones(*size, requires_grad=True)
 | |
|         y = x + 2
 | |
|         y_version = y._version
 | |
|         y[index] = 2
 | |
|         self.assertNotEqual(y._version, y_version)
 | |
|         y.backward(torch.ones(*size))
 | |
|         expected_grad = torch.ones(*size)
 | |
|         expected_grad[index] = 0
 | |
|         self.assertEqual(x.grad, expected_grad)
 | |
| 
 | |
|     def _test_setitem_tensor(self, size, index):
 | |
|         x = torch.ones(*size, requires_grad=True)
 | |
|         y = x + 2
 | |
|         y_version = y._version
 | |
|         value = x.new(x[index].size()).fill_(7)
 | |
|         value.requires_grad = True
 | |
|         y[index] = value
 | |
|         self.assertNotEqual(y._version, y_version)
 | |
|         y.backward(torch.ones(*size))
 | |
|         expected_grad_input = torch.ones(*size)
 | |
|         expected_grad_input[index] = 0
 | |
|         self.assertEqual(x.grad, expected_grad_input)
 | |
|         self.assertEqual(value.grad, torch.ones_like(value))
 | |
| 
 | |
|         # case when x broadcasts to as y[1]
 | |
|         x = torch.randn(4, requires_grad=True)
 | |
|         y = torch.zeros(2, 3, 4)
 | |
|         y[1] = x
 | |
|         y.backward(torch.randn(2, 3, 4))
 | |
|         self.assertEqual(x.size(), x.grad.size())
 | |
| 
 | |
|     def test_setitem(self):
 | |
|         self._test_setitem((5, 5), 1)
 | |
|         self._test_setitem((5,), 1)
 | |
|         self._test_setitem((1,), 0)
 | |
|         self._test_setitem((10,), [[0, 4, 2]])
 | |
|         self._test_setitem((5, 5), [[0, 4], [2, 2]])
 | |
|         self._test_setitem((5, 5, 5), [slice(None), slice(None), [1, 3]])
 | |
|         self._test_setitem((5, 5, 5), [slice(None), [1, 3], slice(None)])
 | |
|         self._test_setitem((5, 5, 5), [[1, 3], slice(None), slice(None)])
 | |
|         self._test_setitem((5, 5, 5), [slice(None), [2, 4], [1, 3]])
 | |
|         self._test_setitem((5, 5, 5), [[1, 3], [2, 4], slice(None)])
 | |
|         self._test_setitem_tensor((5, 5), 3)
 | |
|         self._test_setitem_tensor((5, 5), [[0, 1], [1, 0]])
 | |
|         self._test_setitem_tensor((5,), 3)
 | |
|         self._test_setitem_tensor((5,), Variable(torch.LongTensor([3]), requires_grad=False).sum())
 | |
|         self._test_setitem_tensor((5,), [[0, 1, 2, 3]])
 | |
|         self._test_setitem_tensor((5, 5, 5), [slice(None), slice(None), [1, 3]])
 | |
|         self._test_setitem_tensor((5, 5, 5), [slice(None), [1, 3], slice(None)])
 | |
|         self._test_setitem_tensor((5, 5, 5), [[1, 3], slice(None), slice(None)])
 | |
|         self._test_setitem_tensor((5, 5, 5), [slice(None), [2, 4], [1, 3]])
 | |
|         self._test_setitem_tensor((5, 5, 5), [[1, 3], [2, 4], slice(None)])
 | |
|         self._test_setitem_tensor((5, 5, 5), [Variable(torch.LongTensor([1,
 | |
|                                               3]), requires_grad=False), [2, 4], slice(None)])
 | |
| 
 | |
|     def test_setitem_mask(self):
 | |
|         mask = torch.BoolTensor(5, 5).bernoulli_()
 | |
|         self._test_setitem((5, 5), Variable(mask))
 | |
|         self._test_setitem((5,), Variable(mask[0]))
 | |
|         self._test_setitem((1,), Variable(mask[0, 0:1]))
 | |
|         self._test_setitem_tensor((5, 5), Variable(mask))
 | |
|         self._test_setitem_tensor((5,), Variable(mask[0]))
 | |
| 
 | |
|     def test_select_sum(self):
 | |
|         # both select and sum return Scalars in ATen; ensure they work together.
 | |
|         x = torch.randn(10, dtype=torch.double, requires_grad=True)
 | |
| 
 | |
|         def func(x):
 | |
|             return x.select(0, 1).sum()
 | |
| 
 | |
|         gradcheck(func, [x])
 | |
|         gradgradcheck(func, [x])
 | |
| 
 | |
|     def test_diagonal_expanded_v(self):
 | |
|         value = torch.rand([])
 | |
|         v_expanded = torch.tensor(value).expand(10)
 | |
|         a = torch.rand(10, 10, dtype=torch.double, requires_grad=True)
 | |
|         result, = torch.autograd.grad(a.diagonal(), a, v_expanded)
 | |
|         self.assertEqual(result, torch.eye(10, dtype=torch.double) * value)
 | |
| 
 | |
|     def test_select_expanded_v(self):
 | |
|         v_expanded = torch.rand(10).expand(10, 10)
 | |
|         a = torch.rand(10, 10, 10, requires_grad=True)
 | |
|         result, = torch.autograd.grad(a[0], a, v_expanded)
 | |
|         expected = torch.zeros(10, 10, 10)
 | |
|         expected[0] = v_expanded
 | |
|         self.assertEqual(result, expected)
 | |
| 
 | |
|     def test_slice_expanded_v(self):
 | |
|         v_expanded = torch.rand(10, 1).expand(2, 10, 10)
 | |
|         a = torch.rand(10, 10, 10, requires_grad=True)
 | |
|         result, = torch.autograd.grad(a[3:5], a, v_expanded)
 | |
|         expected = torch.zeros(10, 10, 10)
 | |
|         expected[3:5] = v_expanded
 | |
|         self.assertEqual(result, expected)
 | |
| 
 | |
|     def test_unused_output(self):
 | |
|         x = torch.randn(10, 10, requires_grad=True)
 | |
|         outputs = x.chunk(5)
 | |
|         o = outputs[2]
 | |
|         o = o * 4 + 2
 | |
|         o.sum().backward()
 | |
|         expected_grad = torch.zeros(10, 10)
 | |
|         expected_grad[4:6] = 4
 | |
|         self.assertEqual(x.grad, expected_grad)
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             x.grad.zero_()
 | |
|         grad_output = torch.randn(2, 10)
 | |
|         outputs = x.chunk(5)
 | |
|         outputs[0].backward(grad_output)
 | |
|         expected_grad = torch.zeros(10, 10)
 | |
|         expected_grad[:2] = grad_output
 | |
|         self.assertEqual(x.grad, expected_grad)
 | |
| 
 | |
|     # TODO: opinfo this or move to the sparse test suite
 | |
|     def _test_sparse_gather(self, size_x, size_ind, dim):
 | |
|         x = torch.randn(size_x, requires_grad=True)
 | |
|         if len(size_ind) > 0 and len(size_x) > 0:
 | |
|             ind = torch.randint(x.size(dim), size_ind)
 | |
|         else:
 | |
|             ind = torch.zeros(size_ind, dtype=torch.int64)
 | |
|         out = torch.gather(x, dim, ind, sparse_grad=False)
 | |
|         grad = torch.rand_like(out)
 | |
|         out.backward(grad)
 | |
|         grad_dense = x.grad.clone()
 | |
|         x.grad = None
 | |
|         out = torch.gather(x, dim, ind, sparse_grad=True)
 | |
|         out.backward(grad)
 | |
|         self.assertEqual(grad_dense, x.grad.to_dense())
 | |
| 
 | |
|     def test_sparse_gather_dim0(self):
 | |
|         self._test_sparse_gather((10, 10), (5, 10), 0)
 | |
| 
 | |
|     def test_sparse_gather_dim1(self):
 | |
|         self._test_sparse_gather((10, 10, 5), (10, 5, 5), 1)
 | |
| 
 | |
|     def test_sparse_gather_dim_neg(self):
 | |
|         self._test_sparse_gather((10, 10, 5), (10, 10, 2), -1)
 | |
| 
 | |
|     def test_sparse_gather_ind_scalar(self):
 | |
|         self._test_sparse_gather((10,), (), 0)
 | |
| 
 | |
|     def test_sparse_gather_x_scalar(self):
 | |
|         self._test_sparse_gather((), (2,), 0)
 | |
| 
 | |
|     def test_sparse_gather_both_scalar(self):
 | |
|         self._test_sparse_gather((), (), 0)
 | |
| 
 | |
|     def test_gc_in_destructor(self):
 | |
|         """
 | |
|         Previously, if a Function destructor triggered a garbage collection,
 | |
|         the Variable's tp_dealloc handler would get called twice leading to a
 | |
|         segfault.
 | |
|         """
 | |
|         class CollectOnDelete(Function):
 | |
|             def forward(self, x):
 | |
|                 return x
 | |
| 
 | |
|             def backward(self, grad_output):
 | |
|                 return grad_output
 | |
| 
 | |
|             def __del__(self):
 | |
|                 gc.collect()
 | |
| 
 | |
|         for _ in range(10):
 | |
|             CollectOnDelete().forward(torch.randn(1, requires_grad=True)).backward()
 | |
| 
 | |
|     def test_naughty_autograd_function_attribute_access(self):
 | |
|         class Id(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_x):
 | |
|                 return grad_x
 | |
| 
 | |
|         with self.assertWarnsRegex(DeprecationWarning, "should not be instantiated"):
 | |
|             f = Id()
 | |
| 
 | |
|         # After raising warning, should still return an instance
 | |
|         self.assertIsInstance(f, Id)
 | |
|         x = torch.zeros(1, requires_grad=True)
 | |
|         with self.assertRaisesRegex(RuntimeError, "non-static forward method is deprecated"):
 | |
|             f(x)
 | |
|         t = Id.apply(x)
 | |
|         self.assertEqual(t.grad_fn.name(), "IdBackward")
 | |
| 
 | |
|         # THPFunction is the base class of both grad_fn and autograd functions,
 | |
|         # which means that a lot of accessors on them may segfault. Test that we
 | |
|         # properly error in this case.
 | |
|         t = torch.ones(1, requires_grad=True)
 | |
|         t._backward_hooks = {}
 | |
|         with self.assertRaisesRegex(RuntimeError, "Attribute '_register_hook_dict' is invalid"):
 | |
|             f._register_hook_dict(t)
 | |
|         with self.assertRaisesRegex(RuntimeError, "Attribute 'register_hook' is invalid"):
 | |
|             f.register_hook(lambda x, y: None)
 | |
|         with self.assertRaisesRegex(RuntimeError, "Attribute 'next_functions' is invalid"):
 | |
|             f.next_functions
 | |
|         with self.assertRaisesRegex(RuntimeError, "Attribute 'name' is invalid"):
 | |
|             f.name()
 | |
|         with self.assertRaisesRegex(RuntimeError, "underlying PyNode has already been deallocated"):
 | |
|             f.metadata
 | |
| 
 | |
|     @unittest.expectedFailure
 | |
|     def test_naughty_anomaly_access(self):
 | |
|         class MyFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, g):
 | |
|                 return g
 | |
| 
 | |
|         x = torch.zeros(1, requires_grad=True)
 | |
|         y = MyFunction.apply(x)
 | |
|         y.backward()
 | |
|         y.grad_fn.metadata
 | |
|         g = y.grad_fn
 | |
|         del y
 | |
|         g.metadata  # this currently fails, but shouldn't
 | |
| 
 | |
|     def test_naughty_autograd_function_stashing_ctx(self):
 | |
|         saved_ctx = []
 | |
| 
 | |
|         class Id(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 ctx.save_for_backward(x)
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_x):
 | |
|                 saved_ctx.append(ctx)
 | |
|                 return ctx.saved_tensors
 | |
| 
 | |
|         p = torch.zeros(1, requires_grad=True)
 | |
|         loss = Id.apply(p)
 | |
|         loss.backward(retain_graph=True)
 | |
|         del loss
 | |
|         # At this point in time, it complains that the graph has been freed
 | |
|         # (which indeed true, although a somewhat indirect way of stating the
 | |
|         # problem).
 | |
|         self.assertRaises(RuntimeError, lambda: saved_ctx[0].saved_tensors)
 | |
| 
 | |
|     def test_custom_autograd_repeated_grad_grad(self):
 | |
|         # This test failed the equality check in PR #22983; it's an interesting
 | |
|         # and different test case worth enshrining.  mult1 is not testing
 | |
|         # anything that interesting, but mult2 is the interesting case.
 | |
| 
 | |
|         def mult1(x):
 | |
|             return x.prod(dim=-1).prod(dim=-1)
 | |
| 
 | |
|         class Mult(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 y = mult1(x)
 | |
|                 ctx.save_for_backward(x, y)
 | |
|                 return y
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output):
 | |
|                 x, y = ctx.saved_tensors
 | |
|                 return (grad_output * y)[:, None, None] / x
 | |
| 
 | |
|         mult2 = Mult.apply
 | |
| 
 | |
|         def check_gradgrad_repeated(x, y):
 | |
|             gy, = torch.autograd.grad(y[0], x, create_graph=True)
 | |
|             ggy_1, = torch.autograd.grad(gy[0, 0, 0], x, retain_graph=True)
 | |
|             gy, = torch.autograd.grad(y[0], x, create_graph=True)
 | |
|             ggy_2, = torch.autograd.grad(gy[0, 0, 0], x, retain_graph=True)
 | |
|             self.assertEqual(ggy_1[0, 0, 1], ggy_2[0, 0, 1])
 | |
| 
 | |
|         x = torch.ones(2, 4, 4).requires_grad_()
 | |
|         check_gradgrad_repeated(x, mult1(x))
 | |
|         check_gradgrad_repeated(x, mult2(x))
 | |
| 
 | |
|     def test_custom_autograd_no_early_free(self):
 | |
|         # This test failed complaining that buffers had already been freed
 | |
|         # prior to #22983.  Also pretty interesting test case.
 | |
|         class Double(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 y = x ** 2
 | |
|                 ctx.save_for_backward(x, y)
 | |
|                 return y
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output):
 | |
|                 x, _ = ctx.saved_tensors
 | |
|                 return grad_output * 2 * x
 | |
| 
 | |
|         # this is equivalent, but uses the output of .forward() in .backward()
 | |
|         class Double2(Double):
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output):
 | |
|                 x, y = ctx.saved_tensors
 | |
|                 return grad_output * 2 * y / x
 | |
| 
 | |
|         double = Double.apply
 | |
|         double2 = Double2.apply
 | |
| 
 | |
|         x = torch.tensor(2).double().requires_grad_()
 | |
| 
 | |
|         self.assertTrue(gradcheck(double, x))
 | |
|         self.assertTrue(gradgradcheck(double, x))
 | |
|         self.assertTrue(gradcheck(double2, x))
 | |
|         self.assertTrue(gradgradcheck(double2, x))
 | |
| 
 | |
|         y = double(x)
 | |
|         torch.autograd.grad(y, x, create_graph=True)
 | |
|         torch.autograd.grad(y, x)
 | |
| 
 | |
|         y = double2(x)
 | |
|         torch.autograd.grad(y, x, create_graph=True)
 | |
|         torch.autograd.grad(y, x)  # should not error!
 | |
| 
 | |
|     def test_detach(self):
 | |
|         x = torch.randn(10, 10, requires_grad=True)
 | |
|         y = x + 2
 | |
|         y = y.detach()
 | |
|         z = y * 4 + 2
 | |
|         self.assertFalse(y.requires_grad)
 | |
|         self.assertFalse(z.requires_grad)
 | |
| 
 | |
|         x = torch.randn(10, 10, requires_grad=True)
 | |
|         y = x * 2
 | |
|         y = y.detach()
 | |
|         self.assertFalse(y.requires_grad)
 | |
|         self.assertIsNone(y.grad_fn)
 | |
|         z = x + y
 | |
|         z.sum().backward()
 | |
|         # This is an incorrect gradient, but we assume that's what the user
 | |
|         # wanted. detach() is an advanced option.
 | |
|         self.assertEqual(x.grad, torch.ones(10, 10))
 | |
| 
 | |
|         # in-place detach
 | |
|         x = torch.randn(10, 10, requires_grad=True)
 | |
|         y = torch.randn(10, 10, requires_grad=True)
 | |
|         a = x * 2
 | |
|         (y + a).sum().backward(retain_graph=True)
 | |
|         a.detach_()
 | |
|         self.assertFalse(a.requires_grad)
 | |
|         (y + a).sum().backward()  # this won't backprop to x
 | |
|         self.assertEqual(x.grad, torch.ones(10, 10) * 2)
 | |
|         self.assertEqual(y.grad, torch.ones(10, 10) * 2)
 | |
| 
 | |
|         # in-place detach on a view raises an exception
 | |
|         view = x.narrow(0, 1, 4)
 | |
|         self.assertRaisesRegex(RuntimeError, 'view', lambda: view.detach_())
 | |
| 
 | |
|     def test_detach_base(self):
 | |
|         "detaching base does not detach view"
 | |
|         x = torch.randn(10, 10, requires_grad=True)
 | |
|         view = x.narrow(0, 1, 4)
 | |
|         x.detach_()
 | |
|         self.assertFalse(x.requires_grad)
 | |
|         self.assertTrue(view.requires_grad)
 | |
|         self.assertIsNotNone(view.grad_fn)
 | |
|         self.assertIs(view._base, x)
 | |
| 
 | |
|     def test_detach_then_inplace_raises_in_autograd(self):
 | |
|         x = torch.randn([], requires_grad=True)
 | |
|         orig_x = x.detach().clone()
 | |
| 
 | |
|         y = x ** 2  # saves x
 | |
|         z = x.detach()
 | |
|         z.zero_()
 | |
|         with self.assertRaisesRegex(RuntimeError, "has been modified by an inplace"):
 | |
|             y.backward()
 | |
| 
 | |
|     def _test_type_conversion_backward(self, t, ):
 | |
|         fvar = Variable(t(torch.randn(5, 5).float()), requires_grad=True)
 | |
|         fvar.double().sum().backward()
 | |
|         self.assertEqual(fvar.grad, torch.ones_like(fvar))
 | |
|         self.assertEqual(type(fvar.grad), type(fvar))
 | |
|         dvar = Variable(t(torch.randn(5, 5).double()), requires_grad=True)
 | |
|         dvar.float().sum().backward()
 | |
|         self.assertEqual(dvar.grad, torch.ones_like(dvar))
 | |
|         self.assertEqual(type(dvar.grad), type(dvar))
 | |
| 
 | |
|     def test_type_conversions(self):
 | |
|         x = torch.randn(5, 5)
 | |
|         self.assertIsInstance(x.float(), torch.FloatTensor)
 | |
|         self.assertIsInstance(x.int(), torch.IntTensor)
 | |
|         if torch.cuda.is_available():
 | |
|             self.assertIsInstance(x.float().cuda(), torch.cuda.FloatTensor)
 | |
|             self.assertIsInstance(x.int().cuda(), torch.cuda.IntTensor)
 | |
|             self.assertIsInstance(x.int().cuda().cpu(), torch.IntTensor)
 | |
|             if torch.cuda.device_count() >= 2:
 | |
|                 x2 = x.float().cuda(1)
 | |
|                 self.assertIsInstance(x2, torch.cuda.FloatTensor)
 | |
|                 self.assertIs(x2.get_device(), 1)
 | |
|                 x2 = x.float().cuda()
 | |
|                 self.assertIsInstance(x2, torch.cuda.FloatTensor)
 | |
|                 self.assertIs(x2.get_device(), 0)
 | |
|                 x2 = x2.cuda(1)
 | |
|                 self.assertIsInstance(x2, torch.cuda.FloatTensor)
 | |
|                 self.assertIs(x2.get_device(), 1)
 | |
|                 y = Variable(torch.randn(5).cuda(1), requires_grad=True)
 | |
|                 y.cpu().sum().backward()
 | |
|                 self.assertIs(y.grad.get_device(), 1)
 | |
|                 self.assertIs(y.long().get_device(), 1)
 | |
| 
 | |
|         for t in [torch.DoubleTensor, torch.FloatTensor, torch.IntTensor, torch.ByteTensor]:
 | |
|             for y_var in (True, False):
 | |
|                 y = torch.randint(5, (5, 5), dtype=t.dtype)
 | |
|                 y = Variable(y) if y_var else y
 | |
|                 self.assertIsInstance(x.type(t), t)
 | |
|                 self.assertIsInstance(x.type_as(y), t)
 | |
|                 # TODO: t.dtype should work
 | |
|                 t_dtype = t().dtype
 | |
|                 self.assertIsInstance(x.type(t_dtype), t)
 | |
|                 self.assertIs(t_dtype, x.type(t_dtype).dtype)
 | |
|                 self.assertEqual(y.data_ptr(), y.type(t).data_ptr())
 | |
|                 if torch.cuda.is_available():
 | |
|                     for x_cuda in (True, False):
 | |
|                         for y_cuda in (True, False):
 | |
|                             x_c = x.cuda() if x_cuda else x
 | |
|                             y_c = y.cuda() if y_cuda else y
 | |
|                             _, y_type = y_c.type().rsplit('.', 1)
 | |
|                             y_typestr = ('torch.cuda.' if y_cuda else 'torch.') + y_type
 | |
|                             self.assertEqual(y_c.type(), x_c.type(y_typestr).type())
 | |
|                             self.assertIs(y_c.dtype, x_c.type(y_c.dtype).dtype)
 | |
|                             self.assertEqual(y_c.data_ptr(), y_c.cuda().data_ptr() if y_cuda else y_c.data_ptr())
 | |
| 
 | |
|         self._test_type_conversion_backward(lambda x: x)
 | |
|         if torch.cuda.is_available():
 | |
|             self._test_type_conversion_backward(lambda x: x.cuda())
 | |
|             if torch.cuda.device_count() >= 2:
 | |
|                 # one of these has to be the non-default device
 | |
|                 self._test_type_conversion_backward(lambda x: x.cuda(0))
 | |
|                 self._test_type_conversion_backward(lambda x: x.cuda(1))
 | |
| 
 | |
|     def test_isolated_node(self):
 | |
|         x = torch.randn(5, 5, requires_grad=True)
 | |
|         y = torch.randn(5, 5, requires_grad=True)
 | |
| 
 | |
|         a = x + y
 | |
|         b = torch.max(a, 1, True)[1].repeat(1, 5).double()
 | |
|         o = (b + a).sum()
 | |
|         o.backward()
 | |
| 
 | |
|     def test_shape(self):
 | |
|         x = torch.randn(3, 4)
 | |
|         self.assertEqual(2, len(x.shape))
 | |
|         self.assertEqual(x.shape[0], 3)
 | |
|         self.assertEqual(x.shape[1], 4)
 | |
| 
 | |
|     def test_numpy_requires_grad(self):
 | |
|         x = torch.randn(2, 2, requires_grad=True)
 | |
|         err_msg_outputs = r"Can't call numpy\(\) on Tensor that requires grad. Use tensor.detach\(\).numpy\(\) instead."
 | |
|         with self.assertRaisesRegex(RuntimeError, err_msg_outputs):
 | |
|             x.numpy()
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             x.numpy()
 | |
| 
 | |
|         x = torch.randn(2, 2)
 | |
|         x.numpy()
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             x.numpy()
 | |
| 
 | |
|     def test_return_leaf(self):
 | |
|         class Identity(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, a, b):
 | |
|                 return a, a + b
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_a, grad_b):
 | |
|                 return grad_a + grad_b, grad_b
 | |
| 
 | |
|         hook_called = [False]
 | |
|         x = torch.randn(5, 5, requires_grad=True)
 | |
|         y = torch.randn(5, 5, requires_grad=True)
 | |
| 
 | |
|         q, p = Identity.apply(x, y)
 | |
| 
 | |
|         # Make sure hooks only receive grad from usage of q, not x.
 | |
|         def hook(grad):
 | |
|             hook_called[0] = True
 | |
|             self.assertEqual(grad, torch.ones(5, 5))
 | |
| 
 | |
|         q.register_hook(hook)
 | |
|         (q + p + x).sum().backward()
 | |
|         self.assertEqual(x.grad, torch.ones(5, 5) * 3)
 | |
|         self.assertEqual(y.grad, torch.ones(5, 5))
 | |
|         self.assertTrue(hook_called[0])
 | |
| 
 | |
|     def test_return_leaf_inplace(self):
 | |
|         class Inplace(InplaceFunction):
 | |
|             @staticmethod
 | |
|             def forward(ctx, a, b):
 | |
|                 ctx.mark_dirty(a)
 | |
|                 return a.add_(b), b + 2
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_a, grad_b):
 | |
|                 return grad_a, grad_a + grad_b
 | |
| 
 | |
|         x = torch.randn(5, 5)
 | |
|         y = torch.randn(5, 5, requires_grad=True)
 | |
| 
 | |
|         q, p = Inplace.apply(x, y)
 | |
|         self.assertIs(q, x)
 | |
|         self.assertIs(q.grad_fn.__class__, Inplace._backward_cls)
 | |
|         self.assertTrue(q.requires_grad)
 | |
|         q.sum().backward()
 | |
|         self.assertEqual(y.grad, torch.ones(5, 5))
 | |
| 
 | |
|     def test_leaf_assignment(self):
 | |
|         x = torch.randn(5, 5)
 | |
|         y = torch.randn(5, requires_grad=True)
 | |
|         z = torch.randn(5, requires_grad=True)
 | |
| 
 | |
|         x[0] = y
 | |
|         x[1] = 2 * z
 | |
|         self.assertTrue(x.requires_grad)
 | |
|         self.assertIsNot(x.grad_fn, None)
 | |
|         x.sum().backward()
 | |
|         self.assertEqual(y.grad, torch.ones(5))
 | |
|         self.assertEqual(z.grad, torch.ones(5) * 2)
 | |
| 
 | |
|     def test_no_grad_assignment(self):
 | |
|         x = torch.randn(5, 5, requires_grad=True)
 | |
|         y = torch.randn(5)
 | |
|         with torch.no_grad():
 | |
|             x[0] = y
 | |
| 
 | |
|         self.assertTrue(x.requires_grad)
 | |
|         self.assertIsNone(x.grad_fn)
 | |
| 
 | |
|     def test_no_grad_modifies_version(self):
 | |
|         x = torch.randn(5, requires_grad=True)
 | |
|         y = torch.randn(5, requires_grad=True)
 | |
|         z = (x * y).sum()
 | |
|         with torch.no_grad():
 | |
|             x *= 2
 | |
|         self.assertRaisesRegex(RuntimeError, 'modified by an inplace operation',
 | |
|                                lambda: z.backward())
 | |
| 
 | |
|     def test_increment_version(self):
 | |
|         a = torch.rand(5, requires_grad=True)
 | |
|         v = a._version
 | |
|         torch.autograd.graph.increment_version(a)
 | |
|         self.assertEqual(a._version, v + 1)
 | |
| 
 | |
|         a = torch.zeros(5, dtype=torch.int)
 | |
|         v = a._version
 | |
|         torch.autograd.graph.increment_version(a)
 | |
|         self.assertEqual(a._version, v + 1)
 | |
| 
 | |
|         with torch.inference_mode():
 | |
|             a = torch.rand(5, requires_grad=True)
 | |
|         msg = "update to inference tensor outside InferenceMode"
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             torch.autograd.graph.increment_version(a)
 | |
| 
 | |
| 
 | |
|     def test_no_grad_input(self):
 | |
|         class MyFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(self, x):
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(self, grad_output):
 | |
|                 return grad_output
 | |
| 
 | |
|         x = torch.randn(5, requires_grad=True)
 | |
|         with torch.no_grad():
 | |
|             y = MyFunction.apply(x)
 | |
| 
 | |
|         self.assertTrue(x.requires_grad)
 | |
|         self.assertIsNone(y.grad_fn)
 | |
| 
 | |
|     def test_backward_copy(self):
 | |
|         # This tests checks backward engine for a very subtle bug that appreared
 | |
|         # in one of the initial versions of autograd. Gradients tensors were
 | |
|         # simply stored in lists while the function waited for all its gradients
 | |
|         # to be computed. However, sometimes an output was used multiple times,
 | |
|         # so the gradients needed to be summed. Engine used to keep a need_copy
 | |
|         # set of tensors that will need a clone upon next addition and removed
 | |
|         # them from the set as soon as the clone was performed. However, this
 | |
|         # could lead to incorrect results if the same gradient tensor was
 | |
|         # buffered in three places in the graph:
 | |
|         # 1. When accumulating gradients in one of these places it was cloned
 | |
|         #    and removed from need_copy set.
 | |
|         # 2. When accumulating in second place, it wasn't in the need_copy set,
 | |
|         #    so the gradients were simply accumulated in-place (which already
 | |
|         #    modified the grad in 3rd place)
 | |
|         # 3. When accumulating in the third place, it wasn't in the need_copy set
 | |
|         #    as well, so the incoming gradient was summed in-place, yielding
 | |
|         #    incorrect results in all functions, except the first one.
 | |
|         x = torch.ones(5, 5, requires_grad=True)
 | |
|         y = torch.ones(5, 5, requires_grad=True)
 | |
|         # Simulate that we're in the middle of the graph
 | |
|         a = x + 2
 | |
|         b = y + 2
 | |
|         c = x + 2
 | |
|         # This op will just return grad_output two times in backward
 | |
|         add1 = a + b
 | |
|         add2 = add1 + c
 | |
|         # Simulate a long branch, so grad_output will get buffered.
 | |
|         for _ in range(4):
 | |
|             a = a * 2
 | |
|             b = b * 2
 | |
|             c = c * 2
 | |
|         branch = a + b + c
 | |
|         out = add2 + branch
 | |
|         # expected gradients are:
 | |
|         # for x: 34 (16 from final a, 16 from final c, 2 from add2)
 | |
|         # for y: 17 (16 from final b, 1 from add2)
 | |
|         grad_output = torch.ones(5, 5)
 | |
|         out.backward(grad_output)
 | |
|         self.assertEqual(x.grad, torch.ones(5, 5) * 34)
 | |
|         self.assertEqual(y.grad, torch.ones(5, 5) * 17)
 | |
| 
 | |
|     def test_save_none_for_backward(self):
 | |
|         test_case = self
 | |
| 
 | |
|         class MyFn(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, input):
 | |
|                 ctx.save_for_backward(None, input, None)
 | |
|                 return input * input
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output):
 | |
|                 n1, input, n2 = ctx.saved_tensors
 | |
|                 test_case.assertIsNone(n1)
 | |
|                 test_case.assertIsNone(n2)
 | |
|                 return 2 * input * grad_output
 | |
| 
 | |
|         x = torch.randn(5, 5, requires_grad=True)
 | |
|         y = MyFn.apply(x)
 | |
|         y.sum().backward()
 | |
|         self.assertEqual(x.grad, 2 * x)
 | |
| 
 | |
|     def test_too_many_grads(self):
 | |
|         class MyFn(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, input):
 | |
|                 return input
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output):
 | |
|                 return grad_output, None, None
 | |
| 
 | |
|         x = torch.randn(5, 5, requires_grad=True)
 | |
|         y = MyFn.apply(x)
 | |
|         y.sum().backward()
 | |
|         self.assertEqual(x.grad, torch.ones_like(x))
 | |
| 
 | |
|     def test_pickle(self):
 | |
|         x = torch.randn(10, 10, requires_grad=True)
 | |
|         y = torch.randn(10, 10, requires_grad=False)
 | |
| 
 | |
|         def assert_strict_equal(var1, var2):
 | |
|             self.assertEqual(var1, var2)
 | |
|             self.assertEqual(var1.requires_grad, var2.requires_grad)
 | |
| 
 | |
|         serialized = [pickle.dumps([x, y], protocol=p) for p in range(3)]
 | |
|         for dump in serialized:
 | |
|             xc, yc = pickle.loads(dump)
 | |
|             assert_strict_equal(xc, x)
 | |
|             assert_strict_equal(yc, y)
 | |
| 
 | |
|     def test_dep_nograd(self):
 | |
|         class F1(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, input):
 | |
|                 out = torch.randn(input.size())
 | |
|                 ctx.mark_non_differentiable(out)
 | |
|                 return input, out
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output, ignored):
 | |
|                 return grad_output
 | |
| 
 | |
|         class F2(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, input, ignored):
 | |
|                 return input
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output):
 | |
|                 return grad_output, None
 | |
| 
 | |
|         x = torch.randn(5, requires_grad=True)
 | |
|         a, b = F1.apply(x)
 | |
|         b = b + 1  # separate F1 from F2 by another op
 | |
|         self.assertTrue(a.requires_grad)
 | |
|         self.assertFalse(b.requires_grad)
 | |
|         c = F2.apply(a, b)
 | |
|         c.backward(torch.ones(c.size()))
 | |
|         self.assertEqual(x.grad, torch.ones(x.size()))
 | |
| 
 | |
|     def test_set_grad_enabled(self):
 | |
|         x = torch.tensor([1.], requires_grad=True)
 | |
|         with torch.set_grad_enabled(False):
 | |
|             y = x * 2
 | |
|         self.assertFalse(y.requires_grad)
 | |
|         with torch.set_grad_enabled(True):
 | |
|             y = x * 2
 | |
|         self.assertTrue(y.requires_grad)
 | |
|         with torch.set_grad_enabled(False):
 | |
|             torch.set_grad_enabled(True)
 | |
|             y = x * 2
 | |
|         self.assertTrue(y.requires_grad)
 | |
| 
 | |
|     def test_simple_reentrant(self):
 | |
|         y_data = torch.randn(2, 2)
 | |
| 
 | |
|         class Reenter(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 with torch.enable_grad():
 | |
|                     ctx.x = Variable(x, requires_grad=True)
 | |
|                     ctx.y = Variable(y_data, requires_grad=True)
 | |
|                     ctx.output_var = ctx.x * ctx.y
 | |
|                 return ctx.output_var.detach()
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output):
 | |
|                 with torch.enable_grad():
 | |
|                     ctx.output_var.sum().backward()
 | |
|                 return ctx.x.grad * grad_output
 | |
| 
 | |
|         # Reentrant starts on CPU thread, finishs on GPU thread
 | |
|         x = torch.randn(2, 2, requires_grad=True)
 | |
|         out = Reenter.apply(x)
 | |
|         out.sum().backward()
 | |
|         self.assertEqual(x.grad, y_data)
 | |
| 
 | |
|     def test_reentrant_child_error(self):
 | |
|         # Parent graph.
 | |
|         a = torch.rand(3, 3, requires_grad=True)
 | |
|         c = a * a
 | |
| 
 | |
|         # Reentrant child graph.
 | |
|         b = torch.rand(3, 3, requires_grad=True)
 | |
|         e = b * b
 | |
|         f = TestAutograd.SimulateBackwardError.apply(e)
 | |
|         reentrant_root = f.sum()
 | |
| 
 | |
|         class ReentrantFunc(Function):
 | |
| 
 | |
|             @staticmethod
 | |
|             def forward(ctx, inp):
 | |
|                 return inp.clone()
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 # Reentrant backward in child will throw an error.
 | |
|                 reentrant_root.backward()
 | |
|                 return grad
 | |
| 
 | |
|         d = ReentrantFunc.apply(c)
 | |
|         with self.assertRaisesRegex(Exception, 'Simulate error'):
 | |
|             d.sum().backward()
 | |
| 
 | |
|     def test_var_mean_differentiable(self):
 | |
|         dim = [2, 4]
 | |
|         keepdim = False
 | |
|         input1 = torch.randn(3, 4, 5, 6, 2, 3, requires_grad=True)
 | |
|         input2 = deepcopy(input1)
 | |
|         var1, mean1 = torch.var_mean(input1, dim=dim, keepdim=keepdim)
 | |
|         var2 = input2.var(dim=dim, keepdim=keepdim)
 | |
|         mean2 = input2.mean(dim=dim, keepdim=keepdim)
 | |
|         grad = torch.randn(3, 4, 6, 3, requires_grad=True)
 | |
| 
 | |
|         r1 = var1 * var1 * mean1 * mean1
 | |
|         r2 = var2 * var2 * mean2 * mean2
 | |
|         self.assertEqual(r1, r2, rtol=0.01, atol=0.0)
 | |
| 
 | |
|         torch.autograd.backward(r1, grad)
 | |
|         torch.autograd.backward(r2, grad)
 | |
|         self.assertEqual(input1.grad, input2.grad, rtol=0.01, atol=0.0)
 | |
| 
 | |
|     @skipIfNoLapack
 | |
|     def test_lobpcg(self):
 | |
| 
 | |
|         def func(k, A, largest=True, B=None):
 | |
|             X_shape = list(A.shape)
 | |
|             X_shape[-1] = k
 | |
|             X = torch.eye(A.size(-2), k, dtype=A.dtype, device=A.device)
 | |
|             if A.dim() > 2:
 | |
|                 X = X.expand(X_shape)
 | |
| 
 | |
|             D, U = torch.lobpcg(A=A, k=k, B=B, X=X, largest=largest)
 | |
| 
 | |
|             # LOBPCG uses a random initial eigenspace approximation
 | |
|             # if parameter `X` is not provided.
 | |
|             # This may cause a non-deterministic behavior
 | |
|             # when it comes to the sign of an eigenvector
 | |
|             # (note if v is an eigenvector, so is -v),
 | |
|             # hence we eliminate this non-determinism
 | |
|             # by making sure that each column of U
 | |
|             # gets multiplied by the sign of its max (in absolute value) element.
 | |
|             # Also, gradcheck changes the content of the input by +/- eps (default to 1e-06)
 | |
|             # to compute the numerical gradient which can also cause the signs to flip.
 | |
|             _, idx = U.abs().max(-2, keepdim=True)
 | |
|             sign = U.gather(-2, idx).sign()
 | |
|             U = U * sign
 | |
|             return D, U
 | |
| 
 | |
|         # TODO: review if this can be ported to OpInfos or moved to test_linalg.py
 | |
|         def run_symeig_test(k, sizes, largest=True):
 | |
|             A = torch.rand(*sizes).double()
 | |
|             A = (A @ A.mT) / 10
 | |
|             A.requires_grad_(True)
 | |
| 
 | |
|             gradcheck(lambda A: func(k, A, largest), A, check_batched_grad=False)
 | |
| 
 | |
|             # Custom gradient vectors for better stability due to some
 | |
|             # non-determinism in the lobpcg's forward.
 | |
|             # Note it is not required if symeig is in forward instead (tested).
 | |
|             D_grad = torch.rand(*A.shape[:-2], k) / 100
 | |
|             U_grad = torch.rand(*A.shape[:-1], k) / 100
 | |
|             gradgradcheck(lambda A: func(k, A, largest), A, [D_grad, U_grad], atol=1e-4, check_batched_grad=False)
 | |
| 
 | |
|             # check whether A.grad is symmetric
 | |
|             A = A.detach().requires_grad_(True)
 | |
|             D, U = func(k, A, largest)
 | |
|             (D.sum() + U.sum()).backward()
 | |
|             self.assertEqual(A.grad, A.grad.mT)
 | |
| 
 | |
|         for largest in [True, False]:
 | |
|             run_symeig_test(1, (6, 6), largest=largest)
 | |
|             run_symeig_test(1, (2, 6, 6), largest=largest)
 | |
|             run_symeig_test(1, (2, 2, 6, 6), largest=largest)
 | |
|             run_symeig_test(2, (6, 6), largest=largest)
 | |
|             run_symeig_test(2, (2, 6, 6), largest=largest)
 | |
|             run_symeig_test(2, (2, 2, 6, 6), largest=largest)
 | |
|             run_symeig_test(3, (9, 9), largest=largest)
 | |
|             run_symeig_test(3, (2, 9, 9), largest=largest)
 | |
|             run_symeig_test(3, (2, 2, 9, 9), largest=largest)
 | |
| 
 | |
|     def test_variable_traverse(self):
 | |
|         def get_out_and_unrefed_cycle():
 | |
|             inp = torch.randn(10, requires_grad=True)
 | |
|             tmp = inp.view(10, 1)
 | |
|             out = tmp.view(10)
 | |
| 
 | |
|             # Create a reference cycle that contains an
 | |
|             # intermediary Variable in the graph
 | |
|             my_list = []
 | |
|             my_list.append(tmp)
 | |
|             my_list.append(my_list)
 | |
| 
 | |
|             return out
 | |
| 
 | |
|         out = get_out_and_unrefed_cycle()
 | |
|         gc.collect()
 | |
|         # This will segfault if things have been erroneously released
 | |
|         out.backward(torch.randn(out.size()))
 | |
| 
 | |
|     # TODO: review porting these to OpInfo tests
 | |
|     def test_pow_zero_tensor_gradient(self):
 | |
|         def run_test(input_size, exponent):
 | |
|             input = torch.zeros(*input_size, requires_grad=True)
 | |
|             input.pow(exponent).sum().backward()
 | |
|             self.assertEqual(input.grad.abs().sum(), 0)
 | |
| 
 | |
|         run_test((10,), torch.zeros(10))
 | |
|         run_test((10, 10), torch.zeros(10, 10))
 | |
|         run_test((10,), 0)
 | |
| 
 | |
|     def test_current_graph_task_id(self):
 | |
|         id = [-1]
 | |
| 
 | |
|         def hook(_):
 | |
|             id[0] = (torch._C._current_graph_task_id())
 | |
| 
 | |
|         t = torch.tensor(1., requires_grad=True).clone()
 | |
|         t.register_hook(hook)
 | |
| 
 | |
|         t.backward(retain_graph=True)
 | |
|         base = id[0]
 | |
|         t.backward(retain_graph=True)
 | |
|         self.assertEqual(id[0] - base, 1)
 | |
|         t.backward(retain_graph=True)
 | |
|         self.assertEqual(id[0] - base, 2)
 | |
| 
 | |
|         self.assertEqual(torch._C._current_graph_task_id(), -1)
 | |
| 
 | |
|     def test_current_graph_task_execution_order(self):
 | |
|         predicted = [None]
 | |
| 
 | |
|         def hook(_):
 | |
|             predicted[0] = torch._C._current_graph_task_execution_order()
 | |
| 
 | |
|         def names(nodes):
 | |
|             return ", ".join([node.name().split(' ')[-1] for node in nodes]) + '\n'
 | |
| 
 | |
|         def grad_fns(*tensors):
 | |
|             # or grad accumulator
 | |
|             out = []
 | |
|             for t in tensors:
 | |
|                 if t.requires_grad and t.grad_fn is None:
 | |
|                     out.append(t.clone().grad_fn.next_functions[0][0])
 | |
|                 else:
 | |
|                     out.append(t.grad_fn)
 | |
|             return out
 | |
| 
 | |
|         actual = []
 | |
| 
 | |
|         def register_logging_hooks(*tensors):
 | |
|             # register hooks that log the order in which they are called
 | |
|             def get_hook(i):
 | |
|                 def hook(t_):
 | |
|                     actual.append(tensors[i])
 | |
|                 return hook
 | |
| 
 | |
|             for i, t in enumerate(tensors):
 | |
|                 t.register_hook(get_hook(i))
 | |
| 
 | |
|         # Basic example: single path
 | |
|         t = torch.tensor(1., requires_grad=True).clone().sin().exp()
 | |
|         t.register_hook(hook)
 | |
|         with torch.autograd.set_multithreading_enabled(False):
 | |
|             t.backward()
 | |
|         self.assertExpectedInline(names(predicted[0]), """\
 | |
| ExpBackward0, SinBackward0, CloneBackward0, torch::autograd::AccumulateGrad
 | |
| """)
 | |
| 
 | |
|         # We don't exactly follow sequence_nr order
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         b = torch.tensor(2., requires_grad=True)
 | |
|         c = b.sin()
 | |
|         d = a.cos()
 | |
|         out = c * d
 | |
|         register_logging_hooks(a, b, c, d, out)
 | |
|         out.register_hook(hook)
 | |
|         with torch.autograd.set_multithreading_enabled(False):
 | |
|             out.backward()
 | |
|         self.assertEqual(predicted[0], grad_fns(*actual))
 | |
|         actual = []
 | |
| 
 | |
|         # Multiple roots are also OK
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         b = a * 2
 | |
|         out = b.sin()
 | |
|         out2 = b.cos()
 | |
|         out3 = b.cos()
 | |
|         register_logging_hooks(a, b, out, out2, out3)
 | |
|         out3.register_hook(hook)
 | |
|         with torch.autograd.set_multithreading_enabled(False):
 | |
|             torch.autograd.grad((out, out3, out2), inputs=(a,))
 | |
|         self.assertExpectedInline(names(predicted[0]), """\
 | |
| CosBackward0, CosBackward0, SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
 | |
| """)
 | |
|         # TODO: Uncomment after update to hooks behavior
 | |
|         # self.assertEqual(predicted[0], grad_fns(*actual))
 | |
|         actual = []
 | |
| 
 | |
|         # Case where next node is nullptr
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         b = a * 2
 | |
|         out = b.sin()
 | |
|         register_logging_hooks(a, b, out)
 | |
|         out.register_hook(hook)
 | |
|         with torch.autograd.set_multithreading_enabled(False):
 | |
|             out.backward()
 | |
|         self.assertEqual(predicted[0], grad_fns(*actual))
 | |
|         actual = []
 | |
| 
 | |
|         # Case where two `inputs` on the same path
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         b = a * 2
 | |
|         out = b.sin()
 | |
|         register_logging_hooks(a, b, out)
 | |
|         out.register_hook(hook)
 | |
|         with torch.autograd.set_multithreading_enabled(False):
 | |
|             torch.autograd.grad((out,), inputs=(a, b,))
 | |
|         self.assertEqual(names(predicted[0]), """\
 | |
| SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
 | |
| """)
 | |
|         # TODO: Uncomment after update to hooks behavior
 | |
|         # self.assertEqual(predicted[0], grad_fns(*actual))
 | |
|         actual = []
 | |
| 
 | |
|         # Case where `inputs` specifies a subgraph
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         b = torch.tensor(1., requires_grad=True)
 | |
|         c = a * b
 | |
|         out = c.sin()
 | |
|         register_logging_hooks(a, b, c, out)
 | |
|         out.register_hook(hook)
 | |
|         with torch.autograd.set_multithreading_enabled(False):
 | |
|             torch.autograd.grad((out,), inputs=(a,))
 | |
|         self.assertEqual(names(predicted[0]), """\
 | |
| SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
 | |
| """)
 | |
|         # TODO: Uncomment after update to hooks behavior
 | |
|         # self.assertEqual(predicted[0], grad_fns(*actual))
 | |
|         actual = []
 | |
| 
 | |
|         # Errors when not called in a backward
 | |
|         with self.assertRaisesRegex(RuntimeError, "should only be called during the backward pass"):
 | |
|             torch._C._current_graph_task_execution_order()
 | |
| 
 | |
|         # Errors when context manager not enabled
 | |
|         t = torch.tensor(1., requires_grad=True).clone().sin().exp()
 | |
|         t.register_hook(hook)
 | |
|         with self.assertRaisesRegex(RuntimeError, "expects the current backward to be executed with multithreading disabled"):
 | |
|             t.backward()
 | |
| 
 | |
|     def test_view_replay_enabled(self):
 | |
|         def f(x):
 | |
|             out = x.clone().view(-1)
 | |
|             # mutate the view, triggering autograd view-replay logic
 | |
|             out.add_(1)
 | |
|             return out
 | |
| 
 | |
|         x = torch.ones(2, 2, requires_grad=True)
 | |
|         with torch.autograd._force_original_view_tracking(True):
 | |
|             out = f(x)
 | |
| 
 | |
|         # view-replay was enabled, so we should see ViewBackward in the graph
 | |
|         # instead of AsStridedBackward.
 | |
|         self.assertTrue("ViewBackward" in str(out.grad_fn))
 | |
| 
 | |
|         # Without view-replay we should as an AsStridedBackward
 | |
|         out = f(x)
 | |
|         self.assertTrue("AsStridedBackward" in str(out.grad_fn))
 | |
| 
 | |
|     def test_unsafe_set_version_counter(self):
 | |
|         x = torch.ones(2, requires_grad=True).clone()
 | |
|         x.add_(1)
 | |
|         x.add_(2)
 | |
|         self.assertEqual(2, x._version)
 | |
|         with torch.autograd._unsafe_preserve_version_counter(x):
 | |
|             x.mul_(2)
 | |
|             x.mul_(3)
 | |
|         # version counter doesn't change inside of the context manager
 | |
|         self.assertEqual(2, x._version)
 | |
| 
 | |
|         torch._C._autograd._unsafe_set_version_counter(x, 0)
 | |
|         self.assertEqual(0, x._version)
 | |
|         with self.assertRaisesRegex(RuntimeError, "Cannot set"):
 | |
|             torch._C._autograd._unsafe_set_version_counter(x, -1)
 | |
| 
 | |
| 
 | |
|     def test_current_node(self):
 | |
|         pr = []
 | |
| 
 | |
|         class MyMode(TorchDispatchMode):
 | |
|             def __torch_dispatch__(self, func, types, args, kwargs=None):
 | |
|                 node = torch._C._current_autograd_node()
 | |
|                 # Don't use node.name() here as it is not consistent on windows
 | |
|                 node_name = node.__class__.__name__ if node else "None"
 | |
|                 pr.append(f"Running {func} from within {node_name}")
 | |
|                 return func(*args, **kwargs or {})
 | |
| 
 | |
|         with MyMode():
 | |
|             pr.append("FW")
 | |
|             a = torch.rand(10, requires_grad=True)
 | |
|             b = a.mul(2).div(3).sum()
 | |
|             pr.append("BW")
 | |
|             b.backward()
 | |
|             pr.append("Done")
 | |
| 
 | |
|         self.assertExpectedInline("\n".join(pr), """\
 | |
| FW
 | |
| Running aten.rand.default from within None
 | |
| Running aten.mul.Tensor from within None
 | |
| Running aten.div.Tensor from within None
 | |
| Running aten.sum.default from within None
 | |
| BW
 | |
| Running aten.ones_like.default from within None
 | |
| Running aten.expand.default from within SumBackward0
 | |
| Running aten.div.Tensor from within DivBackward0
 | |
| Running aten.mul.Tensor from within MulBackward0
 | |
| Running aten.detach.default from within AccumulateGrad
 | |
| Running aten.detach.default from within AccumulateGrad
 | |
| Done""")
 | |
| 
 | |
|     def test_profiler(self):
 | |
|         x = torch.randn(10, 10)
 | |
| 
 | |
|         with profile(use_kineto=kineto_available()) as p:
 | |
|             self.assertTrue(torch.autograd._profiler_enabled())
 | |
|             y = x * 2 + 4
 | |
| 
 | |
|         self.assertFalse(torch.autograd._profiler_enabled())
 | |
| 
 | |
|         names = ['aten::mul', 'aten::add']
 | |
|         found_indices = set()
 | |
|         for evt in p.function_events:
 | |
|             if evt.name in names:
 | |
|                 found_indices.add(names.index(evt.name))
 | |
|         self.assertEqual(len(found_indices), len(names))
 | |
| 
 | |
|     def test_profiler_seq_nr(self):
 | |
|         with profile(use_kineto=kineto_available()) as p:
 | |
|             x = torch.randn(10, 10, requires_grad=True)
 | |
|             y = torch.randn(10, 10, requires_grad=True)
 | |
|             z = x + y
 | |
|             s = z.sum(dim=None)
 | |
|             s.backward()
 | |
|         print(p.key_averages().table(
 | |
|             sort_by="self_cpu_time_total", row_limit=-1))
 | |
|         # expecting aten::add, aten::sum to have the sequence numbers,
 | |
|         # expecting the corresponding backward nodes to have the same numbers
 | |
|         # as the forward ops
 | |
|         autograd_ops = {
 | |
|             ("aten::add", "Add"): [],
 | |
|             ("aten::sum", "Sum"): [],
 | |
|         }
 | |
|         accumulate_ops = []
 | |
|         found_empty = False
 | |
|         for e in p.function_events:
 | |
|             for (fwd_name, bwd_name), ops in autograd_ops.items():
 | |
|                 if e.name == fwd_name or (bwd_name in e.name and "Backward" in e.name):
 | |
|                     ops.append(e)
 | |
| 
 | |
|             if "AccumulateGrad" in e.name:
 | |
|                 accumulate_ops.append(e)
 | |
| 
 | |
|             # check that nested ops (e.g. empty) don't have
 | |
|             # sequence number
 | |
|             if e.name == "aten::empty":
 | |
|                 self.assertEqual(e.sequence_nr, -1)
 | |
|                 found_empty = True
 | |
| 
 | |
|         for idx, ((fwd_name, bwd_name), ops) in enumerate(autograd_ops.items()):
 | |
|             self.assertEqual(len(ops), 3)
 | |
|             self.assertEqual(ops[0].name, fwd_name)
 | |
|             self.assertEqual(ops[1].name, f"autograd::engine::evaluate_function: {bwd_name}Backward{idx}")
 | |
|             self.assertEqual(ops[2].name, f"{bwd_name}Backward{idx}")
 | |
|             self.assertGreaterEqual(ops[0].sequence_nr, 0)
 | |
|             self.assertEqual(ops[1].sequence_nr, ops[0].sequence_nr)
 | |
|             self.assertEqual(ops[2].sequence_nr, ops[0].sequence_nr)
 | |
|             self.assertEqual(ops[0].fwd_thread, 0)
 | |
|             self.assertEqual(ops[1].fwd_thread, ops[0].thread)
 | |
|             self.assertEqual(ops[2].fwd_thread, ops[0].thread)
 | |
|         self.assertTrue(found_empty)
 | |
| 
 | |
|     def test_profiler_unboxed_only(self):
 | |
|         x = torch.rand(3, 4)
 | |
| 
 | |
|         with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
 | |
|             x.resize_([3, 2])
 | |
| 
 | |
|     def test_profiler_propagation(self):
 | |
|         def foo(x):
 | |
|             with record_function("in_foo") as rf:
 | |
|                 return x * 2
 | |
| 
 | |
|         x = torch.rand(3, 4)
 | |
|         traced_foo = torch.jit.trace(foo, x)
 | |
| 
 | |
|         def bar(x):
 | |
|             with record_function("in_bar") as rf:
 | |
|                 # we expect that profiler will be able
 | |
|                 # propagate across fork
 | |
|                 fut = torch.jit._fork(traced_foo, x)
 | |
|                 y = torch.jit._wait(fut)
 | |
|                 # note: continuation (and rf's end) can
 | |
|                 # be executed in a different thread
 | |
|                 with record_function("in_bar_after_wait") as rf2:
 | |
|                     y = y * 2
 | |
|                 return y
 | |
| 
 | |
|         traced_bar = torch.jit.trace(bar, x)
 | |
| 
 | |
|         with profile(use_kineto=kineto_available()) as p:
 | |
|             traced_bar(x)
 | |
| 
 | |
|         found_foo = False
 | |
|         found_bar = False
 | |
|         found_bar_after_wait = False
 | |
|         for info in p.function_events:
 | |
|             if info.name == "in_foo":
 | |
|                 self.assertFalse(found_foo)
 | |
|                 found_foo = True
 | |
|             elif info.name == "in_bar":
 | |
|                 self.assertFalse(found_bar)
 | |
|                 found_bar = True
 | |
|             elif info.name == "in_bar_after_wait":
 | |
|                 self.assertFalse(found_bar_after_wait)
 | |
|                 found_bar_after_wait = True
 | |
|         self.assertTrue(found_foo)
 | |
|         self.assertTrue(found_bar)
 | |
|         self.assertTrue(found_bar_after_wait)
 | |
| 
 | |
|     def test_record_function_callbacks(self):
 | |
|         x = torch.randn(10, 10)
 | |
|         with profile(use_kineto=kineto_available()) as p:
 | |
|             with record_function("foo"):
 | |
|                 y = x * 2 + 4
 | |
| 
 | |
|         function_events = p.function_events
 | |
|         foo_event = [event for event in function_events if "foo" in event.name][0]
 | |
|         self.assertEqual(foo_event.count, 1)
 | |
| 
 | |
|     def test_record_function_legacy(self):
 | |
|         # Test the new _record_function ops work
 | |
|         # Note: Remove once record_function uses these directly
 | |
|         x = torch.randn(10, 10)
 | |
|         with profile(use_kineto=kineto_available()) as p:
 | |
|             handle = torch.ops.profiler._record_function_enter("bar", None)
 | |
|             try:
 | |
|                 y = x * 2 + 4
 | |
|             finally:
 | |
|                 torch.ops.profiler._record_function_exit(handle)
 | |
| 
 | |
|         function_events = p.function_events
 | |
|         foo_event = [event for event in function_events if "bar" in event.name][0]
 | |
|         self.assertEqual(foo_event.count, 1)
 | |
| 
 | |
|     def test_profiler_aggregation_fake(self):
 | |
|         events = EventList()
 | |
|         id = [0]
 | |
| 
 | |
|         def get_id():
 | |
|             id[0] = id[0] + 1
 | |
|             return id[0]
 | |
| 
 | |
|         # [[thread_id, [(start, end, id), ....]], ...]
 | |
|         # Using list instead of a dict so order is guaranteed for any Python
 | |
|         # version
 | |
|         threads = [
 | |
|             [1, [(0, 1, get_id()), (1, 2, get_id())]],
 | |
|             [0, [(0, 2, get_id()), (1, 2, get_id()), (1, 3, get_id())]],
 | |
|         ]
 | |
|         for thread, ranges in threads:
 | |
|             for range in ranges:
 | |
|                 assert(len(range) == 3)
 | |
|                 events.append(
 | |
|                     FunctionEvent(
 | |
|                         id=range[2],
 | |
|                         node_id=0,
 | |
|                         name="",
 | |
|                         thread=thread,
 | |
|                         start_us=range[0],
 | |
|                         end_us=range[1],
 | |
|                     )
 | |
|                 )
 | |
| 
 | |
|         events._populate_cpu_children()
 | |
| 
 | |
|         # Note that [1, 3] pushes out [0, 2] first. Then we record [1, 2]
 | |
|         # as a child of [1, 3]
 | |
|         res = [[], [], [], [], [4]]
 | |
| 
 | |
|         def get_children_ids(event):
 | |
|             return [child.id for child in event.cpu_children]
 | |
| 
 | |
|         assert([get_children_ids(event) for event in events] == res)
 | |
| 
 | |
|     def test_profiler_aggregation_table(self):
 | |
|         """
 | |
|         Test if the profiling result is aggregated for `str(prof)`
 | |
| 
 | |
|         See: https://github.com/pytorch/pytorch/issues/37500
 | |
|         """
 | |
| 
 | |
|         x = torch.randn(1024)
 | |
|         with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
 | |
|             torch.einsum("i->", x)
 | |
| 
 | |
|         prof_str = str(prof)
 | |
|         prof_table = prof.table()
 | |
| 
 | |
|         self.assertEqual(prof_table, prof_str)
 | |
| 
 | |
|     def test_profiler_function_event_avg(self):
 | |
|         avg = FunctionEventAvg()
 | |
|         avg.add(FunctionEvent(id=0, node_id=0, name="foo", thread=0, start_us=10, end_us=15))
 | |
|         avg.add(FunctionEvent(id=1, node_id=0, name="foo", thread=0, start_us=20, end_us=30))
 | |
|         avg.add(avg)
 | |
|         self.assertEqual(avg.key, "foo")
 | |
| 
 | |
|         # aggregate stats
 | |
|         self.assertEqual(avg.count, 4)
 | |
|         self.assertEqual(avg.cpu_time_total, 30)
 | |
|         self.assertEqual(avg.self_cpu_time_total, 30)
 | |
|         self.assertEqual(avg.cuda_time_total, 0)
 | |
| 
 | |
|         # average stats
 | |
|         self.assertEqual(avg.cpu_time, 7.5)
 | |
|         self.assertEqual(avg.cuda_time_total, 0)
 | |
| 
 | |
|     def test_profiler_shapes(self):
 | |
|         print("")
 | |
|         layer1 = torch.nn.Linear(20, 30)
 | |
|         layer2 = torch.nn.Linear(30, 40)
 | |
|         input = torch.randn(128, 20)
 | |
|         with profile(record_shapes=True, use_kineto=kineto_available()) as prof:
 | |
|             layer2(layer1(input))
 | |
| 
 | |
|         print(prof.function_events)
 | |
| 
 | |
|         linear_expected_shapes = [
 | |
|             [[128, 20], [30, 20], [30]],
 | |
|             [[128, 30], [40, 30], [40]],
 | |
|         ]
 | |
| 
 | |
|         found_indices = set()
 | |
|         for event in prof.function_events:
 | |
|             if event.name == "aten::linear":
 | |
|                 self.assertTrue(event.input_shapes in linear_expected_shapes)
 | |
|                 found_indices.add(linear_expected_shapes.index(event.input_shapes))
 | |
|         self.assertEqual(len(found_indices), len(linear_expected_shapes))
 | |
| 
 | |
|     def test_profiler_aggregation_lstm(self):
 | |
|         print("")
 | |
|         rnn = torch.nn.LSTM(10, 20, 2)
 | |
|         total_time_s = 0
 | |
|         with profile(record_shapes=True, use_kineto=kineto_available()) as prof:
 | |
|             for i in range(20):
 | |
|                 input = torch.randn(5, 3, 10)
 | |
|                 h = torch.randn(2, 3, 20)
 | |
|                 c = torch.randn(2, 3, 20)
 | |
|                 start = time.time()
 | |
|                 rnn(input, (h, c))
 | |
|                 end = time.time()
 | |
|                 total_time_s += end - start
 | |
| 
 | |
|         print(prof.table(
 | |
|             sort_by="self_cpu_time_total", row_limit=10, header="TEST"))
 | |
|         print(prof.key_averages(group_by_input_shape=True).table(
 | |
|             sort_by="self_cpu_time_total", row_limit=10))
 | |
|         print(prof.table(
 | |
|             sort_by="self_cpu_time_total", row_limit=10, max_src_column_width=300, header="TEST", top_level_events_only=True))
 | |
|         print(prof.key_averages(group_by_input_shape=True).table(
 | |
|             sort_by="self_cpu_time_total", row_limit=10, top_level_events_only=True))
 | |
| 
 | |
|         total_time_us = total_time_s * 1000.0 * 1000.0  # make it us which is profiler default
 | |
|         print(
 | |
|             "Total time based on python measurements: ",
 | |
|             _format_time(total_time_us)
 | |
|         )
 | |
|         print(
 | |
|             "CPU time measurement python side overhead: {:.2f}%".format(
 | |
|                 (total_time_us / prof.self_cpu_time_total - 1.0) * 100.0
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         if sys.platform != "win32":
 | |
|             with tempfile.NamedTemporaryFile() as trace_file:
 | |
|                 prof.export_chrome_trace(trace_file.name)
 | |
| 
 | |
|     def test_record_function(self):
 | |
|         x = torch.randn(10, 10)
 | |
| 
 | |
|         def forward(x):
 | |
|             with record_function("outer"):
 | |
|                 y = x * 2 + 4
 | |
|                 with record_function("inner"):
 | |
|                     y = y - 1
 | |
|             y = y / 1
 | |
| 
 | |
|         forward(x)
 | |
| 
 | |
|         with profile(use_kineto=kineto_available()) as p:
 | |
|             forward(x)
 | |
| 
 | |
|         events = p.function_events
 | |
|         important_events = [
 | |
|             'outer',
 | |
|             'aten::mul',
 | |
|             'aten::add',
 | |
|             'inner',
 | |
|             'aten::sub',
 | |
|             'aten::div'
 | |
|         ]
 | |
|         idx = 0
 | |
|         for info in events:
 | |
|             if info.name == important_events[idx]:
 | |
|                 idx = idx + 1
 | |
|             if idx == len(important_events):
 | |
|                 break
 | |
|         self.assertEqual(idx, len(important_events))
 | |
| 
 | |
|         # We can also use record_function to decorate arbitrary function
 | |
|         @record_function('my_func')
 | |
|         def f(x, y):
 | |
|             return x + y
 | |
| 
 | |
|         with profile(use_kineto=kineto_available()) as p:
 | |
|             f(1, 2)
 | |
| 
 | |
|         self.assertTrue('my_func' in str(p))
 | |
| 
 | |
|     def test_record_function_multithreaded(self):
 | |
|         rf = record_function("outer")
 | |
|         rf.__enter__()
 | |
|         with record_function("inner"):
 | |
|             # test that exiting the record function after starting another one
 | |
|             # doesn't throw.
 | |
|             rf.__exit__(None, None, None)
 | |
| 
 | |
|         with record_function("inner"):
 | |
|             rf.__enter__()
 | |
|         # test that exiting the record function after ending another one
 | |
|         # doesn't throw.
 | |
|         rf.__exit__(None, None, None)
 | |
| 
 | |
| 
 | |
|     def test_dir(self):
 | |
|         x = torch.randn(10, 10)
 | |
|         keys = dir(x)
 | |
|         self.assertIn('shape', keys)
 | |
| 
 | |
|         # real and imag are only implemented for complex tensors.
 | |
|         y = torch.randn(10, 10, dtype=torch.cfloat)
 | |
|         imag_key = 'imag'
 | |
|         self.assertRaises(RuntimeError, lambda: hasattr(x, imag_key))
 | |
|         self.assertTrue(hasattr(y, imag_key))
 | |
|         keys.remove(imag_key)
 | |
| 
 | |
|         for key in keys:
 | |
|             self.assertTrue(hasattr(x, key))
 | |
| 
 | |
| 
 | |
|     def test_inplace_on_view_saved_output(self):
 | |
|         # Test an in-place operation on a view in which the in-place op saves
 | |
|         # its output. Previously, this created a reference cycle.
 | |
|         dealloc = [0]
 | |
| 
 | |
|         class IncrementOnDelete:
 | |
|             def __del__(self):
 | |
|                 dealloc[0] += 1
 | |
| 
 | |
|         def test():
 | |
|             root = torch.randn(3, 3, requires_grad=True)
 | |
|             copy = root.clone()
 | |
|             copy.grad_fn.register_hook(IncrementOnDelete())
 | |
|             view = copy.view(9)
 | |
|             torch.nn.functional.relu(view, inplace=True)
 | |
| 
 | |
|         test()
 | |
|         self.assertEqual(dealloc[0], 1)
 | |
| 
 | |
|     def test_inplace_on_view_leaf_errors(self):
 | |
|         # Issue #21875: Fail faster (when we try to modify the view vs. in backward())
 | |
|         x = torch.zeros(1, requires_grad=True)
 | |
|         y = x.view_as(x)
 | |
|         with self.assertRaisesRegex(RuntimeError,
 | |
|                                     "a view of a leaf Variable that "
 | |
|                                     "requires grad is being used in "
 | |
|                                     "an in-place operation."):
 | |
|             y.add_(1)
 | |
| 
 | |
|     def test_inplace_on_view_backward(self):
 | |
|         # Issue #10532: Make sure that this does not raise RuntimeError.
 | |
|         net = nn.Sequential(
 | |
|             nn.InstanceNorm2d(2),
 | |
|             nn.ReLU(True)
 | |
|         )
 | |
| 
 | |
|         x = torch.tensor([[[[1.0, 1.0]]]], requires_grad=True)
 | |
|         g, = torch.autograd.grad(net(x).pow(2), [x], grad_outputs=x.new_ones(x.shape) , create_graph=True)
 | |
|         torch.autograd.grad(g.sum(), [x])
 | |
|         self.assertEqual(x, torch.tensor([[[[1.0, 1.0]]]]))
 | |
| 
 | |
|         # https://discuss.pytorch.org/t/freeing-buffer-strange-behavior/31955/8
 | |
|         inputs = torch.ones((1, 3, 256, 256), requires_grad=True)
 | |
| 
 | |
|         tmp1 = (inputs + 1).view_as(inputs)
 | |
|         tmp2 = torch.nn.functional.threshold(tmp1, 0., 0., True)
 | |
|         prob_interpolated = torch.sigmoid(tmp2)
 | |
| 
 | |
|         gradients = torch.autograd.grad(outputs=prob_interpolated, inputs=inputs,
 | |
|                                         grad_outputs=torch.ones(prob_interpolated.size()),
 | |
|                                         create_graph=True, retain_graph=True)[0]
 | |
| 
 | |
|         gradient_penalty = gradients.sum()
 | |
|         gradient_penalty.backward()
 | |
| 
 | |
|         fn = gradient_penalty.grad_fn.next_functions[0][0].next_functions[1][0]
 | |
|         self.assertEqual(fn.name(), "ThresholdBackwardBackward0")
 | |
| 
 | |
|     def test_inplace_on_view_weak_grad_fn(self):
 | |
|         # Issue 23502: Test that b's grad_fn is preserved.
 | |
|         a = torch.arange(10.0, requires_grad=True)
 | |
| 
 | |
|         b = a.narrow(0, 0, 2).clone().view(-1)
 | |
|         b.relu_()
 | |
| 
 | |
|         c = b.clone()
 | |
|         del b
 | |
|         gc.collect()
 | |
| 
 | |
|         s = c.sum()
 | |
|         s.backward()
 | |
|         self.assertEqual(s, torch.tensor(1.0))
 | |
| 
 | |
|         # Issue #21875: Fail faster (when we try to modify the view vs. in backward())
 | |
|         a = torch.rand(10, requires_grad=True).narrow(0, 0, 10)
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             b = a.relu_()
 | |
| 
 | |
|     def test_out_variant_raises_when_inputs_require_grad(self):
 | |
|         a = torch.randn(2, 2, requires_grad=True)
 | |
|         b = torch.randn(2, 2, requires_grad=True)
 | |
|         x = torch.zeros_like(a)
 | |
| 
 | |
|         # out=... functions don't support automatic differentiation currently
 | |
|         self.assertRaisesRegex(RuntimeError, 'out=', lambda: torch.mul(a, b, out=x))
 | |
| 
 | |
|         # the inputs can require grad if we're in no_grad() mode
 | |
|         with torch.no_grad():
 | |
|             torch.mul(a, b, out=x)
 | |
|             self.assertEqual(x, a * b)
 | |
| 
 | |
|         a = torch.randn(2, 2)
 | |
|         b = torch.randn(2, 2)
 | |
|         x = torch.zeros(2, 2, requires_grad=True)
 | |
|         # we should throw an exception if the output requires grad
 | |
|         self.assertRaisesRegex(RuntimeError, 'out=', lambda: torch.mul(a, b, out=x))
 | |
| 
 | |
|     def test_anomaly_detect_nan(self):
 | |
|         size = 10
 | |
| 
 | |
|         class MyFunc(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, inp1, inp2, fail_0th):
 | |
|                 ctx.fail_0th = fail_0th
 | |
|                 return inp1.sum(0, keepdim=True)
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, gO):
 | |
|                 gI = gO.clone().expand(size)
 | |
|                 gI[0] = 0
 | |
|                 gI[0] /= 0  # Generate a nan
 | |
|                 if ctx.fail_0th:
 | |
|                     return gI, None, None
 | |
|                 else:
 | |
|                     return None, gI, None
 | |
| 
 | |
|         inp = torch.rand(size, requires_grad=True)
 | |
|         out = MyFunc.apply(inp, inp, True)
 | |
|         out.backward()  # Should not fail
 | |
| 
 | |
|         inp = torch.rand(size, requires_grad=True)
 | |
|         out = MyFunc.apply(inp, inp, True)
 | |
|         with self.assertRaisesRegex(RuntimeError, "Function 'MyFuncBackward' returned nan values in its 0th output."):
 | |
|             with warnings.catch_warnings(record=True) as w:
 | |
|                 with detect_anomaly():
 | |
|                     out.backward()
 | |
|             self.assertIn('No forward pass information', str(w[0].message))
 | |
| 
 | |
|         inp = torch.rand(size, requires_grad=True)
 | |
|         with self.assertRaisesRegex(RuntimeError, "Function 'MyFuncBackward' returned nan values in its 1th output."):
 | |
|             with warnings.catch_warnings(record=True) as w:
 | |
|                 with detect_anomaly():
 | |
|                     out = MyFunc.apply(inp, inp, False)
 | |
|                     out.backward()
 | |
|             self.assertIn('MyFunc.apply', str(w[0].message))
 | |
| 
 | |
|     def test_calculate_shape_util(self):
 | |
|         out = torch.randn(10, 5, requires_grad=True)
 | |
|         grad = torch.randn(5, 10, requires_grad=True)
 | |
|         out_shape, grad_shape = _calculate_shape(out, grad, False)
 | |
| 
 | |
|         assert out_shape == torch.Size([10, 5])
 | |
|         assert grad_shape == torch.Size([5, 10])
 | |
| 
 | |
|         out = torch.nested.as_nested_tensor([
 | |
|             torch.randn(10, 5, requires_grad=True),
 | |
|             torch.randn(10, 5, requires_grad=True),
 | |
|             torch.randn(10, 5, requires_grad=True)]
 | |
|         )
 | |
|         grad = torch.nested.as_nested_tensor([torch.randn(5, 10, requires_grad=True), torch.randn(5, 10, requires_grad=True)])
 | |
|         out_shape, grad_shape = _calculate_shape(out, grad, False)
 | |
| 
 | |
|         assert torch.equal(out_shape, torch.tensor([[10, 5], [10, 5], [10, 5]]))
 | |
|         assert torch.equal(grad_shape, torch.tensor([[5, 10], [5, 10]]))
 | |
| 
 | |
|     def test_nested_anomaly_detect_nan(self):
 | |
|         size = 10
 | |
| 
 | |
|         class MyFunc(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, inp1, fail_0th):
 | |
|                 ctx.fail_0th = fail_0th
 | |
|                 ctx.save_for_backward(inp1)
 | |
|                 return inp1.sum(0, keepdim=True)
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, gO):
 | |
|                 inp, = ctx.saved_tensors
 | |
|                 fail_0th = ctx.fail_0th
 | |
|                 g = gO.clone().expand(size)
 | |
|                 gI = MyFunc2.apply(g * inp, g + inp, fail_0th)
 | |
|                 return gI, None
 | |
| 
 | |
|         class MyFunc2(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, inp1, inp2, fail_0th):
 | |
|                 ctx.fail_0th = fail_0th
 | |
|                 return inp1 * 2.0 + inp2
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, gO):
 | |
|                 fail_0th = ctx.fail_0th
 | |
|                 g1 = gO.clone()
 | |
|                 g2 = gO.clone()
 | |
|                 g1[0] = 0
 | |
|                 g2[0] = 0
 | |
|                 # generate a nan
 | |
|                 if fail_0th:
 | |
|                     g1[0] /= 0
 | |
|                 else:
 | |
|                     g2[0] /= 0
 | |
|                 return g1, g2, None
 | |
| 
 | |
|         inp = torch.rand(size, requires_grad=True)
 | |
|         out = MyFunc.apply(inp, True)
 | |
|         ginp, = torch.autograd.grad(out, (inp,), create_graph=True)
 | |
|         gsum = ginp.sum()
 | |
|         gsum.backward()  # should not fail
 | |
| 
 | |
|         inp = torch.rand(size, requires_grad=True)
 | |
|         out = MyFunc.apply(inp, True)
 | |
|         ginp, = torch.autograd.grad(out, (inp,), create_graph=True)
 | |
|         gsum = ginp.sum()
 | |
|         with warnings.catch_warnings(record=True) as w:
 | |
|             with self.assertRaisesRegex(RuntimeError, "Function 'MyFunc2Backward' returned nan values in its 0th output."):
 | |
|                 with detect_anomaly():
 | |
|                     gsum.backward()
 | |
|         self.assertIn('No forward pass information', str(w[1].message))
 | |
| 
 | |
|         inp = torch.rand(size, requires_grad=True)
 | |
|         with warnings.catch_warnings(record=True) as w:
 | |
|             with self.assertRaisesRegex(RuntimeError, "Function 'MyFunc2Backward' returned nan values in its 1th output."):
 | |
|                 with detect_anomaly():
 | |
|                     out = MyFunc.apply(inp, False)
 | |
|                     ginp, = torch.autograd.grad(out, (inp,), create_graph=True)
 | |
|                     gsum = ginp.sum()
 | |
|                     gsum.backward()
 | |
|         self.assertIn('MyFunc2.apply', str(w[1].message))
 | |
|         self.assertIn('MyFunc.apply', str(w[2].message))
 | |
| 
 | |
|     def test_anomaly_grad_warnings(self):
 | |
|         # PyTorch won't throw warnings if there is an error
 | |
|         # but we'd want to at least see them in stderr
 | |
| 
 | |
|         class StdErrDiverter:
 | |
|             def __enter__(self):
 | |
|                 self.stderr_orig = sys.stderr
 | |
|                 self.stderr_new = io.StringIO()
 | |
|                 sys.stderr = self.stderr_new
 | |
|                 return self
 | |
| 
 | |
|             def __exit__(self, *args):
 | |
|                 self.captured = self.stderr_new.getvalue()
 | |
|                 sys.stderr = self.stderr_orig
 | |
| 
 | |
| 
 | |
|         # if the warnings don't throw, they will be handled as regular warnings
 | |
|         with self.assertRaisesRegex(RuntimeError,
 | |
|                                     "one of the variables needed for gradient computation has been "
 | |
|                                     "modified by an inplace operation"):
 | |
|             with warnings.catch_warnings(record=True) as w:
 | |
|                 with detect_anomaly():
 | |
|                     a = torch.randn(5, requires_grad=True)
 | |
|                     d1 = a + 1
 | |
|                     d2 = d1 ** 2
 | |
|                     d1 += 1
 | |
|                     torch.autograd.grad(d2.sum(), a)
 | |
| 
 | |
|         self.assertEqual(len(w), 2)
 | |
|         self.assertIn('Anomaly Detection has been enabled', str(w[0].message))
 | |
|         self.assertIn('Error detected in PowBackward0', str(w[1].message))
 | |
| 
 | |
|         # if the warning throws, it will be printed to sys.stderr
 | |
|         with self.assertRaisesRegex(RuntimeError,
 | |
|                                     "one of the variables needed for gradient computation has been "
 | |
|                                     "modified by an inplace operation"):
 | |
|             with warnings.catch_warnings(record=True) as w:
 | |
|                 with detect_anomaly():
 | |
|                     warnings.simplefilter("error")
 | |
|                     with StdErrDiverter() as s:
 | |
|                         a = torch.randn(5, requires_grad=True)
 | |
|                         d1 = a + 1
 | |
|                         d2 = d1 ** 2
 | |
|                         d1 += 1
 | |
|                         torch.autograd.grad(d2.sum(), a)
 | |
| 
 | |
|         self.assertEqual(len(w), 1)
 | |
|         self.assertIn('Anomaly Detection has been enabled', str(w[0].message))
 | |
|         self.assertIn('Error detected in PowBackward0', s.captured)
 | |
| 
 | |
|     def test_anomaly_assign_parent_cleanup(self):
 | |
|         # Test that python objects created are properly cleaned up when assign_parent is called
 | |
| 
 | |
|         def get_ref():
 | |
|             # we use torch.exp here but any function that will construct a new node in its
 | |
|             # backward call in grad mode will work
 | |
|             x = torch.randn(2, 2, requires_grad=True)
 | |
|             t = x.exp()
 | |
| 
 | |
|             # ExpBackward calls mul, creating the MulBackward node when create_graph=True.
 | |
|             # In anomaly mode, a PyObject referencing MulBackward's "parent" ExpBackward is added to
 | |
|             # MulBackward's anomaly metadata dict, creating the following reference chain:
 | |
|             #
 | |
|             # grad -> MulBackward -> PyObject -> ExpBackward
 | |
|             #
 | |
|             with detect_anomaly():
 | |
|                 grad = torch.autograd.grad(t, x, torch.ones_like(t), create_graph=True)
 | |
| 
 | |
|             # We add a weak reference to a new Foo object, which we insert into ExpBackward's metadata dict
 | |
|             #
 | |
|             # (PyObject) -> ExpBackward -> dict -> *Foo*
 | |
|             #            t ----^        WeakRef ---^
 | |
|             #
 | |
|             # We want to test that when grad goes out of scope at the end of this function that PyObject is destroyed
 | |
|             # We can test this by seeing whether Foo is not kept alive once t is destroyed
 | |
|             class Foo:
 | |
|                 pass
 | |
|             my_obj = Foo()
 | |
|             meta_dict = t.grad_fn.metadata
 | |
|             meta_dict[0] = my_obj
 | |
|             ref = weakref.ref(my_obj)
 | |
|             return t, ref
 | |
| 
 | |
|         t, ref = get_ref()
 | |
|         self.assertIsNotNone(ref())
 | |
|         del t
 | |
|         self.assertIsNone(ref())
 | |
| 
 | |
|     def test_nested_anomaly_printstack_cleanup(self):
 | |
|         # Test if metadata dict PyObject is properly destroyed
 | |
|         def get_ref():
 | |
|             # This is similar to the construction in test_anomaly_assign_parent_cleanup:
 | |
|             #
 | |
|             # MyFuncBackward2 -> PyObject -> MyFuncBackward -> dict -> Foo
 | |
|             #                               out ---^         WeakRef ---^
 | |
|             #
 | |
|             # We want to check that Foo is still properly destroyed even when MyFunc2Backward's
 | |
|             # AnomalyMetadata calls printstack, which does some python object manipulation.
 | |
|             #
 | |
|             # You might be wondering why we still have to test_anomaly_assign_parent_cleanup,
 | |
|             # since if PyObject is not destroyed here, wouldn't this test would detect that also?
 | |
|             # The answer is that custom function's PyObject (THPFunction) actually only hold
 | |
|             # a weak reference to the c++ node!
 | |
|             class MyFunc(Function):
 | |
|                 @staticmethod
 | |
|                 def forward(ctx, x):
 | |
|                     ctx.save_for_backward(x)
 | |
|                     return x
 | |
| 
 | |
|                 @staticmethod
 | |
|                 def backward(ctx, gO):
 | |
|                     x, = ctx.saved_tensors
 | |
|                     return MyFunc2.apply(x)
 | |
| 
 | |
|             class MyFunc2(Function):
 | |
|                 @staticmethod
 | |
|                 def forward(ctx, x):
 | |
|                     return x
 | |
| 
 | |
|                 @staticmethod
 | |
|                 def backward(ctx, gO):
 | |
|                     return gO + float("NaN")
 | |
| 
 | |
|             inp = torch.rand(1, requires_grad=True)
 | |
|             out = MyFunc.apply(inp)
 | |
|             ginp, = torch.autograd.grad(out, (inp,), create_graph=True)
 | |
| 
 | |
|             with warnings.catch_warnings(record=True) as w:
 | |
|                 with self.assertRaisesRegex(RuntimeError, "Function 'MyFunc2Backward' returned nan values in its 0th output."):
 | |
|                     with detect_anomaly():
 | |
|                         ginp.backward()
 | |
| 
 | |
|             class Foo:
 | |
|                 pass
 | |
|             my_obj = Foo()
 | |
|             meta_dict = out.grad_fn.metadata
 | |
|             meta_dict[0] = my_obj
 | |
|             ref = weakref.ref(my_obj)
 | |
|             return out, ref
 | |
| 
 | |
|         t, ref = get_ref()
 | |
|         self.assertIsNotNone(ref())
 | |
|         del t
 | |
|         self.assertIsNone(ref())
 | |
| 
 | |
|     def test_anomaly_mode_no_check_nan(self):
 | |
|         class MyFunc(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, inp):
 | |
|                 return inp.clone()
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, gO):
 | |
|                 return torch.tensor(float("nan")).expand(10, 10)
 | |
| 
 | |
|         def run_fn(a):
 | |
|             out = MyFunc.apply(a)
 | |
|             return out.sum()
 | |
| 
 | |
|         with warnings.catch_warnings(record=True) as w:
 | |
|             with torch.autograd.detect_anomaly(check_nan=False):
 | |
|                 inp = torch.rand(10, 10, requires_grad=True)
 | |
|                 out = run_fn(inp)
 | |
|                 out.backward(retain_graph=True)
 | |
| 
 | |
|                 with torch.autograd.detect_anomaly(check_nan=True):
 | |
|                     with self.assertRaisesRegex(RuntimeError, "Function 'MyFuncBackward' returned nan values in its 0th output."):
 | |
|                         out.backward(retain_graph=True)
 | |
| 
 | |
|                 out.backward()
 | |
| 
 | |
|     def test_no_grad_copy(self):
 | |
|         # create autograd function that saves grad pointer as class static
 | |
|         class MyFunc(Function):
 | |
|             static_grad_ptr = None
 | |
| 
 | |
|             @staticmethod
 | |
|             def forward(ctx, inp1, inp2):
 | |
|                 return inp1 + inp2
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 MyFunc.static_grad_ptr = grad.data_ptr()
 | |
|                 return grad, grad
 | |
| 
 | |
|         class NonContGradFunc(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, inp1):
 | |
|                 ctx.size = inp1.size()
 | |
|                 return torch.tensor([1.])
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 return torch.ones(1).expand(ctx.size)
 | |
| 
 | |
|         a = torch.randn(5, 6, requires_grad=True)
 | |
|         b = torch.randn(5, 6, requires_grad=True)
 | |
|         # non-contiguous grad should be copied
 | |
|         NonContGradFunc.apply(MyFunc.apply(a, b)).backward()
 | |
|         self.assertFalse(a.grad.data_ptr() == MyFunc.static_grad_ptr)
 | |
|         self.assertFalse(b.grad.data_ptr() == MyFunc.static_grad_ptr)
 | |
|         # test case that should trigger no copy for one of a,b
 | |
|         a.grad = b.grad = None
 | |
|         MyFunc.apply(a, b)[1][0].backward()
 | |
|         p_g = MyFunc.static_grad_ptr
 | |
|         p_a = a.grad.data_ptr()
 | |
|         p_b = b.grad.data_ptr()
 | |
|         # check a,b uses different grad buffer
 | |
|         self.assertFalse(p_a == p_b)
 | |
|         # check one of them is using the computed buffer
 | |
|         self.assertTrue(p_a == p_g or p_b == p_g)
 | |
| 
 | |
|     def test_no_grad_copy_sparse(self):
 | |
|         # create autograd function that saves grad pointer as class static
 | |
|         class MyFunc(Function):
 | |
|             static_grad_ptr = None
 | |
| 
 | |
|             @staticmethod
 | |
|             def forward(ctx, inp1, inp2):
 | |
|                 return inp1 + inp2
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 MyFunc.static_grad_ptr = grad._values().data_ptr()
 | |
|                 return grad, grad
 | |
| 
 | |
|         class NonContGradFunc(Function):
 | |
|             static_grad_ptr = None
 | |
| 
 | |
|             @staticmethod
 | |
|             def forward(ctx, inp1, inp2):
 | |
|                 return inp1 + inp2
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 # Create a sparse tensor with non-contigous indices and values
 | |
|                 # and return as grad.
 | |
|                 v = torch.rand(1, 3)
 | |
|                 i = torch.ones(1, 1, dtype=torch.long)
 | |
|                 nv = v.expand(8, 3)
 | |
|                 ni = i.expand(1, 8)
 | |
|                 ngrad = torch.sparse_coo_tensor(ni, nv, (10, 3), dtype=torch.float32)
 | |
|                 NonContGradFunc.static_grad_ptr = ngrad._values().data_ptr()
 | |
|                 return ngrad, ngrad
 | |
| 
 | |
|         a = torch.randn(10, 3, requires_grad=True)
 | |
|         b = torch.randn(10, 3, requires_grad=True)
 | |
|         input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
 | |
|         offsets = torch.tensor([0, 4])
 | |
|         import torch.nn.functional as F
 | |
| 
 | |
|         # test case that should trigger no copy for one of a,b
 | |
|         emb_matrix = MyFunc.apply(a, b)
 | |
|         loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
 | |
|         loss.backward(retain_graph=True)
 | |
|         p_g = MyFunc.static_grad_ptr
 | |
|         p_a = a.grad._values().data_ptr()
 | |
|         p_b = b.grad._values().data_ptr()
 | |
|         # check a,b uses different grad buffer
 | |
|         self.assertFalse(p_a == p_b)
 | |
|         # check one of them is using the computed buffer
 | |
|         self.assertTrue(p_a == p_g or p_b == p_g)
 | |
| 
 | |
|         # Run backwards multiple times to ensure accumulation works.
 | |
|         for i in range(10):
 | |
|             loss.backward(retain_graph=True)
 | |
| 
 | |
|         # non-contiguous indices and value, we should trigger a copy.
 | |
|         a.grad = b.grad = None
 | |
|         emb_matrix = NonContGradFunc.apply(a, b)
 | |
|         loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
 | |
|         loss.backward(retain_graph=True)
 | |
|         p_g = NonContGradFunc.static_grad_ptr
 | |
|         p_a = a.grad._values().data_ptr()
 | |
|         p_b = b.grad._values().data_ptr()
 | |
|         # check a,b uses different grad buffer
 | |
|         self.assertFalse(p_a == p_b)
 | |
|         # Verify we cloned both grads.
 | |
|         self.assertFalse(p_a == p_g)
 | |
|         self.assertFalse(p_b == p_g)
 | |
| 
 | |
|         # Run backwards multiple times to ensure accumulation works.
 | |
|         for i in range(10):
 | |
|             loss.backward(retain_graph=True)
 | |
| 
 | |
|     def test_gradcheck_single_input(self):
 | |
|         def check(fast_mode):
 | |
|             def f(inp):
 | |
|                 return inp.mul(5)
 | |
| 
 | |
|             gradcheck(f, torch.rand(10, dtype=torch.float64, requires_grad=True), fast_mode=fast_mode)
 | |
|             gradgradcheck(f, torch.rand(10, dtype=torch.float64, requires_grad=True), fast_mode=fast_mode)
 | |
|         check(fast_mode=True)
 | |
|         check(fast_mode=False)
 | |
| 
 | |
|     @parametrize('layout', (torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc))
 | |
|     def test_gradcheck_input(self, layout):
 | |
|         if layout in {torch.sparse_bsr, torch.sparse_bsc}:
 | |
|             blocksize = (2, 2)
 | |
|             size = (4, 8)
 | |
|         else:
 | |
|             blocksize = None
 | |
|             size = (2, 2)
 | |
| 
 | |
|         def check(fast_mode, masked):
 | |
|             def fn(sparse):
 | |
|                 return torch.sum(sparse)
 | |
| 
 | |
|             gradcheck(fn, torch.rand(size, dtype=torch.double).to_sparse(layout=layout, blocksize=blocksize).requires_grad_(),
 | |
|                       masked=masked, check_batched_grad=False, fast_mode=fast_mode)
 | |
| 
 | |
|         for fast_mode, masked in product(*[(True, False)] * 2):
 | |
|             check(fast_mode=fast_mode, masked=masked)
 | |
| 
 | |
|     def test_gradcheck_nondeterministic(self):
 | |
|         class NonDetFunc(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, jitter=0.0):
 | |
|                 ctx._jitter = jitter
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_out):
 | |
|                 return NonDetFunc.apply(grad_out, ctx._jitter) * (1 + torch.rand_like(grad_out) * ctx._jitter), None
 | |
| 
 | |
|         def check(fast_mode):
 | |
|             inp = torch.randn(5, 5, dtype=torch.double, requires_grad=True)
 | |
|             gradcheck(lambda x: NonDetFunc.apply(x, 0.0), inp, check_batched_grad=False, fast_mode=fast_mode)
 | |
|             with self.assertRaisesRegex(RuntimeError, 'Backward is not reentrant'):
 | |
|                 gradcheck(lambda x: NonDetFunc.apply(x, 1e-6), inp, check_batched_grad=False, fast_mode=fast_mode)
 | |
|             with self.assertRaisesRegex(RuntimeError, 'Backward is not reentrant'):
 | |
|                 gradgradcheck(lambda x: NonDetFunc.apply(x, 1e-12), inp, check_batched_grad=False, fast_mode=fast_mode)
 | |
|             gradcheck(lambda x: NonDetFunc.apply(x, 0.0), inp, nondet_tol=1e-5, check_batched_grad=False,
 | |
|                       fast_mode=fast_mode)
 | |
|             gradcheck(lambda x: NonDetFunc.apply(x, 1e-6), inp, nondet_tol=1e-5, check_batched_grad=False,
 | |
|                       fast_mode=fast_mode)
 | |
|             gradgradcheck(lambda x: NonDetFunc.apply(x, 1e-12), inp, nondet_tol=1e-5, check_batched_grad=False,
 | |
|                           fast_mode=fast_mode)
 | |
|         check(fast_mode=True)
 | |
|         check(fast_mode=False)
 | |
| 
 | |
|     def test_gradcheck_validates_inputs(self):
 | |
|         def check(fast_mode):
 | |
|             x = torch.rand(10, requires_grad=True).to_sparse()
 | |
|             self.assertTrue(gradcheck(lambda x: x.to_dense(), (x,), check_batched_grad=False,
 | |
|                                       atol=1e-1, fast_mode=fast_mode, masked=True))
 | |
|             self.assertFalse(gradcheck(lambda x: x.to_dense(), (x,), masked=False,
 | |
|                                        check_batched_grad=False, raise_exception=False, fast_mode=fast_mode))
 | |
|             self.assertTrue(gradcheck(lambda x: x.to_dense(masked_grad=False), (x,), masked=False,
 | |
|                                       atol=1e-1, check_batched_grad=False, raise_exception=False, fast_mode=fast_mode))
 | |
| 
 | |
|             # when none of the inputs require grad (always raises even if raise_exception=False)
 | |
|             x = torch.rand(10, requires_grad=False)
 | |
|             with self.assertRaisesRegex(ValueError, 'at least one input tensor to require gradient'):
 | |
|                 gradcheck(lambda x: x, (x,), raise_exception=False, fast_mode=fast_mode)
 | |
| 
 | |
|             # (warning) when inputs are not double precision
 | |
|             x = torch.ones(1, dtype=torch.float32, requires_grad=True)
 | |
|             with self.assertWarnsRegex(UserWarning, "Input #0 requires gradient and is not a double precision"):
 | |
|                 self.assertTrue(gradcheck(lambda x: x, (x,), atol=1e-1, fast_mode=fast_mode))
 | |
| 
 | |
|             # when layout is not mkldnn(aka has strides) and input has a dimension with stride 0. (always raises
 | |
|             # even if raise_exception=False)
 | |
|             x = torch.ones(1, dtype=torch.float64, requires_grad=True)
 | |
|             x = x.expand((2, 2))
 | |
|             with self.assertRaisesRegex(RuntimeError, 'The 0th input has a dimension with stride 0'):
 | |
|                 gradcheck(lambda x: x, (x,), raise_exception=False, fast_mode=fast_mode)
 | |
| 
 | |
|         check(fast_mode=True)
 | |
|         check(fast_mode=False)
 | |
| 
 | |
|     @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled")
 | |
|     def test_gradcheck_validates_input_mkldnn(self):
 | |
|         # when mkldnn inputs, forward mode testing is not allowed
 | |
|         # Update tolerances below to make sure the gradient match even in single precision floats
 | |
|         # Use the warning assert to hide the float32 warning
 | |
|         x = torch.ones(1).to_mkldnn().requires_grad_()
 | |
|         with self.assertWarnsRegex(UserWarning, "Input #0 requires gradient and is not a double precision"):
 | |
|             with self.assertRaisesRegex(ValueError, 'MKLDNN inputs are not support for forward AD gradcheck.'):
 | |
|                 gradcheck(lambda x: x.to_dense(), (x,), raise_exception=False, fast_mode=False, check_forward_ad=True,
 | |
|                           atol=1e-1, rtol=1e-1)
 | |
| 
 | |
|         with self.assertWarnsRegex(UserWarning, "Input #0 requires gradient and is not a double precision"):
 | |
|             with self.assertRaisesRegex(ValueError, 'MKLDNN inputs are not support for forward AD gradcheck.'):
 | |
|                 gradcheck(lambda x: x.to_dense(), (x,), raise_exception=False, fast_mode=True, check_forward_ad=True,
 | |
|                           atol=1e-1, rtol=1e-1)
 | |
| 
 | |
|     @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled")
 | |
|     def test_gradcheck_test_outputs(self):
 | |
|         def check(fast_mode):
 | |
|             # when sparse outputs (always raise even if raise_exception=False)
 | |
|             x = torch.rand(10, requires_grad=True).to_sparse()
 | |
|             with self.assertRaisesRegex(ValueError, 'Sparse output is not supported at gradcheck yet'):
 | |
|                 gradcheck(lambda x: x, (x,), masked=True, check_batched_grad=False, raise_exception=False,
 | |
|                           fast_mode=fast_mode)
 | |
| 
 | |
|             # when mkldnn outputs (always raise even if raise_exception=False)
 | |
|             root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True)
 | |
|             with self.assertRaisesRegex(ValueError, 'MKLDNN output is not supported at gradcheck yet'):
 | |
|                 gradcheck(lambda x: x.to_mkldnn(), (root,), check_batched_grad=False, raise_exception=False, fast_mode=fast_mode)
 | |
|         check(fast_mode=True)
 | |
|         check(fast_mode=False)
 | |
| 
 | |
|     def test_gradcheck_check_no_differentiable_outputs(self):
 | |
|         def check(fast_mode):
 | |
|             # When none of the outputs are differentiable, but numerical gradient is not zero
 | |
|             x = torch.ones((1,), requires_grad=True)
 | |
|             with self.assertRaisesRegex(RuntimeError, 'Numerical gradient for function expected to be zero'):
 | |
|                 gradcheck(lambda x: torch.tensor([x]), x)
 | |
|             self.assertFalse(gradcheck(lambda x: torch.tensor([x]), x, raise_exception=False, fast_mode=fast_mode))
 | |
| 
 | |
|             # succeed when no outputs at all
 | |
|             self.assertTrue(gradcheck(lambda x: (), (x,), fast_mode=fast_mode))
 | |
|         check(fast_mode=True)
 | |
|         check(fast_mode=False)
 | |
| 
 | |
|     def test_gradcheck_check_batched_grad(self):
 | |
|         def check(fast_mode):
 | |
|             x = torch.rand(10, dtype=torch.double, requires_grad=True).to_sparse()
 | |
|             # runtime error while compute batched grad (print big error)
 | |
|             with self.assertRaisesRegex(RuntimeError, 'gradcheck or gradgradcheck failed while testing batched gradient'):
 | |
|                 gradcheck(lambda x: x.to_dense(), (x,), masked=True, check_batched_grad=True, fast_mode=fast_mode)
 | |
|             self.assertFalse(gradcheck(lambda x: x.to_dense(), (x,), masked=True, check_batched_grad=True,
 | |
|                                        raise_exception=False, fast_mode=fast_mode))
 | |
|         check(fast_mode=True)
 | |
|         check(fast_mode=False)
 | |
| 
 | |
|     def test_gradcheck_backward_mul_by_grad_output(self):
 | |
|         # when grad_input is sparse and has incorrect sparse_dim/dense_dim
 | |
|         def check(fast_mode):
 | |
|             def fn(x):
 | |
|                 def hook(grad):
 | |
|                     if grad is not None:
 | |
|                         return grad.to_dense().to_sparse(1)
 | |
|                     return grad
 | |
|                 y = x.clone()
 | |
|                 y.register_hook(hook)
 | |
|                 return y.to_dense()
 | |
|             x = torch.ones((2, 2), dtype=torch.double, requires_grad=True).to_sparse()
 | |
|             with self.assertRaisesRegex(RuntimeError, 'grad is sparse tensor, but has incorrect sparse_dim'):
 | |
|                 gradcheck(fn, (x,), atol=1e-1, masked=True, check_batched_grad=False, fast_mode=fast_mode)
 | |
|             self.assertFalse(gradcheck(fn, (x,), atol=1e-1, masked=True, check_batched_grad=False,
 | |
|                                        raise_exception=False, fast_mode=fast_mode))
 | |
| 
 | |
|             # when backward not multiplied by grad_output (non-sparse case)
 | |
|             def fn2(x):
 | |
|                 y = x.clone()
 | |
|                 y.register_hook(lambda x: x + 1e-2)
 | |
|                 return y
 | |
|             x = torch.ones(1, dtype=torch.double, requires_grad=True)
 | |
|             with self.assertRaisesRegex(RuntimeError, 'backward not multiplied by grad_output'):
 | |
|                 gradcheck(fn2, (x,), atol=1e-1, fast_mode=fast_mode)
 | |
|             self.assertFalse(gradcheck(fn2, (x,), atol=1e-1, raise_exception=False, fast_mode=fast_mode))
 | |
| 
 | |
|             # when backward not multiplied by grad_output (sparse case)
 | |
|             def fn3(x):
 | |
|                 y = x.clone().to_dense()
 | |
|                 y.register_hook(lambda x: x + 1e-2)
 | |
|                 return y
 | |
|             x = torch.ones(1, dtype=torch.double, requires_grad=True).to_sparse()
 | |
|             with self.assertRaisesRegex(RuntimeError, 'backward not multiplied by grad_output'):
 | |
|                 gradcheck(fn3, (x,), atol=1e-1, masked=True, check_batched_grad=False, fast_mode=fast_mode)
 | |
|             self.assertFalse(gradcheck(fn3, (x,), atol=1e-1, masked=True, check_batched_grad=False,
 | |
|                                        raise_exception=False, fast_mode=fast_mode))
 | |
| 
 | |
|             # when layout of grad_input is not the same as input
 | |
|             class Test(Function):
 | |
|                 @staticmethod
 | |
|                 def forward(ctx, x):
 | |
|                     return x
 | |
| 
 | |
|                 @staticmethod
 | |
|                 def backward(ctx, x):
 | |
|                     return x.to_sparse()
 | |
|             x = torch.ones(1, dtype=torch.double, requires_grad=True)
 | |
|             with self.assertRaisesRegex(RuntimeError, 'grad is incorrect layout'):
 | |
|                 gradcheck(Test.apply, (x,), check_batched_grad=False, fast_mode=fast_mode)
 | |
|             self.assertFalse(gradcheck(Test.apply, (x,), check_batched_grad=False, raise_exception=False, fast_mode=fast_mode))
 | |
|         check(fast_mode=True)
 | |
|         check(fast_mode=False)
 | |
| 
 | |
|     def test_gradcheck_undefined_grad(self):
 | |
|         def check(fast_mode):
 | |
|             # when encounter runtime error while running backward
 | |
|             def fn(x):
 | |
|                 def hook(x):
 | |
|                     if x is None:
 | |
|                         raise RuntimeError("x is undefined")
 | |
|                 y = x.clone()
 | |
|                 y.register_hook(hook)
 | |
|                 return y
 | |
|             x = torch.ones(1, dtype=torch.double, requires_grad=True)
 | |
|             with self.assertWarnsRegex(UserWarning, "Backwards compatibility: New undefined gradient support checking feature"):
 | |
|                 with self.assertRaisesRegex(RuntimeError, 'Expected backward function to handle undefined output grads'):
 | |
|                     gradcheck(fn, (x,), fast_mode=fast_mode)
 | |
|                 self.assertFalse(gradcheck(fn, (x,), raise_exception=False, fast_mode=fast_mode))
 | |
|         check(fast_mode=True)
 | |
|         check(fast_mode=False)
 | |
| 
 | |
|     def test_gradcheck_jacobian_mismatch(self):
 | |
|         def check(fast_mode):
 | |
|             def fn(x):  # R -> R, C -> C
 | |
|                 y = x.clone()
 | |
|                 y.register_hook(lambda x: x + 1e-2)
 | |
|                 return y
 | |
|             x = torch.ones(2, 2, requires_grad=True)
 | |
|             with self.assertRaisesRegex(RuntimeError, 'Jacobian mismatch for output 0 with respect to input 0'):
 | |
|                 gradcheck(fn, (x,), fast_mode=fast_mode)
 | |
|             self.assertFalse(gradcheck(fn, (x,), raise_exception=False, fast_mode=fast_mode))
 | |
| 
 | |
|             x_c = torch.ones(2, 2, requires_grad=True, dtype=torch.complex128)
 | |
|             with self.assertRaisesRegex(RuntimeError, 'While considering the imaginary part of complex outputs only'):
 | |
|                 gradcheck(fn, (x_c,), fast_mode=False)
 | |
|             self.assertFalse(gradcheck(fn, (x_c,), raise_exception=False, fast_mode=False))
 | |
| 
 | |
|             def fn2(x):  # R -> C
 | |
|                 y = torch.complex(x, x)
 | |
|                 y.register_hook(lambda x: x + 1e-2)
 | |
|                 return y
 | |
|             x = torch.ones(2, 2, requires_grad=True)
 | |
|             with self.assertRaisesRegex(RuntimeError, 'While considering the imaginary part of complex outputs only'):
 | |
|                 gradcheck(fn2, (x,), fast_mode=False)
 | |
|             self.assertFalse(gradcheck(fn2, (x,), raise_exception=False, fast_mode=False))
 | |
| 
 | |
|             def fn3(x):  # C -> R
 | |
|                 y = torch.real(x)
 | |
|                 y.register_hook(lambda x: x + 1e-2)
 | |
|                 return y
 | |
|             with self.assertRaisesRegex(RuntimeError, 'Jacobian mismatch for output 0 with respect to input 0'):
 | |
|                 gradcheck(fn3, (x_c,), fast_mode=False)
 | |
|             self.assertFalse(gradcheck(fn3, (x_c,), raise_exception=False, fast_mode=False))
 | |
|         check(fast_mode=True)
 | |
|         check(fast_mode=False)
 | |
| 
 | |
|     def test_gradcheck_dense_and_sparse_inputs(self):
 | |
|         def check(fast_mode):
 | |
|             def fn(x, y):
 | |
|                 return x * y.coalesce().to_dense()
 | |
|             a = torch.rand(2, 2, dtype=torch.double, requires_grad=True)
 | |
|             b = torch.rand(2, 2, dtype=torch.double,).to_sparse().requires_grad_(True)
 | |
|             self.assertTrue(gradcheck(fn, (a, b), masked=True, check_batched_grad=False, fast_mode=fast_mode))
 | |
|         check(fast_mode=True)
 | |
|         check(fast_mode=False)
 | |
| 
 | |
|     @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled")
 | |
|     def test_gradcheck_multiple_mkldnn_inputs(self):
 | |
|         def check(fast_mode):
 | |
|             def fn(x, y):
 | |
|                 return x + y.to_dense()
 | |
|             a = torch.rand(10, requires_grad=True)
 | |
|             b = torch.rand(10, dtype=torch.float32).to_mkldnn().requires_grad_(True)
 | |
|             self.assertTrue(gradcheck(fn, (a, b), atol=1e-1, check_batched_grad=False, fast_mode=fast_mode))
 | |
| 
 | |
|             def fn2(x, y):
 | |
|                 return x.to_dense() + y.to_dense()
 | |
|             c = torch.rand(10, dtype=torch.float32).to_mkldnn().requires_grad_(True)
 | |
|             self.assertTrue(gradcheck(fn, (a, c), atol=1e-1, check_batched_grad=False, fast_mode=fast_mode))
 | |
|         check(fast_mode=True)
 | |
|         check(fast_mode=False)
 | |
| 
 | |
|     def test_gradcheck_output_shape_or_dtype_depend_on_values(self):
 | |
|         def check(fast_mode):
 | |
|             def fn(x):
 | |
|                 if torch.all(x >= 1):
 | |
|                     return torch.cat([x, x])
 | |
|                 else:
 | |
|                     return x
 | |
|             a = torch.ones(1, dtype=torch.double, requires_grad=True)
 | |
|             with self.assertRaisesRegex(AssertionError, 'return outputs with the same shape when inputs are perturbed'):
 | |
|                 self.assertTrue(gradcheck(fn, (a,), fast_mode=fast_mode))
 | |
| 
 | |
|             def fn2(x):
 | |
|                 if torch.all(x >= 1):
 | |
|                     return x.to(torch.float32)
 | |
|                 else:
 | |
|                     return x
 | |
|             with self.assertRaisesRegex(AssertionError, 'return outputs with the same dtype when inputs are perturbed'):
 | |
|                 self.assertTrue(gradcheck(fn2, (a,), fast_mode=fast_mode))
 | |
|         check(fast_mode=True)
 | |
|         check(fast_mode=False)
 | |
| 
 | |
|     def test_gradcheck_complex_non_complex_outputs(self):
 | |
|         def fn(x, y):
 | |
|             z = torch.complex(x, y)
 | |
|             return z, x + 1
 | |
|         a = torch.ones(2, 2, requires_grad=True, dtype=torch.float64)
 | |
|         b = torch.ones(2, 2, requires_grad=True, dtype=torch.float64)
 | |
|         self.assertTrue(gradcheck(fn, (a, b)))
 | |
| 
 | |
|         def fn2(z):
 | |
|             return z, torch.real(z)
 | |
|         c = torch.ones(2, 2, requires_grad=True, dtype=torch.complex128)
 | |
|         self.assertTrue(gradcheck(fn2, (c)))
 | |
| 
 | |
|     def test_gradcheck_get_numerical_jacobian(self):
 | |
|         # get_numerical_jacobian is deprecated and no longer used internally by gradcheck
 | |
|         from torch.autograd.gradcheck import get_numerical_jacobian
 | |
| 
 | |
|         def fn(inputs):
 | |
|             # get_numerical_jacobian requires fn to take inputs as a tuple
 | |
|             # and returns the jacobian wrt the first output
 | |
|             x = inputs[0]
 | |
|             y = inputs[1]
 | |
|             return 2 * x + y, x + 2 * y
 | |
|         a = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)
 | |
|         b = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)
 | |
| 
 | |
|         with self.assertWarnsRegex(UserWarning, "get_numerical_jacobian was part of PyTorch's private API"):
 | |
|             jacobian = get_numerical_jacobian(fn, (a, b), target=a, eps=1e-6)
 | |
|         self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double))
 | |
| 
 | |
|         with self.assertWarnsRegex(UserWarning, "get_numerical_jacobian was part of PyTorch's private API"):
 | |
|             jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6)
 | |
|         self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double))
 | |
|         self.assertEqual(jacobian[1], 1 * torch.eye(4, dtype=torch.double))
 | |
| 
 | |
|         with self.assertRaisesRegex(ValueError, "Expected grad_out to be 1.0"):
 | |
|             jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6, grad_out=2.0)
 | |
| 
 | |
|     def test_gradcheck_get_analytical_jacobian(self):
 | |
|         from torch.autograd.gradcheck import get_analytical_jacobian
 | |
| 
 | |
|         def fn(x, y):
 | |
|             return 2 * x + y, x + 2 * y
 | |
| 
 | |
|         a = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)
 | |
|         b = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)
 | |
| 
 | |
|         outputs = fn(a, b)
 | |
|         with self.assertWarnsRegex(UserWarning, "get_analytical_jacobian was part of PyTorch's private API"):
 | |
|             jacobians, reentrant, correct_grad_sizes, correct_grad_types = get_analytical_jacobian((a, b), outputs[0])
 | |
|         self.assertEqual(jacobians[0], 2 * torch.eye(4, dtype=torch.double))
 | |
|         self.assertEqual(jacobians[1], 1 * torch.eye(4, dtype=torch.double))
 | |
|         self.assertTrue(reentrant)
 | |
| 
 | |
|         class NonDetFunc(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, jitter=0.0):
 | |
|                 ctx._jitter = jitter
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_out):
 | |
|                 return NonDetFunc.apply(grad_out, ctx._jitter) * (1 + torch.rand_like(grad_out) * ctx._jitter), None
 | |
| 
 | |
|         outputs = NonDetFunc.apply(a, 1e-6)
 | |
|         with self.assertWarnsRegex(UserWarning, "get_analytical_jacobian was part of PyTorch's private API"):
 | |
|             jacobians, reentrant, correct_grad_sizes, correct_grad_types = get_analytical_jacobian((a,), outputs)
 | |
|         self.assertFalse(reentrant)
 | |
| 
 | |
|         with self.assertRaisesRegex(ValueError, "Expected grad_out to be 1.0"):
 | |
|             jacobians, _, _, _ = get_analytical_jacobian((a,), outputs, grad_out=2.0)
 | |
| 
 | |
|     def test_gradcheck_custom_error(self):
 | |
|         from torch.autograd.gradcheck import GradcheckError
 | |
| 
 | |
|         def check(fast_mode):
 | |
|             def fn(x):
 | |
|                 y = x.clone()
 | |
|                 y.register_hook(lambda x: x + 1e-2)
 | |
|                 return y
 | |
|             x = torch.ones(2, 2, requires_grad=True)
 | |
|             with self.assertRaisesRegex(GradcheckError, 'Jacobian mismatch for output 0 with respect to input 0'):
 | |
|                 gradcheck(fn, (x,), fast_mode=fast_mode)
 | |
|             with self.assertRaisesRegex(RuntimeError, 'Jacobian mismatch for output 0 with respect to input 0'):
 | |
|                 gradcheck(fn, (x,), fast_mode=fast_mode)
 | |
|             self.assertFalse(gradcheck(fn, (x,), raise_exception=False, fast_mode=fast_mode))
 | |
| 
 | |
|             def fn2(x):
 | |
|                 raise RuntimeError("Not a GradcheckError!")
 | |
|             # Checks that when raise_exception=False, non-GradcheckErrors are not caught by gradcheck
 | |
|             with self.assertRaisesRegex(RuntimeError, "Not a GradcheckError!"):
 | |
|                 gradcheck(fn2, (x,), fast_mode=fast_mode, raise_exception=False)
 | |
| 
 | |
|         check(fast_mode=True)
 | |
|         check(fast_mode=False)
 | |
| 
 | |
|     def test_gradcheck_forward_ad(self):
 | |
|         def fn(x, y):
 | |
|             return x + y, y
 | |
| 
 | |
|         def bad_fn(x, y):
 | |
|             # Hacky way to check if we're currently inside a forward ad level
 | |
|             is_running_forward_ad = fwAD._current_level >= 0
 | |
| 
 | |
|             if is_running_forward_ad:
 | |
|                 y_p, y_d = fwAD.unpack_dual(y)
 | |
|                 y = fwAD.make_dual(y_p, y_d * 1.1)
 | |
| 
 | |
|             return x + y, y
 | |
| 
 | |
|         err_msg = "Jacobian computed with forward mode mismatch for output 0 with respect to input 1"
 | |
| 
 | |
|         for fast_mode in [True, False]:
 | |
|             # Test for all inputs and outputs being real
 | |
|             x = torch.rand(2, dtype=torch.double, requires_grad=True)
 | |
|             y = torch.rand(2, dtype=torch.double, requires_grad=True)
 | |
| 
 | |
|             gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
 | |
|             with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|                 gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
 | |
| 
 | |
|             def basic_mul(x):
 | |
|                 return torch.view_as_real(torch.resolve_conj(x * 1j))
 | |
|             gradcheck(basic_mul, x, check_forward_ad=True, fast_mode=fast_mode)
 | |
| 
 | |
|             # Test for one input and one output being complex
 | |
|             x = torch.rand(2, dtype=torch.cdouble, requires_grad=True)
 | |
| 
 | |
|             gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
 | |
|             with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|                 gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
 | |
| 
 | |
|             # Test for all inputs and outputs being complex
 | |
|             y = torch.rand(2, dtype=torch.cdouble, requires_grad=True)
 | |
| 
 | |
|             gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
 | |
|             with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|                 gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
 | |
| 
 | |
|     def test_gradcheck_forward_ad_runs_with_no_requires_grad(self):
 | |
|         # Currently requires_grad is used as a easy way for gradcheck to know
 | |
|         # which inputs of the function are meant to be differentiable
 | |
|         # This test checks that when the inputs are passed to the function they should not have
 | |
|         # requires_grad=True even though they may have requires_grad=True when passed
 | |
|         # to gradcheck
 | |
|         class UserFn(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, y):
 | |
|                 if fwAD._current_level >= 0:
 | |
|                     self.assertFalse(x.requires_grad)
 | |
|                     self.assertFalse(y.requires_grad)
 | |
|                 return x.clone(), y.clone()
 | |
| 
 | |
|             @staticmethod
 | |
|             def jvp(ctx, x_t, y_t):
 | |
|                 return x_t, y_t
 | |
| 
 | |
|         x = torch.rand(2, dtype=torch.double, requires_grad=True)
 | |
|         y = torch.rand(2, dtype=torch.double, requires_grad=True)
 | |
| 
 | |
|         gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=False, check_backward_ad=False,
 | |
|                   check_batched_grad=False, check_batched_forward_grad=False)
 | |
| 
 | |
|         gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
 | |
|                   check_batched_grad=False, check_batched_forward_grad=False)
 | |
| 
 | |
|         gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
 | |
|                   check_batched_grad=False, check_batched_forward_grad=True)
 | |
| 
 | |
|         x = torch.rand(2, dtype=torch.double, requires_grad=True)
 | |
|         y = torch.rand(2, dtype=torch.double, requires_grad=False)
 | |
|         gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
 | |
|                   check_batched_grad=False, check_batched_forward_grad=True)
 | |
| 
 | |
|     def test_gradcheck_forward_ad_respects_requires_grad(self):
 | |
|         # Currently requires_grad is used as a easy way for gradcheck to know
 | |
|         # which inputs of the function are meant to be differentiable
 | |
|         jvp_count = [0]
 | |
| 
 | |
|         class UserFn(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, y):
 | |
|                 return x.clone(), y.clone()
 | |
| 
 | |
|             @staticmethod
 | |
|             def jvp(ctx, x_t, y_t):
 | |
|                 jvp_count[0] += 1
 | |
|                 return x_t, y_t
 | |
| 
 | |
|         # NB: In slow gradcheck we need to loop through numel times so use numel = 1 to ensure
 | |
|         #     that fast and slow have the same counts
 | |
|         x = torch.rand(1, dtype=torch.double, requires_grad=True)
 | |
|         y = torch.rand(1, dtype=torch.double, requires_grad=True)
 | |
|         gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=False, check_backward_ad=False,
 | |
|                   check_batched_grad=False, check_batched_forward_grad=False)
 | |
|         self.assertEqual(jvp_count[0], 2)  # (2) once per input
 | |
|         jvp_count = [0]
 | |
| 
 | |
|         gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
 | |
|                   check_batched_grad=False, check_batched_forward_grad=False)
 | |
|         self.assertEqual(jvp_count[0], 6)  # (+4): (once with normal ZT (+1), once with efficient ZT (+1)) for each input (x2)
 | |
|         jvp_count = [0]
 | |
| 
 | |
|         gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
 | |
|                   check_batched_grad=False, check_batched_forward_grad=True)
 | |
|         self.assertEqual(jvp_count[0], 12)  # (+6): (compute batch of 2 with vmap (+1), with a loop (+2)) for each input (x2)
 | |
|         jvp_count = [0]
 | |
| 
 | |
|         # Repeat the previous test except we mark one input with requires_grad=False
 | |
|         # NB: _test_undefined_forward_mode is only (+1), when function has single differentiable input, not (+2)!
 | |
|         #     Otherwise, other counts are halved.
 | |
|         x = torch.rand(1, dtype=torch.double, requires_grad=True)
 | |
|         y = torch.rand(1, dtype=torch.double, requires_grad=False)
 | |
|         gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
 | |
|                   check_batched_grad=False, check_batched_forward_grad=True)
 | |
|         self.assertEqual(jvp_count[0], 5)  # 1 + 1 + 3
 | |
| 
 | |
|     def test_gradcheck_check_forward_or_backward_only(self):
 | |
|         """Depending on settings for check_forward_ad and check_backward_ad, the
 | |
|         correct codepaths should be reached (or not reached)
 | |
|         """
 | |
|         fwd_fail_err_msg = "FAIL FWD"
 | |
|         bwd_fail_err_msg = "FAIL BWD"
 | |
| 
 | |
|         class UserFn(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, foo, fwd_bad, bwd_bad):
 | |
|                 ctx.fwd_bad = fwd_bad
 | |
|                 ctx.bwd_bad = bwd_bad
 | |
|                 return foo * 2
 | |
| 
 | |
|             @staticmethod
 | |
|             def vjp(ctx, gO):
 | |
|                 if ctx.bwd_bad:
 | |
|                     raise RuntimeError(bwd_fail_err_msg)
 | |
|                 else:
 | |
|                     return 2 * gO, None, None
 | |
| 
 | |
|             @staticmethod
 | |
|             def jvp(ctx, gI, _1, _2):
 | |
|                 if ctx.fwd_bad:
 | |
|                     raise RuntimeError(fwd_fail_err_msg)
 | |
|                 else:
 | |
|                     return 2 * gI
 | |
| 
 | |
|         for fast_mode in (True, False):
 | |
|             for check_forward_ad in (True, False):
 | |
|                 for check_backward_ad in (True, False):
 | |
|                     for fwd_bad in (True, False):
 | |
|                         for bwd_bad in (True, False):
 | |
|                             fwd_should_fail = fwd_bad and check_forward_ad
 | |
|                             bwd_should_fail = bwd_bad and check_backward_ad
 | |
| 
 | |
|                             def run():
 | |
|                                 gradcheck(UserFn.apply, (x, fwd_bad, bwd_bad), check_forward_ad=check_forward_ad,
 | |
|                                           check_backward_ad=check_backward_ad, check_undefined_grad=check_backward_ad,
 | |
|                                           check_batched_grad=check_backward_ad, fast_mode=fast_mode)
 | |
| 
 | |
|                             x = torch.rand(2, dtype=torch.double, requires_grad=True)
 | |
| 
 | |
|                             if not check_forward_ad and not check_backward_ad:
 | |
|                                 with self.assertRaisesRegex(AssertionError, "Expected at least one of"):
 | |
|                                     run()
 | |
|                                 continue
 | |
| 
 | |
|                             if not fwd_should_fail and not bwd_should_fail:
 | |
|                                 run()
 | |
|                             else:
 | |
|                                 # If both fail, backward AD failure "hides" forward AD failure
 | |
|                                 if fwd_should_fail:
 | |
|                                     fail_msg = fwd_fail_err_msg
 | |
|                                 if bwd_should_fail:
 | |
|                                     fail_msg = bwd_fail_err_msg
 | |
|                                 with self.assertRaisesRegex(RuntimeError, fail_msg):
 | |
|                                     run()
 | |
| 
 | |
|     def test_gradcheck_forward_ad_batched_grad(self):
 | |
|         x = torch.rand(2, dtype=torch.double, requires_grad=True)
 | |
| 
 | |
|         # multiple inputs and outputs with non-tensors inputs
 | |
|         def fn1(a: torch.Tensor, b: int):
 | |
|             return a.clone(), a + 1
 | |
|         gradcheck(fn1, (x, 1), check_forward_ad=True, check_backward_ad=False, check_batched_grad=False,
 | |
|                   check_undefined_grad=False, check_batched_forward_grad=True)
 | |
| 
 | |
|         # unrelated inputs: tangent for c is None
 | |
|         def fn2(a: torch.Tensor, c: torch.Tensor):
 | |
|             return a.clone()
 | |
|         gradcheck(fn2, (x, x.clone()), check_forward_ad=True, check_backward_ad=False, check_batched_grad=False,
 | |
|                   check_undefined_grad=False, check_batched_forward_grad=True)
 | |
| 
 | |
|         class Fn(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, foo):
 | |
|                 return foo * 2
 | |
| 
 | |
|             @staticmethod
 | |
|             def vjp(ctx, gO):
 | |
|                 return gO * 2
 | |
| 
 | |
|             @staticmethod
 | |
|             def jvp(ctx, gI):
 | |
|                 torch.randn_like(gI)
 | |
|                 return gI * 2
 | |
| 
 | |
|         msg = "vmap: We do not yet support calling random operations inside of vmap"
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             gradcheck(Fn.apply, (x,), check_forward_ad=True, check_batched_forward_grad=True)
 | |
| 
 | |
|     def test_version_counter(self):
 | |
|         x = torch.randn(1, 2)
 | |
| 
 | |
|         # In-place op bumps version
 | |
|         x_saved_version = x._version
 | |
|         x.add_(1).add_(1)
 | |
|         self.assertTrue(x._version > x_saved_version)
 | |
| 
 | |
|         # Differentiable view shares version counter
 | |
|         xz = x[:]
 | |
|         self.assertTrue(x._version == xz._version)
 | |
|         xz.add_(1)
 | |
|         self.assertTrue(x._version == xz._version)
 | |
| 
 | |
|         # `x.data = y` preserves version counter of `x`
 | |
|         x_saved_version = x._version
 | |
|         x.data = torch.randn(2, 3)
 | |
|         self.assertTrue(x._version == x_saved_version)
 | |
|         x.add_(1)
 | |
|         self.assertTrue(x._version > x_saved_version)
 | |
|         # Make sure `x` is still using the same version counter it shares with `xz`
 | |
|         self.assertTrue(x._version == xz._version)
 | |
| 
 | |
|         # In-place op on `xz` also updates version of `x`,
 | |
|         # because they share the version counter
 | |
|         xz.add_(1)
 | |
|         self.assertTrue(x._version == xz._version)
 | |
| 
 | |
|     def test_set_data_tensorimpl_type(self):
 | |
|         # Dense tensor has impl of type `TensorImpl`, while sparse tensor has impl
 | |
|         # of type `SparseTensorImpl`.
 | |
|         x = torch.randn(1, 2)
 | |
|         x_s = torch.sparse_coo_tensor(torch.zeros([1, 1]), torch.ones([1]))
 | |
|         with self.assertRaisesRegex(RuntimeError, 'incompatible tensor type'):
 | |
|             x.data = x_s
 | |
| 
 | |
|     def test_set_data_preserve_pyobj(self):
 | |
|         a = torch.randn(1, 2)
 | |
|         b = torch.randn(1, 2)
 | |
|         b_id_saved = id(b)
 | |
|         b.data = a
 | |
|         self.assertTrue(b_id_saved == id(b))
 | |
| 
 | |
|     def test_set_data_self_requires_grad(self):
 | |
|         a = torch.tensor(1.0, requires_grad=True)
 | |
|         b = torch.tensor(2.0)
 | |
|         c = torch.tensor(3, dtype=torch.int64)
 | |
|         a.data = b
 | |
|         with self.assertRaisesRegex(RuntimeError, 'must be floating point or complex dtype'):
 | |
|             a.data = c
 | |
| 
 | |
|     @unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows")
 | |
|     def test_thread_shutdown(self):
 | |
|         code = """import torch
 | |
| from torch.autograd import Function
 | |
| class MyFunction(Function):
 | |
|     @staticmethod
 | |
|     def forward(ctx, x):
 | |
|         return x
 | |
| 
 | |
|     @staticmethod
 | |
|     def backward(ctx, grad):
 | |
|         return grad
 | |
| 
 | |
| # Run on cuda if it is available to ensure that the worker thread
 | |
| # is properly initialized by the time we exit.
 | |
| device = "cuda" if torch.cuda.is_available() else "cpu"
 | |
| 
 | |
| for shape in [(1,), ()]:
 | |
|     v = torch.ones(shape, requires_grad=True, device=device)
 | |
|     MyFunction.apply(v).backward()
 | |
| """
 | |
|         s = TestCase.runWithPytorchAPIUsageStderr(code)
 | |
|         # The autograd engine creates worker threads only when GPU devices are present.
 | |
|         # So make sure that we do shutdown threads when we're testing cuda and make sure
 | |
|         # that there is no thread to shutdown when we're not using cuda.
 | |
|         if TEST_CUDA or torch.backends.mps.is_available():
 | |
|             self.assertRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown")
 | |
|         else:
 | |
|             self.assertNotRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown")
 | |
| 
 | |
|     @unittest.skipIf(IS_MACOS, "Fails with SIGBUS on macOS; https://github.com/pytorch/pytorch/issues/25941")
 | |
|     def test_deep_reentrant(self):
 | |
| 
 | |
|         class DeepReentrant(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 with torch.enable_grad():
 | |
|                     ctx.x = Variable(x.detach(), requires_grad=True)
 | |
|                     ctx.x = ctx.x - 1
 | |
|                 return ctx.x.detach()
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, x):
 | |
|                 if ctx.x < 0:
 | |
|                     return x
 | |
|                 with torch.enable_grad():
 | |
|                     DeepReentrant.apply(ctx.x).sum().backward()
 | |
|                 return x
 | |
| 
 | |
|         # Test stack overflow escape mechanism
 | |
|         v = torch.tensor(2000.0, requires_grad=True)
 | |
|         # This will cause stack overflow if reentrant calls are handled
 | |
|         # in the same thread recursively
 | |
|         DeepReentrant.apply(v).sum().backward()
 | |
| 
 | |
|         # Test stack overflow escape mechanism multiple times
 | |
|         # to ensure reusing workers in the pool works fine
 | |
|         v2 = torch.tensor(200.0, requires_grad=True)
 | |
|         DeepReentrant.apply(v2).sum().backward()
 | |
| 
 | |
|     def test_reentrant_priority(self):
 | |
|         order = []
 | |
| 
 | |
|         class MyFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, x):
 | |
|                 order.append("MyFunction")
 | |
|                 return x
 | |
| 
 | |
|         class Reentrant(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 with torch.enable_grad():
 | |
|                     ctx.x = Variable(x.detach(), requires_grad=True)
 | |
|                     ctx.x = ctx.x - 1
 | |
|                 return ctx.x.detach()
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, x):
 | |
|                 order.append("Reentrant")
 | |
|                 if ctx.x < 0:
 | |
|                     return x
 | |
|                 with torch.enable_grad():
 | |
|                     Reentrant.apply(ctx.x).backward()
 | |
|                 return x
 | |
| 
 | |
|         a = MyFunction.apply(torch.tensor(6.0, requires_grad=True))
 | |
|         b = Reentrant.apply(torch.tensor(9.0, requires_grad=True))
 | |
|         v = a * b
 | |
|         v.backward()
 | |
|         # The tasks for the Reentrant and MyFunction backward() will be added
 | |
|         # to the queue in the autograd engine at the same time. The backward
 | |
|         # for Reentrant will be executed first, which will then add other
 | |
|         # backward tasks to the queue. We want to ensure all the reentrant tasks
 | |
|         # are prioritized over the MyFunction backward task regardless of their
 | |
|         # sequence numbers
 | |
|         self.assertEqual(len(order), 11)
 | |
|         self.assertEqual(order.count("Reentrant"), 10)
 | |
|         self.assertEqual(order[-1], "MyFunction")
 | |
| 
 | |
| 
 | |
|     @slowTest
 | |
|     def test_checkpointing(self):
 | |
|         num_inp = 2000
 | |
|         nz_inp = 10
 | |
|         nz_out = 10
 | |
|         nz_bottleneck = 1000
 | |
| 
 | |
|         # small proxy network for some complex reasoning we want to do per input
 | |
|         module = nn.Sequential(
 | |
|             nn.Linear(nz_inp, nz_bottleneck),
 | |
|             nn.ReLU(),
 | |
|             nn.Linear(nz_bottleneck, nz_inp)
 | |
|         )
 | |
| 
 | |
|         feat_combined = []
 | |
|         for r in range(num_inp):
 | |
|             data_r = torch.empty(1, nz_inp)
 | |
|             data_r.uniform_()
 | |
|             data_r.requires_grad = True
 | |
|             feat_r = checkpoint(module, data_r, use_reentrant=True)
 | |
|             feat_combined.append(feat_r)
 | |
| 
 | |
|         # compute mean as a proxy for some joint reasoning
 | |
|         mean_combined = torch.stack(feat_combined).mean()
 | |
|         mean_combined.backward()
 | |
| 
 | |
|     def _test_checkpointing_non_reentrant_autocast(self, device_type):
 | |
|         for enabled in [True, False]:
 | |
|             def foo(x, y, z):
 | |
|                 # torch.mm is on autocast's list of ops that should run in
 | |
|                 # the autocast precision
 | |
|                 x = torch.mm(x, y)
 | |
|                 y = torch.mm(x, z)
 | |
|                 z = torch.mm(z, z)
 | |
|                 expected_dtype = (
 | |
|                     torch.float32 if not enabled else torch.bfloat16
 | |
|                 )
 | |
|                 self.assertEqual(expected_dtype, z.dtype)
 | |
|                 return z
 | |
| 
 | |
|             x = torch.randn(3, 3, requires_grad=True)
 | |
|             y = torch.randn(3, 3, requires_grad=True)
 | |
|             z = torch.randn(3, 3, requires_grad=True)
 | |
|             if device_type == 'cuda':
 | |
|                 x = x.cuda()
 | |
|                 y = y.cuda()
 | |
|                 z = z.cuda()
 | |
| 
 | |
|             with torch.autocast(enabled=enabled, device_type=device_type, dtype=torch.bfloat16):
 | |
|                 loss = checkpoint(foo, x, y, z, use_reentrant=False)
 | |
|                 loss = loss.sum()
 | |
| 
 | |
|             # Without saving + recasting the autocast type, would raise error in autograd
 | |
|             # about mismatched dtypes.
 | |
|             loss.backward()  # triggers recomputation to check it runs in bfloat
 | |
| 
 | |
|     def test_checkpointing_non_reentrant_autocast_cpu(self):
 | |
|         """
 | |
|         Test that autocast args such as the dtype are preserved during non-reentrant
 | |
|         checkpoint recomputation on CPU.
 | |
|         """
 | |
|         self._test_checkpointing_non_reentrant_autocast(device_type='cpu')
 | |
| 
 | |
|     @unittest.skipIf(
 | |
|         not torch.cuda.is_available() or not torch.cuda.is_bf16_supported(),
 | |
|         "Test requires CUDA bf16 support"
 | |
|     )
 | |
|     def test_checkpointing_non_reentrant_autocast_gpu(self):
 | |
|         """
 | |
|         Test that autocast args/kwargs such as the dtype are preserved during
 | |
|         non-reentrant checkpoint recomputation on GPU.
 | |
|         """
 | |
|         self._test_checkpointing_non_reentrant_autocast(device_type='cuda')
 | |
| 
 | |
|     @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
 | |
|     @slowTest
 | |
|     def test_checkpointing_without_reentrant_memory_savings(self):
 | |
|         class MyModel(nn.Module):
 | |
|             def __init__(self, n, use_checkpoint, use_reentrant):
 | |
|                 super().__init__()
 | |
|                 self.n = n
 | |
|                 self.use_checkpoint = use_checkpoint
 | |
|                 self.use_reentrant = use_reentrant
 | |
|                 self.layers = nn.ModuleList()
 | |
|                 for i in range(self.n):
 | |
|                     layer = nn.Sequential(
 | |
|                         nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256)
 | |
|                     )
 | |
|                     self.layers.append(layer)
 | |
|                 # pre-allocate the grad so that increased memory usage is mainly
 | |
|                 # due to activations.
 | |
|                 for layer in self.layers:
 | |
|                     for lin in layer:
 | |
|                         lin.weight.grad = torch.ones_like(lin.weight)
 | |
|                         lin.bias.grad = torch.ones_like(lin.bias)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 for i in range(self.n):
 | |
|                     if not self.use_checkpoint:
 | |
|                         x = self.layers[i](x)
 | |
|                     else:
 | |
|                         x = checkpoint(self.layers[i], x, use_reentrant=self.use_reentrant)
 | |
| 
 | |
|                 return x
 | |
| 
 | |
|         model_no_checkpoint = MyModel(8, use_checkpoint=False, use_reentrant=False).cuda()
 | |
|         model_reentrant_checkpoint = MyModel(8, use_checkpoint=True, use_reentrant=True).cuda()
 | |
|         model_no_reentrant_checkpoint = MyModel(8, use_checkpoint=True, use_reentrant=False).cuda()
 | |
| 
 | |
|         x = torch.randn(100, 256, requires_grad=True, device='cuda')
 | |
| 
 | |
|         torch.cuda.reset_peak_memory_stats()
 | |
|         loss = model_no_checkpoint(x.clone()).sum()
 | |
|         loss.backward()
 | |
|         mem_no_checkpoint = torch.cuda.max_memory_allocated()
 | |
| 
 | |
|         torch.cuda.reset_peak_memory_stats()
 | |
|         loss = model_reentrant_checkpoint(x.clone()).sum()
 | |
|         loss.backward()
 | |
|         mem_reentrant_checkpoint = torch.cuda.max_memory_allocated()
 | |
| 
 | |
|         torch.cuda.reset_peak_memory_stats()
 | |
|         loss = model_no_reentrant_checkpoint(x.clone()).sum()
 | |
|         loss.backward()
 | |
|         mem_no_reentrant_checkpoint = torch.cuda.max_memory_allocated()
 | |
| 
 | |
|         self.assertTrue(mem_reentrant_checkpoint < mem_no_checkpoint)
 | |
|         self.assertTrue(mem_no_reentrant_checkpoint < mem_no_checkpoint)
 | |
| 
 | |
|     def test_checkpointing_without_reentrant_custom_function_works(self):
 | |
|         msg = "torch.utils.checkpoint: unpack is being triggered for a tensor that was either"
 | |
| 
 | |
|         class MyFunc(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, y, z):
 | |
|                 w = x * y * z
 | |
|                 out = w + w
 | |
|                 ctx.save_for_backward(x, y, z, w, out)
 | |
|                 return out
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_out):
 | |
|                 x, y, z, w, out = ctx.saved_tensors
 | |
|                 # Accessing the saved Tensors a second time will raise because
 | |
|                 # recomputed tensors get cleared as soon as they are unpacked.
 | |
|                 # A recomputation is only triggered if your backward has a new
 | |
|                 # graph-task id.
 | |
|                 with self.assertRaisesRegex(RuntimeError, msg):
 | |
|                     x_2, y_2, z_2, w_2, out_2 = ctx.saved_tensors
 | |
|                 return x, y, z
 | |
| 
 | |
|         x = torch.tensor(1., requires_grad=True)
 | |
|         y = torch.tensor(2., requires_grad=True)
 | |
|         z = torch.tensor(3., requires_grad=True)
 | |
| 
 | |
|         def foo(x, y, z):
 | |
|             x = x * y * z
 | |
|             y = y * y * z
 | |
|             z = z * z
 | |
|             out = MyFunc.apply(x, y, z)
 | |
|             return out
 | |
| 
 | |
|         out = checkpoint(foo, x, y, z, use_reentrant=False)
 | |
|         out.sum().backward()
 | |
| 
 | |
|     def test_checkpointing_without_reentrant_with_context_fn(self):
 | |
|         class VerboseTorchDispatchMode(TorchDispatchMode):
 | |
|             def __init__(self):
 | |
|                 self.operators = []
 | |
| 
 | |
|             def __torch_dispatch__(self, func, types, args=(), kwargs=None):
 | |
|                 if kwargs is None:
 | |
|                     kwargs = {}
 | |
|                 self.operators.append(func.__name__)
 | |
|                 return func(*args, **kwargs)
 | |
| 
 | |
|         x = torch.tensor(1., requires_grad=True)
 | |
|         verbose_mode = VerboseTorchDispatchMode()
 | |
| 
 | |
|         def context_fn():
 | |
|             return verbose_mode, contextlib.nullcontext()
 | |
|         out = checkpoint(lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn)
 | |
|         self.assertEqual(verbose_mode.operators, ['exp.default'])
 | |
| 
 | |
|         verbose_mode.operators = []
 | |
| 
 | |
|         def context_fn():
 | |
|             return contextlib.nullcontext(), verbose_mode
 | |
|         out = checkpoint(lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn)
 | |
|         out.backward()
 | |
|         self.assertEqual(
 | |
|             verbose_mode.operators,
 | |
|             ['exp.default', 'detach.default', 'detach.default']
 | |
|         )
 | |
| 
 | |
|         with self.assertRaisesRegex(Exception, "only supported when use_reentrant=False"):
 | |
|             out = checkpoint(lambda x: x.sin(), x, use_reentrant=True, context_fn=context_fn)
 | |
| 
 | |
|     def test_checkpoint_warns_if_use_reentrant_not_passed_explcitly(self):
 | |
|         a = torch.randn(1, requires_grad=True)
 | |
| 
 | |
|         # Passing explicitly should not warn
 | |
|         with warnings.catch_warnings(record=True) as w:
 | |
|             checkpoint(lambda x: x, a, use_reentrant=False)
 | |
|         self.assertEqual(len(w), 0)
 | |
| 
 | |
|         # Not passing explicitly warns
 | |
|         with warnings.catch_warnings(record=True) as w:
 | |
|             checkpoint(lambda x: x, a)
 | |
|         self.assertEqual(len(w), 1)
 | |
|         self.assertIn(
 | |
|             "please pass in use_reentrant=True or use_reentrant=False explicitly",
 | |
|             str(w[0].message)
 | |
|         )
 | |
| 
 | |
|     def test_access_saved_tensor_twice_without_recomputation_works(self):
 | |
|         count = [0]
 | |
| 
 | |
|         def foo(a):
 | |
|             count[0] += 1
 | |
|             b = a * a
 | |
|             c = a * b
 | |
|             d = torch.exp(a)
 | |
|             return d
 | |
| 
 | |
|         a = torch.randn(5, requires_grad=True)
 | |
|         d = checkpoint(foo, a, use_reentrant=False)
 | |
|         self.assertEqual(count[0], 1)
 | |
|         # Recomputed variables only persist within a particular backward call.
 | |
|         # If _saved_result is accessed outside of a backward, it will trigger
 | |
|         # a recompute. And afterwards, those recomputed results are immediately
 | |
|         # cleared.
 | |
|         d.grad_fn._saved_result
 | |
|         self.assertEqual(count[0], 2)
 | |
|         # Second access will trigger another recompute
 | |
|         d.grad_fn._saved_result
 | |
|         self.assertEqual(count[0], 3)
 | |
|         # Backward clears the saved variable
 | |
|         d.sum().backward()
 | |
|         self.assertEqual(count[0], 4)
 | |
|         # Now it raises an error
 | |
|         with self.assertRaisesRegex(
 | |
|             RuntimeError,
 | |
|             "or directly access saved tensors after they have already been freed"
 | |
|         ):
 | |
|             d.grad_fn._saved_result
 | |
| 
 | |
|     @slowTest
 | |
|     @parametrize("input_requires_grad", [True, False])
 | |
|     def test_checkpointing_without_reentrant(self, input_requires_grad):
 | |
|         """
 | |
|         Basic test for checkpoint without reentrant autograd.
 | |
|         """
 | |
|         num_inp = 2000
 | |
|         nz_inp = 10
 | |
|         nz_out = 10
 | |
|         nz_bottleneck = 1000
 | |
| 
 | |
|         # small proxy network for some complex reasoning we want to do per input
 | |
|         module = nn.Sequential(
 | |
|             nn.Linear(nz_inp, nz_bottleneck),
 | |
|             nn.ReLU(),
 | |
|             nn.Linear(nz_bottleneck, nz_inp)
 | |
|         )
 | |
| 
 | |
|         # Module holder for testing activation checkpointing with no_reentrant
 | |
|         # supports kwargs.
 | |
|         class MyModule(nn.Module):
 | |
|             def __init__(self, mod):
 | |
|                 super().__init__()
 | |
|                 self.module = mod
 | |
| 
 | |
|             def forward(self, data):
 | |
|                 return self.module(data)
 | |
| 
 | |
|         module = MyModule(mod=module)
 | |
| 
 | |
|         # Run model with and without checkpointing and verify gradients are
 | |
|         # equivalent, regardless of if inputs require grads or not.
 | |
|         module_copy = deepcopy(module)
 | |
| 
 | |
|         feat_combined = []
 | |
|         feat_combined_no_checkpoint = []
 | |
|         for r in range(num_inp):
 | |
|             data_r = torch.empty(1, nz_inp)
 | |
|             data_r.uniform_()
 | |
|             data_r.requires_grad = input_requires_grad
 | |
|             data_r_copy = data_r.clone()
 | |
|             feat_r = checkpoint(module, data=data_r, use_reentrant=False)
 | |
|             feat_combined.append(feat_r)
 | |
|             feat_r_no_checkpoint = module_copy(data_r)
 | |
|             feat_combined_no_checkpoint.append(feat_r_no_checkpoint)
 | |
| 
 | |
| 
 | |
|         # compute mean as a proxy for some joint reasoning
 | |
|         mean_combined = torch.stack(feat_combined).mean()
 | |
|         mean_combined.backward()
 | |
|         mean_combined_no_checkpoint = torch.stack(feat_combined_no_checkpoint).mean()
 | |
|         mean_combined_no_checkpoint.backward()
 | |
| 
 | |
|         for checkpoint_param, param in zip(module.parameters(), module_copy.parameters()):
 | |
|             self.assertEqual(checkpoint_param.grad, param.grad)
 | |
| 
 | |
|     def test_checkpoint_valid_reset_on_error(self):
 | |
|         a = torch.randn(2, 2, requires_grad=True)
 | |
| 
 | |
|         with self.assertRaisesRegex(Exception, "Checkpointing is not compatible with .grad()"):
 | |
|             b = checkpoint(torch.exp, a, use_reentrant=True).sum()
 | |
|             torch.autograd.grad(b, (a,))
 | |
| 
 | |
|         c = checkpoint(torch.exp, a, use_reentrant=True).sum()
 | |
|         c.backward()
 | |
| 
 | |
|     @parametrize("use_reentrant", [True, False])
 | |
|     def test_checkpointing_without_reentrant_detached_tensor(self, use_reentrant):
 | |
|         class NoGradModule(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.linear = nn.Linear(2, 2, bias=False)
 | |
|                 self.lin2 = nn.Linear(2, 2, bias=False)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 with torch.no_grad():
 | |
|                     return self.lin2(self.linear(x))
 | |
| 
 | |
|         module = NoGradModule()
 | |
| 
 | |
|         err_ctx = (
 | |
|             self.assertRaisesRegex(
 | |
|                 RuntimeError,
 | |
|                 "none of output has requires_grad=True"
 | |
|             )
 | |
|             if use_reentrant
 | |
|             else contextlib.nullcontext()
 | |
|         )
 | |
| 
 | |
|         a = torch.randn(2, 2, requires_grad=True)
 | |
|         for _ in range(3):
 | |
|             with err_ctx:
 | |
|                 # out does not require grad
 | |
|                 out = checkpoint(module, a, use_reentrant=use_reentrant)
 | |
|                 # Make loss require grad, otherwise we would run into
 | |
|                 # "element 0 of tensors does not require grad and does not have a grad_fn"
 | |
|                 out += a
 | |
|                 out.sum().backward()
 | |
| 
 | |
|     def test_checkpointing_without_reentrant_correct_grad(self):
 | |
|         """
 | |
|         Verifies that correct gradients are calculated for checkpoint
 | |
|         without reentrant autograd, for both backward() and autograd.grad().
 | |
|         """
 | |
|         a = torch.randn(2, 2, requires_grad=True)
 | |
| 
 | |
|         b = torch.exp(a).sum()
 | |
|         b.backward()
 | |
|         b_grad = a.grad
 | |
| 
 | |
|         a.grad = None
 | |
|         c = checkpoint(torch.exp, a, use_reentrant=False).sum()
 | |
|         c.backward()
 | |
|         c_grad = a.grad
 | |
| 
 | |
|         a.grad = None
 | |
|         d = checkpoint(torch.exp, a, use_reentrant=False).sum()
 | |
|         d_grad, = torch.autograd.grad(d, (a,))
 | |
| 
 | |
|         self.assertEqual(b_grad, c_grad)
 | |
|         self.assertEqual(b_grad, d_grad)
 | |
| 
 | |
|     def test_checkpointing_without_reentrant_dataparallel(self):
 | |
|         """
 | |
|         Verifies gradient correctness when checkpoint without reentrant autograd
 | |
|         is used in conjunction with DataParallel.
 | |
|         """
 | |
|         class LinearModule(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.linear = nn.Linear(2, 2, bias=False)
 | |
| 
 | |
|             def forward(self, inp):
 | |
|                 return self.linear(inp)
 | |
| 
 | |
|         a = torch.randn(2, 2, requires_grad=True)
 | |
|         if torch.cuda.is_available():
 | |
|             a = a.cuda()
 | |
| 
 | |
|         model = LinearModule()
 | |
|         if torch.cuda.is_available():
 | |
|             model = model.cuda()
 | |
| 
 | |
|         b = deepcopy(model)(a).sum()
 | |
|         b.backward()
 | |
|         b_grad = a.grad
 | |
| 
 | |
|         a.grad = None
 | |
| 
 | |
|         module = torch.nn.DataParallel(deepcopy(model))
 | |
|         c = checkpoint(module, a, use_reentrant=False).sum()
 | |
|         c.backward()
 | |
|         c_grad = a.grad
 | |
| 
 | |
|         self.assertEqual(b_grad, c_grad)
 | |
| 
 | |
|     def test_checkpointing_without_reentrant_parameter_used_in_an_out(self):
 | |
|         """
 | |
|         Ensures that gradient hooks are only called once per tensor.
 | |
|         """
 | |
|         w = torch.randn(10, 10, requires_grad=True)
 | |
|         count = 0
 | |
| 
 | |
|         def hook(grad):
 | |
|             nonlocal count
 | |
|             count += 1
 | |
| 
 | |
|         w.register_hook(hook)
 | |
|         x = torch.rand(10, 10, requires_grad=True)
 | |
|         h = w * x  # Using w outside the checkpoint
 | |
|         out = checkpoint(lambda x: w * x, h, use_reentrant=False)  # Using w inside the checkpoint
 | |
| 
 | |
|         out.sum().backward()
 | |
|         # should only call hook once
 | |
|         self.assertEqual(count, 1)
 | |
| 
 | |
|     def test_checkpointing_without_reentrant_arbitrary_input_output(self):
 | |
|         """
 | |
|         Ensures checkpointing without reentrant autograd works with functions
 | |
|         with arbitrary input/output structures.
 | |
|         """
 | |
| 
 | |
|         class MyModel(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.layer = torch.nn.Linear(5, 5, bias=False)
 | |
| 
 | |
|             def forward(self, dict_input):
 | |
|                 tensor = dict_input["tensor"]
 | |
|                 return {
 | |
|                     "result": self.layer(tensor)
 | |
|                 }
 | |
| 
 | |
|         model_no_checkpoint = MyModel()
 | |
|         model_checkpoint_without_reentrant = deepcopy(model_no_checkpoint)
 | |
| 
 | |
|         inp = {
 | |
|             "tensor": torch.randn(5, 5)
 | |
|         }
 | |
| 
 | |
|         out_no_checkpoint = model_no_checkpoint(inp)["result"].sum()
 | |
| 
 | |
|         out_checkpoint = checkpoint(
 | |
|             model_checkpoint_without_reentrant,
 | |
|             inp,
 | |
|             use_reentrant=False
 | |
|         )["result"].sum()
 | |
| 
 | |
|         self.assertEqual(out_checkpoint, out_no_checkpoint)
 | |
| 
 | |
|         out_no_checkpoint.backward()
 | |
|         out_checkpoint.backward()
 | |
| 
 | |
|         for param, checkpoint_param in zip(model_no_checkpoint.parameters(), model_checkpoint_without_reentrant.parameters()):
 | |
|             self.assertEqual(param.grad, checkpoint_param.grad)
 | |
| 
 | |
|     def test_callback_adds_callback(self):
 | |
|         called = [0]
 | |
| 
 | |
|         def callback_final():
 | |
|             called[0] += 1
 | |
| 
 | |
|         def callback_adds_callback():
 | |
|             called[0] += 1
 | |
|             Variable._execution_engine.queue_callback(callback_final)
 | |
| 
 | |
|         class MyFunc(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, input):
 | |
|                 return input
 | |
| 
 | |
|             @staticmethod
 | |
|             @once_differentiable
 | |
|             def backward(ctx, grad):
 | |
|                 Variable._execution_engine.queue_callback(callback_adds_callback)
 | |
|                 return grad
 | |
| 
 | |
|         a = torch.rand((3, 3), requires_grad=True)
 | |
|         b = MyFunc.apply(a)
 | |
|         b.sum().backward()
 | |
| 
 | |
|         self.assertEqual(called[0], 2)
 | |
| 
 | |
|     def _test_reentrant_with_callbacks(self, install_callbacks_in_depths):
 | |
|         counter = {}
 | |
|         counter["inner"] = 0
 | |
|         counter["outer"] = 0
 | |
| 
 | |
|         def inc_inner_counter():
 | |
|             counter["inner"] += 1
 | |
| 
 | |
|         def inc_outer_counter():
 | |
|             counter["outer"] += 1
 | |
| 
 | |
|         class MyFunc(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, input):
 | |
|                 return input
 | |
| 
 | |
|             @staticmethod
 | |
|             @once_differentiable
 | |
|             def backward(ctx, input):
 | |
|                 if 1 in install_callbacks_in_depths:
 | |
|                     # Add a callback to execute.
 | |
|                     Variable._execution_engine.queue_callback(inc_inner_counter)
 | |
| 
 | |
|                 return input
 | |
| 
 | |
|         class MyReentrantFunc(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, input):
 | |
|                 return input
 | |
| 
 | |
|             @staticmethod
 | |
|             @once_differentiable
 | |
|             def backward(ctx, input):
 | |
|                 if 0 in install_callbacks_in_depths:
 | |
|                     # Add a callback to execute.
 | |
|                     Variable._execution_engine.queue_callback(inc_outer_counter)
 | |
|                 # Reentrant backward call.
 | |
|                 tmp_inp = input.detach().requires_grad_()
 | |
|                 with torch.enable_grad():
 | |
|                     tmp_out = (MyFunc.apply(tmp_inp)).sum()
 | |
|                 tmp_out.backward()
 | |
|                 return input
 | |
| 
 | |
|         t1 = torch.rand((3, 3), requires_grad=True)
 | |
|         t2 = MyReentrantFunc.apply(t1)
 | |
|         t3 = t2.sum()
 | |
|         torch.autograd.backward([t3])
 | |
| 
 | |
|         return counter
 | |
| 
 | |
|     def test_reentrant_with_callbacks_depth_0(self):
 | |
|         # Verify callback is called only once.
 | |
|         ret = self._test_reentrant_with_callbacks([0])
 | |
|         self.assertEqual(1, ret["outer"])
 | |
|         self.assertEqual(0, ret["inner"])
 | |
| 
 | |
|     def test_reentrant_with_callbacks_depth_1(self):
 | |
|         # Verify callback is called only once.
 | |
|         ret = self._test_reentrant_with_callbacks([1])
 | |
|         self.assertEqual(0, ret["outer"])
 | |
|         self.assertEqual(1, ret["inner"])
 | |
| 
 | |
|     def test_reentrant_with_callbacks_both_depths(self):
 | |
|         # Verify callback is called twice.
 | |
|         ret = self._test_reentrant_with_callbacks([0, 1])
 | |
|         self.assertEqual(1, ret["outer"])
 | |
|         self.assertEqual(1, ret["inner"])
 | |
| 
 | |
|     def test_reentrant_with_leaf_variable_hook(self):
 | |
|         handle = None
 | |
|         param = torch.rand(10, requires_grad=True)
 | |
| 
 | |
|         def add_gradient_penalty_to_grad(grad):
 | |
|             handle.remove()
 | |
|             old_param_grad = grad
 | |
|             param.grad = None
 | |
|             # Add some sort of gradient penalty by directly updating the gradients
 | |
|             with torch.enable_grad():
 | |
|                 g = grad.detach().requires_grad_()
 | |
|                 new_param = param.detach().requires_grad_()
 | |
|                 out = ((g * 2) + new_param).sum()
 | |
|                 out.backward()
 | |
|             res = g.grad + grad
 | |
|             param.grad = old_param_grad
 | |
|             return res
 | |
| 
 | |
|         handle = param.register_hook(add_gradient_penalty_to_grad)
 | |
|         # Forward pass
 | |
|         tmp = (param * param)
 | |
|         loss = tmp.sum()
 | |
|         # Compute the gradients
 | |
|         loss.backward()
 | |
| 
 | |
|     def test_reentrant_with_non_leaf_variable_hook(self):
 | |
|         handle = None
 | |
|         param = torch.rand(10, requires_grad=True)
 | |
| 
 | |
|         def manual_increase_gradient(grad):
 | |
|             handle.remove()
 | |
|             # Add some sort of gradient penalty by directly updating the gradients
 | |
|             with torch.enable_grad():
 | |
|                 g = grad.detach().requires_grad_()
 | |
|                 out = ((g * 2) + 5).sum()
 | |
|                 out.backward()
 | |
|             res = g.grad + grad
 | |
|             return res
 | |
| 
 | |
|         # Forward pass
 | |
|         tmp = (param * param)
 | |
|         handle = tmp.register_hook(manual_increase_gradient)
 | |
|         loss = tmp.sum()
 | |
|         # Compute the gradients
 | |
|         loss.backward()
 | |
|         self.assertEqual(param.grad, 6 * param)
 | |
| 
 | |
|     def test_grad_fn_attr_bindings(self):
 | |
|         # Check that the getter of each type returns what we want
 | |
|         # See `gen_autograd_functions.py` for how the getters are generated
 | |
|         #
 | |
|         # This test is only meant to check if the codegen'd bindings work
 | |
|         # Please help update this test if you update the names of any the fields we check!
 | |
|         #
 | |
|         a = torch.ones(1, requires_grad=True)
 | |
|         b = torch.zeros(1, requires_grad=True)
 | |
|         out1 = torch.stack([a, b], dim=0)
 | |
|         out2 = (a * 2) * b
 | |
|         # TODO: I don't think we have a backward saving a list of tensors
 | |
|         #       at the moment. It used to be stack, but for no reason...
 | |
|         #       see discussion in #84993
 | |
|         # self.assertEqual(out.grad_fn._saved_tensors, (a, b))              # TewnsorList -> Tuple[Tensor]
 | |
|         self.assertEqual(out2.grad_fn._saved_self, a * 2)
 | |
|         self.assertIsInstance(out2.grad_fn._saved_self, torch.Tensor)
 | |
|         self.assertIsInstance(out2.grad_fn._raw_saved_self, torch._C._autograd.SavedTensor)
 | |
|         self.assertEqual(out1.grad_fn._saved_dim, 0)                       # int64_t -> int
 | |
|         self.assertIsInstance(out1.grad_fn._saved_dim, int)
 | |
| 
 | |
|         out2.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
 | |
| 
 | |
|         out2.sum().backward()
 | |
|         with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
 | |
|             out2.grad_fn._saved_self
 | |
|         # TODO: interestingly, this only happens if indexing into a list grad_fn._raw_saved_tensors[0],
 | |
|         #       not when using a saved tensor, see discussion in #84993
 | |
|         # with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
 | |
|         #     out2.grad_fn._raw_saved_self
 | |
|         self.assertEqual(out1.grad_fn._saved_dim, 0)
 | |
| 
 | |
|         a = torch.ones(2, 2, requires_grad=True)
 | |
|         indices = torch.tensor([0, 1])
 | |
|         out = a[:, indices]
 | |
|         self.assertEqual(out.grad_fn._saved_indices, (None, indices))     # c10::List<c10::optional<Tensor>> -> Tuple[Tensor?]
 | |
|         self.assertIsInstance(out.grad_fn._saved_indices[1], torch.Tensor)
 | |
|         self.assertIsInstance(out.grad_fn._raw_saved_indices[1], torch._C._autograd.SavedTensor)
 | |
|         self.assertEqual(out.grad_fn._saved_self_sym_sizes, a.shape)          # SymIntArrayRef -> Tuple[SymInt]
 | |
|         self.assertIsInstance(out.grad_fn._saved_self_sym_sizes[0], int)
 | |
| 
 | |
|         out.grad_fn._raw_saved_indices[1].register_hooks(lambda x: x, lambda x: x)
 | |
|         with self.assertRaisesRegex(RuntimeError, "None is forbidden"):
 | |
|             out.grad_fn._raw_saved_indices[0].register_hooks(lambda x: x, lambda x: x)
 | |
| 
 | |
|         out = a.mean()
 | |
|         self.assertEqual(out.grad_fn._saved_self_sym_sizes, a.shape)          # IntArrayRef -> Tuple[int]
 | |
| 
 | |
|         a = torch.ones(2, 2, requires_grad=True)
 | |
|         out = a * a
 | |
|         out.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
 | |
|         out.sum().backward()
 | |
|         with self.assertRaisesRegex(RuntimeError, "after it has been freed"):
 | |
|             out.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
 | |
| 
 | |
|         a = torch.ones(1, 1, 2, requires_grad=True)
 | |
|         out = torch.nn.functional.interpolate(a, 4, mode="linear")
 | |
|         self.assertEqual(out.grad_fn._saved_output_size, (4,))            # c10::optional<IntArrayRef> -> int[]?
 | |
|         self.assertIsInstance(out.grad_fn._saved_output_size[0], int)
 | |
|         self.assertEqual(out.grad_fn._saved_align_corners, False)         # bool -> bool
 | |
|         self.assertIsInstance(out.grad_fn._saved_align_corners, bool)
 | |
|         if hasattr(out.grad_fn, '_saved_scale_factors'):
 | |
|             self.assertIsNone(out.grad_fn._saved_scale_factors)           # c10::optional<ArrayRef<double>> -> float[]?
 | |
|         else:
 | |
|             self.assertIsNone(out.grad_fn._saved_scales)                  # c10::optional<ArrayRef<double>> -> float[]?
 | |
| 
 | |
|         a = torch.ones(1, 1, 3, 3, requires_grad=True)
 | |
|         out = nn.Conv2d(1, 1, 3)(a)
 | |
|         self.assertEqual(out.grad_fn._saved_bias_sym_sizes_opt, (1,))     # c10::optional<SymIntArrayRef> -> SymInt[]?
 | |
|         out = nn.Conv2d(1, 1, 3, bias=False)(a)
 | |
|         # TODO: This is BAD! we converted a c10::nullopt into a (0,)
 | |
|         self.assertEqual(out.grad_fn._saved_bias_sym_sizes_opt, (0,))
 | |
| 
 | |
|         a = torch.ones(1, 3, 3, requires_grad=True)
 | |
|         out = torch.addbmm(a.squeeze(0), a, a)
 | |
|         self.assertEqual(out.grad_fn._saved_batch1_sym_argsize_0, 1)      # int64_t
 | |
|         self.assertEqual(out.grad_fn._saved_batch1_sym_argsize_1, 3)      # int64_t
 | |
| 
 | |
|         a = torch.ones(1, 1, 3, 3, requires_grad=True)
 | |
|         out = torch.nn.functional.unfold(a, 3)
 | |
|         self.assertEqual(out.grad_fn._saved_self_sym_argsize_minus_2, 3)  # SymInt
 | |
|         self.assertEqual(out.grad_fn._saved_self_sym_argsize_minus_1, 3)  # SymInt
 | |
| 
 | |
|         a = torch.ones(1, 1, 2, requires_grad=True)
 | |
|         out = torch.nn.functional.interpolate(a, scale_factor=0.5, mode="linear")
 | |
|         self.assertEqual(out.grad_fn._saved_scales, 0.5)
 | |
| 
 | |
|         a = torch.ones(2, 2, requires_grad=True)
 | |
|         out = torch.pdist(a, p=1)
 | |
|         self.assertEqual(out.grad_fn._saved_p, 1.)                        # double -> float
 | |
|         self.assertIsInstance(out.grad_fn._saved_p, float)
 | |
| 
 | |
|         a = torch.ones(1, 1, 2, requires_grad=True)
 | |
|         out = torch.logit(a, 1.)
 | |
|         self.assertEqual(out.grad_fn._saved_eps, 1.)                      # c10:optional<double> -> float?
 | |
|         self.assertIsInstance(out.grad_fn._saved_eps, float)
 | |
|         out = torch.logit(a)
 | |
|         self.assertIsNone(out.grad_fn._saved_eps)
 | |
| 
 | |
|         if torch._C.has_lapack:
 | |
|             a = torch.ones(1, 1, requires_grad=True)
 | |
|             q, r = torch.linalg.qr(a, mode="reduced")
 | |
|             self.assertEqual(q.grad_fn._saved_mode, "reduced")                # std::string -> str
 | |
| 
 | |
|         a = torch.tensor([1.], requires_grad=True)
 | |
|         out = torch.div(a, 2., rounding_mode="trunc")
 | |
|         self.assertEqual(out.grad_fn._saved_rounding_mode, "trunc")       # c10::optional<std::string> -> str?
 | |
|         out = torch.div(a, 2., rounding_mode=None)
 | |
|         self.assertIsNone(out.grad_fn._saved_rounding_mode)               # c10::optional<std::string> -> str?
 | |
| 
 | |
|         x = torch.zeros(5, requires_grad=True)
 | |
|         out = torch.threshold(x, threshold=(1 + 0j), value=(1 + 0j))
 | |
|         self.assertIsInstance(out.grad_fn._saved_threshold, complex)      # Scalar(complex double) -> complex
 | |
|         cfloat = torch.tensor(1 + 0j, dtype=torch.complex64)
 | |
|         out = torch.threshold(x, threshold=cfloat, value=(1 + 0j))
 | |
|         self.assertIsInstance(out.grad_fn._saved_threshold, complex)      # Scalar(complex float) -> complex
 | |
|         out = torch.threshold(x, threshold=1., value=1.)
 | |
|         self.assertIsInstance(out.grad_fn._saved_threshold, float)        # Scalar(floating point) -> float
 | |
|         out = torch.threshold(x, threshold=1, value=1)
 | |
|         self.assertIsInstance(out.grad_fn._saved_threshold, int)          # Scalar(integral) -> int
 | |
|         out = torch.threshold(x, threshold=False, value=False)
 | |
|         self.assertIsInstance(out.grad_fn._saved_threshold, bool)         # Scalar(bool) -> bool
 | |
| 
 | |
|         a = torch.ones(2, 2, requires_grad=True)
 | |
|         out = a.as_strided((3,), (1,), 1)
 | |
|         self.assertEqual(out.grad_fn._saved_storage_offset, 1)            # c10:optional<int64_t> -> int?
 | |
|         self.assertIsInstance(out.grad_fn._saved_storage_offset, int)
 | |
|         out = a.as_strided((3,), (1,))
 | |
|         self.assertIsNone(out.grad_fn._saved_storage_offset)
 | |
| 
 | |
|         a = torch.ones(2, requires_grad=True)
 | |
|         out = torch.tanh(a)
 | |
|         self.assertEqual(out, out.grad_fn._saved_result)                  # saved variable when output
 | |
| 
 | |
|         a = torch.randn(3, 5, requires_grad=True)
 | |
|         b = torch.tensor([1, 0, 4])
 | |
|         loss = nn.NLLLoss()
 | |
|         out = loss(a, b)
 | |
|         self.assertIsNone(out.grad_fn._saved_weight)
 | |
|         loss = nn.NLLLoss(weight=torch.ones((5,)))
 | |
|         out = loss(a, b)
 | |
|         self.assertEqual(out.grad_fn._saved_weight, torch.ones((5,)))     # c10:optional<Tensor> -> Tensor?
 | |
| 
 | |
|         out.sum().backward()
 | |
|         with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
 | |
|             out.grad_fn._saved_weight
 | |
| 
 | |
|         num_tensors = 3
 | |
|         input_tensors = [torch.ones(2, 2, requires_grad=True) for _ in range(num_tensors)]
 | |
|         scalars = [0.0 for _ in range(num_tensors)]                       # ArrayRef<Scalar> -> Tuple[Scalar, ...]
 | |
|         results = torch._foreach_maximum(input_tensors, scalars)
 | |
|         for t in results:
 | |
|             self.assertEqual(t.grad_fn._saved_scalars, scalars)
 | |
| 
 | |
| 
 | |
|     def test_cant_create_saved_tensors(self):
 | |
|         with self.assertRaisesRegex(RuntimeError, "Trying to create a SavedTensor object from Python is forbidden"):
 | |
|             torch.autograd.SavedTensor()
 | |
| 
 | |
|     def test_custom_function_saved_tensors(self):
 | |
|         def getFn(save=True):
 | |
|             class MyFn(Function):
 | |
|                 @staticmethod
 | |
|                 def forward(ctx, x):
 | |
|                     if save:
 | |
|                         ctx.save_for_backward(x, None)
 | |
|                     return x
 | |
| 
 | |
|                 @staticmethod
 | |
|                 def backward(ctx, g):
 | |
|                     return g
 | |
| 
 | |
|             return MyFn
 | |
| 
 | |
|         a = torch.randn(5, requires_grad=True)
 | |
| 
 | |
|         y = getFn(True).apply(a)
 | |
| 
 | |
|         self.assertEqual((a, None), y.grad_fn.saved_tensors)
 | |
|         saved = y.grad_fn._raw_saved_tensors
 | |
|         self.assertIsInstance(saved[0], torch._C._autograd.SavedTensor)
 | |
|         # We can't tell the underlying tensor is None without unpacking it
 | |
|         self.assertIsInstance(saved[1], torch._C._autograd.SavedTensor)
 | |
| 
 | |
|         # We catch that error when the user calls register_hooks on it
 | |
|         with self.assertRaisesRegex(RuntimeError, "None is forbidden"):
 | |
|             saved[1].register_hooks(lambda x: x, lambda x: x)
 | |
| 
 | |
|         with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
 | |
|             saved[0].register_hooks(lambda x: x)
 | |
|         with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
 | |
|             saved[0].register_hooks(1, 1)
 | |
|         saved[0].register_hooks(lambda x: x, lambda x: x)
 | |
|         with self.assertRaisesRegex(RuntimeError, "already been set"):
 | |
|             saved[0].register_hooks(lambda x: x, lambda x: x)
 | |
|         y.sum().backward()
 | |
| 
 | |
|         # Using a reference to the SavedTensor object after the
 | |
|         # saved variables have been released can lead to undefined behavior
 | |
|         del saved
 | |
|         with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
 | |
|             y.grad_fn._raw_saved_tensors
 | |
|         with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
 | |
|             y.grad_fn.saved_tensors
 | |
| 
 | |
|         y = getFn(False).apply(a)
 | |
|         self.assertEqual(y.grad_fn.saved_tensors, ())
 | |
|         self.assertEqual(y.grad_fn._raw_saved_tensors, ())
 | |
| 
 | |
|     def test_autograd_node_isinstance(self):
 | |
|         # Node is a "virtual" base class of codegen'd nodes. This means that
 | |
|         # isinstance and issubclass are overridden, but mro is unchanged
 | |
|         Node = torch.autograd.graph.Node
 | |
| 
 | |
|         a = torch.rand(3, 3, requires_grad=True)
 | |
|         b = a.exp()
 | |
| 
 | |
|         # Some nodes have codegened registrations to the torch._C._function module
 | |
|         self.assertIsInstance(b.grad_fn, Node)
 | |
|         self.assertTrue(issubclass(type(b.grad_fn), Node))
 | |
|         self.assertTrue(Node not in type(b.grad_fn).mro())
 | |
| 
 | |
|         # Other nodes have manual registrations to the torch._C._function module
 | |
|         self.assertNotIsInstance(torch._C._functions.AccumulateGrad, Node)
 | |
|         self.assertTrue(issubclass(torch._C._functions.AccumulateGrad, Node))
 | |
|         self.assertIsInstance(b.grad_fn.next_functions[0][0], Node)
 | |
|         self.assertTrue(issubclass(torch._C._functions.DelayedError, Node))
 | |
| 
 | |
|         # Special cases
 | |
|         self.assertNotIsInstance(None, Node)
 | |
|         self.assertNotIsInstance(1, Node)
 | |
|         self.assertNotIsInstance(Node, Node)
 | |
|         self.assertTrue(issubclass(Node, Node))
 | |
| 
 | |
|         # Custom function case
 | |
|         self.assertTrue(issubclass(torch.autograd.function.BackwardCFunction, Node))
 | |
| 
 | |
|         class Func(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 self.assertIsInstance(ctx, Node)
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, x):
 | |
|                 self.assertIsInstance(ctx, Node)
 | |
|                 return x
 | |
| 
 | |
|         out = Func.apply(a)
 | |
|         self.assertIsInstance(out.grad_fn, Node)
 | |
|         self.assertTrue(issubclass(type(out.grad_fn), Node))
 | |
|         self.assertTrue(Node not in type(out.grad_fn).mro())
 | |
|         out.sum().backward()
 | |
| 
 | |
|     def test_autograd_views_codegen(self):
 | |
|         # This is not necessarily the absolute correct behavior, but this is the current
 | |
|         # one. This test is here to make sure that any change to this behavior is detected
 | |
|         # and not silent. The TODOs below mark the places with unexpected behavior.
 | |
|         # Note that any change in these test will be BC-breaking and should be done carefully.
 | |
| 
 | |
|         # This test checks the behavior of two codegen functions (view_as and unbind)
 | |
|         # with respect to view tracking and inplace operation on the output.
 | |
| 
 | |
|         def run_test(grad_mode, requires_grad, is_view, should_raise_tuple):
 | |
|             def maybe_check_raise(fn, should_raise):
 | |
|                 self.assertTrue(should_raise is None or isinstance(should_raise, str))
 | |
|                 if should_raise is not None:
 | |
|                     with self.assertRaisesRegex(RuntimeError, should_raise):
 | |
|                         fn()
 | |
|                 else:
 | |
|                     fn()
 | |
| 
 | |
|             inp = torch.rand(2, requires_grad=requires_grad).clone()
 | |
|             with torch.set_grad_enabled(grad_mode):
 | |
|                 out = inp.view_as(inp)
 | |
|             # Are they differentiable views?
 | |
|             self.assertTrue(out._is_view() == is_view)
 | |
|             # Are inplace allowed?
 | |
|             maybe_check_raise(lambda: out.add_(1), should_raise_tuple[0])
 | |
| 
 | |
|             inp = torch.rand(2, requires_grad=requires_grad).clone()
 | |
|             with torch.set_grad_enabled(grad_mode):
 | |
|                 out = inp.unbind()
 | |
|             # Are they differentiable views?
 | |
|             self.assertTrue(out[0]._is_view() == is_view)
 | |
|             self.assertTrue(out[1]._is_view() == is_view)
 | |
|             # Are inplace allowed?
 | |
|             maybe_check_raise(lambda: out[0].add_(1), should_raise_tuple[1])
 | |
|             maybe_check_raise(lambda: out[1].add_(1), should_raise_tuple[2])
 | |
| 
 | |
|         # should_raise contains None if it should not raise
 | |
|         # should_raise contains a string of the error if it should raise
 | |
|         # The 3 elements are for view_as, first output of unbind and second output of unbind
 | |
|         run_test(grad_mode=True, requires_grad=False, is_view=True,
 | |
|                  should_raise_tuple=(None, None, None))
 | |
|         inp_change_err = "Output {} of UnbindBackward0 is a view and is being modified inplace."
 | |
|         run_test(grad_mode=True, requires_grad=True, is_view=True,
 | |
|                  should_raise_tuple=(None, inp_change_err.format("0"), inp_change_err.format("1")))
 | |
|         leaf_grad_err = "A view was created in no_grad mode and is being modified inplace"
 | |
|         run_test(grad_mode=False, requires_grad=True, is_view=True,
 | |
|                  should_raise_tuple=(leaf_grad_err, leaf_grad_err, leaf_grad_err))
 | |
|         run_test(grad_mode=False, requires_grad=False, is_view=True,
 | |
|                  should_raise_tuple=(None, None, None))
 | |
| 
 | |
|     def test_inplace_not_requires_grad(self):
 | |
|         class MyFn(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, inp):
 | |
|                 return inp.view_as(inp)
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 return grad
 | |
| 
 | |
|         # Original Tensor does not require grad
 | |
|         a = torch.rand(1, 2)
 | |
| 
 | |
|         # Tensor being written does require grad
 | |
|         b = torch.rand(1, requires_grad=True)
 | |
| 
 | |
|         # Take an invalid view on 'a' that should raise an error (warns during deprecation)
 | |
|         view_a = MyFn.apply(a)
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "This view was created inside a custom Function"):
 | |
|             view_a += b
 | |
| 
 | |
|         # Extra test for copy_ that is a manual implementation and could be easily
 | |
|         # forgotten when the codegen is updated (warns during deprecation)
 | |
|         a = torch.rand(1, 2)
 | |
|         b = torch.rand(1, requires_grad=True)
 | |
|         view_a = MyFn.apply(a)
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "This view was created inside a custom Function"):
 | |
|             view_a.copy_(b)
 | |
| 
 | |
|         # Functions that should throw must properly throw
 | |
|         a = torch.rand(1, 2)
 | |
|         b = torch.rand(1, requires_grad=True)
 | |
|         view_a = a.unbind()[0]
 | |
|         with self.assertRaisesRegex(RuntimeError, "This view is the output of a function that returns "
 | |
|                                                   "multiple views."):
 | |
|             view_a.copy_(b)
 | |
| 
 | |
|         # Sanity check that views that should work still work
 | |
|         a = torch.rand(1, 2)
 | |
|         b = torch.rand(1, requires_grad=True)
 | |
|         a.select(1, 0).copy_(b)
 | |
| 
 | |
|     def _do_test_autograd_simple_views_python(self, dtype):
 | |
|         # This is not necessarily the absolute correct behavior, but this is the current
 | |
|         # one. This test is here to make sure that any change to this behavior is detected
 | |
|         # and not silent. The TODOs below mark the places with unexpected behavior.
 | |
|         # Note that any change in these test will be BC-breaking and should be done carefully.
 | |
| 
 | |
|         # This checks the autograd.Function behavior when we return one or multiple outputs
 | |
|         # while one of these is an input, a view of an input or of a temporary tensor.
 | |
| 
 | |
|         # This indicator is used to track how many times the backward function was called
 | |
|         bw_called = [0]
 | |
|         # This indicator is used to check if the argument `ga` contains non-zero values
 | |
|         ga_nz = [False]
 | |
| 
 | |
|         class IdOneOutput(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, a, b, make_view):
 | |
|                 if make_view:
 | |
|                     a = a.narrow(0, 0, 2)
 | |
|                 else:
 | |
|                     a = a.clone()
 | |
|                 return a
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, ga):
 | |
|                 bw_called[0] += 1
 | |
|                 return ga, None, None
 | |
| 
 | |
|         class IdTwoOutput(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, a, b, make_view):
 | |
|                 if make_view:
 | |
|                     a = a.narrow(0, 0, 2)
 | |
|                 else:
 | |
|                     a = a.clone()
 | |
|                 return a, a + b
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, ga, gab):
 | |
|                 bw_called[0] += 1
 | |
|                 if ga.eq(0).all():
 | |
|                     ga_nz[0] = False
 | |
|                 else:
 | |
|                     ga_nz[0] = True
 | |
|                 return ga + gab, gab, None
 | |
| 
 | |
|         class ViewOfTemp(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, a, make_view):
 | |
|                 ctx.save_for_backward(a)
 | |
|                 if make_view:
 | |
|                     a = a.narrow(0, 0, 2)
 | |
|                 else:
 | |
|                     a = a.clone()
 | |
|                 b = a.clone()
 | |
|                 return b.select(0, 0)
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 bw_called[0] += 1
 | |
|                 a, = ctx.saved_tensors
 | |
|                 res = torch.zeros_like(a)
 | |
|                 res.select(0, 0).copy_(grad)
 | |
|                 return res, None
 | |
| 
 | |
|         fn_id_to_inplace_on_view_err_msg = {
 | |
|             "one_output": ("Output 0 of IdOneOutputBackward is a view and is being "
 | |
|                            "modified inplace. This view was created inside a custom Function"),
 | |
|             "two_output": ("Output 0 of IdTwoOutputBackward is a view and is being modified inplace."
 | |
|                            " This view is the output of a function that returns multiple views."),
 | |
|             "view_of_temp": ("Output 0 of ViewOfTempBackward is a view and is being "
 | |
|                              "modified inplace. This view was created inside a custom Function")
 | |
|         }
 | |
| 
 | |
|         for fn_id in ["one_output", "two_output", "view_of_temp"]:
 | |
|             for inplace in [True, False]:
 | |
|                 for make_view in [True, False]:
 | |
|                     # Used for special casing the tests below
 | |
|                     output_is_a_view = (make_view or fn_id == "view_of_temp")
 | |
| 
 | |
|                     def fn(a, b):
 | |
|                         # never modify a, b inplace for gracheck
 | |
|                         a = a.clone()
 | |
|                         b = b.clone()
 | |
|                         if fn_id == "two_output":
 | |
|                             tmp1, tmp2 = IdTwoOutput.apply(a, b, make_view)
 | |
|                             if inplace:
 | |
|                                 tmp1 += 3
 | |
|                                 tmp2 += 3
 | |
|                             else:
 | |
|                                 tmp1 = tmp1 + 3
 | |
|                                 tmp2 = tmp2 + 3
 | |
|                             tmp = tmp1 * tmp2
 | |
|                         else:
 | |
|                             if fn_id == "one_output":
 | |
|                                 tmp = IdOneOutput.apply(a, b, make_view)
 | |
|                             else:
 | |
|                                 tmp = ViewOfTemp.apply(a + b, make_view)
 | |
|                             if inplace:
 | |
|                                 tmp += 3
 | |
|                             else:
 | |
|                                 tmp = tmp + 3
 | |
| 
 | |
|                         return tmp.sum()
 | |
| 
 | |
|                     a = torch.ones(2, dtype=dtype, requires_grad=True)
 | |
|                     b = torch.ones(2, dtype=dtype, requires_grad=True)
 | |
| 
 | |
|                     err_msg = fn_id_to_inplace_on_view_err_msg[fn_id]
 | |
| 
 | |
|                     if not inplace or not output_is_a_view:
 | |
|                         gradcheck(fn, (a, b), check_batched_grad=False)
 | |
| 
 | |
|                     # Was the custom backward called properly
 | |
|                     bw_called[0] = 0
 | |
|                     ga_nz[0] = True  # For the case where the backward is called
 | |
| 
 | |
|                     if inplace and output_is_a_view:
 | |
|                         with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|                             fn(a, b)
 | |
|                     else:
 | |
|                         fn(a, b).abs().backward()
 | |
| 
 | |
|                     expected_called = 1
 | |
|                     expected_ga_nz = True
 | |
| 
 | |
|                     if output_is_a_view and inplace:
 | |
|                         expected_called = 0
 | |
| 
 | |
|                     self.assertTrue(bw_called[0] == expected_called)
 | |
|                     self.assertTrue(ga_nz[0] == expected_ga_nz)
 | |
| 
 | |
|     def test_autograd_simple_views_python(self):
 | |
|         self._do_test_autograd_simple_views_python(torch.double)
 | |
|         self._do_test_autograd_simple_views_python(torch.cdouble)
 | |
| 
 | |
|     def test_autograd_inplace_views_creation_meta(self):
 | |
|         # Tests creation_meta properly handled for inplace views
 | |
| 
 | |
|         class Func(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x.view_as(x)
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, x):
 | |
|                 return x
 | |
|         view_custom = Func.apply
 | |
| 
 | |
|         def run_test(fn, fn_type, grad_mode_view, grad_mode_iview, requires_grad, error1, error2):
 | |
|             # This test checks the behavior of inplace-view functions when
 | |
|             # the views are created in grad mode or not
 | |
|             base = torch.rand(2, 3, requires_grad=requires_grad).clone()
 | |
|             # 1. Create a view with `grad_mode=grad_mode_view`
 | |
|             with torch.set_grad_enabled(grad_mode_view):
 | |
|                 if fn_type == "multi_view":
 | |
|                     inp = base.unbind()[0]
 | |
|                 elif fn_type == "custom" :
 | |
|                     inp = view_custom(base)
 | |
|                 else:
 | |
|                     inp = base.view_as(base)
 | |
| 
 | |
|             # 2. Perform inplace view with `grad_mode=grad_mode_iview`
 | |
|             with torch.set_grad_enabled(grad_mode_iview):
 | |
|                 if error1 is not None:
 | |
|                     with self.assertRaisesRegex(RuntimeError, error1):
 | |
|                         fn(inp)
 | |
|                     return
 | |
|                 else:
 | |
|                     # If error is None, check that runs without error
 | |
|                     fn(inp)
 | |
|             # 3. Do inplace on the (new) view
 | |
|             if error2 is not None:
 | |
|                 with self.assertRaisesRegex(RuntimeError, error2):
 | |
|                     inp.add_(1)
 | |
|             else:
 | |
|                 # If error is None, check that runs without error
 | |
|                 inp.add_(1)
 | |
| 
 | |
|         no_grad_err = "A view was created in no_grad mode"
 | |
|         multi_view_err = "function that returns multiple views"
 | |
|         custom_err = "view was created inside a custom Function"
 | |
| 
 | |
|         def run_tests(fn):
 | |
|             for fn_type in ("normal", "multi_view", "custom"):
 | |
|                 for grad_mode_view in (True, False):
 | |
|                     for grad_mode_iview in (True, False):
 | |
|                         for requires_grad in (True, False):
 | |
|                             error1 = None  # expected error when we do inplace_view on original view
 | |
|                             error2 = None  # expected error when we do inplace on the resulting view
 | |
| 
 | |
|                             if requires_grad:
 | |
|                                 if not grad_mode_view and grad_mode_iview:
 | |
|                                     error1 = no_grad_err
 | |
|                                 if not grad_mode_view and not grad_mode_iview:
 | |
|                                     error2 = no_grad_err
 | |
| 
 | |
|                                 if fn_type == "multi_view":
 | |
|                                     if grad_mode_view and grad_mode_iview:
 | |
|                                         error1 = multi_view_err
 | |
|                                     if grad_mode_view and not grad_mode_iview:
 | |
|                                         error2 = multi_view_err
 | |
| 
 | |
|                                 if fn_type == "custom":
 | |
|                                     if grad_mode_view and grad_mode_iview:
 | |
|                                         error1 = custom_err
 | |
|                                     if grad_mode_view and not grad_mode_iview:
 | |
|                                         error2 = custom_err
 | |
| 
 | |
|                             run_test(fn, fn_type, grad_mode_view, grad_mode_iview, requires_grad, error1, error2)
 | |
| 
 | |
|         # This list was created by logging gen_inplace_or_view_type.py
 | |
|         #   detach_ is excluded for this test because it cannot be applied to
 | |
|         #   views and thus does not return a view
 | |
|         run_tests(lambda v: v.as_strided_((1, 0), (2, 2)))
 | |
|         run_tests(lambda v: v.transpose_(0, 0))
 | |
|         run_tests(lambda v: v.t_())
 | |
|         run_tests(lambda v: v.squeeze_(0))
 | |
|         run_tests(lambda v: v.unsqueeze_(0))
 | |
|         run_tests(lambda v: v.swapdims_(0, 0))
 | |
|         run_tests(lambda v: v.swapaxes_(0, 0))
 | |
| 
 | |
|     def test_autograd_inplace_view_of_view(self):
 | |
|         x = torch.zeros(2)
 | |
|         with torch.no_grad():
 | |
|             y = x.view(2)
 | |
|         y.requires_grad_(True)
 | |
|         z = y.view(2)
 | |
|         with self.assertRaisesRegex(RuntimeError, "a view of a view .* is being .* inside the no_grad block"):
 | |
|             z /= 2
 | |
| 
 | |
|         x = torch.zeros(2)
 | |
|         with torch.inference_mode():
 | |
|             y = x.view(2)
 | |
|         y.requires_grad_(True)
 | |
|         z = y.view(2)
 | |
|         with self.assertRaisesRegex(RuntimeError, "a view of a view .* is being .* inside the inference_mode"):
 | |
|             z /= 2
 | |
| 
 | |
|     # TODO This is not the correct behavior -
 | |
|     # See https://github.com/pytorch/pytorch/issues/49825#issuecomment-794466627
 | |
|     def test_autograd_inplace_views_cross_dtype(self):
 | |
|         # This test is here to make sure that any change to this behavior is detected
 | |
|         # and not silent. The TODOs below mark the places with unexpected behavior.
 | |
|         a_orig = torch.rand(3, 3, requires_grad=True, dtype=torch.complex64)
 | |
|         a = a_orig.clone()
 | |
|         b = torch.view_as_real(a)
 | |
|         b = b.transpose(0, 1)
 | |
|         b += 1
 | |
|         b.backward(torch.arange(0, 18, dtype=torch.float).view(3, 3, 2))
 | |
|         non_inplace_grad = a_orig.grad
 | |
| 
 | |
|         a_orig = torch.rand(3, 3, requires_grad=True, dtype=torch.complex64)
 | |
|         a = a_orig.clone()
 | |
|         b = torch.view_as_real(a)
 | |
|         b.transpose_(0, 1)
 | |
|         b += 1
 | |
|         b.backward(torch.arange(0, 18, dtype=torch.float).view(3, 3, 2))
 | |
|         inplace_grad = a_orig.grad
 | |
| 
 | |
|         # TODO: this is a bug!
 | |
|         # once this is fixed, it should have the transpose removed:
 | |
|         # self.assertEqual(non_inplace_grad, inplace_grad)
 | |
|         self.assertEqual(non_inplace_grad.T, inplace_grad)
 | |
| 
 | |
|     def test_autograd_multiple_views_python(self):
 | |
|         # This is not necessarily the absolute correct behavior, but this is the current
 | |
|         # one. This test is here to make sure that any change to this behavior is detected
 | |
|         # and not silent. The TODOs below mark the places with unexpected behavior.
 | |
|         # Note that any change in these test will be BC-breaking and should be done carefully.
 | |
| 
 | |
|         # This checks that multiples views in the forward are properly traced and how they
 | |
|         # behave with respect to inplace operations.
 | |
| 
 | |
|         # This indicator is used to track how many times the backward function was called
 | |
|         bw_called = [0]
 | |
| 
 | |
|         class ComplexView(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, a, idx):
 | |
|                 res = a.narrow(0, idx, 1)
 | |
|                 res = a.select(0, idx)
 | |
|                 ctx.save_for_backward(a)
 | |
|                 ctx.idx = idx
 | |
|                 return res
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 bw_called[0] += 1
 | |
|                 a, = ctx.saved_tensors
 | |
|                 res = torch.zeros_like(a)
 | |
|                 res.select(0, ctx.idx).copy_(grad)
 | |
|                 return res, None
 | |
| 
 | |
|         a = torch.ones(2, requires_grad=True)
 | |
|         idx = 1
 | |
| 
 | |
|         bw_called[0] = 0
 | |
|         out = ComplexView.apply(a.clone(), idx)
 | |
|         out.sum().backward()
 | |
|         self.assertTrue(bw_called[0] == 1)
 | |
| 
 | |
|         out = ComplexView.apply(a.clone(), idx)
 | |
|         with self.assertRaisesRegex(RuntimeError,
 | |
|                                     "Output 0 of ComplexViewBackward is a view and is being modified inplace"):
 | |
|             out += 1
 | |
| 
 | |
|     def test_autograd_python_custom_function_inplace(self):
 | |
|         # This is not necessarily the absolute correct behavior, but this is the current
 | |
|         # one. This test is here to make sure that any change to this behavior is detected
 | |
|         # and not silent. The TODOs below mark the places with unexpected behavior.
 | |
|         # Note that any change in these test will be BC-breaking and should be done carefully.
 | |
| 
 | |
|         # This test checks custom autograd.Function that perform inplace operations
 | |
| 
 | |
|         bw_called = [0]
 | |
| 
 | |
|         # I) Single output
 | |
|         class MyAdder(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, a, b):
 | |
|                 a.add_(b)
 | |
|                 ctx.mark_dirty(a)
 | |
|                 return a
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 bw_called[0] += 1
 | |
|                 return grad, grad
 | |
| 
 | |
| 
 | |
|         a = torch.ones(2, requires_grad=True)
 | |
|         b = torch.ones(2, requires_grad=True)
 | |
| 
 | |
|         # No extra inplace
 | |
|         c = MyAdder.apply(a.clone(), b)
 | |
|         c.sum().backward()
 | |
|         self.assertTrue(bw_called[0] == 1)
 | |
| 
 | |
|         # With extra inplace on the output
 | |
|         bw_called[0] = 0
 | |
|         c = MyAdder.apply(a.clone(), b)
 | |
|         c += 2
 | |
|         c.sum().backward()
 | |
|         self.assertTrue(bw_called[0] == 1)
 | |
| 
 | |
|         # The input is a view
 | |
|         bw_called[0] = 0
 | |
|         c = MyAdder.apply(a.clone().view_as(a), b)
 | |
|         c.sum().backward()
 | |
|         self.assertTrue(bw_called[0] == 1)
 | |
| 
 | |
|         # Should not give non-inputs to mark_dirty
 | |
|         class MyAdderBad(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, a, b):
 | |
|                 c = 3 * a
 | |
|                 c.add_(b)
 | |
|                 ctx.mark_dirty(c)
 | |
|                 return c
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 bw_called[0] += 1
 | |
|                 grad = 3 * grad
 | |
|                 return grad, grad
 | |
| 
 | |
|         a = torch.ones(2, requires_grad=True)
 | |
|         b = torch.ones(2, requires_grad=True)
 | |
| 
 | |
|         with warnings.catch_warnings(record=True) as w:
 | |
|             MyAdderBad.apply(a.clone(), b)
 | |
|         self.assertEqual(len(w), 1)
 | |
| 
 | |
|         # II) Multiple outputs
 | |
|         class MyBadAdder(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, a, b):
 | |
|                 a.add_(b)
 | |
|                 ctx.mark_dirty(a)
 | |
|                 return a, a + b
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, ga, gab):
 | |
|                 bw_called[0] += 1
 | |
|                 return ga + gab, ga + gab
 | |
| 
 | |
|         # No extra inplace
 | |
|         bw_called[0] = 0
 | |
|         c, d = MyBadAdder.apply(a.clone(), b)
 | |
|         (c * d).sum().backward()
 | |
|         self.assertTrue(bw_called[0] == 1)
 | |
| 
 | |
|         # With extra inplace on the output
 | |
|         bw_called[0] = 0
 | |
|         c, d = MyBadAdder.apply(a.clone(), b)
 | |
|         c += 2
 | |
|         (c * d).sum().backward()
 | |
|         self.assertTrue(bw_called[0] == 1)
 | |
| 
 | |
|         # The input is a view
 | |
|         inplace_on_view_err = "your Function modifies inplace an input that is a view of another Tensor"
 | |
|         with self.assertRaisesRegex(RuntimeError, inplace_on_view_err):
 | |
|             c, d = MyBadAdder.apply(a.clone().view_as(a), b)
 | |
| 
 | |
|         # III) Inplace + other op
 | |
|         class MyOutPlaceAdder(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, a, b):
 | |
|                 a.add_(b)
 | |
|                 ctx.mark_dirty(a)
 | |
|                 return a.clone(), a + b
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, ga, gab):
 | |
|                 bw_called[0] += 1
 | |
|                 return ga + gab, ga + 2 * gab
 | |
| 
 | |
|         # We don't reuse the input
 | |
|         def fn(a, b):
 | |
|             orig_a = a.clone().view_as(a)
 | |
|             c, d = MyOutPlaceAdder.apply(orig_a, b)
 | |
|             return (c * d).sum()
 | |
| 
 | |
|         bad_mark_dirty_err = "Some elements marked as dirty during the forward method were not returned as output."
 | |
|         with self.assertRaisesRegex(RuntimeError, bad_mark_dirty_err):
 | |
|             fn(a, b)
 | |
| 
 | |
|     def test_custom_function_mark_dirty_not_differentiable(self):
 | |
|         def get_custom_fn(jvp_err):
 | |
|             class InplaceMul(torch.autograd.Function):
 | |
|                 @staticmethod
 | |
|                 def forward(ctx, x):
 | |
|                     result = x.mul_(2)
 | |
|                     ctx.mark_dirty(result)
 | |
|                     return result
 | |
| 
 | |
|                 @staticmethod
 | |
|                 def backward(ctx, grad_output):
 | |
|                     pass
 | |
| 
 | |
|                 @staticmethod
 | |
|                 def jvp(ctx, x_t):
 | |
|                     if jvp_err:
 | |
|                         return x_t
 | |
|                     else:
 | |
|                         return x_t.mul_(2)
 | |
|             return InplaceMul
 | |
| 
 | |
|         for requires_grad, jvp_err in product([True, False], repeat=2):
 | |
|             InplaceMul = get_custom_fn(jvp_err)
 | |
|             # Make sure that tensor is always returned as-is if marked dirty
 | |
|             z = torch.tensor(1., requires_grad=requires_grad)
 | |
|             x = z.clone()
 | |
|             y = InplaceMul.apply(x)
 | |
|             self.assertTrue(x is y)
 | |
|             self.assertEqual(x, z * 2)
 | |
| 
 | |
|             # jvp must properly modify the input grad if mark_dirty is set
 | |
|             with fwAD.dual_level():
 | |
|                 x_tangent = torch.ones_like(x)
 | |
|                 x_dual = fwAD.make_dual(x, x_tangent)
 | |
| 
 | |
|                 if jvp_err:
 | |
|                     bad_mark_dirty_err = "jvp function must modify the corresponding gradient inplace"
 | |
|                     with self.assertRaisesRegex(RuntimeError, bad_mark_dirty_err):
 | |
|                         InplaceMul.apply(x_dual)
 | |
|                 else:
 | |
|                     out_dual = InplaceMul.apply(x_dual)
 | |
|                     _, out_tangent = fwAD.unpack_dual(out_dual)
 | |
|                     self.assertTrue(out_dual is x_dual)
 | |
|                     self.assertTrue(out_tangent is x_tangent)
 | |
| 
 | |
|     def test_named_tensor_for_complex_views(self):
 | |
|         names = ["batch", "height", "width", "complex"]
 | |
|         z = torch.ones((2, 1, 2, 2), requires_grad=True)
 | |
|         z_named = z.refine_names(*names)
 | |
|         z_complex = torch.view_as_complex(z_named.rename(None)).refine_names(*names[:-1])
 | |
|         z_complex.sum().abs().backward()
 | |
|         expected = torch.ones_like(z_complex).rename(None)
 | |
|         abs_1_1j = abs(1 + 1j)
 | |
|         expected.fill_(complex(abs_1_1j / 2, abs_1_1j / 2))
 | |
|         self.assertEqual(z.grad, torch.view_as_real(expected))
 | |
| 
 | |
|     def test_custom_function_return_view_in_nograd(self):
 | |
|         class Alias(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x[:]
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, gx):
 | |
|                 return gx
 | |
| 
 | |
|         inp = torch.rand(2, requires_grad=True)
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             output = Alias.apply(inp)
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             expected_output = inp[:]
 | |
| 
 | |
|         # Calling the custom function should operate as if we called an equivalent op
 | |
|         self.assertEqual(output.requires_grad, expected_output.requires_grad)
 | |
| 
 | |
|         # Check that in-place modification on view throws
 | |
|         leaf_grad_err = "A view was created in no_grad mode and is being modified inplace"
 | |
|         with self.assertRaisesRegex(RuntimeError, leaf_grad_err):
 | |
|             output.zero_()
 | |
| 
 | |
|     def test_grad_mode_restored_reentrant(self):
 | |
|         class MyFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, inp):
 | |
|                 return inp.clone()
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, go):
 | |
|                 original = torch._C.is_grad_enabled()
 | |
|                 with torch.enable_grad():
 | |
|                     self.assertTrue(torch._C.is_grad_enabled())
 | |
|                     foo = torch.rand(go.size(), requires_grad=True)
 | |
|                     grad, = torch.autograd.grad(
 | |
|                         foo ** 3, foo, grad_outputs=go
 | |
|                     )
 | |
|                     self.assertTrue(torch._C.is_grad_enabled())
 | |
|                 self.assertTrue(torch._C.is_grad_enabled() == original)
 | |
|                 return grad
 | |
| 
 | |
|         inp = torch.rand(3, requires_grad=True)
 | |
| 
 | |
|         # Case where original==False
 | |
|         MyFunction.apply(inp).sum().backward()
 | |
|         # Case where original==True
 | |
|         MyFunction.apply(inp).sum().backward(create_graph=True)
 | |
| 
 | |
|     def test_power_function(self):
 | |
|         a = torch.tensor([0., 0., 0.])
 | |
|         b = torch.tensor([-1., 0., 1.], requires_grad=True)
 | |
|         c = torch.sum(a**b)
 | |
|         c.backward()
 | |
|         self.assertEqual(b.grad, torch.tensor([-inf, 0., 0.]))
 | |
| 
 | |
|         s = 0
 | |
|         b = torch.tensor([-1., 0., 1.], requires_grad=True)
 | |
|         c = torch.sum(s**b)
 | |
|         c.backward()
 | |
|         self.assertEqual(b.grad, torch.tensor([-inf, 0., 0.]))
 | |
| 
 | |
|     def test_custom_function_error(self):
 | |
|         class BadFw(Function):
 | |
|             @staticmethod
 | |
|             def backward(ctx, foo):
 | |
|                 return foo
 | |
| 
 | |
|         class BadBw(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, foo):
 | |
|                 return foo.clone()
 | |
| 
 | |
|         class BadBw2(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, foo):
 | |
|                 return foo.clone()
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, foo):
 | |
|                 return foo
 | |
| 
 | |
|             @staticmethod
 | |
|             def vjp(ctx, foo):
 | |
|                 return foo
 | |
| 
 | |
|         class BadJvp(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, foo):
 | |
|                 return foo.clone()
 | |
| 
 | |
|         inp = torch.rand(1, requires_grad=True)
 | |
|         with self.assertRaisesRegex(NotImplementedError, "must implement the forward"):
 | |
|             BadFw.apply(inp)
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "must implement either the backward"):
 | |
|             BadBw.apply(inp).sum().backward()
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "Implementing both 'backward' and 'vjp'"):
 | |
|             BadBw2.apply(inp).sum().backward()
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "must implement the jvp function"):
 | |
|             with fwAD.dual_level():
 | |
|                 d = fwAD.make_dual(inp, torch.rand_like(inp))
 | |
|                 res = BadJvp.apply(d)
 | |
| 
 | |
|     def test_custom_function_forward_mode_view_checks(self):
 | |
|         flag_to_error = {
 | |
|             "ok": None,
 | |
|             "not_a_view": "jvp is not returning a view",
 | |
|             "not_a_view_of_inp": "jvp is not returning a view of the given",
 | |
|             "not_a_view_of_inp_base": "jvp is not returning a view of the same base",
 | |
|         }
 | |
| 
 | |
|         class ViewFn(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, foo, flag):
 | |
|                 ctx.flag = flag
 | |
|                 ctx.size = foo.size()
 | |
|                 return foo.narrow(0, 0, 2)
 | |
| 
 | |
|             @staticmethod
 | |
|             def vjp(ctx, gO):
 | |
|                 gI = gO.new_zeros(ctx.size)
 | |
|                 gI.narrow(0, 0, 2).copy_(gO)
 | |
|                 return gI, None
 | |
| 
 | |
|             @staticmethod
 | |
|             def jvp(ctx, gI, _):
 | |
|                 res = gI.narrow(0, 0, 2)
 | |
|                 if ctx.flag != "ok":
 | |
|                     # Break the view in the gradients!
 | |
|                     res = res.clone()
 | |
|                 if ctx.flag in ["not_a_view_of_inp", "not_a_view_of_inp_base"]:
 | |
|                     # Result should be a view, just of the wrong thing
 | |
|                     res = res.view_as(res)
 | |
|                 return res
 | |
| 
 | |
|         inp = torch.rand(4, 4, dtype=torch.double, requires_grad=True)
 | |
| 
 | |
|         for flag, msg in flag_to_error.items():
 | |
|             def test_fn(inp):
 | |
|                 if flag == "not_a_view_of_inp_base":
 | |
|                     inp = inp.view_as(inp)
 | |
|                 return ViewFn.apply(inp, flag)
 | |
| 
 | |
|             if msg is None:
 | |
|                 gradcheck(test_fn, inp, check_forward_ad=True)
 | |
|             else:
 | |
|                 with self.assertRaisesRegex(RuntimeError, msg):
 | |
|                     gradcheck(test_fn, inp, check_forward_ad=True)
 | |
| 
 | |
|     def test_custom_function_forward_mode_inplace_checks(self):
 | |
|         class InplaceFn(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, foo, flag):
 | |
|                 ctx.mark_dirty(foo)
 | |
|                 ctx.flag = flag
 | |
|                 foo.mul_(2)
 | |
|                 return foo
 | |
| 
 | |
|             @staticmethod
 | |
|             def vjp(ctx, gO):
 | |
|                 return 2 * gO, None
 | |
| 
 | |
|             @staticmethod
 | |
|             def jvp(ctx, gI, _):
 | |
|                 if ctx.flag:
 | |
|                     # Don't do the change inplace
 | |
|                     return 2 * gI
 | |
|                 else:
 | |
|                     gI.mul_(2)
 | |
|                     return gI
 | |
| 
 | |
|         inp = torch.rand(4, 4, dtype=torch.double, requires_grad=True)
 | |
| 
 | |
|         def test_fn(inp, flag):
 | |
|             inp = inp.clone()
 | |
|             return InplaceFn.apply(inp, flag)
 | |
| 
 | |
|         gradcheck(test_fn, (inp, False), check_forward_ad=True)
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "inplace custom Function is not modifying the forward mode gradients inplace"):
 | |
|             gradcheck(test_fn, (inp, True), check_forward_ad=True)
 | |
| 
 | |
|     def test_custom_function_forward_mode_wrong_formula(self):
 | |
|         class UserFn(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, foo, should_fail):
 | |
|                 ctx.should_fail = should_fail
 | |
|                 return foo * 2
 | |
| 
 | |
|             @staticmethod
 | |
|             def vjp(ctx, gO):
 | |
|                 return 2 * gO, None
 | |
| 
 | |
|             @staticmethod
 | |
|             def jvp(ctx, gI, _):
 | |
|                 if ctx.should_fail:
 | |
|                     # Wrong gradient formula
 | |
|                     return 3 * gI
 | |
|                 else:
 | |
|                     return 2 * gI
 | |
| 
 | |
|         inp = torch.rand(10, dtype=torch.double, requires_grad=True)
 | |
|         gradcheck(UserFn.apply, (inp, False), check_forward_ad=True)
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "Jacobian computed with forward mode mismatch for output 0"):
 | |
|             gradcheck(UserFn.apply, (inp, True), check_forward_ad=True)
 | |
| 
 | |
|     def test_custom_function_forward_mode_non_tensor_before_tensor_args(self):
 | |
|         class MyFn(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, nt, x, nt2, y):
 | |
|                 return x * 2 + y * 3
 | |
| 
 | |
|             @staticmethod
 | |
|             def jvp(ctx, nt, x_t, nt2, y_t):
 | |
|                 self.assertIsNone(nt)
 | |
|                 self.assertIsNone(nt2)
 | |
|                 return x_t * 2 + y_t * 3
 | |
| 
 | |
|         x = torch.tensor(1., dtype=torch.double)
 | |
|         t = torch.tensor(1., dtype=torch.double)
 | |
|         y = torch.tensor(1., dtype=torch.double)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             dual_x = fwAD.make_dual(x, t)
 | |
|             MyFn.apply(1, dual_x, 1, y)
 | |
| 
 | |
|         gradcheck(MyFn.apply, (1, x.requires_grad_(True), 1, y.requires_grad_(True)), check_forward_ad=True,
 | |
|                   check_backward_ad=False, check_batched_grad=False)
 | |
| 
 | |
|     def test_custom_function_forward_mode_forward_is_no_op(self):
 | |
|         error_regex = "A custom Function's forward is returning a view \\(or an input as-is\\)"
 | |
| 
 | |
|         return_lambdas = {
 | |
|             # If we return an input as-is in forward, that is treated
 | |
|             # as if self.view_as(self) is performed. If jvp returns x.view_as(x),
 | |
|             # this is OK.
 | |
|             "view_as": lambda x: x.view_as(x),
 | |
|             # Expect this to raise an error
 | |
|             "self": lambda x: x,
 | |
|             # Expect this to raise the same error
 | |
|             "mul_by_2": lambda x: x * 2,
 | |
|         }
 | |
| 
 | |
|         for k, fn in return_lambdas.items():
 | |
|             class MyFn(torch.autograd.Function):
 | |
|                 @staticmethod
 | |
|                 def forward(ctx, x, y):
 | |
|                     return x + y, x
 | |
| 
 | |
|                 @staticmethod
 | |
|                 def vjp(ctx, gO1, gO2):
 | |
|                     return gO1 + gO2, gO1
 | |
| 
 | |
|                 @staticmethod
 | |
|                 def jvp(ctx, x_t, y_t):
 | |
|                     return x_t + y_t, fn(x_t)
 | |
| 
 | |
|             a = torch.tensor(1., dtype=torch.double, requires_grad=True)
 | |
|             t = torch.tensor(1., dtype=torch.double)
 | |
|             b = torch.tensor(1., dtype=torch.double, requires_grad=True)
 | |
| 
 | |
|             c = torch.tensor(1., dtype=torch.double)
 | |
|             t2 = torch.tensor(1., dtype=torch.double)
 | |
|             d = torch.tensor(1., dtype=torch.double)
 | |
| 
 | |
|             with fwAD.dual_level():
 | |
|                 a_dual = fwAD.make_dual(a, t)
 | |
|                 c_dual = fwAD.make_dual(c, t2)
 | |
| 
 | |
|                 if k == "view_as":
 | |
|                     _, out2 = MyFn.apply(a_dual, b)
 | |
|                     self.assertTrue(fwAD.unpack_dual(out2).tangent._base is t)
 | |
| 
 | |
|                     _, out2 = MyFn.apply(c_dual, d)
 | |
|                     self.assertTrue(fwAD.unpack_dual(out2).tangent._base is t2)
 | |
|                 else:
 | |
|                     with self.assertRaisesRegex(RuntimeError, error_regex):
 | |
|                         MyFn.apply(a_dual, b)
 | |
| 
 | |
|                     with self.assertRaisesRegex(RuntimeError, error_regex):
 | |
|                         MyFn.apply(c_dual, d)
 | |
| 
 | |
|             if k == "view_as":
 | |
|                 gradcheck(MyFn.apply, (a, c), check_forward_ad=True)
 | |
|             else:
 | |
|                 with self.assertRaisesRegex(RuntimeError, error_regex):
 | |
|                     gradcheck(MyFn.apply, (a, c), check_forward_ad=True)
 | |
| 
 | |
|     def test_custom_function_save_for_forward(self):
 | |
|         class Func(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
 | |
|                 ctx.save_for_backward(x, y)
 | |
|                 ctx.save_for_forward(x, y)
 | |
|                 ctx.z = z
 | |
|                 ctx.prod = x * y
 | |
|                 return z * ctx.prod
 | |
| 
 | |
|             @staticmethod
 | |
|             def jvp(ctx, x_t, y_t, _):
 | |
|                 x_p, y_p = ctx.saved_tensors
 | |
|                 z = ctx.z
 | |
|                 return z * (y_p * x_t + x_p * y_t)
 | |
| 
 | |
|             @staticmethod
 | |
|             def vjp(ctx, grad_out):
 | |
|                 x, y = ctx.saved_tensors
 | |
|                 z = ctx.z
 | |
|                 return z * grad_out * y, z * grad_out * x, None
 | |
| 
 | |
|         a = torch.tensor(1., requires_grad=True, dtype=torch.double)
 | |
|         t = torch.tensor(1., dtype=torch.double)
 | |
|         b = torch.tensor(2., requires_grad=True, dtype=torch.double)
 | |
|         c = 4
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             a_dual = fwAD.make_dual(a, t)
 | |
|             out = Func.apply(a_dual, b, c)
 | |
|             out.backward()
 | |
| 
 | |
|         gradcheck(Func.apply, (a, b, c), check_forward_ad=True)
 | |
| 
 | |
|         # When saved for backward, but not saved for forward
 | |
|         class Func(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x: torch.Tensor):
 | |
|                 ctx.save_for_backward(x)
 | |
|                 return x.clone()
 | |
| 
 | |
|             @staticmethod
 | |
|             def jvp(ctx, x_t):
 | |
|                 self.assertEqual(len(ctx.saved_tensors), 0)
 | |
|                 return x_t
 | |
| 
 | |
|             @staticmethod
 | |
|             def vjp(ctx, grad_out):
 | |
|                 x, = ctx.saved_tensors
 | |
|                 self.assertEqual(len(ctx.saved_tensors), 1)
 | |
|                 return grad_out
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             a_dual = fwAD.make_dual(a, t)
 | |
|             out = Func.apply(a_dual)
 | |
|             out.backward()
 | |
| 
 | |
|         gradcheck(Func.apply, (a,), check_forward_ad=True)
 | |
| 
 | |
|     def test_custom_function_forward_mode_non_differentiable(self):
 | |
|         # returns differentiable type, marked non-differentiable
 | |
|         class Func(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, y):
 | |
|                 out = y.clone()
 | |
|                 ctx.mark_non_differentiable(out)
 | |
|                 return x.clone(), out
 | |
| 
 | |
|             @staticmethod
 | |
|             def jvp(ctx, x_tangent, y_tangent):
 | |
|                 return x_tangent, None
 | |
| 
 | |
|         x = torch.tensor(2.)
 | |
|         x_tangent = torch.tensor(1.)
 | |
|         y = torch.tensor(3.)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             x_dual = fwAD.make_dual(x, x_tangent)
 | |
|             _, out2_dual = Func.apply(x_dual, y)
 | |
|             self.assertEqual(fwAD.unpack_dual(out2_dual).tangent, None)
 | |
| 
 | |
|         y = torch.tensor(3)
 | |
| 
 | |
|         # returns non-differentiable type, NOT marked non-differentiable
 | |
|         class Func(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, y):
 | |
|                 return x.clone(), y.clone()
 | |
| 
 | |
|             @staticmethod
 | |
|             def jvp(ctx, x_tangent, y_tangent):
 | |
|                 self.assertIsNone(y_tangent)
 | |
|                 return x_tangent, None
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             x_dual = fwAD.make_dual(x, x_tangent)
 | |
|             _, out2_dual = Func.apply(x_dual, y)
 | |
|             self.assertEqual(fwAD.unpack_dual(out2_dual).tangent, None)
 | |
| 
 | |
|         class FuncWrong(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, y):
 | |
|                 out = y.clone()
 | |
|                 ctx.mark_non_differentiable(out)
 | |
|                 return x.clone(), out
 | |
| 
 | |
|             @staticmethod
 | |
|             def jvp(ctx, x_tangent, y_tangent):
 | |
|                 return x_tangent, x_tangent.clone()
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             x_dual = fwAD.make_dual(x, x_tangent)
 | |
|             with self.assertRaisesRegex(RuntimeError, "You should return None at that position instead"):
 | |
|                 FuncWrong.apply(x_dual, y)
 | |
| 
 | |
| 
 | |
|     def test_custom_function_local_inplace(self):
 | |
|         class MyFn(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, inp, inplace):
 | |
|                 view = inp.clone()[:3]
 | |
|                 if inplace:
 | |
|                     view += 2
 | |
|                 return view
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 return grad, None
 | |
| 
 | |
|         base = torch.rand(10, requires_grad=True)
 | |
| 
 | |
|         foo = MyFn.apply(base, False)
 | |
|         self.assertEqual(foo.grad_fn.__class__.__name__, "MyFnBackward")
 | |
| 
 | |
|         foo = MyFn.apply(base, True)
 | |
|         self.assertEqual(foo.grad_fn.__class__.__name__, "MyFnBackward")
 | |
| 
 | |
|     def test_integer_outputs(self):
 | |
|         inp = torch.rand(4, requires_grad=True)
 | |
| 
 | |
|         out = inp.argmax()
 | |
|         self.assertFalse(out.dtype.is_floating_point)
 | |
|         self.assertFalse(out.requires_grad)
 | |
| 
 | |
|         out = inp.argmin()
 | |
|         self.assertFalse(out.dtype.is_floating_point)
 | |
|         self.assertFalse(out.requires_grad)
 | |
| 
 | |
|         out = inp.argsort()
 | |
|         self.assertFalse(out.dtype.is_floating_point)
 | |
|         self.assertFalse(out.requires_grad)
 | |
| 
 | |
|         val = torch.rand((), requires_grad=True)
 | |
| 
 | |
|         out = torch.searchsorted(inp, val)
 | |
|         self.assertFalse(out.dtype.is_floating_point)
 | |
|         self.assertFalse(out.requires_grad)
 | |
| 
 | |
|         bins = torch.linspace(0, 1.0, steps=100, requires_grad=True)
 | |
|         vals = torch.rand(5, 5, requires_grad=True)
 | |
|         out = torch.bucketize(vals, bins)
 | |
|         self.assertFalse(out.dtype.is_floating_point)
 | |
|         self.assertFalse(out.requires_grad)
 | |
| 
 | |
|         val = torch.empty(5).requires_grad_()
 | |
|         out = val.count_nonzero()
 | |
|         self.assertFalse(out.requires_grad)
 | |
| 
 | |
|         def assert_only_first_requires_grad(res):
 | |
|             if not isinstance(res, tuple):
 | |
|                 res = (res,)
 | |
|             self.assertTrue(res[0].requires_grad)
 | |
|             for out in res[1:]:
 | |
|                 if out is not None:
 | |
|                     self.assertFalse(out.requires_grad)
 | |
| 
 | |
|         for sort in [True, False]:
 | |
|             for return_inverse in [True, False]:
 | |
|                 for return_counts in [True, False]:
 | |
|                     res = torch.unique(inp, sorted=sort, return_inverse=return_inverse,
 | |
|                                        return_counts=return_counts)
 | |
|                     assert_only_first_requires_grad(res)
 | |
| 
 | |
|                     res = torch.unique(inp, sorted=sort, return_inverse=return_inverse,
 | |
|                                        return_counts=return_counts, dim=0)
 | |
|                     assert_only_first_requires_grad(res)
 | |
| 
 | |
|                     res = torch.unique_consecutive(inp, return_inverse=return_inverse,
 | |
|                                                    return_counts=return_counts)
 | |
|                     assert_only_first_requires_grad(res)
 | |
| 
 | |
|                     res = torch.unique_consecutive(inp, return_inverse=return_inverse,
 | |
|                                                    return_counts=return_counts, dim=0)
 | |
|                     assert_only_first_requires_grad(res)
 | |
| 
 | |
|                     # Here we test the internal functions to make sure all of them are
 | |
|                     # covered on top of the public API
 | |
|                     res = torch._unique(inp, sorted=sort, return_inverse=return_inverse)
 | |
|                     assert_only_first_requires_grad(res)
 | |
| 
 | |
|                     # This looks public but is actually manually deleted from the
 | |
|                     # torch namespace in torch/functional.py
 | |
|                     res = torch._VF.unique_dim(inp, dim=0, sorted=sort, return_inverse=return_inverse,
 | |
|                                                return_counts=return_counts)
 | |
|                     assert_only_first_requires_grad(res)
 | |
| 
 | |
|                     # We don't test `unique_dim_consecutive` here.
 | |
|                     # It looks public but the python binding is actually manually disabled in
 | |
|                     # tools/autograd/gen_python_functions.py
 | |
| 
 | |
|                     res = torch._unique2(inp, sorted=sort, return_inverse=return_inverse,
 | |
|                                          return_counts=return_counts)
 | |
|                     assert_only_first_requires_grad(res)
 | |
| 
 | |
|     def test_custom_function_cycle(self):
 | |
|         class MyFn(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, metadata):
 | |
|                 x = x.clone()
 | |
|                 ctx.meta = metadata
 | |
|                 ctx.save_for_backward(x)
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, gO):
 | |
|                 x, = ctx.saved_tensors
 | |
|                 self.assertEqual(x, 3.14)
 | |
|                 self.assertEqual(ctx.meta["foo"], 3.14)
 | |
|                 return gO * x, None
 | |
| 
 | |
|         def get_refs(with_backward):
 | |
|             a = torch.tensor(3.14, requires_grad=True)
 | |
| 
 | |
|             metadata = {}
 | |
|             out = MyFn.apply(a, metadata)
 | |
| 
 | |
|             metadata["foo"] = out
 | |
| 
 | |
|             if with_backward:
 | |
|                 out.sum().backward()
 | |
|                 self.assertEqual(a.grad, a)
 | |
| 
 | |
|             return torch._C._WeakTensorRef(out)
 | |
| 
 | |
|         with disable_gc():
 | |
|             ref = get_refs(False)
 | |
|             self.assertFalse(ref.expired())
 | |
|         gc.collect()
 | |
|         self.assertTrue(ref.expired())
 | |
| 
 | |
|         # The backward clears the saved_variables but not the __dict__
 | |
|         with disable_gc():
 | |
|             ref = get_refs(True)
 | |
|             self.assertFalse(ref.expired())
 | |
|         gc.collect()
 | |
|         self.assertTrue(ref.expired())
 | |
| 
 | |
|     def test_create_graph_and_full_backward_hook_cycle(self):
 | |
|         # If BackwardHook saves grad_output, it can create a cycle when we perform backward
 | |
|         # with create_graph=True
 | |
|         #
 | |
|         #   grad_output -> grad_output.grad_fn -> graph -> hook -> grad_output
 | |
|         #
 | |
|         class TestCls():
 | |
|             # Dummy class for the purpose of creating a weakref
 | |
|             pass
 | |
| 
 | |
|         def get_ref(input_requires_grad, nb_hooks):
 | |
|             t = torch.randn(10, requires_grad=input_requires_grad)
 | |
|             a = torch.tensor(1., requires_grad=True)
 | |
| 
 | |
|             class Test(nn.Module):
 | |
|                 def forward(self, x):
 | |
|                     return x ** 2 * a ** 2
 | |
|             mod = Test()
 | |
| 
 | |
|             for _ in range(nb_hooks):
 | |
|                 mod.register_full_backward_hook(lambda a, b, c: None)
 | |
| 
 | |
|             tmp = mod(t)
 | |
| 
 | |
|             # Save dummy object to graph and get a weak ref to it
 | |
|             test = TestCls()
 | |
|             ref = weakref.ref(test)
 | |
|             tmp.grad_fn.metadata["a"] = test
 | |
| 
 | |
|             with set_warn_always_context(True):
 | |
|                 with warnings.catch_warnings(record=True) as w:
 | |
|                     tmp.exp().sum().backward(create_graph=True)
 | |
|                     self.assertTrue(len(w) == 1)
 | |
|                     self.assertTrue("Using backward() with create_graph=True" in str(w[0].message))
 | |
| 
 | |
|             # Remove the backward + create_graph=True cycle
 | |
|             a.grad = None
 | |
|             t.grad = None
 | |
| 
 | |
|             return ref
 | |
| 
 | |
|         for nb_hooks in (1, 2, 3):
 | |
|             for input_requires_grad in (True, False):
 | |
|                 ref_ = get_ref(
 | |
|                     input_requires_grad=input_requires_grad,
 | |
|                     nb_hooks=nb_hooks,
 | |
|                 )
 | |
|                 gc.collect()
 | |
|                 self.assertIsNone(ref_())
 | |
| 
 | |
|     @parametrize("use_custom_function", [True, False])
 | |
|     @parametrize("use_tensor_hook", [True, False])
 | |
|     def test_hook_closure_cycle(self, use_custom_function, use_tensor_hook):
 | |
|         # This creates a cycle between the hook and grad_fn_b
 | |
|         # hook -> closure -> grad_fn_b (python) -> grad_fn (cpp) -> hook (cpp)
 | |
|         # -> dict -> hook
 | |
|         #
 | |
|         # This test is testing that the grad_fn_b (python) only traverses the
 | |
|         # dict if it is the only one holding a reference to the grad_fn_b (cpp)
 | |
|         # shared_ptr
 | |
|         #
 | |
|         # See: https://github.com/pytorch/pytorch/issues/102174
 | |
|         class Function(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 return grad
 | |
| 
 | |
|         class Test():
 | |
|             pass
 | |
| 
 | |
|         count = [0]
 | |
| 
 | |
|         def scope():
 | |
|             a = torch.tensor(1., requires_grad=True)
 | |
|             if use_custom_function:
 | |
|                 b = Function.apply(a)
 | |
|             else:
 | |
|                 b = a.clone()
 | |
|             grad_fn_b = b.grad_fn
 | |
|             obj = Test()
 | |
| 
 | |
|             def hook(*args):
 | |
|                 # Make sure this hook's closure holds onto grad_fn_b
 | |
|                 # This forms a cycle between the hook and grad_fn_b
 | |
|                 # We also hold onto a sentinel object 'obj' to track
 | |
|                 # whether this cycle is still alive. See 'ref' below.
 | |
|                 grad_fn_b
 | |
|                 obj
 | |
|                 count[0] += 1
 | |
|             if use_tensor_hook:
 | |
|                 b.register_hook(hook)
 | |
|             else:
 | |
|                 b.grad_fn.register_hook(hook)
 | |
|             c = b.clone()
 | |
|             ref = weakref.ref(obj)
 | |
|             return c, ref
 | |
| 
 | |
|         with disable_gc():
 | |
|             out, ref = scope()
 | |
|             out.backward(retain_graph=True)
 | |
| 
 | |
|             gc.collect()
 | |
| 
 | |
|             # Make sure gc does not clear the cycle noted above.
 | |
|             # e.g. the hook is alive and gets fired even after gc runs
 | |
|             out.backward(retain_graph=True)
 | |
|             self.assertEqual(count[0], 2)
 | |
| 
 | |
|             # ref is still alive because the use_count of the cpp grad_fn
 | |
|             # shared_ptr > 1 since (1) the python grad_fn is alive, and (2) the
 | |
|             # rest of the graph holds onto the shared_ptr
 | |
|             self.assertIsNotNone(ref())
 | |
| 
 | |
|             # Then delete the rest of the graph and check that ref is dead
 | |
|             del out
 | |
|             gc.collect()
 | |
|             self.assertIsNone(ref())
 | |
| 
 | |
|     def test_full_backward_hook_double_backward(self):
 | |
|         x = torch.rand(1, requires_grad=True)
 | |
|         y = torch.rand_like(x)
 | |
| 
 | |
|         func = torch.nn.MSELoss()
 | |
|         counter = [0]
 | |
| 
 | |
|         def hook(module, grad_input, grad_output):
 | |
|             counter[0] += 1
 | |
| 
 | |
|         func.register_full_backward_hook(hook)
 | |
| 
 | |
|         f = func(x, y)
 | |
| 
 | |
|         (gradx_f,) = torch.autograd.grad(f, x, create_graph=True)
 | |
|         self.assertEqual(counter[0], 1)
 | |
|         _ = torch.autograd.grad(gradx_f, x)
 | |
|         # We should not error, and counter should not be incremented
 | |
|         self.assertEqual(counter[0], 1)
 | |
| 
 | |
|     def test_input_buffer_accum(self):
 | |
|         leaf = torch.rand(2, 2, requires_grad=True)
 | |
| 
 | |
|         # An op that returns sparse gradients
 | |
|         ind = torch.tensor([[0, 0]], dtype=torch.long)
 | |
|         out2 = leaf.gather(0, ind, sparse_grad=True)
 | |
| 
 | |
|         # An op that returns the gradients as-is
 | |
|         out1 = leaf.clone()
 | |
| 
 | |
|         grad_out1_original = torch.rand_like(out1)
 | |
|         grad_out1 = grad_out1_original.clone()
 | |
|         grad_out2 = torch.rand_like(out2)
 | |
| 
 | |
|         torch.autograd.backward((out1, out2), (grad_out1, grad_out2))
 | |
| 
 | |
|         # Given gradients should not be modified inplace
 | |
|         self.assertEqual(grad_out1, grad_out1_original)
 | |
| 
 | |
|     def test_no_unnecessary_unwrapping(self):
 | |
|         a = torch.randn(5, requires_grad=True)
 | |
|         a_orig = a.detach().clone()
 | |
|         b = a * a
 | |
|         c = a * b
 | |
|         d = torch.exp(a)
 | |
| 
 | |
|         # a is leaf
 | |
|         self.assertIs(b.grad_fn._saved_self, a)
 | |
|         self.assertIs(b.grad_fn._saved_other, a)
 | |
|         self.assertIs(c.grad_fn._saved_self, a)
 | |
| 
 | |
|         # b is not an output
 | |
|         self.assertIs(c.grad_fn._saved_other, b)
 | |
| 
 | |
|         # d is an output
 | |
|         self.assertEqual(d.grad_fn._saved_result, d)
 | |
|         self.assertIsNot(d.grad_fn._saved_result, d)
 | |
| 
 | |
|         c.sum().backward()
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
 | |
|             c.grad_fn._saved_self
 | |
| 
 | |
|         # a is left untouched
 | |
|         self.assertEqual(a, a_orig)
 | |
| 
 | |
|     def test_saved_variable_version_counter(self):
 | |
|         a = torch.rand(2, requires_grad=True)
 | |
| 
 | |
|         b = torch.exp(a)
 | |
| 
 | |
|         b_unpacked = b.grad_fn._saved_result
 | |
|         self.assertEqual(b, b_unpacked)
 | |
|         self.assertEqual(b._version, b_unpacked._version)
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             b += 1
 | |
| 
 | |
|         self.assertEqual(b, b_unpacked)
 | |
|         self.assertEqual(b._version, b_unpacked._version)
 | |
| 
 | |
|     def test_saved_variable_packing_unpacking_saved_original_with_hooks(self):
 | |
|         # Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks
 | |
|         # The saved_original / did_not_save_original distinction corresponds to the `save_original`
 | |
|         # attribute of `SavedVariable`.
 | |
| 
 | |
|         def test(get_input, is_leaf):
 | |
|             a = get_input()
 | |
|             grad_fn = a.grad_fn
 | |
|             y = a * a
 | |
|             y.grad_fn._raw_saved_self.register_hooks(lambda x: 2 * x, lambda x: x / 2)
 | |
|             self.assertEqual(a, y.grad_fn._saved_self)
 | |
|             if not is_leaf:
 | |
|                 self.assertIs(grad_fn, y.grad_fn._saved_self.grad_fn)
 | |
|                 y.sum().backward()
 | |
|             else:
 | |
|                 y.sum().backward()
 | |
|                 self.assertEqual(2 * a, a.grad)
 | |
| 
 | |
|             a = get_input()
 | |
|             grad_fn = a.grad_fn
 | |
|             y = a * a
 | |
|             y.grad_fn._raw_saved_self.register_hooks(lambda x: 2 * x, lambda x: x)
 | |
|             self.assertEqual(2 * a, y.grad_fn._saved_self)
 | |
|             if not is_leaf:
 | |
|                 self.assertIs(grad_fn, y.grad_fn._saved_self.grad_fn)
 | |
|                 y.sum().backward()
 | |
|             else:
 | |
|                 y.sum().backward()
 | |
|                 self.assertEqual(3 * a, a.grad)
 | |
| 
 | |
|             # double backward
 | |
|             a = get_input()
 | |
|             grad_fn = a.grad_fn
 | |
|             y = a ** 3
 | |
|             y.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
 | |
|             s = torch.sum(y)
 | |
|             g, = torch.autograd.grad(s, (a, ), create_graph=True)
 | |
|             if not is_leaf:
 | |
|                 self.assertIs(grad_fn, y.grad_fn._saved_self.grad_fn)
 | |
|                 g.sum().backward()
 | |
|             else:
 | |
|                 g.sum().backward()
 | |
|                 self.assertEqual(6 * a, a.grad)
 | |
| 
 | |
|             a = get_input()
 | |
|             y = a * a
 | |
|             y.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: 1)
 | |
|             with self.assertRaisesRegex(TypeError, "Output of saved tensor unpack_hook expected to be a Tensor"):
 | |
|                 print(y.grad_fn._saved_self)
 | |
| 
 | |
|             a = get_input()
 | |
|             y = a * a
 | |
|             with self.assertRaisesRegex(TypeError, "missing 1 required positional argument"):
 | |
|                 y.grad_fn._raw_saved_self.register_hooks(lambda x, b: x, lambda x: x)
 | |
| 
 | |
|             a = get_input()
 | |
|             y = a * a
 | |
|             with self.assertRaisesRegex(TypeError, "missing 1 required positional argument"):
 | |
|                 y.grad_fn._raw_saved_self.register_hooks(lambda x, b: (x, b), lambda x: x)
 | |
| 
 | |
|             def inplace_double(x):
 | |
|                 x *= 2
 | |
|                 return x
 | |
| 
 | |
|             a = get_input()
 | |
|             t = a * a
 | |
| 
 | |
|             with self.assertRaisesRegex(RuntimeError, "A saved tensor pack hook is modifying its input in place."):
 | |
|                 t.grad_fn._raw_saved_self.register_hooks(inplace_double, lambda x: x / 2)
 | |
| 
 | |
|         # leaf
 | |
|         test(lambda: torch.randn(5, requires_grad=True), True)
 | |
| 
 | |
|         # not leaf, not output
 | |
|         test(lambda: (1 + torch.randn(5, requires_grad=True)), False)
 | |
| 
 | |
|     def test_saved_variable_saved_original_inplace_detach(self):
 | |
|         # Detaching a tensor that is saved input raises
 | |
|         a = torch.tensor(1., requires_grad=True).clone()
 | |
|         b = a.sin()
 | |
|         a.detach_()
 | |
|         with self.assertRaisesRegex(RuntimeError, "Trying to use a saved tensor that has been detached"):
 | |
|             b.backward()
 | |
| 
 | |
|         # Detaching a tensor that is saved as output is OK
 | |
|         a = torch.tensor(1., requires_grad=True).clone()
 | |
|         b = a.exp()
 | |
|         a.detach_()
 | |
|         b.backward()
 | |
| 
 | |
|     def test_saved_variable_packing_unpacking_did_not_save_original_with_hooks(self):
 | |
|         # Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks
 | |
|         # The saved_original / did_not_save_original distinction corresponds to the `save_original`
 | |
|         # attribute of `SavedVariable`.
 | |
| 
 | |
|         a = torch.randn(5, requires_grad=True)
 | |
|         y = torch.exp(a)
 | |
|         y.grad_fn._raw_saved_result.register_hooks(lambda x: x, lambda x: x)
 | |
|         self.assertEqual(y, y.grad_fn._saved_result)
 | |
|         self.assertIs(y.grad_fn, y.grad_fn._saved_result.grad_fn)
 | |
|         y.sum().backward()
 | |
|         self.assertEqual(a.grad, y)
 | |
| 
 | |
|     def test_saved_variable_packing_unpacking_saved_original_with_default_hooks(self):
 | |
|         # Tests that default hooks are properly registered, used and reset
 | |
|         # The saved_original / did_not_save_original distinction corresponds to the `save_original`
 | |
|         # attribute of `SavedVariable`.
 | |
|         # See also:
 | |
|         #  - test_saved_variable_packing_unpacking_saved_original_with_hooks
 | |
| 
 | |
|         def pack(x):
 | |
|             warnings.warn("pack")
 | |
|             return x
 | |
| 
 | |
|         with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
 | |
|             a = torch.ones(5, requires_grad=True)
 | |
| 
 | |
|             with warnings.catch_warnings(record=True) as w:
 | |
|                 warnings.simplefilter('always')
 | |
|                 y = a * a
 | |
|                 # should raise two warnings from a being saved twice
 | |
|                 self.assertEqual(len(w), 2)
 | |
| 
 | |
|         with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
 | |
|             a = torch.randn(5, requires_grad=True)
 | |
|             y = a * a
 | |
|             self.assertEqual(a, y.grad_fn._saved_self)
 | |
|             self.assertEqual(a, y.grad_fn._saved_other)
 | |
|             y.sum().backward()
 | |
|             self.assertEqual(2 * a, a.grad)
 | |
| 
 | |
|         with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x / 2):
 | |
|             a = torch.randn(5, requires_grad=True)
 | |
|             y = a * a
 | |
|             self.assertEqual(a, y.grad_fn._saved_self)
 | |
|             self.assertEqual(a, y.grad_fn._saved_other)
 | |
|             y.sum().backward()
 | |
|             self.assertEqual(2 * a, a.grad)
 | |
| 
 | |
|         with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x):
 | |
|             a = torch.randn(5, requires_grad=True)
 | |
|             y = a * a
 | |
|             self.assertEqual(2 * a, y.grad_fn._saved_self)
 | |
|             self.assertEqual(2 * a, y.grad_fn._saved_other)
 | |
|             y.sum().backward()
 | |
|             self.assertEqual(4 * a, a.grad)
 | |
| 
 | |
|         # Exited hooks correctly
 | |
|         a = torch.randn(5, requires_grad=True)
 | |
|         y = a * a
 | |
|         self.assertEqual(a, y.grad_fn._saved_self)
 | |
|         self.assertEqual(a, y.grad_fn._saved_other)
 | |
|         y.sum().backward()
 | |
|         self.assertEqual(2 * a, a.grad)
 | |
| 
 | |
|     def test_saved_variable_packing_unpacking_did_not_save_original_with_default_hooks(self):
 | |
|         # See also test_saved_variable_packing_unpacking_did_not_save_original_with_hooks
 | |
| 
 | |
|         with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
 | |
|             a = torch.randn(5, requires_grad=True)
 | |
|             y = torch.exp(a)
 | |
|             self.assertEqual(y, y.grad_fn._saved_result)
 | |
|             y.sum().backward()
 | |
|             self.assertEqual(a.grad, y)
 | |
| 
 | |
|     def test_setting_default_saved_variable_hooks_twice_should_not_fail(self):
 | |
|         with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
 | |
|             with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
 | |
|                 pass
 | |
| 
 | |
|     def test_setting_default_saved_variable_hooks_twice_should_use_inner(self):
 | |
|         with torch.autograd.graph.saved_tensors_hooks(lambda x: 3 * x, lambda x: 3 * x):
 | |
|             b = torch.randn(5, requires_grad=True)
 | |
|             with torch.autograd.graph.saved_tensors_hooks(lambda x: 5 * x, lambda x: 5 * x):
 | |
|                 a = torch.randn(5, requires_grad=True)
 | |
|                 y = a * a
 | |
|             z = b * b
 | |
|         y.sum().backward()
 | |
|         z.sum().backward()
 | |
|         self.assertEqual(2 * 5 * 5 * a, a.grad)
 | |
|         self.assertEqual(2 * 3 * 3 * b, b.grad)
 | |
| 
 | |
|     def test_disabling_saved_tensor_hooks(self):
 | |
|         with torch.autograd.graph.disable_saved_tensors_hooks("error message"):
 | |
|             with self.assertRaisesRegex(RuntimeError, "error message"):
 | |
|                 with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
 | |
|                     pass
 | |
| 
 | |
|         self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled())
 | |
| 
 | |
|         with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
 | |
|             with self.assertRaisesRegex(RuntimeError, "error message"):
 | |
|                 with torch.autograd.graph.disable_saved_tensors_hooks("error message"):
 | |
|                     pass
 | |
| 
 | |
|         self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled())
 | |
| 
 | |
|     def test_disabling_saved_tensor_hooks_nested(self):
 | |
|         with torch.autograd.graph.disable_saved_tensors_hooks("outer"):
 | |
|             with torch.autograd.graph.disable_saved_tensors_hooks("inner"):
 | |
|                 with self.assertRaisesRegex(RuntimeError, "inner"):
 | |
|                     with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
 | |
|                         pass
 | |
| 
 | |
|             self.assertFalse(torch._C._autograd._saved_tensors_hooks_is_enabled())
 | |
| 
 | |
|         self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled())
 | |
| 
 | |
|     def test_saved_tensor_hooks_custom_error_propagaation(self):
 | |
|         class CustomError(Exception):
 | |
|             pass
 | |
| 
 | |
|         class error_on_pack_hook(torch.autograd.graph.saved_tensors_hooks):
 | |
|             def __init__(self):
 | |
|                 def pack_hook(x):
 | |
|                     raise CustomError("pack")
 | |
| 
 | |
|                 super().__init__(pack_hook, lambda x: x)
 | |
| 
 | |
|         class error_on_unpack_hook(torch.autograd.graph.saved_tensors_hooks):
 | |
|             def __init__(self):
 | |
|                 def unpack_hook(x):
 | |
|                     raise CustomError("unpack")
 | |
| 
 | |
|                 super().__init__(lambda x: x, unpack_hook)
 | |
| 
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
| 
 | |
|         with error_on_pack_hook():
 | |
|             with self.assertRaisesRegex(CustomError, "pack"):
 | |
|                 out = torch.sin(a)
 | |
| 
 | |
|         with error_on_unpack_hook():
 | |
|             out = torch.sin(a)
 | |
|             with self.assertRaisesRegex(CustomError, "unpack"):
 | |
|                 out.backward()
 | |
| 
 | |
| 
 | |
|     def test_save_on_cpu_and_checkpoint(self):
 | |
|         a = torch.randn(2, 2, requires_grad=True)
 | |
| 
 | |
|         b = a.pow(2).pow(2).pow(2).pow(2)
 | |
|         b.sum().backward()
 | |
|         b_grad = a.grad.clone()
 | |
|         a.grad.zero_()
 | |
| 
 | |
|         with torch.autograd.graph.save_on_cpu():
 | |
|             h = a.pow(2)
 | |
|             h = checkpoint(lambda x: x.pow(2).pow(2), h, use_reentrant=False)
 | |
|             c = h.pow(2)
 | |
|         c.sum().backward()
 | |
|         c_grad = a.grad.clone()
 | |
|         a.grad.zero_()
 | |
| 
 | |
|         def f(a):
 | |
|             h = a.pow(2)
 | |
|             with torch.autograd.graph.save_on_cpu():
 | |
|                 h = h.pow(2).pow(2)
 | |
|             return h.pow(2)
 | |
| 
 | |
|         d = checkpoint(f, a, use_reentrant=False)
 | |
|         d.sum().backward()
 | |
|         d_grad = a.grad.clone()
 | |
| 
 | |
|         self.assertEqual(b_grad, c_grad)
 | |
|         self.assertEqual(b_grad, d_grad)
 | |
| 
 | |
|     def test_pack_hook_with_inplace_modification_should_fail(self):
 | |
|         a = torch.randn(5, requires_grad=True)
 | |
| 
 | |
|         def inc(x):
 | |
|             x += 1
 | |
|             return x
 | |
|         with torch.autograd.graph.saved_tensors_hooks(inc, lambda x: x):
 | |
|             with self.assertRaisesRegex(RuntimeError, "A saved tensor pack hook is modifying its input in place."):
 | |
|                 y = torch.exp(a)
 | |
| 
 | |
|         y = torch.exp(a)
 | |
|         with self.assertRaisesRegex(RuntimeError, "A saved tensor pack hook is modifying its input in place."):
 | |
|             y.grad_fn._raw_saved_result.register_hooks(inc, lambda x: x)
 | |
| 
 | |
|     def test_saving_variable_to_disk(self):
 | |
|         with tempfile.TemporaryDirectory() as tmp_dir:
 | |
|             def pack(x):
 | |
|                 name = os.path.join(tmp_dir, str(uuid.uuid4()))
 | |
|                 torch.save(x, name)
 | |
|                 return name
 | |
| 
 | |
|             def unpack(name):
 | |
|                 return torch.load(name)
 | |
| 
 | |
|             with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
 | |
|                 a = torch.ones(5, requires_grad=True)
 | |
|                 y = a * a
 | |
|                 self.assertEqual(a, y.grad_fn._saved_self)
 | |
| 
 | |
|                 y.sum().backward()
 | |
|                 self.assertEqual(2 * a, a.grad)
 | |
| 
 | |
|     def test_default_saved_variable_hooks_double_backward(self):
 | |
|         with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
 | |
|             a = torch.randn(5, requires_grad=True)
 | |
|             y = a ** 3
 | |
|             s = torch.sum(y)
 | |
|             g, = torch.autograd.grad(s, (a, ), create_graph=True)
 | |
|             g.sum().backward()
 | |
|             self.assertEqual(6 * a, a.grad)
 | |
| 
 | |
| 
 | |
|         with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x):
 | |
|             a = torch.randn(5, requires_grad=True)
 | |
|             y = a ** 3
 | |
|             s = torch.sum(y)
 | |
|         g, = torch.autograd.grad(s, (a, ), create_graph=True)
 | |
|         g.sum().backward()
 | |
|         # factor 2 because only a is saved once
 | |
|         self.assertEqual(6 * 2 * a, a.grad)
 | |
| 
 | |
| 
 | |
|         a = torch.randn(5, requires_grad=True)
 | |
|         y = a ** 3
 | |
|         s = torch.sum(y)
 | |
|         with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x):
 | |
|             g, = torch.autograd.grad(s, (a, ), create_graph=True)
 | |
|             g.sum().backward()
 | |
|             # factor 4 because pow_backward is grad * (exp * self.pow(exp - 1))
 | |
|             # so grad is saved and self (i.e. a) is saved
 | |
|             self.assertEqual(6 * 4 * a, a.grad)
 | |
| 
 | |
| 
 | |
|         with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x):
 | |
|             a = torch.randn(5, requires_grad=True)
 | |
|             y = a ** 3
 | |
|             s = torch.sum(y)
 | |
|             g, = torch.autograd.grad(s, (a, ), create_graph=True)
 | |
|             g.sum().backward()
 | |
|             # combining the two above blocks: 2 * 4 = 8
 | |
|             # note that in that sense, a is saved twice
 | |
|             self.assertEqual(6 * 8 * a, a.grad)
 | |
| 
 | |
|     def test_wrapped_number_saved_variable_hooks(self):
 | |
|         def err_hook(x):
 | |
|             raise RuntimeError("this hook should not be called")
 | |
| 
 | |
|         with torch.autograd.graph.saved_tensors_hooks(err_hook, err_hook):
 | |
|             a = torch.randn(5, requires_grad=True)
 | |
|             out = (a * 3).sum()
 | |
|             # 3 is saved as a saved tensor because it is a wrapped number, but
 | |
|             # wrapped numbers should be special cased to not trigger saved variable hooks
 | |
|             torch.autograd.grad(out, (a,))
 | |
| 
 | |
|     def test_graph_save_on_cpu(self):
 | |
|         def test(get_input, cuda, pin_memory):
 | |
|             with torch.autograd.graph.save_on_cpu(pin_memory):
 | |
|                 a = get_input()
 | |
|                 if cuda:
 | |
|                     a.cuda()
 | |
|                 y = a * a
 | |
|                 self.assertEqual(a, y.grad_fn._saved_self)
 | |
|                 self.assertEqual(a, y.grad_fn._saved_other)
 | |
|                 self.assertEqual(a.dtype, y.grad_fn._saved_self.dtype)
 | |
|                 self.assertEqual(a.layout, y.grad_fn._saved_self.layout)
 | |
|                 if y.is_sparse:
 | |
|                     y = y.to_dense()
 | |
|                 y.sum().backward()
 | |
| 
 | |
|                 actual = 2 * a
 | |
|                 expected = a.grad
 | |
|                 if a.is_sparse:
 | |
|                     actual = actual.coalesce()
 | |
|                     expected = expected.coalesce()
 | |
| 
 | |
|                 self.assertEqual(actual, expected)
 | |
| 
 | |
|         for cuda in [False] + ([True] if torch.cuda.is_available() else []):
 | |
|             for pin_memory in [True, False]:
 | |
|                 # FloatTensor
 | |
|                 test(lambda: torch.randn(5, requires_grad=True), cuda, pin_memory)
 | |
|                 # DoubleTensor
 | |
|                 test(lambda: torch.randn(5, requires_grad=True, dtype=torch.double), cuda, pin_memory)
 | |
|                 # Sparse tensor
 | |
|                 x = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]), requires_grad=True)
 | |
|                 test(lambda: x, cuda, pin_memory)
 | |
| 
 | |
|     @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
 | |
|     def test_graph_save_on_cpu_cuda(self):
 | |
|         def f(x):
 | |
|             a = x + 1
 | |
|             return a * a
 | |
| 
 | |
|         # with grad
 | |
|         a = torch.ones(1, requires_grad=True, device="cuda")
 | |
|         y = f(a)
 | |
|         memory_with_grad = torch.cuda.memory_allocated()
 | |
| 
 | |
|         del a
 | |
|         del y
 | |
| 
 | |
|         # without grad
 | |
|         a = torch.ones(1, requires_grad=True, device="cuda")
 | |
|         with torch.no_grad():
 | |
|             y = f(a)
 | |
|         memory_without_grad = torch.cuda.memory_allocated()
 | |
| 
 | |
|         self.assertGreater(memory_with_grad, memory_without_grad)
 | |
| 
 | |
|         del a
 | |
|         del y
 | |
| 
 | |
|         # with hooks
 | |
|         with torch.autograd.graph.save_on_cpu():
 | |
|             a = torch.ones(1, requires_grad=True, device="cuda")
 | |
|             y = f(a)
 | |
|             memory_with_hooks = torch.cuda.memory_allocated()
 | |
|             self.assertEqual(memory_with_hooks, memory_without_grad)
 | |
| 
 | |
|     def test_multi_grad_hooks(self):
 | |
|         t1 = torch.rand(2, requires_grad=True)
 | |
|         t2 = torch.rand(2, requires_grad=True)
 | |
|         t3 = torch.rand(2, requires_grad=True)
 | |
|         t4 = torch.rand(2, requires_grad=True)
 | |
| 
 | |
|         res = [None] * 4
 | |
|         count = [0]
 | |
| 
 | |
|         def hook(grads):
 | |
|             nonlocal res
 | |
|             count[0] += 1
 | |
|             res = [g is not None for g in grads]
 | |
| 
 | |
|         handle = torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook)
 | |
| 
 | |
|         out = t2 * t3
 | |
| 
 | |
|         out.sum().backward(inputs=(t2, t3), retain_graph=True)
 | |
|         self.assertEqual(count[0], 1)
 | |
|         self.assertEqual(res, [False, True, True, False])
 | |
| 
 | |
|         out.sum().backward(inputs=(t1, t4), retain_graph=True)
 | |
|         self.assertEqual(count[0], 1)
 | |
| 
 | |
|         out.sum().backward(inputs=(t1, t3), retain_graph=True)
 | |
|         self.assertEqual(count[0], 2)
 | |
|         self.assertEqual(res, [False, False, True, False])
 | |
| 
 | |
|         class Func(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, gO):
 | |
|                 raise RuntimeError("error message")
 | |
| 
 | |
|         out = Func.apply(t2) * t3
 | |
|         with self.assertRaisesRegex(RuntimeError, "error message"):
 | |
|             out.sum().backward(inputs=(t2, t3), retain_graph=True)
 | |
|         self.assertEqual(count[0], 2)
 | |
| 
 | |
|         handle.remove()
 | |
|         out.sum().backward(inputs=(t1, t3), retain_graph=True)
 | |
|         self.assertEqual(count[0], 2)
 | |
| 
 | |
|     def test_pynode_destruction_deadlock(self):
 | |
|         script = """
 | |
| import torch
 | |
| 
 | |
| class Foo(torch.autograd.Function):
 | |
|     @staticmethod
 | |
|     def forward(ctx, x):
 | |
|         return x.clone()
 | |
| 
 | |
|     @staticmethod
 | |
|     def forward(ctx, gO):
 | |
|         return gO.clone()
 | |
| 
 | |
| def get_out():
 | |
|     inp = torch.rand(2, requires_grad=True)
 | |
| 
 | |
|     # The python function is first so that it runs
 | |
|     # last in the backward pass
 | |
|     right = Foo.apply(inp)
 | |
| 
 | |
|     # An op that creates new memory
 | |
|     left1 = inp.clone()
 | |
|     # An op that saves its input
 | |
|     left2 = left1 ** 2
 | |
| 
 | |
|     # Inplace modify so that the backward for
 | |
|     # left2 always raises an error
 | |
|     left1 += 1
 | |
| 
 | |
|     # An op that takes both side as input.
 | |
|     # After running, both side's last op will be in
 | |
|     # the ready queue
 | |
|     # And the op for left will run first as it was
 | |
|     # executed last during the forward
 | |
|     out = left2 + right
 | |
| 
 | |
|     return out
 | |
| 
 | |
| # Nothing should be global variables here as, from what
 | |
| # I can see, python leaks all the global objects
 | |
| get_out().sum().backward()
 | |
| 
 | |
| # This used to deadlock when the PyNode is being destroyed after
 | |
| # the error is raised.
 | |
| """
 | |
|         try:
 | |
|             subprocess.check_output(
 | |
|                 [sys.executable, '-c', script],
 | |
|                 stderr=subprocess.STDOUT,
 | |
|                 # On Windows, opening the subprocess with the default CWD makes `import torch`
 | |
|                 # fail, so just set CWD to this script's directory
 | |
|                 cwd=os.path.dirname(os.path.realpath(__file__)),
 | |
|                 # It is ok to have an extra long timeout here as a timeout means the test failed
 | |
|                 timeout=20)
 | |
|         except subprocess.TimeoutExpired as e:
 | |
|             self.fail(msg="Example code timed out! See the code sample in the test for details.")
 | |
|         except subprocess.CalledProcessError as e:
 | |
|             err_msg = "RuntimeError: one of the variables needed for gradient computation"
 | |
|             self.assertTrue(err_msg in e.output.decode("utf-8"))
 | |
| 
 | |
|     def test_view_func_replay(self):
 | |
|         def _assert_match_metadata(a, b):
 | |
|             self.assertEqual(a.size(), b.size())
 | |
|             self.assertEqual(a.stride(), b.stride())
 | |
|             self.assertEqual(a.storage_offset(), b.storage_offset())
 | |
| 
 | |
|         def _test_op(fn, inp, args):
 | |
|             out = fn(inp, *args)
 | |
|             self.assertTrue(out._is_view)
 | |
|             self.assertTrue(out._base is inp)
 | |
| 
 | |
|             new_inp = inp.clone()
 | |
|             _assert_match_metadata(new_inp, inp)
 | |
|             new_out = out._view_func(new_inp)
 | |
|             _assert_match_metadata(new_out, out)
 | |
| 
 | |
|         _test_op(torch.select, torch.rand(2, 2), (0, 0))
 | |
|         _test_op(torch.as_strided, torch.rand(2, 2), ((4,), (1,)))
 | |
|         _test_op(torch.view_as_complex, torch.rand(2, 2), ())
 | |
|         _test_op(torch.view_as_real, torch.rand(2, 2, dtype=torch.cfloat), ())
 | |
| 
 | |
| 
 | |
| def index_perm_variable(shape, max_indices):
 | |
|     if not isinstance(shape, tuple):
 | |
|         shape = (shape,)
 | |
| 
 | |
|     index = torch.randperm(max_indices).narrow(0, 0, reduce(mul, shape)).view(shape)
 | |
|     return index
 | |
| 
 | |
| def bernoulli_scalar():
 | |
|     return torch.tensor(0, dtype=torch.uint8).bernoulli_()
 | |
| 
 | |
| 
 | |
| class TestAutogradForwardModeBatchedGrad(TestCase):
 | |
|     def test_out_of_place_basic(self):
 | |
|         a = torch.rand(4, 4, dtype=torch.double, requires_grad=True)
 | |
|         b = torch.rand(4, 4, dtype=torch.double, requires_grad=True)
 | |
|         self.assertTrue(gradcheck(torch.sin, a, check_forward_ad=True, check_batched_grad=True,
 | |
|                                   check_batched_forward_grad=True))
 | |
|         self.assertTrue(gradcheck(torch.add, (a, b), check_forward_ad=True, check_batched_grad=True,
 | |
|                                   check_batched_forward_grad=True))
 | |
| 
 | |
|     def test_out_of_place_not_same_layout(self):
 | |
|         input = torch.zeros([2, 2]).transpose(0, 1)
 | |
|         tangent = torch.zeros([2, 2, 2])
 | |
| 
 | |
|         def jvp(tangent):
 | |
|             with fwAD.dual_level():
 | |
|                 x = fwAD.make_dual(input, tangent)
 | |
|                 return fwAD.unpack_dual(x)[1]
 | |
|         x_tangent = torch._vmap_internals._vmap(jvp, 0, 0)(tangent)
 | |
| 
 | |
|         self.assertIsNot(x_tangent, tangent)
 | |
| 
 | |
|     def test_inplace_on_view_same_layout(self):
 | |
|         input = torch.zeros([2, 2])
 | |
|         tangent = torch.zeros([2, 2, 2])
 | |
|         base = torch.zeros([2, 2])
 | |
|         view = base.view_as(base)
 | |
| 
 | |
|         def jvp(tangent):
 | |
|             with fwAD.dual_level():
 | |
|                 x = fwAD.make_dual(input, tangent)
 | |
|                 view.copy_(x)
 | |
|                 return fwAD.unpack_dual(x)[1], fwAD.unpack_dual(view)[1], fwAD.unpack_dual(view._base)[1]
 | |
|         x_tangent, view_tangent, base_tangent = torch._vmap_internals._vmap(jvp, 0, 0)(tangent)
 | |
| 
 | |
|         self.assertFalse(view_tangent._is_view())  # Optimization to share the same tensor!
 | |
|         self.assertIs(view_tangent, base_tangent)
 | |
|         self.assertIs(x_tangent, tangent)
 | |
| 
 | |
|     def test_inplace_on_view_not_same_layout(self):
 | |
|         input = torch.zeros([2, 2])
 | |
|         tangent = torch.zeros([2, 2, 2])
 | |
|         view = torch.zeros([2, 2]).transpose(0, 1)
 | |
| 
 | |
|         def jvp(tangent):
 | |
|             with fwAD.dual_level():
 | |
|                 x = fwAD.make_dual(input, tangent)
 | |
|                 view.copy_(x)
 | |
|                 return fwAD.unpack_dual(x)[1], fwAD.unpack_dual(view)[1], fwAD.unpack_dual(view._base)[1]
 | |
|         x_tangent, view_tangent, base_tangent = torch._vmap_internals._vmap(jvp, 0, 0)(tangent)
 | |
| 
 | |
|         self.assertIs(view_tangent._base, base_tangent)
 | |
|         self.assertIs(x_tangent, tangent)
 | |
|         self.assertIsNot(view_tangent, tangent)
 | |
| 
 | |
|     def test_metadata_check_for_storage_numel_skipped(self):
 | |
|         # See: test_metadata_check_checks_storage_numel for the reverse of this test
 | |
|         primal = torch.randn(5)[:4].detach()
 | |
|         self.assertEqual(len(primal.storage()), 5)
 | |
|         tangent = torch.randn(10, 4)
 | |
| 
 | |
|         def jvp(tangent):
 | |
|             with fwAD.dual_level():
 | |
|                 dual = fwAD.make_dual(primal, tangent)
 | |
|                 _, unpacked_tangent = fwAD.unpack_dual(dual)
 | |
| 
 | |
|                 # No copy is made
 | |
|                 self.assertIs(tangent, unpacked_tangent)
 | |
| 
 | |
|                 # as_strided raises
 | |
|                 with self.assertRaisesRegex(RuntimeError, "can access memory outside of `tensor`"):
 | |
|                     dual.as_strided((5,), (1,), 0)
 | |
|             return unpacked_tangent
 | |
| 
 | |
|         torch._vmap_internals._vmap(jvp, 0, 0)(tangent)
 | |
| 
 | |
| 
 | |
| class TestAutogradForwardMode(TestCase):
 | |
|     def tearDown(self):
 | |
|         # Ensure that a failing test won't make others fail
 | |
|         while fwAD._current_level >= 0:
 | |
|             fwAD.exit_dual_level()
 | |
| 
 | |
|         super().tearDown()
 | |
| 
 | |
|     def test_forward_level_cleanup(self):
 | |
|         def get_tensor_and_weak_ref():
 | |
|             # Create a new Tensor and weak reference
 | |
|             t = torch.rand(2, requires_grad=True)
 | |
|             return t, torch._C._WeakTensorRef(t)
 | |
| 
 | |
|         # Sanity check that the helper function works as expected
 | |
|         t, t_ref = get_tensor_and_weak_ref()
 | |
|         self.assertFalse(t_ref.expired())
 | |
| 
 | |
|         del t
 | |
|         self.assertTrue(t_ref.expired())
 | |
| 
 | |
|         # Main test code
 | |
|         foo = torch.rand(2)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             tangent, tangent_ref = get_tensor_and_weak_ref()
 | |
|             self.assertFalse(tangent_ref.expired())
 | |
| 
 | |
|             dual = fwAD.make_dual(foo, tangent)
 | |
|             self.assertFalse(tangent_ref.expired())
 | |
| 
 | |
|             # Make sure that the tangent we provided has been re-used as is
 | |
|             self.assertTrue(fwAD.unpack_dual(dual)[1] is tangent)
 | |
| 
 | |
|             # Make sure that dual is keeping the tangent alive
 | |
|             del tangent
 | |
|             self.assertFalse(tangent_ref.expired())
 | |
| 
 | |
|             # Make sure that the dual level does not keep the c++
 | |
|             # version of the tangent alive
 | |
|             del dual
 | |
|             self.assertTrue(tangent_ref.expired())
 | |
| 
 | |
|     def test_size_check(self):
 | |
|         foo = torch.rand(2)
 | |
|         tangent = torch.rand(3)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             with self.assertRaisesRegex(RuntimeError, "Trying to set a forward gradient that has a different size"):
 | |
|                 dual = fwAD.make_dual(foo, tangent)
 | |
| 
 | |
|             dual = fwAD.make_dual(foo, tangent[1:])
 | |
| 
 | |
|     def test_metadata_check_checks_storage_numel(self):
 | |
|         primal = torch.randn(5)[:4].detach()
 | |
|         self.assertEqual(len(primal.storage()), 5)
 | |
|         tangent = torch.randn(4)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             dual = fwAD.make_dual(primal, tangent)
 | |
|             _, unpacked_tangent = fwAD.unpack_dual(dual)
 | |
| 
 | |
|             # # Verify that mutating unpacked tangent does not affect the original tangent
 | |
|             tangent_clone = tangent.clone()
 | |
|             unpacked_tangent *= 2
 | |
|             self.assertTrue(torch.allclose(tangent_clone, tangent))
 | |
| 
 | |
|             # as_strided runs without error
 | |
|             dual.as_strided((5,), (1,), 0)
 | |
| 
 | |
|     def test_metadata_check_checks_ignores_size_zero(self):
 | |
|         a = torch.ones(0).as_strided((0, 1,), (1, 1,), 0)
 | |
|         b = torch.ones(0).as_strided((0, 1,), (1, 0,), 0)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             dual = fwAD.make_dual(a, b)
 | |
|             torch.diagonal(dual, offset=0)
 | |
| 
 | |
|         input = torch.rand([0, 1], dtype=torch.complex128, requires_grad=True)
 | |
|         func = partial(torch.diagonal, offset=0)
 | |
|         torch.autograd.gradcheck(func, (input,), check_forward_ad=True)
 | |
| 
 | |
|     def test_metadata_check_when_primal_has_conj_bit(self):
 | |
|         # Make sure the _has_same_storage_numel is a fallthrough, so that
 | |
|         # conj bit does not materialize. If it materializes it would
 | |
|         # cause the layout check to fail for views that do not index the
 | |
|         # the entire storage.
 | |
|         a = torch.randn(2, 2, dtype=torch.cdouble).conj()
 | |
|         b = torch.rand_like(a)
 | |
| 
 | |
|         self.assertTrue(torch.is_conj(a))
 | |
|         self.assertEqual(len(a.storage()), len(b.storage()))
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             dual = fwAD.make_dual(a, b)
 | |
|             dual[1:]
 | |
| 
 | |
|     def test_metadata_check_when_primal_has_neg_bit(self):
 | |
|         # Make sure the _has_same_storage_numel is a fallthrough, so that
 | |
|         # conj bit does not materialize. If it materializes it would
 | |
|         # cause the layout check to fail for views that do not index the
 | |
|         # the entire storage.
 | |
|         a = torch.randn(2, 2, dtype=torch.cdouble).conj().imag
 | |
|         b = torch.randn(2, 2, dtype=torch.cdouble).imag
 | |
| 
 | |
|         self.assertTrue(torch.is_neg(a))
 | |
|         self.assertEqual(len(a.storage()), len(b.storage()))
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             dual = fwAD.make_dual(a, b)
 | |
|             dual[1:]
 | |
| 
 | |
|     def test_metadata_check_check_conj(self):
 | |
|         keys = {
 | |
|             "NEITHER": lambda x: x,
 | |
|             "CONJ": lambda x: x.conj(),
 | |
|             "NEG": lambda x: x._neg_view()
 | |
|         }
 | |
| 
 | |
|         for primal_key, tangent_key in product(keys, keys):
 | |
|             x = keys[primal_key](torch.randn(2, 3, 4, dtype=torch.cdouble))
 | |
|             t = keys[tangent_key](torch.randn(2, 3, 4, dtype=torch.cdouble))
 | |
| 
 | |
|             if primal_key == tangent_key:
 | |
|                 with fwAD.dual_level():
 | |
|                     dual = fwAD.make_dual(x, t)
 | |
|                     self.assertTrue(fwAD.unpack_dual(dual).tangent is t)
 | |
|                     torch.real(dual)
 | |
|                     torch.imag(dual)
 | |
|             else:
 | |
|                 with fwAD.dual_level():
 | |
|                     dual = fwAD.make_dual(x, t)
 | |
|                     self.assertTrue(fwAD.unpack_dual(dual).tangent is not t)
 | |
|                     torch.real(dual)
 | |
|                     torch.imag(dual)
 | |
| 
 | |
|     def test_metadata_check_ignore_storage_offset_for_zero_numel_tensor(self):
 | |
|         # See https://github.com/pytorch/pytorch/issues/80507
 | |
|         a = torch.tensor([1.]).as_strided((0,), (1,), 1)
 | |
|         b = torch.tensor([1.]).as_strided((0,), (1,), 2)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             dual_input = fwAD.make_dual(a, b)
 | |
|             # Check that no copy is made
 | |
|             self.assertIs(fwAD.unpack_dual(dual_input).tangent, b)
 | |
| 
 | |
|         a = torch.tensor([1.]).as_strided((1,), (2,), 0)
 | |
|         b = torch.tensor([1.]).as_strided((1,), (1,), 0)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             dual_input = fwAD.make_dual(a, b)
 | |
|             dual_input[1:]
 | |
| 
 | |
|     # The following test functions want to ensure all the following behaviors:
 | |
|     #   - Ensure that default level system in the python binding works
 | |
|     #   - Ensure that only level 0 exists and nesting is properly disabled
 | |
|     #   - Ensure that printing works fine
 | |
|     #   - Ensure that basic packing/unpacking works
 | |
|     #   - Ensure that advanced packing/unpacking works
 | |
|     #     - For memory / version counter share
 | |
|     #     - For backward AD (regular ops)
 | |
|     #   - Ensure that view + inplace for both modes work fine
 | |
|     #   - Ensure we do proper cleanup on exit of a level
 | |
| 
 | |
| 
 | |
|     def test_default_level(self):
 | |
|         foo = torch.rand(2)
 | |
|         bar = torch.rand(2)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             baz = fwAD.make_dual(foo, bar)
 | |
|             baz_primal, baz_tangent = fwAD.unpack_dual(baz)
 | |
|         self.assertEqual(baz_primal, foo)
 | |
|         # We don't actually need to enforce that these two are the exact same python
 | |
|         # object, feel free to relax in the future
 | |
|         self.assertIs(baz_tangent, bar)
 | |
| 
 | |
|         baz_primal, baz_tangent = fwAD.unpack_dual(baz)
 | |
|         self.assertEqual(baz_primal, foo)
 | |
|         self.assertEqual(baz_tangent, None)
 | |
| 
 | |
|     def test_fwd_grad_enabled(self):
 | |
|         # Tests some private helper functions to enable/disable fwd grad mode
 | |
|         enabled = fwAD._is_fwd_grad_enabled()
 | |
|         self.assertTrue(enabled)
 | |
| 
 | |
|         try:
 | |
|             torch._C._set_fwd_grad_enabled(False)
 | |
|             enabled = fwAD._is_fwd_grad_enabled()
 | |
|             self.assertFalse(enabled)
 | |
|         finally:
 | |
|             torch._C._set_fwd_grad_enabled(True)
 | |
| 
 | |
|         enabled = fwAD._is_fwd_grad_enabled()
 | |
|         self.assertTrue(enabled)
 | |
| 
 | |
|     def test_set_fwd_grad_enabled(self):
 | |
|         # Tests a private helper function
 | |
|         try:
 | |
|             torch._C._set_fwd_grad_enabled(False)
 | |
|             enabled = fwAD._is_fwd_grad_enabled()
 | |
|             self.assertFalse(enabled)
 | |
| 
 | |
|             with fwAD._set_fwd_grad_enabled(True):
 | |
|                 enabled = fwAD._is_fwd_grad_enabled()
 | |
|                 self.assertTrue(enabled)
 | |
| 
 | |
|             enabled = fwAD._is_fwd_grad_enabled()
 | |
|             self.assertFalse(enabled)
 | |
|         finally:
 | |
|             torch._C._set_fwd_grad_enabled(True)
 | |
| 
 | |
|     def test_nested_level(self):
 | |
|         with fwAD.dual_level() as level:
 | |
|             # For now only level 0 exists
 | |
|             self.assertEqual(level, 0)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             with self.assertRaisesRegex(RuntimeError, "Nested forward mode AD is not supported at the moment"):
 | |
|                 nest_level = fwAD.enter_dual_level()
 | |
| 
 | |
|     def test_set_fw_grad_having_own_fw_grad_at_same_level(self):
 | |
|         foo = torch.rand(2)
 | |
|         bar = torch.rand(2)
 | |
|         baz = torch.rand(2)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             dual = fwAD.make_dual(foo, bar)
 | |
|             with self.assertRaisesRegex(RuntimeError, "has a forward gradient at the same level"):
 | |
|                 fwAD.make_dual(baz, dual)
 | |
| 
 | |
|     def test_codegen_ignores_undefined_outputs(self):
 | |
|         # This test checks that codegen silently ignores undefined outputs
 | |
|         # Below, grad_input is specified as False in grad_output_mask, so
 | |
|         # convolution backward will return a undefined tensor in that position.
 | |
|         # Note that for this test to work we need to make sure either grad_output
 | |
|         # or weight to be a dual tensor, so grad_input requires forward grad
 | |
|         weight = torch.randn(6, 1, 30, 30)
 | |
|         inp = torch.rand((1, 1, 32, 32))
 | |
|         out = torch.nn.functional.conv2d(inp, weight)
 | |
|         grad_out = torch.ones_like(out)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             dual_weight = fwAD.make_dual(weight, torch.ones_like(weight))
 | |
|             grad_input, _, _ = torch.ops.aten.convolution_backward(
 | |
|                 grad_out, inp, dual_weight, (0,),
 | |
|                 (1, 1), (0, 0), (1, 1), False, (0, 0), 1, (False, True, False))
 | |
|         self.assertIsNone(grad_input)
 | |
| 
 | |
|     def test_make_dual_inference_tensor_in_inference_mode(self):
 | |
|         with torch.inference_mode():
 | |
|             foo = torch.rand(2)
 | |
|             bar = torch.rand(2)
 | |
|             foo_copy = foo.clone()
 | |
| 
 | |
|             with fwAD.dual_level():
 | |
|                 dual = fwAD.make_dual(foo, bar)
 | |
|                 self.assertFalse(dual._is_view())
 | |
| 
 | |
|                 dual += 1
 | |
|                 self.assertFalse(torch.allclose(foo, foo_copy))
 | |
| 
 | |
|     def test_make_dual_torch_dispatch(self):
 | |
|         counter = [0]
 | |
| 
 | |
|         class MySubclass(torch.Tensor):
 | |
|             def __new__(cls, data=None):
 | |
|                 return torch.Tensor._make_subclass(cls, data)
 | |
| 
 | |
|             __torch_function__ = torch._C._disabled_torch_function_impl
 | |
| 
 | |
|             @classmethod
 | |
|             def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
 | |
|                 if func.overloadpacket == torch.ops.aten.alias:
 | |
|                     counter[0] += 1
 | |
| 
 | |
|                     # Make sure we can re-enable autograd here
 | |
|                     with torch.overrides.enable_reentrant_dispatch():
 | |
|                         foo = torch.rand(1, requires_grad=True)
 | |
|                         self.assertIsNotNone(foo.exp().grad_fn)
 | |
| 
 | |
|                 with no_dispatch():
 | |
|                     return func(*args, **kwargs)
 | |
| 
 | |
|         a = torch.tensor(1.)
 | |
|         s = MySubclass(a)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             # Only the primal has "alias" called on it
 | |
|             fwAD.make_dual(s, torch.rand_like(s))
 | |
|             self.assertEqual(counter[0], 1)
 | |
|             fwAD.make_dual(torch.rand_like(s), s)
 | |
|             self.assertEqual(counter[0], 1)
 | |
| 
 | |
|     def test_make_dual_forbid_integral_dtype(self):
 | |
|         primal_f = torch.ones(2, 2, dtype=torch.float)
 | |
|         primal_l = torch.ones(2, 2, dtype=torch.long)
 | |
| 
 | |
|         tangent_f = torch.ones(2, 2, dtype=torch.float)
 | |
|         tangent_l = torch.ones(2, 2, dtype=torch.long)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             # Float Primal and Long Tangent
 | |
|             with self.assertRaisesRegex(ValueError, "Expected tangent to be floating point or complex"):
 | |
|                 fwAD.make_dual(primal_f, tangent_l)
 | |
| 
 | |
|             # Long Primal and Long Tangent
 | |
|             with self.assertRaisesRegex(ValueError, "Expected primal to be floating point or complex"):
 | |
|                 fwAD.make_dual(primal_l, tangent_l)
 | |
| 
 | |
|             # Long Primal and Float Tangent
 | |
|             with self.assertRaisesRegex(ValueError, "Expected primal to be floating point or complex"):
 | |
|                 fwAD.make_dual(primal_l, tangent_f)
 | |
| 
 | |
|     def test_print(self):
 | |
|         with fwAD.dual_level() as level:
 | |
|             a = torch.rand(3)
 | |
|             self.assertFalse("tangent=" in str(a))
 | |
| 
 | |
|             b = fwAD.make_dual(a, torch.rand(3))
 | |
|             self.assertFalse("tangent=" in str(a))
 | |
|             self.assertTrue("tangent=" in str(b))
 | |
| 
 | |
|             b_primal, b_tangent = fwAD.unpack_dual(b)
 | |
|             self.assertFalse("tangent=" in str(b_primal))
 | |
|             self.assertFalse("tangent=" in str(b_tangent))
 | |
| 
 | |
|     def test_basic_packing_unpacking(self):
 | |
|         foo = torch.rand(2)
 | |
|         bar = torch.rand(2)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             baz = fwAD.make_dual(foo, bar)
 | |
|             baz_primal, baz_tangent = fwAD.unpack_dual(baz)
 | |
|             self.assertEqual(baz_primal, foo)
 | |
|             self.assertIs(baz_tangent, bar)
 | |
| 
 | |
|             # Check unpacked dual is returned as a named tuple
 | |
|             # NB: Every invocation of unpack_dual returns a new tensor view
 | |
|             self.assertIsNot(baz_primal, fwAD.unpack_dual(baz).primal)
 | |
|             self.assertEqual(baz_primal, fwAD.unpack_dual(baz).primal)
 | |
|             self.assertIs(baz_tangent, fwAD.unpack_dual(baz).tangent)
 | |
| 
 | |
|             # Check that packing/unpacking did not change the input
 | |
|             foo_primal, foo_tangent = fwAD.unpack_dual(foo)
 | |
|             self.assertEqual(foo_primal, foo)
 | |
|             self.assertIsNone(foo_tangent)
 | |
| 
 | |
|     def test_advanced_packing_unpacking(self):
 | |
|         foo = torch.rand(2)
 | |
|         bar = torch.ones(2)
 | |
| 
 | |
|         # Memory and version counter check
 | |
|         with fwAD.dual_level():
 | |
|             dual = fwAD.make_dual(foo, bar)
 | |
| 
 | |
|             # Ensure that they are sharing memory and version counter
 | |
|             self.assertEqual(dual.storage().data_ptr(), foo.storage().data_ptr())
 | |
| 
 | |
|             # Ensure we properly share the version counter
 | |
|             self.assertEqual(foo._version, dual._version)
 | |
|             foo.add_(1)
 | |
|             self.assertEqual(foo._version, dual._version)
 | |
| 
 | |
|             # Unpacking should only create aliases as well
 | |
|             dual_primal, dual_tangent = fwAD.unpack_dual(dual)
 | |
|             self.assertEqual(dual_primal.storage().data_ptr(), foo.storage().data_ptr())
 | |
|             self.assertEqual(dual_tangent.storage().data_ptr(), bar.storage().data_ptr())
 | |
|             # And the tangent is actually re-used as-is so it is still the same Tensor
 | |
|             self.assertIs(dual_tangent, bar)
 | |
| 
 | |
|             # Ensure we properly share the version counter
 | |
|             self.assertEqual(foo._version, dual_primal._version)
 | |
|             foo.add_(1)
 | |
|             self.assertEqual(foo._version, dual_primal._version)
 | |
|             self.assertEqual(bar._version, dual_tangent._version)
 | |
|             bar.add_(1)
 | |
|             self.assertEqual(bar._version, dual_tangent._version)
 | |
| 
 | |
|         # backward mode check
 | |
|         with fwAD.dual_level():
 | |
|             foo.requires_grad_()
 | |
|             bar.requires_grad_()
 | |
| 
 | |
|             # Check that backward gradients properly propagates through packing/unpacking
 | |
|             dual = fwAD.make_dual(foo, bar)
 | |
|             p, t = fwAD.unpack_dual(dual)
 | |
| 
 | |
|             gfoo, gbar = torch.autograd.grad(p.sum(), (foo, bar), retain_graph=True, allow_unused=True)
 | |
|             self.assertEqual(gfoo, torch.ones_like(foo))
 | |
|             self.assertIsNone(gbar)
 | |
| 
 | |
|             gfoo, gbar = torch.autograd.grad(t.sum(), (foo, bar), retain_graph=True, allow_unused=True)
 | |
|             self.assertIsNone(gfoo)
 | |
|             self.assertEqual(gbar, torch.ones_like(bar))
 | |
| 
 | |
|             # Check that forward gradients are impacted by detach()
 | |
|             detached_dual = dual.detach()
 | |
|             out = detached_dual * 2
 | |
|             p, t = fwAD.unpack_dual(out)
 | |
|             self.assertFalse(p.requires_grad)
 | |
|             self.assertEqual(p, foo * 2)
 | |
|             self.assertIsNone(t)
 | |
| 
 | |
|             # Check that forward gradients are not impacted by no_grad
 | |
|             with torch.no_grad():
 | |
|                 out = dual * 3
 | |
|             p, t = fwAD.unpack_dual(out)
 | |
|             self.assertFalse(p.requires_grad)
 | |
|             self.assertFalse(t.requires_grad)
 | |
|             self.assertEqual(p, foo * 3)
 | |
|             self.assertEqual(t, bar * 3)
 | |
| 
 | |
|             # Check that forward gradients are not impacted by inplace detach
 | |
|             dual = dual.clone()
 | |
|             dual.detach_()
 | |
|             out = dual * 2
 | |
|             p, t = fwAD.unpack_dual(out)
 | |
|             self.assertFalse(p.requires_grad)
 | |
|             self.assertEqual(p, foo * 2)
 | |
|             self.assertIsNone(t)
 | |
| 
 | |
|     def test_view_inplace_non_differentiable_views(self):
 | |
|         original_foo = torch.rand(2, dtype=torch.double)
 | |
|         original_bar = torch.ones(2, dtype=torch.double)
 | |
| 
 | |
|         # Do clones to be able to compare the values updated inplace
 | |
|         # with the original content of these Tensors
 | |
|         foo = original_foo.clone()
 | |
|         bar = original_bar.clone()
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             # Note that in this test, we use "update" to mean computing the right tangent for the dual
 | |
|             # All the inplace operations here are expected to update the primal value of the Tensors but
 | |
|             # not always their tangents.
 | |
|             # Also all mentions of "non differentiable view" here means non forward differentiable view
 | |
|             # unless specified otherwise.
 | |
|             # See note [Forward Grad View/inplace] for more details on how these views work.
 | |
| 
 | |
|             # Check that inplace ops do not update non-differentiable views
 | |
|             # Non differentiable view
 | |
|             dual = fwAD.make_dual(foo, bar)
 | |
|             dual *= 2
 | |
|             # Check that non differentiable view's tangent was not updated
 | |
|             self.assertIsNone(fwAD.unpack_dual(foo)[1])
 | |
|             # Check that the computed result is correct
 | |
|             self.assertEqual(bar, original_bar * 2)
 | |
|             self.assertEqual(fwAD.unpack_dual(dual)[1], original_bar * 2)
 | |
|             self.assertEqual(foo, original_foo * 2)
 | |
|             self.assertEqual(fwAD.unpack_dual(dual)[0], original_foo * 2)
 | |
|             # Other non differentiable view
 | |
|             dual_primal, dual_tangent = fwAD.unpack_dual(dual)
 | |
|             self.assertIsNone(fwAD.unpack_dual(dual_primal)[1])
 | |
|             self.assertIsNone(fwAD.unpack_dual(dual_tangent)[1])
 | |
|             dual_primal *= 2
 | |
|             # Ensure dual's tangent did not change
 | |
|             self.assertEqual(fwAD.unpack_dual(dual)[0], original_foo * 4)
 | |
|             self.assertEqual(fwAD.unpack_dual(dual)[1], original_bar * 2)
 | |
|             dual_tangent *= 2
 | |
|             # Ensure dual's primal did not change
 | |
|             self.assertEqual(fwAD.unpack_dual(dual)[0], original_foo * 4)
 | |
|             self.assertEqual(fwAD.unpack_dual(dual)[1], original_bar * 4)
 | |
| 
 | |
| 
 | |
|     def test_view_inplace_differentiable_views(self):
 | |
|         original_foo = torch.rand(2)
 | |
|         original_bar = torch.ones(2)
 | |
| 
 | |
|         # Do clones to be able to compare the values updated inplace
 | |
|         # with the original content of these Tensors
 | |
|         foo = original_foo.clone()
 | |
|         bar = original_bar.clone()
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             # Check that inplace ops do update differentiable view but stop at non differentiable ones
 | |
|             # A non differentiable view
 | |
|             dual = fwAD.make_dual(foo, bar)
 | |
|             # A differentiable view
 | |
|             view = dual.narrow(0, 0, 1)
 | |
|             view *= 2
 | |
|             # Check that non differentiable view was not updated
 | |
|             self.assertIsNone(fwAD.unpack_dual(foo)[1])
 | |
|             # Check that differentiable view was updated
 | |
|             self.assertEqual(fwAD.unpack_dual(dual)[1], torch.tensor([2., 1.]))
 | |
|             self.assertEqual(fwAD.unpack_dual(view)[1], torch.tensor([2.]))
 | |
| 
 | |
|             # Check that we track differentiable view even for Tensors that are not dual
 | |
|             baz = torch.rand(2)
 | |
|             baz += dual
 | |
|             self.assertEqual(fwAD.unpack_dual(baz)[1], fwAD.unpack_dual(dual)[1])
 | |
|             # Updates on view should as well
 | |
|             baz = torch.rand(2)
 | |
|             baz[0] = dual[0]
 | |
|             self.assertEqual(fwAD.unpack_dual(baz)[1][0], fwAD.unpack_dual(dual)[1][0])
 | |
|             # Unused values get a gradient of 0
 | |
|             self.assertEqual(fwAD.unpack_dual(baz)[1][1], 0.)
 | |
| 
 | |
|             # Check that forward non-differentiable views do prevent gradient update
 | |
|             baz = torch.rand(2)
 | |
|             view = baz.detach()
 | |
|             view += dual
 | |
|             self.assertIsNone(fwAD.unpack_dual(baz)[1])
 | |
| 
 | |
|     def test_view_inplace_always_creates_a_view(self):
 | |
|         # See https://github.com/pytorch/pytorch/issues/67800
 | |
|         # The codepath may depend on the op. At the time writing, when self is not a dual tensor
 | |
|         # the resulting forward grad for self for...
 | |
|         # - add_ has the same layout as self
 | |
|         # - mul_ has the same layout as other
 | |
|         # This is kind of fragile because the above depends on how the forward grad expression
 | |
|         # is written. For add and mul at least, the output inherits the layout of LHS.
 | |
|         # We want to handle at least these two cases.
 | |
|         inplace_binary_ops = (  # Add more to this list?
 | |
|             lambda x, y: x.add_(y),
 | |
|             lambda x, y: x.mul_(y),
 | |
|             lambda x, y: x.copy_(y),
 | |
|         )
 | |
| 
 | |
|         for inplace_binary_op in inplace_binary_ops:
 | |
|             base = torch.randn(2, 2)
 | |
|             view = base.transpose(0, 1)
 | |
| 
 | |
|             primal = torch.randn(2, 2)
 | |
|             tangent = torch.randn(2, 2)
 | |
| 
 | |
|             with fwAD.dual_level():
 | |
|                 dual = fwAD.make_dual(primal, tangent)
 | |
|                 inplace_binary_op(view, dual)
 | |
| 
 | |
|                 # Verify that a view relationship is created for both the primal and tangent
 | |
|                 p, t = fwAD.unpack_dual(base)
 | |
|                 p_clone = p.clone()
 | |
|                 t_clone = t.clone()
 | |
|                 view *= 2
 | |
|                 p, t = fwAD.unpack_dual(base)
 | |
| 
 | |
|                 self.assertTrue(torch.allclose(p_clone * 2, p))
 | |
|                 self.assertTrue(torch.allclose(t_clone * 2, t))
 | |
| 
 | |
|     def test_grad_cleanup(self):
 | |
|         foo = torch.rand(2)
 | |
|         bar = torch.rand(2)
 | |
|         baz = torch.rand(2)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             dual = fwAD.make_dual(foo, bar)
 | |
|             self.assertIsNone(fwAD.unpack_dual(foo)[1])
 | |
|             self.assertIs(fwAD.unpack_dual(dual)[1], bar)
 | |
| 
 | |
|         self.assertIsNone(fwAD.unpack_dual(dual)[1])
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             self.assertIsNone(fwAD.unpack_dual(foo)[1])
 | |
|             new_dual = fwAD.make_dual(foo, baz)
 | |
| 
 | |
|             dual_primal, dual_tangent = fwAD.unpack_dual(dual)
 | |
|             new_dual_primal, new_dual_tangent = fwAD.unpack_dual(new_dual)
 | |
|             self.assertEqual(dual_primal, new_dual_primal)
 | |
|             self.assertIsNone(dual_tangent)
 | |
|             self.assertEqual(new_dual_tangent, baz)
 | |
| 
 | |
|     def test_detach_view_tracking(self):
 | |
|         # Default detach is both forward and backward non-differentiable
 | |
|         foo = torch.rand(2)
 | |
|         foo_weak = torch._C._WeakTensorRef(foo)
 | |
| 
 | |
|         out = foo.detach()
 | |
| 
 | |
|         del foo
 | |
|         self.assertTrue(foo_weak.expired())
 | |
| 
 | |
|     def test_out_variant(self):
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             foo = fwAD.make_dual(torch.rand(2), torch.rand(2))
 | |
|             bar = torch.rand(2)
 | |
| 
 | |
|             with self.assertRaisesRegex(RuntimeError, "out= function"):
 | |
|                 torch.add(bar, bar, out=foo)
 | |
| 
 | |
|             with self.assertRaisesRegex(RuntimeError, "out= function"):
 | |
|                 torch.add(foo, bar, out=bar)
 | |
| 
 | |
|     def test_non_differentiable(self):
 | |
|         with fwAD.dual_level():
 | |
|             foo = fwAD.make_dual(torch.rand(2), torch.rand(2))
 | |
|             bar = torch.rand(2)
 | |
| 
 | |
|             # No differentiable outputs, shouldn't error
 | |
|             eq = foo == bar
 | |
| 
 | |
|             # Inplace
 | |
|             foo.eq_(bar)
 | |
| 
 | |
|     def test_create_new_zeros_with_same_meta(self):
 | |
|         new_zeroes_fn = torch.ops.aten._new_zeros_with_same_feature_meta
 | |
| 
 | |
|         def check(a, b):
 | |
|             def assert_same_meta(t, target):
 | |
|                 for num_bdim in range(t.dim()):
 | |
|                     result = new_zeroes_fn(t, target, self_num_batch_dims=num_bdim)
 | |
| 
 | |
|                     self.assertEqual(result.dim(), target.dim() + num_bdim)
 | |
| 
 | |
|                     # Check size/strides match for feature dims only
 | |
|                     for i in range(num_bdim, result.dim()):
 | |
|                         self.assertEqual(result.size()[i], target.size()[i - num_bdim])
 | |
|                         self.assertEqual(result.stride()[i], target.stride()[i - num_bdim])
 | |
| 
 | |
|                     # Check that we generate strides reasonably
 | |
|                     if target.is_contiguous():
 | |
|                         self.assertTrue(result.is_contiguous())
 | |
| 
 | |
|                     self.assertEqual(result.storage_offset(), target.storage_offset())
 | |
| 
 | |
|                     prod_of_t_bdims = reduce(operator.mul, t.size()[:num_bdim], 1)
 | |
|                     self.assertEqual(len(result.storage()), len(target.storage()) * prod_of_t_bdims)
 | |
| 
 | |
|                     # TensorOptions is same
 | |
|                     self.assertEqual(result.dtype, target.dtype)
 | |
| 
 | |
|             assert_same_meta(a, b)
 | |
|             assert_same_meta(b, a)
 | |
| 
 | |
|         a = torch.randn(5, dtype=torch.float)
 | |
|         b = torch.randn(2, 3, 4, dtype=torch.double)
 | |
|         check(a, b)
 | |
| 
 | |
|         # non-contiguous case
 | |
|         a = torch.randn(2, 3, 4).transpose(0, 1).contiguous().transpose(0, 1)
 | |
|         b = torch.randn(2, 3, 4)
 | |
|         check(a, b)
 | |
| 
 | |
|         a = torch.randn(5).narrow(0, 1, 2)
 | |
|         b = torch.randn(2)
 | |
|         check(a, b)
 | |
| 
 | |
|         # tensor is not a view, but still does not index entirety of storage
 | |
|         a = torch.randn(5).resize_(4)
 | |
|         b = torch.randn(4)
 | |
|         check(a, b)
 | |
| 
 | |
|         # Zero-numel tensors
 | |
|         a = torch.randn(1, 0, 2)
 | |
|         b = torch.randn(1, 2)
 | |
|         check(a, b)
 | |
| 
 | |
|         # Scalar tensor
 | |
|         a = torch.tensor(1.)
 | |
|         b = torch.randn(1, 2)
 | |
|         check(a, b)
 | |
| 
 | |
|     def test_backward_graph_destruction(self):
 | |
|         def fn():
 | |
|             a = torch.rand(10, requires_grad=True)
 | |
| 
 | |
|             da = fwAD.make_dual(torch.rand_like(a), a)
 | |
| 
 | |
|             # Create an object with a c++ cycle as:
 | |
|             # db -> AutogradMeta -> ForwardGrad -> db's grad
 | |
|             # db's grad -> AutogradMeta -> MulBackward
 | |
|             # MulBackward -> SavedVariable -> db
 | |
|             db = da.exp()
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             fn()
 | |
|         # This test make sure that we don't deadlock on exit of this
 | |
|         # context manager. If you do, there is something wrong with the
 | |
|         # locking of the forward ad level most likely
 | |
| 
 | |
| # Generic device type autograd tests.
 | |
| class TestAutogradDeviceType(TestCase):
 | |
| 
 | |
|     def test_min_max_median_backprops_to_all_values(self, device):
 | |
|         for f in [torch.min, torch.max, torch.median, torch.nanmedian]:
 | |
|             x1 = torch.tensor([1., 0., 1., 0., 1., 0.], device=device, requires_grad=True)
 | |
|             x2 = torch.tensor([float('nan'), float('nan'), float('nan')], requires_grad=True)
 | |
|             for x in [x1, x2]:
 | |
|                 y = f(x)
 | |
|                 y.backward()
 | |
|                 self.assertEqual(x.grad.sum(), 1.)
 | |
|                 self.assertEqual((x.grad == 1 / 3).sum(), 3)
 | |
| 
 | |
|     def test_scatter_index_reduce_amin_amax_backprops_to_all_values(self, device):
 | |
|         # tests that gradients are evenly distributed when there are multiple max/min values
 | |
|         # tested here instead of adding a SampleInput as the backward for this case is non-differentiable for gradgrad
 | |
|         # as is the case for test_min_max_median_backprops_to_all_values above
 | |
|         fns = (torch.scatter_reduce, torch.index_reduce)
 | |
|         reduces = ('amin', 'amax')
 | |
|         for fn, reduction in product(fns, reduces):
 | |
|             input = torch.randn((2, 3), device=device, dtype=torch.float64, requires_grad=True)
 | |
|             src = input.clone().detach_().requires_grad_(True)
 | |
|             idx = torch.arange(2).to(dtype=torch.long, device=device)
 | |
|             if fn == torch.scatter_reduce:
 | |
|                 idx = idx.unsqueeze(-1).expand((2, 3))
 | |
| 
 | |
|             gradcheck(fn, (input, 0, idx, src, reduction), check_batched_grad=False)
 | |
| 
 | |
|     def test_scatter_index_reduce_prod_gradgrad_error(self, device):
 | |
|         # test that double backward raises an error for the case where 2 zeros in src
 | |
|         # are scattered to the same position in self
 | |
|         input = torch.tensor([1.], device=device, dtype=torch.float64, requires_grad=True)
 | |
|         src = torch.tensor([0., 0.], device=device, dtype=torch.float64, requires_grad=True)
 | |
|         idx = torch.tensor([0, 0], device=device, dtype=torch.long)
 | |
| 
 | |
|         for fn in (torch.scatter_reduce, torch.index_reduce):
 | |
|             # check that this case passes on gradcheck
 | |
|             gradcheck(fn, (input, 0, idx, src, 'prod'), check_batched_grad=False)
 | |
|             with self.assertRaisesRegex(RuntimeError, "Double backward is unsupported for"):
 | |
|                 gradgradcheck(fn, (input, 0, idx, src, 'prod'))
 | |
| 
 | |
|     @skipIfMps  # the test doesn't work on MPS as double types are not supported
 | |
|     def test_parameter_resize(self, device):
 | |
|         asd = torch.nn.Parameter(torch.ones(16, dtype=torch.double, device=device))
 | |
| 
 | |
|         for i in range(2):
 | |
|             with torch.no_grad():
 | |
|                 asd.set_(asd[1:])
 | |
|                 asd.grad = None
 | |
| 
 | |
|             m = torch.cat((asd, asd))
 | |
|             m.sum().backward()
 | |
| 
 | |
|     @skipIfMps  # the test doesn't work on MPS as double types are not supported
 | |
|     @dtypes(torch.double, torch.cdouble)
 | |
|     def test_sparse_ctor_getter_backward(self, device, dtype):
 | |
|         # See NOTE [ Sparse: autograd and API ] on the expected behavior of this test
 | |
|         def _test(size, sparse_dim, nnz, device):
 | |
|             v_size = [nnz] + list(size[sparse_dim:])
 | |
|             i = torch.rand(sparse_dim, nnz)
 | |
|             i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
 | |
|             i = i.to(torch.long)
 | |
| 
 | |
|             inp = torch.randn(v_size, dtype=torch.double, device=device, requires_grad=True)
 | |
|             other = self.genSparseTensor(size, sparse_dim, nnz, is_uncoalesced=True, device=device,
 | |
|                                          dtype=dtype)[0]
 | |
| 
 | |
|             def fn(v):
 | |
|                 x = torch.sparse_coo_tensor(i, v, size, dtype=dtype, device=device)
 | |
|                 y = (x + other).coalesce()
 | |
|                 yv = y.values()
 | |
|                 new_v = yv.tanh()
 | |
|                 z = torch.sparse_coo_tensor(y.indices(), new_v, y.size())
 | |
|                 return z.coalesce().values()
 | |
| 
 | |
|             gradcheck(fn, (inp,), check_batched_grad=False)
 | |
|             # FIXME: make gradgradcheck work.
 | |
|             # gradgradcheck(fn, (inp,), check_batched_grad=False)
 | |
| 
 | |
|             # assert that _values is non-differentiable
 | |
|             with self.assertRaisesRegex(RuntimeError, "does not have a grad_fn"):
 | |
|                 other.detach().requires_grad_()._values().backward(torch.ones_like(other._values()))
 | |
| 
 | |
|         for empty_i, empty_v, empty_nnz in product([True, False], repeat=3):
 | |
|             sparse_size = [] if empty_i else [2, 1]
 | |
|             dense_size = [1, 0, 2] if empty_v else [1, 2]
 | |
|             nnz = 0 if empty_nnz else 5
 | |
|             _test(sparse_size + dense_size, len(sparse_size), nnz, device)
 | |
| 
 | |
|     @skipMeta
 | |
|     @skipIfMps
 | |
|     @dtypes(torch.double, torch.cdouble)
 | |
|     def test_sparse_backward(self, device, dtype):
 | |
|         class FixedGradientFunction(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, grad_x):
 | |
|                 ctx.save_for_backward(grad_x)
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_x):
 | |
|                 saved_grad_x, = ctx.saved_tensors
 | |
|                 return saved_grad_x, None
 | |
| 
 | |
|         size = torch.Size([6, 3, 2])
 | |
|         i1 = torch.tensor([
 | |
|             [0, 3, 4],
 | |
|             [0, 2, 2],
 | |
|         ], dtype=torch.long)
 | |
|         v1 = make_tensor([3, 2], dtype=dtype, device=device)
 | |
|         sparse_grad1 = torch.sparse_coo_tensor(i1, v1, size, dtype=dtype, device=device)
 | |
|         i2 = torch.tensor([
 | |
|             [0, 1, 3, 4],
 | |
|             [0, 1, 2, 2],
 | |
|         ], dtype=torch.long)
 | |
|         v2 = make_tensor([4, 2], dtype=dtype, device=device)
 | |
|         sparse_grad2 = torch.sparse_coo_tensor(i2, v2, size, dtype=dtype, device=device)
 | |
|         dense_grad = torch.rand(size, device=device, dtype=dtype)
 | |
|         fn = FixedGradientFunction
 | |
| 
 | |
|         # sparse first
 | |
|         x = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
 | |
|         (fn.apply(x, sparse_grad1) + fn.apply(x, dense_grad) + fn.apply(x, sparse_grad2)).sum().abs().backward()
 | |
|         self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
 | |
|         # dense first
 | |
|         x = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
 | |
|         (fn.apply(x, dense_grad) + fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().abs().backward()
 | |
|         self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
 | |
|         # sparse only
 | |
|         x = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
 | |
|         (fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().abs().backward()
 | |
|         self.assertEqual(x.grad, sparse_grad1 + sparse_grad2)
 | |
| 
 | |
|     @skipIfMps
 | |
|     def test_sparse_mask_autograd(self, device):
 | |
|         tensor = torch.randn(3, requires_grad=True, device=device)
 | |
|         mask = torch.ones(3, device=device)
 | |
|         mask[1] = 0
 | |
|         mask = mask.to_sparse()
 | |
|         converted = tensor.sparse_mask(mask).to_dense()
 | |
|         converted.sum().backward()
 | |
|         self.assertEqual(tensor.grad, mask.to_dense())
 | |
| 
 | |
|     @skipIfMps  # the test doesn't work on MPS as double types are not supported
 | |
|     def test_pyscalar_conversions(self, device):
 | |
|         def _test_pyscalar_conversions(t, integral_conv):
 | |
|             # integral -> integral
 | |
|             l = t(torch.zeros(1, 1, 1, dtype=torch.long))
 | |
|             pyscalar = -12345
 | |
|             l[0] = pyscalar
 | |
|             self.assertEqual(integral_conv(l), pyscalar)
 | |
| 
 | |
|             # floating point -> floating point
 | |
|             f = Variable(t(torch.randn(1, 1, dtype=torch.double)))
 | |
|             pyscalar = -12345.1
 | |
|             f[0] = pyscalar
 | |
|             self.assertEqual(float(f), pyscalar)
 | |
|             f[0] = nan
 | |
|             self.assertTrue(math.isnan(float(f)))
 | |
|             f[0] = inf
 | |
|             self.assertEqual(float(f), inf)
 | |
|             f[0] = -inf
 | |
|             self.assertEqual(float(f), -inf)
 | |
| 
 | |
|             # integral -> floating point
 | |
|             # check we can convert something that loses precision
 | |
|             pyscalar = 1234567890123456789
 | |
|             self.assertNotEqual(pyscalar, integral_conv(float(pyscalar)))
 | |
|             l[0] = pyscalar
 | |
|             self.assertEqual(float(l), float(pyscalar))
 | |
| 
 | |
|             # floating point -> integral
 | |
|             f[0] = nan
 | |
|             self.assertRaises(ValueError, lambda: integral_conv(f[0]))
 | |
|             f[0] = inf
 | |
|             self.assertRaises(OverflowError, lambda: integral_conv(f[0]))
 | |
|             f[0] = -inf
 | |
|             self.assertRaises(OverflowError, lambda: integral_conv(f[0]))
 | |
|             f[0] = sys.float_info.max
 | |
|             self.assertEqual(integral_conv(f), sys.float_info.max)
 | |
| 
 | |
|             # bool, nonzero
 | |
|             def test_nonzero(tensor, value, expected):
 | |
|                 tensor[0] = value
 | |
|                 self.assertEqual(expected, bool(tensor))
 | |
|                 self.assertEqual(expected, True if tensor else False)
 | |
| 
 | |
|             test_nonzero(l, 0, False)
 | |
|             test_nonzero(l, -2, True)
 | |
|             test_nonzero(f, 0.0, False)
 | |
|             test_nonzero(f, sys.float_info.min, True)
 | |
|             test_nonzero(f, nan, bool(nan))
 | |
|             test_nonzero(f, inf, bool(inf))
 | |
|             test_nonzero(f, -inf, bool(-inf))
 | |
| 
 | |
| 
 | |
|         _test_pyscalar_conversions(lambda x: x.to(device), lambda x: int(x))
 | |
| 
 | |
|     @dtypesIfMPS(torch.float32)
 | |
|     @dtypesIfCUDA(torch.half, torch.float, torch.double, torch.int8, torch.int16, torch.int32, torch.int64)
 | |
|     @dtypes(torch.float, torch.double, torch.int8, torch.int16, torch.int32, torch.int64)
 | |
|     def test_set_requires_grad_only_for_floats(self, device, dtype):
 | |
|         def f1():
 | |
|             a = torch.ones(1, dtype=dtype, device=device)
 | |
|             a.requires_grad_()
 | |
| 
 | |
|         def f2():
 | |
|             a = torch.ones(1, dtype=dtype, device=device)
 | |
|             a.requires_grad = True
 | |
| 
 | |
|         def f3():
 | |
|             torch.ones(1, dtype=dtype, device=device, requires_grad=True)
 | |
| 
 | |
|         a = torch.ones(1, dtype=dtype, device=device)
 | |
|         a.requires_grad = False  # should always work
 | |
|         a.requires_grad_(False)
 | |
| 
 | |
|         for f in [f1, f2, f3]:
 | |
|             if dtype.is_floating_point:
 | |
|                 f()
 | |
|             else:
 | |
|                 with self.assertRaisesRegex(RuntimeError, 'floating point', msg="dt: {} device: {}".format(a.dtype, a.device)):
 | |
|                     f()
 | |
| 
 | |
|     @onlyCUDA
 | |
|     def test_advanced_indexing_backwards_large(self, device):
 | |
|         # See https://github.com/pytorch/pytorch/issues/22843
 | |
|         n = (1 << 16)
 | |
|         x = torch.rand(n, 1, device=device, requires_grad=True)
 | |
|         a = x[:, [0]]
 | |
|         a.sum().backward()
 | |
|         self.assertEqual(x.grad, torch.ones(n, 1, device=device))
 | |
| 
 | |
|     def test_advanced_indexing_backwards_memory_format(self, device):
 | |
|         # See https://github.com/pytorch/pytorch/issues/36956
 | |
|         shape = (2, 8, 1, 2)
 | |
|         i = torch.randint(1, shape, device=device).contiguous(memory_format=torch.channels_last)
 | |
|         x = torch.randn(shape, requires_grad=True, device=device)
 | |
|         x[i].sum().backward()
 | |
| 
 | |
|     def _test_reentrant_parent_error_on_cpu(self, device):
 | |
|         t1 = torch.rand([3, 3], requires_grad=True)
 | |
|         t2 = torch.rand([3, 3], device=device, requires_grad=True)
 | |
|         t3 = torch.rand([3, 3], device=device, requires_grad=True)
 | |
| 
 | |
|         # Parent graph cpu graph.
 | |
|         t4 = t1 * t1
 | |
|         t5 = TestAutograd.SimulateBackwardError.apply(t4)
 | |
| 
 | |
|         # Child gpu graph (much longer than parent graph).
 | |
|         prev = t2 * t2
 | |
|         for i in range(10):
 | |
|             prev = prev * t2
 | |
|         reentrant_root = prev
 | |
| 
 | |
|         class ReentrantFunc(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, inp):
 | |
|                 return inp.clone()
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 # Reentrant backward in child will take much longer.
 | |
|                 reentrant_root.backward()
 | |
|                 return grad
 | |
| 
 | |
|         # Parent gpu graph.
 | |
|         t6 = ReentrantFunc.apply(t3)
 | |
|         t7 = t6 * t6
 | |
| 
 | |
|         # Parent graph will error out first, while child graph will continue executing.
 | |
|         with self.assertRaisesRegex(Exception, "Simulate error"):
 | |
|             torch.autograd.backward([t5.sum(), t7.sum()])
 | |
| 
 | |
|         # No grads should be accumulated since child graph will stop execution
 | |
|         # after parent receives error.
 | |
|         self.assertIsNone(t2.grad)
 | |
|         self.assertIsNone(t1.grad)
 | |
|         self.assertIsNone(t3.grad)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     def test_reentrant_parent_error_on_cpu(self, device):
 | |
|         def _get_cuda_memory_usage():
 | |
|             # we don't need CUDA synchronize because the statistics are not tracked at
 | |
|             # actual freeing, but at when marking the block as free.
 | |
|             num_devices = torch.cuda.device_count()
 | |
|             gc.collect()
 | |
|             return tuple(torch.cuda.memory_allocated(i) for i in range(num_devices))
 | |
| 
 | |
|         before = _get_cuda_memory_usage()
 | |
| 
 | |
|         # Run as separate function so that gc can clean up everything when we
 | |
|         # check for memory usage.
 | |
|         self._test_reentrant_parent_error_on_cpu(device)
 | |
| 
 | |
|         # Wait for autograd thread to cleanup failed tasks.
 | |
|         after = _get_cuda_memory_usage()
 | |
|         start = time.time()
 | |
|         while before != after and time.time() - start < 30:
 | |
|             time.sleep(0.1)
 | |
|             after = _get_cuda_memory_usage()
 | |
| 
 | |
|         self.assertEqual(before, after)
 | |
| 
 | |
|     @skipIfMps  # the test doesn't work on MPS
 | |
|     # TODO: see if these tests can be ported to OpInfos or moved to where's test suite
 | |
|     def test_where_functional(self, device):
 | |
|         x = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True)
 | |
|         y = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True)
 | |
|         cond = mask_not_all_zeros((5, 5)).to(device=device)
 | |
| 
 | |
|         def where(cond, x, y):
 | |
|             return torch.where(cond, x, y)
 | |
| 
 | |
|         gradcheck(where, [cond, x, y], raise_exception=True)
 | |
|         gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, device=device)])
 | |
| 
 | |
|         x = torch.randn(5, 1, 5, dtype=torch.double, device=device, requires_grad=True)
 | |
|         y = torch.randn(5, 5, 1, dtype=torch.double, device=device, requires_grad=True)
 | |
|         gradcheck(where, [cond, x, y], raise_exception=True)
 | |
|         gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, 5, device=device)])
 | |
| 
 | |
|     @skipIfMps  # the test doesn't work on MPS
 | |
|     def test_where_scalar(self, device):
 | |
|         x = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True)
 | |
|         scalar = 4.
 | |
|         cond = mask_not_all_zeros((5, 5)).to(device=device)
 | |
| 
 | |
|         def where_scalar_first(cond, x):
 | |
|             return torch.where(cond, scalar, x)
 | |
| 
 | |
|         def where_scalar_second(cond, x):
 | |
|             return torch.where(cond, x, scalar)
 | |
| 
 | |
|         gradcheck(where_scalar_first, (cond, x))
 | |
|         gradgradcheck(where_scalar_first, (cond, x))
 | |
| 
 | |
|         gradcheck(where_scalar_second, (cond, x))
 | |
|         gradgradcheck(where_scalar_second, (cond, x))
 | |
| 
 | |
|     @onlyCUDA
 | |
|     def test_free_unneeded_tensor(self, device):
 | |
|         x = torch.randn(2, 3, 10, 10, device=device, requires_grad=True)
 | |
|         m = torch.randn(1, 3, 1, 1, device=device)
 | |
| 
 | |
|         z = x.sum()
 | |
|         base_mem = torch.cuda.memory_allocated()
 | |
|         z = ((x + 2) * m).sum()
 | |
|         end_mem = torch.cuda.memory_allocated()
 | |
| 
 | |
|         # In the end the memory usage should remain equal, because neither of
 | |
|         # (x + 2) and ((x + 2) * m) should be kept alive for backward, while the
 | |
|         # previous allocation of z had the same size as the current one.
 | |
|         self.assertEqual(base_mem, end_mem)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     def test_pin_memory(self, device):
 | |
|         x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
 | |
|         self.assertEqual(x, x.pin_memory())
 | |
|         self.assertIsNot(x, x.pin_memory())
 | |
|         self.assertTrue(x.pin_memory().requires_grad)
 | |
|         gradcheck(lambda x: x.pin_memory(), [x])
 | |
|         gradgradcheck(lambda x: x.pin_memory(), [x])
 | |
| 
 | |
|     @onlyCUDA
 | |
|     def test_profiler_emit_nvtx(self, device):
 | |
|         # This test is not intended to ensure correctness of nvtx ranges.
 | |
|         # That would require something a great deal more complex (you'd have to create a
 | |
|         # profile in a subprocess, open it, and parse the sql somehow).
 | |
|         # This test is merely intended to catch if emit_nvtx breaks on construction.
 | |
|         a = torch.tensor([1, 2, 3], dtype=torch.float32, device=device)
 | |
|         with torch.cuda.profiler.profile():
 | |
|             with emit_nvtx():
 | |
|                 a.add(1.0)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     def test_rnn_backward_to_input_but_not_parameters(self, device):
 | |
|         # this checks whether it is possible to not require
 | |
|         # weight parameters, but require inputs, see #7722
 | |
|         l = torch.nn.LSTM(2, 3).to(device)
 | |
|         for p in l.parameters():
 | |
|             p.requires_grad = False
 | |
|         s = torch.randn(1, 1, 2, requires_grad=True, device=device)
 | |
|         out, _ = l(s)
 | |
|         out.sum().backward()
 | |
|         self.assertFalse(s.grad is None or s.grad.abs().sum().item() == 0)
 | |
| 
 | |
|     @unittest.skipIf(not torch.profiler.itt.is_available(), "ITT is required")
 | |
|     def test_profiler_emit_itt(self, device):
 | |
|         # This test is not intended to ensure correctness of itt ranges.
 | |
|         # That would require something a great deal more complex (you'd have to create a
 | |
|         # profile in a subprocess, open it, and parse the sql somehow).
 | |
|         # This test is merely intended to catch if emit_itt breaks on construction.
 | |
|         a = torch.tensor([1, 2, 3], dtype=torch.float32, device=device)
 | |
|         with emit_itt():
 | |
|             a.add(1.0)
 | |
| 
 | |
|     @skipIfMps  # the test doesn't work as randn is not supported with type long
 | |
|     @deviceCountAtLeast(1)
 | |
|     def test_grad_assignment(self, devices):
 | |
|         x = torch.randn(5, 5, device=devices[0])
 | |
| 
 | |
|         # Tests that the wrong type raises
 | |
|         with self.assertRaisesRegex(TypeError, "expected to be a Tensor or None"):
 | |
|             x.grad = 0
 | |
| 
 | |
|         # Tests that the wrong shape raises
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             x.grad = torch.randn(2, 2, device=devices[0])
 | |
| 
 | |
|         # Tests that the wrong dtype raises
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             x.grad = torch.randn(5, 5, dtype=torch.long, device=devices[0])
 | |
| 
 | |
|         # Tests that self-assignment raises
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             x.grad = x
 | |
| 
 | |
|         # Tests device -> cpu grad assignment raises
 | |
|         if self.device_type != 'cpu':
 | |
|             with self.assertRaises(RuntimeError):
 | |
|                 t_cpu = torch.rand(5, 5)
 | |
|                 t_cpu.grad = torch.randn(5, 5, device=devices[0])
 | |
| 
 | |
|         # Tests half type on CUDA
 | |
|         if self.device_type == 'cuda':
 | |
|             x = x.to(dtype=torch.half, device=devices[0])
 | |
|             x.grad = torch.zeros_like(x)
 | |
| 
 | |
|         # Tests cross-device assignment raises
 | |
|         if len(devices) > 1:
 | |
|             x = torch.randn(5, 5, device=devices[0])
 | |
|             with self.assertRaises(RuntimeError):
 | |
|                 x.grad = torch.randn(5, 5, device=devices[1])
 | |
| 
 | |
|     @dtypesIfMPS(torch.float32)
 | |
|     @deviceCountAtLeast(1)
 | |
|     @dtypes(torch.float, torch.double)
 | |
|     def test_requires_grad_factory(self, devices, dtype):
 | |
|         fns = [torch.ones_like, torch.randn_like]
 | |
|         x = torch.randn(2, 3, dtype=dtype, device=devices[0])
 | |
| 
 | |
|         for fn in fns:
 | |
|             for requires_grad in [True, False]:
 | |
|                 output = fn(x, dtype=dtype, device=devices[0], requires_grad=requires_grad)
 | |
|                 self.assertEqual(requires_grad, output.requires_grad)
 | |
|                 self.assertIs(dtype, output.dtype)
 | |
|                 self.assertEqual(devices[0], str(x.device))
 | |
| 
 | |
|     @deviceCountAtLeast(2)
 | |
|     def test_unused_output_device(self, devices):
 | |
|         from torch.nn.parallel._functions import Broadcast
 | |
|         x = torch.randn(5, 5, dtype=torch.float, device=devices[0], requires_grad=True)
 | |
|         outputs = Broadcast.apply(list(range(len(devices))), x)
 | |
|         y = outputs[-1] * 2
 | |
|         y.sum().backward()
 | |
|         self.assertEqual(x.grad, torch.ones(5, 5) * 2)
 | |
| 
 | |
|     @deviceCountAtLeast(2)
 | |
|     def test_backward_device(self, devices):
 | |
|         # check that current device matches the variable's device
 | |
|         device = [None]
 | |
| 
 | |
|         class Identity(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x.clone()
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output):
 | |
|                 device[0] = grad_output.device
 | |
|                 return grad_output.clone()
 | |
| 
 | |
|         v = torch.randn(1, device=devices[1], requires_grad=True)
 | |
|         Identity.apply(v).backward()
 | |
|         self.assertEqual(str(device[0]), devices[1])
 | |
| 
 | |
|     @deviceCountAtLeast(2)
 | |
|     def test_inputbuffer_add_multidevice(self, devices):
 | |
|         input = torch.randn(1, device=devices[0], requires_grad=True)
 | |
|         output = input.to(device=devices[1]) + input.to(device=devices[1])
 | |
|         output.backward()
 | |
| 
 | |
|     @onlyCPU
 | |
|     def test_copy_(self, device):
 | |
|         # At the time of writing this test, copy_ is not generated from native_functions.yaml
 | |
|         # there was a bug that bfloat16 was not recognized as floating.
 | |
|         x = torch.randn(10, device=device, requires_grad=True)
 | |
|         floating_dt = floating_types_and(torch.half, torch.bfloat16)
 | |
|         for dt in floating_dt:
 | |
|             y = torch.empty(10, device=device, dtype=dt)
 | |
|             y.copy_(x)
 | |
|             self.assertTrue(y.requires_grad)
 | |
|             z = x.to(torch.bfloat16)
 | |
|             self.assertTrue(z.requires_grad)
 | |
| 
 | |
|     def test_copy_forward_ad_broadcasting(self, device):
 | |
|         # copy_ allows the src to have a different shape from self as long as src is
 | |
|         # broadcastable to self. Make sure forward AD handles this case.
 | |
|         primal = torch.rand(3, 3, device=device)
 | |
|         tangent = torch.rand(3, 3, device=device)
 | |
|         non_dual = torch.rand(1, 3, 3, device=device)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             dual = fwAD.make_dual(primal, tangent)
 | |
|             non_dual.copy_(dual)
 | |
| 
 | |
|     def test_copy_forward_ad_same_layout_copies_grad(self, device):
 | |
|         primal = torch.tensor([[3.], [4.]], device=device)
 | |
|         tangent = torch.tensor([[5.], [6.]], device=device)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             x_dual = fwAD.make_dual(primal, tangent)
 | |
|             non_dual = torch.tensor([[1.], [2.]])
 | |
|             non_dual.copy_(x_dual)
 | |
|             self.assertTrue(fwAD.unpack_dual(non_dual).tangent is not tangent)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     def test_simple_reentrant_cross_device(self, device):
 | |
|         class ReentrantFunc(Function):
 | |
|             _cpu_mode = True
 | |
| 
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x * (x + 2)
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad_output):
 | |
|                 with torch.enable_grad():
 | |
|                     if ReentrantFunc._cpu_mode:
 | |
|                         new_param = torch.randn(2, 2, requires_grad=True)
 | |
|                         (new_param ** 2).sum().backward()
 | |
|                     else:
 | |
|                         new_param = torch.randn(2, 2, device=device, requires_grad=True)
 | |
|                         (new_param ** 2).sum().backward()
 | |
|                 return grad_output
 | |
| 
 | |
|         # Reentrant starts on GPU thread, finishs on GPU thread
 | |
|         x = torch.randn(2, 2, device=device, requires_grad=True)
 | |
|         out = ReentrantFunc.apply(x)
 | |
|         out.sum().backward()
 | |
| 
 | |
|         # Reentrant starts on CPU thread, finishs on GPU thread
 | |
|         x = torch.randn(2, 2, requires_grad=True)
 | |
|         # set ReentrantFunc node to GPU to emit tasks to GPU queue
 | |
|         ReentrantFunc._cpu_mode = False
 | |
|         out = ReentrantFunc.apply(x)
 | |
|         out.sum().backward()
 | |
| 
 | |
|         # Reentrant starts on GPU thread, finishs on CPU thread
 | |
|         x = torch.randn(2, 2, device=device, requires_grad=True)
 | |
|         # set ReentrantFunc node to CPU to emit tasks to CPU queue
 | |
|         ReentrantFunc._cpu_mode = True
 | |
|         out = ReentrantFunc.apply(x)
 | |
|         out.sum().backward()
 | |
| 
 | |
|     @onlyCUDA
 | |
|     def test_cross_device_reentrant_autograd(self, device):
 | |
|         # Output on gpu so that this task will be associated with the gpu thread
 | |
|         def fn_on_gpu(inp):
 | |
|             # Artificially increase the priority of the next op to make sure it runs
 | |
|             # as soon as we reach it before the ops of branch1.
 | |
|             dummy = inp * 2 * 2 * 2 * 2
 | |
|             return inp.to(device=device)
 | |
| 
 | |
|         def parent_on_cpu(inp):
 | |
|             # Slow branch of ops on gpu so that the work queue for the gpu thread
 | |
|             # won't empty too quickly. They also have smaller priorities than the
 | |
|             # ones created by fn_on_gpu
 | |
|             branch1 = inp.to(device=device)
 | |
|             branch1 = branch1 / branch1
 | |
|             branch1 = branch1 / branch1
 | |
|             branch1 = branch1 / branch1
 | |
|             # Perform checkpoint on cpu tensors. So the last op performed in the reentrant
 | |
|             # autograd is an AccumulateGrad that runs on the cpu thread for the gpu thread.
 | |
|             # So the cpu thread will notify the gpu thread with an empty NodeTask.
 | |
|             branch2 = checkpoint(fn_on_gpu, inp, use_reentrant=True)
 | |
|             out = branch2 + branch1
 | |
|             return out
 | |
| 
 | |
|         inp = torch.rand(2, requires_grad=True)
 | |
|         out = parent_on_cpu(inp)
 | |
|         # This will segfault if the empty NodeTask is not handled properly in the
 | |
|         # gpu thread ReadyQueue
 | |
|         out.sum().backward()
 | |
| 
 | |
|     def test_inplace_on_view_backprop_base(self, device):
 | |
|         # modify view and back-prop through base
 | |
|         root = torch.randn(2, 2, device=device, requires_grad=True)
 | |
|         x = root.clone()
 | |
|         v1 = x.narrow(0, 0, 1)
 | |
|         v1.mul_(2)
 | |
|         x.sum().backward()
 | |
|         self.assertEqual(root.grad.tolist(), [[2, 2], [1, 1]])
 | |
| 
 | |
|     def test_inplace_on_view_backprop_view_of_view(self, device):
 | |
|         # modify view and backprop through view-of-view
 | |
|         root = torch.randn(2, 2, device=device, requires_grad=True)
 | |
|         x = root.clone()
 | |
|         v1 = x.narrow(0, 0, 1)
 | |
|         v2 = x.narrow(0, 0, 1)
 | |
|         v1.mul_(2)
 | |
|         v2.sum().backward()
 | |
|         self.assertEqual(root.grad.tolist(), [[2, 2], [0, 0]])
 | |
| 
 | |
|     def test_inplace_on_view_of_view(self, device):
 | |
|         # modify view-of-view and backprop through base
 | |
|         root = torch.randn(2, 2, device=device, requires_grad=True)
 | |
|         x = root.clone()
 | |
| 
 | |
|         v1 = x.narrow(0, 0, 1)
 | |
|         v2 = v1.narrow(1, 1, 1)
 | |
|         v2.mul_(2)
 | |
|         x.sum().backward()
 | |
|         self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1]])
 | |
| 
 | |
|     @skipIfMps  # the test doesn't work on MPS as double types are not supported
 | |
|     def test_inplace_on_view_then_no_grad(self, device):
 | |
|         # Perform an in-place operation on a view of a non-leaf variable.
 | |
|         a = torch.ones(3, 1, dtype=torch.double, device=device, requires_grad=True)
 | |
|         b = a * 2
 | |
|         c = b.view_as(b)
 | |
|         c[0][0] = 3
 | |
| 
 | |
|         # Force a graph update with grad disabled.
 | |
|         with torch.no_grad():
 | |
|             c.grad_fn
 | |
| 
 | |
|         c.sum().backward()
 | |
| 
 | |
|     @skipIfMps  # the test doesn't work on MPS as double types are not supported
 | |
|     def test_inplace_on_view_gradcheck(self, device):
 | |
|         # gradcheck modifications to views
 | |
|         a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True)
 | |
|         b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True)
 | |
| 
 | |
|         def func(root, b):
 | |
|             x = root.clone()
 | |
|             x.narrow(1, 2, 2).narrow(0, 1, 2).mul_(b)
 | |
|             x.narrow(1, 0, 2).narrow(0, 1, 2).mul_(b)
 | |
|             return x
 | |
| 
 | |
|         gradcheck(func, [a, b], raise_exception=True)
 | |
|         go = torch.randn(a.size(), dtype=torch.double, device=device, requires_grad=True)
 | |
|         gradgradcheck(func, (a, b), (go,))
 | |
| 
 | |
|     def test_inplace_on_view_multiple_outputs(self, device):
 | |
|         root = torch.arange(9., dtype=torch.double).reshape(3, 3).requires_grad_()
 | |
|         x = root.clone()
 | |
|         v1 = x.unbind()
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             v1[0].mul_(2)
 | |
| 
 | |
|     @skipIfMps  # the test doesn't work on MPS as double types are not supported
 | |
|     def test_inplace_on_view_of_multiple_output_view(self, device):
 | |
|         a = torch.rand(10, dtype=torch.double, device=device, requires_grad=True).clone()
 | |
|         b = a.unbind(0)
 | |
|         c = b[0].view_as(b[0])
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             c.mul_(2)
 | |
| 
 | |
|     @skipIfMps  # MPS backend doesn't support double types
 | |
|     def test_inplace_multiple_output_view_of_view(self, device):
 | |
|         a = torch.rand(10, dtype=torch.double, device=device, requires_grad=True).clone()
 | |
|         b = a.view_as(a)
 | |
|         c = b.unbind(0)
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             c[0].mul_(2)
 | |
| 
 | |
|     @skipIfMps  # MPS backend doesn't support double types
 | |
|     def test_inplace_on_view_makes_base_require_grad(self, device):
 | |
|         # in-place modification to view makes base require grad
 | |
|         a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=False)
 | |
|         b = torch.randn(4, 2, dtype=torch.double, device=device, requires_grad=True)
 | |
| 
 | |
|         def func(root, b):
 | |
|             x = root.clone()
 | |
|             self.assertFalse(x.requires_grad)
 | |
|             x.narrow(1, 2, 2).mul_(b)
 | |
|             self.assertTrue(x.requires_grad)
 | |
|             return x
 | |
| 
 | |
|         gradcheck(func, [a, b], raise_exception=True)
 | |
|         go = torch.randn(a.size(), dtype=torch.double, device=device, requires_grad=True)
 | |
|         gradgradcheck(func, (a, b), (go,))
 | |
| 
 | |
|     def test_inplace_on_view_backprop_view(self, device):
 | |
|         # modify view and backprop through view
 | |
|         a = torch.tensor([2., 5.], device=device, requires_grad=False)
 | |
|         b = torch.tensor([3.], device=device, requires_grad=True)
 | |
|         res = a.narrow(0, 1, 1).mul_(b)
 | |
|         res.sum().backward()
 | |
|         self.assertEqual(b.grad.tolist(), [5])
 | |
|         self.assertIsNone(a.grad)
 | |
| 
 | |
|     @skipIfMps  # the test doesn't work on MPS as double types are not supported
 | |
|     def test_inplace_on_view_modify_base(self, device):
 | |
|         # Test that an in-place operation on a base that forced it to require
 | |
|         # grad also forces any previous views to require grad and backprop
 | |
|         # correctly
 | |
|         r = torch.ones(1, dtype=torch.double, device=device, requires_grad=True)
 | |
| 
 | |
|         def fn(r):
 | |
|             x = torch.ones(5, dtype=torch.double, device=device)
 | |
|             v = x.select(0, 1)
 | |
|             self.assertFalse(v.requires_grad)
 | |
|             self.assertIsNone(v.grad_fn)
 | |
|             x.add_(r)  # v is now dependent on r due to the in-place op on x
 | |
|             self.assertTrue(v.requires_grad)
 | |
|             return v
 | |
| 
 | |
|         gradcheck(fn, [r])
 | |
|         gradgradcheck(fn, [r])
 | |
| 
 | |
|     @skipIfMps  # the test doesn't work on MPS as double types are not supported
 | |
|     def test_inplace_on_view_python(self, device):
 | |
|         # in-place modifications of Python-autograd created view
 | |
|         a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True)
 | |
|         b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True)
 | |
| 
 | |
|         class PyAdd(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, y):
 | |
|                 ctx.mark_dirty(x)
 | |
|                 x.add_(y)
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, grad):
 | |
|                 return grad, grad
 | |
| 
 | |
|         def func(root, b):
 | |
|             x = root.clone()
 | |
|             PyAdd.apply(x.narrow(1, 2, 2).narrow(0, 1, 2), b)
 | |
|             PyAdd.apply(x.narrow(1, 0, 2).narrow(0, 1, 2), b)
 | |
|             return x
 | |
| 
 | |
|         gradcheck(func, [a, b], raise_exception=True)
 | |
|         go = torch.randn(a.size(), dtype=torch.double, device=device, requires_grad=True)
 | |
|         gradgradcheck(func, (a, b), (go,))
 | |
| 
 | |
|     def test_inplace_on_view_non_contig(self, device):
 | |
|         root = torch.ones(2, 3, 2, device=device).select(2, 1).t().requires_grad_(True)
 | |
|         x = root.clone()
 | |
|         v1 = x.narrow(0, 0, 1)
 | |
|         v2 = v1.narrow(1, 1, 1)
 | |
|         v2.mul_(2)
 | |
|         x.sum().backward()
 | |
|         self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1], [1, 1]])
 | |
| 
 | |
|     def test_inplace_on_view_multi_output_unsafe(self, device):
 | |
|         for f in [lambda t: t.unsafe_split(1),
 | |
|                   lambda t: t.unsafe_split_with_sizes((1, 1, 1)),
 | |
|                   lambda t: t.unsafe_chunk(3)]:
 | |
|             a = torch.randn(3, 3, device=device, requires_grad=True)
 | |
|             b = a + a
 | |
|             s1, s2, s3 = f(b)
 | |
|             s1.mul_(s2)
 | |
|             s1.sum().backward()
 | |
| 
 | |
|     def test_inplace_on_view_multi_output_safe(self, device):
 | |
|         for f in [lambda t: t.split(1),
 | |
|                   lambda t: t.split_with_sizes((1, 1, 1)),
 | |
|                   lambda t: t.chunk(3)]:
 | |
|             a = torch.randn(3, 3, device=device, requires_grad=True)
 | |
|             b = a + a
 | |
|             s1, s2, s3 = f(b)
 | |
|             error_msg = 'This view is the output of a function that returns multiple views.'
 | |
|             with self.assertRaisesRegex(RuntimeError, error_msg):
 | |
|                 s1.mul_(s2)
 | |
| 
 | |
|     @skipIfMps  # the test doesn't work on MPS as double types are not supported
 | |
|     def test_mv_grad_stride_0(self, device):
 | |
|         # Reference: https://github.com/pytorch/pytorch/issues/38315
 | |
|         mat = torch.randn(2, 2, dtype=torch.double, device=device)
 | |
|         vec = torch.randn(1, dtype=torch.double, device=device).requires_grad_(True)
 | |
| 
 | |
|         def fn(vec):
 | |
|             # Expand inside the function to make sure the input to
 | |
|             # gradcheck does not have overlapping memory
 | |
|             vec = vec.expand(2)
 | |
|             return (mat @ vec).sum()
 | |
| 
 | |
|         gradcheck(fn, (vec))
 | |
|         gradgradcheck(fn, (vec))
 | |
| 
 | |
|     @onlyCUDA
 | |
|     def test_gradcheck_input_output_different_device(self, device):
 | |
|         x = torch.ones((1,), dtype=torch.double, device="cuda", requires_grad=True)
 | |
|         gradcheck(lambda x: x.to("cpu"), (x,))
 | |
| 
 | |
|         x = torch.ones((1,), dtype=torch.double, device="cpu", requires_grad=True)
 | |
|         gradcheck(lambda x: x.to("cuda"), (x,))
 | |
| 
 | |
|     def test_strided_leaf_grad_layout(self, device):
 | |
|         # (1) If leaf is non-overlapping and dense, grad's layout should match its leaf.
 | |
|         for fmt_a in (torch.contiguous_format, torch.channels_last):
 | |
|             for fmt_b in (torch.contiguous_format, torch.channels_last):
 | |
|                 a = torch.rand((2, 3, 4, 5), device=device).to(memory_format=fmt_a)
 | |
|                 b = torch.rand((2, 3, 4, 5), device=device).to(memory_format=fmt_b)
 | |
|                 a.requires_grad_()
 | |
|                 b.requires_grad_()
 | |
|                 # checks (1) for broadcasted gradients
 | |
|                 a.sum().backward()
 | |
|                 self.assertEqual(a.grad.stride(), a.stride())
 | |
|                 b.sum().backward()
 | |
|                 self.assertEqual(b.grad.stride(), b.stride())
 | |
|                 # checks (1) for non-broadcasted gradients
 | |
|                 a.grad = None
 | |
|                 b.grad = None
 | |
|                 (a * b).sum().backward()
 | |
|                 self.assertEqual(a.grad.stride(), a.stride())
 | |
|                 self.assertEqual(b.grad.stride(), b.stride())
 | |
| 
 | |
|         # (2) If leaf isn't dense, checks that grads are rowmajor contiguous.
 | |
|         c = torch.empty_strided((2, 2), (4, 2), device=device).copy_(torch.rand((2, 2), device=device))
 | |
|         c.requires_grad_()
 | |
|         d = torch.rand((2, 2), device=device)
 | |
|         # checks (2) for broadcasted gradients
 | |
|         c.sum().backward()
 | |
|         self.assertEqual(c.grad.stride(), (2, 1))
 | |
|         # checks (2) for non-broadcasted gradients
 | |
|         c.grad = None
 | |
|         (c * d).sum().backward()
 | |
|         self.assertEqual(c.grad.stride(), (2, 1))
 | |
| 
 | |
|     @skipIfMps
 | |
|     def test_copy_r_to_c(self, device):
 | |
|         out_c = torch.empty(3, 2, dtype=torch.cdouble, device=device)
 | |
|         inp_r = torch.randn(3, 2, dtype=torch.double, device=device,
 | |
|                             requires_grad=True)
 | |
| 
 | |
|         def do_test():
 | |
|             out_c.copy_(inp_r)
 | |
|             out_c_inter = out_c.sum()
 | |
|             out_c_inter.abs().backward()
 | |
|             with torch.no_grad():
 | |
|                 self.assertEqual(inp_r.grad, torch.ones_like(inp_r) * torch.sgn(out_c_inter).real)
 | |
| 
 | |
|         self.assertNotWarn(do_test)
 | |
| 
 | |
|     def test_to_r_to_c(self, device):
 | |
|         def do_test():
 | |
|             inp_r = torch.randn(3, 2, dtype=torch.double, device=device,
 | |
|                                 requires_grad=True)
 | |
|             out = inp_r.to(torch.complex128)
 | |
|             out_inter = out.sum()
 | |
|             out_inter.abs().backward()
 | |
|             with torch.no_grad():
 | |
|                 self.assertEqual(inp_r.grad, torch.ones_like(inp_r) * torch.sgn(out_inter).real)
 | |
| 
 | |
|         self.assertNotWarn(do_test)
 | |
| 
 | |
|     def test_non_differentiable_ops(self, device):
 | |
|         # Just make sure the op doesn't raise an error
 | |
|         # and resulting tensor has requires_grad=False.
 | |
|         x = torch.tensor([[1, 2], [3, 4.]], requires_grad=True, device=device)
 | |
|         out = torch.isin(x, torch.tensor([2, 3], device=device))
 | |
|         self.assertFalse(out.requires_grad)
 | |
| 
 | |
|         x = torch.randn(3, 3, requires_grad=True)
 | |
|         out = torch.signbit(x)
 | |
|         self.assertFalse(out.requires_grad)
 | |
| 
 | |
|     def test_warning_in_backward(self, device):
 | |
|         # Test warning during backward are always propagated as python warnings (gh-50209)
 | |
|         # NOTE: For device=cuda, warning gets propagated from a worker thread
 | |
|         a = torch.zeros((), device=device, requires_grad=True)
 | |
|         b = torch._C._nn._test_warn_in_autograd(a)
 | |
| 
 | |
|         with self.assertWarnsRegex(UserWarning, "Warn from backward"):
 | |
|             b.backward()
 | |
| 
 | |
|     def test_complex_scalar_backward(self, device):
 | |
|         a = torch.zeros(1, device=device, requires_grad=True)
 | |
|         b = a * 0.5j
 | |
| 
 | |
|         msg = "grad can be implicitly created only for real scalar outputs"
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             b.backward()
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             torch.autograd.grad(b, a)
 | |
| 
 | |
|     def test_pow_real_negative_base_complex_exponent(self, device):
 | |
|         # OpInfo doesn't naturally support input of mixed types, hence this test here.
 | |
|         base = -torch.ones(2, device=device, dtype=torch.double)
 | |
|         exponent = torch.randn(2, device=device, dtype=torch.cdouble, requires_grad=True)
 | |
| 
 | |
|         def fn(exponent):
 | |
|             return torch.pow(base, exponent)
 | |
| 
 | |
|         torch.autograd.gradcheck(fn, (exponent,))
 | |
| 
 | |
|         def fn(exponent):
 | |
|             return torch.pow(-1, exponent)
 | |
| 
 | |
|         torch.autograd.gradcheck(fn, (exponent,))
 | |
| 
 | |
|     def test_resize_version_bump(self, device):
 | |
|         x = torch.rand((1,), device=device)
 | |
|         y = torch.randn((3,), device=device)
 | |
|         x.resize_((1, 2))
 | |
|         self.assertEqual(x._version, 1)
 | |
|         x.resize_as_(y)
 | |
|         self.assertEqual(x._version, 2)
 | |
| 
 | |
|         # In the following cases, `resize` is no-op,
 | |
|         # so no version bumps.
 | |
|         x.resize_((3,))
 | |
|         self.assertEqual(x._version, 2)
 | |
| 
 | |
|         x.resize_as_(y)
 | |
|         self.assertEqual(x._version, 2)
 | |
| 
 | |
| 
 | |
| class TestAllowMutationOnSaved(TestCase):
 | |
|     def assertClonedLenEqual(self, ctx, n):
 | |
|         self.assertEqual(len(list(ctx.cloned.items())), n)
 | |
| 
 | |
|     def assertTIDMapLenEqual(self, ctx, n):
 | |
|         self.assertEqual(len(list(ctx.tid_to_weakhandle.items())), n)
 | |
| 
 | |
|     def test_basic(self):
 | |
|         a = torch.rand(2, 3, requires_grad=True)
 | |
| 
 | |
|         def fn(a):
 | |
|             b = a.clone()
 | |
|             out = (b**2).sum()
 | |
|             b.sin_()
 | |
|             out.sum().backward()
 | |
|             return a.grad
 | |
|         msg = "variables needed for gradient computation has been modified by an inplace"
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             fn(a)
 | |
| 
 | |
|         with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
 | |
|             da = fn(a)
 | |
| 
 | |
|         self.assertTrue(torch.allclose(a * 2, da))
 | |
|         self.assertClonedLenEqual(ctx, 0)
 | |
| 
 | |
|     def test_views(self):
 | |
|         a = torch.rand(2, 3, requires_grad=True)
 | |
| 
 | |
|         def fn(a):
 | |
|             b = a.clone()
 | |
|             c = b.view_as(b)
 | |
|             out = (b**2).sum()  # How does this work?
 | |
|             c.sin_()
 | |
|             out.sum().backward()
 | |
|             return a.grad
 | |
| 
 | |
|         msg = "variables needed for gradient computation has been modified by an inplace"
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             fn(a)
 | |
| 
 | |
|         with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
 | |
|             da = fn(a)
 | |
| 
 | |
|         self.assertClonedLenEqual(ctx, 0)
 | |
|         self.assertTrue(torch.allclose(a * 2, da))
 | |
| 
 | |
|     def test_save_base_and_modify_view(self):
 | |
|         with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
 | |
|             a = torch.rand(2, 3, requires_grad=True)
 | |
|             b = a.clone()
 | |
|             c = b[:1]
 | |
|             out = b**2
 | |
|             # modify the view
 | |
|             c *= 10
 | |
|             # self.assertClonedLenEqual(ctx, 1)
 | |
|             out.sum().backward()
 | |
|             self.assertClonedLenEqual(ctx, 0)
 | |
| 
 | |
|         self.assertClonedLenEqual(ctx, 0)
 | |
|         self.assertTrue(torch.allclose(a * 2, a.grad))
 | |
| 
 | |
|     def test_save_view_modify_base(self):
 | |
|         with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
 | |
|             a = torch.rand(2, 3, requires_grad=True)
 | |
|             b = a.clone()
 | |
|             c = b[:]
 | |
|             out = (c**2).sum()
 | |
|             b *= 2
 | |
|             out.backward()
 | |
|             self.assertTrue(torch.allclose(a * 2, a.grad))
 | |
| 
 | |
|     def test_double_backward(self):
 | |
|         with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
 | |
|             a = torch.rand(2, 3, requires_grad=True)
 | |
|             b = a.clone()
 | |
|             out = (b**2).sum()
 | |
|             b.sin_()
 | |
|             torch.autograd.grad(out, a, create_graph=True)
 | |
|             da, = torch.autograd.grad(out, a, create_graph=True)
 | |
|             d2a, = torch.autograd.grad(da.sum(), a)
 | |
| 
 | |
|         self.assertTrue(torch.allclose(torch.ones_like(a) * 2, d2a))
 | |
|         self.assertClonedLenEqual(ctx, 0)
 | |
| 
 | |
|     def test_saved_but_not_anymore(self):
 | |
|         # Make sure we don't clone if the tensor was once saved, but
 | |
|         # by the time we do in-place, it is no longer saved
 | |
|         with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
 | |
|             a = torch.randn(2, 3, requires_grad=True).clone()
 | |
|             out = (a**2).sum()
 | |
|             self.assertTIDMapLenEqual(ctx, 1)
 | |
|             self.assertClonedLenEqual(ctx, 0)
 | |
|             out.backward()
 | |
|             a.sin_()
 | |
|             self.assertClonedLenEqual(ctx, 0)
 | |
|             out = (a**2).sum()
 | |
|             a.sin_()
 | |
|             self.assertClonedLenEqual(ctx, 1)
 | |
|             del out
 | |
|             self.assertClonedLenEqual(ctx, 0)
 | |
| 
 | |
|     def test_saved_same_tensor_many_times(self):
 | |
|         # We should only clone once
 | |
|         with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
 | |
|             a = torch.randn(2, 3, requires_grad=True).clone()
 | |
|             b = a**2
 | |
|             c = a**2
 | |
|             a.sin_()
 | |
|             self.assertClonedLenEqual(ctx, 1)
 | |
|             del b, c
 | |
|             self.assertClonedLenEqual(ctx, 0)
 | |
| 
 | |
|     def test_saved_same_tensor_different_versions(self):
 | |
|         with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
 | |
|             a = torch.randn(2, 3, requires_grad=True).clone()
 | |
|             b = a**2
 | |
|             a.sin_()
 | |
|             c = a**2
 | |
|             a.sin_()
 | |
|             self.assertClonedLenEqual(ctx, 2)
 | |
|             del b
 | |
|             self.assertClonedLenEqual(ctx, 1)
 | |
|             del c
 | |
|             self.assertClonedLenEqual(ctx, 0)
 | |
| 
 | |
|     def test_with_math_views(self):
 | |
|         with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
 | |
|             a = torch.tensor([1 + 1j], requires_grad=True).clone()
 | |
|             b = a.conj()
 | |
|             out = (b**2).sum()
 | |
|             a.sin_()
 | |
|             out.abs().backward()
 | |
| 
 | |
|             a = torch.tensor([1 + 1j], requires_grad=True).clone()
 | |
|             b = a.conj()
 | |
|             out = (b**2).sum()
 | |
|             # in this case, it is no longer a view it seems
 | |
|             b.sin_()
 | |
|             out.abs().backward()
 | |
| 
 | |
|     def test_with_out_variant(self):
 | |
|         with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
 | |
|             a = torch.tensor([1.], requires_grad=True)
 | |
|             b = torch.tensor([1.])
 | |
|             c = torch.tensor([2.])
 | |
|             out = a * b
 | |
|             self.assertTIDMapLenEqual(ctx, 1)
 | |
|             torch.sin(c, out=b)
 | |
|             self.assertClonedLenEqual(ctx, 1)
 | |
|             out.backward()
 | |
|             self.assertClonedLenEqual(ctx, 0)
 | |
| 
 | |
|     def test_backward_out_of_context(self):
 | |
|         # Out of context
 | |
|         with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
 | |
|             a = torch.rand(2, 3, requires_grad=True)
 | |
|             out = (a**2).sum()
 | |
| 
 | |
|         msg = "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
 | |
|         with self.assertRaisesRegex(AssertionError, msg):
 | |
|             out.backward()
 | |
| 
 | |
|         # Different context
 | |
|         with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
 | |
|             a = torch.rand(2, 3, requires_grad=True)
 | |
|             out = (a**2).sum()
 | |
| 
 | |
|         with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
 | |
|             with self.assertRaisesRegex(AssertionError, msg):
 | |
|                 out.backward()
 | |
| 
 | |
|     def test_disallow_nesting(self):
 | |
|         with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
 | |
|             msg = "allow_mutation_on_saved_tensors contexts cannot be nested"
 | |
|             with self.assertRaisesRegex(RuntimeError, msg):
 | |
|                 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
 | |
|                     pass
 | |
| 
 | |
| class TestAutogradInferenceMode(TestCase):
 | |
|     def _is_inference_tensor(self, tensor):
 | |
|         try:
 | |
|             err_msg = "Inference tensors do not track version counter"
 | |
|             with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|                 tensor._version
 | |
|             return True
 | |
|         except AssertionError as e:
 | |
|             return False
 | |
| 
 | |
|     def test_inference_mode_context_manager(self):
 | |
|         self.assertFalse(torch.is_inference_mode_enabled())
 | |
|         with torch.inference_mode():
 | |
|             self.assertTrue(torch.is_inference_mode_enabled())
 | |
|             with torch.inference_mode(False):
 | |
|                 self.assertFalse(torch.is_inference_mode_enabled())
 | |
|             self.assertTrue(torch.is_inference_mode_enabled())
 | |
|         self.assertFalse(torch.is_inference_mode_enabled())
 | |
| 
 | |
|     def test_inference_mode_decorator(self):
 | |
|         for mode in (True, False):
 | |
|             @torch.inference_mode(mode)
 | |
|             def func(x):
 | |
|                 self.assertEqual(torch.is_inference_mode_enabled(), mode)
 | |
|                 return x * x
 | |
| 
 | |
|             for requires_grad in (True, False):
 | |
|                 c = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
|                 d = func(c)
 | |
|                 self.assertTrue(not mode or torch.is_inference(d))
 | |
|                 self.assertEqual(d.requires_grad, requires_grad and not mode)
 | |
| 
 | |
|     def test_inference_mode_tensor_creation(self):
 | |
|         with torch.inference_mode():
 | |
|             # new tensors created through constructors are inference tensors
 | |
|             c = torch.ones(1, 2, 3)
 | |
|             self.assertFalse(c.requires_grad)
 | |
|             self.assertTrue(torch.is_inference(c))
 | |
| 
 | |
|             # requires_grad doesn't change inference tensor behavior in InferenceMode
 | |
|             tmp = torch.ones(1, 2, 3, requires_grad=True)
 | |
|             self.assertTrue(tmp.requires_grad)
 | |
|             self.assertTrue(torch.is_inference(tmp))
 | |
| 
 | |
|             tmp = torch.ones(1, 2, 3).requires_grad_(False)
 | |
|             self.assertFalse(tmp.requires_grad)
 | |
|             self.assertTrue(torch.is_inference(tmp))
 | |
| 
 | |
|     def test_inference_mode_existing_autograd_session(self):
 | |
|         s = torch.ones(1, 2, 3, requires_grad=True)
 | |
|         a = s.clone()
 | |
| 
 | |
|         # `a` gets saved outside of inference mode
 | |
|         out = a * a
 | |
|         with torch.inference_mode():
 | |
|             a.add_(2)
 | |
| 
 | |
|         self.assertFalse(torch.is_inference(a))
 | |
|         # tensors created outside of inference mode aren't
 | |
|         # inference tensors, so they will still have their
 | |
|         # version counters tracked
 | |
|         err_msg = ("one of the variables needed for gradient computation has been "
 | |
|                    "modified by an inplace operation")
 | |
|         with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|             out.backward(torch.ones_like(out))
 | |
| 
 | |
|     def test_inference_mode_inf_tensor_in_inf_mode_functional_op(self):
 | |
|         def functional_op(x):
 | |
|             return x * x
 | |
| 
 | |
|         with torch.inference_mode():
 | |
|             for requires_grad in (True, False):
 | |
|                 c = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
| 
 | |
|                 # performing a non-view operation produces a inference tensor
 | |
|                 # that does not require grad
 | |
|                 func_out = functional_op(c)
 | |
|                 self.assertTrue(torch.is_inference(func_out))
 | |
|                 self.assertFalse(func_out.requires_grad)
 | |
| 
 | |
|     def test_inference_mode_inf_tensor_in_inf_mode_inplace_op(self):
 | |
|         @torch.inference_mode()
 | |
|         def run_test(fn):
 | |
|             for requires_grad in (True, False):
 | |
|                 c = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
| 
 | |
|                 # after performing inplace operation, tensor is still
 | |
|                 # an inference tensor
 | |
|                 fn(c)
 | |
|                 self.assertTrue(torch.is_inference(c))
 | |
|                 self.assertEqual(c.requires_grad, requires_grad)
 | |
|         run_test(lambda x: x.add_(2))
 | |
|         run_test(lambda x: x.transpose_(0, 1))
 | |
| 
 | |
|         # inplace ops with manual kernel for ADInplaceOrView key in VariableTypeManual.cpp
 | |
|         run_test(lambda x: x.resize_(1, 2))
 | |
|         run_test(lambda x: x.resize_as_(torch.ones(1, 2)))
 | |
|         run_test(lambda x: x.copy_(torch.ones(1, 2, 3)))
 | |
| 
 | |
|     def test_inference_mode_inf_tensor_in_inf_mode_view_op(self):
 | |
|         with torch.inference_mode():
 | |
|             for requires_grad in (True, False):
 | |
|                 c = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
| 
 | |
|                 # perform view operation produces inference tensor
 | |
|                 # that does not require grad
 | |
|                 view_out = c.view(-1)
 | |
|                 self.assertTrue(torch.is_inference(view_out))
 | |
|                 self.assertFalse(view_out.requires_grad)
 | |
| 
 | |
|     def test_inference_mode_inf_tensor_in_normal_mode_functional_op(self):
 | |
|         def functional_op(x):
 | |
|             return x * x
 | |
| 
 | |
|         for requires_grad in (True, False):
 | |
|             with torch.inference_mode():
 | |
|                 c = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
| 
 | |
|         func_out = functional_op(c)
 | |
|         self.assertFalse(torch.is_inference(func_out))
 | |
|         self.assertFalse(func_out.requires_grad)
 | |
|         self.assertTrue(func_out.is_leaf)
 | |
| 
 | |
|     def test_inference_mode_inf_tensor_in_normal_mode_inplace_op(self):
 | |
|         def run_test(fn):
 | |
|             for requires_grad in (False, True):
 | |
|                 with torch.inference_mode():
 | |
|                     c = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
| 
 | |
|                 if requires_grad:
 | |
|                     # leaf variable that requires grad is being used in an inplace
 | |
|                     # operation when requires_grad=True
 | |
|                     pass
 | |
|                 else:
 | |
|                     err_msg = "Inplace update to inference tensor outside InferenceMode"
 | |
|                     with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|                         fn(c)
 | |
|         run_test(lambda x: x.add_(2))
 | |
|         run_test(lambda x: x.transpose_(0, 1))
 | |
| 
 | |
|     def test_inference_mode_inf_tensor_in_normal_mode_view_op(self):
 | |
|         for requires_grad in (True, False):
 | |
|             with torch.inference_mode():
 | |
|                 c = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
| 
 | |
|             out = c.view(-1)
 | |
|             self.assertTrue(torch.is_inference(out))
 | |
|             self.assertFalse(out.requires_grad)
 | |
|             self.assertFalse(out._is_view())
 | |
|             self.assertTrue(out.is_leaf)
 | |
| 
 | |
|     def test_normal_tensor_inplace_output_in_inference_mode(self):
 | |
|         def run_test(fn):
 | |
|             for requires_grad in (True, False):
 | |
|                 s = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
|                 a = s.clone()
 | |
| 
 | |
|                 with torch.inference_mode():
 | |
|                     fn(a)
 | |
|                     self.assertFalse(torch.is_inference(a))
 | |
|                     self.assertEqual(a.requires_grad, requires_grad)
 | |
| 
 | |
|                     # inplace -> inplace
 | |
|                     fn(a)
 | |
|                     self.assertFalse(torch.is_inference(a))
 | |
|                     self.assertEqual(a.requires_grad, requires_grad)
 | |
| 
 | |
|                     # inplace -> inplace -> view
 | |
|                     view_out = a.view(-1)
 | |
|                     self.assertFalse(torch.is_inference(view_out))
 | |
|                     self.assertEqual(view_out.requires_grad, requires_grad)
 | |
|         run_test(lambda x: x.add_(2))
 | |
|         run_test(lambda x: x.transpose_(0, 1))
 | |
| 
 | |
|     def test_normal_tensor_inplace_output_in_normal_mode(self):
 | |
|         def run_test(fn):
 | |
|             for requires_grad in (True, False):
 | |
|                 s = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
|                 a = s.clone()
 | |
| 
 | |
|                 with torch.inference_mode():
 | |
|                     fn(a)
 | |
|                     self.assertFalse(torch.is_inference(a))
 | |
|                     self.assertEqual(a.requires_grad, requires_grad)
 | |
| 
 | |
|                 fn(a)
 | |
|                 self.assertFalse(torch.is_inference(a))
 | |
|                 self.assertEqual(a.requires_grad, requires_grad)
 | |
| 
 | |
|                 # inplace -> inplace
 | |
|                 fn(a)
 | |
|                 self.assertFalse(torch.is_inference(a))
 | |
|                 self.assertEqual(a.requires_grad, requires_grad)
 | |
| 
 | |
|                 # inplace -> inplace -> view
 | |
|                 view_out = a.view(-1)
 | |
|                 self.assertFalse(torch.is_inference(view_out))
 | |
|                 self.assertEqual(view_out.requires_grad, requires_grad)
 | |
|             run_test(lambda x: x.add_(2))
 | |
|             run_test(lambda x: x.transpose_(0, 1))
 | |
| 
 | |
|     def test_normal_tensor_view_output_in_inference_mode(self):
 | |
|         for requires_grad in (True, False):
 | |
|             s = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
|             a = s.clone()
 | |
| 
 | |
|             with torch.inference_mode():
 | |
|                 out = a.view(-1)
 | |
|                 self.assertFalse(torch.is_inference(out))
 | |
|                 self.assertEqual(out.requires_grad, requires_grad)
 | |
|                 self.assertTrue(out._is_view())
 | |
| 
 | |
|                 # view -> view
 | |
|                 tmp = out.view(-1)
 | |
|                 self.assertFalse(torch.is_inference(tmp))
 | |
|                 self.assertEqual(tmp.requires_grad, requires_grad)
 | |
|                 self.assertTrue(tmp._is_view())
 | |
|                 self.assertTrue(tmp.is_leaf)
 | |
| 
 | |
|                 # view -> view -> inplace
 | |
|                 self.assertTrue(torch.is_inference_mode_enabled())
 | |
|                 tmp.add_(2)
 | |
|                 self.assertFalse(torch.is_inference(tmp))
 | |
|                 self.assertEqual(tmp.requires_grad, requires_grad)
 | |
|                 # Accessing is_leaf in python tries to update grad_fn and raises:
 | |
|                 # A view was created in inference mode and its base or
 | |
|                 # another view of its base has been modified inplace in normal mode
 | |
|                 # tmp.is_leaf
 | |
|                 self.assertEqual(a._version, tmp._version)
 | |
| 
 | |
|     def test_normal_tensor_view_output_in_normal_mode(self):
 | |
|         def functional_op(x):
 | |
|             return x * x
 | |
| 
 | |
|         for requires_grad in (True, False):
 | |
|             s = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
|             a = s.clone()
 | |
| 
 | |
|             with torch.inference_mode():
 | |
|                 out = a.view(-1)
 | |
|                 self.assertFalse(torch.is_inference(out))
 | |
|                 self.assertEqual(out.requires_grad, requires_grad)
 | |
|                 self.assertTrue(out._is_view())
 | |
|                 self.assertTrue(out.is_leaf)
 | |
| 
 | |
|             tmp = functional_op(out)
 | |
|             self.assertFalse(torch.is_inference(tmp))
 | |
|             self.assertEqual(tmp.requires_grad, requires_grad)
 | |
| 
 | |
|             if requires_grad:
 | |
|                 err_msg = "A view was created in inference mode and is being modified inplace"
 | |
|                 with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|                     out.add_(2)
 | |
|                 pass
 | |
|             else:
 | |
|                 out.add_(2)
 | |
| 
 | |
|             tmp = out.view(2, 3)
 | |
|             self.assertFalse(torch.is_inference(tmp))
 | |
|             self.assertEqual(tmp.requires_grad, requires_grad)
 | |
| 
 | |
|     def test_mix_inference_and_normal_tensor_functional_op(self):
 | |
|         for requires_grad in (True, False):
 | |
|             s = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
| 
 | |
|             with torch.inference_mode():
 | |
|                 c = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
| 
 | |
|             # add is safe since it doesn't save any variable for backward
 | |
|             out = c.add(s)
 | |
|             self.assertFalse(torch.is_inference(out))
 | |
|             self.assertEqual(out.requires_grad, requires_grad)
 | |
|             if requires_grad:
 | |
|                 # leaf inference tensor with requires_grad=True can still have gradient
 | |
|                 out.backward(torch.ones_like(out))
 | |
|                 self.assertEqual(c.grad, torch.ones_like(c))
 | |
| 
 | |
|             if requires_grad:
 | |
|                 err_msg = "Inference tensors cannot be saved for backward"
 | |
|                 with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|                     c * s
 | |
| 
 | |
|                 # TODO: Test this with an autograd.Function when it works
 | |
|                 #       stack stopped capturing a TensorList input
 | |
|                 # # inference tensor in TensorList input
 | |
|                 # inputs = [s, c]
 | |
|                 # with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|                 #     torch.stack(inputs)
 | |
| 
 | |
| 
 | |
|     def test_mix_inference_and_normal_tensor_inplace_op(self):
 | |
|         for requires_grad in (True, False):
 | |
|             s = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
|             a = s.clone()
 | |
| 
 | |
|             with torch.inference_mode():
 | |
|                 c = torch.ones(1, 2, 3)
 | |
| 
 | |
|             self.assertTrue(torch.is_inference(c))
 | |
|             if requires_grad:
 | |
|                 err_msg = "Inference tensors cannot be saved for backward"
 | |
|                 with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|                     a.mul_(c)
 | |
| 
 | |
|                 # inference tensor in TensorList input
 | |
|                 err_msg = ("out=... arguments don't support automatic differentiation, "
 | |
|                            "but one of the arguments requires grad")
 | |
|                 with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|                     torch.mul(s, s, out=c)
 | |
|             else:
 | |
|                 a.mul_(c)
 | |
|                 err_msg = "Inplace update to inference tensor outside InferenceMode is not allowed"
 | |
|                 with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|                     torch.mul(s, s, out=c)
 | |
| 
 | |
|     def test_mix_inference_and_normal_tensor_view_op(self):
 | |
|         for requires_grad in (True, False):
 | |
|             s = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
| 
 | |
|             with torch.inference_mode():
 | |
|                 c = torch.ones(1, 2, 3)
 | |
| 
 | |
|             # view_as is a composite op which calls view with only one
 | |
|             # tensor argument. So there isn't a mixed inference and normal
 | |
|             # tensor inputs for view ops
 | |
|             tmp1 = c.view_as(s)
 | |
|             self.assertTrue(torch.is_inference(tmp1))
 | |
|             self.assertFalse(tmp1.requires_grad)
 | |
| 
 | |
|             # this is fine since its equivalent as s.view(c.sizes()) which
 | |
|             # isn't a mixed input scenario
 | |
|             tmp2 = s.view_as(c)
 | |
|             self.assertFalse(torch.is_inference(tmp2))
 | |
|             self.assertEqual(tmp2.requires_grad, requires_grad)
 | |
| 
 | |
|     def test_inference_mode_handle_direct_view_on_rebase(self):
 | |
|         def run_test(fn):
 | |
|             for requires_grad in (True, False):
 | |
|                 s = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
|                 a = s.clone()
 | |
| 
 | |
|                 with torch.inference_mode():
 | |
|                     view_out = a.view_as(a)
 | |
| 
 | |
|                 if requires_grad:
 | |
|                     err_msg = "A view was created in inference mode and is being modified inplace"
 | |
|                     with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|                         fn(view_out)
 | |
|                     pass
 | |
|                 else:
 | |
|                     fn(view_out)
 | |
|         run_test(lambda x: x.add_(2))
 | |
|         run_test(lambda x: x.transpose_(0, 1))
 | |
| 
 | |
|     def test_inference_mode_handle_indirect_view_on_rebase(self):
 | |
|         def run_test(fn):
 | |
|             for requires_grad in (True, False):
 | |
|                 s = torch.ones(1, 2, 3, requires_grad=requires_grad)
 | |
|                 a = s.clone()
 | |
| 
 | |
|                 with torch.inference_mode():
 | |
|                     view_out = a.view(-1)
 | |
| 
 | |
|                 fn(a)
 | |
|                 if requires_grad:
 | |
|                     err_msg = "A view was created in inference mode and its base or another view "
 | |
|                     with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|                         view_out.grad_fn
 | |
|                     pass
 | |
|                 else:
 | |
|                     view_out.grad_fn
 | |
|         run_test(lambda x: x.add_(2))
 | |
|         run_test(lambda x: x.transpose_(0, 1))
 | |
| 
 | |
| 
 | |
| class TestMultithreadAutograd(TestCase):
 | |
|     def _run_py_multithread_fn(self, fn, args=(), num_threads=10, kwargs=None, pass_idx=False):
 | |
| 
 | |
|         class PropagatingThread(threading.Thread):
 | |
|             '''Helper class to propagate exception from child
 | |
|             thread to main thread on join.
 | |
| 
 | |
|             Reference: https://stackoverflow.com/a/31614591/5602957
 | |
|             '''
 | |
| 
 | |
|             def run(self):
 | |
|                 self.exception = None
 | |
|                 try:
 | |
|                     self.ret = super().run()
 | |
|                 except Exception as e:
 | |
|                     self.exception = e
 | |
| 
 | |
|             def join(self, timeout=None):
 | |
|                 super().join(timeout)
 | |
|                 if self.exception:
 | |
|                     raise self.exception from self.exception
 | |
|                 return self.ret
 | |
| 
 | |
|         threads = []
 | |
|         for idx in range(num_threads):
 | |
|             p = PropagatingThread(target=fn, args=((idx, *args) if pass_idx else args))
 | |
|             p.start()
 | |
|             threads.append(p)
 | |
| 
 | |
|         for p in threads:
 | |
|             p.join()
 | |
| 
 | |
|     def test_multithreaded_exception_propagation(self):
 | |
|         # Test whether exception in child thread
 | |
|         # are propagated to main thread.
 | |
|         def fn():
 | |
|             self.assertTrue(False)
 | |
| 
 | |
|         with self.assertRaises(AssertionError):
 | |
|             self._run_py_multithread_fn(fn)
 | |
| 
 | |
|     def test_simple_backward(self):
 | |
|         # simple multithreaded backward that create threads in the beginning of training
 | |
|         # and everything else is training separately, i.e. inputs, operations, etc.
 | |
|         def train_fn():
 | |
|             x = torch.ones(5, 5, requires_grad=True)
 | |
|             y = (x + 3) * (x + 4) * 0.5
 | |
|             y.sum().backward()
 | |
|             self.assertEqual(x.grad, x + 3.5)
 | |
| 
 | |
|         self._run_py_multithread_fn(train_fn)
 | |
| 
 | |
|     def test_simple_backward_same_input(self):
 | |
|         # simple multithreaded backward with only shared inputs (i.e. This is common
 | |
|         # for things like Hogwild multithreaded training with multiple CPU threads)
 | |
|         def train_fn_backward(x):
 | |
|             y = (x + 3) * (x + 4) * 0.5
 | |
|             y.sum().backward()
 | |
| 
 | |
|         x = torch.ones(5, 5, requires_grad=True)
 | |
|         self._run_py_multithread_fn(train_fn_backward, (x,))
 | |
|         # Since we are calling backward from multiple threads
 | |
|         # and all threads share the same input, when we do backward
 | |
|         # concurrently, different backwards will all accumulate to
 | |
|         # the same .grad for each input, and the gradients should
 | |
|         # be equal to num_threads * gradient
 | |
|         self.assertEqual(x.grad, 10 * (x + 3.5))
 | |
| 
 | |
|         def train_fn_grad(x):
 | |
|             y = (x + 3) * (x + 4) * 0.5
 | |
|             grads = torch.autograd.grad(y.sum(), x)
 | |
|             self.assertEqual(len(grads), 1)
 | |
|             self.assertEqual(grads[0], x + 3.5)
 | |
| 
 | |
|         # since we use functional grad() api, gradients will not
 | |
|         # be accumulate to the same place and should be the same
 | |
|         self._run_py_multithread_fn(train_fn_grad, (x,))
 | |
| 
 | |
|     def test_multi_grad_hooks(self):
 | |
|         # Multihooks should behave independently per execution of backward
 | |
|         # Test that the hook fired the number of times we ran backward
 | |
|         # even if those executions occur concurrently on different threads
 | |
|         t1 = torch.rand(2, requires_grad=True)
 | |
|         t2 = torch.rand(2, requires_grad=True)
 | |
|         t3 = torch.rand(2, requires_grad=True)
 | |
|         t4 = torch.rand(2, requires_grad=True)
 | |
| 
 | |
|         res = None
 | |
|         count = [0]
 | |
| 
 | |
|         def hook(grads):
 | |
|             nonlocal res
 | |
|             count[0] += 1
 | |
|             grad_is_none = [g is not None for g in grads]
 | |
|             if res is None:
 | |
|                 res = grad_is_none
 | |
|             else:
 | |
|                 self.assertEqual(res, grad_is_none)
 | |
| 
 | |
|         torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook)
 | |
| 
 | |
|         out = (t2 * t3).sum()
 | |
| 
 | |
|         def backward_retain_graph(out, t2, t3):
 | |
|             out.backward(inputs=(t2, t3), retain_graph=True)
 | |
| 
 | |
|         self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5)
 | |
| 
 | |
|         self.assertEqual(count[0], 5)
 | |
|         self.assertEqual(res, [False, True, True, False])
 | |
| 
 | |
|         # Leave one hook partially applied
 | |
|         res = None
 | |
|         count = [0]
 | |
|         err_count = [0]
 | |
|         bw_count = [0]
 | |
| 
 | |
|         class Func(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x):
 | |
|                 return x
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, gO):
 | |
|                 bw_count[0] += 1
 | |
|                 if bw_count[0] == 1:
 | |
|                     raise RuntimeError("error message")
 | |
|                 else:
 | |
|                     return gO
 | |
| 
 | |
|         out = (Func.apply(t2) * t3).sum()
 | |
| 
 | |
|         def backward_retain_graph(out, t2, t3):
 | |
|             try:
 | |
|                 out.backward(inputs=(t2, t3), retain_graph=True)
 | |
|             except RuntimeError:
 | |
|                 err_count[0] += 1
 | |
| 
 | |
|         self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5)
 | |
| 
 | |
|         self.assertEqual(count[0], 4)
 | |
|         self.assertEqual(err_count[0], 1)
 | |
|         self.assertEqual(res, [False, True, True, False])
 | |
| 
 | |
| 
 | |
|     def test_dataparallel_saved_tensors_hooks(self):
 | |
|         def pack(x):
 | |
|             warnings.warn("pack")
 | |
|             return x
 | |
| 
 | |
|         _self = self
 | |
| 
 | |
|         class Model(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 with warnings.catch_warnings(record=True) as w:
 | |
|                     y = x * x
 | |
|                     if torch.cuda.device_count() >= 2:
 | |
|                         # DataParallel is calling the forward in different threads
 | |
|                         # without progating TLS, so hooks should not be called here
 | |
|                         _self.assertEqual(len(w), 0)
 | |
|                     else:
 | |
|                         # DataParallel only uses one thread
 | |
|                         # so hooks should be called here
 | |
|                         _self.assertGreater(len(w), 0)
 | |
| 
 | |
|         x = torch.ones(5, 5, requires_grad=True)
 | |
|         model = torch.nn.DataParallel(Model())
 | |
| 
 | |
|         with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
 | |
|             model(x)
 | |
|             with warnings.catch_warnings(record=True) as w:
 | |
|                 y = x * x
 | |
|                 # hooks should be called here
 | |
|                 _self.assertGreater(len(w), 0)
 | |
| 
 | |
|     def test_python_thread_in_middle(self):
 | |
|         # User might write a network that starts on one CPU thread, then runs its second half
 | |
|         # concurrently with other threads (either via python threading or fork/join calls),
 | |
|         # then calls backward()/grad() on BOTH threads, like a Y pattern from input at the
 | |
|         # bottom to output at the top. This way part of the GraphTask is being shared across
 | |
|         # different threads and we need to ensure user specify retain_graph=True, otherwise
 | |
|         # error out with the correct error message
 | |
| 
 | |
|         # Case 1: multiple backward with python threads, retain_graph=False
 | |
|         # should throw error in some threads with no retain_graph.
 | |
|         success_vs_raises = [0, 0]
 | |
| 
 | |
|         def train_fn_no_retain_graph(x):
 | |
|             y = x + x ** 2
 | |
|             try:
 | |
|                 y.sum().backward()
 | |
|                 success_vs_raises[0] += 1
 | |
|             except RuntimeError as error:
 | |
|                 success_vs_raises[1] += 1
 | |
|                 self.assertRegex(str(error), "Specify retain_graph=True")
 | |
| 
 | |
|         x_no_retain = torch.ones(5, 5, requires_grad=True)
 | |
|         y_no_retain = x_no_retain + x_no_retain ** 2
 | |
|         self._run_py_multithread_fn(train_fn_no_retain_graph, (y_no_retain,), num_threads=5)
 | |
|         # at least one thread will be success in this case, all other threads should raise
 | |
|         # with the error that throw to user to recommend them specify retain_graph=True
 | |
|         self.assertTrue(success_vs_raises[0] >= 1)
 | |
| 
 | |
|         # multiple backward with python threads, no error with retain_graph=True
 | |
|         def train_fn_retain_graph(x):
 | |
|             y = x + x ** 2
 | |
|             y.sum().backward(retain_graph=True)
 | |
| 
 | |
|         x_retain = torch.ones(5, 5, requires_grad=True)
 | |
|         y_retain = x_retain + x_retain ** 2
 | |
|         self._run_py_multithread_fn(train_fn_retain_graph, (y_retain,), num_threads=5)
 | |
|         # result should equal to num_thread * gradients
 | |
|         self.assertEqual(x_retain.grad, 5 * (4 * x_retain ** 3 + 6 * (x_retain ** 2) + 4 * x_retain + 1))
 | |
| 
 | |
|     def test_fork_join_in_middle(self):
 | |
|         # multiple backward with jit threads (fork/join primitive)
 | |
|         # similar to test_python_thread_in_middle, we test with retain_graph=False/True
 | |
| 
 | |
|         # Case 1: multiple grad() calls with jit threads, retain_graph=False
 | |
|         # should throw error in some threads with no retain_graph.
 | |
|         @torch.jit.script
 | |
|         def train_fn_jit_no_retain(middle, orig_x):
 | |
|             y = middle + middle ** 2
 | |
|             return torch.autograd.grad([y.sum()], [orig_x])
 | |
| 
 | |
|         @torch.jit.script
 | |
|         def train_fn_fork_join_calls_no_retain(x):
 | |
|             y_no_retain = (x + 3) * (x + 4) * 0.5
 | |
| 
 | |
|             fut = torch.jit._fork(train_fn_jit_no_retain, y_no_retain, x)
 | |
|             grad_hat = train_fn_jit_no_retain(y_no_retain, x)
 | |
|             grad = torch.jit._wait(fut)
 | |
|             return grad, grad_hat
 | |
| 
 | |
|         try:
 | |
|             train_fn_fork_join_calls_no_retain(torch.randn(5, 5, requires_grad=True))
 | |
|         except RuntimeError as error:
 | |
|             self.assertRegex(str(error), "Specify retain_graph=True")
 | |
| 
 | |
|         # Case 2: no error with retain_graph=True
 | |
|         @torch.jit.script
 | |
|         def train_fn_jit_retain(middle, orig_x):
 | |
|             y = middle + middle ** 2
 | |
|             return torch.autograd.grad([y.sum()], [orig_x], retain_graph=True)
 | |
| 
 | |
|         @torch.jit.script
 | |
|         def train_fn_fork_join_calls_retain(x):
 | |
|             y_retain = (x + 3) * (x + 4) * 0.5
 | |
|             fut1 = torch.jit._fork(train_fn_jit_retain, y_retain, x)
 | |
|             fut2 = torch.jit._fork(train_fn_jit_retain, y_retain, x)
 | |
|             grad = train_fn_jit_retain(y_retain, x)
 | |
|             grad1 = torch.jit._wait(fut1)
 | |
|             grad2 = torch.jit._wait(fut2)
 | |
|             return grad, grad1, grad2
 | |
| 
 | |
|         grad, grad1, grad2 = train_fn_fork_join_calls_retain(torch.randn(5, 5, requires_grad=True))
 | |
|         self.assertEqual(grad, grad1)
 | |
|         self.assertEqual(grad, grad2)
 | |
| 
 | |
|     def test_preserve_backtrace(self):
 | |
|         class Foo(torch.autograd.Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, input):
 | |
|                 return input
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, *grad):
 | |
|                 raise ValueError("something")
 | |
| 
 | |
|         t = torch.rand(10, requires_grad=True)
 | |
|         try:
 | |
|             Foo.apply(t).sum().backward()
 | |
|         except Exception:
 | |
|             import traceback
 | |
|             tb = sys.exc_info()[2]
 | |
|             tb_str = "\n".join(traceback.format_tb(tb))
 | |
|             self.assertTrue('raise ValueError("something")' in tb_str)
 | |
| 
 | |
|     # TODO(@anjali411): add an OpInfo based test for torch.cat
 | |
|     # Issue: https://github.com/pytorch/pytorch/issues/51627
 | |
|     #        https://github.com/pytorch/pytorch/issues/75852
 | |
|     def test_cat_stack_r_to_c(self):
 | |
|         inp_c = torch.rand(3, 2, dtype=torch.cdouble, requires_grad=True)
 | |
|         inp_r = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
 | |
| 
 | |
|         def fn(x1, x2):
 | |
|             return torch.cat((x1, x2), dim=-1)
 | |
| 
 | |
|         def fn2(x1, x2):
 | |
|             return torch.stack((x1, x2), dim=-1)
 | |
| 
 | |
|         torch.autograd.gradcheck(fn, [inp_r, inp_c], check_forward_ad=True)
 | |
|         torch.autograd.gradcheck(fn, [inp_c, inp_r], check_forward_ad=True)
 | |
| 
 | |
|         torch.autograd.gradcheck(fn2, [inp_r, inp_c], check_forward_ad=True)
 | |
|         torch.autograd.gradcheck(fn2, [inp_c, inp_r], check_forward_ad=True)
 | |
| 
 | |
| class TestNestedCheckpoint(TestCase):
 | |
|     @staticmethod
 | |
|     def grad(fn):
 | |
|         def wrapper(x):
 | |
|             with torch.enable_grad():
 | |
|                 out = fn(x)
 | |
|                 grad_input, = torch.autograd.grad(out, inputs=(x,), create_graph=True)
 | |
|             return grad_input
 | |
|         return wrapper
 | |
| 
 | |
|     @staticmethod
 | |
|     def sum(fn):
 | |
|         def wrapped(x):
 | |
|             return fn(x).sum()
 | |
|         return wrapped
 | |
| 
 | |
|     @staticmethod
 | |
|     def checkpoint(fn):
 | |
|         def wrapped(*args, **kwargs):
 | |
|             return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=False, **kwargs)
 | |
|         return wrapped
 | |
| 
 | |
|     def get_tests(self, fn):
 | |
|         grad, c = self.grad, self.checkpoint
 | |
| 
 | |
|         tests = (
 | |
|             # function <> tuple of function arbitrarily wrapped in checkpoint in various ways
 | |
|             (fn, (c(fn), c(c(fn)))),
 | |
|             (grad(fn), (grad(c(fn)), grad(c(c(fn))))),
 | |
|             (grad(grad(fn)), (grad(c(grad(fn))), c(grad(grad(c(fn)))), grad(c(grad(c(fn)))))),
 | |
|             (grad(grad(grad(fn))), (grad(c(grad(grad(c(fn))))), grad(c(grad(c(grad(c(fn)))))))),
 | |
|         )
 | |
|         return tests
 | |
| 
 | |
|     def check_graph_dies(self, fn):
 | |
|         def iter_graph(roots):
 | |
|             if not roots:
 | |
|                 return
 | |
|             seen = set()
 | |
|             q = collections.deque()
 | |
|             for node in roots:
 | |
|                 if node is not None:
 | |
|                     seen.add(node)
 | |
|                     q.append(node)
 | |
| 
 | |
|             while q:
 | |
|                 node = q.popleft()
 | |
|                 for fn, _idx in node.next_functions:
 | |
|                     if fn in seen or fn is None:
 | |
|                         continue
 | |
|                     seen.add(fn)
 | |
|                     q.append(fn)
 | |
| 
 | |
|                 yield node
 | |
| 
 | |
|         class Handle():
 | |
|             __slot__ = ["node_name"]
 | |
| 
 | |
|             def __init__(self, node_name):
 | |
|                 self.node_name = node_name
 | |
| 
 | |
|         def scope():
 | |
|             a = torch.randn((), requires_grad=True)
 | |
|             out = fn(a)
 | |
|             refs = []
 | |
|             for node in iter_graph([out.grad_fn]):
 | |
|                 handle = Handle(node.name())
 | |
|                 refs.append(weakref.ref(handle))
 | |
|                 node.metadata["blah"] = handle
 | |
|             return refs
 | |
| 
 | |
|         refs = scope()
 | |
|         node_names = [ref().node_name for ref in refs if ref() is not None]
 | |
|         if len(node_names) > 0:
 | |
|             print("Nodes still alive:", node_names)
 | |
| 
 | |
|         self.assertEqual(len(node_names), 0)
 | |
| 
 | |
|     @parametrize("early_stop", [True, False])
 | |
|     def test_nested_checkpoint(self, early_stop):
 | |
|         with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
 | |
|             x = torch.randn((), requires_grad=True)
 | |
| 
 | |
|             def f(x):
 | |
|                 out = x.sin().exp().sin()
 | |
|                 return out
 | |
| 
 | |
|             def g(x):
 | |
|                 a = x.sin().exp().sin()
 | |
|                 b = x.sin().exp().sin()
 | |
|                 ga, = torch.autograd.grad(a, x)
 | |
|                 gb, = torch.autograd.grad(b, x)
 | |
|                 return x.sin()
 | |
| 
 | |
|             for fn in (f, g):
 | |
|                 for expected_fn, actual_fns in self.get_tests(fn):
 | |
|                     expected = expected_fn(x)
 | |
| 
 | |
|                     for actual_fn in actual_fns:
 | |
|                         actual = actual_fn(x)
 | |
|                         self.assertTrue(torch.allclose(expected, actual))
 | |
|                         self.check_graph_dies(actual_fn)
 | |
| 
 | |
| 
 | |
|     @parametrize("early_stop", [True, False])
 | |
|     def test_nested_checkpoint_two_children(self, early_stop):
 | |
|         with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
 | |
|             grad, sum, c = self.grad, self.sum, self.checkpoint
 | |
| 
 | |
|             def f(x):
 | |
|                 return x.sin().exp().sin()
 | |
| 
 | |
|             def g(x):
 | |
|                 return x.cos().sin().exp()
 | |
| 
 | |
|             def hc(x):
 | |
|                 return c(g)(c(f)(x))
 | |
| 
 | |
|             def h(x):
 | |
|                 return g(f(x))
 | |
| 
 | |
|             a = torch.randn(3, 3, requires_grad=True)
 | |
|             expected = grad(sum(grad(sum(h))))(a)
 | |
|             actual = grad(sum(grad(sum(c(hc)))))(a)
 | |
|             self.assertTrue(torch.allclose(expected, actual))
 | |
| 
 | |
|             actual = grad(sum(c(grad(sum(c(hc))))))(a)
 | |
|             self.assertTrue(torch.allclose(expected, actual))
 | |
| 
 | |
|             self.check_graph_dies(grad(c(hc)))
 | |
|             self.check_graph_dies(grad(sum(grad(sum(c(hc))))))
 | |
|             self.check_graph_dies(grad(sum(c(grad(sum(c(hc)))))))
 | |
| 
 | |
|     @parametrize("early_stop", [True, False])
 | |
|     def test_nested_checkpoint_non_tensor_inputs_and_outputs(self, early_stop):
 | |
|         def fn(k, a, b, f):
 | |
|             return f(k * a * b.exp()), 1, "abcd"
 | |
| 
 | |
|         k = 3
 | |
|         a = torch.tensor(2., requires_grad=True)
 | |
|         b = torch.tensor(3., requires_grad=True)
 | |
| 
 | |
|         def f(x):
 | |
|             return x.sin()
 | |
| 
 | |
|         with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
 | |
|             out, _unused1, _unused2 = checkpoint(fn, k, a, b, f, use_reentrant=False)
 | |
|         actual_grads = torch.autograd.grad(out, (a, b))
 | |
| 
 | |
|         out, _unused1, _unused2 = fn(k, a, b, f)
 | |
|         expected_grads = torch.autograd.grad(out, (a, b))
 | |
|         for actual, expected in zip(actual_grads, expected_grads):
 | |
|             self.assertTrue(torch.allclose(actual, expected))
 | |
| 
 | |
|     @parametrize("early_stop", [True, False])
 | |
|     def test_nested_checkpoint_kwargs(self, early_stop):
 | |
|         def fn(a, blah=None):
 | |
|             out = a.sin().exp()
 | |
|             if blah is not None:
 | |
|                 out = out * blah
 | |
|             return out.sin().exp()
 | |
| 
 | |
|         a = torch.tensor(2., requires_grad=True)
 | |
|         b = torch.tensor(3., requires_grad=True)
 | |
| 
 | |
|         with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
 | |
|             out = checkpoint(fn, a, blah=b, use_reentrant=False)
 | |
|             actual_grads = torch.autograd.grad(out, (a, b))
 | |
| 
 | |
|             out = fn(a, blah=b)
 | |
|             expected_grads = torch.autograd.grad(out, (a, b))
 | |
|             for actual, expected in zip(actual_grads, expected_grads):
 | |
|                 self.assertTrue(torch.allclose(actual, expected))
 | |
| 
 | |
|     @parametrize("early_stop", [True, False])
 | |
|     def test_nested_checkpoint_same_graph(self, early_stop):
 | |
|         counter = [0]
 | |
| 
 | |
|         def hook(*_unused_args):
 | |
|             counter[0] += 1
 | |
| 
 | |
|         def fn(a):
 | |
|             return a.sin().cos().sin()
 | |
| 
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
| 
 | |
|         with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
 | |
|             out = checkpoint(fn, a, use_reentrant=False)
 | |
|         # The hook is registered on the original graph
 | |
|         out.grad_fn.next_functions[0][0].register_hook(hook)
 | |
|         # And backward is performed on the original graph
 | |
|         out.backward()
 | |
| 
 | |
|         self.assertEqual(counter[0], 1)
 | |
| 
 | |
|     @parametrize("early_stop", [True, False])
 | |
|     def test_nested_checkpoint_reentrant_backwards(self, early_stop):
 | |
|         def fn(a):
 | |
|             x = a.sin().cos()
 | |
|             out = x.sin()
 | |
|             return x, out
 | |
| 
 | |
|         def hook(*_unused_args):
 | |
|             # do backward again, but skip over the part of the graph where
 | |
|             # the hook was registered
 | |
|             x.backward(retain_graph=True)
 | |
| 
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
 | |
|             x, out = checkpoint(fn, a, use_reentrant=False)
 | |
|         out.grad_fn.register_hook(hook)
 | |
|         out.backward(retain_graph=True)
 | |
| 
 | |
|     def test_nested_checkpoint_set_early_stop(self):
 | |
|         counter = [0]
 | |
| 
 | |
|         def clone(x):
 | |
|             counter[0] += 1
 | |
|             return x.clone()
 | |
| 
 | |
|         def fn(x):
 | |
|             # Since clone does not save anything, it is not recomputed iff
 | |
|             # early stop is enabled.
 | |
|             return clone(x.sin().cos())
 | |
| 
 | |
|         # Early stopping is enabled by default
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         out = checkpoint(fn, a, use_reentrant=False)
 | |
|         out.backward()
 | |
|         self.assertEqual(counter[0], 1)
 | |
| 
 | |
|         # Try using the context manager to set early stopping to False.
 | |
|         # Expect early stopping to be disabled for all checkpoints ran under
 | |
|         # the context manager, even though context manager is no longer active
 | |
|         # when backward/recomputation is performed.
 | |
|         counter = [0]
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         with torch.utils.checkpoint.set_checkpoint_early_stop(False):
 | |
|             out = checkpoint(fn, a, use_reentrant=False)
 | |
| 
 | |
|         out.backward()
 | |
|         self.assertEqual(counter[0], 2)
 | |
| 
 | |
|     def test_nested_checkpoint_set_early_stop_no_recompution_needed(self):
 | |
|         # Case 1: We have one tensor saved and its the input
 | |
| 
 | |
|         # We have two different counters here because in this case we actually
 | |
|         # do call into x.sin() at the python level during recomputation whether
 | |
|         # or not early stop is enabled. This is because the early stopping
 | |
|         # only happens at the autograd level (preventing us from reaching the
 | |
|         # backend).
 | |
|         python_dispatch_counter = [0]
 | |
|         counter = [0]
 | |
| 
 | |
|         class SinCounterMode(TorchDispatchMode):
 | |
|             def __init__(self):
 | |
|                 self.count = 0
 | |
| 
 | |
|             def __torch_dispatch__(self, func, types, args=(), kwargs=None):
 | |
|                 kwargs = {} if kwargs is None else kwargs
 | |
|                 if func is torch.ops.aten.sin.default:
 | |
|                     self.count += 1
 | |
|                 return func(*args, **kwargs)
 | |
| 
 | |
|         def fn(x):
 | |
|             counter[0] += 1
 | |
|             return x.sin()
 | |
| 
 | |
|         # With early stopping (enabled by default)
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         with SinCounterMode() as python_dispatch_counter:
 | |
|             out = checkpoint(fn, a, use_reentrant=False)
 | |
|             out.backward()
 | |
|         self.assertEqual(counter[0], 2)
 | |
|         self.assertEqual(python_dispatch_counter.count, 1)
 | |
| 
 | |
|         # Without early stopping
 | |
|         counter = [0]
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         with SinCounterMode() as python_dispatch_counter:
 | |
|             with torch.utils.checkpoint.set_checkpoint_early_stop(False):
 | |
|                 out = checkpoint(fn, a, use_reentrant=False)
 | |
|             out.backward()
 | |
|         self.assertEqual(counter[0], 2)
 | |
|         self.assertEqual(python_dispatch_counter.count, 2)
 | |
| 
 | |
|         # Case 2: Forward saves no tensors
 | |
| 
 | |
|         # Since unpack isn't even called, counter is 1 whether or not early stop
 | |
|         # is enabled!
 | |
|         counter = [0]
 | |
| 
 | |
|         def fn2(x):
 | |
|             counter[0] += 1
 | |
|             return x.clone()
 | |
| 
 | |
|         # With early stopping (enabled by default)
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         out = checkpoint(fn2, a, use_reentrant=False)
 | |
|         out.backward()
 | |
|         self.assertEqual(counter[0], 1)
 | |
| 
 | |
|         # Without early stopping
 | |
|         counter = [0]
 | |
|         a = torch.tensor(1., requires_grad=True)
 | |
|         with torch.utils.checkpoint.set_checkpoint_early_stop(False):
 | |
|             out = checkpoint(fn2, a, use_reentrant=False)
 | |
|         out.backward()
 | |
|         self.assertEqual(counter[0], 1)
 | |
| 
 | |
| 
 | |
| class TestAutogradMultipleDispatch(TestCase):
 | |
|     def test_autograd_multiple_dispatch_registrations(self, device):
 | |
|         t = torch.randn(3, 3, device=device, requires_grad=True)
 | |
|         # using _test_autograd_multiple_dispatch.fullcoverage which has
 | |
|         # registrations in derivatives.yaml for Default, AutogradCUDA and NestedTensorAutograd
 | |
|         out = torch._test_autograd_multiple_dispatch(t)
 | |
|         grad = torch.randn(3, 3, device=device)
 | |
|         out.backward(grad)
 | |
| 
 | |
|         if 'cuda' not in device:
 | |
|             # bogus default gradient registered for Autograd is grad + 1
 | |
|             self.assertEqual(t.grad, grad + 1)
 | |
|         else:
 | |
|             # bogus gradient registered for AutogradCUDA is grad * 2
 | |
|             self.assertEqual(t.grad, grad * 2)
 | |
| 
 | |
|         # test registered AutogradNestedTensor formula
 | |
|         a = torch.arange(6, dtype=torch.float, device=device).reshape(2, 3).requires_grad_(True)
 | |
|         b = torch.arange(8, dtype=torch.float, device=device).reshape(2, 4).requires_grad_(True)
 | |
|         nt = torch.nested.as_nested_tensor([a, b], dtype=torch.float, device=device)
 | |
| 
 | |
|         nt_out = torch._test_autograd_multiple_dispatch(nt)
 | |
|         c = torch.randn(2, 3, device=device)
 | |
|         d = torch.randn(2, 4, device=device)
 | |
|         nt_grad = torch.nested.nested_tensor([c, d], dtype=torch.float, device=device)
 | |
|         nt_out.backward(nt_grad)
 | |
| 
 | |
|         # bogus gradient for AutogradNestedTensor is grad * grad
 | |
|         self.assertEqual(a.grad, c * c)
 | |
|         self.assertEqual(b.grad, d * d)
 | |
| 
 | |
|     def test_autograd_composite_implicit_and_dispatch_registration(self, device):
 | |
|         t = torch.randn(3, 3, device=device, requires_grad=True)
 | |
|         # using _test_autograd_multiple_dispatch.ntonly
 | |
|         # which has registrations in derivatives.yaml for NestedTensorAutograd and otherwise is CompositeImplicit
 | |
|         out = torch._test_autograd_multiple_dispatch(t, True)
 | |
|         grad = torch.randn(3, 3, device=device)
 | |
|         out.backward(grad)
 | |
| 
 | |
|         # t.grad is just out.grad by composite op since _test_autograd_multiple_dispatch is just a clone
 | |
|         self.assertEqual(t.grad, grad)
 | |
| 
 | |
|         # test registered AutogradNestedTensor formula
 | |
|         a = torch.arange(6, dtype=torch.float, device=device).reshape(2, 3).requires_grad_(True)
 | |
|         b = torch.arange(8, dtype=torch.float, device=device).reshape(2, 4).requires_grad_(True)
 | |
|         nt = torch.nested.as_nested_tensor([a, b], dtype=torch.float, device=device)
 | |
| 
 | |
|         nt_out = torch._test_autograd_multiple_dispatch(nt, True)
 | |
|         c = torch.randn(2, 3, device=device)
 | |
|         d = torch.randn(2, 4, device=device)
 | |
|         nt_grad = torch.nested.nested_tensor([c, d], dtype=torch.float, device=device)
 | |
|         nt_out.backward(nt_grad)
 | |
| 
 | |
|         # bogus gradient for AutogradNestedTensor is grad * grad + grad
 | |
|         self.assertEqual(a.grad, c * c + c)
 | |
|         self.assertEqual(b.grad, d * d + d)
 | |
| 
 | |
|     def test_foward_mode_AD(self, device):
 | |
|         # check that forward mode AD is only registered for the Default
 | |
|         # dispatch for _test_autograd_multiple_dispatch.fullcoverage and not AutogradCUDA
 | |
| 
 | |
|         primal = torch.randn(3, device=device)
 | |
|         tangent = torch.randn(3, device=device)
 | |
| 
 | |
|         with fwAD.dual_level():
 | |
|             dual_input = fwAD.make_dual(primal, tangent)
 | |
| 
 | |
|             err_msg = r"Trying to use forward AD with .* that does not support it"
 | |
|             hint_msg = "Running forward AD for an OP that does not implement it should raise a NotImplementedError"
 | |
| 
 | |
|             if 'cuda' in device:
 | |
|                 with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
 | |
|                     torch._test_autograd_multiple_dispatch(dual_input)
 | |
|             else:
 | |
|                 torch._test_autograd_multiple_dispatch(dual_input)
 | |
| 
 | |
|     def test_view_copy(self, device):
 | |
|         # tests that view_copy derivative formulas are also generated per dispatch key
 | |
|         # from their respective view ops in derivatives.yaml
 | |
|         t = torch.randn(2, 2, device=device, requires_grad=True)
 | |
|         t_ref = t.clone().detach().requires_grad_()
 | |
|         # _test_autograd_multiple_dispatch_view does a .view(-1) on the input
 | |
|         t_view = torch._test_autograd_multiple_dispatch_view(t_ref)
 | |
|         t_view_copy = torch._test_autograd_multiple_dispatch_view_copy(t)
 | |
| 
 | |
|         grad = torch.randn(4, device=device)
 | |
|         t_view_copy.backward(grad)
 | |
|         t_view.backward(grad.clone())
 | |
| 
 | |
|         # forward and backward give the same shape + result
 | |
|         self.assertEqual(t_view_copy, t_view)
 | |
|         self.assertEqual(t.grad, t_ref.grad)
 | |
|         # backward results are per-dispatch-key in derivatives.yaml
 | |
|         if 'cuda' in device:
 | |
|             # gradient registered to AutogradCUDA is grad.reshape_as(self) + 1
 | |
|             self.assertEqual(t.grad, grad.reshape_as(t) + 1)
 | |
|         else:
 | |
|             # Default gradient registered is grad.reshape_as(self)
 | |
|             self.assertEqual(t.grad, grad.reshape_as(t))
 | |
| 
 | |
|     @onlyCPU
 | |
|     def test_per_dispatch_key_input_saving(self, device):
 | |
|         # Tests that sum.dim_IntList's input is not saved for regular tensors but is saved for nested tensors
 | |
|         def foo(x):
 | |
|             # Don't modify the input inplace
 | |
|             x = x.clone()
 | |
|             res = x.sum(-1, keepdim=True)
 | |
|             x.add_(x)
 | |
|             return res
 | |
| 
 | |
|         inp = torch.rand(2, device=device, requires_grad=True)
 | |
|         # sum's input is not saved for regular Tensors
 | |
|         foo(inp).backward()
 | |
| 
 | |
|         # sum's input is saved for Nested Tensors
 | |
|         nt = torch.nested.nested_tensor([torch.rand(2), torch.rand(2)], device=device, requires_grad=True)
 | |
|         with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
 | |
|             foo(nt).backward(torch.nested.nested_tensor([torch.rand(1), torch.rand(1)], device=device))
 | |
| 
 | |
|     @onlyCUDA
 | |
|     def test_backward_single_threaded(self):
 | |
| 
 | |
|         threads_eq = None
 | |
| 
 | |
|         class TestFn(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, self):
 | |
|                 ctx.self = self
 | |
|                 ctx.tid = threading.get_ident()
 | |
|                 return x.clone()
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, gO):
 | |
|                 nonlocal threads_eq
 | |
|                 threads_eq = ctx.tid == threading.get_ident()
 | |
|                 return gO, None
 | |
| 
 | |
|         inp = torch.rand(10, device="cuda", requires_grad=True)
 | |
| 
 | |
|         with torch.autograd.set_multithreading_enabled(False):
 | |
|             TestFn.apply(inp, None).sum().backward()
 | |
|         self.assertTrue(threads_eq)
 | |
| 
 | |
|         TestFn.apply(inp, None).sum().backward()
 | |
|         self.assertFalse(threads_eq)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     def test_backward_tls_stash(self):
 | |
| 
 | |
|         local = threading.local()
 | |
|         local.my_obj = {}
 | |
|         local.my_obj[10] = 10
 | |
|         test_self = self
 | |
|         torch._C._stash_obj_in_tls("my_obj", local.my_obj)
 | |
| 
 | |
|         class TestFn(Function):
 | |
|             @staticmethod
 | |
|             def forward(ctx, x, self):
 | |
|                 return x.clone()
 | |
| 
 | |
|             @staticmethod
 | |
|             def backward(ctx, gO):
 | |
|                 test_self.assertTrue(torch._C._is_key_in_tls("my_obj"))
 | |
|                 test_self.assertTrue(torch._C._get_obj_in_tls("my_obj")[10] == 10)
 | |
|                 torch._C._get_obj_in_tls("my_obj")[10] = 5
 | |
|                 return gO, None
 | |
| 
 | |
|         inp = torch.rand(10, device="cuda", requires_grad=True)
 | |
| 
 | |
|         TestFn.apply(inp, None).sum().backward()
 | |
|         self.assertEqual(local.my_obj[10], 5)
 | |
| 
 | |
| 
 | |
| # Import test cases from below autograd/ here. These are found
 | |
| # implicitly by the loader, so Flake8 thinks they are unused, hence
 | |
| # the suppressions.
 | |
| 
 | |
| from autograd.test_complex import TestAutogradComplex  # noqa: F401
 | |
| from autograd.test_functional import TestAutogradFunctional  # noqa: F401
 | |
| 
 | |
| # e.g., TestAutogradDeviceTypeCPU and TestAutogradDeviceTypeCUDA
 | |
| instantiate_device_type_tests(
 | |
|     TestAutogradDeviceType,
 | |
|     globals(),
 | |
|     except_for=None
 | |
| )
 | |
| 
 | |
| instantiate_device_type_tests(
 | |
|     TestAutogradMultipleDispatch,
 | |
|     globals(),
 | |
|     only_for=('cpu', 'cuda')
 | |
| )
 | |
| 
 | |
| instantiate_parametrized_tests(TestAutograd)
 | |
| instantiate_parametrized_tests(TestNestedCheckpoint)
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     run_tests()
 |