mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DTensor][Test] Remove safe global context for weights_only torch.load() DTensor (#140173)
We have added DTensor related classes to allowed globals so we can torch.load(DTensor) with weights_only=True. So we don't need the safe_globals context for this test anymore. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140173 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #139949
This commit is contained in:
@ -28,7 +28,6 @@ from torch.distributed.tensor.parallel import (
|
||||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.serialization import safe_globals
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
@ -539,11 +538,9 @@ class DTensorTest(DTensorTestBase):
|
||||
buffer.seek(0)
|
||||
reloaded_st = torch.load(buffer, weights_only=False)
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
# Test weights_only load
|
||||
with safe_globals([DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]):
|
||||
buffer.seek(0)
|
||||
reloaded_st = torch.load(buffer, weights_only=True)
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
buffer.seek(0)
|
||||
reloaded_st = torch.load(buffer, weights_only=True)
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
|
||||
|
||||
class DTensorMeshTest(DTensorTestBase):
|
||||
|
Reference in New Issue
Block a user