mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66144 * Add logic and tests * minor edits * Eliminate expand ops * Fix flake and editing * Modified errant message * Add overrun check * Add overrun descriptions * Remove emptyline Test Plan: Imported from OSS Reviewed By: jansel Differential Revision: D31424095 fbshipit-source-id: 5b8ef6ac21c32d43c3dbc8e51e1ef30bffb19c25
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b18c298f24
commit
a0fc14c20f
@ -5329,6 +5329,63 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
self.run_test(TensorFactory(), x, test_with_inputs=[another_x],
|
||||
input_names=["input_1"], dynamic_axes={"input_1": [0, 1, 2]})
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_diagonal(self):
|
||||
class DiagonalModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.diagonal(x)
|
||||
|
||||
x = torch.randn(2, 4, 5, 2)
|
||||
# Other test inputs to test dynamic behavior
|
||||
another_x = torch.randn(5, 6, 7, 8)
|
||||
self.run_test(DiagonalModel(), x, test_with_inputs=[another_x],
|
||||
input_names=["input_1"],
|
||||
dynamic_axes={"input_1": [0, 1, 2, 3]})
|
||||
|
||||
class DiagonalModelNegOffset(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.diagonal(x, offset=-1)
|
||||
|
||||
x = torch.randn(2, 4, 5, 2)
|
||||
# Other test inputs to test dynamic behavior
|
||||
another_x = torch.randn(5, 6, 7, 8)
|
||||
self.run_test(DiagonalModelNegOffset(), x, test_with_inputs=[another_x],
|
||||
input_names=["input_1"],
|
||||
dynamic_axes={"input_1": [0, 1, 2, 3]})
|
||||
|
||||
class DiagonalModelPosOffset(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.diagonal(x, offset=1)
|
||||
|
||||
x = torch.randn(2, 4, 5, 2)
|
||||
# Other test inputs to test dynamic behavior
|
||||
another_x = torch.randn(5, 6, 7, 8)
|
||||
self.run_test(DiagonalModelPosOffset(), x, test_with_inputs=[another_x],
|
||||
input_names=["input_1"],
|
||||
dynamic_axes={"input_1": [0, 1, 2, 3]})
|
||||
|
||||
class DiagonalModelWithDims(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.diagonal(x, offset=-1, dim1=1, dim2=2)
|
||||
|
||||
x = torch.randn(2, 4, 5, 2)
|
||||
# Other test inputs to test dynamic behavior
|
||||
another_x = torch.randn(5, 6, 7, 8)
|
||||
self.run_test(DiagonalModelWithDims(), x, test_with_inputs=[another_x],
|
||||
input_names=["input_1"],
|
||||
dynamic_axes={"input_1": [0, 1, 2, 3]})
|
||||
|
||||
class DiagonalModelOffsetOverrun(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.diagonal(x, offset=-2), torch.diagonal(x, offset=5)
|
||||
|
||||
x = torch.randn(2, 4, 5, 2)
|
||||
# Other test inputs to test dynamic behavior
|
||||
another_x = torch.randn(5, 6, 7, 8)
|
||||
self.run_test(DiagonalModelOffsetOverrun(), x, test_with_inputs=[another_x],
|
||||
input_names=["input_1"],
|
||||
dynamic_axes={"input_1": [0, 1, 2, 3]})
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_inplace_zero(self):
|
||||
class Zero_(torch.nn.Module):
|
||||
|
@ -5,7 +5,8 @@
|
||||
import torch
|
||||
import torch.onnx.symbolic_helper as sym_help
|
||||
from torch.onnx.symbolic_helper import parse_args, _unimplemented
|
||||
from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero, expand
|
||||
from torch.onnx.symbolic_opset9 import (overload_by_arg_count, _maybe_cast_reduce_op_input,
|
||||
nonzero, expand, zeros, ones, size)
|
||||
from torch.onnx.symbolic_opset11 import unsqueeze
|
||||
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
|
||||
|
||||
@ -318,3 +319,88 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None):
|
||||
loop_out = loop.node().output()
|
||||
loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim)
|
||||
return loop_out
|
||||
|
||||
|
||||
@parse_args("v", "i", "i", "i")
|
||||
def diagonal(g, self, offset, dim1, dim2):
|
||||
dim1_size = size(g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1])))
|
||||
dim2_size = size(g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2])))
|
||||
|
||||
# Create appropriate mask
|
||||
mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0)
|
||||
mask = zeros(g, mask_shape, None, None, None)
|
||||
mask = g.op("EyeLike", mask, k_i=offset)
|
||||
|
||||
# dim1 and dim2 appended as a dimension at the end of the shape
|
||||
rank = sym_help._get_tensor_rank(self)
|
||||
if rank is not None:
|
||||
axes = list(range(rank))
|
||||
axes.remove(dim1)
|
||||
axes.remove(dim2)
|
||||
self = g.op("Transpose", self, perm_i=axes + [dim1, dim2])
|
||||
else:
|
||||
return _unimplemented("diagonal", "unknown input rank")
|
||||
|
||||
# Multiply input and mask to calculate values along diagonal
|
||||
# The mask consists of one values where diagonal values are to be calculated
|
||||
# For example:
|
||||
# [[1.1, 1.2, 1.3], * [[1, 0, 0] = [[1.1, 0, 0],
|
||||
# [2.1, 2.2, 2.3], [0, 1, 0] [0, 2.2, 0],
|
||||
# [3.1, 3.2, 3.3]] [0, 0, 1]] [0, 0, 3.3]]
|
||||
result = g.op("Mul", self, mask)
|
||||
result = sym_help._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0)
|
||||
|
||||
# Calculate gather indices based on offset and dims
|
||||
# If offset is greater than zero, set offset to zero as this aids in
|
||||
# calculation of selection window
|
||||
offset_op = g.op("Constant", value_t=torch.LongTensor([offset]))
|
||||
if offset >= 0:
|
||||
diag_size = g.op("Max", g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)),
|
||||
g.op("Constant", value_t=torch.LongTensor([0])))
|
||||
offset = 0
|
||||
else:
|
||||
diag_size = g.op("Max", g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size),
|
||||
g.op("Constant", value_t=torch.LongTensor([0])))
|
||||
diag_size = g.op("Concat", diag_size, axis_i=0)
|
||||
|
||||
# Calculate which diagonal values to select
|
||||
# For example, in cases with offsets:
|
||||
# [[0, 1.1, 0]
|
||||
# [0, 0, 2.2]]
|
||||
# we need to select the last two columns, so we create a tensor
|
||||
# with all columns that are to be selected
|
||||
# So in this example, it is [1, 2]
|
||||
select_window_ones_fill = ones(g, diag_size, 4, None, None)
|
||||
select_window = g.op("CumSum", select_window_ones_fill, g.op("Constant", value_t=torch.LongTensor([0])))
|
||||
select_window = g.op("Add", select_window, g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])))
|
||||
|
||||
gather_shape = [size(g, result,
|
||||
dim=g.op("Constant", value_t=torch.LongTensor([axis]))) for axis in list(range(rank))[:-2]]
|
||||
gather_shape.append(diag_size)
|
||||
gather_shape = g.op("Concat", *gather_shape, axis_i=0)
|
||||
gather_indices = zeros(g, gather_shape, 4, None, None)
|
||||
|
||||
# There might be cases where offset value is greater than number of rows/columns
|
||||
# and might cause the diagonal to overrun and as a result of this, diag_size would be zero.
|
||||
# For example, if
|
||||
# offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows)
|
||||
# diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above
|
||||
# Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0
|
||||
# In cases without diagonal overrun, we select the appropriate rows/columns along which we
|
||||
# are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has
|
||||
# the dimension of the row/column where overrun occurred as 0-dim, as we are essentially
|
||||
# returning an empty tensor
|
||||
overrun_cond = g.op("Not", g.op("Equal", diag_size, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))))
|
||||
if_op = g.op("If", overrun_cond)
|
||||
if_node = if_op.node()
|
||||
|
||||
if_block = _add_block(if_node)
|
||||
gather_indices_if_block = if_block.op("Add", gather_indices, select_window)
|
||||
gather_indices_if_block = sym_help._unsqueeze_helper(if_block, gather_indices_if_block, [rank - 1])
|
||||
final_non_overrun_ = if_block.op("GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2)
|
||||
_add_output_to_block(if_block, final_non_overrun_)
|
||||
|
||||
else_block = _add_block(if_node)
|
||||
final_overrun_ = zeros(else_block, gather_shape, 6, None, None)
|
||||
_add_output_to_block(else_block, final_overrun_)
|
||||
return if_op
|
||||
|
Reference in New Issue
Block a user