Add check-sparse-tensor-invariants flag to Context - 2nd try. (#92094)

This PR is a copy of https://github.com/pytorch/pytorch/pull/90849 that merge was reverted.

The PR adds "check sparse tensor invariants" flag to Context that when enabled will trigger sparse tensor data invariants checks in unsafe methods of constructing sparse COO/CSR/CSC/BSR/BSC tensors. The feature includes the following changes to UI:

`torch.sparse.check_sparse_tensor_invariants` class provides different ways to enable/disable the invariant checking.

`torch.sparse_coo/csr/csc/bsr/bsc/compressed_tensor` functions have a new optional argument `check_invariants` to enable/disable the invariant checks explicitly. When the `check_invariants` argument is specified, the global state of the feature is temporarily overridden.

The PR fixes https://github.com/pytorch/pytorch/issues/90833

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92094
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2023-01-12 21:00:28 +02:00
committed by PyTorch MergeBot
parent a111dd9014
commit b3e4f5029b
18 changed files with 493 additions and 96 deletions

View File

@ -831,6 +831,27 @@ PyObject* THPModule_isEnabledXNNPACK(PyObject* _unused, PyObject* noargs) {
Py_RETURN_FALSE;
}
PyObject* THPModule_setCheckSparseTensorInvariants(
PyObject* _unused,
PyObject* arg) {
THPUtils_assert(
PyBool_Check(arg),
"set_check_sparse_tensor_invariants expects a bool, "
"but got %s",
THPUtils_typename(arg));
at::globalContext().setCheckSparseTensorInvariants(arg == Py_True);
Py_RETURN_NONE;
}
PyObject* THPModule_checkSparseTensorInvariants(
PyObject* _unused,
PyObject* noargs) {
if (at::globalContext().checkSparseTensorInvariants())
Py_RETURN_TRUE;
else
Py_RETURN_FALSE;
}
PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
bool isTHPFunction = THPFunction_Check(arg);
@ -1122,6 +1143,14 @@ static PyMethodDef TorchMethods[] = {
{"_set_qengine", THPModule_setQEngine, METH_O, nullptr},
{"_supported_qengines", THPModule_supportedQEngines, METH_NOARGS, nullptr},
{"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr},
{"_set_check_sparse_tensor_invariants",
THPModule_setCheckSparseTensorInvariants,
METH_O,
nullptr},
{"_check_sparse_tensor_invariants",
THPModule_checkSparseTensorInvariants,
METH_NOARGS,
nullptr},
{"_will_engine_execute_node",
THPModule_willEngineExecuteNode,
METH_O,