Revert "[opaque_obj_v2] PyObject custom op schema type (#165004)"

This reverts commit 3faee200674c0c2bca3f395a063264cfd8a9a5b7.

Reverted https://github.com/pytorch/pytorch/pull/165004 on behalf of https://github.com/seemethere due to This fails internal tests, see D84399300 ([comment](https://github.com/pytorch/pytorch/pull/165004#issuecomment-3398906856))
This commit is contained in:
PyTorch MergeBot
2025-10-13 20:08:38 +00:00
parent c44d638b15
commit a71ca4dcb9
7 changed files with 4 additions and 170 deletions

View File

@ -1,84 +0,0 @@
# Owner(s): ["module: custom-operators"]
import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._library.opaque_object import register_opaque_type
class OpaqueQueue:
def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
super().__init__()
self.queue = queue
self.init_tensor_ = init_tensor_
def push(self, tensor: torch.Tensor) -> None:
self.queue.append(tensor)
def pop(self) -> torch.Tensor:
if len(self.queue) > 0:
return self.queue.pop(0)
return self.init_tensor_
def size(self) -> int:
return len(self.queue)
class TestOpaqueObject(TestCase):
def setUp(self):
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") # noqa: TOR901
register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
torch.library.define(
"_TestOpaqueObject::queue_push",
"(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=self.lib,
)
@torch.library.impl(
"_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib
)
def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None:
assert isinstance(queue, OpaqueQueue)
queue.push(b)
self.lib.define(
"queue_pop(_TestOpaqueObject_OpaqueQueue a) -> Tensor",
)
def pop_impl(queue: OpaqueQueue) -> torch.Tensor:
assert isinstance(queue, OpaqueQueue)
return queue.pop()
self.lib.impl("queue_pop", pop_impl, "CompositeExplicitAutograd")
@torch.library.custom_op(
"_TestOpaqueObject::queue_size",
mutates_args=[],
)
def size_impl(queue: OpaqueQueue) -> int:
assert isinstance(queue, OpaqueQueue)
return queue.size()
super().setUp()
def tearDown(self):
self.lib._destroy()
super().tearDown()
def test_ops(self):
queue = OpaqueQueue([], torch.zeros(3))
torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3) + 1)
size = torch.ops._TestOpaqueObject.queue_size(queue)
self.assertEqual(size, 1)
popped = torch.ops._TestOpaqueObject.queue_pop(queue)
self.assertEqual(popped, torch.ones(3) + 1)
size = torch.ops._TestOpaqueObject.queue_size(queue)
self.assertEqual(size, 0)
if __name__ == "__main__":
run_tests()

View File

@ -1627,8 +1627,6 @@ def _jit_pass_lint(Graph) -> None: ...
def _make_opaque_object(payload: Any) -> ScriptObject: ...
def _get_opaque_object_payload(obj: ScriptObject) -> Any: ...
def _set_opaque_object_payload(obj: ScriptObject, payload: Any) -> None: ...
def _register_opaque_type(type_name: str) -> None: ...
def _is_opaque_type_registered(type_name: str) -> _bool: ...
# Defined in torch/csrc/jit/python/python_custom_class.cpp
def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...

View File

