# Requires: pip install sentence-transformers
from sentence_transformers import CrossEncoder
from typing import List, Dict
class CrossEncoderReranker:
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
"""
Lightweight cross-encoder (CPU-capable). Swap for larger models if you have GPU budget
e.g., "BAAI/bge-reranker-large" for higher precision at higher latency.
"""
self.model = CrossEncoder(model_name)
def rerank(
self,
query: str,
candidates: List[Dict], # [{document: str, score: float, rank: int}, ...]
top_k: int = 5
) -> List[Dict]:
"""Score (query, document) pairs and return top_k by cross-encoder score."""
pairs = [(query, c["document"]) for c in candidates]
ce_scores = self.model.predict(pairs).tolist()
rescored = []
for c, s in zip(candidates, ce_scores):
item = dict(c)
item["rerank_score"] = float(s)
rescored.append(item)
rescored.sort(key=lambda x: x["rerank_score"], reverse=True)
for i, item in enumerate(rescored[:top_k]):
item["rerank_rank"] = i + 1
return rescored[:top_k]