diff --git a/tools/flight_recorder/fr_trace.py b/tools/flight_recorder/fr_trace.py new file mode 100644 index 000000000000..dc05ea67fcbc --- /dev/null +++ b/tools/flight_recorder/fr_trace.py @@ -0,0 +1,828 @@ +#!/usr/bin/env python3 +"""Flight Recorder Trace Analyzer + +This script primarily merges data from individual flight recorder buffers from individual ranks in a +PyTorch Distributed program into a flattened database format that can be used for further analysis. + +However as part of the merging process, it is necessary to perform some analysis in order to match operators +on one rank with corresponding operators on other ranks and register them as one 'collective' entry. During this +process, a significant amount of useful information can already be extracted such as where the first mismatch occurs +in cases of desync (when not all ranks issue a compatible collective in a particular process group). + + +Not Yet Implemented +- TODO- tracebacks aren't implemented + +Known Issues +- Flight Recorder buffer sequence_id information is not sufficient to match collectives and coalseced collectives + unless we have the trace data from the beginning of the program. To enable confident analysis of trace buffers that + do not start from zero (and to simplify the script's matching logic) we need to add more information to the recorder. +- Currently, the script omits checking the 'status' of collectives. We can look for the first 'non completed' + collective easily enough and report that. + +Usage +python fr_trace.py -d [-o ] + +- Omitting the optional output file will still yield analysis information to stdout +- the output file is a pickle of the flat DB, which may change in format in the future. +""" + +import argparse +import os +import pickle +import sys +from typing import ( # type: ignore[attr-defined] + _eval_type, + Any, + Dict, + Generic, + List, + NamedTuple, + Tuple, + Type, + TypeVar, + Union, +) + +import tabulate # type: ignore[import-untyped] + + +T = TypeVar("T", bound=NamedTuple) + + +class Ref(Generic[T]): + pass + + +class TypeInfo(NamedTuple): + name: str + fields: List[Tuple[str, Type]] # type: ignore[type-arg] + + @classmethod + def from_type(cls, c: T) -> "TypeInfo": + if hasattr(c, "__name__"): + name = c.__name__ + else: + name = str(c) + return cls( + name, + [(f, _eval_type(c.__annotations__[f], globals(), {})) for f in c._fields], + ) + + +""" +Schema for flat DB + +TODO schemas not yet implemented +# threads as recorded at termination of process +Threads + id: int + traceback_id: int + process_id: int + +Process: + id: int # Same as world groups RANK + pid: int + hostname: str + +NCCLOp: + # nccl op implementation details (sends/recv) + id: int + nccl_call_id: int + +""" + + +class Group(NamedTuple): + id: int + desc: str + size: int + + +class Membership(NamedTuple): + group_id: Ref[Group] + global_rank: int + + +class Traceback(NamedTuple): + id: int + frames: str + + +class Collective(NamedTuple): + id: int + group_id: Ref[Group] + + +class NCCLCall(NamedTuple): + id: int + collective_id: Ref[Collective] + group_id: Ref[Group] + global_rank: int # technically Ref[Process] once we have it + traceback_id: Ref[Traceback] + collective_type: str + sizes: List[List[int]] + + +class Database(NamedTuple): + groups: List[Group] + memberships: List[Membership] + tracebacks: List[Traceback] + collectives: List[Collective] + ncclcalls: List[NCCLCall] + + +types = [ + TypeInfo.from_type(t) # type: ignore[type-var] + for t in globals().values() + if ( + isinstance(t, type) + and issubclass(t, tuple) + and hasattr(t, "_fields") + and t is not TypeInfo + ) +] + +""" +Stacktrace cache +TODO +""" + + +""" +Collective Matching logic +""" +COLLECTIVES = { + "broadcast", + "all_gather", + "all_reduce", + "_all_gather_base", + "all_gather_into_tensor_coalesced", + "reduce_scatter_tensor_coalesced", + "_reduce_scatter_base", + "gather", + "scatter", +} + +P2P = { + "send", + "recv", +} + + +class Op: + """Parses relevant info about operation out of 'event' dict + + examples of supported `profiling_name`s: + nccl:broadcast + nccl:send 1->2 + nccl:recv 3<-0 + """ + + def __init__(self, event: Dict[Any, Any], memberships: Dict[str, List[Membership]]): + profiling_name = event["profiling_name"] + nccl, name = profiling_name.split(":") + assert nccl == "nccl", f"name formatting error? {nccl} != 'nccl'" + parts = name.split(" ") + type = parts[0] + meta = parts[1] if len(parts) == 2 else None + self.state = event["state"] + + self.pg_name, _ = event["process_group"] + + assert type in COLLECTIVES | P2P | { + "coalesced" + }, f"{type} is not a supported operation" + self.type = type + if type == "send": + s, d = meta.split("->") + self._src, self._dst = int(s), int(d) + elif type == "recv": + d, s = meta.split("<-") + self._dst, self._src = int(d), int(s) + else: + self._src, self._dst = -1, -1 + pg_name, pg_desc = event["process_group"] + self._init_global_src_dst(memberships[pg_name]) + + if type in P2P | COLLECTIVES: + self.input_sizes = event["input_sizes"] + self.output_sizes = event["output_sizes"] + else: + self.input_sizes, self.output_sizes = None, None + self.collective_seq_id = event["collective_seq_id"] + self.p2p_seq_id = event["p2p_seq_id"] + + def _init_global_src_dst(self, pg_ranks: List[Membership]) -> None: + pg_ranks = sorted(pg_ranks) + self._src_g = pg_ranks[self._src] if self._src is not None else None + self._dst_g = pg_ranks[self._dst] if self._dst is not None else None + + @property + def src(self) -> int: + assert self.type in P2P, "can't get src of non-p2p op" + return self._src + + @property + def dst(self) -> int: + assert self.type in P2P, "can't get dst of non-p2p op" + return self._dst + + def __repr__(self) -> str: + if self.type in P2P: + return ( + f"{self.type}(s={self._src_g} d={self._dst_g}, sz={self.input_sizes})" + ) + return f"{self.type}(input_sizes={self.input_sizes}, {self.state})" + + def match(self, other) -> bool: # type: ignore[no-untyped-def] + # TODO: I think this can validly not match, + # e.g. if one PG was used for p2p ops between only some of the peers? + # if self.seq_id != other.seq_id: + # return False + + if self.type == "send": + return bool( + other.type == "recv" + and self.src == other.src + and self.dst == other.dst + and self.input_sizes == other.output_sizes + ) + elif self.type == "recv": + return bool( + other.type == "send" + and self.src == other.src + and self.dst == other.dst + and self.output_sizes == other.input_sizes + ) + elif self.type in COLLECTIVES: + return bool( + self.type == other.type and self.input_sizes == other.input_sizes + ) + # TODO(whc) - output sizes dont have to match for e.g. gather, not sure if they ever have to match? + # and self.output_sizes == other.output_sizes) + elif self.type == "coalesced": + return bool(other.type == "coalesced") + return False + + +def match_one_event( + event_a: Dict[Any, Any], + event_b: Dict[Any, Any], + memberships: Dict[str, List[Membership]], +) -> bool: + op_a = Op(event_a, memberships) + op_b = Op(event_b, memberships) + return op_a.match(op_b) + + +def match_coalesced_groups( + all_rank_events: Dict[Any, Any], + group_size: int, + groups: Dict[str, Group], + memberships: Dict[str, List[Membership]], +) -> bool: + """ + all_rank_events: { + rank: [ + (idx, event_dict) + ] + } + + Note: it is possible for event dicts in a coalesced group to be asymmetric. + e.g. the following events lists form a valid coalescing group + events0 [send:1] + events1 [recv:0, send:2] + events2 [recv:1] + + Rule 1: all ops should find a match + Rule 2: relative ordering of sends and recvs in one event list can be arbitrary + e.g. + events1 [recv:0, send:2] —> okay + events1 [send:2, recv:0] —> also okay + Rule 3: sends to the same dest or recvs from the src should be in a consistent order + e.g. + rank0 [send:1 (100B), send:1 (1000B)] + rank1 [recv:0 (1000B), recv:0 (100B)] —> not okay + """ + all_ops = { + rank: [Op(e, memberships) for i, e in all_rank_events[rank]] + for rank in all_rank_events + } + + def visualize_ops(match: bool) -> None: + all_ops = { + rank: [Op(e, memberships) for i, e in all_rank_events[rank]] + for rank in all_rank_events + } + + i = 0 + row = [] + progress = True + table = [] + while progress: + progress = False + for r in all_ops: + if len(all_ops[r]) > i: + _, event = all_rank_events[r][i] + row.append(Op(event, memberships)) + progress = True + else: + row.append(None) # type: ignore[arg-type] + table.append(row) + row = [] + i += 1 + title = "Match" if match else "MISMATCH" + print(f"{title}\n", tabulate(table)) # type: ignore[operator] + + # TODO can't verify seq_id bc there might have been valid seq deltas between ranks even within a pg. + for op_list in all_ops.values(): + if not op_list: + # print("TODO- not sure if its valid for only some ranks in a PG to participate in a coalesced op?") + return False + assert op_list[-1].type == "coalesced" + op_list.pop(-1) + + while all_ops: + first_rank = next(iter(all_ops)) + my_ops = all_ops[first_rank] + + if len(all_ops[first_rank]) == 0: + all_ops.pop(first_rank) + continue + + # lets match the first collective! we need to know which ranks are involved, and ensure that this same + # collective is also the first one on those ranks within that group + op = my_ops[0] + match_idx = -1 + if op.type in P2P: + dst_global_rank = sorted(memberships[op.pg_name])[op.dst] + peer_ops = all_ops[dst_global_rank] + for i, other in enumerate(peer_ops): + if op.match(other): + match_idx = i + break + elif op.dst == other.src: + # Rule 3 + break + else: + # Rule 1 + continue + else: + raise NotImplementedError("coalesced collective ops") + if match_idx >= 0: + my_ops.pop(0) + peer_ops.pop(match_idx) + else: + visualize_ops(False) + return False + + visualize_ops(True) + return True + + +""" +Flat DB builder +""" + + +def build_groups_memberships( + pg_config: Any, +) -> Tuple[List[Group], Dict[Any, Group], List[Membership], Dict[str, Any]]: + """ + pg_config: { + global_rank: { + (pg_id, desc, ranks) + } + } + + `pg_id` is a system generated id, but depending on the mode of PG creation it could be a globally incrementing int + or a hash of the ranks. See `_process_group_name` in distributed_c10d.py. + `desc` is provided by the user (optionally) and should be 'meaningful' (e.g. TP/PP/DP group) + `ranks` is a list of the 'global ranks' that are members of the PG. + + (pg_id, desc, ranks) tuples are appended lazily to the flight buffer when `getNCCLComm` is called on a PG and + the `enabled_` flag is true for that PG. + - the order of calling (init_process_group, new_group, etc) does not affect the order of the tuples in the list + + Returns: a groups table and a membership table, where each row is a Group or Membership namedtuple + """ + # flat lists for return + groups = [] + memberships = [] + + # dicts for faster cross-rank validation + _groups = {} + _memberships = {} + for global_rank in pg_config: + for pg_id in pg_config[global_rank]: + desc = pg_config[global_rank][pg_id]["desc"] + ranks = pg_config[global_rank][pg_id]["ranks"] + if isinstance(ranks, str): + # TODO Bug in FR data format? ranks is '[0, 1,...]' + ranks = eval(ranks) + + if pg_id not in _groups: + groups.append(Group(id=pg_id, desc=desc, size=len(ranks))) + for rank in ranks: + memberships.append(Membership(group_id=pg_id, global_rank=rank)) + _groups[pg_id] = groups[-1] + _memberships[pg_id] = set(ranks) + else: + # validation across ranks + assert ( + _groups[pg_id].desc == desc + ), f"mismatch in desc {_groups[pg_id].desc} vs {desc}" + assert _memberships[pg_id] == set( + ranks + ), f"mismatch in membership {_memberships[pg_id]} vs {set(ranks)}" + return groups, _groups, memberships, _memberships + + +def build_nccl_call( + entry: Dict[Any, Any], + id: int, + collective_id: Any, + group_id: int, + global_rank: Any, +) -> NCCLCall: + return NCCLCall( + id=id, + collective_id=collective_id, + group_id=group_id, # type: ignore[arg-type] + global_rank=global_rank, + traceback_id=0, # type: ignore[arg-type] + collective_type=entry["profiling_name"], + sizes=entry["input_sizes"], + ) + + +def find_coalesced_group( + pg_name: str, entries: List[Dict[str, Any]] +) -> List[Tuple[int, Dict[str, Any]]]: + """Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones, + build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id + TODO: handle p2p_seq_id v/s collective_seq_id separately here. + """ + found = [] + collective_seq_id = None + for i, e in enumerate(entries): + if e["process_group"][0] != pg_name: + continue + elif collective_seq_id is None: + collective_seq_id = e["collective_seq_id"] + found.append((i, e)) + elif e["collective_seq_id"] == collective_seq_id: + found.append((i, e)) + else: + break + + if len(found) > 1: + assert found[-1][1]["profiling_name"] == "nccl:coalesced" + return found + return [] + + +def just_print_entries( + all_entries: Dict[int, List[Dict[str, Any]]], + _groups: Dict[str, Group], + _memberships: Dict[str, List[Membership]], +) -> None: + rows = [] + ranks = sorted(all_entries.keys()) + headers = [f"Rank {rank}" for rank in ranks] + progress = True + while progress: + progress = False + row = [] + for rank in ranks: + if len(all_entries[rank]) == 0: + row.append("") + else: + entry = all_entries[rank].pop(0) + row.append(str(Op(entry, _memberships))) + progress = True + if progress: + rows.append(row) + + print(tabulate(rows, headers=headers)) # type: ignore[operator] + + +def build_collectives( + all_entries: Dict[int, List[Dict[str, Any]]], + _groups: Dict[str, Group], + _memberships: Dict[str, List[Membership]], +) -> Tuple[List[Traceback], List[Collective], List[NCCLCall]]: + """ + groups, memberships are the non-flat dicts that are indexable + all_entries is a raw dict from the original dumps: + + all_entries: { + global_rank: [ + { + record_id: ordered id of the event in the trace buffer + pg_id: ProcessGroupNCCL::uid_ + *note: `pg_id` corresponds to nothing in groups table + process_group: (pg_name, desc) + *note: `pg_name`, `desc` corresponds to `pg_id`, `desc` in groups table + collective_seq_id: ordered id for collective operations and coalesced group operations + p2p_seq_id: ordered id for point-to-point operations + op_id: ordered id including individual ops inside coalescing group + profiling_name: descriptive name of the operation + 'time_created_ns', + 'input_sizes', + 'output_sizes', + 'state', + 'time_discovered_started_ns', + 'time_discovered_completed_ns', + 'retired', + 'frames', + } + ] + } + """ + tracebacks: List[Traceback] = [] + + collectives: List[Collective] = [] + nccl_calls: List[NCCLCall] = [] + + # once we find one mismatch, we stop pairing up collectives since the pairing is possibly incorrect + # instead, just record the remaining ops as NCCLCalls + mismatch = {_groups[g].id: 0 for g in _groups} + MISMATCH_TAIL = 10 + """ + - it doesn't matter what order I put collectives/ncclops into their table. we can later on re-sort it by start time + - there could be multiple options for the "first" collective to pair up (rank 0,1 might do a bcast while rank 2,3 do a bcast) + - within a group, the first collective must be the same on all ranks in the group, then it can be marked as a + collective and removed + """ + while all_entries: + # we greedily match collectives, starting arbitrarily with the trace from the first rank + # later, if we exhaust the first rank, we continue with the next 'first rank' + rank_iter = iter(all_entries) + first_rank = next(rank_iter) + other_ranks = list(rank_iter) + + if len(all_entries[first_rank]) == 0: + all_entries.pop(first_rank) + continue + + # lets match the first collective! we need to know which ranks are involved, and ensure that this same + # collective is also the first one on those ranks within that group + entries = all_entries[first_rank] + pg_name, desc = entries[0]["process_group"] + profiling_name = entries[0]["profiling_name"] + expected_ranks = set(_memberships[pg_name]) + found_ranks = {first_rank} + found_idx = {} + + if find_coalesced_group(pg_name, entries): + expected_ranks.add(first_rank) + done_ranks = set() + all_coalesced_entries = {} + while expected_ranks: + curr = expected_ranks.pop() + done_ranks.add(curr) + grp = ( + find_coalesced_group(pg_name, all_entries[curr]) # type: ignore[index] + if curr in all_entries # type: ignore[comparison-overlap] + else [] + ) + all_coalesced_entries[curr] = grp + for index, entry in grp: + op = Op(entry, _memberships) + peer = None + if op.type == "send": + assert op._src_g == curr, (op._src_g, curr) + peer = op._dst_g + elif op.type == "recv": + assert op._dst_g == curr, (op._dst_g, curr) + peer = op._src_g + if peer and peer not in done_ranks: + expected_ranks.add(peer) + + match = match_coalesced_groups( + all_coalesced_entries, + group_size=_groups[pg_name].size, + groups=_groups, + memberships=_memberships, + ) + + if match and mismatch[pg_name] == 0: + collectives.append(Collective(id=len(collectives), group_id=pg_name)) + else: + mismatch[pg_name] += 1 + + for r in all_coalesced_entries: + reversed_calls = [] + for i, _ in reversed(all_coalesced_entries[r]): + reversed_calls.append( + build_nccl_call( + all_entries[r].pop(i), # type: ignore[index] + id=len(nccl_calls), + collective_id=collectives[-1].id if match else None, + group_id=pg_name, + global_rank=r, + ) + ) + nccl_calls.extend(reversed(reversed_calls)) + + else: + for o in expected_ranks.intersection(set(other_ranks)): + for i, e in enumerate(all_entries[o]): # type: ignore[index] + # step over ops from other PGs + if e["process_group"] == (pg_name, desc): + if ( + match_one_event(entries[0], e, _memberships) + and mismatch[pg_name] == 0 + ): + found_ranks.add(o) + found_idx[o] = i + else: + # we found a mismatch. what do we do with that? + mismatch[pg_name] += 1 + print( + f"Mismatched collective on rank {o} for group {pg_name}:{desc} collective {profiling_name}" + ) + break + + # at this point there are 3 possibilities + # 1. we found a match on all the ranks that are members of the group + # -> we create a Collective and remove the individual entries from their original lists + if found_ranks == expected_ranks and mismatch[pg_name] == 0: + collectives.append(Collective(id=len(collectives), group_id=pg_name)) + for r in found_ranks: + i = found_idx[r] if r != first_rank else 0 + nccl_calls.append( + build_nccl_call( + all_entries[r].pop(i), # type: ignore[index] + id=len(nccl_calls), + collective_id=collectives[-1].id, + group_id=pg_name, + global_rank=r, + ) + ) + + # 2. we found a partial match but some ranks are missing + # 3. we found no match + # -> since its not a complete collective, no entry goes into collectives but we still record a nccl call + # TODO should there be a way to mark 'mismatches'? + else: + print("appending a non-matching collective") + nccl_calls.append( + build_nccl_call( + all_entries[first_rank].pop(0), + id=len(nccl_calls), + collective_id=None, + group_id=pg_name, + global_rank=r, + ) + ) + + if mismatch[pg_name] > MISMATCH_TAIL: + print(f"Too many mismatches for process_group {pg_name}:{desc}, aborting") + sys.exit(-1) + + return tracebacks, collectives, nccl_calls + + +def check_no_missing_dump_files( + entries: Dict[str, Any], memberships: List[Membership] +) -> None: + all_ranks = set() + for membership in memberships: + all_ranks.add(str(membership.global_rank)) + dumps_ranks = set(entries.keys()) + assert ( + dumps_ranks == all_ranks + ), f"Missing dump files from ranks {all_ranks - dumps_ranks}" + + +def check_version(versions: Dict[str, Any]) -> None: + for rank, version in versions.items(): # noqa: PERF102 + major, minor = map(int, version.split(".")) + # assert major == 2, f"Rank {rank} unsupported version {version}" + # assert minor >= 0, f"Rank {rank} unsupported version {version}" + + +def check_trace_from_beginning(entries: Dict[str, Any]) -> bool: + for rank in entries: + first_record_id = entries[rank][0]["record_id"] + # TODO add more sequence information such that analysis can proceed even without complete buffer + + # assert first_record_id == 0, f"Rank {rank} trace does not start at time 0 (first record is {first_record_id}." + if first_record_id != 0: + print( + f"Rank {rank} trace does not start at time 0 (first record is {first_record_id}." + ) + return False + return True + + +def build_db(details: Dict[str, Dict[str, Any]], args: argparse.Namespace) -> Database: + # temporary state used for building database + entries = {} + pg_config = {} + version = {} + for dump in details.values(): + rank = dump["rank"] + entries[rank] = dump["entries"] + version[rank] = dump["version"] + pg_config[rank] = dump["pg_config"] + + check_version(version) + check_trace_from_beginning(entries) + + # flattened database + groups, _groups, memberships, _memberships = build_groups_memberships(pg_config) + print("built groups, memberships") + + check_no_missing_dump_files(entries, memberships) + + if args.just_print_entries: + just_print_entries(entries, _groups, _memberships) + sys.exit(0) + + tracebacks, collectives, nccl_calls = build_collectives( + entries, _groups, _memberships + ) + print("built collectives, nccl_calls") + if args.verbose: + print("Groups\n", tabulate(groups, headers=Group._fields)) # type: ignore[operator] + print("Memberships\n", tabulate(memberships, headers=Membership._fields)) # type: ignore[operator] + print("Collectives\n", tabulate(collectives, headers=Collective._fields)) # type: ignore[operator] + print("NCCLCalls\n", tabulate(nccl_calls, headers=NCCLCall._fields)) # type: ignore[operator] + db = Database( + tracebacks=tracebacks, + collectives=collectives, + ncclcalls=nccl_calls, + groups=groups, + memberships=memberships, + ) + return db + + +def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any]]]: + basename = os.path.basename(filename) + assert ( + basename.find(prefix) == 0 + ), f"args.prefix ({prefix}) must match the beginning of each filename ({basename})" + rank = int(basename[len(prefix) :]) + host_name = f"host_rank{rank}" + + with open(filename, "rb") as infile: + dump = pickle.load(infile) + + entries = dump["entries"] + version = dump["version"] + pg_config = dump["pg_config"] + + return { + "host_name": host_name, + "rank": rank, + "entries": entries, + "version": version, + "pg_config": pg_config, + } + + +def read_dir(prefix: str, folder: str) -> Dict[Any, Any]: # TODO; fix types + import gc + import time + + gc.disable() + details = {} + t0 = time.time() + for root, _, files in os.walk(folder): + for f in files: + ta = time.time() + details[f] = read_dump(prefix, os.path.join(root, f)) + tb = time.time() + # print(f"read file {f} in {tb - ta}s") + print(f"loaded {len(files)} files in {tb - t0}s") + return details + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("-d", "--dir", help="Directory with flight recorder dumps") + parser.add_argument("-o", "--output", default=None) + parser.add_argument( + "-p", + "--prefix", + help="prefix to strip such that rank can be extracted", + default="rank_", + ) + parser.add_argument("-j", "--just_print_entries", action="store_true") + parser.add_argument("-v", "--verbose", action="store_true") + args = parser.parse_args() + + details = read_dir(args.prefix, args.dir) + db = build_db(details, args) + if args.output: + with open(args.output, "wb") as f: + pickle.dump((types, db), f) + + +if __name__ == "__main__": + main()