mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 06:44:55 +08:00
Add check-sparse-tensor-invariants flag to Context. (#90849)
This PR adds "check sparse tensor invariants" flag to Context that when enabled will trigger sparse tensor data invariants checks in unsafe methods of constructing sparse COO/CSR/CSC/BSR/BSC tensors. The feature includes the following changes to UI: - `torch.enable_check_sparse_tensor_invariants` and `torch.is_check_sparse_tensor_invariants_enabled` functions to globally enable/disable the invariant checks and to retrieve the state of the feature, respectively - `torch.sparse_coo/csr/csc/bsr/bsc/compressed_tensor` functions have a new optional argument `check_invariants` to enable/disable the invariant checks explicitly. When the `check_invariants` argument is specified, the global state of the feature is temporarily overridden. The PR also fixes https://github.com/pytorch/pytorch/issues/90833 # Main issue *The following content is outdated after merging the PRs in this ghstack but kept for the record.* The importance of this feature is that when enabling the invariants checks by default, say, via <details> ``` $ git diff diff --git a/torch/__init__.py b/torch/__init__.py index c8543057c7..19a91d0482 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1239,3 +1239,8 @@ if 'TORCH_CUDA_SANITIZER' in os.environ: # Populate magic methods on SymInt and SymFloat import torch.fx.experimental.symbolic_shapes + +# temporarily enable sparse tensor arguments validation in unsafe +# constructors: + +torch._C._set_check_sparse_tensor_invariants(True) ``` </details> a massive number of test failures/errors occur in test_sparse_csr.py tests: ``` $ pytest -sv test/test_sparse_csr.py <snip> ==== 4293 failed, 1557 passed, 237 skipped, 2744 errors in 69.71s (0:01:09) ==== ``` that means that we are silently constructing sparse compressed tensors that do not satisfy the sparse tensor invariants. In particular, the following errors are raised: ``` AssertionError: "resize_as_sparse_compressed_tensor_: self and src must have the same layout" does not match "expected values to be a strided and contiguous tensor" RuntimeError: CUDA error: device-side assert triggered RuntimeError: `col_indices[..., crow_indices[..., i - 1]:crow_indices[..., i]] for all i = 1, ..., nrows are sorted and distinct along the last dimension values` is not satisfied. RuntimeError: expected col_indices to be a strided and contiguous tensor RuntimeError: expected row_indices to be a strided and contiguous tensor RuntimeError: expected values to be a strided and contiguous tensor RuntimeError: for_each: failed to synchronize: cudaErrorAssert: device-side assert triggered RuntimeError: tensor dimensionality must be sum of batch, base, and dense dimensionalities (=0 + 2 + 0) but got 3 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/90849 Approved by: https://github.com/amjames, https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
949f25be0c
commit
b9a035c1c5
@ -18,6 +18,7 @@ else:
|
||||
|
||||
__all__ = [
|
||||
'addmm',
|
||||
'check_sparse_tensor_invariants',
|
||||
'mm',
|
||||
'sum',
|
||||
'softmax',
|
||||
@ -356,3 +357,108 @@ Specifying a positive offset::
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]])
|
||||
""")
|
||||
|
||||
|
||||
class check_sparse_tensor_invariants(object):
|
||||
"""A tool to control checking sparse tensor invariants.
|
||||
|
||||
The following options exists to manage sparsr tensor invariants
|
||||
checking in sparse tensor construction:
|
||||
|
||||
1. Using a context manager:
|
||||
|
||||
.. code:: python
|
||||
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
run_my_model()
|
||||
|
||||
2. Using a procedural approach:
|
||||
|
||||
.. code:: python
|
||||
|
||||
prev_checks_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled()
|
||||
torch.sparse.check_sparse_tensor_invariants.enable()
|
||||
|
||||
run_my_model()
|
||||
|
||||
if not prev_checks_enabled:
|
||||
torch.sparse.check_sparse_tensor_invariants.disable()
|
||||
|
||||
3. Using function decoration:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@torch.sparse.check_sparse_tensor_invariants()
|
||||
def run_my_model():
|
||||
...
|
||||
|
||||
run_my_model()
|
||||
|
||||
4. Using ``check_invariants`` keyword argument in sparse tensor constructor call.
|
||||
For example:
|
||||
|
||||
>>> torch.sparse_csr_tensor([0, 1, 3], [0, 1], [1, 2], check_invariants=True)
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
RuntimeError: `crow_indices[..., -1] == nnz` is not satisfied.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def is_enabled():
|
||||
r"""Returns True if the sparse tensor invariants checking is enabled.
|
||||
|
||||
.. note::
|
||||
|
||||
Use :func:`torch.sparse.check_sparse_tensor_invariants.enable` or
|
||||
:func:`torch.sparse.check_sparse_tensor_invariants.disable` to
|
||||
manage the state of the sparse tensor invariants checks.
|
||||
"""
|
||||
return torch._C._check_sparse_tensor_invariants()
|
||||
|
||||
@staticmethod
|
||||
def enable():
|
||||
r"""Enable sparse tensor invariants checking in sparse tensor constructors.
|
||||
|
||||
.. note::
|
||||
|
||||
By default, the sparse tensor invariants checks are disabled. Use
|
||||
:func:`torch.sparse.check_sparse_tensor_invariants.is_enabled` to
|
||||
retrieve the current state of sparse tensor invariants checking.
|
||||
|
||||
.. note::
|
||||
|
||||
The sparse tensor invariants check flag is effective to all sparse
|
||||
tensor constructors, both in Python and ATen.
|
||||
|
||||
The flag can be locally overridden by the ``check_invariants``
|
||||
optional argument of the sparse tensor constructor functions.
|
||||
"""
|
||||
torch._C._set_check_sparse_tensor_invariants(True)
|
||||
|
||||
@staticmethod
|
||||
def disable():
|
||||
r"""Disable sparse tensor invariants checking in sparse tensor constructors.
|
||||
|
||||
See :func:`torch.sparse.check_sparse_tensor_invariants.enable` for more information.
|
||||
"""
|
||||
torch._C._set_check_sparse_tensor_invariants(False)
|
||||
|
||||
# context manager support
|
||||
def __init__(self, enable=True):
|
||||
self.state = enable
|
||||
self.saved_state = self.is_enabled()
|
||||
|
||||
def __enter__(self):
|
||||
torch._C._set_check_sparse_tensor_invariants(self.state)
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
torch._C._set_check_sparse_tensor_invariants(self.saved_state)
|
||||
|
||||
# decorator support
|
||||
def __call__(self, mth):
|
||||
|
||||
def test_mth(*args, **kwargs):
|
||||
with type(self)(self.state):
|
||||
return mth(*args, **kwargs)
|
||||
|
||||
return test_mth
|
||||
|
||||
Reference in New Issue
Block a user