Files
pytorch/torch/_library/opaque_object.py
Maggie Moss d795fb225a [RFC] Add pyrefly to lintrunner (#165179)
This will add pyrefly to lint runner as a warning only - and allow us to collect feedback about the tool before switching to pyrefly as the main type checker.

References the steps outlined here: : https://github.com/pytorch/pytorch/issues/163283:

test plan:
`lintrunner init`
`lintrunner`
confirm when pyrefly errors are present results look like: https://gist.github.com/maggiemoss/e6cb2d015dd1ded560ae1329098cf33f

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165179
Approved by: https://github.com/ezyang
2025-10-16 20:07:09 +00:00

188 lines
6.5 KiB
Python

from typing import Any, NewType, Optional
import torch
from .fake_class_registry import FakeScriptObject, register_fake_class
@register_fake_class("aten::OpaqueObject")
class FakeOpaqueObject:
def __init__(self) -> None:
pass
@classmethod
def __obj_unflatten__(cls, flattened_ctx: dict[str, Any]) -> None:
raise RuntimeError(
"FakeOpaqueObject should not be created through __obj_unflatten__ "
"and should be special handled. Please file an issue to Github."
)
OpaqueTypeStr = "__torch__.torch.classes.aten.OpaqueObject"
OpaqueType = NewType("OpaqueType", torch._C.ScriptObject)
def make_opaque(payload: Any = None) -> torch._C.ScriptObject:
"""
Creates an opaque object which stores the given Python object.
This opaque object can be passed to any custom operator as an argument.
The Python object can then be accessed from the opaque object using the `get_payload()` API.
The opaque object has `._type()`
"__torch__.torch.classes.aten.OpaqueObject", which should be the type used
when creating custom operator schemas.
Args:
payload (Any): The Python object to store in the opaque object. This can
be empty, and can be set with `set_payload()` later.
Returns:
torch._C.ScriptObject: The opaque object that stores the given Python object.
Example:
>>> import random
>>> import torch
>>> from torch._library.opaque_object import (
... make_opaque,
... get_payload,
... set_payload,
... )
>>>
>>> class RNGState:
>>> def __init__(self, seed):
>>> self.rng = random.Random(seed)
>>>
>>> rng = RNGState(0)
>>> obj = make_opaque()
>>> set_payload(obj, rng)
>>>
>>> assert get_payload(obj) == rng
>>>
>>> lib = torch.library.Library("mylib", "FRAGMENT")
>>>
>>> torch.library.define(
>>> "mylib::noisy_inject",
>>> "(Tensor x, __torch__.torch.classes.aten.OpaqueObject obj) -> Tensor",
>>> tags=torch.Tag.pt2_compliant_tag,
>>> lib=lib,
>>> )
>>>
>>> @torch.library.impl(
>>> "mylib::noisy_inject", "CompositeExplicitAutograd", lib=lib
>>> )
>>> def noisy_inject(x: torch.Tensor, obj: torch._C.ScriptObject) -> torch.Tensor:
>>> rng_state = get_payload(obj)
>>> assert isinstance(rng_state, RNGState)
>>> out = x.clone()
>>> for i in range(out.numel()):
>>> out.view(-1)[i] += rng_state.rng.random()
>>> return out
>>>
>>> print(torch.ops.mylib.noisy_inject(torch.ones(3), obj))
"""
return torch._C._make_opaque_object(payload)
def get_payload(opaque_object: torch._C.ScriptObject) -> Any:
"""
Retrieves the Python object stored in the given opaque object.
Args:
torch._C.ScriptObject: The opaque object that stores the given Python object.
Returns:
payload (Any): The Python object stored in the opaque object. This can
be set with `set_payload()`.
"""
if isinstance(opaque_object, FakeScriptObject):
raise ValueError(
"get_payload: this function was called with a FakeScriptObject "
"implying that you are calling get_payload inside of a fake kernel."
"The fake kernel should not depend on the contents of the "
"OpaqueObject at all, so we're erroring out. If you need this"
"functionality, consider creating a custom TorchBind Object instead"
"(but note that this is more difficult)."
)
if not (
isinstance(opaque_object, torch._C.ScriptObject)
and opaque_object._type().qualified_name() == OpaqueTypeStr
):
type_ = (
opaque_object._type().qualified_name()
if isinstance(opaque_object, torch._C.ScriptObject)
else type(opaque_object)
)
raise ValueError(
f"Tried to get the payload from a non-OpaqueObject of type `{type_}`"
)
return torch._C._get_opaque_object_payload(opaque_object)
def set_payload(opaque_object: torch._C.ScriptObject, payload: Any) -> None:
"""
Sets the Python object stored in the given opaque object.
Args:
torch._C.ScriptObject: The opaque object that stores the given Python object.
payload (Any): The Python object to store in the opaque object.
"""
if isinstance(opaque_object, FakeScriptObject):
raise ValueError(
"set_payload: this function was called with a FakeScriptObject "
"implying that you are calling get_payload inside of a fake kernel."
"The fake kernel should not depend on the contents of the "
"OpaqueObject at all, so we're erroring out. If you need this"
"functionality, consider creating a custom TorchBind Object instead"
"(but note that this is more difficult)."
)
if not (
isinstance(opaque_object, torch._C.ScriptObject)
and opaque_object._type().qualified_name() == OpaqueTypeStr
):
type_ = (
opaque_object._type().qualified_name()
if isinstance(opaque_object, torch._C.ScriptObject)
else type(opaque_object)
)
raise ValueError(
f"Tried to get the payload from a non-OpaqueObject of type `{type_}`"
)
torch._C._set_opaque_object_payload(opaque_object, payload)
_OPAQUE_TYPES: dict[Any, str] = {}
def register_opaque_type(cls: Any, name: Optional[str] = None) -> None:
"""
Registers the given type as an opaque type which allows this to be consumed
by a custom operator.
Args:
cls (type): The class to register as an opaque type.
name (str): A unique qualified name of the type.
"""
if name is None:
name = cls.__name__
if "." in name:
# The schema_type_parser will break up types with periods
raise ValueError(
f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
)
_OPAQUE_TYPES[cls] = name
# pyrefly: ignore # missing-attribute
torch._C._register_opaque_type(name)
def is_opaque_type(cls: Any) -> bool:
"""
Checks if the given type is an opaque type.
"""
if cls not in _OPAQUE_TYPES:
return False
# pyrefly: ignore # missing-attribute
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])