mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86950 Approved by: https://github.com/Chillee
188 lines
7.4 KiB
Python
188 lines
7.4 KiB
Python
from typing import Any, List
|
|
|
|
import torch
|
|
import torch.fx.traceback as fx_traceback
|
|
from torch import fx
|
|
from torch.fx.node import Node
|
|
|
|
|
|
def args_str(args):
|
|
# a debug helper
|
|
if torch.is_tensor(args):
|
|
return f"T[{args.shape}]"
|
|
elif isinstance(args, tuple):
|
|
return f"tuple({', '.join([args_str(x) for x in args])})"
|
|
elif isinstance(args, list):
|
|
return f"list({', '.join([args_str(x) for x in args])})"
|
|
else:
|
|
return str(args)
|
|
|
|
|
|
class DDPOptimizer:
|
|
def __init__(
|
|
self,
|
|
bucket_bytes_cap: int,
|
|
parameters_to_ignore: List[str],
|
|
backend_compile_fn,
|
|
debug=False,
|
|
):
|
|
self.bucket_bytes_cap = bucket_bytes_cap
|
|
self.parameters_to_ignore = parameters_to_ignore
|
|
self.backend_compile_fn = backend_compile_fn
|
|
self.debug = debug
|
|
|
|
def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
|
|
"""
|
|
TODO:
|
|
- handle params_and_buffers_to_ignore
|
|
- handle kwargs
|
|
"""
|
|
|
|
# 1: compute the partition map according to DDP bucket logic
|
|
bucket_bytes = 0
|
|
bucket_actual_sizes = []
|
|
node_splits = [[]]
|
|
for node in reversed(gm.graph.nodes):
|
|
if node.op == "output" or node.op == "placeholder":
|
|
continue
|
|
|
|
if bucket_bytes >= self.bucket_bytes_cap:
|
|
bucket_actual_sizes.insert(0, bucket_bytes)
|
|
bucket_bytes = 0
|
|
node_splits.insert(0, [])
|
|
|
|
elif node.op == "call_module":
|
|
target = gm.get_submodule(node.target)
|
|
params_size_b = sum(
|
|
[
|
|
p.storage().nbytes()
|
|
for p in target.parameters()
|
|
if p.requires_grad
|
|
]
|
|
)
|
|
bucket_bytes += params_size_b
|
|
# print(f"accumulated {params_size_b} b from {node}")
|
|
elif node.op == "get_attr":
|
|
maybe_param = getattr(gm, node.target)
|
|
if maybe_param.requires_grad:
|
|
bucket_bytes += maybe_param.storage().nbytes()
|
|
else:
|
|
# TODO(whc) confirm this:
|
|
# (e.g. call_method, call_function aren't expected to 'have' parameters)
|
|
pass
|
|
|
|
node_splits[0].append(node)
|
|
|
|
if len(node_splits) == 1:
|
|
if self.debug:
|
|
print(
|
|
"DDPOptimizer did not split graphs."
|
|
f" Accumulated {bucket_bytes} bytes, and bucket cap is {self.bucket_bytes_cap}"
|
|
)
|
|
return self.backend_compile_fn(gm, example_inputs)
|
|
|
|
if len(bucket_actual_sizes) < len(node_splits):
|
|
bucket_actual_sizes.insert(0, bucket_bytes)
|
|
|
|
if self.debug:
|
|
print(
|
|
f"DDPOptimizer used bucket cap {self.bucket_bytes_cap}"
|
|
f" and split graphs into parameter sizes {', '.join([str(b) for b in bucket_actual_sizes])}"
|
|
)
|
|
|
|
# 2: partition the graphmodule according to bucket capacity
|
|
partition_map = {}
|
|
for p, nodes in enumerate(node_splits):
|
|
for node in nodes:
|
|
partition_map[node] = p
|
|
|
|
split_gm = fx.passes.split_module.split_module(
|
|
gm, None, lambda node: partition_map[node]
|
|
)
|
|
if self.debug:
|
|
with open("debug_ddp_optimizer.log", "w") as dump_file:
|
|
dump_file.write("---orig graph---")
|
|
dump_file.write(str(gm.graph))
|
|
dump_file.write("\n---split graph---")
|
|
dump_file.write(str(split_gm.graph))
|
|
|
|
# 3: compile each of the partitioned submodules using the user-provided compiler
|
|
class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
|
def __init__(self, module, compiler, debug=False):
|
|
super().__init__(module)
|
|
self.compiler = compiler
|
|
self.debug = debug
|
|
|
|
def compile_submod(self, submod, args, kwargs):
|
|
"""
|
|
Compile the submodule,
|
|
using a wrapper to make sure its output is always a tuple,
|
|
which is required by AotAutograd based compilers
|
|
"""
|
|
assert len(kwargs) == 0, "We assume only args for these modules"
|
|
|
|
class WrapperModule(torch.nn.Module):
|
|
def __init__(self, compiled_submod, unwrap_singleton_tuple):
|
|
super().__init__()
|
|
self.compiled_submod = compiled_submod
|
|
self.unwrap_singleton_tuple = unwrap_singleton_tuple
|
|
|
|
def forward(self, *args):
|
|
x = self.compiled_submod(*args)
|
|
# TODO(whc)
|
|
# for some reason the isinstance check is necessary if I split one node per submod
|
|
# - even though I supposedly wrapped the output in a tuple in those cases, the real
|
|
# compiled module was still returning a tensor
|
|
if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
|
|
return x[0]
|
|
return x
|
|
|
|
unwrap_singleton_tuple = False
|
|
for sn in submod.graph.nodes:
|
|
if sn.op == "output":
|
|
if not isinstance(sn.args[0], tuple):
|
|
unwrap_singleton_tuple = True
|
|
sn.args = (sn.args,)
|
|
submod.recompile()
|
|
|
|
wrapper = WrapperModule(
|
|
self.compiler(submod, args),
|
|
unwrap_singleton_tuple,
|
|
)
|
|
return wrapper
|
|
|
|
def run_node(self, n: Node) -> Any:
|
|
with fx_traceback.append_stack_trace(n.stack_trace):
|
|
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
|
if self.debug:
|
|
print(f"run_node {n.op}, {n.target} got args {args_str(args)}")
|
|
assert isinstance(args, tuple)
|
|
assert isinstance(kwargs, dict)
|
|
|
|
# modify the currently running FX graph
|
|
# maybe this isn't sound in general, but only changing the target of a node might be ok?
|
|
if n.op == "call_module":
|
|
submod = self.fetch_attr(n.target)
|
|
if self.debug:
|
|
with open("debug_ddp_optimizer.log", "a") as dump_file:
|
|
dump_file.write(f"\n---{n.target} graph---")
|
|
dump_file.write(str(submod.graph))
|
|
compiled_submod = self.compile_submod(submod, args, kwargs)
|
|
self.module.delete_submodule(n.target)
|
|
n.target = "compiled_" + n.target
|
|
self.module.add_submodule(n.target, compiled_submod)
|
|
|
|
# then we execute the modified node using the usual logic
|
|
return getattr(self, n.op)(n.target, args, kwargs)
|
|
|
|
submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, self.debug)
|
|
submod_compiler.run(*example_inputs)
|
|
split_gm.recompile()
|
|
|
|
if self.debug:
|
|
with open("debug_ddp_optimizer.log", "a") as dump_file:
|
|
dump_file.write("\n---final graph---")
|
|
dump_file.write(str(split_gm.graph))
|
|
|
|
return split_gm
|