Allow OpOverloadPackets as safe torch functions, sanitize dynamo gm before running aotdispatch with cache (#139785)

Summary:
This diff implements two things to improve cache hit rates after testing AOTAutogradCache with internal cogwheel jobs:
- We should allow torch functions that are OpOverloadPackets
- When running with cache, there are some fields that dynamo puts into the input graph module to aotdispatch that are not stable between runs. We use a context manager to null these out so that they can't be used to affect the output of AOTAutograd, and then we put the fields back onto the gm before returning from AOTAutogradCache.load().

Test Plan:
New unit tests + running nanogpt with AOTAutogradCache.

Meta:

Run on a long running job
Cache miss:
 {F1953831996}

Cache hit:
 {F1953830872}

Servicelabs here:
https://www.internalfb.com/servicelab/experiment/4301352991/

Cache hit:
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/f660597709-TrainingApplication/attempt_0/version_0/rank_0/index.html

Cache miss:
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/f660569960-TrainingApplication/attempt_0/version_0/rank_0/index.html

We can see that with these changes, autograd cache hits and saves compile time:
https://fburl.com/scuba/pt2_compile_events/ycddxstd

Differential Revision: D65436373

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139785
Approved by: https://github.com/bdhirsh
This commit is contained in:
James Wu
2024-11-06 16:34:00 +00:00
committed by PyTorch MergeBot
parent e05a096c49
commit dd6a5de00d
2 changed files with 161 additions and 98 deletions

View File

@ -14,6 +14,7 @@ from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache,
autograd_cache_key,
BypassAOTAutogradCache,
sanitize_gm_for_cache,
)
from torch._functorch._aot_autograd.schemas import AOTConfig
from torch._inductor import config as inductor_config
@ -776,6 +777,31 @@ class AOTAutogradCachePicklerTests(torch._dynamo.test_case.TestCase):
config = self.default_config()
self.gen_cache_key(fn, config)
def test_sanitize_gm_for_cache(self):
def fn(x):
y = torch.sin(x)
z = torch.cos(x)
w = y + z
w.abs()
return w
_, fx_g, example_inputs = self._get_dynamo_output(fn, torch.ones(3))
fx_g.meta = {"foo": "bar"}
fx_g.compile_subgraph_reason = "Blah"
config = self.default_config()
with sanitize_gm_for_cache(fx_g):
c1 = autograd_cache_key(fx_g, example_inputs, config, {})
c3 = autograd_cache_key(fx_g, example_inputs, config, {})
fx_g.meta = {"foo": "baz"}
fx_g.compile_subgraph_reason = None
with sanitize_gm_for_cache(fx_g):
c2 = autograd_cache_key(fx_g, example_inputs, config, {})
c4 = autograd_cache_key(fx_g, example_inputs, config, {})
self.assertEqual(c1, c2)
self.assertNotEqual(c3, c4)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -5,6 +5,7 @@ Utils for caching the outputs of AOTAutograd
from __future__ import annotations
import base64
import contextlib
import functools
import json
import logging
@ -125,7 +126,7 @@ def check_node_safe(node: Node):
)
def is_torch_function(target):
if isinstance(target, torch._ops.OpOverload):
if isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
return True
if is_public_torch_api(target):
return True
@ -324,7 +325,7 @@ class FXGraphCacheLoadable:
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "fx_graph_cache_hash",
"name": "fx_graph_cache_hit", # always a hit
"encoding": "json",
},
payload_fn=lambda: json.dumps(cache_info),
@ -508,6 +509,36 @@ class AOTAutogradCacheEntry:
return compiled_function
@contextlib.contextmanager
def sanitize_gm_for_cache(gm: torch.fx.GraphModule):
"""
Clears a few fields in a dynamo supplied Graph Module that are not stable between graph inputs, but don't
affect inductor or aotdispatch correctness.
These fields **can** be used by code calling into aotdispatch (namely, dynamo), so we can't null them out completely.
To ensure that these fields are not accessed by inductor or aotdispatch, we clear them during AOTAutogradCache.load,
and then put them back before returning. This way, we generate a cache key based off of a canonical graph
without these fields, and also guarantee they aren't used to affect the cache's output.
"""
IGNORED_FIELDS = (
"meta", # metadata used by export
"compile_subgraph_reason", # Used by dynamo only for logging, no change in inductor/autograd behavior
"_param_name_to_source", # Encapsulated by aot_config.aot_autograd_arg_pos_to_source
)
saved_fields = {}
for field in IGNORED_FIELDS:
saved_fields[field] = getattr(gm, field, None)
# Clear the field
setattr(gm, field, None)
try:
yield
finally:
# Put the fields back after dispatch_and_compile is complete
for field, value in saved_fields.items():
setattr(gm, field, value)
class AOTAutogradCache:
"""
Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas
@ -566,107 +597,113 @@ class AOTAutogradCache:
Load a result from the cache, and reconstruct a runtime wrapper around the object
"""
gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod
compiled_fn = None
cache_info: Dict[str, Any] = {}
cache_key = None
debug_lines: List[str] = []
cache_event_time = time.time_ns()
cache_state = None
fx_config: _CompileFxKwargs = {"cudagraphs": cudagraphs}
try:
cache_key, debug_lines = autograd_cache_key(gm, args, aot_config, fx_config)
entry: Optional[AOTAutogradCacheEntry] = AOTAutogradCache._lookup(
cache_key, local, remote
)
if entry is not None:
compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config)
log.info("AOTAutograd cache hit for key %s", cache_key)
counters["aot_autograd"]["autograd_cache_hit"] += 1
cache_state = "hit"
cache_event_time = time.time_ns()
forward_time_saved = entry.forward_time_taken_ns // 1e6
backward_time_saved = entry.backward_time_taken_ns // 1e6
cache_info.update(
{
"forward_time_saved_ms": forward_time_saved,
"backward_time_saved_ms": backward_time_saved,
"time_saved_ms": forward_time_saved + backward_time_saved,
}
)
time_saved_ns = (
entry.forward_time_taken_ns + entry.backward_time_taken_ns
)
# TODO: should we use the same field for remote cache time saved for both
# FXGraphCache and AOTAutogradCache?
# add_remote_cache_time_saved(time_saved_ns, is_backward=False)
if (
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
time_saved_ns
)
) != 0:
cache_info["ephemeral_timeout_increase"] = ephemeral_increase
if compiled_fn is None:
log.info("AOTAutograd cache miss for key %s", cache_key)
counters["aot_autograd"]["autograd_cache_miss"] += 1
cache_state = "miss"
cache_event_time = time.time_ns()
# Count missing the FXGraphCache as a miss not a bypass
except FXGraphCacheMiss as e:
counters["aot_autograd"]["autograd_cache_miss"] += 1
# Special counter when we pass autograd cache but
# fail when on inductor guards
counters["aot_autograd"]["autograd_cache_guard_miss"] += 1
if config.strict_autograd_cache:
raise e
except BypassAOTAutogradCache as e:
with sanitize_gm_for_cache(gm):
compiled_fn = None
cache_info: Dict[str, Any] = {}
cache_key = None
counters["aot_autograd"]["autograd_cache_bypass"] += 1
cache_state = "bypass"
debug_lines: List[str] = []
cache_event_time = time.time_ns()
cache_info["cache_bypass_reason"] = str(e)
if remote:
log_cache_bypass("bypass_aot_autograd", str(e))
if config.strict_autograd_cache:
raise e
if compiled_fn is None:
# Set the cache key so we can save a cache result later
if cache_key is not None:
aot_config.cache_info = AOTAutogradCacheInfo(cache_key, time.time_ns())
compiled_fn = dispatch_and_compile()
cache_state = None
fx_config: _CompileFxKwargs = {"cudagraphs": cudagraphs}
try:
cache_key, debug_lines = autograd_cache_key(
gm, args, aot_config, fx_config
)
entry: Optional[AOTAutogradCacheEntry] = AOTAutogradCache._lookup(
cache_key, local, remote
)
if entry is not None:
compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config)
log.info("AOTAutograd cache hit for key %s", cache_key)
counters["aot_autograd"]["autograd_cache_hit"] += 1
cache_state = "hit"
cache_event_time = time.time_ns()
forward_time_saved = entry.forward_time_taken_ns // 1e6
backward_time_saved = entry.backward_time_taken_ns // 1e6
cache_info.update(
{
"forward_time_saved_ms": forward_time_saved,
"backward_time_saved_ms": backward_time_saved,
"time_saved_ms": forward_time_saved + backward_time_saved,
}
)
time_saved_ns = (
entry.forward_time_taken_ns + entry.backward_time_taken_ns
)
# TODO: should we use the same field for remote cache time saved for both
# FXGraphCache and AOTAutogradCache?
# add_remote_cache_time_saved(time_saved_ns, is_backward=False)
if (
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
time_saved_ns
)
) != 0:
cache_info["ephemeral_timeout_increase"] = ephemeral_increase
cache_info.update(
{
"key": cache_key,
"cache_state": cache_state,
"components": debug_lines,
}
)
chromium_log = get_chromium_event_logger()
chromium_log.log_instant_event(
f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_info
)
if compiled_fn is None:
log.info("AOTAutograd cache miss for key %s", cache_key)
counters["aot_autograd"]["autograd_cache_miss"] += 1
cache_state = "miss"
cache_event_time = time.time_ns()
# Count missing the FXGraphCache as a miss not a bypass
except FXGraphCacheMiss as e:
counters["aot_autograd"]["autograd_cache_miss"] += 1
# Special counter when we pass autograd cache but
# fail when on inductor guards
counters["aot_autograd"]["autograd_cache_guard_miss"] += 1
cache_state = "miss"
if config.strict_autograd_cache:
raise e
except BypassAOTAutogradCache as e:
cache_key = None
counters["aot_autograd"]["autograd_cache_bypass"] += 1
cache_state = "bypass"
cache_event_time = time.time_ns()
cache_info["cache_bypass_reason"] = str(e)
if remote:
log_cache_bypass("bypass_aot_autograd", str(e))
if config.strict_autograd_cache:
raise e
if compiled_fn is None:
# Set the cache key so we can save a cache result later
if cache_key is not None:
aot_config.cache_info = AOTAutogradCacheInfo(
cache_key, time.time_ns()
)
compiled_fn = dispatch_and_compile()
chromium_log.add_event_data(
"backend_compile",
cache_state=cache_state,
cache_event_time=cache_event_time,
key=cache_info.get("key"),
components=cache_info.get("components"),
cache_bypass_reason=cache_info.get("cache_bypass_reason"),
remote_cache_enabled=remote,
local_cache_enabled=local,
)
cache_info.update(
{
"key": cache_key,
"cache_state": cache_state,
"components": debug_lines,
}
)
chromium_log = get_chromium_event_logger()
chromium_log.log_instant_event(
f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_info
)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "aotautograd_cache_hash",
"encoding": "json",
},
payload_fn=lambda: json.dumps(cache_info),
)
return compiled_fn
chromium_log.add_event_data(
"backend_compile",
cache_state=cache_state,
cache_event_time=cache_event_time,
key=cache_info.get("key"),
components=cache_info.get("components"),
cache_bypass_reason=cache_info.get("cache_bypass_reason"),
remote_cache_enabled=remote,
local_cache_enabled=local,
)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "aotautograd_cache_hash",
"encoding": "json",
},
payload_fn=lambda: json.dumps(cache_info),
)
return compiled_fn
@staticmethod
def _get_tmp_dir() -> str: