[precompile] Detect source code changes for save/load. (#156432)

Go through all dynamo traced functions and compute checksum for them. While loading a precompilation back to memory, we will always check the checksum and refuse to load when
source code changes are detected.

Differential Revision: [D76987123](https://our.internmc.facebook.com/intern/diff/D76987123/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156432
Approved by: https://github.com/jansel, https://github.com/jamesjwu
This commit is contained in:
zhxchen17
2025-06-30 09:08:14 -07:00
committed by PyTorch MergeBot
parent d3efd73234
commit f096820d0f
3 changed files with 151 additions and 4 deletions

View File

@ -1,6 +1,9 @@
# Owner(s): ["module: dynamo"]
import importlib
import os
import sys
import tempfile
import unittest
import torch
@ -185,6 +188,76 @@ class TestPackage(torch._inductor.test_case.TestCase):
):
compiled_fn(*args2)
def test_file_change(self):
ctx = DiskDynamoStore()
def import_from_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
mock_module_add_original = """
def add(x, y):
return x + y
"""
mock_module_add_modified = """
def add(x, y):
return x - y
"""
with tempfile.TemporaryDirectory() as tmp_dir:
mock_module_add_original_path = os.path.join(
tmp_dir, "mock_module_add_original.py"
)
mock_module_add_modified_path = os.path.join(
tmp_dir, "mock_module_add_modified.py"
)
with open(mock_module_add_original_path, "w") as f:
f.write(mock_module_add_original)
with open(mock_module_add_modified_path, "w") as f:
f.write(mock_module_add_modified)
module = import_from_path(
"torch.test_package_helper",
mock_module_add_original_path,
)
def fn(x):
return module.add(x, 1)
args = (torch.randn(3, 2),)
def guard_filter_fn(guards):
return [
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
for guard in guards
]
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(
backend="eager", package=package, guard_filter_fn=guard_filter_fn
)(fn)
compiled_fn(*args)
for backend_id, backend in package.cached_backends.items():
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
module = import_from_path(
"torch.test_package_helper",
mock_module_add_modified_path,
)
with self.assertRaisesRegex(RuntimeError, "Source code changes detected"):
ctx.load_package(fn, self.path())
module = import_from_path(
"torch.test_package_helper",
mock_module_add_original_path,
)
ctx.load_package(fn, self.path())
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -976,6 +976,7 @@ def _compile(
if package is not None:
assert check_fn.guards_state is not None
package.add_guarded_code(check_fn.guards_state, out_code)
package.add_inlined_source(output.tracing_context.traced_code)
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
annotation_str = "Torch-Compiled Region: " + compile_id_str

View File

@ -14,6 +14,7 @@ import dataclasses
import functools
import hashlib
import importlib
import inspect
import logging
import os
import pickle
@ -96,6 +97,14 @@ _BackendId = NewType("_BackendId", str) # __compiled_fn
_FunctionId = NewType("_FunctionId", str) # __resume_at
@dataclasses.dataclass(frozen=True)
class InlinedSource:
module: str
firstlineno: int
lastlineno: int
checksum: str
@dataclasses.dataclass
class _DynamoCodeCacheEntry:
"""
@ -124,6 +133,7 @@ class _DynamoCodeCacheEntry:
@dataclasses.dataclass
class _DynamoCacheEntry:
codes: list[_DynamoCodeCacheEntry]
inlined_sources: set[InlinedSource]
python_version: str = platform.python_version()
torch_version: str = torch.__version__
@ -142,6 +152,22 @@ class _DynamoCacheArtifact(PrecompileCacheArtifact[_DynamoCacheEntry]):
return pickle.loads(self.content)
def _hash_source(source: str) -> str:
sha256_hash = hashlib.sha256()
sha256_hash.update(source.encode())
return sha256_hash.hexdigest()
def _get_sourcelines(
m: types.ModuleType, firstlineno: int, lastlineno: int
) -> list[str]:
return inspect.getsourcelines(m)[0][firstlineno - 1 : lastlineno - 1]
def _hash_sourcelines(m: types.ModuleType, firstlineno: int, lastlineno: int) -> str:
return _hash_source("".join(_get_sourcelines(m, firstlineno, lastlineno)))
class CompilePackage:
"""
CompilePackage is considered a low level component and should not be directly exposed to
@ -155,7 +181,12 @@ class CompilePackage:
updates with compiled functions and resume functions.
"""
def __init__(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None:
def __init__(
self,
fn: Any,
dynamo: Optional[_DynamoCacheEntry] = None,
ignore_inlined_sources: bool = False,
) -> None:
self._innermost_fn = None
self._codes: dict[types.CodeType, _DynamoCodeCacheEntry] = {}
@ -164,14 +195,22 @@ class CompilePackage:
# For debugging/testing purpose only.
self._cached_backends: dict[_BackendId, Any] = {}
self._inlined_sources: set[InlinedSource] = set()
self._resume_codes: set[types.CodeType] = set()
self._initialize(fn, dynamo)
self._initialize(fn, dynamo, ignore_inlined_sources)
self.uninstall()
self.validate()
def _initialize(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None:
def _initialize(
self,
fn: Any,
dynamo: Optional[_DynamoCacheEntry] = None,
ignore_inlined_sources: bool = False,
) -> None:
from .eval_frame import innermost_fn
self._inlined_sources = set()
self._innermost_fn = innermost_fn(fn)
assert self._innermost_fn is not None
if dynamo is not None:
@ -184,6 +223,16 @@ class CompilePackage:
raise RuntimeError(
f"Compile package was created with a different PyTorch version: {dynamo.torch_version}"
)
if not ignore_inlined_sources:
for code in dynamo.inlined_sources:
m = importlib.import_module(code.module)
checksum = _hash_sourcelines(m, code.firstlineno, code.lastlineno)
if checksum != code.checksum:
raise RuntimeError(
f"Source code changes detected for {code.module} (line {code.firstlineno} - line {code.lastlineno})"
)
self._inlined_sources = dynamo.inlined_sources
main, *codes = dynamo.codes
self._codes = {self._innermost_fn.__code__: main}
@ -252,6 +301,27 @@ class CompilePackage:
)
self._current_entry.guarded_codes.append(guarded_code_entry)
def add_inlined_source(self, sources: list[types.CodeType]) -> None:
for code in sources:
if code in self._resume_codes:
continue
module = inspect.getmodule(code)
if module is None:
continue
source = inspect.getsource(code)
lastlineno = code.co_firstlineno + len(inspect.getsourcelines(code)[0])
assert source == "".join(
_get_sourcelines(module, code.co_firstlineno, lastlineno)
)
self._inlined_sources.add(
InlinedSource(
module=module.__name__,
firstlineno=code.co_firstlineno,
lastlineno=lastlineno,
checksum=_hash_source(source),
)
)
def add_resume_function(
self,
python_code: types.CodeType,
@ -261,6 +331,7 @@ class CompilePackage:
self._add_function(
python_code, python_module, _FunctionId(name) if name else None
)
self._resume_codes.add(python_code)
def add_import_source(self, alias: str, module_name: str) -> None:
assert self._current_entry is not None
@ -345,7 +416,9 @@ class CompilePackage:
def cache_entry(self) -> _DynamoCacheEntry:
self.validate()
return _DynamoCacheEntry(codes=list(self._codes.values()))
return _DynamoCacheEntry(
codes=list(self._codes.values()), inlined_sources=self._inlined_sources
)
@CacheArtifactFactory.register