Skip to content

Commit 25571b3

Browse files
committed
Fix transformers bug
1 parent 60b6140 commit 25571b3

2 files changed

Lines changed: 24 additions & 28 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ Thumbs.db
3232
wandb/
3333
runs/
3434
checkpoints/
35+
output/

scripts/evaluate_baselines.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -236,48 +236,43 @@ def _load_zilliz():
236236
from transformers import AutoModel
237237

238238
model_name = "zilliz/semantic-highlight-bilingual-v1"
239-
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16)
239+
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, dtype=torch.float16)
240240
device = "cuda" if torch.cuda.is_available() else "cpu"
241241
model = model.to(device)
242242
model.eval()
243243
return model
244244

245245

246246
def baseline_zilliz(model, task: str, tool_output: str, threshold: float = 0.5) -> list[str]:
247-
"""Run Zilliz semantic-highlight via its process() API.
247+
"""Run Zilliz semantic-highlight via get_raw_predictions().
248248
249-
The model takes (question, context) and returns highlighted sentences
250-
with per-sentence probabilities.
249+
Uses the low-level API to avoid the broken process() path in
250+
transformers 5.2 (build_inputs_with_special_tokens removed).
251+
Each line is passed as a separate context, and per-token pruning
252+
probabilities are averaged per line.
251253
"""
254+
import torch
255+
252256
lines = tool_output.split("\n")
253-
if not lines:
257+
non_empty = [line for line in lines if line.strip()]
258+
if not non_empty:
254259
return []
255260

256-
result = model.process(
257-
question=task,
258-
context=tool_output,
259-
threshold=threshold,
260-
return_sentence_metrics=True,
261-
show_progress=False,
262-
)
261+
with torch.no_grad():
262+
raw = model.get_raw_predictions(query=task, contexts=non_empty)
263263

264-
highlighted = result.get("highlighted_sentences", [])
265-
if not highlighted:
266-
return []
267-
268-
# Map highlighted sentences back to original lines
269-
# (Zilliz works at sentence level, we need line level)
270-
highlighted_set = set(s.strip() for s in highlighted)
271264
kept = []
272-
for line in lines:
273-
stripped = line.strip()
274-
if not stripped:
275-
continue
276-
# Check if this line is contained in any highlighted sentence
277-
for hs in highlighted_set:
278-
if stripped in hs or hs in stripped:
279-
kept.append(line)
280-
break
265+
for i, line in enumerate(non_empty):
266+
if i >= len(raw.context_ranges):
267+
break
268+
start, end = raw.context_ranges[i]
269+
segment = raw.pruning_probs[start:end]
270+
if segment.size > 0:
271+
score = float(segment.mean())
272+
else:
273+
score = 0.0
274+
if score >= threshold:
275+
kept.append(line)
281276

282277
return kept
283278

0 commit comments

Comments
 (0)