Add environment variable to force no weights_only load (#138225)

In preparation for `weights_only` flip, if users don't have access to the `torch.load` call

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138225
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-10-21 19:59:13 +00:00
committed by PyTorch MergeBot
parent ec4ce094b2
commit e24871eb3c
3 changed files with 57 additions and 10 deletions

View File

@ -8,7 +8,11 @@ Miscellaneous Environment Variables
* - Variable
- Description
* - ``TORCH_FORCE_WEIGHTS_ONLY_LOAD``
- If set to [``1``, ``y``, ``yes``, ``true``], the torch.load will use ``weight_only=True``. For more documentation on this, see :func:`torch.load`.
- If set to [``1``, ``y``, ``yes``, ``true``], the torch.load will use ``weights_only=True``. This will happen even if
``weights_only=False`` was passed at the callsite. For more documentation on this, see :func:`torch.load`.
* - ``TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD``
- If set to [``1``, ``y``, ``yes``, ``true``], the torch.load will use ``weights_only=False`` if the ``weights_only`` variable was not
passed at the callsite. For more documentation on this, see :func:`torch.load`.
* - ``TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT``
- Under some conditions, autograd threads can hang on shutdown, therefore we do not wait for them to shutdown indefinitely but rely on timeout that is default set to ``10`` seconds. This environment variable can be used to set the timeout in seconds.
* - ``TORCH_DEVICE_BACKEND_AUTOLOAD``

View File

@ -4308,6 +4308,32 @@ class TestSerialization(TestCase, SerializationMixin):
f.seek(0)
torch.load(f, weights_only=True)
@parametrize("force_weights_only", (True, False))
def test_weights_only_env_variables(self, force_weights_only):
env_var = "TORCH_FORCE_WEIGHTS_ONLY_LOAD" if force_weights_only else "TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"
args = (
(pickle.UnpicklingError, "Weights only load failed")
if force_weights_only
else (UserWarning, "forcing weights_only=False")
)
ctx = self.assertRaisesRegex if force_weights_only else self.assertWarnsRegex
m = torch.nn.Linear(3, 5)
with TemporaryFileName() as f:
torch.save(m, f)
try:
old_value = os.environ[env_var] if env_var in os.environ else None
os.environ[env_var] = "1"
# if weights_only is explicitly set, TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD cannot override it
with self.assertRaisesRegex(pickle.UnpicklingError, "Weights only load failed"):
m = torch.load(f, weights_only=not force_weights_only)
with ctx(*args):
m = torch.load(f, weights_only=None)
finally:
if old_value is None:
del os.environ[env_var]
else:
os.environ[env_var] = old_value
def run(self, *args, **kwargs):
with serialization_method(use_zip=True):
return super().run(*args, **kwargs)

View File

@ -1277,20 +1277,37 @@ def load(
"is not supported yet. Please call torch.load outside the skip_data context manager."
)
true_values = ["1", "y", "yes", "true"]
# Add ability to force safe only or non-safe weight loads via environment variables
force_weights_only_load = (
os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0") in true_values
)
force_no_weights_only_load = (
os.getenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "0") in true_values
)
if force_weights_only_load and force_no_weights_only_load:
raise RuntimeError(
"Only one of `TORCH_FORCE_WEIGHTS_ONLY_LOAD` or `TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD` "
"should be set, but both were set."
)
elif force_weights_only_load:
weights_only = True
elif force_no_weights_only_load:
if weights_only is None:
warnings.warn(
"Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the"
"`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.",
UserWarning,
stacklevel=2,
)
weights_only = False
if weights_only is None:
weights_only, warn_weights_only = False, True
else:
warn_weights_only = False
# Add ability to force safe only weight loads via environment variable
if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in [
"1",
"y",
"yes",
"true",
]:
weights_only = True
if weights_only:
if pickle_module is not None:
raise RuntimeError(