mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[async_tp] Support ag+mm with gather_dim lastdim of mat_A (#163068)
Adding ag+mm support for the case, when gather_dim is last dim of matmul (reduction dim). When we decompose matmul by reduction dimension we result in partials that needs additional reduction, we allocate memory for accumulator. Decomposition should not produce small (thin) mms that can not efficiently load the GPU. Limiting for minimal size of the shard 1024 (found empirically by testing in torchtitan). scaled_mm is not supported yet for this case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163068 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
d795fb225a
commit
585b9dbb5e
@ -294,7 +294,7 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
||||
)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("gather_dim", [0, 1])
|
||||
@parametrize("gather_dim", [0, 1, 2])
|
||||
def test_fused_all_gather_matmul(self, gather_dim: int) -> None:
|
||||
self._init_process()
|
||||
|
||||
@ -306,7 +306,10 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||
rank = self.rank
|
||||
|
||||
torch.manual_seed(42 + rank)
|
||||
A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
|
||||
A_shard_shape = [BATCH, M, K]
|
||||
A_shard_shape[gather_dim] //= self.world_size
|
||||
|
||||
A_shard = torch.rand(A_shard_shape, device="cuda")
|
||||
Bs = [torch.rand(K, N, device="cuda") for _ in range(3)]
|
||||
|
||||
ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback(
|
||||
@ -523,7 +526,7 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||
BATCH = 8
|
||||
M = 64
|
||||
N = 16
|
||||
K = 32
|
||||
K = 1024
|
||||
group = dist.group.WORLD
|
||||
rank = self.rank
|
||||
|
||||
|
@ -27,6 +27,10 @@ aten = torch.ops.aten
|
||||
patterns = PatternMatcherPass()
|
||||
|
||||
|
||||
def _is_last_dim(t: torch.Tensor, dim: int) -> bool:
|
||||
return dim == t.ndim - 1 or dim == -1
|
||||
|
||||
|
||||
def _is_backward(graph: torch.fx.Graph) -> bool:
|
||||
placeholders = []
|
||||
for node in graph.nodes:
|
||||
@ -645,9 +649,17 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
|
||||
if not is_symm_mem_enabled_for_group(group_name):
|
||||
return
|
||||
|
||||
if gather_dim >= len(_get_tensor(shard_node).shape) - 1:
|
||||
# Decomposing the matmul on the K dimension is not supported
|
||||
return
|
||||
filter_matmul = None
|
||||
if _is_last_dim(_get_tensor(shard_node), gather_dim):
|
||||
# Decomposed mms should not be too small
|
||||
if _get_tensor(shard_node).shape[-1] < 1024:
|
||||
return
|
||||
|
||||
# scaled_mm is not supported yet for last dim
|
||||
def _filter_out_scaled_matmul(matmul: _Matmul):
|
||||
return not isinstance(matmul, _ScaledMatmul)
|
||||
|
||||
filter_matmul = _filter_out_scaled_matmul
|
||||
|
||||
# Find consumer matmuls
|
||||
matmuls = _find_consumer_matmuls(ag_res_node)
|
||||
@ -663,18 +675,29 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
|
||||
if len(matmuls) == 0 or len(OrderedSet(map(type, matmuls))) != 1:
|
||||
return
|
||||
|
||||
if _is_last_dim(_get_tensor(shard_node), gather_dim) and len(
|
||||
all_gather.res_node.users
|
||||
) > len(matmuls):
|
||||
# The result of ag-split-cat is used not only in matmuls.
|
||||
# Then it has to be materialized, which can have overhead.
|
||||
return
|
||||
|
||||
if filter_matmul and not filter_matmul(matmuls[0]):
|
||||
return
|
||||
|
||||
# Fuse the all_gather_tensor with the eligible matmuls
|
||||
graph = ag_node.graph
|
||||
with graph.inserting_before(ag_node):
|
||||
if "val" in shard_node.meta:
|
||||
restrided = restride_A_shard_for_fused_all_gather_matmul(
|
||||
_get_tensor(shard_node),
|
||||
gather_dim,
|
||||
)
|
||||
shard_node = graph.call_function(
|
||||
inductor_prims.force_stride_order,
|
||||
args=(shard_node, restrided.stride()),
|
||||
)
|
||||
if not _is_last_dim(_get_tensor(shard_node), gather_dim):
|
||||
if "val" in shard_node.meta:
|
||||
restrided = restride_A_shard_for_fused_all_gather_matmul(
|
||||
_get_tensor(shard_node),
|
||||
gather_dim,
|
||||
)
|
||||
shard_node = graph.call_function(
|
||||
inductor_prims.force_stride_order,
|
||||
args=(shard_node, restrided.stride()),
|
||||
)
|
||||
|
||||
fused_node = _insert_fused_all_gather_matmul(
|
||||
graph, matmuls, shard_node, gather_dim, group_name
|
||||
@ -881,7 +904,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
|
||||
return
|
||||
|
||||
filter_matmul = None
|
||||
if orig_scatter_dim == _get_tensor(input_node).ndim - 1:
|
||||
if _is_last_dim(_get_tensor(input_node), orig_scatter_dim):
|
||||
# scaled_mm is not supported yet for last dim mm+rs
|
||||
def _filter_out_scaled_matmul(matmul: _Matmul):
|
||||
return not isinstance(matmul, _ScaledMatmul)
|
||||
|
@ -524,6 +524,19 @@ def _fused_all_gather_matmul_impl(
|
||||
|
||||
group = c10d._resolve_process_group(group_name)
|
||||
|
||||
if gather_dim == A_shard.ndim - 1 or gather_dim == -1:
|
||||
return _fused_all_gather_matmul_last_gather_dim_impl(
|
||||
mm_out_op,
|
||||
A_shard,
|
||||
Bs,
|
||||
A_scale,
|
||||
kwargs_list,
|
||||
out_dtypes,
|
||||
gather_dim,
|
||||
group_name,
|
||||
return_A,
|
||||
)
|
||||
|
||||
# Move the gather_dim to the front and flatten the tensor into a 2D matrix.
|
||||
# The flattened tensor doesn't need to be contiguous (for computation
|
||||
# efficiency), as _pipelined_all_gather_and_consume guarantees that shards
|
||||
@ -624,6 +637,140 @@ def _fused_all_gather_matmul_impl(
|
||||
return A, [unflatten(output) for output in outputs]
|
||||
|
||||
|
||||
def _pipelined_all_gather_and_consume_last_dim(
|
||||
shard: torch.Tensor,
|
||||
shard_consumer: Callable[[torch.Tensor, int], None],
|
||||
ag_out: torch.Tensor,
|
||||
group_name: str,
|
||||
ag_out_needed: bool = True,
|
||||
) -> None:
|
||||
p2p_workspace_size_req = 0
|
||||
p2p_workspace_size_req = shard.numel() * shard.element_size()
|
||||
symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req)
|
||||
group_size = symm_mem.world_size
|
||||
rank = symm_mem.rank
|
||||
|
||||
symm_mem.barrier(channel=0)
|
||||
backend_stream = _get_backend_stream()
|
||||
backend_stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
def copy_shard(dst: torch.Tensor, src: torch.Tensor) -> None:
|
||||
dst.copy_(src)
|
||||
|
||||
def get_p2p_buf(remote_rank: int) -> torch.Tensor:
|
||||
buf = symm_mem.get_buffer(
|
||||
remote_rank,
|
||||
shard.shape,
|
||||
shard.dtype,
|
||||
)
|
||||
return buf
|
||||
|
||||
local_p2p_buf = get_p2p_buf(rank)
|
||||
|
||||
shards = ag_out.chunk(group_size)
|
||||
|
||||
copy_shard(dst=local_p2p_buf, src=shard)
|
||||
symm_mem.barrier(channel=1)
|
||||
backend_stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
# At this point, all ranks have copied their local shard to
|
||||
# their local p2p buffer. Each rank can now copy and consume
|
||||
# remote shards.
|
||||
shard_consumer(shard, rank)
|
||||
|
||||
for step in range(1, group_size):
|
||||
if step % 2 == 0:
|
||||
stream = torch.cuda.current_stream()
|
||||
else:
|
||||
stream = backend_stream
|
||||
remote_rank = (step + rank) % group_size
|
||||
remote_p2p_buf = get_p2p_buf(remote_rank)
|
||||
with stream:
|
||||
copy_shard(dst=shards[remote_rank], src=remote_p2p_buf)
|
||||
shard_consumer(shards[remote_rank], remote_rank)
|
||||
|
||||
if ag_out_needed:
|
||||
# Copy from input to the all-gather output. Opportunistically overlap
|
||||
# it with the last shard_consumer.
|
||||
if group_size % 2 == 0:
|
||||
stream = torch.cuda.current_stream()
|
||||
else:
|
||||
stream = backend_stream
|
||||
with stream:
|
||||
copy_shard(dst=shards[rank], src=shard)
|
||||
|
||||
torch.cuda.current_stream().wait_stream(backend_stream)
|
||||
symm_mem.barrier(channel=0)
|
||||
|
||||
|
||||
def _fused_all_gather_matmul_last_gather_dim_impl(
|
||||
mm_out_op: torch._ops.OpOverload,
|
||||
A_shard: torch.Tensor,
|
||||
Bs: list[torch.Tensor],
|
||||
A_scale: torch.Tensor | None,
|
||||
kwargs_list: list[dict[str, Any]],
|
||||
out_dtypes: list[torch.dtype | None],
|
||||
gather_dim: int,
|
||||
group_name: str,
|
||||
return_A: bool,
|
||||
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
|
||||
group = c10d._resolve_process_group(group_name)
|
||||
group_size = group.size()
|
||||
|
||||
B_shards = [B.chunk(group.size()) for B in Bs]
|
||||
|
||||
leading_dims = list(A_shard.shape[:-1])
|
||||
A_shard_flat = A_shard.flatten(0, -2)
|
||||
|
||||
def unflatten(t: torch.Tensor) -> torch.Tensor:
|
||||
return t.view(*leading_dims, -1)
|
||||
|
||||
A_flat_out = A_shard_flat.new_empty(
|
||||
A_shard_flat.shape[0] * group.size(),
|
||||
A_shard_flat.shape[1],
|
||||
)
|
||||
|
||||
outputs = [
|
||||
torch.empty(
|
||||
(A_shard_flat.shape[0], B.shape[1]),
|
||||
dtype=out_dtype or B.dtype,
|
||||
device=A_shard.device,
|
||||
)
|
||||
for B, out_dtype in zip(Bs, out_dtypes)
|
||||
]
|
||||
|
||||
first = True
|
||||
events = [torch.cuda.Event() for _ in outputs]
|
||||
|
||||
def default_consumer(shard: torch.Tensor, rank: int) -> None:
|
||||
nonlocal first
|
||||
for out, event, B_shard, kwargs in zip(outputs, events, B_shards, kwargs_list):
|
||||
event.wait()
|
||||
if first:
|
||||
torch.ops.aten.mm.out(shard, B_shard[rank], **kwargs, out=out)
|
||||
else:
|
||||
out.addmm_(shard, B_shard[rank])
|
||||
event.record()
|
||||
|
||||
first = False
|
||||
|
||||
_pipelined_all_gather_and_consume_last_dim(
|
||||
A_shard_flat,
|
||||
default_consumer,
|
||||
A_flat_out,
|
||||
group_name,
|
||||
return_A,
|
||||
)
|
||||
ret_A = None
|
||||
if return_A:
|
||||
# This path is inefficient and will be filtered out at passes stage
|
||||
# Added only for completeness.
|
||||
A_split_cat_out_flat = torch.cat(A_flat_out.chunk(group_size), dim=-1)
|
||||
ret_A = unflatten(A_split_cat_out_flat)
|
||||
|
||||
return ret_A, [unflatten(output) for output in outputs]
|
||||
|
||||
|
||||
@torch.library.impl(lib, "fused_all_gather_matmul", "Meta")
|
||||
def _fused_all_gather_matmul_fallback(
|
||||
A_shard: torch.Tensor,
|
||||
@ -638,6 +785,15 @@ def _fused_all_gather_matmul_fallback(
|
||||
A_shard.contiguous(), group_size, group_name
|
||||
)
|
||||
A = torch.ops._c10d_functional.wait_tensor(A)
|
||||
if gather_dim == A.ndim - 1 or gather_dim == -1:
|
||||
A_splits = A.chunk(group_size)
|
||||
A_mm = torch.cat(A_splits, dim=-1)
|
||||
res = [torch.matmul(A_mm, B) for B in Bs]
|
||||
if return_A:
|
||||
return A_mm, res
|
||||
else:
|
||||
return None, res
|
||||
|
||||
A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1)
|
||||
res = [torch.matmul(A, B).movedim(0, gather_dim) for B in Bs]
|
||||
if return_A:
|
||||
|
Reference in New Issue
Block a user