[opaque_obj] Add set_payload + docs (#163276)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163276
Approved by: https://github.com/zou3519
ghstack dependencies: #162660
This commit is contained in:
angelayi
2025-09-21 17:52:04 -07:00
committed by PyTorch MergeBot
parent 3be9c86c74
commit dd30667f6c
4 changed files with 115 additions and 3 deletions

View File

@ -2,7 +2,7 @@
import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._library.opaque_object import get_payload, make_opaque
from torch._library.opaque_object import get_payload, make_opaque, set_payload
class OpaqueQueue:
@ -74,7 +74,8 @@ class TestOpaqueObject(TestCase):
def test_ops(self):
queue = OpaqueQueue([], torch.zeros(3))
obj = make_opaque(queue)
obj = make_opaque()
set_payload(obj, queue)
torch.ops._TestOpaqueObject.queue_push(obj, torch.ones(3) + 1)
self.assertEqual(queue.size(), 1)

View File

@ -1623,6 +1623,7 @@ def _jit_pass_lint(Graph) -> None: ...
def _make_opaque_object(payload: Any) -> ScriptObject: ...
def _get_opaque_object_payload(obj: ScriptObject) -> Any: ...
def _set_opaque_object_payload(obj: ScriptObject, payload: Any) -> None: ...
# Defined in torch/csrc/jit/python/python_custom_class.cpp
def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...

View File

@ -3,14 +3,114 @@ from typing import Any
import torch
def make_opaque(payload: Any) -> torch._C.ScriptObject:
OPAQUE_OBJ_TYPE = "__torch__.torch.classes.aten.OpaqueObject"
def make_opaque(payload: Any = None) -> torch._C.ScriptObject:
"""
Creates an opaque object which stores the given Python object.
This opaque object can be passed to any custom operator as an argument.
The Python object can then be accessed from the opaque object using the `get_payload()` API.
The opaque object has `._type()`
"__torch__.torch.classes.aten.OpaqueObject", which should be the type used
when creating custom operator schemas.
Args:
payload (Any): The Python object to store in the opaque object. This can
be empty, and can be set with `set_payload()` later.
Returns:
torch._C.ScriptObject: The opaque object that stores the given Python object.
Example:
>>> import random
>>> import torch
>>> from torch._library.opaque_object import (
... make_opaque,
... get_payload,
... set_payload,
... )
>>>
>>> class RNGState:
>>> def __init__(self, seed):
>>> self.rng = random.Random(seed)
>>>
>>> rng = RNGState(0)
>>> obj = make_opaque()
>>> set_payload(obj, rng)
>>>
>>> assert get_payload(obj) == rng
>>>
>>> lib = torch.library.Library("mylib", "FRAGMENT")
>>>
>>> torch.library.define(
>>> "mylib::noisy_inject",
>>> "(Tensor x, __torch__.torch.classes.aten.OpaqueObject obj) -> Tensor",
>>> tags=torch.Tag.pt2_compliant_tag,
>>> lib=lib,
>>> )
>>>
>>> @torch.library.impl(
>>> "mylib::noisy_inject", "CompositeExplicitAutograd", lib=lib
>>> )
>>> def noisy_inject(x: torch.Tensor, obj: torch._C.ScriptObject) -> torch.Tensor:
>>> rng_state = get_payload(obj)
>>> assert isinstance(rng_state, RNGState)
>>> out = x.clone()
>>> for i in range(out.numel()):
>>> out.view(-1)[i] += rng_state.rng.random()
>>> return out
>>>
>>> print(torch.ops.mylib.noisy_inject(torch.ones(3), obj))
"""
return torch._C._make_opaque_object(payload)
def get_payload(opaque_object: torch._C.ScriptObject) -> Any:
"""
Retrieves the Python object stored in the given opaque object.
Args:
torch._C.ScriptObject: The opaque object that stores the given Python object.
Returns:
payload (Any): The Python object stored in the opaque object. This can
be set with `set_payload()`.
"""
if not (
isinstance(opaque_object, torch._C.ScriptObject)
and opaque_object._type().qualified_name() == OPAQUE_OBJ_TYPE
):
type_ = (
opaque_object._type().qualified_name()
if isinstance(opaque_object, torch._C.ScriptObject)
else type(opaque_object)
)
raise ValueError(
f"Tried to get the payload from a non-OpaqueObject of type `{type_}`"
)
return torch._C._get_opaque_object_payload(opaque_object)
def set_payload(opaque_object: torch._C.ScriptObject, payload: Any) -> None:
"""
Sets the Python object stored in the given opaque object.
Args:
torch._C.ScriptObject: The opaque object that stores the given Python object.
payload (Any): The Python object to store in the opaque object.
"""
if not (
isinstance(opaque_object, torch._C.ScriptObject)
and opaque_object._type().qualified_name() == OPAQUE_OBJ_TYPE
):
type_ = (
opaque_object._type().qualified_name()
if isinstance(opaque_object, torch._C.ScriptObject)
else type(opaque_object)
)
raise ValueError(
f"Tried to get the payload from a non-OpaqueObject of type `{type_}`"
)
torch._C._set_opaque_object_payload(opaque_object, payload)

View File

@ -1885,6 +1885,16 @@ void initJITBindings(PyObject* module) {
return customObj->getPayload();
},
R"doc(Returns the Python object stored on the given opaque object.)doc");
m.def(
"_set_opaque_object_payload",
[](py::object obj, py::object payload) {
auto typePtr =
torch::getCustomClass("__torch__.torch.classes.aten.OpaqueObject");
auto ivalue = torch::jit::toIValue(std::move(obj), typePtr);
auto customObj = ivalue.toCustomClass<OpaqueObject>();
customObj->setPayload(std::move(payload));
},
R"doc(Sets the payload of the given opaque object with the given Python object.)doc");
m.def("unify_type_list", [](const std::vector<TypePtr>& types) {
std::ostringstream s;
auto type = unifyTypeList(types, s);