mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FR] Make pg_name unique, show P2P collective status and fix bugs when running the script as command (#134780)
Fixes a bunches of bugs in the script when running with the generated command and 3D parallel. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134780 Approved by: https://github.com/c-p-i-o ghstack dependencies: #134528
This commit is contained in:
@ -46,13 +46,16 @@ class FlightRecorderEventTest(TestCase):
|
||||
"all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||
)
|
||||
membership = {"0": {0, 1}}
|
||||
self.assertEqual(match_one_event(e1, e1, membership), MatchState.FULLY_MATCHED)
|
||||
self.assertEqual(
|
||||
match_one_event(e1, e1, membership, "0"), MatchState.FULLY_MATCHED
|
||||
)
|
||||
|
||||
e2 = create_one_event(
|
||||
"all_gather", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e1, e2, membership), MatchState.COLLECTIVE_TYPE_MISMATCH
|
||||
match_one_event(e1, e2, membership, "0"),
|
||||
MatchState.COLLECTIVE_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
e3 = create_one_event(
|
||||
@ -61,34 +64,35 @@ class FlightRecorderEventTest(TestCase):
|
||||
e4 = create_one_event(
|
||||
"all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||
)
|
||||
self.assertEqual(match_one_event(e3, e4, membership), MatchState.UNDECIDED)
|
||||
self.assertEqual(match_one_event(e3, e4, membership, "0"), MatchState.UNDECIDED)
|
||||
|
||||
e5 = create_one_event(
|
||||
"all_reduce", ("0", "default"), [[5, 4]], [[4, 4]], "scheduled", 1, 1
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e1, e5, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
match_one_event(e1, e5, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
)
|
||||
|
||||
e6 = create_one_event(
|
||||
"all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 1, 2
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e1, e6, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
match_one_event(e1, e6, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
)
|
||||
|
||||
e7 = create_one_event(
|
||||
"all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 2
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e7, e7, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
match_one_event(e7, e7, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
)
|
||||
|
||||
e9 = create_one_event(
|
||||
"all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "completed", 1
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e1, e9, membership), MatchState.COLLECTIVE_STATE_MISMATCH
|
||||
match_one_event(e1, e9, membership, "0"),
|
||||
MatchState.COLLECTIVE_STATE_MISMATCH,
|
||||
)
|
||||
|
||||
e10 = create_one_event(
|
||||
@ -101,7 +105,8 @@ class FlightRecorderEventTest(TestCase):
|
||||
output_dtypes="float16",
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e10, e9, membership), MatchState.COLLECTIVE_DTYPE_MISMATCH
|
||||
match_one_event(e10, e9, membership, "0"),
|
||||
MatchState.COLLECTIVE_DTYPE_MISMATCH,
|
||||
)
|
||||
|
||||
|
||||
|
@ -49,7 +49,13 @@ Flat DB builder
|
||||
|
||||
def build_groups_memberships(
|
||||
pg_config: Any,
|
||||
) -> Tuple[List[Group], Dict[Any, Group], List[Membership], Dict[str, Set[Any]]]:
|
||||
) -> Tuple[
|
||||
List[Group],
|
||||
Dict[Any, Group],
|
||||
List[Membership],
|
||||
Dict[str, Set[Any]],
|
||||
Dict[Tuple[str, int], str],
|
||||
]:
|
||||
"""
|
||||
pg_config: {
|
||||
global_rank: {
|
||||
@ -71,6 +77,7 @@ def build_groups_memberships(
|
||||
`_groups`: a dict that is indexed by pg_guid with Group namedtuple as value.
|
||||
`memberships`: a membership table where each row is a Membership namedtuple.
|
||||
`_memberships`: a dict that is indexed by pg_guid with set of ranks (int) as value.
|
||||
`_pg_guids`: a dict that is indexed by (pg_uid, global_rank) with pg_guid as value.
|
||||
"""
|
||||
# flat lists for return
|
||||
groups = []
|
||||
@ -79,10 +86,16 @@ def build_groups_memberships(
|
||||
# dicts for faster cross-rank validation
|
||||
_groups = {}
|
||||
_memberships = {}
|
||||
_pg_guids = {}
|
||||
for global_rank in pg_config:
|
||||
for pg_guid in pg_config[global_rank]:
|
||||
desc = pg_config[global_rank][pg_guid]["desc"]
|
||||
ranks = ast.literal_eval(pg_config[global_rank][pg_guid]["ranks"])
|
||||
for pg_uid in pg_config[global_rank]:
|
||||
desc = pg_config[global_rank][pg_uid]["desc"]
|
||||
ranks = ast.literal_eval(pg_config[global_rank][pg_uid]["ranks"])
|
||||
# With the adoption of the split_group API, we can have multiple PGs with the same pg_guid (PG Name)
|
||||
# So we need to add the hash of all its ranks within the PG as well.
|
||||
# Also guid must be a string because `_process_group_name` returns a string.
|
||||
pg_guid = pg_uid + str(hash(frozenset(ranks)))
|
||||
_pg_guids[(pg_uid, global_rank)] = pg_guid
|
||||
if isinstance(ranks, str):
|
||||
# TODO Bug in FR data format? ranks is '[0, 1,...]'
|
||||
ranks = eval(ranks)
|
||||
@ -97,18 +110,18 @@ def build_groups_memberships(
|
||||
# validation across ranks
|
||||
assert (
|
||||
_groups[pg_guid].desc == desc
|
||||
), f"mismatch in desc {_groups[pg_guid].desc} vs {desc}"
|
||||
), f"mismatch in desc {_groups[pg_guid].desc} vs {desc} for group {pg_guid}"
|
||||
assert _memberships[pg_guid] == set(
|
||||
ranks
|
||||
), f"mismatch in membership {_memberships[pg_guid]} vs {set(ranks)}"
|
||||
return groups, _groups, memberships, _memberships
|
||||
), f"mismatch in membership for group {pg_guid} {_memberships[pg_guid]} vs {set(ranks)}"
|
||||
return groups, _groups, memberships, _memberships, _pg_guids
|
||||
|
||||
|
||||
def build_nccl_call(
|
||||
entry: Dict[Any, Any],
|
||||
id: int,
|
||||
collective_id: Any,
|
||||
group_id: int,
|
||||
group_id: str,
|
||||
global_rank: Any,
|
||||
) -> NCCLCall:
|
||||
return NCCLCall(
|
||||
@ -126,6 +139,7 @@ def build_collectives(
|
||||
all_entries: Dict[int, List[Dict[str, Any]]],
|
||||
_groups: Dict[str, Group],
|
||||
_memberships: Dict[str, Set[Any]],
|
||||
_pg_guids: Dict[Tuple[str, int], str],
|
||||
) -> Tuple[List[Traceback], List[Collective], List[NCCLCall]]:
|
||||
"""
|
||||
groups, memberships are the non-flat dicts that are indexable
|
||||
@ -186,7 +200,14 @@ def build_collectives(
|
||||
entries = all_entries[first_rank]
|
||||
pg_name, desc = entries[0]["process_group"]
|
||||
profiling_name = entries[0]["profiling_name"]
|
||||
pg_name = _pg_guids[(pg_name, first_rank)]
|
||||
collective_seq_id = entries[0]["collective_seq_id"]
|
||||
print(
|
||||
"collective_seq_id ",
|
||||
collective_seq_id,
|
||||
" p2p_seq_id ",
|
||||
entries[0]["p2p_seq_id"],
|
||||
)
|
||||
record_id = entries[0]["record_id"]
|
||||
input_sizes = entries[0]["input_sizes"]
|
||||
output_sizes = entries[0]["output_sizes"]
|
||||
@ -198,7 +219,7 @@ def build_collectives(
|
||||
found_ranks = set()
|
||||
found_idx = {}
|
||||
|
||||
if find_coalesced_group(pg_name, entries):
|
||||
if find_coalesced_group(pg_name, entries, _pg_guids, first_rank):
|
||||
expected_ranks.add(first_rank)
|
||||
done_ranks = set()
|
||||
all_coalesced_entries = {}
|
||||
@ -206,13 +227,13 @@ def build_collectives(
|
||||
curr = expected_ranks.pop()
|
||||
done_ranks.add(curr)
|
||||
grp = (
|
||||
find_coalesced_group(pg_name, all_entries[curr]) # type: ignore[index]
|
||||
find_coalesced_group(pg_name, all_entries[curr], _pg_guids, curr) # type: ignore[index]
|
||||
if curr in all_entries # type: ignore[comparison-overlap]
|
||||
else []
|
||||
)
|
||||
all_coalesced_entries[curr] = grp
|
||||
for index, entry in grp:
|
||||
op = Op(entry, _memberships)
|
||||
op = Op(entry, _memberships, pg_name)
|
||||
peer = None
|
||||
if op.type == "send":
|
||||
assert op._src_g == curr, (op._src_g, curr)
|
||||
@ -228,6 +249,7 @@ def build_collectives(
|
||||
group_size=_groups[pg_name].size,
|
||||
groups=_groups,
|
||||
memberships=_memberships,
|
||||
_pg_guids=_pg_guids,
|
||||
)
|
||||
|
||||
if match and mismatch[pg_name] == 0:
|
||||
@ -256,10 +278,13 @@ def build_collectives(
|
||||
# step over ops from other PGs
|
||||
# only check match state when seq_id matches
|
||||
if (
|
||||
e["process_group"] == (pg_name, desc)
|
||||
_pg_guids[(e["process_group"][0], o)] == pg_name
|
||||
and e["process_group"][1] == desc
|
||||
and e["collective_seq_id"] == collective_seq_id
|
||||
):
|
||||
match_state = match_one_event(entries[0], e, _memberships)
|
||||
match_state = match_one_event(
|
||||
entries[0], e, _memberships, pg_name
|
||||
)
|
||||
if (
|
||||
match_state
|
||||
in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED]
|
||||
@ -403,17 +428,19 @@ def build_db(details: Dict[str, Dict[str, Any]], args: argparse.Namespace) -> Da
|
||||
entries = sort_trace_from_beginning(entries)
|
||||
|
||||
# flattened database
|
||||
groups, _groups, memberships, _memberships = build_groups_memberships(pg_config)
|
||||
groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships(
|
||||
pg_config
|
||||
)
|
||||
print("built groups, memberships")
|
||||
|
||||
check_no_missing_dump_files(entries, memberships)
|
||||
|
||||
if args.just_print_entries:
|
||||
just_print_entries(entries, _groups, _memberships)
|
||||
just_print_entries(entries, _groups, _memberships, _pg_guids)
|
||||
sys.exit(0)
|
||||
|
||||
tracebacks, collectives, nccl_calls = build_collectives(
|
||||
entries, _groups, _memberships
|
||||
entries, _groups, _memberships, _pg_guids
|
||||
)
|
||||
print("built collectives, nccl_calls")
|
||||
if args.verbose:
|
||||
|
@ -5,6 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from typing import Optional, Sequence
|
||||
|
||||
|
||||
class JobConfig:
|
||||
@ -30,5 +31,7 @@ class JobConfig:
|
||||
self.parser.add_argument("-j", "--just_print_entries", action="store_true")
|
||||
self.parser.add_argument("-v", "--verbose", action="store_true")
|
||||
|
||||
def parse_args(self: "JobConfig") -> argparse.Namespace:
|
||||
return self.parser.parse_args()
|
||||
def parse_args(
|
||||
self: "JobConfig", args: Optional[Sequence[str]]
|
||||
) -> argparse.Namespace:
|
||||
return self.parser.parse_args(args)
|
||||
|
@ -41,9 +41,7 @@ def read_dir(prefix: str, folder: str) -> Dict[str, Dict[str, Any]]:
|
||||
t0 = time.time()
|
||||
for root, _, files in os.walk(folder):
|
||||
for f in files:
|
||||
ta = time.time()
|
||||
details[f] = read_dump(prefix, os.path.join(root, f))
|
||||
tb = time.time()
|
||||
# print(f"read file {f} in {tb - ta}s")
|
||||
print(f"loaded {len(files)} files in {tb - t0}s")
|
||||
return details
|
||||
|
@ -68,13 +68,13 @@ NCCLOp:
|
||||
|
||||
|
||||
class Group(NamedTuple):
|
||||
id: int
|
||||
id: str
|
||||
desc: str
|
||||
size: int
|
||||
|
||||
|
||||
class Membership(NamedTuple):
|
||||
group_id: Ref[Group]
|
||||
group_id: str
|
||||
global_rank: int
|
||||
|
||||
|
||||
@ -85,13 +85,13 @@ class Traceback(NamedTuple):
|
||||
|
||||
class Collective(NamedTuple):
|
||||
id: int
|
||||
group_id: Ref[Group]
|
||||
group_id: str
|
||||
|
||||
|
||||
class NCCLCall(NamedTuple):
|
||||
id: int
|
||||
collective_id: Ref[Collective]
|
||||
group_id: Ref[Group]
|
||||
group_id: str
|
||||
global_rank: int # technically Ref[Process] once we have it
|
||||
traceback_id: Ref[Traceback]
|
||||
collective_type: str
|
||||
@ -188,7 +188,9 @@ class Op:
|
||||
nccl:recv 3<-0
|
||||
"""
|
||||
|
||||
def __init__(self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]]):
|
||||
def __init__(
|
||||
self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]], pg_name: str
|
||||
):
|
||||
profiling_name = event["profiling_name"]
|
||||
nccl, name = profiling_name.split(":")
|
||||
assert nccl == "nccl", f"name formatting error? {nccl} != 'nccl'"
|
||||
@ -209,7 +211,7 @@ class Op:
|
||||
self._dst, self._src = int(d), int(s)
|
||||
else:
|
||||
self._src, self._dst = -1, -1
|
||||
pg_name, pg_desc = event["process_group"]
|
||||
_, pg_desc = event["process_group"]
|
||||
self._init_global_src_dst(memberships[pg_name])
|
||||
self.pg_size = len(memberships[pg_name])
|
||||
if type in P2P | COLLECTIVES:
|
||||
@ -239,10 +241,8 @@ class Op:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.type in P2P:
|
||||
return (
|
||||
f"{self.type}(s={self._src_g} d={self._dst_g}, sz={self.input_sizes})"
|
||||
)
|
||||
return f"{self.type}(input_sizes={self.input_sizes}, {self.state})"
|
||||
return f"{self.type}(s={self._src_g} d={self._dst_g}, sz={self.input_sizes}, state={self.state})"
|
||||
return f"{self.type}(input_sizes={self.input_sizes}, state={self.state})"
|
||||
|
||||
def match(self, other: "Op") -> MatchState:
|
||||
# TODO: I think this can validly not match,
|
||||
|
@ -5,7 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Set, Tuple # type: ignore[attr-defined]
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
from tools.flight_recorder.components.types import (
|
||||
Group,
|
||||
@ -40,9 +40,10 @@ def match_one_event(
|
||||
event_a: Dict[Any, Any],
|
||||
event_b: Dict[Any, Any],
|
||||
memberships: Dict[str, Set[Any]],
|
||||
pg_name: str,
|
||||
) -> MatchState:
|
||||
op_a = Op(event_a, memberships)
|
||||
op_b = Op(event_b, memberships)
|
||||
op_a = Op(event_a, memberships, pg_name)
|
||||
op_b = Op(event_b, memberships, pg_name)
|
||||
return op_a.match(op_b)
|
||||
|
||||
|
||||
@ -51,6 +52,7 @@ def match_coalesced_groups(
|
||||
group_size: int,
|
||||
groups: Dict[str, Group],
|
||||
memberships: Dict[str, Set[Any]],
|
||||
_pg_guids: Dict[Tuple[str, int], str],
|
||||
) -> bool:
|
||||
"""
|
||||
all_rank_events: {
|
||||
@ -76,13 +78,22 @@ def match_coalesced_groups(
|
||||
rank1 [recv:0 (1000B), recv:0 (100B)] —> not okay
|
||||
"""
|
||||
all_ops = {
|
||||
rank: [Op(e, memberships) for i, e in all_rank_events[rank]]
|
||||
rank: [
|
||||
Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
|
||||
for i, e in all_rank_events[rank]
|
||||
]
|
||||
for rank in all_rank_events
|
||||
}
|
||||
|
||||
def visualize_ops(match: bool) -> None:
|
||||
def visualize_ops(
|
||||
match: bool,
|
||||
_pg_guids: Dict[Tuple[str, int], str],
|
||||
) -> None:
|
||||
all_ops = {
|
||||
rank: [Op(e, memberships) for i, e in all_rank_events[rank]]
|
||||
rank: [
|
||||
Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
|
||||
for i, e in all_rank_events[rank]
|
||||
]
|
||||
for rank in all_rank_events
|
||||
}
|
||||
|
||||
@ -94,8 +105,14 @@ def match_coalesced_groups(
|
||||
progress = False
|
||||
for r in all_ops:
|
||||
if len(all_ops[r]) > i:
|
||||
_, event = all_rank_events[r][i]
|
||||
row.append(Op(event, memberships))
|
||||
rank, event = all_rank_events[r][i]
|
||||
row.append(
|
||||
Op(
|
||||
event,
|
||||
memberships,
|
||||
_pg_guids[(event["process_group"][0], rank)],
|
||||
)
|
||||
)
|
||||
progress = True
|
||||
else:
|
||||
row.append(None) # type: ignore[arg-type]
|
||||
@ -144,10 +161,10 @@ def match_coalesced_groups(
|
||||
my_ops.pop(0)
|
||||
peer_ops.pop(match_idx)
|
||||
else:
|
||||
visualize_ops(False)
|
||||
visualize_ops(False, _pg_guids)
|
||||
return False
|
||||
|
||||
visualize_ops(True)
|
||||
visualize_ops(True, _pg_guids)
|
||||
return True
|
||||
|
||||
|
||||
@ -161,21 +178,27 @@ def check_size_alltoall(alltoall_cases: List[Dict[str, Any]]) -> Tuple[bool, int
|
||||
|
||||
|
||||
def find_coalesced_group(
|
||||
pg_name: str, entries: List[Dict[str, Any]]
|
||||
pg_name: str,
|
||||
entries: List[Dict[str, Any]],
|
||||
_pg_guids: Dict[Tuple[str, int], str],
|
||||
rank: int,
|
||||
) -> List[Tuple[int, Dict[str, Any]]]:
|
||||
"""Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones,
|
||||
build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id
|
||||
TODO: handle p2p_seq_id v/s collective_seq_id separately here.
|
||||
"""
|
||||
found = []
|
||||
collective_seq_id = None
|
||||
for i, e in enumerate(entries):
|
||||
if e["process_group"][0] != pg_name:
|
||||
if _pg_guids[(e["process_group"][0], rank)] != pg_name:
|
||||
continue
|
||||
elif collective_seq_id is None:
|
||||
collective_seq_id = e["collective_seq_id"]
|
||||
collective_seq_id = (
|
||||
e["p2p_seq_id"] if e["is_p2p"] else e["collective_seq_id"]
|
||||
)
|
||||
found.append((i, e))
|
||||
elif e["collective_seq_id"] == collective_seq_id:
|
||||
elif not e["is_p2p"] and e["collective_seq_id"] == collective_seq_id:
|
||||
found.append((i, e))
|
||||
elif e["is_p2p"] and e["p2p_seq_id"] == collective_seq_id:
|
||||
found.append((i, e))
|
||||
else:
|
||||
break
|
||||
@ -190,6 +213,7 @@ def just_print_entries(
|
||||
all_entries: Dict[int, List[Dict[str, Any]]],
|
||||
_groups: Dict[str, Group],
|
||||
_memberships: Dict[str, Set[Any]],
|
||||
_pg_guids: Dict[Tuple[str, int], str],
|
||||
) -> None:
|
||||
rows = []
|
||||
ranks = sorted(all_entries.keys())
|
||||
@ -203,7 +227,8 @@ def just_print_entries(
|
||||
row.append("")
|
||||
else:
|
||||
entry = all_entries[rank].pop(0)
|
||||
row.append(str(Op(entry, _memberships)))
|
||||
pg_name = _pg_guids[(entry["process_group"][0], rank)]
|
||||
row.append(str(Op(entry, _memberships, pg_name)))
|
||||
progress = True
|
||||
if progress:
|
||||
rows.append(row)
|
||||
|
@ -28,8 +28,8 @@ python fr_trace.py -d <dump dir containing trace files> [-o <output file>]
|
||||
- This script is versioned so that we can ensure our future changes to flight recorder are backwards compatible.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pickle
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from tools.flight_recorder.components.builder import build_db
|
||||
from tools.flight_recorder.components.config_manager import JobConfig
|
||||
@ -37,7 +37,9 @@ from tools.flight_recorder.components.loader import read_dir
|
||||
from tools.flight_recorder.components.types import types
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
def main(args: Optional[Sequence[str]] = None) -> None:
|
||||
config = JobConfig()
|
||||
args = config.parse_args(args)
|
||||
details = read_dir(args.prefix, args.dir)
|
||||
db = build_db(details, args)
|
||||
if args.output:
|
||||
@ -46,5 +48,4 @@ def main(args: argparse.Namespace) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = JobConfig()
|
||||
main(config.parse_args())
|
||||
main()
|
||||
|
Reference in New Issue
Block a user