mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This makes good on Chillee's CR comment at
af30d351cc (r843315222)
which was never done in the original PR.
There is no logic change, just unpack the args/kwargs at the top
level and remove the inner function indirection.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89666
Approved by: https://github.com/voznesenskym
160 lines
5.4 KiB
Python
160 lines
5.4 KiB
Python
import torch
|
|
from torch.fx import GraphModule
|
|
from torch.nn import Module
|
|
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
from torch.utils._pytree import tree_map
|
|
import torch._dynamo # type: ignore[import]
|
|
from torch._dynamo.optimizations.training import AotAutogradStrategy # type: ignore[import]
|
|
|
|
import operator
|
|
from collections import defaultdict
|
|
from typing import Set, Dict, Any
|
|
|
|
# TODO: maybe this should live in torch._dynamo instead
|
|
|
|
__all__ = ['aot_autograd_cudagraphs']
|
|
|
|
def cloner(t):
|
|
if isinstance(t, torch.Tensor):
|
|
return t.clone()
|
|
else:
|
|
return t
|
|
|
|
|
|
class CudaGraphModule(Module):
|
|
gm: GraphModule
|
|
mutated_inputs: Set[int]
|
|
|
|
def __init__(self, gm, mutated_inputs):
|
|
super().__init__()
|
|
self.gm = gm
|
|
self.mutated_inputs = mutated_inputs
|
|
|
|
warmed_up = False
|
|
|
|
# these are all None or all filled
|
|
graph = None
|
|
static_inputs = None
|
|
static_outputs = None
|
|
|
|
# NB: we override __call__ as we don't need any nn.Module machinery
|
|
# and to reduce overhead
|
|
def __call__(self, *args):
|
|
# TODO: once we've recorded here, we'd like to replace the __call__
|
|
# implementation with compiled bytecode that copies into static, replays
|
|
# the cuda graph, then copies out. First condition is the hotpath,
|
|
# needs optimizing
|
|
if self.graph is not None:
|
|
assert len(args) == len(self.static_inputs)
|
|
for dst, src in zip(self.static_inputs, args):
|
|
dst.copy_(src)
|
|
self.graph.replay()
|
|
for i in self.mutated_inputs:
|
|
args[i].copy_(self.static_inputs[i])
|
|
return tree_map(cloner, self.static_outputs)
|
|
|
|
elif self.warmed_up:
|
|
# record
|
|
self.static_inputs = [x.clone() for x in args]
|
|
self.graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(self.graph):
|
|
self.static_outputs = self.gm(*self.static_inputs)
|
|
# NB: recording doesn't actually run the operations, so
|
|
# now we immediately replay the graph to serve up the result
|
|
self.graph.replay()
|
|
for i in self.mutated_inputs:
|
|
args[i].copy_(self.static_inputs[i])
|
|
return tree_map(cloner, self.static_outputs)
|
|
|
|
else:
|
|
# warmup
|
|
stream = torch.cuda.Stream()
|
|
stream.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(stream):
|
|
r = self.gm(*args)
|
|
torch.cuda.current_stream().wait_stream(stream)
|
|
self.warmed_up = True
|
|
return r
|
|
|
|
|
|
# Interpreter versions of these passes can be found at
|
|
# https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23
|
|
|
|
|
|
def find_input_mutations(g):
|
|
FK = 'fake_result'
|
|
inputs = defaultdict(set)
|
|
input_idx = 0
|
|
mutated_inputs = set()
|
|
for n in g.nodes:
|
|
if n.op == 'placeholder':
|
|
inputs[StorageWeakRef(n.meta[FK]._typed_storage())].add(input_idx)
|
|
input_idx += 1
|
|
elif n.op == 'call_function':
|
|
if n.target is operator.getitem:
|
|
continue
|
|
schema = n.target._schema
|
|
for i, arg in enumerate(schema.arguments):
|
|
if i < len(n.args):
|
|
argument = n.args[i]
|
|
else:
|
|
if arg.name not in n.kwargs:
|
|
continue
|
|
argument = n.kwargs[arg.name]
|
|
mut_arg = False
|
|
if arg.alias_info:
|
|
if arg.alias_info.is_write:
|
|
mut_arg = True
|
|
if mut_arg:
|
|
# TODO: not correct for args that contain tensors in a struct
|
|
# like list
|
|
mutated_inputs |= inputs[StorageWeakRef(argument.meta[FK]._typed_storage())]
|
|
# TODO: error on unrecognized nodes
|
|
return mutated_inputs
|
|
|
|
|
|
# Mutates input graph
|
|
def apply_cuda_graphs(gm):
|
|
for n in gm.graph.nodes:
|
|
if n.op == 'call_module':
|
|
assert not n.kwargs
|
|
submod = gm.get_submodule(n.target)
|
|
gm.delete_submodule(n.target)
|
|
mutated_inputs = find_input_mutations(submod.graph)
|
|
gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs))
|
|
# NB: we didn't actually change the graph, no need for recompile
|
|
|
|
|
|
def cudagraphs(model, inputs):
|
|
model = partition_cudagraphs(model, inputs)
|
|
apply_cuda_graphs(model)
|
|
return model
|
|
|
|
|
|
def raw_aot_autograd_cudagraphs(model, inputs):
|
|
kwargs: Dict[str, Any] = {
|
|
# these are taken from memory_efficient_fusion()
|
|
"fw_compiler": cudagraphs,
|
|
"bw_compiler": cudagraphs,
|
|
}
|
|
|
|
def _wrapped_bw_compiler(*args, **kwargs):
|
|
# stop dynamo from trying to compile our generated backwards pass
|
|
return torch._dynamo.disable(bw_compiler(*args, **kwargs)) # type: ignore[operator]
|
|
|
|
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
|
|
kwargs["bw_compiler"] = _wrapped_bw_compiler
|
|
|
|
from functorch.compile import aot_module_simplified # type: ignore[import]
|
|
|
|
return aot_module_simplified(model, **kwargs)
|
|
|
|
|
|
class AOTAutogradCudaGraphs(AotAutogradStrategy):
|
|
def candidate(self):
|
|
return raw_aot_autograd_cudagraphs(self.gm, self.example_inputs)
|
|
|
|
|
|
aot_autograd_cudagraphs = AOTAutogradCudaGraphs.compile_fn
|