mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
User-controlled sparse tensor validation when loading data from external storage (#154610)
This PR lets users to control sparse tensor invariants validation (that can be expensive, especially, for sparse tensors with many indices) when loading data from external sources. By default, the validation of sparse tensor invariants is disabled. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154610 Approved by: https://github.com/amjames, https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
9258cfc227
commit
3f3c1f419f
@ -439,7 +439,7 @@ class SerializationMixin:
|
||||
torch.save({"spoofed": TensorSerializationSpoofer(x)}, f)
|
||||
for weights_only in (False, True):
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(
|
||||
with torch.sparse.check_sparse_tensor_invariants(), self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"size is inconsistent with indices"):
|
||||
y = torch.load(f, weights_only=weights_only)
|
||||
@ -471,14 +471,15 @@ class SerializationMixin:
|
||||
torch.save(sd, f)
|
||||
for weights_only in (True,):
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(
|
||||
with torch.sparse.check_sparse_tensor_invariants(), self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"size is inconsistent with indices"):
|
||||
"size is inconsistent with indices|found negative index"):
|
||||
y = torch.load(f, weights_only=weights_only)
|
||||
finally:
|
||||
if prev_invariant_check_enabled:
|
||||
torch.sparse.check_sparse_tensor_invariants.enable()
|
||||
|
||||
@torch.sparse.check_sparse_tensor_invariants(enable=True)
|
||||
def _test_serialization_sparse_compressed_invalid(self,
|
||||
conversion,
|
||||
get_compressed_indices,
|
||||
|
@ -274,6 +274,13 @@ _sparse_tensors_to_validate: list["torch.Tensor"] = []
|
||||
# to Pickler semantics, we have to use the same (non-validating) function for
|
||||
# unpickling sparse tensors, regardless of the caller.
|
||||
def _validate_loaded_sparse_tensors():
|
||||
if not torch.sparse.check_sparse_tensor_invariants().is_enabled():
|
||||
# Skip sparse tensor invariants validation for better
|
||||
# performance. See check_sparse_tensor_invariants
|
||||
# documentation for how to control sparse tensor invariants
|
||||
# checking.
|
||||
_sparse_tensors_to_validate.clear()
|
||||
return
|
||||
try:
|
||||
for t in _sparse_tensors_to_validate:
|
||||
if True:
|
||||
|
Reference in New Issue
Block a user