mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixed issue with GradTrackingTensor not properly propagating sparse layout (#165765)
Fixes #164286 Fixed issue with GradTrackingTensor not properly propagating sparse layout. @ezyang @jcaip Pull Request resolved: https://github.com/pytorch/pytorch/pull/165765 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a25a649e70
commit
29b029648e
@ -160,6 +160,10 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
|
|||||||
DispatchKey::CUDA,
|
DispatchKey::CUDA,
|
||||||
DispatchKey::CPU,
|
DispatchKey::CPU,
|
||||||
DispatchKey::PrivateUse1,
|
DispatchKey::PrivateUse1,
|
||||||
|
DispatchKey::SparseCPU,
|
||||||
|
DispatchKey::SparseCUDA,
|
||||||
|
DispatchKey::SparseCsrCPU,
|
||||||
|
DispatchKey::SparseCsrCUDA,
|
||||||
});
|
});
|
||||||
|
|
||||||
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
|
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
|
||||||
|
@ -313,6 +313,24 @@ class TestGradTransform(TestCase):
|
|||||||
def test_numel(self, device):
|
def test_numel(self, device):
|
||||||
self._test_attributes(lambda x: x.numel(), device)
|
self._test_attributes(lambda x: x.numel(), device)
|
||||||
|
|
||||||
|
def test_layout_sparse(self, device):
|
||||||
|
indices = torch.tensor([[0, 1, 1], [2, 0, 2]], device=device)
|
||||||
|
values = torch.tensor([3.0, 4.0, 5.0], device=device)
|
||||||
|
sparse_x = torch.sparse_coo_tensor(indices, values, (2, 3), device=device)
|
||||||
|
|
||||||
|
# Verify the input is sparse
|
||||||
|
self.assertEqual(sparse_x.layout, torch.sparse_coo)
|
||||||
|
|
||||||
|
def foo(x):
|
||||||
|
# assert GradTrackingTensor still reports sparse layout
|
||||||
|
self.assertEqual(x.layout, torch.sparse_coo)
|
||||||
|
return x.coalesce()._values().sum()
|
||||||
|
|
||||||
|
result = grad(foo)(sparse_x)
|
||||||
|
|
||||||
|
# The gradient should also be sparse
|
||||||
|
self.assertEqual(result.layout, torch.sparse_coo)
|
||||||
|
|
||||||
def test_inplace(self, device):
|
def test_inplace(self, device):
|
||||||
x = torch.randn([], device=device)
|
x = torch.randn([], device=device)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user