#include #include #include #include namespace torch { namespace jit { namespace python { using namespace torch::autograd; using namespace at; // Alphabet used to describe structure of inputs/outputs (D for desc) namespace D { static constexpr char DictOpen = '<'; static constexpr char DictClose = '>'; static constexpr char ListOpen = '['; static constexpr char ListClose = ']'; static constexpr char TupleOpen = '('; static constexpr char TupleClose = ')'; static constexpr char Variable = 'v'; static constexpr char String = 's'; } // namespace D namespace { template py::object cast_handle_sequence(std::vector objs) { auto num_objs = objs.size(); T sequence{num_objs}; for (size_t i = 0; i < num_objs; ++i) sequence[i] = py::reinterpret_borrow(objs[i]); return sequence; } void flatten_rec(PyObject* obj, ParsedArgs& args) { auto& structure = args.desc.structure; if (six::isTuple(obj)) { structure.push_back(D::TupleOpen); for (auto item : py::reinterpret_borrow(obj)) flatten_rec(item.ptr(), args); structure.push_back(D::TupleClose); } else if (PyList_Check(obj)) { structure.push_back(D::ListOpen); for (auto item : py::reinterpret_borrow(obj)) flatten_rec(item.ptr(), args); structure.push_back(D::ListClose); } else if (PyDict_Check(obj)) { auto dict_items = PyDict_Items(obj); structure.push_back(D::DictOpen); for (auto item : py::reinterpret_borrow(dict_items)){ flatten_rec(item.ptr(), args); } structure.push_back(D::DictClose); } else if (THPUtils_checkString(obj)) { string str = THPUtils_unpackString(obj); args.desc.strings.emplace_back(str); args.desc.structure.push_back(D::String); } else if (THPVariable_Check(obj)) { auto& var = reinterpret_cast(obj)->cdata; args.vars.push_back(var); args.desc.metadata.emplace_back(var); args.desc.structure.push_back(D::Variable); } else { std::string msg = "Only tuples, lists and Variables supported as JIT inputs/outputs. " "Dictionaries and strings are also accepted but their usage is not " "recommended. But got unsupported type "; msg += THPUtils_typename(obj); throw std::runtime_error(msg); } } } // anonymous namespace ParsedArgs flatten(py::handle obj) { ParsedArgs args; args.desc.grad_enabled = autograd::GradMode::is_enabled(); flatten_rec(obj.ptr(), args); return args; } namespace { template py::object cast_sequence(std::vector objs) { auto num_objs = objs.size(); T sequence{num_objs}; for (size_t i = 0; i < num_objs; ++i) sequence[i] = std::move(objs[i]); return std::move(sequence); } py::object cast_dict(std::vector objs) { auto num_objs = objs.size(); py::dict sequence = {}; for (size_t i = 0; i < num_objs; ++i){ py::tuple obj = py::reinterpret_borrow(objs[i]); sequence[obj[0]] = std::move(obj[1]); } return std::move(sequence); } py::object unflatten_rec( ArrayRef::iterator& var_it, ArrayRef::iterator& var_it_end, std::string::const_iterator& desc_it, std::vector::const_iterator& str_it, std::vector::const_iterator& str_it_end) { char type = *desc_it++; if (type == D::TupleOpen) { std::vector objs; while (*desc_it != D::TupleClose) objs.push_back(unflatten_rec(var_it, var_it_end, desc_it, str_it, str_it_end)); ++desc_it; return cast_sequence(objs); } else if (type == D::ListOpen) { std::vector objs; while (*desc_it != D::ListClose) objs.push_back(unflatten_rec(var_it, var_it_end, desc_it,str_it, str_it_end)); ++desc_it; return cast_sequence(objs); } else if (type == D::DictOpen) { std::vector objs; while (*desc_it != D::DictClose){ objs.push_back(unflatten_rec(var_it, var_it_end, desc_it,str_it, str_it_end)); } ++desc_it; return cast_dict(objs); } else if (type == D::String) { if (str_it == str_it_end) throw std::runtime_error("Not enough Variables given to unflatten"); auto str = *str_it++; return py::reinterpret_borrow(THPUtils_packString(str)); } else { if (var_it == var_it_end) throw std::runtime_error("Not enough Variables given to unflatten"); auto var = *var_it++; return py::reinterpret_steal(THPVariable_Wrap(var)); } } } // anonymous namespace PyObject* unflatten(ArrayRef vars, const IODescriptor& desc) { // NB: We don't do correctness checking on descriptor. // It has to be a correct bytes object produced by unflatten. auto vars_it = vars.begin(); auto vars_it_end = vars.end(); auto desc_it = desc.structure.begin(); std::vector::const_iterator str_it = desc.strings.begin(); std::vector::const_iterator str_end = desc.strings.end(); auto output = unflatten_rec(vars_it, vars_it_end, desc_it, str_it, str_end); if (vars_it != vars_it_end) throw std::runtime_error("Too many Variables given to unflatten"); return output.release().ptr(); } } // namespace python } // namespace jit } // namespace torch