diff --git a/test/test_opaque_obj_v2.py b/test/test_opaque_obj_v2.py new file mode 100644 index 000000000000..aea2441c61b9 --- /dev/null +++ b/test/test_opaque_obj_v2.py @@ -0,0 +1,84 @@ +# Owner(s): ["module: custom-operators"] + +import torch +from torch._dynamo.test_case import run_tests, TestCase +from torch._library.opaque_object import register_opaque_type + + +class OpaqueQueue: + def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None: + super().__init__() + self.queue = queue + self.init_tensor_ = init_tensor_ + + def push(self, tensor: torch.Tensor) -> None: + self.queue.append(tensor) + + def pop(self) -> torch.Tensor: + if len(self.queue) > 0: + return self.queue.pop(0) + return self.init_tensor_ + + def size(self) -> int: + return len(self.queue) + + +class TestOpaqueObject(TestCase): + def setUp(self): + self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") # noqa: TOR901 + + register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue") + + torch.library.define( + "_TestOpaqueObject::queue_push", + "(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=self.lib, + ) + + @torch.library.impl( + "_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib + ) + def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None: + assert isinstance(queue, OpaqueQueue) + queue.push(b) + + self.lib.define( + "queue_pop(_TestOpaqueObject_OpaqueQueue a) -> Tensor", + ) + + def pop_impl(queue: OpaqueQueue) -> torch.Tensor: + assert isinstance(queue, OpaqueQueue) + return queue.pop() + + self.lib.impl("queue_pop", pop_impl, "CompositeExplicitAutograd") + + @torch.library.custom_op( + "_TestOpaqueObject::queue_size", + mutates_args=[], + ) + def size_impl(queue: OpaqueQueue) -> int: + assert isinstance(queue, OpaqueQueue) + return queue.size() + + super().setUp() + + def tearDown(self): + self.lib._destroy() + + super().tearDown() + + def test_ops(self): + queue = OpaqueQueue([], torch.zeros(3)) + + torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3) + 1) + size = torch.ops._TestOpaqueObject.queue_size(queue) + self.assertEqual(size, 1) + popped = torch.ops._TestOpaqueObject.queue_pop(queue) + self.assertEqual(popped, torch.ones(3) + 1) + size = torch.ops._TestOpaqueObject.queue_size(queue) + self.assertEqual(size, 0) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 2f6ad3f6de67..9597690fd28d 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1627,6 +1627,8 @@ def _jit_pass_lint(Graph) -> None: ... def _make_opaque_object(payload: Any) -> ScriptObject: ... def _get_opaque_object_payload(obj: ScriptObject) -> Any: ... def _set_opaque_object_payload(obj: ScriptObject, payload: Any) -> None: ... +def _register_opaque_type(type_name: str) -> None: ... +def _is_opaque_type_registered(type_name: str) -> _bool: ... # Defined in torch/csrc/jit/python/python_custom_class.cpp def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ... diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index 05fe47cd3733..51986d08e23c 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -9,7 +9,7 @@ import torch from torch import device, dtype, Tensor, types from torch.utils._exposed_in import exposed_in -from .opaque_object import OpaqueType, OpaqueTypeStr +from .opaque_object import _OPAQUE_TYPES, is_opaque_type, OpaqueType, OpaqueTypeStr # This is used as a negative test for @@ -125,8 +125,11 @@ def infer_schema( # we convert it to the actual type. annotation_type, _ = unstringify_type(param.annotation) + schema_type = None if annotation_type not in SUPPORTED_PARAM_TYPES: - if annotation_type == torch._C.ScriptObject: + if is_opaque_type(annotation_type): + schema_type = _OPAQUE_TYPES[annotation_type] + elif annotation_type == torch._C.ScriptObject: error_fn( f"Parameter {name}'s type cannot be inferred from the schema " "as it is a ScriptObject. Please manually specify the schema " @@ -152,8 +155,11 @@ def infer_schema( f"Parameter {name} has unsupported type {param.annotation}. " f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." ) + else: + schema_type = SUPPORTED_PARAM_TYPES[annotation_type] + + assert schema_type is not None - schema_type = SUPPORTED_PARAM_TYPES[annotation_type] if type(mutates_args) is str: if mutates_args != UNKNOWN_MUTATES: raise ValueError( diff --git a/torch/_library/opaque_object.py b/torch/_library/opaque_object.py index ba02970d5504..b3460fa2dda8 100644 --- a/torch/_library/opaque_object.py +++ b/torch/_library/opaque_object.py @@ -1,4 +1,4 @@ -from typing import Any, NewType +from typing import Any, NewType, Optional import torch @@ -150,3 +150,36 @@ def set_payload(opaque_object: torch._C.ScriptObject, payload: Any) -> None: f"Tried to get the payload from a non-OpaqueObject of type `{type_}`" ) torch._C._set_opaque_object_payload(opaque_object, payload) + + +_OPAQUE_TYPES: dict[Any, str] = {} + + +def register_opaque_type(cls: Any, name: Optional[str] = None) -> None: + """ + Registers the given type as an opaque type which allows this to be consumed + by a custom operator. + + Args: + cls (type): The class to register as an opaque type. + name (str): A unique qualified name of the type. + """ + if name is None: + name = cls.__name__ + + if "." in name: + # The schema_type_parser will break up types with periods + raise ValueError( + f"Unable to accept name, {name}, for this opaque type as it contains a '.'" + ) + _OPAQUE_TYPES[cls] = name + torch._C._register_opaque_type(name) + + +def is_opaque_type(cls: Any) -> bool: + """ + Checks if the given type is an opaque type. + """ + if cls not in _OPAQUE_TYPES: + return False + return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls]) diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index 4df9fb663984..735856dc10a7 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -8,6 +8,7 @@ #include #include #include +#include using c10::AliasInfo; using c10::AwaitType; @@ -42,6 +43,25 @@ using c10::VarType; namespace torch::jit { +static std::unordered_set& getOpaqueTypes() { + static std::unordered_set global_opaque_types; + return global_opaque_types; +} + +void registerOpaqueType(const std::string& type_name) { + auto& global_opaque_types = getOpaqueTypes(); + auto [_, inserted] = global_opaque_types.insert(type_name); + if (!inserted) { + throw std::runtime_error( + "Type '" + type_name + "' is already registered as an opaque type"); + } +} + +bool isRegisteredOpaqueType(const std::string& type_name) { + auto& global_opaque_types = getOpaqueTypes(); + return global_opaque_types.find(type_name) != global_opaque_types.end(); +} + TypePtr SchemaTypeParser::parseBaseType() { static std::unordered_map type_map = { {"Generator", c10::TypeFactory::get()}, @@ -81,6 +101,11 @@ TypePtr SchemaTypeParser::parseBaseType() { } std::string text = tok.text(); + // Check if this type is registered as an opaque type first + if (isRegisteredOpaqueType(text)) { + return c10::PyObjectType::get(); + } + auto it = type_map.find(text); if (it == type_map.end()) { if (allow_typevars_ && !text.empty() && islower(text[0])) { diff --git a/torch/csrc/jit/frontend/schema_type_parser.h b/torch/csrc/jit/frontend/schema_type_parser.h index ca5a00ecaa3f..19f108fa17e8 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.h +++ b/torch/csrc/jit/frontend/schema_type_parser.h @@ -10,6 +10,9 @@ namespace torch::jit { using TypePtr = c10::TypePtr; +TORCH_API void registerOpaqueType(const std::string& type_name); +TORCH_API bool isRegisteredOpaqueType(const std::string& type_name); + struct TORCH_API SchemaTypeParser { TypePtr parseBaseType(); std::optional parseAliasAnnotation(); diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 9b6f1b5ee3de..beb6f8951980 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -15,6 +15,7 @@ #endif #include #include +#include #include #include #include @@ -1890,6 +1891,18 @@ void initJITBindings(PyObject* module) { customObj->setPayload(std::move(payload)); }, R"doc(Sets the payload of the given opaque object with the given Python object.)doc"); + m.def( + "_register_opaque_type", + [](const std::string& type_name) { + torch::jit::registerOpaqueType(type_name); + }, + R"doc(Registers a type name to be treated as an opaque type (PyObject) in schema parsing.)doc"); + m.def( + "_is_opaque_type_registered", + [](const std::string& type_name) -> bool { + return torch::jit::isRegisteredOpaqueType(type_name); + }, + R"doc(Checks if a type name is registered as an opaque type.)doc"); m.def("unify_type_list", [](const std::vector& types) { std::ostringstream s; auto type = unifyTypeList(types, s);