mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] fix NamedTupleVariable cloning (#158190)
FIXES https://github.com/pytorch/pytorch/issues/157945 ## Explanation 1. Some VTs add additional attrs e.g. NamedTupleVariable has "dynamic_attributes"a0308edb6c/torch/_dynamo/variables/lists.py (L1048-L1051)
2. VT.clone passes everything by dict, includes "dynamic_attributes"a0308edb6c/torch/_dynamo/variables/base.py (L255-L259)
3. Non-handled args become kwargs in VT's `__init__`, `super().__init__()` passes kwargs to Base VTa0308edb6c/torch/_dynamo/variables/lists.py (L1048-L1051)
4. Base VT's `__init__` gets unexpected "dynamic_attributes" kwarga0308edb6c/torch/_dynamo/variables/base.py (L609-L613)
You could also let Base VT's `__init__` ignore additional kwargs, but that seemed a bit too permissive, and I don't think many VT's add these derived class only attrs. ## After fix ```python ===== __compiled_fn_1_7f9541ed_e166_43fe_8322_c5225ce4207f ===== /home/xmfan/core/miniconda3/envs/0712/lib/python3.12/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[4, 8, 6][48, 6, 1]cpu"): l_x_ = L_x_ # File: /home/xmfan/core/a/torchtitan/wtf.py:10 in forward, code: U, S = torch.linalg.svd(x)[:2] linalg_svd = torch._C._linalg.linalg_svd(l_x_); l_x_ = None U: "f32[4, 8, 8][64, 1, 8]cpu" = linalg_svd[0] S: "f32[4, 6][6, 1]cpu" = linalg_svd[1]; linalg_svd = None # File: /home/xmfan/core/a/torchtitan/wtf.py:11 in forward, code: reduced = U[:, :, :self.k] @ torch.diag_embed(S[:, :self.k]) getitem_3: "f32[4, 8, 5][64, 1, 8]cpu" = U[(slice(None, None, None), slice(None, None, None), slice(None, 5, None))]; U = None getitem_4: "f32[4, 5][6, 1]cpu" = S[(slice(None, None, None), slice(None, 5, None))]; S = None diag_embed: "f32[4, 5, 5][25, 5, 1]cpu" = torch.diag_embed(getitem_4); getitem_4 = None reduced: "f32[4, 8, 5][40, 5, 1]cpu" = getitem_3 @ diag_embed; getitem_3 = diag_embed = None return (reduced,) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158190 Approved by: https://github.com/StrongerXi
This commit is contained in:
committed by
PyTorch MergeBot
parent
08799217ae
commit
7cf31b4a42
@ -7569,6 +7569,25 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
with mock.patch("torch.cuda.is_initialized", lambda: False):
|
||||
self.assertEqual(f(inp), inp + 2)
|
||||
|
||||
def test_named_tuple_vt_clone(self):
|
||||
# https://github.com/pytorch/pytorch/issues/157945
|
||||
class SVDCompressor(nn.Module):
|
||||
def __init__(self, k=10):
|
||||
super().__init__()
|
||||
self.k = k
|
||||
|
||||
def forward(self, x):
|
||||
U, S = torch.linalg.svd(x)[:2]
|
||||
reduced = U[:, :, : self.k] @ torch.diag_embed(S[:, : self.k])
|
||||
return reduced
|
||||
|
||||
input = torch.randn(4, 8, 6)
|
||||
model = SVDCompressor(k=5)
|
||||
|
||||
out1 = model(input.clone())
|
||||
out2 = torch.compile(model, backend="eager")(input.clone())
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ReproTests)
|
||||
|
||||
|
@ -1045,10 +1045,10 @@ class NamedTupleVariable(TupleVariable):
|
||||
*TupleVariable._nonvar_fields,
|
||||
}
|
||||
|
||||
def __init__(self, items, tuple_cls, **kwargs) -> None:
|
||||
def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
self.tuple_cls = tuple_cls
|
||||
self.dynamic_attributes = {}
|
||||
self.dynamic_attributes = {} if not dynamic_attributes else dynamic_attributes
|
||||
|
||||
def is_namedtuple(self):
|
||||
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(
|
||||
|
Reference in New Issue
Block a user