ONNX Export Argmin and Argmax ops

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17382

Differential Revision: D14338811

Pulled By: houseroad

fbshipit-source-id: be07548d8063d1aa94f1801c18137738365b85fb
This commit is contained in:
Lara Haidar-Ahmad
2019-03-06 12:05:48 -08:00
committed by Facebook Github Bot
parent 97eb139a94
commit 073634612f
4 changed files with 83 additions and 0 deletions

View File

@ -0,0 +1,53 @@
ir_version: 4
producer_name: "pytorch"
producer_version: "1.1"
graph {
node {
input: "x"
output: "1"
op_type: "ArgMax"
attribute {
name: "axis"
i: 1
type: INT
}
attribute {
name: "keepdims"
i: 0
type: INT
}
}
name: "torch-jit-export"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "1"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 10
}

View File

@ -426,6 +426,10 @@ class TestOperators(TestCase):
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.flatten(x, 1), x)
def test_argmax(self):
x = torch.randn(4, 4, requires_grad=True)
self.assertONNX(lambda x: torch.argmax(x, dim=1), x)
def test_logsoftmax(self):
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.assertONNX(nn.LogSoftmax(dim=3), x)

View File

@ -1074,6 +1074,22 @@ class TestCaffe2Backend(unittest.TestCase):
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_model_test(FlattenModel(), train=False, input=x, batch_size=BATCH_SIZE)
def test_argmax(self):
class ArgmaxModel(torch.nn.Module):
def forward(self, input):
return torch.argmax(input, dim=1)
x = torch.randn(4, 4, requires_grad=True)
self.run_model_test(ArgmaxModel(), train=False, input=x, batch_size=BATCH_SIZE)
def test_argmin(self):
class ArgminModel(torch.nn.Module):
def forward(self, input):
return torch.argmin(input, dim=1)
x = torch.randn(4, 4, requires_grad=True)
self.run_model_test(ArgminModel(), train=False, input=x, batch_size=BATCH_SIZE)
def test_reshape(self):
class ReshapeModel(torch.nn.Module):
def forward(self, input):

View File

@ -1617,3 +1617,13 @@ def flatten(g, input, start_dim, end_dim):
@parse_args('v')
def nonzero(g, input):
return g.op('NonZero', input)
@parse_args('v', 'i', 'i')
def _argmax(g, input, dim, keepdim):
return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim)
@parse_args('v', 'i', 'i')
def _argmin(g, input, dim, keepdim):
return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim)