mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
This PR introduces *DeepCompile*, a new feature that efficiently integrates compiler optimizations with other DeepSpeed features. DeepCompile utilizes torch's dynamo to capture the computation graph and modifies it to incorporate DeepSpeed’s optimizations seamlessly. Currently, DeepCompile supports ZeRO-1 and ZeRO-3, with enhancements such as proactive prefetching and selective unsharding to improve performance. (More details will be added later.) --------- Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com> Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: zafarsadiq <zafarsadiq120@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
215 lines
8.7 KiB
Python
215 lines
8.7 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import torch
|
|
|
|
try:
|
|
import torch.utils._pytree as pytree
|
|
from torch._functorch.aot_autograd import create_aot_dispatcher_function
|
|
from torch._inductor.lowering import register_lowering, fallbacks, add_needs_realized_inputs
|
|
from torch._inductor.ir import TensorBox, FallbackKernel, Layout, IRNode
|
|
from torch._inductor.virtualized import V
|
|
from torch._inductor.scheduler import Scheduler
|
|
|
|
original_create_aot_dispatcher_function = create_aot_dispatcher_function
|
|
except ImportError:
|
|
pass
|
|
|
|
from .util import get_input_nodes
|
|
from .graph_param import DSGraphParamManager
|
|
|
|
|
|
def patch_compiler(original_compiler, dc_compiler, z3_partition: bool, graph_id, graph_param_manager, bwd: bool):
|
|
|
|
def wrapped_compiler(gm, fake_inputs):
|
|
mod_graph = dc_compiler(gm, fake_inputs)
|
|
|
|
# For symint case
|
|
if mod_graph is None:
|
|
return None
|
|
|
|
if z3_partition:
|
|
# Inductor validates input size estimated by the first trace, where ds tensor is materialized.
|
|
# We need to patch the input tensors to avoid the validation error.
|
|
patched_inputs = []
|
|
if bwd:
|
|
param_nodes_bw, _ = graph_param_manager[graph_id].get_bwd_mapping(gm.graph)
|
|
param_names = [n.name for n in param_nodes_bw]
|
|
else:
|
|
param_names = graph_param_manager[graph_id].param_names
|
|
input_nodes = get_input_nodes(gm.graph)
|
|
|
|
for in_node, in_v in zip(input_nodes, fake_inputs):
|
|
ds_param = in_node.name in param_names
|
|
if ds_param:
|
|
from torch._subclasses.fake_tensor import is_fake
|
|
from torch._dynamo.utils import to_fake_tensor
|
|
assert is_fake(in_v), f"Input {in_v} should be fake tensor"
|
|
patched_inputs.append(
|
|
to_fake_tensor(torch.empty([0], dtype=in_v.dtype, device=in_v.device), in_v.fake_mode))
|
|
else:
|
|
patched_inputs.append(in_v)
|
|
|
|
patched_inputs = tuple(patched_inputs)
|
|
else:
|
|
patched_inputs = fake_inputs
|
|
|
|
return original_compiler(gm, patched_inputs)
|
|
|
|
return wrapped_compiler
|
|
|
|
|
|
def wrap_partition_fn(partition_fn, real_inputs, param_indices):
|
|
|
|
def wrapped_partition_fn(*args, **kwargs):
|
|
|
|
fw_module, bw_module = partition_fn(*args, **kwargs)
|
|
|
|
# get parameter names
|
|
pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices)
|
|
|
|
def fix_placeholder_meta(graph):
|
|
for n in graph.nodes:
|
|
if n.op == "placeholder" and n.name in pm.param_names:
|
|
n.meta["val"] = torch.empty([0], dtype=n.meta["val"].dtype, device=n.meta["val"].device)
|
|
|
|
fix_placeholder_meta(fw_module.graph)
|
|
fix_placeholder_meta(bw_module.graph)
|
|
|
|
return fw_module, bw_module
|
|
|
|
return wrapped_partition_fn
|
|
|
|
|
|
def patch_create_aot_dispatcher_function(graph_id: int, z3_partition: bool, make_fw_graph, make_bw_graph, real_inputs,
|
|
param_indices, param_manager):
|
|
|
|
from torch._dynamo.backends.common import AotAutograd
|
|
import functools
|
|
|
|
def patch_aotautograd():
|
|
# Unpatch if it was already patched
|
|
if hasattr(AotAutograd, "__original_init"):
|
|
AotAutograd.__init__ = AotAutograd.__original_init
|
|
|
|
original_init = AotAutograd.__init__
|
|
|
|
@functools.wraps(original_init)
|
|
def patched_init(self, **kwargs):
|
|
kwargs["fw_compiler"] = patch_compiler(kwargs["fw_compiler"],
|
|
make_fw_graph,
|
|
z3_partition,
|
|
graph_id,
|
|
param_manager,
|
|
bwd=False)
|
|
kwargs["bw_compiler"] = patch_compiler(kwargs["bw_compiler"],
|
|
make_bw_graph,
|
|
z3_partition,
|
|
graph_id,
|
|
param_manager,
|
|
bwd=True)
|
|
kwargs["inference_compiler"] = kwargs["fw_compiler"]
|
|
|
|
if z3_partition:
|
|
kwargs["partition_fn"] = wrap_partition_fn(kwargs["partition_fn"], real_inputs, param_indices)
|
|
|
|
original_init(self, **kwargs)
|
|
|
|
AotAutograd.__original_init = original_init
|
|
AotAutograd.__init__ = patched_init
|
|
|
|
patch_aotautograd()
|
|
|
|
|
|
def register_custom_ops():
|
|
|
|
def fallback_handler_no_reuse(kernel,
|
|
never_reuse_input,
|
|
never_reuse_output,
|
|
force_free_input,
|
|
add_to_fallback_set=True):
|
|
if add_to_fallback_set:
|
|
fallbacks.add(kernel)
|
|
|
|
def handler(*args, **kwargs):
|
|
|
|
def wrap_tensors(x):
|
|
out = TensorBox.create(x) if isinstance(x, torch._inductor.ir.IRNode) else x
|
|
if out is not None and never_reuse_output:
|
|
V.graph.never_reuse_buffers.add(out.get_name())
|
|
return out
|
|
|
|
class CustomDCKernel(FallbackKernel):
|
|
|
|
def __init__(self, op, *args, **kwargs):
|
|
super().__init__(op, *args, **kwargs)
|
|
|
|
def add_to_never_reuse(x):
|
|
if isinstance(x, IRNode):
|
|
assert hasattr(x, "get_name"), f"x doesn't have get_name {x.__class__}"
|
|
V.graph.never_reuse_buffers.add(x.get_name())
|
|
|
|
if never_reuse_input:
|
|
pytree.tree_map(add_to_never_reuse, args)
|
|
|
|
def get_var_name_for_arg(self, arg: str):
|
|
if arg.isidentifier():
|
|
return arg
|
|
|
|
import re
|
|
match = re.match(r"reinterpret_tensor\((\w+),", arg)
|
|
if match:
|
|
return match.group(1)
|
|
return None
|
|
|
|
def codegen(self, wrapper):
|
|
if not force_free_input:
|
|
return super().codegen(wrapper)
|
|
|
|
kernel = self.op_overload
|
|
self.codegen_comment(wrapper)
|
|
args = [*self.codegen_args(), *self.codegen_kwargs()]
|
|
|
|
V.graph.wrapper_code.generate_fallback_kernel(self, args)
|
|
if isinstance(self.layout, Layout):
|
|
self.codegen_size_asserts(wrapper)
|
|
|
|
var_name = self.get_var_name_for_arg(args[0])
|
|
if var_name:
|
|
wrapper.writeline(f"{var_name} = None")
|
|
|
|
self.codegen_unbacked_symbol_defs(wrapper)
|
|
|
|
kernel_cls = CustomDCKernel if force_free_input else FallbackKernel
|
|
return pytree.tree_map(wrap_tensors, kernel_cls.create(kernel, *args, **kwargs))
|
|
|
|
return handler
|
|
|
|
def register_fallback_no_reuse(op_overload,
|
|
never_reuse_input=False,
|
|
never_reuse_output=False,
|
|
force_free_input=False):
|
|
add_needs_realized_inputs(op_overload)
|
|
return register_lowering(op_overload, type_promotion_kind=None)(fallback_handler_no_reuse(
|
|
op_overload,
|
|
never_reuse_input=never_reuse_input,
|
|
never_reuse_output=never_reuse_output,
|
|
force_free_input=force_free_input))
|
|
|
|
# Inductor tries to reuse output buffer when possible. We need to disable this behavior for some custom ops.
|
|
# -> It seems that memory region is still reused in some cases. So we clone the inputs for some ops.
|
|
register_fallback_no_reuse(torch.ops.dc.allgather_param.default, never_reuse_input=False, never_reuse_output=True)
|
|
register_fallback_no_reuse(torch.ops.dc.wait_allgather.default, never_reuse_input=True, never_reuse_output=True)
|
|
register_fallback_no_reuse(torch.ops.dc.release_param.default, never_reuse_input=True, never_reuse_output=False)
|
|
register_fallback_no_reuse(torch.ops.dc.reduce_grad.default,
|
|
never_reuse_input=True,
|
|
never_reuse_output=True,
|
|
force_free_input=True)
|
|
register_fallback_no_reuse(torch.ops.dc.free_tensors.default, never_reuse_input=True, never_reuse_output=True)
|
|
|
|
if not hasattr(Scheduler, "is_dc_patched") or not Scheduler.is_dc_patched:
|
|
Scheduler.is_dc_patched = True
|
|
Scheduler.dead_node_elimination = lambda _: None
|