[PyTorch] Deprecate numpy serialization for MTIA (#157884)

Summary:
NumPy based tensor rebuilding from serialization has been deprecated by other backends (eg. [XLA](https://github.com/pytorch/pytorch/pull/137444)). The new flow has CPU storage being constructed with data from the file and then moved to the target backend device.

Furthermore, relying on numpy for serialization will fail loudly when torch.load flips weights_only.

Reviewed By: andyanwang

Differential Revision: D77843238

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157884
Approved by: https://github.com/albanD
This commit is contained in:
Simon Mahns
2025-07-11 17:57:30 +00:00
committed by PyTorch MergeBot
parent 157683d862
commit 1cb0597a89

View File

@ -330,7 +330,7 @@ class Tensor(torch._C.TensorBase):
torch.serialization._serialization_tls.materialize_fake_tensors
)
if self.device.type in ["xla", "maia"] or (
if self.device.type in ["xla", "maia", "mtia"] or (
not torch._C._has_storage(self)
and self.device.type == torch._C._get_privateuse1_backend_name()
):
@ -343,34 +343,6 @@ class Tensor(torch._C.TensorBase):
torch._utils._rebuild_device_tensor_from_cpu_tensor,
(cpu_tensor, self.dtype, str(self.device), self.requires_grad),
)
# Legacy comment that does not hold anymore.
# Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors.
# We considered a few options:
# 1. CPU tensor can't be used here.
# Otherwise in torch.load CPU storage is reconstructed with randomly
# initialized data, moved onto backend device, and then storage is updated
# to the serialized content. This works perfectly for CPU/CUDA but not these backends;
# their tensors are disconnected with storage so they don't get the update.
# 2. Python list is not a good fit due to performance reason.
# `tolist()` converts every single element in the tensor into python objects
# and serialize them one by one.
if self.device.type in ["mtia"]:
# Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't
# support BFloat16. The rebuild tensor from numpy takes in the original self.dtype,
# this would reconstruct the BFloat16 tensor from numpy.
if skip_data:
raise RuntimeError(
"Cannot serialize tensors on backends with no storage under skip_data context manager"
)
numpy_tensor = (
self.cpu().numpy()
if self.dtype != torch.bfloat16
else self.cpu().to(torch.float32).numpy()
)
return (
torch._utils._rebuild_device_tensor_from_numpy,
(numpy_tensor, self.dtype, str(self.device), self.requires_grad),
)
if self.device.type == "meta":
# NB: This implementation BREAKS storage sharing. Current
# hypothesis is that no one cares for meta tensors.