mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73284 Some important ops won't support optional type until opset 16, so we can't fully test things end-to-end, but I believe this should be all that's needed. Once ONNX Runtime supports opset 16, we can do more testing and fix any remaining bugs. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D34625646 Pulled By: malfet fbshipit-source-id: 537fcbc1e9d87686cc61f5bd66a997e99cec287b Co-authored-by: BowenBao <bowbao@microsoft.com> Co-authored-by: neginraoof <neginmr@utexas.edu> Co-authored-by: Nikita Shulga <nshulga@fb.com> (cherry picked from commit 822e79f31ae54d73407f34f166b654f4ba115ea5)
This commit is contained in:
committed by
PyTorch MergeBot
parent
b8776e143f
commit
679fc90cdb
@ -30,6 +30,10 @@ static constexpr char NoneType = 'n';
|
||||
|
||||
namespace {
|
||||
|
||||
inline bool PyNone_Check(PyObject* o) {
|
||||
return o == Py_None;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
py::object cast_handle_sequence(std::vector<py::handle> objs) {
|
||||
auto num_objs = objs.size();
|
||||
@ -68,7 +72,7 @@ void flatten_rec(PyObject* obj, ParsedArgs& args) {
|
||||
args.vars.push_back(var);
|
||||
args.desc.metadata.emplace_back(var);
|
||||
args.desc.structure.push_back(D::Variable);
|
||||
} else if (strcmp(THPUtils_typename(obj), "NoneType") == 0) {
|
||||
} else if (PyNone_Check(obj)) {
|
||||
args.desc.structure.push_back(D::NoneType);
|
||||
} else if (PyBool_Check(obj)) { // Wrap bools in Bool tensors
|
||||
at::Tensor var = scalar_to_tensor(at::Scalar(THPUtils_unpackBool(obj)));
|
||||
|
Reference in New Issue
Block a user