Update on "[dynamo] Support namedtuple subclass"

Fixes #133762. This involves
1. support tuple subclass constructed inside compile region.
2. handle the "fake" global scope associated with NamedTuple-generated
   `__new__`.
3. handle `namedtuple._tuplegetter` more faithfully.
4. use `object.__getattribute__(obj, "__dict__")[key] = value` to replay
   side effects onto pre-existing user-defined objects, because
   `object.__setattr__` can still trigger user code on via descriptors
   setter.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
This commit is contained in:
Ryan Guo
2025-05-22 18:20:12 -07:00
4 changed files with 96 additions and 9 deletions

View File

@ -2003,6 +2003,30 @@ utils_device.CURRENT_DEVICE == None""".split(
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
res = opt_fn(g, torch.ones(2, 2))
def test_set_descriptor(self):
class Field:
def __set__(self, obj, value):
obj._value = value * 2
def __get__(self, obj, owner):
return obj._value + 1
class Foo:
field = Field()
def fn(x, foo):
foo.field = 10
return x + foo.field
opt_fn = torch.compile(fn, fullgraph=True, backend="eager")
x = torch.zeros(2)
foo1, foo2 = Foo(), Foo()
ref = fn(x, foo1)
res = opt_fn(x, foo2)
self.assertEqual(ref, res)
self.assertEqual(foo1.field, foo2.field)
def test_get_attr_function(self):
def fn(g, x):
return g(x)
@ -12223,6 +12247,27 @@ fn
with torch.compiler.set_stance("fail_on_recompile"):
self.assertEqual(fn(*inputs), inputs[0])
def test_is_op(self):
def fn(x, obj, d):
# `obj.__dict__` is meant to make Dynamo create a `GetAttrVariable`.
# `x += 1` makes sure we don't skip the frame.
x += 1
b1 = d is obj.__dict__
x += 1
b2 = d is not obj.__dict__
return b1, b2
class Foo:
pass
opt_fn = torch.compile(fn, backend="eager")
x = torch.ones(2)
obj = Foo()
ref = fn(x, obj, obj.__dict__)
res = opt_fn(x, obj, obj.__dict__)
self.assertEqual(ref, res)
class TestTracer(JitTestCase):
def test_jit_save(self):

View File

@ -668,6 +668,13 @@ class BuiltinVariable(VariableTracker):
)
def handle_is(tx: "InstructionTranslator", left, right):
# Stay conservative when we see `GetAttrVariable`, because
# it might represent the other VariableTracker under the
# hood.
if isinstance(left, variables.GetAttrVariable) or isinstance(
right, variables.GetAttrVariable
):
return None
# If the two objects are of different type, we can safely return False
# and True for `is` and `is not`, respectively
if type(left) is not type(right):
@ -2342,10 +2349,13 @@ class BuiltinVariable(VariableTracker):
and id(extract_fake_example_value(left.as_proxy().node))
== id(extract_fake_example_value(right.as_proxy().node))
)
if op is operator.is_:
return ConstantVariable.create(is_result)
else:
return ConstantVariable.create(not is_result)
if is_result:
return ConstantVariable.create(op is operator.is_)
# Else we stay conservative, because we might have `GetAttrVariable`
# which represents a `TensorVariable` under the hood and happens to
# be the same as the other `TensorVariable`. This happens with
# numpy's `flatiter.base` descriptor.
return None
if op not in supported_tensor_comparison_op_values:
unimplemented_v2(

View File

@ -1169,7 +1169,9 @@ class GetAttrVariable(VariableTracker):
elif name == "__setitem__" and self.name == "__dict__" and not kwargs:
if isinstance(self.obj, variables.UserDefinedObjectVariable):
# Bypass any custom setattr as we are updating the `__dict__` itself
return self.obj.method_setattr_standard(tx, args[0], args[1])
return self.obj.method_setattr_standard(
tx, args[0], args[1], bypass_descriptor=True
)
if isinstance(self.obj, variables.NNModuleVariable):
# This matches how `setattr` is handled for NNModuleVariable
self.obj.convert_to_unspecialized(tx)

View File

@ -898,7 +898,9 @@ class UserDefinedObjectVariable(UserDefinedVariable):
return super().call_method(tx, name, args, kwargs)
def method_setattr_standard(self, tx: "InstructionTranslator", name, value):
def method_setattr_standard(
self, tx: "InstructionTranslator", name, value, bypass_descriptor=False
):
try:
name = name.as_python_constant()
except NotImplementedError:
@ -906,6 +908,28 @@ class UserDefinedObjectVariable(UserDefinedVariable):
if not tx.output.side_effects.is_attribute_mutation(self):
unimplemented(f"setattr({self}, {name}, ...)")
if not bypass_descriptor:
# Emulate
# https://github.com/python/cpython/blob/3.11/Objects/object.c#L1371-L1452
# NOTE we use `type(...)` to ignore instance attrs.
setter = NO_SUCH_SUBOBJ
descriptor = inspect.getattr_static(type(self.value), name, NO_SUCH_SUBOBJ)
if descriptor is not NO_SUCH_SUBOBJ:
setter = inspect.getattr_static(
type(descriptor), "__set__", NO_SUCH_SUBOBJ
)
if setter is not NO_SUCH_SUBOBJ:
desc_source = None
func_source = None
if self.cls_source:
desc_source = self.get_source_by_walking_mro(name)
func_source = AttrSource(TypeSource(desc_source), "__set__")
desc_var = VariableTracker.build(tx, descriptor, desc_source)
func_var = VariableTracker.build(tx, setter, func_source)
args = [desc_var, self, value]
return func_var.call_function(tx, args, {})
# Emulate the standard setattr on instance dict.
tx.output.side_effects.store_attr(self, name, value)
return variables.ConstantVariable(None)
@ -1189,9 +1213,13 @@ class UserDefinedObjectVariable(UserDefinedVariable):
# e.g.: inspect.getattr_static({}, "fromkeys")
func = subobj.__get__(self.value, None)
return VariableTracker.build(tx, func, source)
elif inspect.ismethoddescriptor(subobj) and not is_wrapper_or_member_descriptor(
subobj.__get__
elif inspect.getattr_static(
type(subobj), "__get__", NO_SUCH_SUBOBJ
) is not NO_SUCH_SUBOBJ and not is_wrapper_or_member_descriptor(
type(subobj).__get__
):
# Emulate https://github.com/python/cpython/blob/3.11/Objects/object.c#L1271-L1285
#
# Attribute has a __get__ method. Create a user defined object vt
# for the subobj, and then trace the __get__ method.
descriptor_source = None
@ -1200,7 +1228,9 @@ class UserDefinedObjectVariable(UserDefinedVariable):
# To access the method descriptor from the udf object w/o using
# inspect.getattr_static, we can look into the class mro
descriptor_source = self.get_source_by_walking_mro(name)
descriptor_get_source = AttrSource(descriptor_source, "__get__")
descriptor_get_source = AttrSource(
TypeSource(descriptor_source), "__get__"
)
descriptor_var = VariableTracker.build(tx, subobj, descriptor_source)
else:
# Sourceless Builder does not support user defined objects