mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
refactor caffe2 operator constructors - 6/9 (#17087)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17087 clangr codemod Reviewed By: ezyang Differential Revision: D14078525 fbshipit-source-id: 7cc03b30b0d4eb99818e35406be4119b27bdb1bc
This commit is contained in:
committed by
Facebook Github Bot
parent
a4ed7126ca
commit
28b5df1c8f
@ -17,8 +17,9 @@ template <class Context>
|
||||
class PercentileOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
PercentileOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit PercentileOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override;
|
||||
|
||||
|
@ -11,8 +11,9 @@ class PiecewiseLinearTransformOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
PiecewiseLinearTransformOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {
|
||||
template <class... Args>
|
||||
explicit PiecewiseLinearTransformOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {
|
||||
binary_ = this->template GetSingleArgument<bool>("binary", false);
|
||||
|
||||
// Retrieve transform params (i.e., the linear functions).
|
||||
|
@ -16,8 +16,9 @@ class PoolOp final : public ConvPoolOpBase<Context> {
|
||||
public:
|
||||
USE_CONV_POOL_BASE_FUNCTIONS(Context);
|
||||
|
||||
PoolOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: ConvPoolOpBase<Context>(operator_def, ws), functor_(*this) {
|
||||
template <class... Args>
|
||||
explicit PoolOp(Args&&... args)
|
||||
: ConvPoolOpBase<Context>(std::forward<Args>(args)...), functor_(*this) {
|
||||
const int kernel_size = kernel_.size();
|
||||
for (int i = 0; i < kernel_size; ++i) {
|
||||
CAFFE_ENFORCE_EQ(
|
||||
@ -107,8 +108,9 @@ template <typename T, class Context, class Functor>
|
||||
class PoolGradientOp final : public ConvPoolOpBase<Context> {
|
||||
public:
|
||||
USE_CONV_POOL_BASE_FUNCTIONS(Context);
|
||||
PoolGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: ConvPoolOpBase<Context>(operator_def, ws), functor_(*this) {}
|
||||
template <class... Args>
|
||||
explicit PoolGradientOp(Args&&... args)
|
||||
: ConvPoolOpBase<Context>(std::forward<Args>(args)...), functor_(*this) {}
|
||||
|
||||
~PoolGradientOp() = default;
|
||||
|
||||
|
@ -50,8 +50,9 @@ void SetTensorDescriptor(
|
||||
template <class Functor>
|
||||
class CuDNNPoolOp final : public ConvPoolOpBase<CUDAContext> {
|
||||
public:
|
||||
CuDNNPoolOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: ConvPoolOpBase<CUDAContext>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit CuDNNPoolOp(Args&&... args)
|
||||
: ConvPoolOpBase<CUDAContext>(std::forward<Args>(args)...),
|
||||
cudnn_wrapper_(&context_),
|
||||
functor_(*this),
|
||||
equal_padding_(std::equal(
|
||||
@ -190,8 +191,9 @@ class CuDNNPoolOp final : public ConvPoolOpBase<CUDAContext> {
|
||||
template <class Functor>
|
||||
class CuDNNPoolGradientOp final : public ConvPoolOpBase<CUDAContext> {
|
||||
public:
|
||||
CuDNNPoolGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: ConvPoolOpBase<CUDAContext>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit CuDNNPoolGradientOp(Args&&... args)
|
||||
: ConvPoolOpBase<CUDAContext>(std::forward<Args>(args)...),
|
||||
cudnn_wrapper_(&context_),
|
||||
functor_(*this),
|
||||
equal_padding_(std::equal(
|
||||
|
@ -20,8 +20,9 @@ class PowOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
PowOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit PowOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
OP_SINGLE_ARG(bool, "broadcast", enable_broadcast_, 0),
|
||||
OP_SINGLE_ARG(int, "axis", axis_, -1),
|
||||
OP_SINGLE_ARG(string, "axis_str", axis_str_, ""),
|
||||
|
@ -10,8 +10,9 @@ namespace caffe2 {
|
||||
template <typename T, class Context>
|
||||
class PReluOp final : public Operator<Context> {
|
||||
public:
|
||||
PReluOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit PReluOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
order_(StringToStorageOrder(
|
||||
this->template GetSingleArgument<string>("order", "NCHW"))) {}
|
||||
|
||||
@ -26,8 +27,9 @@ class PReluOp final : public Operator<Context> {
|
||||
template <typename T, class Context>
|
||||
class PReluGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
PReluGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit PReluGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
order_(StringToStorageOrder(
|
||||
this->template GetSingleArgument<string>("order", "NCHW"))) {}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
@ -13,8 +13,9 @@ template <class Context>
|
||||
class PrependDimOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
PrependDimOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit PrependDimOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
dim_size_(this->template GetSingleArgument<int64_t>("dim_size", 0)) {
|
||||
CAFFE_ENFORCE_GT(
|
||||
dim_size_, 0, "Argument dim_size must be greater than zero.");
|
||||
@ -57,8 +58,9 @@ template <class Context>
|
||||
class MergeDimOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
MergeDimOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit MergeDimOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
auto& input = Input(0);
|
||||
|
@ -102,8 +102,9 @@ template <QuantDecodeRunTy QuantDecodeRun>
|
||||
class QuantDecodeOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
USE_OPERATOR_FUNCTIONS(CPUContext);
|
||||
QuantDecodeOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit QuantDecodeOp(Args&&... args)
|
||||
: Operator<CPUContext>(std::forward<Args>(args)...) {}
|
||||
|
||||
~QuantDecodeOp() {}
|
||||
|
||||
@ -138,8 +139,9 @@ class QuantDecodeOp final : public Operator<CPUContext> {
|
||||
class QuantDecodeGradientOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
USE_OPERATOR_FUNCTIONS(CPUContext);
|
||||
QuantDecodeGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit QuantDecodeGradientOp(Args&&... args)
|
||||
: Operator<CPUContext>(std::forward<Args>(args)...) {}
|
||||
~QuantDecodeGradientOp() {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
@ -15,9 +15,8 @@ namespace int8 {
|
||||
template <Activation Ac>
|
||||
class Int8AddOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
Int8AddOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
ws_(ws) {}
|
||||
explicit Int8AddOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws), ws_(ws) {}
|
||||
|
||||
~Int8AddOp() {
|
||||
if (this->qnnpackOperator_ != nullptr) {
|
||||
|
@ -16,8 +16,9 @@ namespace int8 {
|
||||
template <Activation Ac>
|
||||
class Int8AveragePoolOp final : public ConvPoolOpBase<CPUContext> {
|
||||
public:
|
||||
Int8AveragePoolOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: ConvPoolOpBase<CPUContext>(operator_def, ws) {
|
||||
template <class... Args>
|
||||
explicit Int8AveragePoolOp(Args&&... args)
|
||||
: ConvPoolOpBase<CPUContext>(std::forward<Args>(args)...) {
|
||||
OPERATOR_NEEDS_FEATURE(
|
||||
this->order_ == StorageOrder::NHWC, "Int8 only supports NHWC order.");
|
||||
}
|
||||
|
@ -15,9 +15,8 @@ namespace int8 {
|
||||
|
||||
class Int8ChannelShuffleOp final : public ConvPoolOpBase<CPUContext> {
|
||||
public:
|
||||
Int8ChannelShuffleOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: ConvPoolOpBase<CPUContext>(operator_def, ws),
|
||||
ws_(ws) {
|
||||
explicit Int8ChannelShuffleOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: ConvPoolOpBase<CPUContext>(operator_def, ws), ws_(ws) {
|
||||
OPERATOR_NEEDS_FEATURE(
|
||||
this->order_ == StorageOrder::NHWC,
|
||||
"Int8ChannelShuffleOp only supports NHWC order");
|
||||
|
@ -12,8 +12,9 @@ namespace int8 {
|
||||
|
||||
class Int8ConcatOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
Int8ConcatOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws) {
|
||||
template <class... Args>
|
||||
explicit Int8ConcatOp(Args&&... args)
|
||||
: Operator<CPUContext>(std::forward<Args>(args)...) {
|
||||
// concat supports more than NHWC format
|
||||
if (this->template GetSingleArgument<string>("order", "") == "NHWC") {
|
||||
// Default to C axis
|
||||
|
@ -18,8 +18,9 @@ template <Activation Ac>
|
||||
class Int8ConvOp final : public ConvPoolOpBase<CPUContext> {
|
||||
public:
|
||||
USE_CONV_POOL_BASE_FUNCTIONS(CPUContext);
|
||||
Int8ConvOp(const OperatorDef& def, Workspace* ws)
|
||||
: ConvPoolOpBase(def, ws) {
|
||||
template <class... Args>
|
||||
explicit Int8ConvOp(Args&&... args)
|
||||
: ConvPoolOpBase(std::forward<Args>(args)...) {
|
||||
OPERATOR_NEEDS_FEATURE(
|
||||
this->order_ == StorageOrder::NHWC,
|
||||
"Int8Conv only supports NHWC order");
|
||||
|
@ -17,8 +17,9 @@ namespace int8 {
|
||||
class Int8ConvTransposeOp final : public ConvTransposeUnpoolBase<CPUContext> {
|
||||
public:
|
||||
USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(CPUContext);
|
||||
Int8ConvTransposeOp(const OperatorDef& def, Workspace* ws)
|
||||
: ConvTransposeUnpoolBase(def, ws) {
|
||||
template <class... Args>
|
||||
explicit Int8ConvTransposeOp(Args&&... args)
|
||||
: ConvTransposeUnpoolBase(std::forward<Args>(args)...) {
|
||||
OPERATOR_NEEDS_FEATURE(
|
||||
this->order_ == StorageOrder::NHWC,
|
||||
"Int8ConvTransposeOp only supports NHWC order");
|
||||
|
@ -15,9 +15,8 @@ namespace int8 {
|
||||
|
||||
class Int8FCOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
Int8FCOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
ws_(ws) {
|
||||
explicit Int8FCOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws), ws_(ws) {
|
||||
createSharedBuffer<CPUContext>(ws_);
|
||||
}
|
||||
|
||||
|
@ -12,8 +12,9 @@ namespace int8 {
|
||||
|
||||
class Int8FlattenOp : public Operator<CPUContext> {
|
||||
public:
|
||||
Int8FlattenOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit Int8FlattenOp(Args&&... args)
|
||||
: Operator<CPUContext>(std::forward<Args>(args)...),
|
||||
axis_(this->template GetSingleArgument<int>("axis", 1)) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
@ -14,8 +14,9 @@ namespace int8 {
|
||||
|
||||
class Int8GivenTensorFillOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
Int8GivenTensorFillOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit Int8GivenTensorFillOp(Args&&... args)
|
||||
: Operator<CPUContext>(std::forward<Args>(args)...),
|
||||
scale_(this->template GetSingleArgument<float>("Y_scale", 1.0)),
|
||||
zero_point_(
|
||||
this->template GetSingleArgument<int32_t>("Y_zero_point", 0)),
|
||||
@ -63,8 +64,9 @@ class Int8GivenTensorFillOp final : public Operator<CPUContext> {
|
||||
|
||||
class Int8GivenIntTensorFillOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
Int8GivenIntTensorFillOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit Int8GivenIntTensorFillOp(Args&&... args)
|
||||
: Operator<CPUContext>(std::forward<Args>(args)...),
|
||||
scale_(this->template GetSingleArgument<float>("Y_scale", 1.0)),
|
||||
zero_point_(
|
||||
this->template GetSingleArgument<int32_t>("Y_zero_point", 0)),
|
||||
|
@ -14,9 +14,8 @@ namespace int8 {
|
||||
|
||||
class Int8LeakyReluOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
Int8LeakyReluOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
ws_(ws) {
|
||||
explicit Int8LeakyReluOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws), ws_(ws) {
|
||||
const float alpha = this->template GetSingleArgument<float>("alpha", 0.01);
|
||||
CAFFE_ENFORCE_GT(alpha, 0.0);
|
||||
CAFFE_ENFORCE_LT(alpha, 1.0);
|
||||
|
@ -16,8 +16,9 @@ namespace int8 {
|
||||
template <Activation Ac>
|
||||
class Int8MaxPoolOp final : public ConvPoolOpBase<CPUContext> {
|
||||
public:
|
||||
Int8MaxPoolOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: ConvPoolOpBase<CPUContext>(operator_def, ws) {
|
||||
template <class... Args>
|
||||
explicit Int8MaxPoolOp(Args&&... args)
|
||||
: ConvPoolOpBase<CPUContext>(std::forward<Args>(args)...) {
|
||||
OPERATOR_NEEDS_FEATURE(
|
||||
this->order_ == StorageOrder::NHWC, "Int8 only supports NHWC order.");
|
||||
}
|
||||
|
@ -14,9 +14,8 @@ namespace int8 {
|
||||
|
||||
class Int8ReluOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
Int8ReluOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
ws_(ws) {}
|
||||
explicit Int8ReluOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws), ws_(ws) {}
|
||||
|
||||
~Int8ReluOp() {
|
||||
if (this->qnnpackOperator_ != nullptr) {
|
||||
|
@ -13,8 +13,9 @@ namespace int8 {
|
||||
|
||||
class Int8ReshapeOp final : public ReshapeOp<uint8_t, CPUContext> {
|
||||
public:
|
||||
Int8ReshapeOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: ReshapeOp(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit Int8ReshapeOp(Args&&... args)
|
||||
: ReshapeOp(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
if (InputSize() == 2) {
|
||||
|
@ -12,8 +12,9 @@ namespace int8 {
|
||||
|
||||
class Int8ResizeNearestOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
Int8ResizeNearestOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws) {
|
||||
template <class... Args>
|
||||
explicit Int8ResizeNearestOp(Args&&... args)
|
||||
: Operator<CPUContext>(std::forward<Args>(args)...) {
|
||||
width_scale_ = this->template GetSingleArgument<float>("width_scale", 1);
|
||||
height_scale_ = this->template GetSingleArgument<float>("height_scale", 1);
|
||||
CAFFE_ENFORCE_GT(width_scale_, 0);
|
||||
|
@ -258,8 +258,9 @@ void ROIAlignForward(
|
||||
|
||||
class Int8RoIAlignOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
Int8RoIAlignOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit Int8RoIAlignOp(Args&&... args)
|
||||
: Operator<CPUContext>(std::forward<Args>(args)...),
|
||||
order_(StringToStorageOrder(
|
||||
this->template GetSingleArgument<string>("order", "NHWC"))),
|
||||
spatial_scale_(
|
||||
|
@ -13,9 +13,8 @@ namespace int8 {
|
||||
|
||||
class Int8SigmoidOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
Int8SigmoidOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
ws_(ws) {}
|
||||
explicit Int8SigmoidOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws), ws_(ws) {}
|
||||
|
||||
~Int8SigmoidOp() {
|
||||
if (this->qnnpackOperator_ != nullptr) {
|
||||
|
@ -13,8 +13,8 @@ namespace int8 {
|
||||
|
||||
class Int8SliceOp final : public SliceOp<CPUContext> {
|
||||
public:
|
||||
Int8SliceOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: SliceOp(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit Int8SliceOp(Args&&... args) : SliceOp(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
if (InputSize() > 1) {
|
||||
|
Reference in New Issue
Block a user