Refactor how AOTAutograd backends are defined (#89736)

There was a lot of strangeness in how AOTAutograd backends were previously defined. This refactor replaces the strangeness with something simple and straightforward. The improvements:

- There is no longer a footgun aot_autograd "backend" which doesn't actually work. No more mistyping `torch._dynamo.optimize("aot_autograd")` when you meant "aot_eager"
- Deleted aot_print because it's annoying and anyway there's no uses of it
- Instead of having BOTH the backend Subgraph and AotAutogradStrategy, there is now only an aot_autograd function which takes the kwargs to configure AOTAutograd, and then gives you a compiler function that does AOTAutograd given those kwargs. Easy.
- The primary downside is that we are now eagerly populating all of the kwargs, and that can get us into import cycle shenanigans. Some cycles I resolved directly (e.g., we now no longer manually disable the forward function before passing it to aot_autograd; aot_autograd it does it for us), but for getting inductor decompositions I had to make it take a lambda so I could lazily populate the decomps later.

New code is 130 lines shorter!

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89736
Approved by: https://github.com/anjali411, https://github.com/albanD
This commit is contained in:
Edward Z. Yang
2022-11-28 14:57:42 +00:00
committed by PyTorch MergeBot
parent cf4969d9d6
commit b589e726d9
7 changed files with 112 additions and 393 deletions

View File

@ -15,7 +15,6 @@ import torch.nn as nn
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
import torch.utils.dlpack import torch.utils.dlpack
from torch import Tensor from torch import Tensor
from torch._dynamo import disable as disable_torchdynamo
from torch._dynamo.utils import dynamo_timed from torch._dynamo.utils import dynamo_timed
from torch._subclasses import FakeTensorMode, CrossRefFakeMode from torch._subclasses import FakeTensorMode, CrossRefFakeMode
from torch.fx import immutable_collections, Interpreter from torch.fx import immutable_collections, Interpreter
@ -1315,7 +1314,6 @@ def aot_dispatch_deduplicated_autograd(flat_fn, flat_args: List[Tensor], aot_con
fw_metadata = _fw_metadata fw_metadata = _fw_metadata
@staticmethod @staticmethod
@disable_torchdynamo
def forward(ctx, *deduped_flat_tensor_args): def forward(ctx, *deduped_flat_tensor_args):
# There is a pretty complicated calling convention around what the compiled fw returns. # There is a pretty complicated calling convention around what the compiled fw returns.
@ -1361,7 +1359,6 @@ def aot_dispatch_deduplicated_autograd(flat_fn, flat_args: List[Tensor], aot_con
return tuple(fw_outs[0:num_forward_returns]) return tuple(fw_outs[0:num_forward_returns])
@staticmethod @staticmethod
@disable_torchdynamo
def backward(ctx, *all_flat_args): def backward(ctx, *all_flat_args):
# Calling convention: we expect a grad_out passed to the backward: # Calling convention: we expect a grad_out passed to the backward:
# - for every output of the fw that does *not* alias an input # - for every output of the fw that does *not* alias an input

View File

@ -1788,7 +1788,6 @@ class TestImports(TestCase):
"torch.contrib.", # something weird "torch.contrib.", # something weird
"torch.testing._internal.distributed.", # just fails "torch.testing._internal.distributed.", # just fails
"torch.ao.pruning._experimental.", # depends on pytorch_lightning, not user-facing "torch.ao.pruning._experimental.", # depends on pytorch_lightning, not user-facing
"torch.cuda._dynamo_graphs", # depends on torchdynamo
] ]
# See https://github.com/pytorch/pytorch/issues/77801 # See https://github.com/pytorch/pytorch/issues/77801
if not sys.version_info >= (3, 9): if not sys.version_info >= (3, 9):

View File

@ -37,7 +37,6 @@ else:
from . import config, convert_frame, skipfiles, utils from . import config, convert_frame, skipfiles, utils
from .exc import ResetRequired from .exc import ResetRequired
from .mutation_guard import install_generation_tagging_init from .mutation_guard import install_generation_tagging_init
from .optimizations.distributed import DDPOptimizer
from .output_graph import CompilerFn from .output_graph import CompilerFn
from .types import DynamoCallback from .types import DynamoCallback
from .utils import compile_times from .utils import compile_times
@ -311,6 +310,8 @@ def catch_errors_wrapper(callback):
ddp_module = DistributedDataParallel._get_active_ddp_module() ddp_module = DistributedDataParallel._get_active_ddp_module()
if ddp_module: if ddp_module:
with compile_lock: with compile_lock:
from .optimizations.distributed import DDPOptimizer
ddp_optimizer = DDPOptimizer( ddp_optimizer = DDPOptimizer(
bucket_bytes_cap=ddp_module.bucket_bytes_cap, bucket_bytes_cap=ddp_module.bucket_bytes_cap,
backend_compile_fn=callback._torchdynamo_orig_callable, backend_compile_fn=callback._torchdynamo_orig_callable,

View File

@ -517,22 +517,6 @@ def cudagraphs_inner(model, inputs, copy_outputs=True):
return run return run
@create_backend
def aot_autograd(subgraph, **kwargs):
def _wrapped_bw_compiler(*args, **kwargs):
# stop TorchDynamo from trying to compile our generated backwards pass
return disable(disable(bw_compiler)(*args, **kwargs))
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
kwargs["bw_compiler"] = _wrapped_bw_compiler
from functorch.compile import aot_module_simplified
from .. import disable
return aot_module_simplified(subgraph.model, **kwargs)
def tvm_compile(jit_mod, example_inputs, log_file=None, **kwargs): def tvm_compile(jit_mod, example_inputs, log_file=None, **kwargs):
if jit_mod is None: if jit_mod is None:
return None return None

View File

@ -6,6 +6,15 @@ from functools import partial
from importlib import import_module from importlib import import_module
from typing import Set from typing import Set
from functorch._src.compilers import debug_nop
from functorch.compile import (
aot_module_simplified,
min_cut_rematerialization_partition,
nop,
ts_compile,
)
import torch import torch
from torch.fx import GraphModule from torch.fx import GraphModule
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
@ -13,7 +22,7 @@ from torch.multiprocessing.reductions import StorageWeakRef
from torch.nn import Module from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from .. import config from .. import config, eval_frame
from ..utils import clone_inputs, count_calls, counters from ..utils import clone_inputs, count_calls, counters
from .analysis import has_mutation from .analysis import has_mutation
from .backends import BACKENDS from .backends import BACKENDS
@ -22,6 +31,62 @@ from .normalize import normalize_ir
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def aot_autograd(**kwargs):
def compiler_fn(gm: torch.fx.GraphModule, example_inputs):
import functorch.compile
# Hack to get around circular import problems with aot_inductor_debug
if callable(kwargs.get("decompositions")):
kwargs["decompositions"] = kwargs["decompositions"]()
# TODO: stop monkeypatching here (without even cleaning up, UGH!)
functorch.compile.config.use_functionalize = True
functorch.compile.config.use_fake_tensor = True
force_compile_tiny_graphs = kwargs.pop("force_compile_tiny_graphs", False)
if count_calls(gm.graph) < 2 and not force_compile_tiny_graphs:
return gm # no point for tiny graphs
counters["aot_autograd"]["total"] += 1
use_fallback = False
if not functorch.compile.config.use_functionalize and config.normalize_ir:
try:
gm = normalize_ir(gm, clone_inputs(example_inputs))
except Exception:
log.debug("TorchDynamo unable to remove mutation")
use_fallback = True
# NB: no clone here on example inputs
if not is_aot_autograd_safe_to_run(gm, example_inputs):
use_fallback = True
if use_fallback:
log.debug("Unable to use AOT Autograd because graph has mutation")
counters["aot_autograd"]["not_ok"] += 1
return gm
# OK attempt to compile
def _wrapped_bw_compiler(*args, **kwargs):
# stop TorchDynamo from trying to compile our generated backwards pass
return eval_frame.disable(eval_frame.disable(bw_compiler)(*args, **kwargs))
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
kwargs["bw_compiler"] = _wrapped_bw_compiler
try:
cg = aot_module_simplified(gm, **kwargs)
counters["aot_autograd"]["ok"] += 1
return eval_frame.disable(cg)
except Exception:
counters["aot_autograd"]["not_ok"] += 1
raise
return compiler_fn
def is_aot_autograd_safe_to_run(gm, example_inputs): def is_aot_autograd_safe_to_run(gm, example_inputs):
""" """
There are some known issues with Aot Autograd. This is a workaround to catch There are some known issues with Aot Autograd. This is a workaround to catch
@ -86,107 +151,28 @@ def is_aot_autograd_safe_to_run(gm, example_inputs):
return True return True
class AotAutogradStrategy(object): DEBUG = False
"""Base class for backend strategies that use AOT Autograd"""
@classmethod # Useful for debugging purpose
def compile_fn(cls, gm: torch.fx.GraphModule, example_inputs): aot_eager = aot_autograd(fw_compiler=debug_nop if DEBUG else nop)
if count_calls(gm.graph) < 2:
return gm # no point for tiny graphs
return cls(gm, example_inputs).verified_candidate()
def __init__(self, gm: torch.fx.GraphModule, example_inputs): # AOT Autograd with torchscript backend. Default partitioner.
import functorch.compile aot_ts = aot_autograd(fw_compiler=ts_compile)
functorch.compile.config.use_functionalize = True # Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
functorch.compile.config.use_fake_tensor = True # inductor problems.
aot_inductor_debug = aot_autograd(
super(AotAutogradStrategy, self).__init__() # these are taken from memory_efficient_fusion()
counters["aot_autograd"]["total"] += 1 fw_compiler=nop,
self.use_fallback = False bw_compiler=nop,
self.original_example_inputs = example_inputs # NB: lambda here is to delay import of inductor
self.gm = gm decompositions=lambda: import_module(
f"{config.inductor_import}.compile_fx"
if not functorch.compile.config.use_functionalize and config.normalize_ir: ).select_decomp_table(),
try: partition_fn=functools.partial(
self.gm = normalize_ir(gm, self.example_inputs) min_cut_rematerialization_partition, compiler="inductor"
except Exception: ),
log.debug("TorchDynamo unable to remove mutation") )
self.use_fallback = True
pass
if not is_aot_autograd_safe_to_run(gm, example_inputs):
self.use_fallback = True
@property
def example_inputs(self):
return clone_inputs(self.original_example_inputs)
def verified_candidate(self):
if self.use_fallback:
log.debug("Unable to use AOT Autograd because graph has mutation")
counters["aot_autograd"]["not_ok"] += 1
return self.gm
cg = self.candidate()
if cg is None:
counters["aot_autograd"]["not_ok"] += 1
raise RuntimeError("AOT Autograd failed to compile")
counters["aot_autograd"]["ok"] += 1
return cg
def candidate(self):
raise NotImplementedError()
class AotNop(AotAutogradStrategy):
"""Useful for debugging purpose"""
def candidate(self):
from functorch._src.compilers import debug_nop
from functorch.compile import nop
DEBUG = False
return BACKENDS["aot_autograd"](
self.gm, self.example_inputs, fw_compiler=debug_nop if DEBUG else nop
) # type: ignore[call-arg]
aot_eager = AotNop.compile_fn
class AotTorchscript(AotAutogradStrategy):
"""
AOT Autograd with torchscript backend. Default partitioner.
"""
def candidate(self):
from functorch.compile import ts_compile
return BACKENDS["aot_autograd"](
self.gm, self.example_inputs, fw_compiler=ts_compile
) # type: ignore[call-arg]
aot_ts = AotTorchscript.compile_fn
# Global counter to differentiate between different graphs.
graph_idx = 0
class AotPrint(AotNop):
"""Saves all the gm models so that we can run them separately"""
def candidate(self):
global graph_idx
module_idx = "module_" + str(graph_idx)
self.gm.to_folder(module_idx, "Bar")
for idx, x in enumerate(self.example_inputs):
torch.save(x, module_idx + "_tensor" + str(idx) + ".pt")
graph_idx += 1
return super(AotPrint, self).candidate()
aot_print = AotPrint.compile_fn
def mem_efficient_fusion_kwargs(use_decomps): def mem_efficient_fusion_kwargs(use_decomps):
@ -209,66 +195,15 @@ def mem_efficient_fusion_kwargs(use_decomps):
return kwargs return kwargs
class AotMemEfficientFusion(AotAutogradStrategy): # Use min cut rematerialization and TorchScript+nvFuser with AOT Autograd
"""Use Min cut rematerilization and TorchScript+nvFuser with AOT Autograd""" aot_mem_efficient_fusion = aot_autograd(**mem_efficient_fusion_kwargs(use_decomps=True))
aot_mem_efficient_fusion_no_decomp = aot_autograd(
**mem_efficient_fusion_kwargs(use_decomps=False)
)
def candidate(self): # Pass TorchScript+nvFuser context to TorchDynamo
kwargs = mem_efficient_fusion_kwargs(use_decomps=True) aot_mem_efficient_fusion.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) # type: ignore[call-arg] aot_mem_efficient_fusion_no_decomp.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
class AotMemEfficientFusionNoDecomps(AotAutogradStrategy):
"""Use Min cut rematerilization and TorchScript+nvFuser with AOT Autograd"""
def candidate(self):
kwargs = mem_efficient_fusion_kwargs(use_decomps=False)
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) # type: ignore[call-arg]
class AotInductorDebug(AotAutogradStrategy):
"""
Uses TorchInductor Aot Autograd decopms and partitioner to isolate aot vs
inductor problems.
"""
def candidate(self):
from functorch.compile import min_cut_rematerialization_partition, nop
decompositions = import_module(
f"{config.inductor_import}.compile_fx"
).select_decomp_table()
kwargs = {
# these are taken from memory_efficient_fusion()
"fw_compiler": nop,
"bw_compiler": nop,
"decompositions": decompositions,
"partition_fn": functools.partial(
min_cut_rematerialization_partition, compiler="inductor"
),
}
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) # type: ignore[call-arg]
aot_inductor_debug = AotInductorDebug.compile_fn
class AOTMemEfficientFusionWithContext:
"""Pass TorchScript+nvFuser context to TorchDynamo"""
def __init__(self, use_decomps=True):
self.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
self.use_decomps = use_decomps
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
if self.use_decomps:
return AotMemEfficientFusion.compile_fn(gm, example_inputs)
else:
return AotMemEfficientFusionNoDecomps.compile_fn(gm, example_inputs)
aot_mem_efficient_fusion = AOTMemEfficientFusionWithContext(True)
aot_mem_efficient_fusion_no_decomp = AOTMemEfficientFusionWithContext(False)
def prims_executor(gm, inputs, *, executor): def prims_executor(gm, inputs, *, executor):
@ -332,27 +267,15 @@ def nvprims_fw_bw_partition_fn(joint_module, joint_inputs, *, num_fwd_outputs):
def create_nvprims_backend(*, executor): def create_nvprims_backend(*, executor):
class NvPrims(AotAutogradStrategy): return aot_autograd(
def __init__(self, gm: torch.fx.GraphModule, example_inputs): fw_compiler=partial(prims_executor, executor=executor),
super(NvPrims, self).__init__(gm, example_inputs) bw_compiler=partial(prims_executor, executor=executor),
self.executor = executor partition_fn=nvprims_fw_bw_partition_fn,
)
def candidate(self):
from torch._dynamo import disable
return BACKENDS["aot_autograd"](
self.gm,
self.example_inputs,
fw_compiler=partial(prims_executor, executor=self.executor),
bw_compiler=partial(prims_executor, executor=self.executor),
partition_fn=disable(nvprims_fw_bw_partition_fn),
) # type: ignore[call-arg]
return NvPrims
aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser").compile_fn aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser")
aot_nvprims_aten = create_nvprims_backend(executor="aten").compile_fn aot_nvprims_aten = create_nvprims_backend(executor="aten")
def cloner(t): def cloner(t):
@ -476,33 +399,7 @@ def cudagraphs(model, inputs):
return model return model
def raw_aot_autograd_cudagraphs(model, inputs): aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs)
kwargs = {
# these are taken from memory_efficient_fusion()
"fw_compiler": cudagraphs,
"bw_compiler": cudagraphs,
}
def _wrapped_bw_compiler(*args, **kwargs):
# stop TorchDynamo from trying to compile our generated backwards pass
return disable(disable(bw_compiler)(*args, **kwargs)) # type: ignore[operator]
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
kwargs["bw_compiler"] = _wrapped_bw_compiler
from functorch.compile import aot_module_simplified # type: ignore[import]
from .. import disable
return aot_module_simplified(model, **kwargs)
class AotAutogradCudaGraphs(AotAutogradStrategy):
def candidate(self):
return raw_aot_autograd_cudagraphs(self.gm, self.example_inputs)
aot_cudagraphs = AotAutogradCudaGraphs.compile_fn
def create_aot_backends(): def create_aot_backends():
@ -512,11 +409,6 @@ def create_aot_backends():
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging. # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
BACKENDS["aot_eager"] = aot_eager BACKENDS["aot_eager"] = aot_eager
# aot_eager uses AOT Autograd backend with print compiler. It prints the
# graphs and also saves the graph modules that are sent to AOT Autograd.
# This is helpful for debugging.
BACKENDS["aot_print"] = aot_print
# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser # aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
# by using the relevant fuser with torch.jit.fuser(...) # by using the relevant fuser with torch.jit.fuser(...)
BACKENDS["aot_ts"] = aot_ts BACKENDS["aot_ts"] = aot_ts

