mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Fix index_copy for strided indices (#161333)
By passing strides to strided variant of the tensor Fixes https://github.com/pytorch/pytorch/issues/160993 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161333 Approved by: https://github.com/huydhn, https://github.com/wdvr ghstack dependencies: #161206, #161267
This commit is contained in:
committed by
PyTorch MergeBot
parent
f912c93344
commit
4acdbb8311
@ -358,6 +358,7 @@ kernel void index_copy_strided(
|
||||
constant long* input_strides,
|
||||
constant long* output_strides,
|
||||
constant long* source_strides,
|
||||
constant long& indices_stride,
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
int pos[max_ndim];
|
||||
pos_from_thread_index(int(thread_index), pos, sizes, ndim);
|
||||
@ -374,7 +375,7 @@ kernel void index_copy_strided(
|
||||
// find the last index in the indices array that equals this coordinate
|
||||
int last_matching_index = -1;
|
||||
for (uint i = 0; i < indices_numel; i++) {
|
||||
if (indices[i] == orig_dim) {
|
||||
if (indices[i * indices_stride] == orig_dim) {
|
||||
last_matching_index = int(i);
|
||||
}
|
||||
}
|
||||
@ -413,6 +414,7 @@ kernel void index_copy_strided(
|
||||
constant long*, \
|
||||
constant long*, \
|
||||
constant long*, \
|
||||
constant long&, \
|
||||
uint);
|
||||
|
||||
#define REGISTER_MASKED_FILL_SCALAR(SIZE, DTYPE) \
|
||||
|
@ -282,7 +282,7 @@ TORCH_IMPL_FUNC(index_copy_out_mps)(const Tensor& self,
|
||||
[computeEncoder setComputePipelineState:indexCopyPSO];
|
||||
mtl_setArgs(computeEncoder, result, self, source, index, dim_arg, self.sizes(), ndim, indices_numel);
|
||||
if (!is_dense) {
|
||||
mtl_setArgs<8>(computeEncoder, self.strides(), result.strides(), source.strides());
|
||||
mtl_setArgs<8>(computeEncoder, self.strides(), result.strides(), source.strides(), index.strides());
|
||||
}
|
||||
mtl_dispatch1DJob(computeEncoder, indexCopyPSO, result.numel());
|
||||
}
|
||||
|
@ -1870,7 +1870,7 @@ class TestIndexing(TestCase):
|
||||
self.assertEqual(dest, expected)
|
||||
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
@expectedFailureMPS # See https://github.com/pytorch/pytorch/issues/160993
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat))
|
||||
def test_index_copy(self, device, dtype):
|
||||
# We just test for num_copy <= num_dest, as otherwise there are repeated indices
|
||||
# and the behavior is undefined
|
||||
|
Reference in New Issue
Block a user