mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix condition for weights_only unpickler for DTensor (#140740)
Same as #140739 but for DTensor (move safe globals for DTensor to `torch.distributed.tensor.__init__` and update error message to let user know `torch.distributed.tensor` must be imported to load DTensor) Differential Revision: [D65961690](https://our.internmc.facebook.com/intern/diff/D65961690) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140740 Approved by: https://github.com/malfet ghstack dependencies: #140739
This commit is contained in:
committed by
PyTorch MergeBot
parent
b63a84804c
commit
f3f305ef3e
@ -2,6 +2,9 @@
|
|||||||
# Owner(s): ["oncall: distributed"]
|
# Owner(s): ["oncall: distributed"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
from numpy.testing import assert_array_equal
|
from numpy.testing import assert_array_equal
|
||||||
|
|
||||||
@ -28,7 +31,7 @@ from torch.distributed.tensor.parallel import (
|
|||||||
parallelize_module,
|
parallelize_module,
|
||||||
RowwiseParallel,
|
RowwiseParallel,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import run_tests
|
from torch.testing._internal.common_utils import IS_FBCODE, run_tests
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
DTensorTestBase,
|
DTensorTestBase,
|
||||||
with_comms,
|
with_comms,
|
||||||
@ -542,6 +545,33 @@ class DTensorTest(DTensorTestBase):
|
|||||||
reloaded_st = torch.load(buffer, weights_only=True)
|
reloaded_st = torch.load(buffer, weights_only=True)
|
||||||
self.assertEqual(sharded_tensor, reloaded_st)
|
self.assertEqual(sharded_tensor, reloaded_st)
|
||||||
|
|
||||||
|
@with_comms
|
||||||
|
@unittest.skipIf(
|
||||||
|
IS_FBCODE,
|
||||||
|
"subprocess import torch fails with ModuleNotFoundError: No module named 'torch' in fbcode",
|
||||||
|
)
|
||||||
|
def test_dtensor_save_load_import(self):
|
||||||
|
for should_import in [True, False]:
|
||||||
|
device_mesh = self.build_device_mesh()
|
||||||
|
placements = [Shard(0)]
|
||||||
|
local_tensor = torch.randn(3, 3)
|
||||||
|
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
|
||||||
|
with tempfile.NamedTemporaryFile() as f:
|
||||||
|
torch.save(sharded_tensor, f)
|
||||||
|
import_string = (
|
||||||
|
"import torch.distributed.tensor;" if should_import else ""
|
||||||
|
)
|
||||||
|
filename = pathlib.Path(f.name)
|
||||||
|
err_msg = (
|
||||||
|
(
|
||||||
|
"_pickle.UnpicklingError: Weights only load failed. "
|
||||||
|
"``torch.distributed.tensor`` must be imported to load DTensors"
|
||||||
|
)
|
||||||
|
if not should_import
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self._attempt_load_from_subprocess(filename, import_string, err_msg)
|
||||||
|
|
||||||
|
|
||||||
class DTensorMeshTest(DTensorTestBase):
|
class DTensorMeshTest(DTensorTestBase):
|
||||||
@property
|
@property
|
||||||
@ -943,9 +973,11 @@ class TestDTensorPlacementTypes(DTensorTestBase):
|
|||||||
from torch.distributed.tensor._collective_utils import unpad_tensor
|
from torch.distributed.tensor._collective_utils import unpad_tensor
|
||||||
|
|
||||||
unpadded_list = [
|
unpadded_list = [
|
||||||
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
|
(
|
||||||
if pad_sizes[i] > 0
|
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
|
||||||
else tensor
|
if pad_sizes[i] > 0
|
||||||
|
else tensor
|
||||||
|
)
|
||||||
for i, tensor in enumerate(splitted_tensor_list)
|
for i, tensor in enumerate(splitted_tensor_list)
|
||||||
]
|
]
|
||||||
expected_is_tensor_empty = [
|
expected_is_tensor_empty = [
|
||||||
|
@ -169,19 +169,6 @@ def _get_allowed_globals():
|
|||||||
"builtins.bytearray": bytearray, # for bytearray
|
"builtins.bytearray": bytearray, # for bytearray
|
||||||
"builtins.set": set, # for set
|
"builtins.set": set, # for set
|
||||||
}
|
}
|
||||||
# Only add the dtensor related classes if the dtensor module is available
|
|
||||||
if hasattr(torch.distributed, "tensor"):
|
|
||||||
dtensor_rc: Dict[str, Any] = {
|
|
||||||
# DTensor related
|
|
||||||
"torch.distributed.device_mesh.DeviceMesh": torch.distributed.device_mesh.DeviceMesh,
|
|
||||||
"torch.distributed.tensor._dtensor_spec.DTensorSpec": torch.distributed.tensor._dtensor_spec.DTensorSpec,
|
|
||||||
"torch.distributed.tensor._dtensor_spec.TensorMeta": torch.distributed.tensor._dtensor_spec.TensorMeta,
|
|
||||||
"torch.distributed.tensor.DTensor": torch.distributed.tensor.DTensor,
|
|
||||||
"torch.distributed.tensor.placement_types.Partial": torch.distributed.tensor.placement_types.Partial,
|
|
||||||
"torch.distributed.tensor.placement_types.Replicate": torch.distributed.tensor.placement_types.Replicate,
|
|
||||||
"torch.distributed.tensor.placement_types.Shard": torch.distributed.tensor.placement_types.Shard,
|
|
||||||
}
|
|
||||||
rc.update(dtensor_rc)
|
|
||||||
|
|
||||||
# dtype
|
# dtype
|
||||||
for t in torch.storage._dtype_to_storage_type_map().keys():
|
for t in torch.storage._dtype_to_storage_type_map().keys():
|
||||||
@ -341,6 +328,20 @@ class Unpickler:
|
|||||||
raise UnpicklingError(
|
raise UnpicklingError(
|
||||||
"``torch.nested`` and ``torch._dynamo`` must be imported to load nested jagged tensors (NJTs)"
|
"``torch.nested`` and ``torch._dynamo`` must be imported to load nested jagged tensors (NJTs)"
|
||||||
)
|
)
|
||||||
|
elif full_path in (
|
||||||
|
[
|
||||||
|
"torch.distributed.device_mesh.DeviceMesh",
|
||||||
|
"torch.distributed.tensor._dtensor_spec.DTensorSpec",
|
||||||
|
"torch.distributed.tensor._dtensor_spec.TensorMeta",
|
||||||
|
"torch.distributed.tensor.DTensor",
|
||||||
|
"torch.distributed.tensor.placement_types.Partial",
|
||||||
|
"torch.distributed.tensor.placement_types.Replicate",
|
||||||
|
"torch.distributed.tensor.placement_types.Shard",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
raise UnpicklingError(
|
||||||
|
"``torch.distributed.tensor`` must be imported to load DTensors"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise UnpicklingError(
|
raise UnpicklingError(
|
||||||
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
|
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
|
||||||
|
@ -45,6 +45,22 @@ __all__ = [
|
|||||||
"zeros",
|
"zeros",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# For weights_only torch.load
|
||||||
|
from ._dtensor_spec import DTensorSpec as _DTensorSpec, TensorMeta as _TensorMeta
|
||||||
|
|
||||||
|
|
||||||
|
torch.serialization.add_safe_globals(
|
||||||
|
[
|
||||||
|
DeviceMesh,
|
||||||
|
_DTensorSpec,
|
||||||
|
_TensorMeta,
|
||||||
|
DTensor,
|
||||||
|
Partial,
|
||||||
|
Replicate,
|
||||||
|
Shard,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Append DTensor to the list of supported types for foreach implementation for optimizer
|
# Append DTensor to the list of supported types for foreach implementation for optimizer
|
||||||
# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA.
|
# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA.
|
||||||
|
Reference in New Issue
Block a user