[torch.export] Rmoving unused constants - add support for corner case (#165205)

Summary: In some cases unused constant had only one level of child node, no second level of child node. Those constants should be removed too. The added test case has the scenario where this scenario will happen.

Test Plan:
```
buck test mode/opt caffe2/test:test_export -- 'test_unused_constant'
```

https://www.internalfb.com/intern/testinfra/testrun/15481123837456594

Differential Revision: D84398413

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165205
Approved by: https://github.com/angelayi
This commit is contained in:
Malay Bag
2025-10-14 20:26:24 +00:00
committed by PyTorch MergeBot
parent 2b4ef6b4d6
commit 058782c6ab
2 changed files with 23 additions and 0 deletions

View File

@ -1628,6 +1628,24 @@ graph():
ep = export(M(), (torch.ones(3),))
self.assertEqual(len(ep.constants), 0)
class M(torch.nn.Module):
def __init__(self, num_features: int = 1) -> None:
super().__init__()
self.num_features = num_features
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
res = [torch.Tensor([])] * self.num_features
for i in range(self.num_features):
res[i] = x * (i + 1)
return res
inp = torch.ones(3)
ep = export(M(), (inp,))
self.assertEqual(len(ep.constants), 0)
unf = unflatten(ep)
self.assertTrue(torch.allclose(M()(inp)[0], unf(inp)[0]))
def test_unbacked_bincount(self):
class Foo(torch.nn.Module):
def forward(self, xs):

View File

@ -142,6 +142,10 @@ def _unused_constant(node: torch.fx.Node) -> Optional[list[torch.fx.Node]]:
if len(lift_fresh_node.users) > 1:
return None
# Case 1: lift node is not used anywhere
if len(lift_fresh_node.users) == 0:
return [lift_fresh_node, node]
detach_node = next(iter(lift_fresh_node.users.keys()))
if not (
detach_node.op == "call_function"
@ -156,6 +160,7 @@ def _unused_constant(node: torch.fx.Node) -> Optional[list[torch.fx.Node]]:
if len(detach_node.users) > 0:
return None
else:
# Case 2: Lift node's child is not used anywhere
return [detach_node, lift_fresh_node, node]