mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
A big pain point ppl have with custom ops is that they do not accept arbitrary input/outputs. In this PR we create the concept of an "OpaqueObject" which allows users to pass arbitrary python objects into custom operators. Some still slightly annoying parts with this implementation: - The schema of the operator is `__torch__.torch.classes.aten.OpaqueObject` instead of whatever python type - `@torch.library.custom_op` doesn't work.. yet? UX: ```python from torch._library.opaque_object import make_opaque, get_payload # your custom python class class OpaqueQueue: def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None: super().__init__() self.queue = queue self.init_tensor_ = init_tensor_ def push(self, tensor: torch.Tensor) -> None: self.queue.append(tensor) def pop(self) -> torch.Tensor: if len(self.queue) > 0: return self.queue.pop(0) return self.init_tensor_ def size(self) -> int: return len(self.queue) queue = OpaqueQueue([], torch.zeros(3)) obj: torch._C.ScriptObject = make_opaque(queue) # obj.payload stores a direct reference to this python queue object self.assertEqual(get_payload(obj), queue) # This is able to be passed through the dispatcher torch.ops._TestOpaqueObject.queue_push(obj, torch.ones(3)) self.assertTrue(queue.size(), 1) ``` Authoring a custom op: ```python lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") torch.library.define( f"_TestOpaqueObject::queue_push", "(__torch__.torch.classes.aten.OpaqueObject a, Tensor b) -> ()", tags=torch.Tag.pt2_compliant_tag, lib=lib, ) @torch.library.impl(f"{libname}::queue_push", "CompositeExplicitAutograd", lib=lib) def push_impl(q: torch._C.ScriptObject, b: torch.Tensor) -> None: # We can get the payload directly by get_payload(q) queue = get_payload(q) assert isinstance(queue, OpaqueQueue) queue.push(b) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/162660 Approved by: https://github.com/zou3519
17 lines
526 B
Python
17 lines
526 B
Python
from typing import Any
|
|
|
|
import torch
|
|
|
|
|
|
def make_opaque(payload: Any) -> torch._C.ScriptObject:
|
|
"""
|
|
Creates an opaque object which stores the given Python object.
|
|
This opaque object can be passed to any custom operator as an argument.
|
|
The Python object can then be accessed from the opaque object using the `get_payload()` API.
|
|
"""
|
|
return torch._C._make_opaque_object(payload)
|
|
|
|
|
|
def get_payload(opaque_object: torch._C.ScriptObject) -> Any:
|
|
return torch._C._get_opaque_object_payload(opaque_object)
|