mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
ONNX Export Argmin and Argmax ops
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17382 Differential Revision: D14338811 Pulled By: houseroad fbshipit-source-id: be07548d8063d1aa94f1801c18137738365b85fb
This commit is contained in:
committed by
Facebook Github Bot
parent
97eb139a94
commit
073634612f
53
test/onnx/expect/TestOperators.test_argmax.expect
Normal file
53
test/onnx/expect/TestOperators.test_argmax.expect
Normal file
@ -0,0 +1,53 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.1"
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
output: "1"
|
||||
op_type: "ArgMax"
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: 1
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "keepdims"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "torch-jit-export"
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "1"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 10
|
||||
}
|
@ -426,6 +426,10 @@ class TestOperators(TestCase):
|
||||
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||
self.assertONNX(lambda x: torch.flatten(x, 1), x)
|
||||
|
||||
def test_argmax(self):
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
self.assertONNX(lambda x: torch.argmax(x, dim=1), x)
|
||||
|
||||
def test_logsoftmax(self):
|
||||
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||
self.assertONNX(nn.LogSoftmax(dim=3), x)
|
||||
|
@ -1074,6 +1074,22 @@ class TestCaffe2Backend(unittest.TestCase):
|
||||
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||
self.run_model_test(FlattenModel(), train=False, input=x, batch_size=BATCH_SIZE)
|
||||
|
||||
def test_argmax(self):
|
||||
class ArgmaxModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return torch.argmax(input, dim=1)
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
self.run_model_test(ArgmaxModel(), train=False, input=x, batch_size=BATCH_SIZE)
|
||||
|
||||
def test_argmin(self):
|
||||
class ArgminModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return torch.argmin(input, dim=1)
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
self.run_model_test(ArgminModel(), train=False, input=x, batch_size=BATCH_SIZE)
|
||||
|
||||
def test_reshape(self):
|
||||
class ReshapeModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
|
@ -1617,3 +1617,13 @@ def flatten(g, input, start_dim, end_dim):
|
||||
@parse_args('v')
|
||||
def nonzero(g, input):
|
||||
return g.op('NonZero', input)
|
||||
|
||||
|
||||
@parse_args('v', 'i', 'i')
|
||||
def _argmax(g, input, dim, keepdim):
|
||||
return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim)
|
||||
|
||||
|
||||
@parse_args('v', 'i', 'i')
|
||||
def _argmin(g, input, dim, keepdim):
|
||||
return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim)
|
||||
|
Reference in New Issue
Block a user