[FR] Make pg_name unique, show P2P collective status and fix bugs when running the script as command (#134780)

Fixes a bunches of bugs in the script when running with the generated command and 3D parallel.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134780
Approved by: https://github.com/c-p-i-o
ghstack dependencies: #134528
This commit is contained in:
fduwjj
2024-08-29 20:25:14 -07:00
committed by PyTorch MergeBot
parent 15f5a4858b
commit 1993a2aa9e
7 changed files with 118 additions and 59 deletions

View File

@ -46,13 +46,16 @@ class FlightRecorderEventTest(TestCase):
"all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
)
membership = {"0": {0, 1}}
self.assertEqual(match_one_event(e1, e1, membership), MatchState.FULLY_MATCHED)
self.assertEqual(
match_one_event(e1, e1, membership, "0"), MatchState.FULLY_MATCHED
)
e2 = create_one_event(
"all_gather", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
)
self.assertEqual(
match_one_event(e1, e2, membership), MatchState.COLLECTIVE_TYPE_MISMATCH
match_one_event(e1, e2, membership, "0"),
MatchState.COLLECTIVE_TYPE_MISMATCH,
)
e3 = create_one_event(
@ -61,34 +64,35 @@ class FlightRecorderEventTest(TestCase):
e4 = create_one_event(
"all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
)
self.assertEqual(match_one_event(e3, e4, membership), MatchState.UNDECIDED)
self.assertEqual(match_one_event(e3, e4, membership, "0"), MatchState.UNDECIDED)
e5 = create_one_event(
"all_reduce", ("0", "default"), [[5, 4]], [[4, 4]], "scheduled", 1, 1
)
self.assertEqual(
match_one_event(e1, e5, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH
match_one_event(e1, e5, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
)
e6 = create_one_event(
"all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 1, 2
)
self.assertEqual(
match_one_event(e1, e6, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH
match_one_event(e1, e6, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
)
e7 = create_one_event(
"all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 2
)
self.assertEqual(
match_one_event(e7, e7, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH
match_one_event(e7, e7, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
)
e9 = create_one_event(
"all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "completed", 1
)
self.assertEqual(
match_one_event(e1, e9, membership), MatchState.COLLECTIVE_STATE_MISMATCH
match_one_event(e1, e9, membership, "0"),
MatchState.COLLECTIVE_STATE_MISMATCH,
)
e10 = create_one_event(
@ -101,7 +105,8 @@ class FlightRecorderEventTest(TestCase):
output_dtypes="float16",
)
self.assertEqual(
match_one_event(e10, e9, membership), MatchState.COLLECTIVE_DTYPE_MISMATCH
match_one_event(e10, e9, membership, "0"),
MatchState.COLLECTIVE_DTYPE_MISMATCH,
)

View File

@ -49,7 +49,13 @@ Flat DB builder
def build_groups_memberships(
pg_config: Any,
) -> Tuple[List[Group], Dict[Any, Group], List[Membership], Dict[str, Set[Any]]]:
) -> Tuple[
List[Group],
Dict[Any, Group],
List[Membership],
Dict[str, Set[Any]],
Dict[Tuple[str, int], str],
]:
"""
pg_config: {
global_rank: {
@ -71,6 +77,7 @@ def build_groups_memberships(
`_groups`: a dict that is indexed by pg_guid with Group namedtuple as value.
`memberships`: a membership table where each row is a Membership namedtuple.
`_memberships`: a dict that is indexed by pg_guid with set of ranks (int) as value.
`_pg_guids`: a dict that is indexed by (pg_uid, global_rank) with pg_guid as value.
"""
# flat lists for return
groups = []
@ -79,10 +86,16 @@ def build_groups_memberships(
# dicts for faster cross-rank validation
_groups = {}
_memberships = {}
_pg_guids = {}
for global_rank in pg_config:
for pg_guid in pg_config[global_rank]:
desc = pg_config[global_rank][pg_guid]["desc"]
ranks = ast.literal_eval(pg_config[global_rank][pg_guid]["ranks"])
for pg_uid in pg_config[global_rank]:
desc = pg_config[global_rank][pg_uid]["desc"]
ranks = ast.literal_eval(pg_config[global_rank][pg_uid]["ranks"])
# With the adoption of the split_group API, we can have multiple PGs with the same pg_guid (PG Name)
# So we need to add the hash of all its ranks within the PG as well.
# Also guid must be a string because `_process_group_name` returns a string.
pg_guid = pg_uid + str(hash(frozenset(ranks)))
_pg_guids[(pg_uid, global_rank)] = pg_guid
if isinstance(ranks, str):
# TODO Bug in FR data format? ranks is '[0, 1,...]'
ranks = eval(ranks)
@ -97,18 +110,18 @@ def build_groups_memberships(
# validation across ranks
assert (
_groups[pg_guid].desc == desc
), f"mismatch in desc {_groups[pg_guid].desc} vs {desc}"
), f"mismatch in desc {_groups[pg_guid].desc} vs {desc} for group {pg_guid}"
assert _memberships[pg_guid] == set(
ranks
), f"mismatch in membership {_memberships[pg_guid]} vs {set(ranks)}"
return groups, _groups, memberships, _memberships
), f"mismatch in membership for group {pg_guid} {_memberships[pg_guid]} vs {set(ranks)}"
return groups, _groups, memberships, _memberships, _pg_guids
def build_nccl_call(
entry: Dict[Any, Any],
id: int,
collective_id: Any,
group_id: int,
group_id: str,
global_rank: Any,
) -> NCCLCall:
return NCCLCall(
@ -126,6 +139,7 @@ def build_collectives(
all_entries: Dict[int, List[Dict[str, Any]]],
_groups: Dict[str, Group],
_memberships: Dict[str, Set[Any]],
_pg_guids: Dict[Tuple[str, int], str],
) -> Tuple[List[Traceback], List[Collective], List[NCCLCall]]:
"""
groups, memberships are the non-flat dicts that are indexable
@ -186,7 +200,14 @@ def build_collectives(
entries = all_entries[first_rank]
pg_name, desc = entries[0]["process_group"]
profiling_name = entries[0]["profiling_name"]
pg_name = _pg_guids[(pg_name, first_rank)]
collective_seq_id = entries[0]["collective_seq_id"]
print(
"collective_seq_id ",
collective_seq_id,
" p2p_seq_id ",
entries[0]["p2p_seq_id"],
)
record_id = entries[0]["record_id"]
input_sizes = entries[0]["input_sizes"]
output_sizes = entries[0]["output_sizes"]
@ -198,7 +219,7 @@ def build_collectives(
found_ranks = set()
found_idx = {}
if find_coalesced_group(pg_name, entries):
if find_coalesced_group(pg_name, entries, _pg_guids, first_rank):
expected_ranks.add(first_rank)
done_ranks = set()
all_coalesced_entries = {}
@ -206,13 +227,13 @@ def build_collectives(
curr = expected_ranks.pop()
done_ranks.add(curr)
grp = (
find_coalesced_group(pg_name, all_entries[curr]) # type: ignore[index]
find_coalesced_group(pg_name, all_entries[curr], _pg_guids, curr) # type: ignore[index]
if curr in all_entries # type: ignore[comparison-overlap]
else []
)
all_coalesced_entries[curr] = grp
for index, entry in grp:
op = Op(entry, _memberships)
op = Op(entry, _memberships, pg_name)
peer = None
if op.type == "send":
assert op._src_g == curr, (op._src_g, curr)
@ -228,6 +249,7 @@ def build_collectives(
group_size=_groups[pg_name].size,
groups=_groups,
memberships=_memberships,
_pg_guids=_pg_guids,
)
if match and mismatch[pg_name] == 0:
@ -256,10 +278,13 @@ def build_collectives(
# step over ops from other PGs
# only check match state when seq_id matches
if (
e["process_group"] == (pg_name, desc)
_pg_guids[(e["process_group"][0], o)] == pg_name
and e["process_group"][1] == desc
and e["collective_seq_id"] == collective_seq_id
):
match_state = match_one_event(entries[0], e, _memberships)
match_state = match_one_event(
entries[0], e, _memberships, pg_name
)
if (
match_state
in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED]
@ -403,17 +428,19 @@ def build_db(details: Dict[str, Dict[str, Any]], args: argparse.Namespace) -> Da
entries = sort_trace_from_beginning(entries)
# flattened database
groups, _groups, memberships, _memberships = build_groups_memberships(pg_config)
groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships(
pg_config
)
print("built groups, memberships")
check_no_missing_dump_files(entries, memberships)
if args.just_print_entries:
just_print_entries(entries, _groups, _memberships)
just_print_entries(entries, _groups, _memberships, _pg_guids)
sys.exit(0)
tracebacks, collectives, nccl_calls = build_collectives(
entries, _groups, _memberships
entries, _groups, _memberships, _pg_guids
)
print("built collectives, nccl_calls")
if args.verbose:

View File

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import argparse
from typing import Optional, Sequence
class JobConfig:
@ -30,5 +31,7 @@ class JobConfig:
self.parser.add_argument("-j", "--just_print_entries", action="store_true")
self.parser.add_argument("-v", "--verbose", action="store_true")
def parse_args(self: "JobConfig") -> argparse.Namespace:
return self.parser.parse_args()
def parse_args(
self: "JobConfig", args: Optional[Sequence[str]]
) -> argparse.Namespace:
return self.parser.parse_args(args)

View File

@ -41,9 +41,7 @@ def read_dir(prefix: str, folder: str) -> Dict[str, Dict[str, Any]]:
t0 = time.time()
for root, _, files in os.walk(folder):
for f in files:
ta = time.time()
details[f] = read_dump(prefix, os.path.join(root, f))
tb = time.time()
# print(f"read file {f} in {tb - ta}s")
print(f"loaded {len(files)} files in {tb - t0}s")
return details

View File

@ -68,13 +68,13 @@ NCCLOp:
class Group(NamedTuple):
id: int
id: str
desc: str
size: int
class Membership(NamedTuple):
group_id: Ref[Group]
group_id: str
global_rank: int
@ -85,13 +85,13 @@ class Traceback(NamedTuple):
class Collective(NamedTuple):
id: int
group_id: Ref[Group]
group_id: str
class NCCLCall(NamedTuple):
id: int
collective_id: Ref[Collective]
group_id: Ref[Group]
group_id: str
global_rank: int # technically Ref[Process] once we have it
traceback_id: Ref[Traceback]
collective_type: str
@ -188,7 +188,9 @@ class Op:
nccl:recv 3<-0
"""
def __init__(self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]]):
def __init__(
self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]], pg_name: str
):
profiling_name = event["profiling_name"]
nccl, name = profiling_name.split(":")
assert nccl == "nccl", f"name formatting error? {nccl} != 'nccl'"
@ -209,7 +211,7 @@ class Op:
self._dst, self._src = int(d), int(s)
else:
self._src, self._dst = -1, -1
pg_name, pg_desc = event["process_group"]
_, pg_desc = event["process_group"]
self._init_global_src_dst(memberships[pg_name])
self.pg_size = len(memberships[pg_name])
if type in P2P | COLLECTIVES:
@ -239,10 +241,8 @@ class Op:
def __repr__(self) -> str:
if self.type in P2P:
return (
f"{self.type}(s={self._src_g} d={self._dst_g}, sz={self.input_sizes})"
)
return f"{self.type}(input_sizes={self.input_sizes}, {self.state})"
return f"{self.type}(s={self._src_g} d={self._dst_g}, sz={self.input_sizes}, state={self.state})"
return f"{self.type}(input_sizes={self.input_sizes}, state={self.state})"
def match(self, other: "Op") -> MatchState:
# TODO: I think this can validly not match,

View File

@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
import math
from typing import Any, Dict, List, Set, Tuple # type: ignore[attr-defined]
from typing import Any, Dict, List, Set, Tuple
from tools.flight_recorder.components.types import (
Group,
@ -40,9 +40,10 @@ def match_one_event(
event_a: Dict[Any, Any],
event_b: Dict[Any, Any],
memberships: Dict[str, Set[Any]],
pg_name: str,
) -> MatchState:
op_a = Op(event_a, memberships)
op_b = Op(event_b, memberships)
op_a = Op(event_a, memberships, pg_name)
op_b = Op(event_b, memberships, pg_name)
return op_a.match(op_b)
@ -51,6 +52,7 @@ def match_coalesced_groups(
group_size: int,
groups: Dict[str, Group],
memberships: Dict[str, Set[Any]],
_pg_guids: Dict[Tuple[str, int], str],
) -> bool:
"""
all_rank_events: {
@ -76,13 +78,22 @@ def match_coalesced_groups(
rank1 [recv:0 (1000B), recv:0 (100B)] —> not okay
"""
all_ops = {
rank: [Op(e, memberships) for i, e in all_rank_events[rank]]
rank: [
Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
for i, e in all_rank_events[rank]
]
for rank in all_rank_events
}
def visualize_ops(match: bool) -> None:
def visualize_ops(
match: bool,
_pg_guids: Dict[Tuple[str, int], str],
) -> None:
all_ops = {
rank: [Op(e, memberships) for i, e in all_rank_events[rank]]
rank: [
Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
for i, e in all_rank_events[rank]
]
for rank in all_rank_events
}
@ -94,8 +105,14 @@ def match_coalesced_groups(
progress = False
for r in all_ops:
if len(all_ops[r]) > i:
_, event = all_rank_events[r][i]
row.append(Op(event, memberships))
rank, event = all_rank_events[r][i]
row.append(
Op(
event,
memberships,
_pg_guids[(event["process_group"][0], rank)],
)
)
progress = True
else:
row.append(None) # type: ignore[arg-type]
@ -144,10 +161,10 @@ def match_coalesced_groups(
my_ops.pop(0)
peer_ops.pop(match_idx)
else:
visualize_ops(False)
visualize_ops(False, _pg_guids)
return False
visualize_ops(True)
visualize_ops(True, _pg_guids)
return True
@ -161,21 +178,27 @@ def check_size_alltoall(alltoall_cases: List[Dict[str, Any]]) -> Tuple[bool, int
def find_coalesced_group(
pg_name: str, entries: List[Dict[str, Any]]
pg_name: str,
entries: List[Dict[str, Any]],
_pg_guids: Dict[Tuple[str, int], str],
rank: int,
) -> List[Tuple[int, Dict[str, Any]]]:
"""Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones,
build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id
TODO: handle p2p_seq_id v/s collective_seq_id separately here.
"""
found = []
collective_seq_id = None
for i, e in enumerate(entries):
if e["process_group"][0] != pg_name:
if _pg_guids[(e["process_group"][0], rank)] != pg_name:
continue
elif collective_seq_id is None:
collective_seq_id = e["collective_seq_id"]
collective_seq_id = (
e["p2p_seq_id"] if e["is_p2p"] else e["collective_seq_id"]
)
found.append((i, e))
elif e["collective_seq_id"] == collective_seq_id:
elif not e["is_p2p"] and e["collective_seq_id"] == collective_seq_id:
found.append((i, e))
elif e["is_p2p"] and e["p2p_seq_id"] == collective_seq_id:
found.append((i, e))
else:
break
@ -190,6 +213,7 @@ def just_print_entries(
all_entries: Dict[int, List[Dict[str, Any]]],
_groups: Dict[str, Group],
_memberships: Dict[str, Set[Any]],
_pg_guids: Dict[Tuple[str, int], str],
) -> None:
rows = []
ranks = sorted(all_entries.keys())
@ -203,7 +227,8 @@ def just_print_entries(
row.append("")
else:
entry = all_entries[rank].pop(0)
row.append(str(Op(entry, _memberships)))
pg_name = _pg_guids[(entry["process_group"][0], rank)]
row.append(str(Op(entry, _memberships, pg_name)))
progress = True
if progress:
rows.append(row)

View File

@ -28,8 +28,8 @@ python fr_trace.py -d <dump dir containing trace files> [-o <output file>]
- This script is versioned so that we can ensure our future changes to flight recorder are backwards compatible.
"""
import argparse
import pickle
from typing import Optional, Sequence
from tools.flight_recorder.components.builder import build_db
from tools.flight_recorder.components.config_manager import JobConfig
@ -37,7 +37,9 @@ from tools.flight_recorder.components.loader import read_dir
from tools.flight_recorder.components.types import types
def main(args: argparse.Namespace) -> None:
def main(args: Optional[Sequence[str]] = None) -> None:
config = JobConfig()
args = config.parse_args(args)
details = read_dir(args.prefix, args.dir)
db = build_db(details, args)
if args.output:
@ -46,5 +48,4 @@ def main(args: argparse.Namespace) -> None:
if __name__ == "__main__":
config = JobConfig()
main(config.parse_args())
main()