mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add environment variable to force no weights_only load (#138225)
In preparation for `weights_only` flip, if users don't have access to the `torch.load` call Pull Request resolved: https://github.com/pytorch/pytorch/pull/138225 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
ec4ce094b2
commit
e24871eb3c
@ -8,7 +8,11 @@ Miscellaneous Environment Variables
|
||||
* - Variable
|
||||
- Description
|
||||
* - ``TORCH_FORCE_WEIGHTS_ONLY_LOAD``
|
||||
- If set to [``1``, ``y``, ``yes``, ``true``], the torch.load will use ``weight_only=True``. For more documentation on this, see :func:`torch.load`.
|
||||
- If set to [``1``, ``y``, ``yes``, ``true``], the torch.load will use ``weights_only=True``. This will happen even if
|
||||
``weights_only=False`` was passed at the callsite. For more documentation on this, see :func:`torch.load`.
|
||||
* - ``TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD``
|
||||
- If set to [``1``, ``y``, ``yes``, ``true``], the torch.load will use ``weights_only=False`` if the ``weights_only`` variable was not
|
||||
passed at the callsite. For more documentation on this, see :func:`torch.load`.
|
||||
* - ``TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT``
|
||||
- Under some conditions, autograd threads can hang on shutdown, therefore we do not wait for them to shutdown indefinitely but rely on timeout that is default set to ``10`` seconds. This environment variable can be used to set the timeout in seconds.
|
||||
* - ``TORCH_DEVICE_BACKEND_AUTOLOAD``
|
||||
|
@ -4308,6 +4308,32 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
f.seek(0)
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
@parametrize("force_weights_only", (True, False))
|
||||
def test_weights_only_env_variables(self, force_weights_only):
|
||||
env_var = "TORCH_FORCE_WEIGHTS_ONLY_LOAD" if force_weights_only else "TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"
|
||||
args = (
|
||||
(pickle.UnpicklingError, "Weights only load failed")
|
||||
if force_weights_only
|
||||
else (UserWarning, "forcing weights_only=False")
|
||||
)
|
||||
ctx = self.assertRaisesRegex if force_weights_only else self.assertWarnsRegex
|
||||
m = torch.nn.Linear(3, 5)
|
||||
with TemporaryFileName() as f:
|
||||
torch.save(m, f)
|
||||
try:
|
||||
old_value = os.environ[env_var] if env_var in os.environ else None
|
||||
os.environ[env_var] = "1"
|
||||
# if weights_only is explicitly set, TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD cannot override it
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, "Weights only load failed"):
|
||||
m = torch.load(f, weights_only=not force_weights_only)
|
||||
with ctx(*args):
|
||||
m = torch.load(f, weights_only=None)
|
||||
finally:
|
||||
if old_value is None:
|
||||
del os.environ[env_var]
|
||||
else:
|
||||
os.environ[env_var] = old_value
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with serialization_method(use_zip=True):
|
||||
return super().run(*args, **kwargs)
|
||||
|
@ -1277,20 +1277,37 @@ def load(
|
||||
"is not supported yet. Please call torch.load outside the skip_data context manager."
|
||||
)
|
||||
|
||||
true_values = ["1", "y", "yes", "true"]
|
||||
# Add ability to force safe only or non-safe weight loads via environment variables
|
||||
force_weights_only_load = (
|
||||
os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0") in true_values
|
||||
)
|
||||
force_no_weights_only_load = (
|
||||
os.getenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "0") in true_values
|
||||
)
|
||||
|
||||
if force_weights_only_load and force_no_weights_only_load:
|
||||
raise RuntimeError(
|
||||
"Only one of `TORCH_FORCE_WEIGHTS_ONLY_LOAD` or `TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD` "
|
||||
"should be set, but both were set."
|
||||
)
|
||||
elif force_weights_only_load:
|
||||
weights_only = True
|
||||
elif force_no_weights_only_load:
|
||||
if weights_only is None:
|
||||
warnings.warn(
|
||||
"Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the"
|
||||
"`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
weights_only = False
|
||||
|
||||
if weights_only is None:
|
||||
weights_only, warn_weights_only = False, True
|
||||
else:
|
||||
warn_weights_only = False
|
||||
|
||||
# Add ability to force safe only weight loads via environment variable
|
||||
if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in [
|
||||
"1",
|
||||
"y",
|
||||
"yes",
|
||||
"true",
|
||||
]:
|
||||
weights_only = True
|
||||
|
||||
if weights_only:
|
||||
if pickle_module is not None:
|
||||
raise RuntimeError(
|
||||
|
Reference in New Issue
Block a user