[opaque obj] Initial OpaqueObject (#162660)

A big pain point ppl have with custom ops is that they do not accept arbitrary input/outputs. In this PR we create the concept of an "OpaqueObject" which allows users to pass arbitrary python objects into custom operators.

Some still slightly annoying parts with this implementation:
- The schema of the operator is `__torch__.torch.classes.aten.OpaqueObject` instead of whatever python type
- `@torch.library.custom_op` doesn't work.. yet?

UX:
```python
from torch._library.opaque_object import make_opaque, get_payload

# your custom python class
class OpaqueQueue:
    def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
        super().__init__()
        self.queue = queue
        self.init_tensor_ = init_tensor_

    def push(self, tensor: torch.Tensor) -> None:
        self.queue.append(tensor)

    def pop(self) -> torch.Tensor:
        if len(self.queue) > 0:
            return self.queue.pop(0)
        return self.init_tensor_

    def size(self) -> int:
        return len(self.queue)

queue = OpaqueQueue([], torch.zeros(3))
obj: torch._C.ScriptObject = make_opaque(queue)

# obj.payload stores a direct reference to this python queue object
self.assertEqual(get_payload(obj), queue)

# This is able to be passed through the dispatcher
torch.ops._TestOpaqueObject.queue_push(obj, torch.ones(3))
self.assertTrue(queue.size(), 1)
```

Authoring a custom op:

```python
lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT")

torch.library.define(
    f"_TestOpaqueObject::queue_push",
    "(__torch__.torch.classes.aten.OpaqueObject a, Tensor b) -> ()",
    tags=torch.Tag.pt2_compliant_tag,
    lib=lib,
)

@torch.library.impl(f"{libname}::queue_push", "CompositeExplicitAutograd", lib=lib)
def push_impl(q: torch._C.ScriptObject, b: torch.Tensor) -> None:
    # We can get the payload directly by get_payload(q)
    queue = get_payload(q)
    assert isinstance(queue, OpaqueQueue)
    queue.push(b)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162660
Approved by: https://github.com/zou3519
This commit is contained in:
angelayi
2025-09-21 17:52:03 -07:00
committed by PyTorch MergeBot
parent bec967eaa4
commit 3be9c86c74
5 changed files with 156 additions and 0 deletions

87
test/test_opaque_obj.py Normal file
View File

@ -0,0 +1,87 @@
# Owner(s): ["module: custom-operators"]
import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._library.opaque_object import get_payload, make_opaque
class OpaqueQueue:
def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
super().__init__()
self.queue = queue
self.init_tensor_ = init_tensor_
def push(self, tensor: torch.Tensor) -> None:
self.queue.append(tensor)
def pop(self) -> torch.Tensor:
if len(self.queue) > 0:
return self.queue.pop(0)
return self.init_tensor_
def size(self) -> int:
return len(self.queue)
class TestOpaqueObject(TestCase):
def setUp(self):
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") # noqa: TOR901
torch.library.define(
"_TestOpaqueObject::queue_push",
"(__torch__.torch.classes.aten.OpaqueObject a, Tensor b) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=self.lib,
)
@torch.library.impl(
"_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib
)
def push_impl(q: torch._C.ScriptObject, b: torch.Tensor) -> None:
queue = get_payload(q)
assert isinstance(queue, OpaqueQueue)
queue.push(b)
self.lib.define(
"queue_pop(__torch__.torch.classes.aten.OpaqueObject a) -> Tensor",
)
def pop_impl(q: torch._C.ScriptObject) -> torch.Tensor:
queue = get_payload(q)
assert isinstance(queue, OpaqueQueue)
return queue.pop()
self.lib.impl("queue_pop", pop_impl, "CompositeExplicitAutograd")
super().setUp()
def tearDown(self):
self.lib._destroy()
super().tearDown()
def test_creation(self):
queue = OpaqueQueue([], torch.zeros(3))
obj = make_opaque(queue)
self.assertTrue(isinstance(obj, torch._C.ScriptObject))
self.assertEqual(str(obj._type()), "__torch__.torch.classes.aten.OpaqueObject")
# obj.payload stores a direct reference to this python queue object
payload = get_payload(obj)
self.assertEqual(payload, queue)
queue.push(torch.ones(3))
self.assertEqual(payload.size(), 1)
def test_ops(self):
queue = OpaqueQueue([], torch.zeros(3))
obj = make_opaque(queue)
torch.ops._TestOpaqueObject.queue_push(obj, torch.ones(3) + 1)
self.assertEqual(queue.size(), 1)
popped = torch.ops._TestOpaqueObject.queue_pop(obj)
self.assertEqual(popped, torch.ones(3) + 1)
self.assertEqual(queue.size(), 0)
if __name__ == "__main__":
run_tests()

View File

@ -1621,6 +1621,9 @@ def _jit_pass_dce(Graph) -> None: ...
def _jit_pass_dce_graph(Graph) -> None: ...
def _jit_pass_lint(Graph) -> None: ...
def _make_opaque_object(payload: Any) -> ScriptObject: ...
def _get_opaque_object_payload(obj: ScriptObject) -> Any: ...
# Defined in torch/csrc/jit/python/python_custom_class.cpp
def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...

View File

@ -0,0 +1,16 @@
from typing import Any
import torch
def make_opaque(payload: Any) -> torch._C.ScriptObject:
"""
Creates an opaque object which stores the given Python object.
This opaque object can be passed to any custom operator as an argument.
The Python object can then be accessed from the opaque object using the `get_payload()` API.
"""
return torch._C._make_opaque_object(payload)
def get_payload(opaque_object: torch._C.ScriptObject) -> Any:
return torch._C._get_opaque_object_payload(opaque_object)

View File

@ -78,6 +78,7 @@
#include <torch/csrc/jit/passes/vulkan_rewrite.h>
#include <torch/csrc/jit/passes/xnnpack_rewrite.h>
#include <torch/csrc/jit/python/init.h>
#include <torch/csrc/jit/python/opaque_obj.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/jit/python/python_arg_flatten.h>
#include <torch/csrc/jit/python/python_custom_class.h>
@ -1865,6 +1866,25 @@ void initJITBindings(PyObject* module) {
&parseSchema,
py::arg("schema"),
py::arg("allow_typevars") = true);
m.def(
"_make_opaque_object",
[](py::object payload) {
auto obj = c10::make_intrusive<OpaqueObject>(payload);
auto typePtr =
torch::getCustomClass("__torch__.torch.classes.aten.OpaqueObject");
return torch::jit::toPyObject(c10::IValue(std::move(obj)));
},
R"doc(Creates an opaque object which stores the given Python object.)doc");
m.def(
"_get_opaque_object_payload",
[](py::object obj) {
auto typePtr =
torch::getCustomClass("__torch__.torch.classes.aten.OpaqueObject");
auto ivalue = torch::jit::toIValue(std::move(obj), typePtr);
auto customObj = ivalue.toCustomClass<OpaqueObject>();
return customObj->getPayload();
},
R"doc(Returns the Python object stored on the given opaque object.)doc");
m.def("unify_type_list", [](const std::vector<TypePtr>& types) {
std::ostringstream s;
auto type = unifyTypeList(types, s);

View File

@ -0,0 +1,30 @@
#pragma once
#include <string>
#include <c10/macros/Macros.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/custom_class.h>
namespace torch::jit {
struct OpaqueObject : public CustomClassHolder {
OpaqueObject(py::object payload) : payload_(payload) {}
void setPayload(py::object payload) {
payload_ = payload;
}
py::object getPayload() {
return payload_;
}
py::object payload_;
};
static auto register_opaque_obj_class =
torch::class_<OpaqueObject>("aten", "OpaqueObject");
} // namespace torch::jit