Add hacked_twin overloads for _unsafe indexing functions (#104127)

Fixes #104037

This hacky workaround already exists for the normal overloads.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104127
Approved by: https://github.com/ezyang
This commit is contained in:
Peter Bell
2023-07-04 21:52:36 +00:00
committed by PyTorch MergeBot
parent 2385dad4b3
commit 12ca224662
3 changed files with 68 additions and 0 deletions

View File

@ -332,9 +332,11 @@ PT_OPS_PRIM = [
"aten::copy_.float",
"aten::backward",
"aten::index.Tensor_hacked_twin",
"aten::_unsafe_index.Tensor_hacked_twin",
"aten::_index_put_impl_.hacked_twin",
"aten::index_put_.hacked_twin",
"aten::index_put.hacked_twin",
"aten::_unsafe_index_put.hacked_twin",
"aten::to.prim_Device",
"aten::to.prim_dtype",
"prim::is_cuda",