Compare commits

...

2 Commits

Author SHA1 Message Date
98826fd37b [annotate] add annotate_fn function decorator 2025-10-17 09:23:55 -07:00
585b9dbb5e [async_tp] Support ag+mm with gather_dim lastdim of mat_A (#163068)
Adding ag+mm support for the case, when gather_dim is last dim of matmul (reduction dim).

When we decompose matmul by reduction dimension we result in partials that needs additional reduction,
we allocate memory for accumulator.

Decomposition should not produce small (thin) mms that can not efficiently load the GPU. Limiting for minimal size of the shard 1024 (found empirically by testing in torchtitan).

scaled_mm is not supported yet for this case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163068
Approved by: https://github.com/ngimel
2025-10-16 20:14:39 +00:00
5 changed files with 280 additions and 19 deletions

View File

@ -294,7 +294,7 @@ class AsyncTPTest(MultiProcContinuousTest):
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(2)
@parametrize("gather_dim", [0, 1])
@parametrize("gather_dim", [0, 1, 2])
def test_fused_all_gather_matmul(self, gather_dim: int) -> None:
self._init_process()
@ -306,7 +306,10 @@ class AsyncTPTest(MultiProcContinuousTest):
rank = self.rank
torch.manual_seed(42 + rank)
A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
A_shard_shape = [BATCH, M, K]
A_shard_shape[gather_dim] //= self.world_size
A_shard = torch.rand(A_shard_shape, device="cuda")
Bs = [torch.rand(K, N, device="cuda") for _ in range(3)]
ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback(
@ -523,7 +526,7 @@ class AsyncTPTest(MultiProcContinuousTest):
BATCH = 8
M = 64
N = 16
K = 32
K = 1024
group = dist.group.WORLD
rank = self.rank

View File

@ -57,7 +57,7 @@ def graph_capture(model, inputs, with_export):
with ExitStack() as stack:
joint_with_descriptors = aot_export_joint_with_descriptors(
stack,
model,
gm,
inputs,
)
return joint_with_descriptors.graph_module
@ -922,6 +922,46 @@ class inner_f(torch.nn.Module):
in custom_metadata
)
def test_preserve_annotate_function(self):
"""Test basic annotate_fn usage"""
@fx_traceback.annotate_fn({"pp_stage": 1})
def example_function(x):
return x * x
class SimpleLinear(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 2)
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
y = self.linear(x)
y = example_function(y)
return y - 1
inputs = (torch.randn(4, 3),)
model = SimpleLinear()
for with_export in [True, False]:
graph_module = graph_capture(model, inputs, with_export)
custom_metadata = fx_traceback._get_custom_metadata(graph_module)
self.assertExpectedInline(
str(custom_metadata),
"""\
('call_function', 't', {'pp_stage': 0})
('call_function', 'addmm', {'pp_stage': 0})
('call_function', 'mul', {'pp_stage': 1})
('call_function', 'mul_1', {'pp_stage': 1})
('call_function', 'mul_2', {'pp_stage': 1})
('call_function', 't_1', {'pp_stage': 0})
('call_function', 'mm', {'pp_stage': 0})
('call_function', 't_2', {'pp_stage': 0})
('call_function', 'sum_1', {'pp_stage': 0})
('call_function', 'view', {'pp_stage': 0})
('call_function', 't_3', {'pp_stage': 0})""",
)
if __name__ == "__main__":
run_tests()

View File

@ -27,6 +27,10 @@ aten = torch.ops.aten
patterns = PatternMatcherPass()
def _is_last_dim(t: torch.Tensor, dim: int) -> bool:
return dim == t.ndim - 1 or dim == -1
def _is_backward(graph: torch.fx.Graph) -> bool:
placeholders = []
for node in graph.nodes:
@ -645,9 +649,17 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
if not is_symm_mem_enabled_for_group(group_name):
return
if gather_dim >= len(_get_tensor(shard_node).shape) - 1:
# Decomposing the matmul on the K dimension is not supported
return
filter_matmul = None
if _is_last_dim(_get_tensor(shard_node), gather_dim):
# Decomposed mms should not be too small
if _get_tensor(shard_node).shape[-1] < 1024:
return
# scaled_mm is not supported yet for last dim
def _filter_out_scaled_matmul(matmul: _Matmul):
return not isinstance(matmul, _ScaledMatmul)
filter_matmul = _filter_out_scaled_matmul
# Find consumer matmuls
matmuls = _find_consumer_matmuls(ag_res_node)
@ -663,18 +675,29 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
if len(matmuls) == 0 or len(OrderedSet(map(type, matmuls))) != 1:
return
if _is_last_dim(_get_tensor(shard_node), gather_dim) and len(
all_gather.res_node.users
) > len(matmuls):
# The result of ag-split-cat is used not only in matmuls.
# Then it has to be materialized, which can have overhead.
return
if filter_matmul and not filter_matmul(matmuls[0]):
return
# Fuse the all_gather_tensor with the eligible matmuls
graph = ag_node.graph
with graph.inserting_before(ag_node):
if "val" in shard_node.meta:
restrided = restride_A_shard_for_fused_all_gather_matmul(
_get_tensor(shard_node),
gather_dim,
)
shard_node = graph.call_function(
inductor_prims.force_stride_order,
args=(shard_node, restrided.stride()),
)
if not _is_last_dim(_get_tensor(shard_node), gather_dim):
if "val" in shard_node.meta:
restrided = restride_A_shard_for_fused_all_gather_matmul(
_get_tensor(shard_node),
gather_dim,
)
shard_node = graph.call_function(
inductor_prims.force_stride_order,
args=(shard_node, restrided.stride()),
)
fused_node = _insert_fused_all_gather_matmul(
graph, matmuls, shard_node, gather_dim, group_name
@ -881,7 +904,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
return
filter_matmul = None
if orig_scatter_dim == _get_tensor(input_node).ndim - 1:
if _is_last_dim(_get_tensor(input_node), orig_scatter_dim):
# scaled_mm is not supported yet for last dim mm+rs
def _filter_out_scaled_matmul(matmul: _Matmul):
return not isinstance(matmul, _ScaledMatmul)

View File

@ -524,6 +524,19 @@ def _fused_all_gather_matmul_impl(
group = c10d._resolve_process_group(group_name)
if gather_dim == A_shard.ndim - 1 or gather_dim == -1:
return _fused_all_gather_matmul_last_gather_dim_impl(
mm_out_op,
A_shard,
Bs,
A_scale,
kwargs_list,
out_dtypes,
gather_dim,
group_name,
return_A,
)
# Move the gather_dim to the front and flatten the tensor into a 2D matrix.
# The flattened tensor doesn't need to be contiguous (for computation
# efficiency), as _pipelined_all_gather_and_consume guarantees that shards
@ -624,6 +637,140 @@ def _fused_all_gather_matmul_impl(
return A, [unflatten(output) for output in outputs]
def _pipelined_all_gather_and_consume_last_dim(
shard: torch.Tensor,
shard_consumer: Callable[[torch.Tensor, int], None],
ag_out: torch.Tensor,
group_name: str,
ag_out_needed: bool = True,
) -> None:
p2p_workspace_size_req = 0
p2p_workspace_size_req = shard.numel() * shard.element_size()
symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req)
group_size = symm_mem.world_size
rank = symm_mem.rank
symm_mem.barrier(channel=0)
backend_stream = _get_backend_stream()
backend_stream.wait_stream(torch.cuda.current_stream())
def copy_shard(dst: torch.Tensor, src: torch.Tensor) -> None:
dst.copy_(src)
def get_p2p_buf(remote_rank: int) -> torch.Tensor:
buf = symm_mem.get_buffer(
remote_rank,
shard.shape,
shard.dtype,
)
return buf
local_p2p_buf = get_p2p_buf(rank)
shards = ag_out.chunk(group_size)
copy_shard(dst=local_p2p_buf, src=shard)
symm_mem.barrier(channel=1)
backend_stream.wait_stream(torch.cuda.current_stream())
# At this point, all ranks have copied their local shard to
# their local p2p buffer. Each rank can now copy and consume
# remote shards.
shard_consumer(shard, rank)
for step in range(1, group_size):
if step % 2 == 0:
stream = torch.cuda.current_stream()
else:
stream = backend_stream
remote_rank = (step + rank) % group_size
remote_p2p_buf = get_p2p_buf(remote_rank)
with stream:
copy_shard(dst=shards[remote_rank], src=remote_p2p_buf)
shard_consumer(shards[remote_rank], remote_rank)
if ag_out_needed:
# Copy from input to the all-gather output. Opportunistically overlap
# it with the last shard_consumer.
if group_size % 2 == 0:
stream = torch.cuda.current_stream()
else:
stream = backend_stream
with stream:
copy_shard(dst=shards[rank], src=shard)
torch.cuda.current_stream().wait_stream(backend_stream)
symm_mem.barrier(channel=0)
def _fused_all_gather_matmul_last_gather_dim_impl(
mm_out_op: torch._ops.OpOverload,
A_shard: torch.Tensor,
Bs: list[torch.Tensor],
A_scale: torch.Tensor | None,
kwargs_list: list[dict[str, Any]],
out_dtypes: list[torch.dtype | None],
gather_dim: int,
group_name: str,
return_A: bool,
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
group = c10d._resolve_process_group(group_name)
group_size = group.size()
B_shards = [B.chunk(group.size()) for B in Bs]
leading_dims = list(A_shard.shape[:-1])
A_shard_flat = A_shard.flatten(0, -2)
def unflatten(t: torch.Tensor) -> torch.Tensor:
return t.view(*leading_dims, -1)
A_flat_out = A_shard_flat.new_empty(
A_shard_flat.shape[0] * group.size(),
A_shard_flat.shape[1],
)
outputs = [
torch.empty(
(A_shard_flat.shape[0], B.shape[1]),
dtype=out_dtype or B.dtype,
device=A_shard.device,
)
for B, out_dtype in zip(Bs, out_dtypes)
]
first = True
events = [torch.cuda.Event() for _ in outputs]
def default_consumer(shard: torch.Tensor, rank: int) -> None:
nonlocal first
for out, event, B_shard, kwargs in zip(outputs, events, B_shards, kwargs_list):
event.wait()
if first:
torch.ops.aten.mm.out(shard, B_shard[rank], **kwargs, out=out)
else:
out.addmm_(shard, B_shard[rank])
event.record()
first = False
_pipelined_all_gather_and_consume_last_dim(
A_shard_flat,
default_consumer,
A_flat_out,
group_name,
return_A,
)
ret_A = None
if return_A:
# This path is inefficient and will be filtered out at passes stage
# Added only for completeness.
A_split_cat_out_flat = torch.cat(A_flat_out.chunk(group_size), dim=-1)
ret_A = unflatten(A_split_cat_out_flat)
return ret_A, [unflatten(output) for output in outputs]
@torch.library.impl(lib, "fused_all_gather_matmul", "Meta")
def _fused_all_gather_matmul_fallback(
A_shard: torch.Tensor,
@ -638,6 +785,15 @@ def _fused_all_gather_matmul_fallback(
A_shard.contiguous(), group_size, group_name
)
A = torch.ops._c10d_functional.wait_tensor(A)
if gather_dim == A.ndim - 1 or gather_dim == -1:
A_splits = A.chunk(group_size)
A_mm = torch.cat(A_splits, dim=-1)
res = [torch.matmul(A_mm, B) for B in Bs]
if return_A:
return A_mm, res
else:
return None, res
A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1)
res = [torch.matmul(A, B).movedim(0, gather_dim) for B in Bs]
if return_A:

View File

@ -18,6 +18,7 @@ log = logging.getLogger(__name__)
__all__ = [
"annotate",
"annotate_fn",
"preserve_node_meta",
"has_preserved_node_meta",
"set_stack_trace",
@ -266,9 +267,10 @@ def annotate(annotation_dict: dict):
into the FX trace metadata.
Example:
After exiting the context, custom annotations are removed.
>>> with annotate({"source": "custom_pass", "tag": 42}):
... # compute here
# After exiting the context, custom annotations are removed.
... pass # Your computation here
"""
global current_meta
@ -291,6 +293,43 @@ def annotate(annotation_dict: dict):
del current_meta["custom"]
@compatibility(is_backward_compatible=False)
def annotate_fn(annotation_dict: dict):
"""
A decorator that wraps a function with the annotate context manager.
Use this when you want to annotate an entire function instead of a specific code block.
Note:
This API is **not backward compatible** and may evolve in future releases.
Note:
This API is not compatible with fx.symbolic_trace or jit.trace. It's intended
to be used with PT2 family of tracers, e.g. torch.export and dynamo.
Args:
annotation_dict (dict): A dictionary of custom key-value pairs to inject
into the FX trace metadata for all operations in the function.
Example:
All operations in my_function will have {"pp_stage": 1} in their metadata.
>>> @annotate_fn({"pp_stage": 1})
... def my_function(x):
... return x + 1
"""
from functools import wraps
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
with annotate(annotation_dict):
return func(*args, **kwargs)
return wrapper
return decorator
@compatibility(is_backward_compatible=False)
def set_grad_fn_seq_nr(seq_nr):
global current_meta