diff --git a/code_tests/integration_tests/test_ai_models/test_exa_model.py b/code_tests/integration_tests/test_ai_models/test_exa_model.py index b0f201dc..58e5ea67 100644 --- a/code_tests/integration_tests/test_ai_models/test_exa_model.py +++ b/code_tests/integration_tests/test_ai_models/test_exa_model.py @@ -70,6 +70,35 @@ async def test_invoke_for_highlights_in_relevance_order(mocker: Mock) -> None: ), f"Highlights not in descending order at index {i}" +async def test_invoke_for_highlights_with_missing_scores(mocker: Mock) -> None: + mock_return_value = [ + ExaSource( + original_query="test query", + auto_prompt_string=None, + title="Test Title", + url="https://example.com", + text="Test text", + author=None, + published_date=None, + score=0.9, + highlights=["Highlight A", "Highlight B"], + highlight_scores=[], + ), + ] + AiModelMockManager.mock_ai_model_direct_call_with_value( + mocker, ExaSearcher, mock_return_value + ) + + searcher = ExaSearcher() + cheap_input = searcher._get_cheap_input_for_invoke() + result = await searcher.invoke_for_highlights_in_relevance_order(cheap_input) + + assert len(result) == 2 + for quote in result: + assert isinstance(quote, ExaHighlightQuote) + assert quote.score == 1.0 + + async def test_general_invoke() -> None: num_results = 2 model = ExaSearcher( diff --git a/forecasting_tools/ai_models/exa_searcher.py b/forecasting_tools/ai_models/exa_searcher.py index 057eab4c..d696a2d5 100644 --- a/forecasting_tools/ai_models/exa_searcher.py +++ b/forecasting_tools/ai_models/exa_searcher.py @@ -144,7 +144,8 @@ async def invoke_for_highlights_in_relevance_order( sources = await self.invoke(search_query_or_strategy) all_highlights = [] for source in sources: - for highlight, score in zip(source.highlights, source.highlight_scores): + scores = source.highlight_scores or [1.0] * len(source.highlights) + for highlight, score in zip(source.highlights, scores): all_highlights.append( ExaHighlightQuote( highlight_text=highlight, score=score, source=source