[async_tp] Support ag+mm with gather_dim lastdim of mat_A (#163068)

Adding ag+mm support for the case, when gather_dim is last dim of matmul (reduction dim).

When we decompose matmul by reduction dimension we result in partials that needs additional reduction,
we allocate memory for accumulator.

Decomposition should not produce small (thin) mms that can not efficiently load the GPU. Limiting for minimal size of the shard 1024 (found empirically by testing in torchtitan).

scaled_mm is not supported yet for this case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163068
Approved by: https://github.com/ngimel
This commit is contained in:
IvanKobzarev
2025-10-15 03:26:11 -07:00
committed by PyTorch MergeBot
parent d795fb225a
commit 585b9dbb5e
3 changed files with 198 additions and 16 deletions

View File

@ -294,7 +294,7 @@ class AsyncTPTest(MultiProcContinuousTest):
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(2)
@parametrize("gather_dim", [0, 1])
@parametrize("gather_dim", [0, 1, 2])
def test_fused_all_gather_matmul(self, gather_dim: int) -> None:
self._init_process()
@ -306,7 +306,10 @@ class AsyncTPTest(MultiProcContinuousTest):
rank = self.rank
torch.manual_seed(42 + rank)
A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
A_shard_shape = [BATCH, M, K]
A_shard_shape[gather_dim] //= self.world_size
A_shard = torch.rand(A_shard_shape, device="cuda")
Bs = [torch.rand(K, N, device="cuda") for _ in range(3)]
ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback(
@ -523,7 +526,7 @@ class AsyncTPTest(MultiProcContinuousTest):
BATCH = 8
M = 64
N = 16
K = 32
K = 1024
group = dist.group.WORLD
rank = self.rank

View File

@ -27,6 +27,10 @@ aten = torch.ops.aten
patterns = PatternMatcherPass()
def _is_last_dim(t: torch.Tensor, dim: int) -> bool:
return dim == t.ndim - 1 or dim == -1
def _is_backward(graph: torch.fx.Graph) -> bool:
placeholders = []
for node in graph.nodes:
@ -645,9 +649,17 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
if not is_symm_mem_enabled_for_group(group_name):
return
if gather_dim >= len(_get_tensor(shard_node).shape) - 1:
# Decomposing the matmul on the K dimension is not supported
return
filter_matmul = None
if _is_last_dim(_get_tensor(shard_node), gather_dim):
# Decomposed mms should not be too small
if _get_tensor(shard_node).shape[-1] < 1024:
return
# scaled_mm is not supported yet for last dim
def _filter_out_scaled_matmul(matmul: _Matmul):
return not isinstance(matmul, _ScaledMatmul)
filter_matmul = _filter_out_scaled_matmul
# Find consumer matmuls
matmuls = _find_consumer_matmuls(ag_res_node)
@ -663,18 +675,29 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
if len(matmuls) == 0 or len(OrderedSet(map(type, matmuls))) != 1:
return
if _is_last_dim(_get_tensor(shard_node), gather_dim) and len(
all_gather.res_node.users
) > len(matmuls):
# The result of ag-split-cat is used not only in matmuls.
# Then it has to be materialized, which can have overhead.
return
if filter_matmul and not filter_matmul(matmuls[0]):
return
# Fuse the all_gather_tensor with the eligible matmuls
graph = ag_node.graph
with graph.inserting_before(ag_node):
if "val" in shard_node.meta:
restrided = restride_A_shard_for_fused_all_gather_matmul(
_get_tensor(shard_node),
gather_dim,
)
shard_node = graph.call_function(
inductor_prims.force_stride_order,
args=(shard_node, restrided.stride()),
)
if not _is_last_dim(_get_tensor(shard_node), gather_dim):
if "val" in shard_node.meta:
restrided = restride_A_shard_for_fused_all_gather_matmul(
_get_tensor(shard_node),
gather_dim,
)
shard_node = graph.call_function(
inductor_prims.force_stride_order,
args=(shard_node, restrided.stride()),
)
fused_node = _insert_fused_all_gather_matmul(
graph, matmuls, shard_node, gather_dim, group_name
@ -881,7 +904,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
return
filter_matmul = None
if orig_scatter_dim == _get_tensor(input_node).ndim - 1:
if _is_last_dim(_get_tensor(input_node), orig_scatter_dim):
# scaled_mm is not supported yet for last dim mm+rs
def _filter_out_scaled_matmul(matmul: _Matmul):
return not isinstance(matmul, _ScaledMatmul)

View File

@ -524,6 +524,19 @@ def _fused_all_gather_matmul_impl(
group = c10d._resolve_process_group(group_name)
if gather_dim == A_shard.ndim - 1 or gather_dim == -1:
return _fused_all_gather_matmul_last_gather_dim_impl(
mm_out_op,
A_shard,
Bs,
A_scale,
kwargs_list,
out_dtypes,
gather_dim,
group_name,
return_A,
)
# Move the gather_dim to the front and flatten the tensor into a 2D matrix.
# The flattened tensor doesn't need to be contiguous (for computation
# efficiency), as _pipelined_all_gather_and_consume guarantees that shards
@ -624,6 +637,140 @@ def _fused_all_gather_matmul_impl(
return A, [unflatten(output) for output in outputs]
def _pipelined_all_gather_and_consume_last_dim(
shard: torch.Tensor,
shard_consumer: Callable[[torch.Tensor, int], None],
ag_out: torch.Tensor,
group_name: str,
ag_out_needed: bool = True,
) -> None:
p2p_workspace_size_req = 0
p2p_workspace_size_req = shard.numel() * shard.element_size()
symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req)
group_size = symm_mem.world_size
rank = symm_mem.rank
symm_mem.barrier(channel=0)
backend_stream = _get_backend_stream()
backend_stream.wait_stream(torch.cuda.current_stream())
def copy_shard(dst: torch.Tensor, src: torch.Tensor) -> None:
dst.copy_(src)
def get_p2p_buf(remote_rank: int) -> torch.Tensor:
buf = symm_mem.get_buffer(
remote_rank,
shard.shape,
shard.dtype,
)
return buf
local_p2p_buf = get_p2p_buf(rank)
shards = ag_out.chunk(group_size)
copy_shard(dst=local_p2p_buf, src=shard)
symm_mem.barrier(channel=1)
backend_stream.wait_stream(torch.cuda.current_stream())
# At this point, all ranks have copied their local shard to
# their local p2p buffer. Each rank can now copy and consume
# remote shards.
shard_consumer(shard, rank)
for step in range(1, group_size):
if step % 2 == 0:
stream = torch.cuda.current_stream()
else:
stream = backend_stream
remote_rank = (step + rank) % group_size
remote_p2p_buf = get_p2p_buf(remote_rank)
with stream:
copy_shard(dst=shards[remote_rank], src=remote_p2p_buf)
shard_consumer(shards[remote_rank], remote_rank)
if ag_out_needed:
# Copy from input to the all-gather output. Opportunistically overlap
# it with the last shard_consumer.
if group_size % 2 == 0:
stream = torch.cuda.current_stream()
else:
stream = backend_stream
with stream:
copy_shard(dst=shards[rank], src=shard)
torch.cuda.current_stream().wait_stream(backend_stream)
symm_mem.barrier(channel=0)
def _fused_all_gather_matmul_last_gather_dim_impl(
mm_out_op: torch._ops.OpOverload,
A_shard: torch.Tensor,
Bs: list[torch.Tensor],
A_scale: torch.Tensor | None,
kwargs_list: list[dict[str, Any]],
out_dtypes: list[torch.dtype | None],
gather_dim: int,
group_name: str,
return_A: bool,
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
group = c10d._resolve_process_group(group_name)
group_size = group.size()
B_shards = [B.chunk(group.size()) for B in Bs]
leading_dims = list(A_shard.shape[:-1])
A_shard_flat = A_shard.flatten(0, -2)
def unflatten(t: torch.Tensor) -> torch.Tensor:
return t.view(*leading_dims, -1)
A_flat_out = A_shard_flat.new_empty(
A_shard_flat.shape[0] * group.size(),
A_shard_flat.shape[1],
)
outputs = [
torch.empty(
(A_shard_flat.shape[0], B.shape[1]),
dtype=out_dtype or B.dtype,
device=A_shard.device,
)
for B, out_dtype in zip(Bs, out_dtypes)
]
first = True
events = [torch.cuda.Event() for _ in outputs]
def default_consumer(shard: torch.Tensor, rank: int) -> None:
nonlocal first
for out, event, B_shard, kwargs in zip(outputs, events, B_shards, kwargs_list):
event.wait()
if first:
torch.ops.aten.mm.out(shard, B_shard[rank], **kwargs, out=out)
else:
out.addmm_(shard, B_shard[rank])
event.record()
first = False
_pipelined_all_gather_and_consume_last_dim(
A_shard_flat,
default_consumer,
A_flat_out,
group_name,
return_A,
)
ret_A = None
if return_A:
# This path is inefficient and will be filtered out at passes stage
# Added only for completeness.
A_split_cat_out_flat = torch.cat(A_flat_out.chunk(group_size), dim=-1)
ret_A = unflatten(A_split_cat_out_flat)
return ret_A, [unflatten(output) for output in outputs]
@torch.library.impl(lib, "fused_all_gather_matmul", "Meta")
def _fused_all_gather_matmul_fallback(
A_shard: torch.Tensor,
@ -638,6 +785,15 @@ def _fused_all_gather_matmul_fallback(
A_shard.contiguous(), group_size, group_name
)
A = torch.ops._c10d_functional.wait_tensor(A)
if gather_dim == A.ndim - 1 or gather_dim == -1:
A_splits = A.chunk(group_size)
A_mm = torch.cat(A_splits, dim=-1)
res = [torch.matmul(A_mm, B) for B in Bs]
if return_A:
return A_mm, res
else:
return None, res
A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1)
res = [torch.matmul(A, B).movedim(0, gather_dim) for B in Bs]
if return_A: