mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
all_gather_bucketing fx pass (#157396)
Porting passes to bucket all_gathers The main logic of the pass is done via 1. Searching for all all_gathers from the buckets Copying tests from @wconstab PR to test compatibility with reordering. Test checks only compatibility, as because of (3) the joint all_gather will be scheduled already as early as possible and no space for reordering. Pass changes: Using mutation ops to match performance of fsdp, in future the perfect scenario will be to have only functional graph, that inductor does all memory optimizations on its own without mutable ops. Inductor changes: Adding foreach_copy_ lowering Pull Request resolved: https://github.com/pytorch/pytorch/pull/157396 Approved by: https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
19ae5afdaa
commit
7b392bac13
@ -25,6 +25,7 @@ from torch._inductor.scheduler import BaseSchedulerNode
|
||||
from torch._inductor.utils import run_and_get_triton_code
|
||||
from torch.distributed.distributed_c10d import GroupMember
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_cuda import SM80OrLater
|
||||
from torch.testing._internal.common_distributed import (
|
||||
_dynamo_dist_per_rank_init,
|
||||
DynamoDistributedMultiProcTestCase,
|
||||
@ -1503,6 +1504,179 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
self.assertEqual(stats.limiting_factor, "data dependency")
|
||||
self.assertEqual(stats.moves, 0)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||
def test_all_gather_bucket(self):
|
||||
def func(x, w, ag_0, ag_1, *, tag, ranks, group_size):
|
||||
# do some unrelated matmuls
|
||||
y = torch.mm(x, w)
|
||||
|
||||
# cast the inputs
|
||||
ag_0_cast = ag_0.to(torch.bfloat16)
|
||||
ag_1_cast = ag_1.to(torch.bfloat16)
|
||||
|
||||
# allgather
|
||||
group_name = (
|
||||
torch.distributed.distributed_c10d._get_default_group().group_name
|
||||
)
|
||||
ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
ag_0_cast, group_size, group_name
|
||||
)
|
||||
ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
ag_1_cast, group_size, group_name
|
||||
)
|
||||
|
||||
# wait op
|
||||
ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_out)
|
||||
ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_out)
|
||||
|
||||
return y, ag_0_out, ag_1_out
|
||||
|
||||
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
|
||||
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ag_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
inputs = [x, w, ag_0, ag_1]
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
{
|
||||
"bucket_all_gathers_fx": "fsdp",
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
||||
# NOTE: The first return value should be the output of the first wait_tensor.
|
||||
# We want to make sure no unneccessary copy is made.
|
||||
(FileCheck().check("all_gather_into_tensor_out").run(code))
|
||||
out = compiled(*inputs, **self.get_world_trs())
|
||||
correct = func(*inputs, **self.get_world_trs())
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||
def test_reorder_peak_memory_bucketed(self):
|
||||
"""
|
||||
Simulate the case where a bucketing pass ran and grouped several inputs into one bucketed allgather.
|
||||
Ensure the whole bucketed group including copy-ops get moved together rather than the copy ops preventing the
|
||||
comm from moving due to data dependency.
|
||||
"""
|
||||
|
||||
def func(x, w, ag_0, ag_1, *, tag, ranks, group_size):
|
||||
# do some unrelated matmuls
|
||||
y = torch.mm(x, w)
|
||||
|
||||
# cast the inputs
|
||||
ag_0_cast = ag_0.to(torch.bfloat16)
|
||||
ag_1_cast = ag_1.to(torch.bfloat16)
|
||||
|
||||
# allgather
|
||||
group_name = (
|
||||
torch.distributed.distributed_c10d._get_default_group().group_name
|
||||
)
|
||||
ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
ag_0_cast, group_size, group_name
|
||||
)
|
||||
ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
ag_1_cast, group_size, group_name
|
||||
)
|
||||
|
||||
# wait op
|
||||
ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_out)
|
||||
ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_out)
|
||||
|
||||
return y, ag_0_out, ag_1_out
|
||||
|
||||
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
|
||||
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ag_1 = torch.ones(512, device="cuda", dtype=torch.float32)
|
||||
inputs = [x, w, ag_0, ag_1]
|
||||
|
||||
# get stats directly from the internal helper without affecting the real pass's signature
|
||||
node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None
|
||||
|
||||
def _reorder_communication_preserving_peak_memory(
|
||||
snodes: list[BaseSchedulerNode],
|
||||
) -> list[BaseSchedulerNode]:
|
||||
nonlocal node_stats
|
||||
(
|
||||
reordered_snodes,
|
||||
node_stats,
|
||||
) = _reorder_communication_preserving_peak_memory_internal(snodes)
|
||||
return reordered_snodes
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
{
|
||||
"bucket_all_gathers_fx": "all",
|
||||
"reorder_for_compute_comm_overlap": True,
|
||||
"reorder_for_compute_comm_overlap_passes": [
|
||||
"sink_waits",
|
||||
# same as reorder_communication_preserving_peak_memory but returns debug info structures directly
|
||||
_reorder_communication_preserving_peak_memory,
|
||||
],
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
||||
# NOTE: The first return value should be the output of the first wait_tensor.
|
||||
# We want to make sure no unneccessary copy is made.
|
||||
(FileCheck().check("all_gather_into_tensor_out").run(code))
|
||||
out = compiled(*inputs, **self.get_world_trs())
|
||||
correct = func(*inputs, **self.get_world_trs())
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
assert node_stats is not None
|
||||
self.assertTrue(isinstance(node_stats, dict))
|
||||
self.assertEqual(len(node_stats), 1)
|
||||
|
||||
# TODO: Debug why reordering does not move collective after bucketing
|
||||
# for stats in node_stats.values():
|
||||
# self.assertEqual(stats.initial_exposed, 0)
|
||||
def _reorder_communication_preserving_peak_memory(
|
||||
snodes: list[BaseSchedulerNode],
|
||||
) -> list[BaseSchedulerNode]:
|
||||
nonlocal node_stats
|
||||
(
|
||||
reordered_snodes,
|
||||
node_stats,
|
||||
) = _reorder_communication_preserving_peak_memory_internal(snodes)
|
||||
return reordered_snodes
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
{
|
||||
"reorder_for_compute_comm_overlap": True,
|
||||
"reorder_for_compute_comm_overlap_passes": [
|
||||
"sink_waits",
|
||||
# same as reorder_communication_preserving_peak_memory but returns debug info structures directly
|
||||
_reorder_communication_preserving_peak_memory,
|
||||
],
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
||||
# NOTE: The first return value should be the output of the first wait_tensor.
|
||||
# We want to make sure no unneccessary copy is made.
|
||||
(
|
||||
FileCheck()
|
||||
.check("all_gather")
|
||||
.check("wait")
|
||||
.check("all_gather")
|
||||
.check("wait")
|
||||
.run(code)
|
||||
)
|
||||
out = compiled(*inputs, **self.get_world_trs())
|
||||
correct = func(*inputs, **self.get_world_trs())
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
# TODO make the test case more interesting and validate the actual desired behavior
|
||||
assert node_stats is not None
|
||||
self.assertTrue(isinstance(node_stats, dict))
|
||||
self.assertEqual(len(node_stats), 2)
|
||||
# for stats in node_stats.values():
|
||||
# self.assertEqual(stats.moves, 0)
|
||||
# self.assertEqual(stats.limiting_factor, "data dependency")
|
||||
# self.assertEqual(stats.moves, 3)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_reorder_respects_wait_dep(self):
|
||||
"""
|
||||
|
@ -384,6 +384,10 @@ reorder_prefetch_limit: Optional[int] = None
|
||||
# enable operator reordering for peak memory optimization
|
||||
reorder_for_peak_memory = True
|
||||
|
||||
bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none"
|
||||
# By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used
|
||||
bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None
|
||||
|
||||
# runtime estimation function for ops
|
||||
# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
|
||||
estimate_op_runtime = "default"
|
||||
|
432
torch/_inductor/fx_passes/bucketing.py
Normal file
432
torch/_inductor/fx_passes/bucketing.py
Normal file
@ -0,0 +1,432 @@
|
||||
import logging
|
||||
import operator
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def bucket_size_determinator(bucket_id: int) -> float:
|
||||
"""
|
||||
Determine the size of a bucket based on its ID.
|
||||
|
||||
Args:
|
||||
bucket_id (int): The ID of the bucket.
|
||||
|
||||
Returns:
|
||||
float: The size of the bucket.
|
||||
"""
|
||||
return 2000.0
|
||||
|
||||
|
||||
def bucket_all_gather(
|
||||
gm: torch.fx.GraphModule, all_gather_bucket_cap_mb_callback: Callable[[int], float]
|
||||
) -> None:
|
||||
ag_buckets = bucket_all_gather_by_mb(gm, all_gather_bucket_cap_mb_callback)
|
||||
if len(ag_buckets) == 0:
|
||||
return
|
||||
merge_all_gather(gm, ag_buckets)
|
||||
|
||||
|
||||
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type]
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
|
||||
)
|
||||
|
||||
|
||||
def is_wait_tensor(node: torch.fx.Node) -> bool:
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops._c10d_functional.wait_tensor.default
|
||||
)
|
||||
|
||||
|
||||
def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
|
||||
return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def bucket_all_gather_by_mb(
|
||||
gm: torch.fx.GraphModule,
|
||||
all_gather_bucket_cap_mb_callback: Callable[[int], float],
|
||||
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
"""
|
||||
Identifies all all_gather nodes and groups them into buckets based on size limit `all_gather_bucket_cap_mb_callback`.
|
||||
|
||||
|
||||
Returns a list of buckets, where each bucket is a list of all_gather nodes.
|
||||
"""
|
||||
|
||||
node_list = gm.graph.nodes
|
||||
|
||||
# Prerequisite: Check if there is any all_gather node
|
||||
found_all_gather = False
|
||||
for node in node_list:
|
||||
if is_all_gather_into_tensor(node):
|
||||
found_all_gather = True
|
||||
break
|
||||
if not found_all_gather:
|
||||
return []
|
||||
|
||||
ag_nodes: list[torch.fx.Node] = []
|
||||
|
||||
# Step 1: Find all all_gather nodes
|
||||
for node in node_list:
|
||||
if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]):
|
||||
if (filter_wait_node is None) or filter_wait_node(node):
|
||||
ag_node = node.args[0]
|
||||
ag_nodes.append(ag_node)
|
||||
|
||||
# Step 2: Put all_gather nodes into buckets
|
||||
ag_buckets: list[list[torch.fx.Node]] = []
|
||||
cur_bucket: list[torch.fx.Node] = []
|
||||
cur_bucket_size_bytes: int = 0
|
||||
cur_bucket_id: int = 0
|
||||
# Convert MiB to bytes
|
||||
all_gather_bucket_size_bytes = int(
|
||||
all_gather_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
|
||||
)
|
||||
for ag_node in ag_nodes:
|
||||
assert is_all_gather_into_tensor(ag_node)
|
||||
assert "val" in ag_node.meta
|
||||
ag_output_size_bytes = (
|
||||
ag_node.meta["val"].numel()
|
||||
* torch.finfo(ag_node.meta["val"].dtype).bits
|
||||
// 8
|
||||
)
|
||||
if (
|
||||
cur_bucket_size_bytes + ag_output_size_bytes > all_gather_bucket_size_bytes
|
||||
and cur_bucket
|
||||
):
|
||||
# Current bucket is full, create new bucket
|
||||
ag_buckets.append(cur_bucket)
|
||||
cur_bucket = []
|
||||
cur_bucket_size_bytes = 0
|
||||
cur_bucket_id += 1
|
||||
cur_bucket_size_bytes += ag_output_size_bytes
|
||||
cur_bucket.append(ag_node)
|
||||
if cur_bucket:
|
||||
# add remaining nodes in the last bucket
|
||||
ag_buckets.append(cur_bucket)
|
||||
|
||||
return ag_buckets
|
||||
|
||||
|
||||
def node_copy( # type: ignore[no-untyped-def]
|
||||
env,
|
||||
new_graph,
|
||||
node: torch.fx.Node,
|
||||
arg_transform: Callable[[torch.fx.Node], torch.fx.node.Argument],
|
||||
) -> torch.fx.Node:
|
||||
if node not in env:
|
||||
new_node = new_graph.node_copy(node, arg_transform=arg_transform)
|
||||
env[node] = new_node
|
||||
else:
|
||||
new_node = env[node]
|
||||
return new_node
|
||||
|
||||
|
||||
def new_graph_call_function( # type: ignore[no-untyped-def]
|
||||
new_graph,
|
||||
target: Callable[..., Any],
|
||||
args: Optional[tuple[torch.fx.node.Argument, ...]] = None,
|
||||
kwargs: Optional[dict[str, torch.fx.node.Argument]] = None,
|
||||
type_expr: Optional[Any] = None,
|
||||
) -> torch.fx.Node:
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
new_node = new_graph.call_function(target, args, kwargs)
|
||||
args_val = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], args)
|
||||
kwargs_val = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], kwargs)
|
||||
with V.fake_mode, enable_python_dispatcher():
|
||||
new_fake_tensor = target(*args_val, **kwargs_val)
|
||||
new_node.meta["val"] = new_fake_tensor
|
||||
return new_node
|
||||
|
||||
|
||||
def env_lookup( # type: ignore[no-untyped-def]
|
||||
env, x: torch.fx.Node, node_user: Union[torch.fx.Node, str]
|
||||
) -> torch.fx.Node:
|
||||
assert x in env, (
|
||||
f"Dependent node {x} not in env when creating downstream node {node_user}"
|
||||
)
|
||||
return env[x]
|
||||
|
||||
|
||||
def merge_all_gather(
|
||||
gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]]
|
||||
) -> None:
|
||||
"""
|
||||
Transforms the graph to use bucketed all_gather operations based on `ag_buckets`.
|
||||
"""
|
||||
assert len(ag_buckets) > 0
|
||||
|
||||
ag_nodes: list[torch.fx.Node] = []
|
||||
cast_nodes: list[torch.fx.Node] = []
|
||||
ag_node_to_wait_node: dict[torch.fx.Node, torch.fx.Node] = {}
|
||||
ag_node_to_bucket_id = {}
|
||||
cast_node_to_bucket_id = {}
|
||||
|
||||
# Map nodes to buckets and identify wait nodes
|
||||
for bucket_id, bucket in enumerate(ag_buckets):
|
||||
for ag_node in bucket:
|
||||
assert len(ag_node.users) == 1, (
|
||||
f"Expect only one user for {ag_node}, but got {ag_node.users}"
|
||||
)
|
||||
wait_node = next(iter(ag_node.users))
|
||||
ag_node_to_wait_node[ag_node] = wait_node
|
||||
ag_nodes.append(ag_node)
|
||||
ag_node_to_bucket_id[ag_node] = bucket_id
|
||||
if (
|
||||
ag_node.args[0].op == "call_function" # type: ignore[union-attr]
|
||||
and ag_node.args[0].target # type: ignore[union-attr]
|
||||
== torch.ops.prims.convert_element_type.default
|
||||
):
|
||||
cast_nodes.append(ag_node.args[0]) # type: ignore[arg-type]
|
||||
cast_node_to_bucket_id[ag_node.args[0]] = bucket_id # type: ignore[arg-type]
|
||||
|
||||
# Step 3: Create new (bucketed) all_gather nodes
|
||||
bucket_id_to_bucketed_op_info = {}
|
||||
bucket_id_is_scheduled = {}
|
||||
cast_bucket_id_is_scheduled = {}
|
||||
_, group_size, group_name = next(iter(ag_node_to_wait_node.keys())).args
|
||||
for bucket_id, ag_bucket in enumerate(ag_buckets):
|
||||
ag_input_nodes = []
|
||||
wait_nodes = []
|
||||
for ag_node in ag_bucket:
|
||||
assert (
|
||||
ag_node in ag_node_to_wait_node
|
||||
and ag_node.args[1] == group_size
|
||||
and ag_node.args[2] == group_name
|
||||
)
|
||||
ag_input_nodes.append(ag_node.args[0])
|
||||
wait_nodes.append(ag_node_to_wait_node[ag_node])
|
||||
bucket_id_to_bucketed_op_info[bucket_id] = (
|
||||
ag_input_nodes,
|
||||
group_size,
|
||||
group_name,
|
||||
wait_nodes,
|
||||
)
|
||||
|
||||
ag_wait_nodes = list(ag_node_to_wait_node.values())
|
||||
ag_and_wait_nodes = OrderedSet(ag_nodes + ag_wait_nodes)
|
||||
cast_nodes = OrderedSet(cast_nodes)
|
||||
new_graph: torch.fx.Graph = torch.fx.Graph()
|
||||
env: dict[torch.fx.Node, torch.fx.Node] = {}
|
||||
|
||||
node_list = gm.graph.nodes
|
||||
for node in node_list:
|
||||
if node not in ag_and_wait_nodes and node not in cast_nodes:
|
||||
# not cast-before-all_gather, all_gather or its wait_tensor - schedule it normally
|
||||
node_copy(env, new_graph, node, lambda x: env_lookup(env, x, node))
|
||||
elif node in cast_nodes:
|
||||
# batch cast nodes together into one foreach_copy node
|
||||
assert node in cast_node_to_bucket_id
|
||||
bucket_id = cast_node_to_bucket_id[node]
|
||||
if bucket_id not in cast_bucket_id_is_scheduled:
|
||||
ag_input_nodes, group_size, group_name, orig_wait_nodes = (
|
||||
bucket_id_to_bucketed_op_info[bucket_id]
|
||||
)
|
||||
# device = ag_input_nodes[0].meta["val"].device
|
||||
# rank = device.index
|
||||
# dtype = ag_input_nodes[0].meta["val"].dtype
|
||||
if all(
|
||||
n.op == "call_function" # type: ignore[union-attr]
|
||||
and n.target == torch.ops.prims.convert_element_type.default # type: ignore[union-attr]
|
||||
for n in ag_input_nodes
|
||||
):
|
||||
param_all_gather_inputs = [
|
||||
new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops.aten.empty.memory_format,
|
||||
(n.meta["val"].shape,), # type: ignore[union-attr]
|
||||
{
|
||||
"dtype": n.args[1], # type: ignore[union-attr]
|
||||
"device": n.meta["val"].device, # type: ignore[union-attr]
|
||||
"pin_memory": False,
|
||||
},
|
||||
)
|
||||
for n in ag_input_nodes
|
||||
]
|
||||
for pp, n in zip(param_all_gather_inputs, ag_input_nodes):
|
||||
pp.meta = n.meta.copy() # type: ignore[union-attr]
|
||||
|
||||
cast_input_nodes = [env[n.args[0]] for n in ag_input_nodes] # type: ignore[union-attr, index]
|
||||
foreach_copy = new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops.aten._foreach_copy.default,
|
||||
(param_all_gather_inputs, cast_input_nodes),
|
||||
{},
|
||||
)
|
||||
foreach_copy.meta["val"] = [n.meta["val"] for n in ag_input_nodes] # type: ignore[union-attr]
|
||||
getitems = [
|
||||
new_graph_call_function(
|
||||
new_graph,
|
||||
operator.getitem,
|
||||
(foreach_copy, i),
|
||||
{},
|
||||
)
|
||||
for i in range(len(ag_input_nodes))
|
||||
]
|
||||
|
||||
for new_n, old_n in zip(getitems, ag_input_nodes):
|
||||
env[old_n] = new_n # type: ignore[index] # noqa: PERF403
|
||||
else:
|
||||
param_all_gather_inputs_orig = [
|
||||
node_copy(
|
||||
env,
|
||||
new_graph,
|
||||
ag_input_node, # type: ignore[arg-type]
|
||||
lambda x: env_lookup(env, x, ag_input_node), # type: ignore[arg-type]
|
||||
)
|
||||
for ag_input_node in ag_input_nodes
|
||||
]
|
||||
cast_bucket_id_is_scheduled[bucket_id] = True
|
||||
else:
|
||||
continue
|
||||
elif node in ag_node_to_wait_node:
|
||||
assert node in ag_node_to_bucket_id
|
||||
bucket_id = ag_node_to_bucket_id[node]
|
||||
if bucket_id not in bucket_id_is_scheduled:
|
||||
ag_input_nodes, group_size, group_name, orig_wait_nodes = (
|
||||
bucket_id_to_bucketed_op_info[bucket_id]
|
||||
)
|
||||
device = ag_input_nodes[0].meta["val"].device # type: ignore[union-attr]
|
||||
rank = device.index
|
||||
dtype = ag_input_nodes[0].meta["val"].dtype # type: ignore[union-attr]
|
||||
# TODO: if we want to support mixed dtype in the same bucket,
|
||||
# we need to first view all all_gather inputs as uint8 (common denominator),
|
||||
# then do the all_gather, then view the output back to the original dtype.
|
||||
# Look at FSDP2 to see how to do this.
|
||||
assert all(n.meta["val"].dtype == dtype for n in ag_input_nodes), ( # type: ignore[union-attr]
|
||||
"All all_gather inputs in the same bucket must have the same dtype"
|
||||
)
|
||||
# must schedule all the all_gather input nodes first, before the bucketed all_gather node
|
||||
param_all_gather_inputs_orig = [
|
||||
node_copy(
|
||||
env,
|
||||
new_graph,
|
||||
ag_input_node, # type: ignore[arg-type]
|
||||
lambda x: env_lookup(env, x, ag_input_node), # type: ignore[arg-type]
|
||||
)
|
||||
for ag_input_node in ag_input_nodes
|
||||
]
|
||||
# schedule the bucketed all_gather node
|
||||
param_all_gather_inputs_flattened = [
|
||||
new_graph_call_function(
|
||||
new_graph, torch.ops.aten.reshape.default, (n, [-1]), {}
|
||||
)
|
||||
for n in param_all_gather_inputs_orig
|
||||
]
|
||||
inp_split_sizes = [
|
||||
n.meta["val"].numel() for n in param_all_gather_inputs_orig
|
||||
]
|
||||
param_all_gather_outputs = [
|
||||
new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops.aten.empty.memory_format,
|
||||
([n.meta["val"].numel() * group_size],),
|
||||
{
|
||||
"dtype": n.meta["val"].dtype,
|
||||
"device": n.meta["val"].device,
|
||||
"pin_memory": False,
|
||||
},
|
||||
)
|
||||
for n in param_all_gather_inputs_orig
|
||||
]
|
||||
# TODO: This assumes dim-0 sharding.
|
||||
# If we need to support sharding on another dim, we should look at how FSDP2 does it
|
||||
# (e.g. search for `shard_dim` in FSDP2 codebase)
|
||||
param_all_gather_outputs_shape_orig = [
|
||||
(n.meta["val"].shape[0] * group_size,) + n.meta["val"].shape[1:]
|
||||
for n in param_all_gather_inputs_orig
|
||||
]
|
||||
all_gather_input_numel = sum(inp_split_sizes)
|
||||
|
||||
all_gather_output = new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops.aten.empty.memory_format,
|
||||
([all_gather_input_numel * group_size],),
|
||||
{
|
||||
"dtype": dtype,
|
||||
"device": device,
|
||||
"pin_memory": False,
|
||||
},
|
||||
)
|
||||
all_gather_copy_in = new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops.fsdp.all_gather_copy_in.default,
|
||||
(
|
||||
param_all_gather_inputs_flattened,
|
||||
all_gather_output,
|
||||
inp_split_sizes,
|
||||
all_gather_input_numel,
|
||||
rank,
|
||||
),
|
||||
{},
|
||||
)
|
||||
all_gather_input = new_graph_call_function(
|
||||
new_graph,
|
||||
operator.getitem,
|
||||
(all_gather_copy_in, 0),
|
||||
{},
|
||||
)
|
||||
all_gather_into_tensor_out = new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
|
||||
(all_gather_input, group_size, group_name),
|
||||
{"out": all_gather_output},
|
||||
)
|
||||
wait_tensor = new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops._c10d_functional.wait_tensor.default,
|
||||
(all_gather_into_tensor_out,),
|
||||
{},
|
||||
)
|
||||
all_gather_output_reshaped = new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops.aten.reshape.default,
|
||||
(wait_tensor, [group_size, -1]),
|
||||
{},
|
||||
)
|
||||
outs_flattened = [
|
||||
new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops.aten.reshape.default,
|
||||
(n, [group_size, -1]),
|
||||
{},
|
||||
)
|
||||
for n in param_all_gather_outputs
|
||||
]
|
||||
split_with_sizes_copy = new_graph_call_function( # noqa: F841
|
||||
new_graph,
|
||||
torch.ops.fsdp.split_with_sizes_copy.default,
|
||||
(all_gather_output_reshaped, inp_split_sizes),
|
||||
{"dim": 1, "out": outs_flattened},
|
||||
)
|
||||
outs = [
|
||||
new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops.aten.reshape.default,
|
||||
(n, orig_shape),
|
||||
{},
|
||||
)
|
||||
for n, orig_shape in zip(
|
||||
outs_flattened, param_all_gather_outputs_shape_orig
|
||||
)
|
||||
]
|
||||
assert len(orig_wait_nodes) == len(outs)
|
||||
assert len(orig_wait_nodes) > 0
|
||||
for out, orig_wait_node in zip(outs, orig_wait_nodes):
|
||||
env[orig_wait_node] = out # noqa: PERF403
|
||||
bucket_id_is_scheduled[bucket_id] = True
|
||||
else:
|
||||
continue
|
||||
gm.graph = new_graph
|
68
torch/_inductor/fx_passes/fsdp.py
Normal file
68
torch/_inductor/fx_passes/fsdp.py
Normal file
@ -0,0 +1,68 @@
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
bucket_all_gather_by_mb,
|
||||
merge_all_gather,
|
||||
)
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def is_graph_input(node: torch.fx.Node) -> bool:
|
||||
return node.op == "placeholder"
|
||||
|
||||
|
||||
def is_fsdp_all_gather_wait(wait: torch.fx.Node) -> bool:
|
||||
# Assume all_gather_into_tensor input is either graph input
|
||||
# or dtype conversion of graph input
|
||||
ag_node = wait.args[0] # type: ignore[arg-type, union-attr]
|
||||
return (
|
||||
is_graph_input(ag_node.args[0]) # type: ignore[arg-type, union-attr]
|
||||
or ( # type: ignore[arg-type, union-attr]
|
||||
ag_node.args[0].op == "call_function" # type: ignore[arg-type, union-attr]
|
||||
and ag_node.args[0].target # type: ignore[arg-type, union-attr]
|
||||
== torch.ops.prims.convert_element_type.default # type: ignore[arg-type, union-attr]
|
||||
and is_graph_input(ag_node.args[0].args[0]) # type: ignore[arg-type, union-attr]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def bucket_fsdp_all_gather(
|
||||
gm: torch.fx.GraphModule, all_gather_bucket_cap_mb_callback: Callable[[int], float]
|
||||
) -> None:
|
||||
"""
|
||||
Bucketing pass for SimpleFSDP all_gather ops.
|
||||
|
||||
Attributes:
|
||||
gm (torch.fx.GraphModule): Graph module of the graph.
|
||||
all_gather_bucket_cap_mb_callback (Callable[[int], float]): callback function that
|
||||
takes in bucket id and returns size of a bucket in megabytes.
|
||||
|
||||
Usage:
|
||||
```
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
bucket_all_gather,
|
||||
bucket_size_determinator,
|
||||
)
|
||||
|
||||
|
||||
def _bucket_all_gather(graph):
|
||||
return bucket_all_gather(graph.owning_module, bucket_size_determinator)
|
||||
|
||||
|
||||
torch._inductor.config.post_grad_custom_post_pass = _bucket_all_gather
|
||||
```
|
||||
"""
|
||||
|
||||
ag_buckets = bucket_all_gather_by_mb(
|
||||
gm,
|
||||
all_gather_bucket_cap_mb_callback,
|
||||
filter_wait_node=is_fsdp_all_gather_wait,
|
||||
)
|
||||
if len(ag_buckets) == 0:
|
||||
return
|
||||
merge_all_gather(gm, ag_buckets)
|
@ -219,6 +219,28 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
decompose_map_to_while_loop
|
||||
)
|
||||
|
||||
# Fx all_gather bucketing introduces mutation op
|
||||
# Keeping it in the end to keep invariant of functional graph for previous passes.
|
||||
if config.bucket_all_gathers_fx != "none":
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
bucket_all_gather,
|
||||
bucket_size_determinator,
|
||||
)
|
||||
from torch._inductor.fx_passes.fsdp import bucket_fsdp_all_gather
|
||||
|
||||
p = (
|
||||
bucket_fsdp_all_gather
|
||||
if config.bucket_all_gathers_fx == "fsdp"
|
||||
else bucket_all_gather
|
||||
)
|
||||
d = (
|
||||
config.bucket_all_gathers_fx_bucket_size_determinator
|
||||
or bucket_size_determinator
|
||||
)
|
||||
GraphTransformObserver(gm, "bucket_all_gathers").apply_graph_pass(
|
||||
lambda graph: p(graph.owning_module, d)
|
||||
)
|
||||
|
||||
gm.recompile()
|
||||
gm.graph.lint()
|
||||
|
||||
|
Reference in New Issue
Block a user