mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH` All changes but the ones to `.clang-tidy` are generated using following script: ``` for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008 Reviewed By: driazati, r-barnes Differential Revision: D29838584 Pulled By: malfet fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
147 lines
4.0 KiB
C++
147 lines
4.0 KiB
C++
#include "caffe2/operators/slice_op.h"
|
|
#include "caffe2/utils/math.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
REGISTER_CPU_OPERATOR(Slice, SliceOp<CPUContext>);
|
|
REGISTER_CPU_GRADIENT_OPERATOR(SliceGradient, SliceGradientOp<CPUContext>);
|
|
|
|
OPERATOR_SCHEMA(Slice)
|
|
.NumInputs(1, 3)
|
|
.NumOutputs(1)
|
|
.DisallowInputFillers() // the filler cannot be enabled without output dims
|
|
.SetDoc(R"DOC(
|
|
Produces a slice of the input tensor.
|
|
|
|
- Currently, only slicing in a single dimension is supported.
|
|
|
|
- Start and end indices are either passed as two 1D input tensors or using the `starts` and `ends` arguments.
|
|
|
|
- If a negative value is passed for any of the start or end indices, it represents |value| - 1 elements before the end of that dimension. End indices are non-inclusive unless negative (end index -1 means up to and including the last element).
|
|
|
|
Github Links:
|
|
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/slice_op.cc
|
|
|
|
<details>
|
|
|
|
<summary> <b>Example</b> </summary>
|
|
|
|
**Code**
|
|
|
|
```
|
|
|
|
workspace.ResetWorkspace()
|
|
|
|
op = core.CreateOperator(
|
|
"Slice",
|
|
["X"],
|
|
["Y"],
|
|
starts=(0,1),
|
|
ends=(-1,3)
|
|
)
|
|
|
|
workspace.FeedBlob("X", np.array([[1,2,3,4],[5,6,7,8]]))
|
|
print("X:", workspace.FetchBlob("X"))
|
|
workspace.RunOperatorOnce(op)
|
|
print("Y:", workspace.FetchBlob("Y"))
|
|
|
|
```
|
|
|
|
**Result**
|
|
|
|
```
|
|
|
|
X:
|
|
[[1 2 3 4]
|
|
[5 6 7 8]]
|
|
Y:
|
|
[[2 3]
|
|
[6 7]]
|
|
|
|
```
|
|
|
|
</details>
|
|
|
|
)DOC")
|
|
.Input(0, "X", "(*Tensor*): tensor to extract slices from")
|
|
.Input(
|
|
1,
|
|
"starts",
|
|
"(*Tensor`<int>`*): 1D tensor of start-indices for each dimension of data (dimensions following the sliced one might be omitted)")
|
|
.Input(
|
|
2,
|
|
"ends",
|
|
"(*Tensor`<int>`*): 1D tensor of end-indices for each dimension of data (dimensions following the sliced one might be omitted)")
|
|
.Arg("starts", "(*Tuple(int)*): list of starting indices")
|
|
.Arg("ends", "(*Tuple(int)*): list of ending indices")
|
|
.TensorInferenceFunction([](const OperatorDef& def,
|
|
const vector<TensorShape>& in) {
|
|
if (in.size() > 1) {
|
|
// Cannot compute shape inference when the splits are defined
|
|
// in data.
|
|
return vector<TensorShape>();
|
|
}
|
|
auto const& data = in[0];
|
|
|
|
ArgumentHelper helper(def);
|
|
auto starts = helper.GetRepeatedArgument<int>("starts", vector<int>());
|
|
auto ends = helper.GetRepeatedArgument<int>("ends", vector<int>());
|
|
vector<int> dst_sizes(data.dims_size());
|
|
|
|
for (int i = 0; i < data.dims_size(); ++i) {
|
|
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
|
if (i >= starts.size()) {
|
|
dst_sizes[i] = data.dims(i);
|
|
continue;
|
|
}
|
|
if (data.dims(i) > 0) {
|
|
auto start = starts[i];
|
|
auto end = ends[i];
|
|
if (start < 0) {
|
|
start = data.dims(i) + 1 + start;
|
|
}
|
|
if (end < 0) {
|
|
end = data.dims(i) + 1 + end;
|
|
}
|
|
dst_sizes[i] = end - start;
|
|
} else {
|
|
dst_sizes[i] = 0;
|
|
}
|
|
}
|
|
return vector<TensorShape>{
|
|
CreateTensorShape(dst_sizes, data.data_type())};
|
|
})
|
|
.Output(0, "Y", "(*Tensor*): sliced output tensor")
|
|
.InheritOnnxSchema();
|
|
|
|
GRADIENT_OPERATOR_SCHEMA(SliceGradient)
|
|
.TensorInferenceFunction([](const OperatorDef& /*def*/,
|
|
const vector<TensorShape>& in) {
|
|
vector<TensorShape> out(1);
|
|
out.at(0) = in.at(0);
|
|
return out;
|
|
});
|
|
|
|
namespace {
|
|
struct GetSliceGradient : public GradientMakerBase {
|
|
using GradientMakerBase::GradientMakerBase;
|
|
vector<OperatorDef> GetGradientDefs() override {
|
|
if (def_.input_size() > 1) {
|
|
return vector<OperatorDef>{CreateOperatorDef(
|
|
"SliceGradient",
|
|
"",
|
|
std::vector<string>{I(0), I(1), I(2), GO(0)},
|
|
std::vector<string>{GI(0)})};
|
|
} else {
|
|
return vector<OperatorDef>{CreateOperatorDef(
|
|
"SliceGradient",
|
|
"",
|
|
std::vector<string>{I(0), GO(0)},
|
|
std::vector<string>{GI(0)})};
|
|
}
|
|
}
|
|
};
|
|
} // namespace
|
|
REGISTER_GRADIENT(Slice, GetSliceGradient);
|
|
} // namespace caffe2
|