[MPS] Fix index_copy for scalars (#161267)

By `squeezing the input` when copying into scalar tensor from a 1d one
And enable `test_index_copy_scalars_mps`

Fixes https://github.com/pytorch/pytorch/issues/160737
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161267
Approved by: https://github.com/manuelcandales, https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #161206
This commit is contained in:
Nikita Shulga
2025-08-22 07:10:56 -07:00
committed by PyTorch MergeBot
parent 4c36c8a994
commit c8bb0e4720
2 changed files with 5 additions and 4 deletions

View File

@ -230,7 +230,7 @@ TORCH_IMPL_FUNC(index_copy_out_mps)(const Tensor& self,
index.numel());
int64_t idx = index.item<int64_t>();
TORCH_CHECK(idx == 0, "index_copy_(): the only valid index for a 0-dim tensor is 0, but got ", idx);
result.copy_(source);
result.copy_(source.squeeze());
return;
}
@ -254,11 +254,12 @@ TORCH_IMPL_FUNC(index_copy_out_mps)(const Tensor& self,
}
}
TORCH_CHECK(source.size(dim) == index.numel(),
const auto source_size_dim = source.dim() > 0 ? source.size(dim) : 1;
TORCH_CHECK(index.numel() == source_size_dim,
"index_copy_(): Number of indices (",
index.numel(),
") should be equal to source.size(dim) (",
source.size(dim),
source_size_dim,
")");
auto stream = getCurrentMPSStream();

View File

@ -1913,8 +1913,8 @@ class TestIndexing(TestCase):
# onlyNativeDeviceTypes due to an XLA error:
# https://github.com/pytorch/pytorch/issues/53256
@onlyNativeDeviceTypes
@expectedFailureMPS # See https://github.com/pytorch/pytorch/issues/160737
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat))
def test_index_copy_scalars(self, device, dtype):
# Create the 8 possible combinations of scalar sizes for target / index / source
scalars = (