diff --git a/test/fx/test_fx_const_fold.py b/test/fx/test_fx_const_fold.py index 2f70e9b730d3..0d178e956c47 100644 --- a/test/fx/test_fx_const_fold.py +++ b/test/fx/test_fx_const_fold.py @@ -601,7 +601,9 @@ class TestConstFold(TestCase): self.relu = torch.nn.ReLU() def forward(self, x): - quant_weight = torch.quantize_per_tensor(self.weight, 0.5, 3, torch.quint8) + quant_weight = torch.quantize_per_tensor( + self.weight, 0.5, 3, torch.quint8 + ) dequant_weight = torch.dequantize(quant_weight) output = torch.nn.functional.linear(x, dequant_weight, self.bias) return self.relu(output) @@ -630,3 +632,25 @@ class TestConstFold(TestCase): fold_result = gm_folded(in_x) base_result = mod(in_x) self.assertTrue(torch.equal(fold_result, base_result)) + + def test_fold_module(self): + r""" + Perform constant folding with a call_module node. + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin_input = torch.nn.Parameter(torch.randn(4, 4)) + self.lin = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.lin(self.lin_input) + x + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + inp = torch.randn(4, 4) + self.assertTrue(torch.equal(mod_folded(inp), mod(inp))) diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index 554395f37f29..a7365ee668f6 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -171,6 +171,9 @@ def split_const_subgraphs( for node in non_const_gm.graph.nodes: if node.op == "call_module": setattr(split, node.target, getattr(non_const_gm, node.target)) + for node in const_gm.graph.nodes: + if node.op == "call_module": + setattr(split, node.target, getattr(const_gm, node.target)) # split_module currently does not use get_attrs for attrs. Instead it passes # them in as args from the parent module, which used get_attrs. Here we set