mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Enable NoneType inputs to export API (#45792)
Summary: Enables the use of NoneType arguments to inputs tuple in the export API Pull Request resolved: https://github.com/pytorch/pytorch/pull/45792 Reviewed By: heitorschueroff Differential Revision: D24312784 Pulled By: bzinodev fbshipit-source-id: 1717e856b56062add371af7dc09cdd9c7b5646da
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c556d4550c
commit
1ea14e30f5
@ -462,6 +462,42 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
x = {"test_key_in": torch.randn(1, 2, 3)}
|
||||
self.run_test(MyModel(), (x,))
|
||||
|
||||
def test_none_as_input(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
if y is not None:
|
||||
return x + y
|
||||
return x
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
self.run_test(Model(), (x, None))
|
||||
|
||||
def test_none_as_tuple_input(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
if y[0] is not None:
|
||||
return x + y[0]
|
||||
if y[1] is not None:
|
||||
return x + y[1]
|
||||
return x
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
y = torch.randn(2, 3)
|
||||
self.run_test(Model(), (x, (None, y)))
|
||||
|
||||
def test_none_as_named_input(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y=None, z=None):
|
||||
if y is not None:
|
||||
return x + y
|
||||
if z is not None:
|
||||
return x + z
|
||||
return x
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
z = torch.randn(2, 3)
|
||||
self.run_test(Model(), (x, None, z))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_cste_script(self):
|
||||
class MyModel(torch.jit.ScriptModule):
|
||||
|
@ -21,6 +21,7 @@ static constexpr char TupleOpen = '(';
|
||||
static constexpr char TupleClose = ')';
|
||||
static constexpr char Variable = 'v';
|
||||
static constexpr char String = 's';
|
||||
static constexpr char NoneType = 'n';
|
||||
} // namespace D
|
||||
|
||||
namespace {
|
||||
@ -62,6 +63,8 @@ void flatten_rec(PyObject* obj, ParsedArgs& args) {
|
||||
args.vars.push_back(var);
|
||||
args.desc.metadata.emplace_back(var);
|
||||
args.desc.structure.push_back(D::Variable);
|
||||
} else if (strcmp(THPUtils_typename(obj), "NoneType") == 0) {
|
||||
args.desc.structure.push_back(D::NoneType);
|
||||
} else {
|
||||
std::string msg =
|
||||
"Only tuples, lists and Variables supported as JIT inputs/outputs. "
|
||||
@ -136,6 +139,8 @@ py::object unflatten_rec(
|
||||
throw std::runtime_error("Not enough Variables given to unflatten");
|
||||
auto str = *str_it++;
|
||||
return py::reinterpret_borrow<py::object>(THPUtils_packString(str));
|
||||
} else if (type == D::NoneType) {
|
||||
return py::reinterpret_borrow<py::object>(py::none());
|
||||
} else {
|
||||
if (var_it == var_it_end)
|
||||
throw std::runtime_error("Not enough Variables given to unflatten");
|
||||
|
@ -44,7 +44,7 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
|
||||
model (torch.nn.Module): the model to be exported.
|
||||
args (tuple of arguments or torch.Tensor): the inputs to
|
||||
the model, e.g., such that ``model(*args)`` is a valid
|
||||
invocation of the model. Any non-Tensor arguments will
|
||||
invocation of the model. Any non-Tensor arguments (including None) will
|
||||
be hard-coded into the exported model; any Tensor arguments
|
||||
will become inputs of the exported model, in the order they
|
||||
occur in args. If args is a Tensor, this is equivalent
|
||||
|
Reference in New Issue
Block a user