[itemwise-dropout][1/x][low-level module] Implement Itemwise Sparse Feature Dropout in Dper3 (#59322)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59322

Implement sparse feature dropout (with replacement) that can drop out individual items in each sparse feature. For example, the existing sparse feature dropout with replacement drops out whole feature (e.g., a list of page ids) when the feature is selected for drop out. This itemwise dropout assigns probability and drops out to individual items in sparse features.

Test Plan:
```
buck test mode/dev caffe2/torch/fb/sparsenn:test
```

https://www.internalfb.com/intern/testinfra/testrun/281475166777899/

```
buck test mode/dev //dper3/dper3/modules/tests:sparse_itemwise_dropout_with_replacement_test
```
https://www.internalfb.com/intern/testinfra/testrun/6473924504443423

```
buck test mode/opt caffe2/caffe2/python:layers_test
```
https://www.internalfb.com/intern/testinfra/testrun/2533274848456607

```
buck test mode/opt caffe2/caffe2/python/operator_test:sparse_itemwise_dropout_with_replacement_op_test
```
https://www.internalfb.com/intern/testinfra/testrun/8725724318782701

Reviewed By: Wakeupbuddy

Differential Revision: D27867213

fbshipit-source-id: 8e173c7b3294abbc8bf8a3b04f723cb170446b96
This commit is contained in:
Jeongmin Lee
2021-06-04 19:56:10 -07:00
committed by Facebook GitHub Bot
parent 68df4d40d2
commit bca25d97ad
5 changed files with 365 additions and 0 deletions

View File

@ -0,0 +1,125 @@
#include "caffe2/operators/sparse_itemwise_dropout_with_replacement_op.h"
#include <algorithm>
#include <iterator>
namespace caffe2 {
template <>
bool SparseItemwiseDropoutWithReplacementOp<CPUContext>::RunOnDevice() {
auto& X = Input(0);
CAFFE_ENFORCE_EQ(X.ndim(), 1, "Input tensor should be 1-D");
const int64_t* Xdata = X.data<int64_t>();
auto& Lengths = Input(1);
CAFFE_ENFORCE_EQ(Lengths.ndim(), 1, "Lengths tensor should be 1-D");
auto* OutputLengths = Output(1, Lengths.size(), at::dtype<int32_t>());
int32_t const* input_lengths_data = Lengths.template data<int32_t>();
int32_t* output_lengths_data =
OutputLengths->template mutable_data<int32_t>();
// Check that input lengths add up to the length of input data
int total_input_length = 0;
for (int i = 0; i < Lengths.numel(); ++i) {
total_input_length += input_lengths_data[i];
}
CAFFE_ENFORCE_EQ(
total_input_length,
X.numel(),
"Inconsistent input data. Number of elements should match total length.");
at::bernoulli_distribution<double> dist(1. - ratio_);
auto* gen = context_.RandGenerator();
const float _BARNUM = 0.5;
vector<bool> selected(total_input_length, false);
for (int i = 0; i < total_input_length; ++i) {
if (dist(gen) > _BARNUM) {
selected[i] = true;
}
}
for (int i = 0; i < Lengths.numel(); ++i) {
output_lengths_data[i] = input_lengths_data[i];
}
auto* Y = Output(0, {total_input_length}, at::dtype<int64_t>());
int64_t* Ydata = Y->template mutable_data<int64_t>();
for (int i = 0; i < total_input_length; ++i) {
if (selected[i]) {
// Copy logical elements from input to output
Ydata[i] = Xdata[i];
} else {
Ydata[i] = replacement_value_;
}
}
return true;
}
REGISTER_CPU_OPERATOR(
SparseItemwiseDropoutWithReplacement,
SparseItemwiseDropoutWithReplacementOp<CPUContext>);
OPERATOR_SCHEMA(SparseItemwiseDropoutWithReplacement)
.NumInputs(2)
.SameNumberOfOutput()
.SetDoc(R"DOC(
`SparseItemwiseDropoutWithReplacement` takes a 1-d input tensor and a lengths tensor.
Values in the Lengths tensor represent how many input elements consitute each
example in a given batch. The each input value in the tensor of an example can be
replaced with the replacement value with probability given by the `ratio`
argument.
<details>
<summary> <b>Example</b> </summary>
**Code**
```
workspace.ResetWorkspace()
op = core.CreateOperator(
"SparseItemwiseDropoutWithReplacement",
["X", "Lengths"],
["Y", "OutputLengths"],
ratio=0.5,
replacement_value=-1
)
workspace.FeedBlob("X", np.array([1, 2, 3, 4, 5]).astype(np.int64))
workspace.FeedBlob("Lengths", np.array([2, 3]).astype(np.int32))
print("X:", workspace.FetchBlob("X"))
print("Lengths:", workspace.FetchBlob("Lengths"))
workspace.RunOperatorOnce(op)
print("Y:", workspace.FetchBlob("Y"))
print("OutputLengths:", workspace.FetchBlob("OutputLengths"))
```
**Result**
```
X: [1, 2, 3, 4, 5]
Lengths: [2, 3]
Y: [1, 2, -1]
OutputLengths: [2, 1]
```
</details>
)DOC")
.Arg(
"ratio",
"*(type: float; default: 0.0)* Probability of an element to be replaced.")
.Arg(
"replacement_value",
"*(type: int64_t; default: 0)* Value elements are replaced with.")
.Input(0, "X", "*(type: Tensor`<int64_t>`)* Input data tensor.")
.Input(
1,
"Lengths",
"*(type: Tensor`<int32_t>`)* Lengths tensor for input.")
.Output(0, "Y", "*(type: Tensor`<int64_t>`)* Output tensor.")
.Output(1, "OutputLengths", "*(type: Tensor`<int32_t>`)* Output tensor.");
NO_GRADIENT(SparseItemwiseDropoutWithReplacement);
} // namespace caffe2

View File

@ -0,0 +1,35 @@
#ifndef CAFFE2_OPERATORS_SPARSE_ITEMWISE_DROPOUT_WITH_REPLACEMENT_OP_H_
#define CAFFE2_OPERATORS_SPARSE_ITEMWISE_DROPOUT_WITH_REPLACEMENT_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <class Context>
class SparseItemwiseDropoutWithReplacementOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit SparseItemwiseDropoutWithReplacementOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
ratio_(this->template GetSingleArgument<float>("ratio", 0.0)),
replacement_value_(
this->template GetSingleArgument<int64_t>("replacement_value", 0)) {
// It is allowed to drop all or drop none.
CAFFE_ENFORCE_GE(ratio_, 0.0, "Ratio should be a valid probability");
CAFFE_ENFORCE_LE(ratio_, 1.0, "Ratio should be a valid probability");
}
bool RunOnDevice() override;
private:
float ratio_;
int64_t replacement_value_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_SPARSE_ITEMWISE_DROPOUT_WITH_REPLACEMENT_OP_H_

View File

@ -0,0 +1,103 @@
from caffe2.python import schema
from caffe2.python.layers.layers import (
IdList,
ModelLayer,
)
# Model layer for implementing probabilistic replacement of individual elements in
# IdLists. Takes probabilities for train, eval and predict nets as input, as
# well as the replacement value when dropout happens. For features we may have
# available to us in train net but not in predict net, we'd set dropout
# probability for predict net to be 1.0 and set the feature to the replacement
# value given here. This way, the value is tied to the particular model and not
# to any specific logic in feature processing in serving.
# Consider the following example where X is the values in the IdList and Lengths
# is the number of values corresponding to each example.
# X: [1, 2, 3, 4, 5]
# Lengths: [2, 3]
# This IdList contains 2 IdList features of lengths 2, 3. Let's assume we used a
# ratio of 0.5 and ended up dropping out 2nd item in 2nd IdList feature, and used a
# replacement value of -1. We will end up with the following IdList.
# Y: [1, 2, 3, -1, 5]
# OutputLengths: [2, 3]
# where the 2nd item in 2nd IdList feature [4] was replaced with [-1].
class SparseItemwiseDropoutWithReplacement(ModelLayer):
def __init__(
self,
model,
input_record,
dropout_prob_train,
dropout_prob_eval,
dropout_prob_predict,
replacement_value,
name='sparse_itemwise_dropout',
**kwargs):
super(SparseItemwiseDropoutWithReplacement, self).__init__(model, name, input_record, **kwargs)
assert schema.equal_schemas(input_record, IdList), "Incorrect input type"
self.dropout_prob_train = float(dropout_prob_train)
self.dropout_prob_eval = float(dropout_prob_eval)
self.dropout_prob_predict = float(dropout_prob_predict)
self.replacement_value = int(replacement_value)
assert (self.dropout_prob_train >= 0 and
self.dropout_prob_train <= 1.0), \
"Expected 0 <= dropout_prob_train <= 1, but got %s" \
% self.dropout_prob_train
assert (self.dropout_prob_eval >= 0 and
self.dropout_prob_eval <= 1.0), \
"Expected 0 <= dropout_prob_eval <= 1, but got %s" \
% dropout_prob_eval
assert (self.dropout_prob_predict >= 0 and
self.dropout_prob_predict <= 1.0), \
"Expected 0 <= dropout_prob_predict <= 1, but got %s" \
% dropout_prob_predict
assert(self.dropout_prob_train > 0 or
self.dropout_prob_eval > 0 or
self.dropout_prob_predict > 0), \
"Ratios all set to 0.0 for train, eval and predict"
self.output_schema = schema.NewRecord(model.net, IdList)
if input_record.lengths.metadata:
self.output_schema.lengths.set_metadata(
input_record.lengths.metadata)
if input_record.items.metadata:
self.output_schema.items.set_metadata(
input_record.items.metadata)
def _add_ops(self, net, ratio):
input_values_blob = self.input_record.items()
input_lengths_blob = self.input_record.lengths()
output_lengths_blob = self.output_schema.lengths()
output_values_blob = self.output_schema.items()
net.SparseItemwiseDropoutWithReplacement(
[
input_values_blob,
input_lengths_blob
],
[
output_values_blob,
output_lengths_blob
],
ratio=ratio,
replacement_value=self.replacement_value
)
def add_train_ops(self, net):
self._add_ops(net, self.dropout_prob_train)
def add_eval_ops(self, net):
self._add_ops(net, self.dropout_prob_eval)
def add_ops(self, net):
self._add_ops(net, self.dropout_prob_predict)

View File

@ -2480,3 +2480,37 @@ class TestLayers(LayersTestCase):
predict_net = self.get_predict_net()
self.assertNetContainOps(predict_net, [sparse_lookup_op_spec])
def testSparseItemwiseDropoutWithReplacement(self):
input_record = schema.NewRecord(self.model.net, IdList)
self.model.output_schema = schema.Struct()
lengths_blob = input_record.field_blobs()[0]
values_blob = input_record.field_blobs()[1]
lengths = np.array([1] * 10).astype(np.int32)
values = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.int64)
workspace.FeedBlob(lengths_blob, lengths)
workspace.FeedBlob(values_blob, values)
out = self.model.SparseItemwiseDropoutWithReplacement(
input_record, 0.0, 0.5, 1.0, -1, output_names_or_num=1)
self.assertEqual(schema.List(schema.Scalar(np.int64,)), out)
train_init_net, train_net = self.get_training_nets()
eval_net = self.get_eval_net()
predict_net = self.get_predict_net()
workspace.RunNetOnce(train_init_net)
workspace.RunNetOnce(train_net)
out_values = workspace.FetchBlob(out.items())
out_lengths = workspace.FetchBlob(out.lengths())
self.assertBlobsEqual(out_values, values)
self.assertBlobsEqual(out_lengths, lengths)
workspace.RunNetOnce(eval_net)
workspace.RunNetOnce(predict_net)
predict_values = workspace.FetchBlob("values_auto_0")
predict_lengths = workspace.FetchBlob("lengths_auto_0")
self.assertBlobsEqual(predict_values, np.array([-1] * 10).astype(np.int64))
self.assertBlobsEqual(predict_lengths, lengths)

