Fix condition for weights_only unpickler for DTensor (#140740)

Same as #140739 but for DTensor (move safe globals for DTensor to `torch.distributed.tensor.__init__` and update error message to let user know `torch.distributed.tensor` must be imported to load DTensor)

Differential Revision: [D65961690](https://our.internmc.facebook.com/intern/diff/D65961690)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140740
Approved by: https://github.com/malfet
ghstack dependencies: #140739
This commit is contained in:
Mikayla Gawarecki
2024-11-18 14:03:09 -08:00
committed by PyTorch MergeBot
parent b63a84804c
commit f3f305ef3e
3 changed files with 66 additions and 17 deletions

View File

@ -2,6 +2,9 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import os import os
import pathlib
import tempfile
import unittest
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
@ -28,7 +31,7 @@ from torch.distributed.tensor.parallel import (
parallelize_module, parallelize_module,
RowwiseParallel, RowwiseParallel,
) )
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import IS_FBCODE, run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase, DTensorTestBase,
with_comms, with_comms,
@ -542,6 +545,33 @@ class DTensorTest(DTensorTestBase):
reloaded_st = torch.load(buffer, weights_only=True) reloaded_st = torch.load(buffer, weights_only=True)
self.assertEqual(sharded_tensor, reloaded_st) self.assertEqual(sharded_tensor, reloaded_st)
@with_comms
@unittest.skipIf(
IS_FBCODE,
"subprocess import torch fails with ModuleNotFoundError: No module named 'torch' in fbcode",
)
def test_dtensor_save_load_import(self):
for should_import in [True, False]:
device_mesh = self.build_device_mesh()
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
with tempfile.NamedTemporaryFile() as f:
torch.save(sharded_tensor, f)
import_string = (
"import torch.distributed.tensor;" if should_import else ""
)
filename = pathlib.Path(f.name)
err_msg = (
(
"_pickle.UnpicklingError: Weights only load failed. "
"``torch.distributed.tensor`` must be imported to load DTensors"
)
if not should_import
else None
)
self._attempt_load_from_subprocess(filename, import_string, err_msg)
class DTensorMeshTest(DTensorTestBase): class DTensorMeshTest(DTensorTestBase):
@property @property
@ -943,9 +973,11 @@ class TestDTensorPlacementTypes(DTensorTestBase):
from torch.distributed.tensor._collective_utils import unpad_tensor from torch.distributed.tensor._collective_utils import unpad_tensor
unpadded_list = [ unpadded_list = [
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i]) (
if pad_sizes[i] > 0 unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
else tensor if pad_sizes[i] > 0
else tensor
)
for i, tensor in enumerate(splitted_tensor_list) for i, tensor in enumerate(splitted_tensor_list)
] ]
expected_is_tensor_empty = [ expected_is_tensor_empty = [

View File

@ -169,19 +169,6 @@ def _get_allowed_globals():
"builtins.bytearray": bytearray, # for bytearray "builtins.bytearray": bytearray, # for bytearray
"builtins.set": set, # for set "builtins.set": set, # for set
} }
# Only add the dtensor related classes if the dtensor module is available
if hasattr(torch.distributed, "tensor"):
dtensor_rc: Dict[str, Any] = {
# DTensor related
"torch.distributed.device_mesh.DeviceMesh": torch.distributed.device_mesh.DeviceMesh,
"torch.distributed.tensor._dtensor_spec.DTensorSpec": torch.distributed.tensor._dtensor_spec.DTensorSpec,
"torch.distributed.tensor._dtensor_spec.TensorMeta": torch.distributed.tensor._dtensor_spec.TensorMeta,
"torch.distributed.tensor.DTensor": torch.distributed.tensor.DTensor,
"torch.distributed.tensor.placement_types.Partial": torch.distributed.tensor.placement_types.Partial,
"torch.distributed.tensor.placement_types.Replicate": torch.distributed.tensor.placement_types.Replicate,
"torch.distributed.tensor.placement_types.Shard": torch.distributed.tensor.placement_types.Shard,
}
rc.update(dtensor_rc)
# dtype # dtype
for t in torch.storage._dtype_to_storage_type_map().keys(): for t in torch.storage._dtype_to_storage_type_map().keys():
@ -341,6 +328,20 @@ class Unpickler:
raise UnpicklingError( raise UnpicklingError(
"``torch.nested`` and ``torch._dynamo`` must be imported to load nested jagged tensors (NJTs)" "``torch.nested`` and ``torch._dynamo`` must be imported to load nested jagged tensors (NJTs)"
) )
elif full_path in (
[
"torch.distributed.device_mesh.DeviceMesh",
"torch.distributed.tensor._dtensor_spec.DTensorSpec",
"torch.distributed.tensor._dtensor_spec.TensorMeta",
"torch.distributed.tensor.DTensor",
"torch.distributed.tensor.placement_types.Partial",
"torch.distributed.tensor.placement_types.Replicate",
"torch.distributed.tensor.placement_types.Shard",
]
):
raise UnpicklingError(
"``torch.distributed.tensor`` must be imported to load DTensors"
)
else: else:
raise UnpicklingError( raise UnpicklingError(
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "

View File

@ -45,6 +45,22 @@ __all__ = [
"zeros", "zeros",
] ]
# For weights_only torch.load
from ._dtensor_spec import DTensorSpec as _DTensorSpec, TensorMeta as _TensorMeta
torch.serialization.add_safe_globals(
[
DeviceMesh,
_DTensorSpec,
_TensorMeta,
DTensor,
Partial,
Replicate,
Shard,
]
)
# Append DTensor to the list of supported types for foreach implementation for optimizer # Append DTensor to the list of supported types for foreach implementation for optimizer
# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA. # and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA.