mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow BUILD/NEWOBJ instruction for items added via torch.serialization.add_safe_globals (#129251)
Previously, allowlisting functions/classes via `torch.serialization.add_safe_globals(obj)` for the `weights_only` Unpickler had the following effect:
- For a [`GLOBAL`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L1926-L1939) instruction, `GLOBAL obj.__module__ obj.__name__` would be allowed and translated back to obj to be pushed back to the stack.
- For a [`REDUCE`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L1926-L1982) instruction where we expect the stack to contain `func` and `args`, `func` is allowed if it was added via `add_safe_globals`
However, it did not have an effect on `BUILD` and `NEWOBJ` instructions
Some classes may be rebuilt via [`NEWOBJ`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L2091-L2104) instruction, which indicates that their constructor should be used to rebuild the class.
Further, a [`BUILD`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L1984-L2007) instruction might be used if an object's `__reduce__`/`__reduce_ex__` returns a non-None value for `state`. Which indicates a `__setstate__` or `__dict__.update`.
**This PR makes sure that adding objects to the allowlist will also allow `NEWOBJ` and `BUILD` instructions for them.**
In particular, the update for `NEWOBJ` should unblock allowlisting of [`ScaledMMConfig`](d4ade877df/float8_experimental/float8_tensor.py (L26-L30)
) in float8_experimental @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129251
Approved by: https://github.com/albanD
ghstack dependencies: #129244
This commit is contained in:
committed by
PyTorch MergeBot
parent
1bb1e3463c
commit
c5f7755e86
@ -536,6 +536,16 @@ class DTensorTest(DTensorTestBase):
|
||||
buffer.seek(0)
|
||||
reloaded_st = torch.load(buffer)
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
# Test weights_only load
|
||||
try:
|
||||
torch.serialization.add_safe_globals(
|
||||
[DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]
|
||||
)
|
||||
buffer.seek(0)
|
||||
reloaded_st = torch.load(buffer, weights_only=True)
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
finally:
|
||||
torch.serialization.clear_safe_globals()
|
||||
|
||||
|
||||
class DTensorMeshTest(DTensorTestBase):
|
||||
|
Reference in New Issue
Block a user