-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbaselines.py
More file actions
129 lines (97 loc) · 3.76 KB
/
baselines.py
File metadata and controls
129 lines (97 loc) · 3.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from typing import List, Dict
import time
from src.core.retrieval_engine import RetrievalEngine
from src.core.answer_generation import AnswerGenerator
from src.llm.llm_client import llm_client
class BaselineRunner:
"""Runs baseline methods for comparison."""
def __init__(self, documents: List[str] = None):
"""
Initialize baseline runner.
Args:
documents: Documents for retrieval (if None, no retrieval available)
"""
self.retrieval_engine = RetrievalEngine(documents) if documents else None
self.answer_generator = AnswerGenerator()
def run_llm_only(self, query: str) -> Dict:
"""
Run LLM-only baseline (no retrieval).
Args:
query: User query
Returns:
Dictionary with answer and metadata
"""
start_time = time.time()
answer = self.answer_generator.generate_answer(query, context=None)
latency = time.time() - start_time
return {
"answer": answer,
"evidence": [],
"retrieval_calls": 0,
"latency": latency,
"method": "llm_only"
}
def run_static_rag(self, query: str) -> Dict:
"""
Run static RAG baseline (always retrieve).
Args:
query: User query
Returns:
Dictionary with answer and metadata
"""
if not self.retrieval_engine:
raise ValueError("Retrieval engine required for static RAG baseline")
start_time = time.time()
# Always retrieve
retrieved_passages = self.retrieval_engine.retrieve(query)
# Generate answer
answer = self.answer_generator.generate_answer(query, context=retrieved_passages)
latency = time.time() - start_time
return {
"answer": answer,
"evidence": retrieved_passages,
"retrieval_calls": 1,
"latency": latency,
"method": "static_rag"
}
def run_react_always_retrieve(self, query: str) -> Dict:
"""
Run ReAct baseline with always-retrieve policy.
Args:
query: User query
Returns:
Dictionary with answer and metadata
"""
if not self.retrieval_engine:
raise ValueError("Retrieval engine required for ReAct baseline")
start_time = time.time()
retrieval_calls = 0
# Simple ReAct loop: always retrieve, then generate
retrieved_passages = self.retrieval_engine.retrieve(query)
retrieval_calls += 1
# Generate answer
answer = self.answer_generator.generate_answer(query, context=retrieved_passages)
latency = time.time() - start_time
return {
"answer": answer,
"evidence": retrieved_passages,
"retrieval_calls": retrieval_calls,
"latency": latency,
"method": "react_always_retrieve"
}
def run_all_baselines(self, query: str) -> Dict[str, Dict]:
"""
Run all baseline methods.
Args:
query: User query
Returns:
Dictionary mapping baseline name to results
"""
results = {}
# LLM-only (always available)
results["llm_only"] = self.run_llm_only(query)
# Retrieval-based baselines (if retrieval engine available)
if self.retrieval_engine:
results["static_rag"] = self.run_static_rag(query)
results["react_always_retrieve"] = self.run_react_always_retrieve(query)
return results