Fix failures when default is flipped for weights_only (#127627)

Tests on XLA shard not fixed yet but there is an issue here https://github.com/pytorch/xla/issues/7799

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127627
Approved by: https://github.com/albanD
ghstack dependencies: #132349
This commit is contained in:
Mikayla Gawarecki
2024-08-15 19:48:35 +00:00
committed by PyTorch MergeBot
parent c8ad5e37e8
commit d9576c9440
22 changed files with 135 additions and 78 deletions

View File

@ -27,6 +27,7 @@ from torch.distributed.tensor.parallel import (
parallelize_module,
RowwiseParallel,
)
from torch.serialization import safe_globals
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
@ -535,18 +536,13 @@ class DTensorTest(DTensorTestBase):
buffer = io.BytesIO()
torch.save(sharded_tensor, buffer)
buffer.seek(0)
reloaded_st = torch.load(buffer)
reloaded_st = torch.load(buffer, weights_only=False)
self.assertEqual(sharded_tensor, reloaded_st)
# Test weights_only load
try:
torch.serialization.add_safe_globals(
[DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]
)
with safe_globals([DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]):
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=True)
self.assertEqual(sharded_tensor, reloaded_st)
finally:
torch.serialization.clear_safe_globals()
class DTensorMeshTest(DTensorTestBase):