[ONNX] Support primitive type input/outputs and attributes (#53550) (#54864)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54864

Support primitive type attributes. Needed for Silero model.

Test Plan: Imported from OSS

Reviewed By: nikithamalgifb

Differential Revision: D27408982

Pulled By: SplitInfinity

fbshipit-source-id: 16b291eedbe9f9bb31d7664a29a484555df53755
This commit is contained in:
Negin Raoof
2021-03-31 21:11:25 -07:00
committed by Facebook GitHub Bot
parent ce48b14060
commit cd9dd653e9
4 changed files with 146 additions and 48 deletions

View File

@ -1342,6 +1342,74 @@ class TestONNXRuntime(unittest.TestCase):
x = torch.randn(2, 3, 4)
self.run_test(ArithmeticModule(), x)
def test_arithmetic_prim_long(self):
class ArithmeticModule(torch.nn.Module):
def forward(self, x, y: int):
x = x + y
x = x - y
x = x * (y * 3)
x = x / (y * 4)
return x
x = torch.randn(2, 3, 4)
y = 2
self.run_test(ArithmeticModule(), (x, y))
class ArithmeticModule(torch.nn.Module):
def forward(self, x):
x = x + 2
x = x - 3
return x.shape[0]
x = torch.randn(2, 3, 4)
self.run_test(ArithmeticModule(), x)
def test_arithmetic_prim_float(self):
class ArithmeticModule(torch.nn.Module):
def forward(self, x, y: float):
x = x + y
x = x - y
x = x * (y * 3)
x = x / (y * 4)
return x
x = torch.randn(2, 3, 4)
y = 2.5
self.run_test(ArithmeticModule(), (x, y))
class ArithmeticModule(torch.nn.Module):
def forward(self, x):
x = x + 2
x = x - 3
return x.shape[1] / 2
x = torch.randn(2, 3, 4)
self.run_test(ArithmeticModule(), x)
def test_arithmetic_prim_bool(self):
class ArithmeticModule(torch.nn.Module):
def forward(self, x, y: int, z: bool, t: float):
x = x + y
x = x - y
if z:
x = x * (y * 3)
x = x / (y * 4)
return x / t, z
x = torch.randn(2, 3, 4)
y = 2
z = False
t = 2.5
self.run_test(ArithmeticModule(), (x, y, z, t))
class ArithmeticModule(torch.nn.Module):
def forward(self, x: float, y: float):
return x == y
x = 3
y = 2
self.run_test(ArithmeticModule(), (x, y))
# In scripting the first transpose node do not carry shape and dtype info.
# The following test only works when onnx shape inference is enabled.
@skipIfONNXShapeInference(False)
@ -6107,7 +6175,7 @@ class TestONNXRuntime(unittest.TestCase):
return torch.nn.functional.embedding(input, emb, padding_idx=1)
model = EmbedModel()
x = torch.randint(4, (4, ))
x = torch.randint(4, (4,))
x[2] = x[0] = 1
embedding_matrix = torch.rand(10, 3)
self.run_test(model, (x, embedding_matrix))
@ -6140,14 +6208,14 @@ class TestONNXRuntime(unittest.TestCase):
return self.emb(input), self.emb2(input)
model = EmbedModel()
x = torch.randint(4, (4, ))
x = torch.randint(4, (4,))
x[2] = x[0] = 1
self.run_test(model, (x, ))
self.run_test(model, (x,))
x = torch.randint(4, (4, 3, 2))
x[2] = 1
x[0][1] = 1
self.run_test(model, (x, ))
self.run_test(model, (x,))
class EmbedModelWithoutPaddingIdx(torch.nn.Module):
def __init__(self):
@ -6159,7 +6227,7 @@ class TestONNXRuntime(unittest.TestCase):
model = EmbedModelWithoutPaddingIdx()
x = torch.randint(4, (4, 3, 2))
self.run_test(model, (x, ))
self.run_test(model, (x,))
def _dispatch_rnn_test(self, name, *args, **kwargs):
if name == 'elman':
@ -6956,6 +7024,7 @@ class TestONNXRuntime(unittest.TestCase):
super().__init__()
self.weights = InnerModule2.get_embedding(embedding_dim)
self.register_buffer("_float_tensor", torch.FloatTensor(1))
self.const = 2
@staticmethod
def get_embedding(embedding_dim: int):
@ -6965,9 +7034,11 @@ class TestONNXRuntime(unittest.TestCase):
def forward(self, input, incremental_state: Optional[torch.Tensor] = None):
bsz, seq_len = input.shape[0], input.shape[1]
self.const = 3
if self.weights is None:
self.weights = InnerModule.get_embedding(self.embedding_dim)
self.weights = self.weights.to(self._float_tensor)
self.weights = self.weights * self.const
if incremental_state is not None:
pos = seq_len
return self.weights[1 + pos, :].expand(bsz, 1, -1)
@ -7007,6 +7078,7 @@ class TestONNXRuntime(unittest.TestCase):
def __init__(self, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
self.const = 2.5
self.weights = InnerModule.get_embedding(self.embedding_dim)
self.register_buffer("_float_tensor", torch.FloatTensor(1))
@ -7018,10 +7090,11 @@ class TestONNXRuntime(unittest.TestCase):
def forward(self, input, incremental_state: Optional[torch.Tensor] = None):
bsz, seq_len = input.shape[0], input.shape[1]
self.const = 1.5
self.weights = InnerModule.get_embedding(self.embedding_dim)
return (
self.weights.index_select(0, torch.ones((bsz * seq_len), dtype=torch.int64)).view(bsz, seq_len, -1)
)
) * self.const
class Module(torch.nn.Module):
def __init__(self):
@ -7039,12 +7112,17 @@ class TestONNXRuntime(unittest.TestCase):
def __init__(self):
super(MyModule, self).__init__()
self.conv = torch.nn.Conv1d(3, 10, 2)
self.b = False
def forward(self, box_regression, weight):
self.b = True
self.conv.weight = weight
w = torch.softmax(self.conv.weight, dim=0)
self.conv.weight = w + w
return box_regression + self.conv.weight
if self.b:
return box_regression + self.conv.weight
else:
return box_regression - self.conv.weight
model = torch.jit.script(MyModule())
weight = torch.ones(3, 2)

View File

@ -32,26 +32,37 @@ bool IsInplaceNode(const Node* n) {
return false;
}
Node* addDummyCloneToBlock(Block* b, Value* orig_data) {
auto graph = b->owningGraph();
Node* addDummyClone(
Graph* graph,
Value* orig_data,
bool insertBefore,
Node* referenceNode) {
Node* newNode = nullptr;
if (orig_data->type()->kind() == TypeKind::ListType) {
newNode = graph->create(aten::list, /*num_outputs =*/1);
newNode->addInput(orig_data);
newNode->output()->setType(orig_data->type());
b->prependNode(newNode);
} else if (orig_data->type()->kind() == TypeKind::TensorType) {
newNode = graph->create(aten::clone, /*num_outputs =*/1);
newNode->addInput(orig_data);
if (insertBefore)
newNode->insertBefore(referenceNode);
else
referenceNode->owningBlock()->prependNode(newNode);
} else if (
orig_data->type()->kind() == TypeKind::TensorType ||
orig_data->type()->kind() == TypeKind::IntType ||
orig_data->type()->kind() == TypeKind::FloatType ||
orig_data->type()->kind() == TypeKind::BoolType) {
auto* noneNode = graph->create(prim::Constant);
noneNode->output()->setType(NoneType::get());
newNode = graph->create(aten::clone, /*num_outputs =*/1);
newNode->addInput(orig_data);
newNode->addInput(noneNode->output());
newNode->output()->setType(orig_data->type());
b->prependNode(newNode);
if (insertBefore)
newNode->insertBefore(referenceNode);
else
referenceNode->owningBlock()->prependNode(newNode);
noneNode->insertBefore(newNode);
} // TODO: Handle float/int attributes
}
return newNode;
}
@ -82,7 +93,8 @@ Value* MatchIfBlocksOutputForValue(
for (Block* b : outer_block->owningNode()->blocks()) {
if (b->outputs().size() < output_size) {
auto clone_node = addDummyCloneToBlock(b, orig_data);
auto clone_node =
addDummyClone(b->owningGraph(), orig_data, false, b->return_node());
b->registerOutput(clone_node->output());
b->outputs()
.at(b->outputs().size() - 1)
@ -492,7 +504,8 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) {
<< (*it)->debugName() << "'. This changes graph semantics."
<< std::endl;
Node* newNode = addDummyCloneToBlock(b, input);
Node* newNode =
addDummyClone(b->owningGraph(), input, false, b->return_node());
TORCH_INTERNAL_ASSERT(nullptr != newNode);
node->replaceInput(index, newNode->output());
input->replaceAllUsesAfterNodeWith(node, newNode->output());
@ -531,7 +544,7 @@ std::deque<std::string> findSubModuleAttr(
moduleNames.push_front(node->s(attr::name));
node = node->inputs()[0]->node();
} else {
return moduleNames;
break;
}
}
// Assign the inner module to attrModule.
@ -554,31 +567,6 @@ Value* findArgumentAsInputParam(
name);
}
Node* insertCloneBeforeNode(
const std::shared_ptr<Graph>& graph,
Value* orig_data,
Node* node) {
Node* newNode = nullptr;
if (orig_data->type()->kind() == TypeKind::ListType) {
// Create an aten::list to clone the list in graph inputs
newNode = graph->create(aten::list, /*num_outputs =*/1);
newNode->addInput(orig_data);
newNode->output()->setType(orig_data->type());
newNode->insertBefore(node);
} else if (orig_data->type()->kind() == TypeKind::TensorType) {
auto* noneNode = graph->create(prim::Constant);
noneNode->output()->setType(NoneType::get());
newNode = graph->create(aten::clone, /*num_outputs =*/1);
newNode->addInput(orig_data);
newNode->addInput(noneNode->output());
newNode->output()->setType(orig_data->type());
newNode->insertBefore(node);
noneNode->insertBefore(newNode);
} // TODO: Handle float/int attributes
return newNode;
}
Value* registerSetAttrInBlocks(
const std::shared_ptr<Graph>& graph,
Block* block,
@ -694,7 +682,8 @@ void trackAndRegisterAttributesInBlocks(
// If inside a block, keep the output value to register in block
// output.
auto block_ = n->owningBlock();
Node* cloneNode = insertCloneBeforeNode(graph, n->inputs().at(1), n);
Node* cloneNode =
addDummyClone(block_->owningGraph(), n->inputs().at(1), true, n);
if (block_->owningNode() &&
(block_->owningNode()->kind() == prim::If ||
block_->owningNode()->kind() == prim::Loop)) {

View File

@ -20,6 +20,9 @@ static constexpr char ListClose = ']';
static constexpr char TupleOpen = '(';
static constexpr char TupleClose = ')';
static constexpr char Variable = 'v';
static constexpr char Bool = 'b';
static constexpr char Long = 'l';
static constexpr char Double = 'd';
static constexpr char String = 's';
static constexpr char NoneType = 'n';
} // namespace D
@ -36,6 +39,12 @@ py::object cast_handle_sequence(std::vector<py::handle> objs) {
}
void flatten_rec(PyObject* obj, ParsedArgs& args) {
auto as_variable = [](at::Tensor& tensor) // Wrap tensor as Variable
{
PyObject* wappred_obj = THPVariable_Wrap(tensor);
return reinterpret_cast<THPVariable*>(wappred_obj)->cdata;
};
auto& structure = args.desc.structure;
if (six::isTuple(obj)) {
structure.push_back(D::TupleOpen);
@ -65,6 +74,25 @@ void flatten_rec(PyObject* obj, ParsedArgs& args) {
args.desc.structure.push_back(D::Variable);
} else if (strcmp(THPUtils_typename(obj), "NoneType") == 0) {
args.desc.structure.push_back(D::NoneType);
} else if (PyBool_Check(obj)) { // Wrap integers in bool tensors
at::Tensor tensor = scalar_to_tensor(at::Scalar(THPUtils_unpackBool(obj)));
auto var = as_variable(tensor);
args.vars.push_back(var);
args.desc.metadata.emplace_back(var);
args.desc.structure.push_back(D::Bool);
} else if (PyLong_Check(obj)) { // Wrap integers in long tensors
at::Tensor tensor = scalar_to_tensor(
at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(obj))));
auto var = as_variable(tensor);
args.vars.push_back(var);
args.desc.metadata.emplace_back(var);
args.desc.structure.push_back(D::Long);
} else if (PyFloat_Check(obj)) { // Wrap floating points in double tensors
at::Tensor tensor = scalar_to_tensor(THPUtils_unpackDouble(obj));
auto var = as_variable(tensor);
args.vars.push_back(var);
args.desc.metadata.emplace_back(var);
args.desc.structure.push_back(D::Double);
} else {
std::string msg =
"Only tuples, lists and Variables are supported as JIT inputs/outputs. "
@ -142,6 +170,9 @@ py::object unflatten_rec(
} else if (type == D::NoneType) {
return py::reinterpret_borrow<py::object>(py::none());
} else {
// if (type == D::Long || type == D::Double || type == D::Bool ||
// D::Variable) unwrap variables (D::Variable), or unwrap primitive types
// (Long, Double, Bool) as variables for tracer.
if (var_it == var_it_end)
throw std::runtime_error("Not enough Variables given to unflatten");
auto var = *var_it++;

View File

@ -436,10 +436,10 @@ def _model_to_graph(model, args, verbose=False,
training=None, dynamic_axes=None):
from torch.onnx.symbolic_helper import _export_onnx_opset_version
# Special case for common case of passing a single Tensor
if isinstance(args, torch.Tensor):
if isinstance(args, (torch.Tensor, int, float, bool)):
args = (args, )
if isinstance(example_outputs, torch.Tensor):
if isinstance(example_outputs, (torch.Tensor, int, float, bool)):
example_outputs = (example_outputs,)
graph, params, torch_out, module = _create_jit_graph(model, args,