mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[precompile] Detect source code changes for save/load. (#156432)
Go through all dynamo traced functions and compute checksum for them. While loading a precompilation back to memory, we will always check the checksum and refuse to load when source code changes are detected. Differential Revision: [D76987123](https://our.internmc.facebook.com/intern/diff/D76987123/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156432 Approved by: https://github.com/jansel, https://github.com/jamesjwu
This commit is contained in:
committed by
PyTorch MergeBot
parent
d3efd73234
commit
f096820d0f
@ -1,6 +1,9 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
@ -185,6 +188,76 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
||||
):
|
||||
compiled_fn(*args2)
|
||||
|
||||
def test_file_change(self):
|
||||
ctx = DiskDynamoStore()
|
||||
|
||||
def import_from_path(module_name, file_path):
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
mock_module_add_original = """
|
||||
def add(x, y):
|
||||
return x + y
|
||||
"""
|
||||
|
||||
mock_module_add_modified = """
|
||||
def add(x, y):
|
||||
return x - y
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
mock_module_add_original_path = os.path.join(
|
||||
tmp_dir, "mock_module_add_original.py"
|
||||
)
|
||||
mock_module_add_modified_path = os.path.join(
|
||||
tmp_dir, "mock_module_add_modified.py"
|
||||
)
|
||||
with open(mock_module_add_original_path, "w") as f:
|
||||
f.write(mock_module_add_original)
|
||||
with open(mock_module_add_modified_path, "w") as f:
|
||||
f.write(mock_module_add_modified)
|
||||
|
||||
module = import_from_path(
|
||||
"torch.test_package_helper",
|
||||
mock_module_add_original_path,
|
||||
)
|
||||
|
||||
def fn(x):
|
||||
return module.add(x, 1)
|
||||
|
||||
args = (torch.randn(3, 2),)
|
||||
|
||||
def guard_filter_fn(guards):
|
||||
return [
|
||||
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
|
||||
for guard in guards
|
||||
]
|
||||
|
||||
# Saving
|
||||
package = CompilePackage(fn)
|
||||
compiled_fn = torch._dynamo.optimize(
|
||||
backend="eager", package=package, guard_filter_fn=guard_filter_fn
|
||||
)(fn)
|
||||
compiled_fn(*args)
|
||||
for backend_id, backend in package.cached_backends.items():
|
||||
ctx.record_eager_backend(backend_id, backend)
|
||||
ctx.save_package(package, self.path())
|
||||
|
||||
module = import_from_path(
|
||||
"torch.test_package_helper",
|
||||
mock_module_add_modified_path,
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, "Source code changes detected"):
|
||||
ctx.load_package(fn, self.path())
|
||||
|
||||
module = import_from_path(
|
||||
"torch.test_package_helper",
|
||||
mock_module_add_original_path,
|
||||
)
|
||||
ctx.load_package(fn, self.path())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -976,6 +976,7 @@ def _compile(
|
||||
if package is not None:
|
||||
assert check_fn.guards_state is not None
|
||||
package.add_guarded_code(check_fn.guards_state, out_code)
|
||||
package.add_inlined_source(output.tracing_context.traced_code)
|
||||
|
||||
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
|
||||
annotation_str = "Torch-Compiled Region: " + compile_id_str
|
||||
|
@ -14,6 +14,7 @@ import dataclasses
|
||||
import functools
|
||||
import hashlib
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
@ -96,6 +97,14 @@ _BackendId = NewType("_BackendId", str) # __compiled_fn
|
||||
_FunctionId = NewType("_FunctionId", str) # __resume_at
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class InlinedSource:
|
||||
module: str
|
||||
firstlineno: int
|
||||
lastlineno: int
|
||||
checksum: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DynamoCodeCacheEntry:
|
||||
"""
|
||||
@ -124,6 +133,7 @@ class _DynamoCodeCacheEntry:
|
||||
@dataclasses.dataclass
|
||||
class _DynamoCacheEntry:
|
||||
codes: list[_DynamoCodeCacheEntry]
|
||||
inlined_sources: set[InlinedSource]
|
||||
python_version: str = platform.python_version()
|
||||
torch_version: str = torch.__version__
|
||||
|
||||
@ -142,6 +152,22 @@ class _DynamoCacheArtifact(PrecompileCacheArtifact[_DynamoCacheEntry]):
|
||||
return pickle.loads(self.content)
|
||||
|
||||
|
||||
def _hash_source(source: str) -> str:
|
||||
sha256_hash = hashlib.sha256()
|
||||
sha256_hash.update(source.encode())
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def _get_sourcelines(
|
||||
m: types.ModuleType, firstlineno: int, lastlineno: int
|
||||
) -> list[str]:
|
||||
return inspect.getsourcelines(m)[0][firstlineno - 1 : lastlineno - 1]
|
||||
|
||||
|
||||
def _hash_sourcelines(m: types.ModuleType, firstlineno: int, lastlineno: int) -> str:
|
||||
return _hash_source("".join(_get_sourcelines(m, firstlineno, lastlineno)))
|
||||
|
||||
|
||||
class CompilePackage:
|
||||
"""
|
||||
CompilePackage is considered a low level component and should not be directly exposed to
|
||||
@ -155,7 +181,12 @@ class CompilePackage:
|
||||
updates with compiled functions and resume functions.
|
||||
"""
|
||||
|
||||
def __init__(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
fn: Any,
|
||||
dynamo: Optional[_DynamoCacheEntry] = None,
|
||||
ignore_inlined_sources: bool = False,
|
||||
) -> None:
|
||||
self._innermost_fn = None
|
||||
self._codes: dict[types.CodeType, _DynamoCodeCacheEntry] = {}
|
||||
|
||||
@ -164,14 +195,22 @@ class CompilePackage:
|
||||
|
||||
# For debugging/testing purpose only.
|
||||
self._cached_backends: dict[_BackendId, Any] = {}
|
||||
self._inlined_sources: set[InlinedSource] = set()
|
||||
self._resume_codes: set[types.CodeType] = set()
|
||||
|
||||
self._initialize(fn, dynamo)
|
||||
self._initialize(fn, dynamo, ignore_inlined_sources)
|
||||
self.uninstall()
|
||||
self.validate()
|
||||
|
||||
def _initialize(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None:
|
||||
def _initialize(
|
||||
self,
|
||||
fn: Any,
|
||||
dynamo: Optional[_DynamoCacheEntry] = None,
|
||||
ignore_inlined_sources: bool = False,
|
||||
) -> None:
|
||||
from .eval_frame import innermost_fn
|
||||
|
||||
self._inlined_sources = set()
|
||||
self._innermost_fn = innermost_fn(fn)
|
||||
assert self._innermost_fn is not None
|
||||
if dynamo is not None:
|
||||
@ -184,6 +223,16 @@ class CompilePackage:
|
||||
raise RuntimeError(
|
||||
f"Compile package was created with a different PyTorch version: {dynamo.torch_version}"
|
||||
)
|
||||
if not ignore_inlined_sources:
|
||||
for code in dynamo.inlined_sources:
|
||||
m = importlib.import_module(code.module)
|
||||
checksum = _hash_sourcelines(m, code.firstlineno, code.lastlineno)
|
||||
if checksum != code.checksum:
|
||||
raise RuntimeError(
|
||||
f"Source code changes detected for {code.module} (line {code.firstlineno} - line {code.lastlineno})"
|
||||
)
|
||||
|
||||
self._inlined_sources = dynamo.inlined_sources
|
||||
|
||||
main, *codes = dynamo.codes
|
||||
self._codes = {self._innermost_fn.__code__: main}
|
||||
@ -252,6 +301,27 @@ class CompilePackage:
|
||||
)
|
||||
self._current_entry.guarded_codes.append(guarded_code_entry)
|
||||
|
||||
def add_inlined_source(self, sources: list[types.CodeType]) -> None:
|
||||
for code in sources:
|
||||
if code in self._resume_codes:
|
||||
continue
|
||||
module = inspect.getmodule(code)
|
||||
if module is None:
|
||||
continue
|
||||
source = inspect.getsource(code)
|
||||
lastlineno = code.co_firstlineno + len(inspect.getsourcelines(code)[0])
|
||||
assert source == "".join(
|
||||
_get_sourcelines(module, code.co_firstlineno, lastlineno)
|
||||
)
|
||||
self._inlined_sources.add(
|
||||
InlinedSource(
|
||||
module=module.__name__,
|
||||
firstlineno=code.co_firstlineno,
|
||||
lastlineno=lastlineno,
|
||||
checksum=_hash_source(source),
|
||||
)
|
||||
)
|
||||
|
||||
def add_resume_function(
|
||||
self,
|
||||
python_code: types.CodeType,
|
||||
@ -261,6 +331,7 @@ class CompilePackage:
|
||||
self._add_function(
|
||||
python_code, python_module, _FunctionId(name) if name else None
|
||||
)
|
||||
self._resume_codes.add(python_code)
|
||||
|
||||
def add_import_source(self, alias: str, module_name: str) -> None:
|
||||
assert self._current_entry is not None
|
||||
@ -345,7 +416,9 @@ class CompilePackage:
|
||||
|
||||
def cache_entry(self) -> _DynamoCacheEntry:
|
||||
self.validate()
|
||||
return _DynamoCacheEntry(codes=list(self._codes.values()))
|
||||
return _DynamoCacheEntry(
|
||||
codes=list(self._codes.values()), inlined_sources=self._inlined_sources
|
||||
)
|
||||
|
||||
|
||||
@CacheArtifactFactory.register
|
||||
|
Reference in New Issue
Block a user