[ModularChecker] QOL for the modular checker (#41361)

* update

* fancy table fancy prints

* download to cache folder, never need it everagain

* stule

* update based on review
This commit is contained in:
Arthur
2025-10-06 12:52:10 +02:00
committed by GitHub
parent 9db58abd6e
commit 0452f28544

View File

@ -103,20 +103,34 @@ import json
import logging
import os
import re
from datetime import datetime
from functools import cache
from pathlib import Path
import numpy as np
import torch
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub import logging as huggingface_hub_logging
from safetensors.numpy import load_file as safetensors_load
from safetensors.numpy import save_file as safetensors_save
from tqdm import tqdm
import transformers
from transformers import AutoModel, AutoTokenizer
from transformers.utils import logging as transformers_logging
# ANSI color codes for CLI output styling
ANSI_RESET = "\033[0m"
ANSI_BOLD = "\033[1m"
ANSI_HEADER = "\033[1;36m"
ANSI_SECTION = "\033[1;35m"
ANSI_ROW = "\033[0;37m"
ANSI_HIGHLIGHT_TOP = "\033[1;32m"
ANSI_HIGHLIGHT_OLD = "\033[1;33m"
ANSI_HIGHLIGHT_CANDIDATE = "\033[1;34m"
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
@ -238,34 +252,40 @@ class CodeSimilarityAnalyzer:
self.models_root = MODELS_ROOT
self.hub_dataset = hub_dataset
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = "auto"
self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
self.model = (
AutoModel.from_pretrained(
EMBEDDING_MODEL,
torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32,
)
.eval()
.to(self.device)
)
self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto", device_map="auto").eval()
self.device = self.model.device
self.index_dir: Path | None = None
# ---------- HUB IO ----------
def _resolve_index_path(self, filename: str) -> Path:
if self.index_dir is None:
return Path(filename)
return self.index_dir / filename
def ensure_local_index(self) -> None:
"""Download index files from Hub if they don't exist locally."""
have_all = Path(EMBEDDINGS_PATH).exists() and Path(INDEX_MAP_PATH).exists() and Path(TOKENS_PATH).exists()
if have_all:
"""Ensure index files are available locally, preferring Hub cache snapshots."""
if self.index_dir is not None and all(
(self.index_dir / fname).exists() for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH)
):
return
logging.info(f"downloading index from hub: {self.hub_dataset}")
for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH):
hf_hub_download(
repo_id=self.hub_dataset,
filename=fname,
repo_type="dataset",
local_dir=".",
local_dir_use_symlinks=False,
)
workspace_dir = Path.cwd()
if all((workspace_dir / fname).exists() for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH)):
self.index_dir = workspace_dir
return
logging.info(f"downloading index from hub cache: {self.hub_dataset}")
snapshot_path = snapshot_download(repo_id=self.hub_dataset, repo_type="dataset")
snapshot_dir = Path(snapshot_path)
missing = [
fname for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH) if not (snapshot_dir / fname).exists()
]
if missing:
raise FileNotFoundError("Missing expected files in Hub snapshot: " + ", ".join(missing))
self.index_dir = snapshot_dir
def push_index_to_hub(self) -> None:
"""Upload index files to the Hub dataset repository."""
@ -284,7 +304,7 @@ class CodeSimilarityAnalyzer:
def _extract_definitions(
self, file_path: Path, relative_to: Path | None = None, model_hint: str | None = None
) -> tuple[dict[str, str], dict[str, str], dict[str, list[str]]]:
) -> tuple[dict[str, str], dict[str, str], dict[str, list[str]], dict[str, str]]:
"""
Extract class and function definitions from a Python file.
@ -294,14 +314,16 @@ class CodeSimilarityAnalyzer:
model_hint (`str` or `None`): Model name hint for sanitization.
Returns:
`tuple[dict[str, str], dict[str, str], dict[str, list[str]]]`: A tuple containing:
`tuple[dict[str, str], dict[str, str], dict[str, list[str]], dict[str, str]]`: A tuple containing:
- definitions_raw: Mapping of identifiers to raw source code
- definitions_sanitized: Mapping of identifiers to sanitized source code
- definitions_tokens: Mapping of identifiers to sorted token lists
- definitions_kind: Mapping of identifiers to either "class" or "function"
"""
definitions_raw = {}
definitions_sanitized = {}
definitions_tokens = {}
definitions_kind = {}
source = file_path.read_text(encoding="utf-8")
lines = source.splitlines()
tree = ast.parse(source)
@ -322,7 +344,11 @@ class CodeSimilarityAnalyzer:
sanitized = _sanitize_for_embedding(segment, model_hint, node.name)
definitions_sanitized[identifier] = sanitized
definitions_tokens[identifier] = sorted(_tokenize(sanitized))
return definitions_raw, definitions_sanitized, definitions_tokens
if isinstance(node, ast.ClassDef):
definitions_kind[identifier] = "class"
else:
definitions_kind[identifier] = "function"
return definitions_raw, definitions_sanitized, definitions_tokens, definitions_kind
def _infer_model_from_relative_path(self, relative_path: Path) -> str | None:
try:
@ -400,9 +426,12 @@ class CodeSimilarityAnalyzer:
for file_path in tqdm(files, desc="parse", leave=False):
model_hint = self._infer_model_from_relative_path(file_path)
definitions_raw, definitions_sanitized, definitions_tokens = self._extract_definitions(
file_path, self.models_root, model_hint
)
(
_,
definitions_sanitized,
definitions_tokens,
_,
) = self._extract_definitions(file_path, self.models_root, model_hint)
for identifier in definitions_sanitized.keys():
identifiers.append(identifier)
sanitized_sources.append(definitions_sanitized[identifier])
@ -418,6 +447,8 @@ class CodeSimilarityAnalyzer:
with open(TOKENS_PATH, "w", encoding="utf-8") as file:
json.dump(tokens_map, file)
self.index_dir = Path.cwd()
def _topk_embedding(
self,
query_embedding_row: np.ndarray,
@ -439,7 +470,7 @@ class CodeSimilarityAnalyzer:
continue
if self_model_normalized and _normalize(parent_model) == self_model_normalized:
continue
output.append((f"{parent_model}::{match_name}", float(similarities[match_id])))
output.append((identifier, float(similarities[match_id])))
if len(output) >= k:
break
return output
@ -480,12 +511,12 @@ class CodeSimilarityAnalyzer:
continue
score = len(query_tokens & tokens) / len(query_tokens | tokens)
if score > 0:
scores.append((f"{parent_model}::{match_name}", score))
scores.append((identifier, score))
scores.sort(key=lambda x: x[1], reverse=True)
return scores[:k]
def analyze_file(
self, modeling_file: Path, top_k_per_item: int = 5, allow_hub_fallback: bool = True
self, modeling_file: Path, top_k_per_item: int = 5, allow_hub_fallback: bool = True, use_jaccard=False
) -> dict[str, dict[str, list]]:
"""
Analyze a modeling file and find similar code definitions in the index.
@ -502,16 +533,18 @@ class CodeSimilarityAnalyzer:
if allow_hub_fallback:
self.ensure_local_index()
base = safetensors_load(EMBEDDINGS_PATH)
base = safetensors_load(str(self._resolve_index_path(EMBEDDINGS_PATH)))
base_embeddings = base["embeddings"]
with open(INDEX_MAP_PATH, "r", encoding="utf-8") as file:
with open(self._resolve_index_path(INDEX_MAP_PATH), "r", encoding="utf-8") as file:
identifier_map = {int(key): value for key, value in json.load(file).items()}
identifiers = [identifier_map[i] for i in range(len(identifier_map))]
with open(TOKENS_PATH, "r", encoding="utf-8") as file:
with open(self._resolve_index_path(TOKENS_PATH), "r", encoding="utf-8") as file:
tokens_map = json.load(file)
self_model = self._infer_query_model_name(modeling_file)
definitions_raw, definitions_sanitized, _ = self._extract_definitions(modeling_file, None, self_model)
definitions_raw, definitions_sanitized, _, definitions_kind = self._extract_definitions(
modeling_file, None, self_model
)
query_identifiers = list(definitions_raw.keys())
query_sources_sanitized = [definitions_sanitized[key] for key in query_identifiers]
query_tokens_list = [set(_tokenize(source)) for source in query_sources_sanitized]
@ -528,28 +561,146 @@ class CodeSimilarityAnalyzer:
embedding_top = self._topk_embedding(
query_embeddings[i], base_embeddings, identifier_map, self_model_normalized, query_name, top_k_per_item
)
jaccard_top = self._topk_jaccard(
query_tokens_list[i], identifiers, tokens_map, self_model_normalized, query_name, top_k_per_item
)
embedding_set = {identifier for identifier, _ in embedding_top}
jaccard_set = {identifier for identifier, _ in jaccard_top}
intersection = list(embedding_set & jaccard_set)
output[query_name] = {"embedding": embedding_top, "jaccard": jaccard_top, "intersection": intersection}
kind = definitions_kind.get(query_identifier, "function")
entry = {"kind": kind, "embedding": embedding_top}
if use_jaccard:
jaccard_top = self._topk_jaccard(
query_tokens_list[i], identifiers, tokens_map, self_model_normalized, query_name, top_k_per_item
)
jaccard_set = {identifier for identifier, _ in jaccard_top}
intersection = set(embedding_set & jaccard_set)
entry.update({"jaccard": jaccard_top, "intersection": intersection})
output[query_name] = entry
return output
_RELEASE_RE = re.compile(
r"(?:^|[\*_`\s>])(?:this|the)\s+model\s+was\s+released\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE
)
def build_date_data() -> dict[str, str]:
"""
Scan Markdown files in `root_dir` and build {model_id: date_released}.
- model_id is the filename without extension (e.g., "llama" for "llama.md")
- date_released is the first YYYY-MM-DD matched after "...was released on ..."
- Ignores non-*.md files and directories.
Returns:
dict[str, str]: mapping of model_id -> ISO date string (YYYY-MM-DD).
Files without a match are simply omitted.
"""
root_dir = transformers.__file__.split("src/transformers")[0]
root = Path(root_dir).joinpath("docs/source/en/model_doc")
result: dict[str, str] = {}
for md_path in root.glob("*.md"):
try:
text = md_path.read_text(encoding="utf-8", errors="ignore")
except Exception:
# Skip unreadable files quietly
logging.info(f"Failed to read md for {md_path}")
m = _RELEASE_RE.search(text)
if m:
model_id = md_path.stem # e.g., "llama" from "llama.md"
result[model_id] = m.group(1)
return result
def _format_table(headers: list[str], rows: list[tuple[str, ...] | None], row_styles: list[str] | None = None) -> str:
if not rows:
return f"{ANSI_ROW}(no matches){ANSI_RESET}"
widths = [len(header) for header in headers]
for row in rows:
if row is None:
continue
for idx, cell in enumerate(row):
widths[idx] = max(widths[idx], len(cell))
header_line = " | ".join(header.ljust(widths[idx]) for idx, header in enumerate(headers))
divider = "-+-".join("-" * widths[idx] for idx in range(len(headers)))
total_width = sum(widths) + 3 * (len(headers) - 1)
styled_rows = []
style_idx = 0
for row in rows:
if row is None:
styled_rows.append(f"{ANSI_SECTION}{'-' * total_width}{ANSI_RESET}")
continue
line = " | ".join(cell.ljust(widths[col_idx]) for col_idx, cell in enumerate(row))
style = ANSI_ROW
if row_styles and style_idx < len(row_styles) and row_styles[style_idx]:
style = row_styles[style_idx]
styled_rows.append(f"{style}{line}{ANSI_RESET}")
style_idx += 1
return "\n".join([f"{ANSI_SECTION}{header_line}{ANSI_RESET}", divider] + styled_rows)
def _parse_release_date(value: str) -> datetime | None:
"""Return a datetime parsed from YYYY-MM-DD strings, otherwise None."""
try:
return datetime.strptime(value, "%Y-%m-%d")
except (TypeError, ValueError):
return None
@cache
def _load_definition_line_map(relative_path: str) -> dict[str, int]:
"""Return {definition_name: line_number} for top-level definitions in the given file."""
file_path = MODELS_ROOT / relative_path
try:
source = file_path.read_text(encoding="utf-8")
except (FileNotFoundError, OSError):
return {} # gracefully keep going
try:
tree = ast.parse(source)
except SyntaxError:
return {}
line_map: dict[str, int] = {}
for node in ast.iter_child_nodes(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
line_map[node.name] = getattr(node, "lineno", None) or 1
elif isinstance(node, ast.Assign):
continue
return line_map
def _resolve_definition_location(relative_path: str, definition: str) -> tuple[str, str]:
"""Return full path and formatted line number string for the given definition."""
full_path = MODELS_ROOT / relative_path
line = _load_definition_line_map(relative_path).get(definition)
line_str = str(line) if line is not None else "?"
return str(full_path), line_str
def _colorize_heading(text: str) -> str:
return f"{ANSI_HEADER}{ANSI_BOLD}{text}{ANSI_RESET}"
def main():
"""CLI entry point for the modular model detector."""
logging.basicConfig(level=logging.INFO, format="%(message)s")
parser = argparse.ArgumentParser(prog="hf-code-sim")
parser.add_argument("--build", action="store_true")
parser.add_argument("--modeling-file", type=str)
parser.add_argument("--modeling-file", type=str, help='You can just specify "vits" if you are lazy like me.')
parser.add_argument(
"--push-new-index", action="store_true", help="After --build, push index files to a Hub dataset."
)
parser.add_argument(
"--hub-dataset", type=str, default=HUB_DATASET_DEFAULT, help="Hub dataset repo id to pull/push the index."
)
parser.add_argument("--use_jaccard", type=bool, default=False, help="Whether or not to use jaccard index")
args = parser.parse_args()
analyzer = CodeSimilarityAnalyzer(hub_dataset=args.hub_dataset)
@ -563,20 +714,198 @@ def main():
if not args.modeling_file:
raise SystemExit("Provide --modeling-file or use --build")
results = analyzer.analyze_file(Path(args.modeling_file), top_k_per_item=5, allow_hub_fallback=True)
modeling_filename = Path(args.modeling_file).name
dates = build_date_data()
modeling_file = args.modeling_file
if os.sep not in modeling_file:
modeling_file = os.path.join("src", "transformers", "models", modeling_file, f"modeling_{modeling_file}.py")
results = analyzer.analyze_file(
Path(modeling_file), top_k_per_item=5, allow_hub_fallback=True, use_jaccard=args.use_jaccard
)
modeling_filename = Path(modeling_file).name
release_key = modeling_filename.split("modeling_")[-1][:-3]
release_date = dates.get(release_key, "unknown release date")
aggregate_scores: dict[str, float] = {}
for data in results.values():
for identifier, score in data.get("embedding", []):
try:
relative_path, _ = identifier.split(":", 1)
except ValueError:
continue
aggregate_scores[relative_path] = aggregate_scores.get(relative_path, 0.0) + score
best_candidate_path: str | None = None
if aggregate_scores:
best_candidate_path = max(aggregate_scores.items(), key=lambda item: item[1])[0]
best_model = Path(best_candidate_path).parts[0] if Path(best_candidate_path).parts else "?"
best_release = dates.get(best_model, "unknown release date")
logging.info(
f"{ANSI_HIGHLIGHT_CANDIDATE}Closest overall candidate: {MODELS_ROOT / best_candidate_path}"
f" (release: {best_release}, total score: {aggregate_scores[best_candidate_path]:.4f}){ANSI_RESET}"
)
grouped: dict[str, list[tuple[str, dict]]] = {"class": [], "function": []}
for query_name, data in results.items():
logging.info(f"{modeling_filename}::{query_name}:")
logging.info(" embedding:")
for identifier, score in data["embedding"]:
logging.info(f" {identifier} ({score:.4f})")
logging.info(" jaccard:")
for identifier, score in data["jaccard"]:
logging.info(f" {identifier} ({score:.4f})")
logging.info(" intersection:")
for identifier in data["intersection"]:
logging.info(f" {identifier}")
logging.info("")
kind = data.get("kind", "function")
grouped.setdefault(kind, []).append((query_name, data))
section_titles = [("class", "Classes"), ("function", "Functions")]
legend_shown = False
for kind, title in section_titles:
entries = grouped.get(kind, [])
if not entries:
continue
metrics_present: set[str] = set()
for _, data in entries:
if data.get("embedding"):
metrics_present.add("embedding")
if args.use_jaccard:
if data.get("jaccard"):
metrics_present.add("jaccard")
if data.get("intersection"):
metrics_present.add("intersection")
include_metric_column = bool(metrics_present - {"embedding"})
headers = ["Symbol", "Path", "Score", "Release"]
if include_metric_column:
headers = ["Symbol", "Metric", "Path", "Score", "Release"]
table_rows: list[tuple[str, ...] | None] = []
row_styles: list[str] = []
has_metric_rows = False
logging.info(_colorize_heading(title))
for query_name, data in entries:
if table_rows:
table_rows.append(None)
symbol_label = query_name
if release_date:
symbol_label = f"{symbol_label}"
symbol_row = (symbol_label,) + ("",) * (len(headers) - 1)
table_rows.append(symbol_row)
row_styles.append(ANSI_BOLD)
embedding_details: list[tuple[str, str, str, float, str]] = []
embedding_style_indices: list[int] = []
for identifier, score in data.get("embedding", []):
try:
relative_path, match_name = identifier.split(":", 1)
except ValueError:
continue
model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?"
match_release = dates.get(model_id, "unknown release date")
full_path, line = _resolve_definition_location(relative_path, match_name)
display_path = f"{full_path}:{line} ({match_name})"
if include_metric_column:
row = ("", "embedding", display_path, f"{score:.4f}", match_release)
else:
row = ("", display_path, f"{score:.4f}", match_release)
table_rows.append(row)
row_styles.append(ANSI_ROW)
embedding_style_indices.append(len(row_styles) - 1)
embedding_details.append((relative_path, model_id, match_name, score, match_release))
has_metric_rows = True
if embedding_details:
highest_score = None
highest_idx = None
for idx, (_, _, _, score, _) in enumerate(embedding_details):
if highest_score is None or score > highest_score:
highest_score = score
highest_idx = idx
if highest_idx is not None:
row_styles[embedding_style_indices[highest_idx]] = ANSI_HIGHLIGHT_TOP
if highest_score is not None:
oldest_idx = None
oldest_date = None
for idx, (_, model_id, _, score, release_value) in enumerate(embedding_details):
if highest_score - score > 0.1:
continue
parsed = _parse_release_date(release_value)
if parsed is None:
continue
if oldest_date is None or parsed < oldest_date:
oldest_date = parsed
oldest_idx = idx
if (
oldest_idx is not None
and row_styles[embedding_style_indices[oldest_idx]] != ANSI_HIGHLIGHT_TOP
):
row_styles[embedding_style_indices[oldest_idx]] = ANSI_HIGHLIGHT_OLD
if best_candidate_path is not None:
for idx, (relative_path, _, _, _, _) in enumerate(embedding_details):
style_position = embedding_style_indices[idx]
if row_styles[style_position] != ANSI_ROW:
continue
if relative_path == best_candidate_path:
row_styles[style_position] = ANSI_HIGHLIGHT_CANDIDATE
if args.use_jaccard:
for identifier, score in data.get("jaccard", []):
try:
relative_path, match_name = identifier.split(":", 1)
except ValueError:
continue
model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?"
match_release = dates.get(model_id, "unknown release date")
full_path, line = _resolve_definition_location(relative_path, match_name)
display_path = f"{full_path}:{line} ({match_name})"
if include_metric_column:
row = ("", "jaccard", display_path, f"{score:.4f}", match_release)
else:
row = ("", display_path, f"{score:.4f}", match_release)
table_rows.append(row)
row_styles.append(ANSI_ROW)
has_metric_rows = True
if best_candidate_path == relative_path:
row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE
for identifier in sorted(data.get("intersection", [])):
try:
relative_path, match_name = identifier.split(":", 1)
except ValueError:
continue
model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?"
match_release = dates.get(model_id, "unknown release date")
full_path, line = _resolve_definition_location(relative_path, match_name)
display_path = f"{full_path}:{line} ({match_name})"
if include_metric_column:
row = ("", "intersection", display_path, "--", match_release)
else:
row = ("", display_path, "--", match_release)
table_rows.append(row)
row_styles.append(ANSI_ROW)
has_metric_rows = True
if best_candidate_path == relative_path:
row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE
if table_rows:
if not legend_shown and has_metric_rows:
logging.info(
"Legend: "
f"{ANSI_HIGHLIGHT_TOP}highest match{ANSI_RESET}, "
f"{ANSI_HIGHLIGHT_OLD}oldest within 0.1{ANSI_RESET}, "
f"{ANSI_HIGHLIGHT_CANDIDATE}closest overall candidate{ANSI_RESET}"
)
legend_shown = True
logging.info(_format_table(headers, table_rows, row_styles))
logging.info("")
if __name__ == "__main__":