mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[opaque_obj] Add set_payload + docs (#163276)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163276 Approved by: https://github.com/zou3519 ghstack dependencies: #162660
This commit is contained in:
committed by
PyTorch MergeBot
parent
3be9c86c74
commit
dd30667f6c
@ -2,7 +2,7 @@
|
||||
|
||||
import torch
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._library.opaque_object import get_payload, make_opaque
|
||||
from torch._library.opaque_object import get_payload, make_opaque, set_payload
|
||||
|
||||
|
||||
class OpaqueQueue:
|
||||
@ -74,7 +74,8 @@ class TestOpaqueObject(TestCase):
|
||||
|
||||
def test_ops(self):
|
||||
queue = OpaqueQueue([], torch.zeros(3))
|
||||
obj = make_opaque(queue)
|
||||
obj = make_opaque()
|
||||
set_payload(obj, queue)
|
||||
|
||||
torch.ops._TestOpaqueObject.queue_push(obj, torch.ones(3) + 1)
|
||||
self.assertEqual(queue.size(), 1)
|
||||
|
@ -1623,6 +1623,7 @@ def _jit_pass_lint(Graph) -> None: ...
|
||||
|
||||
def _make_opaque_object(payload: Any) -> ScriptObject: ...
|
||||
def _get_opaque_object_payload(obj: ScriptObject) -> Any: ...
|
||||
def _set_opaque_object_payload(obj: ScriptObject, payload: Any) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/python_custom_class.cpp
|
||||
def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...
|
||||
|
@ -3,14 +3,114 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
|
||||
def make_opaque(payload: Any) -> torch._C.ScriptObject:
|
||||
OPAQUE_OBJ_TYPE = "__torch__.torch.classes.aten.OpaqueObject"
|
||||
|
||||
|
||||
def make_opaque(payload: Any = None) -> torch._C.ScriptObject:
|
||||
"""
|
||||
Creates an opaque object which stores the given Python object.
|
||||
This opaque object can be passed to any custom operator as an argument.
|
||||
The Python object can then be accessed from the opaque object using the `get_payload()` API.
|
||||
The opaque object has `._type()`
|
||||
"__torch__.torch.classes.aten.OpaqueObject", which should be the type used
|
||||
when creating custom operator schemas.
|
||||
|
||||
Args:
|
||||
payload (Any): The Python object to store in the opaque object. This can
|
||||
be empty, and can be set with `set_payload()` later.
|
||||
|
||||
Returns:
|
||||
torch._C.ScriptObject: The opaque object that stores the given Python object.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import random
|
||||
>>> import torch
|
||||
>>> from torch._library.opaque_object import (
|
||||
... make_opaque,
|
||||
... get_payload,
|
||||
... set_payload,
|
||||
... )
|
||||
>>>
|
||||
>>> class RNGState:
|
||||
>>> def __init__(self, seed):
|
||||
>>> self.rng = random.Random(seed)
|
||||
>>>
|
||||
>>> rng = RNGState(0)
|
||||
>>> obj = make_opaque()
|
||||
>>> set_payload(obj, rng)
|
||||
>>>
|
||||
>>> assert get_payload(obj) == rng
|
||||
>>>
|
||||
>>> lib = torch.library.Library("mylib", "FRAGMENT")
|
||||
>>>
|
||||
>>> torch.library.define(
|
||||
>>> "mylib::noisy_inject",
|
||||
>>> "(Tensor x, __torch__.torch.classes.aten.OpaqueObject obj) -> Tensor",
|
||||
>>> tags=torch.Tag.pt2_compliant_tag,
|
||||
>>> lib=lib,
|
||||
>>> )
|
||||
>>>
|
||||
>>> @torch.library.impl(
|
||||
>>> "mylib::noisy_inject", "CompositeExplicitAutograd", lib=lib
|
||||
>>> )
|
||||
>>> def noisy_inject(x: torch.Tensor, obj: torch._C.ScriptObject) -> torch.Tensor:
|
||||
>>> rng_state = get_payload(obj)
|
||||
>>> assert isinstance(rng_state, RNGState)
|
||||
>>> out = x.clone()
|
||||
>>> for i in range(out.numel()):
|
||||
>>> out.view(-1)[i] += rng_state.rng.random()
|
||||
>>> return out
|
||||
>>>
|
||||
>>> print(torch.ops.mylib.noisy_inject(torch.ones(3), obj))
|
||||
"""
|
||||
return torch._C._make_opaque_object(payload)
|
||||
|
||||
|
||||
def get_payload(opaque_object: torch._C.ScriptObject) -> Any:
|
||||
"""
|
||||
Retrieves the Python object stored in the given opaque object.
|
||||
|
||||
Args:
|
||||
torch._C.ScriptObject: The opaque object that stores the given Python object.
|
||||
|
||||
Returns:
|
||||
payload (Any): The Python object stored in the opaque object. This can
|
||||
be set with `set_payload()`.
|
||||
"""
|
||||
if not (
|
||||
isinstance(opaque_object, torch._C.ScriptObject)
|
||||
and opaque_object._type().qualified_name() == OPAQUE_OBJ_TYPE
|
||||
):
|
||||
type_ = (
|
||||
opaque_object._type().qualified_name()
|
||||
if isinstance(opaque_object, torch._C.ScriptObject)
|
||||
else type(opaque_object)
|
||||
)
|
||||
raise ValueError(
|
||||
f"Tried to get the payload from a non-OpaqueObject of type `{type_}`"
|
||||
)
|
||||
return torch._C._get_opaque_object_payload(opaque_object)
|
||||
|
||||
|
||||
def set_payload(opaque_object: torch._C.ScriptObject, payload: Any) -> None:
|
||||
"""
|
||||
Sets the Python object stored in the given opaque object.
|
||||
|
||||
Args:
|
||||
torch._C.ScriptObject: The opaque object that stores the given Python object.
|
||||
payload (Any): The Python object to store in the opaque object.
|
||||
"""
|
||||
if not (
|
||||
isinstance(opaque_object, torch._C.ScriptObject)
|
||||
and opaque_object._type().qualified_name() == OPAQUE_OBJ_TYPE
|
||||
):
|
||||
type_ = (
|
||||
opaque_object._type().qualified_name()
|
||||
if isinstance(opaque_object, torch._C.ScriptObject)
|
||||
else type(opaque_object)
|
||||
)
|
||||
raise ValueError(
|
||||
f"Tried to get the payload from a non-OpaqueObject of type `{type_}`"
|
||||
)
|
||||
torch._C._set_opaque_object_payload(opaque_object, payload)
|
||||
|
@ -1885,6 +1885,16 @@ void initJITBindings(PyObject* module) {
|
||||
return customObj->getPayload();
|
||||
},
|
||||
R"doc(Returns the Python object stored on the given opaque object.)doc");
|
||||
m.def(
|
||||
"_set_opaque_object_payload",
|
||||
[](py::object obj, py::object payload) {
|
||||
auto typePtr =
|
||||
torch::getCustomClass("__torch__.torch.classes.aten.OpaqueObject");
|
||||
auto ivalue = torch::jit::toIValue(std::move(obj), typePtr);
|
||||
auto customObj = ivalue.toCustomClass<OpaqueObject>();
|
||||
customObj->setPayload(std::move(payload));
|
||||
},
|
||||
R"doc(Sets the payload of the given opaque object with the given Python object.)doc");
|
||||
m.def("unify_type_list", [](const std::vector<TypePtr>& types) {
|
||||
std::ostringstream s;
|
||||
auto type = unifyTypeList(types, s);
|
||||
|
Reference in New Issue
Block a user