mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Disallow reshape in slice (#95905)
Disallow reshapes for arrayViews. Current code allows a base shape of `[2, 4, 256]` to be sliced into `[4, 1, 256]` (view's shape) - which is not possible. Slicing a smaller dimension into a bigger one will always error out. Fixes https://github.com/pytorch/pytorch/issues/95883 Pull Request resolved: https://github.com/pytorch/pytorch/pull/95905 Approved by: https://github.com/razarmehr, https://github.com/kulinseth
This commit is contained in:
committed by
PyTorch MergeBot
parent
a32be76a53
commit
304a95435d
@ -478,25 +478,19 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
|
||||
}
|
||||
|
||||
IntArrayRef src_base_shape = getIMPSAllocator()->getBufferShape(src.storage().data());
|
||||
std::vector<int64_t> src_base_squeezed_shape = getSqueezedBaseShape(src, src_base_shape);
|
||||
size_t src_ndim_base = src_base_shape.size();
|
||||
size_t src_squeezed_ndim_base = src_base_squeezed_shape.size();
|
||||
std::vector<int64_t> src_view_squeezed_shape = getViewShape(src, mpsShape, true);
|
||||
size_t src_ndim_view = getViewShape(src, mpsShape, false).size();
|
||||
size_t src_squeezed_ndim_view = src_view_squeezed_shape.size();
|
||||
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape, false);
|
||||
size_t src_ndim_view = src_view_shape.size();
|
||||
|
||||
if (src_ndim_base != src_ndim_view) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src_squeezed_ndim_base == src_squeezed_ndim_view) {
|
||||
for (const auto i: c10::irange(src_squeezed_ndim_base)) {
|
||||
if (src_view_squeezed_shape[i] > src_base_squeezed_shape[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto i: c10::irange(src_ndim_base)) {
|
||||
if (src_view_shape[i] > src_base_shape[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1954,6 +1954,78 @@ class TestMPS(TestCaseMPS):
|
||||
x_cpu = x_cpu + 2
|
||||
self.assertEqual(x, x_cpu)
|
||||
|
||||
def test_reshape_storage_offset(self):
|
||||
# https://github.com/pytorch/pytorch/issues/95883
|
||||
B = 4
|
||||
T = 1
|
||||
|
||||
lin_cpu = nn.Linear(10, 256)
|
||||
lin_mps = nn.Linear(10, 256, device="mps")
|
||||
|
||||
# Use the same weights and bias as the ones from the cpu
|
||||
lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_()
|
||||
lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_()
|
||||
|
||||
x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True)
|
||||
x_cpu = x_mps.detach().clone().cpu().requires_grad_()
|
||||
x_mps = lin_mps(x_mps)
|
||||
x_cpu = lin_cpu(x_cpu)
|
||||
|
||||
self.assertEqual(x_mps.shape, (B, T, 256))
|
||||
self.assertEqual(x_cpu.shape, (B, T, 256))
|
||||
|
||||
cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1)
|
||||
cls_token_cpu = cls_token_mps.detach().clone().cpu()
|
||||
x_mps = torch.cat([cls_token_mps, x_mps], dim=1)
|
||||
x_cpu = torch.cat([cls_token_cpu, x_cpu], dim=1)
|
||||
|
||||
x_mps = x_mps.transpose(0, 1)
|
||||
x_cpu = x_cpu.transpose(0, 1)
|
||||
|
||||
target_mps = torch.rand_like(x_mps)
|
||||
target_cpu = target_mps.detach().clone().cpu()
|
||||
loss_mps = F.mse_loss(x_mps, target_mps)
|
||||
loss_cpu = F.mse_loss(x_cpu, target_cpu)
|
||||
self.assertEqual(loss_mps, loss_cpu)
|
||||
|
||||
loss_mps.backward()
|
||||
loss_cpu.backward()
|
||||
self.assertEqual(x_mps.grad, x_cpu.grad)
|
||||
|
||||
def test_stack(self):
|
||||
# https://github.com/pytorch/pytorch/issues/87856
|
||||
x_cpu = torch.tensor([[1, 2]])
|
||||
x_mps = x_cpu.detach().clone().to("mps")
|
||||
|
||||
y_cpu = torch.stack((x_cpu[:, :1], x_cpu[:, -1:]), dim=-1)
|
||||
y_mps = torch.stack((x_mps[:, :1], x_mps[:, -1:]), dim=-1)
|
||||
|
||||
self.assertEqual(y_cpu, y_mps)
|
||||
|
||||
t_mps = torch.tensor([1, 2, 3, 4], device="mps")
|
||||
t_cpu = t_mps.detach().cpu().detach()
|
||||
|
||||
x_mps = t_mps[2:]
|
||||
y_mps = t_mps[:2]
|
||||
|
||||
x_cpu = t_cpu[2:]
|
||||
y_cpu = t_cpu[:2]
|
||||
|
||||
res_mps = torch.stack((y_mps, x_mps), dim=-1)
|
||||
res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
|
||||
|
||||
self.assertEqual(res_mps, res_cpu)
|
||||
|
||||
def test_unsafe_chunk(self):
|
||||
# https://github.com/pytorch/pytorch/issues/91065
|
||||
a = torch.rand(5, dtype=torch.float32, device="cpu")
|
||||
ret = a.unsafe_chunk(4, 0)
|
||||
y = ret[0] * ret[2]
|
||||
a_mps = a.to("mps")
|
||||
ret_mps = a_mps.unsafe_chunk(4, 0)
|
||||
y_mps = ret_mps[0] * ret_mps[2]
|
||||
self.assertEqual(y, y_mps)
|
||||
|
||||
def test_slice_casting(self):
|
||||
# generate random binary numbers
|
||||
cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8)
|
||||
|
Reference in New Issue
Block a user