[2/N] Refactor FR script - add a loader module (#133929)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133929
Approved by: https://github.com/c-p-i-o
ghstack dependencies: #133927
This commit is contained in:
fduwjj
2024-08-19 16:42:59 -07:00
committed by PyTorch MergeBot
parent 2bd02e0c82
commit 36376efd06
2 changed files with 50 additions and 43 deletions

View File

@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import gc
import os
import pickle
import time
from typing import Any, Dict, List, Union
def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any]]]:
basename = os.path.basename(filename)
assert (
basename.find(prefix) == 0
), f"args.prefix ({prefix}) must match the beginning of each filename ({basename})"
rank = int(basename[len(prefix) :])
host_name = f"host_rank{rank}"
with open(filename, "rb") as infile:
dump = pickle.load(infile)
entries = dump["entries"]
version = dump["version"]
pg_config = dump["pg_config"]
return {
"host_name": host_name,
"rank": rank,
"entries": entries,
"version": version,
"pg_config": pg_config,
}
def read_dir(prefix: str, folder: str) -> Dict[str, Dict[str, Any]]:
gc.disable()
details = {}
t0 = time.time()
for root, _, files in os.walk(folder):
for f in files:
ta = time.time()
details[f] = read_dump(prefix, os.path.join(root, f))
tb = time.time()
# print(f"read file {f} in {tb - ta}s")
print(f"loaded {len(files)} files in {tb - t0}s")
return details

View File

@ -30,7 +30,6 @@ python fr_trace.py -d <dump dir containing trace files> [-o <output file>]
import argparse
import ast
import math
import os
import pickle
import sys
from enum import Enum
@ -45,10 +44,10 @@ from typing import ( # type: ignore[attr-defined]
Tuple,
Type,
TypeVar,
Union,
)
from tools.flight_recorder.components.config_manager import JobConfig
from tools.flight_recorder.components.loader import read_dir
try:
@ -926,47 +925,6 @@ def build_db(details: Dict[str, Dict[str, Any]], args: argparse.Namespace) -> Da
return db
def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any]]]:
basename = os.path.basename(filename)
assert (
basename.find(prefix) == 0
), f"args.prefix ({prefix}) must match the beginning of each filename ({basename})"
rank = int(basename[len(prefix) :])
host_name = f"host_rank{rank}"
with open(filename, "rb") as infile:
dump = pickle.load(infile)
entries = dump["entries"]
version = dump["version"]
pg_config = dump["pg_config"]
return {
"host_name": host_name,
"rank": rank,
"entries": entries,
"version": version,
"pg_config": pg_config,
}
def read_dir(prefix: str, folder: str) -> Dict[Any, Any]: # TODO; fix types
import gc
import time
gc.disable()
details = {}
t0 = time.time()
for root, _, files in os.walk(folder):
for f in files:
ta = time.time()
details[f] = read_dump(prefix, os.path.join(root, f))
tb = time.time()
# print(f"read file {f} in {tb - ta}s")
print(f"loaded {len(files)} files in {tb - t0}s")
return details
def main(args: argparse.Namespace) -> None:
details = read_dir(args.prefix, args.dir)
db = build_db(details, args)