#include #include #include #include #include #include #include #include "THP.h" #include "torch/csrc/utils/python_strings.h" #include "torch/csrc/utils/invalid_arguments.h" #include "torch/csrc/DynamicTypes.h" #include "generic/utils.cpp" #include #include "generic/utils.cpp" #include #ifdef WITH_CUDA #include "torch/csrc/cuda/THCP.h" #endif int THPUtils_getCallable(PyObject *arg, PyObject **result) { if (!PyCallable_Check(arg)) return 0; *result = arg; return 1; } THLongStoragePtr THPUtils_unpackSize(PyObject *arg) { THLongStoragePtr result; if (!THPUtils_tryUnpackLongs(arg, result)) { std::string msg = "THPUtils_unpackSize() expects a torch.Size (got '"; msg += Py_TYPE(arg)->tp_name; msg += "')"; throw std::runtime_error(msg); } return result; } bool THPUtils_tryUnpackLongs(PyObject *arg, THLongStoragePtr& result) { bool tuple = PyTuple_Check(arg); bool list = PyList_Check(arg); if (tuple || list) { int nDim = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); THLongStoragePtr storage(THLongStorage_newWithSize(nDim)); for (int i = 0; i != nDim; ++i) { PyObject* item = tuple ? PyTuple_GET_ITEM(arg, i) : PyList_GET_ITEM(arg, i); if (!THPUtils_checkLong(item)) { return false; } storage->data[i] = THPUtils_unpackLong(item); } result = std::move(storage); return true; } return false; } bool THPUtils_tryUnpackLongVarArgs(PyObject *args, int ignore_first, THLongStoragePtr& result) { Py_ssize_t length = PyTuple_Size(args) - ignore_first; if (length < 1) { return false; } PyObject *first_arg = PyTuple_GET_ITEM(args, ignore_first); if (length == 1 && THPUtils_tryUnpackLongs(first_arg, result)) { return true; } // Try to parse the numbers result = THLongStorage_newWithSize(length); for (Py_ssize_t i = 0; i < length; ++i) { PyObject *arg = PyTuple_GET_ITEM(args, i + ignore_first); if (!THPUtils_checkLong(arg)) { return false; } result->data[i] = THPUtils_unpackLong(arg); } return true; } bool THPUtils_checkIntTuple(PyObject *arg) { if (!PyTuple_Check(arg)) { return false; } for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) { if (!THPUtils_checkLong(PyTuple_GET_ITEM(arg, i))) { return false; } } return true; } std::vector THPUtils_unpackIntTuple(PyObject *arg) { if (!THPUtils_checkIntTuple(arg)) { throw std::runtime_error("Couldn't unpack int tuple"); } std::vector values(PyTuple_GET_SIZE(arg)); for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) { values[i] = (int)THPUtils_unpackLong(PyTuple_GET_ITEM(arg, i)); } return values; } void THPUtils_setError(const char *format, ...) { static const size_t ERROR_BUFFER_SIZE = 1000; char buffer[ERROR_BUFFER_SIZE]; va_list fmt_args; va_start(fmt_args, format); vsnprintf(buffer, ERROR_BUFFER_SIZE, format, fmt_args); va_end(fmt_args); PyErr_SetString(PyExc_RuntimeError, buffer); } void THPUtils_addPyMethodDefs(std::vector& vector, PyMethodDef* methods) { if (!vector.empty()) { // remove NULL terminator vector.pop_back(); } while (1) { vector.push_back(*methods); if (!methods->ml_name) { break; } methods++; } } static const char* classOrTypename(PyObject* obj) { if (PyType_Check(obj)) { return ((PyTypeObject*)obj)->tp_name; } return Py_TYPE(obj)->tp_name; } PyObject * THPUtils_dispatchStateless( PyObject *tensor, const char *name, PyObject *args, PyObject *kwargs) { THPObjectPtr methods(PyObject_GetAttrString(tensor, THP_STATELESS_ATTRIBUTE_NAME)); if (!methods) { return PyErr_Format( PyExc_TypeError, "Type %s doesn't implement stateless methods", classOrTypename(tensor)); } THPObjectPtr method(PyObject_GetAttrString(methods, name)); if (!method) { return PyErr_Format( PyExc_TypeError, "Type %s doesn't implement stateless method %s", classOrTypename(tensor), name); } return PyObject_Call(method.get(), args, kwargs); } void THPUtils_invalidArguments(PyObject *given_args, PyObject *given_kwargs, const char *function_name, size_t num_options, ...) { std::vector option_strings; va_list option_list; va_start(option_list, num_options); for (size_t i = 0; i < num_options; i++) option_strings.push_back(va_arg(option_list, const char*)); va_end(option_list); PyErr_SetString(PyExc_TypeError, torch::format_invalid_args( given_args, given_kwargs, function_name, option_strings).c_str()); } template<> void THPPointer::free() { if (ptr) Py_DECREF(ptr); } template class THPPointer; static bool backCompatBroadcastWarn = false; void setBackCompatBroadcastWarn(bool warn) { backCompatBroadcastWarn = warn; } bool getBackCompatBroadcastWarn() { return backCompatBroadcastWarn; } static bool backCompatKeepdimWarn = false; void setBackCompatKeepdimWarn(bool warn) { backCompatKeepdimWarn = warn; } bool getBackCompatKeepdimWarn() { return backCompatKeepdimWarn; } bool maybeThrowBackCompatKeepdimWarn(char *func) { if(getBackCompatKeepdimWarn()) { std::ostringstream ss; ss << "backwards compatibility: call to \"" << func << "\" uses default value for keepdim which has changed default to False. Consider passing as kwarg.", PyErr_WarnEx(PyExc_UserWarning, ss.str().c_str(), 1); } return true; } std::vector THPUtils_PySequence_to_TensorList(PyObject *obj) { if (!PySequence_Check(obj)) { throw std::runtime_error("Expected a sequence in THPUtils_PySequence_to_TensorList"); } THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, NULL)); if (seq.get() == NULL) { throw std::runtime_error("expected PySequence, but got " + std::string(THPUtils_typename(obj))); } std::vector list; Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get()); for (Py_ssize_t i = 0; i < length; i++) { at::Tensor tensor = torch::createTensor(PySequence_Fast_GET_ITEM(seq.get(), i)); list.push_back(tensor); } return list; } #ifdef WITH_CUDA std::vector THPUtils_PySequence_to_THCStreamList(PyObject *obj) { if (!PySequence_Check(obj)) { throw std::runtime_error("Expected a sequence in THPUtils_PySequence_to_THCStreamList"); } THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, NULL)); if (seq.get() == NULL) { throw std::runtime_error("expected PySequence, but got " + std::string(THPUtils_typename(obj))); } std::vector streams; Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get()); for (Py_ssize_t i = 0; i < length; i++) { PyObject *stream = PySequence_Fast_GET_ITEM(seq.get(), i); if (PyObject_IsInstance(stream, THCPStreamClass)) { streams.push_back( ((THCPStream *)stream)->cdata); } else if (stream == Py_None) { streams.push_back(NULL); } else { std::runtime_error("Unknown data type found in stream list. Need THCStream or None"); } } return streams; } #endif