mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Python string standard lib (#21059)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21059 ghimport-source-id: f813585cde1b275c134b19009a2f5c0b3d70fc6e Reviewed By: jamesr66a Differential Revision: D15830704 Pulled By: bwasti fbshipit-source-id: e55a8c6bf910a163b9a5260235e315af9532b129
This commit is contained in:
committed by
Facebook Github Bot
parent
65a3dbdfb0
commit
dddc65db9e
@ -421,6 +421,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/print_handler.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/register_string_ops.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/register_quantized_ops.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/scope.cpp
|
||||
|
||||
@ -12126,30 +12126,6 @@ a")
|
||||
|
||||
self.checkScript(fn, ("abcde",))
|
||||
|
||||
def test_str_ops(self):
|
||||
def test_str_is(s):
|
||||
# type: (str) -> Tuple[bool, bool, bool, bool, bool, bool]
|
||||
return s.isupper(), s.islower(), s.isdigit(), s.isspace(), \
|
||||
s.isalnum(), s.isalpha()
|
||||
|
||||
def test_str_to(s):
|
||||
# type: (str) -> Tuple[str, str]
|
||||
return s.upper(), s.lower()
|
||||
|
||||
inputs = ["", "12a", "!B", "12", "a", "B", "aB", "$12", "B12", "AB ",
|
||||
" \t", " \n", "\na", "abc"]
|
||||
|
||||
for input in inputs:
|
||||
self.checkScript(test_str_is, (input,))
|
||||
self.checkScript(test_str_to, (input,))
|
||||
|
||||
def test_str_cmp(a, b):
|
||||
# type: (str, str) -> Tuple[bool, bool, bool, bool, bool, bool]
|
||||
return a != b, a == b, a < b, a > b, a <= b, a >= b
|
||||
|
||||
for i in range(len(inputs) - 1):
|
||||
self.checkScript(test_str_cmp, (inputs[i], inputs[i + 1]))
|
||||
|
||||
def test_ord(self):
|
||||
def fn(x):
|
||||
# type: (str) -> int
|
||||
|
||||
305
test/test_jit_string.py
Normal file
305
test/test_jit_string.py
Normal file
@ -0,0 +1,305 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
from test_jit import JitTestCase
|
||||
|
||||
class TestScript(JitTestCase):
|
||||
def test_str_ops(self):
|
||||
def test_str_is(s):
|
||||
# type: (str) -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]
|
||||
return s.isupper(), s.islower(), s.isdigit(), s.isspace(), \
|
||||
s.isalnum(), s.isalpha(), s.isdecimal(), s.isnumeric(), \
|
||||
s.isidentifier(), s.istitle(), s.isprintable()
|
||||
|
||||
def test_str_to(s):
|
||||
# type: (str) -> Tuple[str, str, str, str, str]
|
||||
return s.upper(), s.lower(), s.capitalize(), s.title(), s.swapcase()
|
||||
|
||||
def test_str_strip(s):
|
||||
# type: (str) -> Tuple[str, str, str]
|
||||
return (
|
||||
s.lstrip(),
|
||||
s.rstrip(),
|
||||
s.strip(),
|
||||
)
|
||||
|
||||
def test_str_strip_char_set(s, char_set):
|
||||
# type: (str, str) -> Tuple[str, str, str]
|
||||
return (
|
||||
s.lstrip(char_set),
|
||||
s.rstrip(char_set),
|
||||
s.strip(char_set),
|
||||
)
|
||||
|
||||
inputs = ["", "12a", "!B", "12", "a", "B", "aB", "$12", "B12", "AB ",
|
||||
" \t", " \n", "\na", "abc", "123.3", "s a", "b12a ",
|
||||
"more strings with spaces", "Titular Strings", "\x0acan'tprintthis",
|
||||
"spaces at the end ", " begin"]
|
||||
|
||||
def test_str_center(i, s):
|
||||
# type: (int, str) -> str
|
||||
return s.center(i)
|
||||
|
||||
def test_str_center_fc(i, s):
|
||||
# type: (int, str) -> str
|
||||
return s.center(i, '*')
|
||||
|
||||
def test_str_center_error(s):
|
||||
# type: (str) -> str
|
||||
return s.center(10, '**')
|
||||
|
||||
def test_ljust(s, i):
|
||||
# type: (str, int) -> str
|
||||
return s.ljust(i)
|
||||
|
||||
def test_ljust_fc(s, i, fc):
|
||||
# type: (str, int, str) -> str
|
||||
return s.ljust(i, fc)
|
||||
|
||||
def test_ljust_fc_err(s):
|
||||
# type: (str) -> str
|
||||
return s.ljust(10, '**')
|
||||
|
||||
def test_rjust(s, i):
|
||||
# type: (str, int) -> str
|
||||
return s.rjust(i)
|
||||
|
||||
def test_rjust_fc(s, i, fc):
|
||||
# type: (str, int, str) -> str
|
||||
return s.rjust(i, fc)
|
||||
|
||||
def test_rjust_fc_err(s):
|
||||
# type: (str) -> str
|
||||
return s.rjust(10, '**')
|
||||
|
||||
def test_zfill(s, i):
|
||||
# type: (str, int) -> str
|
||||
return s.zfill(i)
|
||||
|
||||
for input in inputs:
|
||||
self.checkScript(test_str_is, (input,))
|
||||
self.checkScript(test_str_to, (input,))
|
||||
self.checkScript(test_str_strip, (input,))
|
||||
for char_set in ["abc", "123", " ", "\t"]:
|
||||
self.checkScript(test_str_strip_char_set, (input, char_set))
|
||||
for i in range(7):
|
||||
self.checkScript(test_str_center, (i, input,))
|
||||
self.checkScript(test_str_center_fc, (i, input,))
|
||||
self.checkScript(test_ljust, (input, i))
|
||||
self.checkScript(test_ljust_fc, (input, i, '*'))
|
||||
self.checkScript(test_rjust, (input, i))
|
||||
self.checkScript(test_rjust_fc, (input, i, '*'))
|
||||
self.checkScript(test_zfill, (input, i))
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
test_str_center_error("error")
|
||||
test_ljust("error")
|
||||
|
||||
def test_count():
|
||||
# type: () -> Tuple[int, int, int, int, int, int, int, int, int, int, int, int]
|
||||
return (
|
||||
"hello".count("h"),
|
||||
"hello".count("h", 0, 1),
|
||||
"hello".count("h", -3),
|
||||
"hello".count("h", -10, 1),
|
||||
"hello".count("h", 0, -10),
|
||||
"hello".count("h", 0, 10),
|
||||
"hello".count("ell"),
|
||||
"hello".count("ell", 0, 1),
|
||||
"hello".count("ell", -3),
|
||||
"hello".count("ell", -10, 1),
|
||||
"hello".count("ell", 0, -10),
|
||||
"hello".count("ell", 0, 10)
|
||||
)
|
||||
self.checkScript(test_count, ())
|
||||
def test_endswith():
|
||||
# type: () -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]
|
||||
return (
|
||||
"hello".endswith("lo"),
|
||||
"hello".endswith("lo", 0),
|
||||
"hello".endswith("lo", -2),
|
||||
"hello".endswith("lo", -8),
|
||||
"hello".endswith("lo", 0, -5),
|
||||
"hello".endswith("lo", -2, 3),
|
||||
"hello".endswith("lo", -8, 4),
|
||||
"hello".endswith("l"),
|
||||
"hello".endswith("l", 0),
|
||||
"hello".endswith("l", -2),
|
||||
"hello".endswith("l", -8),
|
||||
"hello".endswith("l", 0, -5),
|
||||
"hello".endswith("l", -2, 3),
|
||||
"hello".endswith("l", -8, 4)
|
||||
)
|
||||
self.checkScript(test_endswith, ())
|
||||
def test_startswith():
|
||||
# type: () -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]
|
||||
return (
|
||||
"hello".startswith("lo"),
|
||||
"hello".startswith("lo", 0),
|
||||
"hello".startswith("lo", -2),
|
||||
"hello".startswith("lo", -8),
|
||||
"hello".startswith("lo", 0, -5),
|
||||
"hello".startswith("lo", -2, 3),
|
||||
"hello".startswith("lo", -8, 4),
|
||||
"hello".startswith("l"),
|
||||
"hello".startswith("l", 0),
|
||||
"hello".startswith("l", -2),
|
||||
"hello".startswith("l", -8),
|
||||
"hello".startswith("l", 0, -5),
|
||||
"hello".startswith("l", -2, 3),
|
||||
"hello".startswith("l", -8, 4)
|
||||
)
|
||||
self.checkScript(test_startswith, ())
|
||||
def test_expandtabs():
|
||||
# type: () -> Tuple[str, str, str, str, str, str]
|
||||
return (
|
||||
'xyz\t82345\tabc'.expandtabs(),
|
||||
'xyz\t32345\tabc'.expandtabs(3),
|
||||
'xyz\t52345\tabc'.expandtabs(5),
|
||||
'xyz\t62345\tabc'.expandtabs(6),
|
||||
'xyz\t72345\tabc'.expandtabs(7),
|
||||
'xyz\t62345\tabc'.expandtabs(-5),
|
||||
)
|
||||
self.checkScript(test_expandtabs, ())
|
||||
|
||||
def test_rfind():
|
||||
# type: () -> Tuple[int, int, int, int, int, int, int, int, int]
|
||||
return (
|
||||
"hello123abc".rfind("llo"),
|
||||
"hello123abc".rfind("12"),
|
||||
"hello123abc".rfind("ab"),
|
||||
"hello123abc".rfind("ll", -1),
|
||||
"hello123abc".rfind("12", 4),
|
||||
"hello123abc".rfind("ab", -7),
|
||||
"hello123abc".rfind("ll", -1, 8),
|
||||
"hello123abc".rfind("12", 4, -4),
|
||||
"hello123abc".rfind("ab", -7, -20),
|
||||
)
|
||||
self.checkScript(test_rfind, ())
|
||||
|
||||
def test_find():
|
||||
# type: () -> Tuple[int, int, int, int, int, int, int, int, int]
|
||||
return (
|
||||
"hello123abc".find("llo"),
|
||||
"hello123abc".find("12"),
|
||||
"hello123abc".find("ab"),
|
||||
"hello123abc".find("ll", -1),
|
||||
"hello123abc".find("12", 4),
|
||||
"hello123abc".find("ab", -7),
|
||||
"hello123abc".find("ll", -1, 8),
|
||||
"hello123abc".find("12", 4, -4),
|
||||
"hello123abc".find("ab", -7, -20),
|
||||
)
|
||||
self.checkScript(test_find, ())
|
||||
|
||||
def test_index():
|
||||
# type: () -> Tuple[int, int, int, int, int, int]
|
||||
return (
|
||||
"hello123abc".index("llo"),
|
||||
"hello123abc".index("12"),
|
||||
"hello123abc".index("ab"),
|
||||
"hello123abc".index("12", 4),
|
||||
"hello123abc".index("ab", -7),
|
||||
"hello123abc".index("12", 4, -4),
|
||||
)
|
||||
self.checkScript(test_index, ())
|
||||
|
||||
def test_rindex():
|
||||
# type: () -> Tuple[int, int, int, int, int, int]
|
||||
return (
|
||||
"hello123abc".rindex("llo"),
|
||||
"hello123abc".rindex("12"),
|
||||
"hello123abc".rindex("ab"),
|
||||
"hello123abc".rindex("12", 4),
|
||||
"hello123abc".rindex("ab", -7),
|
||||
"hello123abc".rindex("12", 4, -4),
|
||||
)
|
||||
self.checkScript(test_rindex, ())
|
||||
|
||||
def test_replace():
|
||||
# type: () -> Tuple[str, str, str, str, str, str, str]
|
||||
return (
|
||||
"hello123abc".replace("llo", "sdf"),
|
||||
"ff".replace("f", "ff"),
|
||||
"abc123".replace("a", "testing"),
|
||||
"aaaaaa".replace("a", "testing", 3),
|
||||
"bbb".replace("a", "testing", 3),
|
||||
"ccc".replace("c", "ccc", 3),
|
||||
"cc".replace("c", "ccc", -3),
|
||||
)
|
||||
self.checkScript(test_replace, ())
|
||||
|
||||
def test_partition():
|
||||
# type: () -> Tuple[Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str]]
|
||||
return (
|
||||
"hello123abc".partition("llo"),
|
||||
"ff".partition("f"),
|
||||
"abc123".partition("a"),
|
||||
"aaaaaa".partition("testing"),
|
||||
"bbb".partition("a"),
|
||||
"ccc".partition("ccc"),
|
||||
"cc".partition("ccc"),
|
||||
)
|
||||
self.checkScript(test_partition, ())
|
||||
|
||||
def test_rpartition():
|
||||
# type: () -> Tuple[Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str]]
|
||||
return (
|
||||
"hello123abc".rpartition("llo"),
|
||||
"ff".rpartition("f"),
|
||||
"abc123".rpartition("a"),
|
||||
"aaaaaa".rpartition("testing"),
|
||||
"bbb".rpartition("a"),
|
||||
"ccc".rpartition("ccc"),
|
||||
"cc".rpartition("ccc"),
|
||||
)
|
||||
self.checkScript(test_rpartition, ())
|
||||
|
||||
def test_split():
|
||||
# type: () -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str]]
|
||||
return (
|
||||
"a a a a a".split(),
|
||||
" a a a a a ".split(" "),
|
||||
"a a a a a ".split(" ", 10),
|
||||
"a a a a a ".split(" ", -1),
|
||||
"a a a a a ".split(" ", 3),
|
||||
" a a a a a ".split("*"),
|
||||
" a*a a*a a".split("*"),
|
||||
" a*a a*a a ".split("*", -1),
|
||||
" a*a a*a a ".split("a*", 10),
|
||||
)
|
||||
self.checkScript(test_split, ())
|
||||
|
||||
def test_rsplit():
|
||||
# type: () -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str]]
|
||||
return (
|
||||
"a a a a a".rsplit(),
|
||||
" a a a a a ".rsplit(" "),
|
||||
"a a a a a ".rsplit(" ", 10),
|
||||
"a a a a a ".rsplit(" ", -1),
|
||||
"a a a a a ".rsplit(" ", 3),
|
||||
" a a a a a ".rsplit("*"),
|
||||
" a*a a*a a ".rsplit("*"),
|
||||
" a*a a*a a ".rsplit("*", -1),
|
||||
" a*a a*a a".rsplit("a*", 10),
|
||||
)
|
||||
self.checkScript(test_rsplit, ())
|
||||
|
||||
def test_splitlines():
|
||||
# type: () -> Tuple[ List[str], List[str], List[str], List[str], List[str], List[str] ]
|
||||
return (
|
||||
"hello\ntest".splitlines(),
|
||||
"hello\n\ntest\n".splitlines(),
|
||||
"hello\ntest\n\n".splitlines(),
|
||||
"hello\vtest".splitlines(),
|
||||
"hello\v\f\ntest".splitlines(),
|
||||
"hello\ftest".splitlines(),
|
||||
)
|
||||
self.checkScript(test_splitlines, ())
|
||||
|
||||
def test_str_cmp(a, b):
|
||||
# type: (str, str) -> Tuple[bool, bool, bool, bool, bool, bool]
|
||||
return a != b, a == b, a < b, a > b, a <= b, a >= b
|
||||
|
||||
for i in range(len(inputs) - 1):
|
||||
self.checkScript(test_str_cmp, (inputs[i], inputs[i + 1]))
|
||||
|
||||
|
||||
@ -107,6 +107,7 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/passes/utils/memory_dag.cpp",
|
||||
"torch/csrc/jit/print_handler.cpp",
|
||||
"torch/csrc/jit/register_prim_ops.cpp",
|
||||
"torch/csrc/jit/register_string_ops.cpp",
|
||||
"torch/csrc/jit/register_special_ops.cpp",
|
||||
"torch/csrc/jit/register_quantized_ops.cpp",
|
||||
"torch/csrc/jit/scope.cpp",
|
||||
|
||||
@ -1201,30 +1201,6 @@ RegisterOperators logging_operators(
|
||||
return 0; \
|
||||
})
|
||||
|
||||
int stringSlice(Stack& stack) {
|
||||
auto step = pop(stack).toInt();
|
||||
TORCH_CHECK(step == 1, "Slicing a string only supports step=1");
|
||||
|
||||
auto end = pop(stack).toInt();
|
||||
auto start = pop(stack).toInt();
|
||||
auto string = pop(stack).toStringRef();
|
||||
const int64_t size = string.size();
|
||||
|
||||
// Clamp start and end to the bounds of the list
|
||||
start = std::max(int64_t(0), normalizeIndex(start, size));
|
||||
end = std::min(size, normalizeIndex(end, size));
|
||||
|
||||
if (end <= start) {
|
||||
// Slice is empty
|
||||
push(stack, std::string(""));
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string result(string.begin() + start, string.begin() + end);
|
||||
push(stack, std::move(result));
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Equivalent to list.at(idx)
|
||||
template <typename T>
|
||||
T getItem(const c10::ListPtr<T>& list, int64_t idx) {
|
||||
@ -1972,72 +1948,6 @@ RegisterOperators reg2({
|
||||
"aten::ne(Tensor[] a, Tensor[] b) -> bool",
|
||||
listNe<at::Tensor>),
|
||||
Operator("aten::ne(bool[] a, bool[] b) -> bool", listNe<bool>),
|
||||
Operator(
|
||||
"aten::slice(str string, int start, int end=9223372036854775807, int step=1) -> str",
|
||||
stringSlice),
|
||||
|
||||
// python string is methods return false if empty
|
||||
#define DEFINE_STRING_IS_OP(op_name, char_op) \
|
||||
Operator(#op_name "(str self) -> bool", [](Stack& stack) { \
|
||||
auto string = pop(stack).toStringRef(); \
|
||||
push( \
|
||||
stack, \
|
||||
string.size() != 0 && \
|
||||
std::all_of(string.begin(), string.end(), [](char c) { \
|
||||
return char_op(c); \
|
||||
})); \
|
||||
return 0; \
|
||||
})
|
||||
|
||||
// upper and lower require there to be at least one alpha character,
|
||||
// and ignore all other characters
|
||||
Operator(
|
||||
"aten::isupper(str self) -> bool",
|
||||
[](Stack& stack) {
|
||||
auto string = pop(stack).toStringRef();
|
||||
bool found_alpha = false;
|
||||
bool is_upper = true;
|
||||
for (size_t i = 0; i < string.size() && is_upper; ++i) {
|
||||
char c = string[i];
|
||||
found_alpha |= std::isalpha(c);
|
||||
is_upper &= (!std::isalpha(c) || std::isupper(c));
|
||||
}
|
||||
push(stack, found_alpha && is_upper);
|
||||
return 0;
|
||||
}),
|
||||
Operator(
|
||||
"aten::islower(str self) -> bool",
|
||||
[](Stack& stack) {
|
||||
auto string = pop(stack).toStringRef();
|
||||
bool found_alpha = false;
|
||||
bool is_lower = true;
|
||||
for (size_t i = 0; i < string.size() && is_lower; ++i) {
|
||||
char c = string[i];
|
||||
found_alpha |= std::isalpha(c);
|
||||
is_lower &= (!std::isalpha(c) || std::islower(c));
|
||||
}
|
||||
push(stack, found_alpha && is_lower);
|
||||
return 0;
|
||||
}),
|
||||
|
||||
DEFINE_STRING_IS_OP(aten::isdigit, std::isdigit),
|
||||
DEFINE_STRING_IS_OP(aten::isspace, std::isspace),
|
||||
DEFINE_STRING_IS_OP(aten::isalnum, std::isalnum),
|
||||
DEFINE_STRING_IS_OP(aten::isalpha, std::isalpha),
|
||||
|
||||
#define DEFINE_STRING_CHAR_MAP_OP(op_name, char_op) \
|
||||
Operator(#op_name "(str self) -> str", [](Stack& stack) { \
|
||||
auto string = pop(stack).toStringRef(); \
|
||||
std::stringstream ss; \
|
||||
for (char c : string) { \
|
||||
ss << static_cast<char>(char_op(c)); \
|
||||
} \
|
||||
push(stack, ss.str()); \
|
||||
return 0; \
|
||||
})
|
||||
|
||||
DEFINE_STRING_CHAR_MAP_OP(aten::upper, std::toupper),
|
||||
DEFINE_STRING_CHAR_MAP_OP(aten::lower, std::tolower),
|
||||
|
||||
Operator(
|
||||
"prim::StringIndex(str string, int index) -> str",
|
||||
|
||||
599
torch/csrc/jit/register_string_ops.cpp
Normal file
599
torch/csrc/jit/register_string_ops.cpp
Normal file
@ -0,0 +1,599 @@
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
#include <torch/csrc/jit/custom_operator.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace {
|
||||
|
||||
// Convert an python index (which may be negative) into an index usable for a
|
||||
// C++ container
|
||||
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
|
||||
if (idx < 0) {
|
||||
// Handle negative indexing
|
||||
idx = list_size + idx;
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
std::string stringSlice(std::string string, int64_t start, int64_t end, int64_t step) {
|
||||
TORCH_CHECK(step == 1, "Slicing a string only supports step=1");
|
||||
|
||||
const int64_t size = string.size();
|
||||
|
||||
// Clamp start and end to the bounds of the list
|
||||
start = std::max(int64_t(0), normalizeIndex(start, size));
|
||||
end = std::min(size, normalizeIndex(end, size));
|
||||
|
||||
if (end <= start) {
|
||||
// Slice is empty
|
||||
return std::string("");
|
||||
}
|
||||
|
||||
std::string result(string.begin() + start, string.begin() + end);
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t stringFindImpl(
|
||||
std::string string,
|
||||
std::string substr,
|
||||
int64_t start,
|
||||
int64_t end,
|
||||
bool reverse = false) {
|
||||
int64_t size = string.size();
|
||||
if (start < 0) {
|
||||
start = std::max(int64_t(0), int64_t(size + start));
|
||||
}
|
||||
if (end < 0) {
|
||||
end = std::max(int64_t(0), int64_t(size + end + 1));
|
||||
}
|
||||
if (end > start) {
|
||||
string = string.substr(start, end - start);
|
||||
} else {
|
||||
string = "";
|
||||
}
|
||||
|
||||
int64_t result = -1;
|
||||
if (string.size() >= substr.size()) {
|
||||
auto pos = string.find(substr, 0);
|
||||
if (reverse) {
|
||||
auto rpos = pos;
|
||||
do {
|
||||
pos = rpos;
|
||||
rpos = string.find(substr, pos + 1);
|
||||
} while (rpos != std::string::npos);
|
||||
}
|
||||
if (pos != std::string::npos) {
|
||||
result = pos + start;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
RegisterOperators reg_str_ops({
|
||||
|
||||
// python string is methods return false if empty
|
||||
#define DEFINE_STRING_IS_OP(op_name, char_op) \
|
||||
Operator(#op_name "(str self) -> bool", [](Stack& stack) { \
|
||||
auto string = pop(stack).toStringRef(); \
|
||||
push( \
|
||||
stack, \
|
||||
string.size() != 0 && \
|
||||
std::all_of(string.begin(), string.end(), [](char c) { \
|
||||
return char_op(c); \
|
||||
})); \
|
||||
return 0; \
|
||||
})
|
||||
|
||||
DEFINE_STRING_IS_OP(aten::isdigit, std::isdigit),
|
||||
DEFINE_STRING_IS_OP(aten::isspace, std::isspace),
|
||||
DEFINE_STRING_IS_OP(aten::isalnum, std::isalnum),
|
||||
DEFINE_STRING_IS_OP(aten::isalpha, std::isalpha),
|
||||
DEFINE_STRING_IS_OP(aten::isdecimal, std::isdigit),
|
||||
DEFINE_STRING_IS_OP(aten::isnumeric, std::isdigit),
|
||||
|
||||
#define DEFINE_STRING_CHAR_MAP_OP(op_name, char_op) \
|
||||
Operator(#op_name "(str self) -> str", [](Stack& stack) { \
|
||||
auto string = pop(stack).toStringRef(); \
|
||||
std::stringstream ss; \
|
||||
for (char c : string) { \
|
||||
ss << static_cast<char>(char_op(c)); \
|
||||
} \
|
||||
push(stack, ss.str()); \
|
||||
return 0; \
|
||||
})
|
||||
|
||||
DEFINE_STRING_CHAR_MAP_OP(aten::upper, std::toupper),
|
||||
DEFINE_STRING_CHAR_MAP_OP(aten::lower, std::tolower),
|
||||
DEFINE_STRING_CHAR_MAP_OP(aten::swapcase, ([](char c) {
|
||||
if (c == static_cast<char>(std::toupper(c))) {
|
||||
return static_cast<char>(std::tolower(c));
|
||||
} else {
|
||||
return static_cast<char>(std::toupper(c));
|
||||
}
|
||||
}))
|
||||
|
||||
});
|
||||
|
||||
auto reg_str_ops_2 = torch::jit::RegisterOperators()
|
||||
.op("aten::splitlines(str self, bool keepends=False) -> str[]",
|
||||
[](std::string string, bool keepends) {
|
||||
std::string delimiters =
|
||||
"\n\r\r\n\v\x0b\f\x0c\x1c\x1d\x1e\x85\u2028\u2029";
|
||||
std::vector<std::string> splits;
|
||||
|
||||
auto prev_pos = 0;
|
||||
auto pos = 0;
|
||||
while ((pos = string.find_first_of(delimiters, pos)) !=
|
||||
std::string::npos) {
|
||||
splits.emplace_back(string.substr(prev_pos, pos - prev_pos));
|
||||
if (keepends) {
|
||||
splits.emplace_back(string.substr(pos, 1));
|
||||
}
|
||||
pos++;
|
||||
prev_pos = pos;
|
||||
}
|
||||
if (prev_pos != string.size()) {
|
||||
splits.emplace_back(
|
||||
string.substr(prev_pos, string.size() - prev_pos));
|
||||
}
|
||||
|
||||
return splits;
|
||||
})
|
||||
.op("aten::slice(str string, int start, int end=9223372036854775807, int step=1) -> str",
|
||||
stringSlice)
|
||||
|
||||
// upper and lower require there to be at least one alpha character,
|
||||
// and ignore all other characters
|
||||
.op("aten::isupper(str self) -> bool",
|
||||
[](std::string string) {
|
||||
bool found_alpha = false;
|
||||
bool is_upper = true;
|
||||
for (size_t i = 0; i < string.size() && is_upper; ++i) {
|
||||
char c = string[i];
|
||||
found_alpha |= std::isalpha(c);
|
||||
is_upper &= (!std::isalpha(c) || std::isupper(c));
|
||||
}
|
||||
return found_alpha && is_upper;
|
||||
})
|
||||
.op("aten::islower(str self) -> bool",
|
||||
[](std::string string) {
|
||||
bool found_alpha = false;
|
||||
bool is_lower = true;
|
||||
for (size_t i = 0; i < string.size() && is_lower; ++i) {
|
||||
char c = string[i];
|
||||
found_alpha |= std::isalpha(c);
|
||||
is_lower &= (!std::isalpha(c) || std::islower(c));
|
||||
}
|
||||
return found_alpha && is_lower;
|
||||
})
|
||||
|
||||
.op("aten::capitalize(str self) -> str",
|
||||
[](std::string string) {
|
||||
std::stringstream ss;
|
||||
auto first_char = true;
|
||||
for (char c : string) {
|
||||
if (first_char) {
|
||||
ss << static_cast<char>(std::toupper(c));
|
||||
first_char = false;
|
||||
} else {
|
||||
ss << static_cast<char>(std::tolower(c));
|
||||
}
|
||||
}
|
||||
return ss.str();
|
||||
})
|
||||
|
||||
.op("aten::title(str self) -> str",
|
||||
[](std::string string) {
|
||||
std::stringstream ss;
|
||||
bool prev_is_nonalpha = true;
|
||||
for (char c : string) {
|
||||
if (prev_is_nonalpha) {
|
||||
ss << static_cast<char>(std::toupper(c));
|
||||
} else {
|
||||
ss << static_cast<char>(std::tolower(c));
|
||||
}
|
||||
if (std::isalpha(c)) {
|
||||
prev_is_nonalpha = false;
|
||||
} else {
|
||||
prev_is_nonalpha = true;
|
||||
}
|
||||
}
|
||||
return ss.str();
|
||||
})
|
||||
|
||||
.op("aten::center(str self, int width, str fillchar=' ') -> str",
|
||||
[](std::string string, int64_t width, std::string fillchar) {
|
||||
if (fillchar.size() != 1) {
|
||||
// TODO: this should be a TypeError
|
||||
throw std::runtime_error(
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
}
|
||||
if (string.size() > width) {
|
||||
return string;
|
||||
}
|
||||
std::stringstream ss;
|
||||
auto full_padding = width - string.size();
|
||||
auto l_pad = full_padding / 2;
|
||||
auto r_pad = (full_padding + 1) / 2;
|
||||
if (width % 2) {
|
||||
auto tmp = r_pad;
|
||||
r_pad = l_pad;
|
||||
l_pad = tmp;
|
||||
}
|
||||
for (auto i = 0; i < l_pad; ++i) {
|
||||
ss << fillchar;
|
||||
}
|
||||
ss << string;
|
||||
for (auto i = 0; i < r_pad; ++i) {
|
||||
ss << fillchar;
|
||||
}
|
||||
return ss.str();
|
||||
})
|
||||
|
||||
// Adapted from
|
||||
// https://stackoverflow.com/questions/22489073/counting-the-number-of-occurrences-of-a-string-within-a-string
|
||||
.op("aten::count(str self, str substr, int start=0, int end=-1) -> int",
|
||||
[](std::string string, std::string substr, int64_t start, int64_t end) {
|
||||
int64_t size = string.size();
|
||||
if (start > size) {
|
||||
return int64_t(0);
|
||||
}
|
||||
if (start < 0) {
|
||||
start = std::max(int64_t(0), int64_t(size + start));
|
||||
}
|
||||
if (end < 0) {
|
||||
end = std::max(int64_t(0), int64_t(size + end + 1));
|
||||
}
|
||||
|
||||
int64_t occurrences = 0;
|
||||
std::string::size_type pos = start;
|
||||
while ((pos = string.find(substr, pos)) != std::string::npos) {
|
||||
if (pos < end) {
|
||||
++occurrences;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
pos += substr.length();
|
||||
}
|
||||
return occurrences;
|
||||
})
|
||||
|
||||
.op("aten::endswith(str self, str substr, int start=0, int end=-1) -> bool",
|
||||
[](std::string string, std::string substr, int64_t start, int64_t end) {
|
||||
|
||||
int64_t size = string.size();
|
||||
if (start < 0) {
|
||||
start = std::max(int64_t(0), int64_t(size + start));
|
||||
}
|
||||
if (end < 0) {
|
||||
end = std::max(int64_t(0), int64_t(size + end + 1));
|
||||
}
|
||||
|
||||
string = string.substr(start, end - start);
|
||||
|
||||
auto result = false;
|
||||
if (string.length() >= substr.length()) {
|
||||
result = !string.compare(
|
||||
string.length() - substr.length(), substr.length(), substr);
|
||||
}
|
||||
return result;
|
||||
})
|
||||
|
||||
.op("aten::startswith(str self, str substr, int start=0, int end=-1) -> bool",
|
||||
[](std::string string, std::string substr, int64_t start, int64_t end) {
|
||||
|
||||
int64_t size = string.size();
|
||||
if (start < 0) {
|
||||
start = std::max(int64_t(0), int64_t(size + start));
|
||||
}
|
||||
if (end < 0) {
|
||||
end = std::max(int64_t(0), int64_t(size + end + 1));
|
||||
}
|
||||
|
||||
string = string.substr(start, end - start);
|
||||
|
||||
auto result = false;
|
||||
if (string.length() >= substr.length()) {
|
||||
result = !string.compare(0, substr.length(), substr);
|
||||
}
|
||||
return result;
|
||||
})
|
||||
|
||||
.op("aten::expandtabs(str self, int tabsize=8) -> str",
|
||||
[](std::string string, int64_t tabsize) {
|
||||
std::stringstream ss;
|
||||
size_t index = 0;
|
||||
for (const auto& c : string) {
|
||||
if (c != '\t') {
|
||||
ss << c;
|
||||
index++;
|
||||
} else {
|
||||
if (tabsize <= 0) {
|
||||
continue;
|
||||
}
|
||||
do {
|
||||
ss << ' ';
|
||||
index++;
|
||||
} while (index % tabsize);
|
||||
}
|
||||
}
|
||||
return ss.str();
|
||||
})
|
||||
|
||||
.op("aten::find(str self, str substr, int start=0, int end=-1) -> int",
|
||||
[](std::string string, std::string substr, int64_t start, int64_t end) {
|
||||
return stringFindImpl(string, substr, start, end);
|
||||
})
|
||||
|
||||
.op("aten::rfind(str self, str substr, int start=0, int end=-1) -> int",
|
||||
[](std::string string, std::string substr, int64_t start, int64_t end) {
|
||||
return stringFindImpl(string, substr, start, end, true);
|
||||
})
|
||||
|
||||
.op("aten::index(str self, str substr, int start=0, int end=-1) -> int",
|
||||
[](std::string string, std::string substr, int64_t start, int64_t end) {
|
||||
auto result = stringFindImpl(string, substr, start, end);
|
||||
if (result < 0) {
|
||||
throw std::runtime_error("ValueError: substring not found");
|
||||
}
|
||||
return result;
|
||||
})
|
||||
|
||||
.op("aten::rindex(str self, str substr, int start=0, int end=-1) -> int",
|
||||
[](std::string string, std::string substr, int64_t start, int64_t end) {
|
||||
auto result = stringFindImpl(string, substr, start, end, true);
|
||||
if (result < 0) {
|
||||
throw std::runtime_error("ValueError: substring not found");
|
||||
}
|
||||
return result;
|
||||
})
|
||||
|
||||
.op("aten::isidentifier(str self) -> bool",
|
||||
[](std::string string) {
|
||||
LOG(WARNING)
|
||||
<< "The isidentifier() implementation being used is from Python 2\n";
|
||||
if (string.size() < 1) {
|
||||
return false;
|
||||
}
|
||||
if (std::isdigit(string[0])) {
|
||||
return false;
|
||||
}
|
||||
auto result = std::all_of(string.begin(), string.end(), [](char c) {
|
||||
return std::isalnum(c);
|
||||
});
|
||||
return result;
|
||||
})
|
||||
|
||||
.op("aten::istitle(str self) -> bool",
|
||||
[](std::string string) {
|
||||
auto result = false;
|
||||
|
||||
bool prev_is_alpha = false;
|
||||
for (char c : string) {
|
||||
if (prev_is_alpha) {
|
||||
if (c != static_cast<char>(std::tolower(c))) {
|
||||
result = false;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
if (c != static_cast<char>(std::toupper(c))) {
|
||||
result = false;
|
||||
break;
|
||||
}
|
||||
// Only true if there exists at least one alpha
|
||||
if (std::isalpha(c)) {
|
||||
result = true;
|
||||
}
|
||||
}
|
||||
if (std::isalpha(c)) {
|
||||
prev_is_alpha = true;
|
||||
} else {
|
||||
prev_is_alpha = false;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
})
|
||||
|
||||
// Can't reuse DEFINE_STRING_IS_OP because "" is printable
|
||||
.op("aten::isprintable(str self) -> bool",
|
||||
[](std::string string) {
|
||||
auto result = std::all_of(string.begin(), string.end(), [](char c) {
|
||||
return std::isalnum(c) || std::ispunct(c) || c == ' ';
|
||||
});
|
||||
return result;
|
||||
})
|
||||
|
||||
.op("aten::ljust(str self, int width, str fillchar=' ') -> str",
|
||||
[](std::string string, int64_t width, std::string fillchar) {
|
||||
if (fillchar.size() != 1) {
|
||||
// TODO: this should be a TypeError
|
||||
throw std::runtime_error(
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
}
|
||||
auto to_append = std::max(int64_t(0), width - static_cast<int64_t>(string.size()));
|
||||
|
||||
std::stringstream ss;
|
||||
ss << string;
|
||||
for (auto i = 0; i < to_append; ++i) {
|
||||
ss << fillchar;
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
})
|
||||
|
||||
.op("aten::rjust(str self, int width, str fillchar=' ') -> str",
|
||||
[](std::string string, int64_t width, std::string fillchar) {
|
||||
if (fillchar.size() != 1) {
|
||||
// TODO: this should be a TypeError
|
||||
throw std::runtime_error(
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
}
|
||||
auto to_append = std::max(int64_t(0), width - static_cast<int64_t>(string.size()));
|
||||
|
||||
std::stringstream ss;
|
||||
for (auto i = 0; i < to_append; ++i) {
|
||||
ss << fillchar;
|
||||
}
|
||||
ss << string;
|
||||
return ss.str();
|
||||
})
|
||||
|
||||
.op("aten::zfill(str self, int width) -> str",
|
||||
[](std::string string, int64_t width) {
|
||||
auto to_append = std::max(int64_t(0), width - static_cast<int64_t>(string.size()));
|
||||
|
||||
std::stringstream ss;
|
||||
for (auto i = 0; i < to_append; ++i) {
|
||||
ss << '0';
|
||||
}
|
||||
ss << string;
|
||||
|
||||
return ss.str();
|
||||
})
|
||||
|
||||
.op("aten::lstrip(str self, str chars=' \\n\\t\\f\\v') -> str",
|
||||
[](std::string string, std::string chars) {
|
||||
auto index = string.find_first_not_of(chars);
|
||||
if (index != std::string::npos) {
|
||||
string = string.substr(index, string.size());
|
||||
} else {
|
||||
string = "";
|
||||
}
|
||||
return string;
|
||||
})
|
||||
|
||||
.op("aten::rstrip(str self, str chars=' \\n\\t\\f\\v') -> str",
|
||||
[](std::string string, std::string chars) {
|
||||
auto index = string.find_last_not_of(chars);
|
||||
if (index != std::string::npos) {
|
||||
string = string.substr(0, index + 1);
|
||||
} else {
|
||||
string = "";
|
||||
}
|
||||
return string;
|
||||
})
|
||||
|
||||
.op("aten::strip(str self, str chars=' \\n\\t\\f\\v') -> str",
|
||||
[](std::string string, std::string chars) {
|
||||
auto rindex = string.find_last_not_of(chars);
|
||||
if (rindex != std::string::npos) {
|
||||
string = string.substr(0, rindex + 1);
|
||||
} else {
|
||||
string = "";
|
||||
}
|
||||
auto lindex = string.find_first_not_of(chars);
|
||||
if (lindex != std::string::npos) {
|
||||
string = string.substr(lindex, string.size());
|
||||
} else {
|
||||
string = "";
|
||||
}
|
||||
return string;
|
||||
})
|
||||
|
||||
.op("aten::replace(str self, str old, str new, int max=-1) -> str",
|
||||
[](std::string string, std::string old_str, std::string new_str, int64_t max) {
|
||||
int64_t occurrences = 0;
|
||||
std::string::size_type pos = 0;
|
||||
while ((pos = string.find(old_str, pos)) != std::string::npos) {
|
||||
if (max >= 0 && ++occurrences > max) {
|
||||
break;
|
||||
}
|
||||
string = string.replace(pos, old_str.length(), new_str);
|
||||
pos += new_str.length();
|
||||
}
|
||||
|
||||
return string;
|
||||
})
|
||||
|
||||
.op("aten::partition(str self, str separator) -> (str, str, str)",
|
||||
[](std::string string, std::string separator) {
|
||||
|
||||
auto pos = string.find(separator, 0);
|
||||
if (pos == std::string::npos) {
|
||||
pos = string.size();
|
||||
separator = "";
|
||||
}
|
||||
auto pre_partition = string.substr(0, pos);
|
||||
auto post_partition =
|
||||
string.substr(pos + separator.size(), string.size());
|
||||
|
||||
return std::make_tuple(pre_partition,
|
||||
separator,
|
||||
post_partition);
|
||||
})
|
||||
|
||||
.op("aten::rpartition(str self, str separator) -> (str, str, str)",
|
||||
[](std::string string, std::string separator) {
|
||||
|
||||
auto pos = string.find(separator, 0);
|
||||
auto rpos = pos;
|
||||
do {
|
||||
pos = rpos;
|
||||
rpos = string.find(separator, pos + 1);
|
||||
} while (rpos != std::string::npos);
|
||||
|
||||
if (pos == std::string::npos) {
|
||||
pos = 0;
|
||||
separator = "";
|
||||
}
|
||||
|
||||
auto pre_partition = string.substr(0, pos);
|
||||
auto post_partition =
|
||||
string.substr(pos + separator.size(), string.size());
|
||||
|
||||
return std::make_tuple(pre_partition,
|
||||
separator,
|
||||
post_partition);
|
||||
})
|
||||
|
||||
.op("aten::split(str self, str separator=' ', int max=-1) -> str[]",
|
||||
[](std::string string, std::string separator, int64_t max) {
|
||||
std::string::size_type prev_pos = 0;
|
||||
std::string::size_type pos = 0;
|
||||
std::vector<std::string> splits;
|
||||
auto count = 0;
|
||||
while ((pos = string.find(separator, pos)) != std::string::npos) {
|
||||
count++;
|
||||
if (max >= 0 && count > max) {
|
||||
break;
|
||||
} else {
|
||||
splits.emplace_back(string.substr(prev_pos, pos - prev_pos));
|
||||
}
|
||||
pos += separator.size();
|
||||
prev_pos = pos;
|
||||
}
|
||||
splits.emplace_back(
|
||||
string.substr(prev_pos, string.size() - prev_pos));
|
||||
return splits;
|
||||
})
|
||||
|
||||
.op("aten::rsplit(str self, str separator=' ', int max=-1) -> str[]",
|
||||
[](std::string string, std::string separator, int64_t max) {
|
||||
std::reverse(separator.begin(), separator.end());
|
||||
std::reverse(string.begin(), string.end());
|
||||
|
||||
std::string::size_type prev_pos = 0;
|
||||
std::string::size_type pos = 0;
|
||||
std::vector<std::string> splits;
|
||||
auto count = 0;
|
||||
while ((pos = string.find(separator, pos)) != std::string::npos) {
|
||||
count++;
|
||||
if (max >= 0 && count > max) {
|
||||
break;
|
||||
} else {
|
||||
auto substr = string.substr(prev_pos, pos - prev_pos);
|
||||
std::reverse(substr.begin(), substr.end());
|
||||
splits.emplace(splits.begin(), substr);
|
||||
}
|
||||
pos += separator.size();
|
||||
prev_pos = pos;
|
||||
}
|
||||
auto substr = string.substr(prev_pos, string.size() - prev_pos);
|
||||
std::reverse(substr.begin(), substr.end());
|
||||
splits.emplace(splits.begin(), substr);
|
||||
return splits;
|
||||
});
|
||||
} // namespace
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
Reference in New Issue
Block a user