mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
* doc * doc * no remote code * safe-ize the release + remove remote * fixes * add some documentation as well
584 lines
23 KiB
Python
584 lines
23 KiB
Python
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# 🔴🔴🔴 THIS IS AN INTERNAL TOOL. It WILL interact with the hub and use significant local compute resources. Use at your own risk.
|
|
|
|
|
|
"""
|
|
Modular model detector: utilities for detecting code similarities between model implementations.
|
|
|
|
This module provides tools to analyze and detect similarities between different model implementations
|
|
in the transformers library. It uses both embedding-based and token-based (Jaccard) similarity metrics
|
|
to identify similar code patterns across different model definitions.
|
|
|
|
Its function is to identify which models can be _modular_-ized, meaning, which already existing classes are
|
|
present in the codebase and look very similar to the one we have.
|
|
|
|
Two scores are computed, one is a code embedding, and the other is a simple Jaccard bag-of-tokens index for overlap
|
|
of token sets. A score of 1.00 means the code is identical.
|
|
|
|
Usage:
|
|
|
|
```bash
|
|
cd transformers
|
|
|
|
# Use directly the util, it will download the index embedding from the hub. It will require some RAM/VRAM.
|
|
|
|
>>> python utils/modular_model_detector.py --modeling-file my_new_beit3_modeling_file.py
|
|
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 33.62it/s]
|
|
encoding 21 query definitions with Qwen/Qwen3-Embedding-4B (device=cuda, batch=16, max_length=4096)
|
|
stuff.py::Beit3ImageTextMatchingOutput:
|
|
embedding:
|
|
blip_2::Blip2ImageTextMatchingModelOutput (0.9994)
|
|
chinese_clip::ChineseCLIPOutput (0.9818)
|
|
owlvit::OwlViTOutput (0.9818)
|
|
aimv2::Aimv2Output (0.9818)
|
|
blip::BlipOutput (0.9818)
|
|
jaccard:
|
|
owlv2::Owlv2Output (0.9667)
|
|
metaclip_2::MetaClip2Output (0.9667)
|
|
altclip::AltCLIPOutput (0.9667)
|
|
owlvit::OwlViTOutput (0.9667)
|
|
blip::BlipOutput (0.9667)
|
|
intersection:
|
|
blip::BlipOutput
|
|
owlvit::OwlViTOutput
|
|
|
|
stuff.py::Beit3MLP:
|
|
embedding:
|
|
efficientloftr::EfficientLoFTRMLP (0.9718)
|
|
seggpt::SegGptMlp (0.9650)
|
|
mgp_str::MgpstrMlp (0.9646)
|
|
vitpose_backbone::VitPoseBackboneMLP (0.9640)
|
|
granitemoeshared::GraniteMoeSharedMLP (0.9633)
|
|
jaccard:
|
|
chinese_clip::ChineseCLIPTextSelfOutput (0.5294)
|
|
convbert::ConvBertSelfOutput (0.5294)
|
|
bert::BertSelfOutput (0.5294)
|
|
roformer::RoFormerSelfOutput (0.5294)
|
|
layoutlmv3::LayoutLMv3SelfOutput (0.5294)
|
|
intersection:
|
|
|
|
stuff.py::Beit3FeedForwardNetwork:
|
|
embedding:
|
|
prophetnet::ProphetNetFeedForward (0.9766)
|
|
dab_detr::DabDetrDecoderLayerFFN (0.9730)
|
|
kosmos2::Kosmos2TextFFN (0.9697)
|
|
kosmos2_5::Kosmos2_5TextFFN (0.9697)
|
|
parakeet::ParakeetEncoderFeedForward (0.9678)
|
|
jaccard:
|
|
groupvit::GroupViTMLP (0.4898)
|
|
convbert::ConvBertOutput (0.4600)
|
|
chinese_clip::ChineseCLIPTextOutput (0.4565)
|
|
bert::BertOutput (0.4565)
|
|
roformer::RoFormerOutput (0.4565)
|
|
intersection:
|
|
|
|
|
|
|
|
```
|
|
|
|
|
|
# If you wish to build the index first, you can run
|
|
|
|
python utils/modular_model_detector.py --build
|
|
|
|
# You can also change the embedding model for a larger/smaller one.
|
|
"""
|
|
|
|
import argparse
|
|
import ast
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
from huggingface_hub import HfApi, hf_hub_download
|
|
from huggingface_hub import logging as huggingface_hub_logging
|
|
from safetensors.numpy import load_file as safetensors_load
|
|
from safetensors.numpy import save_file as safetensors_save
|
|
from tqdm import tqdm
|
|
|
|
from transformers import AutoModel, AutoTokenizer
|
|
from transformers.utils import logging as transformers_logging
|
|
|
|
|
|
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
|
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
|
|
|
MODELS_ROOT = Path("src/transformers/models")
|
|
EMBEDDINGS_PATH = "embeddings.safetensors"
|
|
INDEX_MAP_PATH = "code_index_map.json"
|
|
TOKENS_PATH = "code_index_tokens.json"
|
|
HUB_DATASET_DEFAULT = "hf-internal-testing/transformers_code_embeddings"
|
|
|
|
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B"
|
|
BATCH_SIZE = 16
|
|
MAX_LENGTH = 4096
|
|
|
|
|
|
def _normalize(string: str | None) -> str:
|
|
"""
|
|
Normalize a string by removing all non-alphanumeric characters and converting to lowercase.
|
|
|
|
Args:
|
|
string (`str` or `None`): The string to normalize.
|
|
|
|
Returns:
|
|
`str`: The normalized string, or empty string if input is None.
|
|
"""
|
|
return re.sub(r"[^a-z0-9]+", "", string.lower()) if string else ""
|
|
|
|
|
|
def _strip_source_for_tokens(code: str) -> str:
|
|
"""
|
|
Strip docstrings, comments, and import statements from source code.
|
|
|
|
Args:
|
|
code (`str`): The source code to strip.
|
|
|
|
Returns:
|
|
`str`: The stripped source code.
|
|
"""
|
|
code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code)
|
|
code = re.sub(r"#.*", "", code)
|
|
return "\n".join(line for line in code.splitlines() if not re.match(r"\s*(from|import)\s+", line))
|
|
|
|
|
|
def _tokenize(code: str) -> set[str]:
|
|
"""
|
|
Extract all Python identifiers from source code.
|
|
|
|
Args:
|
|
code (`str`): The source code to tokenize.
|
|
|
|
Returns:
|
|
`set[str]`: A set of all identifiers found in the code.
|
|
"""
|
|
return set(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", code))
|
|
|
|
|
|
def _leading_symbol_prefix(name: str) -> str:
|
|
"""
|
|
Extract the leading prefix from a symbol name (e.g., 'Llama' from 'LlamaAttention').
|
|
|
|
Args:
|
|
name (`str`): The symbol name to extract prefix from.
|
|
|
|
Returns:
|
|
`str`: The leading prefix, or empty string if no match.
|
|
"""
|
|
match = re.match(r"^([A-Z][a-z0-9]+)", name) or re.match(r"^([A-Za-z0-9]+)", name)
|
|
return match.group(1) if match else ""
|
|
|
|
|
|
def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str | None) -> str:
|
|
"""
|
|
Sanitize code for embedding by replacing model-specific identifiers with generic placeholder.
|
|
|
|
Args:
|
|
code (`str`): The source code to sanitize.
|
|
model_hint (`str` or `None`): Hint about the model name (e.g., 'llama').
|
|
symbol_hint (`str` or `None`): Hint about the symbol name (e.g., 'LlamaAttention').
|
|
|
|
Returns:
|
|
`str`: The sanitized code with model-specific identifiers replaced by 'Model'.
|
|
"""
|
|
base = _strip_source_for_tokens(code)
|
|
variants = set()
|
|
if model_hint:
|
|
variants.add(model_hint)
|
|
variants.add(model_hint.replace("_", ""))
|
|
variants.add(re.sub(r"\d+", "", model_hint))
|
|
if symbol_hint:
|
|
prefix = _leading_symbol_prefix(symbol_hint)
|
|
if prefix:
|
|
variants.add(prefix)
|
|
variants.add(prefix.replace("_", ""))
|
|
variants.add(re.sub(r"\d+", "", prefix))
|
|
variants |= {variant.lower() for variant in list(variants)}
|
|
sanitized = base
|
|
for variant in sorted({x for x in variants if len(x) >= 3}, key=len, reverse=True):
|
|
sanitized = re.sub(re.escape(variant), "Model", sanitized, flags=re.IGNORECASE)
|
|
return sanitized
|
|
|
|
|
|
class CodeSimilarityAnalyzer:
|
|
"""
|
|
Analyzer for detecting code similarities between model implementations.
|
|
|
|
This class uses embedding-based and token-based similarity metrics to identify similar
|
|
code patterns across different model definitions in the transformers library.
|
|
|
|
Args:
|
|
hub_dataset (`str`): The Hub dataset repository ID containing the code embeddings index.
|
|
"""
|
|
|
|
def __init__(self, hub_dataset: str):
|
|
for name in ("huggingface_hub", "httpx", "urllib3", "transformers"):
|
|
logging.getLogger(name).setLevel(logging.ERROR)
|
|
huggingface_hub_logging.set_verbosity_error()
|
|
transformers_logging.set_verbosity_error()
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.set_grad_enabled(False)
|
|
|
|
self.models_root = MODELS_ROOT
|
|
self.hub_dataset = hub_dataset
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.dtype = "auto"
|
|
self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
|
|
self.model = (
|
|
AutoModel.from_pretrained(
|
|
EMBEDDING_MODEL,
|
|
torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32,
|
|
)
|
|
.eval()
|
|
.to(self.device)
|
|
)
|
|
|
|
# ---------- HUB IO ----------
|
|
|
|
def ensure_local_index(self) -> None:
|
|
"""Download index files from Hub if they don't exist locally."""
|
|
have_all = Path(EMBEDDINGS_PATH).exists() and Path(INDEX_MAP_PATH).exists() and Path(TOKENS_PATH).exists()
|
|
if have_all:
|
|
return
|
|
logging.info(f"downloading index from hub: {self.hub_dataset}")
|
|
for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH):
|
|
hf_hub_download(
|
|
repo_id=self.hub_dataset,
|
|
filename=fname,
|
|
repo_type="dataset",
|
|
local_dir=".",
|
|
local_dir_use_symlinks=False,
|
|
)
|
|
|
|
def push_index_to_hub(self) -> None:
|
|
"""Upload index files to the Hub dataset repository."""
|
|
api = HfApi()
|
|
api.create_repo(repo_id=self.hub_dataset, repo_type="dataset", exist_ok=True)
|
|
for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH):
|
|
logging.info(f"pushing {fname} -> {self.hub_dataset}")
|
|
api.upload_file(
|
|
path_or_fileobj=fname,
|
|
path_in_repo=os.path.basename(fname),
|
|
repo_id=self.hub_dataset,
|
|
repo_type="dataset",
|
|
)
|
|
|
|
# ---------- parsing & encoding ----------
|
|
|
|
def _extract_definitions(
|
|
self, file_path: Path, relative_to: Path | None = None, model_hint: str | None = None
|
|
) -> tuple[dict[str, str], dict[str, str], dict[str, list[str]]]:
|
|
"""
|
|
Extract class and function definitions from a Python file.
|
|
|
|
Args:
|
|
file_path (`Path`): Path to the Python file to parse.
|
|
relative_to (`Path` or `None`): Base path for computing relative identifiers.
|
|
model_hint (`str` or `None`): Model name hint for sanitization.
|
|
|
|
Returns:
|
|
`tuple[dict[str, str], dict[str, str], dict[str, list[str]]]`: A tuple containing:
|
|
- definitions_raw: Mapping of identifiers to raw source code
|
|
- definitions_sanitized: Mapping of identifiers to sanitized source code
|
|
- definitions_tokens: Mapping of identifiers to sorted token lists
|
|
"""
|
|
definitions_raw = {}
|
|
definitions_sanitized = {}
|
|
definitions_tokens = {}
|
|
source = file_path.read_text(encoding="utf-8")
|
|
lines = source.splitlines()
|
|
tree = ast.parse(source)
|
|
for node in ast.iter_child_nodes(tree):
|
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
|
segment = ast.get_source_segment(source, node)
|
|
if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"):
|
|
start = max(0, node.lineno - 1)
|
|
end = node.end_lineno
|
|
segment = "\n".join(lines[start:end])
|
|
if segment:
|
|
identifier = (
|
|
f"{file_path.relative_to(relative_to)}:{node.name}"
|
|
if relative_to
|
|
else f"{file_path.name}:{node.name}"
|
|
)
|
|
definitions_raw[identifier] = segment
|
|
sanitized = _sanitize_for_embedding(segment, model_hint, node.name)
|
|
definitions_sanitized[identifier] = sanitized
|
|
definitions_tokens[identifier] = sorted(_tokenize(sanitized))
|
|
return definitions_raw, definitions_sanitized, definitions_tokens
|
|
|
|
def _infer_model_from_relative_path(self, relative_path: Path) -> str | None:
|
|
try:
|
|
relative = relative_path.resolve().relative_to(self.models_root.resolve())
|
|
return relative.parts[0]
|
|
except Exception:
|
|
return None
|
|
|
|
def _infer_query_model_name(self, modeling_file: Path) -> str | None:
|
|
model = self._infer_model_from_relative_path(modeling_file)
|
|
if model:
|
|
return model
|
|
stem = modeling_file.stem
|
|
if stem.startswith("modeling_") and len(stem) > len("modeling_"):
|
|
return stem[len("modeling_") :]
|
|
return None
|
|
|
|
def _encode_batch(self, texts: list[str]) -> np.ndarray:
|
|
"""
|
|
Encode a batch of texts into normalized embeddings.
|
|
|
|
Args:
|
|
texts (`list[str]`): List of text strings to encode.
|
|
|
|
Returns:
|
|
`np.ndarray`: Normalized embeddings as a float32 numpy array.
|
|
"""
|
|
encoded = self.tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
|
|
encoded = {key: value.to(self.device) for key, value in encoded.items()}
|
|
with (
|
|
torch.autocast(device_type=self.device.type, dtype=self.dtype)
|
|
if self.device.type == "cuda"
|
|
else torch.no_grad()
|
|
):
|
|
output = self.model(**encoded)
|
|
if hasattr(output, "last_hidden_state"):
|
|
embeddings = output.last_hidden_state
|
|
mask = encoded["attention_mask"].unsqueeze(-1)
|
|
embeddings = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-9)
|
|
elif hasattr(output, "pooler_output"):
|
|
embeddings = output.pooler_output
|
|
else:
|
|
embeddings = output[0].mean(dim=1)
|
|
embeddings = torch.nn.functional.normalize(embeddings.float(), p=2, dim=1)
|
|
return embeddings.cpu().numpy().astype("float32")
|
|
|
|
def encode(self, texts: list[str]) -> np.ndarray:
|
|
"""
|
|
Encode a list of texts into embeddings, processing in batches.
|
|
|
|
Args:
|
|
texts (`list[str]`): List of text strings to encode.
|
|
|
|
Returns:
|
|
`np.ndarray`: Stacked embeddings for all texts.
|
|
"""
|
|
output = []
|
|
for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="encode", leave=False):
|
|
output.append(self._encode_batch(texts[i : i + BATCH_SIZE]))
|
|
if self.device.type == "cuda":
|
|
torch.cuda.empty_cache()
|
|
return np.vstack(output) if output else np.zeros((0, 0), dtype="float32")
|
|
|
|
# ---------- build & search ----------
|
|
|
|
def build_index(self) -> None:
|
|
"""Build the code similarity index from all modeling files and save to disk."""
|
|
logging.info("collecting files")
|
|
files = list(self.models_root.rglob("modeling_*.py"))
|
|
logging.info(f"parsing {len(files)} files")
|
|
|
|
identifiers = []
|
|
sanitized_sources = []
|
|
tokens_map = {}
|
|
|
|
for file_path in tqdm(files, desc="parse", leave=False):
|
|
model_hint = self._infer_model_from_relative_path(file_path)
|
|
definitions_raw, definitions_sanitized, definitions_tokens = self._extract_definitions(
|
|
file_path, self.models_root, model_hint
|
|
)
|
|
for identifier in definitions_sanitized.keys():
|
|
identifiers.append(identifier)
|
|
sanitized_sources.append(definitions_sanitized[identifier])
|
|
tokens_map[identifier] = definitions_tokens[identifier]
|
|
|
|
logging.info(
|
|
f"encoding {len(sanitized_sources)} definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})"
|
|
)
|
|
embeddings = self.encode(sanitized_sources)
|
|
safetensors_save({"embeddings": embeddings}, EMBEDDINGS_PATH)
|
|
with open(INDEX_MAP_PATH, "w", encoding="utf-8") as file:
|
|
json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file)
|
|
with open(TOKENS_PATH, "w", encoding="utf-8") as file:
|
|
json.dump(tokens_map, file)
|
|
|
|
def _topk_embedding(
|
|
self,
|
|
query_embedding_row: np.ndarray,
|
|
base_embeddings: np.ndarray,
|
|
identifier_map: dict[int, str],
|
|
self_model_normalized: str,
|
|
self_name: str,
|
|
k: int,
|
|
) -> list[tuple[str, float]]:
|
|
similarities = query_embedding_row @ base_embeddings.T
|
|
indices = np.argpartition(-similarities, k + 32)[: k + 32]
|
|
indices = indices[np.argsort(-similarities[indices])]
|
|
output = []
|
|
for match_id in indices:
|
|
identifier = identifier_map[int(match_id)]
|
|
parent_relative_path, match_name = identifier.split(":", 1)
|
|
parent_model = Path(parent_relative_path).parts[0]
|
|
if match_name == self_name:
|
|
continue
|
|
if self_model_normalized and _normalize(parent_model) == self_model_normalized:
|
|
continue
|
|
output.append((f"{parent_model}::{match_name}", float(similarities[match_id])))
|
|
if len(output) >= k:
|
|
break
|
|
return output
|
|
|
|
def _topk_jaccard(
|
|
self,
|
|
query_tokens: set[str],
|
|
identifiers: list[str],
|
|
tokens_map: dict[str, list[str]],
|
|
self_model_normalized: str,
|
|
self_name: str,
|
|
k: int,
|
|
) -> list[tuple[str, float]]:
|
|
"""
|
|
Find top-k most similar definitions using Jaccard similarity on token sets.
|
|
|
|
Args:
|
|
query_tokens (`set[str]`): Set of tokens from the query definition.
|
|
identifiers (`list[str]`): List of all definition identifiers in the index.
|
|
tokens_map (`dict[str, list[str]]`): Mapping of identifiers to their token lists.
|
|
self_model_normalized (`str`): Normalized name of the query model to exclude.
|
|
self_name (`str`): Name of the query definition to exclude.
|
|
k (`int`): Number of top results to return.
|
|
|
|
Returns:
|
|
`list[tuple[str, float]]`: List of (identifier, score) tuples.
|
|
"""
|
|
scores = []
|
|
for identifier in identifiers:
|
|
parent_relative_path, match_name = identifier.split(":", 1)
|
|
parent_model = Path(parent_relative_path).parts[0]
|
|
if match_name == self_name:
|
|
continue
|
|
if self_model_normalized and _normalize(parent_model) == self_model_normalized:
|
|
continue
|
|
tokens = set(tokens_map.get(identifier, []))
|
|
if not tokens or not query_tokens:
|
|
continue
|
|
score = len(query_tokens & tokens) / len(query_tokens | tokens)
|
|
if score > 0:
|
|
scores.append((f"{parent_model}::{match_name}", score))
|
|
scores.sort(key=lambda x: x[1], reverse=True)
|
|
return scores[:k]
|
|
|
|
def analyze_file(
|
|
self, modeling_file: Path, top_k_per_item: int = 5, allow_hub_fallback: bool = True
|
|
) -> dict[str, dict[str, list]]:
|
|
"""
|
|
Analyze a modeling file and find similar code definitions in the index.
|
|
|
|
Args:
|
|
modeling_file (`Path`): Path to the modeling file to analyze.
|
|
top_k_per_item (`int`, *optional*, defaults to 5): Number of top matches to return per definition.
|
|
allow_hub_fallback (`bool`, *optional*, defaults to `True`): Whether to download index from Hub if not found locally.
|
|
|
|
Returns:
|
|
`dict[str, dict[str, list]]`: Dictionary mapping definition names to their similarity results.
|
|
Each result contains 'embedding', 'jaccard', and 'intersection' keys.
|
|
"""
|
|
if allow_hub_fallback:
|
|
self.ensure_local_index()
|
|
|
|
base = safetensors_load(EMBEDDINGS_PATH)
|
|
base_embeddings = base["embeddings"]
|
|
with open(INDEX_MAP_PATH, "r", encoding="utf-8") as file:
|
|
identifier_map = {int(key): value for key, value in json.load(file).items()}
|
|
identifiers = [identifier_map[i] for i in range(len(identifier_map))]
|
|
with open(TOKENS_PATH, "r", encoding="utf-8") as file:
|
|
tokens_map = json.load(file)
|
|
|
|
self_model = self._infer_query_model_name(modeling_file)
|
|
definitions_raw, definitions_sanitized, _ = self._extract_definitions(modeling_file, None, self_model)
|
|
query_identifiers = list(definitions_raw.keys())
|
|
query_sources_sanitized = [definitions_sanitized[key] for key in query_identifiers]
|
|
query_tokens_list = [set(_tokenize(source)) for source in query_sources_sanitized]
|
|
self_model_normalized = _normalize(self_model)
|
|
|
|
logging.info(
|
|
f"encoding {len(query_sources_sanitized)} query definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})"
|
|
)
|
|
query_embeddings = self.encode(query_sources_sanitized)
|
|
|
|
output = {}
|
|
for i, query_identifier in enumerate(query_identifiers):
|
|
query_name = query_identifier.split(":")[-1]
|
|
embedding_top = self._topk_embedding(
|
|
query_embeddings[i], base_embeddings, identifier_map, self_model_normalized, query_name, top_k_per_item
|
|
)
|
|
jaccard_top = self._topk_jaccard(
|
|
query_tokens_list[i], identifiers, tokens_map, self_model_normalized, query_name, top_k_per_item
|
|
)
|
|
embedding_set = {identifier for identifier, _ in embedding_top}
|
|
jaccard_set = {identifier for identifier, _ in jaccard_top}
|
|
intersection = list(embedding_set & jaccard_set)
|
|
output[query_name] = {"embedding": embedding_top, "jaccard": jaccard_top, "intersection": intersection}
|
|
return output
|
|
|
|
|
|
def main():
|
|
"""CLI entry point for the modular model detector."""
|
|
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
|
parser = argparse.ArgumentParser(prog="hf-code-sim")
|
|
parser.add_argument("--build", action="store_true")
|
|
parser.add_argument("--modeling-file", type=str)
|
|
parser.add_argument(
|
|
"--push-new-index", action="store_true", help="After --build, push index files to a Hub dataset."
|
|
)
|
|
parser.add_argument(
|
|
"--hub-dataset", type=str, default=HUB_DATASET_DEFAULT, help="Hub dataset repo id to pull/push the index."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
analyzer = CodeSimilarityAnalyzer(hub_dataset=args.hub_dataset)
|
|
|
|
if args.build:
|
|
analyzer.build_index()
|
|
if args.push_new_index:
|
|
analyzer.push_index_to_hub()
|
|
return
|
|
|
|
if not args.modeling_file:
|
|
raise SystemExit("Provide --modeling-file or use --build")
|
|
|
|
results = analyzer.analyze_file(Path(args.modeling_file), top_k_per_item=5, allow_hub_fallback=True)
|
|
modeling_filename = Path(args.modeling_file).name
|
|
for query_name, data in results.items():
|
|
logging.info(f"{modeling_filename}::{query_name}:")
|
|
logging.info(" embedding:")
|
|
for identifier, score in data["embedding"]:
|
|
logging.info(f" {identifier} ({score:.4f})")
|
|
logging.info(" jaccard:")
|
|
for identifier, score in data["jaccard"]:
|
|
logging.info(f" {identifier} ({score:.4f})")
|
|
logging.info(" intersection:")
|
|
for identifier in data["intersection"]:
|
|
logging.info(f" {identifier}")
|
|
logging.info("")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|