Files
pytorch/torch/_library/opaque_object.py
angelayi 2b4ef6b4d6 [opaque_obj_v2] PyObject custom op schema type (#165004)
This is a cleaner implementation of opaque objects (https://github.com/pytorch/pytorch/pull/162660). Instead now we just need to do:

Call `register_opaque_type` to register the type as being "opaque" and allowed by custom ops. You also need to pass a unique name that maps to the type.
```python
class OpaqueQueue:
    def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
        super().__init__()
        self.queue = queue
        self.init_tensor_ = init_tensor_

    def push(self, tensor: torch.Tensor) -> None:
        self.queue.append(tensor)

    def pop(self) -> torch.Tensor:
        if len(self.queue) > 0:
            return self.queue.pop(0)
        return self.init_tensor_

    def size(self) -> int:
        return len(self.queue)

register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
```

When creating the custom op, the schema will then use the unique name:
```python
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT")

torch.library.define(
    "_TestOpaqueObject::queue_push",
    "(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()",
    tags=torch.Tag.pt2_compliant_tag,
    lib=self.lib,
)

@torch.library.impl(
    "_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib
)
def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None:
    assert isinstance(queue, OpaqueQueue)
    queue.push(b)
```

Using the custom op:
```python
queue = OpaqueQueue([], torch.zeros(3))
torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3))
self.assertTrue(queue.size(), 1)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165004
Approved by: https://github.com/albanD
2025-10-14 20:21:04 +00:00

186 lines
6.4 KiB
Python

from typing import Any, NewType, Optional
import torch
from .fake_class_registry import FakeScriptObject, register_fake_class
@register_fake_class("aten::OpaqueObject")
class FakeOpaqueObject:
def __init__(self) -> None:
pass
@classmethod
def __obj_unflatten__(cls, flattened_ctx: dict[str, Any]) -> None:
raise RuntimeError(
"FakeOpaqueObject should not be created through __obj_unflatten__ "
"and should be special handled. Please file an issue to Github."
)
OpaqueTypeStr = "__torch__.torch.classes.aten.OpaqueObject"
OpaqueType = NewType("OpaqueType", torch._C.ScriptObject)
def make_opaque(payload: Any = None) -> torch._C.ScriptObject:
"""
Creates an opaque object which stores the given Python object.
This opaque object can be passed to any custom operator as an argument.
The Python object can then be accessed from the opaque object using the `get_payload()` API.
The opaque object has `._type()`
"__torch__.torch.classes.aten.OpaqueObject", which should be the type used
when creating custom operator schemas.
Args:
payload (Any): The Python object to store in the opaque object. This can
be empty, and can be set with `set_payload()` later.
Returns:
torch._C.ScriptObject: The opaque object that stores the given Python object.
Example:
>>> import random
>>> import torch
>>> from torch._library.opaque_object import (
... make_opaque,
... get_payload,
... set_payload,
... )
>>>
>>> class RNGState:
>>> def __init__(self, seed):
>>> self.rng = random.Random(seed)
>>>
>>> rng = RNGState(0)
>>> obj = make_opaque()
>>> set_payload(obj, rng)
>>>
>>> assert get_payload(obj) == rng
>>>
>>> lib = torch.library.Library("mylib", "FRAGMENT")
>>>
>>> torch.library.define(
>>> "mylib::noisy_inject",
>>> "(Tensor x, __torch__.torch.classes.aten.OpaqueObject obj) -> Tensor",
>>> tags=torch.Tag.pt2_compliant_tag,
>>> lib=lib,
>>> )
>>>
>>> @torch.library.impl(
>>> "mylib::noisy_inject", "CompositeExplicitAutograd", lib=lib
>>> )
>>> def noisy_inject(x: torch.Tensor, obj: torch._C.ScriptObject) -> torch.Tensor:
>>> rng_state = get_payload(obj)
>>> assert isinstance(rng_state, RNGState)
>>> out = x.clone()
>>> for i in range(out.numel()):
>>> out.view(-1)[i] += rng_state.rng.random()
>>> return out
>>>
>>> print(torch.ops.mylib.noisy_inject(torch.ones(3), obj))
"""
return torch._C._make_opaque_object(payload)
def get_payload(opaque_object: torch._C.ScriptObject) -> Any:
"""
Retrieves the Python object stored in the given opaque object.
Args:
torch._C.ScriptObject: The opaque object that stores the given Python object.
Returns:
payload (Any): The Python object stored in the opaque object. This can
be set with `set_payload()`.
"""
if isinstance(opaque_object, FakeScriptObject):
raise ValueError(
"get_payload: this function was called with a FakeScriptObject "
"implying that you are calling get_payload inside of a fake kernel."
"The fake kernel should not depend on the contents of the "
"OpaqueObject at all, so we're erroring out. If you need this"
"functionality, consider creating a custom TorchBind Object instead"
"(but note that this is more difficult)."
)
if not (
isinstance(opaque_object, torch._C.ScriptObject)
and opaque_object._type().qualified_name() == OpaqueTypeStr
):
type_ = (
opaque_object._type().qualified_name()
if isinstance(opaque_object, torch._C.ScriptObject)
else type(opaque_object)
)
raise ValueError(
f"Tried to get the payload from a non-OpaqueObject of type `{type_}`"
)
return torch._C._get_opaque_object_payload(opaque_object)
def set_payload(opaque_object: torch._C.ScriptObject, payload: Any) -> None:
"""
Sets the Python object stored in the given opaque object.
Args:
torch._C.ScriptObject: The opaque object that stores the given Python object.
payload (Any): The Python object to store in the opaque object.
"""
if isinstance(opaque_object, FakeScriptObject):
raise ValueError(
"set_payload: this function was called with a FakeScriptObject "
"implying that you are calling get_payload inside of a fake kernel."
"The fake kernel should not depend on the contents of the "
"OpaqueObject at all, so we're erroring out. If you need this"
"functionality, consider creating a custom TorchBind Object instead"
"(but note that this is more difficult)."
)
if not (
isinstance(opaque_object, torch._C.ScriptObject)
and opaque_object._type().qualified_name() == OpaqueTypeStr
):
type_ = (
opaque_object._type().qualified_name()
if isinstance(opaque_object, torch._C.ScriptObject)
else type(opaque_object)
)
raise ValueError(
f"Tried to get the payload from a non-OpaqueObject of type `{type_}`"
)
torch._C._set_opaque_object_payload(opaque_object, payload)
_OPAQUE_TYPES: dict[Any, str] = {}
def register_opaque_type(cls: Any, name: Optional[str] = None) -> None:
"""
Registers the given type as an opaque type which allows this to be consumed
by a custom operator.
Args:
cls (type): The class to register as an opaque type.
name (str): A unique qualified name of the type.
"""
if name is None:
name = cls.__name__
if "." in name:
# The schema_type_parser will break up types with periods
raise ValueError(
f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
)
_OPAQUE_TYPES[cls] = name
torch._C._register_opaque_type(name)
def is_opaque_type(cls: Any) -> bool:
"""
Checks if the given type is an opaque type.
"""
if cls not in _OPAQUE_TYPES:
return False
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])