mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
refactor caffe2 operator constructors - 9/9 (#17090)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17090 clangr codemod Reviewed By: ezyang Differential Revision: D14078550 fbshipit-source-id: 68e6de4298e55ce83039b7806c1a275c4d6593c8
This commit is contained in:
committed by
Facebook Github Bot
parent
9bcceb75b5
commit
a9395ce259
@ -20,8 +20,9 @@ template <class Context>
|
||||
class NanCheckOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
NanCheckOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit NanCheckOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override;
|
||||
|
||||
@ -46,8 +47,9 @@ class WallClockTimeOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
WallClockTimeOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit WallClockTimeOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
int64_t nanoseconds = static_cast<long int>(
|
||||
@ -70,7 +72,7 @@ class PrintOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
USE_DISPATCH_HELPER;
|
||||
PrintOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
explicit PrintOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
tensor_printer_(
|
||||
operator_def.input(0),
|
||||
@ -395,8 +397,9 @@ class WeightedSumGradientOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
WeightedSumGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit WeightedSumGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
grad_on_w_(this->template GetSingleArgument<bool>("grad_on_w", false)) {
|
||||
}
|
||||
|
||||
@ -597,8 +600,9 @@ class ScatterAssignOp : public Operator<Context> {
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
virtual ~ScatterAssignOp() {}
|
||||
|
||||
ScatterAssignOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit ScatterAssignOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
runners_({{{TensorProto_DataType_INT32, TensorProto_DataType_FLOAT},
|
||||
&ScatterAssignOp::DoRun<int32_t, float>},
|
||||
{{TensorProto_DataType_INT32, TensorProto_DataType_FLOAT16},
|
||||
@ -871,8 +875,9 @@ template <class Context>
|
||||
class LengthsToWeightsOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
LengthsToWeightsOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit LengthsToWeightsOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
power_(this->template GetSingleArgument<float>("power", 0.5)) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
@ -1149,8 +1154,9 @@ class LengthsGatherOp : public Operator<Context> {
|
||||
template <typename T, class Context>
|
||||
class AccumulateHistogramOp : public Operator<Context> {
|
||||
public:
|
||||
AccumulateHistogramOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<Context>(def, ws),
|
||||
template <class... Args>
|
||||
explicit AccumulateHistogramOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
lower_bound_(
|
||||
this->template GetSingleArgument<float>("lower_bound", 0.0)),
|
||||
upper_bound_(
|
||||
@ -1288,8 +1294,9 @@ class RangeOp : public Operator<Context> {
|
||||
|
||||
class ThrowExceptionOp : public Operator<CPUContext> {
|
||||
public:
|
||||
ThrowExceptionOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit ThrowExceptionOp(Args&&... args)
|
||||
: Operator<CPUContext>(std::forward<Args>(args)...),
|
||||
message_(GetSingleArgument<std::string>(
|
||||
"message",
|
||||
"Exception from ThrowExceptionOp")) {}
|
||||
@ -1304,8 +1311,9 @@ class ThrowExceptionOp : public Operator<CPUContext> {
|
||||
|
||||
class ThrowChildThreadExceptionOp : public Operator<CPUContext> {
|
||||
public:
|
||||
ThrowChildThreadExceptionOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit ThrowChildThreadExceptionOp(Args&&... args)
|
||||
: Operator<CPUContext>(std::forward<Args>(args)...),
|
||||
message_(GetSingleArgument<std::string>(
|
||||
"message",
|
||||
"Exception from ThrowChildThreadExceptionOp")) {}
|
||||
@ -1323,8 +1331,9 @@ class ThrowChildThreadExceptionOp : public Operator<CPUContext> {
|
||||
|
||||
class LogFatalOp : public Operator<CPUContext> {
|
||||
public:
|
||||
LogFatalOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit LogFatalOp(Args&&... args)
|
||||
: Operator<CPUContext>(std::forward<Args>(args)...),
|
||||
message_(GetSingleArgument<std::string>(
|
||||
"message",
|
||||
"Logging from LogFatalOp")) {}
|
||||
@ -1340,8 +1349,9 @@ class LogFatalOp : public Operator<CPUContext> {
|
||||
|
||||
class FailOp : public Operator<CPUContext> {
|
||||
public:
|
||||
FailOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit FailOp(Args&&... args)
|
||||
: Operator<CPUContext>(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
return false;
|
||||
|
@ -12,8 +12,10 @@ class CuDNNWeightedSumOp : public Operator<CUDAContext> {
|
||||
public:
|
||||
USE_OPERATOR_FUNCTIONS(CUDAContext);
|
||||
|
||||
CuDNNWeightedSumOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CUDAContext>(operator_def, ws), cudnn_wrapper_(&context_) {
|
||||
template <class... Args>
|
||||
explicit CuDNNWeightedSumOp(Args&&... args)
|
||||
: Operator<CUDAContext>(std::forward<Args>(args)...),
|
||||
cudnn_wrapper_(&context_) {
|
||||
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
|
||||
CUDNN_ENFORCE(cudnnCreateOpTensorDescriptor(&add_desc_));
|
||||
// Both float and at::Half require opTensorCompType to be CUDNN_DATA_FLOAT.
|
||||
|
@ -29,10 +29,9 @@ void VariableLengthSequencePadding(
|
||||
template <typename T, typename Context>
|
||||
class VariableLengthSequencePaddingOp : public Operator<Context> {
|
||||
public:
|
||||
VariableLengthSequencePaddingOp(
|
||||
const OperatorDef& operator_def,
|
||||
Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit VariableLengthSequencePaddingOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
@ -9,8 +9,9 @@ class WeightedMultiSamplingOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
WeightedMultiSamplingOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit WeightedMultiSamplingOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
num_samples_(
|
||||
this->template GetSingleArgument<int64_t>("num_samples", 0)) {
|
||||
CAFFE_ENFORCE_GE(num_samples_, 0);
|
||||
|
@ -13,8 +13,9 @@ namespace caffe2 {
|
||||
template <typename T, class Context>
|
||||
class WeightedSampleOp final : public Operator<Context> {
|
||||
public:
|
||||
WeightedSampleOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit WeightedSampleOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
|
@ -10,7 +10,7 @@ namespace caffe2 {
|
||||
template <class Context>
|
||||
class WhileOp final : public Operator<Context> {
|
||||
public:
|
||||
WhileOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
explicit WhileOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {
|
||||
CAFFE_ENFORCE(
|
||||
this->template HasSingleArgumentOfType<NetDef>("loop_net"),
|
||||
|
@ -6,7 +6,7 @@ namespace {
|
||||
|
||||
class GetAllBlobNamesOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
GetAllBlobNamesOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
explicit GetAllBlobNamesOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
include_shared_(GetSingleArgument<int>("include_shared", true)),
|
||||
ws_(ws) {}
|
||||
|
Reference in New Issue
Block a user