[test_fx_const_fold] Remove dependencies on acc_* (#72810)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72810

Test Plan: CI

Reviewed By: hl475

Differential Revision: D34220004

fbshipit-source-id: c58e287cb140411dcb5a6795c179004612e4016c
(cherry picked from commit 0f7c99f00498f224c60b7d5ecd2c3d902d5d6785)
This commit is contained in:
Jordan Fix
2022-02-14 12:33:10 -08:00
committed by PyTorch MergeBot
parent 8e8c15cf6e
commit 454e2ec7bc

View File

@ -5,7 +5,6 @@ import operator
import torch
import torch.fx
from torch.fx.experimental import const_fold
from torch.fx.experimental.fx_acc import acc_tracer, acc_ops
from torch.testing._internal.common_utils import TestCase
@ -610,14 +609,14 @@ class TestConstFold(TestCase):
mod = ConstFoldTestModule()
in_x = torch.randn(2, 4)
gm = acc_tracer.trace(mod, in_x)
gm = torch.fx.symbolic_trace(mod)
def skip_folding_quant_dequant(node: torch.fx.Node):
if node.target != acc_ops.quantize_per_tensor:
if node.target != torch.quantize_per_tensor:
return False
# If quantize_per_node -> dequantize, then skip folding.
for user in node.users:
if user.target == acc_ops.dequantize:
if user.target == torch.dequantize:
return True
return False