Compare commits

...

1 Commits

Author SHA1 Message Date
4f6a767b3c transform fr traces for ft
Summary:
- the ranks in the default pg config are local ranks
- however fr trace analysis requires them to be global ranks
- so we transform the local ranks to global ranks before the analysis kicks in based on a cli flag
2025-10-29 10:54:55 -07:00
3 changed files with 31 additions and 1 deletions

View File

@ -374,6 +374,22 @@ def build_collectives(
return tracebacks, collectives, nccl_calls
def transform_ft(
details: dict[str, dict[str, Any]], group_world_size: int
) -> dict[str, dict[str, Any]]:
for dump_key, dump in details.items():
rank = dump["rank"]
for key, pg_config in dump["pg_config"].items():
if pg_config["desc"] == "default_pg":
ranks = eval(pg_config["ranks"])
replica_id = rank // group_world_size
first_rank = replica_id * group_world_size
new_ranks = [r + first_rank for r in ranks]
details[dump_key]["pg_config"][key]["ranks"] = f"{new_ranks}"
return details
def build_db(
details: dict[str, dict[str, Any]], args: argparse.Namespace, version: str
) -> Database:

View File

@ -74,6 +74,17 @@ class JobConfig:
default=10,
help="Maximum number of mismatches we print (from earliest).",
)
self.parser.add_argument(
"--transform-ft",
action="store_true",
help="Transform PG config to use global ranks to analyze traces produced by torchft",
)
self.parser.add_argument(
"--group-world-size",
type=int,
default=None,
help="The number of ranks in 1 torchft replica group. Must be specified if --transform-ft is True",
)
def parse_args(
self: "JobConfig", args: Optional[Sequence[str]]

View File

@ -32,7 +32,7 @@ import pickle
from collections.abc import Sequence
from typing import Optional
from tools.flight_recorder.components.builder import build_db
from tools.flight_recorder.components.builder import build_db, transform_ft
from tools.flight_recorder.components.config_manager import JobConfig
from tools.flight_recorder.components.loader import read_dir
from tools.flight_recorder.components.types import types
@ -46,6 +46,9 @@ def main(args: Optional[Sequence[str]] = None) -> None:
assert args.trace_dir, "Trace directory trace_dir is required"
# pyrefly: ignore [bad-argument-type]
details, version = read_dir(args)
if args.transform_ft:
assert args.group_world_size, "World size is required for transform_ft"
details = transform_ft(details, args.group_world_size)
# pyrefly: ignore [bad-argument-type]
db = build_db(details, args, version)
# pyrefly: ignore [missing-attribute]