[dynamo] Make nonstrict_trace work with some pytree.register_constant-ed instances (#148007)

As title, this enables `nonstrict_trace`-ed function to take in object
whose type has been `pytree.register_constant`-ed, as long as the object
existed outside the `torch.compile` region. This also forces Dynamo to
emit a `EQUALS_MATCH` guard on the object.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148007
Approved by: https://github.com/zou3519
ghstack dependencies: #148385
This commit is contained in:
Ryan Guo
2025-03-04 15:35:04 -08:00
committed by PyTorch MergeBot
parent a10f577ee0
commit ad9a10aff0
10 changed files with 256 additions and 38 deletions

View File

@ -471,6 +471,49 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
res = opt_fn(x, y)
self.assertEqual(ref, res)
def test_nonstrict_trace_pre_existing_register_constant_type_guard(self):
class State:
def __init__(self, n):
self.n = n
def get_num(self):
torch._dynamo.graph_break()
return self.n
def __eq__(self, other):
return isinstance(other, State) and self.n == other.n
def __hash__(self):
return hash(self.n)
# Assume `State` is implemented in C, and the author didn't bother to
# provide a pytree decomposition for it, and its instances are safe to
# treat as a constant by `torch.compile`.
torch.utils._pytree.register_constant(State)
@torch._dynamo.nonstrict_trace
def trace_me(x, s):
return x * s.get_num()
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
@torch.compile(fullgraph=True, backend=cnts)
def fn(x, s):
res = trace_me(x, s)
return res
x = torch.ones(10)
# Make sure recompilation didn't happen.
self.assertEqual(cnts.frame_count, 0)
fn(x, State(42))
self.assertEqual(cnts.frame_count, 1)
fn(x, State(42))
self.assertEqual(cnts.frame_count, 1)
# Make sure recompilation did happen.
fn(x, State(41))
self.assertEqual(cnts.frame_count, 2)
def test_nonstrict_trace_tuple_and_sym_int_output(self):
@torch._dynamo.nonstrict_trace
def trace_me(x):
@ -632,6 +675,7 @@ Applying `nonstrict_trace` to function <trace_me>; however, `nonstrict_trace` cu
except torch._dynamo.exc.Unsupported as e:
msg = """
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <DecoratorTests.test_nonstrict_trace_custom_class_error.<locals>.Point>, please use one of the following to register the type with pytree:
* `torch.utils._pytree.register_constant`
* `torch.utils._pytree.register_dataclass`
* `torch.utils._pytree.register_pytree_node`
""" # NOQA: B950
@ -683,39 +727,104 @@ For `nonstrict_trace`-ed function, the only allowed input types are basic types
except torch._dynamo.exc.Unsupported as e:
msg = """
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <DecoratorTests.test_nonstrict_trace_nested_custom_class_error.<locals>.Point>, please use one of the following to register the type with pytree:
* `torch.utils._pytree.register_constant`
* `torch.utils._pytree.register_dataclass`
* `torch.utils._pytree.register_pytree_node`
""" # NOQA: B950
self.assertIn(msg, str(e))
def test_nonstrict_trace_pytree_register_constant_error(self):
def test_nonstrict_newly_constructed_trace_register_constant_type_error(self):
class State:
def __init__(self, n):
self.n = n
def get_num(self):
torch._dynamo.graph_break()
return self.n
def __eq__(self, other):
return isinstance(other, State) and self.n == other.n
def __hash__(self):
return hash(self.n)
# Assume `State` is implemented in C, and the author didn't bother to
# provide a pytree decomposition for it, and its instances are safe to
# treat as a constant by `torch.compile`.
torch.utils._pytree.register_constant(State)
@torch._dynamo.nonstrict_trace
def trace_me(x, s):
return x * s.get_num()
@torch.compile(fullgraph=True, backend="aot_eager")
def fn(x):
s = State(10)
res = trace_me(x, s)
return res
try:
x = torch.ones(10)
fn(x)
self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e:
msg = """
You are calling a `nonstrict_trace`-ed function with an input that contains an object of type <DecoratorTests.test_nonstrict_newly_constructed_trace_register_constant_type_error.<locals>.State>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region.
Please construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.
""" # NOQA: B950
self.assertIn(msg, str(e))
def test_nonstrict_trace_object_in_context_error(self):
class Point:
x: int
y: int
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
torch.utils._pytree.register_constant(Point)
class PointTensor:
p: Point
t: torch.Tensor
def __init__(self, p, t):
self.p = p
self.t = t
torch.utils._pytree.register_pytree_node(
PointTensor,
lambda pt: ((pt.t,), pt.p),
lambda ts, p: PointTensor(p, ts[0]),
)
@torch._dynamo.nonstrict_trace
def trace_me(x, p):
def trace_me(pt):
torch._dynamo.graph_break()
return x * p.x + p.y
return pt.t + pt.p.x * pt.p.y
@torch.compile(fullgraph=True, backend="aot_eager")
def fn(x, p):
res = trace_me(x, p)
return res + 1
def fn(x, y):
p = Point(x, y)
t = x + y
pt = PointTensor(p, t)
res = trace_me(pt)
return res
try:
p = Point(3, 4)
fn(torch.ones(10), p)
x, y = torch.ones(10), torch.ones(1)
fn(x, y)
self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e:
msg = """
This error is most likely due to a call to `nonstrict_trace`-ed function, where one of the argument contains object of a type that has been (or needs to be) `torch.utils._pytree.register_constant`-ed. We currently don't support that.
You are calling a `nonstrict_trace`-ed function where one one of the inputs has been registered with a `pytree_flatten` that puts an object of type <DecoratorTests.test_nonstrict_trace_object_in_context_error.<locals>.Point> into the context.
Please consider modifying that `pytree_flatten` to avoid putting the object into context, and apply one of the following to <DecoratorTests.test_nonstrict_trace_object_in_context_error.<locals>.Point>
* `torch.utils._pytree.register_constant`
* `torch.utils._pytree.register_dataclass`
* `torch.utils._pytree.register_pytree_node`
If the above doesn't work, please subtmit an issue to GitHub.
""" # NOQA: B950
self.assertIn(msg, str(e))

View File

@ -24,7 +24,7 @@ def distance(a, b, norm):
return (a.x - b.x).abs() + (a.y - b.y).abs()
@dataclass
@dataclass(frozen=True)
class Norm:
typ: str

View File

@ -9754,7 +9754,7 @@ graph():
@testing.expectedFailureSerDerNonStrict # register_constant needs to handle serialization
@testing.expectedFailureSerDer # register_constant needs to handle serialization
def test_register_constant(self):
@dataclass
@dataclass(frozen=True)
class MyInput:
int_1: int
int_2: int

View File

@ -1180,7 +1180,9 @@ if "optree" in sys.modules:
self.assertEqual(point.y, torch.tensor(2))
def test_constant(self):
@dataclass
# Either use `frozen=True` or `unsafe_hash=True` so we have a
# non-default `__hash__`.
@dataclass(unsafe_hash=True)
class Config:
norm: str
@ -1191,6 +1193,33 @@ if "optree" in sys.modules:
self.assertEqual(elements, [])
self.assertEqual(spec.context.value, config)
def test_constant_default_eq_error(self):
class Config:
def __init__(self, norm: str):
self.norm = norm
try:
py_pytree.register_constant(Config)
self.assertFalse(True) # must raise error before this
except TypeError as e:
msg = "register_constant(cls) expects `cls` to have a non-default `__eq__` implementation."
self.assertIn(msg, str(e))
def test_constant_default_hash_error(self):
class Config:
def __init__(self, norm: str):
self.norm = norm
def __eq__(self, other):
return self.norm == other.norm
try:
py_pytree.register_constant(Config)
self.assertFalse(True) # must raise error before this
except TypeError as e:
msg = "register_constant(cls) expects `cls` to have a non-default `__hash__` implementation."
self.assertIn(msg, str(e))
def test_tree_map_with_path_multiple_trees(self):
@dataclass
class ACustomPytree:

View File

@ -1628,15 +1628,15 @@ class GuardBuilder(GuardBuilderBase):
DeviceMesh,
)
def check_type(obj):
import torch.utils._pytree as pytree
return istype(obj, ok_types) or pytree.is_constant_class(type(obj))
if istype(val, dict):
assert all(
istype(x, ok_types) for x in itertools.chain(val.keys(), val.values())
)
assert all(check_type(x) for x in itertools.chain(val.keys(), val.values()))
else:
assert istype(
val,
ok_types,
), f"Unexpected type {type(val)}, not in {ok_types}"
assert check_type(val), f"Unexpected type {type(val)}"
# Special case for nan because float("nan") == float("nan") evaluates to False
if istype(val, float) and math.isnan(val):

