mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorch][TorchScript] Add support for join on List of strings in TorchScript
Summary: Add support for join on List of strings in TorchScript. Test Plan: (pytorch) smummadi@smummadi-mbp pytorch % python test/test_jit_string.py Fail to import hypothesis in common_utils, tests are not derandomized . ---------------------------------------------------------------------- Ran 1 test in 1.090s OK Differential Revision: D19611800 fbshipit-source-id: cef66356abc14dfd100a806d25dd1a8bc9af0a11
This commit is contained in:
committed by
Facebook Github Bot
parent
cccf5e7011
commit
8ead65a946
@ -65,6 +65,7 @@ TESTS = [
|
||||
'test_jit_disabled',
|
||||
'test_function_schema',
|
||||
'test_overrides',
|
||||
'test_jit_string',
|
||||
]
|
||||
|
||||
# skip < 3.3 because mock is added in 3.3 and is used in rpc_spawn
|
||||
|
@ -1,5 +1,5 @@
|
||||
from test_jit import JitTestCase
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
class TestScript(JitTestCase):
|
||||
def test_str_ops(self):
|
||||
@ -310,3 +310,17 @@ class TestScript(JitTestCase):
|
||||
|
||||
for i in range(len(inputs) - 1):
|
||||
self.checkScript(test_str_cmp, (inputs[i], inputs[i + 1]))
|
||||
|
||||
def test_str_join():
|
||||
return (
|
||||
",".join(["a"]),
|
||||
",".join(["a", "b", "c"]),
|
||||
",".join(["aa", "bb", "cc"]),
|
||||
",".join(["a,a", "bb", "c,c"]),
|
||||
"**a**".join(["b", "c", "d", "e"]),
|
||||
"".join(["a", "b", "c"]),
|
||||
)
|
||||
self.checkScript(test_str_join, ())
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -705,7 +705,22 @@ auto reg_str_ops_2 =
|
||||
std::reverse(substr.begin(), substr.end());
|
||||
splits.emplace(splits.begin(), substr);
|
||||
return splits;
|
||||
}));
|
||||
}))
|
||||
|
||||
.op("aten::join(str self, str[] values) -> str",
|
||||
torch::RegisterOperators::options()
|
||||
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)
|
||||
.catchAllKernel([](const std::string& string,
|
||||
const c10::List<std::string>& values) {
|
||||
std::stringstream ss;
|
||||
for (auto it = values.begin(); it != values.end(); ++it) {
|
||||
ss << static_cast<std::string>(*it);
|
||||
if (it != values.end() - 1) {
|
||||
ss << string;
|
||||
}
|
||||
}
|
||||
return ss.str();
|
||||
}));
|
||||
} // namespace
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user