mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix inconsistent results of string split
func on JIT mode (#38772)
Summary: Resolve https://github.com/pytorch/pytorch/issues/38207 Below is the description of split function according to [Python doc](https://docs.python.org/3.8/library/stdtypes.html?highlight=split#str.split). ``` If sep is not specified or is None, a different splitting algorithm is applied: runs of consecutive whitespace are regarded as a single separator, and the result will contain no empty strings at the start or end if the string has leading or trailing whitespace. ``` The logic to handle both none and empty separators is added in register_string_ops.cpp as fix. Signed-off-by: Xiong Wei <xiongw.fnst@cn.fujitsu.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/38772 Differential Revision: D21789612 Pulled By: suo fbshipit-source-id: 4dfd74eda71e0bfd757378daedc927a4a63ec0e4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5e77999ecb
commit
55bcb5dccc
@ -62,7 +62,7 @@ TORCH_LIBRARY(aten, m) {
|
||||
m.def("replace(str self, str old, str new, int max=-1) -> str");
|
||||
m.def("partition(str self, str separator) -> (str, str, str)");
|
||||
m.def("rpartition(str self, str separator) -> (str, str, str)");
|
||||
m.def("split.str(str self, str separator=' ', int max=-1) -> str[]");
|
||||
m.def("split.str(str self, str? separator=None, int max=-1) -> str[]");
|
||||
m.def("rsplit(str self, str separator=' ', int max=-1) -> str[]");
|
||||
m.def("join(str self, str[] values) -> str");
|
||||
|
||||
|
@ -101,6 +101,7 @@ white_list = [
|
||||
('aten::__and__', datetime.date(2020, 6, 30)),
|
||||
('aten::__or__', datetime.date(2020, 6, 30)),
|
||||
('aten::__xor__', datetime.date(2020, 6, 30)),
|
||||
('aten::split', datetime.date(2020, 6, 30)),
|
||||
]
|
||||
|
||||
|
||||
|
@ -263,9 +263,14 @@ class TestScript(JitTestCase):
|
||||
self.checkScript(test_rpartition, ())
|
||||
|
||||
def test_split():
|
||||
# type: () -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str]]
|
||||
"""
|
||||
type: () -> Tuple[List[str], List[str], List[str], List[str], List[str],
|
||||
List[str], List[str], List[str], List[str], List[str], List[str]]
|
||||
"""
|
||||
return (
|
||||
"a a a a a".split(),
|
||||
"a a a a a".split(),
|
||||
" a a\ta \v a \v\f\n a \t ".split(),
|
||||
" a a a a a ".split(" "),
|
||||
"a a a a a ".split(" ", 10),
|
||||
"a a a a a ".split(" ", -1),
|
||||
@ -277,6 +282,14 @@ class TestScript(JitTestCase):
|
||||
)
|
||||
self.checkScript(test_split, ())
|
||||
|
||||
# test raising error for empty separator
|
||||
def test_split_empty_separator():
|
||||
s = "test"
|
||||
return s.split("")
|
||||
|
||||
self.checkScriptRaisesRegex(test_split_empty_separator, (), Exception,
|
||||
"empty separator")
|
||||
|
||||
def test_rsplit():
|
||||
# type: () -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str]]
|
||||
return (
|
||||
|
@ -129,6 +129,33 @@ RegisterOperators reg_str_ops({
|
||||
|
||||
});
|
||||
|
||||
// consecutive whitespace are regarded as a single separator,
|
||||
// the result will contain no empty strings at the start or end
|
||||
// if the string has leading or trailing whitespace.
|
||||
c10::List<std::string> splitNoneSeparator(const std::string& string) {
|
||||
c10::List<std::string> splits;
|
||||
// whitespaces includes tab, space and
|
||||
// the delimiters defined in the implementation of splitlines
|
||||
std::string whitespaces =
|
||||
" \t\n\r\r\n\v\x0b\f\x0c\x1c\x1d\x1e\x85\u2028\u2029";
|
||||
std::string::size_type prev_pos = 0;
|
||||
std::string::size_type pos = 0;
|
||||
|
||||
while ((pos = string.find_first_of(whitespaces, pos)) != std::string::npos) {
|
||||
auto substr = string.substr(prev_pos, pos - prev_pos);
|
||||
// skip the whitespaces as the Python split() method
|
||||
if (!substr.empty()) {
|
||||
splits.emplace_back(substr);
|
||||
}
|
||||
pos++;
|
||||
prev_pos = pos;
|
||||
}
|
||||
if (prev_pos != string.size()) {
|
||||
splits.emplace_back(string.substr(prev_pos));
|
||||
}
|
||||
return splits;
|
||||
}
|
||||
|
||||
// String Ops
|
||||
// Implementations located in torch/csrc/jit/runtime/register_string_ops.cpp
|
||||
TORCH_LIBRARY_IMPL(aten, CatchAll, m) {
|
||||
@ -546,19 +573,34 @@ TORCH_LIBRARY_IMPL(aten, CatchAll, m) {
|
||||
});
|
||||
|
||||
m.impl(
|
||||
"split.str", [](std::string string, std::string separator, int64_t max) {
|
||||
"split.str",
|
||||
[](const std::string& string,
|
||||
c10::optional<std::string> separator,
|
||||
int64_t max) {
|
||||
if (!separator.has_value()) {
|
||||
// if separator is not specified,
|
||||
// a different splitting algorithm is applied as Python
|
||||
return splitNoneSeparator(string);
|
||||
;
|
||||
}
|
||||
if (separator.value().empty()) {
|
||||
throw std::runtime_error("ValueError: empty separator");
|
||||
}
|
||||
|
||||
std::string::size_type prev_pos = 0;
|
||||
std::string::size_type pos = 0;
|
||||
c10::List<std::string> splits;
|
||||
auto count = 0;
|
||||
while ((pos = string.find(separator, pos)) != std::string::npos) {
|
||||
|
||||
while ((pos = string.find(separator.value(), pos)) !=
|
||||
std::string::npos) {
|
||||
count++;
|
||||
if (max >= 0 && count > max) {
|
||||
break;
|
||||
} else {
|
||||
splits.emplace_back(string.substr(prev_pos, pos - prev_pos));
|
||||
}
|
||||
pos += separator.size();
|
||||
pos += separator.value().size();
|
||||
prev_pos = pos;
|
||||
}
|
||||
splits.emplace_back(string.substr(prev_pos, string.size() - prev_pos));
|
||||
|
Reference in New Issue
Block a user