[Precompile] Various small bugfixes, add CachingPrecompile to torchbench (#158847)

This PR addresses a few small bugfixes needed to make NanoGPT inference work, and also adds a new `--caching-precompile` argument to torchbench. With `--caching-precompile`, after every benchmark we save precompile artifacts to DynamoCache, allowing us to test caching precompile on all existing benchmarks.

The following bugfixes are in this PR to make all of this work:
- Fix global variables being pruned with DUPLICATE_INPUT guards. DUPLICATE_INPUT guards have additional vars from the second input, which we track with additional_local_vars, but we never tracked additional global variables. This fixes the issue. (See torch/_dynamo/guards.py changes)
- Return None from PRecompileContext.serialize() if no new dynamo compiles occurred. There's no reason to save artifacts (i.e. autotuning artifacts, etc) if no dynamo_compile occurred, so we return None early. We may later want to support editing existing dynamo artifacts as a TODO, but that's upcoming.
- log `dynamo_start` on CompilePackage.load: This is only needed so that tlparse doesn't ignore TORCH_TRACE logs generated when caching precompile hits. If there are no actual compiles, we never log a "dynamo_start" entry, which makes internal tlparse ignore the TORCH_TRACE file.

## Test Plan

After this PR, the following now works:
```
TORCH_LOGS=dynamo tlp python benchmarks/dynamo/torchbench.py --only nanogpt --performance  --inference --backend inductor  --caching-precompile --warm-start-latency
```
tlparse result (internal):
Cold Start (6 seconds):
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpAWe0zD/dedicated_log_torch_trace_vk9nkp4m.log/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

Warm Start (~1 s):
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpAWe0zD/dedicated_log_torch_trace_5l4iwrpm.log/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

The 1 second of warm start here can be improved: the costs here are mostly in starting up workers and triton and initializing CUDA, a lot of which should not be included in the compile time cost in real world scenarios where these are already loaded before training begins.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158847
Approved by: https://github.com/zhxchen17
This commit is contained in:
James Wu
2025-07-23 19:50:31 -07:00
committed by PyTorch MergeBot
parent a3025e17b2
commit f55c5d085e
8 changed files with 85 additions and 42 deletions

View File

@ -3269,6 +3269,12 @@ def parse_args(args=None):
instead of deleting it and creating a new one.",
)
parser.add_argument(
"--caching-precompile",
action="store_true",
help="Enables caching precompile, serializing artifacts to DynamoCache between runs",
)
group_latency = parser.add_mutually_exclusive_group()
group_latency.add_argument(
"--cold-start-latency",
@ -3419,6 +3425,29 @@ def parse_args(args=None):
return parser.parse_args(args)
def process_caching_precompile():
"""
After every process_entry, save precompile artifacts to DynamoCache
"""
assert torch._dynamo.config.caching_precompile, (
"Caching precompile should be enabled with --caching-precompile"
)
from torch._dynamo.precompile_context import PrecompileContext
# Serialize all callables, clear PrecompileContext
# TODO: put this under torch.compiler API once ready
serialized = PrecompileContext.serialize()
PrecompileContext.clear()
if serialized is not None:
artifacts, info = serialized
print(
f"Saving {len(info.precompile_dynamo_artifacts)} Precompile Artifact(s)..."
)
results = PrecompileContext.deserialize(artifacts)
assert results is not None
PrecompileContext.populate_caches(results)
def process_entry(rank, runner, original_dir, args):
args.rank = rank
with maybe_init_distributed(
@ -3427,7 +3456,10 @@ def process_entry(rank, runner, original_dir, args):
world_size=args.world_size,
port=args.distributed_master_port,
):
return run(runner, args, original_dir)
result = run(runner, args, original_dir)
if args.caching_precompile:
process_caching_precompile()
return result
def maybe_fresh_cache(args):
@ -3463,6 +3495,10 @@ def main(runner, original_dir=None, args=None):
)
with maybe_fresh_cache(args):
if args.caching_precompile:
os.environ["TORCH_CACHING_PRECOMPILE"] = "1"
torch._dynamo.config.caching_precompile = True
args.init_distributed = args.only and args.multiprocess
if args.init_distributed:
# NB: Do NOT query device count before CUDA initialization; we're

View File

@ -4,7 +4,7 @@ import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._functorch
from torch._dynamo.precompile_context import PrecompileContext
from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext
from torch._functorch import config as functorch_config
from torch._functorch._aot_autograd.autograd_cache import (
BundledAOTAutogradCacheArtifact,
@ -14,8 +14,8 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, requires_triton
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch(
{"bundled_autograd_cache": True}
@torch._dynamo.config.patch(
{"caching_precompile": True}
) # Requires bundledaotautograd cache for now
class PrecompileContextTests(InductorTestCase):
def setUp(self):
@ -41,8 +41,7 @@ class PrecompileContextTests(InductorTestCase):
x = torch.randn(10, device=GPU_TYPE, requires_grad=True)
result = compiled_fn(x)
result.sum().backward()
# Check that PrecompileContext._new_cache_artifacts_by_key has length 1
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 2)
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
result = PrecompileContext.serialize()
@ -77,14 +76,11 @@ class PrecompileContextTests(InductorTestCase):
x = torch.randn(10, device=GPU_TYPE, requires_grad=True)
result = compiled_fn(x)
result.sum().backward()
# Check that PrecompileContext._new_cache_artifacts_by_key has length 1
# TODO: the key right now is the AOTAutogradCacheKey, but will be backend_id once
# we have torch._dynamo.package implemented
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)
key = next(iter(PrecompileContext._new_cache_artifacts_by_key.keys()))
result = PrecompileContext.serialize_artifact_by_key(key)
assert isinstance(result, BundledAOTAutogradCacheArtifact)
self.assertEqual(result.key, key)
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 2)
for key in PrecompileContext._new_cache_artifacts_by_key.keys():
result = PrecompileContext.serialize_artifact_by_key(key)
assert isinstance(result, PrecompileCacheArtifact)
self.assertEqual(result.key, key)
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
result = PrecompileContext.serialize()

