mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow OpOverloadPackets as safe torch functions, sanitize dynamo gm before running aotdispatch with cache (#139785)
Summary: This diff implements two things to improve cache hit rates after testing AOTAutogradCache with internal cogwheel jobs: - We should allow torch functions that are OpOverloadPackets - When running with cache, there are some fields that dynamo puts into the input graph module to aotdispatch that are not stable between runs. We use a context manager to null these out so that they can't be used to affect the output of AOTAutograd, and then we put the fields back onto the gm before returning from AOTAutogradCache.load(). Test Plan: New unit tests + running nanogpt with AOTAutogradCache. Meta: Run on a long running job Cache miss: {F1953831996} Cache hit: {F1953830872} Servicelabs here: https://www.internalfb.com/servicelab/experiment/4301352991/ Cache hit: https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/f660597709-TrainingApplication/attempt_0/version_0/rank_0/index.html Cache miss: https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/f660569960-TrainingApplication/attempt_0/version_0/rank_0/index.html We can see that with these changes, autograd cache hits and saves compile time: https://fburl.com/scuba/pt2_compile_events/ycddxstd Differential Revision: D65436373 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139785 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
e05a096c49
commit
dd6a5de00d
@ -14,6 +14,7 @@ from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCache,
|
||||
autograd_cache_key,
|
||||
BypassAOTAutogradCache,
|
||||
sanitize_gm_for_cache,
|
||||
)
|
||||
from torch._functorch._aot_autograd.schemas import AOTConfig
|
||||
from torch._inductor import config as inductor_config
|
||||
@ -776,6 +777,31 @@ class AOTAutogradCachePicklerTests(torch._dynamo.test_case.TestCase):
|
||||
config = self.default_config()
|
||||
self.gen_cache_key(fn, config)
|
||||
|
||||
def test_sanitize_gm_for_cache(self):
|
||||
def fn(x):
|
||||
y = torch.sin(x)
|
||||
z = torch.cos(x)
|
||||
w = y + z
|
||||
w.abs()
|
||||
return w
|
||||
|
||||
_, fx_g, example_inputs = self._get_dynamo_output(fn, torch.ones(3))
|
||||
fx_g.meta = {"foo": "bar"}
|
||||
fx_g.compile_subgraph_reason = "Blah"
|
||||
config = self.default_config()
|
||||
with sanitize_gm_for_cache(fx_g):
|
||||
c1 = autograd_cache_key(fx_g, example_inputs, config, {})
|
||||
c3 = autograd_cache_key(fx_g, example_inputs, config, {})
|
||||
|
||||
fx_g.meta = {"foo": "baz"}
|
||||
fx_g.compile_subgraph_reason = None
|
||||
with sanitize_gm_for_cache(fx_g):
|
||||
c2 = autograd_cache_key(fx_g, example_inputs, config, {})
|
||||
c4 = autograd_cache_key(fx_g, example_inputs, config, {})
|
||||
|
||||
self.assertEqual(c1, c2)
|
||||
self.assertNotEqual(c3, c4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -5,6 +5,7 @@ Utils for caching the outputs of AOTAutograd
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
@ -125,7 +126,7 @@ def check_node_safe(node: Node):
|
||||
)
|
||||
|
||||
def is_torch_function(target):
|
||||
if isinstance(target, torch._ops.OpOverload):
|
||||
if isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
|
||||
return True
|
||||
if is_public_torch_api(target):
|
||||
return True
|
||||
@ -324,7 +325,7 @@ class FXGraphCacheLoadable:
|
||||
torch._logging.trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "fx_graph_cache_hash",
|
||||
"name": "fx_graph_cache_hit", # always a hit
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(cache_info),
|
||||
@ -508,6 +509,36 @@ class AOTAutogradCacheEntry:
|
||||
return compiled_function
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def sanitize_gm_for_cache(gm: torch.fx.GraphModule):
|
||||
"""
|
||||
Clears a few fields in a dynamo supplied Graph Module that are not stable between graph inputs, but don't
|
||||
affect inductor or aotdispatch correctness.
|
||||
|
||||
These fields **can** be used by code calling into aotdispatch (namely, dynamo), so we can't null them out completely.
|
||||
|
||||
To ensure that these fields are not accessed by inductor or aotdispatch, we clear them during AOTAutogradCache.load,
|
||||
and then put them back before returning. This way, we generate a cache key based off of a canonical graph
|
||||
without these fields, and also guarantee they aren't used to affect the cache's output.
|
||||
"""
|
||||
IGNORED_FIELDS = (
|
||||
"meta", # metadata used by export
|
||||
"compile_subgraph_reason", # Used by dynamo only for logging, no change in inductor/autograd behavior
|
||||
"_param_name_to_source", # Encapsulated by aot_config.aot_autograd_arg_pos_to_source
|
||||
)
|
||||
saved_fields = {}
|
||||
for field in IGNORED_FIELDS:
|
||||
saved_fields[field] = getattr(gm, field, None)
|
||||
# Clear the field
|
||||
setattr(gm, field, None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Put the fields back after dispatch_and_compile is complete
|
||||
for field, value in saved_fields.items():
|
||||
setattr(gm, field, value)
|
||||
|
||||
|
||||
class AOTAutogradCache:
|
||||
"""
|
||||
Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas
|
||||
@ -566,107 +597,113 @@ class AOTAutogradCache:
|
||||
Load a result from the cache, and reconstruct a runtime wrapper around the object
|
||||
"""
|
||||
gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod
|
||||
compiled_fn = None
|
||||
cache_info: Dict[str, Any] = {}
|
||||
cache_key = None
|
||||
debug_lines: List[str] = []
|
||||
cache_event_time = time.time_ns()
|
||||
cache_state = None
|
||||
fx_config: _CompileFxKwargs = {"cudagraphs": cudagraphs}
|
||||
try:
|
||||
cache_key, debug_lines = autograd_cache_key(gm, args, aot_config, fx_config)
|
||||
entry: Optional[AOTAutogradCacheEntry] = AOTAutogradCache._lookup(
|
||||
cache_key, local, remote
|
||||
)
|
||||
if entry is not None:
|
||||
compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config)
|
||||
log.info("AOTAutograd cache hit for key %s", cache_key)
|
||||
counters["aot_autograd"]["autograd_cache_hit"] += 1
|
||||
cache_state = "hit"
|
||||
cache_event_time = time.time_ns()
|
||||
forward_time_saved = entry.forward_time_taken_ns // 1e6
|
||||
backward_time_saved = entry.backward_time_taken_ns // 1e6
|
||||
cache_info.update(
|
||||
{
|
||||
"forward_time_saved_ms": forward_time_saved,
|
||||
"backward_time_saved_ms": backward_time_saved,
|
||||
"time_saved_ms": forward_time_saved + backward_time_saved,
|
||||
}
|
||||
)
|
||||
time_saved_ns = (
|
||||
entry.forward_time_taken_ns + entry.backward_time_taken_ns
|
||||
)
|
||||
# TODO: should we use the same field for remote cache time saved for both
|
||||
# FXGraphCache and AOTAutogradCache?
|
||||
# add_remote_cache_time_saved(time_saved_ns, is_backward=False)
|
||||
if (
|
||||
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
|
||||
time_saved_ns
|
||||
)
|
||||
) != 0:
|
||||
cache_info["ephemeral_timeout_increase"] = ephemeral_increase
|
||||
|
||||
if compiled_fn is None:
|
||||
log.info("AOTAutograd cache miss for key %s", cache_key)
|
||||
counters["aot_autograd"]["autograd_cache_miss"] += 1
|
||||
cache_state = "miss"
|
||||
cache_event_time = time.time_ns()
|
||||
# Count missing the FXGraphCache as a miss not a bypass
|
||||
except FXGraphCacheMiss as e:
|
||||
counters["aot_autograd"]["autograd_cache_miss"] += 1
|
||||
# Special counter when we pass autograd cache but
|
||||
# fail when on inductor guards
|
||||
counters["aot_autograd"]["autograd_cache_guard_miss"] += 1
|
||||
if config.strict_autograd_cache:
|
||||
raise e
|
||||
except BypassAOTAutogradCache as e:
|
||||
with sanitize_gm_for_cache(gm):
|
||||
compiled_fn = None
|
||||
cache_info: Dict[str, Any] = {}
|
||||
cache_key = None
|
||||
counters["aot_autograd"]["autograd_cache_bypass"] += 1
|
||||
cache_state = "bypass"
|
||||
debug_lines: List[str] = []
|
||||
cache_event_time = time.time_ns()
|
||||
cache_info["cache_bypass_reason"] = str(e)
|
||||
if remote:
|
||||
log_cache_bypass("bypass_aot_autograd", str(e))
|
||||
if config.strict_autograd_cache:
|
||||
raise e
|
||||
if compiled_fn is None:
|
||||
# Set the cache key so we can save a cache result later
|
||||
if cache_key is not None:
|
||||
aot_config.cache_info = AOTAutogradCacheInfo(cache_key, time.time_ns())
|
||||
compiled_fn = dispatch_and_compile()
|
||||
cache_state = None
|
||||
fx_config: _CompileFxKwargs = {"cudagraphs": cudagraphs}
|
||||
try:
|
||||
cache_key, debug_lines = autograd_cache_key(
|
||||
gm, args, aot_config, fx_config
|
||||
)
|
||||
entry: Optional[AOTAutogradCacheEntry] = AOTAutogradCache._lookup(
|
||||
cache_key, local, remote
|
||||
)
|
||||
if entry is not None:
|
||||
compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config)
|
||||
log.info("AOTAutograd cache hit for key %s", cache_key)
|
||||
counters["aot_autograd"]["autograd_cache_hit"] += 1
|
||||
cache_state = "hit"
|
||||
cache_event_time = time.time_ns()
|
||||
forward_time_saved = entry.forward_time_taken_ns // 1e6
|
||||
backward_time_saved = entry.backward_time_taken_ns // 1e6
|
||||
cache_info.update(
|
||||
{
|
||||
"forward_time_saved_ms": forward_time_saved,
|
||||
"backward_time_saved_ms": backward_time_saved,
|
||||
"time_saved_ms": forward_time_saved + backward_time_saved,
|
||||
}
|
||||
)
|
||||
time_saved_ns = (
|
||||
entry.forward_time_taken_ns + entry.backward_time_taken_ns
|
||||
)
|
||||
# TODO: should we use the same field for remote cache time saved for both
|
||||
# FXGraphCache and AOTAutogradCache?
|
||||
# add_remote_cache_time_saved(time_saved_ns, is_backward=False)
|
||||
if (
|
||||
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
|
||||
time_saved_ns
|
||||
)
|
||||
) != 0:
|
||||
cache_info["ephemeral_timeout_increase"] = ephemeral_increase
|
||||
|
||||
cache_info.update(
|
||||
{
|
||||
"key": cache_key,
|
||||
"cache_state": cache_state,
|
||||
"components": debug_lines,
|
||||
}
|
||||
)
|
||||
chromium_log = get_chromium_event_logger()
|
||||
chromium_log.log_instant_event(
|
||||
f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_info
|
||||
)
|
||||
if compiled_fn is None:
|
||||
log.info("AOTAutograd cache miss for key %s", cache_key)
|
||||
counters["aot_autograd"]["autograd_cache_miss"] += 1
|
||||
cache_state = "miss"
|
||||
cache_event_time = time.time_ns()
|
||||
# Count missing the FXGraphCache as a miss not a bypass
|
||||
except FXGraphCacheMiss as e:
|
||||
counters["aot_autograd"]["autograd_cache_miss"] += 1
|
||||
# Special counter when we pass autograd cache but
|
||||
# fail when on inductor guards
|
||||
counters["aot_autograd"]["autograd_cache_guard_miss"] += 1
|
||||
cache_state = "miss"
|
||||
if config.strict_autograd_cache:
|
||||
raise e
|
||||
except BypassAOTAutogradCache as e:
|
||||
cache_key = None
|
||||
counters["aot_autograd"]["autograd_cache_bypass"] += 1
|
||||
cache_state = "bypass"
|
||||
cache_event_time = time.time_ns()
|
||||
cache_info["cache_bypass_reason"] = str(e)
|
||||
if remote:
|
||||
log_cache_bypass("bypass_aot_autograd", str(e))
|
||||
if config.strict_autograd_cache:
|
||||
raise e
|
||||
if compiled_fn is None:
|
||||
# Set the cache key so we can save a cache result later
|
||||
if cache_key is not None:
|
||||
aot_config.cache_info = AOTAutogradCacheInfo(
|
||||
cache_key, time.time_ns()
|
||||
)
|
||||
compiled_fn = dispatch_and_compile()
|
||||
|
||||
chromium_log.add_event_data(
|
||||
"backend_compile",
|
||||
cache_state=cache_state,
|
||||
cache_event_time=cache_event_time,
|
||||
key=cache_info.get("key"),
|
||||
components=cache_info.get("components"),
|
||||
cache_bypass_reason=cache_info.get("cache_bypass_reason"),
|
||||
remote_cache_enabled=remote,
|
||||
local_cache_enabled=local,
|
||||
)
|
||||
cache_info.update(
|
||||
{
|
||||
"key": cache_key,
|
||||
"cache_state": cache_state,
|
||||
"components": debug_lines,
|
||||
}
|
||||
)
|
||||
chromium_log = get_chromium_event_logger()
|
||||
chromium_log.log_instant_event(
|
||||
f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_info
|
||||
)
|
||||
|
||||
torch._logging.trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "aotautograd_cache_hash",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(cache_info),
|
||||
)
|
||||
return compiled_fn
|
||||
chromium_log.add_event_data(
|
||||
"backend_compile",
|
||||
cache_state=cache_state,
|
||||
cache_event_time=cache_event_time,
|
||||
key=cache_info.get("key"),
|
||||
components=cache_info.get("components"),
|
||||
cache_bypass_reason=cache_info.get("cache_bypass_reason"),
|
||||
remote_cache_enabled=remote,
|
||||
local_cache_enabled=local,
|
||||
)
|
||||
|
||||
torch._logging.trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "aotautograd_cache_hash",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(cache_info),
|
||||
)
|
||||
return compiled_fn
|
||||
|
||||
@staticmethod
|
||||
def _get_tmp_dir() -> str:
|
||||
|
Reference in New Issue
Block a user