mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Add check-sparse-tensor-invariants flag to Context - 2nd try. (#92094)
This PR is a copy of https://github.com/pytorch/pytorch/pull/90849 that merge was reverted. The PR adds "check sparse tensor invariants" flag to Context that when enabled will trigger sparse tensor data invariants checks in unsafe methods of constructing sparse COO/CSR/CSC/BSR/BSC tensors. The feature includes the following changes to UI: `torch.sparse.check_sparse_tensor_invariants` class provides different ways to enable/disable the invariant checking. `torch.sparse_coo/csr/csc/bsr/bsc/compressed_tensor` functions have a new optional argument `check_invariants` to enable/disable the invariant checks explicitly. When the `check_invariants` argument is specified, the global state of the feature is temporarily overridden. The PR fixes https://github.com/pytorch/pytorch/issues/90833 Pull Request resolved: https://github.com/pytorch/pytorch/pull/92094 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
a111dd9014
commit
b3e4f5029b
@ -831,6 +831,27 @@ PyObject* THPModule_isEnabledXNNPACK(PyObject* _unused, PyObject* noargs) {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
PyObject* THPModule_setCheckSparseTensorInvariants(
|
||||
PyObject* _unused,
|
||||
PyObject* arg) {
|
||||
THPUtils_assert(
|
||||
PyBool_Check(arg),
|
||||
"set_check_sparse_tensor_invariants expects a bool, "
|
||||
"but got %s",
|
||||
THPUtils_typename(arg));
|
||||
at::globalContext().setCheckSparseTensorInvariants(arg == Py_True);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject* THPModule_checkSparseTensorInvariants(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
if (at::globalContext().checkSparseTensorInvariants())
|
||||
Py_RETURN_TRUE;
|
||||
else
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
bool isTHPFunction = THPFunction_Check(arg);
|
||||
@ -1122,6 +1143,14 @@ static PyMethodDef TorchMethods[] = {
|
||||
{"_set_qengine", THPModule_setQEngine, METH_O, nullptr},
|
||||
{"_supported_qengines", THPModule_supportedQEngines, METH_NOARGS, nullptr},
|
||||
{"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr},
|
||||
{"_set_check_sparse_tensor_invariants",
|
||||
THPModule_setCheckSparseTensorInvariants,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_check_sparse_tensor_invariants",
|
||||
THPModule_checkSparseTensorInvariants,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_will_engine_execute_node",
|
||||
THPModule_willEngineExecuteNode,
|
||||
METH_O,
|
||||
|
Reference in New Issue
Block a user