[PyTorch][TorchScript] Add support for join on List of strings in TorchScript

Summary: Add support for join on List of strings in TorchScript.

Test Plan:
(pytorch) smummadi@smummadi-mbp pytorch % python test/test_jit_string.py
Fail to import hypothesis in common_utils, tests are not derandomized
.
----------------------------------------------------------------------
Ran 1 test in 1.090s

OK

Differential Revision: D19611800

fbshipit-source-id: cef66356abc14dfd100a806d25dd1a8bc9af0a11
This commit is contained in:
Sampath Mummadi
2020-01-29 18:21:07 -08:00
committed by Facebook Github Bot
parent cccf5e7011
commit 8ead65a946
3 changed files with 32 additions and 2 deletions

View File

@ -65,6 +65,7 @@ TESTS = [
'test_jit_disabled',
'test_function_schema',
'test_overrides',
'test_jit_string',
]
# skip < 3.3 because mock is added in 3.3 and is used in rpc_spawn

View File

@ -1,5 +1,5 @@
from test_jit import JitTestCase
from torch.testing._internal.common_utils import run_tests
class TestScript(JitTestCase):
def test_str_ops(self):
@ -310,3 +310,17 @@ class TestScript(JitTestCase):
for i in range(len(inputs) - 1):
self.checkScript(test_str_cmp, (inputs[i], inputs[i + 1]))
def test_str_join():
return (
",".join(["a"]),
",".join(["a", "b", "c"]),
",".join(["aa", "bb", "cc"]),
",".join(["a,a", "bb", "c,c"]),
"**a**".join(["b", "c", "d", "e"]),
"".join(["a", "b", "c"]),
)
self.checkScript(test_str_join, ())
if __name__ == '__main__':
run_tests()

View File

@ -705,7 +705,22 @@ auto reg_str_ops_2 =
std::reverse(substr.begin(), substr.end());
splits.emplace(splits.begin(), substr);
return splits;
}));
}))
.op("aten::join(str self, str[] values) -> str",
torch::RegisterOperators::options()
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)
.catchAllKernel([](const std::string& string,
const c10::List<std::string>& values) {
std::stringstream ss;
for (auto it = values.begin(); it != values.end(); ++it) {
ss << static_cast<std::string>(*it);
if (it != values.end() - 1) {
ss << string;
}
}
return ss.str();
}));
} // namespace
} // namespace jit
} // namespace torch