NamedTuple: Allow side effects for dynamic attributes (#161645)

I confirmed that the tracing was correct i.e. NamedTupleVariable had the correct dynamic attribute added to it.

The problem was that NamedTupleVariable was always marked as immutable. This does not reflect the behavior of namedtuple.

Subclasses of namedtuple may be mutable, so when a NamedTupleVariable is derived from a subclass that is mutable, I made NamedTupleVariable mutable as well. Then side_effects correctly updates the returned object.

Fixes #161610

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161645
Approved by: https://github.com/anijain2305, https://github.com/StrongerXi
This commit is contained in:
morrison-turnansky
2025-09-09 19:41:59 +00:00
committed by PyTorch MergeBot
parent 8508651477
commit 86d34a43f5
4 changed files with 85 additions and 6 deletions

View File

@ -1767,6 +1767,52 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
out = f(MyTuple(a, b))
self.assertTrue(same(a + 1, out))
def test_namedtuple_source_dynamic_attributes(self):
class MyNamedTuple(typing.NamedTuple):
a: torch.Tensor
b: torch.Tensor
class MyNamedTupleSubclass(MyNamedTuple):
pass
@torch.compile(fullgraph=True, backend="eager")
def f(tup):
c = torch.tensor(3.0)
tup.c = c # Add dynamic attribute
return tup
extended_tup = MyNamedTupleSubclass(a=torch.tensor([1.0]), b=torch.tensor(2.0))
result = f(extended_tup)
# Verify the tuple has the expected structure
self.assertEqual(result.a, torch.tensor([1.0]))
self.assertEqual(result.b, torch.tensor(2.0))
self.assertTrue(hasattr(result, "c"))
self.assertEqual(result.c, torch.tensor(3.0))
def test_namedtuple_sourceless_dynamic_attributes(self):
class MyNamedTuple(typing.NamedTuple):
a: torch.Tensor
b: torch.Tensor
class MyNamedTupleSubclass(MyNamedTuple):
pass
@torch.compile(backend="eager")
def f():
# Create namedtuple inside function (sourceless)
tup = MyNamedTupleSubclass(a=torch.tensor([1.0]), b=torch.tensor(2.0))
# Add dynamic attribute
tup.c = torch.tensor(3.0)
return tup
result = f()
# Verify the tuple has the expected structure
self.assertEqual(result.a, torch.tensor([1.0]))
self.assertEqual(result.b, torch.tensor(2.0))
# Verify the dynamic attribute is preserved
self.assertTrue(hasattr(result, "c"))
self.assertEqual(result.c, torch.tensor(3.0))
def test_structseq1(self):
def fn(x, y):
return torch.return_types.max((x, y))

View File

@ -708,7 +708,7 @@ class VariableBuilder:
result = NamedTupleVariable(
output, tuple_cls=type(value), source=self.source
)
return result
return self.tx.output.side_effects.track_object_existing(value, result)
elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
self.install_guards(GuardBuilder.TYPE_MATCH)
all_const = all(ConstantVariable.is_literal(k) for k in value.keys())

View File

@ -26,7 +26,11 @@ import torch
import torch.fx
from .. import graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..bytecode_transformation import (
create_call_function,
create_instruction,
create_rot_n,
)
from ..exc import raise_observed_exception, unimplemented_v2
from ..source import AttrSource, NamedTupleFieldsSource
from ..utils import (
@ -1173,9 +1177,24 @@ class NamedTupleVariable(TupleVariable):
def as_python_constant(self):
if self.is_structseq():
# StructSequenceType(iterable)
return self.python_type()([x.as_python_constant() for x in self.items])
# NamedTupleType(*iterable)
return self.python_type()(*[x.as_python_constant() for x in self.items])
result = self.python_type()([x.as_python_constant() for x in self.items])
else:
# NamedTupleType(*iterable)
result = self.python_type()(*[x.as_python_constant() for x in self.items])
# Apply dynamic attributes if any were set
if self.dynamic_attributes:
for attr_name, attr_value in self.dynamic_attributes.items():
# Convert VariableTracker to Python constant if needed
if hasattr(attr_value, "as_python_constant"):
python_value = attr_value.as_python_constant()
else:
raise NotImplementedError(
"Can not convert dynamic attribute without python constant value to python constant."
)
setattr(result, attr_name, python_value)
return result
def as_proxy(self):
assert self.python_type() is not SizeVariable
@ -1186,6 +1205,7 @@ class NamedTupleVariable(TupleVariable):
return self.python_type()(*self._as_proxy())
def reconstruct(self, codegen: "PyCodegen") -> None:
# Always reconstruct the NamedTuple normally first
# Constructors:
# StructSequenceType(iterable)
# NamedTupleType(*iterable)
@ -1204,6 +1224,12 @@ class NamedTupleVariable(TupleVariable):
+ create_call_function(1, False)
)
for name, value in self.dynamic_attributes.items():
codegen.dup_top()
codegen(value)
codegen.extend_output(create_rot_n(2))
codegen.store_attr(name)
def call_method(
self,
tx,
@ -1227,6 +1253,8 @@ class NamedTupleVariable(TupleVariable):
raise_observed_exception(AttributeError, tx)
# Subclass of namedtuple type can have dynamic attributes
tx.output.side_effects.mutation(self)
if self.source:
tx.output.side_effects.store_attr(self, attr, value)
self.dynamic_attributes[attr] = value
return ConstantVariable.create(None)
return super().call_method(tx, name, args, kwargs)

View File

@ -716,7 +716,12 @@ class UserDefinedClassVariable(UserDefinedVariable):
assert all(x is not None for x in items)
return variables.NamedTupleVariable(items, self.value)
# Modify mutability of namedtuple for sourcelesss instantiations.
from .base import AttributeMutationNew
return variables.NamedTupleVariable(
items, self.value, mutation_type=AttributeMutationNew()
)
elif self.value is torch.Size:
# This simulates `THPSize_pynew`, the C impl for `Size.__new__`.
tup = variables.BuiltinVariable(tuple).call_function(tx, args, kwargs)