mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS][BE] Delete as_strided_tensorimpl_mps (#157772)
Because it's just copy-n-paste of `as_strided_tensorimpl` with call to `updateTensorBaseShape`, which is not called/used anywhere else. Fixes https://github.com/pytorch/pytorch/issues/152701 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157772 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
bbe681ed51
commit
a5c61eb78d
@ -1408,9 +1408,6 @@ Tensor as_strided_tensorimpl(
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
std::optional<int64_t> storage_offset_) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!self.is_mps(),
|
||||
"as_strided_tensorimpl does not work with MPS; call self.as_strided(...) instead");
|
||||
auto storage_offset = storage_offset_.value_or(self.storage_offset());
|
||||
auto result = at::detail::make_tensor<TensorImpl>(
|
||||
c10::TensorImpl::VIEW,
|
||||
|
||||
@ -17,26 +17,7 @@
|
||||
#include <ATen/ops/view_as_real.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
static IntArrayRef updateTensorBaseShape(const Tensor& self) {
|
||||
IntArrayRef base_shape = getIMPSAllocator()->getBufferShape(self.storage().data());
|
||||
// if there's no base_shape stored in MPSAllocator, then infer it from tensor's size and store it
|
||||
if (base_shape.size() == 0) {
|
||||
// IntArrayRef wouldn't own the data, so we use a static storage
|
||||
static const int64_t shape_1d = 1;
|
||||
// self.sizes().size() could be zero
|
||||
base_shape = self.sizes().size()
|
||||
? self.sizes()
|
||||
: ((self.is_view() && self._base().sizes().size()) ? self._base().sizes() : IntArrayRef(&shape_1d, 1));
|
||||
|
||||
// base_shape will be retained in MPSAllocator until buffer gets recycled
|
||||
if (self.storage().data())
|
||||
getIMPSAllocator()->setBufferShape(self.storage().data(), base_shape);
|
||||
}
|
||||
return base_shape;
|
||||
}
|
||||
namespace at::native::mps {
|
||||
|
||||
// For both scatter and gather kernels, there are 4 specized ones (for 1D to 4D tensor)
|
||||
// and one generic, for 5+D ones. Assumption (to be tested) about specialized kernels
|
||||
@ -198,26 +179,4 @@ Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output) {
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
// implementation of as_strided() op
|
||||
Tensor as_strided_tensorimpl_mps(const Tensor& self,
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
std::optional<int64_t> storage_offset_) {
|
||||
auto storage_offset = storage_offset_.value_or(self.storage_offset());
|
||||
auto result =
|
||||
detail::make_tensor<TensorImpl>(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
|
||||
setStrided(result, size, stride, storage_offset);
|
||||
|
||||
// creating the view graph will be deferred until gatherViewTensor() or scatterViewTensor() are called.
|
||||
// In as_strided, we just update the base shape of the buffer in order to retrieve it later
|
||||
// when we create/run the view graph.
|
||||
IntArrayRef base_shape = mps::updateTensorBaseShape(self);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!base_shape.empty(), "Failed to update the base shape of tensor's buffer at ", self.storage().data());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
} // namespace at::native::mps
|
||||
|
||||
@ -941,9 +941,8 @@
|
||||
- func: as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)
|
||||
variants: function, method
|
||||
dispatch:
|
||||
ZeroTensor, CPU, CUDA, MTIA: as_strided_tensorimpl
|
||||
ZeroTensor, CPU, CUDA, MTIA, MPS: as_strided_tensorimpl
|
||||
Meta: as_strided_tensorimpl_meta_symint
|
||||
MPS: as_strided_tensorimpl_mps
|
||||
QuantizedCPU, QuantizedCUDA: as_strided_qtensorimpl
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
|
||||
@ -3592,6 +3592,16 @@ class TestMPS(TestCaseMPS):
|
||||
# TODO: enable memory format test
|
||||
# self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous())
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/152701
|
||||
def test_jacfwd_cat(self):
|
||||
def fn(x, y):
|
||||
return torch.cat((x, y))
|
||||
|
||||
x = torch.rand(2, device="mps")
|
||||
y = torch.rand(3, device="mps")
|
||||
rc = torch.func.jacfwd(fn)(x, y)
|
||||
self.assertEqual(rc.shape, (5, 2))
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/85967
|
||||
def test_from_numpy_non_contiguous(self):
|
||||
a = np.arange(9).reshape(3, 3)[:, :2]
|
||||
|
||||
Reference in New Issue
Block a user