mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 13:34:57 +08:00
Fix handling of unicode in torch._C._add_docstr (#487)
This commit is contained in:
@ -421,31 +421,49 @@ PyObject *THPModule_safeCall(PyObject *_unused, PyObject *args, PyObject *kwargs
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string parseString(PyObject *obj)
|
||||
{
|
||||
if (PyBytes_Check(obj)) {
|
||||
return std::string(PyBytes_AS_STRING(obj));
|
||||
#if PY_MAJOR_VERSION == 3
|
||||
} else if (PyUnicode_Check(obj)) {
|
||||
return std::string(PyUnicode_AsUTF8(obj));
|
||||
#else
|
||||
} else if (PyUnicode_Check(obj)) {
|
||||
THPObjectPtr utf8 = PyUnicode_AsUTF8String(obj);
|
||||
return std::string(PyBytes_AS_STRING(utf8.get()));
|
||||
#endif
|
||||
}
|
||||
return "<invalid string>";
|
||||
}
|
||||
|
||||
PyObject *THPModule_addDocStr(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
// adds a __doc__ string to a function, similar to numpy's arr_add_docstring
|
||||
static std::vector<std::string> all_docs;
|
||||
PyObject *obj;
|
||||
PyObject *doc;
|
||||
if (!PyArg_ParseTuple(args, "OO!", &obj, &THPUtils_stringType, &doc)) {
|
||||
PyObject *doc_obj;
|
||||
if (!PyArg_ParseTuple(args, "OO", &obj, &doc_obj)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
all_docs.push_back(parseString(doc_obj));
|
||||
const char* doc_str = all_docs.back().c_str();
|
||||
|
||||
if (Py_TYPE(obj) == &PyCFunction_Type) {
|
||||
PyCFunctionObject* f = (PyCFunctionObject *)obj;
|
||||
if (f->m_ml->ml_doc) {
|
||||
return PyErr_Format(PyExc_RuntimeError,
|
||||
"function '%s' already has a docstring", f->m_ml->ml_name);
|
||||
}
|
||||
f->m_ml->ml_doc = THPUtils_stringAsString(doc);
|
||||
Py_INCREF(doc);
|
||||
f->m_ml->ml_doc = doc_str;
|
||||
} else if (strcmp(Py_TYPE(obj)->tp_name, "method_descriptor") == 0) {
|
||||
PyMethodDescrObject* m = (PyMethodDescrObject *)obj;
|
||||
if (m->d_method->ml_doc) {
|
||||
return PyErr_Format(PyExc_RuntimeError,
|
||||
"method '%s' already has a docstring", m->d_method->ml_name);
|
||||
}
|
||||
m->d_method->ml_doc = THPUtils_stringAsString(doc);
|
||||
Py_INCREF(doc);
|
||||
m->d_method->ml_doc = doc_str;
|
||||
} else {
|
||||
return PyErr_Format(PyExc_TypeError,
|
||||
"don't know how to add docstring to type '%s'", Py_TYPE(obj)->tp_name);
|
||||
|
||||
Reference in New Issue
Block a user