mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Partially fixes: https://github.com/pytorch/pytorch/issues/394 Implementation detail: Codegen is modified to generate codes that looks like below: ```C++ static PyObject * THPVariable_svd(PyObject* self_, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ "svd(Tensor input, bool some=True, bool compute_uv=True, *, TensorList[3] out=None)", }, /*traceable=*/true); ParsedArgs<6> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); static PyStructSequence_Field fields0[] = { {"U", ""}, {"S", ""}, {"V", ""}, {nullptr} }; static PyStructSequence_Desc desc0 = { "torch.return_types.svd_out", nullptr, fields0, 3 }; static PyTypeObject type0; static bool namedtuple_type_initialized0 = false; if (!namedtuple_type_initialized0) { PyStructSequence_InitType(&type0, &desc0); namedtuple_type_initialized0 = true; } static PyStructSequence_Field fields1[] = { {"U", ""}, {"S", ""}, {"V", ""}, {nullptr} }; static PyStructSequence_Desc desc1 = { "torch.return_types.svd", nullptr, fields1, 3 }; static PyTypeObject type1; static bool namedtuple_type_initialized1 = false; if (!namedtuple_type_initialized1) { PyStructSequence_InitType(&type1, &desc1); namedtuple_type_initialized1 = true; } if (r.idx == 0) { if (r.isNone(3)) { return wrap(&type1, dispatch_svd(r.tensor(0), r.toBool(1), r.toBool(2))); } else { auto results = r.tensorlist_n<3>(3); return wrap(&type0, dispatch_svd(r.tensor(0), r.toBool(1), r.toBool(2), results[0], results[1], results[2])); } } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } ``` Types are defined as static member of `THPVariable_${op_name}` functions, and initialized at the first time the function is called. When parsing function prototypes in `native_functions.yaml`, the parser will set the specified name as `field_name` when see things like `-> (Tensor t1, ...)`. These field names will be the field names of namedtuple. The class of namedtuples will be named `torch.return_types.${op_name}`. In some python 2, `PyStructSequence` is not a subtype of tuple, so we have to create some functions to check if an object is a tuple or namedtuple for compatibility issue. Operators in `native_functions.yaml` are changed such that only `max` and `svd` are generated as namedtuple. Tests are added for these two operators to see if the return value works as expected. Docs for these two ops are also updated to explicitly mention the return value is a namedtuple. More ops will be added in later PRs. There is some issue with Windows build of linker unable to resolve `PyStructSequence_UnnamedField`, and some workaround is added to deal with this case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15429 Differential Revision: D13709678 Pulled By: ezyang fbshipit-source-id: 23a511c9436977098afc49374e9a748b6e30bccf
14 lines
318 B
C
14 lines
318 B
C
#pragma once
|
|
|
|
// workaround for Python 2 issue: https://bugs.python.org/issue17120
|
|
#pragma push_macro("_XOPEN_SOURCE")
|
|
#pragma push_macro("_POSIX_C_SOURCE")
|
|
#undef _XOPEN_SOURCE
|
|
#undef _POSIX_C_SOURCE
|
|
|
|
#include <Python.h>
|
|
#include <structseq.h>
|
|
|
|
#pragma pop_macro("_XOPEN_SOURCE")
|
|
#pragma pop_macro("_POSIX_C_SOURCE")
|