Files
pytorch/torch/distributed/pipelining/_utils.py
Howard Huang 5e8b95605f [PP] Support OVERLAP_F_B computation type (#158978)
Some changes to validation code and visualizer to support a new computation type that will be used in DualPipeV (see https://github.com/pytorch/pytorch/pull/159591)

The IR looks like:

```
[0F0, 0F1, 0F2, 0F3, 0F4, 0F5, 0F6, 7F0, 7I0, 7W0, 7F1, 7I1, 7W1, 7F2, 7I2, 7W2, 7F3, (0F7;7B3)OVERLAP_F_B, (7F4;0B0)OVERLAP_F_B, (0F8;7B4)OVERLAP_F_B, (7F5;0B1)OVERLAP_F_B, (0F9;7B5)OVERLAP_F_B, (7F6;0B2)OVERLAP_F_B, 7B6, (7F7;0B3)OVERLAP_F_B, 7B7, (7F8;0B4)OVERLAP_F_B, 7B8, (7F9;0B5)OVERLAP_F_B, 7B9, 0I6, 0W6, 0I7, 0W7, 0I8, 0W8, 0I9, 0W9]
[1F0, 1F1, 1F2, 1F3, 1F4, 6F0, 1F5, 6F1, 6I0, 6W0, 6F2, 6I1, 6W1, 6F3, (1F6;6B2)OVERLAP_F_B, (6F4;1B0)OVERLAP_F_B, (1F7;6B3)OVERLAP_F_B, (6F5;1B1)OVERLAP_F_B, (1F8;6B4)OVERLAP_F_B, (6F6;1B2)OVERLAP_F_B, (1F9;6B5)OVERLAP_F_B, (6F7;1B3)OVERLAP_F_B, 6B6, (6F8;1B4)OVERLAP_F_B, 6B7, (6F9;1B5)OVERLAP_F_B, 6B8, 1B6, 6I9, 1I7, 6W9, 1I8, 1W7, 1I9, 1W8, 1W9]
[2F0, 2F1, 2F2, 5F0, 2F3, 5F1, 2F4, 5F2, 5I0, 5W0, 5F3, (2F5;5B1)OVERLAP_F_B, (5F4;2B0)OVERLAP_F_B, (2F6;5B2)OVERLAP_F_B, (5F5;2B1)OVERLAP_F_B, (2F7;5B3)OVERLAP_F_B, (5F6;2B2)OVERLAP_F_B, (2F8;5B4)OVERLAP_F_B, (5F7;2B3)OVERLAP_F_B, (2F9;5B5)OVERLAP_F_B, (5F8;2B4)OVERLAP_F_B, 5B6, (5F9;2B5)OVERLAP_F_B, 5B7, 2B6, 5B8, 2I7, 5I9, 2I8, 2W7, 2I9, 5W9, 2W8, 2W9]
[3F0, 4F0, 3F1, 4F1, 3F2, 4F2, 3F3, 4F3, 3F4, 4B0, (4F4;3B0)OVERLAP_F_B, (3F5;4B1)OVERLAP_F_B, (4F5;3B1)OVERLAP_F_B, (3F6;4B2)OVERLAP_F_B, (4F6;3B2)OVERLAP_F_B, (3F7;4B3)OVERLAP_F_B, (4F7;3B3)OVERLAP_F_B, (3F8;4B4)OVERLAP_F_B, (4F8;3B4)OVERLAP_F_B, (3F9;4B5)OVERLAP_F_B, (4F9;3B5)OVERLAP_F_B, 4B6, 3B6, 4B7, 3B7, 4I8, 3I8, 4I9, 3I9, 4W8, 3W8, 4W9, 3W9]
```

In this PR, the schedule execution will just treat the OVERLAP_F_B as two separate operations of F and B (so there is no actual overlap). The next step is to allow users to create a custom function to plug in what this operation does.

814629043a/torch/distributed/pipelining/schedules.py (L1205-L1216)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158978
Approved by: https://github.com/wconstab
2025-08-01 20:22:30 +00:00

161 lines
4.6 KiB
Python

# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from dataclasses import dataclass
from typing import Union
import torch
from torch import fx
logger = logging.getLogger(__name__)
def flatten_args_detach(args):
"""
Flatten the args into a list form and detach the tensors from computational graph.
"""
flat_detached_args = []
def extract_tensor_args(a):
nonlocal flat_detached_args
if isinstance(a, torch.Tensor):
val = a.detach().requires_grad_(a.requires_grad)
flat_detached_args.append(val)
return val
else:
flat_detached_args.append(a)
return a
new_args = fx.node.map_aggregate(
args,
extract_tensor_args,
)
return new_args, flat_detached_args
def flatten_args(args):
"""
Flatten the args into a list form.
"""
flat_args = []
def extract_tensor_args(a):
nonlocal flat_args
flat_args.append(a)
return a
fx.node.map_aggregate(
args,
extract_tensor_args,
)
return flat_args
class PipeliningShapeError(RuntimeError):
"""Shape mismatch between configured and runtime values."""
def validate_tensor_metadata(desc, expected, given):
if not expected.shape == given.shape:
raise PipeliningShapeError(
f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}"
)
if not expected.dtype == given.dtype:
raise PipeliningShapeError(
f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}"
)
if not expected.stride() == given.stride():
raise PipeliningShapeError(
f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}"
)
def validate_tensors_metadata(
desc,
expected_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
actual_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
):
if len(expected_tensors) != len(actual_tensors):
raise PipeliningShapeError(
f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})"
)
for i in range(len(expected_tensors)):
validate_tensor_metadata(
f"{desc}: value {i}", expected_tensors[i], actual_tensors[i]
)
def generate_stage_to_rank_mapping(
pp_size: int, num_stages: int, style: str = "loop"
) -> dict[int, int]:
"""
Compute the stage id to rank mapping for either a looped or V-style schedule.
Most commonly num_stages == pp_size * 2, but this function can be used to
compute the mapping for any number of stages per rank.
"""
mapping = {}
if style == "loop":
for stage_index in range(num_stages):
mapping[stage_index] = stage_index % pp_size
elif style == "v":
if num_stages % pp_size != 0:
raise ValueError(
f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size} for V schedules"
)
rank_index = 0
for stage_index in range(num_stages):
mapping[stage_index] = rank_index
# dont change rank if we are on the border (to keep v shape)
if (stage_index + 1) % pp_size == 0:
continue
if (stage_index // pp_size) % 2 == 0:
rank_index += 1
else:
rank_index -= 1
else:
raise ValueError(f"Style {style} is not supported.")
return mapping
def generate_rank_to_stage_mapping(
pp_size: int, num_stages: int, style: str = "loop"
) -> dict[int, list[int]]:
"""
Compute the rank to stage id mapping for either a looped or V-style schedule.
This function inverts the stage_to_rank_mapping to get which stages are assigned to each rank.
Returns a dictionary mapping rank -> list of stage indices assigned to that rank.
"""
stage_to_rank = generate_stage_to_rank_mapping(pp_size, num_stages, style)
# Invert the mapping: rank -> list of stages
rank_to_stages: dict[int, list[int]] = {}
for stage_id, rank in stage_to_rank.items():
if rank not in rank_to_stages:
rank_to_stages[rank] = []
rank_to_stages[rank].append(stage_id)
# Sort the stage lists for each rank to ensure consistent ordering
for stages in rank_to_stages.values():
stages.sort()
return rank_to_stages
@dataclass
class PipeInfo:
"""
Captures information for a pipeline (`Pipe` object).
"""
graph: fx.Graph
num_stages: int
has_loss_and_backward: bool