Compare commits

...

3 Commits

Author SHA1 Message Date
83349ae64d [async_tp] Base support ag-transpose-mm(mat_B) case
ghstack-source-id: edd51b9c46e46e8eca0c45e0ea53c1b26b375c01
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163069
2025-09-19 08:35:51 -07:00
bf08b164dc [async_tp] Support ag+mm with gather_dim lastdim of mat_A
ghstack-source-id: 8de8acdc31566643d4b8370f27006002b05cdd61
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163068
2025-09-16 04:42:16 -07:00
da0b6aea11 [async_tp] Support mm+rs with scatter_dim matmul K by sharding B
ghstack-source-id: dee5390f82c6899af543adc6b91b5954097077ad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162794
2025-09-12 04:34:10 -07:00
3 changed files with 394 additions and 49 deletions

View File

@ -813,7 +813,10 @@ _fuse_ddp_communication_passes: list[Union[Callable[..., None], str]] = [
"schedule_comm_wait",
]
_micro_pipeline_tp: bool = False
_micro_pipeline_tp: bool = True
_micro_pipeline_tp_mm_rs_last_dim_enabled: bool = False
_micro_pipeline_tp_ag_mm_last_dim_enabled: bool = False
_micro_pipeline_tp_ag_transpose_mm_enabled: bool = False
class _collective:

View File

