mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow NJT by default for weights_only torch.load (#140304)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140304 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
096929c1e8
commit
1f28235ee2
@ -3712,21 +3712,8 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
|
||||
|
||||
with tempfile.TemporaryFile() as f:
|
||||
torch.save(nt, f)
|
||||
safe_globals = [
|
||||
torch.nested._internal.nested_tensor.NestedTensor,
|
||||
torch.nested._internal.nested_tensor._rebuild_njt,
|
||||
set,
|
||||
torch._dynamo.decorators._DimRange,
|
||||
]
|
||||
f.seek(0)
|
||||
ctx = (
|
||||
torch.serialization.safe_globals(safe_globals)
|
||||
if weights_only
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
|
||||
with ctx:
|
||||
nt_loaded = torch.load(f, weights_only=weights_only)
|
||||
nt_loaded = torch.load(f, weights_only=weights_only)
|
||||
|
||||
self.assertIsNot(nt, nt_loaded)
|
||||
# we expect a new offsets tensor -> different nested int upon load
|
||||
|
@ -182,6 +182,16 @@ def _get_allowed_globals():
|
||||
"torch.distributed.tensor.placement_types.Shard": torch.distributed.tensor.placement_types.Shard,
|
||||
}
|
||||
rc.update(dtensor_rc)
|
||||
|
||||
# nested tensor related
|
||||
rc["torch.nested._internal.nested_tensor.NestedTensor"] = (
|
||||
torch.nested._internal.nested_tensor.NestedTensor
|
||||
)
|
||||
rc["torch.nested._internal.nested_tensor._rebuild_njt"] = (
|
||||
torch.nested._internal.nested_tensor._rebuild_njt
|
||||
)
|
||||
rc["torch._dynamo.decorators._DimRange"] = torch._dynamo.decorators._DimRange
|
||||
|
||||
# dtype
|
||||
for t in torch.storage._dtype_to_storage_type_map().keys():
|
||||
rc[str(t)] = t
|
||||
|
Reference in New Issue
Block a user