mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
585 lines
16 KiB
C++
585 lines
16 KiB
C++
#include <Python.h>
|
|
#include <stdarg.h>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <sstream>
|
|
#include <algorithm>
|
|
#include <unordered_map>
|
|
#include "THP.h"
|
|
|
|
#include "generic/utils.cpp"
|
|
#include <TH/THGenerateAllTypes.h>
|
|
|
|
int THPUtils_getCallable(PyObject *arg, PyObject **result) {
|
|
if (!PyCallable_Check(arg))
|
|
return 0;
|
|
*result = arg;
|
|
return 1;
|
|
}
|
|
|
|
THLongStoragePtr THPUtils_unpackSize(PyObject *arg) {
|
|
THLongStoragePtr result;
|
|
if (!THPUtils_tryUnpackLongs(arg, result)) {
|
|
std::string msg = "THPUtils_unpackSize() expects a torch.Size (got '";
|
|
msg += Py_TYPE(arg)->tp_name;
|
|
msg += "')";
|
|
throw std::runtime_error(msg);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
bool THPUtils_tryUnpackLongs(PyObject *arg, THLongStoragePtr& result) {
|
|
bool tuple = PyTuple_Check(arg);
|
|
bool list = PyList_Check(arg);
|
|
if (tuple || list) {
|
|
int nDim = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
|
THLongStoragePtr storage = THLongStorage_newWithSize(nDim);
|
|
for (int i = 0; i != nDim; ++i) {
|
|
PyObject* item = tuple ? PyTuple_GET_ITEM(arg, i) : PyList_GET_ITEM(arg, i);
|
|
if (!THPUtils_checkLong(item)) {
|
|
return false;
|
|
}
|
|
storage->data[i] = THPUtils_unpackLong(item);
|
|
}
|
|
result = storage.release();
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool THPUtils_tryUnpackLongVarArgs(PyObject *args, int ignore_first, THLongStoragePtr& result) {
|
|
Py_ssize_t length = PyTuple_Size(args) - ignore_first;
|
|
if (length < 1) {
|
|
return false;
|
|
}
|
|
|
|
PyObject *first_arg = PyTuple_GET_ITEM(args, ignore_first);
|
|
if (length == 1 && THPUtils_tryUnpackLongs(first_arg, result)) {
|
|
return true;
|
|
}
|
|
|
|
// Try to parse the numbers
|
|
result = THLongStorage_newWithSize(length);
|
|
for (Py_ssize_t i = 0; i < length; ++i) {
|
|
PyObject *arg = PyTuple_GET_ITEM(args, i + ignore_first);
|
|
if (!THPUtils_checkLong(arg)) {
|
|
return false;
|
|
}
|
|
result->data[i] = THPUtils_unpackLong(arg);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool THPUtils_checkIntTuple(PyObject *arg)
|
|
{
|
|
if (!PyTuple_Check(arg)) {
|
|
return false;
|
|
}
|
|
for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) {
|
|
if (!THPUtils_checkLong(PyTuple_GET_ITEM(arg, i))) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
std::vector<int> THPUtils_unpackIntTuple(PyObject *arg)
|
|
{
|
|
if (!THPUtils_checkIntTuple(arg)) {
|
|
throw std::runtime_error("Couldn't unpack int tuple");
|
|
}
|
|
std::vector<int> values(PyTuple_GET_SIZE(arg));
|
|
for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) {
|
|
values[i] = (int)THPUtils_unpackLong(PyTuple_GET_ITEM(arg, i));
|
|
}
|
|
return values;
|
|
}
|
|
|
|
void THPUtils_setError(const char *format, ...)
|
|
{
|
|
static const size_t ERROR_BUFFER_SIZE = 1000;
|
|
char buffer[ERROR_BUFFER_SIZE];
|
|
va_list fmt_args;
|
|
|
|
va_start(fmt_args, format);
|
|
vsnprintf(buffer, ERROR_BUFFER_SIZE, format, fmt_args);
|
|
va_end(fmt_args);
|
|
PyErr_SetString(PyExc_RuntimeError, buffer);
|
|
}
|
|
|
|
void THPUtils_addPyMethodDefs(std::vector<PyMethodDef>& vector, PyMethodDef* methods)
|
|
{
|
|
if (!vector.empty()) {
|
|
// remove NULL terminator
|
|
vector.pop_back();
|
|
}
|
|
while (1) {
|
|
vector.push_back(*methods);
|
|
if (!methods->ml_name) {
|
|
break;
|
|
}
|
|
methods++;
|
|
}
|
|
}
|
|
|
|
std::string _THPUtils_typename(PyObject *object)
|
|
{
|
|
std::string type_name = Py_TYPE(object)->tp_name;
|
|
std::string result;
|
|
if (type_name.find("Storage") != std::string::npos ||
|
|
type_name.find("Tensor") != std::string::npos) {
|
|
PyObject *module_name = PyObject_GetAttrString(object, "__module__");
|
|
#if PY_MAJOR_VERSION == 2
|
|
if (module_name && PyString_Check(module_name)) {
|
|
result = PyString_AS_STRING(module_name);
|
|
}
|
|
#else
|
|
if (module_name && PyUnicode_Check(module_name)) {
|
|
PyObject *module_name_bytes = PyUnicode_AsASCIIString(module_name);
|
|
if (module_name_bytes) {
|
|
result = PyBytes_AS_STRING(module_name_bytes);
|
|
Py_DECREF(module_name_bytes);
|
|
}
|
|
}
|
|
#endif
|
|
Py_XDECREF(module_name);
|
|
result += ".";
|
|
result += type_name;
|
|
} else {
|
|
result = std::move(type_name);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
|
|
struct Type {
|
|
virtual bool is_matching(PyObject *object) = 0;
|
|
virtual ~Type() {};
|
|
};
|
|
|
|
struct SimpleType: public Type {
|
|
SimpleType(std::string& name): name(name) {};
|
|
|
|
bool is_matching(PyObject *object) {
|
|
return _THPUtils_typename(object) == name;
|
|
}
|
|
|
|
std::string name;
|
|
};
|
|
|
|
struct MultiType: public Type {
|
|
MultiType(std::initializer_list<std::string> accepted_types):
|
|
types(accepted_types) {};
|
|
|
|
bool is_matching(PyObject *object) {
|
|
auto it = std::find(types.begin(), types.end(), _THPUtils_typename(object));
|
|
return it != types.end();
|
|
}
|
|
|
|
std::vector<std::string> types;
|
|
};
|
|
|
|
struct NullableType: public Type {
|
|
NullableType(std::unique_ptr<Type> type): type(std::move(type)) {};
|
|
|
|
bool is_matching(PyObject *object) {
|
|
return object == Py_None || type->is_matching(object);
|
|
}
|
|
|
|
std::unique_ptr<Type> type;
|
|
};
|
|
|
|
struct TupleType: public Type {
|
|
TupleType(std::vector<std::unique_ptr<Type>> types):
|
|
types(std::move(types)) {};
|
|
|
|
bool is_matching(PyObject *object) {
|
|
if (!PyTuple_Check(object)) return false;
|
|
auto num_elements = PyTuple_GET_SIZE(object);
|
|
if (num_elements != (long)types.size()) return false;
|
|
for (int i = 0; i < num_elements; i++) {
|
|
if (!types[i]->is_matching(PyTuple_GET_ITEM(object, i)))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
std::vector<std::unique_ptr<Type>> types;
|
|
};
|
|
|
|
struct SequenceType: public Type {
|
|
SequenceType(std::unique_ptr<Type> type):
|
|
type(std::move(type)) {};
|
|
|
|
bool is_matching(PyObject *object) {
|
|
if (!PySequence_Check(object)) return false;
|
|
auto num_elements = PySequence_Length(object);
|
|
for (int i = 0; i < num_elements; i++) {
|
|
if (!type->is_matching(PySequence_GetItem(object, i)))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
std::unique_ptr<Type> type;
|
|
};
|
|
|
|
struct Argument {
|
|
Argument(std::string name, std::unique_ptr<Type> type):
|
|
name(name), type(std::move(type)) {};
|
|
|
|
std::string name;
|
|
std::unique_ptr<Type> type;
|
|
};
|
|
|
|
struct Option {
|
|
Option(std::vector<Argument> arguments, bool is_variadic, bool has_out):
|
|
arguments(std::move(arguments)), is_variadic(is_variadic), has_out(has_out) {};
|
|
Option(bool is_variadic, bool has_out):
|
|
arguments(), is_variadic(is_variadic), has_out(has_out) {};
|
|
Option(const Option&) = delete;
|
|
Option(Option&& other):
|
|
arguments(std::move(other.arguments)), is_variadic(other.is_variadic) {};
|
|
|
|
std::vector<Argument> arguments;
|
|
bool is_variadic;
|
|
bool has_out;
|
|
};
|
|
|
|
std::vector<std::string> _splitString(const std::string &s, const std::string& delim) {
|
|
std::vector<std::string> tokens;
|
|
std::size_t start = 0;
|
|
std::size_t end;
|
|
while((end = s.find(delim, start)) != std::string::npos) {
|
|
tokens.push_back(s.substr(start, end-start));
|
|
start = end + delim.length();
|
|
}
|
|
tokens.push_back(s.substr(start));
|
|
return tokens;
|
|
}
|
|
|
|
std::unique_ptr<Type> _buildType(std::string type_name, bool is_nullable) {
|
|
std::unique_ptr<Type> result;
|
|
if (type_name == "float") {
|
|
result.reset(new MultiType({"float", "int", "long"}));
|
|
} else if (type_name == "int") {
|
|
result.reset(new MultiType({"int", "long"}));
|
|
} else if (type_name.find("tuple[") == 0) {
|
|
auto type_list = type_name.substr(6);
|
|
type_list.pop_back();
|
|
std::vector<std::unique_ptr<Type>> types;
|
|
for (auto& type: _splitString(type_list, ","))
|
|
types.emplace_back(_buildType(type, false));
|
|
result.reset(new TupleType(std::move(types)));
|
|
} else if (type_name.find("sequence[") == 0) {
|
|
auto subtype = type_name.substr(9);
|
|
subtype.pop_back();
|
|
result.reset(new SequenceType(_buildType(subtype, false)));
|
|
} else {
|
|
result.reset(new SimpleType(type_name));
|
|
}
|
|
if (is_nullable)
|
|
result.reset(new NullableType(std::move(result)));
|
|
return result;
|
|
}
|
|
|
|
std::pair<Option, std::string> _parseOption(const std::string& _option_str,
|
|
const std::unordered_map<std::string, PyObject*> kwargs)
|
|
{
|
|
if (_option_str == "no arguments")
|
|
return std::pair<Option, std::string>(Option(false, false), _option_str);
|
|
bool has_out = false;
|
|
std::vector<Argument> arguments;
|
|
std::string printable_option = _option_str;
|
|
std::string option_str = _option_str.substr(1, _option_str.length()-2);
|
|
|
|
/// XXX: this is a hack only for the out arg in TensorMethods
|
|
auto out_pos = printable_option.find('#');
|
|
if (out_pos != std::string::npos) {
|
|
if (kwargs.count("out") > 0) {
|
|
std::string kwonly_part = printable_option.substr(out_pos+1);
|
|
printable_option.erase(out_pos);
|
|
printable_option += "*, ";
|
|
printable_option += kwonly_part;
|
|
} else {
|
|
printable_option.erase(out_pos-2);
|
|
printable_option += ")";
|
|
}
|
|
has_out = true;
|
|
}
|
|
|
|
for (auto& arg: _splitString(option_str, ", ")) {
|
|
bool is_nullable = false;
|
|
auto type_start_idx = 0;
|
|
if (arg[type_start_idx] == '#') {
|
|
type_start_idx++;
|
|
}
|
|
if (arg[type_start_idx] == '[') {
|
|
is_nullable = true;
|
|
type_start_idx++;
|
|
arg.erase(arg.length() - std::string(" or None]").length());
|
|
}
|
|
|
|
auto type_end_idx = arg.find_last_of(' ');
|
|
auto name_start_idx = type_end_idx + 1;
|
|
|
|
// "type ... name" => "type ... name"
|
|
// ^ ^
|
|
auto dots_idx = arg.find("...");
|
|
if (dots_idx != std::string::npos)
|
|
type_end_idx -= 4;
|
|
|
|
std::string type_name =
|
|
arg.substr(type_start_idx, type_end_idx-type_start_idx);
|
|
std::string name =
|
|
arg.substr(name_start_idx);
|
|
|
|
arguments.emplace_back(name, _buildType(type_name, is_nullable));
|
|
}
|
|
|
|
bool is_variadic = option_str.find("...") != std::string::npos;
|
|
return std::pair<Option, std::string>(
|
|
Option(std::move(arguments), is_variadic, has_out),
|
|
std::move(printable_option)
|
|
);
|
|
}
|
|
|
|
bool _argcountMatch(
|
|
const Option& option,
|
|
const std::vector<PyObject*>& arguments,
|
|
const std::unordered_map<std::string, PyObject*>& kwargs)
|
|
{
|
|
auto num_expected = option.arguments.size();
|
|
auto num_got = arguments.size() + kwargs.size();
|
|
// Note: variadic functions don't accept kwargs, so it's ok
|
|
if (option.has_out && kwargs.count("out") == 0)
|
|
num_expected--;
|
|
return num_got == num_expected ||
|
|
(option.is_variadic && num_got > num_expected);
|
|
}
|
|
|
|
std::string _formattedArgDesc(
|
|
const Option& option,
|
|
const std::vector<PyObject*>& arguments,
|
|
const std::unordered_map<std::string, PyObject*>& kwargs)
|
|
{
|
|
std::string red;
|
|
std::string reset_red;
|
|
std::string green;
|
|
std::string reset_green;
|
|
if (isatty(1) && isatty(2)) {
|
|
red = "\33[31;1m";
|
|
reset_red = "\33[0m";
|
|
green = "\33[32;1m";
|
|
reset_green = "\33[0m";
|
|
} else {
|
|
red = "!";
|
|
reset_red = "!";
|
|
green = "";
|
|
reset_green = "";
|
|
}
|
|
|
|
auto num_args = arguments.size() + kwargs.size();
|
|
std::string result = "(";
|
|
for (size_t i = 0; i < num_args; i++) {
|
|
bool is_kwarg = i >= arguments.size();
|
|
PyObject *arg = is_kwarg ? kwargs.at(option.arguments[i].name) : arguments[i];
|
|
|
|
bool is_matching = false;
|
|
if (i < option.arguments.size()) {
|
|
is_matching = option.arguments[i].type->is_matching(arg);
|
|
} else if (option.is_variadic) {
|
|
is_matching = option.arguments.back().type->is_matching(arg);
|
|
}
|
|
|
|
if (is_matching)
|
|
result += green;
|
|
else
|
|
result += red;
|
|
if (is_kwarg) result += option.arguments[i].name + "=";
|
|
result += _THPUtils_typename(arg);
|
|
if (is_matching)
|
|
result += reset_green;
|
|
else
|
|
result += reset_red;
|
|
result += ", ";
|
|
}
|
|
if (arguments.size() > 0)
|
|
result.erase(result.length()-2);
|
|
result += ")";
|
|
return result;
|
|
}
|
|
|
|
std::string _argDesc(const std::vector<PyObject *>& arguments,
|
|
const std::unordered_map<std::string, PyObject *>& kwargs)
|
|
{
|
|
std::string result = "(";
|
|
for (auto& arg: arguments)
|
|
result += std::string(_THPUtils_typename(arg)) + ", ";
|
|
for (auto& kwarg: kwargs)
|
|
result += kwarg.first + "=" + _THPUtils_typename(kwarg.second) + ", ";
|
|
if (arguments.size() > 0)
|
|
result.erase(result.length()-2);
|
|
result += ")";
|
|
return result;
|
|
}
|
|
|
|
std::vector<std::string> _tryMatchKwargs(const Option& option,
|
|
const std::unordered_map<std::string, PyObject*>& kwargs) {
|
|
std::vector<std::string> unmatched;
|
|
int start_idx = option.arguments.size() - kwargs.size();
|
|
if (option.has_out && kwargs.count("out") == 0)
|
|
start_idx--;
|
|
if (start_idx < 0)
|
|
start_idx = 0;
|
|
for (auto& entry: kwargs) {
|
|
bool found = false;
|
|
for (unsigned int i = start_idx; i < option.arguments.size(); i++) {
|
|
if (option.arguments[i].name == entry.first) {
|
|
found = true;
|
|
break;
|
|
}
|
|
}
|
|
if (!found)
|
|
unmatched.push_back(entry.first);
|
|
}
|
|
return unmatched;
|
|
}
|
|
|
|
std::string _parseDictKey(PyObject *key_str) {
|
|
#if PY_MAJOR_VERSION == 3
|
|
THPObjectPtr ascii = PyUnicode_AsASCIIString(key_str);
|
|
return std::string(PyBytes_AS_STRING(ascii.get()));
|
|
#else
|
|
return std::string(PyString_AS_STRING(key_str));
|
|
#endif
|
|
}
|
|
|
|
void THPUtils_invalidArguments(PyObject *given_args, PyObject *given_kwargs,
|
|
const char *function_name, size_t num_options, ...) {
|
|
std::vector<std::string> option_strings;
|
|
std::vector<PyObject *> args;
|
|
std::unordered_map<std::string, PyObject *> kwargs;
|
|
std::string error_msg;
|
|
error_msg.reserve(2000);
|
|
error_msg += function_name;
|
|
error_msg += " received an invalid combination of arguments - ";
|
|
va_list option_list;
|
|
va_start(option_list, num_options);
|
|
for (size_t i = 0; i < num_options; i++)
|
|
option_strings.push_back(va_arg(option_list, const char*));
|
|
va_end(option_list);
|
|
|
|
Py_ssize_t num_args = PyTuple_Size(given_args);
|
|
for (int i = 0; i < num_args; i++) {
|
|
PyObject *arg = PyTuple_GET_ITEM(given_args, i);
|
|
args.push_back(arg);
|
|
}
|
|
|
|
bool has_kwargs = given_kwargs && PyDict_Size(given_kwargs) > 0;
|
|
if (has_kwargs) {
|
|
PyObject *key, *value;
|
|
Py_ssize_t pos = 0;
|
|
|
|
while (PyDict_Next(given_kwargs, &pos, &key, &value)) {
|
|
kwargs.emplace(_parseDictKey(key), value);
|
|
}
|
|
}
|
|
|
|
if (num_options == 1) {
|
|
auto pair = _parseOption(option_strings[0], kwargs);
|
|
auto& option = pair.first;
|
|
auto& option_str = pair.second;
|
|
std::vector<std::string> unmatched_kwargs;
|
|
if (has_kwargs)
|
|
unmatched_kwargs = _tryMatchKwargs(option, kwargs);
|
|
if (unmatched_kwargs.size()) {
|
|
error_msg += "got unrecognized keyword arguments: ";
|
|
for (auto& kwarg: unmatched_kwargs)
|
|
error_msg += kwarg + ", ";
|
|
error_msg.erase(error_msg.length()-2);
|
|
} else {
|
|
error_msg += "got ";
|
|
if (_argcountMatch(option, args, kwargs)) {
|
|
error_msg += _formattedArgDesc(option, args, kwargs);
|
|
} else {
|
|
error_msg += _argDesc(args, kwargs);
|
|
}
|
|
error_msg += ", but expected ";
|
|
error_msg += option_str;
|
|
}
|
|
} else {
|
|
error_msg += "got ";
|
|
error_msg += _argDesc(args, kwargs);
|
|
error_msg += ", but expected one of:\n";
|
|
for (auto &option_str: option_strings) {
|
|
auto pair = _parseOption(option_str, kwargs);
|
|
auto& option = pair.first;
|
|
auto& printable_option_str = pair.second;
|
|
error_msg += " * ";
|
|
error_msg += printable_option_str;
|
|
error_msg += "\n";
|
|
if (_argcountMatch(option, args, kwargs)) {
|
|
std::vector<std::string> unmatched_kwargs;
|
|
if (has_kwargs)
|
|
unmatched_kwargs = _tryMatchKwargs(option, kwargs);
|
|
if (unmatched_kwargs.size() > 0) {
|
|
error_msg += " didn't match because some of the keywords were incorrect: ";
|
|
for (auto& kwarg: unmatched_kwargs)
|
|
error_msg += kwarg + ", ";
|
|
error_msg.erase(error_msg.length()-2);
|
|
error_msg += "\n";
|
|
} else {
|
|
error_msg += " didn't match because some of the arguments have invalid types: ";
|
|
error_msg += _formattedArgDesc(option, args, kwargs);
|
|
error_msg += "\n";
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
PyErr_SetString(PyExc_TypeError, error_msg.c_str());
|
|
}
|
|
|
|
|
|
|
|
bool THPUtils_parseSlice(PyObject *slice, Py_ssize_t len, Py_ssize_t *ostart, Py_ssize_t *ostop, Py_ssize_t *oslicelength)
|
|
{
|
|
Py_ssize_t start, stop, step, slicelength;
|
|
if (PySlice_GetIndicesEx(
|
|
// https://bugsfiles.kde.org/attachment.cgi?id=61186
|
|
#if PY_VERSION_HEX >= 0x03020000
|
|
slice,
|
|
#else
|
|
(PySliceObject *)slice,
|
|
#endif
|
|
len, &start, &stop, &step, &slicelength) < 0) {
|
|
return false;
|
|
}
|
|
if (step != 1) {
|
|
THPUtils_setError("Trying to slice with a step of %ld, but only a step of "
|
|
"1 is supported", (long)step);
|
|
return false;
|
|
}
|
|
*ostart = start;
|
|
*ostop = stop;
|
|
if(oslicelength)
|
|
*oslicelength = slicelength;
|
|
return true;
|
|
}
|
|
|
|
template<>
|
|
void THPPointer<THPGenerator>::free() {
|
|
if (ptr)
|
|
Py_DECREF(ptr);
|
|
}
|
|
|
|
template<>
|
|
void THPPointer<PyObject>::free() {
|
|
if (ptr)
|
|
Py_DECREF(ptr);
|
|
}
|
|
|
|
template class THPPointer<THPGenerator>;
|
|
template class THPPointer<PyObject>;
|