mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[PyTorch] Deprecate numpy serialization for MTIA (#157884)
Summary: NumPy based tensor rebuilding from serialization has been deprecated by other backends (eg. [XLA](https://github.com/pytorch/pytorch/pull/137444)). The new flow has CPU storage being constructed with data from the file and then moved to the target backend device. Furthermore, relying on numpy for serialization will fail loudly when torch.load flips weights_only. Reviewed By: andyanwang Differential Revision: D77843238 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157884 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
157683d862
commit
1cb0597a89
@ -330,7 +330,7 @@ class Tensor(torch._C.TensorBase):
|
||||
torch.serialization._serialization_tls.materialize_fake_tensors
|
||||
)
|
||||
|
||||
if self.device.type in ["xla", "maia"] or (
|
||||
if self.device.type in ["xla", "maia", "mtia"] or (
|
||||
not torch._C._has_storage(self)
|
||||
and self.device.type == torch._C._get_privateuse1_backend_name()
|
||||
):
|
||||
@ -343,34 +343,6 @@ class Tensor(torch._C.TensorBase):
|
||||
torch._utils._rebuild_device_tensor_from_cpu_tensor,
|
||||
(cpu_tensor, self.dtype, str(self.device), self.requires_grad),
|
||||
)
|
||||
# Legacy comment that does not hold anymore.
|
||||
# Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors.
|
||||
# We considered a few options:
|
||||
# 1. CPU tensor can't be used here.
|
||||
# Otherwise in torch.load CPU storage is reconstructed with randomly
|
||||
# initialized data, moved onto backend device, and then storage is updated
|
||||
# to the serialized content. This works perfectly for CPU/CUDA but not these backends;
|
||||
# their tensors are disconnected with storage so they don't get the update.
|
||||
# 2. Python list is not a good fit due to performance reason.
|
||||
# `tolist()` converts every single element in the tensor into python objects
|
||||
# and serialize them one by one.
|
||||
if self.device.type in ["mtia"]:
|
||||
# Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't
|
||||
# support BFloat16. The rebuild tensor from numpy takes in the original self.dtype,
|
||||
# this would reconstruct the BFloat16 tensor from numpy.
|
||||
if skip_data:
|
||||
raise RuntimeError(
|
||||
"Cannot serialize tensors on backends with no storage under skip_data context manager"
|
||||
)
|
||||
numpy_tensor = (
|
||||
self.cpu().numpy()
|
||||
if self.dtype != torch.bfloat16
|
||||
else self.cpu().to(torch.float32).numpy()
|
||||
)
|
||||
return (
|
||||
torch._utils._rebuild_device_tensor_from_numpy,
|
||||
(numpy_tensor, self.dtype, str(self.device), self.requires_grad),
|
||||
)
|
||||
if self.device.type == "meta":
|
||||
# NB: This implementation BREAKS storage sharing. Current
|
||||
# hypothesis is that no one cares for meta tensors.
|
||||
|
Reference in New Issue
Block a user