Fixed issue with GradTrackingTensor not properly propagating sparse layout (#165765)

Fixes #164286

Fixed issue with GradTrackingTensor not properly propagating sparse layout.

@ezyang @jcaip
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165765
Approved by: https://github.com/ezyang
This commit is contained in:
Chris Leonard
2025-10-18 01:00:50 +00:00
committed by PyTorch MergeBot
parent a25a649e70
commit 29b029648e
2 changed files with 22 additions and 0 deletions

View File

@ -160,6 +160,10 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
DispatchKey::CUDA,
DispatchKey::CPU,
DispatchKey::PrivateUse1,
DispatchKey::SparseCPU,
DispatchKey::SparseCUDA,
DispatchKey::SparseCsrCPU,
DispatchKey::SparseCsrCUDA,
});
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {

View File

@ -313,6 +313,24 @@ class TestGradTransform(TestCase):
def test_numel(self, device):
self._test_attributes(lambda x: x.numel(), device)
def test_layout_sparse(self, device):
indices = torch.tensor([[0, 1, 1], [2, 0, 2]], device=device)
values = torch.tensor([3.0, 4.0, 5.0], device=device)
sparse_x = torch.sparse_coo_tensor(indices, values, (2, 3), device=device)
# Verify the input is sparse
self.assertEqual(sparse_x.layout, torch.sparse_coo)
def foo(x):
# assert GradTrackingTensor still reports sparse layout
self.assertEqual(x.layout, torch.sparse_coo)
return x.coalesce()._values().sum()
result = grad(foo)(sparse_x)
# The gradient should also be sparse
self.assertEqual(result.layout, torch.sparse_coo)
def test_inplace(self, device):
x = torch.randn([], device=device)