[ONNX] Enable NoneType inputs to export API (#45792)

Summary:
Enables the use of NoneType arguments to inputs tuple in the export API

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45792

Reviewed By: heitorschueroff

Differential Revision: D24312784

Pulled By: bzinodev

fbshipit-source-id: 1717e856b56062add371af7dc09cdd9c7b5646da
This commit is contained in:
shubhambhokare1
2020-10-29 13:54:44 -07:00
committed by Facebook GitHub Bot
parent c556d4550c
commit 1ea14e30f5
3 changed files with 42 additions and 1 deletions

View File

@ -21,6 +21,7 @@ static constexpr char TupleOpen = '(';
static constexpr char TupleClose = ')';
static constexpr char Variable = 'v';
static constexpr char String = 's';
static constexpr char NoneType = 'n';
} // namespace D
namespace {
@ -62,6 +63,8 @@ void flatten_rec(PyObject* obj, ParsedArgs& args) {
args.vars.push_back(var);
args.desc.metadata.emplace_back(var);
args.desc.structure.push_back(D::Variable);
} else if (strcmp(THPUtils_typename(obj), "NoneType") == 0) {
args.desc.structure.push_back(D::NoneType);
} else {
std::string msg =
"Only tuples, lists and Variables supported as JIT inputs/outputs. "
@ -136,6 +139,8 @@ py::object unflatten_rec(
throw std::runtime_error("Not enough Variables given to unflatten");
auto str = *str_it++;
return py::reinterpret_borrow<py::object>(THPUtils_packString(str));
} else if (type == D::NoneType) {
return py::reinterpret_borrow<py::object>(py::none());
} else {
if (var_it == var_it_end)
throw std::runtime_error("Not enough Variables given to unflatten");