mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[C2] Add string equality operator
Summary: This diff adds a string equality checking operator. Test Plan: Unit tests Differential Revision: D24042344 fbshipit-source-id: c8997c6130e3438f2ae95dae69f76978e2e95527
This commit is contained in:
committed by
Facebook GitHub Bot
parent
162717e527
commit
cf48872d28
@ -71,6 +71,17 @@ struct EndsWith {
|
||||
std::string suffix_;
|
||||
};
|
||||
|
||||
struct Equals {
|
||||
explicit Equals(OperatorBase& op)
|
||||
: text_(op.GetSingleArgument<std::string>("text", "")) {}
|
||||
bool operator()(const std::string& str) {
|
||||
return str == text_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string text_;
|
||||
};
|
||||
|
||||
struct Prefix {
|
||||
explicit Prefix(OperatorBase& op)
|
||||
: length_(op.GetSingleArgument<int>("length", 3)) {}
|
||||
@ -108,6 +119,9 @@ REGISTER_CPU_OPERATOR(
|
||||
REGISTER_CPU_OPERATOR(
|
||||
StringEndsWith,
|
||||
StringElementwiseOp<EndsWith, FixedType<bool>>);
|
||||
REGISTER_CPU_OPERATOR(
|
||||
StringEquals,
|
||||
StringElementwiseOp<Equals, FixedType<bool>>);
|
||||
REGISTER_CPU_OPERATOR(StringJoin, StringJoinOp<CPUContext>);
|
||||
|
||||
OPERATOR_SCHEMA(StringPrefix)
|
||||
@ -164,6 +178,17 @@ Returns tensor of boolean of the same dimension of input.
|
||||
.Input(0, "strings", "Tensor of std::string.")
|
||||
.Output(0, "bools", "Tensor of bools of same shape as input.");
|
||||
|
||||
OPERATOR_SCHEMA(StringEquals)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Performs equality check on each string in the input tensor.
|
||||
Returns tensor of booleans of the same dimension as input.
|
||||
)DOC")
|
||||
.Arg("text", "The text to check input strings equality against.")
|
||||
.Input(0, "strings", "Tensor of std::string.")
|
||||
.Output(0, "bools", "Tensor of bools of same shape as input.");
|
||||
|
||||
OPERATOR_SCHEMA(StringJoin)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
@ -187,6 +212,7 @@ SHOULD_NOT_DO_GRADIENT(StringPrefix);
|
||||
SHOULD_NOT_DO_GRADIENT(StringSuffix);
|
||||
SHOULD_NOT_DO_GRADIENT(StringStartsWith);
|
||||
SHOULD_NOT_DO_GRADIENT(StringEndsWith);
|
||||
SHOULD_NOT_DO_GRADIENT(StringEquals);
|
||||
SHOULD_NOT_DO_GRADIENT(StringJoin);
|
||||
}
|
||||
} // namespace caffe2
|
||||
|
||||
@ -119,6 +119,33 @@ class TestStringOps(serial.SerializedTestCase):
|
||||
[strings],
|
||||
string_ends_with_ref)
|
||||
|
||||
@given(strings=st.text(alphabet=['a', 'b']))
|
||||
@settings(deadline=1000)
|
||||
def test_string_equals(self, strings):
|
||||
text = ""
|
||||
if strings:
|
||||
text = strings[0]
|
||||
|
||||
strings = np.array(
|
||||
[str(a) for a in strings], dtype=np.object
|
||||
)
|
||||
|
||||
def string_equals_ref(strings):
|
||||
return (
|
||||
np.array([a == text for a in strings], dtype=bool),
|
||||
)
|
||||
|
||||
op = core.CreateOperator(
|
||||
'StringEquals',
|
||||
['strings'],
|
||||
['bools'],
|
||||
text=text)
|
||||
self.assertReferenceChecks(
|
||||
hu.cpu_do,
|
||||
op,
|
||||
[strings],
|
||||
string_equals_ref)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user