[dyanmo] support subclasses of namedtuple type (#140534)

Allow subclassing namedtuple type. Allow assign attributes to instances of these subtypes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140534
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan
2024-11-13 22:36:05 +08:00
committed by PyTorch MergeBot
parent ab5c8857ef
commit 90d3584147
4 changed files with 74 additions and 12 deletions

View File

@ -748,21 +748,26 @@ class SizeVariable(TupleVariable):
class NamedTupleVariable(TupleVariable):
_nonvar_fields = {
"tuple_cls",
"dynamic_attributes",
*TupleVariable._nonvar_fields,
}
def __init__(self, items, tuple_cls, **kwargs) -> None:
super().__init__(items, **kwargs)
self.tuple_cls = tuple_cls
self.dynamic_attributes = {}
def is_namedtuple(self):
return hasattr(self.tuple_cls, "_fields") and callable(
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(
getattr(self.tuple_cls, "_make", None)
)
def is_structseq(self):
return not self.is_namedtuple()
def fields(self):
return namedtuple_fields(self.tuple_cls)
def debug_repr(self):
if self.is_structseq():
# StructSequenceType(iterable)
@ -805,6 +810,33 @@ class NamedTupleVariable(TupleVariable):
+ create_call_function(1, False)
)
def call_method(
self,
tx,
name,
args: List[VariableTracker],
kwargs: Dict[str, VariableTracker],
) -> VariableTracker:
if name == "__setattr__":
assert len(args) == 2
assert len(kwargs) == 0
attr, value = args
attr = attr.as_python_constant()
if (
# structseq is immutable
self.is_structseq()
# namedtuple directly created by `collections.namedtuple` is immutable
or self.tuple_cls.__bases__ == (tuple,)
# fields are immutable
or attr in self.fields()
):
raise_observed_exception(AttributeError, tx)
# Subclass of namedtuple type can have dynamic attributes
tx.output.side_effects.mutation(self)
self.dynamic_attributes[attr] = value
return ConstantVariable.create(None)
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx: "InstructionTranslator", name):
def check_and_create_method():
method = inspect.getattr_static(self.tuple_cls, name, None)
@ -821,7 +853,10 @@ class NamedTupleVariable(TupleVariable):
else:
return None
fields = namedtuple_fields(self.tuple_cls)
if name in self.dynamic_attributes:
return self.dynamic_attributes[name]
fields = self.fields()
if name not in fields:
method = check_and_create_method()
if not method:
@ -830,7 +865,9 @@ class NamedTupleVariable(TupleVariable):
return self.items[fields.index(name)]
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
return variables.ConstantVariable.create(hasattr(self.tuple_cls, name))
return variables.ConstantVariable.create(
name in self.dynamic_attributes or hasattr(self.tuple_cls, name)
)
class SliceVariable(BaseListVariable):