Files
pytorch/torch/_library/fake_class_registry.py
2025-10-08 09:09:16 +00:00

385 lines
14 KiB
Python

# mypy: allow-untyped-defs
import copy
import logging
from typing import Any, Optional, Protocol, Union
import torch
from torch._library.utils import parse_namespace
from torch.utils._python_dispatch import _disable_current_modes
log = logging.getLogger(__name__)
class FakeScriptObject:
def __init__(self, wrapped_obj: Any, script_class_name: str, x: torch.ScriptObject):
self.wrapped_obj = wrapped_obj
# The fully qualified name of the class of original script object
self.script_class_name = script_class_name
try:
with _disable_current_modes():
self.real_obj = copy.deepcopy(x)
except RuntimeError as e:
log.warning(
"Unable to deepcopy the custom object %s due to %s. "
"Defaulting to the user given object. This might be "
"dangerous as side effects may be directly applied "
"to the object.",
script_class_name,
str(e),
)
self.real_obj = x
class FakeScriptMethod:
def __init__(
self,
self_fake_obj: FakeScriptObject,
method_name: str,
schema: Optional[torch.FunctionSchema],
):
self.self_fake_obj = self_fake_obj
self.method_name = method_name
self.schema = schema
def __call__(self, *args, **kwargs):
from torch._higher_order_ops.torchbind import call_torchbind
return call_torchbind(self.self_fake_obj, self.method_name, *args, **kwargs)
class HasStaticMethodFromReal(Protocol):
@classmethod
def from_real(cls, real_obj: torch.ScriptObject):
pass
class FakeClassRegistry:
def __init__(self) -> None:
self._registered_class: dict[str, Any] = {}
def has_impl(self, full_qualname: str) -> bool:
return full_qualname in self._registered_class
def get_impl(self, full_qualname: str) -> Any:
self._check_registered(full_qualname)
return self._registered_class[full_qualname]
def register(self, full_qualname: str, fake_class=None) -> None:
if self.has_impl(full_qualname):
log.warning(
"%s is already registered. Previous fake class is overridden with %s.",
full_qualname,
fake_class,
)
self._registered_class[full_qualname] = fake_class
def deregister(self, full_qualname: str) -> Any:
if not self.has_impl(full_qualname):
log.warning(
"Cannot deregister %s. Please use register_fake_class to register it first."
" Or do you dereigster it twice?",
full_qualname,
)
else:
return self._registered_class.pop(full_qualname)
def clear(self) -> None:
self._registered_class.clear()
def _check_registered(self, full_qualname: str) -> None:
if full_qualname not in self._registered_class:
raise RuntimeError(
f"{full_qualname} is not registered. Please use register_fake_class to register it first."
)
global_fake_class_registry = FakeClassRegistry()
# TODO: add this check at compile time for __obj_flatten__.
def _check_valid_flat_script_obj(flat_x):
if not isinstance(flat_x, tuple):
raise RuntimeError("Expect flat x to be a tuple.")
for tp in flat_x:
if not isinstance(tp, tuple):
raise RuntimeError("Expect flat x to be a tuple of tuples.")
if not len(tp) == 2 or not isinstance(tp[0], str):
raise RuntimeError(
"Expect element of flat x to be a tuple of two elements with first element being a string"
)
def tracing_with_real(x: torch.ScriptObject) -> bool:
if not hasattr(x, "tracing_mode"):
return False
assert x.tracing_mode() in [
"real",
"fake",
], f"tracing_mode can be either real or fake but got {x.tracing_mode()}"
return x.tracing_mode() == "real"
def maybe_to_fake_obj(
fake_mode, x: torch.ScriptObject
) -> Union[FakeScriptObject, torch.ScriptObject]:
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import _disable_current_modes
# When tracing with real mode, people should implement meta kernels that can
# handle the case of real script object + fake tensor inputs.
if tracing_with_real(x):
return x
from torch._library.opaque_object import FakeOpaqueObject, OpaqueTypeStr
if str(x._type()) == OpaqueTypeStr:
# In order to make OpaqueObjects truly opaque, the fake kernel should
# not depend on the contents of the OpaqueObject at all.
fake_x = FakeOpaqueObject()
else:
# x.__obj_flatten__() could be calling some tensor operations inside but we don't
# want to call these ops in surrounding dispatch modes when executing it.
# Otherwise, for example, the fake tensor modes will error out when the tensors inside
# script object execute some operations like clone if allow_non_fake_input flag is set.
with _disable_current_modes():
flat_x = x.__obj_flatten__() # type: ignore[attr-defined]
_check_valid_flat_script_obj(flat_x)
with fake_mode:
from torch._higher_order_ops.utils import _tensor_storage
storage_map = {
_tensor_storage(inp): i
for i, inp in enumerate(flat_x)
if isinstance(inp, torch.Tensor)
}
alias_map = {
i: storage_map[_tensor_storage(inp)]
for i, inp in enumerate(flat_x)
if isinstance(inp, torch.Tensor)
and storage_map[_tensor_storage(inp)] != i
}
if len(alias_map) > 0:
log.warning(
"Detected script object %s has aliasing relationship among its tensors. "
"Flattened obj: %s. Aliasing tensor indices: %s. "
"This is not supported and may cause unexpected behavior.",
x,
flat_x,
alias_map,
)
# This breaks the aliasing relationship among the tensors inside the torchbind object
# This is bad but since we don't need to preserve the aliasing relationship anyway and
# we state clearly that aliasing relationship is not preserved in the doc so this might be OK.
fake_flattened = pytree.tree_map_only(
torch.Tensor,
lambda t: torch.empty_strided(
t.size(),
t.stride(),
device=t.device,
dtype=t.dtype,
requires_grad=t.requires_grad,
layout=t.layout,
),
flat_x,
)
fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened)
fake_x_wrapped = FakeScriptObject(fake_x, x._type().qualified_name(), x) # type: ignore[attr-defined]
for name in x._method_names(): # type: ignore[attr-defined]
attr = getattr(fake_x, name, None)
if attr is not None:
if not callable(attr):
raise RuntimeError(f"Expect {name} to be a callable but got {attr}.")
real_attr = getattr(x, name) # type: ignore[attr-defined]
# real attr sometimes is not torch.ScriptMethod thus doesn't have schema e.g. __init___ or __eq__
method_schema: Optional[torch.FunctionSchema] = None
if isinstance(real_attr, torch.ScriptMethod):
method_schema = real_attr.schema # type: ignore[attr-defined]
setattr(
fake_x_wrapped,
name,
FakeScriptMethod(fake_x_wrapped, name, method_schema),
)
else:
override_skip_list = {"__obj_flatten__", "__getstate__", "__setstate__"}
if name not in override_skip_list:
log.warning("fake object of %s doesn't implement method %s.", x, name)
return fake_x_wrapped
def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal] = None):
r"""Register a fake implementation for this class.
It's in the same spirit of registering a fake implementation for
an operator but with the difference that it
associates a fake class with the original torch bind class (registered
with torch::class_). In this way, torch.compile can handle them properly
in components such as Dynamo and AOTAutograd.
This API may be used as a decorator (see example). For the fake class, users
are required to provide a from_real classmethod that takes a real object and
returns an instance of the fake class. All tensors in the fake object should also
be properly fakified with to_fake_tensor() in from_real.
Examples:
# For a custom class Foo defined in test_custom_class_registration.cpp:
TORCH_LIBRARY(_TorchScriptTesting, m) {
m.class_<TensorQueue>("_TensorQueue")
.def(torch::init<at::Tensor>())
.def("push", &TensorQueue::push)
.def("pop", &TensorQueue::pop)
.def("top", &TensorQueue::top)
.def("size", &TensorQueue::size)
.def("clone_queue", &TensorQueue::clone_queue)
.def("__obj_flatten__", &TensorQueue::__obj_flatten__)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<TensorQueue>& self)
-> c10::Dict<std::string, at::Tensor> {
return self->serialize();
},
// __setstate__
[](c10::Dict<std::string, at::Tensor> data)
-> c10::intrusive_ptr<TensorQueue> {
return c10::make_intrusive<TensorQueue>(std::move(data));
});
};
# We could register a fake class FakeTensorQueue in Python as follows:
import torch
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
class FakeTensorQueue:
def __init__(self, queue):
self.queue = queue
@classmethod
def __obj_unflatten__(cls, flattened_ctx):
return cls(**dict(ctx))
def push(self, x):
self.queue.append(x)
def pop(self):
return self.queue.pop(0)
def size(self):
return len(self.queue)
In this example, the original TensorQeue need to add a __obj_flatten__ method
to the class TensorQueue and the flattened result is passed into FakeTensorQueue's
__obj_unflatten__ as inputs to create a fake class. This protocol allows pytorch to look
at the contents of the script object and properly handle them in the subsystems
like dynamo, aot_aotugrad or more.
"""
def inner(fake_class: HasStaticMethodFromReal):
ns, name = parse_namespace(qualname)
# This also checks whether the referred torch::class_ exists.
torch._C._get_custom_class_python_wrapper(ns, name)
from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
if not from_method:
raise RuntimeError(
f"{fake_class} doesn't define a classmethod {_CONVERT_FROM_REAL_NAME}."
)
if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod):
raise RuntimeError(
f"{_CONVERT_FROM_REAL_NAME} method is not a classmethod."
)
global_fake_class_registry.register(_full_qual_class_name(qualname), fake_class)
return fake_class
if fake_class is None:
return inner
return inner(fake_class)
def deregister_fake_class(qualname):
return global_fake_class_registry.deregister(_full_qual_class_name(qualname))
def has_fake_class(full_qualname) -> bool:
return global_fake_class_registry.has_impl(full_qualname)
def find_fake_class(full_qualname) -> Optional[Any]:
if not has_fake_class(full_qualname):
return None
return global_fake_class_registry.get_impl(full_qualname)
def _full_qual_class_name(qualname: str) -> str:
ns, name = parse_namespace(qualname)
return "__torch__.torch.classes." + ns + "." + name
def _is_script_object(obj: Any) -> bool:
return isinstance(
obj, torch.ScriptObject
) and obj._type().qualified_name().startswith( # type: ignore[attr-defined]
"__torch__.torch.classes"
)
# Return the namespace and class name from fully qualified name.
def _ns_and_class_name(full_qualname: str) -> tuple[str, str]:
splits = full_qualname.split(".")
assert len(splits) == 5, f"Could not split {full_qualname=}"
_torch, _torch_ns, _classes, ns, class_name = splits
return ns, class_name
def _find_fake_class_for_script_object(x: torch.ScriptObject) -> Any:
full_qualname = x._type().qualified_name() # type: ignore[attr-defined]
ns, class_name = _ns_and_class_name(full_qualname)
fake_class = find_fake_class(full_qualname)
if fake_class is None:
raise RuntimeError(
f" ScriptObject's {full_qualname} haven't registered a fake class."
f" Please use register_fake_class({ns}::{class_name}) to annotate a fake class for the script obj."
f" Specifically, create a python class that implements a fake version for all the methods"
f" that're used in the program and put annotated class in the program e.g. after loading the library."
f" The fake methods can be written in the same way as a meta kernel for an operator but need to additionally"
f" simulate the object's states. Be sure to add a {_CONVERT_FROM_REAL_NAME} classmethod"
f" to enable creating a fake obj from a real one."
)
return fake_class
_CONVERT_FROM_REAL_NAME = "__obj_unflatten__"
def _fake_obj_from_real(fake_mode, x) -> Any:
fake_class = _find_fake_class_for_script_object(x)
from_real_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
if not from_real_method:
raise RuntimeError(
f"{fake_class} must define a classmethod {_CONVERT_FROM_REAL_NAME}"
f" that converts the real object to the fake object."
)
# from_real defined by user need the ctx to fakify the tensor states.
ctx = torch._library.fake_impl.FakeImplCtx(fake_mode, None)
with torch._library.fake_impl.set_ctx_getter(lambda: ctx):
return fake_class.from_real(x)