Allow OpOverloadPackets as safe torch functions, sanitize dynamo gm before running aotdispatch with cache (#139785)

Summary:
This diff implements two things to improve cache hit rates after testing AOTAutogradCache with internal cogwheel jobs:
- We should allow torch functions that are OpOverloadPackets
- When running with cache, there are some fields that dynamo puts into the input graph module to aotdispatch that are not stable between runs. We use a context manager to null these out so that they can't be used to affect the output of AOTAutograd, and then we put the fields back onto the gm before returning from AOTAutogradCache.load().

Test Plan:
New unit tests + running nanogpt with AOTAutogradCache.

Meta:

Run on a long running job
Cache miss:
 {F1953831996}

Cache hit:
 {F1953830872}

Servicelabs here:
https://www.internalfb.com/servicelab/experiment/4301352991/

Cache hit:
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/f660597709-TrainingApplication/attempt_0/version_0/rank_0/index.html

Cache miss:
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/f660569960-TrainingApplication/attempt_0/version_0/rank_0/index.html

We can see that with these changes, autograd cache hits and saves compile time:
https://fburl.com/scuba/pt2_compile_events/ycddxstd

Differential Revision: D65436373

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139785
Approved by: https://github.com/bdhirsh
This commit is contained in:
James Wu
2024-11-06 16:34:00 +00:00
committed by PyTorch MergeBot
parent e05a096c49
commit dd6a5de00d
2 changed files with 161 additions and 98 deletions

View File

@ -14,6 +14,7 @@ from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache, AOTAutogradCache,
autograd_cache_key, autograd_cache_key,
BypassAOTAutogradCache, BypassAOTAutogradCache,
sanitize_gm_for_cache,
) )
from torch._functorch._aot_autograd.schemas import AOTConfig from torch._functorch._aot_autograd.schemas import AOTConfig
from torch._inductor import config as inductor_config from torch._inductor import config as inductor_config
@ -776,6 +777,31 @@ class AOTAutogradCachePicklerTests(torch._dynamo.test_case.TestCase):
config = self.default_config() config = self.default_config()
self.gen_cache_key(fn, config) self.gen_cache_key(fn, config)
def test_sanitize_gm_for_cache(self):
def fn(x):
y = torch.sin(x)
z = torch.cos(x)
w = y + z
w.abs()
return w
_, fx_g, example_inputs = self._get_dynamo_output(fn, torch.ones(3))
fx_g.meta = {"foo": "bar"}
fx_g.compile_subgraph_reason = "Blah"
config = self.default_config()
with sanitize_gm_for_cache(fx_g):
c1 = autograd_cache_key(fx_g, example_inputs, config, {})
c3 = autograd_cache_key(fx_g, example_inputs, config, {})
fx_g.meta = {"foo": "baz"}
fx_g.compile_subgraph_reason = None
with sanitize_gm_for_cache(fx_g):
c2 = autograd_cache_key(fx_g, example_inputs, config, {})
c4 = autograd_cache_key(fx_g, example_inputs, config, {})
self.assertEqual(c1, c2)
self.assertNotEqual(c3, c4)
if __name__ == "__main__": if __name__ == "__main__":
from torch._dynamo.test_case import run_tests from torch._dynamo.test_case import run_tests

View File