@ -9,7 +9,7 @@ import torch
from torch import device, dtype, Tensor, types
from torch.utils._exposed_in import exposed_in
from .opaque_object import _OPAQUE_TYPES, is_opaque_type, OpaqueType, OpaqueTypeStr
from .opaque_object import OpaqueType, OpaqueTypeStr
# This is used as a negative test for
@ -125,11 +125,8 @@ def infer_schema(
# we convert it to the actual type.
annotation_type, _ = unstringify_type(param.annotation)
schema_type = None
if annotation_type not in SUPPORTED_PARAM_TYPES:
if is_opaque_type(annotation_type):
schema_type = _OPAQUE_TYPES[annotation_type]
elif annotation_type == torch._C.ScriptObject:
if annotation_type == torch._C.ScriptObject:
error_fn(
f"Parameter {name}'s type cannot be inferred from the schema "
"as it is a ScriptObject. Please manually specify the schema "
@ -152,11 +149,8 @@ def infer_schema(
f"Parameter {name} has unsupported type {param.annotation}. "
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
)
else:
schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
assert schema_type is not None
schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
if type(mutates_args) is str:
if mutates_args != UNKNOWN_MUTATES:
raise ValueError(

View File

@ -1,4 +1,4 @@
from typing import Any, NewType, Optional
from typing import Any, NewType
import torch
@ -150,36 +150,3 @@ def set_payload(opaque_object: torch._C.ScriptObject, payload: Any) -> None:
f"Tried to get the payload from a non-OpaqueObject of type `{type_}`"
)
torch._C._set_opaque_object_payload(opaque_object, payload)
_OPAQUE_TYPES: dict[Any, str] = {}
def register_opaque_type(cls: Any, name: Optional[str] = None) -> None:
"""
Registers the given type as an opaque type which allows this to be consumed
by a custom operator.
Args:
cls (type): The class to register as an opaque type.
name (str): A unique qualified name of the type.
"""
if name is None:
name = cls.__name__
if "." in name:
# The schema_type_parser will break up types with periods
raise ValueError(
f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
)
_OPAQUE_TYPES[cls] = name
torch._C._register_opaque_type(name)
def is_opaque_type(cls: Any) -> bool:
"""
Checks if the given type is an opaque type.
"""
if cls not in _OPAQUE_TYPES:
return False
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])

View File

@ -8,7 +8,6 @@
#include <torch/csrc/jit/frontend/parse_string_literal.h>
#include <torch/custom_class.h>
#include <string>
#include <unordered_set>
using c10::AliasInfo;
using c10::AwaitType;
@ -43,25 +42,6 @@ using c10::VarType;
namespace torch::jit {
static std::unordered_set<std::string>& getOpaqueTypes() {
static std::unordered_set<std::string> global_opaque_types;
return global_opaque_types;
}
void registerOpaqueType(const std::string& type_name) {
auto& global_opaque_types = getOpaqueTypes();
auto [_, inserted] = global_opaque_types.insert(type_name);
if (!inserted) {
throw std::runtime_error(
"Type '" + type_name + "' is already registered as an opaque type");
}
}
bool isRegisteredOpaqueType(const std::string& type_name) {
auto& global_opaque_types = getOpaqueTypes();
return global_opaque_types.find(type_name) != global_opaque_types.end();
}
TypePtr SchemaTypeParser::parseBaseType() {
static std::unordered_map<std::string, TypePtr> type_map = {
{"Generator", c10::TypeFactory::get<GeneratorType>()},
@ -101,11 +81,6 @@ TypePtr SchemaTypeParser::parseBaseType() {
}
std::string text = tok.text();
// Check if this type is registered as an opaque type first
if (isRegisteredOpaqueType(text)) {
return c10::TypeFactory::get<c10::PyObjectType>();
}
auto it = type_map.find(text);
if (it == type_map.end()) {
if (allow_typevars_ && !text.empty() && islower(text[0])) {

View File

@ -10,9 +10,6 @@ namespace torch::jit {
using TypePtr = c10::TypePtr;
TORCH_API void registerOpaqueType(const std::string& type_name);
TORCH_API bool isRegisteredOpaqueType(const std::string& type_name);
struct TORCH_API SchemaTypeParser {
TypePtr parseBaseType();
std::optional<c10::AliasInfo> parseAliasAnnotation();

View File

@ -15,7 +15,6 @@
#endif
#include <c10/core/SymNodeImpl.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/frontend/schema_type_parser.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/jit_log.h>
@ -1891,18 +1890,6 @@ void initJITBindings(PyObject* module) {
customObj->setPayload(std::move(payload));
},
R"doc(Sets the payload of the given opaque object with the given Python object.)doc");
m.def(
"_register_opaque_type",
[](const std::string& type_name) {
torch::jit::registerOpaqueType(type_name);
},
R"doc(Registers a type name to be treated as an opaque type (PyObject) in schema parsing.)doc");
m.def(
"_is_opaque_type_registered",
[](const std::string& type_name) -> bool {
return torch::jit::isRegisteredOpaqueType(type_name);
},
R"doc(Checks if a type name is registered as an opaque type.)doc");
m.def("unify_type_list", [](const std::vector<TypePtr>& types) {
std::ostringstream s;
auto type = unifyTypeList(types, s);