mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Precompile] Various small bugfixes, add CachingPrecompile to torchbench (#158847)
This PR addresses a few small bugfixes needed to make NanoGPT inference work, and also adds a new `--caching-precompile` argument to torchbench. With `--caching-precompile`, after every benchmark we save precompile artifacts to DynamoCache, allowing us to test caching precompile on all existing benchmarks. The following bugfixes are in this PR to make all of this work: - Fix global variables being pruned with DUPLICATE_INPUT guards. DUPLICATE_INPUT guards have additional vars from the second input, which we track with additional_local_vars, but we never tracked additional global variables. This fixes the issue. (See torch/_dynamo/guards.py changes) - Return None from PRecompileContext.serialize() if no new dynamo compiles occurred. There's no reason to save artifacts (i.e. autotuning artifacts, etc) if no dynamo_compile occurred, so we return None early. We may later want to support editing existing dynamo artifacts as a TODO, but that's upcoming. - log `dynamo_start` on CompilePackage.load: This is only needed so that tlparse doesn't ignore TORCH_TRACE logs generated when caching precompile hits. If there are no actual compiles, we never log a "dynamo_start" entry, which makes internal tlparse ignore the TORCH_TRACE file. ## Test Plan After this PR, the following now works: ``` TORCH_LOGS=dynamo tlp python benchmarks/dynamo/torchbench.py --only nanogpt --performance --inference --backend inductor --caching-precompile --warm-start-latency ``` tlparse result (internal): Cold Start (6 seconds): https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpAWe0zD/dedicated_log_torch_trace_vk9nkp4m.log/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 Warm Start (~1 s): https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpAWe0zD/dedicated_log_torch_trace_5l4iwrpm.log/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 The 1 second of warm start here can be improved: the costs here are mostly in starting up workers and triton and initializing CUDA, a lot of which should not be included in the compile time cost in real world scenarios where these are already loaded before training begins. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158847 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
a3025e17b2
commit
f55c5d085e
@ -3269,6 +3269,12 @@ def parse_args(args=None):
|
||||
instead of deleting it and creating a new one.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--caching-precompile",
|
||||
action="store_true",
|
||||
help="Enables caching precompile, serializing artifacts to DynamoCache between runs",
|
||||
)
|
||||
|
||||
group_latency = parser.add_mutually_exclusive_group()
|
||||
group_latency.add_argument(
|
||||
"--cold-start-latency",
|
||||
@ -3419,6 +3425,29 @@ def parse_args(args=None):
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def process_caching_precompile():
|
||||
"""
|
||||
After every process_entry, save precompile artifacts to DynamoCache
|
||||
"""
|
||||
assert torch._dynamo.config.caching_precompile, (
|
||||
"Caching precompile should be enabled with --caching-precompile"
|
||||
)
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
|
||||
# Serialize all callables, clear PrecompileContext
|
||||
# TODO: put this under torch.compiler API once ready
|
||||
serialized = PrecompileContext.serialize()
|
||||
PrecompileContext.clear()
|
||||
if serialized is not None:
|
||||
artifacts, info = serialized
|
||||
print(
|
||||
f"Saving {len(info.precompile_dynamo_artifacts)} Precompile Artifact(s)..."
|
||||
)
|
||||
results = PrecompileContext.deserialize(artifacts)
|
||||
assert results is not None
|
||||
PrecompileContext.populate_caches(results)
|
||||
|
||||
|
||||
def process_entry(rank, runner, original_dir, args):
|
||||
args.rank = rank
|
||||
with maybe_init_distributed(
|
||||
@ -3427,7 +3456,10 @@ def process_entry(rank, runner, original_dir, args):
|
||||
world_size=args.world_size,
|
||||
port=args.distributed_master_port,
|
||||
):
|
||||
return run(runner, args, original_dir)
|
||||
result = run(runner, args, original_dir)
|
||||
if args.caching_precompile:
|
||||
process_caching_precompile()
|
||||
return result
|
||||
|
||||
|
||||
def maybe_fresh_cache(args):
|
||||
@ -3463,6 +3495,10 @@ def main(runner, original_dir=None, args=None):
|
||||
)
|
||||
|
||||
with maybe_fresh_cache(args):
|
||||
if args.caching_precompile:
|
||||
os.environ["TORCH_CACHING_PRECOMPILE"] = "1"
|
||||
torch._dynamo.config.caching_precompile = True
|
||||
|
||||
args.init_distributed = args.only and args.multiprocess
|
||||
if args.init_distributed:
|
||||
# NB: Do NOT query device count before CUDA initialization; we're
|
||||
|
@ -4,7 +4,7 @@ import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
import torch._functorch
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
BundledAOTAutogradCacheArtifact,
|
||||
@ -14,8 +14,8 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, requires_triton
|
||||
|
||||
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@functorch_config.patch(
|
||||
{"bundled_autograd_cache": True}
|
||||
@torch._dynamo.config.patch(
|
||||
{"caching_precompile": True}
|
||||
) # Requires bundledaotautograd cache for now
|
||||
class PrecompileContextTests(InductorTestCase):
|
||||
def setUp(self):
|
||||
@ -41,8 +41,7 @@ class PrecompileContextTests(InductorTestCase):
|
||||
x = torch.randn(10, device=GPU_TYPE, requires_grad=True)
|
||||
result = compiled_fn(x)
|
||||
result.sum().backward()
|
||||
# Check that PrecompileContext._new_cache_artifacts_by_key has length 1
|
||||
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)
|
||||
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 2)
|
||||
|
||||
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
|
||||
result = PrecompileContext.serialize()
|
||||
@ -77,14 +76,11 @@ class PrecompileContextTests(InductorTestCase):
|
||||
x = torch.randn(10, device=GPU_TYPE, requires_grad=True)
|
||||
result = compiled_fn(x)
|
||||
result.sum().backward()
|
||||
# Check that PrecompileContext._new_cache_artifacts_by_key has length 1
|
||||
# TODO: the key right now is the AOTAutogradCacheKey, but will be backend_id once
|
||||
# we have torch._dynamo.package implemented
|
||||
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)
|
||||
key = next(iter(PrecompileContext._new_cache_artifacts_by_key.keys()))
|
||||
result = PrecompileContext.serialize_artifact_by_key(key)
|
||||
assert isinstance(result, BundledAOTAutogradCacheArtifact)
|
||||
self.assertEqual(result.key, key)
|
||||
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 2)
|
||||
for key in PrecompileContext._new_cache_artifacts_by_key.keys():
|
||||
result = PrecompileContext.serialize_artifact_by_key(key)
|
||||
assert isinstance(result, PrecompileCacheArtifact)
|
||||
self.assertEqual(result.key, key)
|
||||
|
||||
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
|
||||
result = PrecompileContext.serialize()
|
||||
|
@ -549,7 +549,7 @@ fake_tensor_disable_inference_mode = True
|
||||
|
||||
# Experimental feature for running automatic caching precompile.
|
||||
# Enables automatic DynamoCache save/load
|
||||
caching_precompile = False
|
||||
caching_precompile = os.environ.get("TORCH_CACHING_PRECOMPILE", "0") == "1"
|
||||
|
||||
# Enables the Compiled Autograd engine to trace autograd calls made under torch.compile().
|
||||
# Note: AOTAutograd will still trace and partition an AOT backward graph local to that
|
||||
|
@ -225,6 +225,31 @@ def fx_forward_from_src_skip_result(
|
||||
return result
|
||||
|
||||
|
||||
def log_dynamo_start(code: CodeType, skip: int = 0) -> None:
|
||||
convert_frame_intern = structured.intern_string(__file__)
|
||||
# Initialize the ChromiumEventLogger on start
|
||||
torch._logging.trace_structured(
|
||||
"dynamo_start",
|
||||
lambda: {
|
||||
"stack": list(
|
||||
itertools.takewhile(
|
||||
lambda f: f["filename"] != convert_frame_intern,
|
||||
structured.from_traceback(
|
||||
CapturedTraceback.extract(skip=4 + skip).summary()
|
||||
),
|
||||
)
|
||||
)
|
||||
+ [
|
||||
{
|
||||
"line": code.co_firstlineno,
|
||||
"name": code.co_name,
|
||||
"filename": structured.intern_string(code.co_filename),
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
"""
|
||||
Context manager to:
|
||||
@ -1135,28 +1160,7 @@ def _compile(
|
||||
# # 2 extra here
|
||||
# torch/_logging/_internal.py:1064 in trace_structured
|
||||
# torch/_dynamo/convert_frame.py:780 in <lambda>
|
||||
convert_frame_intern = structured.intern_string(__file__)
|
||||
# Initialize the ChromiumEventLogger on start
|
||||
torch._logging.trace_structured(
|
||||
"dynamo_start",
|
||||
lambda: {
|
||||
"stack": list(
|
||||
itertools.takewhile(
|
||||
lambda f: f["filename"] != convert_frame_intern,
|
||||
structured.from_traceback(
|
||||
CapturedTraceback.extract(skip=4 + skip).summary()
|
||||
),
|
||||
)
|
||||
)
|
||||
+ [
|
||||
{
|
||||
"line": code.co_firstlineno,
|
||||
"name": code.co_name,
|
||||
"filename": structured.intern_string(code.co_filename),
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
log_dynamo_start(code, skip)
|
||||
start_time_ns = time.time_ns()
|
||||
fail_type: Optional[str] = None
|
||||
fail_reason: Optional[str] = None
|
||||
@ -1588,9 +1592,10 @@ class CatchErrorsWrapper:
|
||||
|
||||
with compile_lock, _disable_current_modes():
|
||||
# skip=1: skip this frame
|
||||
return self._torchdynamo_orig_backend(
|
||||
result = self._torchdynamo_orig_backend(
|
||||
frame, cache_entry, self.hooks, frame_state, skip=1
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def catch_errors_wrapper(
|
||||
|
@ -679,8 +679,7 @@ class _TorchDynamoContext:
|
||||
|
||||
# If self._package is lazily initialized, we should check the dynamo cache now
|
||||
if config.caching_precompile:
|
||||
assert self._package is not None
|
||||
if not self._package.is_initialized():
|
||||
if self._package is not None and not self._package.is_initialized():
|
||||
result = DynamoCache.load(fn)
|
||||
if result is None:
|
||||
# Create a fresh CompilePackage
|
||||
|
@ -1970,6 +1970,8 @@ class GuardBuilder(GuardBuilderBase):
|
||||
if self.serialization_mode == "save":
|
||||
if name := get_local_source_name(source_b):
|
||||
self.check_fn_manager.additional_used_local_vars.add(name)
|
||||
if name := get_global_source_name(source_b):
|
||||
self.check_fn_manager.additional_used_global_vars.add(name)
|
||||
|
||||
ref_a = self.arg_ref(guard)
|
||||
ref_b = self.arg_ref(source_b.name())
|
||||
@ -2849,6 +2851,7 @@ class CheckFunctionManager:
|
||||
self.guards_serialization_mode = guards_serialization_mode
|
||||
self.used_builtin_vars: OrderedSet[str] = OrderedSet()
|
||||
self.additional_used_local_vars: OrderedSet[str] = OrderedSet()
|
||||
self.additional_used_global_vars: OrderedSet[str] = OrderedSet()
|
||||
if runtime_global_scope:
|
||||
assert self.guards_serialization_mode == "load"
|
||||
self.runtime_global_scope = runtime_global_scope
|
||||
@ -3039,7 +3042,7 @@ class CheckFunctionManager:
|
||||
global_scope_state = {
|
||||
k: v
|
||||
for k, v in output_graph_guards_state.global_scope.items()
|
||||
if k in used_global_vars
|
||||
if k in used_global_vars or k in self.additional_used_global_vars
|
||||
}
|
||||
global_scope_state[builtins_dict_name] = {
|
||||
k: v
|
||||
|
@ -380,7 +380,7 @@ class CompilePackage:
|
||||
3. Install the precompiled cache entries to ExtraStates on the code object.
|
||||
"""
|
||||
from torch._C._dynamo.eval_frame import _load_precompile_entry
|
||||
from torch._dynamo.convert_frame import get_compile_id
|
||||
from torch._dynamo.convert_frame import get_compile_id, log_dynamo_start
|
||||
from torch._guards import compile_context, CompileContext
|
||||
|
||||
from .output_graph import get_builtins_dict
|
||||
@ -394,6 +394,7 @@ class CompilePackage:
|
||||
# collapsed into 0/0, 1/0 on warm.
|
||||
increment_frame()
|
||||
compile_id = get_compile_id(frame_state={})
|
||||
log_dynamo_start(code)
|
||||
with (
|
||||
compile_context(CompileContext(compile_id)),
|
||||
dynamo_timed(
|
||||
|
@ -141,6 +141,9 @@ class PrecompileContext(CacheArtifactManager):
|
||||
@classmethod
|
||||
def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
|
||||
cls._save_artifacts_by_type()
|
||||
# No need to serialize if there are no new dynamo compiles
|
||||
if "precompile_dynamo" not in cls._new_cache_artifacts:
|
||||
return None
|
||||
return super().serialize()
|
||||
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user