[MPS] Disallow reshape in slice (#95905)

Disallow reshapes for arrayViews.
Current code allows a base shape of `[2, 4, 256]` to be sliced into `[4, 1, 256]` (view's shape) - which is not possible. Slicing a smaller dimension into a bigger one will always error out.

Fixes https://github.com/pytorch/pytorch/issues/95883
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95905
Approved by: https://github.com/razarmehr, https://github.com/kulinseth
This commit is contained in:
Denis Vieriu
2023-03-03 08:08:31 +00:00
committed by PyTorch MergeBot
parent a32be76a53
commit 304a95435d
2 changed files with 79 additions and 13 deletions

View File

@ -478,25 +478,19 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
}
IntArrayRef src_base_shape = getIMPSAllocator()->getBufferShape(src.storage().data());
std::vector<int64_t> src_base_squeezed_shape = getSqueezedBaseShape(src, src_base_shape);
size_t src_ndim_base = src_base_shape.size();
size_t src_squeezed_ndim_base = src_base_squeezed_shape.size();
std::vector<int64_t> src_view_squeezed_shape = getViewShape(src, mpsShape, true);
size_t src_ndim_view = getViewShape(src, mpsShape, false).size();
size_t src_squeezed_ndim_view = src_view_squeezed_shape.size();
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape, false);
size_t src_ndim_view = src_view_shape.size();
if (src_ndim_base != src_ndim_view) {
return false;
}
if (src_squeezed_ndim_base == src_squeezed_ndim_view) {
for (const auto i: c10::irange(src_squeezed_ndim_base)) {
if (src_view_squeezed_shape[i] > src_base_squeezed_shape[i]) {
return false;
}
}
}
for (const auto i: c10::irange(src_ndim_base)) {
if (src_view_shape[i] > src_base_shape[i]) {
return false;
}
}
return true;
}

View File

@ -1954,6 +1954,78 @@ class TestMPS(TestCaseMPS):
x_cpu = x_cpu + 2
self.assertEqual(x, x_cpu)
def test_reshape_storage_offset(self):
# https://github.com/pytorch/pytorch/issues/95883
B = 4
T = 1
lin_cpu = nn.Linear(10, 256)
lin_mps = nn.Linear(10, 256, device="mps")
# Use the same weights and bias as the ones from the cpu
lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_()
lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_()
x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True)
x_cpu = x_mps.detach().clone().cpu().requires_grad_()
x_mps = lin_mps(x_mps)
x_cpu = lin_cpu(x_cpu)
self.assertEqual(x_mps.shape, (B, T, 256))
self.assertEqual(x_cpu.shape, (B, T, 256))
cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1)
cls_token_cpu = cls_token_mps.detach().clone().cpu()
x_mps = torch.cat([cls_token_mps, x_mps], dim=1)
x_cpu = torch.cat([cls_token_cpu, x_cpu], dim=1)
x_mps = x_mps.transpose(0, 1)
x_cpu = x_cpu.transpose(0, 1)
target_mps = torch.rand_like(x_mps)
target_cpu = target_mps.detach().clone().cpu()
loss_mps = F.mse_loss(x_mps, target_mps)
loss_cpu = F.mse_loss(x_cpu, target_cpu)
self.assertEqual(loss_mps, loss_cpu)
loss_mps.backward()
loss_cpu.backward()
self.assertEqual(x_mps.grad, x_cpu.grad)
def test_stack(self):
# https://github.com/pytorch/pytorch/issues/87856
x_cpu = torch.tensor([[1, 2]])
x_mps = x_cpu.detach().clone().to("mps")
y_cpu = torch.stack((x_cpu[:, :1], x_cpu[:, -1:]), dim=-1)
y_mps = torch.stack((x_mps[:, :1], x_mps[:, -1:]), dim=-1)
self.assertEqual(y_cpu, y_mps)
t_mps = torch.tensor([1, 2, 3, 4], device="mps")
t_cpu = t_mps.detach().cpu().detach()
x_mps = t_mps[2:]
y_mps = t_mps[:2]
x_cpu = t_cpu[2:]
y_cpu = t_cpu[:2]
res_mps = torch.stack((y_mps, x_mps), dim=-1)
res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
self.assertEqual(res_mps, res_cpu)
def test_unsafe_chunk(self):
# https://github.com/pytorch/pytorch/issues/91065
a = torch.rand(5, dtype=torch.float32, device="cpu")
ret = a.unsafe_chunk(4, 0)
y = ret[0] * ret[2]
a_mps = a.to("mps")
ret_mps = a_mps.unsafe_chunk(4, 0)
y_mps = ret_mps[0] * ret_mps[2]
self.assertEqual(y, y_mps)
def test_slice_casting(self):
# generate random binary numbers
cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8)