mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19660 Implementation of aggregated Scale operator. The operator takes a list of tensors as an input and scales all of them them with the argument float value. The tensor sizes can be different, therefore bookkeeping of the sizes and pointers to the tensors are necessary for the GPU version of the kernel. Reviewed By: BIT-silence Differential Revision: D14984233 fbshipit-source-id: 37cc97159a4f2c38cd6fff4f5710ab7d3a773611
59 lines
1.4 KiB
C++
59 lines
1.4 KiB
C++
#ifndef CAFFE2_OPERATORS_SCALE_BLOBS_OP_H_
|
|
#define CAFFE2_OPERATORS_SCALE_BLOBS_OP_H_
|
|
|
|
#include "caffe2/core/context.h"
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/utils/math.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
template <class Context>
|
|
class ScaleBlobsOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
template <class... Args>
|
|
explicit ScaleBlobsOp(Args&&... args)
|
|
: Operator<Context>(std::forward<Args>(args)...),
|
|
OP_SINGLE_ARG(float, "scale", scale_, 1.0f) {}
|
|
|
|
template <typename T>
|
|
bool DoRunWithType() {
|
|
int batchSize = InputSize();
|
|
|
|
for (int i = 0; i < batchSize; ++i) {
|
|
const auto& X = Input(i);
|
|
auto* Y = Output(i, X.sizes(), at::dtype<T>());
|
|
math::Scale<float, T, Context>(
|
|
X.numel(),
|
|
scale_,
|
|
X.template data<T>(),
|
|
Y->template mutable_data<T>(),
|
|
&context_);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool RunOnDevice() override {
|
|
for (int i = 0; i < InputSize(); ++i) {
|
|
auto& input = this->template Input<Tensor>(i, CPU);
|
|
auto* output = this->template Output<Tensor>(i, CPU);
|
|
output->ResizeLike(input);
|
|
}
|
|
return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
|
|
}
|
|
|
|
private:
|
|
const float scale_;
|
|
Tensor blobSizes_;
|
|
Tensor inputs_;
|
|
Tensor outputs_;
|
|
|
|
Tensor hostBlobSizes_;
|
|
Tensor hostInputs_;
|
|
Tensor hostOutputs_;
|
|
};
|
|
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_OPERATORS_SCALE_BLOBS_OP_H_
|