Revert "Allow NJT by default for weights_only torch.load (#140304)"

This reverts commit 1f28235ee2984dbad45b55aa65358b59a7aeea33.

Reverted https://github.com/pytorch/pytorch/pull/140304 on behalf of https://github.com/mikaylagawarecki due to Breaking internal tests due to missing torch.nested._internal ([comment](https://github.com/pytorch/pytorch/pull/140304#issuecomment-2473928461))
This commit is contained in:
PyTorch MergeBot
2024-11-13 15:24:00 +00:00
parent b4cc5d38b4
commit 5dc6b8c19e
2 changed files with 14 additions and 11 deletions

View File

@ -3712,8 +3712,21 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
with tempfile.TemporaryFile() as f:
torch.save(nt, f)
safe_globals = [
torch.nested._internal.nested_tensor.NestedTensor,
torch.nested._internal.nested_tensor._rebuild_njt,
set,
torch._dynamo.decorators._DimRange,
]
f.seek(0)
nt_loaded = torch.load(f, weights_only=weights_only)
ctx = (
torch.serialization.safe_globals(safe_globals)
if weights_only
else contextlib.nullcontext()
)
with ctx:
nt_loaded = torch.load(f, weights_only=weights_only)
self.assertIsNot(nt, nt_loaded)
# we expect a new offsets tensor -> different nested int upon load

View File

@ -182,16 +182,6 @@ def _get_allowed_globals():
"torch.distributed.tensor.placement_types.Shard": torch.distributed.tensor.placement_types.Shard,
}
rc.update(dtensor_rc)
# nested tensor related
rc["torch.nested._internal.nested_tensor.NestedTensor"] = (
torch.nested._internal.nested_tensor.NestedTensor
)
rc["torch.nested._internal.nested_tensor._rebuild_njt"] = (
torch.nested._internal.nested_tensor._rebuild_njt
)
rc["torch._dynamo.decorators._DimRange"] = torch._dynamo.decorators._DimRange
# dtype
for t in torch.storage._dtype_to_storage_type_map().keys():
rc[str(t)] = t