Adapt ONNX Slice op changes (#3316)

This commit is contained in:
bddppq
2017-10-27 21:03:29 -07:00
committed by Edward Z. Yang
parent dc6c9e8df8
commit ac8f56656d

View File

@ -26,15 +26,13 @@ class Index(Function):
# We use "Slice" to get the index-th element in i,
# Then we reduce the dimension using "Reshape".
if isinstance(index, int_classes):
axes = g.constant(0, [1], "int")
starts = g.constant(index, [1], "long")
ends = g.constant(index + 1, [1], "long")
slice_node = g.op("Slice", i, axes, starts, ends)
slice_node = g.op("Slice", i,
axes_i=[0],
starts_i=[index],
ends_i=[index + 1])
return g.op("Squeeze", slice_node, axes_i=[0])
elif isinstance(index, tuple):
dims = i.type().sizes()
axes_ten = torch.IntTensor([idx for idx in range(len(index))])
axes = g.op("Constant", value_t=axes_ten)
starts_list = []
ends_list = []
squeeze_indices = []
@ -67,11 +65,10 @@ class Index(Function):
if index[idx].step is not None:
raise ValueError("Strided slice is not supported at this time")
starts_ten = torch.LongTensor(starts_list)
starts = g.op("Constant", value_t=starts_ten)
ends_ten = torch.LongTensor(ends_list)
ends = g.op("Constant", value_t=ends_ten)
slice_node = g.op("Slice", i, axes, starts, ends)
slice_node = g.op("Slice", i,
axes_i=list(range(len(index))),
starts_i=starts_list,
ends_i=ends_list)
if squeeze_indices:
return g.op('Squeeze', slice_node, axes_i=squeeze_indices)
else: