mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[opaque obj] Initial OpaqueObject (#162660)
A big pain point ppl have with custom ops is that they do not accept arbitrary input/outputs. In this PR we create the concept of an "OpaqueObject" which allows users to pass arbitrary python objects into custom operators. Some still slightly annoying parts with this implementation: - The schema of the operator is `__torch__.torch.classes.aten.OpaqueObject` instead of whatever python type - `@torch.library.custom_op` doesn't work.. yet? UX: ```python from torch._library.opaque_object import make_opaque, get_payload # your custom python class class OpaqueQueue: def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None: super().__init__() self.queue = queue self.init_tensor_ = init_tensor_ def push(self, tensor: torch.Tensor) -> None: self.queue.append(tensor) def pop(self) -> torch.Tensor: if len(self.queue) > 0: return self.queue.pop(0) return self.init_tensor_ def size(self) -> int: return len(self.queue) queue = OpaqueQueue([], torch.zeros(3)) obj: torch._C.ScriptObject = make_opaque(queue) # obj.payload stores a direct reference to this python queue object self.assertEqual(get_payload(obj), queue) # This is able to be passed through the dispatcher torch.ops._TestOpaqueObject.queue_push(obj, torch.ones(3)) self.assertTrue(queue.size(), 1) ``` Authoring a custom op: ```python lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") torch.library.define( f"_TestOpaqueObject::queue_push", "(__torch__.torch.classes.aten.OpaqueObject a, Tensor b) -> ()", tags=torch.Tag.pt2_compliant_tag, lib=lib, ) @torch.library.impl(f"{libname}::queue_push", "CompositeExplicitAutograd", lib=lib) def push_impl(q: torch._C.ScriptObject, b: torch.Tensor) -> None: # We can get the payload directly by get_payload(q) queue = get_payload(q) assert isinstance(queue, OpaqueQueue) queue.push(b) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/162660 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
bec967eaa4
commit
3be9c86c74
87
test/test_opaque_obj.py
Normal file
87
test/test_opaque_obj.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Owner(s): ["module: custom-operators"]
|
||||
|
||||
import torch
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._library.opaque_object import get_payload, make_opaque
|
||||
|
||||
|
||||
class OpaqueQueue:
|
||||
def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
|
||||
super().__init__()
|
||||
self.queue = queue
|
||||
self.init_tensor_ = init_tensor_
|
||||
|
||||
def push(self, tensor: torch.Tensor) -> None:
|
||||
self.queue.append(tensor)
|
||||
|
||||
def pop(self) -> torch.Tensor:
|
||||
if len(self.queue) > 0:
|
||||
return self.queue.pop(0)
|
||||
return self.init_tensor_
|
||||
|
||||
def size(self) -> int:
|
||||
return len(self.queue)
|
||||
|
||||
|
||||
class TestOpaqueObject(TestCase):
|
||||
def setUp(self):
|
||||
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") # noqa: TOR901
|
||||
|
||||
torch.library.define(
|
||||
"_TestOpaqueObject::queue_push",
|
||||
"(__torch__.torch.classes.aten.OpaqueObject a, Tensor b) -> ()",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=self.lib,
|
||||
)
|
||||
|
||||
@torch.library.impl(
|
||||
"_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib
|
||||
)
|
||||
def push_impl(q: torch._C.ScriptObject, b: torch.Tensor) -> None:
|
||||
queue = get_payload(q)
|
||||
assert isinstance(queue, OpaqueQueue)
|
||||
queue.push(b)
|
||||
|
||||
self.lib.define(
|
||||
"queue_pop(__torch__.torch.classes.aten.OpaqueObject a) -> Tensor",
|
||||
)
|
||||
|
||||
def pop_impl(q: torch._C.ScriptObject) -> torch.Tensor:
|
||||
queue = get_payload(q)
|
||||
assert isinstance(queue, OpaqueQueue)
|
||||
return queue.pop()
|
||||
|
||||
self.lib.impl("queue_pop", pop_impl, "CompositeExplicitAutograd")
|
||||
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
self.lib._destroy()
|
||||
|
||||
super().tearDown()
|
||||
|
||||
def test_creation(self):
|
||||
queue = OpaqueQueue([], torch.zeros(3))
|
||||
obj = make_opaque(queue)
|
||||
self.assertTrue(isinstance(obj, torch._C.ScriptObject))
|
||||
self.assertEqual(str(obj._type()), "__torch__.torch.classes.aten.OpaqueObject")
|
||||
|
||||
# obj.payload stores a direct reference to this python queue object
|
||||
payload = get_payload(obj)
|
||||
self.assertEqual(payload, queue)
|
||||
queue.push(torch.ones(3))
|
||||
self.assertEqual(payload.size(), 1)
|
||||
|
||||
def test_ops(self):
|
||||
queue = OpaqueQueue([], torch.zeros(3))
|
||||
obj = make_opaque(queue)
|
||||
|
||||
torch.ops._TestOpaqueObject.queue_push(obj, torch.ones(3) + 1)
|
||||
self.assertEqual(queue.size(), 1)
|
||||
popped = torch.ops._TestOpaqueObject.queue_pop(obj)
|
||||
self.assertEqual(popped, torch.ones(3) + 1)
|
||||
self.assertEqual(queue.size(), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -1621,6 +1621,9 @@ def _jit_pass_dce(Graph) -> None: ...
|
||||
def _jit_pass_dce_graph(Graph) -> None: ...
|
||||
def _jit_pass_lint(Graph) -> None: ...
|
||||
|
||||
def _make_opaque_object(payload: Any) -> ScriptObject: ...
|
||||
def _get_opaque_object_payload(obj: ScriptObject) -> Any: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/python_custom_class.cpp
|
||||
def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...
|
||||
|
||||
|
16
torch/_library/opaque_object.py
Normal file
16
torch/_library/opaque_object.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def make_opaque(payload: Any) -> torch._C.ScriptObject:
|
||||
"""
|
||||
Creates an opaque object which stores the given Python object.
|
||||
This opaque object can be passed to any custom operator as an argument.
|
||||
The Python object can then be accessed from the opaque object using the `get_payload()` API.
|
||||
"""
|
||||
return torch._C._make_opaque_object(payload)
|
||||
|
||||
|
||||
def get_payload(opaque_object: torch._C.ScriptObject) -> Any:
|
||||
return torch._C._get_opaque_object_payload(opaque_object)
|
@ -78,6 +78,7 @@
|
||||
#include <torch/csrc/jit/passes/vulkan_rewrite.h>
|
||||
#include <torch/csrc/jit/passes/xnnpack_rewrite.h>
|
||||
#include <torch/csrc/jit/python/init.h>
|
||||
#include <torch/csrc/jit/python/opaque_obj.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/jit/python/python_arg_flatten.h>
|
||||
#include <torch/csrc/jit/python/python_custom_class.h>
|
||||
@ -1865,6 +1866,25 @@ void initJITBindings(PyObject* module) {
|
||||
&parseSchema,
|
||||
py::arg("schema"),
|
||||
py::arg("allow_typevars") = true);
|
||||
m.def(
|
||||
"_make_opaque_object",
|
||||
[](py::object payload) {
|
||||
auto obj = c10::make_intrusive<OpaqueObject>(payload);
|
||||
auto typePtr =
|
||||
torch::getCustomClass("__torch__.torch.classes.aten.OpaqueObject");
|
||||
return torch::jit::toPyObject(c10::IValue(std::move(obj)));
|
||||
},
|
||||
R"doc(Creates an opaque object which stores the given Python object.)doc");
|
||||
m.def(
|
||||
"_get_opaque_object_payload",
|
||||
[](py::object obj) {
|
||||
auto typePtr =
|
||||
torch::getCustomClass("__torch__.torch.classes.aten.OpaqueObject");
|
||||
auto ivalue = torch::jit::toIValue(std::move(obj), typePtr);
|
||||
auto customObj = ivalue.toCustomClass<OpaqueObject>();
|
||||
return customObj->getPayload();
|
||||
},
|
||||
R"doc(Returns the Python object stored on the given opaque object.)doc");
|
||||
m.def("unify_type_list", [](const std::vector<TypePtr>& types) {
|
||||
std::ostringstream s;
|
||||
auto type = unifyTypeList(types, s);
|
||||
|
30
torch/csrc/jit/python/opaque_obj.h
Normal file
30
torch/csrc/jit/python/opaque_obj.h
Normal file
@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
namespace torch::jit {
|
||||
struct OpaqueObject : public CustomClassHolder {
|
||||
OpaqueObject(py::object payload) : payload_(payload) {}
|
||||
|
||||
void setPayload(py::object payload) {
|
||||
payload_ = payload;
|
||||
}
|
||||
|
||||
py::object getPayload() {
|
||||
return payload_;
|
||||
}
|
||||
|
||||
py::object payload_;
|
||||
};
|
||||
|
||||
static auto register_opaque_obj_class =
|
||||
torch::class_<OpaqueObject>("aten", "OpaqueObject");
|
||||
|
||||
} // namespace torch::jit
|
Reference in New Issue
Block a user