[C2] Add string equality operator

Summary: This diff adds a string equality checking operator.

Test Plan: Unit tests

Differential Revision: D24042344

fbshipit-source-id: c8997c6130e3438f2ae95dae69f76978e2e95527
This commit is contained in:
Pawel Garbacki
2020-10-05 10:45:09 -07:00
committed by Facebook GitHub Bot
parent 162717e527
commit cf48872d28
2 changed files with 53 additions and 0 deletions

View File

@ -71,6 +71,17 @@ struct EndsWith {
std::string suffix_;
};
struct Equals {
explicit Equals(OperatorBase& op)
: text_(op.GetSingleArgument<std::string>("text", "")) {}
bool operator()(const std::string& str) {
return str == text_;
}
private:
std::string text_;
};
struct Prefix {
explicit Prefix(OperatorBase& op)
: length_(op.GetSingleArgument<int>("length", 3)) {}
@ -108,6 +119,9 @@ REGISTER_CPU_OPERATOR(
REGISTER_CPU_OPERATOR(
StringEndsWith,
StringElementwiseOp<EndsWith, FixedType<bool>>);
REGISTER_CPU_OPERATOR(
StringEquals,
StringElementwiseOp<Equals, FixedType<bool>>);
REGISTER_CPU_OPERATOR(StringJoin, StringJoinOp<CPUContext>);
OPERATOR_SCHEMA(StringPrefix)
@ -164,6 +178,17 @@ Returns tensor of boolean of the same dimension of input.
.Input(0, "strings", "Tensor of std::string.")
.Output(0, "bools", "Tensor of bools of same shape as input.");
OPERATOR_SCHEMA(StringEquals)
.NumInputs(1)
.NumOutputs(1)
.SetDoc(R"DOC(
Performs equality check on each string in the input tensor.
Returns tensor of booleans of the same dimension as input.
)DOC")
.Arg("text", "The text to check input strings equality against.")
.Input(0, "strings", "Tensor of std::string.")
.Output(0, "bools", "Tensor of bools of same shape as input.");
OPERATOR_SCHEMA(StringJoin)
.NumInputs(1)
.NumOutputs(1)
@ -187,6 +212,7 @@ SHOULD_NOT_DO_GRADIENT(StringPrefix);
SHOULD_NOT_DO_GRADIENT(StringSuffix);
SHOULD_NOT_DO_GRADIENT(StringStartsWith);
SHOULD_NOT_DO_GRADIENT(StringEndsWith);
SHOULD_NOT_DO_GRADIENT(StringEquals);
SHOULD_NOT_DO_GRADIENT(StringJoin);
}
} // namespace caffe2

View File

@ -119,6 +119,33 @@ class TestStringOps(serial.SerializedTestCase):
[strings],
string_ends_with_ref)
@given(strings=st.text(alphabet=['a', 'b']))
@settings(deadline=1000)
def test_string_equals(self, strings):
text = ""
if strings:
text = strings[0]
strings = np.array(
[str(a) for a in strings], dtype=np.object
)
def string_equals_ref(strings):
return (
np.array([a == text for a in strings], dtype=bool),
)
op = core.CreateOperator(
'StringEquals',
['strings'],
['bools'],
text=text)
self.assertReferenceChecks(
hu.cpu_do,
op,
[strings],
string_equals_ref)
if __name__ == "__main__":
import unittest
unittest.main()