mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
gpu sequence op step 1: clean headers
Summary: @public This has no functionality changes yet, only cleaning up the sequence_op file so that the header is context-independent and I will implement the gpu parts separately. Reviewed By: pietern Differential Revision: D4777140 fbshipit-source-id: 9b4aea6c36f06a64a53e235a125cd3477d54a045
This commit is contained in:
committed by
Facebook Github Bot
parent
58f7f2b441
commit
8efb762fcd
@ -1,385 +1,288 @@
|
|||||||
|
#include "caffe2/operators/sequence_ops.h"
|
||||||
#include "caffe2/core/operator.h"
|
#include "caffe2/core/operator.h"
|
||||||
#include "caffe2/core/tensor.h"
|
#include "caffe2/core/tensor.h"
|
||||||
|
|
||||||
namespace caffe2 {
|
namespace caffe2 {
|
||||||
namespace {
|
|
||||||
|
|
||||||
class GatherPaddingOp final : public Operator<CPUContext> {
|
template <>
|
||||||
public:
|
template <typename T>
|
||||||
GatherPaddingOp(const OperatorDef& operator_def, Workspace* ws)
|
bool GatherPaddingOp<CPUContext>::DoRunWithType() {
|
||||||
: Operator(operator_def, ws),
|
const auto& in = Input(0);
|
||||||
startPaddingWidth_(
|
CAFFE_ENFORCE_GE(in.ndim(), 1);
|
||||||
OperatorBase::GetSingleArgument<int>("padding_width", 1)),
|
const int32_t outer_size = in.dims()[0];
|
||||||
endPaddingWidth_(
|
const auto block_size = std::accumulate(
|
||||||
OperatorBase::GetSingleArgument<int>("end_padding_width", -1)) {
|
in.dims().begin() + 1, in.dims().end(), 1, std::multiplies<TIndex>());
|
||||||
CAFFE_ENFORCE_GE(startPaddingWidth_, 0);
|
const auto pad_width = startPaddingWidth_ + endPaddingWidth_;
|
||||||
if (endPaddingWidth_ < 0) {
|
|
||||||
endPaddingWidth_ = startPaddingWidth_;
|
// if no lengths is provided, assume it is a single full-span entry
|
||||||
}
|
const int32_t* lengths_ptr = &outer_size;
|
||||||
|
int64_t lengths_size = 1;
|
||||||
|
if (InputSize() > 1) {
|
||||||
|
const auto& lengths = Input(1);
|
||||||
|
lengths_ptr = lengths.data<int32_t>();
|
||||||
|
lengths_size = lengths.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool RunOnDevice() override {
|
std::vector<TIndex> padShape(in.dims().begin() + 1, in.dims().end());
|
||||||
if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
|
// output will contain accumulator over paddings
|
||||||
Output(0)->Resize(std::vector<TIndex>(0));
|
Output(0)->Resize(padShape);
|
||||||
if (OutputSize() == 2) {
|
T* padding_start_ptr = Output(0)->template mutable_data<T>();
|
||||||
Output(1)->Resize(std::vector<TIndex>(0));
|
memset(padding_start_ptr, 0, sizeof(T) * block_size);
|
||||||
}
|
|
||||||
return true;
|
// if no end_padding is provided, assume it's the same as start_padding
|
||||||
}
|
T* padding_end_ptr = padding_start_ptr;
|
||||||
return DispatchHelper<TensorTypes<float, double, int, int64_t, bool>>::call(
|
if (OutputSize() == 2) {
|
||||||
this, Input(0));
|
Output(1)->Resize(padShape);
|
||||||
|
padding_end_ptr = Output(1)->template mutable_data<T>();
|
||||||
|
memset(padding_end_ptr, 0, sizeof(T) * block_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
const auto* in_ptr = in.template data<T>();
|
||||||
bool DoRunWithType() {
|
int64_t total_length = 0;
|
||||||
const auto& in = Input(0);
|
for (int i = 0; i < lengths_size; ++i) {
|
||||||
CAFFE_ENFORCE_GE(in.ndim(), 1);
|
// check total length consistency
|
||||||
const int32_t outer_size = in.dims()[0];
|
const auto length = lengths_ptr[i];
|
||||||
const auto block_size = std::accumulate(
|
total_length += length;
|
||||||
in.dims().begin() + 1, in.dims().end(), 1, std::multiplies<TIndex>());
|
CAFFE_ENFORCE_LE(total_length, outer_size);
|
||||||
const auto pad_width = startPaddingWidth_ + endPaddingWidth_;
|
|
||||||
|
|
||||||
// if no lengths is provided, assume it is a single full-span entry
|
// accumulate start paddings
|
||||||
const int32_t* lengths_ptr = &outer_size;
|
for (int j = 0; j < startPaddingWidth_; ++j) {
|
||||||
int64_t lengths_size = 1;
|
for (int k = 0; k < block_size; ++k) {
|
||||||
if (InputSize() > 1) {
|
padding_start_ptr[k] += in_ptr[k];
|
||||||
const auto& lengths = Input(1);
|
|
||||||
lengths_ptr = lengths.data<int32_t>();
|
|
||||||
lengths_size = lengths.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<TIndex> padShape(in.dims().begin() + 1, in.dims().end());
|
|
||||||
// output will contain accumulator over paddings
|
|
||||||
Output(0)->Resize(padShape);
|
|
||||||
T* padding_start_ptr = Output(0)->mutable_data<T>();
|
|
||||||
memset(padding_start_ptr, 0, sizeof(T) * block_size);
|
|
||||||
|
|
||||||
// if no end_padding is provided, assume it's the same as start_padding
|
|
||||||
T* padding_end_ptr = padding_start_ptr;
|
|
||||||
if (OutputSize() == 2) {
|
|
||||||
Output(1)->Resize(padShape);
|
|
||||||
padding_end_ptr = Output(1)->mutable_data<T>();
|
|
||||||
memset(padding_end_ptr, 0, sizeof(T) * block_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto* in_ptr = in.data<T>();
|
|
||||||
int64_t total_length = 0;
|
|
||||||
for (int i = 0; i < lengths_size; ++i) {
|
|
||||||
// check total length consistency
|
|
||||||
const auto length = lengths_ptr[i];
|
|
||||||
total_length += length;
|
|
||||||
CAFFE_ENFORCE_LE(total_length, outer_size);
|
|
||||||
|
|
||||||
// accumulate start paddings
|
|
||||||
for (int j = 0; j < startPaddingWidth_; ++j) {
|
|
||||||
for (int k = 0; k < block_size; ++k) {
|
|
||||||
padding_start_ptr[k] += in_ptr[k];
|
|
||||||
}
|
|
||||||
in_ptr += block_size;
|
|
||||||
}
|
|
||||||
in_ptr += block_size * (length - pad_width);
|
|
||||||
// accumulate end paddings
|
|
||||||
for (int j = 0; j < endPaddingWidth_; ++j) {
|
|
||||||
for (int k = 0; k < block_size; ++k) {
|
|
||||||
padding_end_ptr[k] += in_ptr[k];
|
|
||||||
}
|
|
||||||
in_ptr += block_size;
|
|
||||||
}
|
}
|
||||||
|
in_ptr += block_size;
|
||||||
}
|
}
|
||||||
|
in_ptr += block_size * (length - pad_width);
|
||||||
|
// accumulate end paddings
|
||||||
|
for (int j = 0; j < endPaddingWidth_; ++j) {
|
||||||
|
for (int k = 0; k < block_size; ++k) {
|
||||||
|
padding_end_ptr[k] += in_ptr[k];
|
||||||
|
}
|
||||||
|
in_ptr += block_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
template <typename T>
|
||||||
|
bool RemovePaddingOp<CPUContext>::DoRunWithType() {
|
||||||
|
const auto& in = Input(0);
|
||||||
|
CAFFE_ENFORCE_GE(in.ndim(), 1);
|
||||||
|
const int32_t outer_size = in.dims()[0];
|
||||||
|
const auto block_size = std::accumulate(
|
||||||
|
in.dims().begin() + 1, in.dims().end(), 1, std::multiplies<TIndex>());
|
||||||
|
const auto pad_width = startPaddingWidth_ + endPaddingWidth_;
|
||||||
|
|
||||||
|
// if no lengths is provided, assume it is a single full-span entry
|
||||||
|
const int32_t* lengths_ptr = &outer_size;
|
||||||
|
int64_t lengths_size = 1;
|
||||||
|
if (InputSize() > 1) {
|
||||||
|
const auto& lengths = Input(1);
|
||||||
|
lengths_ptr = lengths.data<int32_t>();
|
||||||
|
lengths_size = lengths.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* out = Output(0);
|
||||||
|
{
|
||||||
|
auto out_dims = in.dims();
|
||||||
|
out_dims[0] -= pad_width * lengths_size;
|
||||||
|
out->Resize(std::move(out_dims));
|
||||||
|
}
|
||||||
|
const auto* in_ptr = in.template data<T>();
|
||||||
|
auto* out_ptr = out->template mutable_data<T>();
|
||||||
|
int64_t total_length = 0;
|
||||||
|
for (int i = 0; i < lengths_size; ++i) {
|
||||||
|
// check that total length is consistent
|
||||||
|
const auto length = lengths_ptr[i];
|
||||||
|
total_length += length;
|
||||||
|
CAFFE_ENFORCE_LE(total_length, outer_size);
|
||||||
|
std::copy(
|
||||||
|
in_ptr + block_size * startPaddingWidth_,
|
||||||
|
in_ptr + block_size * (length - endPaddingWidth_),
|
||||||
|
out_ptr);
|
||||||
|
in_ptr += block_size * length;
|
||||||
|
out_ptr += block_size * (length - pad_width);
|
||||||
|
}
|
||||||
|
if (OutputSize() == 1) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
auto* lengths_out = Output(1);
|
||||||
|
lengths_out->Resize(lengths_size);
|
||||||
|
std::transform(
|
||||||
|
lengths_ptr,
|
||||||
|
lengths_ptr + lengths_size,
|
||||||
|
lengths_out->mutable_data<int32_t>(),
|
||||||
|
[pad_width](int32_t x) { return x - pad_width; });
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
template <>
|
||||||
int startPaddingWidth_;
|
template <typename T>
|
||||||
int endPaddingWidth_;
|
bool AddPaddingOp<CPUContext>::DoRunWithType() {
|
||||||
};
|
const auto& in = Input(0);
|
||||||
|
CAFFE_ENFORCE_GE(in.ndim(), 1);
|
||||||
|
const int32_t outer_size = in.dims()[0];
|
||||||
|
const auto block_size = std::accumulate(
|
||||||
|
in.dims().begin() + 1, in.dims().end(), 1, std::multiplies<TIndex>());
|
||||||
|
|
||||||
class RemovePaddingOp final : public Operator<CPUContext> {
|
// if no lengths is provided, assume it is a single full-span entry
|
||||||
public:
|
const int32_t* lengths_ptr = &outer_size;
|
||||||
RemovePaddingOp(const OperatorDef& operator_def, Workspace* ws)
|
int64_t lengths_size = 1;
|
||||||
: Operator(operator_def, ws),
|
if (InputSize() > 1) {
|
||||||
startPaddingWidth_(
|
const auto& lengths = Input(1);
|
||||||
OperatorBase::GetSingleArgument<int>("padding_width", 1)),
|
lengths_ptr = lengths.data<int32_t>();
|
||||||
endPaddingWidth_(
|
lengths_size = lengths.size();
|
||||||
OperatorBase::GetSingleArgument<int>("end_padding_width", -1)) {
|
|
||||||
CAFFE_ENFORCE_GE(startPaddingWidth_, 0);
|
|
||||||
if (endPaddingWidth_ < 0) {
|
|
||||||
endPaddingWidth_ = startPaddingWidth_;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool RunOnDevice() override {
|
// fetch paddings
|
||||||
if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
|
// input_size == 2 : pad with zeros
|
||||||
Output(0)->CopyFrom(Input(0));
|
// input_size == 3 : start and end paddings are the same
|
||||||
if (OutputSize() == 2) {
|
// input_size == 4 : different start and end paddings
|
||||||
Output(1)->CopyFrom(Input(1));
|
const T* padding_start_ptr = nullptr;
|
||||||
}
|
const T* padding_end_ptr = nullptr;
|
||||||
return true;
|
if (InputSize() >= 3) {
|
||||||
}
|
auto& padding_start = Input(2);
|
||||||
return DispatchHelper<TensorTypes<float, double, int, int64_t, bool>>::call(
|
CAFFE_ENFORCE_EQ(block_size, padding_start.size());
|
||||||
this, Input(0));
|
padding_start_ptr = padding_start.template data<T>();
|
||||||
|
}
|
||||||
|
if (InputSize() == 4) {
|
||||||
|
auto& padding_end = Input(3);
|
||||||
|
CAFFE_ENFORCE_EQ(block_size, padding_end.size());
|
||||||
|
padding_end_ptr = padding_end.template data<T>();
|
||||||
|
} else {
|
||||||
|
padding_end_ptr = padding_start_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
auto* out = Output(0);
|
||||||
bool DoRunWithType() {
|
{
|
||||||
const auto& in = Input(0);
|
auto out_dims = in.dims();
|
||||||
CAFFE_ENFORCE_GE(in.ndim(), 1);
|
out_dims[0] += (startPaddingWidth_ + endPaddingWidth_) * lengths_size;
|
||||||
const int32_t outer_size = in.dims()[0];
|
out->Resize(std::move(out_dims));
|
||||||
const auto block_size = std::accumulate(
|
|
||||||
in.dims().begin() + 1, in.dims().end(), 1, std::multiplies<TIndex>());
|
|
||||||
const auto pad_width = startPaddingWidth_ + endPaddingWidth_;
|
|
||||||
|
|
||||||
// if no lengths is provided, assume it is a single full-span entry
|
|
||||||
const int32_t* lengths_ptr = &outer_size;
|
|
||||||
int64_t lengths_size = 1;
|
|
||||||
if (InputSize() > 1) {
|
|
||||||
const auto& lengths = Input(1);
|
|
||||||
lengths_ptr = lengths.data<int32_t>();
|
|
||||||
lengths_size = lengths.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto* out = Output(0);
|
|
||||||
{
|
|
||||||
auto out_dims = in.dims();
|
|
||||||
out_dims[0] -= pad_width * lengths_size;
|
|
||||||
out->Resize(std::move(out_dims));
|
|
||||||
}
|
|
||||||
const auto* in_ptr = in.data<T>();
|
|
||||||
auto* out_ptr = out->mutable_data<T>();
|
|
||||||
int64_t total_length = 0;
|
|
||||||
for (int i = 0; i < lengths_size; ++i) {
|
|
||||||
// check that total length is consistent
|
|
||||||
const auto length = lengths_ptr[i];
|
|
||||||
total_length += length;
|
|
||||||
CAFFE_ENFORCE_LE(total_length, outer_size);
|
|
||||||
std::copy(
|
|
||||||
in_ptr + block_size * startPaddingWidth_,
|
|
||||||
in_ptr + block_size * (length - endPaddingWidth_),
|
|
||||||
out_ptr);
|
|
||||||
in_ptr += block_size * length;
|
|
||||||
out_ptr += block_size * (length - pad_width);
|
|
||||||
}
|
|
||||||
if (OutputSize() == 1) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
auto* lengths_out = Output(1);
|
|
||||||
lengths_out->Resize(lengths_size);
|
|
||||||
std::transform(
|
|
||||||
lengths_ptr,
|
|
||||||
lengths_ptr + lengths_size,
|
|
||||||
lengths_out->mutable_data<int32_t>(),
|
|
||||||
[pad_width](int32_t x) { return x - pad_width; });
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
const auto* in_ptr = in.template data<T>();
|
||||||
private:
|
auto* out_ptr = out->template mutable_data<T>();
|
||||||
int startPaddingWidth_;
|
int64_t total_length = 0;
|
||||||
int endPaddingWidth_;
|
for (int i = 0; i < lengths_size; ++i) {
|
||||||
};
|
// check that total length is consistent
|
||||||
|
const auto length = lengths_ptr[i];
|
||||||
class AddPaddingOp final : public Operator<CPUContext> {
|
total_length += length;
|
||||||
public:
|
CAFFE_ENFORCE_LE(total_length, outer_size);
|
||||||
AddPaddingOp(const OperatorDef& operator_def, Workspace* ws)
|
// copy padding before
|
||||||
: Operator(operator_def, ws),
|
if (!padding_start_ptr) {
|
||||||
startPaddingWidth_(
|
memset(out_ptr, 0, block_size * startPaddingWidth_ * sizeof(T));
|
||||||
OperatorBase::GetSingleArgument<int>("padding_width", 1)),
|
out_ptr += block_size * startPaddingWidth_;
|
||||||
endPaddingWidth_(
|
|
||||||
OperatorBase::GetSingleArgument<int>("end_padding_width", -1)) {
|
|
||||||
CAFFE_ENFORCE_GE(startPaddingWidth_, 0);
|
|
||||||
if (endPaddingWidth_ < 0) {
|
|
||||||
endPaddingWidth_ = startPaddingWidth_;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool RunOnDevice() override {
|
|
||||||
if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
|
|
||||||
Output(0)->CopyFrom(Input(0));
|
|
||||||
if (OutputSize() == 2) {
|
|
||||||
Output(1)->CopyFrom(Input(1));
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return DispatchHelper<TensorTypes<float, double, int, int64_t, bool>>::call(
|
|
||||||
this, Input(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
bool DoRunWithType() {
|
|
||||||
const auto& in = Input(0);
|
|
||||||
CAFFE_ENFORCE_GE(in.ndim(), 1);
|
|
||||||
const int32_t outer_size = in.dims()[0];
|
|
||||||
const auto block_size = std::accumulate(
|
|
||||||
in.dims().begin() + 1, in.dims().end(), 1, std::multiplies<TIndex>());
|
|
||||||
|
|
||||||
// if no lengths is provided, assume it is a single full-span entry
|
|
||||||
const int32_t* lengths_ptr = &outer_size;
|
|
||||||
int64_t lengths_size = 1;
|
|
||||||
if (InputSize() > 1) {
|
|
||||||
const auto& lengths = Input(1);
|
|
||||||
lengths_ptr = lengths.data<int32_t>();
|
|
||||||
lengths_size = lengths.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetch paddings
|
|
||||||
// input_size == 2 : pad with zeros
|
|
||||||
// input_size == 3 : start and end paddings are the same
|
|
||||||
// input_size == 4 : different start and end paddings
|
|
||||||
const T* padding_start_ptr = nullptr;
|
|
||||||
const T* padding_end_ptr = nullptr;
|
|
||||||
if (InputSize() >= 3) {
|
|
||||||
auto& padding_start = Input(2);
|
|
||||||
CAFFE_ENFORCE_EQ(block_size, padding_start.size());
|
|
||||||
padding_start_ptr = padding_start.data<T>();
|
|
||||||
}
|
|
||||||
if (InputSize() == 4) {
|
|
||||||
auto& padding_end = Input(3);
|
|
||||||
CAFFE_ENFORCE_EQ(block_size, padding_end.size());
|
|
||||||
padding_end_ptr = padding_end.data<T>();
|
|
||||||
} else {
|
} else {
|
||||||
padding_end_ptr = padding_start_ptr;
|
for (int j = 0; j < startPaddingWidth_; ++j) {
|
||||||
}
|
std::copy(padding_start_ptr, padding_start_ptr + block_size, out_ptr);
|
||||||
|
out_ptr += block_size;
|
||||||
auto* out = Output(0);
|
|
||||||
{
|
|
||||||
auto out_dims = in.dims();
|
|
||||||
out_dims[0] += (startPaddingWidth_ + endPaddingWidth_) * lengths_size;
|
|
||||||
out->Resize(std::move(out_dims));
|
|
||||||
}
|
|
||||||
const auto* in_ptr = in.data<T>();
|
|
||||||
auto* out_ptr = out->mutable_data<T>();
|
|
||||||
int64_t total_length = 0;
|
|
||||||
for (int i = 0; i < lengths_size; ++i) {
|
|
||||||
// check that total length is consistent
|
|
||||||
const auto length = lengths_ptr[i];
|
|
||||||
total_length += length;
|
|
||||||
CAFFE_ENFORCE_LE(total_length, outer_size);
|
|
||||||
// copy padding before
|
|
||||||
if (!padding_start_ptr) {
|
|
||||||
memset(out_ptr, 0, block_size * startPaddingWidth_ * sizeof(T));
|
|
||||||
out_ptr += block_size * startPaddingWidth_;
|
|
||||||
} else {
|
|
||||||
for (int j = 0; j < startPaddingWidth_; ++j) {
|
|
||||||
std::copy(padding_start_ptr, padding_start_ptr + block_size, out_ptr);
|
|
||||||
out_ptr += block_size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// copy payload
|
|
||||||
const auto num_elems = block_size * length;
|
|
||||||
std::copy(in_ptr, in_ptr + num_elems, out_ptr);
|
|
||||||
in_ptr += num_elems;
|
|
||||||
out_ptr += num_elems;
|
|
||||||
// copy padding after
|
|
||||||
if (!padding_end_ptr) {
|
|
||||||
memset(out_ptr, 0, block_size * endPaddingWidth_ * sizeof(T));
|
|
||||||
out_ptr += block_size * endPaddingWidth_;
|
|
||||||
} else {
|
|
||||||
for (int j = 0; j < endPaddingWidth_; ++j) {
|
|
||||||
std::copy(padding_end_ptr, padding_end_ptr + block_size, out_ptr);
|
|
||||||
out_ptr += block_size;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (OutputSize() == 1) {
|
// copy payload
|
||||||
return true;
|
const auto num_elems = block_size * length;
|
||||||
|
std::copy(in_ptr, in_ptr + num_elems, out_ptr);
|
||||||
|
in_ptr += num_elems;
|
||||||
|
out_ptr += num_elems;
|
||||||
|
// copy padding after
|
||||||
|
if (!padding_end_ptr) {
|
||||||
|
memset(out_ptr, 0, block_size * endPaddingWidth_ * sizeof(T));
|
||||||
|
out_ptr += block_size * endPaddingWidth_;
|
||||||
|
} else {
|
||||||
|
for (int j = 0; j < endPaddingWidth_; ++j) {
|
||||||
|
std::copy(padding_end_ptr, padding_end_ptr + block_size, out_ptr);
|
||||||
|
out_ptr += block_size;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
auto* lengths_out = Output(1);
|
}
|
||||||
lengths_out->Resize(lengths_size);
|
if (OutputSize() == 1) {
|
||||||
const auto pad_width = startPaddingWidth_ + endPaddingWidth_;
|
|
||||||
std::transform(
|
|
||||||
lengths_ptr,
|
|
||||||
lengths_ptr + lengths_size,
|
|
||||||
lengths_out->mutable_data<int32_t>(),
|
|
||||||
[pad_width](int32_t x) { return x + pad_width; });
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
auto* lengths_out = Output(1);
|
||||||
|
lengths_out->Resize(lengths_size);
|
||||||
|
const auto pad_width = startPaddingWidth_ + endPaddingWidth_;
|
||||||
|
std::transform(
|
||||||
|
lengths_ptr,
|
||||||
|
lengths_ptr + lengths_size,
|
||||||
|
lengths_out->mutable_data<int32_t>(),
|
||||||
|
[pad_width](int32_t x) { return x + pad_width; });
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
template <>
|
||||||
int startPaddingWidth_;
|
bool PadEmptySamplesOp<CPUContext>::RunOnDevice() {
|
||||||
int endPaddingWidth_;
|
auto& lengths = Input(0);
|
||||||
};
|
auto* lengthsPtr = lengths.template data<int32_t>();
|
||||||
|
CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
|
||||||
|
CAFFE_ENFORCE(InputSize() >= 1, "Input size must be no less than 1");
|
||||||
|
|
||||||
using TLength = int32_t;
|
auto* out_lengths = Output(0);
|
||||||
|
int needPadding = 0;
|
||||||
class PadEmptySamplesOp : public Operator<CPUContext> {
|
int sumLen = 0;
|
||||||
public:
|
for (int i = 0; i < lengths.size(); ++i) {
|
||||||
PadEmptySamplesOp(const OperatorDef& operator_def, Workspace* ws)
|
if (lengthsPtr[i] == 0) {
|
||||||
: Operator<CPUContext>(operator_def, ws) {}
|
needPadding++;
|
||||||
|
|
||||||
bool RunOnDevice() override {
|
|
||||||
auto& lengths = Input(0);
|
|
||||||
auto* lengthsPtr = lengths.template data<TLength>();
|
|
||||||
CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
|
|
||||||
CAFFE_ENFORCE(InputSize() >= 1, "Input size must be no less than 1");
|
|
||||||
|
|
||||||
auto* out_lengths = Output(0);
|
|
||||||
int needPadding = 0;
|
|
||||||
int sumLen = 0;
|
|
||||||
for (int i = 0; i < lengths.size(); ++i) {
|
|
||||||
if (lengthsPtr[i] == 0) {
|
|
||||||
needPadding++;
|
|
||||||
}
|
|
||||||
sumLen += lengthsPtr[i];
|
|
||||||
}
|
}
|
||||||
|
sumLen += lengthsPtr[i];
|
||||||
out_lengths->Resize(lengths.size());
|
|
||||||
auto* outLengthsPtr = out_lengths->template mutable_data<TLength>();
|
|
||||||
for (int i = 0; i < lengths.size(); ++i) {
|
|
||||||
if (lengthsPtr[i] == 0) {
|
|
||||||
outLengthsPtr[i] = 1;
|
|
||||||
} else {
|
|
||||||
outLengthsPtr[i] = lengthsPtr[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int k = 0; k < InputSize() - 1; k++) {
|
|
||||||
auto& features = Input(1 + k);
|
|
||||||
CAFFE_ENFORCE(features.ndim() >= 1, "FEATURE should at least 1-D");
|
|
||||||
CAFFE_ENFORCE(
|
|
||||||
features.dim(0) == sumLen, "FEATURE and LENGTH should be consistent");
|
|
||||||
const auto block_size = features.size_from_dim(1);
|
|
||||||
|
|
||||||
auto* out_features = Output(1 + k);
|
|
||||||
auto outDim = features.dims();
|
|
||||||
outDim.at(0) += needPadding;
|
|
||||||
out_features->Resize(outDim);
|
|
||||||
auto dst =
|
|
||||||
static_cast<char*>(out_features->raw_mutable_data(features.meta()));
|
|
||||||
auto src_base = static_cast<const char*>(features.raw_data());
|
|
||||||
// copy data and add padding index as zero
|
|
||||||
Tensor<CPUContext> zero;
|
|
||||||
zero.Resize(block_size);
|
|
||||||
auto zeroPtr =
|
|
||||||
static_cast<const char*>(zero.raw_mutable_data(features.meta()));
|
|
||||||
int start_dest = 0;
|
|
||||||
int start_src = 0;
|
|
||||||
for (int i = 0; i < lengths.size(); ++i) {
|
|
||||||
if (lengthsPtr[i] == 0) {
|
|
||||||
context_.template CopyItems<CPUContext, CPUContext>(
|
|
||||||
features.meta(),
|
|
||||||
block_size,
|
|
||||||
zeroPtr,
|
|
||||||
dst + start_dest * features.meta().itemsize());
|
|
||||||
start_dest += block_size;
|
|
||||||
} else {
|
|
||||||
auto src = src_base + start_src * features.meta().itemsize();
|
|
||||||
context_.template CopyItems<CPUContext, CPUContext>(
|
|
||||||
features.meta(),
|
|
||||||
lengthsPtr[i] * block_size,
|
|
||||||
src,
|
|
||||||
dst + start_dest * features.meta().itemsize());
|
|
||||||
start_src += lengthsPtr[i] * block_size;
|
|
||||||
start_dest += lengthsPtr[i] * block_size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
REGISTER_CPU_OPERATOR(AddPadding, AddPaddingOp);
|
out_lengths->Resize(lengths.size());
|
||||||
REGISTER_CPU_OPERATOR(RemovePadding, RemovePaddingOp);
|
auto* outLengthsPtr = out_lengths->template mutable_data<int32_t>();
|
||||||
REGISTER_CPU_OPERATOR(GatherPadding, GatherPaddingOp);
|
for (int i = 0; i < lengths.size(); ++i) {
|
||||||
REGISTER_CPU_OPERATOR(PadEmptySamples, PadEmptySamplesOp);
|
if (lengthsPtr[i] == 0) {
|
||||||
|
outLengthsPtr[i] = 1;
|
||||||
|
} else {
|
||||||
|
outLengthsPtr[i] = lengthsPtr[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int k = 0; k < InputSize() - 1; k++) {
|
||||||
|
auto& features = Input(1 + k);
|
||||||
|
CAFFE_ENFORCE(features.ndim() >= 1, "FEATURE should at least 1-D");
|
||||||
|
CAFFE_ENFORCE(
|
||||||
|
features.dim(0) == sumLen, "FEATURE and LENGTH should be consistent");
|
||||||
|
const auto block_size = features.size_from_dim(1);
|
||||||
|
|
||||||
|
auto* out_features = Output(1 + k);
|
||||||
|
auto outDim = features.dims();
|
||||||
|
outDim.at(0) += needPadding;
|
||||||
|
out_features->Resize(outDim);
|
||||||
|
auto dst =
|
||||||
|
static_cast<char*>(out_features->raw_mutable_data(features.meta()));
|
||||||
|
auto src_base = static_cast<const char*>(features.raw_data());
|
||||||
|
// copy data and add padding index as zero
|
||||||
|
Tensor<CPUContext> zero;
|
||||||
|
zero.Resize(block_size);
|
||||||
|
auto zeroPtr =
|
||||||
|
static_cast<const char*>(zero.raw_mutable_data(features.meta()));
|
||||||
|
int start_dest = 0;
|
||||||
|
int start_src = 0;
|
||||||
|
for (int i = 0; i < lengths.size(); ++i) {
|
||||||
|
if (lengthsPtr[i] == 0) {
|
||||||
|
context_.template CopyItems<CPUContext, CPUContext>(
|
||||||
|
features.meta(),
|
||||||
|
block_size,
|
||||||
|
zeroPtr,
|
||||||
|
dst + start_dest * features.meta().itemsize());
|
||||||
|
start_dest += block_size;
|
||||||
|
} else {
|
||||||
|
auto src = src_base + start_src * features.meta().itemsize();
|
||||||
|
context_.template CopyItems<CPUContext, CPUContext>(
|
||||||
|
features.meta(),
|
||||||
|
lengthsPtr[i] * block_size,
|
||||||
|
src,
|
||||||
|
dst + start_dest * features.meta().itemsize());
|
||||||
|
start_src += lengthsPtr[i] * block_size;
|
||||||
|
start_dest += lengthsPtr[i] * block_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_CPU_OPERATOR(AddPadding, AddPaddingOp<CPUContext>);
|
||||||
|
REGISTER_CPU_OPERATOR(RemovePadding, RemovePaddingOp<CPUContext>);
|
||||||
|
REGISTER_CPU_OPERATOR(GatherPadding, GatherPaddingOp<CPUContext>);
|
||||||
|
REGISTER_CPU_OPERATOR(PadEmptySamples, PadEmptySamplesOp<CPUContext>);
|
||||||
|
|
||||||
struct GetAddPadingGradient : public GradientMakerBase {
|
struct GetAddPadingGradient : public GradientMakerBase {
|
||||||
using GradientMakerBase::GradientMakerBase;
|
using GradientMakerBase::GradientMakerBase;
|
||||||
@ -528,5 +431,5 @@ PadEmptySamples is thread safe.
|
|||||||
0,
|
0,
|
||||||
"out_lengths",
|
"out_lengths",
|
||||||
"Tensor containing lengths with empty sample padded.");
|
"Tensor containing lengths with empty sample padded.");
|
||||||
}
|
|
||||||
}
|
} // namespace caffe2
|
||||||
|
129
caffe2/operators/sequence_ops.h
Normal file
129
caffe2/operators/sequence_ops.h
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
#ifndef CAFFE2_OPERATORS_SEQUENCE_OPS_H_
|
||||||
|
#define CAFFE2_OPERATORS_SEQUENCE_OPS_H_
|
||||||
|
|
||||||
|
#include "caffe2/core/operator.h"
|
||||||
|
#include "caffe2/core/tensor.h"
|
||||||
|
|
||||||
|
namespace caffe2 {
|
||||||
|
|
||||||
|
template <class Context>
|
||||||
|
class GatherPaddingOp final : public Operator<Context> {
|
||||||
|
public:
|
||||||
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||||
|
GatherPaddingOp(const OperatorDef& operator_def, Workspace* ws)
|
||||||
|
: Operator<Context>(operator_def, ws),
|
||||||
|
startPaddingWidth_(
|
||||||
|
OperatorBase::GetSingleArgument<int>("padding_width", 1)),
|
||||||
|
endPaddingWidth_(
|
||||||
|
OperatorBase::GetSingleArgument<int>("end_padding_width", -1)) {
|
||||||
|
CAFFE_ENFORCE_GE(startPaddingWidth_, 0);
|
||||||
|
if (endPaddingWidth_ < 0) {
|
||||||
|
endPaddingWidth_ = startPaddingWidth_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RunOnDevice() override {
|
||||||
|
if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
|
||||||
|
Output(0)->Resize(std::vector<TIndex>(0));
|
||||||
|
if (OutputSize() == 2) {
|
||||||
|
Output(1)->Resize(std::vector<TIndex>(0));
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return DispatchHelper<TensorTypes<float, double, int, int64_t, bool>>::call(
|
||||||
|
this, Input(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool DoRunWithType();
|
||||||
|
|
||||||
|
private:
|
||||||
|
int startPaddingWidth_;
|
||||||
|
int endPaddingWidth_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class Context>
|
||||||
|
class RemovePaddingOp final : public Operator<Context> {
|
||||||
|
public:
|
||||||
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||||
|
RemovePaddingOp(const OperatorDef& operator_def, Workspace* ws)
|
||||||
|
: Operator<Context>(operator_def, ws),
|
||||||
|
startPaddingWidth_(
|
||||||
|
OperatorBase::GetSingleArgument<int>("padding_width", 1)),
|
||||||
|
endPaddingWidth_(
|
||||||
|
OperatorBase::GetSingleArgument<int>("end_padding_width", -1)) {
|
||||||
|
CAFFE_ENFORCE_GE(startPaddingWidth_, 0);
|
||||||
|
if (endPaddingWidth_ < 0) {
|
||||||
|
endPaddingWidth_ = startPaddingWidth_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RunOnDevice() override {
|
||||||
|
if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
|
||||||
|
Output(0)->CopyFrom(Input(0), &context_);
|
||||||
|
if (OutputSize() == 2) {
|
||||||
|
Output(1)->CopyFrom(Input(1), &context_);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return DispatchHelper<TensorTypes<float, double, int, int64_t, bool>>::call(
|
||||||
|
this, Input(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool DoRunWithType();
|
||||||
|
|
||||||
|
private:
|
||||||
|
int startPaddingWidth_;
|
||||||
|
int endPaddingWidth_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class Context>
|
||||||
|
class AddPaddingOp final : public Operator<Context> {
|
||||||
|
public:
|
||||||
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||||
|
AddPaddingOp(const OperatorDef& operator_def, Workspace* ws)
|
||||||
|
: Operator<Context>(operator_def, ws),
|
||||||
|
startPaddingWidth_(
|
||||||
|
OperatorBase::GetSingleArgument<int>("padding_width", 1)),
|
||||||
|
endPaddingWidth_(
|
||||||
|
OperatorBase::GetSingleArgument<int>("end_padding_width", -1)) {
|
||||||
|
CAFFE_ENFORCE_GE(startPaddingWidth_, 0);
|
||||||
|
if (endPaddingWidth_ < 0) {
|
||||||
|
endPaddingWidth_ = startPaddingWidth_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RunOnDevice() override {
|
||||||
|
if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
|
||||||
|
Output(0)->CopyFrom(Input(0), &context_);
|
||||||
|
if (OutputSize() == 2) {
|
||||||
|
Output(1)->CopyFrom(Input(1), &context_);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return DispatchHelper<TensorTypes<float, double, int, int64_t, bool>>::call(
|
||||||
|
this, Input(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool DoRunWithType();
|
||||||
|
|
||||||
|
private:
|
||||||
|
int startPaddingWidth_;
|
||||||
|
int endPaddingWidth_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class Context>
|
||||||
|
class PadEmptySamplesOp : public Operator<Context> {
|
||||||
|
public:
|
||||||
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||||
|
PadEmptySamplesOp(const OperatorDef& operator_def, Workspace* ws)
|
||||||
|
: Operator<Context>(operator_def, ws) {}
|
||||||
|
|
||||||
|
bool RunOnDevice() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace caffe2
|
||||||
|
|
||||||
|
#endif // CAFFE2_OPERATORS_SEQUENCE_OPS_H_
|
Reference in New Issue
Block a user