[ONNX] Support optional type (#68793) (#73284)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73284

Some important ops won't support optional type until opset 16,
so we can't fully test things end-to-end, but I believe this should
be all that's needed. Once ONNX Runtime supports opset 16,
we can do more testing and fix any remaining bugs.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D34625646

Pulled By: malfet

fbshipit-source-id: 537fcbc1e9d87686cc61f5bd66a997e99cec287b

Co-authored-by: BowenBao <bowbao@microsoft.com>
Co-authored-by: neginraoof <neginmr@utexas.edu>
Co-authored-by: Nikita Shulga <nshulga@fb.com>
(cherry picked from commit 822e79f31ae54d73407f34f166b654f4ba115ea5)
This commit is contained in:
BowenBao
2022-05-04 13:19:46 -07:00
committed by PyTorch MergeBot
parent b8776e143f
commit 679fc90cdb
24 changed files with 1105 additions and 494 deletions

View File

@ -30,6 +30,10 @@ static constexpr char NoneType = 'n';
namespace {
inline bool PyNone_Check(PyObject* o) {
return o == Py_None;
}
template <typename T>
py::object cast_handle_sequence(std::vector<py::handle> objs) {
auto num_objs = objs.size();
@ -68,7 +72,7 @@ void flatten_rec(PyObject* obj, ParsedArgs& args) {
args.vars.push_back(var);
args.desc.metadata.emplace_back(var);
args.desc.structure.push_back(D::Variable);
} else if (strcmp(THPUtils_typename(obj), "NoneType") == 0) {
} else if (PyNone_Check(obj)) {
args.desc.structure.push_back(D::NoneType);
} else if (PyBool_Check(obj)) { // Wrap bools in Bool tensors
at::Tensor var = scalar_to_tensor(at::Scalar(THPUtils_unpackBool(obj)));