mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix failures when default is flipped for weights_only (#127627)
Tests on XLA shard not fixed yet but there is an issue here https://github.com/pytorch/xla/issues/7799 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127627 Approved by: https://github.com/albanD ghstack dependencies: #132349
This commit is contained in:
committed by
PyTorch MergeBot
parent
c8ad5e37e8
commit
d9576c9440
@ -27,6 +27,7 @@ from torch.distributed.tensor.parallel import (
|
||||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.serialization import safe_globals
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
@ -535,18 +536,13 @@ class DTensorTest(DTensorTestBase):
|
||||
buffer = io.BytesIO()
|
||||
torch.save(sharded_tensor, buffer)
|
||||
buffer.seek(0)
|
||||
reloaded_st = torch.load(buffer)
|
||||
reloaded_st = torch.load(buffer, weights_only=False)
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
# Test weights_only load
|
||||
try:
|
||||
torch.serialization.add_safe_globals(
|
||||
[DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]
|
||||
)
|
||||
with safe_globals([DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]):
|
||||
buffer.seek(0)
|
||||
reloaded_st = torch.load(buffer, weights_only=True)
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
finally:
|
||||
torch.serialization.clear_safe_globals()
|
||||
|
||||
|
||||
class DTensorMeshTest(DTensorTestBase):
|
||||
|
||||
Reference in New Issue
Block a user