mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Cost Inference for AdaGrad and RowWiseSparseAdagrad
Summary: Add cost inference for AdaGrad and RowWiseSparseAdagrad Test Plan: Ran `buck test caffe2/caffe2/python/operator_test:adagrad_test` Result: https://our.intern.facebook.com/intern/testinfra/testrun/5629499567799494 Reviewed By: bwasti Differential Revision: D23442607 fbshipit-source-id: 67800fb82475696512ad19a43067774247f8b230
This commit is contained in:
committed by
Facebook GitHub Bot
parent
2f044d4ee5
commit
041573c8cd
@ -2,6 +2,46 @@
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
static OpSchema::Cost CostInferenceForAdagrad(
|
||||
const OperatorDef& def,
|
||||
const vector<TensorShape>& inputs) {
|
||||
CAFFE_ENFORCE_GE(inputs.size(), 4, "Adagrad requires at least 4 inputs");
|
||||
|
||||
const TensorShape param = inputs[0];
|
||||
const TensorShape moment = inputs[1];
|
||||
const TensorShape grad = inputs[2];
|
||||
const TensorShape lr = inputs[3];
|
||||
|
||||
uint64_t grad_size = nElemFromDim(grad);
|
||||
int output_size = def.output_size();
|
||||
|
||||
OpSchema::Cost c;
|
||||
// +2: applying weight decay and add to grads
|
||||
// +3: updading moments
|
||||
// +3: updating effective lr (including 1 sqrt)
|
||||
// +2: updating params
|
||||
c.flops = grad_size * 10;
|
||||
|
||||
uint64_t bytes_written =
|
||||
grad_size * (sizeof(param.data_type()) + sizeof(moment.data_type()));
|
||||
|
||||
if (output_size == 3) {
|
||||
// also need to output effective learning rate in this case
|
||||
// assume it's the same data type as lr
|
||||
bytes_written += grad_size * sizeof(lr.data_type());
|
||||
} else if (output_size == 4) {
|
||||
// also need to output effective learning rate and updates in this case
|
||||
// assume update is the same data type as param
|
||||
bytes_written +=
|
||||
grad_size * (sizeof(lr.data_type()) + sizeof(param.data_type()));
|
||||
}
|
||||
c.bytes_written = bytes_written;
|
||||
c.bytes_read = c.bytes_written +
|
||||
grad_size * (sizeof(grad.data_type()) + sizeof(lr.data_type()));
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
REGISTER_CPU_OPERATOR(Adagrad, AdagradOp<CPUContext>);
|
||||
// For backward compatibility
|
||||
REGISTER_CPU_OPERATOR_WITH_ENGINE(Adagrad, SIMD, AdagradOp<CPUContext>);
|
||||
@ -37,7 +77,9 @@ Optionally returns effective_lr and update as well.
|
||||
.Arg(
|
||||
"decay",
|
||||
"Default 1. If it is in (0, 1), the gradient square sum "
|
||||
"is decayed by this factor.");
|
||||
"is decayed by this factor.")
|
||||
.CostInferenceFunction(
|
||||
OpSchema::CostInferenceFunctionType(CostInferenceForAdagrad));
|
||||
|
||||
static OpSchema::Cost CostInferenceForSparseAdagrad(
|
||||
const OperatorDef& /* unused */,
|
||||
@ -91,6 +133,49 @@ new_moment) as in the dense case.
|
||||
.CostInferenceFunction(
|
||||
OpSchema::CostInferenceFunctionType(CostInferenceForSparseAdagrad));
|
||||
|
||||
static OpSchema::Cost CostInferenceForRowWiseSparseAdagrad(
|
||||
const OperatorDef& /* unused */,
|
||||
const vector<TensorShape>& inputs) {
|
||||
CAFFE_ENFORCE_GE(
|
||||
inputs.size(), 5, "RowWiseSparseAdagrad requires at least 4 inputs");
|
||||
|
||||
const TensorShape param = inputs[0];
|
||||
const TensorShape moment = inputs[1];
|
||||
const TensorShape indices = inputs[2];
|
||||
const TensorShape grad = inputs[3];
|
||||
const TensorShape lr = inputs[4];
|
||||
|
||||
uint64_t n = nElemFromDim(indices);
|
||||
uint64_t grad_size = nElemFromDim(grad);
|
||||
auto block_size = grad_size / n;
|
||||
|
||||
OpSchema::Cost c;
|
||||
if (block_size == 1) {
|
||||
// +2: applying weight decay and add to grads
|
||||
// +2: updading moments
|
||||
// +5: updating params
|
||||
c.flops = n * 9;
|
||||
c.bytes_written =
|
||||
n * (sizeof(param.data_type()) + sizeof(moment.data_type()));
|
||||
c.bytes_read = c.bytes_written +
|
||||
n *
|
||||
(sizeof(grad.data_type()) + sizeof(indices.data_type()) +
|
||||
sizeof(lr.data_type()));
|
||||
} else {
|
||||
// 5 per block (not counting index transforms)
|
||||
// 8 for each value of a block
|
||||
c.flops = n * (5 + (block_size * 8));
|
||||
c.bytes_written =
|
||||
n * sizeof(moment.data_type()) + n * block_size * (param.data_type());
|
||||
|
||||
c.bytes_read = c.bytes_written + n * (sizeof(lr.data_type())) +
|
||||
2 * n * block_size *
|
||||
(sizeof(grad.data_type()) + sizeof(param.data_type()));
|
||||
}
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
REGISTER_CPU_OPERATOR(RowWiseSparseAdagrad, RowWiseSparseAdagradOp<CPUContext>);
|
||||
// For backward compatibility
|
||||
REGISTER_CPU_OPERATOR_WITH_ENGINE(
|
||||
@ -119,7 +204,9 @@ also be a 1D tensor indexing into the rows of param.
|
||||
.Input(4, "lr", "learning rate")
|
||||
.Output(0, "output_param", "Updated parameters")
|
||||
.Output(1, "output_moment_1", "Updated moment")
|
||||
.Arg("epsilon", "Default 1e-5");
|
||||
.Arg("epsilon", "Default 1e-5")
|
||||
.CostInferenceFunction(
|
||||
OpSchema::CostInferenceFunctionType(CostInferenceForRowWiseSparseAdagrad));
|
||||
|
||||
SHOULD_NOT_DO_GRADIENT(Adagrad);
|
||||
SHOULD_NOT_DO_GRADIENT(SparseAdagrad);
|
||||
|
Reference in New Issue
Block a user