mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor how AOTAutograd backends are defined (#89736)
There was a lot of strangeness in how AOTAutograd backends were previously defined. This refactor replaces the strangeness with something simple and straightforward. The improvements: - There is no longer a footgun aot_autograd "backend" which doesn't actually work. No more mistyping `torch._dynamo.optimize("aot_autograd")` when you meant "aot_eager" - Deleted aot_print because it's annoying and anyway there's no uses of it - Instead of having BOTH the backend Subgraph and AotAutogradStrategy, there is now only an aot_autograd function which takes the kwargs to configure AOTAutograd, and then gives you a compiler function that does AOTAutograd given those kwargs. Easy. - The primary downside is that we are now eagerly populating all of the kwargs, and that can get us into import cycle shenanigans. Some cycles I resolved directly (e.g., we now no longer manually disable the forward function before passing it to aot_autograd; aot_autograd it does it for us), but for getting inductor decompositions I had to make it take a lambda so I could lazily populate the decomps later. New code is 130 lines shorter! Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/89736 Approved by: https://github.com/anjali411, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
cf4969d9d6
commit
b589e726d9
@ -15,7 +15,6 @@ import torch.nn as nn
|
||||
import torch.utils._pytree as pytree
|
||||
import torch.utils.dlpack
|
||||
from torch import Tensor
|
||||
from torch._dynamo import disable as disable_torchdynamo
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
from torch._subclasses import FakeTensorMode, CrossRefFakeMode
|
||||
from torch.fx import immutable_collections, Interpreter
|
||||
@ -1315,7 +1314,6 @@ def aot_dispatch_deduplicated_autograd(flat_fn, flat_args: List[Tensor], aot_con
|
||||
fw_metadata = _fw_metadata
|
||||
|
||||
@staticmethod
|
||||
@disable_torchdynamo
|
||||
def forward(ctx, *deduped_flat_tensor_args):
|
||||
|
||||
# There is a pretty complicated calling convention around what the compiled fw returns.
|
||||
@ -1361,7 +1359,6 @@ def aot_dispatch_deduplicated_autograd(flat_fn, flat_args: List[Tensor], aot_con
|
||||
return tuple(fw_outs[0:num_forward_returns])
|
||||
|
||||
@staticmethod
|
||||
@disable_torchdynamo
|
||||
def backward(ctx, *all_flat_args):
|
||||
# Calling convention: we expect a grad_out passed to the backward:
|
||||
# - for every output of the fw that does *not* alias an input
|
||||
|
@ -1788,7 +1788,6 @@ class TestImports(TestCase):
|
||||
"torch.contrib.", # something weird
|
||||
"torch.testing._internal.distributed.", # just fails
|
||||
"torch.ao.pruning._experimental.", # depends on pytorch_lightning, not user-facing
|
||||
"torch.cuda._dynamo_graphs", # depends on torchdynamo
|
||||
]
|
||||
# See https://github.com/pytorch/pytorch/issues/77801
|
||||
if not sys.version_info >= (3, 9):
|
||||
|
@ -37,7 +37,6 @@ else:
|
||||
from . import config, convert_frame, skipfiles, utils
|
||||
from .exc import ResetRequired
|
||||
from .mutation_guard import install_generation_tagging_init
|
||||
from .optimizations.distributed import DDPOptimizer
|
||||
from .output_graph import CompilerFn
|
||||
from .types import DynamoCallback
|
||||
from .utils import compile_times
|
||||
@ -311,6 +310,8 @@ def catch_errors_wrapper(callback):
|
||||
ddp_module = DistributedDataParallel._get_active_ddp_module()
|
||||
if ddp_module:
|
||||
with compile_lock:
|
||||
from .optimizations.distributed import DDPOptimizer
|
||||
|
||||
ddp_optimizer = DDPOptimizer(
|
||||
bucket_bytes_cap=ddp_module.bucket_bytes_cap,
|
||||
backend_compile_fn=callback._torchdynamo_orig_callable,
|
||||
|
@ -517,22 +517,6 @@ def cudagraphs_inner(model, inputs, copy_outputs=True):
|
||||
return run
|
||||
|
||||
|
||||
@create_backend
|
||||
def aot_autograd(subgraph, **kwargs):
|
||||
def _wrapped_bw_compiler(*args, **kwargs):
|
||||
# stop TorchDynamo from trying to compile our generated backwards pass
|
||||
return disable(disable(bw_compiler)(*args, **kwargs))
|
||||
|
||||
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
|
||||
kwargs["bw_compiler"] = _wrapped_bw_compiler
|
||||
|
||||
from functorch.compile import aot_module_simplified
|
||||
|
||||
from .. import disable
|
||||
|
||||
return aot_module_simplified(subgraph.model, **kwargs)
|
||||
|
||||
|
||||
def tvm_compile(jit_mod, example_inputs, log_file=None, **kwargs):
|
||||
if jit_mod is None:
|
||||
return None
|
||||
|
@ -6,6 +6,15 @@ from functools import partial
|
||||
from importlib import import_module
|
||||
from typing import Set
|
||||
|
||||
from functorch._src.compilers import debug_nop
|
||||
|
||||
from functorch.compile import (
|
||||
aot_module_simplified,
|
||||
min_cut_rematerialization_partition,
|
||||
nop,
|
||||
ts_compile,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
|
||||
@ -13,7 +22,7 @@ from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.nn import Module
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from .. import config
|
||||
from .. import config, eval_frame
|
||||
from ..utils import clone_inputs, count_calls, counters
|
||||
from .analysis import has_mutation
|
||||
from .backends import BACKENDS
|
||||
@ -22,6 +31,62 @@ from .normalize import normalize_ir
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def aot_autograd(**kwargs):
|
||||
def compiler_fn(gm: torch.fx.GraphModule, example_inputs):
|
||||
import functorch.compile
|
||||
|
||||
# Hack to get around circular import problems with aot_inductor_debug
|
||||
if callable(kwargs.get("decompositions")):
|
||||
kwargs["decompositions"] = kwargs["decompositions"]()
|
||||
|
||||
# TODO: stop monkeypatching here (without even cleaning up, UGH!)
|
||||
functorch.compile.config.use_functionalize = True
|
||||
functorch.compile.config.use_fake_tensor = True
|
||||
|
||||
force_compile_tiny_graphs = kwargs.pop("force_compile_tiny_graphs", False)
|
||||
|
||||
if count_calls(gm.graph) < 2 and not force_compile_tiny_graphs:
|
||||
return gm # no point for tiny graphs
|
||||
|
||||
counters["aot_autograd"]["total"] += 1
|
||||
use_fallback = False
|
||||
|
||||
if not functorch.compile.config.use_functionalize and config.normalize_ir:
|
||||
try:
|
||||
gm = normalize_ir(gm, clone_inputs(example_inputs))
|
||||
except Exception:
|
||||
log.debug("TorchDynamo unable to remove mutation")
|
||||
use_fallback = True
|
||||
|
||||
# NB: no clone here on example inputs
|
||||
if not is_aot_autograd_safe_to_run(gm, example_inputs):
|
||||
use_fallback = True
|
||||
|
||||
if use_fallback:
|
||||
log.debug("Unable to use AOT Autograd because graph has mutation")
|
||||
counters["aot_autograd"]["not_ok"] += 1
|
||||
return gm
|
||||
|
||||
# OK attempt to compile
|
||||
|
||||
def _wrapped_bw_compiler(*args, **kwargs):
|
||||
# stop TorchDynamo from trying to compile our generated backwards pass
|
||||
return eval_frame.disable(eval_frame.disable(bw_compiler)(*args, **kwargs))
|
||||
|
||||
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
|
||||
kwargs["bw_compiler"] = _wrapped_bw_compiler
|
||||
|
||||
try:
|
||||
cg = aot_module_simplified(gm, **kwargs)
|
||||
counters["aot_autograd"]["ok"] += 1
|
||||
return eval_frame.disable(cg)
|
||||
except Exception:
|
||||
counters["aot_autograd"]["not_ok"] += 1
|
||||
raise
|
||||
|
||||
return compiler_fn
|
||||
|
||||
|
||||
def is_aot_autograd_safe_to_run(gm, example_inputs):
|
||||
"""
|
||||
There are some known issues with Aot Autograd. This is a workaround to catch
|
||||
@ -86,107 +151,28 @@ def is_aot_autograd_safe_to_run(gm, example_inputs):
|
||||
return True
|
||||
|
||||
|
||||
class AotAutogradStrategy(object):
|
||||
"""Base class for backend strategies that use AOT Autograd"""
|
||||
DEBUG = False
|
||||
|
||||
@classmethod
|
||||
def compile_fn(cls, gm: torch.fx.GraphModule, example_inputs):
|
||||
if count_calls(gm.graph) < 2:
|
||||
return gm # no point for tiny graphs
|
||||
return cls(gm, example_inputs).verified_candidate()
|
||||
# Useful for debugging purpose
|
||||
aot_eager = aot_autograd(fw_compiler=debug_nop if DEBUG else nop)
|
||||
|
||||
def __init__(self, gm: torch.fx.GraphModule, example_inputs):
|
||||
import functorch.compile
|
||||
# AOT Autograd with torchscript backend. Default partitioner.
|
||||
aot_ts = aot_autograd(fw_compiler=ts_compile)
|
||||
|
||||
functorch.compile.config.use_functionalize = True
|
||||
functorch.compile.config.use_fake_tensor = True
|
||||
|
||||
super(AotAutogradStrategy, self).__init__()
|
||||
counters["aot_autograd"]["total"] += 1
|
||||
self.use_fallback = False
|
||||
self.original_example_inputs = example_inputs
|
||||
self.gm = gm
|
||||
|
||||
if not functorch.compile.config.use_functionalize and config.normalize_ir:
|
||||
try:
|
||||
self.gm = normalize_ir(gm, self.example_inputs)
|
||||
except Exception:
|
||||
log.debug("TorchDynamo unable to remove mutation")
|
||||
self.use_fallback = True
|
||||
pass
|
||||
|
||||
if not is_aot_autograd_safe_to_run(gm, example_inputs):
|
||||
self.use_fallback = True
|
||||
|
||||
@property
|
||||
def example_inputs(self):
|
||||
return clone_inputs(self.original_example_inputs)
|
||||
|
||||
def verified_candidate(self):
|
||||
if self.use_fallback:
|
||||
log.debug("Unable to use AOT Autograd because graph has mutation")
|
||||
counters["aot_autograd"]["not_ok"] += 1
|
||||
return self.gm
|
||||
cg = self.candidate()
|
||||
if cg is None:
|
||||
counters["aot_autograd"]["not_ok"] += 1
|
||||
raise RuntimeError("AOT Autograd failed to compile")
|
||||
counters["aot_autograd"]["ok"] += 1
|
||||
return cg
|
||||
|
||||
def candidate(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AotNop(AotAutogradStrategy):
|
||||
"""Useful for debugging purpose"""
|
||||
|
||||
def candidate(self):
|
||||
from functorch._src.compilers import debug_nop
|
||||
from functorch.compile import nop
|
||||
|
||||
DEBUG = False
|
||||
return BACKENDS["aot_autograd"](
|
||||
self.gm, self.example_inputs, fw_compiler=debug_nop if DEBUG else nop
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
|
||||
aot_eager = AotNop.compile_fn
|
||||
|
||||
|
||||
class AotTorchscript(AotAutogradStrategy):
|
||||
"""
|
||||
AOT Autograd with torchscript backend. Default partitioner.
|
||||
"""
|
||||
|
||||
def candidate(self):
|
||||
from functorch.compile import ts_compile
|
||||
|
||||
return BACKENDS["aot_autograd"](
|
||||
self.gm, self.example_inputs, fw_compiler=ts_compile
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
|
||||
aot_ts = AotTorchscript.compile_fn
|
||||
|
||||
# Global counter to differentiate between different graphs.
|
||||
graph_idx = 0
|
||||
|
||||
|
||||
class AotPrint(AotNop):
|
||||
"""Saves all the gm models so that we can run them separately"""
|
||||
|
||||
def candidate(self):
|
||||
global graph_idx
|
||||
module_idx = "module_" + str(graph_idx)
|
||||
self.gm.to_folder(module_idx, "Bar")
|
||||
for idx, x in enumerate(self.example_inputs):
|
||||
torch.save(x, module_idx + "_tensor" + str(idx) + ".pt")
|
||||
graph_idx += 1
|
||||
return super(AotPrint, self).candidate()
|
||||
|
||||
|
||||
aot_print = AotPrint.compile_fn
|
||||
# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
|
||||
# inductor problems.
|
||||
aot_inductor_debug = aot_autograd(
|
||||
# these are taken from memory_efficient_fusion()
|
||||
fw_compiler=nop,
|
||||
bw_compiler=nop,
|
||||
# NB: lambda here is to delay import of inductor
|
||||
decompositions=lambda: import_module(
|
||||
f"{config.inductor_import}.compile_fx"
|
||||
).select_decomp_table(),
|
||||
partition_fn=functools.partial(
|
||||
min_cut_rematerialization_partition, compiler="inductor"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def mem_efficient_fusion_kwargs(use_decomps):
|
||||
@ -209,66 +195,15 @@ def mem_efficient_fusion_kwargs(use_decomps):
|
||||
return kwargs
|
||||
|
||||
|
||||
class AotMemEfficientFusion(AotAutogradStrategy):
|
||||
"""Use Min cut rematerilization and TorchScript+nvFuser with AOT Autograd"""
|
||||
# Use min cut rematerialization and TorchScript+nvFuser with AOT Autograd
|
||||
aot_mem_efficient_fusion = aot_autograd(**mem_efficient_fusion_kwargs(use_decomps=True))
|
||||
aot_mem_efficient_fusion_no_decomp = aot_autograd(
|
||||
**mem_efficient_fusion_kwargs(use_decomps=False)
|
||||
)
|
||||
|
||||
def candidate(self):
|
||||
kwargs = mem_efficient_fusion_kwargs(use_decomps=True)
|
||||
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class AotMemEfficientFusionNoDecomps(AotAutogradStrategy):
|
||||
"""Use Min cut rematerilization and TorchScript+nvFuser with AOT Autograd"""
|
||||
|
||||
def candidate(self):
|
||||
kwargs = mem_efficient_fusion_kwargs(use_decomps=False)
|
||||
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class AotInductorDebug(AotAutogradStrategy):
|
||||
"""
|
||||
Uses TorchInductor Aot Autograd decopms and partitioner to isolate aot vs
|
||||
inductor problems.
|
||||
"""
|
||||
|
||||
def candidate(self):
|
||||
from functorch.compile import min_cut_rematerialization_partition, nop
|
||||
|
||||
decompositions = import_module(
|
||||
f"{config.inductor_import}.compile_fx"
|
||||
).select_decomp_table()
|
||||
|
||||
kwargs = {
|
||||
# these are taken from memory_efficient_fusion()
|
||||
"fw_compiler": nop,
|
||||
"bw_compiler": nop,
|
||||
"decompositions": decompositions,
|
||||
"partition_fn": functools.partial(
|
||||
min_cut_rematerialization_partition, compiler="inductor"
|
||||
),
|
||||
}
|
||||
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
|
||||
aot_inductor_debug = AotInductorDebug.compile_fn
|
||||
|
||||
|
||||
class AOTMemEfficientFusionWithContext:
|
||||
"""Pass TorchScript+nvFuser context to TorchDynamo"""
|
||||
|
||||
def __init__(self, use_decomps=True):
|
||||
self.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
|
||||
self.use_decomps = use_decomps
|
||||
|
||||
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
||||
if self.use_decomps:
|
||||
return AotMemEfficientFusion.compile_fn(gm, example_inputs)
|
||||
else:
|
||||
return AotMemEfficientFusionNoDecomps.compile_fn(gm, example_inputs)
|
||||
|
||||
|
||||
aot_mem_efficient_fusion = AOTMemEfficientFusionWithContext(True)
|
||||
aot_mem_efficient_fusion_no_decomp = AOTMemEfficientFusionWithContext(False)
|
||||
# Pass TorchScript+nvFuser context to TorchDynamo
|
||||
aot_mem_efficient_fusion.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
|
||||
aot_mem_efficient_fusion_no_decomp.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
|
||||
|
||||
|
||||
def prims_executor(gm, inputs, *, executor):
|
||||
@ -332,27 +267,15 @@ def nvprims_fw_bw_partition_fn(joint_module, joint_inputs, *, num_fwd_outputs):
|
||||
|
||||
|
||||
def create_nvprims_backend(*, executor):
|
||||
class NvPrims(AotAutogradStrategy):
|
||||
def __init__(self, gm: torch.fx.GraphModule, example_inputs):
|
||||
super(NvPrims, self).__init__(gm, example_inputs)
|
||||
self.executor = executor
|
||||
|
||||
def candidate(self):
|
||||
from torch._dynamo import disable
|
||||
|
||||
return BACKENDS["aot_autograd"](
|
||||
self.gm,
|
||||
self.example_inputs,
|
||||
fw_compiler=partial(prims_executor, executor=self.executor),
|
||||
bw_compiler=partial(prims_executor, executor=self.executor),
|
||||
partition_fn=disable(nvprims_fw_bw_partition_fn),
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
return NvPrims
|
||||
return aot_autograd(
|
||||
fw_compiler=partial(prims_executor, executor=executor),
|
||||
bw_compiler=partial(prims_executor, executor=executor),
|
||||
partition_fn=nvprims_fw_bw_partition_fn,
|
||||
)
|
||||
|
||||
|
||||
aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser").compile_fn
|
||||
aot_nvprims_aten = create_nvprims_backend(executor="aten").compile_fn
|
||||
aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser")
|
||||
aot_nvprims_aten = create_nvprims_backend(executor="aten")
|
||||
|
||||
|
||||
def cloner(t):
|
||||
@ -476,33 +399,7 @@ def cudagraphs(model, inputs):
|
||||
return model
|
||||
|
||||
|
||||
def raw_aot_autograd_cudagraphs(model, inputs):
|
||||
kwargs = {
|
||||
# these are taken from memory_efficient_fusion()
|
||||
"fw_compiler": cudagraphs,
|
||||
"bw_compiler": cudagraphs,
|
||||
}
|
||||
|
||||
def _wrapped_bw_compiler(*args, **kwargs):
|
||||
# stop TorchDynamo from trying to compile our generated backwards pass
|
||||
return disable(disable(bw_compiler)(*args, **kwargs)) # type: ignore[operator]
|
||||
|
||||
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
|
||||
kwargs["bw_compiler"] = _wrapped_bw_compiler
|
||||
|
||||
from functorch.compile import aot_module_simplified # type: ignore[import]
|
||||
|
||||
from .. import disable
|
||||
|
||||
return aot_module_simplified(model, **kwargs)
|
||||
|
||||
|
||||
class AotAutogradCudaGraphs(AotAutogradStrategy):
|
||||
def candidate(self):
|
||||
return raw_aot_autograd_cudagraphs(self.gm, self.example_inputs)
|
||||
|
||||
|
||||
aot_cudagraphs = AotAutogradCudaGraphs.compile_fn
|
||||
aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs)
|
||||
|
||||
|
||||
def create_aot_backends():
|
||||
@ -512,11 +409,6 @@ def create_aot_backends():
|
||||
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
|
||||
BACKENDS["aot_eager"] = aot_eager
|
||||
|
||||
# aot_eager uses AOT Autograd backend with print compiler. It prints the
|
||||
# graphs and also saves the graph modules that are sent to AOT Autograd.
|
||||
# This is helpful for debugging.
|
||||
BACKENDS["aot_print"] = aot_print
|
||||
|
||||
# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
|
||||
# by using the relevant fuser with torch.jit.fuser(...)
|
||||
BACKENDS["aot_ts"] = aot_ts
|
||||
|
@ -27,7 +27,7 @@ from .virtualized import V
|
||||
log = logging.getLogger(__name__)
|
||||
ALIGNMENT = 16
|
||||
|
||||
aot_autograd = dynamo_optimizations.backends.aot_autograd
|
||||
aot_autograd = dynamo_optimizations.training.aot_autograd
|
||||
normalize_ir = dynamo_optimizations.normalize.normalize_ir
|
||||
is_aot_autograd_safe_to_run = dynamo_optimizations.training.is_aot_autograd_safe_to_run
|
||||
count_calls = dynamo_utils.count_calls
|
||||
@ -394,12 +394,17 @@ def compile_fx(
|
||||
# in functorch/_src/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
|
||||
# once torchdynamo is merged into pytorch
|
||||
return aot_autograd(
|
||||
model_,
|
||||
example_inputs_,
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
decompositions=select_decomp_table(),
|
||||
partition_fn=functools.partial(
|
||||
min_cut_rematerialization_partition, compiler="inductor"
|
||||
),
|
||||
)
|
||||
# A "tiny" graph can actually decompose into multiple
|
||||
# operators (if it's a decomposition) and inductor can
|
||||
# do a better job on it in this case
|
||||
#
|
||||
# Also, for some reason, test_comprehensive___rmatmul___cpu
|
||||
# fails without forcing a compile lol.
|
||||
force_compile_tiny_graphs=True,
|
||||
)(model_, example_inputs_)
|
||||
|
@ -1,159 +0,0 @@
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from torch.nn import Module
|
||||
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.utils._pytree import tree_map
|
||||
import torch._dynamo # type: ignore[import]
|
||||
from torch._dynamo.optimizations.training import AotAutogradStrategy # type: ignore[import]
|
||||
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from typing import Set, Dict, Any
|
||||
|
||||
# TODO: maybe this should live in torch._dynamo instead
|
||||
|
||||
__all__ = ['aot_autograd_cudagraphs']
|
||||
|
||||
def cloner(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.clone()
|
||||
else:
|
||||
return t
|
||||
|
||||
|
||||
class CudaGraphModule(Module):
|
||||
gm: GraphModule
|
||||
mutated_inputs: Set[int]
|
||||
|
||||
def __init__(self, gm, mutated_inputs):
|
||||
super().__init__()
|
||||
self.gm = gm
|
||||
self.mutated_inputs = mutated_inputs
|
||||
|
||||
warmed_up = False
|
||||
|
||||
# these are all None or all filled
|
||||
graph = None
|
||||
static_inputs = None
|
||||
static_outputs = None
|
||||
|
||||
# NB: we override __call__ as we don't need any nn.Module machinery
|
||||
# and to reduce overhead
|
||||
def __call__(self, *args):
|
||||
# TODO: once we've recorded here, we'd like to replace the __call__
|
||||
# implementation with compiled bytecode that copies into static, replays
|
||||
# the cuda graph, then copies out. First condition is the hotpath,
|
||||
# needs optimizing
|
||||
if self.graph is not None:
|
||||
assert len(args) == len(self.static_inputs)
|
||||
for dst, src in zip(self.static_inputs, args):
|
||||
dst.copy_(src)
|
||||
self.graph.replay()
|
||||
for i in self.mutated_inputs:
|
||||
args[i].copy_(self.static_inputs[i])
|
||||
return tree_map(cloner, self.static_outputs)
|
||||
|
||||
elif self.warmed_up:
|
||||
# record
|
||||
self.static_inputs = [x.clone() for x in args]
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph):
|
||||
self.static_outputs = self.gm(*self.static_inputs)
|
||||
# NB: recording doesn't actually run the operations, so
|
||||
# now we immediately replay the graph to serve up the result
|
||||
self.graph.replay()
|
||||
for i in self.mutated_inputs:
|
||||
args[i].copy_(self.static_inputs[i])
|
||||
return tree_map(cloner, self.static_outputs)
|
||||
|
||||
else:
|
||||
# warmup
|
||||
stream = torch.cuda.Stream()
|
||||
stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(stream):
|
||||
r = self.gm(*args)
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
self.warmed_up = True
|
||||
return r
|
||||
|
||||
|
||||
# Interpreter versions of these passes can be found at
|
||||
# https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23
|
||||
|
||||
|
||||
def find_input_mutations(g):
|
||||
FK = 'fake_result'
|
||||
inputs = defaultdict(set)
|
||||
input_idx = 0
|
||||
mutated_inputs = set()
|
||||
for n in g.nodes:
|
||||
if n.op == 'placeholder':
|
||||
inputs[StorageWeakRef(n.meta[FK]._typed_storage())].add(input_idx)
|
||||
input_idx += 1
|
||||
elif n.op == 'call_function':
|
||||
if n.target is operator.getitem:
|
||||
continue
|
||||
schema = n.target._schema
|
||||
for i, arg in enumerate(schema.arguments):
|
||||
if i < len(n.args):
|
||||
argument = n.args[i]
|
||||
else:
|
||||
if arg.name not in n.kwargs:
|
||||
continue
|
||||
argument = n.kwargs[arg.name]
|
||||
mut_arg = False
|
||||
if arg.alias_info:
|
||||
if arg.alias_info.is_write:
|
||||
mut_arg = True
|
||||
if mut_arg:
|
||||
# TODO: not correct for args that contain tensors in a struct
|
||||
# like list
|
||||
mutated_inputs |= inputs[StorageWeakRef(argument.meta[FK]._typed_storage())]
|
||||
# TODO: error on unrecognized nodes
|
||||
return mutated_inputs
|
||||
|
||||
|
||||
# Mutates input graph
|
||||
def apply_cuda_graphs(gm):
|
||||
for n in gm.graph.nodes:
|
||||
if n.op == 'call_module':
|
||||
assert not n.kwargs
|
||||
submod = gm.get_submodule(n.target)
|
||||
gm.delete_submodule(n.target)
|
||||
mutated_inputs = find_input_mutations(submod.graph)
|
||||
gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs))
|
||||
# NB: we didn't actually change the graph, no need for recompile
|
||||
|
||||
|
||||
def cudagraphs(model, inputs):
|
||||
model = partition_cudagraphs(model, inputs)
|
||||
apply_cuda_graphs(model)
|
||||
return model
|
||||
|
||||
|
||||
def raw_aot_autograd_cudagraphs(model, inputs):
|
||||
kwargs: Dict[str, Any] = {
|
||||
# these are taken from memory_efficient_fusion()
|
||||
"fw_compiler": cudagraphs,
|
||||
"bw_compiler": cudagraphs,
|
||||
}
|
||||
|
||||
def _wrapped_bw_compiler(*args, **kwargs):
|
||||
# stop dynamo from trying to compile our generated backwards pass
|
||||
return torch._dynamo.disable(bw_compiler(*args, **kwargs)) # type: ignore[operator]
|
||||
|
||||
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
|
||||
kwargs["bw_compiler"] = _wrapped_bw_compiler
|
||||
|
||||
from functorch.compile import aot_module_simplified # type: ignore[import]
|
||||
|
||||
return aot_module_simplified(model, **kwargs)
|
||||
|
||||
|
||||
class AOTAutogradCudaGraphs(AotAutogradStrategy):
|
||||
def candidate(self):
|
||||
return raw_aot_autograd_cudagraphs(self.gm, self.example_inputs)
|
||||
|
||||
|
||||
aot_autograd_cudagraphs = AOTAutogradCudaGraphs.compile_fn
|
Reference in New Issue
Block a user