diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 828796e08f0a..ee37ddeaf71a 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -2595,23 +2595,28 @@ TEST(StaticRuntime, JIT_Aten_Numel) { } TEST(StaticRuntime, JIT_Aten_List) { - const std::string script = R"IR( + const auto script_str = R"IR( graph(%a: str): - %1 : int = prim::Constant[value=0]() %ret: str[] = aten::list(%a) return (%ret) )IR"; - - auto graph = std::make_shared(); - std::unordered_map vmap; - vmap.reserve(0); - parseIR(script, graph.get(), vmap); - torch::jit::StaticModule smodule(graph); - - string a = "abcd"; + std::string a = "abcd"; std::vector args0{a}; + testStaticRuntime(script_str, args0); - testStaticRuntime(script, args0); + // Update the result of aten::list to ensure that a deep copy + // took place + const auto script_list = R"IR( + graph(%a : int[]): + %idx : int = prim::Constant[value=0]() + %value : int = prim::Constant[value=42]() + %res : int[] = aten::list(%a) + %updated : int[] = aten::_set_item(%res, %idx, %value) + return (%res, %a) + )IR"; + + std::vector args1{c10::List{1, 2, 3}}; + testStaticRuntime(script_list, args1); } TEST(StaticRuntime, JIT_Aten_Range_Length) { diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp index 16e357d8f459..0695c3f70a39 100644 --- a/torch/csrc/jit/runtime/static/native_ops.cpp +++ b/torch/csrc/jit/runtime/static/native_ops.cpp @@ -204,15 +204,27 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::list, aten_list, [](Node* n) -> SROperator { - return [](ProcessedNode* p_node) { - const auto str = p_node->Input(0).toStringRef(); - c10::List chars; - chars.reserve(str.size()); - for (auto c : str) { - chars.emplace_back(1, c); - } - p_node->Output(0) = std::move(chars); - }; + if (n->matches(torch::schema("aten::list(str t) -> str[]"))) { + return [](ProcessedNode* p_node) { + const auto str = p_node->Input(0).toStringRef(); + c10::List chars; + chars.reserve(str.size()); + for (auto c : str) { + chars.emplace_back(1, c); + } + p_node->Output(0) = std::move(chars); + }; + } + + if (n->matches(torch::schema("aten::list.t(t[] l) -> t[]"))) { + return [](ProcessedNode* p_node) { + const auto input = p_node->Input(0).toList(); + p_node->Output(0) = input.copy(); + }; + } + + LogAndDumpSchema(n); + return nullptr; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp index 7125bb696878..5e5290275f17 100644 --- a/torch/csrc/jit/runtime/static/passes.cpp +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -1043,6 +1043,10 @@ void CreateOwnedRefsForSpecialValuesHelper(Graph& graph, Block* block) { } auto outputs = block->outputs(); + // Create owned refs for inputs. Otherwise, the input cleanup process + // will destroy our outputs before we return. + FastSet inputs = {block->inputs().begin(), block->inputs().end()}; + for (const auto i : c10::irange(outputs.size())) { auto* output = outputs[i]; @@ -1052,7 +1056,7 @@ void CreateOwnedRefsForSpecialValuesHelper(Graph& graph, Block* block) { continue; } - if (toIValue(output).has_value() || + if ((inputs.find(output) != inputs.end()) || toIValue(output).has_value() || // If the output's owning block is not this one, it's from an outer // scope output->node()->owningBlock() != block) {