[Static Runtime] Add schema checks for aten::list (#83753)

Summary:
The previous implementation assumed that there was only one overload and unconditionally tried to convert its input into a string. Some users were running into crashes because of this. Added a new overload for the list overload and schema checks.

Also, I managed to uncover another bug when writing tests for this case (yikes). Returning inputs didn't work because the input cleanup process would destroy the output. Extended `CreateOwnedRefsForSpecialIValues` to fix that.

Test Plan: CI + new unit tests

Differential Revision: D38870803

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83753
Approved by: https://github.com/tenpercent, https://github.com/albanD
This commit is contained in:
Mike Iovine
2022-08-22 13:42:47 +00:00
committed by PyTorch MergeBot
parent d46dba18f7
commit 09157c76c0
3 changed files with 42 additions and 21 deletions

View File

@ -2595,23 +2595,28 @@ TEST(StaticRuntime, JIT_Aten_Numel) {
}
TEST(StaticRuntime, JIT_Aten_List) {
const std::string script = R"IR(
const auto script_str = R"IR(
graph(%a: str):
%1 : int = prim::Constant[value=0]()
%ret: str[] = aten::list(%a)
return (%ret)
)IR";
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
vmap.reserve(0);
parseIR(script, graph.get(), vmap);
torch::jit::StaticModule smodule(graph);
string a = "abcd";
std::string a = "abcd";
std::vector<IValue> args0{a};
testStaticRuntime(script_str, args0);
testStaticRuntime(script, args0);
// Update the result of aten::list to ensure that a deep copy
// took place
const auto script_list = R"IR(
graph(%a : int[]):
%idx : int = prim::Constant[value=0]()
%value : int = prim::Constant[value=42]()
%res : int[] = aten::list(%a)
%updated : int[] = aten::_set_item(%res, %idx, %value)
return (%res, %a)
)IR";
std::vector<IValue> args1{c10::List<int64_t>{1, 2, 3}};
testStaticRuntime(script_list, args1);
}
TEST(StaticRuntime, JIT_Aten_Range_Length) {