mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This is a cleaner implementation of opaque objects (https://github.com/pytorch/pytorch/pull/162660). Instead now we just need to do: Call `register_opaque_type` to register the type as being "opaque" and allowed by custom ops. You also need to pass a unique name that maps to the type. ```python class OpaqueQueue: def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None: super().__init__() self.queue = queue self.init_tensor_ = init_tensor_ def push(self, tensor: torch.Tensor) -> None: self.queue.append(tensor) def pop(self) -> torch.Tensor: if len(self.queue) > 0: return self.queue.pop(0) return self.init_tensor_ def size(self) -> int: return len(self.queue) register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue") ``` When creating the custom op, the schema will then use the unique name: ```python self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") torch.library.define( "_TestOpaqueObject::queue_push", "(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()", tags=torch.Tag.pt2_compliant_tag, lib=self.lib, ) @torch.library.impl( "_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib ) def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None: assert isinstance(queue, OpaqueQueue) queue.push(b) ``` Using the custom op: ```python queue = OpaqueQueue([], torch.zeros(3)) torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3)) self.assertTrue(queue.size(), 1) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165004 Approved by: https://github.com/albanD
186 lines
6.4 KiB
Python
186 lines
6.4 KiB
Python
from typing import Any, NewType, Optional
|
|
|
|
import torch
|
|
|
|
from .fake_class_registry import FakeScriptObject, register_fake_class
|
|
|
|
|
|
@register_fake_class("aten::OpaqueObject")
|
|
class FakeOpaqueObject:
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
@classmethod
|
|
def __obj_unflatten__(cls, flattened_ctx: dict[str, Any]) -> None:
|
|
raise RuntimeError(
|
|
"FakeOpaqueObject should not be created through __obj_unflatten__ "
|
|
"and should be special handled. Please file an issue to Github."
|
|
)
|
|
|
|
|
|
OpaqueTypeStr = "__torch__.torch.classes.aten.OpaqueObject"
|
|
|
|
OpaqueType = NewType("OpaqueType", torch._C.ScriptObject)
|
|
|
|
|
|
def make_opaque(payload: Any = None) -> torch._C.ScriptObject:
|
|
"""
|
|
Creates an opaque object which stores the given Python object.
|
|
This opaque object can be passed to any custom operator as an argument.
|
|
The Python object can then be accessed from the opaque object using the `get_payload()` API.
|
|
The opaque object has `._type()`
|
|
"__torch__.torch.classes.aten.OpaqueObject", which should be the type used
|
|
when creating custom operator schemas.
|
|
|
|
Args:
|
|
payload (Any): The Python object to store in the opaque object. This can
|
|
be empty, and can be set with `set_payload()` later.
|
|
|
|
Returns:
|
|
torch._C.ScriptObject: The opaque object that stores the given Python object.
|
|
|
|
Example:
|
|
|
|
>>> import random
|
|
>>> import torch
|
|
>>> from torch._library.opaque_object import (
|
|
... make_opaque,
|
|
... get_payload,
|
|
... set_payload,
|
|
... )
|
|
>>>
|
|
>>> class RNGState:
|
|
>>> def __init__(self, seed):
|
|
>>> self.rng = random.Random(seed)
|
|
>>>
|
|
>>> rng = RNGState(0)
|
|
>>> obj = make_opaque()
|
|
>>> set_payload(obj, rng)
|
|
>>>
|
|
>>> assert get_payload(obj) == rng
|
|
>>>
|
|
>>> lib = torch.library.Library("mylib", "FRAGMENT")
|
|
>>>
|
|
>>> torch.library.define(
|
|
>>> "mylib::noisy_inject",
|
|
>>> "(Tensor x, __torch__.torch.classes.aten.OpaqueObject obj) -> Tensor",
|
|
>>> tags=torch.Tag.pt2_compliant_tag,
|
|
>>> lib=lib,
|
|
>>> )
|
|
>>>
|
|
>>> @torch.library.impl(
|
|
>>> "mylib::noisy_inject", "CompositeExplicitAutograd", lib=lib
|
|
>>> )
|
|
>>> def noisy_inject(x: torch.Tensor, obj: torch._C.ScriptObject) -> torch.Tensor:
|
|
>>> rng_state = get_payload(obj)
|
|
>>> assert isinstance(rng_state, RNGState)
|
|
>>> out = x.clone()
|
|
>>> for i in range(out.numel()):
|
|
>>> out.view(-1)[i] += rng_state.rng.random()
|
|
>>> return out
|
|
>>>
|
|
>>> print(torch.ops.mylib.noisy_inject(torch.ones(3), obj))
|
|
"""
|
|
return torch._C._make_opaque_object(payload)
|
|
|
|
|
|
def get_payload(opaque_object: torch._C.ScriptObject) -> Any:
|
|
"""
|
|
Retrieves the Python object stored in the given opaque object.
|
|
|
|
Args:
|
|
torch._C.ScriptObject: The opaque object that stores the given Python object.
|
|
|
|
Returns:
|
|
payload (Any): The Python object stored in the opaque object. This can
|
|
be set with `set_payload()`.
|
|
"""
|
|
if isinstance(opaque_object, FakeScriptObject):
|
|
raise ValueError(
|
|
"get_payload: this function was called with a FakeScriptObject "
|
|
"implying that you are calling get_payload inside of a fake kernel."
|
|
"The fake kernel should not depend on the contents of the "
|
|
"OpaqueObject at all, so we're erroring out. If you need this"
|
|
"functionality, consider creating a custom TorchBind Object instead"
|
|
"(but note that this is more difficult)."
|
|
)
|
|
if not (
|
|
isinstance(opaque_object, torch._C.ScriptObject)
|
|
and opaque_object._type().qualified_name() == OpaqueTypeStr
|
|
):
|
|
type_ = (
|
|
opaque_object._type().qualified_name()
|
|
if isinstance(opaque_object, torch._C.ScriptObject)
|
|
else type(opaque_object)
|
|
)
|
|
raise ValueError(
|
|
f"Tried to get the payload from a non-OpaqueObject of type `{type_}`"
|
|
)
|
|
return torch._C._get_opaque_object_payload(opaque_object)
|
|
|
|
|
|
def set_payload(opaque_object: torch._C.ScriptObject, payload: Any) -> None:
|
|
"""
|
|
Sets the Python object stored in the given opaque object.
|
|
|
|
Args:
|
|
torch._C.ScriptObject: The opaque object that stores the given Python object.
|
|
payload (Any): The Python object to store in the opaque object.
|
|
"""
|
|
if isinstance(opaque_object, FakeScriptObject):
|
|
raise ValueError(
|
|
"set_payload: this function was called with a FakeScriptObject "
|
|
"implying that you are calling get_payload inside of a fake kernel."
|
|
"The fake kernel should not depend on the contents of the "
|
|
"OpaqueObject at all, so we're erroring out. If you need this"
|
|
"functionality, consider creating a custom TorchBind Object instead"
|
|
"(but note that this is more difficult)."
|
|
)
|
|
|
|
if not (
|
|
isinstance(opaque_object, torch._C.ScriptObject)
|
|
and opaque_object._type().qualified_name() == OpaqueTypeStr
|
|
):
|
|
type_ = (
|
|
opaque_object._type().qualified_name()
|
|
if isinstance(opaque_object, torch._C.ScriptObject)
|
|
else type(opaque_object)
|
|
)
|
|
raise ValueError(
|
|
f"Tried to get the payload from a non-OpaqueObject of type `{type_}`"
|
|
)
|
|
torch._C._set_opaque_object_payload(opaque_object, payload)
|
|
|
|
|
|
_OPAQUE_TYPES: dict[Any, str] = {}
|
|
|
|
|
|
def register_opaque_type(cls: Any, name: Optional[str] = None) -> None:
|
|
"""
|
|
Registers the given type as an opaque type which allows this to be consumed
|
|
by a custom operator.
|
|
|
|
Args:
|
|
cls (type): The class to register as an opaque type.
|
|
name (str): A unique qualified name of the type.
|
|
"""
|
|
if name is None:
|
|
name = cls.__name__
|
|
|
|
if "." in name:
|
|
# The schema_type_parser will break up types with periods
|
|
raise ValueError(
|
|
f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
|
|
)
|
|
_OPAQUE_TYPES[cls] = name
|
|
torch._C._register_opaque_type(name)
|
|
|
|
|
|
def is_opaque_type(cls: Any) -> bool:
|
|
"""
|
|
Checks if the given type is an opaque type.
|
|
"""
|
|
if cls not in _OPAQUE_TYPES:
|
|
return False
|
|
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])
|