mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Disable pinning check when loading sparse tensors (#154638)
Disables pinning check as unnecessary and to fix https://github.com/pytorch/pytorch/issues/153143 when loading sparse tensor from external storage with sparse tensor invariants check enabled. Fixes https://github.com/pytorch/pytorch/issues/153143 . For FC, to be landed two weeks after https://github.com/pytorch/pytorch/pull/154617, see https://github.com/pytorch/pytorch/pull/154617#issuecomment-2919643612. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154638 Approved by: https://github.com/amjames, https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
8f02161d10
commit
c177abd217
@ -133,11 +133,28 @@ supported_multiprocessing_contexts = [None] + list(
|
||||
)
|
||||
|
||||
|
||||
# collate_fn that returns the batch cloned; defined globally here for pickle purposes.
|
||||
# The following collate functions are defined globally here for pickle purposes.
|
||||
|
||||
|
||||
# collate_fn that returns the batch cloned
|
||||
def _clone_collate(b):
|
||||
return [x.clone() for x in b]
|
||||
|
||||
|
||||
# collate_fn that returns the batch of sparse coo tensors cloned
|
||||
def _sparse_coo_collate(b):
|
||||
lst = []
|
||||
for x in b:
|
||||
t = x.clone()
|
||||
lst.append(t)
|
||||
# Force sparse tensor invariants checks. check_pinning=True
|
||||
# reproduces gh-153143.
|
||||
torch._validate_sparse_coo_tensor_args(
|
||||
t._indices(), t._values(), t.size(), t.is_coalesced(), check_pinning=False
|
||||
)
|
||||
return lst
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_TSAN,
|
||||
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
||||
@ -2893,8 +2910,9 @@ class TestDataLoaderDeviceType(TestCase):
|
||||
def test_nested_tensor_multiprocessing(self, device, context):
|
||||
# The 'fork' multiprocessing context doesn't work for CUDA so skip it
|
||||
if "cuda" in device and context == "fork":
|
||||
# TODO: Skip this better in a better way when the test framework allows
|
||||
return
|
||||
self.skipTest(
|
||||
f"{context} multiprocessing context not supported for {device}"
|
||||
)
|
||||
|
||||
dataset = [
|
||||
torch.nested.nested_tensor([torch.randn(5)], device=device)
|
||||
@ -2932,6 +2950,37 @@ class TestDataLoaderDeviceType(TestCase):
|
||||
|
||||
next(iter(loader))
|
||||
|
||||
@parametrize(
|
||||
"context",
|
||||
[ctx for ctx in supported_multiprocessing_contexts if ctx is not None],
|
||||
)
|
||||
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
|
||||
def test_sparse_tensor_multiprocessing(self, device, context):
|
||||
# The 'fork' multiprocessing context doesn't work for CUDA so skip it
|
||||
if "cuda" in device and context == "fork":
|
||||
self.skipTest(
|
||||
f"{context} multiprocessing context not supported for {device}"
|
||||
)
|
||||
|
||||
dataset = [torch.randn(5, 5).to_sparse().to(device) for _ in range(10)]
|
||||
|
||||
pin_memory_settings = [False]
|
||||
if device == "cpu" and torch.cuda.is_available():
|
||||
pin_memory_settings.append(True)
|
||||
|
||||
for pin_memory in pin_memory_settings:
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
collate_fn=_sparse_coo_collate,
|
||||
pin_memory=pin_memory,
|
||||
multiprocessing_context=context,
|
||||
)
|
||||
|
||||
for i, batch in enumerate(loader):
|
||||
self.assertEqual(batch[0], dataset[i])
|
||||
|
||||
|
||||
class IntegrationTestDataLoaderDataPipe(TestCase):
|
||||
r"""
|
||||
|
@ -282,14 +282,17 @@ def _validate_loaded_sparse_tensors():
|
||||
_sparse_tensors_to_validate.clear()
|
||||
return
|
||||
try:
|
||||
# We disable pinning check (see check_pinning=False below) to
|
||||
# avoid gh-153143. In fact, pinning check is unnecessary
|
||||
# anywhy when loading sparse data from external sources.
|
||||
for t in _sparse_tensors_to_validate:
|
||||
if True:
|
||||
# Temporarily disable sparse tensor validation due to
|
||||
# gh-153143.
|
||||
pass
|
||||
elif t.layout is torch.sparse_coo:
|
||||
if t.layout is torch.sparse_coo:
|
||||
torch._validate_sparse_coo_tensor_args(
|
||||
t._indices(), t._values(), t.size(), t.is_coalesced()
|
||||
t._indices(),
|
||||
t._values(),
|
||||
t.size(),
|
||||
t.is_coalesced(),
|
||||
check_pinning=False,
|
||||
)
|
||||
elif t.layout in {
|
||||
torch.sparse_csr,
|
||||
@ -310,7 +313,12 @@ def _validate_loaded_sparse_tensors():
|
||||
t.row_indices(),
|
||||
)
|
||||
torch._validate_sparse_compressed_tensor_args(
|
||||
compressed_indices, plain_indices, t.values(), t.size(), t.layout
|
||||
compressed_indices,
|
||||
plain_indices,
|
||||
t.values(),
|
||||
t.size(),
|
||||
t.layout,
|
||||
check_pinning=False,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
Reference in New Issue
Block a user