mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixed issue with GradTrackingTensor not properly propagating sparse layout (#165765)
Fixes #164286 Fixed issue with GradTrackingTensor not properly propagating sparse layout. @ezyang @jcaip Pull Request resolved: https://github.com/pytorch/pytorch/pull/165765 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a25a649e70
commit
29b029648e
@ -160,6 +160,10 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
|
||||
DispatchKey::CUDA,
|
||||
DispatchKey::CPU,
|
||||
DispatchKey::PrivateUse1,
|
||||
DispatchKey::SparseCPU,
|
||||
DispatchKey::SparseCUDA,
|
||||
DispatchKey::SparseCsrCPU,
|
||||
DispatchKey::SparseCsrCUDA,
|
||||
});
|
||||
|
||||
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
|
||||
|
@ -313,6 +313,24 @@ class TestGradTransform(TestCase):
|
||||
def test_numel(self, device):
|
||||
self._test_attributes(lambda x: x.numel(), device)
|
||||
|
||||
def test_layout_sparse(self, device):
|
||||
indices = torch.tensor([[0, 1, 1], [2, 0, 2]], device=device)
|
||||
values = torch.tensor([3.0, 4.0, 5.0], device=device)
|
||||
sparse_x = torch.sparse_coo_tensor(indices, values, (2, 3), device=device)
|
||||
|
||||
# Verify the input is sparse
|
||||
self.assertEqual(sparse_x.layout, torch.sparse_coo)
|
||||
|
||||
def foo(x):
|
||||
# assert GradTrackingTensor still reports sparse layout
|
||||
self.assertEqual(x.layout, torch.sparse_coo)
|
||||
return x.coalesce()._values().sum()
|
||||
|
||||
result = grad(foo)(sparse_x)
|
||||
|
||||
# The gradient should also be sparse
|
||||
self.assertEqual(result.layout, torch.sparse_coo)
|
||||
|
||||
def test_inplace(self, device):
|
||||
x = torch.randn([], device=device)
|
||||
|
||||
|
Reference in New Issue
Block a user