mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
code-generate non-aliasing {view}_copy kernels (#73442)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73442 Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D35016025 Pulled By: bdhirsh fbshipit-source-id: 2a7f303ec76f5913b744c7822a531d55a57589c9 (cherry picked from commit 3abe13c2a787bcbe9c41b0a335c96e5a3d3642fb)
This commit is contained in:
committed by
PyTorch MergeBot
parent
dfcb7035a0
commit
23b8414391
@ -908,6 +908,24 @@ class TestViewOps(TestCase):
|
||||
op = partial(fn, source=0, destination=1)
|
||||
run_test(device, op)
|
||||
|
||||
# Testing that the generated view_copy kernel and its derivative are implemented correctly
|
||||
def test_view_copy(self, device):
|
||||
a = torch.randn(4, device=device, requires_grad=True)
|
||||
a_ref = a.clone().detach().requires_grad_()
|
||||
a_view = a_ref.view(2, 2)
|
||||
a_view_copy = torch.view_copy(a, (2, 2))
|
||||
|
||||
# view_copy ops don't preserve view relationship
|
||||
self.assertTrue(self.is_view_of(a_ref, a_view))
|
||||
self.assertFalse(self.is_view_of(a, a_view_copy))
|
||||
|
||||
a_view_copy.sum().backward()
|
||||
a_view.sum().backward()
|
||||
|
||||
# forward and backward give the same shape + result
|
||||
self.assertEqual(a_view_copy, a_view)
|
||||
self.assertEqual(a.grad, a_ref.grad)
|
||||
|
||||
class TestOldViewOps(TestCase):
|
||||
def test_ravel(self, device):
|
||||
|
||||
|
Reference in New Issue
Block a user