Files
Tianhao Huang 14c7358c64 Enable fr_trace to read local traces from multiple hosts. (#159490)
Summary: For training jobs particularly from GenAI, NCCL trace dumps are generated in the format of `<hostname>.pci3_rank_<rank>`. For multi-node training jobs, the hostname varies across traces. The current prefix matching logic can't handle this case.

Test Plan:
Create a local folder `dumps` and several empty files: `host0.pci3_rank_0`, `host0.pci3_rank_1`, `host1.pci3_rank_0`, `host1.pci3_rank_1` inside it. Then run
```
buck2 run fbcode//caffe2/fb/flight_recorder:fr_trace -- trace_dir dumps
```

Before this diff, fr_trace cannot locate any trace files, giving the following assertion error:
```
AssertionError: no files loaded from /home/tianhaoh/dumps with prefix pci3_rank_
```

After this diff, fr_trace is able to locate the trace files, resulting in the exceptions like
```
    dump = pickle.load(infile)
           ^^^^^^^^^^^^^^^^^^^
EOFError: Ran out of input
```
(since the trace files are fake and empty).

Rollback Plan:

Differential Revision: D79224727

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159490
Approved by: https://github.com/fduwjj
2025-08-06 03:15:34 +00:00

93 lines
2.8 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import gc
import os
import pickle
import re
import time
from collections import defaultdict
from typing import Any, Union
from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
logger: FlightRecorderLogger = FlightRecorderLogger()
def read_dump(prefix: str, filename: str) -> dict[str, Union[str, int, list[Any]]]:
basename = os.path.basename(filename)
rank = int(basename[len(prefix) :])
host_name = f"host_rank{rank}"
with open(filename, "rb") as infile:
dump = pickle.load(infile)
entries = dump["entries"]
version = dump["version"]
pg_config = dump["pg_config"]
return {
"host_name": host_name,
"rank": rank,
"entries": entries,
"version": version,
"pg_config": pg_config,
}
exp = re.compile(r"([\w\-\_]*?)(\d+)$")
def _determine_prefix(files: list[str]) -> str:
"""If the user doesn't specify a prefix, but does pass a dir full of similarly-prefixed files, we should be able to
infer the common prefix most of the time. But if we can't confidently infer, just fall back to requiring the user
to specify it
"""
possible_prefixes: defaultdict[str, set[int]] = defaultdict(set)
for f in files:
m = exp.search(f)
if m:
p, r = m.groups()
possible_prefixes[p].add(int(r))
if len(possible_prefixes) == 1:
prefix = next(iter(possible_prefixes))
logger.debug("Inferred common prefix %s", prefix)
return prefix
else:
raise ValueError(
"Unable to automatically determine the common prefix for the trace file names. "
"Please specify --prefix argument manually"
)
def read_dir(args: argparse.Namespace) -> tuple[dict[str, dict[str, Any]], str]:
gc.disable()
prefix = args.prefix
details = {}
t0 = time.time()
version = ""
filecount = 0
assert os.path.isdir(args.trace_dir), f"folder {args.trace_dir} does not exist"
for root, _, files in os.walk(args.trace_dir):
if prefix is None:
prefix = _determine_prefix(files)
for f in files:
if (offset := f.find(prefix)) == -1:
continue
details[f] = read_dump(f[:offset] + prefix, os.path.join(root, f))
filecount += 1
if not version:
version = str(details[f]["version"])
tb = time.time()
assert len(details) > 0, (
f"no files loaded from {args.trace_dir} with prefix {prefix}"
)
logger.debug("loaded %s files in %ss", filecount, tb - t0)
return details, version