mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[opaque_obj_v2] PyObject custom op schema type (#165004)
This is a cleaner implementation of opaque objects (https://github.com/pytorch/pytorch/pull/162660). Instead now we just need to do: Call `register_opaque_type` to register the type as being "opaque" and allowed by custom ops. You also need to pass a unique name that maps to the type. ```python class OpaqueQueue: def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None: super().__init__() self.queue = queue self.init_tensor_ = init_tensor_ def push(self, tensor: torch.Tensor) -> None: self.queue.append(tensor) def pop(self) -> torch.Tensor: if len(self.queue) > 0: return self.queue.pop(0) return self.init_tensor_ def size(self) -> int: return len(self.queue) register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue") ``` When creating the custom op, the schema will then use the unique name: ```python self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") torch.library.define( "_TestOpaqueObject::queue_push", "(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()", tags=torch.Tag.pt2_compliant_tag, lib=self.lib, ) @torch.library.impl( "_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib ) def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None: assert isinstance(queue, OpaqueQueue) queue.push(b) ``` Using the custom op: ```python queue = OpaqueQueue([], torch.zeros(3)) torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3)) self.assertTrue(queue.size(), 1) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165004 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
3f83e8915e
commit
2b4ef6b4d6
84
test/test_opaque_obj_v2.py
Normal file
84
test/test_opaque_obj_v2.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# Owner(s): ["module: custom-operators"]
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch._dynamo.test_case import run_tests, TestCase
|
||||||
|
from torch._library.opaque_object import register_opaque_type
|
||||||
|
|
||||||
|
|
||||||
|
class OpaqueQueue:
|
||||||
|
def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.queue = queue
|
||||||
|
self.init_tensor_ = init_tensor_
|
||||||
|
|
||||||
|
def push(self, tensor: torch.Tensor) -> None:
|
||||||
|
self.queue.append(tensor)
|
||||||
|
|
||||||
|
def pop(self) -> torch.Tensor:
|
||||||
|
if len(self.queue) > 0:
|
||||||
|
return self.queue.pop(0)
|
||||||
|
return self.init_tensor_
|
||||||
|
|
||||||
|
def size(self) -> int:
|
||||||
|
return len(self.queue)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpaqueObject(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") # noqa: TOR901
|
||||||
|
|
||||||
|
register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
|
||||||
|
|
||||||
|
torch.library.define(
|
||||||
|
"_TestOpaqueObject::queue_push",
|
||||||
|
"(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()",
|
||||||
|
tags=torch.Tag.pt2_compliant_tag,
|
||||||
|
lib=self.lib,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.library.impl(
|
||||||
|
"_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib
|
||||||
|
)
|
||||||
|
def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None:
|
||||||
|
assert isinstance(queue, OpaqueQueue)
|
||||||
|
queue.push(b)
|
||||||
|
|
||||||
|
self.lib.define(
|
||||||
|
"queue_pop(_TestOpaqueObject_OpaqueQueue a) -> Tensor",
|
||||||
|
)
|
||||||
|
|
||||||
|
def pop_impl(queue: OpaqueQueue) -> torch.Tensor:
|
||||||
|
assert isinstance(queue, OpaqueQueue)
|
||||||
|
return queue.pop()
|
||||||
|
|
||||||
|
self.lib.impl("queue_pop", pop_impl, "CompositeExplicitAutograd")
|
||||||
|
|
||||||
|
@torch.library.custom_op(
|
||||||
|
"_TestOpaqueObject::queue_size",
|
||||||
|
mutates_args=[],
|
||||||
|
)
|
||||||
|
def size_impl(queue: OpaqueQueue) -> int:
|
||||||
|
assert isinstance(queue, OpaqueQueue)
|
||||||
|
return queue.size()
|
||||||
|
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.lib._destroy()
|
||||||
|
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
|
def test_ops(self):
|
||||||
|
queue = OpaqueQueue([], torch.zeros(3))
|
||||||
|
|
||||||
|
torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3) + 1)
|
||||||
|
size = torch.ops._TestOpaqueObject.queue_size(queue)
|
||||||
|
self.assertEqual(size, 1)
|
||||||
|
popped = torch.ops._TestOpaqueObject.queue_pop(queue)
|
||||||
|
self.assertEqual(popped, torch.ones(3) + 1)
|
||||||
|
size = torch.ops._TestOpaqueObject.queue_size(queue)
|
||||||
|
self.assertEqual(size, 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
@ -1627,6 +1627,8 @@ def _jit_pass_lint(Graph) -> None: ...
|
|||||||
def _make_opaque_object(payload: Any) -> ScriptObject: ...
|
def _make_opaque_object(payload: Any) -> ScriptObject: ...
|
||||||
def _get_opaque_object_payload(obj: ScriptObject) -> Any: ...
|
def _get_opaque_object_payload(obj: ScriptObject) -> Any: ...
|
||||||
def _set_opaque_object_payload(obj: ScriptObject, payload: Any) -> None: ...
|
def _set_opaque_object_payload(obj: ScriptObject, payload: Any) -> None: ...
|
||||||
|
def _register_opaque_type(type_name: str) -> None: ...
|
||||||
|
def _is_opaque_type_registered(type_name: str) -> _bool: ...
|
||||||
|
|
||||||
# Defined in torch/csrc/jit/python/python_custom_class.cpp
|
# Defined in torch/csrc/jit/python/python_custom_class.cpp
|
||||||
def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...
|
def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...
|
||||||
|
@ -9,7 +9,7 @@ import torch
|
|||||||
from torch import device, dtype, Tensor, types
|
from torch import device, dtype, Tensor, types
|
||||||
from torch.utils._exposed_in import exposed_in
|
from torch.utils._exposed_in import exposed_in
|
||||||
|
|
||||||
from .opaque_object import OpaqueType, OpaqueTypeStr
|
from .opaque_object import _OPAQUE_TYPES, is_opaque_type, OpaqueType, OpaqueTypeStr
|
||||||
|
|
||||||
|
|
||||||
# This is used as a negative test for
|
# This is used as a negative test for
|
||||||
@ -125,8 +125,11 @@ def infer_schema(
|
|||||||
# we convert it to the actual type.
|
# we convert it to the actual type.
|
||||||
annotation_type, _ = unstringify_type(param.annotation)
|
annotation_type, _ = unstringify_type(param.annotation)
|
||||||
|
|
||||||
|
schema_type = None
|
||||||
if annotation_type not in SUPPORTED_PARAM_TYPES:
|
if annotation_type not in SUPPORTED_PARAM_TYPES:
|
||||||
if annotation_type == torch._C.ScriptObject:
|
if is_opaque_type(annotation_type):
|
||||||
|
schema_type = _OPAQUE_TYPES[annotation_type]
|
||||||
|
elif annotation_type == torch._C.ScriptObject:
|
||||||
error_fn(
|
error_fn(
|
||||||
f"Parameter {name}'s type cannot be inferred from the schema "
|
f"Parameter {name}'s type cannot be inferred from the schema "
|
||||||
"as it is a ScriptObject. Please manually specify the schema "
|
"as it is a ScriptObject. Please manually specify the schema "
|
||||||
@ -152,8 +155,11 @@ def infer_schema(
|
|||||||
f"Parameter {name} has unsupported type {param.annotation}. "
|
f"Parameter {name} has unsupported type {param.annotation}. "
|
||||||
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
|
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
|
||||||
|
|
||||||
|
assert schema_type is not None
|
||||||
|
|
||||||
schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
|
|
||||||
if type(mutates_args) is str:
|
if type(mutates_args) is str:
|
||||||
if mutates_args != UNKNOWN_MUTATES:
|
if mutates_args != UNKNOWN_MUTATES:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, NewType
|
from typing import Any, NewType, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -150,3 +150,36 @@ def set_payload(opaque_object: torch._C.ScriptObject, payload: Any) -> None:
|
|||||||
f"Tried to get the payload from a non-OpaqueObject of type `{type_}`"
|
f"Tried to get the payload from a non-OpaqueObject of type `{type_}`"
|
||||||
)
|
)
|
||||||
torch._C._set_opaque_object_payload(opaque_object, payload)
|
torch._C._set_opaque_object_payload(opaque_object, payload)
|
||||||
|
|
||||||
|
|
||||||
|
_OPAQUE_TYPES: dict[Any, str] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_opaque_type(cls: Any, name: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
Registers the given type as an opaque type which allows this to be consumed
|
||||||
|
by a custom operator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls (type): The class to register as an opaque type.
|
||||||
|
name (str): A unique qualified name of the type.
|
||||||
|
"""
|
||||||
|
if name is None:
|
||||||
|
name = cls.__name__
|
||||||
|
|
||||||
|
if "." in name:
|
||||||
|
# The schema_type_parser will break up types with periods
|
||||||
|
raise ValueError(
|
||||||
|
f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
|
||||||
|
)
|
||||||
|
_OPAQUE_TYPES[cls] = name
|
||||||
|
torch._C._register_opaque_type(name)
|
||||||
|
|
||||||
|
|
||||||
|
def is_opaque_type(cls: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if the given type is an opaque type.
|
||||||
|
"""
|
||||||
|
if cls not in _OPAQUE_TYPES:
|
||||||
|
return False
|
||||||
|
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#include <torch/csrc/jit/frontend/parse_string_literal.h>
|
#include <torch/csrc/jit/frontend/parse_string_literal.h>
|
||||||
#include <torch/custom_class.h>
|
#include <torch/custom_class.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
using c10::AliasInfo;
|
using c10::AliasInfo;
|
||||||
using c10::AwaitType;
|
using c10::AwaitType;
|
||||||
@ -42,6 +43,25 @@ using c10::VarType;
|
|||||||
|
|
||||||
namespace torch::jit {
|
namespace torch::jit {
|
||||||
|
|
||||||
|
static std::unordered_set<std::string>& getOpaqueTypes() {
|
||||||
|
static std::unordered_set<std::string> global_opaque_types;
|
||||||
|
return global_opaque_types;
|
||||||
|
}
|
||||||
|
|
||||||
|
void registerOpaqueType(const std::string& type_name) {
|
||||||
|
auto& global_opaque_types = getOpaqueTypes();
|
||||||
|
auto [_, inserted] = global_opaque_types.insert(type_name);
|
||||||
|
if (!inserted) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Type '" + type_name + "' is already registered as an opaque type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isRegisteredOpaqueType(const std::string& type_name) {
|
||||||
|
auto& global_opaque_types = getOpaqueTypes();
|
||||||
|
return global_opaque_types.find(type_name) != global_opaque_types.end();
|
||||||
|
}
|
||||||
|
|
||||||
TypePtr SchemaTypeParser::parseBaseType() {
|
TypePtr SchemaTypeParser::parseBaseType() {
|
||||||
static std::unordered_map<std::string, TypePtr> type_map = {
|
static std::unordered_map<std::string, TypePtr> type_map = {
|
||||||
{"Generator", c10::TypeFactory::get<GeneratorType>()},
|
{"Generator", c10::TypeFactory::get<GeneratorType>()},
|
||||||
@ -81,6 +101,11 @@ TypePtr SchemaTypeParser::parseBaseType() {
|
|||||||
}
|
}
|
||||||
std::string text = tok.text();
|
std::string text = tok.text();
|
||||||
|
|
||||||
|
// Check if this type is registered as an opaque type first
|
||||||
|
if (isRegisteredOpaqueType(text)) {
|
||||||
|
return c10::PyObjectType::get();
|
||||||
|
}
|
||||||
|
|
||||||
auto it = type_map.find(text);
|
auto it = type_map.find(text);
|
||||||
if (it == type_map.end()) {
|
if (it == type_map.end()) {
|
||||||
if (allow_typevars_ && !text.empty() && islower(text[0])) {
|
if (allow_typevars_ && !text.empty() && islower(text[0])) {
|
||||||
|
@ -10,6 +10,9 @@ namespace torch::jit {
|
|||||||
|
|
||||||
using TypePtr = c10::TypePtr;
|
using TypePtr = c10::TypePtr;
|
||||||
|
|
||||||
|
TORCH_API void registerOpaqueType(const std::string& type_name);
|
||||||
|
TORCH_API bool isRegisteredOpaqueType(const std::string& type_name);
|
||||||
|
|
||||||
struct TORCH_API SchemaTypeParser {
|
struct TORCH_API SchemaTypeParser {
|
||||||
TypePtr parseBaseType();
|
TypePtr parseBaseType();
|
||||||
std::optional<c10::AliasInfo> parseAliasAnnotation();
|
std::optional<c10::AliasInfo> parseAliasAnnotation();
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
#endif
|
#endif
|
||||||
#include <c10/core/SymNodeImpl.h>
|
#include <c10/core/SymNodeImpl.h>
|
||||||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||||
|
#include <torch/csrc/jit/frontend/schema_type_parser.h>
|
||||||
#include <torch/csrc/jit/frontend/tracer.h>
|
#include <torch/csrc/jit/frontend/tracer.h>
|
||||||
#include <torch/csrc/jit/ir/irparser.h>
|
#include <torch/csrc/jit/ir/irparser.h>
|
||||||
#include <torch/csrc/jit/jit_log.h>
|
#include <torch/csrc/jit/jit_log.h>
|
||||||
@ -1890,6 +1891,18 @@ void initJITBindings(PyObject* module) {
|
|||||||
customObj->setPayload(std::move(payload));
|
customObj->setPayload(std::move(payload));
|
||||||
},
|
},
|
||||||
R"doc(Sets the payload of the given opaque object with the given Python object.)doc");
|
R"doc(Sets the payload of the given opaque object with the given Python object.)doc");
|
||||||
|
m.def(
|
||||||
|
"_register_opaque_type",
|
||||||
|
[](const std::string& type_name) {
|
||||||
|
torch::jit::registerOpaqueType(type_name);
|
||||||
|
},
|
||||||
|
R"doc(Registers a type name to be treated as an opaque type (PyObject) in schema parsing.)doc");
|
||||||
|
m.def(
|
||||||
|
"_is_opaque_type_registered",
|
||||||
|
[](const std::string& type_name) -> bool {
|
||||||
|
return torch::jit::isRegisteredOpaqueType(type_name);
|
||||||
|
},
|
||||||
|
R"doc(Checks if a type name is registered as an opaque type.)doc");
|
||||||
m.def("unify_type_list", [](const std::vector<TypePtr>& types) {
|
m.def("unify_type_list", [](const std::vector<TypePtr>& types) {
|
||||||
std::ostringstream s;
|
std::ostringstream s;
|
||||||
auto type = unifyTypeList(types, s);
|
auto type = unifyTypeList(types, s);
|
||||||
|
Reference in New Issue
Block a user