mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[test_fx_const_fold] Remove dependencies on acc_* (#72810)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72810 Test Plan: CI Reviewed By: hl475 Differential Revision: D34220004 fbshipit-source-id: c58e287cb140411dcb5a6795c179004612e4016c (cherry picked from commit 0f7c99f00498f224c60b7d5ecd2c3d902d5d6785)
This commit is contained in:
committed by
PyTorch MergeBot
parent
8e8c15cf6e
commit
454e2ec7bc
@ -5,7 +5,6 @@ import operator
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.experimental import const_fold
|
||||
from torch.fx.experimental.fx_acc import acc_tracer, acc_ops
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
@ -610,14 +609,14 @@ class TestConstFold(TestCase):
|
||||
|
||||
mod = ConstFoldTestModule()
|
||||
in_x = torch.randn(2, 4)
|
||||
gm = acc_tracer.trace(mod, in_x)
|
||||
gm = torch.fx.symbolic_trace(mod)
|
||||
|
||||
def skip_folding_quant_dequant(node: torch.fx.Node):
|
||||
if node.target != acc_ops.quantize_per_tensor:
|
||||
if node.target != torch.quantize_per_tensor:
|
||||
return False
|
||||
# If quantize_per_node -> dequantize, then skip folding.
|
||||
for user in node.users:
|
||||
if user.target == acc_ops.dequantize:
|
||||
if user.target == torch.dequantize:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
Reference in New Issue
Block a user