@ -5,6 +5,7 @@ Utils for caching the outputs of AOTAutograd
from __future__ import annotations from __future__ import annotations
import base64 import base64
import contextlib
import functools import functools
import json import json
import logging import logging
@ -125,7 +126,7 @@ def check_node_safe(node: Node):
) )
def is_torch_function(target): def is_torch_function(target):
if isinstance(target, torch._ops.OpOverload): if isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
return True return True
if is_public_torch_api(target): if is_public_torch_api(target):
return True return True
@ -324,7 +325,7 @@ class FXGraphCacheLoadable:
torch._logging.trace_structured( torch._logging.trace_structured(
"artifact", "artifact",
metadata_fn=lambda: { metadata_fn=lambda: {
"name": "fx_graph_cache_hash", "name": "fx_graph_cache_hit", # always a hit
"encoding": "json", "encoding": "json",
}, },
payload_fn=lambda: json.dumps(cache_info), payload_fn=lambda: json.dumps(cache_info),
@ -508,6 +509,36 @@ class AOTAutogradCacheEntry:
return compiled_function return compiled_function
@contextlib.contextmanager
def sanitize_gm_for_cache(gm: torch.fx.GraphModule):
"""
Clears a few fields in a dynamo supplied Graph Module that are not stable between graph inputs, but don't
affect inductor or aotdispatch correctness.
These fields **can** be used by code calling into aotdispatch (namely, dynamo), so we can't null them out completely.
To ensure that these fields are not accessed by inductor or aotdispatch, we clear them during AOTAutogradCache.load,
and then put them back before returning. This way, we generate a cache key based off of a canonical graph
without these fields, and also guarantee they aren't used to affect the cache's output.
"""
IGNORED_FIELDS = (
"meta", # metadata used by export
"compile_subgraph_reason", # Used by dynamo only for logging, no change in inductor/autograd behavior
"_param_name_to_source", # Encapsulated by aot_config.aot_autograd_arg_pos_to_source
)
saved_fields = {}
for field in IGNORED_FIELDS:
saved_fields[field] = getattr(gm, field, None)
# Clear the field
setattr(gm, field, None)
try:
yield
finally:
# Put the fields back after dispatch_and_compile is complete
for field, value in saved_fields.items():
setattr(gm, field, value)
class AOTAutogradCache: class AOTAutogradCache:
""" """
Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas
@ -566,107 +597,113 @@ class AOTAutogradCache:
Load a result from the cache, and reconstruct a runtime wrapper around the object Load a result from the cache, and reconstruct a runtime wrapper around the object
""" """
gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod
compiled_fn = None with sanitize_gm_for_cache(gm):
cache_info: Dict[str, Any] = {} compiled_fn = None
cache_key = None cache_info: Dict[str, Any] = {}
debug_lines: List[str] = []
cache_event_time = time.time_ns()
cache_state = None
fx_config: _CompileFxKwargs = {"cudagraphs": cudagraphs}
try:
cache_key, debug_lines = autograd_cache_key(gm, args, aot_config, fx_config)
entry: Optional[AOTAutogradCacheEntry] = AOTAutogradCache._lookup(
cache_key, local, remote
)
if entry is not None:
compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config)
log.info("AOTAutograd cache hit for key %s", cache_key)
counters["aot_autograd"]["autograd_cache_hit"] += 1
cache_state = "hit"
cache_event_time = time.time_ns()
forward_time_saved = entry.forward_time_taken_ns // 1e6
backward_time_saved = entry.backward_time_taken_ns // 1e6
cache_info.update(
{
"forward_time_saved_ms": forward_time_saved,
"backward_time_saved_ms": backward_time_saved,
"time_saved_ms": forward_time_saved + backward_time_saved,
}
)
time_saved_ns = (
entry.forward_time_taken_ns + entry.backward_time_taken_ns
)
# TODO: should we use the same field for remote cache time saved for both
# FXGraphCache and AOTAutogradCache?
# add_remote_cache_time_saved(time_saved_ns, is_backward=False)
if (
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
time_saved_ns
)
) != 0:
cache_info["ephemeral_timeout_increase"] = ephemeral_increase
if compiled_fn is None:
log.info("AOTAutograd cache miss for key %s", cache_key)
counters["aot_autograd"]["autograd_cache_miss"] += 1
cache_state = "miss"
cache_event_time = time.time_ns()
# Count missing the FXGraphCache as a miss not a bypass
except FXGraphCacheMiss as e:
counters["aot_autograd"]["autograd_cache_miss"] += 1
# Special counter when we pass autograd cache but
# fail when on inductor guards
counters["aot_autograd"]["autograd_cache_guard_miss"] += 1
if config.strict_autograd_cache:
raise e
except BypassAOTAutogradCache as e:
cache_key = None cache_key = None
counters["aot_autograd"]["autograd_cache_bypass"] += 1 debug_lines: List[str] = []
cache_state = "bypass"
cache_event_time = time.time_ns() cache_event_time = time.time_ns()
cache_info["cache_bypass_reason"] = str(e) cache_state = None
if remote: fx_config: _CompileFxKwargs = {"cudagraphs": cudagraphs}
log_cache_bypass("bypass_aot_autograd", str(e)) try:
if config.strict_autograd_cache: cache_key, debug_lines = autograd_cache_key(
raise e gm, args, aot_config, fx_config
if compiled_fn is None: )
# Set the cache key so we can save a cache result later entry: Optional[AOTAutogradCacheEntry] = AOTAutogradCache._lookup(
if cache_key is not None: cache_key, local, remote
aot_config.cache_info = AOTAutogradCacheInfo(cache_key, time.time_ns()) )
compiled_fn = dispatch_and_compile() if entry is not None:
compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config)
log.info("AOTAutograd cache hit for key %s", cache_key)
counters["aot_autograd"]["autograd_cache_hit"] += 1
cache_state = "hit"
cache_event_time = time.time_ns()
forward_time_saved = entry.forward_time_taken_ns // 1e6
backward_time_saved = entry.backward_time_taken_ns // 1e6
cache_info.update(
{
"forward_time_saved_ms": forward_time_saved,
"backward_time_saved_ms": backward_time_saved,
"time_saved_ms": forward_time_saved + backward_time_saved,
}
)
time_saved_ns = (
entry.forward_time_taken_ns + entry.backward_time_taken_ns
)
# TODO: should we use the same field for remote cache time saved for both
# FXGraphCache and AOTAutogradCache?
# add_remote_cache_time_saved(time_saved_ns, is_backward=False)
if (
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
time_saved_ns
)
) != 0:
cache_info["ephemeral_timeout_increase"] = ephemeral_increase
cache_info.update( if compiled_fn is None:
{ log.info("AOTAutograd cache miss for key %s", cache_key)
"key": cache_key, counters["aot_autograd"]["autograd_cache_miss"] += 1
"cache_state": cache_state, cache_state = "miss"
"components": debug_lines, cache_event_time = time.time_ns()
} # Count missing the FXGraphCache as a miss not a bypass
) except FXGraphCacheMiss as e:
chromium_log = get_chromium_event_logger() counters["aot_autograd"]["autograd_cache_miss"] += 1
chromium_log.log_instant_event( # Special counter when we pass autograd cache but
f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_info # fail when on inductor guards
) counters["aot_autograd"]["autograd_cache_guard_miss"] += 1
cache_state = "miss"
if config.strict_autograd_cache:
raise e
except BypassAOTAutogradCache as e:
cache_key = None
counters["aot_autograd"]["autograd_cache_bypass"] += 1
cache_state = "bypass"
cache_event_time = time.time_ns()
cache_info["cache_bypass_reason"] = str(e)
if remote:
log_cache_bypass("bypass_aot_autograd", str(e))
if config.strict_autograd_cache:
raise e
if compiled_fn is None:
# Set the cache key so we can save a cache result later
if cache_key is not None:
aot_config.cache_info = AOTAutogradCacheInfo(
cache_key, time.time_ns()
)
compiled_fn = dispatch_and_compile()
chromium_log.add_event_data( cache_info.update(
"backend_compile", {
cache_state=cache_state, "key": cache_key,
cache_event_time=cache_event_time, "cache_state": cache_state,
key=cache_info.get("key"), "components": debug_lines,
components=cache_info.get("components"), }
cache_bypass_reason=cache_info.get("cache_bypass_reason"), )
remote_cache_enabled=remote, chromium_log = get_chromium_event_logger()
local_cache_enabled=local, chromium_log.log_instant_event(
) f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_info
)
torch._logging.trace_structured( chromium_log.add_event_data(
"artifact", "backend_compile",
metadata_fn=lambda: { cache_state=cache_state,
"name": "aotautograd_cache_hash", cache_event_time=cache_event_time,
"encoding": "json", key=cache_info.get("key"),
}, components=cache_info.get("components"),
payload_fn=lambda: json.dumps(cache_info), cache_bypass_reason=cache_info.get("cache_bypass_reason"),
) remote_cache_enabled=remote,
return compiled_fn local_cache_enabled=local,
)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "aotautograd_cache_hash",
"encoding": "json",
},
payload_fn=lambda: json.dumps(cache_info),
)
return compiled_fn
@staticmethod @staticmethod
def _get_tmp_dir() -> str: def _get_tmp_dir() -> str: