@@ -236,48 +236,43 @@ def _load_zilliz():
236236 from transformers import AutoModel
237237
238238 model_name = "zilliz/semantic-highlight-bilingual-v1"
239- model = AutoModel .from_pretrained (model_name , trust_remote_code = True , torch_dtype = torch .float16 )
239+ model = AutoModel .from_pretrained (model_name , trust_remote_code = True , dtype = torch .float16 )
240240 device = "cuda" if torch .cuda .is_available () else "cpu"
241241 model = model .to (device )
242242 model .eval ()
243243 return model
244244
245245
246246def baseline_zilliz (model , task : str , tool_output : str , threshold : float = 0.5 ) -> list [str ]:
247- """Run Zilliz semantic-highlight via its process() API .
247+ """Run Zilliz semantic-highlight via get_raw_predictions() .
248248
249- The model takes (question, context) and returns highlighted sentences
250- with per-sentence probabilities.
249+ Uses the low-level API to avoid the broken process() path in
250+ transformers 5.2 (build_inputs_with_special_tokens removed).
251+ Each line is passed as a separate context, and per-token pruning
252+ probabilities are averaged per line.
251253 """
254+ import torch
255+
252256 lines = tool_output .split ("\n " )
253- if not lines :
257+ non_empty = [line for line in lines if line .strip ()]
258+ if not non_empty :
254259 return []
255260
256- result = model .process (
257- question = task ,
258- context = tool_output ,
259- threshold = threshold ,
260- return_sentence_metrics = True ,
261- show_progress = False ,
262- )
261+ with torch .no_grad ():
262+ raw = model .get_raw_predictions (query = task , contexts = non_empty )
263263
264- highlighted = result .get ("highlighted_sentences" , [])
265- if not highlighted :
266- return []
267-
268- # Map highlighted sentences back to original lines
269- # (Zilliz works at sentence level, we need line level)
270- highlighted_set = set (s .strip () for s in highlighted )
271264 kept = []
272- for line in lines :
273- stripped = line .strip ()
274- if not stripped :
275- continue
276- # Check if this line is contained in any highlighted sentence
277- for hs in highlighted_set :
278- if stripped in hs or hs in stripped :
279- kept .append (line )
280- break
265+ for i , line in enumerate (non_empty ):
266+ if i >= len (raw .context_ranges ):
267+ break
268+ start , end = raw .context_ranges [i ]
269+ segment = raw .pruning_probs [start :end ]
270+ if segment .size > 0 :
271+ score = float (segment .mean ())
272+ else :
273+ score = 0.0
274+ if score >= threshold :
275+ kept .append (line )
281276
282277 return kept
283278
0 commit comments