mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor how AOTAutograd backends are defined (#89736)
There was a lot of strangeness in how AOTAutograd backends were previously defined. This refactor replaces the strangeness with something simple and straightforward. The improvements: - There is no longer a footgun aot_autograd "backend" which doesn't actually work. No more mistyping `torch._dynamo.optimize("aot_autograd")` when you meant "aot_eager" - Deleted aot_print because it's annoying and anyway there's no uses of it - Instead of having BOTH the backend Subgraph and AotAutogradStrategy, there is now only an aot_autograd function which takes the kwargs to configure AOTAutograd, and then gives you a compiler function that does AOTAutograd given those kwargs. Easy. - The primary downside is that we are now eagerly populating all of the kwargs, and that can get us into import cycle shenanigans. Some cycles I resolved directly (e.g., we now no longer manually disable the forward function before passing it to aot_autograd; aot_autograd it does it for us), but for getting inductor decompositions I had to make it take a lambda so I could lazily populate the decomps later. New code is 130 lines shorter! Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/89736 Approved by: https://github.com/anjali411, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
cf4969d9d6
commit
b589e726d9
@ -15,7 +15,6 @@ import torch.nn as nn
|
|||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
import torch.utils.dlpack
|
import torch.utils.dlpack
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch._dynamo import disable as disable_torchdynamo
|
|
||||||
from torch._dynamo.utils import dynamo_timed
|
from torch._dynamo.utils import dynamo_timed
|
||||||
from torch._subclasses import FakeTensorMode, CrossRefFakeMode
|
from torch._subclasses import FakeTensorMode, CrossRefFakeMode
|
||||||
from torch.fx import immutable_collections, Interpreter
|
from torch.fx import immutable_collections, Interpreter
|
||||||
@ -1315,7 +1314,6 @@ def aot_dispatch_deduplicated_autograd(flat_fn, flat_args: List[Tensor], aot_con
|
|||||||
fw_metadata = _fw_metadata
|
fw_metadata = _fw_metadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@disable_torchdynamo
|
|
||||||
def forward(ctx, *deduped_flat_tensor_args):
|
def forward(ctx, *deduped_flat_tensor_args):
|
||||||
|
|
||||||
# There is a pretty complicated calling convention around what the compiled fw returns.
|
# There is a pretty complicated calling convention around what the compiled fw returns.
|
||||||
@ -1361,7 +1359,6 @@ def aot_dispatch_deduplicated_autograd(flat_fn, flat_args: List[Tensor], aot_con
|
|||||||
return tuple(fw_outs[0:num_forward_returns])
|
return tuple(fw_outs[0:num_forward_returns])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@disable_torchdynamo
|
|
||||||
def backward(ctx, *all_flat_args):
|
def backward(ctx, *all_flat_args):
|
||||||
# Calling convention: we expect a grad_out passed to the backward:
|
# Calling convention: we expect a grad_out passed to the backward:
|
||||||
# - for every output of the fw that does *not* alias an input
|
# - for every output of the fw that does *not* alias an input
|
||||||
|
@ -1788,7 +1788,6 @@ class TestImports(TestCase):
|
|||||||
"torch.contrib.", # something weird
|
"torch.contrib.", # something weird
|
||||||
"torch.testing._internal.distributed.", # just fails
|
"torch.testing._internal.distributed.", # just fails
|
||||||
"torch.ao.pruning._experimental.", # depends on pytorch_lightning, not user-facing
|
"torch.ao.pruning._experimental.", # depends on pytorch_lightning, not user-facing
|
||||||
"torch.cuda._dynamo_graphs", # depends on torchdynamo
|
|
||||||
]
|
]
|
||||||
# See https://github.com/pytorch/pytorch/issues/77801
|
# See https://github.com/pytorch/pytorch/issues/77801
|
||||||
if not sys.version_info >= (3, 9):
|
if not sys.version_info >= (3, 9):
|
||||||
|
@ -37,7 +37,6 @@ else:
|
|||||||
from . import config, convert_frame, skipfiles, utils
|
from . import config, convert_frame, skipfiles, utils
|
||||||
from .exc import ResetRequired
|
from .exc import ResetRequired
|
||||||
from .mutation_guard import install_generation_tagging_init
|
from .mutation_guard import install_generation_tagging_init
|
||||||
from .optimizations.distributed import DDPOptimizer
|
|
||||||
from .output_graph import CompilerFn
|
from .output_graph import CompilerFn
|
||||||
from .types import DynamoCallback
|
from .types import DynamoCallback
|
||||||
from .utils import compile_times
|
from .utils import compile_times
|
||||||
@ -311,6 +310,8 @@ def catch_errors_wrapper(callback):
|
|||||||
ddp_module = DistributedDataParallel._get_active_ddp_module()
|
ddp_module = DistributedDataParallel._get_active_ddp_module()
|
||||||
if ddp_module:
|
if ddp_module:
|
||||||
with compile_lock:
|
with compile_lock:
|
||||||
|
from .optimizations.distributed import DDPOptimizer
|
||||||
|
|
||||||
ddp_optimizer = DDPOptimizer(
|
ddp_optimizer = DDPOptimizer(
|
||||||
bucket_bytes_cap=ddp_module.bucket_bytes_cap,
|
bucket_bytes_cap=ddp_module.bucket_bytes_cap,
|
||||||
backend_compile_fn=callback._torchdynamo_orig_callable,
|
backend_compile_fn=callback._torchdynamo_orig_callable,
|
||||||
|
@ -517,22 +517,6 @@ def cudagraphs_inner(model, inputs, copy_outputs=True):
|
|||||||
return run
|
return run
|
||||||
|
|
||||||
|
|
||||||
@create_backend
|
|
||||||
def aot_autograd(subgraph, **kwargs):
|
|
||||||
def _wrapped_bw_compiler(*args, **kwargs):
|
|
||||||
# stop TorchDynamo from trying to compile our generated backwards pass
|
|
||||||
return disable(disable(bw_compiler)(*args, **kwargs))
|
|
||||||
|
|
||||||
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
|
|
||||||
kwargs["bw_compiler"] = _wrapped_bw_compiler
|
|
||||||
|
|
||||||
from functorch.compile import aot_module_simplified
|
|
||||||
|
|
||||||
from .. import disable
|
|
||||||
|
|
||||||
return aot_module_simplified(subgraph.model, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def tvm_compile(jit_mod, example_inputs, log_file=None, **kwargs):
|
def tvm_compile(jit_mod, example_inputs, log_file=None, **kwargs):
|
||||||
if jit_mod is None:
|
if jit_mod is None:
|
||||||
return None
|
return None
|
||||||
|
@ -6,6 +6,15 @@ from functools import partial
|
|||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from typing import Set
|
from typing import Set
|
||||||
|
|
||||||
|
from functorch._src.compilers import debug_nop
|
||||||
|
|
||||||
|
from functorch.compile import (
|
||||||
|
aot_module_simplified,
|
||||||
|
min_cut_rematerialization_partition,
|
||||||
|
nop,
|
||||||
|
ts_compile,
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
|
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
|
||||||
@ -13,7 +22,7 @@ from torch.multiprocessing.reductions import StorageWeakRef
|
|||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from .. import config
|
from .. import config, eval_frame
|
||||||
from ..utils import clone_inputs, count_calls, counters
|
from ..utils import clone_inputs, count_calls, counters
|
||||||
from .analysis import has_mutation
|
from .analysis import has_mutation
|
||||||
from .backends import BACKENDS
|
from .backends import BACKENDS
|
||||||
@ -22,6 +31,62 @@ from .normalize import normalize_ir
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def aot_autograd(**kwargs):
|
||||||
|
def compiler_fn(gm: torch.fx.GraphModule, example_inputs):
|
||||||
|
import functorch.compile
|
||||||
|
|
||||||
|
# Hack to get around circular import problems with aot_inductor_debug
|
||||||
|
if callable(kwargs.get("decompositions")):
|
||||||
|
kwargs["decompositions"] = kwargs["decompositions"]()
|
||||||
|
|
||||||
|
# TODO: stop monkeypatching here (without even cleaning up, UGH!)
|
||||||
|
functorch.compile.config.use_functionalize = True
|
||||||
|
functorch.compile.config.use_fake_tensor = True
|
||||||
|
|
||||||
|
force_compile_tiny_graphs = kwargs.pop("force_compile_tiny_graphs", False)
|
||||||
|
|
||||||
|
if count_calls(gm.graph) < 2 and not force_compile_tiny_graphs:
|
||||||
|
return gm # no point for tiny graphs
|
||||||
|
|
||||||
|
counters["aot_autograd"]["total"] += 1
|
||||||
|
use_fallback = False
|
||||||
|
|
||||||
|
if not functorch.compile.config.use_functionalize and config.normalize_ir:
|
||||||
|
try:
|
||||||
|
gm = normalize_ir(gm, clone_inputs(example_inputs))
|
||||||
|
except Exception:
|
||||||
|
log.debug("TorchDynamo unable to remove mutation")
|
||||||
|
use_fallback = True
|
||||||
|
|
||||||
|
# NB: no clone here on example inputs
|
||||||
|
if not is_aot_autograd_safe_to_run(gm, example_inputs):
|
||||||
|
use_fallback = True
|
||||||
|
|
||||||
|
if use_fallback:
|
||||||
|
log.debug("Unable to use AOT Autograd because graph has mutation")
|
||||||
|
counters["aot_autograd"]["not_ok"] += 1
|
||||||
|
return gm
|
||||||
|
|
||||||
|
# OK attempt to compile
|
||||||
|
|
||||||
|
def _wrapped_bw_compiler(*args, **kwargs):
|
||||||
|
# stop TorchDynamo from trying to compile our generated backwards pass
|
||||||
|
return eval_frame.disable(eval_frame.disable(bw_compiler)(*args, **kwargs))
|
||||||
|
|
||||||
|
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
|
||||||
|
kwargs["bw_compiler"] = _wrapped_bw_compiler
|
||||||
|
|
||||||
|
try:
|
||||||
|
cg = aot_module_simplified(gm, **kwargs)
|
||||||
|
counters["aot_autograd"]["ok"] += 1
|
||||||
|
return eval_frame.disable(cg)
|
||||||
|
except Exception:
|
||||||
|
counters["aot_autograd"]["not_ok"] += 1
|
||||||
|
raise
|
||||||
|
|
||||||
|
return compiler_fn
|
||||||
|
|
||||||
|
|
||||||
def is_aot_autograd_safe_to_run(gm, example_inputs):
|
def is_aot_autograd_safe_to_run(gm, example_inputs):
|
||||||
"""
|
"""
|
||||||
There are some known issues with Aot Autograd. This is a workaround to catch
|
There are some known issues with Aot Autograd. This is a workaround to catch
|
||||||
@ -86,107 +151,28 @@ def is_aot_autograd_safe_to_run(gm, example_inputs):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class AotAutogradStrategy(object):
|
DEBUG = False
|
||||||
"""Base class for backend strategies that use AOT Autograd"""
|
|
||||||
|
|
||||||
@classmethod
|
# Useful for debugging purpose
|
||||||
def compile_fn(cls, gm: torch.fx.GraphModule, example_inputs):
|
aot_eager = aot_autograd(fw_compiler=debug_nop if DEBUG else nop)
|
||||||
if count_calls(gm.graph) < 2:
|
|
||||||
return gm # no point for tiny graphs
|
|
||||||
return cls(gm, example_inputs).verified_candidate()
|
|
||||||
|
|
||||||
def __init__(self, gm: torch.fx.GraphModule, example_inputs):
|
# AOT Autograd with torchscript backend. Default partitioner.
|
||||||
import functorch.compile
|
aot_ts = aot_autograd(fw_compiler=ts_compile)
|
||||||
|
|
||||||
functorch.compile.config.use_functionalize = True
|
# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
|
||||||
functorch.compile.config.use_fake_tensor = True
|
# inductor problems.
|
||||||
|
aot_inductor_debug = aot_autograd(
|
||||||
super(AotAutogradStrategy, self).__init__()
|
# these are taken from memory_efficient_fusion()
|
||||||
counters["aot_autograd"]["total"] += 1
|
fw_compiler=nop,
|
||||||
self.use_fallback = False
|
bw_compiler=nop,
|
||||||
self.original_example_inputs = example_inputs
|
# NB: lambda here is to delay import of inductor
|
||||||
self.gm = gm
|
decompositions=lambda: import_module(
|
||||||
|
f"{config.inductor_import}.compile_fx"
|
||||||
if not functorch.compile.config.use_functionalize and config.normalize_ir:
|
).select_decomp_table(),
|
||||||
try:
|
partition_fn=functools.partial(
|
||||||
self.gm = normalize_ir(gm, self.example_inputs)
|
min_cut_rematerialization_partition, compiler="inductor"
|
||||||
except Exception:
|
),
|
||||||
log.debug("TorchDynamo unable to remove mutation")
|
)
|
||||||
self.use_fallback = True
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not is_aot_autograd_safe_to_run(gm, example_inputs):
|
|
||||||
self.use_fallback = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def example_inputs(self):
|
|
||||||
return clone_inputs(self.original_example_inputs)
|
|
||||||
|
|
||||||
def verified_candidate(self):
|
|
||||||
if self.use_fallback:
|
|
||||||
log.debug("Unable to use AOT Autograd because graph has mutation")
|
|
||||||
counters["aot_autograd"]["not_ok"] += 1
|
|
||||||
return self.gm
|
|
||||||
cg = self.candidate()
|
|
||||||
if cg is None:
|
|
||||||
counters["aot_autograd"]["not_ok"] += 1
|
|
||||||
raise RuntimeError("AOT Autograd failed to compile")
|
|
||||||
counters["aot_autograd"]["ok"] += 1
|
|
||||||
return cg
|
|
||||||
|
|
||||||
def candidate(self):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class AotNop(AotAutogradStrategy):
|
|
||||||
"""Useful for debugging purpose"""
|
|
||||||
|
|
||||||
def candidate(self):
|
|
||||||
from functorch._src.compilers import debug_nop
|
|
||||||
from functorch.compile import nop
|
|
||||||
|
|
||||||
DEBUG = False
|
|
||||||
return BACKENDS["aot_autograd"](
|
|
||||||
self.gm, self.example_inputs, fw_compiler=debug_nop if DEBUG else nop
|
|
||||||
) # type: ignore[call-arg]
|
|
||||||
|
|
||||||
|
|
||||||
aot_eager = AotNop.compile_fn
|
|
||||||
|
|
||||||
|
|
||||||
class AotTorchscript(AotAutogradStrategy):
|
|
||||||
"""
|
|
||||||
AOT Autograd with torchscript backend. Default partitioner.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def candidate(self):
|
|
||||||
from functorch.compile import ts_compile
|
|
||||||
|
|
||||||
return BACKENDS["aot_autograd"](
|
|
||||||
self.gm, self.example_inputs, fw_compiler=ts_compile
|
|
||||||
) # type: ignore[call-arg]
|
|
||||||
|
|
||||||
|
|
||||||
aot_ts = AotTorchscript.compile_fn
|
|
||||||
|
|
||||||
# Global counter to differentiate between different graphs.
|
|
||||||
graph_idx = 0
|
|
||||||
|
|
||||||
|
|
||||||
class AotPrint(AotNop):
|
|
||||||
"""Saves all the gm models so that we can run them separately"""
|
|
||||||
|
|
||||||
def candidate(self):
|
|
||||||
global graph_idx
|
|
||||||
module_idx = "module_" + str(graph_idx)
|
|
||||||
self.gm.to_folder(module_idx, "Bar")
|
|
||||||
for idx, x in enumerate(self.example_inputs):
|
|
||||||
torch.save(x, module_idx + "_tensor" + str(idx) + ".pt")
|
|
||||||
graph_idx += 1
|
|
||||||
return super(AotPrint, self).candidate()
|
|
||||||
|
|
||||||
|
|
||||||
aot_print = AotPrint.compile_fn
|
|
||||||
|
|
||||||
|
|
||||||
def mem_efficient_fusion_kwargs(use_decomps):
|
def mem_efficient_fusion_kwargs(use_decomps):
|
||||||
@ -209,66 +195,15 @@ def mem_efficient_fusion_kwargs(use_decomps):
|
|||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
class AotMemEfficientFusion(AotAutogradStrategy):
|
# Use min cut rematerialization and TorchScript+nvFuser with AOT Autograd
|
||||||
"""Use Min cut rematerilization and TorchScript+nvFuser with AOT Autograd"""
|
aot_mem_efficient_fusion = aot_autograd(**mem_efficient_fusion_kwargs(use_decomps=True))
|
||||||
|
aot_mem_efficient_fusion_no_decomp = aot_autograd(
|
||||||
|
**mem_efficient_fusion_kwargs(use_decomps=False)
|
||||||
|
)
|
||||||
|
|
||||||
def candidate(self):
|
# Pass TorchScript+nvFuser context to TorchDynamo
|
||||||
kwargs = mem_efficient_fusion_kwargs(use_decomps=True)
|
aot_mem_efficient_fusion.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
|
||||||
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) # type: ignore[call-arg]
|
aot_mem_efficient_fusion_no_decomp.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
|
||||||
|
|
||||||
|
|
||||||
class AotMemEfficientFusionNoDecomps(AotAutogradStrategy):
|
|
||||||
"""Use Min cut rematerilization and TorchScript+nvFuser with AOT Autograd"""
|
|
||||||
|
|
||||||
def candidate(self):
|
|
||||||
kwargs = mem_efficient_fusion_kwargs(use_decomps=False)
|
|
||||||
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) # type: ignore[call-arg]
|
|
||||||
|
|
||||||
|
|
||||||
class AotInductorDebug(AotAutogradStrategy):
|
|
||||||
"""
|
|
||||||
Uses TorchInductor Aot Autograd decopms and partitioner to isolate aot vs
|
|
||||||
inductor problems.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def candidate(self):
|
|
||||||
from functorch.compile import min_cut_rematerialization_partition, nop
|
|
||||||
|
|
||||||
decompositions = import_module(
|
|
||||||
f"{config.inductor_import}.compile_fx"
|
|
||||||
).select_decomp_table()
|
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
# these are taken from memory_efficient_fusion()
|
|
||||||
"fw_compiler": nop,
|
|
||||||
"bw_compiler": nop,
|
|
||||||
"decompositions": decompositions,
|
|
||||||
"partition_fn": functools.partial(
|
|
||||||
min_cut_rematerialization_partition, compiler="inductor"
|
|
||||||
),
|
|
||||||
}
|
|
||||||
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) # type: ignore[call-arg]
|
|
||||||
|
|
||||||
|
|
||||||
aot_inductor_debug = AotInductorDebug.compile_fn
|
|
||||||
|
|
||||||
|
|
||||||
class AOTMemEfficientFusionWithContext:
|
|
||||||
"""Pass TorchScript+nvFuser context to TorchDynamo"""
|
|
||||||
|
|
||||||
def __init__(self, use_decomps=True):
|
|
||||||
self.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
|
|
||||||
self.use_decomps = use_decomps
|
|
||||||
|
|
||||||
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
|
||||||
if self.use_decomps:
|
|
||||||
return AotMemEfficientFusion.compile_fn(gm, example_inputs)
|
|
||||||
else:
|
|
||||||
return AotMemEfficientFusionNoDecomps.compile_fn(gm, example_inputs)
|
|
||||||
|
|
||||||
|
|
||||||
aot_mem_efficient_fusion = AOTMemEfficientFusionWithContext(True)
|
|
||||||
aot_mem_efficient_fusion_no_decomp = AOTMemEfficientFusionWithContext(False)
|
|
||||||
|
|
||||||
|
|
||||||
def prims_executor(gm, inputs, *, executor):
|
def prims_executor(gm, inputs, *, executor):
|
||||||
@ -332,27 +267,15 @@ def nvprims_fw_bw_partition_fn(joint_module, joint_inputs, *, num_fwd_outputs):
|
|||||||
|
|
||||||
|
|
||||||
def create_nvprims_backend(*, executor):
|
def create_nvprims_backend(*, executor):
|
||||||
class NvPrims(AotAutogradStrategy):
|
return aot_autograd(
|
||||||
def __init__(self, gm: torch.fx.GraphModule, example_inputs):
|
fw_compiler=partial(prims_executor, executor=executor),
|
||||||
super(NvPrims, self).__init__(gm, example_inputs)
|
bw_compiler=partial(prims_executor, executor=executor),
|
||||||
self.executor = executor
|
partition_fn=nvprims_fw_bw_partition_fn,
|
||||||
|
)
|
||||||
def candidate(self):
|
|
||||||
from torch._dynamo import disable
|
|
||||||
|
|
||||||
return BACKENDS["aot_autograd"](
|
|
||||||
self.gm,
|
|
||||||
self.example_inputs,
|
|
||||||
fw_compiler=partial(prims_executor, executor=self.executor),
|
|
||||||
bw_compiler=partial(prims_executor, executor=self.executor),
|
|
||||||
partition_fn=disable(nvprims_fw_bw_partition_fn),
|
|
||||||
) # type: ignore[call-arg]
|
|
||||||
|
|
||||||
return NvPrims
|
|
||||||
|
|
||||||
|
|
||||||
aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser").compile_fn
|
aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser")
|
||||||
aot_nvprims_aten = create_nvprims_backend(executor="aten").compile_fn
|
aot_nvprims_aten = create_nvprims_backend(executor="aten")
|
||||||
|
|
||||||
|
|
||||||
def cloner(t):
|
def cloner(t):
|
||||||
@ -476,33 +399,7 @@ def cudagraphs(model, inputs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def raw_aot_autograd_cudagraphs(model, inputs):
|
aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs)
|
||||||
kwargs = {
|
|
||||||
# these are taken from memory_efficient_fusion()
|
|
||||||
"fw_compiler": cudagraphs,
|
|
||||||
"bw_compiler": cudagraphs,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _wrapped_bw_compiler(*args, **kwargs):
|
|
||||||
# stop TorchDynamo from trying to compile our generated backwards pass
|
|
||||||
return disable(disable(bw_compiler)(*args, **kwargs)) # type: ignore[operator]
|
|
||||||
|
|
||||||
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
|
|
||||||
kwargs["bw_compiler"] = _wrapped_bw_compiler
|
|
||||||
|
|
||||||
from functorch.compile import aot_module_simplified # type: ignore[import]
|
|
||||||
|
|
||||||
from .. import disable
|
|
||||||
|
|
||||||
return aot_module_simplified(model, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class AotAutogradCudaGraphs(AotAutogradStrategy):
|
|
||||||
def candidate(self):
|
|
||||||
return raw_aot_autograd_cudagraphs(self.gm, self.example_inputs)
|
|
||||||
|
|
||||||
|
|
||||||
aot_cudagraphs = AotAutogradCudaGraphs.compile_fn
|
|
||||||
|
|
||||||
|
|
||||||
def create_aot_backends():
|
def create_aot_backends():
|
||||||
@ -512,11 +409,6 @@ def create_aot_backends():
|
|||||||
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
|
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
|
||||||
BACKENDS["aot_eager"] = aot_eager
|
BACKENDS["aot_eager"] = aot_eager
|
||||||
|
|
||||||
# aot_eager uses AOT Autograd backend with print compiler. It prints the
|
|
||||||
# graphs and also saves the graph modules that are sent to AOT Autograd.
|
|
||||||
# This is helpful for debugging.
|
|
||||||
BACKENDS["aot_print"] = aot_print
|
|
||||||
|
|
||||||
# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
|
# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
|
||||||
# by using the relevant fuser with torch.jit.fuser(...)
|
# by using the relevant fuser with torch.jit.fuser(...)
|
||||||
BACKENDS["aot_ts"] = aot_ts
|
BACKENDS["aot_ts"] = aot_ts
|
||||||
|
@ -27,7 +27,7 @@ from .virtualized import V
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
ALIGNMENT = 16
|
ALIGNMENT = 16
|
||||||
|
|
||||||
aot_autograd = dynamo_optimizations.backends.aot_autograd
|
aot_autograd = dynamo_optimizations.training.aot_autograd
|
||||||
normalize_ir = dynamo_optimizations.normalize.normalize_ir
|
normalize_ir = dynamo_optimizations.normalize.normalize_ir
|
||||||
is_aot_autograd_safe_to_run = dynamo_optimizations.training.is_aot_autograd_safe_to_run
|
is_aot_autograd_safe_to_run = dynamo_optimizations.training.is_aot_autograd_safe_to_run
|
||||||
count_calls = dynamo_utils.count_calls
|
count_calls = dynamo_utils.count_calls
|
||||||
@ -394,12 +394,17 @@ def compile_fx(
|
|||||||
# in functorch/_src/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
|
# in functorch/_src/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
|
||||||
# once torchdynamo is merged into pytorch
|
# once torchdynamo is merged into pytorch
|
||||||
return aot_autograd(
|
return aot_autograd(
|
||||||
model_,
|
|
||||||
example_inputs_,
|
|
||||||
fw_compiler=fw_compiler,
|
fw_compiler=fw_compiler,
|
||||||
bw_compiler=bw_compiler,
|
bw_compiler=bw_compiler,
|
||||||
decompositions=select_decomp_table(),
|
decompositions=select_decomp_table(),
|
||||||
partition_fn=functools.partial(
|
partition_fn=functools.partial(
|
||||||
min_cut_rematerialization_partition, compiler="inductor"
|
min_cut_rematerialization_partition, compiler="inductor"
|
||||||
),
|
),
|
||||||
)
|
# A "tiny" graph can actually decompose into multiple
|
||||||
|
# operators (if it's a decomposition) and inductor can
|
||||||
|
# do a better job on it in this case
|
||||||
|
#
|
||||||
|
# Also, for some reason, test_comprehensive___rmatmul___cpu
|
||||||
|
# fails without forcing a compile lol.
|
||||||
|
force_compile_tiny_graphs=True,
|
||||||
|
)(model_, example_inputs_)
|
||||||
|
@ -1,159 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch.fx import GraphModule
|
|
||||||
from torch.nn import Module
|
|
||||||
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
|
|
||||||
from torch.multiprocessing.reductions import StorageWeakRef
|
|
||||||
from torch.utils._pytree import tree_map
|
|
||||||
import torch._dynamo # type: ignore[import]
|
|
||||||
from torch._dynamo.optimizations.training import AotAutogradStrategy # type: ignore[import]
|
|
||||||
|
|
||||||
import operator
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Set, Dict, Any
|
|
||||||
|
|
||||||
# TODO: maybe this should live in torch._dynamo instead
|
|
||||||
|
|
||||||
__all__ = ['aot_autograd_cudagraphs']
|
|
||||||
|
|
||||||
def cloner(t):
|
|
||||||
if isinstance(t, torch.Tensor):
|
|
||||||
return t.clone()
|
|
||||||
else:
|
|
||||||
return t
|
|
||||||
|
|
||||||
|
|
||||||
class CudaGraphModule(Module):
|
|
||||||
gm: GraphModule
|
|
||||||
mutated_inputs: Set[int]
|
|
||||||
|
|
||||||
def __init__(self, gm, mutated_inputs):
|
|
||||||
super().__init__()
|
|
||||||
self.gm = gm
|
|
||||||
self.mutated_inputs = mutated_inputs
|
|
||||||
|
|
||||||
warmed_up = False
|
|
||||||
|
|
||||||
# these are all None or all filled
|
|
||||||
graph = None
|
|
||||||
static_inputs = None
|
|
||||||
static_outputs = None
|
|
||||||
|
|
||||||
# NB: we override __call__ as we don't need any nn.Module machinery
|
|
||||||
# and to reduce overhead
|
|
||||||
def __call__(self, *args):
|
|
||||||
# TODO: once we've recorded here, we'd like to replace the __call__
|
|
||||||
# implementation with compiled bytecode that copies into static, replays
|
|
||||||
# the cuda graph, then copies out. First condition is the hotpath,
|
|
||||||
# needs optimizing
|
|
||||||
if self.graph is not None:
|
|
||||||
assert len(args) == len(self.static_inputs)
|
|
||||||
for dst, src in zip(self.static_inputs, args):
|
|
||||||
dst.copy_(src)
|
|
||||||
self.graph.replay()
|
|
||||||
for i in self.mutated_inputs:
|
|
||||||
args[i].copy_(self.static_inputs[i])
|
|
||||||
return tree_map(cloner, self.static_outputs)
|
|
||||||
|
|
||||||
elif self.warmed_up:
|
|
||||||
# record
|
|
||||||
self.static_inputs = [x.clone() for x in args]
|
|
||||||
self.graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.cuda.graph(self.graph):
|
|
||||||
self.static_outputs = self.gm(*self.static_inputs)
|
|
||||||
# NB: recording doesn't actually run the operations, so
|
|
||||||
# now we immediately replay the graph to serve up the result
|
|
||||||
self.graph.replay()
|
|
||||||
for i in self.mutated_inputs:
|
|
||||||
args[i].copy_(self.static_inputs[i])
|
|
||||||
return tree_map(cloner, self.static_outputs)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# warmup
|
|
||||||
stream = torch.cuda.Stream()
|
|
||||||
stream.wait_stream(torch.cuda.current_stream())
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
r = self.gm(*args)
|
|
||||||
torch.cuda.current_stream().wait_stream(stream)
|
|
||||||
self.warmed_up = True
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
# Interpreter versions of these passes can be found at
|
|
||||||
# https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23
|
|
||||||
|
|
||||||
|
|
||||||
def find_input_mutations(g):
|
|
||||||
FK = 'fake_result'
|
|
||||||
inputs = defaultdict(set)
|
|
||||||
input_idx = 0
|
|
||||||
mutated_inputs = set()
|
|
||||||
for n in g.nodes:
|
|
||||||
if n.op == 'placeholder':
|
|
||||||
inputs[StorageWeakRef(n.meta[FK]._typed_storage())].add(input_idx)
|
|
||||||
input_idx += 1
|
|
||||||
elif n.op == 'call_function':
|
|
||||||
if n.target is operator.getitem:
|
|
||||||
continue
|
|
||||||
schema = n.target._schema
|
|
||||||
for i, arg in enumerate(schema.arguments):
|
|
||||||
if i < len(n.args):
|
|
||||||
argument = n.args[i]
|
|
||||||
else:
|
|
||||||
if arg.name not in n.kwargs:
|
|
||||||
continue
|
|
||||||
argument = n.kwargs[arg.name]
|
|
||||||
mut_arg = False
|
|
||||||
if arg.alias_info:
|
|
||||||
if arg.alias_info.is_write:
|
|
||||||
mut_arg = True
|
|
||||||
if mut_arg:
|
|
||||||
# TODO: not correct for args that contain tensors in a struct
|
|
||||||
# like list
|
|
||||||
mutated_inputs |= inputs[StorageWeakRef(argument.meta[FK]._typed_storage())]
|
|
||||||
# TODO: error on unrecognized nodes
|
|
||||||
return mutated_inputs
|
|
||||||
|
|
||||||
|
|
||||||
# Mutates input graph
|
|
||||||
def apply_cuda_graphs(gm):
|
|
||||||
for n in gm.graph.nodes:
|
|
||||||
if n.op == 'call_module':
|
|
||||||
assert not n.kwargs
|
|
||||||
submod = gm.get_submodule(n.target)
|
|
||||||
gm.delete_submodule(n.target)
|
|
||||||
mutated_inputs = find_input_mutations(submod.graph)
|
|
||||||
gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs))
|
|
||||||
# NB: we didn't actually change the graph, no need for recompile
|
|
||||||
|
|
||||||
|
|
||||||
def cudagraphs(model, inputs):
|
|
||||||
model = partition_cudagraphs(model, inputs)
|
|
||||||
apply_cuda_graphs(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def raw_aot_autograd_cudagraphs(model, inputs):
|
|
||||||
kwargs: Dict[str, Any] = {
|
|
||||||
# these are taken from memory_efficient_fusion()
|
|
||||||
"fw_compiler": cudagraphs,
|
|
||||||
"bw_compiler": cudagraphs,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _wrapped_bw_compiler(*args, **kwargs):
|
|
||||||
# stop dynamo from trying to compile our generated backwards pass
|
|
||||||
return torch._dynamo.disable(bw_compiler(*args, **kwargs)) # type: ignore[operator]
|
|
||||||
|
|
||||||
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
|
|
||||||
kwargs["bw_compiler"] = _wrapped_bw_compiler
|
|
||||||
|
|
||||||
from functorch.compile import aot_module_simplified # type: ignore[import]
|
|
||||||
|
|
||||||
return aot_module_simplified(model, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class AOTAutogradCudaGraphs(AotAutogradStrategy):
|
|
||||||
def candidate(self):
|
|
||||||
return raw_aot_autograd_cudagraphs(self.gm, self.example_inputs)
|
|
||||||
|
|
||||||
|
|
||||||
aot_autograd_cudagraphs = AOTAutogradCudaGraphs.compile_fn
|
|
Reference in New Issue
Block a user