mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[torch.export] Rmoving unused constants - add support for corner case (#165205)
Summary: In some cases unused constant had only one level of child node, no second level of child node. Those constants should be removed too. The added test case has the scenario where this scenario will happen. Test Plan: ``` buck test mode/opt caffe2/test:test_export -- 'test_unused_constant' ``` https://www.internalfb.com/intern/testinfra/testrun/15481123837456594 Differential Revision: D84398413 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165205 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
2b4ef6b4d6
commit
058782c6ab
@ -1628,6 +1628,24 @@ graph():
|
||||
ep = export(M(), (torch.ones(3),))
|
||||
self.assertEqual(len(ep.constants), 0)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, num_features: int = 1) -> None:
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
res = [torch.Tensor([])] * self.num_features
|
||||
for i in range(self.num_features):
|
||||
res[i] = x * (i + 1)
|
||||
return res
|
||||
|
||||
inp = torch.ones(3)
|
||||
ep = export(M(), (inp,))
|
||||
self.assertEqual(len(ep.constants), 0)
|
||||
|
||||
unf = unflatten(ep)
|
||||
self.assertTrue(torch.allclose(M()(inp)[0], unf(inp)[0]))
|
||||
|
||||
def test_unbacked_bincount(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, xs):
|
||||
|
@ -142,6 +142,10 @@ def _unused_constant(node: torch.fx.Node) -> Optional[list[torch.fx.Node]]:
|
||||
if len(lift_fresh_node.users) > 1:
|
||||
return None
|
||||
|
||||
# Case 1: lift node is not used anywhere
|
||||
if len(lift_fresh_node.users) == 0:
|
||||
return [lift_fresh_node, node]
|
||||
|
||||
detach_node = next(iter(lift_fresh_node.users.keys()))
|
||||
if not (
|
||||
detach_node.op == "call_function"
|
||||
@ -156,6 +160,7 @@ def _unused_constant(node: torch.fx.Node) -> Optional[list[torch.fx.Node]]:
|
||||
if len(detach_node.users) > 0:
|
||||
return None
|
||||
else:
|
||||
# Case 2: Lift node's child is not used anywhere
|
||||
return [detach_node, lift_fresh_node, node]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user