mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
[ModularChecker
] QOL for the modular checker (#41361)
* update * fancy table fancy prints * download to cache folder, never need it everagain * stule * update based on review
This commit is contained in:
@ -103,20 +103,34 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub import logging as huggingface_hub_logging
|
||||
from safetensors.numpy import load_file as safetensors_load
|
||||
from safetensors.numpy import save_file as safetensors_save
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from transformers.utils import logging as transformers_logging
|
||||
|
||||
|
||||
# ANSI color codes for CLI output styling
|
||||
ANSI_RESET = "\033[0m"
|
||||
ANSI_BOLD = "\033[1m"
|
||||
ANSI_HEADER = "\033[1;36m"
|
||||
ANSI_SECTION = "\033[1;35m"
|
||||
ANSI_ROW = "\033[0;37m"
|
||||
ANSI_HIGHLIGHT_TOP = "\033[1;32m"
|
||||
ANSI_HIGHLIGHT_OLD = "\033[1;33m"
|
||||
ANSI_HIGHLIGHT_CANDIDATE = "\033[1;34m"
|
||||
|
||||
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
|
||||
@ -238,34 +252,40 @@ class CodeSimilarityAnalyzer:
|
||||
|
||||
self.models_root = MODELS_ROOT
|
||||
self.hub_dataset = hub_dataset
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.dtype = "auto"
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
|
||||
self.model = (
|
||||
AutoModel.from_pretrained(
|
||||
EMBEDDING_MODEL,
|
||||
torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32,
|
||||
)
|
||||
.eval()
|
||||
.to(self.device)
|
||||
)
|
||||
self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto", device_map="auto").eval()
|
||||
|
||||
self.device = self.model.device
|
||||
self.index_dir: Path | None = None
|
||||
|
||||
# ---------- HUB IO ----------
|
||||
|
||||
def _resolve_index_path(self, filename: str) -> Path:
|
||||
if self.index_dir is None:
|
||||
return Path(filename)
|
||||
return self.index_dir / filename
|
||||
|
||||
def ensure_local_index(self) -> None:
|
||||
"""Download index files from Hub if they don't exist locally."""
|
||||
have_all = Path(EMBEDDINGS_PATH).exists() and Path(INDEX_MAP_PATH).exists() and Path(TOKENS_PATH).exists()
|
||||
if have_all:
|
||||
"""Ensure index files are available locally, preferring Hub cache snapshots."""
|
||||
if self.index_dir is not None and all(
|
||||
(self.index_dir / fname).exists() for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH)
|
||||
):
|
||||
return
|
||||
logging.info(f"downloading index from hub: {self.hub_dataset}")
|
||||
for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH):
|
||||
hf_hub_download(
|
||||
repo_id=self.hub_dataset,
|
||||
filename=fname,
|
||||
repo_type="dataset",
|
||||
local_dir=".",
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
workspace_dir = Path.cwd()
|
||||
if all((workspace_dir / fname).exists() for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH)):
|
||||
self.index_dir = workspace_dir
|
||||
return
|
||||
|
||||
logging.info(f"downloading index from hub cache: {self.hub_dataset}")
|
||||
snapshot_path = snapshot_download(repo_id=self.hub_dataset, repo_type="dataset")
|
||||
snapshot_dir = Path(snapshot_path)
|
||||
missing = [
|
||||
fname for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH) if not (snapshot_dir / fname).exists()
|
||||
]
|
||||
if missing:
|
||||
raise FileNotFoundError("Missing expected files in Hub snapshot: " + ", ".join(missing))
|
||||
self.index_dir = snapshot_dir
|
||||
|
||||
def push_index_to_hub(self) -> None:
|
||||
"""Upload index files to the Hub dataset repository."""
|
||||
@ -284,7 +304,7 @@ class CodeSimilarityAnalyzer:
|
||||
|
||||
def _extract_definitions(
|
||||
self, file_path: Path, relative_to: Path | None = None, model_hint: str | None = None
|
||||
) -> tuple[dict[str, str], dict[str, str], dict[str, list[str]]]:
|
||||
) -> tuple[dict[str, str], dict[str, str], dict[str, list[str]], dict[str, str]]:
|
||||
"""
|
||||
Extract class and function definitions from a Python file.
|
||||
|
||||
@ -294,14 +314,16 @@ class CodeSimilarityAnalyzer:
|
||||
model_hint (`str` or `None`): Model name hint for sanitization.
|
||||
|
||||
Returns:
|
||||
`tuple[dict[str, str], dict[str, str], dict[str, list[str]]]`: A tuple containing:
|
||||
`tuple[dict[str, str], dict[str, str], dict[str, list[str]], dict[str, str]]`: A tuple containing:
|
||||
- definitions_raw: Mapping of identifiers to raw source code
|
||||
- definitions_sanitized: Mapping of identifiers to sanitized source code
|
||||
- definitions_tokens: Mapping of identifiers to sorted token lists
|
||||
- definitions_kind: Mapping of identifiers to either "class" or "function"
|
||||
"""
|
||||
definitions_raw = {}
|
||||
definitions_sanitized = {}
|
||||
definitions_tokens = {}
|
||||
definitions_kind = {}
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
lines = source.splitlines()
|
||||
tree = ast.parse(source)
|
||||
@ -322,7 +344,11 @@ class CodeSimilarityAnalyzer:
|
||||
sanitized = _sanitize_for_embedding(segment, model_hint, node.name)
|
||||
definitions_sanitized[identifier] = sanitized
|
||||
definitions_tokens[identifier] = sorted(_tokenize(sanitized))
|
||||
return definitions_raw, definitions_sanitized, definitions_tokens
|
||||
if isinstance(node, ast.ClassDef):
|
||||
definitions_kind[identifier] = "class"
|
||||
else:
|
||||
definitions_kind[identifier] = "function"
|
||||
return definitions_raw, definitions_sanitized, definitions_tokens, definitions_kind
|
||||
|
||||
def _infer_model_from_relative_path(self, relative_path: Path) -> str | None:
|
||||
try:
|
||||
@ -400,9 +426,12 @@ class CodeSimilarityAnalyzer:
|
||||
|
||||
for file_path in tqdm(files, desc="parse", leave=False):
|
||||
model_hint = self._infer_model_from_relative_path(file_path)
|
||||
definitions_raw, definitions_sanitized, definitions_tokens = self._extract_definitions(
|
||||
file_path, self.models_root, model_hint
|
||||
)
|
||||
(
|
||||
_,
|
||||
definitions_sanitized,
|
||||
definitions_tokens,
|
||||
_,
|
||||
) = self._extract_definitions(file_path, self.models_root, model_hint)
|
||||
for identifier in definitions_sanitized.keys():
|
||||
identifiers.append(identifier)
|
||||
sanitized_sources.append(definitions_sanitized[identifier])
|
||||
@ -418,6 +447,8 @@ class CodeSimilarityAnalyzer:
|
||||
with open(TOKENS_PATH, "w", encoding="utf-8") as file:
|
||||
json.dump(tokens_map, file)
|
||||
|
||||
self.index_dir = Path.cwd()
|
||||
|
||||
def _topk_embedding(
|
||||
self,
|
||||
query_embedding_row: np.ndarray,
|
||||
@ -439,7 +470,7 @@ class CodeSimilarityAnalyzer:
|
||||
continue
|
||||
if self_model_normalized and _normalize(parent_model) == self_model_normalized:
|
||||
continue
|
||||
output.append((f"{parent_model}::{match_name}", float(similarities[match_id])))
|
||||
output.append((identifier, float(similarities[match_id])))
|
||||
if len(output) >= k:
|
||||
break
|
||||
return output
|
||||
@ -480,12 +511,12 @@ class CodeSimilarityAnalyzer:
|
||||
continue
|
||||
score = len(query_tokens & tokens) / len(query_tokens | tokens)
|
||||
if score > 0:
|
||||
scores.append((f"{parent_model}::{match_name}", score))
|
||||
scores.append((identifier, score))
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
return scores[:k]
|
||||
|
||||
def analyze_file(
|
||||
self, modeling_file: Path, top_k_per_item: int = 5, allow_hub_fallback: bool = True
|
||||
self, modeling_file: Path, top_k_per_item: int = 5, allow_hub_fallback: bool = True, use_jaccard=False
|
||||
) -> dict[str, dict[str, list]]:
|
||||
"""
|
||||
Analyze a modeling file and find similar code definitions in the index.
|
||||
@ -502,16 +533,18 @@ class CodeSimilarityAnalyzer:
|
||||
if allow_hub_fallback:
|
||||
self.ensure_local_index()
|
||||
|
||||
base = safetensors_load(EMBEDDINGS_PATH)
|
||||
base = safetensors_load(str(self._resolve_index_path(EMBEDDINGS_PATH)))
|
||||
base_embeddings = base["embeddings"]
|
||||
with open(INDEX_MAP_PATH, "r", encoding="utf-8") as file:
|
||||
with open(self._resolve_index_path(INDEX_MAP_PATH), "r", encoding="utf-8") as file:
|
||||
identifier_map = {int(key): value for key, value in json.load(file).items()}
|
||||
identifiers = [identifier_map[i] for i in range(len(identifier_map))]
|
||||
with open(TOKENS_PATH, "r", encoding="utf-8") as file:
|
||||
with open(self._resolve_index_path(TOKENS_PATH), "r", encoding="utf-8") as file:
|
||||
tokens_map = json.load(file)
|
||||
|
||||
self_model = self._infer_query_model_name(modeling_file)
|
||||
definitions_raw, definitions_sanitized, _ = self._extract_definitions(modeling_file, None, self_model)
|
||||
definitions_raw, definitions_sanitized, _, definitions_kind = self._extract_definitions(
|
||||
modeling_file, None, self_model
|
||||
)
|
||||
query_identifiers = list(definitions_raw.keys())
|
||||
query_sources_sanitized = [definitions_sanitized[key] for key in query_identifiers]
|
||||
query_tokens_list = [set(_tokenize(source)) for source in query_sources_sanitized]
|
||||
@ -528,28 +561,146 @@ class CodeSimilarityAnalyzer:
|
||||
embedding_top = self._topk_embedding(
|
||||
query_embeddings[i], base_embeddings, identifier_map, self_model_normalized, query_name, top_k_per_item
|
||||
)
|
||||
jaccard_top = self._topk_jaccard(
|
||||
query_tokens_list[i], identifiers, tokens_map, self_model_normalized, query_name, top_k_per_item
|
||||
)
|
||||
embedding_set = {identifier for identifier, _ in embedding_top}
|
||||
jaccard_set = {identifier for identifier, _ in jaccard_top}
|
||||
intersection = list(embedding_set & jaccard_set)
|
||||
output[query_name] = {"embedding": embedding_top, "jaccard": jaccard_top, "intersection": intersection}
|
||||
kind = definitions_kind.get(query_identifier, "function")
|
||||
entry = {"kind": kind, "embedding": embedding_top}
|
||||
if use_jaccard:
|
||||
jaccard_top = self._topk_jaccard(
|
||||
query_tokens_list[i], identifiers, tokens_map, self_model_normalized, query_name, top_k_per_item
|
||||
)
|
||||
jaccard_set = {identifier for identifier, _ in jaccard_top}
|
||||
intersection = set(embedding_set & jaccard_set)
|
||||
|
||||
entry.update({"jaccard": jaccard_top, "intersection": intersection})
|
||||
output[query_name] = entry
|
||||
return output
|
||||
|
||||
|
||||
_RELEASE_RE = re.compile(
|
||||
r"(?:^|[\*_`\s>])(?:this|the)\s+model\s+was\s+released\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE
|
||||
)
|
||||
|
||||
|
||||
def build_date_data() -> dict[str, str]:
|
||||
"""
|
||||
Scan Markdown files in `root_dir` and build {model_id: date_released}.
|
||||
|
||||
- model_id is the filename without extension (e.g., "llama" for "llama.md")
|
||||
- date_released is the first YYYY-MM-DD matched after "...was released on ..."
|
||||
- Ignores non-*.md files and directories.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: mapping of model_id -> ISO date string (YYYY-MM-DD).
|
||||
Files without a match are simply omitted.
|
||||
"""
|
||||
|
||||
root_dir = transformers.__file__.split("src/transformers")[0]
|
||||
root = Path(root_dir).joinpath("docs/source/en/model_doc")
|
||||
result: dict[str, str] = {}
|
||||
|
||||
for md_path in root.glob("*.md"):
|
||||
try:
|
||||
text = md_path.read_text(encoding="utf-8", errors="ignore")
|
||||
except Exception:
|
||||
# Skip unreadable files quietly
|
||||
logging.info(f"Failed to read md for {md_path}")
|
||||
|
||||
m = _RELEASE_RE.search(text)
|
||||
if m:
|
||||
model_id = md_path.stem # e.g., "llama" from "llama.md"
|
||||
result[model_id] = m.group(1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _format_table(headers: list[str], rows: list[tuple[str, ...] | None], row_styles: list[str] | None = None) -> str:
|
||||
if not rows:
|
||||
return f"{ANSI_ROW}(no matches){ANSI_RESET}"
|
||||
|
||||
widths = [len(header) for header in headers]
|
||||
for row in rows:
|
||||
if row is None:
|
||||
continue
|
||||
for idx, cell in enumerate(row):
|
||||
widths[idx] = max(widths[idx], len(cell))
|
||||
|
||||
header_line = " | ".join(header.ljust(widths[idx]) for idx, header in enumerate(headers))
|
||||
divider = "-+-".join("-" * widths[idx] for idx in range(len(headers)))
|
||||
total_width = sum(widths) + 3 * (len(headers) - 1)
|
||||
|
||||
styled_rows = []
|
||||
style_idx = 0
|
||||
for row in rows:
|
||||
if row is None:
|
||||
styled_rows.append(f"{ANSI_SECTION}{'-' * total_width}{ANSI_RESET}")
|
||||
continue
|
||||
|
||||
line = " | ".join(cell.ljust(widths[col_idx]) for col_idx, cell in enumerate(row))
|
||||
style = ANSI_ROW
|
||||
if row_styles and style_idx < len(row_styles) and row_styles[style_idx]:
|
||||
style = row_styles[style_idx]
|
||||
styled_rows.append(f"{style}{line}{ANSI_RESET}")
|
||||
style_idx += 1
|
||||
|
||||
return "\n".join([f"{ANSI_SECTION}{header_line}{ANSI_RESET}", divider] + styled_rows)
|
||||
|
||||
|
||||
def _parse_release_date(value: str) -> datetime | None:
|
||||
"""Return a datetime parsed from YYYY-MM-DD strings, otherwise None."""
|
||||
try:
|
||||
return datetime.strptime(value, "%Y-%m-%d")
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
@cache
|
||||
def _load_definition_line_map(relative_path: str) -> dict[str, int]:
|
||||
"""Return {definition_name: line_number} for top-level definitions in the given file."""
|
||||
file_path = MODELS_ROOT / relative_path
|
||||
try:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
except (FileNotFoundError, OSError):
|
||||
return {} # gracefully keep going
|
||||
|
||||
try:
|
||||
tree = ast.parse(source)
|
||||
except SyntaxError:
|
||||
return {}
|
||||
|
||||
line_map: dict[str, int] = {}
|
||||
for node in ast.iter_child_nodes(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
line_map[node.name] = getattr(node, "lineno", None) or 1
|
||||
elif isinstance(node, ast.Assign):
|
||||
continue
|
||||
return line_map
|
||||
|
||||
|
||||
def _resolve_definition_location(relative_path: str, definition: str) -> tuple[str, str]:
|
||||
"""Return full path and formatted line number string for the given definition."""
|
||||
full_path = MODELS_ROOT / relative_path
|
||||
line = _load_definition_line_map(relative_path).get(definition)
|
||||
line_str = str(line) if line is not None else "?"
|
||||
return str(full_path), line_str
|
||||
|
||||
|
||||
def _colorize_heading(text: str) -> str:
|
||||
return f"{ANSI_HEADER}{ANSI_BOLD}{text}{ANSI_RESET}"
|
||||
|
||||
|
||||
def main():
|
||||
"""CLI entry point for the modular model detector."""
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
parser = argparse.ArgumentParser(prog="hf-code-sim")
|
||||
parser.add_argument("--build", action="store_true")
|
||||
parser.add_argument("--modeling-file", type=str)
|
||||
parser.add_argument("--modeling-file", type=str, help='You can just specify "vits" if you are lazy like me.')
|
||||
parser.add_argument(
|
||||
"--push-new-index", action="store_true", help="After --build, push index files to a Hub dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub-dataset", type=str, default=HUB_DATASET_DEFAULT, help="Hub dataset repo id to pull/push the index."
|
||||
)
|
||||
parser.add_argument("--use_jaccard", type=bool, default=False, help="Whether or not to use jaccard index")
|
||||
args = parser.parse_args()
|
||||
|
||||
analyzer = CodeSimilarityAnalyzer(hub_dataset=args.hub_dataset)
|
||||
@ -563,20 +714,198 @@ def main():
|
||||
if not args.modeling_file:
|
||||
raise SystemExit("Provide --modeling-file or use --build")
|
||||
|
||||
results = analyzer.analyze_file(Path(args.modeling_file), top_k_per_item=5, allow_hub_fallback=True)
|
||||
modeling_filename = Path(args.modeling_file).name
|
||||
dates = build_date_data()
|
||||
modeling_file = args.modeling_file
|
||||
if os.sep not in modeling_file:
|
||||
modeling_file = os.path.join("src", "transformers", "models", modeling_file, f"modeling_{modeling_file}.py")
|
||||
|
||||
results = analyzer.analyze_file(
|
||||
Path(modeling_file), top_k_per_item=5, allow_hub_fallback=True, use_jaccard=args.use_jaccard
|
||||
)
|
||||
modeling_filename = Path(modeling_file).name
|
||||
release_key = modeling_filename.split("modeling_")[-1][:-3]
|
||||
release_date = dates.get(release_key, "unknown release date")
|
||||
|
||||
aggregate_scores: dict[str, float] = {}
|
||||
for data in results.values():
|
||||
for identifier, score in data.get("embedding", []):
|
||||
try:
|
||||
relative_path, _ = identifier.split(":", 1)
|
||||
except ValueError:
|
||||
continue
|
||||
aggregate_scores[relative_path] = aggregate_scores.get(relative_path, 0.0) + score
|
||||
|
||||
best_candidate_path: str | None = None
|
||||
if aggregate_scores:
|
||||
best_candidate_path = max(aggregate_scores.items(), key=lambda item: item[1])[0]
|
||||
best_model = Path(best_candidate_path).parts[0] if Path(best_candidate_path).parts else "?"
|
||||
best_release = dates.get(best_model, "unknown release date")
|
||||
logging.info(
|
||||
f"{ANSI_HIGHLIGHT_CANDIDATE}Closest overall candidate: {MODELS_ROOT / best_candidate_path}"
|
||||
f" (release: {best_release}, total score: {aggregate_scores[best_candidate_path]:.4f}){ANSI_RESET}"
|
||||
)
|
||||
|
||||
grouped: dict[str, list[tuple[str, dict]]] = {"class": [], "function": []}
|
||||
for query_name, data in results.items():
|
||||
logging.info(f"{modeling_filename}::{query_name}:")
|
||||
logging.info(" embedding:")
|
||||
for identifier, score in data["embedding"]:
|
||||
logging.info(f" {identifier} ({score:.4f})")
|
||||
logging.info(" jaccard:")
|
||||
for identifier, score in data["jaccard"]:
|
||||
logging.info(f" {identifier} ({score:.4f})")
|
||||
logging.info(" intersection:")
|
||||
for identifier in data["intersection"]:
|
||||
logging.info(f" {identifier}")
|
||||
logging.info("")
|
||||
kind = data.get("kind", "function")
|
||||
grouped.setdefault(kind, []).append((query_name, data))
|
||||
|
||||
section_titles = [("class", "Classes"), ("function", "Functions")]
|
||||
legend_shown = False
|
||||
for kind, title in section_titles:
|
||||
entries = grouped.get(kind, [])
|
||||
if not entries:
|
||||
continue
|
||||
|
||||
metrics_present: set[str] = set()
|
||||
for _, data in entries:
|
||||
if data.get("embedding"):
|
||||
metrics_present.add("embedding")
|
||||
if args.use_jaccard:
|
||||
if data.get("jaccard"):
|
||||
metrics_present.add("jaccard")
|
||||
if data.get("intersection"):
|
||||
metrics_present.add("intersection")
|
||||
|
||||
include_metric_column = bool(metrics_present - {"embedding"})
|
||||
headers = ["Symbol", "Path", "Score", "Release"]
|
||||
if include_metric_column:
|
||||
headers = ["Symbol", "Metric", "Path", "Score", "Release"]
|
||||
|
||||
table_rows: list[tuple[str, ...] | None] = []
|
||||
row_styles: list[str] = []
|
||||
has_metric_rows = False
|
||||
|
||||
logging.info(_colorize_heading(title))
|
||||
|
||||
for query_name, data in entries:
|
||||
if table_rows:
|
||||
table_rows.append(None)
|
||||
|
||||
symbol_label = query_name
|
||||
if release_date:
|
||||
symbol_label = f"{symbol_label}"
|
||||
|
||||
symbol_row = (symbol_label,) + ("",) * (len(headers) - 1)
|
||||
table_rows.append(symbol_row)
|
||||
row_styles.append(ANSI_BOLD)
|
||||
|
||||
embedding_details: list[tuple[str, str, str, float, str]] = []
|
||||
embedding_style_indices: list[int] = []
|
||||
|
||||
for identifier, score in data.get("embedding", []):
|
||||
try:
|
||||
relative_path, match_name = identifier.split(":", 1)
|
||||
except ValueError:
|
||||
continue
|
||||
model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?"
|
||||
match_release = dates.get(model_id, "unknown release date")
|
||||
full_path, line = _resolve_definition_location(relative_path, match_name)
|
||||
display_path = f"{full_path}:{line} ({match_name})"
|
||||
|
||||
if include_metric_column:
|
||||
row = ("", "embedding", display_path, f"{score:.4f}", match_release)
|
||||
else:
|
||||
row = ("", display_path, f"{score:.4f}", match_release)
|
||||
|
||||
table_rows.append(row)
|
||||
row_styles.append(ANSI_ROW)
|
||||
embedding_style_indices.append(len(row_styles) - 1)
|
||||
embedding_details.append((relative_path, model_id, match_name, score, match_release))
|
||||
has_metric_rows = True
|
||||
|
||||
if embedding_details:
|
||||
highest_score = None
|
||||
highest_idx = None
|
||||
for idx, (_, _, _, score, _) in enumerate(embedding_details):
|
||||
if highest_score is None or score > highest_score:
|
||||
highest_score = score
|
||||
highest_idx = idx
|
||||
|
||||
if highest_idx is not None:
|
||||
row_styles[embedding_style_indices[highest_idx]] = ANSI_HIGHLIGHT_TOP
|
||||
|
||||
if highest_score is not None:
|
||||
oldest_idx = None
|
||||
oldest_date = None
|
||||
for idx, (_, model_id, _, score, release_value) in enumerate(embedding_details):
|
||||
if highest_score - score > 0.1:
|
||||
continue
|
||||
parsed = _parse_release_date(release_value)
|
||||
if parsed is None:
|
||||
continue
|
||||
if oldest_date is None or parsed < oldest_date:
|
||||
oldest_date = parsed
|
||||
oldest_idx = idx
|
||||
if (
|
||||
oldest_idx is not None
|
||||
and row_styles[embedding_style_indices[oldest_idx]] != ANSI_HIGHLIGHT_TOP
|
||||
):
|
||||
row_styles[embedding_style_indices[oldest_idx]] = ANSI_HIGHLIGHT_OLD
|
||||
|
||||
if best_candidate_path is not None:
|
||||
for idx, (relative_path, _, _, _, _) in enumerate(embedding_details):
|
||||
style_position = embedding_style_indices[idx]
|
||||
if row_styles[style_position] != ANSI_ROW:
|
||||
continue
|
||||
if relative_path == best_candidate_path:
|
||||
row_styles[style_position] = ANSI_HIGHLIGHT_CANDIDATE
|
||||
|
||||
if args.use_jaccard:
|
||||
for identifier, score in data.get("jaccard", []):
|
||||
try:
|
||||
relative_path, match_name = identifier.split(":", 1)
|
||||
except ValueError:
|
||||
continue
|
||||
model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?"
|
||||
match_release = dates.get(model_id, "unknown release date")
|
||||
full_path, line = _resolve_definition_location(relative_path, match_name)
|
||||
display_path = f"{full_path}:{line} ({match_name})"
|
||||
|
||||
if include_metric_column:
|
||||
row = ("", "jaccard", display_path, f"{score:.4f}", match_release)
|
||||
else:
|
||||
row = ("", display_path, f"{score:.4f}", match_release)
|
||||
|
||||
table_rows.append(row)
|
||||
row_styles.append(ANSI_ROW)
|
||||
has_metric_rows = True
|
||||
if best_candidate_path == relative_path:
|
||||
row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE
|
||||
|
||||
for identifier in sorted(data.get("intersection", [])):
|
||||
try:
|
||||
relative_path, match_name = identifier.split(":", 1)
|
||||
except ValueError:
|
||||
continue
|
||||
model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?"
|
||||
match_release = dates.get(model_id, "unknown release date")
|
||||
full_path, line = _resolve_definition_location(relative_path, match_name)
|
||||
display_path = f"{full_path}:{line} ({match_name})"
|
||||
|
||||
if include_metric_column:
|
||||
row = ("", "intersection", display_path, "--", match_release)
|
||||
else:
|
||||
row = ("", display_path, "--", match_release)
|
||||
|
||||
table_rows.append(row)
|
||||
row_styles.append(ANSI_ROW)
|
||||
has_metric_rows = True
|
||||
if best_candidate_path == relative_path:
|
||||
row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE
|
||||
|
||||
if table_rows:
|
||||
if not legend_shown and has_metric_rows:
|
||||
logging.info(
|
||||
"Legend: "
|
||||
f"{ANSI_HIGHLIGHT_TOP}highest match{ANSI_RESET}, "
|
||||
f"{ANSI_HIGHLIGHT_OLD}oldest within 0.1{ANSI_RESET}, "
|
||||
f"{ANSI_HIGHLIGHT_CANDIDATE}closest overall candidate{ANSI_RESET}"
|
||||
)
|
||||
legend_shown = True
|
||||
|
||||
logging.info(_format_table(headers, table_rows, row_styles))
|
||||
logging.info("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user