View File

@ -27,7 +27,7 @@ from .virtualized import V
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ALIGNMENT = 16 ALIGNMENT = 16
aot_autograd = dynamo_optimizations.backends.aot_autograd aot_autograd = dynamo_optimizations.training.aot_autograd
normalize_ir = dynamo_optimizations.normalize.normalize_ir normalize_ir = dynamo_optimizations.normalize.normalize_ir
is_aot_autograd_safe_to_run = dynamo_optimizations.training.is_aot_autograd_safe_to_run is_aot_autograd_safe_to_run = dynamo_optimizations.training.is_aot_autograd_safe_to_run
count_calls = dynamo_utils.count_calls count_calls = dynamo_utils.count_calls
@ -394,12 +394,17 @@ def compile_fx(
# in functorch/_src/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func # in functorch/_src/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
# once torchdynamo is merged into pytorch # once torchdynamo is merged into pytorch
return aot_autograd( return aot_autograd(
model_,
example_inputs_,
fw_compiler=fw_compiler, fw_compiler=fw_compiler,
bw_compiler=bw_compiler, bw_compiler=bw_compiler,
decompositions=select_decomp_table(), decompositions=select_decomp_table(),
partition_fn=functools.partial( partition_fn=functools.partial(
min_cut_rematerialization_partition, compiler="inductor" min_cut_rematerialization_partition, compiler="inductor"
), ),
) # A "tiny" graph can actually decompose into multiple
# operators (if it's a decomposition) and inductor can
# do a better job on it in this case
#
# Also, for some reason, test_comprehensive___rmatmul___cpu
# fails without forcing a compile lol.
force_compile_tiny_graphs=True,
)(model_, example_inputs_)

View File

@ -1,159 +0,0 @@
import torch
from torch.fx import GraphModule
from torch.nn import Module
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._pytree import tree_map
import torch._dynamo # type: ignore[import]
from torch._dynamo.optimizations.training import AotAutogradStrategy # type: ignore[import]
import operator
from collections import defaultdict
from typing import Set, Dict, Any
# TODO: maybe this should live in torch._dynamo instead
__all__ = ['aot_autograd_cudagraphs']
def cloner(t):
if isinstance(t, torch.Tensor):
return t.clone()
else:
return t
class CudaGraphModule(Module):
gm: GraphModule
mutated_inputs: Set[int]
def __init__(self, gm, mutated_inputs):
super().__init__()
self.gm = gm
self.mutated_inputs = mutated_inputs
warmed_up = False
# these are all None or all filled
graph = None
static_inputs = None
static_outputs = None
# NB: we override __call__ as we don't need any nn.Module machinery
# and to reduce overhead
def __call__(self, *args):
# TODO: once we've recorded here, we'd like to replace the __call__
# implementation with compiled bytecode that copies into static, replays
# the cuda graph, then copies out. First condition is the hotpath,
# needs optimizing
if self.graph is not None:
assert len(args) == len(self.static_inputs)
for dst, src in zip(self.static_inputs, args):
dst.copy_(src)
self.graph.replay()
for i in self.mutated_inputs:
args[i].copy_(self.static_inputs[i])
return tree_map(cloner, self.static_outputs)
elif self.warmed_up:
# record
self.static_inputs = [x.clone() for x in args]
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_outputs = self.gm(*self.static_inputs)
# NB: recording doesn't actually run the operations, so
# now we immediately replay the graph to serve up the result
self.graph.replay()
for i in self.mutated_inputs:
args[i].copy_(self.static_inputs[i])
return tree_map(cloner, self.static_outputs)
else:
# warmup
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
r = self.gm(*args)
torch.cuda.current_stream().wait_stream(stream)
self.warmed_up = True
return r
# Interpreter versions of these passes can be found at
# https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23
def find_input_mutations(g):
FK = 'fake_result'
inputs = defaultdict(set)
input_idx = 0
mutated_inputs = set()
for n in g.nodes:
if n.op == 'placeholder':
inputs[StorageWeakRef(n.meta[FK]._typed_storage())].add(input_idx)
input_idx += 1
elif n.op == 'call_function':
if n.target is operator.getitem:
continue
schema = n.target._schema
for i, arg in enumerate(schema.arguments):
if i < len(n.args):
argument = n.args[i]
else:
if arg.name not in n.kwargs:
continue
argument = n.kwargs[arg.name]
mut_arg = False
if arg.alias_info:
if arg.alias_info.is_write:
mut_arg = True
if mut_arg:
# TODO: not correct for args that contain tensors in a struct
# like list
mutated_inputs |= inputs[StorageWeakRef(argument.meta[FK]._typed_storage())]
# TODO: error on unrecognized nodes
return mutated_inputs
# Mutates input graph
def apply_cuda_graphs(gm):
for n in gm.graph.nodes:
if n.op == 'call_module':
assert not n.kwargs
submod = gm.get_submodule(n.target)
gm.delete_submodule(n.target)
mutated_inputs = find_input_mutations(submod.graph)
gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs))
# NB: we didn't actually change the graph, no need for recompile
def cudagraphs(model, inputs):
model = partition_cudagraphs(model, inputs)
apply_cuda_graphs(model)
return model
def raw_aot_autograd_cudagraphs(model, inputs):
kwargs: Dict[str, Any] = {
# these are taken from memory_efficient_fusion()
"fw_compiler": cudagraphs,
"bw_compiler": cudagraphs,
}
def _wrapped_bw_compiler(*args, **kwargs):
# stop dynamo from trying to compile our generated backwards pass
return torch._dynamo.disable(bw_compiler(*args, **kwargs)) # type: ignore[operator]
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
kwargs["bw_compiler"] = _wrapped_bw_compiler
from functorch.compile import aot_module_simplified # type: ignore[import]
return aot_module_simplified(model, **kwargs)
class AOTAutogradCudaGraphs(AotAutogradStrategy):
def candidate(self):
return raw_aot_autograd_cudagraphs(self.gm, self.example_inputs)
aot_autograd_cudagraphs = AOTAutogradCudaGraphs.compile_fn