Allow NJT by default for weights_only torch.load (#140304)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140304
Approved by: https://github.com/jbschlosser
This commit is contained in:
Mikayla Gawarecki
2024-11-12 11:57:47 -08:00
committed by PyTorch MergeBot
parent 096929c1e8
commit 1f28235ee2
2 changed files with 11 additions and 14 deletions

View File

@ -3712,21 +3712,8 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
with tempfile.TemporaryFile() as f:
torch.save(nt, f)
safe_globals = [
torch.nested._internal.nested_tensor.NestedTensor,
torch.nested._internal.nested_tensor._rebuild_njt,
set,
torch._dynamo.decorators._DimRange,
]
f.seek(0)
ctx = (
torch.serialization.safe_globals(safe_globals)
if weights_only
else contextlib.nullcontext()
)
with ctx:
nt_loaded = torch.load(f, weights_only=weights_only)
nt_loaded = torch.load(f, weights_only=weights_only)
self.assertIsNot(nt, nt_loaded)
# we expect a new offsets tensor -> different nested int upon load

View File

@ -182,6 +182,16 @@ def _get_allowed_globals():
"torch.distributed.tensor.placement_types.Shard": torch.distributed.tensor.placement_types.Shard,
}
rc.update(dtensor_rc)
# nested tensor related
rc["torch.nested._internal.nested_tensor.NestedTensor"] = (
torch.nested._internal.nested_tensor.NestedTensor
)
rc["torch.nested._internal.nested_tensor._rebuild_njt"] = (
torch.nested._internal.nested_tensor._rebuild_njt
)
rc["torch._dynamo.decorators._DimRange"] = torch._dynamo.decorators._DimRange
# dtype
for t in torch.storage._dtype_to_storage_type_map().keys():
rc[str(t)] = t