mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
NamedTuple: Allow side effects for dynamic attributes (#161645)
I confirmed that the tracing was correct i.e. NamedTupleVariable had the correct dynamic attribute added to it. The problem was that NamedTupleVariable was always marked as immutable. This does not reflect the behavior of namedtuple. Subclasses of namedtuple may be mutable, so when a NamedTupleVariable is derived from a subclass that is mutable, I made NamedTupleVariable mutable as well. Then side_effects correctly updates the returned object. Fixes #161610 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161645 Approved by: https://github.com/anijain2305, https://github.com/StrongerXi
This commit is contained in:
committed by
PyTorch MergeBot
parent
8508651477
commit
86d34a43f5
@ -1767,6 +1767,52 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||
out = f(MyTuple(a, b))
|
||||
self.assertTrue(same(a + 1, out))
|
||||
|
||||
def test_namedtuple_source_dynamic_attributes(self):
|
||||
class MyNamedTuple(typing.NamedTuple):
|
||||
a: torch.Tensor
|
||||
b: torch.Tensor
|
||||
|
||||
class MyNamedTupleSubclass(MyNamedTuple):
|
||||
pass
|
||||
|
||||
@torch.compile(fullgraph=True, backend="eager")
|
||||
def f(tup):
|
||||
c = torch.tensor(3.0)
|
||||
tup.c = c # Add dynamic attribute
|
||||
return tup
|
||||
|
||||
extended_tup = MyNamedTupleSubclass(a=torch.tensor([1.0]), b=torch.tensor(2.0))
|
||||
result = f(extended_tup)
|
||||
# Verify the tuple has the expected structure
|
||||
self.assertEqual(result.a, torch.tensor([1.0]))
|
||||
self.assertEqual(result.b, torch.tensor(2.0))
|
||||
self.assertTrue(hasattr(result, "c"))
|
||||
self.assertEqual(result.c, torch.tensor(3.0))
|
||||
|
||||
def test_namedtuple_sourceless_dynamic_attributes(self):
|
||||
class MyNamedTuple(typing.NamedTuple):
|
||||
a: torch.Tensor
|
||||
b: torch.Tensor
|
||||
|
||||
class MyNamedTupleSubclass(MyNamedTuple):
|
||||
pass
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def f():
|
||||
# Create namedtuple inside function (sourceless)
|
||||
tup = MyNamedTupleSubclass(a=torch.tensor([1.0]), b=torch.tensor(2.0))
|
||||
# Add dynamic attribute
|
||||
tup.c = torch.tensor(3.0)
|
||||
return tup
|
||||
|
||||
result = f()
|
||||
# Verify the tuple has the expected structure
|
||||
self.assertEqual(result.a, torch.tensor([1.0]))
|
||||
self.assertEqual(result.b, torch.tensor(2.0))
|
||||
# Verify the dynamic attribute is preserved
|
||||
self.assertTrue(hasattr(result, "c"))
|
||||
self.assertEqual(result.c, torch.tensor(3.0))
|
||||
|
||||
def test_structseq1(self):
|
||||
def fn(x, y):
|
||||
return torch.return_types.max((x, y))
|
||||
|
@ -708,7 +708,7 @@ class VariableBuilder:
|
||||
result = NamedTupleVariable(
|
||||
output, tuple_cls=type(value), source=self.source
|
||||
)
|
||||
return result
|
||||
return self.tx.output.side_effects.track_object_existing(value, result)
|
||||
elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
all_const = all(ConstantVariable.is_literal(k) for k in value.keys())
|
||||
|
@ -26,7 +26,11 @@ import torch
|
||||
import torch.fx
|
||||
|
||||
from .. import graph_break_hints, polyfills, variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..bytecode_transformation import (
|
||||
create_call_function,
|
||||
create_instruction,
|
||||
create_rot_n,
|
||||
)
|
||||
from ..exc import raise_observed_exception, unimplemented_v2
|
||||
from ..source import AttrSource, NamedTupleFieldsSource
|
||||
from ..utils import (
|
||||
@ -1173,9 +1177,24 @@ class NamedTupleVariable(TupleVariable):
|
||||
def as_python_constant(self):
|
||||
if self.is_structseq():
|
||||
# StructSequenceType(iterable)
|
||||
return self.python_type()([x.as_python_constant() for x in self.items])
|
||||
# NamedTupleType(*iterable)
|
||||
return self.python_type()(*[x.as_python_constant() for x in self.items])
|
||||
result = self.python_type()([x.as_python_constant() for x in self.items])
|
||||
else:
|
||||
# NamedTupleType(*iterable)
|
||||
result = self.python_type()(*[x.as_python_constant() for x in self.items])
|
||||
|
||||
# Apply dynamic attributes if any were set
|
||||
if self.dynamic_attributes:
|
||||
for attr_name, attr_value in self.dynamic_attributes.items():
|
||||
# Convert VariableTracker to Python constant if needed
|
||||
if hasattr(attr_value, "as_python_constant"):
|
||||
python_value = attr_value.as_python_constant()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Can not convert dynamic attribute without python constant value to python constant."
|
||||
)
|
||||
setattr(result, attr_name, python_value)
|
||||
|
||||
return result
|
||||
|
||||
def as_proxy(self):
|
||||
assert self.python_type() is not SizeVariable
|
||||
@ -1186,6 +1205,7 @@ class NamedTupleVariable(TupleVariable):
|
||||
return self.python_type()(*self._as_proxy())
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# Always reconstruct the NamedTuple normally first
|
||||
# Constructors:
|
||||
# StructSequenceType(iterable)
|
||||
# NamedTupleType(*iterable)
|
||||
@ -1204,6 +1224,12 @@ class NamedTupleVariable(TupleVariable):
|
||||
+ create_call_function(1, False)
|
||||
)
|
||||
|
||||
for name, value in self.dynamic_attributes.items():
|
||||
codegen.dup_top()
|
||||
codegen(value)
|
||||
codegen.extend_output(create_rot_n(2))
|
||||
codegen.store_attr(name)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
@ -1227,6 +1253,8 @@ class NamedTupleVariable(TupleVariable):
|
||||
raise_observed_exception(AttributeError, tx)
|
||||
# Subclass of namedtuple type can have dynamic attributes
|
||||
tx.output.side_effects.mutation(self)
|
||||
if self.source:
|
||||
tx.output.side_effects.store_attr(self, attr, value)
|
||||
self.dynamic_attributes[attr] = value
|
||||
return ConstantVariable.create(None)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
@ -716,7 +716,12 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
|
||||
assert all(x is not None for x in items)
|
||||
|
||||
return variables.NamedTupleVariable(items, self.value)
|
||||
# Modify mutability of namedtuple for sourcelesss instantiations.
|
||||
from .base import AttributeMutationNew
|
||||
|
||||
return variables.NamedTupleVariable(
|
||||
items, self.value, mutation_type=AttributeMutationNew()
|
||||
)
|
||||
elif self.value is torch.Size:
|
||||
# This simulates `THPSize_pynew`, the C impl for `Size.__new__`.
|
||||
tup = variables.BuiltinVariable(tuple).call_function(tx, args, kwargs)
|
||||
|
Reference in New Issue
Block a user