Fix handling of unicode in torch._C._add_docstr (#487)

This commit is contained in:
Sam Gross
2017-01-18 17:22:30 -05:00
committed by GitHub
parent 99f4864674
commit c414bf0aaf
5 changed files with 81 additions and 68 deletions

View File

@ -421,31 +421,49 @@ PyObject *THPModule_safeCall(PyObject *_unused, PyObject *args, PyObject *kwargs
return result;
}
static std::string parseString(PyObject *obj)
{
if (PyBytes_Check(obj)) {
return std::string(PyBytes_AS_STRING(obj));
#if PY_MAJOR_VERSION == 3
} else if (PyUnicode_Check(obj)) {
return std::string(PyUnicode_AsUTF8(obj));
#else
} else if (PyUnicode_Check(obj)) {
THPObjectPtr utf8 = PyUnicode_AsUTF8String(obj);
return std::string(PyBytes_AS_STRING(utf8.get()));
#endif
}
return "<invalid string>";
}
PyObject *THPModule_addDocStr(PyObject *_unused, PyObject *args)
{
// adds a __doc__ string to a function, similar to numpy's arr_add_docstring
static std::vector<std::string> all_docs;
PyObject *obj;
PyObject *doc;
if (!PyArg_ParseTuple(args, "OO!", &obj, &THPUtils_stringType, &doc)) {
PyObject *doc_obj;
if (!PyArg_ParseTuple(args, "OO", &obj, &doc_obj)) {
return NULL;
}
all_docs.push_back(parseString(doc_obj));
const char* doc_str = all_docs.back().c_str();
if (Py_TYPE(obj) == &PyCFunction_Type) {
PyCFunctionObject* f = (PyCFunctionObject *)obj;
if (f->m_ml->ml_doc) {
return PyErr_Format(PyExc_RuntimeError,
"function '%s' already has a docstring", f->m_ml->ml_name);
}
f->m_ml->ml_doc = THPUtils_stringAsString(doc);
Py_INCREF(doc);
f->m_ml->ml_doc = doc_str;
} else if (strcmp(Py_TYPE(obj)->tp_name, "method_descriptor") == 0) {
PyMethodDescrObject* m = (PyMethodDescrObject *)obj;
if (m->d_method->ml_doc) {
return PyErr_Format(PyExc_RuntimeError,
"method '%s' already has a docstring", m->d_method->ml_name);
}
m->d_method->ml_doc = THPUtils_stringAsString(doc);
Py_INCREF(doc);
m->d_method->ml_doc = doc_str;
} else {
return PyErr_Format(PyExc_TypeError,
"don't know how to add docstring to type '%s'", Py_TYPE(obj)->tp_name);