mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[PP] Customize pipeline's submod name (#164037)
Changing PP submodules' name from `submod_i` to `submod_pp_i` to distinguish from the submodule created by HOP. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164037 Approved by: https://github.com/H-Huang ghstack dependencies: #164045, #164035
This commit is contained in:
@ -35,6 +35,16 @@ logger = logging.getLogger(__name__)
|
||||
# 2. Add parameter movement to split_module
|
||||
|
||||
|
||||
PP_SUBMOD_PREFIX = "submod_pp"
|
||||
|
||||
|
||||
def get_submod_name(stage_idx: int):
|
||||
"""Returns the name of the submod for a given stage index.
|
||||
For example, "submod_pp_0", "submod_pp_1", etc.
|
||||
"""
|
||||
return "_".join([PP_SUBMOD_PREFIX, str(stage_idx)])
|
||||
|
||||
|
||||
def _find_loss_from_output_and_spec(output_val, spec_val):
|
||||
if spec_val is False:
|
||||
return None
|
||||
@ -593,7 +603,7 @@ class Pipe(torch.nn.Module):
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
name = f"submod_{i}"
|
||||
name = get_submod_name(i)
|
||||
submod = getattr(self.split_gm, name)
|
||||
submod.__class__.__reduce__ = _direct_serialization_reduce
|
||||
i += 1
|
||||
@ -639,15 +649,17 @@ class Pipe(torch.nn.Module):
|
||||
"""
|
||||
if stage_idx < 0 or stage_idx >= self.num_stages:
|
||||
raise ValueError(f"Invalid stage index {stage_idx}!")
|
||||
return getattr(self.split_gm, f"submod_{stage_idx}")
|
||||
|
||||
submod_name = get_submod_name(stage_idx)
|
||||
return getattr(self.split_gm, submod_name)
|
||||
|
||||
@staticmethod
|
||||
def _number_and_count_forward_stages(gm: fx.GraphModule):
|
||||
num_stages = 0
|
||||
found_idxs: dict[int, None] = {}
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_module" and node.target.startswith("submod_"):
|
||||
node.meta["stage_idx"] = int(node.target[len("submod_") :])
|
||||
if node.op == "call_module" and node.target.startswith(PP_SUBMOD_PREFIX):
|
||||
node.meta["stage_idx"] = int(node.target[len(PP_SUBMOD_PREFIX) + 1 :])
|
||||
found_idxs.setdefault(node.meta["stage_idx"])
|
||||
num_stages += 1
|
||||
|
||||
@ -729,7 +741,7 @@ class Pipe(torch.nn.Module):
|
||||
|
||||
# TODO: what does split do with module invocations? does it move the modules
|
||||
# into the submodules?
|
||||
split = split_module(traced, mod, split_callback) # type: ignore[arg-type]
|
||||
split = split_module(traced, mod, split_callback, partition_affix="pp") # type: ignore[arg-type]
|
||||
# a (custom) tracer can produce dead code like orphan get_attr nodes
|
||||
split.graph.eliminate_dead_code()
|
||||
|
||||
|
Reference in New Issue
Block a user