Add hacked_twin overloads for _unsafe indexing functions (#104127)

Fixes #104037

This hacky workaround already exists for the normal overloads.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104127
Approved by: https://github.com/ezyang
This commit is contained in:
Peter Bell
2023-07-04 21:52:36 +00:00
committed by PyTorch MergeBot
parent 2385dad4b3
commit 12ca224662
3 changed files with 68 additions and 0 deletions

View File

@ -1058,6 +1058,21 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
push(stack, std::move(result));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::_unsafe_index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"),
[](Stack& stack) {
auto indices = pop(stack).to<c10::List<at::Tensor>>();
c10::List<c10::optional<at::Tensor>> opt_list_indices;
opt_list_indices.reserve(indices.size());
for (const auto& ten : indices) {
opt_list_indices.push_back(ten);
}
auto self = pop(stack).toTensor();
auto result = at::_unsafe_index(self, opt_list_indices);
push(stack, std::move(result));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::_index_put_impl_.hacked_twin(Tensor(a!) self, Tensor[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)"),
@ -1113,6 +1128,24 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
push(stack, std::move(result));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::_unsafe_index_put.hacked_twin(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor"),
[](Stack& stack) {
auto accumulate = pop(stack).toBool();
auto values = pop(stack).toTensor();
auto indices = pop(stack).to<c10::List<at::Tensor>>();
c10::List<c10::optional<at::Tensor>> opt_list_indices;
opt_list_indices.reserve(indices.size());
for (const auto& ten : indices) {
opt_list_indices.push_back(ten);
}
auto self = pop(stack).toTensor();
auto result =
at::_unsafe_index_put(self, opt_list_indices, values, accumulate);
push(stack, std::move(result));
},
aliasAnalysisFromSchema()),
// reference function parse_to_conversion in python_arg_parsing.h
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(