Fix inconsistent results of string split func on JIT mode (#38772)

Summary:
Resolve https://github.com/pytorch/pytorch/issues/38207

Below is the description of split function according to [Python doc](https://docs.python.org/3.8/library/stdtypes.html?highlight=split#str.split).

```
If sep is not specified or is None,  a different splitting algorithm is applied:
runs of consecutive whitespace are regarded as a single separator,
and the result will contain no empty strings at the start or end
if the string has leading or trailing whitespace.
```

The logic to handle both none and empty separators is added in register_string_ops.cpp as fix.

Signed-off-by: Xiong Wei <xiongw.fnst@cn.fujitsu.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38772

Differential Revision: D21789612

Pulled By: suo

fbshipit-source-id: 4dfd74eda71e0bfd757378daedc927a4a63ec0e4
This commit is contained in:
Xiong Wei
2020-06-17 12:41:16 -07:00
committed by Facebook GitHub Bot
parent 5e77999ecb
commit 55bcb5dccc
4 changed files with 61 additions and 5 deletions

View File

@ -62,7 +62,7 @@ TORCH_LIBRARY(aten, m) {
m.def("replace(str self, str old, str new, int max=-1) -> str");
m.def("partition(str self, str separator) -> (str, str, str)");
m.def("rpartition(str self, str separator) -> (str, str, str)");
m.def("split.str(str self, str separator=' ', int max=-1) -> str[]");
m.def("split.str(str self, str? separator=None, int max=-1) -> str[]");
m.def("rsplit(str self, str separator=' ', int max=-1) -> str[]");
m.def("join(str self, str[] values) -> str");

View File

@ -101,6 +101,7 @@ white_list = [
('aten::__and__', datetime.date(2020, 6, 30)),
('aten::__or__', datetime.date(2020, 6, 30)),
('aten::__xor__', datetime.date(2020, 6, 30)),
('aten::split', datetime.date(2020, 6, 30)),
]

View File

@ -263,9 +263,14 @@ class TestScript(JitTestCase):
self.checkScript(test_rpartition, ())
def test_split():
# type: () -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str]]
"""
type: () -> Tuple[List[str], List[str], List[str], List[str], List[str],
List[str], List[str], List[str], List[str], List[str], List[str]]
"""
return (
"a a a a a".split(),
"a a a a a".split(),
" a a\ta \v a \v\f\n a \t ".split(),
" a a a a a ".split(" "),
"a a a a a ".split(" ", 10),
"a a a a a ".split(" ", -1),
@ -277,6 +282,14 @@ class TestScript(JitTestCase):
)
self.checkScript(test_split, ())
# test raising error for empty separator
def test_split_empty_separator():
s = "test"
return s.split("")
self.checkScriptRaisesRegex(test_split_empty_separator, (), Exception,
"empty separator")
def test_rsplit():
# type: () -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str]]
return (

View File

@ -129,6 +129,33 @@ RegisterOperators reg_str_ops({
});
// consecutive whitespace are regarded as a single separator,
// the result will contain no empty strings at the start or end
// if the string has leading or trailing whitespace.
c10::List<std::string> splitNoneSeparator(const std::string& string) {
c10::List<std::string> splits;
// whitespaces includes tab, space and
// the delimiters defined in the implementation of splitlines
std::string whitespaces =
" \t\n\r\r\n\v\x0b\f\x0c\x1c\x1d\x1e\x85\u2028\u2029";
std::string::size_type prev_pos = 0;
std::string::size_type pos = 0;
while ((pos = string.find_first_of(whitespaces, pos)) != std::string::npos) {
auto substr = string.substr(prev_pos, pos - prev_pos);
// skip the whitespaces as the Python split() method
if (!substr.empty()) {
splits.emplace_back(substr);
}
pos++;
prev_pos = pos;
}
if (prev_pos != string.size()) {
splits.emplace_back(string.substr(prev_pos));
}
return splits;
}
// String Ops
// Implementations located in torch/csrc/jit/runtime/register_string_ops.cpp
TORCH_LIBRARY_IMPL(aten, CatchAll, m) {
@ -546,19 +573,34 @@ TORCH_LIBRARY_IMPL(aten, CatchAll, m) {
});
m.impl(
"split.str", [](std::string string, std::string separator, int64_t max) {
"split.str",
[](const std::string& string,
c10::optional<std::string> separator,
int64_t max) {
if (!separator.has_value()) {
// if separator is not specified,
// a different splitting algorithm is applied as Python
return splitNoneSeparator(string);
;
}
if (separator.value().empty()) {
throw std::runtime_error("ValueError: empty separator");
}
std::string::size_type prev_pos = 0;
std::string::size_type pos = 0;
c10::List<std::string> splits;
auto count = 0;
while ((pos = string.find(separator, pos)) != std::string::npos) {
while ((pos = string.find(separator.value(), pos)) !=
std::string::npos) {
count++;
if (max >= 0 && count > max) {
break;
} else {
splits.emplace_back(string.substr(prev_pos, pos - prev_pos));
}
pos += separator.size();
pos += separator.value().size();
prev_pos = pos;
}
splits.emplace_back(string.substr(prev_pos, string.size() - prev_pos));