mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Static Runtime] Add schema checks for aten::list (#83753)
Summary: The previous implementation assumed that there was only one overload and unconditionally tried to convert its input into a string. Some users were running into crashes because of this. Added a new overload for the list overload and schema checks. Also, I managed to uncover another bug when writing tests for this case (yikes). Returning inputs didn't work because the input cleanup process would destroy the output. Extended `CreateOwnedRefsForSpecialIValues` to fix that. Test Plan: CI + new unit tests Differential Revision: D38870803 Pull Request resolved: https://github.com/pytorch/pytorch/pull/83753 Approved by: https://github.com/tenpercent, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
d46dba18f7
commit
09157c76c0
@ -2595,23 +2595,28 @@ TEST(StaticRuntime, JIT_Aten_Numel) {
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, JIT_Aten_List) {
|
||||
const std::string script = R"IR(
|
||||
const auto script_str = R"IR(
|
||||
graph(%a: str):
|
||||
%1 : int = prim::Constant[value=0]()
|
||||
%ret: str[] = aten::list(%a)
|
||||
return (%ret)
|
||||
)IR";
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
vmap.reserve(0);
|
||||
parseIR(script, graph.get(), vmap);
|
||||
torch::jit::StaticModule smodule(graph);
|
||||
|
||||
string a = "abcd";
|
||||
std::string a = "abcd";
|
||||
std::vector<IValue> args0{a};
|
||||
testStaticRuntime(script_str, args0);
|
||||
|
||||
testStaticRuntime(script, args0);
|
||||
// Update the result of aten::list to ensure that a deep copy
|
||||
// took place
|
||||
const auto script_list = R"IR(
|
||||
graph(%a : int[]):
|
||||
%idx : int = prim::Constant[value=0]()
|
||||
%value : int = prim::Constant[value=42]()
|
||||
%res : int[] = aten::list(%a)
|
||||
%updated : int[] = aten::_set_item(%res, %idx, %value)
|
||||
return (%res, %a)
|
||||
)IR";
|
||||
|
||||
std::vector<IValue> args1{c10::List<int64_t>{1, 2, 3}};
|
||||
testStaticRuntime(script_list, args1);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, JIT_Aten_Range_Length) {
|
||||
|
@ -204,15 +204,27 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
aten::list,
|
||||
aten_list,
|
||||
[](Node* n) -> SROperator {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto str = p_node->Input(0).toStringRef();
|
||||
c10::List<std::string> chars;
|
||||
chars.reserve(str.size());
|
||||
for (auto c : str) {
|
||||
chars.emplace_back(1, c);
|
||||
}
|
||||
p_node->Output(0) = std::move(chars);
|
||||
};
|
||||
if (n->matches(torch::schema("aten::list(str t) -> str[]"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto str = p_node->Input(0).toStringRef();
|
||||
c10::List<std::string> chars;
|
||||
chars.reserve(str.size());
|
||||
for (auto c : str) {
|
||||
chars.emplace_back(1, c);
|
||||
}
|
||||
p_node->Output(0) = std::move(chars);
|
||||
};
|
||||
}
|
||||
|
||||
if (n->matches(torch::schema("aten::list.t(t[] l) -> t[]"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto input = p_node->Input(0).toList();
|
||||
p_node->Output(0) = input.copy();
|
||||
};
|
||||
}
|
||||
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
|
@ -1043,6 +1043,10 @@ void CreateOwnedRefsForSpecialValuesHelper(Graph& graph, Block* block) {
|
||||
}
|
||||
|
||||
auto outputs = block->outputs();
|
||||
// Create owned refs for inputs. Otherwise, the input cleanup process
|
||||
// will destroy our outputs before we return.
|
||||
FastSet<Value*> inputs = {block->inputs().begin(), block->inputs().end()};
|
||||
|
||||
for (const auto i : c10::irange(outputs.size())) {
|
||||
auto* output = outputs[i];
|
||||
|
||||
@ -1052,7 +1056,7 @@ void CreateOwnedRefsForSpecialValuesHelper(Graph& graph, Block* block) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (toIValue(output).has_value() ||
|
||||
if ((inputs.find(output) != inputs.end()) || toIValue(output).has_value() ||
|
||||
// If the output's owning block is not this one, it's from an outer
|
||||
// scope
|
||||
output->node()->owningBlock() != block) {
|
||||
|
Reference in New Issue
Block a user