[Static Runtime] Add schema checks for aten::list (#83753)

Summary:
The previous implementation assumed that there was only one overload and unconditionally tried to convert its input into a string. Some users were running into crashes because of this. Added a new overload for the list overload and schema checks.

Also, I managed to uncover another bug when writing tests for this case (yikes). Returning inputs didn't work because the input cleanup process would destroy the output. Extended `CreateOwnedRefsForSpecialIValues` to fix that.

Test Plan: CI + new unit tests

Differential Revision: D38870803

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83753
Approved by: https://github.com/tenpercent, https://github.com/albanD
This commit is contained in:
Mike Iovine
2022-08-22 13:42:47 +00:00
committed by PyTorch MergeBot
parent d46dba18f7
commit 09157c76c0
3 changed files with 42 additions and 21 deletions

View File

@ -2595,23 +2595,28 @@ TEST(StaticRuntime, JIT_Aten_Numel) {
}
TEST(StaticRuntime, JIT_Aten_List) {
const std::string script = R"IR(
const auto script_str = R"IR(
graph(%a: str):
%1 : int = prim::Constant[value=0]()
%ret: str[] = aten::list(%a)
return (%ret)
)IR";
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
vmap.reserve(0);
parseIR(script, graph.get(), vmap);
torch::jit::StaticModule smodule(graph);
string a = "abcd";
std::string a = "abcd";
std::vector<IValue> args0{a};
testStaticRuntime(script_str, args0);
testStaticRuntime(script, args0);
// Update the result of aten::list to ensure that a deep copy
// took place
const auto script_list = R"IR(
graph(%a : int[]):
%idx : int = prim::Constant[value=0]()
%value : int = prim::Constant[value=42]()
%res : int[] = aten::list(%a)
%updated : int[] = aten::_set_item(%res, %idx, %value)
return (%res, %a)
)IR";
std::vector<IValue> args1{c10::List<int64_t>{1, 2, 3}};
testStaticRuntime(script_list, args1);
}
TEST(StaticRuntime, JIT_Aten_Range_Length) {

View File

@ -204,15 +204,27 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::list,
aten_list,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto str = p_node->Input(0).toStringRef();
c10::List<std::string> chars;
chars.reserve(str.size());
for (auto c : str) {
chars.emplace_back(1, c);
}
p_node->Output(0) = std::move(chars);
};
if (n->matches(torch::schema("aten::list(str t) -> str[]"))) {
return [](ProcessedNode* p_node) {
const auto str = p_node->Input(0).toStringRef();
c10::List<std::string> chars;
chars.reserve(str.size());
for (auto c : str) {
chars.emplace_back(1, c);
}
p_node->Output(0) = std::move(chars);
};
}
if (n->matches(torch::schema("aten::list.t(t[] l) -> t[]"))) {
return [](ProcessedNode* p_node) {
const auto input = p_node->Input(0).toList();
p_node->Output(0) = input.copy();
};
}
LogAndDumpSchema(n);
return nullptr;
});
REGISTER_NATIVE_OPERATOR_FUNCTOR(

View File

@ -1043,6 +1043,10 @@ void CreateOwnedRefsForSpecialValuesHelper(Graph& graph, Block* block) {
}
auto outputs = block->outputs();
// Create owned refs for inputs. Otherwise, the input cleanup process
// will destroy our outputs before we return.
FastSet<Value*> inputs = {block->inputs().begin(), block->inputs().end()};
for (const auto i : c10::irange(outputs.size())) {
auto* output = outputs[i];
@ -1052,7 +1056,7 @@ void CreateOwnedRefsForSpecialValuesHelper(Graph& graph, Block* block) {
continue;
}
if (toIValue(output).has_value() ||
if ((inputs.find(output) != inputs.end()) || toIValue(output).has_value() ||
// If the output's owning block is not this one, it's from an outer
// scope
output->node()->owningBlock() != block) {