View File

@ -194,6 +194,17 @@ def is_side_effect_safe(m: MutationType):
return m.scope == scope_id
# This helps users of `as_python_constant` to catch unimplemented error with
# more information; it inherits `NotImplementedError` for backward
# compatibility reasons.
class AsPythonConstantNotImplementedError(NotImplementedError):
vt: "VariableTracker"
def __init__(self, vt: "VariableTracker"):
super().__init__(self, f"{vt} is not a constant")
self.vt = vt
class VariableTrackerMeta(type):
all_subclasses = []
@ -319,7 +330,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
def as_python_constant(self):
"""For constants"""
raise NotImplementedError(f"{self} is not a constant")
raise AsPythonConstantNotImplementedError(self)
def guard_as_python_constant(self):
"""Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""

View File

@ -994,6 +994,8 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
from torch._subclasses.fake_tensor import fake_tensor_tls
from torch.utils._pytree import tree_flatten
from .base import AsPythonConstantNotImplementedError
# 1. Convert `args, kwargs` into pytree-flattened proxy forms.
#
# Rather than reconstructing `args, kwargs` into python objects and
@ -1018,6 +1020,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
unimplemented(
f"""
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <{type_name}>, please use one of the following to register the type with pytree:
* `torch.utils._pytree.register_constant`
* `torch.utils._pytree.register_dataclass`
* `torch.utils._pytree.register_pytree_node`
""" # NOQA: B950
@ -1033,16 +1036,34 @@ For `nonstrict_trace`-ed function, the only allowed input types are basic types
# the spec not a graphable type, so we still have to reconstruct it
# into a python object, and store it as a constant attribute on the
# fx graph.
#
# TODO handle `pytree._register_constant`-ed values.
try:
input_spec = input_spec_vt.as_python_constant()
except NotImplementedError:
unimplemented(
"""
This error is most likely due to a call to `nonstrict_trace`-ed function, where one of the argument contains object of a type that has been (or needs to be) `torch.utils._pytree.register_constant`-ed. We currently don't support that.
except AsPythonConstantNotImplementedError as e:
typ = e.vt.python_type()
type_name = typ.__qualname__
import torch.utils._pytree as pytree
if pytree.is_constant_class(typ):
unimplemented(
f"""
You are calling a `nonstrict_trace`-ed function with an input that contains an object of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region.
Please construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.
""" # NOQA: B950
)
else:
unimplemented(
f"""
You are calling a `nonstrict_trace`-ed function where one one of the inputs has been registered with a `pytree_flatten` that puts an object of type <{type_name}> into the context.
Please consider modifying that `pytree_flatten` to avoid putting the object into context, and apply one of the following to <{type_name}>
* `torch.utils._pytree.register_constant`
* `torch.utils._pytree.register_dataclass`
* `torch.utils._pytree.register_pytree_node`
If the above doesn't work, please subtmit an issue to GitHub.
""" # NOQA: B950
)
)
fn = self.value

