From 29b029648ed3871b83c28d4625bb5f969fe4cb41 Mon Sep 17 00:00:00 2001 From: Chris Leonard Date: Sat, 18 Oct 2025 01:00:50 +0000 Subject: [PATCH] Fixed issue with GradTrackingTensor not properly propagating sparse layout (#165765) Fixes #164286 Fixed issue with GradTrackingTensor not properly propagating sparse layout. @ezyang @jcaip Pull Request resolved: https://github.com/pytorch/pytorch/pull/165765 Approved by: https://github.com/ezyang --- aten/src/ATen/functorch/BatchedTensorImpl.h | 4 ++++ test/functorch/test_eager_transforms.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/aten/src/ATen/functorch/BatchedTensorImpl.h b/aten/src/ATen/functorch/BatchedTensorImpl.h index 3eccc94d3ea6..985b289b3fe0 100644 --- a/aten/src/ATen/functorch/BatchedTensorImpl.h +++ b/aten/src/ATen/functorch/BatchedTensorImpl.h @@ -160,6 +160,10 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({ DispatchKey::CUDA, DispatchKey::CPU, DispatchKey::PrivateUse1, + DispatchKey::SparseCPU, + DispatchKey::SparseCUDA, + DispatchKey::SparseCsrCPU, + DispatchKey::SparseCsrCUDA, }); inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) { diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index ca19be644466..0a5d03f9dd1f 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -313,6 +313,24 @@ class TestGradTransform(TestCase): def test_numel(self, device): self._test_attributes(lambda x: x.numel(), device) + def test_layout_sparse(self, device): + indices = torch.tensor([[0, 1, 1], [2, 0, 2]], device=device) + values = torch.tensor([3.0, 4.0, 5.0], device=device) + sparse_x = torch.sparse_coo_tensor(indices, values, (2, 3), device=device) + + # Verify the input is sparse + self.assertEqual(sparse_x.layout, torch.sparse_coo) + + def foo(x): + # assert GradTrackingTensor still reports sparse layout + self.assertEqual(x.layout, torch.sparse_coo) + return x.coalesce()._values().sum() + + result = grad(foo)(sparse_x) + + # The gradient should also be sparse + self.assertEqual(result.layout, torch.sparse_coo) + def test_inplace(self, device): x = torch.randn([], device=device)