mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[opaque_obj_v2] PyObject custom op schema type (#165004)
This is a cleaner implementation of opaque objects (https://github.com/pytorch/pytorch/pull/162660). Instead now we just need to do: Call `register_opaque_type` to register the type as being "opaque" and allowed by custom ops. You also need to pass a unique name that maps to the type. ```python class OpaqueQueue: def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None: super().__init__() self.queue = queue self.init_tensor_ = init_tensor_ def push(self, tensor: torch.Tensor) -> None: self.queue.append(tensor) def pop(self) -> torch.Tensor: if len(self.queue) > 0: return self.queue.pop(0) return self.init_tensor_ def size(self) -> int: return len(self.queue) register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue") ``` When creating the custom op, the schema will then use the unique name: ```python self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") torch.library.define( "_TestOpaqueObject::queue_push", "(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()", tags=torch.Tag.pt2_compliant_tag, lib=self.lib, ) @torch.library.impl( "_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib ) def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None: assert isinstance(queue, OpaqueQueue) queue.push(b) ``` Using the custom op: ```python queue = OpaqueQueue([], torch.zeros(3)) torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3)) self.assertTrue(queue.size(), 1) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165004 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
3f83e8915e
commit
2b4ef6b4d6
84
test/test_opaque_obj_v2.py
Normal file
84
test/test_opaque_obj_v2.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Owner(s): ["module: custom-operators"]
|
||||
|
||||
import torch
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._library.opaque_object import register_opaque_type
|
||||
|
||||
|
||||
class OpaqueQueue:
|
||||
def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
|
||||
super().__init__()
|
||||
self.queue = queue
|
||||
self.init_tensor_ = init_tensor_
|
||||
|
||||
def push(self, tensor: torch.Tensor) -> None:
|
||||
self.queue.append(tensor)
|
||||
|
||||
def pop(self) -> torch.Tensor:
|
||||
if len(self.queue) > 0:
|
||||
return self.queue.pop(0)
|
||||
return self.init_tensor_
|
||||
|
||||
def size(self) -> int:
|
||||
return len(self.queue)
|
||||
|
||||
|
||||
class TestOpaqueObject(TestCase):
|
||||
def setUp(self):
|
||||
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") # noqa: TOR901
|
||||
|
||||
register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
|
||||
|
||||
torch.library.define(
|
||||
"_TestOpaqueObject::queue_push",
|
||||
"(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=self.lib,
|
||||
)
|
||||
|
||||
@torch.library.impl(
|
||||
"_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib
|
||||
)
|
||||
def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None:
|
||||
assert isinstance(queue, OpaqueQueue)
|
||||
queue.push(b)
|
||||
|
||||
self.lib.define(
|
||||
"queue_pop(_TestOpaqueObject_OpaqueQueue a) -> Tensor",
|
||||
)
|
||||
|
||||
def pop_impl(queue: OpaqueQueue) -> torch.Tensor:
|
||||
assert isinstance(queue, OpaqueQueue)
|
||||
return queue.pop()
|
||||
|
||||
self.lib.impl("queue_pop", pop_impl, "CompositeExplicitAutograd")
|
||||
|
||||
@torch.library.custom_op(
|
||||
"_TestOpaqueObject::queue_size",
|
||||
mutates_args=[],
|
||||
)
|
||||
def size_impl(queue: OpaqueQueue) -> int:
|
||||
assert isinstance(queue, OpaqueQueue)
|
||||
return queue.size()
|
||||
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
self.lib._destroy()
|
||||
|
||||
super().tearDown()
|
||||
|
||||
def test_ops(self):
|
||||
queue = OpaqueQueue([], torch.zeros(3))
|
||||
|
||||
torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3) + 1)
|
||||
size = torch.ops._TestOpaqueObject.queue_size(queue)
|
||||
self.assertEqual(size, 1)
|
||||
popped = torch.ops._TestOpaqueObject.queue_pop(queue)
|
||||
self.assertEqual(popped, torch.ones(3) + 1)
|
||||
size = torch.ops._TestOpaqueObject.queue_size(queue)
|
||||
self.assertEqual(size, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -1627,6 +1627,8 @@ def _jit_pass_lint(Graph) -> None: ...
|
||||
def _make_opaque_object(payload: Any) -> ScriptObject: ...
|
||||
def _get_opaque_object_payload(obj: ScriptObject) -> Any: ...
|
||||
def _set_opaque_object_payload(obj: ScriptObject, payload: Any) -> None: ...
|
||||
def _register_opaque_type(type_name: str) -> None: ...
|
||||
def _is_opaque_type_registered(type_name: str) -> _bool: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/python_custom_class.cpp
|
||||
def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...
|
||||
|
@ -9,7 +9,7 @@ import torch
|
||||
from torch import device, dtype, Tensor, types
|
||||
from torch.utils._exposed_in import exposed_in
|
||||
|
||||
from .opaque_object import OpaqueType, OpaqueTypeStr
|
||||
from .opaque_object import _OPAQUE_TYPES, is_opaque_type, OpaqueType, OpaqueTypeStr
|
||||
|
||||
|
||||
# This is used as a negative test for
|
||||
@ -125,8 +125,11 @@ def infer_schema(
|
||||
# we convert it to the actual type.
|
||||
annotation_type, _ = unstringify_type(param.annotation)
|
||||
|
||||
schema_type = None
|
||||
if annotation_type not in SUPPORTED_PARAM_TYPES:
|
||||
if annotation_type == torch._C.ScriptObject:
|
||||
if is_opaque_type(annotation_type):
|
||||
schema_type = _OPAQUE_TYPES[annotation_type]
|
||||
elif annotation_type == torch._C.ScriptObject:
|
||||
error_fn(
|
||||
f"Parameter {name}'s type cannot be inferred from the schema "
|
||||
"as it is a ScriptObject. Please manually specify the schema "
|
||||
@ -152,8 +155,11 @@ def infer_schema(
|
||||
f"Parameter {name} has unsupported type {param.annotation}. "
|
||||
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
|
||||
)
|
||||
|
||||
else:
|
||||
schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
|
||||
|
||||
assert schema_type is not None
|
||||
|
||||
if type(mutates_args) is str:
|
||||
if mutates_args != UNKNOWN_MUTATES:
|
||||
raise ValueError(
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, NewType
|
||||
from typing import Any, NewType, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -150,3 +150,36 @@ def set_payload(opaque_object: torch._C.ScriptObject, payload: Any) -> None:
|
||||
f"Tried to get the payload from a non-OpaqueObject of type `{type_}`"
|
||||
)
|
||||
torch._C._set_opaque_object_payload(opaque_object, payload)
|
||||
|
||||
|
||||
_OPAQUE_TYPES: dict[Any, str] = {}
|
||||
|
||||
|
||||
def register_opaque_type(cls: Any, name: Optional[str] = None) -> None:
|
||||
"""
|
||||
Registers the given type as an opaque type which allows this to be consumed
|
||||
by a custom operator.
|
||||
|
||||
Args:
|
||||
cls (type): The class to register as an opaque type.
|
||||
name (str): A unique qualified name of the type.
|
||||
"""
|
||||
if name is None:
|
||||
name = cls.__name__
|
||||
|
||||
if "." in name:
|
||||
# The schema_type_parser will break up types with periods
|
||||
raise ValueError(
|
||||
f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
|
||||
)
|
||||
_OPAQUE_TYPES[cls] = name
|
||||
torch._C._register_opaque_type(name)
|
||||
|
||||
|
||||
def is_opaque_type(cls: Any) -> bool:
|
||||
"""
|
||||
Checks if the given type is an opaque type.
|
||||
"""
|
||||
if cls not in _OPAQUE_TYPES:
|
||||
return False
|
||||
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <torch/csrc/jit/frontend/parse_string_literal.h>
|
||||
#include <torch/custom_class.h>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
using c10::AliasInfo;
|
||||
using c10::AwaitType;
|
||||
@ -42,6 +43,25 @@ using c10::VarType;
|
||||
|
||||
namespace torch::jit {
|
||||
|
||||
static std::unordered_set<std::string>& getOpaqueTypes() {
|
||||
static std::unordered_set<std::string> global_opaque_types;
|
||||
return global_opaque_types;
|
||||
}
|
||||
|
||||
void registerOpaqueType(const std::string& type_name) {
|
||||
auto& global_opaque_types = getOpaqueTypes();
|
||||
auto [_, inserted] = global_opaque_types.insert(type_name);
|
||||
if (!inserted) {
|
||||
throw std::runtime_error(
|
||||
"Type '" + type_name + "' is already registered as an opaque type");
|
||||
}
|
||||
}
|
||||
|
||||
bool isRegisteredOpaqueType(const std::string& type_name) {
|
||||
auto& global_opaque_types = getOpaqueTypes();
|
||||
return global_opaque_types.find(type_name) != global_opaque_types.end();
|
||||
}
|
||||
|
||||
TypePtr SchemaTypeParser::parseBaseType() {
|
||||
static std::unordered_map<std::string, TypePtr> type_map = {
|
||||
{"Generator", c10::TypeFactory::get<GeneratorType>()},
|
||||
@ -81,6 +101,11 @@ TypePtr SchemaTypeParser::parseBaseType() {
|
||||
}
|
||||
std::string text = tok.text();
|
||||
|
||||
// Check if this type is registered as an opaque type first
|
||||
if (isRegisteredOpaqueType(text)) {
|
||||
return c10::PyObjectType::get();
|
||||
}
|
||||
|
||||
auto it = type_map.find(text);
|
||||
if (it == type_map.end()) {
|
||||
if (allow_typevars_ && !text.empty() && islower(text[0])) {
|
||||
|
@ -10,6 +10,9 @@ namespace torch::jit {
|
||||
|
||||
using TypePtr = c10::TypePtr;
|
||||
|
||||
TORCH_API void registerOpaqueType(const std::string& type_name);
|
||||
TORCH_API bool isRegisteredOpaqueType(const std::string& type_name);
|
||||
|
||||
struct TORCH_API SchemaTypeParser {
|
||||
TypePtr parseBaseType();
|
||||
std::optional<c10::AliasInfo> parseAliasAnnotation();
|
||||
|
@ -15,6 +15,7 @@
|
||||
#endif
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||
#include <torch/csrc/jit/frontend/schema_type_parser.h>
|
||||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
@ -1890,6 +1891,18 @@ void initJITBindings(PyObject* module) {
|
||||
customObj->setPayload(std::move(payload));
|
||||
},
|
||||
R"doc(Sets the payload of the given opaque object with the given Python object.)doc");
|
||||
m.def(
|
||||
"_register_opaque_type",
|
||||
[](const std::string& type_name) {
|
||||
torch::jit::registerOpaqueType(type_name);
|
||||
},
|
||||
R"doc(Registers a type name to be treated as an opaque type (PyObject) in schema parsing.)doc");
|
||||
m.def(
|
||||
"_is_opaque_type_registered",
|
||||
[](const std::string& type_name) -> bool {
|
||||
return torch::jit::isRegisteredOpaqueType(type_name);
|
||||
},
|
||||
R"doc(Checks if a type name is registered as an opaque type.)doc");
|
||||
m.def("unify_type_list", [](const std::vector<TypePtr>& types) {
|
||||
std::ostringstream s;
|
||||
auto type = unifyTypeList(types, s);
|
||||
|
Reference in New Issue
Block a user