mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
Opt model save and load (#126374)
## save&load support for OptimizedModule [Issue Description](https://github.com/pytorch/pytorch/pull/101651) English is not my native language; please excuse typing errors. This pr is based on commit b9588101c4d3411b107fdc860acfa8a72c642f91\ I'll do something with the merge conflicts later ### test result for test/dynamo Conclusion:\ It performs the same as before as far as I can see. ENV(CPU only):\ platform linux -- Python 3.10.14, pytest-7.3.2, pluggy-1.5.0\ configfile: pytest.ini\ plugins: anyio-3.7.1, cpp-2.3.0, flakefinder-1.1.0, xdist-3.3.1, xdoctest-1.1.0, metadata-3.1.1, html-4.1.1, hypothesis-5.35.1, rerunfailures-14.0 #### before this pr: [before](https://github.com/pytorch/pytorch/files/15329370/before.md) #### after this pr: [after](https://github.com/pytorch/pytorch/files/15329376/after.md) ### some changes 1. add test_save_and_load to test/dynamo/test_modules.py with & without "backend='inductor'" 2. add \_\_reduce\_\_ function to OptimizedModule and derived classes of _TorchDynamoContext for pickling & unpickling 3. change the wrappers into wrapper classes ( including convert_frame_assert, convert_frame, catch_errors_wrapper in torch/_dynamo/convert_frame.py & wrap_backend_debug in torch/_dynamo/repro/after_dynamo.py ) 4. change self.output.compiler_fn into innermost_fn(self.output.compiler_fn) in torch/_dynamo/symbolic_convert.py to get the origin compiler_fn and to avoid the "compiler_fn is not eager" condition Pull Request resolved: https://github.com/pytorch/pytorch/pull/126374 Approved by: https://github.com/msaroufim, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
9a8ab778d3
commit
c3949b20a1
@ -3,6 +3,8 @@
|
||||
import collections
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
import types
|
||||
import unittest
|
||||
@ -16,6 +18,7 @@ import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
import torch.nn.functional as F
|
||||
from torch._dynamo.debug_utils import same_two_models
|
||||
from torch._dynamo.eval_frame import unsupported
|
||||
from torch._dynamo.mutation_guard import GenerationTracker
|
||||
from torch._dynamo.testing import expectedFailureDynamic, same
|
||||
@ -2739,6 +2742,49 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(test_functions._variable, 1)
|
||||
self.assertEqual(res, 3 * torch.ones(10))
|
||||
|
||||
@unittest.skipIf(
|
||||
"inductor" not in torch._dynamo.list_backends(),
|
||||
"inductor backend is not available",
|
||||
)
|
||||
def test_save_and_load_inductor(self):
|
||||
mod = MockModule()
|
||||
opt_mod = torch.compile(mod, backend="inductor")
|
||||
inp = torch.randn(10, 10)
|
||||
opt_mod(inp)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
|
||||
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
|
||||
loaded_model(inp)
|
||||
self.assertTrue(same_two_models(loaded_model, mod, [inp]))
|
||||
self.assertTrue(same_two_models(loaded_model, opt_mod, [inp]))
|
||||
|
||||
torch._dynamo.reset() # force recompiles
|
||||
torch._inductor.metrics.generated_kernel_count = 0
|
||||
loaded_model(inp)
|
||||
self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0)
|
||||
|
||||
def test_save_and_load_all_backends(self):
|
||||
mod = MockModule()
|
||||
inp = torch.randn(10, 10)
|
||||
for backend in torch._dynamo.list_backends():
|
||||
try:
|
||||
opt_mod = torch.compile(mod, backend=backend)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
|
||||
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
|
||||
torch._dynamo.reset() # force recompiles
|
||||
torch._inductor.metrics.generated_kernel_count = 0
|
||||
opt_mod(inp)
|
||||
opt_success = torch._inductor.metrics.generated_kernel_count == 0
|
||||
torch._dynamo.reset() # force recompiles
|
||||
torch._inductor.metrics.generated_kernel_count = 0
|
||||
loaded_model(inp)
|
||||
loaded_success = torch._inductor.metrics.generated_kernel_count == 0
|
||||
self.assertEqual(opt_success, loaded_success)
|
||||
except torch._dynamo.exc.BackendCompilerFailed:
|
||||
pass
|
||||
|
||||
def test_monkeypatching_forward(self):
|
||||
class FakeModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
||||
Reference in New Issue
Block a user