View File

@ -0,0 +1,68 @@
from caffe2.python import core
from hypothesis import given
import caffe2.python.hypothesis_test_util as hu
import numpy as np
class SparseItemwiseDropoutWithReplacementTest(hu.HypothesisTestCase):
@given(**hu.gcs_cpu_only)
def test_no_dropout(self, gc, dc):
X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.int64)
Lengths = np.array([2, 2, 2, 2, 2]).astype(np.int32)
replacement_value = -1
self.ws.create_blob("X").feed(X)
self.ws.create_blob("Lengths").feed(Lengths)
sparse_dropout_op = core.CreateOperator(
"SparseItemwiseDropoutWithReplacement", ["X", "Lengths"], ["Y", "LY"],
ratio=0.0, replacement_value=replacement_value)
self.ws.run(sparse_dropout_op)
Y = self.ws.blobs["Y"].fetch()
OutputLengths = self.ws.blobs["LY"].fetch()
self.assertListEqual(X.tolist(), Y.tolist(),
"Values should stay unchanged")
self.assertListEqual(Lengths.tolist(), OutputLengths.tolist(),
"Lengths should stay unchanged.")
@given(**hu.gcs_cpu_only)
def test_all_dropout(self, gc, dc):
X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.int64)
Lengths = np.array([2, 2, 2, 2, 2]).astype(np.int32)
replacement_value = -1
self.ws.create_blob("X").feed(X)
self.ws.create_blob("Lengths").feed(Lengths)
sparse_dropout_op = core.CreateOperator(
"SparseItemwiseDropoutWithReplacement", ["X", "Lengths"], ["Y", "LY"],
ratio=1.0, replacement_value=replacement_value)
self.ws.run(sparse_dropout_op)
y = self.ws.blobs["Y"].fetch()
lengths = self.ws.blobs["LY"].fetch()
for elem in y:
self.assertEqual(elem, replacement_value, "Expected all \
negative elements when dropout ratio is 1.")
for length in lengths:
self.assertEqual(length, 2)
self.assertEqual(sum(lengths), len(y))
@given(**hu.gcs_cpu_only)
def test_all_dropout_empty_input(self, gc, dc):
X = np.array([]).astype(np.int64)
Lengths = np.array([0]).astype(np.int32)
replacement_value = -1
self.ws.create_blob("X").feed(X)
self.ws.create_blob("Lengths").feed(Lengths)
sparse_dropout_op = core.CreateOperator(
"SparseItemwiseDropoutWithReplacement", ["X", "Lengths"], ["Y", "LY"],
ratio=1.0, replacement_value=replacement_value)
self.ws.run(sparse_dropout_op)
y = self.ws.blobs["Y"].fetch()
lengths = self.ws.blobs["LY"].fetch()
self.assertEqual(len(y), 0, "Expected no dropout value")
self.assertEqual(len(lengths), 1, "Expected single element \
in lengths array")
self.assertEqual(lengths[0], 0, "Expected 0 as sole length")
self.assertEqual(sum(lengths), len(y))