mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix condition for weights_only unpickler for DTensor (#140740)
Same as #140739 but for DTensor (move safe globals for DTensor to `torch.distributed.tensor.__init__` and update error message to let user know `torch.distributed.tensor` must be imported to load DTensor) Differential Revision: [D65961690](https://our.internmc.facebook.com/intern/diff/D65961690) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140740 Approved by: https://github.com/malfet ghstack dependencies: #140739
This commit is contained in:
committed by
PyTorch MergeBot
parent
b63a84804c
commit
f3f305ef3e
@ -2,6 +2,9 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from numpy.testing import assert_array_equal
|
||||
|
||||
@ -28,7 +31,7 @@ from torch.distributed.tensor.parallel import (
|
||||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
@ -542,6 +545,33 @@ class DTensorTest(DTensorTestBase):
|
||||
reloaded_st = torch.load(buffer, weights_only=True)
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
|
||||
@with_comms
|
||||
@unittest.skipIf(
|
||||
IS_FBCODE,
|
||||
"subprocess import torch fails with ModuleNotFoundError: No module named 'torch' in fbcode",
|
||||
)
|
||||
def test_dtensor_save_load_import(self):
|
||||
for should_import in [True, False]:
|
||||
device_mesh = self.build_device_mesh()
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 3)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(sharded_tensor, f)
|
||||
import_string = (
|
||||
"import torch.distributed.tensor;" if should_import else ""
|
||||
)
|
||||
filename = pathlib.Path(f.name)
|
||||
err_msg = (
|
||||
(
|
||||
"_pickle.UnpicklingError: Weights only load failed. "
|
||||
"``torch.distributed.tensor`` must be imported to load DTensors"
|
||||
)
|
||||
if not should_import
|
||||
else None
|
||||
)
|
||||
self._attempt_load_from_subprocess(filename, import_string, err_msg)
|
||||
|
||||
|
||||
class DTensorMeshTest(DTensorTestBase):
|
||||
@property
|
||||
@ -943,9 +973,11 @@ class TestDTensorPlacementTypes(DTensorTestBase):
|
||||
from torch.distributed.tensor._collective_utils import unpad_tensor
|
||||
|
||||
unpadded_list = [
|
||||
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
|
||||
if pad_sizes[i] > 0
|
||||
else tensor
|
||||
(
|
||||
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
|
||||
if pad_sizes[i] > 0
|
||||
else tensor
|
||||
)
|
||||
for i, tensor in enumerate(splitted_tensor_list)
|
||||
]
|
||||
expected_is_tensor_empty = [
|
||||
|
@ -169,19 +169,6 @@ def _get_allowed_globals():
|
||||
"builtins.bytearray": bytearray, # for bytearray
|
||||
"builtins.set": set, # for set
|
||||
}
|
||||
# Only add the dtensor related classes if the dtensor module is available
|
||||
if hasattr(torch.distributed, "tensor"):
|
||||
dtensor_rc: Dict[str, Any] = {
|
||||
# DTensor related
|
||||
"torch.distributed.device_mesh.DeviceMesh": torch.distributed.device_mesh.DeviceMesh,
|
||||
"torch.distributed.tensor._dtensor_spec.DTensorSpec": torch.distributed.tensor._dtensor_spec.DTensorSpec,
|
||||
"torch.distributed.tensor._dtensor_spec.TensorMeta": torch.distributed.tensor._dtensor_spec.TensorMeta,
|
||||
"torch.distributed.tensor.DTensor": torch.distributed.tensor.DTensor,
|
||||
"torch.distributed.tensor.placement_types.Partial": torch.distributed.tensor.placement_types.Partial,
|
||||
"torch.distributed.tensor.placement_types.Replicate": torch.distributed.tensor.placement_types.Replicate,
|
||||
"torch.distributed.tensor.placement_types.Shard": torch.distributed.tensor.placement_types.Shard,
|
||||
}
|
||||
rc.update(dtensor_rc)
|
||||
|
||||
# dtype
|
||||
for t in torch.storage._dtype_to_storage_type_map().keys():
|
||||
@ -341,6 +328,20 @@ class Unpickler:
|
||||
raise UnpicklingError(
|
||||
"``torch.nested`` and ``torch._dynamo`` must be imported to load nested jagged tensors (NJTs)"
|
||||
)
|
||||
elif full_path in (
|
||||
[
|
||||
"torch.distributed.device_mesh.DeviceMesh",
|
||||
"torch.distributed.tensor._dtensor_spec.DTensorSpec",
|
||||
"torch.distributed.tensor._dtensor_spec.TensorMeta",
|
||||
"torch.distributed.tensor.DTensor",
|
||||
"torch.distributed.tensor.placement_types.Partial",
|
||||
"torch.distributed.tensor.placement_types.Replicate",
|
||||
"torch.distributed.tensor.placement_types.Shard",
|
||||
]
|
||||
):
|
||||
raise UnpicklingError(
|
||||
"``torch.distributed.tensor`` must be imported to load DTensors"
|
||||
)
|
||||
else:
|
||||
raise UnpicklingError(
|
||||
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
|
||||
|
@ -45,6 +45,22 @@ __all__ = [
|
||||
"zeros",
|
||||
]
|
||||
|
||||
# For weights_only torch.load
|
||||
from ._dtensor_spec import DTensorSpec as _DTensorSpec, TensorMeta as _TensorMeta
|
||||
|
||||
|
||||
torch.serialization.add_safe_globals(
|
||||
[
|
||||
DeviceMesh,
|
||||
_DTensorSpec,
|
||||
_TensorMeta,
|
||||
DTensor,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# Append DTensor to the list of supported types for foreach implementation for optimizer
|
||||
# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA.
|
||||
|
Reference in New Issue
Block a user