mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54864 Support primitive type attributes. Needed for Silero model. Test Plan: Imported from OSS Reviewed By: nikithamalgifb Differential Revision: D27408982 Pulled By: SplitInfinity fbshipit-source-id: 16b291eedbe9f9bb31d7664a29a484555df53755
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ce48b14060
commit
cd9dd653e9
@ -1342,6 +1342,74 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
x = torch.randn(2, 3, 4)
|
||||
self.run_test(ArithmeticModule(), x)
|
||||
|
||||
def test_arithmetic_prim_long(self):
|
||||
class ArithmeticModule(torch.nn.Module):
|
||||
def forward(self, x, y: int):
|
||||
x = x + y
|
||||
x = x - y
|
||||
x = x * (y * 3)
|
||||
x = x / (y * 4)
|
||||
return x
|
||||
|
||||
x = torch.randn(2, 3, 4)
|
||||
y = 2
|
||||
self.run_test(ArithmeticModule(), (x, y))
|
||||
|
||||
class ArithmeticModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x = x + 2
|
||||
x = x - 3
|
||||
return x.shape[0]
|
||||
|
||||
x = torch.randn(2, 3, 4)
|
||||
self.run_test(ArithmeticModule(), x)
|
||||
|
||||
def test_arithmetic_prim_float(self):
|
||||
class ArithmeticModule(torch.nn.Module):
|
||||
def forward(self, x, y: float):
|
||||
x = x + y
|
||||
x = x - y
|
||||
x = x * (y * 3)
|
||||
x = x / (y * 4)
|
||||
return x
|
||||
|
||||
x = torch.randn(2, 3, 4)
|
||||
y = 2.5
|
||||
self.run_test(ArithmeticModule(), (x, y))
|
||||
|
||||
class ArithmeticModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x = x + 2
|
||||
x = x - 3
|
||||
return x.shape[1] / 2
|
||||
|
||||
x = torch.randn(2, 3, 4)
|
||||
self.run_test(ArithmeticModule(), x)
|
||||
|
||||
def test_arithmetic_prim_bool(self):
|
||||
class ArithmeticModule(torch.nn.Module):
|
||||
def forward(self, x, y: int, z: bool, t: float):
|
||||
x = x + y
|
||||
x = x - y
|
||||
if z:
|
||||
x = x * (y * 3)
|
||||
x = x / (y * 4)
|
||||
return x / t, z
|
||||
|
||||
x = torch.randn(2, 3, 4)
|
||||
y = 2
|
||||
z = False
|
||||
t = 2.5
|
||||
self.run_test(ArithmeticModule(), (x, y, z, t))
|
||||
|
||||
class ArithmeticModule(torch.nn.Module):
|
||||
def forward(self, x: float, y: float):
|
||||
return x == y
|
||||
|
||||
x = 3
|
||||
y = 2
|
||||
self.run_test(ArithmeticModule(), (x, y))
|
||||
|
||||
# In scripting the first transpose node do not carry shape and dtype info.
|
||||
# The following test only works when onnx shape inference is enabled.
|
||||
@skipIfONNXShapeInference(False)
|
||||
@ -6107,7 +6175,7 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
return torch.nn.functional.embedding(input, emb, padding_idx=1)
|
||||
|
||||
model = EmbedModel()
|
||||
x = torch.randint(4, (4, ))
|
||||
x = torch.randint(4, (4,))
|
||||
x[2] = x[0] = 1
|
||||
embedding_matrix = torch.rand(10, 3)
|
||||
self.run_test(model, (x, embedding_matrix))
|
||||
@ -6140,14 +6208,14 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
return self.emb(input), self.emb2(input)
|
||||
|
||||
model = EmbedModel()
|
||||
x = torch.randint(4, (4, ))
|
||||
x = torch.randint(4, (4,))
|
||||
x[2] = x[0] = 1
|
||||
self.run_test(model, (x, ))
|
||||
self.run_test(model, (x,))
|
||||
|
||||
x = torch.randint(4, (4, 3, 2))
|
||||
x[2] = 1
|
||||
x[0][1] = 1
|
||||
self.run_test(model, (x, ))
|
||||
self.run_test(model, (x,))
|
||||
|
||||
class EmbedModelWithoutPaddingIdx(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -6159,7 +6227,7 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
|
||||
model = EmbedModelWithoutPaddingIdx()
|
||||
x = torch.randint(4, (4, 3, 2))
|
||||
self.run_test(model, (x, ))
|
||||
self.run_test(model, (x,))
|
||||
|
||||
def _dispatch_rnn_test(self, name, *args, **kwargs):
|
||||
if name == 'elman':
|
||||
@ -6956,6 +7024,7 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
super().__init__()
|
||||
self.weights = InnerModule2.get_embedding(embedding_dim)
|
||||
self.register_buffer("_float_tensor", torch.FloatTensor(1))
|
||||
self.const = 2
|
||||
|
||||
@staticmethod
|
||||
def get_embedding(embedding_dim: int):
|
||||
@ -6965,9 +7034,11 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
|
||||
def forward(self, input, incremental_state: Optional[torch.Tensor] = None):
|
||||
bsz, seq_len = input.shape[0], input.shape[1]
|
||||
self.const = 3
|
||||
if self.weights is None:
|
||||
self.weights = InnerModule.get_embedding(self.embedding_dim)
|
||||
self.weights = self.weights.to(self._float_tensor)
|
||||
self.weights = self.weights * self.const
|
||||
if incremental_state is not None:
|
||||
pos = seq_len
|
||||
return self.weights[1 + pos, :].expand(bsz, 1, -1)
|
||||
@ -7007,6 +7078,7 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.const = 2.5
|
||||
self.weights = InnerModule.get_embedding(self.embedding_dim)
|
||||
self.register_buffer("_float_tensor", torch.FloatTensor(1))
|
||||
|
||||
@ -7018,10 +7090,11 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
|
||||
def forward(self, input, incremental_state: Optional[torch.Tensor] = None):
|
||||
bsz, seq_len = input.shape[0], input.shape[1]
|
||||
self.const = 1.5
|
||||
self.weights = InnerModule.get_embedding(self.embedding_dim)
|
||||
return (
|
||||
self.weights.index_select(0, torch.ones((bsz * seq_len), dtype=torch.int64)).view(bsz, seq_len, -1)
|
||||
)
|
||||
) * self.const
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -7039,12 +7112,17 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
self.conv = torch.nn.Conv1d(3, 10, 2)
|
||||
self.b = False
|
||||
|
||||
def forward(self, box_regression, weight):
|
||||
self.b = True
|
||||
self.conv.weight = weight
|
||||
w = torch.softmax(self.conv.weight, dim=0)
|
||||
self.conv.weight = w + w
|
||||
return box_regression + self.conv.weight
|
||||
if self.b:
|
||||
return box_regression + self.conv.weight
|
||||
else:
|
||||
return box_regression - self.conv.weight
|
||||
|
||||
model = torch.jit.script(MyModule())
|
||||
weight = torch.ones(3, 2)
|
||||
|
@ -32,26 +32,37 @@ bool IsInplaceNode(const Node* n) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Node* addDummyCloneToBlock(Block* b, Value* orig_data) {
|
||||
auto graph = b->owningGraph();
|
||||
|
||||
Node* addDummyClone(
|
||||
Graph* graph,
|
||||
Value* orig_data,
|
||||
bool insertBefore,
|
||||
Node* referenceNode) {
|
||||
Node* newNode = nullptr;
|
||||
if (orig_data->type()->kind() == TypeKind::ListType) {
|
||||
newNode = graph->create(aten::list, /*num_outputs =*/1);
|
||||
newNode->addInput(orig_data);
|
||||
newNode->output()->setType(orig_data->type());
|
||||
b->prependNode(newNode);
|
||||
} else if (orig_data->type()->kind() == TypeKind::TensorType) {
|
||||
newNode = graph->create(aten::clone, /*num_outputs =*/1);
|
||||
newNode->addInput(orig_data);
|
||||
if (insertBefore)
|
||||
newNode->insertBefore(referenceNode);
|
||||
else
|
||||
referenceNode->owningBlock()->prependNode(newNode);
|
||||
} else if (
|
||||
orig_data->type()->kind() == TypeKind::TensorType ||
|
||||
orig_data->type()->kind() == TypeKind::IntType ||
|
||||
orig_data->type()->kind() == TypeKind::FloatType ||
|
||||
orig_data->type()->kind() == TypeKind::BoolType) {
|
||||
auto* noneNode = graph->create(prim::Constant);
|
||||
noneNode->output()->setType(NoneType::get());
|
||||
newNode = graph->create(aten::clone, /*num_outputs =*/1);
|
||||
newNode->addInput(orig_data);
|
||||
newNode->addInput(noneNode->output());
|
||||
newNode->output()->setType(orig_data->type());
|
||||
b->prependNode(newNode);
|
||||
if (insertBefore)
|
||||
newNode->insertBefore(referenceNode);
|
||||
else
|
||||
referenceNode->owningBlock()->prependNode(newNode);
|
||||
noneNode->insertBefore(newNode);
|
||||
} // TODO: Handle float/int attributes
|
||||
|
||||
}
|
||||
return newNode;
|
||||
}
|
||||
|
||||
@ -82,7 +93,8 @@ Value* MatchIfBlocksOutputForValue(
|
||||
|
||||
for (Block* b : outer_block->owningNode()->blocks()) {
|
||||
if (b->outputs().size() < output_size) {
|
||||
auto clone_node = addDummyCloneToBlock(b, orig_data);
|
||||
auto clone_node =
|
||||
addDummyClone(b->owningGraph(), orig_data, false, b->return_node());
|
||||
b->registerOutput(clone_node->output());
|
||||
b->outputs()
|
||||
.at(b->outputs().size() - 1)
|
||||
@ -492,7 +504,8 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) {
|
||||
<< (*it)->debugName() << "'. This changes graph semantics."
|
||||
<< std::endl;
|
||||
|
||||
Node* newNode = addDummyCloneToBlock(b, input);
|
||||
Node* newNode =
|
||||
addDummyClone(b->owningGraph(), input, false, b->return_node());
|
||||
TORCH_INTERNAL_ASSERT(nullptr != newNode);
|
||||
node->replaceInput(index, newNode->output());
|
||||
input->replaceAllUsesAfterNodeWith(node, newNode->output());
|
||||
@ -531,7 +544,7 @@ std::deque<std::string> findSubModuleAttr(
|
||||
moduleNames.push_front(node->s(attr::name));
|
||||
node = node->inputs()[0]->node();
|
||||
} else {
|
||||
return moduleNames;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Assign the inner module to attrModule.
|
||||
@ -554,31 +567,6 @@ Value* findArgumentAsInputParam(
|
||||
name);
|
||||
}
|
||||
|
||||
Node* insertCloneBeforeNode(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
Value* orig_data,
|
||||
Node* node) {
|
||||
Node* newNode = nullptr;
|
||||
if (orig_data->type()->kind() == TypeKind::ListType) {
|
||||
// Create an aten::list to clone the list in graph inputs
|
||||
newNode = graph->create(aten::list, /*num_outputs =*/1);
|
||||
newNode->addInput(orig_data);
|
||||
newNode->output()->setType(orig_data->type());
|
||||
newNode->insertBefore(node);
|
||||
} else if (orig_data->type()->kind() == TypeKind::TensorType) {
|
||||
auto* noneNode = graph->create(prim::Constant);
|
||||
noneNode->output()->setType(NoneType::get());
|
||||
newNode = graph->create(aten::clone, /*num_outputs =*/1);
|
||||
newNode->addInput(orig_data);
|
||||
|
||||
newNode->addInput(noneNode->output());
|
||||
newNode->output()->setType(orig_data->type());
|
||||
newNode->insertBefore(node);
|
||||
noneNode->insertBefore(newNode);
|
||||
} // TODO: Handle float/int attributes
|
||||
return newNode;
|
||||
}
|
||||
|
||||
Value* registerSetAttrInBlocks(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
Block* block,
|
||||
@ -694,7 +682,8 @@ void trackAndRegisterAttributesInBlocks(
|
||||
// If inside a block, keep the output value to register in block
|
||||
// output.
|
||||
auto block_ = n->owningBlock();
|
||||
Node* cloneNode = insertCloneBeforeNode(graph, n->inputs().at(1), n);
|
||||
Node* cloneNode =
|
||||
addDummyClone(block_->owningGraph(), n->inputs().at(1), true, n);
|
||||
if (block_->owningNode() &&
|
||||
(block_->owningNode()->kind() == prim::If ||
|
||||
block_->owningNode()->kind() == prim::Loop)) {
|
||||
|
@ -20,6 +20,9 @@ static constexpr char ListClose = ']';
|
||||
static constexpr char TupleOpen = '(';
|
||||
static constexpr char TupleClose = ')';
|
||||
static constexpr char Variable = 'v';
|
||||
static constexpr char Bool = 'b';
|
||||
static constexpr char Long = 'l';
|
||||
static constexpr char Double = 'd';
|
||||
static constexpr char String = 's';
|
||||
static constexpr char NoneType = 'n';
|
||||
} // namespace D
|
||||
@ -36,6 +39,12 @@ py::object cast_handle_sequence(std::vector<py::handle> objs) {
|
||||
}
|
||||
|
||||
void flatten_rec(PyObject* obj, ParsedArgs& args) {
|
||||
auto as_variable = [](at::Tensor& tensor) // Wrap tensor as Variable
|
||||
{
|
||||
PyObject* wappred_obj = THPVariable_Wrap(tensor);
|
||||
return reinterpret_cast<THPVariable*>(wappred_obj)->cdata;
|
||||
};
|
||||
|
||||
auto& structure = args.desc.structure;
|
||||
if (six::isTuple(obj)) {
|
||||
structure.push_back(D::TupleOpen);
|
||||
@ -65,6 +74,25 @@ void flatten_rec(PyObject* obj, ParsedArgs& args) {
|
||||
args.desc.structure.push_back(D::Variable);
|
||||
} else if (strcmp(THPUtils_typename(obj), "NoneType") == 0) {
|
||||
args.desc.structure.push_back(D::NoneType);
|
||||
} else if (PyBool_Check(obj)) { // Wrap integers in bool tensors
|
||||
at::Tensor tensor = scalar_to_tensor(at::Scalar(THPUtils_unpackBool(obj)));
|
||||
auto var = as_variable(tensor);
|
||||
args.vars.push_back(var);
|
||||
args.desc.metadata.emplace_back(var);
|
||||
args.desc.structure.push_back(D::Bool);
|
||||
} else if (PyLong_Check(obj)) { // Wrap integers in long tensors
|
||||
at::Tensor tensor = scalar_to_tensor(
|
||||
at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(obj))));
|
||||
auto var = as_variable(tensor);
|
||||
args.vars.push_back(var);
|
||||
args.desc.metadata.emplace_back(var);
|
||||
args.desc.structure.push_back(D::Long);
|
||||
} else if (PyFloat_Check(obj)) { // Wrap floating points in double tensors
|
||||
at::Tensor tensor = scalar_to_tensor(THPUtils_unpackDouble(obj));
|
||||
auto var = as_variable(tensor);
|
||||
args.vars.push_back(var);
|
||||
args.desc.metadata.emplace_back(var);
|
||||
args.desc.structure.push_back(D::Double);
|
||||
} else {
|
||||
std::string msg =
|
||||
"Only tuples, lists and Variables are supported as JIT inputs/outputs. "
|
||||
@ -142,6 +170,9 @@ py::object unflatten_rec(
|
||||
} else if (type == D::NoneType) {
|
||||
return py::reinterpret_borrow<py::object>(py::none());
|
||||
} else {
|
||||
// if (type == D::Long || type == D::Double || type == D::Bool ||
|
||||
// D::Variable) unwrap variables (D::Variable), or unwrap primitive types
|
||||
// (Long, Double, Bool) as variables for tracer.
|
||||
if (var_it == var_it_end)
|
||||
throw std::runtime_error("Not enough Variables given to unflatten");
|
||||
auto var = *var_it++;
|
||||
|
@ -436,10 +436,10 @@ def _model_to_graph(model, args, verbose=False,
|
||||
training=None, dynamic_axes=None):
|
||||
from torch.onnx.symbolic_helper import _export_onnx_opset_version
|
||||
# Special case for common case of passing a single Tensor
|
||||
if isinstance(args, torch.Tensor):
|
||||
if isinstance(args, (torch.Tensor, int, float, bool)):
|
||||
args = (args, )
|
||||
|
||||
if isinstance(example_outputs, torch.Tensor):
|
||||
if isinstance(example_outputs, (torch.Tensor, int, float, bool)):
|
||||
example_outputs = (example_outputs,)
|
||||
|
||||
graph, params, torch_out, module = _create_jit_graph(model, args,
|
||||
|
Reference in New Issue
Block a user