mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
[Static Runtime] Add schema checks for aten::list (#83753)
Summary: The previous implementation assumed that there was only one overload and unconditionally tried to convert its input into a string. Some users were running into crashes because of this. Added a new overload for the list overload and schema checks. Also, I managed to uncover another bug when writing tests for this case (yikes). Returning inputs didn't work because the input cleanup process would destroy the output. Extended `CreateOwnedRefsForSpecialIValues` to fix that. Test Plan: CI + new unit tests Differential Revision: D38870803 Pull Request resolved: https://github.com/pytorch/pytorch/pull/83753 Approved by: https://github.com/tenpercent, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
d46dba18f7
commit
09157c76c0
@ -2595,23 +2595,28 @@ TEST(StaticRuntime, JIT_Aten_Numel) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(StaticRuntime, JIT_Aten_List) {
|
TEST(StaticRuntime, JIT_Aten_List) {
|
||||||
const std::string script = R"IR(
|
const auto script_str = R"IR(
|
||||||
graph(%a: str):
|
graph(%a: str):
|
||||||
%1 : int = prim::Constant[value=0]()
|
|
||||||
%ret: str[] = aten::list(%a)
|
%ret: str[] = aten::list(%a)
|
||||||
return (%ret)
|
return (%ret)
|
||||||
)IR";
|
)IR";
|
||||||
|
std::string a = "abcd";
|
||||||
auto graph = std::make_shared<Graph>();
|
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
|
||||||
vmap.reserve(0);
|
|
||||||
parseIR(script, graph.get(), vmap);
|
|
||||||
torch::jit::StaticModule smodule(graph);
|
|
||||||
|
|
||||||
string a = "abcd";
|
|
||||||
std::vector<IValue> args0{a};
|
std::vector<IValue> args0{a};
|
||||||
|
testStaticRuntime(script_str, args0);
|
||||||
|
|
||||||
testStaticRuntime(script, args0);
|
// Update the result of aten::list to ensure that a deep copy
|
||||||
|
// took place
|
||||||
|
const auto script_list = R"IR(
|
||||||
|
graph(%a : int[]):
|
||||||
|
%idx : int = prim::Constant[value=0]()
|
||||||
|
%value : int = prim::Constant[value=42]()
|
||||||
|
%res : int[] = aten::list(%a)
|
||||||
|
%updated : int[] = aten::_set_item(%res, %idx, %value)
|
||||||
|
return (%res, %a)
|
||||||
|
)IR";
|
||||||
|
|
||||||
|
std::vector<IValue> args1{c10::List<int64_t>{1, 2, 3}};
|
||||||
|
testStaticRuntime(script_list, args1);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(StaticRuntime, JIT_Aten_Range_Length) {
|
TEST(StaticRuntime, JIT_Aten_Range_Length) {
|
||||||
|
|||||||
@ -204,15 +204,27 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|||||||
aten::list,
|
aten::list,
|
||||||
aten_list,
|
aten_list,
|
||||||
[](Node* n) -> SROperator {
|
[](Node* n) -> SROperator {
|
||||||
return [](ProcessedNode* p_node) {
|
if (n->matches(torch::schema("aten::list(str t) -> str[]"))) {
|
||||||
const auto str = p_node->Input(0).toStringRef();
|
return [](ProcessedNode* p_node) {
|
||||||
c10::List<std::string> chars;
|
const auto str = p_node->Input(0).toStringRef();
|
||||||
chars.reserve(str.size());
|
c10::List<std::string> chars;
|
||||||
for (auto c : str) {
|
chars.reserve(str.size());
|
||||||
chars.emplace_back(1, c);
|
for (auto c : str) {
|
||||||
}
|
chars.emplace_back(1, c);
|
||||||
p_node->Output(0) = std::move(chars);
|
}
|
||||||
};
|
p_node->Output(0) = std::move(chars);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n->matches(torch::schema("aten::list.t(t[] l) -> t[]"))) {
|
||||||
|
return [](ProcessedNode* p_node) {
|
||||||
|
const auto input = p_node->Input(0).toList();
|
||||||
|
p_node->Output(0) = input.copy();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
LogAndDumpSchema(n);
|
||||||
|
return nullptr;
|
||||||
});
|
});
|
||||||
|
|
||||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||||
|
|||||||
@ -1043,6 +1043,10 @@ void CreateOwnedRefsForSpecialValuesHelper(Graph& graph, Block* block) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto outputs = block->outputs();
|
auto outputs = block->outputs();
|
||||||
|
// Create owned refs for inputs. Otherwise, the input cleanup process
|
||||||
|
// will destroy our outputs before we return.
|
||||||
|
FastSet<Value*> inputs = {block->inputs().begin(), block->inputs().end()};
|
||||||
|
|
||||||
for (const auto i : c10::irange(outputs.size())) {
|
for (const auto i : c10::irange(outputs.size())) {
|
||||||
auto* output = outputs[i];
|
auto* output = outputs[i];
|
||||||
|
|
||||||
@ -1052,7 +1056,7 @@ void CreateOwnedRefsForSpecialValuesHelper(Graph& graph, Block* block) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (toIValue(output).has_value() ||
|
if ((inputs.find(output) != inputs.end()) || toIValue(output).has_value() ||
|
||||||
// If the output's owning block is not this one, it's from an outer
|
// If the output's owning block is not this one, it's from an outer
|
||||||
// scope
|
// scope
|
||||||
output->node()->owningBlock() != block) {
|
output->node()->owningBlock() != block) {
|
||||||
|
|||||||
Reference in New Issue
Block a user