mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dyanmo] support subclasses of namedtuple type (#140534)
Allow subclassing namedtuple type. Allow assign attributes to instances of these subtypes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140534 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
ab5c8857ef
commit
90d3584147
@ -748,21 +748,26 @@ class SizeVariable(TupleVariable):
|
||||
class NamedTupleVariable(TupleVariable):
|
||||
_nonvar_fields = {
|
||||
"tuple_cls",
|
||||
"dynamic_attributes",
|
||||
*TupleVariable._nonvar_fields,
|
||||
}
|
||||
|
||||
def __init__(self, items, tuple_cls, **kwargs) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
self.tuple_cls = tuple_cls
|
||||
self.dynamic_attributes = {}
|
||||
|
||||
def is_namedtuple(self):
|
||||
return hasattr(self.tuple_cls, "_fields") and callable(
|
||||
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(
|
||||
getattr(self.tuple_cls, "_make", None)
|
||||
)
|
||||
|
||||
def is_structseq(self):
|
||||
return not self.is_namedtuple()
|
||||
|
||||
def fields(self):
|
||||
return namedtuple_fields(self.tuple_cls)
|
||||
|
||||
def debug_repr(self):
|
||||
if self.is_structseq():
|
||||
# StructSequenceType(iterable)
|
||||
@ -805,6 +810,33 @@ class NamedTupleVariable(TupleVariable):
|
||||
+ create_call_function(1, False)
|
||||
)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: List[VariableTracker],
|
||||
kwargs: Dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__setattr__":
|
||||
assert len(args) == 2
|
||||
assert len(kwargs) == 0
|
||||
attr, value = args
|
||||
attr = attr.as_python_constant()
|
||||
if (
|
||||
# structseq is immutable
|
||||
self.is_structseq()
|
||||
# namedtuple directly created by `collections.namedtuple` is immutable
|
||||
or self.tuple_cls.__bases__ == (tuple,)
|
||||
# fields are immutable
|
||||
or attr in self.fields()
|
||||
):
|
||||
raise_observed_exception(AttributeError, tx)
|
||||
# Subclass of namedtuple type can have dynamic attributes
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.dynamic_attributes[attr] = value
|
||||
return ConstantVariable.create(None)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
def check_and_create_method():
|
||||
method = inspect.getattr_static(self.tuple_cls, name, None)
|
||||
@ -821,7 +853,10 @@ class NamedTupleVariable(TupleVariable):
|
||||
else:
|
||||
return None
|
||||
|
||||
fields = namedtuple_fields(self.tuple_cls)
|
||||
if name in self.dynamic_attributes:
|
||||
return self.dynamic_attributes[name]
|
||||
|
||||
fields = self.fields()
|
||||
if name not in fields:
|
||||
method = check_and_create_method()
|
||||
if not method:
|
||||
@ -830,7 +865,9 @@ class NamedTupleVariable(TupleVariable):
|
||||
return self.items[fields.index(name)]
|
||||
|
||||
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||
return variables.ConstantVariable.create(hasattr(self.tuple_cls, name))
|
||||
return variables.ConstantVariable.create(
|
||||
name in self.dynamic_attributes or hasattr(self.tuple_cls, name)
|
||||
)
|
||||
|
||||
|
||||
class SliceVariable(BaseListVariable):
|
||||
|
||||
Reference in New Issue
Block a user