mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add hacked_twin overloads for _unsafe indexing functions (#104127)
Fixes #104037 This hacky workaround already exists for the normal overloads. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104127 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
2385dad4b3
commit
12ca224662
@ -332,9 +332,11 @@ PT_OPS_PRIM = [
|
||||
"aten::copy_.float",
|
||||
"aten::backward",
|
||||
"aten::index.Tensor_hacked_twin",
|
||||
"aten::_unsafe_index.Tensor_hacked_twin",
|
||||
"aten::_index_put_impl_.hacked_twin",
|
||||
"aten::index_put_.hacked_twin",
|
||||
"aten::index_put.hacked_twin",
|
||||
"aten::_unsafe_index_put.hacked_twin",
|
||||
"aten::to.prim_Device",
|
||||
"aten::to.prim_dtype",
|
||||
"prim::is_cuda",
|
||||
|
@ -171,6 +171,39 @@ class TestMisc(JitTestCase):
|
||||
torch.index_put_(input1, [index1], value1, accumulate=False)
|
||||
self.assertEqual(input, input1)
|
||||
|
||||
def test_unsafe_hacked_twin(self):
|
||||
|
||||
def gen_data():
|
||||
with freeze_rng_state():
|
||||
return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
|
||||
|
||||
input, index, value, = gen_data()
|
||||
input1, index1, value1, = gen_data()
|
||||
out1 = torch.ops.aten._unsafe_index_put.hacked_twin(input, [index], value, accumulate=False)
|
||||
out2 = torch.index_put(input1, [index1], value1, accumulate=False)
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
torch.ops.aten._unsafe_index.Tensor_hacked_twin(input, [index])
|
||||
torch.index_put(input1, [index1], value1, accumulate=False)
|
||||
self.assertEqual(input, input1)
|
||||
|
||||
def index_put_fn(input, index, value):
|
||||
return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False)
|
||||
|
||||
input2, index2, value2 = gen_data()
|
||||
script_index_put_fn = torch.jit.script(index_put_fn)
|
||||
expect = index_put_fn(input2.clone(), index2, value2)
|
||||
actual = script_index_put_fn(input2.clone(), index2, value2)
|
||||
self.assertEqual(expect, actual)
|
||||
|
||||
def index_fn(input, index, value):
|
||||
return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False)
|
||||
|
||||
script_index_fn = torch.jit.script(index_fn)
|
||||
expect = index_fn(input2.clone(), index2, value2)
|
||||
actual = script_index_fn(input2.clone(), index2, value2)
|
||||
self.assertEqual(expect, actual)
|
||||
|
||||
def test_export_opnames_interface(self):
|
||||
|
||||
@torch.jit.interface
|
||||
|
@ -1058,6 +1058,21 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
push(stack, std::move(result));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::_unsafe_index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"),
|
||||
[](Stack& stack) {
|
||||
auto indices = pop(stack).to<c10::List<at::Tensor>>();
|
||||
c10::List<c10::optional<at::Tensor>> opt_list_indices;
|
||||
opt_list_indices.reserve(indices.size());
|
||||
for (const auto& ten : indices) {
|
||||
opt_list_indices.push_back(ten);
|
||||
}
|
||||
auto self = pop(stack).toTensor();
|
||||
auto result = at::_unsafe_index(self, opt_list_indices);
|
||||
push(stack, std::move(result));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::_index_put_impl_.hacked_twin(Tensor(a!) self, Tensor[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)"),
|
||||
@ -1113,6 +1128,24 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
push(stack, std::move(result));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::_unsafe_index_put.hacked_twin(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor"),
|
||||
[](Stack& stack) {
|
||||
auto accumulate = pop(stack).toBool();
|
||||
auto values = pop(stack).toTensor();
|
||||
auto indices = pop(stack).to<c10::List<at::Tensor>>();
|
||||
c10::List<c10::optional<at::Tensor>> opt_list_indices;
|
||||
opt_list_indices.reserve(indices.size());
|
||||
for (const auto& ten : indices) {
|
||||
opt_list_indices.push_back(ten);
|
||||
}
|
||||
auto self = pop(stack).toTensor();
|
||||
auto result =
|
||||
at::_unsafe_index_put(self, opt_list_indices, values, accumulate);
|
||||
push(stack, std::move(result));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
// reference function parse_to_conversion in python_arg_parsing.h
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
|
Reference in New Issue
Block a user