Disable pinning check when loading sparse tensors (#154638)

Disables pinning check as unnecessary and to fix https://github.com/pytorch/pytorch/issues/153143 when loading sparse tensor from external storage with sparse tensor invariants check enabled.

Fixes https://github.com/pytorch/pytorch/issues/153143 .

For FC, to be landed two weeks after https://github.com/pytorch/pytorch/pull/154617, see https://github.com/pytorch/pytorch/pull/154617#issuecomment-2919643612.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154638
Approved by: https://github.com/amjames, https://github.com/ngimel
This commit is contained in:
Pearu Peterson
2025-06-17 20:01:25 +00:00
committed by PyTorch MergeBot
parent 8f02161d10
commit c177abd217
2 changed files with 67 additions and 10 deletions

View File

@ -133,11 +133,28 @@ supported_multiprocessing_contexts = [None] + list(
)
# collate_fn that returns the batch cloned; defined globally here for pickle purposes.
# The following collate functions are defined globally here for pickle purposes.
# collate_fn that returns the batch cloned
def _clone_collate(b):
return [x.clone() for x in b]
# collate_fn that returns the batch of sparse coo tensors cloned
def _sparse_coo_collate(b):
lst = []
for x in b:
t = x.clone()
lst.append(t)
# Force sparse tensor invariants checks. check_pinning=True
# reproduces gh-153143.
torch._validate_sparse_coo_tensor_args(
t._indices(), t._values(), t.size(), t.is_coalesced(), check_pinning=False
)
return lst
@unittest.skipIf(
TEST_WITH_TSAN,
"Fails with TSAN with the following error: starting new threads after multi-threaded "
@ -2893,8 +2910,9 @@ class TestDataLoaderDeviceType(TestCase):
def test_nested_tensor_multiprocessing(self, device, context):
# The 'fork' multiprocessing context doesn't work for CUDA so skip it
if "cuda" in device and context == "fork":
# TODO: Skip this better in a better way when the test framework allows
return
self.skipTest(
f"{context} multiprocessing context not supported for {device}"
)
dataset = [
torch.nested.nested_tensor([torch.randn(5)], device=device)
@ -2932,6 +2950,37 @@ class TestDataLoaderDeviceType(TestCase):
next(iter(loader))
@parametrize(
"context",
[ctx for ctx in supported_multiprocessing_contexts if ctx is not None],
)
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
def test_sparse_tensor_multiprocessing(self, device, context):
# The 'fork' multiprocessing context doesn't work for CUDA so skip it
if "cuda" in device and context == "fork":
self.skipTest(
f"{context} multiprocessing context not supported for {device}"
)
dataset = [torch.randn(5, 5).to_sparse().to(device) for _ in range(10)]
pin_memory_settings = [False]
if device == "cpu" and torch.cuda.is_available():
pin_memory_settings.append(True)
for pin_memory in pin_memory_settings:
loader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=4,
collate_fn=_sparse_coo_collate,
pin_memory=pin_memory,
multiprocessing_context=context,
)
for i, batch in enumerate(loader):
self.assertEqual(batch[0], dataset[i])
class IntegrationTestDataLoaderDataPipe(TestCase):
r"""

View File

@ -282,14 +282,17 @@ def _validate_loaded_sparse_tensors():
_sparse_tensors_to_validate.clear()
return
try:
# We disable pinning check (see check_pinning=False below) to
# avoid gh-153143. In fact, pinning check is unnecessary
# anywhy when loading sparse data from external sources.
for t in _sparse_tensors_to_validate:
if True:
# Temporarily disable sparse tensor validation due to
# gh-153143.
pass
elif t.layout is torch.sparse_coo:
if t.layout is torch.sparse_coo:
torch._validate_sparse_coo_tensor_args(
t._indices(), t._values(), t.size(), t.is_coalesced()
t._indices(),
t._values(),
t.size(),
t.is_coalesced(),
check_pinning=False,
)
elif t.layout in {
torch.sparse_csr,
@ -310,7 +313,12 @@ def _validate_loaded_sparse_tensors():
t.row_indices(),
)
torch._validate_sparse_compressed_tensor_args(
compressed_indices, plain_indices, t.values(), t.size(), t.layout
compressed_indices,
plain_indices,
t.values(),
t.size(),
t.layout,
check_pinning=False,
)
else:
raise NotImplementedError(