mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adapt ONNX Slice op changes (#3316)
This commit is contained in:
@ -26,15 +26,13 @@ class Index(Function):
|
||||
# We use "Slice" to get the index-th element in i,
|
||||
# Then we reduce the dimension using "Reshape".
|
||||
if isinstance(index, int_classes):
|
||||
axes = g.constant(0, [1], "int")
|
||||
starts = g.constant(index, [1], "long")
|
||||
ends = g.constant(index + 1, [1], "long")
|
||||
slice_node = g.op("Slice", i, axes, starts, ends)
|
||||
slice_node = g.op("Slice", i,
|
||||
axes_i=[0],
|
||||
starts_i=[index],
|
||||
ends_i=[index + 1])
|
||||
return g.op("Squeeze", slice_node, axes_i=[0])
|
||||
elif isinstance(index, tuple):
|
||||
dims = i.type().sizes()
|
||||
axes_ten = torch.IntTensor([idx for idx in range(len(index))])
|
||||
axes = g.op("Constant", value_t=axes_ten)
|
||||
starts_list = []
|
||||
ends_list = []
|
||||
squeeze_indices = []
|
||||
@ -67,11 +65,10 @@ class Index(Function):
|
||||
if index[idx].step is not None:
|
||||
raise ValueError("Strided slice is not supported at this time")
|
||||
|
||||
starts_ten = torch.LongTensor(starts_list)
|
||||
starts = g.op("Constant", value_t=starts_ten)
|
||||
ends_ten = torch.LongTensor(ends_list)
|
||||
ends = g.op("Constant", value_t=ends_ten)
|
||||
slice_node = g.op("Slice", i, axes, starts, ends)
|
||||
slice_node = g.op("Slice", i,
|
||||
axes_i=list(range(len(index))),
|
||||
starts_i=starts_list,
|
||||
ends_i=ends_list)
|
||||
if squeeze_indices:
|
||||
return g.op('Squeeze', slice_node, axes_i=squeeze_indices)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user