mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Flip default on weights_only (#137602)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137602 Approved by: https://github.com/malfet, https://github.com/albanD ghstack dependencies: #138936, #139221, #139433, #139541
This commit is contained in:
committed by
PyTorch MergeBot
parent
f55dfbcf87
commit
ca43ecd599
2
.github/ci_commit_pins/torchbench.txt
vendored
2
.github/ci_commit_pins/torchbench.txt
vendored
@ -1 +1 @@
|
||||
e522b45cd4535b9dfe067aa68d7315755df38f48
|
||||
766a5e3a189384659fd35a68c3b17b88c761aaac
|
||||
|
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
2eb4a60ed14a38260b85b0c765161f0ce45be6d1
|
||||
f71c02d1f457d58371e013632efb016c01bd1866
|
||||
|
@ -1245,7 +1245,8 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
|
||||
module_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
|
||||
|
||||
buffer.seek(0)
|
||||
state_dict_deser = torch.load(buffer)
|
||||
# weights_only=False as ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
|
||||
state_dict_deser = torch.load(buffer, weights_only=False)
|
||||
module_load.load_state_dict(state_dict_deser, strict=False)
|
||||
|
||||
module_load._register_state_dict_hook(state_dict_hook)
|
||||
@ -1289,7 +1290,8 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
|
||||
|
||||
buffer.seek(0)
|
||||
with load_with_process_group(pg):
|
||||
state_dict_deser = torch.load(buffer)
|
||||
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
|
||||
state_dict_deser = torch.load(buffer, weights_only=False)
|
||||
module_load.load_state_dict(state_dict_deser, strict=False)
|
||||
|
||||
# Verify after load.
|
||||
@ -1361,20 +1363,23 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
|
||||
if self.rank != 0:
|
||||
with self.assertRaisesRegex(RuntimeError, "Local rank at save time was"):
|
||||
with load_with_process_group(pg):
|
||||
state_dict_deser = torch.load(buffer)
|
||||
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
|
||||
state_dict_deser = torch.load(buffer, weights_only=False)
|
||||
else:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Local world size at save time was"
|
||||
):
|
||||
with load_with_process_group(pg):
|
||||
state_dict_deser = torch.load(buffer)
|
||||
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
|
||||
state_dict_deser = torch.load(buffer, weights_only=False)
|
||||
|
||||
dist.destroy_process_group()
|
||||
buffer.seek(0)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Need to initialize default process group"
|
||||
):
|
||||
state_dict_deser = torch.load(buffer)
|
||||
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
|
||||
state_dict_deser = torch.load(buffer, weights_only=False)
|
||||
rpc.shutdown()
|
||||
|
||||
@with_comms
|
||||
|
@ -16,6 +16,12 @@ from torch.distributed._shard.sharded_tensor import (
|
||||
Shard,
|
||||
ShardedTensor,
|
||||
)
|
||||
from torch.distributed._shard.sharded_tensor.metadata import (
|
||||
MEM_FORMAT_ENCODING,
|
||||
ShardedTensorMetadata,
|
||||
TensorProperties,
|
||||
)
|
||||
from torch.distributed._shard.sharding_spec import ChunkShardingSpec, ShardMetadata
|
||||
from torch.distributed._state_dict_utils import (
|
||||
_all_gather_sharded_tensor,
|
||||
_gather_state_dict,
|
||||
@ -37,6 +43,7 @@ from torch.distributed.fsdp import (
|
||||
from torch.distributed.fsdp._common_utils import FSDP_PREFIX
|
||||
from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM
|
||||
from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap
|
||||
from torch.distributed.remote_device import _remote_device
|
||||
from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.optim import SGD
|
||||
@ -1160,6 +1167,20 @@ class TestFSDPStateDict(FSDPTest):
|
||||
checkpoint = io.BytesIO()
|
||||
torch.save(state_dict, checkpoint)
|
||||
checkpoint.seek(0)
|
||||
with torch.serialization.safe_globals(
|
||||
[
|
||||
Shard,
|
||||
ShardMetadata,
|
||||
ShardedTensor,
|
||||
ShardedTensorMetadata,
|
||||
TensorProperties,
|
||||
MEM_FORMAT_ENCODING,
|
||||
_remote_device,
|
||||
getattr,
|
||||
ShardedTensor.ProcessGroupState,
|
||||
ChunkShardingSpec,
|
||||
]
|
||||
):
|
||||
state_dict_saved = torch.load(checkpoint)
|
||||
for k, v in state_dict_saved.items():
|
||||
if isinstance(v, ShardedTensor):
|
||||
|
@ -46,7 +46,10 @@ class InPlaceCompilationTests(TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.save(model, os.path.join(tmpdirname, "model.pt"))
|
||||
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
|
||||
# weights_only=False as this is a legacy use case that loads a module
|
||||
loaded_model = torch.load(
|
||||
os.path.join(tmpdirname, "model.pt"), weights_only=False
|
||||
)
|
||||
loaded_model(torch.randn(1, 10))
|
||||
|
||||
def test_state_dict_save(self):
|
||||
@ -58,7 +61,8 @@ class InPlaceCompilationTests(TestCase):
|
||||
torch.save(model.state_dict(), os.path.join(tmpdirname, "model.pt"))
|
||||
loaded_model = ToyModel()
|
||||
loaded_model.load_state_dict(
|
||||
torch.load(os.path.join(tmpdirname, "model.pt"))
|
||||
# weights_only=False as this is a legacy use case that loads a module
|
||||
torch.load(os.path.join(tmpdirname, "model.pt"), weights_only=False)
|
||||
)
|
||||
loaded_model(torch.randn(1, 10))
|
||||
|
||||
|
@ -3002,7 +3002,10 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
|
||||
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
|
||||
# weights_only=False as this is a legacy use case that loads a module
|
||||
loaded_model = torch.load(
|
||||
os.path.join(tmpdirname, "model.pt"), weights_only=False
|
||||
)
|
||||
loaded_model(inp)
|
||||
self.assertTrue(same_two_models(loaded_model, mod, [inp]))
|
||||
self.assertTrue(same_two_models(loaded_model, opt_mod, [inp]))
|
||||
@ -3020,7 +3023,10 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
opt_mod = torch.compile(mod, backend=backend)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
|
||||
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
|
||||
# weights_only=False as this is a legacy use case that loads a module
|
||||
loaded_model = torch.load(
|
||||
os.path.join(tmpdirname, "model.pt"), weights_only=False
|
||||
)
|
||||
torch._dynamo.reset() # force recompiles
|
||||
torch._inductor.metrics.generated_kernel_count = 0
|
||||
opt_mod(inp)
|
||||
|
@ -5,7 +5,8 @@ import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
script_mod = torch.jit.load(sys.argv[1])
|
||||
mod = torch.load(sys.argv[1] + ".orig")
|
||||
# weights_only=False as this is loading a sharded model
|
||||
mod = torch.load(sys.argv[1] + ".orig", weights_only=False)
|
||||
print(script_mod)
|
||||
inp = torch.rand(2, 28 * 28)
|
||||
_ = mod(inp)
|
||||
|
@ -8825,7 +8825,8 @@ class TestNNMPS(NNTestCase):
|
||||
path = download_file('https://download.pytorch.org/test_data/linear.pt')
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', SourceChangeWarning)
|
||||
m = torch.load(path)
|
||||
# weights_only=False as this is a legacy use case that loads a module
|
||||
m = torch.load(path, weights_only=False)
|
||||
input = torch.randn(2, 3, dtype=torch.float)
|
||||
self.assertEqual(m(input).size(), (2, 5))
|
||||
|
||||
@ -8842,7 +8843,8 @@ class TestNNMPS(NNTestCase):
|
||||
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', SourceChangeWarning)
|
||||
m = torch.load(path, encoding='utf-8')
|
||||
# weights_only=False as this is a legacy use case that loads a module
|
||||
m = torch.load(path, encoding='utf-8', weights_only=False)
|
||||
input = torch.randn((1, 1, 1, 1), dtype=torch.float)
|
||||
self.assertEqual(m(input).size(), (1, 1, 1, 1))
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Owner(s): ["module: nestedtensor"]
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
import io
|
||||
import itertools
|
||||
import math
|
||||
@ -3657,7 +3658,8 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
|
||||
["contig", "noncontig_transposed", "noncontig_with_holes"],
|
||||
name_fn=lambda c: c,
|
||||
)
|
||||
def test_serialization(self, device, dtype, contiguity):
|
||||
@parametrize("weights_only", [True, False])
|
||||
def test_serialization(self, device, dtype, contiguity, weights_only):
|
||||
# Test with 3 cases:
|
||||
# 1. contiguous
|
||||
# 2. non-contiguous transposed
|
||||
@ -3693,8 +3695,21 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
|
||||
|
||||
with tempfile.TemporaryFile() as f:
|
||||
torch.save(nt, f)
|
||||
safe_globals = [
|
||||
torch.nested._internal.nested_tensor.NestedTensor,
|
||||
torch.nested._internal.nested_tensor._rebuild_njt,
|
||||
set,
|
||||
torch._dynamo.decorators._DimRange,
|
||||
]
|
||||
f.seek(0)
|
||||
nt_loaded = torch.load(f)
|
||||
ctx = (
|
||||
torch.serialization.safe_globals(safe_globals)
|
||||
if weights_only
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
|
||||
with ctx:
|
||||
nt_loaded = torch.load(f, weights_only=weights_only)
|
||||
|
||||
self.assertIsNot(nt, nt_loaded)
|
||||
# we expect a new offsets tensor -> different nested int upon load
|
||||
|
@ -1196,7 +1196,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
f.seek(0)
|
||||
if unsafe_global:
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
r"use `torch.serialization.add_safe_globals\(\[TwoTensor\]\)` to allowlist"):
|
||||
r"use `torch.serialization.add_safe_globals\(\[TwoTensor\]\)` or .* to allowlist"):
|
||||
torch.load(f, weights_only=True)
|
||||
else:
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
|
@ -322,8 +322,9 @@ class Unpickler:
|
||||
else:
|
||||
raise UnpicklingError(
|
||||
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
|
||||
f"Please use `torch.serialization.add_safe_globals([{name}])` to allowlist "
|
||||
"this global if you trust this class/function."
|
||||
f"Please use `torch.serialization.add_safe_globals([{name}])` or the "
|
||||
f"`torch.serialization.safe_globals([{name}])` context manager to allowlist this global "
|
||||
"if you trust this class/function."
|
||||
)
|
||||
elif key[0] == NEWOBJ[0]:
|
||||
args = self.stack.pop()
|
||||
|
@ -67,6 +67,7 @@ __all__ = [
|
||||
"skip_data",
|
||||
]
|
||||
|
||||
IS_FBCODE = not hasattr(torch.version, "git_version")
|
||||
|
||||
DEFAULT_PROTOCOL = 2
|
||||
|
||||
@ -92,6 +93,10 @@ else:
|
||||
MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment]
|
||||
|
||||
|
||||
def _default_to_weights_only(pickle_module):
|
||||
return pickle_module is None and not IS_FBCODE
|
||||
|
||||
|
||||
# _serialization_tls is used to store thread local state specific to serialization
|
||||
# that needs to be propagated to other files, in particular we use this for
|
||||
# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils)
|
||||
@ -1205,7 +1210,7 @@ def load(
|
||||
# documentation. We need it so that Sphinx doesn't leak `pickle`s path from
|
||||
# the build environment (e.g. `<module 'pickle' from '/leaked/path').
|
||||
|
||||
"""load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)
|
||||
"""load(f, map_location=None, pickle_module=pickle, *, weights_only=True, mmap=None, **pickle_load_args)
|
||||
|
||||
Loads an object saved with :func:`torch.save` from a file.
|
||||
|
||||
@ -1347,6 +1352,11 @@ def load(
|
||||
"is not supported yet. Please call torch.load outside the skip_data context manager."
|
||||
)
|
||||
|
||||
weights_only_not_set = weights_only is None
|
||||
|
||||
if weights_only_not_set:
|
||||
weights_only = _default_to_weights_only(pickle_module)
|
||||
|
||||
true_values = ["1", "y", "yes", "true"]
|
||||
# Add ability to force safe only or non-safe weight loads via environment variables
|
||||
force_weights_only_load = (
|
||||
@ -1364,7 +1374,8 @@ def load(
|
||||
elif force_weights_only_load:
|
||||
weights_only = True
|
||||
elif force_no_weights_only_load:
|
||||
if weights_only is None:
|
||||
# TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD can only override if callsite did not explicitly set weights_only
|
||||
if weights_only_not_set:
|
||||
warnings.warn(
|
||||
"Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the"
|
||||
"`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.",
|
||||
@ -1373,11 +1384,6 @@ def load(
|
||||
)
|
||||
weights_only = False
|
||||
|
||||
if weights_only is None:
|
||||
weights_only, warn_weights_only = False, True
|
||||
else:
|
||||
warn_weights_only = False
|
||||
|
||||
if weights_only:
|
||||
if pickle_module is not None:
|
||||
raise RuntimeError(
|
||||
@ -1385,21 +1391,6 @@ def load(
|
||||
)
|
||||
else:
|
||||
if pickle_module is None:
|
||||
if warn_weights_only:
|
||||
warnings.warn(
|
||||
"You are using `torch.load` with `weights_only=False` (the current default value), which uses "
|
||||
"the default pickle module implicitly. It is possible to construct malicious pickle data "
|
||||
"which will execute arbitrary code during unpickling (See "
|
||||
"https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). "
|
||||
"In a future release, the default value for `weights_only` will be flipped to `True`. This "
|
||||
"limits the functions that could be executed during unpickling. Arbitrary objects will no "
|
||||
"longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the "
|
||||
"user via `torch.serialization.add_safe_globals`. We recommend you start setting "
|
||||
"`weights_only=True` for any use case where you don't have full control of the loaded file. "
|
||||
"Please open an issue on GitHub for any issues related to this experimental feature.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
pickle_module = pickle
|
||||
|
||||
# make flipping default BC-compatible
|
||||
|
Reference in New Issue
Block a user