mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
The [schedule visualizer](https://github.com/pytorch/pytorch/blob/main/torch/distributed/pipelining/_schedule_visualizer.py) relies on `self.pipeline_order` to be populated. The `_PipelineScheduleRuntime` also depends on this to run the IR. The single stage schedules do not implement this so this PR adds that. Also fixes a bug in the schedule visualizer Pull Request resolved: https://github.com/pytorch/pytorch/pull/155935 Approved by: https://github.com/wconstab
190 lines
6.8 KiB
Python
190 lines
6.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
|
|
"""
|
|
This visualizer requires matplotlib to be installed.
|
|
|
|
Example usage:
|
|
|
|
ops = get_schedule_ops("InterleavedZeroBubble", 4, 8)
|
|
visualize_schedule(ops, "test.png")
|
|
"""
|
|
|
|
from typing import Optional, Union
|
|
from unittest import mock
|
|
|
|
from torch.distributed.pipelining.schedules import (
|
|
_Action,
|
|
_ComputationType,
|
|
_PipelineSchedule,
|
|
get_schedule_class,
|
|
PipelineScheduleMulti,
|
|
PipelineScheduleSingle,
|
|
)
|
|
from torch.distributed.pipelining.stage import PipelineStage
|
|
|
|
|
|
def get_schedule_ops(
|
|
schedule: Union[str, _PipelineSchedule],
|
|
pp_degree: int,
|
|
num_microbatches: int,
|
|
num_stages_per_rank: Optional[int] = None,
|
|
) -> list[list[Optional[_Action]]]:
|
|
"""
|
|
Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists
|
|
where each inner list represents a rank and each element in the inner list represents an action.
|
|
|
|
The schedule can be specified as a string which is passed into get_schedule_class() or a _PipelineSchedule instance.
|
|
"""
|
|
|
|
if isinstance(schedule, str):
|
|
schedule_class = get_schedule_class(schedule)
|
|
elif type(schedule) == _PipelineSchedule:
|
|
schedule_class = schedule
|
|
else:
|
|
raise ValueError(f"Invalid schedule: {schedule}")
|
|
|
|
# Create a mock of the PipelineStage class
|
|
mock_pipeline_stage = mock.create_autospec(PipelineStage, instance=True)
|
|
# Set the return values for group_rank and group_size methods
|
|
mock_pipeline_stage.group_rank = 0
|
|
mock_pipeline_stage.group_size = pp_degree
|
|
mock_pipeline_stage.submod = None
|
|
|
|
# Check num_stages_per_rank is valid
|
|
if issubclass(schedule_class, PipelineScheduleSingle):
|
|
if num_stages_per_rank is None:
|
|
num_stages_per_rank = 1
|
|
assert num_stages_per_rank == 1
|
|
stages = mock_pipeline_stage
|
|
stages.num_stages = num_stages_per_rank * pp_degree
|
|
elif issubclass(schedule_class, PipelineScheduleMulti):
|
|
if num_stages_per_rank is None:
|
|
num_stages_per_rank = 2
|
|
assert num_stages_per_rank >= 2
|
|
stages = [mock_pipeline_stage for _ in range(num_stages_per_rank)]
|
|
for stage in stages:
|
|
stage.num_stages = num_stages_per_rank * pp_degree
|
|
|
|
else:
|
|
raise ValueError(f"Invalid schedule: {schedule_class}")
|
|
|
|
# Instantiate the schedule class
|
|
schedule_instance = schedule_class(stages, num_microbatches)
|
|
|
|
# Convert to List[List[_Action]]
|
|
all_actions = []
|
|
for rank in range(pp_degree):
|
|
all_actions.append(schedule_instance.pipeline_order[rank])
|
|
|
|
# Return the pipeline order
|
|
return all_actions
|
|
|
|
|
|
class _ComputationTypeColor:
|
|
def __init__(
|
|
self,
|
|
color: str,
|
|
text: str = "",
|
|
width: int = 1,
|
|
):
|
|
self.color = color
|
|
self.width = width
|
|
self.text = text
|
|
|
|
|
|
# Update the mapping to use _ComputationTypeColor instances
|
|
action_type_to_color_mapping = {
|
|
_ComputationType.FORWARD: _ComputationTypeColor("blue", "Forward"),
|
|
_ComputationType.BACKWARD_INPUT: _ComputationTypeColor("teal", "Backward Input"),
|
|
_ComputationType.BACKWARD_WEIGHT: _ComputationTypeColor("green", "Backward Weight"),
|
|
_ComputationType.FULL_BACKWARD: _ComputationTypeColor("orange", "Full Backward", 2),
|
|
}
|
|
|
|
|
|
def visualize_schedule(
|
|
schedule: list[list[Optional[_Action]]], filename: Optional[str] = None
|
|
) -> None:
|
|
"""
|
|
Visualize the schedule using matplotlib.
|
|
The schedule is a list of lists where each inner list represents a rank and each element in the inner list represents an action.
|
|
The actions are represented as rectangles with different colors based on their computation type.
|
|
The filename is optional and if provided, the plot will be saved to that file.
|
|
"""
|
|
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.patches import Rectangle
|
|
|
|
plt.rcParams["font.family"] = (
|
|
"DejaVu Sans" # or any other font available on your system
|
|
)
|
|
num_ranks = len(schedule)
|
|
max_actions = max(len(rank) for rank in schedule)
|
|
|
|
# Increase the figure size to provide more space for the legend
|
|
fig, ax = plt.subplots(figsize=(max_actions + 2, num_ranks + 2))
|
|
max_draw_position = -1
|
|
# Calculate dynamic font size based on figure size
|
|
font_size = min(max_actions, num_ranks) + 4
|
|
used_computation = set()
|
|
for rank_idx, actions in enumerate(schedule):
|
|
draw_position = 0 # Initialize drawing position for each rank
|
|
for action in actions:
|
|
if action is not None:
|
|
comp_type_color = action_type_to_color_mapping.get(
|
|
action.computation_type, _ComputationTypeColor("black")
|
|
)
|
|
used_computation.add(action.computation_type)
|
|
color = comp_type_color.color
|
|
width = comp_type_color.width
|
|
# Draw the rectangle to represent the action duration
|
|
rect = Rectangle(
|
|
(draw_position, num_ranks - rank_idx - 1),
|
|
width,
|
|
1,
|
|
facecolor=color,
|
|
edgecolor="black",
|
|
)
|
|
ax.add_patch(rect)
|
|
# Draw the text centered within the rectangle
|
|
ax.text(
|
|
draw_position + width / 2,
|
|
num_ranks - rank_idx - 1 + 0.5,
|
|
str(action),
|
|
ha="center",
|
|
va="center",
|
|
fontsize=font_size,
|
|
color="white",
|
|
)
|
|
# Increment the drawing position by the width of the current action
|
|
draw_position += width
|
|
else:
|
|
draw_position += 1 # Move to the next
|
|
max_draw_position = max(max_draw_position, draw_position)
|
|
ax.set_xlim(-0.5, max_draw_position + 1)
|
|
ax.set_ylim(-0.5, num_ranks + 0.5) # Add extra space at the top
|
|
# Set y-ticks to be in the middle of each rank's row
|
|
ax.set_yticks([num_ranks - rank_idx - 0.5 for rank_idx in range(num_ranks)])
|
|
ax.set_yticklabels([f"Rank {i}" for i in range(num_ranks)], fontsize=font_size)
|
|
ax.set_xticklabels([])
|
|
|
|
# Remove grid lines and ticks
|
|
ax.grid(False)
|
|
# Add legend with larger font size
|
|
legend_elements = [
|
|
Rectangle(
|
|
(0, 0),
|
|
1,
|
|
1,
|
|
facecolor=action_type_to_color_mapping[comp_type].color,
|
|
edgecolor="black",
|
|
label=action_type_to_color_mapping[comp_type].text,
|
|
)
|
|
for comp_type in used_computation
|
|
]
|
|
ax.legend(handles=legend_elements, loc="upper right", fontsize=font_size)
|
|
# Save to file if filename is provided, otherwise display the plot
|
|
if filename:
|
|
plt.savefig(filename, bbox_inches="tight")
|
|
else:
|
|
plt.show()
|