add batch of string ops (#20826)

Summary:
First batch of https://github.com/pytorch/pytorch/issues/20769, handles `isupper`, `islower`, `isdigit`, `isspace`, `isalnum`, `isalpha`, `upper`, `lower`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20826

Differential Revision: D15459166

Pulled By: eellison

fbshipit-source-id: 0ed908022475e27011803cc4af7cf393a4312783
This commit is contained in:
Elias Ellison
2019-05-22 17:22:58 -07:00
committed by Facebook Github Bot
parent 7aa3887f43
commit aebcd80ae4
2 changed files with 82 additions and 5 deletions

View File

@ -12079,14 +12079,29 @@ a")
self.checkScript(fn, ("abcde",))
def test_str_cmp(self):
def test(a, b):
def test_str_ops(self):
def test_str_is(s):
# type: (str) -> Tuple[bool, bool, bool, bool, bool, bool]
return s.isupper(), s.islower(), s.isdigit(), s.isspace(), \
s.isalnum(), s.isalpha()
def test_str_to(s):
# type: (str) -> Tuple[str, str]
return s.upper(), s.lower()
inputs = ["", "12a", "!B", "12", "a", "B", "aB", "$12", "B12", "AB ",
" \t", " \n", "\na", "abc"]
for input in inputs:
self.checkScript(test_str_is, (input,))
self.checkScript(test_str_to, (input,))
def test_str_cmp(a, b):
# type: (str, str) -> Tuple[bool, bool, bool, bool, bool, bool]
return a != b, a == b, a < b, a > b, a <= b, a >= b
self.checkScript(test, ("1", "2"))
self.checkScript(test, ("2", "1"))
self.checkScript(test, ("1", "1"))
for i in range(len(inputs) - 1):
self.checkScript(test_str_cmp, (inputs[i], inputs[i + 1]))
def test_ord(self):
def fn(x):

View File

@ -1913,6 +1913,68 @@ RegisterOperators reg2({
Operator(
"aten::slice(str string, int start, int end=9223372036854775807, int step=1) -> str",
stringSlice),
// python string is methods return false if empty
#define DEFINE_STRING_IS_OP(op_name, char_op) \
Operator(#op_name "(str self) -> bool", [](Stack& stack) { \
auto string = pop(stack).toStringRef(); \
push( \
stack, \
string.size() != 0 && \
std::all_of(string.begin(), string.end(), [](char c) { \
return char_op(c); \
})); \
return 0; \
})
// upper and lower require there to be at least one alpha character,
// and ignore all other characters
Operator(
"aten::isupper(str self) -> bool",
[](Stack& stack) {
auto string = pop(stack).toStringRef();
bool found_alpha = false;
bool is_upper = true;
for (char c : string) {
found_alpha |= std::isalpha(c);
is_upper &= (!std::isalpha(c) || std::isupper(c));
}
push(stack, found_alpha && is_upper);
return 0;
}),
Operator(
"aten::islower(str self) -> bool",
[](Stack& stack) {
auto string = pop(stack).toStringRef();
bool found_alpha = false;
bool is_lower = true;
for (char c : string) {
found_alpha |= std::isalpha(c);
is_lower &= (!std::isalpha(c) || std::islower(c));
}
push(stack, found_alpha && is_lower);
return 0;
}),
DEFINE_STRING_IS_OP(aten::isdigit, std::isdigit),
DEFINE_STRING_IS_OP(aten::isspace, std::isspace),
DEFINE_STRING_IS_OP(aten::isalnum, std::isalnum),
DEFINE_STRING_IS_OP(aten::isalpha, std::isalpha),
#define DEFINE_STRING_CHAR_MAP_OP(op_name, char_op) \
Operator(#op_name "(str self) -> str", [](Stack& stack) { \
auto string = pop(stack).toStringRef(); \
std::stringstream ss; \
for (char c : string) { \
ss << static_cast<char>(char_op(c)); \
} \
push(stack, ss.str()); \
return 0; \
})
DEFINE_STRING_CHAR_MAP_OP(aten::upper, std::toupper),
DEFINE_STRING_CHAR_MAP_OP(aten::lower, std::tolower),
Operator(
"prim::StringIndex(str string, int index) -> str",
[](Stack& stack) {