[pytorch] Replace "blacklist" in test/test_mobile_optimizer.py (#45512)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45512

This diff addresses https://github.com/pytorch/pytorch/issues/41443.
It is a clone of D23205313 which could not be imported from GitHub
for strange reasons.

Test Plan: Continuous integration.

Reviewed By: AshkanAliabadi

Differential Revision: D23967322

fbshipit-source-id: 744eb92de7cb5f0bc9540ed6a994f9e6dce8919a
This commit is contained in:
Meghan Lele
2020-09-30 10:41:05 -07:00
committed by Facebook GitHub Bot
parent a245dd4317
commit ce9df084d5

View File

@ -100,8 +100,8 @@ class TestOptimizer(unittest.TestCase):
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
optimization_blacklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blacklist_no_prepack)
optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack)
optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data)
FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \
@ -118,14 +118,14 @@ class TestOptimizer(unittest.TestCase):
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(str(get_forward(bn_scripted_module._c).graph))
optimization_blacklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blacklist_no_prepack)
optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack)
self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1)
bn_input = torch.rand(1, 1, 6, 6)
torch.testing.assert_allclose(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
optimization_blacklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blacklist_no_fold_bn)
optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn)
FileCheck().check_count("aten::batch_norm", 1, exactly=True) \
.run(str(get_forward_graph(no_bn_fold_scripted_module._c)))
bn_input = torch.rand(1, 1, 6, 6)