From c7a22bb7c7c9c15feb9e71c65a144cc8deb8cf2d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 12 Jan 2023 09:58:16 +0000 Subject: [PATCH] Revert "Add check-sparse-tensor-invariants flag to Context. (#90849)" This reverts commit b9a035c1c58630f3eef5242cb4849881b8376b39. Reverted https://github.com/pytorch/pytorch/pull/90849 on behalf of https://github.com/DanilBaibak due to Break internal build --- aten/src/ATen/Context.cpp | 8 -- aten/src/ATen/Context.h | 3 - .../ATen/native/sparse/SparseCsrTensor.cpp | 10 +- aten/src/ATen/native/sparse/SparseTensor.cpp | 20 ++- docs/source/sparse.rst | 30 +---- docs/source/torch.rst | 4 - test/test_sparse.py | 107 --------------- test/test_sparse_csr.py | 34 ++--- tools/pyi/gen_pyi.py | 9 +- torch/_C/__init__.pyi.in | 2 - torch/__init__.py | 3 +- torch/_torch_docs.py | 30 ++--- torch/_utils.py | 11 +- torch/csrc/Module.cpp | 29 ----- .../python_torch_functions_manual.cpp | 38 +++--- torch/csrc/utils/tensor_new.cpp | 122 ++++-------------- torch/sparse/__init__.py | 106 --------------- torch/testing/_internal/common_utils.py | 23 ---- 18 files changed, 96 insertions(+), 493 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index b6cda72cf1e9..dd33ded7615b 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -367,14 +367,6 @@ bool Context::isXNNPACKAvailable() { #endif } -void Context::setCheckSparseTensorInvariants(bool e) { - enable_sparse_tensor_invariant_checks = e; -} - -bool Context::checkSparseTensorInvariants() const { - return enable_sparse_tensor_invariant_checks; -} - bool Context::releaseWeightsWhenPrepacking() const { return release_original_weights; } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 8816fb0872a8..9ab289b779e0 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -250,8 +250,6 @@ class TORCH_API Context { void setQEngine(at::QEngine e); static const std::vector& supportedQEngines(); static bool isXNNPACKAvailable(); - void setCheckSparseTensorInvariants(bool e); - bool checkSparseTensorInvariants() const; // This method is used to release the original weight after pre-packing. // It should be called once before loading/running the model. // NB: By default it is set to true for mobile builds. @@ -307,7 +305,6 @@ class TORCH_API Context { #endif bool display_vmap_fallback_warnings_ = false; c10::optional quantized_engine = c10::nullopt; - bool enable_sparse_tensor_invariant_checks = false; Allocator* prev_allocator_ptr_{nullptr}; }; diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp index 3d2526c41204..9e3fa5f035b6 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp @@ -356,9 +356,6 @@ Tensor _sparse_compressed_tensor_unsafe(const Tensor& compressed_indices, } Layout layout_ = layout.value(); AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_unsafe", [&]{}); - if (at::globalContext().checkSparseTensorInvariants()) { - _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_); - } TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); SparseCsrTensor self = new_compressed_tensor(options); get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size); @@ -376,9 +373,6 @@ Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indice c10::optional pin_memory) { Layout layout_ = layout.value_or(required_layout); TORCH_CHECK(layout_ == required_layout, "sparse compressed layout must be ",required_layout, " but got ", layout_); - if (at::globalContext().checkSparseTensorInvariants()) { - _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_); - } TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); SparseCsrTensor self = new_compressed_tensor(options); get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size); @@ -480,6 +474,8 @@ Tensor sparse_compressed_tensor( // See [Note: hacky wrapper removal for TensorOptions] TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); + _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_); + return at::native::_sparse_compressed_tensor_unsafe( compressed_indices, plain_indices, @@ -511,6 +507,8 @@ Tensor sparse_compressed_tensor( // See [Note: hacky wrapper removal for TensorOptions] TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); + _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_); + return at::native::_sparse_compressed_tensor_unsafe( compressed_indices, plain_indices, diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index d24068c0a05c..37f6380757d4 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -398,6 +398,8 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRe !options.has_layout() || options.layout() == kSparse, "expected sparse layout, but got layout ", options.layout()); + + at::native::_validate_sparse_coo_tensor_args(indices, values, size); return at::native::_sparse_coo_tensor_unsafe( indices, values, @@ -413,10 +415,20 @@ Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, a c10::optional layout, c10::optional device, c10::optional pin_memory) { - if (at::globalContext().checkSparseTensorInvariants()) { - at::native::_validate_sparse_coo_tensor_args(indices, values_, size); - } - return at::native::_sparse_coo_tensor_unsafe_symint(indices, values_, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + return at::native::_sparse_coo_tensor_unsafe_symint(indices, values_, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + + Tensor values = expand_values_if_needed(values_); + + auto sparse_dim = indices.size(0); + auto dense_dim = values.dim() - 1; + + return at::_sparse_coo_tensor_with_dims_and_tensors( + sparse_dim, + dense_dim, + size, + indices, + values, + values.options().layout(kSparse)); } // NOTE: _sparse_coo_tensor_unsafe() differs from sparse_coo_tensor() diff --git a/docs/source/sparse.rst b/docs/source/sparse.rst index 377368e09c78..77e8dabec274 100644 --- a/docs/source/sparse.rst +++ b/docs/source/sparse.rst @@ -321,15 +321,6 @@ invariants: Dense dimensions always follow sparse dimensions, that is, mixing of dense and sparse dimensions is not supported. -.. note:: - - To be sure that a constructed sparse tensor has consistent indices, - values, and size, the invariant checks can be enabled per tensor - creation via ``check_invariants=True`` keyword argument, or - globally using :class:`torch.sparse.check_sparse_tensor_invariants` - context manager instance. By default, the sparse tensor invariants - checks are disabled. - .. _sparse-uncoalesced-coo-docs: Uncoalesced sparse COO tensors @@ -539,13 +530,6 @@ __ https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_o where ``plain_dim_size`` is the number of plain dimensions (orthogonal to compressed dimensions, e.g. columns or rows). - To be sure that a constructed sparse tensor has consistent indices, - values, and size, the invariant checks can be enabled per tensor - creation via ``check_invariants=True`` keyword argument, or - globally using :class:`torch.sparse.check_sparse_tensor_invariants` - context manager instance. By default, the sparse tensor invariants - checks are disabled. - .. note:: The generalization of sparse compressed layouts to N-dimensional @@ -662,9 +646,9 @@ argument is optional and will be deduced from the ``crow_indices`` and >>> csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.float64) >>> csr tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, - dtype=torch.float64) + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64) >>> csr.to_dense() tensor([[1., 2.], [3., 4.]], dtype=torch.float64) @@ -1176,14 +1160,6 @@ The following :mod:`torch` functions support sparse tensors: :func:`~torch.zeros` :func:`~torch.zeros_like` -To manage checking sparse tensor invariants, see: - -.. autosummary:: - :toctree: generated - :nosignatures: - - sparse.check_sparse_tensor_invariants - Unary functions --------------- diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 137605811795..111ee21f6d83 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -48,10 +48,6 @@ Creation Ops tensor sparse_coo_tensor - sparse_csr_tensor - sparse_csc_tensor - sparse_bsr_tensor - sparse_bsc_tensor asarray as_tensor as_strided diff --git a/test/test_sparse.py b/test/test_sparse.py index d18e3bce8da8..5304ab7eaafc 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -4072,113 +4072,6 @@ class TestSparseMeta(TestCase): class TestSparseAny(TestCase): - @onlyCPU - @all_sparse_layouts('layout', include_strided=False) - @torch.sparse.check_sparse_tensor_invariants(enable=False) - def test_check_sparse_tensor_invariants(self, layout): - - if layout is torch.sparse_coo: - - def create_invalid_tensor(check_invariants=None): - shape = (2, 2) - invalid_indices = torch.tensor([[0], [3]]) # column index is out of range - values = torch.tensor([1]) - if check_invariants is None: - return torch.sparse_coo_tensor(invalid_indices, values, shape) - else: - return torch.sparse_coo_tensor(invalid_indices, values, shape, check_invariants=check_invariants) - - expected_exception_message = 'size is inconsistent with indices: for dim 1, size is 2 but found index 3' - - elif layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}: - - def create_invalid_tensor(check_invariants=None): - shape = (2, 2) - compressed_indices = torch.tensor([0, 0, 1]) - invalid_plain_indices = torch.tensor([3]) # index is out of range - if layout in {torch.sparse_bsr, torch.sparse_bsc}: - values = torch.tensor([[[1]]]) - else: - values = torch.tensor([1]) - if check_invariants is None: - return torch.sparse_compressed_tensor(compressed_indices, invalid_plain_indices, values, shape, layout=layout) - else: - return torch.sparse_compressed_tensor(compressed_indices, invalid_plain_indices, values, shape, layout=layout, - check_invariants=check_invariants) - - if layout in {torch.sparse_csr, torch.sparse_bsr}: - expected_exception_message = r'`0 <= col_indices < ncols` is not satisfied.' - else: - expected_exception_message = r'`0 <= row_indices < nrows` is not satisfied.' - - else: - raise NotImplementedError(layout) - - # First, consider the case where invariant checks are disabled - # "globally" (read: within the context of this test method - # caller) as defined by check_sparse_tensor_invariants(False) - # decorator: - self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) - - # Enable the invariant checks in a local context: - with torch.sparse.check_sparse_tensor_invariants(): - self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) - - # Leaving the local context must restore the "global" state of - # the invariant check feature: - self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) - - # Since invariant checks are disabled by default, we can - # create an invalid sparse tensor without raising an - # exception: - r = create_invalid_tensor() - self.assertEqual(r.layout, layout) - - # Or, when disabling the invariants check explicitly: - r = create_invalid_tensor(check_invariants=False) - self.assertEqual(r.layout, layout) - - # Enabling invariant check via constructor's optional argument - # will raise an exception when sparse tensor invariants are - # violated: - with self.assertRaisesRegex(RuntimeError, expected_exception_message): - create_invalid_tensor(check_invariants=True) - - # Check that the global invariant check flag has been restored - # after raising the exception above: - self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) - - # Next, consider the case where invariant checks are enabled - # within a local context: - with torch.sparse.check_sparse_tensor_invariants(): - self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) - - # Since invariant checks are now enabled by default, an - # attempt to create an invalid sparse tensor will lead to - # an exception: - with self.assertRaisesRegex(RuntimeError, expected_exception_message): - create_invalid_tensor() - - # Similarly, when enabling the invariant checks - # explicitly, invalid sparse tensor construction will lead - # to an exception: - with self.assertRaisesRegex(RuntimeError, expected_exception_message): - create_invalid_tensor(check_invariants=True) - - # However, invariants check can be disabled via - # constructor's optional argument so that the invalid - # tensor is succesfully constructed: - r = create_invalid_tensor(check_invariants=False) - self.assertEqual(r.layout, layout) - - # Check that the invariant check flag has been restored - # when leaving the constructor: - self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) - - # Double-check restoring the global state when leaving the - # local context: - self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) - def test_generate_simple_inputs(self): layouts = [torch.strided, torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc] diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 88047b54d492..77c6b0843ba2 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -1363,17 +1363,9 @@ class TestSparseCSR(TestCase): @onlyCUDA @unittest.skipIf(not (CUDA11OrLater or TEST_WITH_ROCM), "Only CUDA 11+ is supported") - # hmm, the test passes ok on CUDA when Rocm is not available: @skipCUDAIfRocmVersionLessThan((5, 2)) @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) def test_baddbmm(self, device, dtype): - - # TODO: disable the invariant checks within torch.baddbmm that - # constructs unconventional csr tensors leading to - # RuntimeError: tensor dimensionality must be sum of batch, - # base, and dense dimensionalities (=0 + 2 + 0) but got 3 - # when invariant checking is enabled. When done, undecorate run_test. - @torch.sparse.check_sparse_tensor_invariants(enable=False) def run_test(c, a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None): alpha = complex(random.random(), random.random()) if dtype.is_complex else random.random() beta = complex(random.random(), random.random()) if dtype.is_complex else random.random() @@ -1396,8 +1388,8 @@ class TestSparseCSR(TestCase): a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype) # a_batched is a regular CSR tensor but with a batch dimension in the shape - a_batched = torch.sparse_csr_tensor( - a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False) + a_batched = torch._sparse_csr_tensor_unsafe( + a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k)) b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous) c = make_tensor((batch_size, m, n), dtype=dtype, device=device, noncontiguous=noncontiguous) @@ -1428,13 +1420,9 @@ class TestSparseCSR(TestCase): nnz = random.randint(0, m * k) a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype) - # a_batched is a regular CSR tensor but with a batch - # dimension in the shape. It is unorthodox in PyTorch - # to represent a batch sparse tensor in this way, - # hence checking the tensor invariants is locally - # turned off. - a_batched = torch.sparse_csr_tensor( - a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False) + # a_batched is a regular CSR tensor but with a batch dimension in the shape + a_batched = torch._sparse_csr_tensor_unsafe( + a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k)) b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous) for op_b, op_out in itertools.product([True, False], repeat=2): @@ -1561,8 +1549,8 @@ class TestSparseCSR(TestCase): a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype) a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device) a_data = a_data.mT if noncontiguous else a_data - a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(), - a_data, (m * block_size, k * block_size), check_invariants=False) + a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(), + a_data, (m * block_size, k * block_size)) b = make_tensor((k * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous) c = make_tensor((m * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous) for op_b, op_out in itertools.product([True, False], repeat=2): @@ -1597,8 +1585,8 @@ class TestSparseCSR(TestCase): a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype) a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device) a_data = a_data.mT if noncontiguous else a_data # Test column-major blocks - a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(), - a_data, (m * block_size, k * block_size), check_invariants=False) + a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(), + a_data, (m * block_size, k * block_size)) b = make_tensor((k * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous) c = make_tensor((m * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous) self.run_test_block_addmm_addmv(torch.addmv, c, a, b, dtype=dtype, device=device) @@ -1670,8 +1658,8 @@ class TestSparseCSR(TestCase): a = self.genSparseCSRTensor((m, m), nnz, dtype=dtype, device=device, index_dtype=index_dtype) a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device) a_data = a_data.mT if noncontiguous else a_data # Test column-major blocks - a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(), - a_data, (m * block_size, m * block_size), check_invariants=False) + a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(), + a_data, (m * block_size, m * block_size)) b = make_tensor((m * block_size, k), dtype=dtype, device=device, noncontiguous=noncontiguous) for (upper, unitriangular, transpose, op_out) in itertools.product([True, False], repeat=4): diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index a51144589d3f..447ac0a9b62e 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -364,8 +364,7 @@ def gen_pyi( f"{n2}_indices: Union[Tensor, List]," " values: Union[Tensor, List], size: Optional[_size]=None," " *, dtype: Optional[_dtype]=None," - " device: Union[_device, str, None]=None, requires_grad:_bool=False," - " check_invariants:_bool=None) -> Tensor: ..." + " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..." ], f"_sparse_{n}_tensor_unsafe": [ f"def _sparse_{n}_tensor_unsafe({n1}_indices: Union[Tensor, List]," @@ -412,8 +411,7 @@ def gen_pyi( "sparse_coo_tensor": [ "def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List]," " size: Optional[_size]=None, *, dtype: Optional[_dtype]=None," - " device: Union[_device, str, None]=None, requires_grad:_bool=False," - " check_invariants:_bool=None) -> Tensor: ..." + " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..." ], "_sparse_coo_tensor_unsafe": [ "def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int]," @@ -425,8 +423,7 @@ def gen_pyi( "plain_indices: Union[Tensor, List]," " values: Union[Tensor, List], size: Optional[_size]=None," " *, dtype: Optional[_dtype]=None, layout: Optional[_layout] = None," - " device: Union[_device, str, None]=None, requires_grad:_bool=False," - " check_invariants:_bool=None) -> Tensor: ..." + " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..." ], "_sparse_compressed_tensor_unsafe": [ "def _sparse_compressed_tensor_unsafe(comp_indices: Union[Tensor, List]," diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 4d3f020ddce3..9cb7b1a27af1 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -875,8 +875,6 @@ def _get_qengine() -> _int: ... # THPModule_qEngine def _set_qengine(qegine: _int) -> None: ... # THPModule_setQEngine def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK -def _check_sparse_tensor_invariants() -> _bool: ... # THPModule_checkSparseTensorInvariants -def _set_check_sparse_tensor_invariants(arg: _bool) -> None: ... # THPModule_setCheckSparseTensorInvariants def _set_default_mobile_cpu_allocator() -> None: ... # THPModule_setDefaultMobileCPUAllocator def _unset_default_mobile_cpu_allocator() -> None: ... # THPModule_unsetDefaultMobileCPUAllocator def _is_torch_function_enabled() -> _bool: ... # THPModule_isEnabledTorchFunction diff --git a/torch/__init__.py b/torch/__init__.py index c76764092525..1783b6a0d4f7 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -50,7 +50,8 @@ __all__ = [ 'set_deterministic_debug_mode', 'get_deterministic_debug_mode', 'set_float32_matmul_precision', 'get_float32_matmul_precision', 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat', - 'sym_int', 'sym_float', 'compile', 'vmap'] + 'sym_int', 'sym_float', 'compile', 'vmap' +] ################################################################################ # Load the extension module diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 9caa080551ef..a0f2e78e9df5 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -106,9 +106,6 @@ factory_common_args = merge_dicts( the pinned memory. Works only for CPU tensors. Default: ``False``. memory_format (:class:`torch.memory_format`, optional): the desired memory format of returned Tensor. Default: ``torch.contiguous_format``. - check_invariants (bool, optional): If sparse tensor invariants are checked. - Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, - initially False. """ ), { @@ -10164,7 +10161,7 @@ Example:: add_docstr( torch.sparse_compressed_tensor, r"""sparse_compressed_tensor(compressed_indices, plain_indices, values, size=None, """ - r"""*, dtype=None, layout=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + r"""*, dtype=None, layout=None, device=None, requires_grad=False) -> Tensor Constructs a :ref:`sparse tensor in Compressed Sparse format - CSR, CSC, BSR, or BSC - ` with specified values at @@ -10216,7 +10213,6 @@ Keyword args: the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. {requires_grad} - {check_invariants} Example:: >>> compressed_indices = [0, 2, 4] @@ -10236,8 +10232,8 @@ Example:: add_docstr( torch.sparse_csr_tensor, - r"""sparse_csr_tensor(crow_indices, col_indices, values, size=None, """ - r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + r""" +sparse_csr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor Constructs a :ref:`sparse tensor in CSR (Compressed Sparse Row) ` with specified values at the given :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix multiplication operations @@ -10277,7 +10273,6 @@ Keyword args: the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. {requires_grad} - {check_invariants} Example:: >>> crow_indices = [0, 2, 4] @@ -10297,8 +10292,8 @@ Example:: add_docstr( torch.sparse_csc_tensor, - r"""sparse_csc_tensor(ccol_indices, row_indices, values, size=None, """ - r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + r""" +sparse_csc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor Constructs a :ref:`sparse tensor in CSC (Compressed Sparse Column) ` with specified values at the given @@ -10340,7 +10335,6 @@ Keyword args: the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. {requires_grad} - {check_invariants} Example:: >>> ccol_indices = [0, 2, 4] @@ -10360,8 +10354,8 @@ Example:: add_docstr( torch.sparse_bsr_tensor, - r"""sparse_bsr_tensor(crow_indices, col_indices, values, size=None, """ - r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + r""" +sparse_bsr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor Constructs a :ref:`sparse tensor in BSR (Block Compressed Sparse Row)) ` with specified 2-dimensional blocks at the given @@ -10405,7 +10399,6 @@ Keyword args: the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. {requires_grad} - {check_invariants} Example:: >>> crow_indices = [0, 1, 2] @@ -10428,8 +10421,8 @@ Example:: add_docstr( torch.sparse_bsc_tensor, - r"""sparse_bsc_tensor(ccol_indices, row_indices, values, size=None, """ - r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + r""" +sparse_bsc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor Constructs a :ref:`sparse tensor in BSC (Block Compressed Sparse Column)) ` with specified 2-dimensional blocks at the @@ -10472,7 +10465,6 @@ Keyword args: the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. {requires_grad} - {check_invariants} Example:: >>> ccol_indices = [0, 1, 2] @@ -10496,7 +10488,7 @@ Example:: add_docstr( torch.sparse_coo_tensor, r""" -sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor +sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor Constructs a :ref:`sparse tensor in COO(rdinate) format ` with specified values at the given @@ -10528,7 +10520,7 @@ Keyword args: (see :func:`torch.set_default_tensor_type`). :attr:`device` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. {requires_grad} - {check_invariants} + Example:: diff --git a/torch/_utils.py b/torch/_utils.py index e9a07e86a09d..d00d27571c25 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -237,7 +237,7 @@ def _rebuild_sparse_tensor(layout, data): """ if layout == torch.sparse_coo: indices, values, size = data - result = torch.sparse_coo_tensor(indices, values, size, check_invariants=False) + result = torch._sparse_coo_tensor_unsafe(indices, values, size) _sparse_tensors_to_validate.append(result) return result @@ -248,13 +248,8 @@ def _rebuild_sparse_tensor(layout, data): torch.sparse_bsc, }: compressed_indices, plain_indices, values, size = data - result = torch.sparse_compressed_tensor( - compressed_indices, - plain_indices, - values, - size, - layout=layout, - check_invariants=False, + result = torch._sparse_compressed_tensor_unsafe( + compressed_indices, plain_indices, values, size, layout=layout ) _sparse_tensors_to_validate.append(result) return result diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 1f4d9ac30161..f5ee578fd2bd 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -831,27 +831,6 @@ PyObject* THPModule_isEnabledXNNPACK(PyObject* _unused, PyObject* noargs) { Py_RETURN_FALSE; } -PyObject* THPModule_setCheckSparseTensorInvariants( - PyObject* _unused, - PyObject* arg) { - THPUtils_assert( - PyBool_Check(arg), - "set_check_sparse_tensor_invariants expects a bool, " - "but got %s", - THPUtils_typename(arg)); - at::globalContext().setCheckSparseTensorInvariants(arg == Py_True); - Py_RETURN_NONE; -} - -PyObject* THPModule_checkSparseTensorInvariants( - PyObject* _unused, - PyObject* noargs) { - if (at::globalContext().checkSparseTensorInvariants()) - Py_RETURN_TRUE; - else - Py_RETURN_FALSE; -} - PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS bool isTHPFunction = THPFunction_Check(arg); @@ -1143,14 +1122,6 @@ static PyMethodDef TorchMethods[] = { {"_set_qengine", THPModule_setQEngine, METH_O, nullptr}, {"_supported_qengines", THPModule_supportedQEngines, METH_NOARGS, nullptr}, {"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr}, - {"_set_check_sparse_tensor_invariants", - THPModule_setCheckSparseTensorInvariants, - METH_O, - nullptr}, - {"_check_sparse_tensor_invariants", - THPModule_checkSparseTensorInvariants, - METH_NOARGS, - nullptr}, {"_will_engine_execute_node", THPModule_willEngineExecuteNode, METH_O, diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index f444ca869fbd..6aaaaf0eff6e 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -197,29 +197,29 @@ static PyObject* THPVariable_nonzero( THPVARIABLE_SPARSE_COMPRESSED_CTOR( sparse_compressed_tensor, - 10, - ({"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)", - "sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"})) + 9, + ({"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", + "sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) THPVARIABLE_SPARSE_COMPRESSED_CTOR( sparse_csr_tensor, - 10, - ({"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)", - "sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"})) + 9, + ({"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", + "sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) THPVARIABLE_SPARSE_COMPRESSED_CTOR( sparse_csc_tensor, - 10, - ({"sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)", - "sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"})) + 9, + ({"sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", + "sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) THPVARIABLE_SPARSE_COMPRESSED_CTOR( sparse_bsr_tensor, - 10, - ({"sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)", - "sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"})) + 9, + ({"sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", + "sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) THPVARIABLE_SPARSE_COMPRESSED_CTOR( sparse_bsc_tensor, - 10, - ({"sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)", - "sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"})) + 9, + ({"sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", + "sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) THPVARIABLE_SPARSE_COMPRESSED_CTOR( _sparse_compressed_tensor_unsafe, @@ -248,12 +248,12 @@ static PyObject* THPVariable_sparse_coo_tensor( PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ - "sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)", - "sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)", - "sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)", + "sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", + "sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", + "sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", }); - ParsedArgs<7> parsed_args; + ParsedArgs<6> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (r.has_torch_function()) { return handle_torch_function( diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 434ecfa9697c..37e121bd0392 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -784,19 +784,6 @@ Tensor indexing_tensor_from_data( } } -class CheckSparseTensorInvariantsContext { - public: - CheckSparseTensorInvariantsContext() { - state = at::globalContext().checkSparseTensorInvariants(); - } - ~CheckSparseTensorInvariantsContext() { - at::globalContext().setCheckSparseTensorInvariants(state); - } - - private: - bool state; -}; - Tensor sparse_compressed_tensor_ctor_worker( std::string name, c10::DispatchKey dispatch_key, @@ -815,7 +802,6 @@ Tensor sparse_compressed_tensor_ctor_worker( ARG_DEVICE, ARG_PIN_MEMORY, ARG_REQUIRES_GRAD, - ARG_CHECK_INVARIANTS, ARGS_COUNT }; enum { @@ -825,7 +811,6 @@ Tensor sparse_compressed_tensor_ctor_worker( ARG_DEVICE1, ARG_PIN_MEMORY1, ARG_REQUIRES_GRAD1, - ARG_CHECK_INVARIANTS1, ARGS_COUNT1 }; @@ -855,10 +840,6 @@ Tensor sparse_compressed_tensor_ctor_worker( at::ScalarType plain_indices_scalar_type = plain_indices_dtype_attr ? reinterpret_cast(plain_indices_dtype_attr.get())->scalar_type : kInt; - CheckSparseTensorInvariantsContext - restores_check_sparse_tensor_invariants_global_state{}; - bool default_check_invariants = - at::globalContext().checkSparseTensorInvariants(); if (r.idx == 0) { bool type_inference = r.isNone(ARG_TYPE); @@ -867,10 +848,6 @@ Tensor sparse_compressed_tensor_ctor_worker( const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE)); - // the global state of invariants check flag will be restored via - // CheckSparseTensorInvariantsContext destructor - at::globalContext().setCheckSparseTensorInvariants( - r.toBoolWithDefault(ARG_CHECK_INVARIANTS, default_check_invariants)); Tensor values = internal_new_from_data( inferred_options, @@ -923,10 +900,6 @@ Tensor sparse_compressed_tensor_ctor_worker( const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE1, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE1)); - // the global state of invariants check flag will be restored via - // CheckSparseTensorInvariantsContext destructor - at::globalContext().setCheckSparseTensorInvariants( - r.toBoolWithDefault(ARG_CHECK_INVARIANTS1, default_check_invariants)); Tensor values = internal_new_from_data( inferred_options, @@ -1197,54 +1170,17 @@ Tensor sparse_coo_tensor_ctor( PythonArgs& r) { TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key))); TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key))); - enum { - ARG_INDICES = 0, - ARG_VALUES, - ARG_TYPE, - ARG_DEVICE, - ARG_REQUIRES_GRAD, - ARG_CHECK_INVARIANTS, - ARGS_COUNT - }; - enum { - ARG_INDICES1 = 0, - ARG_VALUES1, - ARG_SIZE1, - ARG_TYPE1, - ARG_DEVICE1, - ARG_REQUIRES_GRAD1, - ARG_CHECK_INVARIANTS1, - ARGS_COUNT1 - }; - enum { - ARG_SIZE2 = 0, - ARG_TYPE2, - ARG_DEVICE2, - ARG_REQUIRES_GRAD2, - ARG_CHECK_INVARIANTS2, - ARGS_COUNT2 - }; - CheckSparseTensorInvariantsContext - restores_check_sparse_tensor_invariants_global_state{}; - bool default_check_invariants = - at::globalContext().checkSparseTensorInvariants(); - if (r.idx == 0) { - bool type_inference = r.isNone(ARG_TYPE); - const auto inferred_options = - typeIdWithDefault(r, ARG_DEVICE, dispatch_key); - const auto inferred_scalar_type = - r.scalartypeWithDefault(ARG_TYPE, scalar_type); - at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE)); - at::globalContext().setCheckSparseTensorInvariants( - r.toBoolWithDefault(ARG_CHECK_INVARIANTS, default_check_invariants)); - + bool type_inference = r.isNone(2); + const auto inferred_options = typeIdWithDefault(r, 3, dispatch_key); + const auto inferred_scalar_type = r.scalartypeWithDefault(2, scalar_type); + at::OptionalDeviceGuard device_guard(r.deviceOptional(3)); // if no dtype provided, infer type based on value type. Tensor values = internal_new_from_data( inferred_options, inferred_scalar_type, - r.deviceOptional(ARG_DEVICE), - r.pyobject(ARG_VALUES), + r.deviceOptional(3), + r.pyobject(1), /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/type_inference); @@ -1252,29 +1188,24 @@ Tensor sparse_coo_tensor_ctor( Tensor indices = internal_new_from_data( values.options(), kLong, - r.deviceOptional(ARG_DEVICE), - r.pyobject(ARG_INDICES), + r.deviceOptional(3), + r.pyobject(0), /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/false); return at::sparse_coo_tensor( indices, values, values.options().layout(at::kSparse)) - .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD)); + .set_requires_grad(r.toBool(4)); } else if (r.idx == 1) { - bool type_inference = r.isNone(ARG_TYPE1); - const auto inferred_options = - typeIdWithDefault(r, ARG_DEVICE1, dispatch_key); - const auto inferred_scalar_type = - r.scalartypeWithDefault(ARG_TYPE1, scalar_type); - at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE1)); - at::globalContext().setCheckSparseTensorInvariants( - r.toBoolWithDefault(ARG_CHECK_INVARIANTS1, default_check_invariants)); - + bool type_inference = r.isNone(3); + const auto inferred_options = typeIdWithDefault(r, 4, dispatch_key); + const auto inferred_scalar_type = r.scalartypeWithDefault(3, scalar_type); + at::OptionalDeviceGuard device_guard(r.deviceOptional(4)); Tensor values = internal_new_from_data( inferred_options, inferred_scalar_type, - r.deviceOptional(ARG_DEVICE1), - r.pyobject(ARG_VALUES1), + r.deviceOptional(4), + r.pyobject(1), /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/type_inference); @@ -1282,30 +1213,25 @@ Tensor sparse_coo_tensor_ctor( Tensor indices = internal_new_from_data( values.options(), kLong, - r.deviceOptional(ARG_DEVICE1), - r.pyobject(ARG_INDICES1), + r.deviceOptional(4), + r.pyobject(0), /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/false); return at::sparse_coo_tensor( indices, values, - r.intlist(ARG_SIZE1), + r.intlist(2), values.options().layout(at::kSparse)) - .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD1)); + .set_requires_grad(r.toBool(5)); } else if (r.idx == 2) { - const auto inferred_options = - typeIdWithDefault(r, ARG_DEVICE2, dispatch_key); - const auto inferred_scalar_type = - r.scalartypeWithDefault(ARG_TYPE2, scalar_type); - at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE2)); - at::globalContext().setCheckSparseTensorInvariants( - r.toBoolWithDefault(ARG_CHECK_INVARIANTS2, default_check_invariants)); - + const auto inferred_options = typeIdWithDefault(r, 2, dispatch_key); + const auto inferred_scalar_type = r.scalartypeWithDefault(1, scalar_type); + at::OptionalDeviceGuard device_guard(r.deviceOptional(2)); return at::sparse_coo_tensor( - r.intlist(ARG_SIZE2), + r.intlist(0), inferred_options.dtype(inferred_scalar_type).layout(at::kSparse)) - .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD2)); + .set_requires_grad(r.toBool(3)); } throw std::runtime_error("sparse_coo_tensor(): invalid arguments"); } diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index e3386899d347..a130ef4784c9 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -18,7 +18,6 @@ else: __all__ = [ 'addmm', - 'check_sparse_tensor_invariants', 'mm', 'sum', 'softmax', @@ -358,108 +357,3 @@ Specifying a positive offset:: [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) """) - - -class check_sparse_tensor_invariants(object): - """A tool to control checking sparse tensor invariants. - -The following options exists to manage sparsr tensor invariants -checking in sparse tensor construction: - -1. Using a context manager: - - .. code:: python - - with torch.sparse.check_sparse_tensor_invariants(): - run_my_model() - -2. Using a procedural approach: - - .. code:: python - - prev_checks_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled() - torch.sparse.check_sparse_tensor_invariants.enable() - - run_my_model() - - if not prev_checks_enabled: - torch.sparse.check_sparse_tensor_invariants.disable() - -3. Using function decoration: - - .. code:: python - - @torch.sparse.check_sparse_tensor_invariants() - def run_my_model(): - ... - - run_my_model() - -4. Using ``check_invariants`` keyword argument in sparse tensor constructor call. - For example: - - >>> torch.sparse_csr_tensor([0, 1, 3], [0, 1], [1, 2], check_invariants=True) - Traceback (most recent call last): - File "", line 1, in - RuntimeError: `crow_indices[..., -1] == nnz` is not satisfied. - """ - - @staticmethod - def is_enabled(): - r"""Returns True if the sparse tensor invariants checking is enabled. - -.. note:: - - Use :func:`torch.sparse.check_sparse_tensor_invariants.enable` or - :func:`torch.sparse.check_sparse_tensor_invariants.disable` to - manage the state of the sparse tensor invariants checks. - """ - return torch._C._check_sparse_tensor_invariants() - - @staticmethod - def enable(): - r"""Enable sparse tensor invariants checking in sparse tensor constructors. - -.. note:: - - By default, the sparse tensor invariants checks are disabled. Use - :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled` to - retrieve the current state of sparse tensor invariants checking. - -.. note:: - - The sparse tensor invariants check flag is effective to all sparse - tensor constructors, both in Python and ATen. - - The flag can be locally overridden by the ``check_invariants`` - optional argument of the sparse tensor constructor functions. - """ - torch._C._set_check_sparse_tensor_invariants(True) - - @staticmethod - def disable(): - r"""Disable sparse tensor invariants checking in sparse tensor constructors. - -See :func:`torch.sparse.check_sparse_tensor_invariants.enable` for more information. - """ - torch._C._set_check_sparse_tensor_invariants(False) - - # context manager support - def __init__(self, enable=True): - self.state = enable - self.saved_state = self.is_enabled() - - def __enter__(self): - torch._C._set_check_sparse_tensor_invariants(self.state) - - def __exit__(self, type, value, traceback): - torch._C._set_check_sparse_tensor_invariants(self.saved_state) - - # decorator support - def __call__(self, mth): - - def test_mth(*args, **kwargs): - with type(self)(self.state): - return mth(*args, **kwargs) - - return test_mth diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 2d3b8836a7aa..edc7bba99ebf 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2216,29 +2216,6 @@ class TestCase(expecttest.TestCase): check_if_enable(self) set_rng_seed(SEED) - # Save global check sparse tensor invariants state that can be - # restored from tearDown: - self._check_invariants = torch.sparse.check_sparse_tensor_invariants.is_enabled() - - # Enable invariant checks for all sparse tensors constructions - # including the unsafe ones. If this is not desired for some - # test case, use check_invariants=False optional argument to - # sparse tensor constructors or - # @torch.sparse.check_sparse_tensor_invariants(False) - # decorator to disable the invariant checks. - torch.sparse.check_sparse_tensor_invariants.enable() - - def tearDown(self): - # There exists test cases that override TestCase.setUp - # definition, so we cannot assume that _check_invariants - # attribute is defined in general. - if hasattr(self, '_check_invariants'): - # Restore the global check sparse tensor invariants state - if self._check_invariants: - torch.sparse.check_sparse_tensor_invariants.enable() - else: - torch.sparse.check_sparse_tensor_invariants.disable() - @staticmethod def _make_crow_indices(n_rows, n_cols, nnz, *, device, dtype, random=True):