mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Allow NJT by default for weights_only torch.load (#140304)"
This reverts commit 1f28235ee2984dbad45b55aa65358b59a7aeea33. Reverted https://github.com/pytorch/pytorch/pull/140304 on behalf of https://github.com/mikaylagawarecki due to Breaking internal tests due to missing torch.nested._internal ([comment](https://github.com/pytorch/pytorch/pull/140304#issuecomment-2473928461))
This commit is contained in:
@ -3712,8 +3712,21 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
|
||||
|
||||
with tempfile.TemporaryFile() as f:
|
||||
torch.save(nt, f)
|
||||
safe_globals = [
|
||||
torch.nested._internal.nested_tensor.NestedTensor,
|
||||
torch.nested._internal.nested_tensor._rebuild_njt,
|
||||
set,
|
||||
torch._dynamo.decorators._DimRange,
|
||||
]
|
||||
f.seek(0)
|
||||
nt_loaded = torch.load(f, weights_only=weights_only)
|
||||
ctx = (
|
||||
torch.serialization.safe_globals(safe_globals)
|
||||
if weights_only
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
|
||||
with ctx:
|
||||
nt_loaded = torch.load(f, weights_only=weights_only)
|
||||
|
||||
self.assertIsNot(nt, nt_loaded)
|
||||
# we expect a new offsets tensor -> different nested int upon load
|
||||
|
@ -182,16 +182,6 @@ def _get_allowed_globals():
|
||||
"torch.distributed.tensor.placement_types.Shard": torch.distributed.tensor.placement_types.Shard,
|
||||
}
|
||||
rc.update(dtensor_rc)
|
||||
|
||||
# nested tensor related
|
||||
rc["torch.nested._internal.nested_tensor.NestedTensor"] = (
|
||||
torch.nested._internal.nested_tensor.NestedTensor
|
||||
)
|
||||
rc["torch.nested._internal.nested_tensor._rebuild_njt"] = (
|
||||
torch.nested._internal.nested_tensor._rebuild_njt
|
||||
)
|
||||
rc["torch._dynamo.decorators._DimRange"] = torch._dynamo.decorators._DimRange
|
||||
|
||||
# dtype
|
||||
for t in torch.storage._dtype_to_storage_type_map().keys():
|
||||
rc[str(t)] = t
|
||||
|
Reference in New Issue
Block a user