mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 One more PR after this one. Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165002 Approved by: https://github.com/oulgen
438 lines
16 KiB
Python
438 lines
16 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
|
|
"""
|
|
This visualizer requires matplotlib to be installed.
|
|
|
|
Example usage:
|
|
|
|
ops = get_schedule_ops("InterleavedZeroBubble", 4, 8)
|
|
visualize_schedule(ops, "test.png")
|
|
"""
|
|
|
|
import collections
|
|
from typing import NamedTuple, Optional, Union
|
|
from unittest import mock
|
|
|
|
from torch.distributed.pipelining.schedules import (
|
|
_Action,
|
|
_ComputationType,
|
|
_PipelineSchedule,
|
|
_PipelineScheduleRuntime,
|
|
get_schedule_class,
|
|
PipelineScheduleMulti,
|
|
PipelineScheduleSingle,
|
|
)
|
|
from torch.distributed.pipelining.stage import PipelineStage
|
|
|
|
|
|
class OpKey(NamedTuple):
|
|
stage_index: int
|
|
computation_type: _ComputationType
|
|
microbatch_index: int
|
|
|
|
|
|
def get_schedule_ops(
|
|
schedule: Union[str, type[_PipelineSchedule]],
|
|
pp_degree: int,
|
|
num_microbatches: int,
|
|
num_stages_per_rank: Optional[int] = None,
|
|
add_spacing: bool = False,
|
|
with_comms: bool = False,
|
|
) -> list[list[Optional[_Action]]]:
|
|
"""
|
|
Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists
|
|
where each inner list represents a rank and each element in the inner list represents an action.
|
|
|
|
The schedule can be specified as a string which is passed into get_schedule_class() or a _PipelineSchedule instance.
|
|
"""
|
|
if add_spacing and with_comms:
|
|
raise ValueError("Cannot add spacing and view comms at the same time")
|
|
|
|
if isinstance(schedule, str):
|
|
schedule_class = get_schedule_class(schedule)
|
|
elif issubclass(schedule, _PipelineSchedule):
|
|
schedule_class = schedule
|
|
else:
|
|
raise ValueError(f"Invalid schedule: {schedule}")
|
|
|
|
# Create a mock of the PipelineStage class
|
|
mock_pipeline_stage = mock.create_autospec(PipelineStage, instance=True)
|
|
# Set the return values for group_rank and group_size methods
|
|
mock_pipeline_stage.group_rank = 0
|
|
mock_pipeline_stage.group_size = pp_degree
|
|
mock_pipeline_stage.submod = None
|
|
|
|
# Check num_stages_per_rank is valid
|
|
if issubclass(schedule_class, PipelineScheduleSingle):
|
|
if num_stages_per_rank is None:
|
|
num_stages_per_rank = 1
|
|
assert num_stages_per_rank == 1
|
|
stages = mock_pipeline_stage
|
|
stages.num_stages = num_stages_per_rank * pp_degree
|
|
elif issubclass(schedule_class, PipelineScheduleMulti):
|
|
if num_stages_per_rank is None:
|
|
num_stages_per_rank = 2
|
|
assert num_stages_per_rank >= 2
|
|
stages = [mock_pipeline_stage for _ in range(num_stages_per_rank)]
|
|
for stage in stages:
|
|
stage.num_stages = num_stages_per_rank * pp_degree
|
|
|
|
else:
|
|
raise ValueError(f"Invalid schedule: {schedule_class}")
|
|
|
|
# Instantiate the schedule class
|
|
# pyrefly: ignore # bad-instantiation, bad-argument-type
|
|
schedule_instance = schedule_class(stages, num_microbatches)
|
|
assert schedule_instance.pipeline_order is not None
|
|
|
|
# Convert to List[List[_Action]]
|
|
all_actions: list[list[Optional[_Action]]] = []
|
|
if with_comms:
|
|
runtime = _PipelineScheduleRuntime(stages, num_microbatches)
|
|
runtime._prepare_schedule_with_comms(schedule_instance.pipeline_order)
|
|
for rank in range(pp_degree):
|
|
all_actions.append(list(runtime.pipeline_order_with_comms[rank]))
|
|
else:
|
|
for rank in range(pp_degree):
|
|
all_actions.append(schedule_instance.pipeline_order[rank])
|
|
|
|
# Add spacing
|
|
if add_spacing:
|
|
# remove all Nones, then respace
|
|
# TODO: later we can change this at the schedule creation level to not use Nones
|
|
all_actions = [
|
|
[action for action in rank if action is not None] for rank in all_actions
|
|
]
|
|
all_actions = add_schedule_op_spacing(all_actions)
|
|
|
|
# Return the pipeline order
|
|
return all_actions
|
|
|
|
|
|
class _ComputationTypeVisual:
|
|
def __init__(
|
|
self,
|
|
color: str,
|
|
text: str = "",
|
|
width: int = 1,
|
|
):
|
|
self.color = color
|
|
self.width = width
|
|
self.text = text
|
|
|
|
|
|
# Update the mapping to use _ComputationTypeVisual instances
|
|
action_type_to_color_mapping = {
|
|
_ComputationType.FORWARD: _ComputationTypeVisual("blue", "Forward"),
|
|
_ComputationType.BACKWARD_INPUT: _ComputationTypeVisual("teal", "Backward Input"),
|
|
_ComputationType.BACKWARD_WEIGHT: _ComputationTypeVisual(
|
|
"green", "Backward Weight"
|
|
),
|
|
_ComputationType.FULL_BACKWARD: _ComputationTypeVisual(
|
|
"orange", "Full Backward", 2
|
|
),
|
|
_ComputationType.OVERLAP_F_B: _ComputationTypeVisual("purple", "Overlap F+B", 3),
|
|
}
|
|
|
|
|
|
def add_schedule_op_spacing(
|
|
schedule: list[list[Optional[_Action]]],
|
|
) -> list[list[Optional[_Action]]]:
|
|
"""
|
|
Add spacing to the schedule based on dependencies between ranks.
|
|
|
|
Before adding an operation to the list, this function checks if there are
|
|
dependencies from other ranks. If there are dependencies (other ranks have
|
|
not finished processing the required microbatch), it adds None instead.
|
|
|
|
For example, Forward microbatch 0 on rank 1 depends on rank 0 processing
|
|
Forward microbatch 0 first.
|
|
|
|
Args:
|
|
schedule: The original schedule as a list of lists where each inner list
|
|
represents a rank and each element represents an action.
|
|
|
|
Returns:
|
|
A new schedule with proper spacing based on dependencies.
|
|
"""
|
|
if not schedule:
|
|
return schedule
|
|
|
|
num_stages = (
|
|
max(
|
|
action.stage_index
|
|
for rank_actions in schedule
|
|
for action in rank_actions
|
|
if action is not None
|
|
)
|
|
+ 1
|
|
)
|
|
|
|
num_ranks = len(schedule)
|
|
spaced_schedule: list[list[Optional[_Action]]] = [[] for _ in range(num_ranks)]
|
|
rank_ops = [collections.deque(ops) for ops in schedule]
|
|
|
|
# Track completion times: (stage_index, action_type, microbatch_index) -> completion_time
|
|
scheduled_ops: dict[OpKey, int] = {}
|
|
|
|
def is_dependency_ready(dependency_key: OpKey, timestep: int) -> bool:
|
|
"""Check if a dependency operation has completed by the given timestep."""
|
|
return (
|
|
dependency_key in scheduled_ops
|
|
and timestep >= scheduled_ops[dependency_key]
|
|
)
|
|
|
|
def get_dependencies(action: _Action) -> list[OpKey]:
|
|
"""Get the list of dependencies for an action."""
|
|
stage_idx = action.stage_index
|
|
comp_type = action.computation_type
|
|
mb_idx = action.microbatch_index
|
|
|
|
# Ensure mb_idx is not None for dependency tracking
|
|
assert mb_idx is not None, f"Action {action} has None microbatch_index"
|
|
|
|
# First stage forward has no dependencies
|
|
if stage_idx == 0 and comp_type == _ComputationType.FORWARD:
|
|
return []
|
|
|
|
# Last stage backward depends on forward from previous stage
|
|
if stage_idx == num_stages - 1 and comp_type in (
|
|
_ComputationType.FULL_BACKWARD,
|
|
_ComputationType.BACKWARD_INPUT,
|
|
):
|
|
return [OpKey(stage_idx - 1, _ComputationType.FORWARD, mb_idx)]
|
|
|
|
# Forward depends on previous stage forward
|
|
if comp_type == _ComputationType.FORWARD:
|
|
return [OpKey(stage_idx - 1, _ComputationType.FORWARD, mb_idx)]
|
|
|
|
# Backward depends on next stage backward
|
|
if comp_type in (
|
|
_ComputationType.FULL_BACKWARD,
|
|
_ComputationType.BACKWARD_INPUT,
|
|
):
|
|
return [
|
|
OpKey(stage_idx + 1, _ComputationType.FULL_BACKWARD, mb_idx),
|
|
OpKey(stage_idx + 1, _ComputationType.BACKWARD_INPUT, mb_idx),
|
|
]
|
|
|
|
# Weight backward depends on input backward
|
|
if comp_type == _ComputationType.BACKWARD_WEIGHT:
|
|
return [OpKey(stage_idx, _ComputationType.BACKWARD_INPUT, mb_idx)]
|
|
|
|
raise RuntimeError(f"Unknown computation type: {comp_type}")
|
|
|
|
def is_action_ready(action: _Action, timestep: int) -> bool:
|
|
"""Check if an action is ready to be scheduled at the given timestep."""
|
|
# For OR dependencies (like backward), check if any dependency is satisfied
|
|
if action.computation_type in (
|
|
_ComputationType.FULL_BACKWARD,
|
|
_ComputationType.BACKWARD_INPUT,
|
|
_ComputationType.BACKWARD_WEIGHT,
|
|
):
|
|
dependencies = get_dependencies(action)
|
|
return any(is_dependency_ready(dep, timestep) for dep in dependencies)
|
|
# For AND dependencies, all must be satisfied
|
|
elif action.computation_type == _ComputationType.FORWARD:
|
|
dependencies = get_dependencies(action)
|
|
return all(is_dependency_ready(dep, timestep) for dep in dependencies)
|
|
elif action.computation_type == _ComputationType.OVERLAP_F_B:
|
|
assert action.sub_actions is not None, (
|
|
f"OVERLAP_F_B action {action} has None sub_actions"
|
|
)
|
|
dep_list: list[bool] = []
|
|
for sub_action in action.sub_actions:
|
|
dep_list.append(is_action_ready(sub_action, timestep))
|
|
return all(dep_list)
|
|
else:
|
|
raise RuntimeError(f"Unknown computation type: {action.computation_type}")
|
|
|
|
def schedule_action(action: _Action, rank: int, timestep: int) -> int:
|
|
"""Schedule an action and return completion time."""
|
|
spaced_schedule[rank].append(action)
|
|
comp_type = action.computation_type
|
|
comp_time = action_type_to_color_mapping[comp_type].width
|
|
completion_time = timestep + comp_time
|
|
|
|
if comp_type == _ComputationType.OVERLAP_F_B:
|
|
# For overlap actions, schedule each sub-action with cumulative timing
|
|
assert action.sub_actions is not None, (
|
|
f"OVERLAP_F_B action {action} has None sub_actions"
|
|
)
|
|
cumulative_time = 0
|
|
for sub_action in action.sub_actions:
|
|
assert sub_action.microbatch_index is not None, (
|
|
f"Sub-action {sub_action} has None microbatch_index"
|
|
)
|
|
sub_comp_time = action_type_to_color_mapping[
|
|
sub_action.computation_type
|
|
].width
|
|
cumulative_time += sub_comp_time
|
|
scheduled_ops[
|
|
OpKey(
|
|
sub_action.stage_index,
|
|
sub_action.computation_type,
|
|
sub_action.microbatch_index,
|
|
)
|
|
] = timestep + cumulative_time
|
|
else:
|
|
assert action.microbatch_index is not None, (
|
|
f"Action {action} has None microbatch_index"
|
|
)
|
|
scheduled_ops[
|
|
OpKey(action.stage_index, comp_type, action.microbatch_index)
|
|
] = completion_time
|
|
|
|
return completion_time
|
|
|
|
# Main scheduling loop
|
|
current_timestep = 0
|
|
timesteps_without_progress = 0
|
|
rank_completion_times = dict.fromkeys(range(num_ranks), 0)
|
|
while rank_ops:
|
|
print(f"Current timestep: {current_timestep}")
|
|
# Process all operations during timestep until we run out of ready operations
|
|
for rank, op_queue in enumerate(rank_ops):
|
|
if not op_queue:
|
|
continue
|
|
|
|
op_queue = rank_ops[rank]
|
|
action = op_queue[0]
|
|
print(f"Rank: {rank}, {action=}")
|
|
if action is None:
|
|
spaced_schedule[rank].append(None)
|
|
op_queue.popleft()
|
|
timesteps_without_progress = 0
|
|
elif current_timestep >= rank_completion_times[rank] and is_action_ready(
|
|
action, current_timestep
|
|
):
|
|
rank_completion_times[rank] = schedule_action(
|
|
action, rank, current_timestep
|
|
)
|
|
op_queue.popleft()
|
|
timesteps_without_progress = 0
|
|
|
|
# Add None for ranks that are waiting
|
|
for rank in range(num_ranks):
|
|
if current_timestep >= rank_completion_times[rank]:
|
|
spaced_schedule[rank].append(None)
|
|
|
|
# Remove empty queues and advance timestep
|
|
rank_ops = [op_queue for op_queue in rank_ops if op_queue]
|
|
current_timestep += 1
|
|
timesteps_without_progress += 1
|
|
|
|
if timesteps_without_progress > max(
|
|
visual.width for visual in action_type_to_color_mapping.values()
|
|
):
|
|
raise RuntimeError("No progress made in scheduling - possible deadlock")
|
|
|
|
return spaced_schedule
|
|
|
|
|
|
def visualize_schedule(
|
|
schedule: list[list[Optional[_Action]]],
|
|
filename: Optional[str] = None,
|
|
) -> None:
|
|
"""
|
|
Visualize the schedule using matplotlib.
|
|
The schedule is a list of lists where each inner list represents a rank and each element in the inner list represents an action.
|
|
The actions are represented as rectangles with different colors based on their computation type.
|
|
The filename is optional and if provided, the plot will be saved to that file.
|
|
|
|
Args:
|
|
schedule: The schedule to visualize.
|
|
filename: The filename to save the plot to. If not provided, the plot will be displayed.
|
|
add_schedule_spacing: If True, add spacing to the schedule based on dependencies between ranks.
|
|
|
|
"""
|
|
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.patches import Rectangle
|
|
|
|
plt.rcParams["font.family"] = (
|
|
"DejaVu Sans" # or any other font available on your system
|
|
)
|
|
num_ranks = len(schedule)
|
|
max_actions = max(len(rank) for rank in schedule)
|
|
|
|
# Increase the figure size to provide more space for the legend
|
|
fig, ax = plt.subplots(figsize=(max_actions + 2, num_ranks + 2))
|
|
max_draw_position = -1
|
|
# Calculate dynamic font size based on figure size
|
|
font_size = min(max_actions, num_ranks) + 4
|
|
used_computation = set()
|
|
for rank_idx, actions in enumerate(schedule):
|
|
draw_position = 0 # Initialize drawing position for each rank
|
|
for action in actions:
|
|
if action is not None:
|
|
comp_type_color = action_type_to_color_mapping.get(
|
|
action.computation_type, _ComputationTypeVisual("black")
|
|
)
|
|
used_computation.add(action.computation_type)
|
|
color = comp_type_color.color
|
|
width = comp_type_color.width
|
|
|
|
# Check if action has sub_actions to determine styling
|
|
if action.sub_actions is not None:
|
|
linewidth = 2 # Thicker border for compound actions
|
|
text_weight = "normal" # Bold text for compound actions
|
|
else:
|
|
linewidth = 1 # Default linewidth for regular actions
|
|
text_weight = "normal" # Default text weight
|
|
|
|
# Draw the rectangle to represent the action duration
|
|
rect = Rectangle(
|
|
(draw_position, num_ranks - rank_idx - 1),
|
|
width,
|
|
1,
|
|
facecolor=color,
|
|
edgecolor="black",
|
|
linewidth=linewidth,
|
|
)
|
|
ax.add_patch(rect)
|
|
|
|
# Draw the text centered within the rectangle
|
|
ax.text(
|
|
draw_position + width / 2,
|
|
num_ranks - rank_idx - 1 + 0.5,
|
|
str(action),
|
|
ha="center",
|
|
va="center",
|
|
fontsize=font_size,
|
|
color="white",
|
|
weight=text_weight,
|
|
)
|
|
|
|
draw_position += width
|
|
else:
|
|
draw_position += 1 # Move to the next
|
|
max_draw_position = max(max_draw_position, draw_position)
|
|
ax.set_xlim(-0.5, max_draw_position + 1)
|
|
ax.set_ylim(-0.5, num_ranks + 0.5) # Add extra space at the top
|
|
# Set y-ticks to be in the middle of each rank's row
|
|
ax.set_yticks([num_ranks - rank_idx - 0.5 for rank_idx in range(num_ranks)])
|
|
ax.set_yticklabels([f"Rank {i}" for i in range(num_ranks)], fontsize=font_size)
|
|
ax.set_xticklabels([])
|
|
|
|
# Remove grid lines and ticks
|
|
ax.grid(False)
|
|
# Add legend with larger font size
|
|
legend_elements = [
|
|
Rectangle(
|
|
(0, 0),
|
|
1,
|
|
1,
|
|
facecolor=action_type_to_color_mapping[comp_type].color,
|
|
edgecolor="black",
|
|
label=action_type_to_color_mapping[comp_type].text,
|
|
)
|
|
for comp_type in used_computation
|
|
]
|
|
ax.legend(handles=legend_elements, loc="upper right", fontsize=font_size)
|
|
# Save to file if filename is provided, otherwise display the plot
|
|
if filename:
|
|
plt.savefig(filename, bbox_inches="tight")
|
|
else:
|
|
plt.show()
|