mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 14:15:01 +08:00
101 lines
4.6 KiB
C++
101 lines
4.6 KiB
C++
#include "THP.h"
|
|
|
|
PyObject* sparse_tensor_classes;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
// SPARSE MODULE INITIALIZATION
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static bool THSPModule_loadClasses(PyObject *sparse_module)
|
|
{
|
|
if (!THSPDoubleTensor_postInit(sparse_module)) return false;
|
|
if (!THSPFloatTensor_postInit(sparse_module)) return false;
|
|
if (!THSPLongTensor_postInit(sparse_module)) return false;
|
|
if (!THSPIntTensor_postInit(sparse_module)) return false;
|
|
if (!THSPShortTensor_postInit(sparse_module)) return false;
|
|
if (!THSPCharTensor_postInit(sparse_module)) return false;
|
|
if (!THSPByteTensor_postInit(sparse_module)) return false;
|
|
return true;
|
|
}
|
|
|
|
static bool THSPModule_assignStateless()
|
|
{
|
|
#define INIT_STATELESS(type) \
|
|
stateless = PyObject_Call((PyObject*)&TH_CONCAT_3(Sparse, type, TensorStatelessType), arg, NULL); \
|
|
if (!stateless) { \
|
|
THPUtils_setError("stateless method initialization error"); \
|
|
return false; \
|
|
} \
|
|
if (PyObject_SetAttrString(TH_CONCAT_3(THSP,type,TensorClass), THP_STATELESS_ATTRIBUTE_NAME, stateless) == -1) { \
|
|
THPUtils_setError("stateless method initialization error (on assignment)");\
|
|
}
|
|
PyObject *arg = PyTuple_New(0);
|
|
PyObject *stateless;
|
|
INIT_STATELESS(Double);
|
|
INIT_STATELESS(Float);
|
|
INIT_STATELESS(Long);
|
|
INIT_STATELESS(Int);
|
|
INIT_STATELESS(Short);
|
|
INIT_STATELESS(Char);
|
|
INIT_STATELESS(Byte);
|
|
Py_DECREF(arg);
|
|
return true;
|
|
#undef INIT_STATELESS
|
|
}
|
|
|
|
// Callback for python part. Used for additional initialization of python classes
|
|
PyObject *THSPModule_initExtension(PyObject *self)
|
|
{
|
|
PyObject *module = PyImport_ImportModule("torch.sparse");
|
|
if (!module) return NULL;
|
|
if (!THSPModule_loadClasses(module)) return NULL;
|
|
if (!THSPModule_assignStateless()) return NULL;
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
// Sparse Stateless Functions
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
bool THPModule_isSparseTensor(PyObject *obj)
|
|
{
|
|
int result = PySet_Contains(sparse_tensor_classes, (PyObject*)Py_TYPE(obj));
|
|
if (result == -1)
|
|
throw std::logic_error("FATAL: sparse_tensor_classes isn't a set!");
|
|
return result;
|
|
}
|
|
|
|
|
|
#define IMPLEMENT_SPARSE_STATELESS(name) \
|
|
static PyObject * TH_CONCAT_2(THSPModule_, name)(PyObject *_unused, PyObject *args, PyObject *kwargs) \
|
|
{ \
|
|
PyObject *tensor = THSPFloatTensorClass; \
|
|
PyObject *key, *value; \
|
|
Py_ssize_t pos = 0; \
|
|
for (int i = 0; i < PyTuple_Size(args); i++) { \
|
|
PyObject *item = PyTuple_GET_ITEM(args, i); \
|
|
if (THPModule_isTensor(item) || THPVariable_Check(item)) { \
|
|
tensor = item; \
|
|
goto dispatch; \
|
|
} \
|
|
} \
|
|
if (kwargs) { \
|
|
while (PyDict_Next(kwargs, &pos, &key, &value)) { \
|
|
if (THPModule_isTensor(value) || THPVariable_Check(value)) { \
|
|
tensor = value; \
|
|
goto dispatch; \
|
|
} \
|
|
} \
|
|
} \
|
|
\
|
|
dispatch: \
|
|
return THPUtils_dispatchStateless(tensor, #name, args, kwargs); \
|
|
}
|
|
|
|
IMPLEMENT_SPARSE_STATELESS(spmm);
|
|
IMPLEMENT_SPARSE_STATELESS(sspmm);
|
|
IMPLEMENT_SPARSE_STATELESS(sspaddmm);
|
|
IMPLEMENT_SPARSE_STATELESS(hspmm);
|
|
|
|
#undef IMPLEMENT_SPARSE_STATELESS
|