mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adding gradient to Boolean Mask operator (#21423)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21423 - add gradient for boolean mask - add test for gradient checking Reviewed By: BIT-silence Differential Revision: D15640036 fbshipit-source-id: 79f40c6901e805bf1b8e9b01b57903e30b00f654
This commit is contained in:
committed by
Facebook Github Bot
parent
d3d195e0b1
commit
696b2c89b4
@ -110,7 +110,35 @@ bool BooleanMaskOp<CPUContext>::RunOnDevice() {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
template <class T>
|
||||
bool BooleanMaskOpGradient<CPUContext>::DoRunWithType() {
|
||||
const auto& mask = Input(0);
|
||||
const auto& dY = Input(1);
|
||||
auto* dX = Output(0);
|
||||
|
||||
const int data_length_before_mask = mask.size(0);
|
||||
|
||||
dX->Resize(data_length_before_mask);
|
||||
|
||||
// TODO: we should support any type, not just float
|
||||
T* dXdata = dX->template mutable_data<T>();
|
||||
const T* dYdata = dY.template data<T>();
|
||||
const bool* mask_data = mask.template data<bool>();
|
||||
|
||||
int ind = 0;
|
||||
|
||||
for (int i = 0; i < data_length_before_mask; i++) {
|
||||
dXdata[i] = mask_data[i] ? dYdata[ind++] : 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
REGISTER_CPU_OPERATOR(BooleanMask, BooleanMaskOp<CPUContext>);
|
||||
REGISTER_CPU_GRADIENT_OPERATOR(
|
||||
BooleanMaskGradient,
|
||||
BooleanMaskOpGradient<CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(BooleanMaskLengths, BooleanMaskLengthsOp<CPUContext>);
|
||||
|
||||
OPERATOR_SCHEMA(BooleanMask)
|
||||
@ -164,9 +192,18 @@ masked_indices: [0 3 4]
|
||||
|
||||
)DOC")
|
||||
.Input(0, "data", "(*Tensor*): 1D input tensor")
|
||||
.Input(1, "mask", "(*Tensor`<bool>`*): tensor of bools which determines the input elements that will be left in the `masked_data` output tensor; same shape as `data`")
|
||||
.Output(0, "masked_data", "(*Tensor*): 1D tensor of same type as `data` input that contains the masked input tensor")
|
||||
.Output(1, "masked_indices", "(*Tensor`<int>`*): 1D tensor of indices of the True elements in the `mask` tensor");
|
||||
.Input(
|
||||
1,
|
||||
"mask",
|
||||
"(*Tensor`<bool>`*): tensor of bools which determines the input elements that will be left in the `masked_data` output tensor; same shape as `data`")
|
||||
.Output(
|
||||
0,
|
||||
"masked_data",
|
||||
"(*Tensor*): 1D tensor of same type as `data` input that contains the masked input tensor")
|
||||
.Output(
|
||||
1,
|
||||
"masked_indices",
|
||||
"(*Tensor`<int>`*): 1D tensor of indices of the True elements in the `mask` tensor");
|
||||
|
||||
OPERATOR_SCHEMA(BooleanMaskLengths)
|
||||
.NumInputs(2)
|
||||
@ -218,13 +255,35 @@ masked_lengths: [0 2 2]
|
||||
</details>
|
||||
|
||||
)DOC")
|
||||
.Input(0, "lengths", "(*Tensor`<int>`*): input tensor containing segment lengths")
|
||||
.Input(
|
||||
0,
|
||||
"lengths",
|
||||
"(*Tensor`<int>`*): input tensor containing segment lengths")
|
||||
.Input(1, "mask", "(*Tensor`<bool>`*): A 1D bool tensor of values to keep.")
|
||||
.Output(0, "masked_lengths", "(*Tensor`<int>`*): 1D tensor of same type as inputs that contains the sequence");
|
||||
.Output(
|
||||
0,
|
||||
"masked_lengths",
|
||||
"(*Tensor`<int>`*): 1D tensor of same type as inputs that contains the sequence");
|
||||
|
||||
NO_GRADIENT(BooleanMask)
|
||||
GRADIENT_OPERATOR_SCHEMA(BooleanMaskGradient).NumInputs(2).NumOutputs(1);
|
||||
|
||||
namespace {
|
||||
class GetBooleanMaskGradient : public GradientMakerBase {
|
||||
using GradientMakerBase::GradientMakerBase;
|
||||
vector<OperatorDef> GetGradientDefs() override {
|
||||
return SingleGradientDef(
|
||||
"BooleanMaskGradient",
|
||||
"",
|
||||
vector<string>{I(1), GO(0)},
|
||||
vector<string>{GI(0)});
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_GRADIENT(BooleanMask, GetBooleanMaskGradient);
|
||||
NO_GRADIENT(BooleanMaskLengths);
|
||||
|
||||
} // namespace
|
||||
|
||||
const float minf = -1.0f * std::numeric_limits<float>::infinity();
|
||||
|
||||
// Template this on a functor object so we can generate different
|
||||
|
@ -1,5 +1,5 @@
|
||||
#ifndef BOOLEAN_MASK_OPS_H
|
||||
#define BOOLEAN_MASK_OPS_H
|
||||
#ifndef CAFFE2_OPERATORS_BOOLEAN_MASK_OPS_H_
|
||||
#define CAFFE2_OPERATORS_BOOLEAN_MASK_OPS_H_
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
@ -19,6 +19,27 @@ class BooleanMaskOp final : public Operator<Context> {
|
||||
bool RunOnDevice() override;
|
||||
};
|
||||
|
||||
template <class Context>
|
||||
class BooleanMaskOpGradient final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
BooleanMaskOpGradient(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
|
||||
/* Calculating the gradient of the Boolean Mask operator
|
||||
* requires access to the original mask that's passed in,
|
||||
* and the gradient to backpropagate.
|
||||
*/
|
||||
bool RunOnDevice() override {
|
||||
return DispatchHelper<
|
||||
TensorTypes<bool, std::int32_t, std::int64_t, float, double>>::
|
||||
call(this, Input(1));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool DoRunWithType();
|
||||
};
|
||||
|
||||
template <class Context>
|
||||
class SequenceMaskOp final : public Operator<Context> {
|
||||
public:
|
||||
|
@ -264,7 +264,7 @@ class GradientChecker:
|
||||
# hack.
|
||||
grad_ops, g_input = getGradientForOp(op)
|
||||
|
||||
|
||||
|
||||
_input_device_options = input_device_options or \
|
||||
core.InferOpBlobDevicesAsDict(op)[0]
|
||||
# First, feed in the input.
|
||||
@ -285,7 +285,7 @@ class GradientChecker:
|
||||
raise Exception(
|
||||
"Mismatched gradient shapes: estimated ({}), grad ({})".format(
|
||||
grad_estimate.shape, grad.shape))
|
||||
|
||||
|
||||
dims_to_check = inputs[input_to_check].size
|
||||
for current_dim in range(dims_to_check):
|
||||
# Positive gradient
|
||||
|
@ -12,10 +12,23 @@ import numpy as np
|
||||
|
||||
|
||||
class TestBooleanMaskOp(serial.SerializedTestCase):
|
||||
@given(x=hu.tensor1d(min_len=1,
|
||||
max_len=100,
|
||||
elements=st.floats(min_value=0.5, max_value=1.0)),
|
||||
**hu.gcs_cpu_only)
|
||||
def test_boolean_mask_gradient(self, x, gc, dc):
|
||||
op = core.CreateOperator("BooleanMask",
|
||||
["data", "mask"],
|
||||
"masked_data")
|
||||
mask = np.random.choice(a=[True, False], size=x.shape[0])
|
||||
expected_gradient = np.copy(mask).astype(int)
|
||||
self.assertDeviceChecks(dc, op, [x, mask], [0])
|
||||
self.assertGradientChecks(gc, op, [x, mask], 0, [0])
|
||||
|
||||
@serial.given(x=hu.tensor(min_dim=1,
|
||||
max_dim=5,
|
||||
elements=st.floats(min_value=0.5, max_value=1.0)),
|
||||
|
||||
@given(x=hu.tensor1d(min_len=1,
|
||||
max_len=5,
|
||||
elements=st.floats(min_value=0.5, max_value=1.0)),
|
||||
**hu.gcs)
|
||||
def test_boolean_mask(self, x, gc, dc):
|
||||
op = core.CreateOperator("BooleanMask",
|
||||
@ -25,13 +38,12 @@ class TestBooleanMaskOp(serial.SerializedTestCase):
|
||||
|
||||
def ref(x, mask):
|
||||
return (x[mask],)
|
||||
|
||||
self.assertReferenceChecks(gc, op, [x, mask], ref)
|
||||
self.assertDeviceChecks(dc, op, [x, mask], [0])
|
||||
|
||||
@given(x=hu.tensor(min_dim=1,
|
||||
max_dim=5,
|
||||
elements=st.floats(min_value=0.5, max_value=1.0)),
|
||||
@given(x=hu.tensor1d(min_len=1,
|
||||
max_len=5,
|
||||
elements=st.floats(min_value=0.5, max_value=1.0)),
|
||||
**hu.gcs)
|
||||
def test_boolean_mask_indices(self, x, gc, dc):
|
||||
op = core.CreateOperator("BooleanMask",
|
||||
@ -54,7 +66,7 @@ class TestBooleanMaskOp(serial.SerializedTestCase):
|
||||
x = x.astype(dtype)
|
||||
return x, dc
|
||||
|
||||
@serial.given(x=hu.tensor(min_dim=2,
|
||||
@given(x=hu.tensor(min_dim=2,
|
||||
max_dim=5,
|
||||
elements=st.floats(min_value=0.5, max_value=1.0)),
|
||||
dtype=st.sampled_from([np.float32, np.float16]),
|
||||
|
Reference in New Issue
Block a user