[PP] Customize pipeline's submod name (#164037)

Changing PP submodules' name from `submod_i` to `submod_pp_i` to distinguish from the submodule created by HOP.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164037
Approved by: https://github.com/H-Huang
ghstack dependencies: #164045, #164035
This commit is contained in:
Ke Wen
2025-09-28 13:29:46 -07:00
committed by PyTorch MergeBot
parent d58f7c3ad1
commit 704cd771f6

View File

@ -35,6 +35,16 @@ logger = logging.getLogger(__name__)
# 2. Add parameter movement to split_module
PP_SUBMOD_PREFIX = "submod_pp"
def get_submod_name(stage_idx: int):
"""Returns the name of the submod for a given stage index.
For example, "submod_pp_0", "submod_pp_1", etc.
"""
return "_".join([PP_SUBMOD_PREFIX, str(stage_idx)])
def _find_loss_from_output_and_spec(output_val, spec_val):
if spec_val is False:
return None
@ -593,7 +603,7 @@ class Pipe(torch.nn.Module):
i = 0
while True:
try:
name = f"submod_{i}"
name = get_submod_name(i)
submod = getattr(self.split_gm, name)
submod.__class__.__reduce__ = _direct_serialization_reduce
i += 1
@ -639,15 +649,17 @@ class Pipe(torch.nn.Module):
"""
if stage_idx < 0 or stage_idx >= self.num_stages:
raise ValueError(f"Invalid stage index {stage_idx}!")
return getattr(self.split_gm, f"submod_{stage_idx}")
submod_name = get_submod_name(stage_idx)
return getattr(self.split_gm, submod_name)
@staticmethod
def _number_and_count_forward_stages(gm: fx.GraphModule):
num_stages = 0
found_idxs: dict[int, None] = {}
for node in gm.graph.nodes:
if node.op == "call_module" and node.target.startswith("submod_"):
node.meta["stage_idx"] = int(node.target[len("submod_") :])
if node.op == "call_module" and node.target.startswith(PP_SUBMOD_PREFIX):
node.meta["stage_idx"] = int(node.target[len(PP_SUBMOD_PREFIX) + 1 :])
found_idxs.setdefault(node.meta["stage_idx"])
num_stages += 1
@ -729,7 +741,7 @@ class Pipe(torch.nn.Module):
# TODO: what does split do with module invocations? does it move the modules
# into the submodules?
split = split_module(traced, mod, split_callback) # type: ignore[arg-type]
split = split_module(traced, mod, split_callback, partition_affix="pp") # type: ignore[arg-type]
# a (custom) tracer can produce dead code like orphan get_attr nodes
split.graph.eliminate_dead_code()