@ -376,16 +376,22 @@ class _Matmul:
B_node: torch.fx.Node
pre_mm_reshape: Optional[torch.fx.Node]
post_mm_reshape: Optional[torch.fx.Node]
pre_mm_B_transpose: Optional[torch.fx.Node]
def __post_init__(self):
assert len(self.nodes) in (1, 3)
assert len(self.nodes) in (1, 2, 3)
if len(self.nodes) == 1:
assert self.nodes[0].target in (aten.mm.default, aten._scaled_mm.default)
self.arg_ancestor_nodes = _find_ancestors(self.B_node)
elif len(self.nodes) == 2:
assert self.nodes[0].target == aten.permute.default
assert self.nodes[1].target in (aten.mm.default, aten._scaled_mm.default)
self.arg_ancestor_nodes = _find_ancestors(self.A_node)
else:
assert self.nodes[0].target == aten.reshape.default
assert self.nodes[1].target in (aten.mm.default, aten._scaled_mm.default)
assert self.nodes[2].target == aten.reshape.default
self.arg_ancestor_nodes = _find_ancestors(self.B_node)
self.arg_ancestor_nodes = _find_ancestors(self.B_node)
def replace_with(self, new_node: torch.fx.Node) -> None:
"""
@ -401,6 +407,15 @@ class _Matmul:
graph.erase_node(mm_node)
return
if len(self.nodes) == 2:
permute_node = self.nodes[0]
mm_node = self.nodes[1]
assert mm_node.target in (aten.mm.default, aten._scaled_mm.default)
mm_node.replace_all_uses_with(new_node)
graph.erase_node(mm_node)
if len(permute_node.users) == 0:
graph.erase_node(permute_node)
return
# An ND-matmul is reshape -> mm -> reshape sequence. We first replace
# the second reshape node with `new_node`. Then, we ensure that the
# original mm node in the sequence ends up with zero users by replacing
@ -443,6 +458,7 @@ class _Matmul:
# TODO: explore unifying the _Matmul and _ScaledMatmul approaches to handling reshapes.
pre_mm_reshape=None,
post_mm_reshape=None,
pre_mm_B_transpose=None,
)
@ -550,6 +566,24 @@ def _find_reshape_mm_reshape(node: torch.fx.Node) -> list[_Matmul]:
return matmuls
def _find_permute_mm(permute_node: torch.fx.Node) -> list[_Matmul]:
for mm_node in permute_node.users:
if mm_node.target != aten.mm.default:
continue
if permute_node == mm_node.args[1]:
return [
_Matmul(
nodes=[permute_node, mm_node],
A_node=cast("torch.fx.Node", mm_node.args[0]),
B_node=cast("torch.fx.Node", mm_node.args[1]),
pre_mm_reshape=None,
post_mm_reshape=None,
pre_mm_B_transpose=permute_node,
)
]
return []
def _find_consumer_matmuls(node: torch.fx.Node) -> list[_Matmul]:
"""
Find the matmuls that use `node` as the lhs argument.
@ -566,6 +600,23 @@ def _find_consumer_matmuls(node: torch.fx.Node) -> list[_Matmul]:
elif user.target == aten._scaled_mm.default:
matmul = _ScaledMatmul.from_match([user])
matmuls.append(matmul)
elif (
config._micro_pipeline_tp_ag_transpose_mm_enabled
and user.target == aten.permute.default
and (user.args[1] == [1, 0] or user.args[1] == [0, 1])
):
permute_matmuls = _find_permute_mm(user)
if permute_matmuls:
if not matmuls:
matmuls.extend(permute_matmuls)
else:
has_not_permute_matmul = False
for matmul in matmuls:
if matmul.pre_mm_B_transpose:
has_not_permute_matmul = True
if not has_not_permute_matmul:
matmuls.extend(permute_matmuls)
return matmuls
@ -607,6 +658,55 @@ def _insert_fused_all_gather_matmul(
raise AssertionError(f"Unexpected matmul match type: {mm_type}")
def graph_call_function_transpose(graph, n):
return graph.call_function(
torch.ops.aten.permute.default,
args=(n, [1, 0]),
)
def graph_call_function_contiguous(graph, n):
return graph.call_function(
torch.ops.aten.clone.default,
args=(n,),
kwargs={"memory_format": torch.contiguous_format},
)
# mat_B = ag(shard)
# mat_B_t = mat_B.t()
# return mm(mat_A, mat_B_t)
# ->
# mat_A_t = mat_A.t()
# mat_B = ag(shard)
# res_mm_t = mm(mat_B, mat_A_t)
# return res_mm_t
def _insert_fused_all_gather_transpose_matmul(
graph: torch.fx.Graph,
matmuls: list[_Matmul],
shard_node: torch.fx.Node,
gather_dim: int,
group_name: str,
) -> torch.fx.Node:
mm_types = OrderedSet(map(type, matmuls))
assert len(mm_types) == 1
mm_type = next(iter(mm_types))
if mm_type == _Matmul:
B_nodes = [
graph_call_function_transpose(graph, matmul.A_node) for matmul in matmuls
]
res_fused_mm = graph.call_function(
torch.ops.symm_mem.fused_all_gather_matmul.default,
args=(shard_node, B_nodes, gather_dim, group_name),
kwargs={"return_A": True},
)
return res_fused_mm
else:
raise AssertionError(f"Unexpected matmul match type: {mm_type}")
def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
"""
Fused the pattern
@ -645,8 +745,10 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
if not is_symm_mem_enabled_for_group(group_name):
return
if gather_dim >= len(_get_tensor(shard_node).shape) - 1:
# Decomposing the matmul on the K dimension is not supported
if (
not config._micro_pipeline_tp_ag_mm_last_dim_enabled
and gather_dim == _get_tensor(shard_node).ndim - 1
):
return
# Find consumer matmuls
@ -666,45 +768,84 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
# Fuse the all_gather_tensor with the eligible matmuls
graph = ag_node.graph
with graph.inserting_before(ag_node):
if "val" in shard_node.meta:
restrided = restride_A_shard_for_fused_all_gather_matmul(
_get_tensor(shard_node),
gather_dim,
if matmuls[0].pre_mm_B_transpose is not None:
if "val" in shard_node.meta:
restrided = restride_A_shard_for_fused_all_gather_matmul(
_get_tensor(shard_node),
gather_dim,
)
shard_node = graph.call_function(
inductor_prims.force_stride_order,
args=(shard_node, restrided.stride()),
)
fused_node = _insert_fused_all_gather_transpose_matmul(
graph, matmuls, shard_node, gather_dim, group_name
)
shard_node = graph.call_function(
inductor_prims.force_stride_order,
args=(shard_node, restrided.stride()),
)
fused_node = _insert_fused_all_gather_matmul(
graph, matmuls, shard_node, gather_dim, group_name
)
new_ag_node = graph.call_function(
operator.getitem,
args=(fused_node, 0),
)
new_out_nodes = graph.call_function(
operator.getitem,
args=(fused_node, 1),
)
for idx, matmul in enumerate(matmuls):
new_out_node = graph.call_function(
new_ag_node = graph.call_function(
operator.getitem,
args=(new_out_nodes, idx),
args=(fused_node, 0),
)
fused_out_1 = graph.call_function(
operator.getitem,
args=(fused_node, 1),
)
new_out_nodes = [
graph.call_function(operator.getitem, args=(fused_out_1, i))
for i in range(len(matmuls))
]
new_out_nodes = [
graph_call_function_transpose(graph, out_node_t)
for out_node_t in new_out_nodes
]
# Restride the inputs before fused that result will be contiguous (or pre pass stridenss)
new_out_nodes = [
graph_call_function_contiguous(graph, out_node)
for out_node in new_out_nodes
]
for matmul, new_out_node in zip(matmuls, new_out_nodes):
matmul.replace_with(new_out_node)
matmul.erase()
else:
if "val" in shard_node.meta:
restrided = restride_A_shard_for_fused_all_gather_matmul(
_get_tensor(shard_node),
gather_dim,
)
shard_node = graph.call_function(
inductor_prims.force_stride_order,
args=(shard_node, restrided.stride()),
)
fused_node = _insert_fused_all_gather_matmul(
graph, matmuls, shard_node, gather_dim, group_name
)
new_ag_node = graph.call_function(
operator.getitem,
args=(fused_node, 0),
)
new_out_nodes = graph.call_function(
operator.getitem,
args=(fused_node, 1),
)
matmul.replace_with(new_out_node)
matmul.erase()
all_gather.replace_with(new_ag_node)
all_gather.erase()
# If the new_ag_node has no users, we tell the fused op to not return
# it. This creates more optimization opportunities.
if len(new_ag_node.users) == 0:
graph.erase_node(new_ag_node)
kwargs = dict(fused_node.kwargs)
if "return_A" in kwargs:
kwargs["return_A"] = False
fused_node.kwargs = kwargs
for idx, matmul in enumerate(matmuls):
new_out_node = graph.call_function(
operator.getitem,
args=(new_out_nodes, idx),
)
matmul.replace_with(new_out_node)
matmul.erase()
all_gather.replace_with(new_ag_node)
all_gather.erase()
# If the new_ag_node has no users, we tell the fused op to not return
# it. This creates more optimization opportunities.
if len(new_ag_node.users) == 0:
graph.erase_node(new_ag_node)
kwargs = dict(fused_node.kwargs)
if "return_A" in kwargs:
kwargs["return_A"] = False
fused_node.kwargs = kwargs
# Raise ancestors of non-A args that are topologically ordered between
# ag_res_node and the matmul above fused_node.
@ -880,6 +1021,12 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
if not is_symm_mem_enabled_for_group(group_name):
return
if (
not config._micro_pipeline_tp_mm_rs_last_dim_enabled
and orig_scatter_dim == _get_tensor(input_node).ndim - 1
):
return
# Currently fused_matmul_reduce_scatter doesn't return the matmul result,
# so we can't apply the fusion if the matmul result is used by multiple
# users. This is not a fundamental limitation of the fused op and can be
@ -1072,8 +1219,8 @@ def micro_pipeline_tp_pass(graph: torch.fx.Graph):
"async TP found no matching all-gather/reduce-scatter patterns for fusion"
)
for all_gather in all_gathers:
fuse_all_gather_matmul(all_gather)
for reduce_scatter in reduce_scatters:
fuse_matmul_reduce_scatter(reduce_scatter)
for all_gather in all_gathers:
fuse_all_gather_matmul(all_gather)

View File

@ -526,6 +526,22 @@ def _fused_all_gather_matmul_impl(
group = c10d._resolve_process_group(group_name)
if gather_dim == A_shard.ndim - 1:
# Implementation for gathering on last dimension of matmul (N)
# A_shard splitted column wise
# A_shard: [A0, A1, ... , Ags]
return _fused_all_gather_matmul_last_gather_dim_impl(
mm_out_op,
A_shard,
Bs,
A_scale,
kwargs_list,
out_dtypes,
gather_dim,
group_name,
return_A,
)
# Move the gather_dim to the front and flatten the tensor into a 2D matrix.
# The flattened tensor doesn't need to be contiguous (for computation
# efficiency), as _pipelined_all_gather_and_consume guarantees that shards
@ -626,6 +642,147 @@ def _fused_all_gather_matmul_impl(
return A, [unflatten(output) for output in outputs]
def _pipelined_all_gather_and_consume_last_dim(
shard: torch.Tensor,
shard_consumer: Callable[[torch.Tensor, int], None],
ag_out: torch.Tensor,
group_name: str,
ag_out_needed: bool = True,
) -> None:
p2p_workspace_size_req = 0
p2p_workspace_size_req = shard.numel() * shard.element_size()
symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req)
group_size = symm_mem.world_size
rank = symm_mem.rank
symm_mem.barrier(channel=0)
backend_stream = _get_backend_stream()
backend_stream.wait_stream(torch.cuda.current_stream())
def copy_shard(dst: torch.Tensor, src: torch.Tensor) -> None:
dst.copy_(src)
def get_p2p_buf(remote_rank: int) -> torch.Tensor:
buf = symm_mem.get_buffer(
remote_rank,
shard.shape,
shard.dtype,
)
return buf
local_p2p_buf = get_p2p_buf(rank)
# shards[i] => shard from rank i
shards = []
for i, c in enumerate(ag_out.chunk(group_size)):
shards.append(c)
copy_shard(dst=local_p2p_buf, src=shard)
symm_mem.barrier(channel=1)
backend_stream.wait_stream(torch.cuda.current_stream())
# At this point, all ranks have copied their local shard to
# their local p2p buffer. Each rank can now copy and consume
# remote shards.
shard_consumer(shard, rank)
for step in range(1, group_size):
if step % 2 == 0:
stream = torch.cuda.current_stream()
else:
stream = backend_stream
remote_rank = (step + rank) % group_size
remote_p2p_buf = get_p2p_buf(remote_rank)
with stream:
copy_shard(dst=shards[remote_rank], src=remote_p2p_buf)
shard_consumer(shards[remote_rank], remote_rank)
if ag_out_needed:
# Copy from input to the all-gather output. Opportunistically overlap
# it with the last shard_consumer.
if group_size % 2 == 0:
stream = torch.cuda.current_stream()
else:
stream = backend_stream
with stream:
copy_shard(dst=shards[rank], src=shard)
torch.cuda.current_stream().wait_stream(backend_stream)
symm_mem.barrier(channel=0)
def _fused_all_gather_matmul_last_gather_dim_impl(
mm_out_op: torch._ops.OpOverload,
A_shard: torch.Tensor,
Bs: list[torch.Tensor],
A_scale: torch.Tensor | None,
kwargs_list: list[dict[str, Any]],
out_dtypes: list[torch.dtype | None],
gather_dim: int,
group_name: str,
return_A: bool,
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
assert gather_dim == A_shard.ndim - 1
group = c10d._resolve_process_group(group_name)
# A_shard splitted column wise
# A_shard: [A0, A1, ... , Ags]
# A0 * B is Partial of the same size as output
B_shards = [B.chunk(group.size()) for B in Bs]
leading_dims = list(A_shard.shape[:-1])
A_shard_flat = A_shard.flatten(0, -2)
def unflatten(t: torch.Tensor) -> torch.Tensor:
return t.view(*leading_dims, -1)
A_out_leading_dims = list(A_shard.shape[:-1])
A_out_leading_dims[0] *= 4
def unflatten_A_out(t: torch.Tensor) -> torch.Tensor:
return t.view(*A_out_leading_dims, -1)
A_flat_out = A_shard_flat.new_empty(
A_shard_flat.shape[0] * group.size(),
A_shard_flat.shape[1],
)
# Outputs work as accumulator for output_partials
outputs = [
torch.empty(
(A_shard_flat.shape[0], B.shape[1]),
dtype=out_dtype or B.dtype,
device=A_shard.device,
)
for B, out_dtype in zip(Bs, out_dtypes)
]
# Additional allocation for partials output,
# That will be reduced into output.
output_partials = [torch.empty_like(out) for out in outputs]
first = True
def default_consumer(shard: torch.Tensor, rank: int) -> None:
nonlocal first
for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)):
out = outputs[idx] if first else output_partials[idx]
mm_out_op(shard, B_shards[idx][rank], **kwargs, out=out)
if not first:
outputs[idx] += output_partials[idx]
first = False
_pipelined_all_gather_and_consume_last_dim(
A_shard_flat,
default_consumer,
A_flat_out,
group_name,
return_A,
)
A = unflatten_A_out(A_flat_out) if return_A else None
return A, [unflatten(output) for output in outputs]
@torch.library.impl(lib, "fused_all_gather_matmul", "Meta")
def _fused_all_gather_matmul_fallback(
A_shard: torch.Tensor,
@ -640,12 +797,22 @@ def _fused_all_gather_matmul_fallback(
A_shard.contiguous(), group_size, group_name
)
A = torch.ops._c10d_functional.wait_tensor(A)
A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1)
res = [torch.matmul(A, B).movedim(0, gather_dim) for B in Bs]
if return_A:
return A.movedim(0, gather_dim), res
if gather_dim == A.ndim - 1:
A_mm_shape = list(A_shard.shape)
A_mm_shape[-1] *= group_size
A_mm = A.new_empty(A_mm_shape)
res = [torch.matmul(A_mm, B) for B in Bs]
if return_A:
return A, res
else:
return None, res
else:
return None, res
A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1)
res = [torch.matmul(A, B).movedim(0, gather_dim) for B in Bs]
if return_A:
return A.movedim(0, gather_dim), res
else:
return None, res
@torch.library.impl(lib, "fused_all_gather_matmul", "CUDA")
@ -1070,11 +1237,39 @@ def _fused_matmul_reduce_scatter_impl(
reduce_fn = partial(torch.mean, dim=0)
else:
raise ValueError("reduce_op must be sum or avg")
group = c10d._resolve_process_group(group_name)
out_shape = [*A.shape[:-1], B.shape[1]]
out_shape[scatter_dim] //= group.size()
if scatter_dim == A.ndim - 1:
Bt = B.t()
Bt_shards = Bt.chunk(group.size())
x = A.flatten(0, -2)
def _chunk_producer(rank: int, out: torch.Tensor) -> None:
mm_out_op(A, Bt_shards[rank].t(), **kwargs, out=out)
leading_dims = [group.size()] + list(x.shape[:-1])
stacked_partials = x.new_empty(
x.shape[0], B.shape[1], dtype=out_dtype or A.dtype
)
_pipelined_produce_and_all2all(
_chunk_producer,
stacked_partials,
group_name,
)
# Ensures that the transpose and reduction produce contiguous result
# in a single reduction kernel.
stacked_partials_view = stacked_partials.view(*leading_dims, -1)
stacked_partials_view = stacked_partials_view.movedim(0, scatter_dim)
return reduce_fn(
stacked_partials_view,
dim=scatter_dim,
)
# Move the scatter_dim to the front and flatten the tensor into a 2D matrix
x = A.movedim(scatter_dim, 0)
leading_dims = [group.size()] + list(x.shape[:-1])