View File

@ -764,6 +764,17 @@ class UserDefinedObjectVariable(UserDefinedVariable):
def python_type(self):
return self.value_type
def as_python_constant(self):
import torch.utils._pytree as pytree
if pytree.is_constant_class(self.value_type):
if self.source is not None:
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
return self.value
# TODO else try reconstructing the object by, e.g., leveraging side
# effects and `as_python_constant`.
return super().as_python_constant()
def guard_as_python_constant(self):
if self.source:
install_guard(self.source.make_guard(GuardBuilder.ID_MATCH))
@ -1352,7 +1363,9 @@ class FrozenDataClassVariable(UserDefinedObjectVariable):
import torch.utils._pytree as pytree
if not istype(self.value, (pytree.TreeSpec, pytree.LeafSpec)):
if not istype(
self.value, (pytree.TreeSpec, pytree.LeafSpec, pytree.ConstantNode)
):
# TODO loosen this restriction and fix `as_proxy`.
raise NotImplementedError(
"currently can't reconstruct arbitrary frozen dataclass instances"

View File

@ -47,7 +47,7 @@ def func_to_graphable(func):
return pytree.tree_flatten(_ConstantFunction(func))
@dataclass
@dataclass(frozen=True)
class _ConstantFunction:
func: Callable

View File

@ -290,19 +290,37 @@ def register_dataclass(cls: type[Any]) -> None:
torch.export.register_dataclass(cls)
CONSTANT_NODES: set[type] = set()
def register_constant(cls: type[Any]) -> None:
"""Registers a type as a pytree node with no leaves.
Instances of these types are treated as a constant (sometimes referred to as
"static") by :func:`torch.compile`. When used in a function compiled by
:func:`torch.compile`, :func:`torch.compile` guards on the instance
object's hash: if :func:`torch.compile` sees a new hash then
In a :func:`torch.compile` region, if instances of these types get passed to
:func:`torch._dynamo.nonstrict_trace`-ed function, they treated as a
constant (sometimes referred to as "static"):
1. if the instance object existed before the :func:`torch.compile` region,
we _assume_ no mutation will happen to it inside the :func:`torch.compile`
region, require that it has non-default `__eq__` and `__hash__` methods, and
we guard on the instance based on its `__eq__` method, i.e., if a new
instance fails to match any instances from the previous compilations,
:func:`torch.compile` will recompile the function using the new instance.
2. else if the instance object is created inside the :func:`torch.compile`
region, we currently don't support using it in a
:func:`torch._dynamo.nonstrict_trace`-ed function.
In general, if your class holds Tensors or dynamic int/float/bool (values that
may change from run-to-run of a function being compiled), then you probably
do not want to register it as a constant.
Otherwise if you want to pass instance of a class to a
:func:`torch._dynamo.nonstrict_trace`-ed function, but you either can't use
:func:`register_pytree_node` on the class, or the class is "constant" enough
that you don't want to bother using :func:`register_pytree_node`, you should
consider using this function.
Args:
cls: the type to register as a constant. This type must be hashable.
@ -311,7 +329,7 @@ def register_constant(cls: type[Any]) -> None:
>>> from dataclasses import dataclass
>>> import torch.utils._pytree as pytree
>>>
>>> @dataclass
>>> @dataclass(frozen=True)
>>> class Config:
>>> norm: str
>>>
@ -322,6 +340,17 @@ def register_constant(cls: type[Any]) -> None:
>>> assert len(values) == 0
"""
if cls.__eq__ is object.__eq__: # type: ignore[comparison-overlap]
raise TypeError(
"register_constant(cls) expects `cls` to have a non-default `__eq__` implementation."
)
# Class with a custom `__eq__` without `__hash__` won't inherit the default
# `__hash__` from object; see https://stackoverflow.com/a/1608907.
if cls.__hash__ is None: # type: ignore[comparison-overlap]
raise TypeError(
"register_constant(cls) expects `cls` to have a non-default `__hash__` implementation."
)
def _flatten(x): # type: ignore[no-untyped-def]
return [], ConstantNode(x)
@ -338,9 +367,14 @@ def register_constant(cls: type[Any]) -> None:
_unflatten,
flatten_with_keys_fn=_flatten_with_keys,
)
CONSTANT_NODES.add(cls)
@dataclasses.dataclass
def is_constant_class(cls: type[Any]) -> bool:
return isinstance(cls, type) and cls in CONSTANT_NODES
@dataclasses.dataclass(frozen=True)
class ConstantNode:
value: Any
@ -455,6 +489,7 @@ def _deregister_pytree_node(
node_def = SUPPORTED_SERIALIZED_TYPES[cls]
del SERIALIZED_TYPE_TO_PYTHON_TYPE[node_def.serialized_type_name]
del SUPPORTED_SERIALIZED_TYPES[cls]
CONSTANT_NODES.discard(cls)
def _private_register_pytree_node(