mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[const_fold] Fix call_module const folding (#68614)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68614 We need to copy modules over to the `split` graph during const folding. We were previously only doing so from the non-constant submod, but we need to do this for the constant one as well in case some `call_module` is const folded. Test Plan: Added unit test Reviewed By: wushirong, 842974287 Differential Revision: D32543289 fbshipit-source-id: 80d1d0ce2c18a665b00e1343d6c55d939390ab10
This commit is contained in:
committed by
Facebook GitHub Bot
parent
39747dc456
commit
68d8ab0cc6
@ -601,7 +601,9 @@ class TestConstFold(TestCase):
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
quant_weight = torch.quantize_per_tensor(self.weight, 0.5, 3, torch.quint8)
|
||||
quant_weight = torch.quantize_per_tensor(
|
||||
self.weight, 0.5, 3, torch.quint8
|
||||
)
|
||||
dequant_weight = torch.dequantize(quant_weight)
|
||||
output = torch.nn.functional.linear(x, dequant_weight, self.bias)
|
||||
return self.relu(output)
|
||||
@ -630,3 +632,25 @@ class TestConstFold(TestCase):
|
||||
fold_result = gm_folded(in_x)
|
||||
base_result = mod(in_x)
|
||||
self.assertTrue(torch.equal(fold_result, base_result))
|
||||
|
||||
def test_fold_module(self):
|
||||
r"""
|
||||
Perform constant folding with a call_module node.
|
||||
"""
|
||||
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lin_input = torch.nn.Parameter(torch.randn(4, 4))
|
||||
self.lin = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
return self.lin(self.lin_input) + x
|
||||
|
||||
mod = ConstFoldTestModule()
|
||||
mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
|
||||
self._verify_const_fold_mod(mod_folded)
|
||||
|
||||
# Now run both folded and non-folded to check results equal.
|
||||
inp = torch.randn(4, 4)
|
||||
self.assertTrue(torch.equal(mod_folded(inp), mod(inp)))
|
||||
|
@ -171,6 +171,9 @@ def split_const_subgraphs(
|
||||
for node in non_const_gm.graph.nodes:
|
||||
if node.op == "call_module":
|
||||
setattr(split, node.target, getattr(non_const_gm, node.target))
|
||||
for node in const_gm.graph.nodes:
|
||||
if node.op == "call_module":
|
||||
setattr(split, node.target, getattr(const_gm, node.target))
|
||||
|
||||
# split_module currently does not use get_attrs for attrs. Instead it passes
|
||||
# them in as args from the parent module, which used get_attrs. Here we set
|
||||
|
Reference in New Issue
Block a user