Add Cost Inference for AdaGrad and RowWiseSparseAdagrad

Summary: Add cost inference for AdaGrad and RowWiseSparseAdagrad

Test Plan:
Ran `buck test caffe2/caffe2/python/operator_test:adagrad_test`
Result: https://our.intern.facebook.com/intern/testinfra/testrun/5629499567799494

Reviewed By: bwasti

Differential Revision: D23442607

fbshipit-source-id: 67800fb82475696512ad19a43067774247f8b230
This commit is contained in:
Jiyuan Qian
2020-09-02 17:50:50 -07:00
committed by Facebook GitHub Bot
parent 2f044d4ee5
commit 041573c8cd

View File

@ -2,6 +2,46 @@
namespace caffe2 {
static OpSchema::Cost CostInferenceForAdagrad(
const OperatorDef& def,
const vector<TensorShape>& inputs) {
CAFFE_ENFORCE_GE(inputs.size(), 4, "Adagrad requires at least 4 inputs");
const TensorShape param = inputs[0];
const TensorShape moment = inputs[1];
const TensorShape grad = inputs[2];
const TensorShape lr = inputs[3];
uint64_t grad_size = nElemFromDim(grad);
int output_size = def.output_size();
OpSchema::Cost c;
// +2: applying weight decay and add to grads
// +3: updading moments
// +3: updating effective lr (including 1 sqrt)
// +2: updating params
c.flops = grad_size * 10;
uint64_t bytes_written =
grad_size * (sizeof(param.data_type()) + sizeof(moment.data_type()));
if (output_size == 3) {
// also need to output effective learning rate in this case
// assume it's the same data type as lr
bytes_written += grad_size * sizeof(lr.data_type());
} else if (output_size == 4) {
// also need to output effective learning rate and updates in this case
// assume update is the same data type as param
bytes_written +=
grad_size * (sizeof(lr.data_type()) + sizeof(param.data_type()));
}
c.bytes_written = bytes_written;
c.bytes_read = c.bytes_written +
grad_size * (sizeof(grad.data_type()) + sizeof(lr.data_type()));
return c;
}
REGISTER_CPU_OPERATOR(Adagrad, AdagradOp<CPUContext>);
// For backward compatibility
REGISTER_CPU_OPERATOR_WITH_ENGINE(Adagrad, SIMD, AdagradOp<CPUContext>);
@ -37,7 +77,9 @@ Optionally returns effective_lr and update as well.
.Arg(
"decay",
"Default 1. If it is in (0, 1), the gradient square sum "
"is decayed by this factor.");
"is decayed by this factor.")
.CostInferenceFunction(
OpSchema::CostInferenceFunctionType(CostInferenceForAdagrad));
static OpSchema::Cost CostInferenceForSparseAdagrad(
const OperatorDef& /* unused */,
@ -91,6 +133,49 @@ new_moment) as in the dense case.
.CostInferenceFunction(
OpSchema::CostInferenceFunctionType(CostInferenceForSparseAdagrad));
static OpSchema::Cost CostInferenceForRowWiseSparseAdagrad(
const OperatorDef& /* unused */,
const vector<TensorShape>& inputs) {
CAFFE_ENFORCE_GE(
inputs.size(), 5, "RowWiseSparseAdagrad requires at least 4 inputs");
const TensorShape param = inputs[0];
const TensorShape moment = inputs[1];
const TensorShape indices = inputs[2];
const TensorShape grad = inputs[3];
const TensorShape lr = inputs[4];
uint64_t n = nElemFromDim(indices);
uint64_t grad_size = nElemFromDim(grad);
auto block_size = grad_size / n;
OpSchema::Cost c;
if (block_size == 1) {
// +2: applying weight decay and add to grads
// +2: updading moments
// +5: updating params
c.flops = n * 9;
c.bytes_written =
n * (sizeof(param.data_type()) + sizeof(moment.data_type()));
c.bytes_read = c.bytes_written +
n *
(sizeof(grad.data_type()) + sizeof(indices.data_type()) +
sizeof(lr.data_type()));
} else {
// 5 per block (not counting index transforms)
// 8 for each value of a block
c.flops = n * (5 + (block_size * 8));
c.bytes_written =
n * sizeof(moment.data_type()) + n * block_size * (param.data_type());
c.bytes_read = c.bytes_written + n * (sizeof(lr.data_type())) +
2 * n * block_size *
(sizeof(grad.data_type()) + sizeof(param.data_type()));
}
return c;
}
REGISTER_CPU_OPERATOR(RowWiseSparseAdagrad, RowWiseSparseAdagradOp<CPUContext>);
// For backward compatibility
REGISTER_CPU_OPERATOR_WITH_ENGINE(
@ -119,7 +204,9 @@ also be a 1D tensor indexing into the rows of param.
.Input(4, "lr", "learning rate")
.Output(0, "output_param", "Updated parameters")
.Output(1, "output_moment_1", "Updated moment")
.Arg("epsilon", "Default 1e-5");
.Arg("epsilon", "Default 1e-5")
.CostInferenceFunction(
OpSchema::CostInferenceFunctionType(CostInferenceForRowWiseSparseAdagrad));
SHOULD_NOT_DO_GRADIENT(Adagrad);
SHOULD_NOT_DO_GRADIENT(SparseAdagrad);