[MPS][BE] Delete as_strided_tensorimpl_mps (#157772)

Because it's just copy-n-paste of `as_strided_tensorimpl` with call to `updateTensorBaseShape`, which is not called/used anywhere else.

Fixes https://github.com/pytorch/pytorch/issues/152701
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157772
Approved by: https://github.com/Skylion007
This commit is contained in:
Nikita Shulga
2025-07-08 08:49:10 -07:00
committed by PyTorch MergeBot
parent bbe681ed51
commit a5c61eb78d
4 changed files with 13 additions and 48 deletions

View File

@ -1408,9 +1408,6 @@ Tensor as_strided_tensorimpl(
IntArrayRef size,
IntArrayRef stride,
std::optional<int64_t> storage_offset_) {
TORCH_INTERNAL_ASSERT(
!self.is_mps(),
"as_strided_tensorimpl does not work with MPS; call self.as_strided(...) instead");
auto storage_offset = storage_offset_.value_or(self.storage_offset());
auto result = at::detail::make_tensor<TensorImpl>(
c10::TensorImpl::VIEW,

View File

@ -17,26 +17,7 @@
#include <ATen/ops/view_as_real.h>
#endif
namespace at::native {
namespace mps {
static IntArrayRef updateTensorBaseShape(const Tensor& self) {
IntArrayRef base_shape = getIMPSAllocator()->getBufferShape(self.storage().data());
// if there's no base_shape stored in MPSAllocator, then infer it from tensor's size and store it
if (base_shape.size() == 0) {
// IntArrayRef wouldn't own the data, so we use a static storage
static const int64_t shape_1d = 1;
// self.sizes().size() could be zero
base_shape = self.sizes().size()
? self.sizes()
: ((self.is_view() && self._base().sizes().size()) ? self._base().sizes() : IntArrayRef(&shape_1d, 1));
// base_shape will be retained in MPSAllocator until buffer gets recycled
if (self.storage().data())
getIMPSAllocator()->setBufferShape(self.storage().data(), base_shape);
}
return base_shape;
}
namespace at::native::mps {
// For both scatter and gather kernels, there are 4 specized ones (for 1D to 4D tensor)
// and one generic, for 5+D ones. Assumption (to be tested) about specialized kernels
@ -198,26 +179,4 @@ Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output) {
return output;
}
} // namespace mps
// implementation of as_strided() op
Tensor as_strided_tensorimpl_mps(const Tensor& self,
IntArrayRef size,
IntArrayRef stride,
std::optional<int64_t> storage_offset_) {
auto storage_offset = storage_offset_.value_or(self.storage_offset());
auto result =
detail::make_tensor<TensorImpl>(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
setStrided(result, size, stride, storage_offset);
// creating the view graph will be deferred until gatherViewTensor() or scatterViewTensor() are called.
// In as_strided, we just update the base shape of the buffer in order to retrieve it later
// when we create/run the view graph.
IntArrayRef base_shape = mps::updateTensorBaseShape(self);
TORCH_INTERNAL_ASSERT(
!base_shape.empty(), "Failed to update the base shape of tensor's buffer at ", self.storage().data());
return result;
}
} // namespace at::native
} // namespace at::native::mps

View File

@ -941,9 +941,8 @@
- func: as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)
variants: function, method
dispatch:
ZeroTensor, CPU, CUDA, MTIA: as_strided_tensorimpl
ZeroTensor, CPU, CUDA, MTIA, MPS: as_strided_tensorimpl
Meta: as_strided_tensorimpl_meta_symint
MPS: as_strided_tensorimpl_mps
QuantizedCPU, QuantizedCUDA: as_strided_qtensorimpl
device_check: NoCheck
device_guard: False

View File

@ -3592,6 +3592,16 @@ class TestMPS(TestCaseMPS):
# TODO: enable memory format test
# self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous())
# See https://github.com/pytorch/pytorch/issues/152701
def test_jacfwd_cat(self):
def fn(x, y):
return torch.cat((x, y))
x = torch.rand(2, device="mps")
y = torch.rand(3, device="mps")
rc = torch.func.jacfwd(fn)(x, y)
self.assertEqual(rc.shape, (5, 2))
# See https://github.com/pytorch/pytorch/issues/85967
def test_from_numpy_non_contiguous(self):
a = np.arange(9).reshape(3, 3)[:, :2]