mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Make nonstrict_trace
work with some pytree.register_constant
-ed instances (#148007)
As title, this enables `nonstrict_trace`-ed function to take in object whose type has been `pytree.register_constant`-ed, as long as the object existed outside the `torch.compile` region. This also forces Dynamo to emit a `EQUALS_MATCH` guard on the object. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148007 Approved by: https://github.com/zou3519 ghstack dependencies: #148385
This commit is contained in:
committed by
PyTorch MergeBot
parent
a10f577ee0
commit
ad9a10aff0
@ -471,6 +471,49 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
res = opt_fn(x, y)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_nonstrict_trace_pre_existing_register_constant_type_guard(self):
|
||||
class State:
|
||||
def __init__(self, n):
|
||||
self.n = n
|
||||
|
||||
def get_num(self):
|
||||
torch._dynamo.graph_break()
|
||||
return self.n
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, State) and self.n == other.n
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.n)
|
||||
|
||||
# Assume `State` is implemented in C, and the author didn't bother to
|
||||
# provide a pytree decomposition for it, and its instances are safe to
|
||||
# treat as a constant by `torch.compile`.
|
||||
torch.utils._pytree.register_constant(State)
|
||||
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_me(x, s):
|
||||
return x * s.get_num()
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
|
||||
@torch.compile(fullgraph=True, backend=cnts)
|
||||
def fn(x, s):
|
||||
res = trace_me(x, s)
|
||||
return res
|
||||
|
||||
x = torch.ones(10)
|
||||
# Make sure recompilation didn't happen.
|
||||
self.assertEqual(cnts.frame_count, 0)
|
||||
fn(x, State(42))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
fn(x, State(42))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
# Make sure recompilation did happen.
|
||||
fn(x, State(41))
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
def test_nonstrict_trace_tuple_and_sym_int_output(self):
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_me(x):
|
||||
@ -632,6 +675,7 @@ Applying `nonstrict_trace` to function <trace_me>; however, `nonstrict_trace` cu
|
||||
except torch._dynamo.exc.Unsupported as e:
|
||||
msg = """
|
||||
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <DecoratorTests.test_nonstrict_trace_custom_class_error.<locals>.Point>, please use one of the following to register the type with pytree:
|
||||
* `torch.utils._pytree.register_constant`
|
||||
* `torch.utils._pytree.register_dataclass`
|
||||
* `torch.utils._pytree.register_pytree_node`
|
||||
""" # NOQA: B950
|
||||
@ -683,39 +727,104 @@ For `nonstrict_trace`-ed function, the only allowed input types are basic types
|
||||
except torch._dynamo.exc.Unsupported as e:
|
||||
msg = """
|
||||
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <DecoratorTests.test_nonstrict_trace_nested_custom_class_error.<locals>.Point>, please use one of the following to register the type with pytree:
|
||||
* `torch.utils._pytree.register_constant`
|
||||
* `torch.utils._pytree.register_dataclass`
|
||||
* `torch.utils._pytree.register_pytree_node`
|
||||
""" # NOQA: B950
|
||||
self.assertIn(msg, str(e))
|
||||
|
||||
def test_nonstrict_trace_pytree_register_constant_error(self):
|
||||
def test_nonstrict_newly_constructed_trace_register_constant_type_error(self):
|
||||
class State:
|
||||
def __init__(self, n):
|
||||
self.n = n
|
||||
|
||||
def get_num(self):
|
||||
torch._dynamo.graph_break()
|
||||
return self.n
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, State) and self.n == other.n
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.n)
|
||||
|
||||
# Assume `State` is implemented in C, and the author didn't bother to
|
||||
# provide a pytree decomposition for it, and its instances are safe to
|
||||
# treat as a constant by `torch.compile`.
|
||||
torch.utils._pytree.register_constant(State)
|
||||
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_me(x, s):
|
||||
return x * s.get_num()
|
||||
|
||||
@torch.compile(fullgraph=True, backend="aot_eager")
|
||||
def fn(x):
|
||||
s = State(10)
|
||||
res = trace_me(x, s)
|
||||
return res
|
||||
|
||||
try:
|
||||
x = torch.ones(10)
|
||||
fn(x)
|
||||
self.assertFalse(True) # must raise error before this
|
||||
except torch._dynamo.exc.Unsupported as e:
|
||||
msg = """
|
||||
You are calling a `nonstrict_trace`-ed function with an input that contains an object of type <DecoratorTests.test_nonstrict_newly_constructed_trace_register_constant_type_error.<locals>.State>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region.
|
||||
|
||||
Please construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.
|
||||
""" # NOQA: B950
|
||||
self.assertIn(msg, str(e))
|
||||
|
||||
def test_nonstrict_trace_object_in_context_error(self):
|
||||
class Point:
|
||||
x: int
|
||||
y: int
|
||||
x: torch.Tensor
|
||||
y: torch.Tensor
|
||||
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
torch.utils._pytree.register_constant(Point)
|
||||
class PointTensor:
|
||||
p: Point
|
||||
t: torch.Tensor
|
||||
|
||||
def __init__(self, p, t):
|
||||
self.p = p
|
||||
self.t = t
|
||||
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
PointTensor,
|
||||
lambda pt: ((pt.t,), pt.p),
|
||||
lambda ts, p: PointTensor(p, ts[0]),
|
||||
)
|
||||
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_me(x, p):
|
||||
def trace_me(pt):
|
||||
torch._dynamo.graph_break()
|
||||
return x * p.x + p.y
|
||||
return pt.t + pt.p.x * pt.p.y
|
||||
|
||||
@torch.compile(fullgraph=True, backend="aot_eager")
|
||||
def fn(x, p):
|
||||
res = trace_me(x, p)
|
||||
return res + 1
|
||||
def fn(x, y):
|
||||
p = Point(x, y)
|
||||
t = x + y
|
||||
pt = PointTensor(p, t)
|
||||
res = trace_me(pt)
|
||||
return res
|
||||
|
||||
try:
|
||||
p = Point(3, 4)
|
||||
fn(torch.ones(10), p)
|
||||
x, y = torch.ones(10), torch.ones(1)
|
||||
fn(x, y)
|
||||
self.assertFalse(True) # must raise error before this
|
||||
except torch._dynamo.exc.Unsupported as e:
|
||||
msg = """
|
||||
This error is most likely due to a call to `nonstrict_trace`-ed function, where one of the argument contains object of a type that has been (or needs to be) `torch.utils._pytree.register_constant`-ed. We currently don't support that.
|
||||
You are calling a `nonstrict_trace`-ed function where one one of the inputs has been registered with a `pytree_flatten` that puts an object of type <DecoratorTests.test_nonstrict_trace_object_in_context_error.<locals>.Point> into the context.
|
||||
|
||||
Please consider modifying that `pytree_flatten` to avoid putting the object into context, and apply one of the following to <DecoratorTests.test_nonstrict_trace_object_in_context_error.<locals>.Point>
|
||||
* `torch.utils._pytree.register_constant`
|
||||
* `torch.utils._pytree.register_dataclass`
|
||||
* `torch.utils._pytree.register_pytree_node`
|
||||
|
||||
If the above doesn't work, please subtmit an issue to GitHub.
|
||||
""" # NOQA: B950
|
||||
self.assertIn(msg, str(e))
|
||||
|
||||
|
@ -24,7 +24,7 @@ def distance(a, b, norm):
|
||||
return (a.x - b.x).abs() + (a.y - b.y).abs()
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class Norm:
|
||||
typ: str
|
||||
|
||||
|
@ -9754,7 +9754,7 @@ graph():
|
||||
@testing.expectedFailureSerDerNonStrict # register_constant needs to handle serialization
|
||||
@testing.expectedFailureSerDer # register_constant needs to handle serialization
|
||||
def test_register_constant(self):
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class MyInput:
|
||||
int_1: int
|
||||
int_2: int
|
||||
|
@ -1180,7 +1180,9 @@ if "optree" in sys.modules:
|
||||
self.assertEqual(point.y, torch.tensor(2))
|
||||
|
||||
def test_constant(self):
|
||||
@dataclass
|
||||
# Either use `frozen=True` or `unsafe_hash=True` so we have a
|
||||
# non-default `__hash__`.
|
||||
@dataclass(unsafe_hash=True)
|
||||
class Config:
|
||||
norm: str
|
||||
|
||||
@ -1191,6 +1193,33 @@ if "optree" in sys.modules:
|
||||
self.assertEqual(elements, [])
|
||||
self.assertEqual(spec.context.value, config)
|
||||
|
||||
def test_constant_default_eq_error(self):
|
||||
class Config:
|
||||
def __init__(self, norm: str):
|
||||
self.norm = norm
|
||||
|
||||
try:
|
||||
py_pytree.register_constant(Config)
|
||||
self.assertFalse(True) # must raise error before this
|
||||
except TypeError as e:
|
||||
msg = "register_constant(cls) expects `cls` to have a non-default `__eq__` implementation."
|
||||
self.assertIn(msg, str(e))
|
||||
|
||||
def test_constant_default_hash_error(self):
|
||||
class Config:
|
||||
def __init__(self, norm: str):
|
||||
self.norm = norm
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.norm == other.norm
|
||||
|
||||
try:
|
||||
py_pytree.register_constant(Config)
|
||||
self.assertFalse(True) # must raise error before this
|
||||
except TypeError as e:
|
||||
msg = "register_constant(cls) expects `cls` to have a non-default `__hash__` implementation."
|
||||
self.assertIn(msg, str(e))
|
||||
|
||||
def test_tree_map_with_path_multiple_trees(self):
|
||||
@dataclass
|
||||
class ACustomPytree:
|
||||
|
@ -1628,15 +1628,15 @@ class GuardBuilder(GuardBuilderBase):
|
||||
DeviceMesh,
|
||||
)
|
||||
|
||||
def check_type(obj):
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
return istype(obj, ok_types) or pytree.is_constant_class(type(obj))
|
||||
|
||||
if istype(val, dict):
|
||||
assert all(
|
||||
istype(x, ok_types) for x in itertools.chain(val.keys(), val.values())
|
||||
)
|
||||
assert all(check_type(x) for x in itertools.chain(val.keys(), val.values()))
|
||||
else:
|
||||
assert istype(
|
||||
val,
|
||||
ok_types,
|
||||
), f"Unexpected type {type(val)}, not in {ok_types}"
|
||||
assert check_type(val), f"Unexpected type {type(val)}"
|
||||
|
||||
# Special case for nan because float("nan") == float("nan") evaluates to False
|
||||
if istype(val, float) and math.isnan(val):
|
||||
|
@ -194,6 +194,17 @@ def is_side_effect_safe(m: MutationType):
|
||||
return m.scope == scope_id
|
||||
|
||||
|
||||
# This helps users of `as_python_constant` to catch unimplemented error with
|
||||
# more information; it inherits `NotImplementedError` for backward
|
||||
# compatibility reasons.
|
||||
class AsPythonConstantNotImplementedError(NotImplementedError):
|
||||
vt: "VariableTracker"
|
||||
|
||||
def __init__(self, vt: "VariableTracker"):
|
||||
super().__init__(self, f"{vt} is not a constant")
|
||||
self.vt = vt
|
||||
|
||||
|
||||
class VariableTrackerMeta(type):
|
||||
all_subclasses = []
|
||||
|
||||
@ -319,7 +330,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
|
||||
def as_python_constant(self):
|
||||
"""For constants"""
|
||||
raise NotImplementedError(f"{self} is not a constant")
|
||||
raise AsPythonConstantNotImplementedError(self)
|
||||
|
||||
def guard_as_python_constant(self):
|
||||
"""Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""
|
||||
|
@ -994,6 +994,8 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
from torch._subclasses.fake_tensor import fake_tensor_tls
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
from .base import AsPythonConstantNotImplementedError
|
||||
|
||||
# 1. Convert `args, kwargs` into pytree-flattened proxy forms.
|
||||
#
|
||||
# Rather than reconstructing `args, kwargs` into python objects and
|
||||
@ -1018,6 +1020,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
unimplemented(
|
||||
f"""
|
||||
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <{type_name}>, please use one of the following to register the type with pytree:
|
||||
* `torch.utils._pytree.register_constant`
|
||||
* `torch.utils._pytree.register_dataclass`
|
||||
* `torch.utils._pytree.register_pytree_node`
|
||||
""" # NOQA: B950
|
||||
@ -1033,16 +1036,34 @@ For `nonstrict_trace`-ed function, the only allowed input types are basic types
|
||||
# the spec not a graphable type, so we still have to reconstruct it
|
||||
# into a python object, and store it as a constant attribute on the
|
||||
# fx graph.
|
||||
#
|
||||
# TODO handle `pytree._register_constant`-ed values.
|
||||
try:
|
||||
input_spec = input_spec_vt.as_python_constant()
|
||||
except NotImplementedError:
|
||||
unimplemented(
|
||||
"""
|
||||
This error is most likely due to a call to `nonstrict_trace`-ed function, where one of the argument contains object of a type that has been (or needs to be) `torch.utils._pytree.register_constant`-ed. We currently don't support that.
|
||||
except AsPythonConstantNotImplementedError as e:
|
||||
typ = e.vt.python_type()
|
||||
type_name = typ.__qualname__
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
if pytree.is_constant_class(typ):
|
||||
unimplemented(
|
||||
f"""
|
||||
You are calling a `nonstrict_trace`-ed function with an input that contains an object of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region.
|
||||
|
||||
Please construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.
|
||||
""" # NOQA: B950
|
||||
)
|
||||
else:
|
||||
unimplemented(
|
||||
f"""
|
||||
You are calling a `nonstrict_trace`-ed function where one one of the inputs has been registered with a `pytree_flatten` that puts an object of type <{type_name}> into the context.
|
||||
|
||||
Please consider modifying that `pytree_flatten` to avoid putting the object into context, and apply one of the following to <{type_name}>
|
||||
* `torch.utils._pytree.register_constant`
|
||||
* `torch.utils._pytree.register_dataclass`
|
||||
* `torch.utils._pytree.register_pytree_node`
|
||||
|
||||
If the above doesn't work, please subtmit an issue to GitHub.
|
||||
""" # NOQA: B950
|
||||
)
|
||||
)
|
||||
|
||||
fn = self.value
|
||||
|
||||
|
@ -764,6 +764,17 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
def python_type(self):
|
||||
return self.value_type
|
||||
|
||||
def as_python_constant(self):
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
if pytree.is_constant_class(self.value_type):
|
||||
if self.source is not None:
|
||||
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
|
||||
return self.value
|
||||
# TODO else try reconstructing the object by, e.g., leveraging side
|
||||
# effects and `as_python_constant`.
|
||||
return super().as_python_constant()
|
||||
|
||||
def guard_as_python_constant(self):
|
||||
if self.source:
|
||||
install_guard(self.source.make_guard(GuardBuilder.ID_MATCH))
|
||||
@ -1352,7 +1363,9 @@ class FrozenDataClassVariable(UserDefinedObjectVariable):
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
if not istype(self.value, (pytree.TreeSpec, pytree.LeafSpec)):
|
||||
if not istype(
|
||||
self.value, (pytree.TreeSpec, pytree.LeafSpec, pytree.ConstantNode)
|
||||
):
|
||||
# TODO loosen this restriction and fix `as_proxy`.
|
||||
raise NotImplementedError(
|
||||
"currently can't reconstruct arbitrary frozen dataclass instances"
|
||||
|
@ -47,7 +47,7 @@ def func_to_graphable(func):
|
||||
return pytree.tree_flatten(_ConstantFunction(func))
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class _ConstantFunction:
|
||||
func: Callable
|
||||
|
||||
|
@ -290,19 +290,37 @@ def register_dataclass(cls: type[Any]) -> None:
|
||||
torch.export.register_dataclass(cls)
|
||||
|
||||
|
||||
CONSTANT_NODES: set[type] = set()
|
||||
|
||||
|
||||
def register_constant(cls: type[Any]) -> None:
|
||||
"""Registers a type as a pytree node with no leaves.
|
||||
|
||||
Instances of these types are treated as a constant (sometimes referred to as
|
||||
"static") by :func:`torch.compile`. When used in a function compiled by
|
||||
:func:`torch.compile`, :func:`torch.compile` guards on the instance
|
||||
object's hash: if :func:`torch.compile` sees a new hash then
|
||||
In a :func:`torch.compile` region, if instances of these types get passed to
|
||||
:func:`torch._dynamo.nonstrict_trace`-ed function, they treated as a
|
||||
constant (sometimes referred to as "static"):
|
||||
|
||||
1. if the instance object existed before the :func:`torch.compile` region,
|
||||
we _assume_ no mutation will happen to it inside the :func:`torch.compile`
|
||||
region, require that it has non-default `__eq__` and `__hash__` methods, and
|
||||
we guard on the instance based on its `__eq__` method, i.e., if a new
|
||||
instance fails to match any instances from the previous compilations,
|
||||
:func:`torch.compile` will recompile the function using the new instance.
|
||||
|
||||
2. else if the instance object is created inside the :func:`torch.compile`
|
||||
region, we currently don't support using it in a
|
||||
:func:`torch._dynamo.nonstrict_trace`-ed function.
|
||||
|
||||
In general, if your class holds Tensors or dynamic int/float/bool (values that
|
||||
may change from run-to-run of a function being compiled), then you probably
|
||||
do not want to register it as a constant.
|
||||
|
||||
Otherwise if you want to pass instance of a class to a
|
||||
:func:`torch._dynamo.nonstrict_trace`-ed function, but you either can't use
|
||||
:func:`register_pytree_node` on the class, or the class is "constant" enough
|
||||
that you don't want to bother using :func:`register_pytree_node`, you should
|
||||
consider using this function.
|
||||
|
||||
Args:
|
||||
cls: the type to register as a constant. This type must be hashable.
|
||||
|
||||
@ -311,7 +329,7 @@ def register_constant(cls: type[Any]) -> None:
|
||||
>>> from dataclasses import dataclass
|
||||
>>> import torch.utils._pytree as pytree
|
||||
>>>
|
||||
>>> @dataclass
|
||||
>>> @dataclass(frozen=True)
|
||||
>>> class Config:
|
||||
>>> norm: str
|
||||
>>>
|
||||
@ -322,6 +340,17 @@ def register_constant(cls: type[Any]) -> None:
|
||||
>>> assert len(values) == 0
|
||||
|
||||
"""
|
||||
if cls.__eq__ is object.__eq__: # type: ignore[comparison-overlap]
|
||||
raise TypeError(
|
||||
"register_constant(cls) expects `cls` to have a non-default `__eq__` implementation."
|
||||
)
|
||||
|
||||
# Class with a custom `__eq__` without `__hash__` won't inherit the default
|
||||
# `__hash__` from object; see https://stackoverflow.com/a/1608907.
|
||||
if cls.__hash__ is None: # type: ignore[comparison-overlap]
|
||||
raise TypeError(
|
||||
"register_constant(cls) expects `cls` to have a non-default `__hash__` implementation."
|
||||
)
|
||||
|
||||
def _flatten(x): # type: ignore[no-untyped-def]
|
||||
return [], ConstantNode(x)
|
||||
@ -338,9 +367,14 @@ def register_constant(cls: type[Any]) -> None:
|
||||
_unflatten,
|
||||
flatten_with_keys_fn=_flatten_with_keys,
|
||||
)
|
||||
CONSTANT_NODES.add(cls)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
def is_constant_class(cls: type[Any]) -> bool:
|
||||
return isinstance(cls, type) and cls in CONSTANT_NODES
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ConstantNode:
|
||||
value: Any
|
||||
|
||||
@ -455,6 +489,7 @@ def _deregister_pytree_node(
|
||||
node_def = SUPPORTED_SERIALIZED_TYPES[cls]
|
||||
del SERIALIZED_TYPE_TO_PYTHON_TYPE[node_def.serialized_type_name]
|
||||
del SUPPORTED_SERIALIZED_TYPES[cls]
|
||||
CONSTANT_NODES.discard(cls)
|
||||
|
||||
|
||||
def _private_register_pytree_node(
|
||||
|
Reference in New Issue
Block a user