User-controlled sparse tensor validation when loading data from external storage (#154610)

This PR lets users to control sparse tensor invariants validation (that can be expensive, especially, for sparse tensors with many indices) when loading data from external sources.

By default, the validation of sparse tensor invariants is disabled.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154610
Approved by: https://github.com/amjames, https://github.com/ngimel
This commit is contained in:
Pearu Peterson
2025-06-01 08:39:04 +03:00
committed by PyTorch MergeBot
parent 9258cfc227
commit 3f3c1f419f
2 changed files with 11 additions and 3 deletions

View File

@ -439,7 +439,7 @@ class SerializationMixin:
torch.save({"spoofed": TensorSerializationSpoofer(x)}, f)
for weights_only in (False, True):
f.seek(0)
with self.assertRaisesRegex(
with torch.sparse.check_sparse_tensor_invariants(), self.assertRaisesRegex(
RuntimeError,
"size is inconsistent with indices"):
y = torch.load(f, weights_only=weights_only)
@ -471,14 +471,15 @@ class SerializationMixin:
torch.save(sd, f)
for weights_only in (True,):
f.seek(0)
with self.assertRaisesRegex(
with torch.sparse.check_sparse_tensor_invariants(), self.assertRaisesRegex(
RuntimeError,
"size is inconsistent with indices"):
"size is inconsistent with indices|found negative index"):
y = torch.load(f, weights_only=weights_only)
finally:
if prev_invariant_check_enabled:
torch.sparse.check_sparse_tensor_invariants.enable()
@torch.sparse.check_sparse_tensor_invariants(enable=True)
def _test_serialization_sparse_compressed_invalid(self,
conversion,
get_compressed_indices,

View File

@ -274,6 +274,13 @@ _sparse_tensors_to_validate: list["torch.Tensor"] = []
# to Pickler semantics, we have to use the same (non-validating) function for
# unpickling sparse tensors, regardless of the caller.
def _validate_loaded_sparse_tensors():
if not torch.sparse.check_sparse_tensor_invariants().is_enabled():
# Skip sparse tensor invariants validation for better
# performance. See check_sparse_tensor_invariants
# documentation for how to control sparse tensor invariants
# checking.
_sparse_tensors_to_validate.clear()
return
try:
for t in _sparse_tensors_to_validate:
if True: