Opt model save and load (#126374)

## save&load support for OptimizedModule

[Issue Description](https://github.com/pytorch/pytorch/pull/101651)

English is not my native language; please excuse typing errors.

This pr is based on commit b9588101c4d3411b107fdc860acfa8a72c642f91\
I'll do something with the merge conflicts later

### test result for test/dynamo

Conclusion:\
It performs the same as before as far as I can see.

ENV(CPU only):\
platform linux -- Python 3.10.14, pytest-7.3.2, pluggy-1.5.0\
configfile: pytest.ini\
plugins: anyio-3.7.1, cpp-2.3.0, flakefinder-1.1.0, xdist-3.3.1, xdoctest-1.1.0, metadata-3.1.1, html-4.1.1, hypothesis-5.35.1, rerunfailures-14.0

#### before this pr:

[before](https://github.com/pytorch/pytorch/files/15329370/before.md)

#### after this pr:

[after](https://github.com/pytorch/pytorch/files/15329376/after.md)

### some changes

1. add test_save_and_load to test/dynamo/test_modules.py with & without "backend='inductor'"
2. add \_\_reduce\_\_ function to OptimizedModule and derived classes of _TorchDynamoContext for pickling & unpickling
3. change the wrappers into wrapper classes ( including convert_frame_assert, convert_frame, catch_errors_wrapper in torch/_dynamo/convert_frame.py & wrap_backend_debug in torch/_dynamo/repro/after_dynamo.py )
4. change self.output.compiler_fn into innermost_fn(self.output.compiler_fn) in torch/_dynamo/symbolic_convert.py to get the origin compiler_fn and to avoid the "compiler_fn is not eager" condition

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126374
Approved by: https://github.com/msaroufim, https://github.com/jansel
This commit is contained in:
weiyusheng
2024-06-05 13:01:16 +00:00
committed by PyTorch MergeBot
parent 9a8ab778d3
commit c3949b20a1
5 changed files with 203 additions and 75 deletions

View File

@ -3,6 +3,8 @@
import collections
import copy
import itertools
import os
import tempfile
import traceback
import types
import unittest
@ -16,6 +18,7 @@ import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.nn.functional as F
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.eval_frame import unsupported
from torch._dynamo.mutation_guard import GenerationTracker
from torch._dynamo.testing import expectedFailureDynamic, same
@ -2739,6 +2742,49 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
self.assertEqual(test_functions._variable, 1)
self.assertEqual(res, 3 * torch.ones(10))
@unittest.skipIf(
"inductor" not in torch._dynamo.list_backends(),
"inductor backend is not available",
)
def test_save_and_load_inductor(self):
mod = MockModule()
opt_mod = torch.compile(mod, backend="inductor")
inp = torch.randn(10, 10)
opt_mod(inp)
with tempfile.TemporaryDirectory() as tmpdirname:
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
loaded_model(inp)
self.assertTrue(same_two_models(loaded_model, mod, [inp]))
self.assertTrue(same_two_models(loaded_model, opt_mod, [inp]))
torch._dynamo.reset() # force recompiles
torch._inductor.metrics.generated_kernel_count = 0
loaded_model(inp)
self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0)
def test_save_and_load_all_backends(self):
mod = MockModule()
inp = torch.randn(10, 10)
for backend in torch._dynamo.list_backends():
try:
opt_mod = torch.compile(mod, backend=backend)
with tempfile.TemporaryDirectory() as tmpdirname:
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
torch._dynamo.reset() # force recompiles
torch._inductor.metrics.generated_kernel_count = 0
opt_mod(inp)
opt_success = torch._inductor.metrics.generated_kernel_count == 0
torch._dynamo.reset() # force recompiles
torch._inductor.metrics.generated_kernel_count = 0
loaded_model(inp)
loaded_success = torch._inductor.metrics.generated_kernel_count == 0
self.assertEqual(opt_success, loaded_success)
except torch._dynamo.exc.BackendCompilerFailed:
pass
def test_monkeypatching_forward(self):
class FakeModule(torch.nn.Module):
def forward(self, x):