View File

@ -549,7 +549,7 @@ fake_tensor_disable_inference_mode = True
# Experimental feature for running automatic caching precompile.
# Enables automatic DynamoCache save/load
caching_precompile = False
caching_precompile = os.environ.get("TORCH_CACHING_PRECOMPILE", "0") == "1"
# Enables the Compiled Autograd engine to trace autograd calls made under torch.compile().
# Note: AOTAutograd will still trace and partition an AOT backward graph local to that

View File

@ -225,6 +225,31 @@ def fx_forward_from_src_skip_result(
return result
def log_dynamo_start(code: CodeType, skip: int = 0) -> None:
convert_frame_intern = structured.intern_string(__file__)
# Initialize the ChromiumEventLogger on start
torch._logging.trace_structured(
"dynamo_start",
lambda: {
"stack": list(
itertools.takewhile(
lambda f: f["filename"] != convert_frame_intern,
structured.from_traceback(
CapturedTraceback.extract(skip=4 + skip).summary()
),
)
)
+ [
{
"line": code.co_firstlineno,
"name": code.co_name,
"filename": structured.intern_string(code.co_filename),
}
]
},
)
def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
"""
Context manager to:
@ -1135,28 +1160,7 @@ def _compile(
# # 2 extra here
# torch/_logging/_internal.py:1064 in trace_structured
# torch/_dynamo/convert_frame.py:780 in <lambda>
convert_frame_intern = structured.intern_string(__file__)
# Initialize the ChromiumEventLogger on start
torch._logging.trace_structured(
"dynamo_start",
lambda: {
"stack": list(
itertools.takewhile(
lambda f: f["filename"] != convert_frame_intern,
structured.from_traceback(
CapturedTraceback.extract(skip=4 + skip).summary()
),
)
)
+ [
{
"line": code.co_firstlineno,
"name": code.co_name,
"filename": structured.intern_string(code.co_filename),
}
]
},
)
log_dynamo_start(code, skip)
start_time_ns = time.time_ns()
fail_type: Optional[str] = None
fail_reason: Optional[str] = None
@ -1588,9 +1592,10 @@ class CatchErrorsWrapper:
with compile_lock, _disable_current_modes():
# skip=1: skip this frame
return self._torchdynamo_orig_backend(
result = self._torchdynamo_orig_backend(
frame, cache_entry, self.hooks, frame_state, skip=1
)
return result
def catch_errors_wrapper(

View File

@ -679,8 +679,7 @@ class _TorchDynamoContext:
# If self._package is lazily initialized, we should check the dynamo cache now
if config.caching_precompile:
assert self._package is not None
if not self._package.is_initialized():
if self._package is not None and not self._package.is_initialized():
result = DynamoCache.load(fn)
if result is None:
# Create a fresh CompilePackage

View File

@ -1970,6 +1970,8 @@ class GuardBuilder(GuardBuilderBase):
if self.serialization_mode == "save":
if name := get_local_source_name(source_b):
self.check_fn_manager.additional_used_local_vars.add(name)
if name := get_global_source_name(source_b):
self.check_fn_manager.additional_used_global_vars.add(name)
ref_a = self.arg_ref(guard)
ref_b = self.arg_ref(source_b.name())
@ -2849,6 +2851,7 @@ class CheckFunctionManager:
self.guards_serialization_mode = guards_serialization_mode
self.used_builtin_vars: OrderedSet[str] = OrderedSet()
self.additional_used_local_vars: OrderedSet[str] = OrderedSet()
self.additional_used_global_vars: OrderedSet[str] = OrderedSet()
if runtime_global_scope:
assert self.guards_serialization_mode == "load"
self.runtime_global_scope = runtime_global_scope
@ -3039,7 +3042,7 @@ class CheckFunctionManager:
global_scope_state = {
k: v
for k, v in output_graph_guards_state.global_scope.items()
if k in used_global_vars
if k in used_global_vars or k in self.additional_used_global_vars
}
global_scope_state[builtins_dict_name] = {
k: v

View File

@ -380,7 +380,7 @@ class CompilePackage:
3. Install the precompiled cache entries to ExtraStates on the code object.
"""
from torch._C._dynamo.eval_frame import _load_precompile_entry
from torch._dynamo.convert_frame import get_compile_id
from torch._dynamo.convert_frame import get_compile_id, log_dynamo_start
from torch._guards import compile_context, CompileContext
from .output_graph import get_builtins_dict
@ -394,6 +394,7 @@ class CompilePackage:
# collapsed into 0/0, 1/0 on warm.
increment_frame()
compile_id = get_compile_id(frame_state={})
log_dynamo_start(code)
with (
compile_context(CompileContext(compile_id)),
dynamo_timed(

View File

@ -141,6 +141,9 @@ class PrecompileContext(CacheArtifactManager):
@classmethod
def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
cls._save_artifacts_by_type()
# No need to serialize if there are no new dynamo compiles
if "precompile_dynamo" not in cls._new_cache_artifacts:
return None
return super().serialize()
@staticmethod