Files
pytorch/torch/distributed/pipelining/_schedule_visualizer.py
Maggie Moss 7457d139c5 Add pyrefly suppressions to torch/distributed (7/n) (#165002)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

One more PR after this one.

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165002
Approved by: https://github.com/oulgen
2025-10-09 04:08:25 +00:00

438 lines
16 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
"""
This visualizer requires matplotlib to be installed.
Example usage:
ops = get_schedule_ops("InterleavedZeroBubble", 4, 8)
visualize_schedule(ops, "test.png")
"""
import collections
from typing import NamedTuple, Optional, Union
from unittest import mock
from torch.distributed.pipelining.schedules import (
_Action,
_ComputationType,
_PipelineSchedule,
_PipelineScheduleRuntime,
get_schedule_class,
PipelineScheduleMulti,
PipelineScheduleSingle,
)
from torch.distributed.pipelining.stage import PipelineStage
class OpKey(NamedTuple):
stage_index: int
computation_type: _ComputationType
microbatch_index: int
def get_schedule_ops(
schedule: Union[str, type[_PipelineSchedule]],
pp_degree: int,
num_microbatches: int,
num_stages_per_rank: Optional[int] = None,
add_spacing: bool = False,
with_comms: bool = False,
) -> list[list[Optional[_Action]]]:
"""
Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists
where each inner list represents a rank and each element in the inner list represents an action.
The schedule can be specified as a string which is passed into get_schedule_class() or a _PipelineSchedule instance.
"""
if add_spacing and with_comms:
raise ValueError("Cannot add spacing and view comms at the same time")
if isinstance(schedule, str):
schedule_class = get_schedule_class(schedule)
elif issubclass(schedule, _PipelineSchedule):
schedule_class = schedule
else:
raise ValueError(f"Invalid schedule: {schedule}")
# Create a mock of the PipelineStage class
mock_pipeline_stage = mock.create_autospec(PipelineStage, instance=True)
# Set the return values for group_rank and group_size methods
mock_pipeline_stage.group_rank = 0
mock_pipeline_stage.group_size = pp_degree
mock_pipeline_stage.submod = None
# Check num_stages_per_rank is valid
if issubclass(schedule_class, PipelineScheduleSingle):
if num_stages_per_rank is None:
num_stages_per_rank = 1
assert num_stages_per_rank == 1
stages = mock_pipeline_stage
stages.num_stages = num_stages_per_rank * pp_degree
elif issubclass(schedule_class, PipelineScheduleMulti):
if num_stages_per_rank is None:
num_stages_per_rank = 2
assert num_stages_per_rank >= 2
stages = [mock_pipeline_stage for _ in range(num_stages_per_rank)]
for stage in stages:
stage.num_stages = num_stages_per_rank * pp_degree
else:
raise ValueError(f"Invalid schedule: {schedule_class}")
# Instantiate the schedule class
# pyrefly: ignore # bad-instantiation, bad-argument-type
schedule_instance = schedule_class(stages, num_microbatches)
assert schedule_instance.pipeline_order is not None
# Convert to List[List[_Action]]
all_actions: list[list[Optional[_Action]]] = []
if with_comms:
runtime = _PipelineScheduleRuntime(stages, num_microbatches)
runtime._prepare_schedule_with_comms(schedule_instance.pipeline_order)
for rank in range(pp_degree):
all_actions.append(list(runtime.pipeline_order_with_comms[rank]))
else:
for rank in range(pp_degree):
all_actions.append(schedule_instance.pipeline_order[rank])
# Add spacing
if add_spacing:
# remove all Nones, then respace
# TODO: later we can change this at the schedule creation level to not use Nones
all_actions = [
[action for action in rank if action is not None] for rank in all_actions
]
all_actions = add_schedule_op_spacing(all_actions)
# Return the pipeline order
return all_actions
class _ComputationTypeVisual:
def __init__(
self,
color: str,
text: str = "",
width: int = 1,
):
self.color = color
self.width = width
self.text = text
# Update the mapping to use _ComputationTypeVisual instances
action_type_to_color_mapping = {
_ComputationType.FORWARD: _ComputationTypeVisual("blue", "Forward"),
_ComputationType.BACKWARD_INPUT: _ComputationTypeVisual("teal", "Backward Input"),
_ComputationType.BACKWARD_WEIGHT: _ComputationTypeVisual(
"green", "Backward Weight"
),
_ComputationType.FULL_BACKWARD: _ComputationTypeVisual(
"orange", "Full Backward", 2
),
_ComputationType.OVERLAP_F_B: _ComputationTypeVisual("purple", "Overlap F+B", 3),
}
def add_schedule_op_spacing(
schedule: list[list[Optional[_Action]]],
) -> list[list[Optional[_Action]]]:
"""
Add spacing to the schedule based on dependencies between ranks.
Before adding an operation to the list, this function checks if there are
dependencies from other ranks. If there are dependencies (other ranks have
not finished processing the required microbatch), it adds None instead.
For example, Forward microbatch 0 on rank 1 depends on rank 0 processing
Forward microbatch 0 first.
Args:
schedule: The original schedule as a list of lists where each inner list
represents a rank and each element represents an action.
Returns:
A new schedule with proper spacing based on dependencies.
"""
if not schedule:
return schedule
num_stages = (
max(
action.stage_index
for rank_actions in schedule
for action in rank_actions
if action is not None
)
+ 1
)
num_ranks = len(schedule)
spaced_schedule: list[list[Optional[_Action]]] = [[] for _ in range(num_ranks)]
rank_ops = [collections.deque(ops) for ops in schedule]
# Track completion times: (stage_index, action_type, microbatch_index) -> completion_time
scheduled_ops: dict[OpKey, int] = {}
def is_dependency_ready(dependency_key: OpKey, timestep: int) -> bool:
"""Check if a dependency operation has completed by the given timestep."""
return (
dependency_key in scheduled_ops
and timestep >= scheduled_ops[dependency_key]
)
def get_dependencies(action: _Action) -> list[OpKey]:
"""Get the list of dependencies for an action."""
stage_idx = action.stage_index
comp_type = action.computation_type
mb_idx = action.microbatch_index
# Ensure mb_idx is not None for dependency tracking
assert mb_idx is not None, f"Action {action} has None microbatch_index"
# First stage forward has no dependencies
if stage_idx == 0 and comp_type == _ComputationType.FORWARD:
return []
# Last stage backward depends on forward from previous stage
if stage_idx == num_stages - 1 and comp_type in (
_ComputationType.FULL_BACKWARD,
_ComputationType.BACKWARD_INPUT,
):
return [OpKey(stage_idx - 1, _ComputationType.FORWARD, mb_idx)]
# Forward depends on previous stage forward
if comp_type == _ComputationType.FORWARD:
return [OpKey(stage_idx - 1, _ComputationType.FORWARD, mb_idx)]
# Backward depends on next stage backward
if comp_type in (
_ComputationType.FULL_BACKWARD,
_ComputationType.BACKWARD_INPUT,
):
return [
OpKey(stage_idx + 1, _ComputationType.FULL_BACKWARD, mb_idx),
OpKey(stage_idx + 1, _ComputationType.BACKWARD_INPUT, mb_idx),
]
# Weight backward depends on input backward
if comp_type == _ComputationType.BACKWARD_WEIGHT:
return [OpKey(stage_idx, _ComputationType.BACKWARD_INPUT, mb_idx)]
raise RuntimeError(f"Unknown computation type: {comp_type}")
def is_action_ready(action: _Action, timestep: int) -> bool:
"""Check if an action is ready to be scheduled at the given timestep."""
# For OR dependencies (like backward), check if any dependency is satisfied
if action.computation_type in (
_ComputationType.FULL_BACKWARD,
_ComputationType.BACKWARD_INPUT,
_ComputationType.BACKWARD_WEIGHT,
):
dependencies = get_dependencies(action)
return any(is_dependency_ready(dep, timestep) for dep in dependencies)
# For AND dependencies, all must be satisfied
elif action.computation_type == _ComputationType.FORWARD:
dependencies = get_dependencies(action)
return all(is_dependency_ready(dep, timestep) for dep in dependencies)
elif action.computation_type == _ComputationType.OVERLAP_F_B:
assert action.sub_actions is not None, (
f"OVERLAP_F_B action {action} has None sub_actions"
)
dep_list: list[bool] = []
for sub_action in action.sub_actions:
dep_list.append(is_action_ready(sub_action, timestep))
return all(dep_list)
else:
raise RuntimeError(f"Unknown computation type: {action.computation_type}")
def schedule_action(action: _Action, rank: int, timestep: int) -> int:
"""Schedule an action and return completion time."""
spaced_schedule[rank].append(action)
comp_type = action.computation_type
comp_time = action_type_to_color_mapping[comp_type].width
completion_time = timestep + comp_time
if comp_type == _ComputationType.OVERLAP_F_B:
# For overlap actions, schedule each sub-action with cumulative timing
assert action.sub_actions is not None, (
f"OVERLAP_F_B action {action} has None sub_actions"
)
cumulative_time = 0
for sub_action in action.sub_actions:
assert sub_action.microbatch_index is not None, (
f"Sub-action {sub_action} has None microbatch_index"
)
sub_comp_time = action_type_to_color_mapping[
sub_action.computation_type
].width
cumulative_time += sub_comp_time
scheduled_ops[
OpKey(
sub_action.stage_index,
sub_action.computation_type,
sub_action.microbatch_index,
)
] = timestep + cumulative_time
else:
assert action.microbatch_index is not None, (
f"Action {action} has None microbatch_index"
)
scheduled_ops[
OpKey(action.stage_index, comp_type, action.microbatch_index)
] = completion_time
return completion_time
# Main scheduling loop
current_timestep = 0
timesteps_without_progress = 0
rank_completion_times = dict.fromkeys(range(num_ranks), 0)
while rank_ops:
print(f"Current timestep: {current_timestep}")
# Process all operations during timestep until we run out of ready operations
for rank, op_queue in enumerate(rank_ops):
if not op_queue:
continue
op_queue = rank_ops[rank]
action = op_queue[0]
print(f"Rank: {rank}, {action=}")
if action is None:
spaced_schedule[rank].append(None)
op_queue.popleft()
timesteps_without_progress = 0
elif current_timestep >= rank_completion_times[rank] and is_action_ready(
action, current_timestep
):
rank_completion_times[rank] = schedule_action(
action, rank, current_timestep
)
op_queue.popleft()
timesteps_without_progress = 0
# Add None for ranks that are waiting
for rank in range(num_ranks):
if current_timestep >= rank_completion_times[rank]:
spaced_schedule[rank].append(None)
# Remove empty queues and advance timestep
rank_ops = [op_queue for op_queue in rank_ops if op_queue]
current_timestep += 1
timesteps_without_progress += 1
if timesteps_without_progress > max(
visual.width for visual in action_type_to_color_mapping.values()
):
raise RuntimeError("No progress made in scheduling - possible deadlock")
return spaced_schedule
def visualize_schedule(
schedule: list[list[Optional[_Action]]],
filename: Optional[str] = None,
) -> None:
"""
Visualize the schedule using matplotlib.
The schedule is a list of lists where each inner list represents a rank and each element in the inner list represents an action.
The actions are represented as rectangles with different colors based on their computation type.
The filename is optional and if provided, the plot will be saved to that file.
Args:
schedule: The schedule to visualize.
filename: The filename to save the plot to. If not provided, the plot will be displayed.
add_schedule_spacing: If True, add spacing to the schedule based on dependencies between ranks.
"""
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
plt.rcParams["font.family"] = (
"DejaVu Sans" # or any other font available on your system
)
num_ranks = len(schedule)
max_actions = max(len(rank) for rank in schedule)
# Increase the figure size to provide more space for the legend
fig, ax = plt.subplots(figsize=(max_actions + 2, num_ranks + 2))
max_draw_position = -1
# Calculate dynamic font size based on figure size
font_size = min(max_actions, num_ranks) + 4
used_computation = set()
for rank_idx, actions in enumerate(schedule):
draw_position = 0 # Initialize drawing position for each rank
for action in actions:
if action is not None:
comp_type_color = action_type_to_color_mapping.get(
action.computation_type, _ComputationTypeVisual("black")
)
used_computation.add(action.computation_type)
color = comp_type_color.color
width = comp_type_color.width
# Check if action has sub_actions to determine styling
if action.sub_actions is not None:
linewidth = 2 # Thicker border for compound actions
text_weight = "normal" # Bold text for compound actions
else:
linewidth = 1 # Default linewidth for regular actions
text_weight = "normal" # Default text weight
# Draw the rectangle to represent the action duration
rect = Rectangle(
(draw_position, num_ranks - rank_idx - 1),
width,
1,
facecolor=color,
edgecolor="black",
linewidth=linewidth,
)
ax.add_patch(rect)
# Draw the text centered within the rectangle
ax.text(
draw_position + width / 2,
num_ranks - rank_idx - 1 + 0.5,
str(action),
ha="center",
va="center",
fontsize=font_size,
color="white",
weight=text_weight,
)
draw_position += width
else:
draw_position += 1 # Move to the next
max_draw_position = max(max_draw_position, draw_position)
ax.set_xlim(-0.5, max_draw_position + 1)
ax.set_ylim(-0.5, num_ranks + 0.5) # Add extra space at the top
# Set y-ticks to be in the middle of each rank's row
ax.set_yticks([num_ranks - rank_idx - 0.5 for rank_idx in range(num_ranks)])
ax.set_yticklabels([f"Rank {i}" for i in range(num_ranks)], fontsize=font_size)
ax.set_xticklabels([])
# Remove grid lines and ticks
ax.grid(False)
# Add legend with larger font size
legend_elements = [
Rectangle(
(0, 0),
1,
1,
facecolor=action_type_to_color_mapping[comp_type].color,
edgecolor="black",
label=action_type_to_color_mapping[comp_type].text,
)
for comp_type in used_computation
]
ax.legend(handles=legend_elements, loc="upper right", fontsize=font_size)
# Save to file if filename is provided, otherwise display the plot
if filename:
plt.savefig(filename, bbox_inches="tight")
else:
plt.show()