mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytorch] Replace "blacklist" in test/test_mobile_optimizer.py (#45512)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45512 This diff addresses https://github.com/pytorch/pytorch/issues/41443. It is a clone of D23205313 which could not be imported from GitHub for strange reasons. Test Plan: Continuous integration. Reviewed By: AshkanAliabadi Differential Revision: D23967322 fbshipit-source-id: 744eb92de7cb5f0bc9540ed6a994f9e6dce8919a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a245dd4317
commit
ce9df084d5
@ -100,8 +100,8 @@ class TestOptimizer(unittest.TestCase):
|
||||
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
|
||||
optimization_blacklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
|
||||
optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blacklist_no_prepack)
|
||||
optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
|
||||
optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack)
|
||||
optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data)
|
||||
|
||||
FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \
|
||||
@ -118,14 +118,14 @@ class TestOptimizer(unittest.TestCase):
|
||||
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
|
||||
.run(str(get_forward(bn_scripted_module._c).graph))
|
||||
|
||||
optimization_blacklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
|
||||
bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blacklist_no_prepack)
|
||||
optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
|
||||
bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack)
|
||||
self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1)
|
||||
bn_input = torch.rand(1, 1, 6, 6)
|
||||
torch.testing.assert_allclose(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
|
||||
|
||||
optimization_blacklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
|
||||
no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blacklist_no_fold_bn)
|
||||
optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
|
||||
no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn)
|
||||
FileCheck().check_count("aten::batch_norm", 1, exactly=True) \
|
||||
.run(str(get_forward_graph(no_bn_fold_scripted_module._c)))
|
||||
bn_input = torch.rand(1, 1, 6, 6)
|
||||
|
Reference in New Issue
Block a user