mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Some changes to validation code and visualizer to support a new computation type that will be used in DualPipeV (see https://github.com/pytorch/pytorch/pull/159591)
The IR looks like:
```
[0F0, 0F1, 0F2, 0F3, 0F4, 0F5, 0F6, 7F0, 7I0, 7W0, 7F1, 7I1, 7W1, 7F2, 7I2, 7W2, 7F3, (0F7;7B3)OVERLAP_F_B, (7F4;0B0)OVERLAP_F_B, (0F8;7B4)OVERLAP_F_B, (7F5;0B1)OVERLAP_F_B, (0F9;7B5)OVERLAP_F_B, (7F6;0B2)OVERLAP_F_B, 7B6, (7F7;0B3)OVERLAP_F_B, 7B7, (7F8;0B4)OVERLAP_F_B, 7B8, (7F9;0B5)OVERLAP_F_B, 7B9, 0I6, 0W6, 0I7, 0W7, 0I8, 0W8, 0I9, 0W9]
[1F0, 1F1, 1F2, 1F3, 1F4, 6F0, 1F5, 6F1, 6I0, 6W0, 6F2, 6I1, 6W1, 6F3, (1F6;6B2)OVERLAP_F_B, (6F4;1B0)OVERLAP_F_B, (1F7;6B3)OVERLAP_F_B, (6F5;1B1)OVERLAP_F_B, (1F8;6B4)OVERLAP_F_B, (6F6;1B2)OVERLAP_F_B, (1F9;6B5)OVERLAP_F_B, (6F7;1B3)OVERLAP_F_B, 6B6, (6F8;1B4)OVERLAP_F_B, 6B7, (6F9;1B5)OVERLAP_F_B, 6B8, 1B6, 6I9, 1I7, 6W9, 1I8, 1W7, 1I9, 1W8, 1W9]
[2F0, 2F1, 2F2, 5F0, 2F3, 5F1, 2F4, 5F2, 5I0, 5W0, 5F3, (2F5;5B1)OVERLAP_F_B, (5F4;2B0)OVERLAP_F_B, (2F6;5B2)OVERLAP_F_B, (5F5;2B1)OVERLAP_F_B, (2F7;5B3)OVERLAP_F_B, (5F6;2B2)OVERLAP_F_B, (2F8;5B4)OVERLAP_F_B, (5F7;2B3)OVERLAP_F_B, (2F9;5B5)OVERLAP_F_B, (5F8;2B4)OVERLAP_F_B, 5B6, (5F9;2B5)OVERLAP_F_B, 5B7, 2B6, 5B8, 2I7, 5I9, 2I8, 2W7, 2I9, 5W9, 2W8, 2W9]
[3F0, 4F0, 3F1, 4F1, 3F2, 4F2, 3F3, 4F3, 3F4, 4B0, (4F4;3B0)OVERLAP_F_B, (3F5;4B1)OVERLAP_F_B, (4F5;3B1)OVERLAP_F_B, (3F6;4B2)OVERLAP_F_B, (4F6;3B2)OVERLAP_F_B, (3F7;4B3)OVERLAP_F_B, (4F7;3B3)OVERLAP_F_B, (3F8;4B4)OVERLAP_F_B, (4F8;3B4)OVERLAP_F_B, (3F9;4B5)OVERLAP_F_B, (4F9;3B5)OVERLAP_F_B, 4B6, 3B6, 4B7, 3B7, 4I8, 3I8, 4I9, 3I9, 4W8, 3W8, 4W9, 3W9]
```
In this PR, the schedule execution will just treat the OVERLAP_F_B as two separate operations of F and B (so there is no actual overlap). The next step is to allow users to create a custom function to plug in what this operation does.
814629043a/torch/distributed/pipelining/schedules.py (L1205-L1216)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158978
Approved by: https://github.com/wconstab
161 lines
4.6 KiB
Python
161 lines
4.6 KiB
Python
# mypy: allow-untyped-defs
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Union
|
|
|
|
import torch
|
|
from torch import fx
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def flatten_args_detach(args):
|
|
"""
|
|
Flatten the args into a list form and detach the tensors from computational graph.
|
|
"""
|
|
flat_detached_args = []
|
|
|
|
def extract_tensor_args(a):
|
|
nonlocal flat_detached_args
|
|
if isinstance(a, torch.Tensor):
|
|
val = a.detach().requires_grad_(a.requires_grad)
|
|
flat_detached_args.append(val)
|
|
return val
|
|
else:
|
|
flat_detached_args.append(a)
|
|
return a
|
|
|
|
new_args = fx.node.map_aggregate(
|
|
args,
|
|
extract_tensor_args,
|
|
)
|
|
|
|
return new_args, flat_detached_args
|
|
|
|
|
|
def flatten_args(args):
|
|
"""
|
|
Flatten the args into a list form.
|
|
"""
|
|
flat_args = []
|
|
|
|
def extract_tensor_args(a):
|
|
nonlocal flat_args
|
|
flat_args.append(a)
|
|
return a
|
|
|
|
fx.node.map_aggregate(
|
|
args,
|
|
extract_tensor_args,
|
|
)
|
|
|
|
return flat_args
|
|
|
|
|
|
class PipeliningShapeError(RuntimeError):
|
|
"""Shape mismatch between configured and runtime values."""
|
|
|
|
|
|
def validate_tensor_metadata(desc, expected, given):
|
|
if not expected.shape == given.shape:
|
|
raise PipeliningShapeError(
|
|
f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}"
|
|
)
|
|
if not expected.dtype == given.dtype:
|
|
raise PipeliningShapeError(
|
|
f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}"
|
|
)
|
|
if not expected.stride() == given.stride():
|
|
raise PipeliningShapeError(
|
|
f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}"
|
|
)
|
|
|
|
|
|
def validate_tensors_metadata(
|
|
desc,
|
|
expected_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
|
|
actual_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
|
|
):
|
|
if len(expected_tensors) != len(actual_tensors):
|
|
raise PipeliningShapeError(
|
|
f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})"
|
|
)
|
|
for i in range(len(expected_tensors)):
|
|
validate_tensor_metadata(
|
|
f"{desc}: value {i}", expected_tensors[i], actual_tensors[i]
|
|
)
|
|
|
|
|
|
def generate_stage_to_rank_mapping(
|
|
pp_size: int, num_stages: int, style: str = "loop"
|
|
) -> dict[int, int]:
|
|
"""
|
|
Compute the stage id to rank mapping for either a looped or V-style schedule.
|
|
|
|
Most commonly num_stages == pp_size * 2, but this function can be used to
|
|
compute the mapping for any number of stages per rank.
|
|
"""
|
|
mapping = {}
|
|
if style == "loop":
|
|
for stage_index in range(num_stages):
|
|
mapping[stage_index] = stage_index % pp_size
|
|
elif style == "v":
|
|
if num_stages % pp_size != 0:
|
|
raise ValueError(
|
|
f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size} for V schedules"
|
|
)
|
|
|
|
rank_index = 0
|
|
for stage_index in range(num_stages):
|
|
mapping[stage_index] = rank_index
|
|
# dont change rank if we are on the border (to keep v shape)
|
|
if (stage_index + 1) % pp_size == 0:
|
|
continue
|
|
if (stage_index // pp_size) % 2 == 0:
|
|
rank_index += 1
|
|
else:
|
|
rank_index -= 1
|
|
else:
|
|
raise ValueError(f"Style {style} is not supported.")
|
|
return mapping
|
|
|
|
|
|
def generate_rank_to_stage_mapping(
|
|
pp_size: int, num_stages: int, style: str = "loop"
|
|
) -> dict[int, list[int]]:
|
|
"""
|
|
Compute the rank to stage id mapping for either a looped or V-style schedule.
|
|
|
|
This function inverts the stage_to_rank_mapping to get which stages are assigned to each rank.
|
|
|
|
Returns a dictionary mapping rank -> list of stage indices assigned to that rank.
|
|
"""
|
|
stage_to_rank = generate_stage_to_rank_mapping(pp_size, num_stages, style)
|
|
|
|
# Invert the mapping: rank -> list of stages
|
|
rank_to_stages: dict[int, list[int]] = {}
|
|
for stage_id, rank in stage_to_rank.items():
|
|
if rank not in rank_to_stages:
|
|
rank_to_stages[rank] = []
|
|
rank_to_stages[rank].append(stage_id)
|
|
|
|
# Sort the stage lists for each rank to ensure consistent ordering
|
|
for stages in rank_to_stages.values():
|
|
stages.sort()
|
|
|
|
return rank_to_stages
|
|
|
|
|
|
@dataclass
|
|
class PipeInfo:
|
|
"""
|
|
Captures information for a pipeline (`Pipe` object).
|
|
"""
|
|
|
|
graph: fx.Graph
|
|
num_stages: int
|
|
has_loss_and_backward: bool
|