mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Static Runtime] Add schema checks for aten::list (#83753)
Summary: The previous implementation assumed that there was only one overload and unconditionally tried to convert its input into a string. Some users were running into crashes because of this. Added a new overload for the list overload and schema checks. Also, I managed to uncover another bug when writing tests for this case (yikes). Returning inputs didn't work because the input cleanup process would destroy the output. Extended `CreateOwnedRefsForSpecialIValues` to fix that. Test Plan: CI + new unit tests Differential Revision: D38870803 Pull Request resolved: https://github.com/pytorch/pytorch/pull/83753 Approved by: https://github.com/tenpercent, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
d46dba18f7
commit
09157c76c0
@ -2595,23 +2595,28 @@ TEST(StaticRuntime, JIT_Aten_Numel) {
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, JIT_Aten_List) {
|
||||
const std::string script = R"IR(
|
||||
const auto script_str = R"IR(
|
||||
graph(%a: str):
|
||||
%1 : int = prim::Constant[value=0]()
|
||||
%ret: str[] = aten::list(%a)
|
||||
return (%ret)
|
||||
)IR";
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
vmap.reserve(0);
|
||||
parseIR(script, graph.get(), vmap);
|
||||
torch::jit::StaticModule smodule(graph);
|
||||
|
||||
string a = "abcd";
|
||||
std::string a = "abcd";
|
||||
std::vector<IValue> args0{a};
|
||||
testStaticRuntime(script_str, args0);
|
||||
|
||||
testStaticRuntime(script, args0);
|
||||
// Update the result of aten::list to ensure that a deep copy
|
||||
// took place
|
||||
const auto script_list = R"IR(
|
||||
graph(%a : int[]):
|
||||
%idx : int = prim::Constant[value=0]()
|
||||
%value : int = prim::Constant[value=42]()
|
||||
%res : int[] = aten::list(%a)
|
||||
%updated : int[] = aten::_set_item(%res, %idx, %value)
|
||||
return (%res, %a)
|
||||
)IR";
|
||||
|
||||
std::vector<IValue> args1{c10::List<int64_t>{1, 2, 3}};
|
||||
testStaticRuntime(script_list, args1);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, JIT_Aten_Range_Length) {
|
||||
|
Reference in New Issue
Block a user