mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add batch of string ops (#20826)
Summary: First batch of https://github.com/pytorch/pytorch/issues/20769, handles `isupper`, `islower`, `isdigit`, `isspace`, `isalnum`, `isalpha`, `upper`, `lower` Pull Request resolved: https://github.com/pytorch/pytorch/pull/20826 Differential Revision: D15466986 Pulled By: eellison fbshipit-source-id: d1df65721da803dfa30e28fdd9b874405be6bc7d
This commit is contained in:
committed by
Facebook Github Bot
parent
90182a7332
commit
8fc069fa17
@ -12079,14 +12079,29 @@ a")
|
||||
|
||||
self.checkScript(fn, ("abcde",))
|
||||
|
||||
def test_str_cmp(self):
|
||||
def test(a, b):
|
||||
def test_str_ops(self):
|
||||
def test_str_is(s):
|
||||
# type: (str) -> Tuple[bool, bool, bool, bool, bool, bool]
|
||||
return s.isupper(), s.islower(), s.isdigit(), s.isspace(), \
|
||||
s.isalnum(), s.isalpha()
|
||||
|
||||
def test_str_to(s):
|
||||
# type: (str) -> Tuple[str, str]
|
||||
return s.upper(), s.lower()
|
||||
|
||||
inputs = ["", "12a", "!B", "12", "a", "B", "aB", "$12", "B12", "AB ",
|
||||
" \t", " \n", "\na", "abc"]
|
||||
|
||||
for input in inputs:
|
||||
self.checkScript(test_str_is, (input,))
|
||||
self.checkScript(test_str_to, (input,))
|
||||
|
||||
def test_str_cmp(a, b):
|
||||
# type: (str, str) -> Tuple[bool, bool, bool, bool, bool, bool]
|
||||
return a != b, a == b, a < b, a > b, a <= b, a >= b
|
||||
|
||||
self.checkScript(test, ("1", "2"))
|
||||
self.checkScript(test, ("2", "1"))
|
||||
self.checkScript(test, ("1", "1"))
|
||||
for i in range(len(inputs) - 1):
|
||||
self.checkScript(test_str_cmp, (inputs[i], inputs[i + 1]))
|
||||
|
||||
def test_ord(self):
|
||||
def fn(x):
|
||||
|
@ -24,6 +24,7 @@
|
||||
#include <c10/core/thread_pool.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
|
||||
#include <cctype>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <exception>
|
||||
@ -1913,6 +1914,70 @@ RegisterOperators reg2({
|
||||
Operator(
|
||||
"aten::slice(str string, int start, int end=9223372036854775807, int step=1) -> str",
|
||||
stringSlice),
|
||||
|
||||
// python string is methods return false if empty
|
||||
#define DEFINE_STRING_IS_OP(op_name, char_op) \
|
||||
Operator(#op_name "(str self) -> bool", [](Stack& stack) { \
|
||||
auto string = pop(stack).toStringRef(); \
|
||||
push( \
|
||||
stack, \
|
||||
string.size() != 0 && \
|
||||
std::all_of(string.begin(), string.end(), [](char c) { \
|
||||
return char_op(c); \
|
||||
})); \
|
||||
return 0; \
|
||||
})
|
||||
|
||||
// upper and lower require there to be at least one alpha character,
|
||||
// and ignore all other characters
|
||||
Operator(
|
||||
"aten::isupper(str self) -> bool",
|
||||
[](Stack& stack) {
|
||||
auto string = pop(stack).toStringRef();
|
||||
bool found_alpha = false;
|
||||
bool is_upper = true;
|
||||
for (size_t i = 0; i < string.size() && is_upper; ++i) {
|
||||
char c = string[i];
|
||||
found_alpha |= std::isalpha(c);
|
||||
is_upper &= (!std::isalpha(c) || std::isupper(c));
|
||||
}
|
||||
push(stack, found_alpha && is_upper);
|
||||
return 0;
|
||||
}),
|
||||
Operator(
|
||||
"aten::islower(str self) -> bool",
|
||||
[](Stack& stack) {
|
||||
auto string = pop(stack).toStringRef();
|
||||
bool found_alpha = false;
|
||||
bool is_lower = true;
|
||||
for (size_t i = 0; i < string.size() && is_lower; ++i) {
|
||||
char c = string[i];
|
||||
found_alpha |= std::isalpha(c);
|
||||
is_lower &= (!std::isalpha(c) || std::islower(c));
|
||||
}
|
||||
push(stack, found_alpha && is_lower);
|
||||
return 0;
|
||||
}),
|
||||
|
||||
DEFINE_STRING_IS_OP(aten::isdigit, std::isdigit),
|
||||
DEFINE_STRING_IS_OP(aten::isspace, std::isspace),
|
||||
DEFINE_STRING_IS_OP(aten::isalnum, std::isalnum),
|
||||
DEFINE_STRING_IS_OP(aten::isalpha, std::isalpha),
|
||||
|
||||
#define DEFINE_STRING_CHAR_MAP_OP(op_name, char_op) \
|
||||
Operator(#op_name "(str self) -> str", [](Stack& stack) { \
|
||||
auto string = pop(stack).toStringRef(); \
|
||||
std::stringstream ss; \
|
||||
for (char c : string) { \
|
||||
ss << static_cast<char>(char_op(c)); \
|
||||
} \
|
||||
push(stack, ss.str()); \
|
||||
return 0; \
|
||||
})
|
||||
|
||||
DEFINE_STRING_CHAR_MAP_OP(aten::upper, std::toupper),
|
||||
DEFINE_STRING_CHAR_MAP_OP(aten::lower, std::tolower),
|
||||
|
||||
Operator(
|
||||
"prim::StringIndex(str string, int index) -> str",
|
||||
[](Stack& stack) {
|
||||
|
Reference in New Issue
Block a user