mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Pytorch Edge] Remove methods_to_optimize arg (#57045)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57045 Went back and adjusted the previous optimizations to just be applied to every function. Cleaned up api to match. ghstack-source-id: 127214412 ghstack-source-id: 127536155 Test Plan: unit test Reviewed By: kimishpatel Differential Revision: D27950859 fbshipit-source-id: 214e83d5a19b452747fe223615815c10fa4aee58
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7b160e29a4
commit
60a5ebfac2
@ -105,7 +105,7 @@ class TestOptimizer(TestCase):
|
||||
initial_result = scripted_model(input_data)
|
||||
initial_foo_result = scripted_model.foo(input_data)
|
||||
|
||||
optimized_scripted_model = optimize_for_mobile(scripted_model, methods_to_optimize=['foo'])
|
||||
optimized_scripted_model = optimize_for_mobile(scripted_model, preserved_methods=['foo'])
|
||||
optimized_result = optimized_scripted_model(input_data)
|
||||
optimized_foo_result = optimized_scripted_model.foo(input_data)
|
||||
|
||||
@ -225,7 +225,7 @@ class TestOptimizer(TestCase):
|
||||
m.eval()
|
||||
initial_result = m.foo(input_data)
|
||||
|
||||
optimized_scripted_model = optimize_for_mobile(m, methods_to_optimize=['foo'])
|
||||
optimized_scripted_model = optimize_for_mobile(m, preserved_methods=['foo'])
|
||||
optimized_result = optimized_scripted_model.foo(input_data)
|
||||
|
||||
FileCheck().check_not("dropout.__") \
|
||||
@ -254,7 +254,7 @@ class TestOptimizer(TestCase):
|
||||
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
|
||||
.run(bn_no_forward_scripted_module.foo.graph)
|
||||
|
||||
bn_fold_no_foward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, methods_to_optimize=['foo'])
|
||||
bn_fold_no_foward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, preserved_methods=['foo'])
|
||||
self.assertEqual(len(torch.jit.export_opnames(bn_fold_no_foward_scripted_module)), 1)
|
||||
bn_input = torch.rand(1, 1, 6, 6)
|
||||
torch.testing.assert_allclose(
|
||||
|
||||
Reference in New Issue
Block a user