mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: Action following https://github.com/pytorch/pytorch/issues/66232 This change does require some context: there were several suggestions regarding what to do about this group of tests: tests that are core and crucial to all of PyTorch and are too broad to be owned by one team. 1. Let's add a "module: core" and put people behind it! This idea sounds appealing unless you are one of the people backing the label. From talking to albanD among others, this idea of putting all these core tests on the shoulder of a few people or one team isn't super fair and I have not yet found anyone willing to take on this job. 2. Taking advantage of the fact that we already have a triaging oncall that takes turns triaging issues, we can leave these tests essentially unlabeled and allow the oncall to triage these tests. Since these tests are crucial to PyTorch, we'll add the "high priority" label to mark them different from other unowned tests (see https://github.com/pytorch/pytorch/issues/67552). 3. I _could_ still create an unbacked label "module: core" and attribute these tests there, but I don't like the idea of creating a facade that the tests are "triaged" to a label when no one is actually taking a look. Now we could potentially break these tests down into smaller files so that each piece _could_ be owned by a team, but 1. I don't know if this is currently feasible and 2. This approach does not prevent that from happening in the future. Pull Request resolved: https://github.com/pytorch/pytorch/pull/67553 Reviewed By: albanD Differential Revision: D32025004 Pulled By: janeyx99 fbshipit-source-id: 1fb1aa4c27e305695ab6e80ae3d02f90519939c0
528 lines
19 KiB
Python
528 lines
19 KiB
Python
# Owner(s): ["high priority"]
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
from torch.utils._pytree import tree_map
|
|
from torch.utils._python_dispatch import enable_python_mode
|
|
|
|
from typing import Iterator, List
|
|
import logging
|
|
import contextlib
|
|
import itertools
|
|
|
|
# TODO: move this into library proper
|
|
@contextlib.contextmanager
|
|
def no_dispatch() -> Iterator[None]:
|
|
guard = torch._C._DisableTorchDispatch()
|
|
try:
|
|
yield
|
|
finally:
|
|
del guard
|
|
|
|
|
|
# How the chain of calls works for LoggingTensor:
|
|
# 1. Call torch.sin
|
|
# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely
|
|
# 3. Enter dispatcher, wind your way through Autograd
|
|
# 4. Hit Python dispatch key, call __torch_dispatch__
|
|
|
|
# TODO: TensorBase should work
|
|
class LoggingTensor(torch.Tensor):
|
|
elem: torch.Tensor
|
|
|
|
__slots__ = ['elem']
|
|
|
|
@staticmethod
|
|
def __new__(cls, elem, *args, **kwargs):
|
|
# The wrapping tensor (LoggingTensor) shouldn't hold any
|
|
# memory for the class in question, but it should still
|
|
# advertise the same device as before
|
|
r = torch.Tensor._make_wrapper_subclass(
|
|
cls, elem.size(),
|
|
# TODO: clone strides and storage aliasing
|
|
dtype=elem.dtype, layout=elem.layout,
|
|
device=elem.device, requires_grad=elem.requires_grad
|
|
)
|
|
# ...the real tensor is held as an element on the tensor.
|
|
r.elem = elem
|
|
return r
|
|
|
|
def __repr__(self):
|
|
return f"LoggingTensor({self.elem})"
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
def unwrap(e):
|
|
return e.elem if isinstance(e, LoggingTensor) else e
|
|
|
|
def wrap(e):
|
|
return LoggingTensor(e) if isinstance(e, torch.Tensor) else e
|
|
|
|
# no_dispatch is only needed if you use enable_python_mode.
|
|
# It prevents infinite recursion.
|
|
with no_dispatch():
|
|
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
|
logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
|
|
return rs
|
|
|
|
# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list
|
|
class LoggingTensorHandler(logging.Handler):
|
|
log_list: List[str]
|
|
next_shortid: int
|
|
|
|
def __init__(self, log_list: List[str]) -> None:
|
|
logging.Handler.__init__(self)
|
|
self.log_list = log_list
|
|
self.next_shortid = 0
|
|
|
|
# WARNING: not deterministic over multiple threads, this matters for
|
|
# autograd
|
|
def _shortid(self, o: object) -> int:
|
|
if not hasattr(o, '_shortid'):
|
|
o._shortid = self.next_shortid
|
|
self.next_shortid += 1
|
|
return o._shortid
|
|
|
|
def _fmt(self, a: object) -> str:
|
|
return f'${self._shortid(a)}' if isinstance(a, LoggingTensor) else repr(a)
|
|
|
|
def emit(self, record):
|
|
fmt_args = ", ".join(itertools.chain(
|
|
(self._fmt(a) for a in record.args[0]),
|
|
(f"{k}={self._fmt(v)}" for k, v in record.args[1].items())
|
|
))
|
|
fmt_rets = ", ".join(self._fmt(a) for a in record.args[2]) \
|
|
if isinstance(record.args[2], (list, tuple)) else self._fmt(record.args[2])
|
|
self.log_list.append(f'{fmt_rets} = {record.msg}({fmt_args})')
|
|
|
|
def log_input(name: str, var: object):
|
|
logging.getLogger("LoggingTensor").info("input", (name,), {}, (var,))
|
|
|
|
@contextlib.contextmanager
|
|
def capture_logs() -> Iterator[List[str]]:
|
|
logger = logging.getLogger("LoggingTensor")
|
|
log_list = []
|
|
handler = LoggingTensorHandler(log_list)
|
|
logger.addHandler(handler)
|
|
logger.setLevel(logging.INFO)
|
|
logger.propagate = False
|
|
try:
|
|
yield log_list
|
|
finally:
|
|
logger.removeHandler(handler)
|
|
|
|
class TestPythonDispatch(TestCase):
|
|
def test_basic(self) -> None:
|
|
with capture_logs() as logs:
|
|
x = LoggingTensor(torch.tensor([3.0], requires_grad=True))
|
|
log_input("x", x)
|
|
y = x * x
|
|
saved_x = y.grad_fn._saved_self
|
|
grad_y = LoggingTensor(torch.tensor([1.0]))
|
|
log_input("grad_y", grad_y)
|
|
g, = torch.autograd.grad((y,), (x,), (grad_y,))
|
|
|
|
self.assertEqual(g.elem, torch.tensor([6.0]))
|
|
with torch.no_grad():
|
|
self.assertEqual(saved_x, x)
|
|
self.assertEqual(saved_x._version, x._version)
|
|
x.add_(2)
|
|
self.assertEqual(saved_x, x)
|
|
# TODO: figure out why broken
|
|
# self.assertEqual(saved_x._version, x._version)
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
|
$0 = input('x')
|
|
$1 = torch._ops.aten.mul($0, $0)
|
|
$2 = input('grad_y')
|
|
$3 = torch._ops.aten.mul($2, $0)
|
|
$4 = torch._ops.aten.mul($2, $0)
|
|
$5 = torch._ops.aten.add($4, $3)''')
|
|
|
|
def test_out(self) -> None:
|
|
with capture_logs() as logs:
|
|
x = LoggingTensor(torch.ones(1))
|
|
y = LoggingTensor(torch.zeros(1))
|
|
log_input("x", x)
|
|
log_input("y", y)
|
|
torch.abs(x, out=y)
|
|
|
|
self.assertEqual(y.elem, torch.ones(1))
|
|
# TODO: arguably this shouldn't pass and we should complain
|
|
# that out isn't a kwarg
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
|
$0 = input('x')
|
|
$1 = input('y')
|
|
$2 = torch._ops.aten.abs($0, out=$1)''')
|
|
|
|
|
|
def test_kwarg_only(self) -> None:
|
|
with capture_logs() as logs:
|
|
x = LoggingTensor(torch.ones(1))
|
|
y = LoggingTensor(torch.ones(1, 1))
|
|
z = LoggingTensor(torch.ones(1))
|
|
log_input("x", x)
|
|
log_input("y", y)
|
|
log_input("z", z)
|
|
torch.addmv(x, y, z)
|
|
torch.addmv(x, y, z, beta=1)
|
|
torch.addmv(x, y, z, beta=2)
|
|
torch.addmv(x, y, z, alpha=2)
|
|
torch.addmv(x, y, z, beta=2, alpha=2)
|
|
|
|
# The expectation is that beta/alpha don't show up when they're
|
|
# defaulted. This is even if the user explicitly specified it.
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
|
$0 = input('x')
|
|
$1 = input('y')
|
|
$2 = input('z')
|
|
$3 = torch._ops.aten.addmv($0, $1, $2)
|
|
$4 = torch._ops.aten.addmv($0, $1, $2)
|
|
$5 = torch._ops.aten.addmv($0, $1, $2, beta=2)
|
|
$6 = torch._ops.aten.addmv($0, $1, $2, alpha=2)
|
|
$7 = torch._ops.aten.addmv($0, $1, $2, beta=2, alpha=2)''')
|
|
|
|
def test_kwarg_only_and_positional_default(self) -> None:
|
|
with capture_logs() as logs:
|
|
x = LoggingTensor(torch.ones(1))
|
|
y = LoggingTensor(torch.ones(1))
|
|
log_input("x", x)
|
|
log_input("y", y)
|
|
torch.ops.aten.kl_div(x, y)
|
|
torch.ops.aten.kl_div(x, y, 2)
|
|
torch.ops.aten.kl_div(x, y, log_target=True)
|
|
torch.ops.aten.kl_div(x, y, 2, log_target=True)
|
|
|
|
# What we are testing here is that we omit reduction
|
|
# if it is defaulted, even if a kwarg is set
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
|
$0 = input('x')
|
|
$1 = input('y')
|
|
$2 = torch._ops.aten.kl_div($0, $1)
|
|
$3 = torch._ops.aten.kl_div($0, $1, 2)
|
|
$4 = torch._ops.aten.kl_div($0, $1, log_target=True)
|
|
$5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''')
|
|
|
|
def test_list_ret(self) -> None:
|
|
# test all sequence types are permissible returns
|
|
for list_type in (list, tuple):
|
|
class A(torch._C._TensorBase):
|
|
@staticmethod
|
|
def __new__(cls, elem):
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
if func == torch.ops.aten.split:
|
|
with no_dispatch():
|
|
return list_type(torch.split(*args))
|
|
else:
|
|
raise AssertionError(f"unrecognized func: {func}")
|
|
|
|
self.assertEqual(
|
|
torch.split(A(torch.tensor([0, 1])), 2),
|
|
torch.split(torch.tensor([0, 1]), 2)
|
|
)
|
|
|
|
def test_invalid_ret(self) -> None:
|
|
# test invalid return gets reasonable error message
|
|
class A(torch._C._TensorBase):
|
|
@staticmethod
|
|
def __new__(cls, elem):
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
return "arf"
|
|
|
|
# Wobbles depending on NDEBUG mode of pybind11
|
|
self.assertRaisesRegexp(
|
|
RuntimeError, "Unable to cast", lambda: A(torch.zeros(1)).neg(),
|
|
)
|
|
self.assertExpectedRaisesInline(
|
|
RuntimeError, lambda: A(torch.zeros(1)).detach(),
|
|
"""detach returned invalid type str, expected Tensor"""
|
|
)
|
|
|
|
def test_metadata_change_not_allowed(self) -> None:
|
|
x = LoggingTensor(torch.ones(1))
|
|
y = x.data
|
|
self.assertIsInstance(y, LoggingTensor)
|
|
self.assertRaises(RuntimeError, lambda: y.resize_(4))
|
|
|
|
def test_storage(self) -> None:
|
|
# For now, just make sure it doesn't crash. Ideally, we should
|
|
# return some virtual storage that is safe to work with
|
|
x = LoggingTensor(torch.ones(1))
|
|
self.assertRaises(RuntimeError, lambda: x.storage())
|
|
|
|
def test_make_wrapper_subclass_noalloc(self) -> None:
|
|
# This is ludicrously big (8TB) and this should pass because wrapper
|
|
# subclasses don't allocate
|
|
torch.Tensor._make_wrapper_subclass(LoggingTensor, (1000000000000,))
|
|
|
|
def test_version(self) -> None:
|
|
x = LoggingTensor(torch.ones(1))
|
|
prev_vc = x._version
|
|
x.detach().add_(2)
|
|
cur_vc = x._version
|
|
self.assertNotEqual(prev_vc, cur_vc)
|
|
x.data.add_(2)
|
|
self.assertEqual(cur_vc, x._version)
|
|
|
|
def test_subclass_priority(self) -> None:
|
|
class ErrorA(RuntimeError):
|
|
pass
|
|
|
|
class ErrorB(RuntimeError):
|
|
pass
|
|
|
|
# The big tests for code coverage are test_precedence_semantics in
|
|
# test_overrides.py; this is just to make sure it is wired up at all
|
|
# correctly for __torch_dispatch__
|
|
class A(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, elem):
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
raise ErrorA
|
|
|
|
class B(A):
|
|
@staticmethod
|
|
def __new__(cls, elem):
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
raise ErrorB
|
|
|
|
self.assertRaises(ErrorA, lambda: torch.add(A(torch.empty(1)), A(torch.empty(1))))
|
|
self.assertRaises(ErrorB, lambda: torch.add(A(torch.empty(1)), B(torch.empty(1))))
|
|
self.assertRaises(ErrorB, lambda: torch.add(B(torch.empty(1)), A(torch.empty(1))))
|
|
self.assertRaises(ErrorB, lambda: torch.add(B(torch.empty(1)), B(torch.empty(1))))
|
|
|
|
def test_format(self) -> None:
|
|
x = LoggingTensor(torch.ones(1))
|
|
s1 = str(x)
|
|
s2 = repr(x)
|
|
s3 = f"{x}"
|
|
self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""")
|
|
self.assertEqual(s1, s2)
|
|
self.assertEqual(s1, s3)
|
|
|
|
def test_custom_autograd(self) -> None:
|
|
escape = [None]
|
|
|
|
class Square(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
y = x ** 2
|
|
ctx.save_for_backward(x)
|
|
return y
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
assert isinstance(grad_output, LoggingTensor)
|
|
x, = ctx.saved_tensors
|
|
assert isinstance(x, LoggingTensor)
|
|
escape[0] = x
|
|
return grad_output * 2 * x
|
|
|
|
with capture_logs() as logs:
|
|
x = LoggingTensor(torch.ones(1, requires_grad=True))
|
|
log_input("x", x)
|
|
x.grad = LoggingTensor(torch.zeros(1))
|
|
log_input("x.grad", x.grad)
|
|
y = Square.apply(x)
|
|
grad_output = LoggingTensor(torch.ones(1))
|
|
log_input("grad_output", grad_output)
|
|
y.backward(grad_output)
|
|
|
|
with torch.no_grad():
|
|
self.assertEqual(escape[0], x)
|
|
self.assertEqual(escape[0]._version, x._version)
|
|
# TODO: figure out why x.requires_grad = False doesn't
|
|
# trigger an error for LoggingTensor
|
|
x.add_(2)
|
|
self.assertEqual(escape[0], x)
|
|
# TODO: figure out why this is broken
|
|
# self.assertEqual(escape[0]._version, x._version)
|
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
|
$0 = input('x')
|
|
$1 = input('x.grad')
|
|
$2 = torch._ops.aten.pow($0, 2)
|
|
$3 = input('grad_output')
|
|
$4 = torch._ops.aten.mul($3, tensor(2))
|
|
$5 = torch._ops.aten.mul($4, $0)
|
|
$6 = torch._ops.aten.add_($1, $5)''')
|
|
|
|
def test_subclass_creation(self):
|
|
# Make sure these statements runs without error
|
|
# In particular checking that when internal detach returns
|
|
# subclasses, these are cleanly overwritten.
|
|
class Foo(torch.Tensor):
|
|
pass
|
|
|
|
err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor"
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2)))
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
b = LoggingTensor(torch.rand(2)).as_subclass(Foo)
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
Foo(LoggingTensor(torch.rand(2)))
|
|
|
|
with self.assertRaisesRegex(TypeError, "Foo must define __torch_dispatch__"):
|
|
torch.Tensor._make_wrapper_subclass(Foo, (2, 2))
|
|
|
|
def test_new_ones(self) -> None:
|
|
class MyTensor(torch.Tensor):
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
return MyTensor(3)
|
|
|
|
self.assertEqual(type(MyTensor(2).new_ones(3)), MyTensor)
|
|
|
|
def test_like(self) -> None:
|
|
class MyTensor(torch.Tensor):
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
return MyTensor(3)
|
|
|
|
for f in ["empty", "ones", "rand", "randn", "zeros"]:
|
|
f_name = f + "_like"
|
|
self.assertEqual(type(getattr(torch, f_name)(MyTensor(2))), MyTensor)
|
|
|
|
self.assertEqual(type(torch.full_like(MyTensor(2), 1.)), MyTensor)
|
|
self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor)
|
|
|
|
def test_enable_python_mode_error(self) -> None:
|
|
with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
|
|
with enable_python_mode(torch.Tensor):
|
|
pass
|
|
z = LoggingTensor(torch.empty([]))
|
|
with self.assertRaisesRegex(ValueError, "must be the type"):
|
|
with enable_python_mode(z):
|
|
pass
|
|
|
|
def test_enable_python_mode_basic(self) -> None:
|
|
with enable_python_mode(LoggingTensor):
|
|
z = torch.empty([])
|
|
self.assertTrue(isinstance(z, LoggingTensor))
|
|
|
|
def test_enable_python_mode_unrelated_tensors(self) -> None:
|
|
x = torch.randn([])
|
|
y = torch.randn([])
|
|
with enable_python_mode(LoggingTensor):
|
|
z = x + y
|
|
self.assertTrue(isinstance(z, LoggingTensor))
|
|
|
|
def test_enable_python_mode_subclass_priority(self) -> None:
|
|
class ErrorA(RuntimeError):
|
|
pass
|
|
|
|
class ErrorB(RuntimeError):
|
|
pass
|
|
|
|
class A(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, elem):
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
raise ErrorA
|
|
|
|
class B(A):
|
|
@staticmethod
|
|
def __new__(cls, elem):
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
raise ErrorB
|
|
|
|
a = A(torch.empty(1))
|
|
b = B(torch.empty(1))
|
|
with self.assertRaises(ErrorA):
|
|
a + a
|
|
|
|
# B has precedence over A due to the subclass relationship
|
|
with self.assertRaises(ErrorB):
|
|
with enable_python_mode(A):
|
|
b + b
|
|
with self.assertRaises(ErrorB):
|
|
with enable_python_mode(B):
|
|
a + a
|
|
with self.assertRaises(ErrorB):
|
|
with enable_python_mode(B):
|
|
a + b
|
|
|
|
def test_enable_python_mode_respects_no_dispatch(self) -> None:
|
|
with enable_python_mode(LoggingTensor):
|
|
z = torch.ones([2, 3])
|
|
self.assertTrue(isinstance(z, LoggingTensor))
|
|
with no_dispatch():
|
|
expected = torch.ones([2, 3])
|
|
self.assertEqual(z.elem, expected)
|
|
|
|
def test_nested_enable_python_mode(self) -> None:
|
|
with self.assertRaisesRegex(RuntimeError, "has already been set"):
|
|
with enable_python_mode(LoggingTensor):
|
|
with enable_python_mode(LoggingTensor):
|
|
pass
|
|
|
|
def test_tolist_numpy_with_python_mode(self) -> None:
|
|
x = LoggingTensor(torch.tensor([2.0, 3.0]))
|
|
with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
|
|
x.tolist()
|
|
with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
|
|
x.numpy()
|
|
with self.assertRaises(AssertionError):
|
|
self.assertEqual(x, None)
|
|
|
|
def test_enable_python_mode_subclass_autograd_device_check(self) -> None:
|
|
class NonWrapperSublass(torch.Tensor):
|
|
elem: torch.Tensor
|
|
|
|
__slots__ = ['elem']
|
|
|
|
@staticmethod
|
|
def __new__(cls, elem, *args, **kwargs):
|
|
# Wrong device here!
|
|
r = torch.Tensor._make_subclass(cls, elem.to("meta"), elem.requires_grad)
|
|
# ...the real tensor is held as an element on the tensor.
|
|
r.elem = elem
|
|
return r
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
def unwrap(e):
|
|
return e.elem if isinstance(e, NonWrapperSublass) else e
|
|
|
|
def wrap(e):
|
|
return NonWrapperSublass(e) if isinstance(e, torch.Tensor) else e
|
|
|
|
# no_dispatch is only needed if you use enable_python_mode.
|
|
# It prevents infinite recursion.
|
|
with no_dispatch():
|
|
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
|
logging.getLogger("NonWrapperSublass").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
|
|
return rs
|
|
|
|
x = NonWrapperSublass(torch.tensor([3.0, 4.0], requires_grad=True))
|
|
y = torch.randn(2, requires_grad=True)
|
|
z = x * y
|
|
self.assertIsInstance(z, NonWrapperSublass)
|
|
z.sum().backward(torch.tensor(1))
|
|
self.assertEqual(x.grad, y)
|
|
self.assertEqual(y.grad, x)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|