mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Fix invalid f-strings detected by ruff. --------- Signed-off-by: cyy <cyyever@outlook.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Michael Wyatt <michael.wyatt@snowflake.com>
140 lines
4.6 KiB
Python
140 lines
4.6 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
from typing import Callable, Any, List
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
from torch.fx import Node, Graph
|
|
|
|
from .util import get_last_uses
|
|
|
|
|
|
def get_output_node(graph: Graph):
|
|
for v in graph.nodes:
|
|
if v.target == "output":
|
|
return v
|
|
raise ValueError("No output node found")
|
|
|
|
|
|
def move_primals_to_head(graph: Graph):
|
|
|
|
# Move primals to the head of the graph
|
|
primals = [n for n in graph.nodes if n.op == "placeholder"]
|
|
non_primals = [n for n in graph.nodes if n.op != "placeholder"]
|
|
all_nodes = primals + non_primals
|
|
|
|
new_graph = Graph()
|
|
env = {}
|
|
for node in all_nodes:
|
|
new_node = new_graph.node_copy(node, lambda n: env[n.name])
|
|
env[node.name] = new_node
|
|
new_graph.lint()
|
|
|
|
return new_graph
|
|
|
|
|
|
def add_args_process(graph: Graph,
|
|
node: Node,
|
|
fn: Callable[..., Any],
|
|
extra_args: List[int] = [],
|
|
name=None,
|
|
meta={}) -> List[Node]:
|
|
# Apply fn to all args of node
|
|
new_nodes = []
|
|
with graph.inserting_before(node):
|
|
target_args = [arg for arg in node.args if isinstance(arg, Node)]
|
|
|
|
for arg in target_args:
|
|
new_node = graph.create_node('call_function', fn, (arg, ) + tuple(extra_args), name=name)
|
|
for k, v in meta.items():
|
|
new_node.meta[k] = v
|
|
node.replace_input_with(arg, new_node)
|
|
new_nodes.append(new_node)
|
|
|
|
return new_nodes
|
|
|
|
|
|
def add_postprocess(graph: Graph,
|
|
node: Node,
|
|
fn: Callable[..., Any],
|
|
extra_args: List[int] = [],
|
|
name=None,
|
|
meta={}) -> Node:
|
|
# https://github.com/pytorch/examples/blob/main/fx/wrap_output_dynamically.py
|
|
with graph.inserting_after(node):
|
|
args = (node, )
|
|
for a in extra_args: # To add ds_id
|
|
args += (a, )
|
|
|
|
node_users = node.users.keys()
|
|
new_node = graph.create_node('call_function', fn, args, {}, name=name)
|
|
users = {}
|
|
for u in node_users:
|
|
if u != new_node:
|
|
users[u] = (node, new_node)
|
|
for u, (old_in, new_in) in users.items():
|
|
u.replace_input_with(old_in, new_in)
|
|
|
|
for k, v in meta.items():
|
|
new_node.meta[k] = v
|
|
|
|
return new_node
|
|
|
|
|
|
def _make_node_meta(node: Node, ds_id: int, comm: bool):
|
|
meta = {"param_name": node.name, "ds_id": ds_id, "comm": comm}
|
|
if "tensor_meta" in node.meta:
|
|
meta["tensor_meta"] = node.meta["tensor_meta"]
|
|
return meta
|
|
|
|
|
|
def add_free_activations(graph_id: int, graph: Graph, activation_node_names: List[str]):
|
|
node_to_last_use, _ = get_last_uses(graph)
|
|
activation_nodes_set = set([n for n in graph.nodes if n.op == "placeholder" and n.name in activation_node_names])
|
|
|
|
offload_id_to_node = {}
|
|
node_to_wait_reload = {}
|
|
for node in graph.nodes:
|
|
if node.target == torch.ops.dc.reload_tensor.default:
|
|
offload_act = node.args[0]
|
|
# node_to_offload_id[offload_act] = node.args[2]
|
|
offload_id_to_node[node.args[2]] = offload_act
|
|
elif node.target == torch.ops.dc.wait_reload.default:
|
|
offload_id = node.args[2]
|
|
node_to_wait_reload[offload_id_to_node[offload_id]] = node
|
|
|
|
activation_nodes_set = set(node_to_wait_reload[n] if n in node_to_wait_reload else n for n in activation_nodes_set)
|
|
|
|
last_user_to_uses = defaultdict(list)
|
|
for node, last_user in node_to_last_use.items():
|
|
last_user_to_uses[last_user].append(node)
|
|
|
|
def _should_free(node: Node) -> bool:
|
|
if not hasattr(node, "meta"):
|
|
return False
|
|
if "tensor_meta" not in node.meta:
|
|
return False
|
|
return True
|
|
|
|
def free_tensors(tensors: List[torch.Tensor]):
|
|
for a in tensors:
|
|
if a.numel() > 10_000_000:
|
|
a.data = torch.empty([0], device=a.device, dtype=a.dtype)
|
|
|
|
for last_user, used_nodes in last_user_to_uses.items():
|
|
activation_args = [an for an in used_nodes if an in activation_nodes_set and _should_free(an)]
|
|
|
|
if len(activation_args) == 0:
|
|
continue
|
|
|
|
node_name = f"free_activations_{[n.name for n in used_nodes]}"
|
|
with graph.inserting_after(last_user):
|
|
args = (activation_args, )
|
|
graph.create_node('call_function', torch.ops.dc.free_tensors.default, args, {}, name=node_name)
|
|
|
|
# Python version for debugging
|
|
# graph.create_node('call_function', free_tensors, args, {}, name=node_name)
|