[DTensor][Test] Remove safe global context for weights_only torch.load() DTensor (#140173)

We have added DTensor related classes to allowed globals so we can torch.load(DTensor) with weights_only=True. So we don't need the safe_globals context for this test anymore.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140173
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #139949
This commit is contained in:
wz337
2024-11-08 13:11:44 -08:00
committed by PyTorch MergeBot
parent 72976b2486
commit 4893e248a8

View File

@ -28,7 +28,6 @@ from torch.distributed.tensor.parallel import (
parallelize_module,
RowwiseParallel,
)
from torch.serialization import safe_globals
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
@ -539,11 +538,9 @@ class DTensorTest(DTensorTestBase):
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=False)
self.assertEqual(sharded_tensor, reloaded_st)
# Test weights_only load
with safe_globals([DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]):
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=True)
self.assertEqual(sharded_tensor, reloaded_st)
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=True)
self.assertEqual(sharded_tensor, reloaded_st)
class DTensorMeshTest(DTensorTestBase):