mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Update on "[dynamo] Support namedtuple subclass"
Fixes #133762. This involves 1. support tuple subclass constructed inside compile region. 2. handle the "fake" global scope associated with NamedTuple-generated `__new__`. 3. handle `namedtuple._tuplegetter` more faithfully. 4. use `object.__getattribute__(obj, "__dict__")[key] = value` to replay side effects onto pre-existing user-defined objects, because `object.__setattr__` can still trigger user code on via descriptors setter. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
This commit is contained in:
@ -2003,6 +2003,30 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
||||
res = opt_fn(g, torch.ones(2, 2))
|
||||
|
||||
def test_set_descriptor(self):
|
||||
class Field:
|
||||
def __set__(self, obj, value):
|
||||
obj._value = value * 2
|
||||
|
||||
def __get__(self, obj, owner):
|
||||
return obj._value + 1
|
||||
|
||||
class Foo:
|
||||
field = Field()
|
||||
|
||||
def fn(x, foo):
|
||||
foo.field = 10
|
||||
return x + foo.field
|
||||
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend="eager")
|
||||
x = torch.zeros(2)
|
||||
foo1, foo2 = Foo(), Foo()
|
||||
|
||||
ref = fn(x, foo1)
|
||||
res = opt_fn(x, foo2)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(foo1.field, foo2.field)
|
||||
|
||||
def test_get_attr_function(self):
|
||||
def fn(g, x):
|
||||
return g(x)
|
||||
@ -12223,6 +12247,27 @@ fn
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
self.assertEqual(fn(*inputs), inputs[0])
|
||||
|
||||
def test_is_op(self):
|
||||
def fn(x, obj, d):
|
||||
# `obj.__dict__` is meant to make Dynamo create a `GetAttrVariable`.
|
||||
# `x += 1` makes sure we don't skip the frame.
|
||||
x += 1
|
||||
b1 = d is obj.__dict__
|
||||
x += 1
|
||||
b2 = d is not obj.__dict__
|
||||
return b1, b2
|
||||
|
||||
class Foo:
|
||||
pass
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager")
|
||||
x = torch.ones(2)
|
||||
obj = Foo()
|
||||
|
||||
ref = fn(x, obj, obj.__dict__)
|
||||
res = opt_fn(x, obj, obj.__dict__)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
|
||||
class TestTracer(JitTestCase):
|
||||
def test_jit_save(self):
|
||||
|
@ -668,6 +668,13 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
|
||||
def handle_is(tx: "InstructionTranslator", left, right):
|
||||
# Stay conservative when we see `GetAttrVariable`, because
|
||||
# it might represent the other VariableTracker under the
|
||||
# hood.
|
||||
if isinstance(left, variables.GetAttrVariable) or isinstance(
|
||||
right, variables.GetAttrVariable
|
||||
):
|
||||
return None
|
||||
# If the two objects are of different type, we can safely return False
|
||||
# and True for `is` and `is not`, respectively
|
||||
if type(left) is not type(right):
|
||||
@ -2342,10 +2349,13 @@ class BuiltinVariable(VariableTracker):
|
||||
and id(extract_fake_example_value(left.as_proxy().node))
|
||||
== id(extract_fake_example_value(right.as_proxy().node))
|
||||
)
|
||||
if op is operator.is_:
|
||||
return ConstantVariable.create(is_result)
|
||||
else:
|
||||
return ConstantVariable.create(not is_result)
|
||||
if is_result:
|
||||
return ConstantVariable.create(op is operator.is_)
|
||||
# Else we stay conservative, because we might have `GetAttrVariable`
|
||||
# which represents a `TensorVariable` under the hood and happens to
|
||||
# be the same as the other `TensorVariable`. This happens with
|
||||
# numpy's `flatiter.base` descriptor.
|
||||
return None
|
||||
|
||||
if op not in supported_tensor_comparison_op_values:
|
||||
unimplemented_v2(
|
||||
|
@ -1169,7 +1169,9 @@ class GetAttrVariable(VariableTracker):
|
||||
elif name == "__setitem__" and self.name == "__dict__" and not kwargs:
|
||||
if isinstance(self.obj, variables.UserDefinedObjectVariable):
|
||||
# Bypass any custom setattr as we are updating the `__dict__` itself
|
||||
return self.obj.method_setattr_standard(tx, args[0], args[1])
|
||||
return self.obj.method_setattr_standard(
|
||||
tx, args[0], args[1], bypass_descriptor=True
|
||||
)
|
||||
if isinstance(self.obj, variables.NNModuleVariable):
|
||||
# This matches how `setattr` is handled for NNModuleVariable
|
||||
self.obj.convert_to_unspecialized(tx)
|
||||
|
@ -898,7 +898,9 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def method_setattr_standard(self, tx: "InstructionTranslator", name, value):
|
||||
def method_setattr_standard(
|
||||
self, tx: "InstructionTranslator", name, value, bypass_descriptor=False
|
||||
):
|
||||
try:
|
||||
name = name.as_python_constant()
|
||||
except NotImplementedError:
|
||||
@ -906,6 +908,28 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
if not tx.output.side_effects.is_attribute_mutation(self):
|
||||
unimplemented(f"setattr({self}, {name}, ...)")
|
||||
|
||||
if not bypass_descriptor:
|
||||
# Emulate
|
||||
# https://github.com/python/cpython/blob/3.11/Objects/object.c#L1371-L1452
|
||||
# NOTE we use `type(...)` to ignore instance attrs.
|
||||
setter = NO_SUCH_SUBOBJ
|
||||
descriptor = inspect.getattr_static(type(self.value), name, NO_SUCH_SUBOBJ)
|
||||
if descriptor is not NO_SUCH_SUBOBJ:
|
||||
setter = inspect.getattr_static(
|
||||
type(descriptor), "__set__", NO_SUCH_SUBOBJ
|
||||
)
|
||||
if setter is not NO_SUCH_SUBOBJ:
|
||||
desc_source = None
|
||||
func_source = None
|
||||
if self.cls_source:
|
||||
desc_source = self.get_source_by_walking_mro(name)
|
||||
func_source = AttrSource(TypeSource(desc_source), "__set__")
|
||||
desc_var = VariableTracker.build(tx, descriptor, desc_source)
|
||||
func_var = VariableTracker.build(tx, setter, func_source)
|
||||
args = [desc_var, self, value]
|
||||
return func_var.call_function(tx, args, {})
|
||||
|
||||
# Emulate the standard setattr on instance dict.
|
||||
tx.output.side_effects.store_attr(self, name, value)
|
||||
return variables.ConstantVariable(None)
|
||||
|
||||
@ -1189,9 +1213,13 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
# e.g.: inspect.getattr_static({}, "fromkeys")
|
||||
func = subobj.__get__(self.value, None)
|
||||
return VariableTracker.build(tx, func, source)
|
||||
elif inspect.ismethoddescriptor(subobj) and not is_wrapper_or_member_descriptor(
|
||||
subobj.__get__
|
||||
elif inspect.getattr_static(
|
||||
type(subobj), "__get__", NO_SUCH_SUBOBJ
|
||||
) is not NO_SUCH_SUBOBJ and not is_wrapper_or_member_descriptor(
|
||||
type(subobj).__get__
|
||||
):
|
||||
# Emulate https://github.com/python/cpython/blob/3.11/Objects/object.c#L1271-L1285
|
||||
#
|
||||
# Attribute has a __get__ method. Create a user defined object vt
|
||||
# for the subobj, and then trace the __get__ method.
|
||||
descriptor_source = None
|
||||
@ -1200,7 +1228,9 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
# To access the method descriptor from the udf object w/o using
|
||||
# inspect.getattr_static, we can look into the class mro
|
||||
descriptor_source = self.get_source_by_walking_mro(name)
|
||||
descriptor_get_source = AttrSource(descriptor_source, "__get__")
|
||||
descriptor_get_source = AttrSource(
|
||||
TypeSource(descriptor_source), "__get__"
|
||||
)
|
||||
descriptor_var = VariableTracker.build(tx, subobj, descriptor_source)
|
||||
else:
|
||||
# Sourceless Builder does not support user defined objects
|
||||
|
Reference in New Issue
Block a user