refactor caffe2 operator constructors - 9/9 (#17090)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17090

clangr codemod

Reviewed By: ezyang

Differential Revision: D14078550

fbshipit-source-id: 68e6de4298e55ce83039b7806c1a275c4d6593c8
This commit is contained in:
Sebastian Messmer
2019-02-28 09:50:19 -08:00
committed by Facebook Github Bot
parent 9bcceb75b5
commit a9395ce259
7 changed files with 46 additions and 33 deletions

View File

@ -20,8 +20,9 @@ template <class Context>
class NanCheckOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
NanCheckOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
template <class... Args>
explicit NanCheckOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override;
@ -46,8 +47,9 @@ class WallClockTimeOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
WallClockTimeOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
template <class... Args>
explicit WallClockTimeOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
int64_t nanoseconds = static_cast<long int>(
@ -70,7 +72,7 @@ class PrintOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_DISPATCH_HELPER;
PrintOp(const OperatorDef& operator_def, Workspace* ws)
explicit PrintOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
tensor_printer_(
operator_def.input(0),
@ -395,8 +397,9 @@ class WeightedSumGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
WeightedSumGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit WeightedSumGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
grad_on_w_(this->template GetSingleArgument<bool>("grad_on_w", false)) {
}
@ -597,8 +600,9 @@ class ScatterAssignOp : public Operator<Context> {
USE_OPERATOR_CONTEXT_FUNCTIONS;
virtual ~ScatterAssignOp() {}
ScatterAssignOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit ScatterAssignOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
runners_({{{TensorProto_DataType_INT32, TensorProto_DataType_FLOAT},
&ScatterAssignOp::DoRun<int32_t, float>},
{{TensorProto_DataType_INT32, TensorProto_DataType_FLOAT16},
@ -871,8 +875,9 @@ template <class Context>
class LengthsToWeightsOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
LengthsToWeightsOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit LengthsToWeightsOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
power_(this->template GetSingleArgument<float>("power", 0.5)) {}
bool RunOnDevice() override {
@ -1149,8 +1154,9 @@ class LengthsGatherOp : public Operator<Context> {
template <typename T, class Context>
class AccumulateHistogramOp : public Operator<Context> {
public:
AccumulateHistogramOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
template <class... Args>
explicit AccumulateHistogramOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
lower_bound_(
this->template GetSingleArgument<float>("lower_bound", 0.0)),
upper_bound_(
@ -1288,8 +1294,9 @@ class RangeOp : public Operator<Context> {
class ThrowExceptionOp : public Operator<CPUContext> {
public:
ThrowExceptionOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
template <class... Args>
explicit ThrowExceptionOp(Args&&... args)
: Operator<CPUContext>(std::forward<Args>(args)...),
message_(GetSingleArgument<std::string>(
"message",
"Exception from ThrowExceptionOp")) {}
@ -1304,8 +1311,9 @@ class ThrowExceptionOp : public Operator<CPUContext> {
class ThrowChildThreadExceptionOp : public Operator<CPUContext> {
public:
ThrowChildThreadExceptionOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
template <class... Args>
explicit ThrowChildThreadExceptionOp(Args&&... args)
: Operator<CPUContext>(std::forward<Args>(args)...),
message_(GetSingleArgument<std::string>(
"message",
"Exception from ThrowChildThreadExceptionOp")) {}
@ -1323,8 +1331,9 @@ class ThrowChildThreadExceptionOp : public Operator<CPUContext> {
class LogFatalOp : public Operator<CPUContext> {
public:
LogFatalOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
template <class... Args>
explicit LogFatalOp(Args&&... args)
: Operator<CPUContext>(std::forward<Args>(args)...),
message_(GetSingleArgument<std::string>(
"message",
"Logging from LogFatalOp")) {}
@ -1340,8 +1349,9 @@ class LogFatalOp : public Operator<CPUContext> {
class FailOp : public Operator<CPUContext> {
public:
FailOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws) {}
template <class... Args>
explicit FailOp(Args&&... args)
: Operator<CPUContext>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
return false;

View File

@ -12,8 +12,10 @@ class CuDNNWeightedSumOp : public Operator<CUDAContext> {
public:
USE_OPERATOR_FUNCTIONS(CUDAContext);
CuDNNWeightedSumOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CUDAContext>(operator_def, ws), cudnn_wrapper_(&context_) {
template <class... Args>
explicit CuDNNWeightedSumOp(Args&&... args)
: Operator<CUDAContext>(std::forward<Args>(args)...),
cudnn_wrapper_(&context_) {
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(cudnnCreateOpTensorDescriptor(&add_desc_));
// Both float and at::Half require opTensorCompType to be CUDNN_DATA_FLOAT.

View File

@ -29,10 +29,9 @@ void VariableLengthSequencePadding(
template <typename T, typename Context>
class VariableLengthSequencePaddingOp : public Operator<Context> {
public:
VariableLengthSequencePaddingOp(
const OperatorDef& operator_def,
Workspace* ws)
: Operator<Context>(operator_def, ws) {}
template <class... Args>
explicit VariableLengthSequencePaddingOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {

View File

@ -9,8 +9,9 @@ class WeightedMultiSamplingOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
WeightedMultiSamplingOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit WeightedMultiSamplingOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
num_samples_(
this->template GetSingleArgument<int64_t>("num_samples", 0)) {
CAFFE_ENFORCE_GE(num_samples_, 0);

View File

@ -13,8 +13,9 @@ namespace caffe2 {
template <typename T, class Context>
class WeightedSampleOp final : public Operator<Context> {
public:
WeightedSampleOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
template <class... Args>
explicit WeightedSampleOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;

View File

@ -10,7 +10,7 @@ namespace caffe2 {
template <class Context>
class WhileOp final : public Operator<Context> {
public:
WhileOp(const OperatorDef& operator_def, Workspace* ws)
explicit WhileOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {
CAFFE_ENFORCE(
this->template HasSingleArgumentOfType<NetDef>("loop_net"),

View File

@ -6,7 +6,7 @@ namespace {
class GetAllBlobNamesOp final : public Operator<CPUContext> {
public:
GetAllBlobNamesOp(const OperatorDef& operator_def, Workspace* ws)
explicit GetAllBlobNamesOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
include_shared_(GetSingleArgument<int>("include_shared", true)),
ws_(ws) {}