[Static Runtime] Schema checks for index_put (#84152)

Summary:
`index_put` can take a list of tensors, but Static Runtime always tries to convert its argument to a list of optional tensors. This was causing crashes for some users. Add some schema checks to prevent this, and add a new overload for the new case.

Also, I found a clear bug in the JIT interpreter (mutating the argument when its not supposed to), so I fixed that too.

Test Plan: New unit test

Differential Revision: D39072214

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84152
Approved by: https://github.com/tenpercent
This commit is contained in:
Mike Iovine
2022-08-31 01:20:14 +00:00
committed by PyTorch MergeBot
parent 7532d5b125
commit db7784e722
4 changed files with 47 additions and 24 deletions

View File

@ -1090,7 +1090,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
}
auto self = pop(stack).toTensor();
auto result =
at::index_put_(self, opt_list_indices, values, accumulate);
at::index_put(self, opt_list_indices, values, accumulate);
push(stack, std::move(result));
},
aliasAnalysisFromSchema()),