mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Fix index_copy
for scalars (#161267)
By `squeezing the input` when copying into scalar tensor from a 1d one And enable `test_index_copy_scalars_mps` Fixes https://github.com/pytorch/pytorch/issues/160737 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161267 Approved by: https://github.com/manuelcandales, https://github.com/Skylion007, https://github.com/dcci ghstack dependencies: #161206
This commit is contained in:
committed by
PyTorch MergeBot
parent
4c36c8a994
commit
c8bb0e4720
@ -230,7 +230,7 @@ TORCH_IMPL_FUNC(index_copy_out_mps)(const Tensor& self,
|
||||
index.numel());
|
||||
int64_t idx = index.item<int64_t>();
|
||||
TORCH_CHECK(idx == 0, "index_copy_(): the only valid index for a 0-dim tensor is 0, but got ", idx);
|
||||
result.copy_(source);
|
||||
result.copy_(source.squeeze());
|
||||
return;
|
||||
}
|
||||
|
||||
@ -254,11 +254,12 @@ TORCH_IMPL_FUNC(index_copy_out_mps)(const Tensor& self,
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_CHECK(source.size(dim) == index.numel(),
|
||||
const auto source_size_dim = source.dim() > 0 ? source.size(dim) : 1;
|
||||
TORCH_CHECK(index.numel() == source_size_dim,
|
||||
"index_copy_(): Number of indices (",
|
||||
index.numel(),
|
||||
") should be equal to source.size(dim) (",
|
||||
source.size(dim),
|
||||
source_size_dim,
|
||||
")");
|
||||
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
@ -1913,8 +1913,8 @@ class TestIndexing(TestCase):
|
||||
# onlyNativeDeviceTypes due to an XLA error:
|
||||
# https://github.com/pytorch/pytorch/issues/53256
|
||||
@onlyNativeDeviceTypes
|
||||
@expectedFailureMPS # See https://github.com/pytorch/pytorch/issues/160737
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat))
|
||||
def test_index_copy_scalars(self, device, dtype):
|
||||
# Create the 8 possible combinations of scalar sizes for target / index / source
|
||||
scalars = (
|
||||
|
Reference in New Issue
Block a user