[MPS] Fix index_copy for strided indices (#161333)

By passing strides to strided variant of the tensor

Fixes https://github.com/pytorch/pytorch/issues/160993
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161333
Approved by: https://github.com/huydhn, https://github.com/wdvr
ghstack dependencies: #161206, #161267
This commit is contained in:
Nikita Shulga
2025-08-22 18:05:52 -07:00
committed by PyTorch MergeBot
parent f912c93344
commit 4acdbb8311
3 changed files with 5 additions and 3 deletions

View File

@ -358,6 +358,7 @@ kernel void index_copy_strided(
constant long* input_strides,
constant long* output_strides,
constant long* source_strides,
constant long& indices_stride,
uint thread_index [[thread_position_in_grid]]) {
int pos[max_ndim];
pos_from_thread_index(int(thread_index), pos, sizes, ndim);
@ -374,7 +375,7 @@ kernel void index_copy_strided(
// find the last index in the indices array that equals this coordinate
int last_matching_index = -1;
for (uint i = 0; i < indices_numel; i++) {
if (indices[i] == orig_dim) {
if (indices[i * indices_stride] == orig_dim) {
last_matching_index = int(i);
}
}
@ -413,6 +414,7 @@ kernel void index_copy_strided(
constant long*, \
constant long*, \
constant long*, \
constant long&, \
uint);
#define REGISTER_MASKED_FILL_SCALAR(SIZE, DTYPE) \

View File

@ -282,7 +282,7 @@ TORCH_IMPL_FUNC(index_copy_out_mps)(const Tensor& self,
[computeEncoder setComputePipelineState:indexCopyPSO];
mtl_setArgs(computeEncoder, result, self, source, index, dim_arg, self.sizes(), ndim, indices_numel);
if (!is_dense) {
mtl_setArgs<8>(computeEncoder, self.strides(), result.strides(), source.strides());
mtl_setArgs<8>(computeEncoder, self.strides(), result.strides(), source.strides(), index.strides());
}
mtl_dispatch1DJob(computeEncoder, indexCopyPSO, result.numel());
}

View File

@ -1870,7 +1870,7 @@ class TestIndexing(TestCase):
self.assertEqual(dest, expected)
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
@expectedFailureMPS # See https://github.com/pytorch/pytorch/issues/160993
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat))
def test_index_copy(self, device, dtype):
# We just test for num_copy <= num_dest, as otherwise there are repeated indices
# and the behavior is undefined