bug fix 19374 - fix for upsample export

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20116

Differential Revision: D15256899

Pulled By: houseroad

fbshipit-source-id: cf0dfd679d528fbb77f483e23071f4a96fb27091
This commit is contained in:
Peyman Manikashani
2019-05-23 14:45:22 -07:00
committed by Facebook Github Bot
parent 48bf7b9be8
commit 93d5503f34
5 changed files with 609 additions and 15 deletions

View File

@ -0,0 +1,272 @@
ir_version: 4
producer_name: "pytorch"
producer_version: "1.1"
graph {
node {
output: "1"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
raw_data: "\002\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "input"
output: "2"
op_type: "Shape"
}
node {
input: "2"
input: "1"
output: "3"
op_type: "Gather"
attribute {
name: "axis"
i: 0
type: INT
}
}
node {
output: "4"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
raw_data: "\002\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "3"
input: "4"
output: "5"
op_type: "Mul"
}
node {
input: "5"
output: "6"
op_type: "Floor"
}
node {
output: "7"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
raw_data: "\003\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "input"
output: "8"
op_type: "Shape"
}
node {
input: "8"
input: "7"
output: "9"
op_type: "Gather"
attribute {
name: "axis"
i: 0
type: INT
}
}
node {
output: "10"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
raw_data: "\002\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "9"
input: "10"
output: "11"
op_type: "Mul"
}
node {
input: "11"
output: "12"
op_type: "Floor"
}
node {
input: "6"
output: "13"
op_type: "Unsqueeze"
attribute {
name: "axes"
ints: 0
type: INTS
}
}
node {
input: "12"
output: "14"
op_type: "Unsqueeze"
attribute {
name: "axes"
ints: 0
type: INTS
}
}
node {
input: "13"
input: "14"
output: "15"
op_type: "Concat"
attribute {
name: "axis"
i: 0
type: INT
}
}
node {
output: "16"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 2
data_type: 1
raw_data: "\000\000\200?\000\000\200?"
}
type: TENSOR
}
}
node {
input: "15"
output: "17"
op_type: "Cast"
attribute {
name: "to"
i: 1
type: INT
}
}
node {
input: "input"
output: "18"
op_type: "Shape"
}
node {
input: "18"
output: "19"
op_type: "Slice"
attribute {
name: "axes"
ints: 0
type: INTS
}
attribute {
name: "ends"
ints: 4
type: INTS
}
attribute {
name: "starts"
ints: 2
type: INTS
}
}
node {
input: "19"
output: "20"
op_type: "Cast"
attribute {
name: "to"
i: 1
type: INT
}
}
node {
input: "17"
input: "20"
output: "21"
op_type: "Div"
}
node {
input: "16"
input: "21"
output: "22"
op_type: "Concat"
attribute {
name: "axis"
i: 0
type: INT
}
}
node {
input: "input"
input: "22"
output: "23"
op_type: "Upsample"
attribute {
name: "mode"
s: "linear"
type: STRING
}
}
name: "torch-jit-export"
input {
name: "input"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "23"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 6
}
dim {
dim_value: 8
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@ -0,0 +1,272 @@
ir_version: 4
producer_name: "pytorch"
producer_version: "1.1"
graph {
node {
output: "1"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
raw_data: "\002\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "input"
output: "2"
op_type: "Shape"
}
node {
input: "2"
input: "1"
output: "3"
op_type: "Gather"
attribute {
name: "axis"
i: 0
type: INT
}
}
node {
output: "4"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
raw_data: "\002\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "3"
input: "4"
output: "5"
op_type: "Mul"
}
node {
input: "5"
output: "6"
op_type: "Floor"
}
node {
output: "7"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
raw_data: "\003\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "input"
output: "8"
op_type: "Shape"
}
node {
input: "8"
input: "7"
output: "9"
op_type: "Gather"
attribute {
name: "axis"
i: 0
type: INT
}
}
node {
output: "10"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
raw_data: "\002\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "9"
input: "10"
output: "11"
op_type: "Mul"
}
node {
input: "11"
output: "12"
op_type: "Floor"
}
node {
input: "6"
output: "13"
op_type: "Unsqueeze"
attribute {
name: "axes"
ints: 0
type: INTS
}
}
node {
input: "12"
output: "14"
op_type: "Unsqueeze"
attribute {
name: "axes"
ints: 0
type: INTS
}
}
node {
input: "13"
input: "14"
output: "15"
op_type: "Concat"
attribute {
name: "axis"
i: 0
type: INT
}
}
node {
output: "16"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 2
data_type: 1
raw_data: "\000\000\200?\000\000\200?"
}
type: TENSOR
}
}
node {
input: "15"
output: "17"
op_type: "Cast"
attribute {
name: "to"
i: 1
type: INT
}
}
node {
input: "input"
output: "18"
op_type: "Shape"
}
node {
input: "18"
output: "19"
op_type: "Slice"
attribute {
name: "axes"
ints: 0
type: INTS
}
attribute {
name: "ends"
ints: 4
type: INTS
}
attribute {
name: "starts"
ints: 2
type: INTS
}
}
node {
input: "19"
output: "20"
op_type: "Cast"
attribute {
name: "to"
i: 1
type: INT
}
}
node {
input: "17"
input: "20"
output: "21"
op_type: "Div"
}
node {
input: "16"
input: "21"
output: "22"
op_type: "Concat"
attribute {
name: "axis"
i: 0
type: INT
}
}
node {
input: "input"
input: "22"
output: "23"
op_type: "Upsample"
attribute {
name: "mode"
s: "nearest"
type: STRING
}
}
name: "torch-jit-export"
input {
name: "input"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "23"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 6
}
dim {
dim_value: 8
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@ -8,7 +8,6 @@ import torch.nn as nn
import itertools
import io
import unittest
import inspect
import glob
import os
@ -477,8 +476,11 @@ class TestOperators(TestCase):
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.assertONNX(lambda x: x.norm(p=2, dim=2), (x))
@unittest.skip("Temporary - waiting for https://github.com/onnx/onnx/pull/1773.")
def test_upsample(self):
def test_upsample_nearest(self):
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.assertONNX(lambda x: nn.functional.interpolate(x, scale_factor=2., mode='nearest'), x)
def test_upsample_bilinear(self):
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.assertONNX(lambda x: nn.functional.interpolate(x, scale_factor=2., mode='bilinear'), x)

View File

@ -978,6 +978,22 @@ class TestCaffe2Backend(unittest.TestCase):
self.run_model_test(model, train=False, input=(x),
batch_size=BATCH_SIZE, use_gpu=False)
def test_interpolate_upsample_dynamic_sizes(self):
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
size = [v * 2 for v in x.size()[2:]]
return nn.functional.interpolate(x,
size=size,
mode='nearest')
x = torch.randn(1, 2, 3, 4, requires_grad=True)
model = MyModel()
self.run_model_test(model, train=False, input=(x),
batch_size=BATCH_SIZE, use_gpu=False)
def test_repeat_dim_overflow(self):
class MyModel(torch.nn.Module):
def __init__(self):

View File

@ -662,25 +662,57 @@ replication_pad2d = replication_pad
replication_pad3d = replication_pad
@parse_args('v', 'is')
def upsample_nearest2d(g, input, output_size):
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale,
width_scale]))
output_size = sym_help._maybe_get_const(output_size, 'is')
if sym_help._is_value(output_size):
offset = 2
input_length = len(input.type().sizes())
offsets = g.op("Constant", value_t=torch.tensor([1. for i in range(offset)]))
dividend = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Float"])
divisor = g.op(
"Slice",
g.op("Shape", input),
axes_i=[0],
ends_i=[input_length],
starts_i=[offset]
)
divisor = g.op("Cast", divisor, to_i=sym_help.cast_pytorch_to_onnx["Float"])
scale_dims = g.op("Div", dividend, divisor)
scales = g.op("Concat", offsets, scale_dims, axis_i=0)
else:
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale, width_scale]))
return g.op("Upsample", input, scales,
mode_s="nearest")
return g.op("Upsample", input, scales, mode_s="nearest")
@parse_args('v', 'is', 'i')
def upsample_bilinear2d(g, input, output_size, align_corners):
align_corners = sym_help._maybe_get_scalar(align_corners)
if align_corners:
return _unimplemented("upsample_bilinear2d", "align_corners == True")
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale,
width_scale]))
output_size = sym_help._maybe_get_const(output_size, 'is')
if sym_help._is_value(output_size):
offset = 2
input_length = len(input.type().sizes())
offsets = g.op("Constant", value_t=torch.tensor([1. for i in range(offset)]))
dividend = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Float"])
divisor = g.op(
"Slice",
g.op("Shape", input),
axes_i=[0],
ends_i=[input_length],
starts_i=[offset]
)
divisor = g.op("Cast", divisor, to_i=sym_help.cast_pytorch_to_onnx["Float"])
scale_dims = g.op("Div", dividend, divisor)
scales = g.op("Concat", offsets, scale_dims, axis_i=0)
else:
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale,
width_scale]))
return g.op("Upsample", input, scales,
mode_s="linear")