[dynamo] fix NamedTupleVariable cloning (#158190)

FIXES https://github.com/pytorch/pytorch/issues/157945

## Explanation
1. Some VTs add additional attrs e.g. NamedTupleVariable has "dynamic_attributes"
a0308edb6c/torch/_dynamo/variables/lists.py (L1048-L1051)

2. VT.clone passes everything by dict, includes "dynamic_attributes"
a0308edb6c/torch/_dynamo/variables/base.py (L255-L259)

3. Non-handled args become kwargs in VT's `__init__`, `super().__init__()` passes kwargs to Base VT
a0308edb6c/torch/_dynamo/variables/lists.py (L1048-L1051)

4. Base VT's `__init__` gets unexpected "dynamic_attributes" kwarg
a0308edb6c/torch/_dynamo/variables/base.py (L609-L613)

You could also let Base VT's `__init__` ignore additional kwargs, but that seemed a bit too permissive, and I don't think many VT's add these derived class only attrs.

## After fix

```python
 ===== __compiled_fn_1_7f9541ed_e166_43fe_8322_c5225ce4207f =====
 /home/xmfan/core/miniconda3/envs/0712/lib/python3.12/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[4, 8, 6][48, 6, 1]cpu"):
        l_x_ = L_x_

         # File: /home/xmfan/core/a/torchtitan/wtf.py:10 in forward, code: U, S = torch.linalg.svd(x)[:2]
        linalg_svd = torch._C._linalg.linalg_svd(l_x_);  l_x_ = None
        U: "f32[4, 8, 8][64, 1, 8]cpu" = linalg_svd[0]
        S: "f32[4, 6][6, 1]cpu" = linalg_svd[1];  linalg_svd = None

         # File: /home/xmfan/core/a/torchtitan/wtf.py:11 in forward, code: reduced = U[:, :, :self.k] @ torch.diag_embed(S[:, :self.k])
        getitem_3: "f32[4, 8, 5][64, 1, 8]cpu" = U[(slice(None, None, None), slice(None, None, None), slice(None, 5, None))];  U = None
        getitem_4: "f32[4, 5][6, 1]cpu" = S[(slice(None, None, None), slice(None, 5, None))];  S = None
        diag_embed: "f32[4, 5, 5][25, 5, 1]cpu" = torch.diag_embed(getitem_4);  getitem_4 = None
        reduced: "f32[4, 8, 5][40, 5, 1]cpu" = getitem_3 @ diag_embed;  getitem_3 = diag_embed = None
        return (reduced,)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158190
Approved by: https://github.com/StrongerXi
This commit is contained in:
Simon Fan
2025-07-12 10:23:42 -07:00
committed by PyTorch MergeBot
parent 08799217ae
commit 7cf31b4a42
2 changed files with 21 additions and 2 deletions

View File

@ -7569,6 +7569,25 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
with mock.patch("torch.cuda.is_initialized", lambda: False):
self.assertEqual(f(inp), inp + 2)
def test_named_tuple_vt_clone(self):
# https://github.com/pytorch/pytorch/issues/157945
class SVDCompressor(nn.Module):
def __init__(self, k=10):
super().__init__()
self.k = k
def forward(self, x):
U, S = torch.linalg.svd(x)[:2]
reduced = U[:, :, : self.k] @ torch.diag_embed(S[:, : self.k])
return reduced
input = torch.randn(4, 8, 6)
model = SVDCompressor(k=5)
out1 = model(input.clone())
out2 = torch.compile(model, backend="eager")(input.clone())
self.assertEqual(out1, out2)
instantiate_parametrized_tests(ReproTests)

View File

@ -1045,10 +1045,10 @@ class NamedTupleVariable(TupleVariable):
*TupleVariable._nonvar_fields,
}
def __init__(self, items, tuple_cls, **kwargs) -> None:
def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None:
super().__init__(items, **kwargs)
self.tuple_cls = tuple_cls
self.dynamic_attributes = {}
self.dynamic_attributes = {} if not dynamic_attributes else dynamic_attributes
def is_namedtuple(self):
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(