Add hacked_twin overloads for _unsafe indexing functions (#104127)

Fixes #104037

This hacky workaround already exists for the normal overloads.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104127
Approved by: https://github.com/ezyang
This commit is contained in:
Peter Bell
2023-07-04 21:52:36 +00:00
committed by PyTorch MergeBot
parent 2385dad4b3
commit 12ca224662
3 changed files with 68 additions and 0 deletions

View File

@ -332,9 +332,11 @@ PT_OPS_PRIM = [
"aten::copy_.float",
"aten::backward",
"aten::index.Tensor_hacked_twin",
"aten::_unsafe_index.Tensor_hacked_twin",
"aten::_index_put_impl_.hacked_twin",
"aten::index_put_.hacked_twin",
"aten::index_put.hacked_twin",
"aten::_unsafe_index_put.hacked_twin",
"aten::to.prim_Device",
"aten::to.prim_dtype",
"prim::is_cuda",

View File

@ -171,6 +171,39 @@ class TestMisc(JitTestCase):
torch.index_put_(input1, [index1], value1, accumulate=False)
self.assertEqual(input, input1)
def test_unsafe_hacked_twin(self):
def gen_data():
with freeze_rng_state():
return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
input, index, value, = gen_data()
input1, index1, value1, = gen_data()
out1 = torch.ops.aten._unsafe_index_put.hacked_twin(input, [index], value, accumulate=False)
out2 = torch.index_put(input1, [index1], value1, accumulate=False)
self.assertEqual(out1, out2)
torch.ops.aten._unsafe_index.Tensor_hacked_twin(input, [index])
torch.index_put(input1, [index1], value1, accumulate=False)
self.assertEqual(input, input1)
def index_put_fn(input, index, value):
return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False)
input2, index2, value2 = gen_data()
script_index_put_fn = torch.jit.script(index_put_fn)
expect = index_put_fn(input2.clone(), index2, value2)
actual = script_index_put_fn(input2.clone(), index2, value2)
self.assertEqual(expect, actual)
def index_fn(input, index, value):
return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False)
script_index_fn = torch.jit.script(index_fn)
expect = index_fn(input2.clone(), index2, value2)
actual = script_index_fn(input2.clone(), index2, value2)
self.assertEqual(expect, actual)
def test_export_opnames_interface(self):
@torch.jit.interface

View File

@ -1058,6 +1058,21 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
push(stack, std::move(result));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::_unsafe_index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"),
[](Stack& stack) {
auto indices = pop(stack).to<c10::List<at::Tensor>>();
c10::List<c10::optional<at::Tensor>> opt_list_indices;
opt_list_indices.reserve(indices.size());
for (const auto& ten : indices) {
opt_list_indices.push_back(ten);
}
auto self = pop(stack).toTensor();
auto result = at::_unsafe_index(self, opt_list_indices);
push(stack, std::move(result));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::_index_put_impl_.hacked_twin(Tensor(a!) self, Tensor[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)"),
@ -1113,6 +1128,24 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
push(stack, std::move(result));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::_unsafe_index_put.hacked_twin(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor"),
[](Stack& stack) {
auto accumulate = pop(stack).toBool();
auto values = pop(stack).toTensor();
auto indices = pop(stack).to<c10::List<at::Tensor>>();
c10::List<c10::optional<at::Tensor>> opt_list_indices;
opt_list_indices.reserve(indices.size());
for (const auto& ten : indices) {
opt_list_indices.push_back(ten);
}
auto self = pop(stack).toTensor();
auto result =
at::_unsafe_index_put(self, opt_list_indices, values, accumulate);
push(stack, std::move(result));
},
aliasAnalysisFromSchema()),
// reference function parse_to_conversion in python_arg_parsing.h
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(