diff --git a/README.md b/README.md
index a794ea4..a563260 100644
--- a/README.md
+++ b/README.md
@@ -45,6 +45,7 @@ QueryGym implements the following query reformulation methods:
| **MuGI** | Multi-granularity information expansion with adaptive concatenation | [Zhang et al., 2024](https://arxiv.org/abs/2401.06311) |
| **LameR** | Context-based passage synthesis using retrieved documents | [Mackie et al., 2023](https://arxiv.org/abs/2304.14233) |
| **CSQE** | Context-based sentence-level query expansion (KEQE + CSQE) | [Lee et al., 2024](https://arxiv.org/abs/2402.18031) |
+| **ThinkQE** | Multi-round reasoning-based query expansion with corpus feedback | [Le et al., 2025](https://arxiv.org/abs/2506.09260) |
| **Query2E** | Query to entity/keyword expansion | [Jagerman et al., 2023](https://arxiv.org/abs/2305.03653)|
For detailed usage and parameters, see the [Methods Reference](https://querygym.readthedocs.io/en/latest/user-guide/methods-reference/).
diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md
index 9991c3f..ba9ea30 100644
--- a/docs/getting-started/quickstart.md
+++ b/docs/getting-started/quickstart.md
@@ -48,6 +48,7 @@ Available methods:
- `lamer` - Context-based passage synthesis
- `query2e` - Query to entity expansion
- `csqe` - Context-based sentence extraction
+- `thinkqe` - Multi-round reasoning-based passage expansion
### 3. Reformulate Queries
@@ -137,7 +138,7 @@ See [Loading Datasets](../user-guide/datasets.md) for more details.
## Context-Based Reformulation
-Some methods (like `lamer`, `csqe`) use retrieved contexts:
+Some methods (like `lamer`, `csqe`, `thinkqe`) use retrieved contexts:
```python
import querygym as qg
diff --git a/docs/user-guide/methods-reference.md b/docs/user-guide/methods-reference.md
index 1aaf00a..7a40b36 100644
--- a/docs/user-guide/methods-reference.md
+++ b/docs/user-guide/methods-reference.md
@@ -15,6 +15,7 @@ Complete reference guide for all query reformulation methods in QueryGym, includ
- [LameR](#lamer)
- [Query2E](#query2e)
- [CSQE](#csqe)
+ - [ThinkQE](#thinkqe)
---
@@ -561,6 +562,68 @@ result = reformulator.reformulate(qg.QueryItem("q1", "quantum computing"))
---
+### ThinkQE
+
+**Method Name:** `"thinkqe"`
+**Requires Context:** Yes
+**Description:** Multi-round query expansion with retrieved passage feedback. Each round uses the original query plus newly retrieved passages to generate pseudo-passages, appends them to the retrieval query, and retrieves again.
+
+#### Parameters
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `keep_passage_num` | int | `5` | Number of retrieved passages kept for prompting |
+| `gen_num` | int | `2` | Number of expansions generated per round |
+| `num_interaction` | int | `3` | Number of expansion rounds after baseline retrieval |
+| `accumulate` | bool | `True` | Accumulate all previous expansions into later rounds |
+| `use_passage_filter` | bool | `True` | Blacklist passages repeated from two rounds ago |
+| `repeat_weight` | float | `3` | Divisor for adaptive query repetition |
+| `search_k` | int | `keep_passage_num` | Retrieval depth for each round before filtering; use `1000` to mirror the original archive runs |
+| `max_demo_len` | int | `None` | Optional word truncation length for each passage |
+| `no_thinking` | bool | `False` | Prefill a closing `` tag to disable reasoning traces |
+| `searcher` | object | `None` | Pre-configured searcher instance (recommended) |
+| `searcher_type` | str | `"pyserini"` | Type of searcher to create |
+| `searcher_kwargs` | dict | `{}` | Keyword arguments for searcher initialization |
+| `index` | str | `None` | Pyserini index name (legacy format) |
+| `temperature` | float | `0.7` | Sampling temperature (via `llm_config`) |
+| `max_tokens` | int | `32768` | Maximum tokens per generation (via `llm_config`) |
+
+#### Usage Example
+
+```python
+import querygym as qg
+from pyserini.search.lucene import LuceneSearcher
+
+pyserini_searcher = LuceneSearcher.from_prebuilt_index("msmarco-v1-passage")
+searcher = qg.wrap_pyserini_searcher(pyserini_searcher, answer_key="contents")
+
+reformulator = qg.create_reformulator(
+ "thinkqe",
+ model="deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
+ params={
+ "searcher": searcher,
+ "keep_passage_num": 5,
+ "gen_num": 2,
+ "num_interaction": 3,
+ "accumulate": True,
+ "use_passage_filter": True,
+ "repeat_weight": 3,
+ "search_k": 1000,
+ "max_demo_len": 128,
+ },
+ llm_config={"temperature": 0.7, "max_tokens": 32768}
+)
+
+result = reformulator.reformulate_batch([qg.QueryItem("q1", "quantum computing")])[0]
+```
+
+#### Output Format
+
+- **Concatenation:** `(query × adaptive_repeat) + expansion_1 + expansion_2 + ...` using newline joins
+- **Metadata:** Includes `round_history`, `gen_num`, `keep_passage_num`, `accumulated_count`, `q_repeat`, and per-round raw response counts
+
+---
+
## Quick Reference Table
| Method | Requires Context | Key Parameters | Default LLM Config |
@@ -573,12 +636,13 @@ result = reformulator.reformulate(qg.QueryItem("q1", "quantum computing"))
| **LameR** | Yes | `retrieval_k`, `gen_passages`, `searcher` | temp=1.0, max_tokens=128 |
| **Query2E** | No | `mode`, `max_keywords`, `num_examples` (fs) | temp=0.3, max_tokens=256 |
| **CSQE** | Yes | `retrieval_k`, `gen_num`, `searcher` | temp=1.0, max_tokens=1024 |
+| **ThinkQE** | Yes | `keep_passage_num`, `gen_num`, `num_interaction`, `searcher` | temp=0.7, max_tokens=32768 |
---
## Tips and Best Practices
-1. **Context-Based Methods (LameR, CSQE):**
+1. **Context-Based Methods (LameR, CSQE, ThinkQE):**
- Always provide a `searcher` instance or configure `searcher_type`/`searcher_kwargs`
- Use `qg.wrap_pyserini_searcher()` for easy integration with Pyserini
- Set appropriate `retrieval_k` based on your needs (default: 10)
@@ -604,4 +668,3 @@ result = reformulator.reformulate(qg.QueryItem("q1", "quantum computing"))
- [API Reference](../api/methods.md) - Technical API documentation
- [Query Reformulation Guide](reformulation.md) - Usage tutorials
- [Examples](https://github.com/ls3-lab/QueryGym/tree/main/examples) - Complete workflow examples
-
diff --git a/docs/user-guide/reformulation.md b/docs/user-guide/reformulation.md
index 3f98cdc..bcfaf31 100644
--- a/docs/user-guide/reformulation.md
+++ b/docs/user-guide/reformulation.md
@@ -93,6 +93,26 @@ reformulator = qg.create_reformulator("csqe", model="gpt-4")
results = reformulator.reformulate_batch(queries, contexts=contexts)
```
+### ThinkQE
+
+Multi-round reasoning-based expansion with iterative corpus feedback.
+
+```python
+reformulator = qg.create_reformulator(
+ "thinkqe",
+ model="deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
+ params={
+ "searcher": searcher,
+ "num_interaction": 3,
+ "keep_passage_num": 5,
+ "gen_num": 2,
+ "accumulate": True,
+ "use_passage_filter": True,
+ "search_k": 1000,
+ },
+)
+```
+
## Method Comparison
| Method | Requires Context | Type | Best For |
@@ -105,6 +125,7 @@ results = reformulator.reformulate_batch(queries, contexts=contexts)
| lamer | Yes | Context synthesis | Re-ranking |
| query2e | No | Entity expansion | Entity queries |
| csqe | Yes | Sentence extraction | Precision-focused |
+| thinkqe | Yes | Iterative reasoning | Multi-round feedback |
## Custom Parameters
diff --git a/examples/querygym_pyserini/README.md b/examples/querygym_pyserini/README.md
index 54b7f7a..48b006b 100644
--- a/examples/querygym_pyserini/README.md
+++ b/examples/querygym_pyserini/README.md
@@ -127,7 +127,7 @@ python examples/querygym_pyserini/reformulate_queries.py \
```
**Options:**
-- `--method`: QueryGym method (genqr, genqr_ensemble, query2doc, qa_expand, mugi, lamer, query2e, csqe)
+- `--method`: QueryGym method (genqr, genqr_ensemble, query2doc, qa_expand, mugi, lamer, query2e, csqe, thinkqe)
- `--model`: LLM model name (e.g., qwen2.5:7b, llama3.1:8b, gpt-4, etc.)
- `--base-url`: LLM API endpoint (e.g., http://localhost:11434/v1)
- `--api-key`: LLM API key
@@ -525,4 +525,3 @@ python examples/querygym_pyserini/pipeline.py \
```
**Note:** The config file (`reformulation_config.yaml`) contains pre-configured settings for all methods, including complex method parameters. CLI arguments override config file values.
-
diff --git a/examples/querygym_pyserini/reformulate_queries.py b/examples/querygym_pyserini/reformulate_queries.py
index fa92bcf..f74d7eb 100755
--- a/examples/querygym_pyserini/reformulate_queries.py
+++ b/examples/querygym_pyserini/reformulate_queries.py
@@ -300,7 +300,7 @@ def main():
parser.add_argument(
'--method',
type=str,
- help='QueryGym reformulation method (genqr, genqr_ensemble, query2doc, etc.)'
+ help='QueryGym reformulation method (genqr, genqr_ensemble, query2doc, lamer, csqe, thinkqe, etc.)'
)
parser.add_argument(
'--model',
@@ -485,4 +485,3 @@ def main():
if __name__ == '__main__':
main()
-
diff --git a/examples/querygym_pyserini/reformulation_config.yaml b/examples/querygym_pyserini/reformulation_config.yaml
index e7e9770..dfa6019 100644
--- a/examples/querygym_pyserini/reformulation_config.yaml
+++ b/examples/querygym_pyserini/reformulation_config.yaml
@@ -118,6 +118,24 @@ methods:
gen_num: 2 # Number of expansions for both KEQE and CSQE (default: 2)
# Note: Searcher is automatically configured from dataset registry
+ # ThinkQE: Multi-round reasoning-based query expansion (requires retrieval)
+ thinkqe:
+ enabled: true
+ model: "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
+ llm:
+ temperature: 0.7
+ max_tokens: 32768
+ params:
+ keep_passage_num: 5
+ gen_num: 2
+ num_interaction: 3
+ accumulate: true
+ use_passage_filter: true
+ repeat_weight: 3
+ max_demo_len: 128
+ search_k: 1000
+ # Note: Searcher is automatically configured from dataset registry
+
# Example: Multiple configurations for the same method
# You can define multiple variants with different parameters:
#
@@ -136,4 +154,3 @@ methods:
# temperature: 0.9
# params:
# n_generations: 7
-
diff --git a/querygym/__init__.py b/querygym/__init__.py
index fe1c2b2..f31fdd2 100644
--- a/querygym/__init__.py
+++ b/querygym/__init__.py
@@ -3,10 +3,10 @@
Simple usage:
import querygym as qg
-
+
# Create a reformulator
reformulator = qg.create_reformulator("genqr_ensemble")
-
+
# Reformulate queries
result = reformulator.reformulate(qg.QueryItem("q1", "what causes diabetes"))
"""
@@ -24,7 +24,11 @@
from .core.searcher import BaseSearcher, SearchHit, SearcherRegistry, create_searcher
# Searcher wrappers for user convenience
-from .core.searcher_wrappers import wrap_pyserini_searcher, wrap_pyterrier_retriever, wrap_custom_searcher
+from .core.searcher_wrappers import (
+ wrap_pyserini_searcher,
+ wrap_pyterrier_retriever,
+ wrap_custom_searcher,
+)
# Data loaders
from .data.dataloader import DataLoader, UnifiedQuerySource
@@ -42,6 +46,7 @@
LameR,
Query2E,
CSQE,
+ ThinkQE,
)
# High-level runner
@@ -52,6 +57,7 @@
# Import adapters to register them
from . import adapters
+
# Convenience factory function
def create_reformulator(
method_name: str,
@@ -59,11 +65,11 @@ def create_reformulator(
params: dict = None,
llm_config: dict = None,
prompt_bank_path: str = None,
- **kwargs
+ **kwargs,
):
"""
Create a reformulator instance with sensible defaults.
-
+
Args:
method_name: Name of the method (e.g., "genqr", "genqr_ensemble", "query2doc")
model: LLM model name (default: "gpt-4")
@@ -71,52 +77,52 @@ def create_reformulator(
llm_config: Additional LLM configuration (temperature, max_tokens, etc.)
prompt_bank_path: Path to prompt bank YAML (default: bundled prompt_bank.yaml)
**kwargs: Additional MethodConfig parameters (seed, retries)
-
+
Returns:
BaseReformulator instance
-
+
Example:
>>> import querygym as qg
>>> reformulator = qg.create_reformulator("genqr_ensemble", model="gpt-4")
>>> result = reformulator.reformulate(qg.QueryItem("q1", "what causes diabetes"))
"""
from pathlib import Path
-
+
if params is None:
params = {}
-
+
if llm_config is None:
llm_config = {}
-
+
# Set up LLM config
llm_cfg = {
"model": model,
"temperature": llm_config.get("temperature", 0.8),
"max_tokens": llm_config.get("max_tokens", 256),
- **{k: v for k, v in llm_config.items() if k not in ["temperature", "max_tokens"]}
+ **{k: v for k, v in llm_config.items() if k not in ["temperature", "max_tokens"]},
}
-
+
# Create config
config = MethodConfig(
name=method_name,
params=params,
llm=llm_cfg,
seed=kwargs.get("seed", 42),
- retries=kwargs.get("retries", 2)
+ retries=kwargs.get("retries", 2),
)
-
+
# Build LLM client
llm = build_llm(config)
-
+
# Load prompt bank
if prompt_bank_path is None:
prompt_bank_path = Path(__file__).parent / "prompt_bank.yaml"
pb = PromptBank(prompt_bank_path)
-
+
# Get method class and instantiate
if method_name not in METHODS:
raise ValueError(f"Unknown method: {method_name}. Available: {list(METHODS.keys())}")
-
+
MethodClass = METHODS[method_name]
return MethodClass(config, llm, pb)
@@ -124,15 +130,15 @@ def create_reformulator(
def load_queries(path: str, format: str = "tsv", **kwargs):
"""
Load queries from a local file.
-
+
Args:
path: Path to queries file
format: File format - "tsv" or "jsonl" (default: "tsv")
**kwargs: Additional parameters for DataLoader.load_queries()
-
+
Returns:
List of QueryItem objects
-
+
Example:
>>> queries = qg.load_queries("queries.tsv", format="tsv")
>>> queries = qg.load_queries("queries.jsonl", format="jsonl")
@@ -143,15 +149,15 @@ def load_queries(path: str, format: str = "tsv", **kwargs):
def load_qrels(path: str, format: str = "trec", **kwargs):
"""
Load qrels from a local file.
-
+
Args:
path: Path to qrels file
format: File format - "trec" (default: "trec")
**kwargs: Additional parameters for DataLoader.load_qrels()
-
+
Returns:
Dict mapping qid -> {docid -> relevance}
-
+
Example:
>>> qrels = qg.load_qrels("qrels.txt")
"""
@@ -161,14 +167,14 @@ def load_qrels(path: str, format: str = "trec", **kwargs):
def load_contexts(path: str, **kwargs):
"""
Load contexts from a JSONL file.
-
+
Args:
path: Path to contexts JSONL file
**kwargs: Additional parameters for DataLoader.load_contexts()
-
+
Returns:
Dict mapping qid -> list of context strings
-
+
Example:
>>> contexts = qg.load_contexts("contexts.jsonl")
"""
@@ -178,20 +184,20 @@ def load_contexts(path: str, **kwargs):
def load_examples(path: str, **kwargs):
"""
Load few-shot examples from a JSONL file.
-
+
Used for methods like Query2E that need (query, passage) pairs as demonstrations.
-
+
Args:
path: Path to examples JSONL file
**kwargs: Additional parameters for DataLoader.load_examples()
-
+
Returns:
List of dicts with 'query' and 'passage' keys
-
+
Example:
>>> examples = qg.load_examples("examples.jsonl")
>>> reformulator = qg.create_reformulator("query2e", params={"mode": "fs", "examples": examples})
-
+
JSONL format:
{"query": "how long is flea life cycle?", "passage": "The life cycle of a flea..."}
{"query": "cost of flooring?", "passage": "The cost of interior concrete..."}
@@ -202,33 +208,27 @@ def load_examples(path: str, **kwargs):
__all__ = [
# Version
"__version__",
-
# Core classes
"QueryItem",
"ReformulationResult",
"MethodConfig",
"BaseReformulator",
-
# LLM & Prompts
"OpenAICompatibleClient",
"PromptBank",
-
# Searcher interface
"BaseSearcher",
"SearchHit",
"SearcherRegistry",
"create_searcher",
-
# Searcher wrappers
"wrap_pyserini_searcher",
- "wrap_pyterrier_retriever",
+ "wrap_pyterrier_retriever",
"wrap_custom_searcher",
-
# Data
"DataLoader",
"UnifiedQuerySource", # Deprecated
"loaders",
-
# Methods
"GENQR",
"GenQREnsemble",
@@ -238,7 +238,7 @@ def load_examples(path: str, **kwargs):
"LameR",
"Query2E",
"CSQE",
-
+ "ThinkQE",
# High-level API
"run_method",
"build_llm",
@@ -247,10 +247,8 @@ def load_examples(path: str, **kwargs):
"load_qrels",
"load_contexts",
"load_examples",
-
# Retrieval
"Retriever",
-
# Registry
"METHODS",
"register_method",
diff --git a/querygym/__main__.py b/querygym/__main__.py
index 4a0fdaf..f4b201c 100644
--- a/querygym/__main__.py
+++ b/querygym/__main__.py
@@ -1,6 +1,6 @@
"""Entry point for running queryGym as a module: python -m queryGym"""
+
from .cli import app
if __name__ == "__main__":
app()
-
diff --git a/querygym/adapters/__init__.py b/querygym/adapters/__init__.py
index e9cf8bf..da4ea89 100644
--- a/querygym/adapters/__init__.py
+++ b/querygym/adapters/__init__.py
@@ -13,4 +13,3 @@
"PyseriniSearcher",
"PyTerrierSearcher",
]
-
diff --git a/querygym/adapters/pyserini_adapter.py b/querygym/adapters/pyserini_adapter.py
index 101c4b1..104bca1 100644
--- a/querygym/adapters/pyserini_adapter.py
+++ b/querygym/adapters/pyserini_adapter.py
@@ -15,20 +15,28 @@
class PyseriniSearcher(BaseSearcher):
"""
Pyserini adapter implementing the BaseSearcher interface.
-
+
This adapter wraps Pyserini's LuceneSearcher and LuceneImpactSearcher
to provide a standardized interface for querygym.
"""
-
- def __init__(self, index: str, searcher_type: str = "bm25",
- encoder: Optional[str] = None, min_idf: float = 0,
- k1: Optional[float] = None, b: Optional[float] = None,
- rm3: bool = False, rocchio: bool = False,
- rocchio_use_negative: bool = False,
- answer_key: str = "contents", **kwargs):
+
+ def __init__(
+ self,
+ index: str,
+ searcher_type: str = "bm25",
+ encoder: Optional[str] = None,
+ min_idf: float = 0,
+ k1: Optional[float] = None,
+ b: Optional[float] = None,
+ rm3: bool = False,
+ rocchio: bool = False,
+ rocchio_use_negative: bool = False,
+ answer_key: str = "contents",
+ **kwargs,
+ ):
"""
Initialize Pyserini searcher.
-
+
Args:
index: Path to Lucene index or prebuilt index name
searcher_type: Type of searcher ("bm25" or "impact")
@@ -45,8 +53,10 @@ def __init__(self, index: str, searcher_type: str = "bm25",
try:
from pyserini.search.lucene import LuceneSearcher, LuceneImpactSearcher
except ImportError:
- raise ImportError("Pyserini is required for PyseriniSearcher. Install with: pip install pyserini")
-
+ raise ImportError(
+ "Pyserini is required for PyseriniSearcher. Install with: pip install pyserini"
+ )
+
self.index = index
self.searcher_type = searcher_type
self.answer_key = answer_key
@@ -61,9 +71,9 @@ def __init__(self, index: str, searcher_type: str = "bm25",
"rm3": rm3,
"rocchio": rocchio,
"rocchio_use_negative": rocchio_use_negative,
- "answer_key": answer_key
+ "answer_key": answer_key,
}
-
+
# Initialize searcher based on type
if searcher_type == "impact":
if os.path.exists(index):
@@ -75,113 +85,108 @@ def __init__(self, index: str, searcher_type: str = "bm25",
self.searcher = LuceneSearcher(index)
else:
self.searcher = LuceneSearcher.from_prebuilt_index(index)
-
+
# Set BM25 parameters if specified
if k1 is not None and b is not None:
self.searcher.set_bm25(k1, b)
else:
# Auto-set based on known indices
self._set_auto_bm25_params()
-
+
# Add RM3 if requested
if rm3:
self.searcher.set_rm3()
-
+
# Add Rocchio if requested
if rocchio:
if rocchio_use_negative:
self.searcher.set_rocchio(gamma=0.15, use_negative=True)
else:
self.searcher.set_rocchio()
-
+
def _set_auto_bm25_params(self):
"""Set BM25 parameters based on index type."""
- if 'msmarco-passage' in self.index:
+ if "msmarco-passage" in self.index:
self.searcher.set_bm25(0.82, 0.68)
- elif 'msmarco-doc' in self.index:
+ elif "msmarco-doc" in self.index:
self.searcher.set_bm25(4.46, 0.82)
-
+
def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]:
"""Search for documents using a single query."""
hits = self.searcher.search(query, k)
return self._process_hits(hits)
-
- def batch_search(self, queries: List[str], k: int = 10,
- num_threads: int = 1, **kwargs) -> List[List[SearchHit]]:
+
+ def batch_search(
+ self, queries: List[str], k: int = 10, num_threads: int = 1, **kwargs
+ ) -> List[List[SearchHit]]:
"""Search for documents using multiple queries in batch."""
pseudo_batch_topic_ids = [str(idx) for idx, _ in enumerate(queries)]
- results = self.searcher.batch_search(
- queries, pseudo_batch_topic_ids, k, num_threads
- )
-
+ results = self.searcher.batch_search(queries, pseudo_batch_topic_ids, k, num_threads)
+
# Convert results to list in original order
batch_results = [results[id_] for id_ in pseudo_batch_topic_ids]
-
+
# Process each result
processed_results = []
for hits in batch_results:
processed_results.append(self._process_hits(hits))
-
+
return processed_results
-
+
def _process_hits(self, hits) -> List[SearchHit]:
"""Process hits and extract content."""
search_hits = []
-
+
for hit in hits:
try:
# Extract content using hit.lucene_document.get('raw')
- raw_content = hit.lucene_document.get('raw')
+ raw_content = hit.lucene_document.get("raw")
if raw_content:
raw_df = json.loads(raw_content)
text_list = [raw_df[k] for k in self.answer_key.split("|") if raw_df.get(k)]
content = "\t".join(text_list)
else:
# Fallback: try to get content from other fields
- content = getattr(hit, 'contents', '') or getattr(hit, 'text', '') or str(hit.docid)
-
+ content = (
+ getattr(hit, "contents", "") or getattr(hit, "text", "") or str(hit.docid)
+ )
+
search_hit = SearchHit(
docid=hit.docid,
score=hit.score,
content=content,
- metadata={
- "raw_content": raw_content,
- "answer_key": self.answer_key
- }
+ metadata={"raw_content": raw_content, "answer_key": self.answer_key},
)
search_hits.append(search_hit)
-
+
except (json.JSONDecodeError, KeyError, AttributeError) as e:
# Fallback for malformed or missing content
search_hit = SearchHit(
- docid=getattr(hit, 'docid', 'unknown'),
- score=getattr(hit, 'score', 0.0),
- content=str(getattr(hit, 'docid', 'unknown')),
- metadata={
- "error": str(e),
- "fallback": True
- }
+ docid=getattr(hit, "docid", "unknown"),
+ score=getattr(hit, "score", 0.0),
+ content=str(getattr(hit, "docid", "unknown")),
+ metadata={"error": str(e), "fallback": True},
)
search_hits.append(search_hit)
-
+
return search_hits
-
+
def get_searcher_info(self) -> Dict[str, Any]:
"""Get information about the searcher configuration."""
return self._searcher_info.copy()
-
+
def configure(self, **kwargs) -> None:
"""Configure searcher parameters."""
# Update searcher info with new configuration
self._searcher_info.update(kwargs)
-
+
# Apply configuration if possible
if "k1" in kwargs and "b" in kwargs:
self.searcher.set_bm25(kwargs["k1"], kwargs["b"])
-
+
if "rm3" in kwargs and kwargs["rm3"]:
self.searcher.set_rm3()
-
+
if "rocchio" in kwargs and kwargs["rocchio"]:
if kwargs.get("rocchio_use_negative", False):
self.searcher.set_rocchio(gamma=0.15, use_negative=True)
diff --git a/querygym/adapters/pyterrier_adapter.py b/querygym/adapters/pyterrier_adapter.py
index 5d4a7d6..6d86c17 100644
--- a/querygym/adapters/pyterrier_adapter.py
+++ b/querygym/adapters/pyterrier_adapter.py
@@ -14,17 +14,21 @@
class PyTerrierSearcher(BaseSearcher):
"""
PyTerrier adapter implementing the BaseSearcher interface.
-
+
This adapter wraps PyTerrier's search functionality to provide
a standardized interface for querygym.
"""
-
- def __init__(self, index_path: Optional[str] = None,
- index_name: Optional[str] = None,
- searcher_type: str = "bm25", **kwargs):
+
+ def __init__(
+ self,
+ index_path: Optional[str] = None,
+ index_name: Optional[str] = None,
+ searcher_type: str = "bm25",
+ **kwargs,
+ ):
"""
Initialize PyTerrier searcher.
-
+
Args:
index_path: Path to PyTerrier index
index_name: Name of prebuilt PyTerrier index
@@ -34,20 +38,22 @@ def __init__(self, index_path: Optional[str] = None,
try:
import pyterrier as pt
except ImportError:
- raise ImportError("PyTerrier is required for PyTerrierSearcher. Install with: pip install python-terrier")
-
+ raise ImportError(
+ "PyTerrier is required for PyTerrierSearcher. Install with: pip install python-terrier"
+ )
+
self.searcher_type = searcher_type
self._searcher_info = {
"name": "PyTerrierSearcher",
"type": searcher_type,
"index_path": index_path,
- "index_name": index_name
+ "index_name": index_name,
}
-
+
# Initialize PyTerrier if not already done
if not pt.started():
pt.init()
-
+
# Load index
if index_path:
self.index = pt.IndexRef.of(index_path)
@@ -59,14 +65,14 @@ def __init__(self, index_path: Optional[str] = None,
raise ValueError(f"Could not load prebuilt index '{index_name}': {e}")
else:
raise ValueError("Either index_path or index_name must be provided")
-
+
# Create searcher based on type
self.searcher = self._create_searcher(searcher_type, **kwargs)
-
+
def _create_searcher(self, searcher_type: str, **kwargs):
"""Create the appropriate searcher based on type."""
import pyterrier as pt
-
+
# Create a proper PyTerrier Retriever, not a pipeline component
if searcher_type == "bm25":
return pt.terrier.Retriever(self.index, wmodel="BM25", **kwargs)
@@ -81,94 +87,94 @@ def _create_searcher(self, searcher_type: str, **kwargs):
else:
# Default to BM25
return pt.terrier.Retriever(self.index, wmodel="BM25", **kwargs)
-
+
def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]:
"""Search for documents using a single query."""
# Use PyTerrier's search method which returns a DataFrame
search_results = self.searcher.search(query)
-
+
# Limit to top-k results
search_results = search_results.head(k)
-
+
return self._process_results(search_results)
-
- def batch_search(self, queries: List[str], k: int = 10,
- num_threads: int = 1, **kwargs) -> List[List[SearchHit]]:
+
+ def batch_search(
+ self, queries: List[str], k: int = 10, num_threads: int = 1, **kwargs
+ ) -> List[List[SearchHit]]:
"""Search for documents using multiple queries in batch."""
# Create query DataFrame for batch processing
- query_df = pd.DataFrame({
- 'qid': [str(i) for i in range(len(queries))],
- 'query': queries
- })
-
+ query_df = pd.DataFrame({"qid": [str(i) for i in range(len(queries))], "query": queries})
+
# Use PyTerrier's transform method for batch processing
search_results = self.searcher.transform(query_df)
-
+
# Group results by query ID
batch_results = []
for qid in range(len(queries)):
- query_results = search_results[search_results['qid'] == str(qid)].head(k)
+ query_results = search_results[search_results["qid"] == str(qid)].head(k)
batch_results.append(self._process_results(query_results))
-
+
return batch_results
-
+
def _process_results(self, results_df: pd.DataFrame) -> List[SearchHit]:
"""Process PyTerrier results into SearchHit objects."""
search_hits = []
-
+
if results_df.empty:
return search_hits
-
+
for _, row in results_df.iterrows():
try:
# Handle different possible column names
- docid = str(row.get('docno', row.get('docid', '')))
- score = float(row.get('score', 0.0))
-
+ docid = str(row.get("docno", row.get("docid", "")))
+ score = float(row.get("score", 0.0))
+
# Try different content field names
- content = (row.get('text', '') or
- row.get('content', '') or
- row.get('body', '') or
- str(docid))
-
+ content = (
+ row.get("text", "")
+ or row.get("content", "")
+ or row.get("body", "")
+ or str(docid)
+ )
+
search_hit = SearchHit(
docid=docid,
score=score,
content=str(content),
metadata={
- 'qid': str(row.get('qid', '')),
- 'docno': str(row.get('docno', '')),
- 'searcher_type': self.searcher_type,
- 'available_columns': list(row.index)
- }
+ "qid": str(row.get("qid", "")),
+ "docno": str(row.get("docno", "")),
+ "searcher_type": self.searcher_type,
+ "available_columns": list(row.index),
+ },
)
search_hits.append(search_hit)
-
+
except Exception as e:
# Fallback for malformed data
search_hit = SearchHit(
- docid=str(row.get('docno', 'unknown')),
+ docid=str(row.get("docno", "unknown")),
score=0.0,
- content=str(row.get('docno', 'unknown')),
+ content=str(row.get("docno", "unknown")),
metadata={
- 'error': str(e),
- 'fallback': True,
- 'searcher_type': self.searcher_type
- }
+ "error": str(e),
+ "fallback": True,
+ "searcher_type": self.searcher_type,
+ },
)
search_hits.append(search_hit)
-
+
return search_hits
-
+
def get_searcher_info(self) -> Dict[str, Any]:
"""Get information about the searcher configuration."""
return self._searcher_info.copy()
-
+
def configure(self, **kwargs) -> None:
"""Configure searcher parameters."""
# Update searcher info
self._searcher_info.update(kwargs)
-
+
# Recreate searcher with new parameters
self.searcher = self._create_searcher(self.searcher_type, **kwargs)
diff --git a/querygym/cli.py b/querygym/cli.py
index d9a47db..2cc98b8 100644
--- a/querygym/cli.py
+++ b/querygym/cli.py
@@ -7,10 +7,18 @@
from .core.runner import run_method
from .core.prompts import PromptBank
from .data.dataloader import UnifiedQuerySource
+
app = typer.Typer(help="querygym Toolkit CLI")
-def build_script_lines(index_path:str, topics:str, run:str, qrels:Optional[str]=None,
- bm25:bool=True, extra:str="") -> list[str]:
+
+def build_script_lines(
+ index_path: str,
+ topics: str,
+ run: str,
+ qrels: Optional[str] = None,
+ bm25: bool = True,
+ extra: str = "",
+) -> list[str]:
lines = [
"#!/usr/bin/env bash",
"set -euo pipefail",
@@ -28,53 +36,78 @@ def build_script_lines(index_path:str, topics:str, run:str, qrels:Optional[str]=
if qrels:
lines += [
"# trec_eval",
- f"trec_eval -m map -m P.10 -m ndcg_cut.10 {qrels} {run} | tee {run}.eval.txt"
+ f"trec_eval -m map -m P.10 -m ndcg_cut.10 {qrels} {run} | tee {run}.eval.txt",
]
return lines
+
@app.command()
-def run(method: str = typer.Option(...),
- queries_tsv: Path = typer.Option(..., "--queries-tsv", exists=True),
- output_tsv: Path = typer.Option(..., "--output-tsv"),
- cfg_path: Optional[Path] = typer.Option(None, "--cfg-path"),
- prompt_bank: Path = typer.Option(Path(__file__).with_name("prompt_bank.yaml"), "--prompt-bank"),
- ctx_jsonl: Optional[Path] = typer.Option(None, "--ctx-jsonl", help="Optional contexts JSONL"),
- output_format: str = typer.Option("both", "--output-format", help="Output format: 'concat', 'plain', or 'both' (default: both)"),
- parallel: Optional[bool] = typer.Option(None, "--parallel", help="Enable parallel generation for methods like MuGI"),
- mode: Optional[str] = typer.Option(None, "--mode", help="Mode for methods: 'zs' (zero-shot) or 'fs' (few-shot)"),
- # Few-shot data loading options
- dataset_type: Optional[str] = typer.Option(None, "--dataset-type", help="Dataset type: 'msmarco', 'beir', or 'generic'"),
- collection_path: Optional[Path] = typer.Option(None, "--collection-path", help="Path to collection file (TSV for MS MARCO/generic)"),
- train_queries_path: Optional[Path] = typer.Option(None, "--train-queries-path", help="Path to training queries file"),
- train_qrels_path: Optional[Path] = typer.Option(None, "--train-qrels-path", help="Path to training qrels file"),
- beir_data_dir: Optional[Path] = typer.Option(None, "--beir-data-dir", help="Path to BEIR dataset directory"),
- train_split: Optional[str] = typer.Option(None, "--train-split", help="BEIR train split: 'train' or 'dev'"),
- num_examples: Optional[int] = typer.Option(None, "--num-examples", help="Number of few-shot examples"),
+def run(
+ method: str = typer.Option(...),
+ queries_tsv: Path = typer.Option(..., "--queries-tsv", exists=True),
+ output_tsv: Path = typer.Option(..., "--output-tsv"),
+ cfg_path: Optional[Path] = typer.Option(None, "--cfg-path"),
+ prompt_bank: Path = typer.Option(Path(__file__).with_name("prompt_bank.yaml"), "--prompt-bank"),
+ ctx_jsonl: Optional[Path] = typer.Option(None, "--ctx-jsonl", help="Optional contexts JSONL"),
+ output_format: str = typer.Option(
+ "both",
+ "--output-format",
+ help="Output format: 'concat', 'plain', or 'both' (default: both)",
+ ),
+ parallel: Optional[bool] = typer.Option(
+ None, "--parallel", help="Enable parallel generation for methods like MuGI"
+ ),
+ mode: Optional[str] = typer.Option(
+ None, "--mode", help="Mode for methods: 'zs' (zero-shot) or 'fs' (few-shot)"
+ ),
+ # Few-shot data loading options
+ dataset_type: Optional[str] = typer.Option(
+ None, "--dataset-type", help="Dataset type: 'msmarco', 'beir', or 'generic'"
+ ),
+ collection_path: Optional[Path] = typer.Option(
+ None, "--collection-path", help="Path to collection file (TSV for MS MARCO/generic)"
+ ),
+ train_queries_path: Optional[Path] = typer.Option(
+ None, "--train-queries-path", help="Path to training queries file"
+ ),
+ train_qrels_path: Optional[Path] = typer.Option(
+ None, "--train-qrels-path", help="Path to training qrels file"
+ ),
+ beir_data_dir: Optional[Path] = typer.Option(
+ None, "--beir-data-dir", help="Path to BEIR dataset directory"
+ ),
+ train_split: Optional[str] = typer.Option(
+ None, "--train-split", help="BEIR train split: 'train' or 'dev'"
+ ),
+ num_examples: Optional[int] = typer.Option(
+ None, "--num-examples", help="Number of few-shot examples"
+ ),
):
import yaml
import os
import re
-
+
def expand_env_vars(text):
"""Expand environment variables in YAML content"""
+
def replace_env_var(match):
var_expr = match.group(1)
- if ':-' in var_expr:
- var_name, default_value = var_expr.split(':-', 1)
+ if ":-" in var_expr:
+ var_name, default_value = var_expr.split(":-", 1)
return os.getenv(var_name, default_value)
else:
- return os.getenv(var_expr, '')
-
- return re.sub(r'\$\{([^}]+)\}', replace_env_var, text)
-
+ return os.getenv(var_expr, "")
+
+ return re.sub(r"\$\{([^}]+)\}", replace_env_var, text)
+
# Default to defaults.yaml if no config path provided
if cfg_path is None:
cfg_path = Path(__file__).parent / "config" / "defaults.yaml"
-
+
yaml_content = cfg_path.read_text()
expanded_content = expand_env_vars(yaml_content)
cfg = yaml.safe_load(expanded_content)
-
+
# Override parameters if provided via CLI
params = cfg.get("params", {})
if parallel is not None:
@@ -95,7 +128,7 @@ def replace_env_var(match):
params["train_split"] = train_split
if num_examples is not None:
params["num_examples"] = num_examples
-
+
# Handle context/examples loading based on method
ctx_map = None
if ctx_jsonl:
@@ -104,6 +137,7 @@ def replace_env_var(match):
# Auto-set mode to "fs" if not explicitly set to "zs"
if mode not in ["zs", "zeroshot"]:
from .data.dataloader import DataLoader
+
examples = DataLoader.load_examples(ctx_jsonl)
params["examples"] = examples
# Auto-set mode to fs if not already set
@@ -116,33 +150,46 @@ def replace_env_var(match):
else:
# For other methods (LameR, CSQE): load as per-query contexts
from .data.dataloader import UnifiedContextSource
+
ctx_src = UnifiedContextSource(mode="file", path=ctx_jsonl)
- ctx_map = ctx_src.load(list(UnifiedQuerySource(backend="local", format="tsv", path=queries_tsv).iter()))
-
- mc = MethodConfig(name=method, params=params, llm=cfg["llm"],
- seed=cfg.get("seed",42), retries=cfg.get("retries",2))
+ ctx_map = ctx_src.load(
+ list(UnifiedQuerySource(backend="local", format="tsv", path=queries_tsv).iter())
+ )
+
+ mc = MethodConfig(
+ name=method,
+ params=params,
+ llm=cfg["llm"],
+ seed=cfg.get("seed", 42),
+ retries=cfg.get("retries", 2),
+ )
src = UnifiedQuerySource(backend="local", format="tsv", path=queries_tsv)
queries = list(src.iter())
-
+
# Pass ctx_map as-is (None if not provided, dict if provided)
- results = run_method(method_name=method, cfg=mc, queries=queries,
- prompt_bank_path=str(prompt_bank), ctx_map=ctx_map)
-
+ results = run_method(
+ method_name=method,
+ cfg=mc,
+ queries=queries,
+ prompt_bank_path=str(prompt_bank),
+ ctx_map=ctx_map,
+ )
+
# Show progress summary
typer.echo(f"Processed {len(results)} queries with {method}")
-
+
def write_concat_format(output_path):
"""Write concatenated format to file"""
with open(output_path, "w", encoding="utf-8") as f:
- w = csv.writer(f, delimiter="\t", quoting=csv.QUOTE_NONE, escapechar='\\')
+ w = csv.writer(f, delimiter="\t", quoting=csv.QUOTE_NONE, escapechar="\\")
for r in results:
w.writerow([r.qid, r.reformulated])
-
+
def write_plain_format(output_path):
"""Write plain format to file"""
with open(output_path, "w", encoding="utf-8") as f:
- w = csv.writer(f, delimiter="\t", quoting=csv.QUOTE_NONE, escapechar='\\')
-
+ w = csv.writer(f, delimiter="\t", quoting=csv.QUOTE_NONE, escapechar="\\")
+
# Write header row based on method
if method == "lamer":
# LameR: qid \t passage_1 \t passage_2 \t ... \t passage_n
@@ -157,11 +204,13 @@ def write_plain_format(output_path):
reformulated = results[0].reformulated
parts = reformulated.replace(original, "").strip().split()
num_passages = len(parts)
-
+
header = ["qid"] + [f"passage_{i+1}" for i in range(num_passages)]
w.writerow(header)
else:
- w.writerow(["qid", "passage_1", "passage_2", "passage_3", "passage_4", "passage_5"])
+ w.writerow(
+ ["qid", "passage_1", "passage_2", "passage_3", "passage_4", "passage_5"]
+ )
elif method == "query2doc":
# Query2Doc: qid \t passage
w.writerow(["qid", "passage"])
@@ -177,7 +226,16 @@ def write_plain_format(output_path):
header = ["qid"] + [f"reformulation_{i+1}" for i in range(num_reformulations)]
w.writerow(header)
else:
- w.writerow(["qid", "reformulation_1", "reformulation_2", "reformulation_3", "reformulation_4", "reformulation_5"])
+ w.writerow(
+ [
+ "qid",
+ "reformulation_1",
+ "reformulation_2",
+ "reformulation_3",
+ "reformulation_4",
+ "reformulation_5",
+ ]
+ )
elif method == "qa_expand":
# QA-Expand: qid \t final_refined_query
w.writerow(["qid", "refined_query"])
@@ -196,7 +254,21 @@ def write_plain_format(output_path):
header = ["qid"] + [f"keyword_{i+1}" for i in range(num_keywords)]
w.writerow(header)
else:
- w.writerow(["qid", "keyword_1", "keyword_2", "keyword_3", "keyword_4", "keyword_5", "keyword_6", "keyword_7", "keyword_8", "keyword_9", "keyword_10"])
+ w.writerow(
+ [
+ "qid",
+ "keyword_1",
+ "keyword_2",
+ "keyword_3",
+ "keyword_4",
+ "keyword_5",
+ "keyword_6",
+ "keyword_7",
+ "keyword_8",
+ "keyword_9",
+ "keyword_10",
+ ]
+ )
elif method == "query2e":
# Query2E: qid \t keyword_1 \t keyword_2 \t ... \t keyword_n
# Determine number of keywords from first result
@@ -213,7 +285,7 @@ def write_plain_format(output_path):
else:
# Default fallback
w.writerow(["qid", "generated_content"])
-
+
# Write data rows
for r in results:
if method == "lamer":
@@ -284,33 +356,34 @@ def write_plain_format(output_path):
plain_output = r.reformulated.replace(r.original, "").strip()
cleaned_output = clean_text(plain_output)
w.writerow([r.qid, cleaned_output])
-
+
def clean_text(text):
"""Clean text by removing newlines, extra whitespace, and other formatting issues"""
if not text:
return ""
-
+
# Convert to string if not already
text = str(text)
-
+
# Remove newlines and carriage returns
- text = text.replace('\n', ' ').replace('\r', ' ')
-
+ text = text.replace("\n", " ").replace("\r", " ")
+
# Remove backslashes that might be escape characters
- text = text.replace('\\', ' ')
-
+ text = text.replace("\\", " ")
+
# Replace multiple spaces with single space
import re
- text = re.sub(r'\s+', ' ', text)
-
+
+ text = re.sub(r"\s+", " ", text)
+
# Strip leading/trailing whitespace
text = text.strip()
-
+
# Remove quotes from beginning and end
text = text.strip('"').strip("'")
-
+
return text
-
+
# Generate output files based on format
if output_format == "concat":
write_concat_format(output_tsv)
@@ -323,71 +396,101 @@ def clean_text(text):
output_dir = output_tsv.parent
output_stem = output_tsv.stem
output_suffix = output_tsv.suffix
-
+
concat_file = output_dir / f"{output_stem}_concat{output_suffix}"
plain_file = output_dir / f"{output_stem}_plain{output_suffix}"
-
+
write_concat_format(concat_file)
write_plain_format(plain_file)
-
+
typer.echo(f"Wrote both formats:")
typer.echo(f" Concat: {concat_file}")
typer.echo(f" Plain: {plain_file}")
else:
- raise ValueError(f"Invalid output_format: {output_format}. Must be 'concat', 'plain', or 'both'")
+ raise ValueError(
+ f"Invalid output_format: {output_format}. Must be 'concat', 'plain', or 'both'"
+ )
+
@app.command("data-to-tsv")
-def data_to_tsv(backend: str = typer.Option(...),
- source: str = typer.Option(...),
- out: Path = typer.Option(..., "--out"),
- path: Optional[Path] = typer.Option(None, help="Local TSV/JSONL path"),
- format: Optional[str] = typer.Option(None, help="tsv|jsonl"),
- tsv_qid_col: int = 0, tsv_query_col: int = 1,
- jsonl_qid_key: str = "qid", jsonl_query_key: str = "query",
- msmarco_queries_tsv: Optional[Path] = None,
- hf_name: Optional[str] = None, hf_config: Optional[str] = None,
- split: str = "dev", hf_qid_key: str = "query_id", hf_query_key: str = "query",
- beir_root: Optional[Path] = None, beir_name: Optional[str] = None,
+def data_to_tsv(
+ backend: str = typer.Option(...),
+ source: str = typer.Option(...),
+ out: Path = typer.Option(..., "--out"),
+ path: Optional[Path] = typer.Option(None, help="Local TSV/JSONL path"),
+ format: Optional[str] = typer.Option(None, help="tsv|jsonl"),
+ tsv_qid_col: int = 0,
+ tsv_query_col: int = 1,
+ jsonl_qid_key: str = "qid",
+ jsonl_query_key: str = "query",
+ msmarco_queries_tsv: Optional[Path] = None,
+ hf_name: Optional[str] = None,
+ hf_config: Optional[str] = None,
+ split: str = "dev",
+ hf_qid_key: str = "query_id",
+ hf_query_key: str = "query",
+ beir_root: Optional[Path] = None,
+ beir_name: Optional[str] = None,
):
src = UnifiedQuerySource(
- backend=backend if backend in ("local","msmarco","beir") else None,
- source=source if source in ("file","hf","beir") else None,
- path=path, format=format, tsv_qid_col=tsv_qid_col, tsv_query_col=tsv_query_col,
- jsonl_qid_key=jsonl_qid_key, jsonl_query_key=jsonl_query_key,
+ backend=backend if backend in ("local", "msmarco", "beir") else None,
+ source=source if source in ("file", "hf", "beir") else None,
+ path=path,
+ format=format,
+ tsv_qid_col=tsv_qid_col,
+ tsv_query_col=tsv_query_col,
+ jsonl_qid_key=jsonl_qid_key,
+ jsonl_query_key=jsonl_query_key,
msmarco_queries_tsv=msmarco_queries_tsv,
- hf_name=hf_name, hf_config=hf_config, split=split,
- hf_qid_key=hf_qid_key, hf_query_key=hf_query_key,
- beir_root=beir_root, beir_name=beir_name,
+ hf_name=hf_name,
+ hf_config=hf_config,
+ split=split,
+ hf_qid_key=hf_qid_key,
+ hf_query_key=hf_query_key,
+ beir_root=beir_root,
+ beir_name=beir_name,
)
UnifiedQuerySource.export_to_tsv(src.iter(), out)
typer.echo(f"Wrote {out}")
+
@app.command("prompts-list")
def prompts_list(prompt_bank: Path = typer.Option(Path(__file__).with_name("prompt_bank.yaml"))):
pb = PromptBank(prompt_bank)
for pid in pb.list():
typer.echo(pid)
+
@app.command("prompts-show")
-def prompts_show(prompt_id: str,
- prompt_bank: Path = typer.Option(Path(__file__).with_name("prompt_bank.yaml"))):
+def prompts_show(
+ prompt_id: str, prompt_bank: Path = typer.Option(Path(__file__).with_name("prompt_bank.yaml"))
+):
pb = PromptBank(prompt_bank)
meta = pb.get_meta(prompt_id)
typer.echo(json.dumps(meta, indent=2))
+
@app.command("script-gen")
-def script_gen(index_path: Path = typer.Option(..., exists=False),
- topics: Path = typer.Option(..., help="TSV queries"),
- run: Path = typer.Option(..., help="Output run file"),
- output_bash: Path = typer.Option(Path("run_retrieval.sh")),
- qrels: Optional[Path] = None,
- bm25: bool = True,
- extra: str = typer.Option("", help="Extra Pyserini CLI flags, concatenated")
+def script_gen(
+ index_path: Path = typer.Option(..., exists=False),
+ topics: Path = typer.Option(..., help="TSV queries"),
+ run: Path = typer.Option(..., help="Output run file"),
+ output_bash: Path = typer.Option(Path("run_retrieval.sh")),
+ qrels: Optional[Path] = None,
+ bm25: bool = True,
+ extra: str = typer.Option("", help="Extra Pyserini CLI flags, concatenated"),
):
- lines = build_script_lines(str(index_path), str(topics), str(run),
- str(qrels) if qrels else None, bm25=bm25, extra=extra)
+ lines = build_script_lines(
+ str(index_path),
+ str(topics),
+ str(run),
+ str(qrels) if qrels else None,
+ bm25=bm25,
+ extra=extra,
+ )
output_bash.write_text("\n".join(lines))
typer.echo(f"Generated {output_bash}")
+
if __name__ == "__main__":
app()
diff --git a/querygym/core/base.py b/querygym/core/base.py
index 5de92dc..f65f245 100644
--- a/querygym/core/base.py
+++ b/querygym/core/base.py
@@ -3,30 +3,33 @@
from typing import Dict, List, Optional, Any
from tqdm import tqdm
+
@dataclass
class QueryItem:
"""Represents a query with its ID and text.
-
+
Attributes:
qid: Unique query identifier
text: Query text content
-
+
Example:
>>> query = QueryItem(qid="q1", text="what causes diabetes")
"""
+
qid: str
text: str
+
@dataclass
class ReformulationResult:
"""Result of a query reformulation operation.
-
+
Attributes:
qid: Query identifier
original: Original query text
reformulated: Reformulated query text
metadata: Additional metadata about the reformulation
-
+
Example:
>>> result = ReformulationResult(
... qid="q1",
@@ -34,22 +37,24 @@ class ReformulationResult:
... reformulated="diabetes causes symptoms treatment"
... )
"""
+
qid: str
original: str
reformulated: str
metadata: Dict[str, Any] = field(default_factory=dict)
+
@dataclass
class MethodConfig:
"""Configuration for a reformulation method.
-
+
Attributes:
name: Method name (e.g., "genqr", "query2doc")
params: Method-specific parameters
llm: LLM configuration (model, temperature, max_tokens, etc.)
seed: Random seed for reproducibility (default: 42)
retries: Number of retries on LLM failure (default: 2)
-
+
Example:
>>> config = MethodConfig(
... name="genqr_ensemble",
@@ -58,19 +63,21 @@ class MethodConfig:
... seed=42
... )
"""
+
name: str
params: Dict[str, Any]
llm: Dict[str, Any]
seed: int = 42
retries: int = 2
+
class BaseReformulator:
"""Base class for all query reformulation methods.
-
+
All reformulation methods inherit from this class and implement
the `reformulate` method. This class provides common functionality
for concatenating results and batch processing.
-
+
Attributes:
VERSION: Method version string
REQUIRES_CONTEXT: Whether method requires retrieved contexts
@@ -80,16 +87,17 @@ class BaseReformulator:
llm: LLM client instance
prompts: Prompt bank instance
"""
+
VERSION = "0.1"
REQUIRES_CONTEXT = False
-
+
# Concatenation strategy configuration
CONCATENATION_STRATEGY = "query_repeat_plus_generated" # Default strategy
DEFAULT_QUERY_REPEATS = 5 # Default number of query repetitions
def __init__(self, cfg: MethodConfig, llm_client, prompt_resolver):
"""Initialize the reformulator.
-
+
Args:
cfg: Method configuration
llm_client: LLM client for generating reformulations
@@ -99,63 +107,69 @@ def __init__(self, cfg: MethodConfig, llm_client, prompt_resolver):
self.llm = llm_client
self.prompts = prompt_resolver
- def reformulate(self, q: QueryItem, contexts: Optional[List[str]] = None) -> ReformulationResult:
+ def reformulate(
+ self, q: QueryItem, contexts: Optional[List[str]] = None
+ ) -> ReformulationResult:
"""Reformulate a single query.
-
+
Args:
q: Query to reformulate
contexts: Optional retrieved contexts (required for some methods)
-
+
Returns:
Reformulation result
-
+
Raises:
NotImplementedError: Must be implemented by subclasses
"""
raise NotImplementedError
- def concatenate_result(self, original_query: str, generated_content: str | List[str],
- query_repeats: Optional[int] = None,
- content_separator: str = " ") -> str:
+ def concatenate_result(
+ self,
+ original_query: str,
+ generated_content: str | List[str],
+ query_repeats: Optional[int] = None,
+ content_separator: str = " ",
+ ) -> str:
"""
Concatenate the original query with generated content based on the method's strategy.
Matches exact patterns from baseline implementations.
-
+
Args:
original_query: The original query text
generated_content: The content generated by the LLM (string or list of strings)
query_repeats: Number of times to repeat the original query (overrides default)
content_separator: Separator to use between multiple content pieces
-
+
Returns:
Concatenated reformulated query
"""
if query_repeats is None:
query_repeats = self.cfg.params.get("repeat_query_weight", self.DEFAULT_QUERY_REPEATS)
-
+
strategy = self.cfg.params.get("concatenation_strategy", self.CONCATENATION_STRATEGY)
-
+
# Handle multiple content pieces
if isinstance(generated_content, list):
generated_text = content_separator.join(generated_content)
else:
generated_text = generated_content
-
+
if strategy == "query_repeat_plus_generated":
# Pattern: (Q × N) + generated_content
# Used by: GenQR, GenQREnsemble, QA-Expand
result = " ".join([original_query] * query_repeats + [generated_text])
-
+
elif strategy == "query_plus_generated":
# Pattern: Q + generated_content (single repetition)
# Used by: Query2E
result = " ".join([original_query, generated_text])
-
+
elif strategy == "generated_only":
# Pattern: Only generated content (no original query)
# Used by: Query2Doc
result = generated_text
-
+
elif strategy == "adaptive_query_repeat_plus_generated":
# Pattern: (query + ' ') * repetition_times + generated_content
# Uses adaptive repetition based on content length
@@ -164,7 +178,7 @@ def concatenate_result(self, original_query: str, generated_content: str | List[
repetition_times = (len(generated_text) // len(original_query)) // adaptive_times
repetition_times = max(1, repetition_times) # At least 1 repetition
result = (original_query + " ") * repetition_times + generated_text
-
+
elif strategy == "interleaved_query_content":
# Pattern: q + a1 + q + a2 + q + a3 + ... (interleaving)
# Used by: LameR
@@ -177,32 +191,34 @@ def concatenate_result(self, original_query: str, generated_content: str | List[
else:
# Single content: query + content
result = " ".join([original_query, generated_text])
-
+
elif strategy == "generated_plus_query_repeat":
# Pattern: generated_content + (Q × N)
# Used by: Future methods that want generated content first
result = " ".join([generated_text] + [original_query] * query_repeats)
-
+
elif strategy == "query_sandwich":
# Pattern: Q + generated_content + Q
# Used by: Future methods that want query wrapping
result = " ".join([original_query, generated_text, original_query])
-
+
else:
# Default fallback to query_repeat_plus_generated pattern
result = " ".join([original_query] * query_repeats + [generated_text])
-
+
# Clean up any newlines and quotes in the final result (for non-newline strategies)
- result = result.replace('\n', ' ').replace('\r', ' ').strip()
+ result = result.replace("\n", " ").replace("\r", " ").strip()
# Remove quotes from beginning and end
result = result.strip('"').strip("'")
return result
- def retrieve_contexts_batch(self, queries: List[QueryItem], retrieval_params: Optional[Dict[str, Any]] = None) -> Dict[str, List[str]]:
+ def retrieve_contexts_batch(
+ self, queries: List[QueryItem], retrieval_params: Optional[Dict[str, Any]] = None
+ ) -> Dict[str, List[str]]:
"""Retrieve contexts for a batch of queries via any searcher implementing BaseSearcher.
-
+
This method supports both searcher adapters (via registry) and wrapped searchers:
-
+
**Using Adapters (via searcher_type):**
retrieval_params = {
"searcher_type": "pyserini", # or "pyterrier"
@@ -210,28 +226,28 @@ def retrieve_contexts_batch(self, queries: List[QueryItem], retrieval_params: Op
"k": 10,
"threads": 16,
}
-
+
**Using Wrapped Searchers:**
# Wrap your existing searcher
from pyserini.search.lucene import LuceneSearcher
my_searcher = LuceneSearcher.from_prebuilt_index('msmarco-v1-passage')
wrapped = wrap_pyserini_searcher(my_searcher)
-
+
retrieval_params = {
"searcher": wrapped, # Pass wrapped searcher directly
"k": 10,
"threads": 16,
}
-
+
**Using Adapter Instances Directly:**
from queryGym import create_searcher
searcher = create_searcher("pyserini", index="msmarco-v1-passage")
-
+
retrieval_params = {
"searcher": searcher, # Pass adapter instance directly
"k": 10,
}
-
+
Args:
queries: List of QueryItem objects to retrieve contexts for
retrieval_params: Method-specific retrieval parameters. Can include:
@@ -240,85 +256,90 @@ def retrieve_contexts_batch(self, queries: List[QueryItem], retrieval_params: Op
- searcher_kwargs: Keyword arguments to pass to searcher constructor
- k: Number of documents to retrieve per query (default: 10)
- threads: Number of threads for batch search (default: 16)
-
+
Returns:
Dictionary mapping query IDs to lists of context strings
"""
if not retrieval_params or not queries:
return {}
-
+
# Extract retrieval configuration
k = retrieval_params.get("k", 10)
threads = retrieval_params.get("threads", 16)
-
+
# Get or create searcher
searcher = retrieval_params.get("searcher")
-
+
if searcher is None:
# Create searcher using registry
searcher_type = retrieval_params.get("searcher_type", "pyserini")
searcher_kwargs = retrieval_params.get("searcher_kwargs", {})
-
+
# Lazy import to avoid hard dependency at load time
try:
from .searcher import create_searcher
+
searcher = create_searcher(searcher_type, **searcher_kwargs)
except Exception as e:
# If searcher creation fails, return empty dict
# This allows methods to work without retrieval if contexts are pre-provided
return {}
-
+
# Ensure searcher implements BaseSearcher interface
from .searcher import BaseSearcher
+
if not isinstance(searcher, BaseSearcher):
raise ValueError(
- f"Searcher must implement BaseSearcher interface. "
- f"Got type: {type(searcher)}"
+ f"Searcher must implement BaseSearcher interface. " f"Got type: {type(searcher)}"
)
-
+
# Extract query texts and IDs
query_texts = [q.text for q in queries]
query_ids = [q.qid for q in queries]
-
+
# Perform batch retrieval using searcher's batch_search method
batch_results = searcher.batch_search(query_texts, k=k, num_threads=threads)
-
+
# Convert SearchHit results to dictionary format (qid -> list of context strings)
contexts = {}
for qid, search_hits in zip(query_ids, batch_results):
# Extract content from each SearchHit
contexts[qid] = [hit.content for hit in search_hits]
-
+
return contexts
- def retrieve_contexts_if_needed(self, q: QueryItem, retrieval_params: Optional[Dict[str, Any]] = None) -> List[str]:
+ def retrieve_contexts_if_needed(
+ self, q: QueryItem, retrieval_params: Optional[Dict[str, Any]] = None
+ ) -> List[str]:
"""Retrieve contexts for a single query (fallback method).
-
+
Args:
q: QueryItem to retrieve contexts for
retrieval_params: Method-specific retrieval parameters
-
+
Returns:
List of context strings
"""
contexts = self.retrieve_contexts_batch([q], retrieval_params)
return contexts.get(q.qid, [])
- def reformulate_batch(self, queries: List[QueryItem], ctx_map: Optional[Dict[str, List[str]]] = None) -> List[ReformulationResult]:
+ def reformulate_batch(
+ self, queries: List[QueryItem], ctx_map: Optional[Dict[str, List[str]]] = None
+ ) -> List[ReformulationResult]:
"""Reformulate multiple queries in batch.
-
+
Processes a list of queries and returns reformulation results.
Shows a progress bar during processing. If the method requires
contexts and none are provided, attempts to retrieve them automatically.
-
+
Args:
queries: List of queries to reformulate
ctx_map: Optional mapping of query IDs to context lists.
If None and method requires contexts, will attempt retrieval.
-
+
Returns:
List of reformulation results, one per input query
-
+
Example:
>>> queries = [QueryItem("q1", "diabetes"), QueryItem("q2", "cancer")]
>>> results = reformulator.reformulate_batch(queries)
@@ -326,31 +347,31 @@ def reformulate_batch(self, queries: List[QueryItem], ctx_map: Optional[Dict[str
... print(f"{r.qid}: {r.reformulated}")
"""
# If no contexts provided and method requires context, do batch retrieval
- if ctx_map is None and hasattr(self, 'REQUIRES_CONTEXT') and self.REQUIRES_CONTEXT:
+ if ctx_map is None and hasattr(self, "REQUIRES_CONTEXT") and self.REQUIRES_CONTEXT:
retrieval_params = self._get_retrieval_params()
if retrieval_params:
ctx_map = self.retrieve_contexts_batch(queries, retrieval_params)
-
+
out: List[ReformulationResult] = []
-
+
# Import tqdm for progress bar
try:
progress_bar = tqdm(queries, desc=f"Reformulating with {self.cfg.name}", unit="query")
except ImportError:
# Fallback if tqdm is not available
progress_bar = queries
-
+
for q in progress_bar:
ctx = (ctx_map or {}).get(q.qid)
result = self.reformulate(q, ctx)
out.append(result)
-
+
# Update progress bar description with current query info
- if hasattr(progress_bar, 'set_description'):
+ if hasattr(progress_bar, "set_description"):
progress_bar.set_description(f"Reformulating with {self.cfg.name} (QID: {q.qid})")
-
+
return out
-
+
def _get_retrieval_params(self) -> Optional[Dict[str, Any]]:
"""Get retrieval parameters for this method. Override in subclasses."""
return None
diff --git a/querygym/core/llm.py b/querygym/core/llm.py
index f99543c..f64f3ab 100644
--- a/querygym/core/llm.py
+++ b/querygym/core/llm.py
@@ -3,16 +3,17 @@
from openai import OpenAI
from typing import List, Dict, Any
+
class OpenAICompatibleClient:
"""Client for OpenAI-compatible chat completion APIs.
-
+
Supports any API that implements the OpenAI chat completions format,
including OpenAI, Azure OpenAI, local models via vLLM, etc.
-
+
Attributes:
client: OpenAI client instance
model: Model identifier to use for completions
-
+
Example:
>>> client = OpenAICompatibleClient(
... model="gpt-4",
@@ -24,30 +25,32 @@ class OpenAICompatibleClient:
... {"role": "user", "content": "Hello!"}
... ])
"""
-
+
def __init__(self, model: str, base_url: str | None = None, api_key: str | None = None):
"""Initialize the LLM client.
-
+
Args:
model: Model identifier (e.g., "gpt-4", "gpt-3.5-turbo")
base_url: Optional base URL for the API. If None, uses OPENAI_BASE_URL
environment variable or OpenAI's default.
api_key: Optional API key. If None, uses OPENAI_API_KEY environment variable.
"""
- self.client = OpenAI(base_url=base_url or os.getenv("OPENAI_BASE_URL", None),
- api_key=api_key or os.getenv("OPENAI_API_KEY"))
+ self.client = OpenAI(
+ base_url=base_url or os.getenv("OPENAI_BASE_URL", None),
+ api_key=api_key or os.getenv("OPENAI_API_KEY"),
+ )
self.model = model
def chat(self, messages: List[Dict[str, str]], **kw: Any) -> str:
"""Send a chat completion request.
-
+
Args:
messages: List of message dicts with 'role' and 'content' keys
**kw: Additional parameters (temperature, max_tokens, etc.)
-
+
Returns:
Generated text response
-
+
Example:
>>> response = client.chat(
... messages=[{"role": "user", "content": "Hello"}],
diff --git a/querygym/core/prompts.py b/querygym/core/prompts.py
index 3bdcde0..3099f49 100644
--- a/querygym/core/prompts.py
+++ b/querygym/core/prompts.py
@@ -4,10 +4,11 @@
from pathlib import Path
from typing import Dict, List, Any
+
@dataclass
class PromptSpec:
"""Specification for a single prompt template.
-
+
Attributes:
id: Unique prompt identifier
method_family: Method family this prompt belongs to
@@ -15,34 +16,36 @@ class PromptSpec:
template: Template dict with 'system', 'user', 'assistant' keys
meta: Additional metadata (authors, license, etc.)
"""
+
id: str
method_family: str
version: int
template: Dict[str, str]
meta: Dict[str, Any]
+
class PromptBank:
"""Manages prompt templates from a YAML file.
-
+
Loads and provides access to prompt templates with metadata.
Supports rendering templates with variable substitution.
-
+
Attributes:
_by_id: Internal dict mapping prompt IDs to PromptSpec objects
-
+
Example:
>>> from pathlib import Path
>>> pb = PromptBank(Path("querygym/prompt_bank.yaml"))
>>> messages = pb.render("genqr_keywords", query="diabetes")
>>> print(messages)
"""
-
+
def __init__(self, path: str | Path):
"""Initialize the prompt bank from a YAML file.
-
+
Args:
path: Path to the prompt bank YAML file
-
+
Raises:
FileNotFoundError: If the YAML file doesn't exist
yaml.YAMLError: If the YAML is malformed
@@ -52,54 +55,58 @@ def __init__(self, path: str | Path):
for x in items:
self._by_id[x["id"]] = PromptSpec(
id=x["id"],
- method_family=x.get("method_family",""),
- version=x.get("version",1),
+ method_family=x.get("method_family", ""),
+ version=x.get("version", 1),
template=x["template"],
- meta={k:v for k,v in x.items() if k not in ["id","method_family","version","template"]}
+ meta={
+ k: v
+ for k, v in x.items()
+ if k not in ["id", "method_family", "version", "template"]
+ },
)
def render(self, prompt_id: str, **vars) -> List[Dict[str, str]]:
"""Render a prompt template with variable substitution.
-
+
Args:
prompt_id: ID of the prompt to render
**vars: Variables to substitute in the template (e.g., query="text")
-
+
Returns:
List of message dicts with 'role' and 'content' keys,
ready for LLM chat completion
-
+
Raises:
KeyError: If prompt_id doesn't exist
KeyError: If required template variables are missing
-
+
Example:
>>> messages = pb.render("genqr_keywords", query="diabetes")
>>> # Returns: [{"role": "system", "content": "..."}, ...]
"""
spec = self._by_id[prompt_id]
messages = []
-
+
# System message (if present)
if "system" in spec.template:
- sys = spec.template.get("system","").format(**vars)
- messages.append({"role":"system","content":sys})
-
+ sys = spec.template.get("system", "").format(**vars)
+ messages.append({"role": "system", "content": sys})
+
# User message (if present)
if "user" in spec.template:
- usr = spec.template.get("user","").format(**vars)
- messages.append({"role":"user","content":usr})
-
+ usr = spec.template.get("user", "").format(**vars)
+ messages.append({"role": "user", "content": usr})
+
# Assistant message (if present) - for priming responses
if "assistant" in spec.template:
- asst = spec.template.get("assistant","").format(**vars)
- messages.append({"role":"assistant","content":asst})
-
+ asst = spec.template.get("assistant", "").format(**vars)
+ messages.append({"role": "assistant", "content": asst})
+
return messages
def list(self) -> List[str]:
"""List all available prompt IDs.
-
+
Returns:
List of prompt IDs
"""
@@ -107,27 +114,27 @@ def list(self) -> List[str]:
def get(self, prompt_id: str) -> PromptSpec:
"""Get a prompt specification by ID.
-
+
Args:
prompt_id: ID of the prompt
-
+
Returns:
PromptSpec object
-
+
Raises:
KeyError: If prompt_id doesn't exist
"""
return self._by_id[prompt_id]
-
+
def get_meta(self, prompt_id: str) -> Dict[str, Any]:
"""Get metadata for a prompt.
-
+
Args:
prompt_id: ID of the prompt
-
+
Returns:
Metadata dict (authors, license, etc.)
-
+
Raises:
KeyError: If prompt_id doesn't exist
"""
diff --git a/querygym/core/registry.py b/querygym/core/registry.py
index 187fe6a..1ef2b8a 100644
--- a/querygym/core/registry.py
+++ b/querygym/core/registry.py
@@ -4,8 +4,10 @@
METHODS: Dict[str, Type[BaseReformulator]] = {}
+
def register_method(name: str):
def deco(cls: Type[BaseReformulator]):
METHODS[name] = cls
return cls
+
return deco
diff --git a/querygym/core/runner.py b/querygym/core/runner.py
index 38f02df..88abba7 100644
--- a/querygym/core/runner.py
+++ b/querygym/core/runner.py
@@ -13,15 +13,24 @@
from ..methods.mugi import MuGI
from ..methods.lamer import LameR
from ..methods.query2e import Query2E
+from ..methods.csqe import CSQE
+from ..methods.thinkqe import ThinkQE
+
def build_llm(cfg: MethodConfig):
llm_cfg = cfg.llm
- return OpenAICompatibleClient(model=llm_cfg["model"],
- base_url=llm_cfg.get("base_url"),
- api_key=llm_cfg.get("api_key"))
+ return OpenAICompatibleClient(
+ model=llm_cfg["model"], base_url=llm_cfg.get("base_url"), api_key=llm_cfg.get("api_key")
+ )
+
-def run_method(method_name: str, cfg: MethodConfig, queries: List[QueryItem],
- prompt_bank_path: str, ctx_map: Optional[Dict[str, List[str]]] = None):
+def run_method(
+ method_name: str,
+ cfg: MethodConfig,
+ queries: List[QueryItem],
+ prompt_bank_path: str,
+ ctx_map: Optional[Dict[str, List[str]]] = None,
+):
Method = METHODS[method_name]
llm = build_llm(cfg)
pb = PromptBank(prompt_bank_path)
diff --git a/querygym/core/searcher.py b/querygym/core/searcher.py
index 2be0e7c..47c118a 100644
--- a/querygym/core/searcher.py
+++ b/querygym/core/searcher.py
@@ -14,11 +14,12 @@
@dataclass
class SearchHit:
"""Standardized search hit result."""
+
docid: str
score: float
content: str
metadata: Dict[str, Any] = None
-
+
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
@@ -27,57 +28,58 @@ def __post_init__(self):
class BaseSearcher(ABC):
"""
Abstract base class for searcher implementations.
-
+
Any searcher library (Pyserini, PyTerrier, etc.) can implement this interface
to be compatible with querygym's retrieval framework.
"""
-
+
@abstractmethod
def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]:
"""
Search for documents using a single query.
-
+
Args:
query: The search query string
k: Number of documents to retrieve
**kwargs: Additional search parameters specific to the searcher
-
+
Returns:
List of SearchHit objects ordered by relevance score (highest first)
"""
pass
-
+
@abstractmethod
- def batch_search(self, queries: List[str], k: int = 10,
- num_threads: int = 1, **kwargs) -> List[List[SearchHit]]:
+ def batch_search(
+ self, queries: List[str], k: int = 10, num_threads: int = 1, **kwargs
+ ) -> List[List[SearchHit]]:
"""
Search for documents using multiple queries in batch.
-
+
Args:
queries: List of search query strings
k: Number of documents to retrieve per query
num_threads: Number of threads for parallel processing
**kwargs: Additional search parameters specific to the searcher
-
+
Returns:
List of lists of SearchHit objects, one list per query
"""
pass
-
+
@abstractmethod
def get_searcher_info(self) -> Dict[str, Any]:
"""
Get information about the searcher configuration.
-
+
Returns:
Dictionary containing searcher metadata (name, version, parameters, etc.)
"""
pass
-
+
def configure(self, **kwargs) -> None:
"""
Configure searcher parameters.
-
+
Args:
**kwargs: Configuration parameters specific to the searcher
"""
@@ -88,25 +90,25 @@ def configure(self, **kwargs) -> None:
class SearcherRegistry:
"""Registry for managing different searcher implementations."""
-
+
_searchers: Dict[str, type] = {}
-
+
@classmethod
def register(cls, name: str, searcher_class: type) -> None:
"""Register a searcher implementation."""
if not issubclass(searcher_class, BaseSearcher):
raise ValueError(f"Searcher class must inherit from BaseSearcher")
cls._searchers[name] = searcher_class
-
+
@classmethod
def get_searcher(cls, name: str, **kwargs) -> BaseSearcher:
"""Get a searcher instance by name."""
if name not in cls._searchers:
raise ValueError(f"Unknown searcher: {name}. Available: {list(cls._searchers.keys())}")
-
+
searcher_class = cls._searchers[name]
return searcher_class(**kwargs)
-
+
@classmethod
def list_searchers(cls) -> List[str]:
"""List all registered searcher names."""
@@ -117,11 +119,11 @@ def list_searchers(cls) -> List[str]:
def create_searcher(searcher_type: str, **kwargs) -> BaseSearcher:
"""
Create a searcher instance.
-
+
Args:
searcher_type: Type of searcher to create
**kwargs: Arguments to pass to the searcher constructor
-
+
Returns:
BaseSearcher instance
"""
diff --git a/querygym/core/searcher_wrappers.py b/querygym/core/searcher_wrappers.py
index e24d34d..863a083 100644
--- a/querygym/core/searcher_wrappers.py
+++ b/querygym/core/searcher_wrappers.py
@@ -16,17 +16,17 @@
def wrap_pyserini_searcher(pyserini_searcher, answer_key: str = "contents") -> BaseSearcher:
"""
Wrap a user's Pyserini searcher for use with querygym.
-
+
This function works standalone - it doesn't import Pyserini at the module level.
The user must provide their own Pyserini searcher instance.
-
+
Args:
pyserini_searcher: User's LuceneSearcher or LuceneImpactSearcher instance
answer_key: Field name(s) to extract content from (pipe-separated)
-
+
Returns:
BaseSearcher instance that can be used with querygym
-
+
Example:
>>> from pyserini.search.lucene import LuceneSearcher
>>> lucene_searcher = LuceneSearcher.from_prebuilt_index('msmarco-v1-passage')
@@ -34,46 +34,52 @@ def wrap_pyserini_searcher(pyserini_searcher, answer_key: str = "contents") -> B
>>> wrapped_searcher = wrap_pyserini_searcher(lucene_searcher)
>>> retriever = qg.Retriever(searcher=wrapped_searcher)
"""
-
+
class PyseriniWrapper(BaseSearcher):
def __init__(self, searcher, answer_key):
# Validate that the searcher has the expected methods
- if not hasattr(searcher, 'search') or not hasattr(searcher, 'batch_search'):
+ if not hasattr(searcher, "search") or not hasattr(searcher, "batch_search"):
raise ValueError("Provided searcher must have 'search' and 'batch_search' methods")
-
+
self.searcher = searcher
self.answer_key = answer_key
self._searcher_info = {
"name": "UserPyseriniWrapper",
"type": "user_pyserini",
"answer_key": answer_key,
- "searcher_class": type(searcher).__name__
+ "searcher_class": type(searcher).__name__,
}
-
+
def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]:
hits = self.searcher.search(query, k)
return self._process_hits(hits)
-
- def batch_search(self, queries: List[str], k: int = 10,
- num_threads: int = 1, **kwargs) -> List[List[SearchHit]]:
+
+ def batch_search(
+ self, queries: List[str], k: int = 10, num_threads: int = 1, **kwargs
+ ) -> List[List[SearchHit]]:
pseudo_batch_topic_ids = [str(idx) for idx, _ in enumerate(queries)]
results = self.searcher.batch_search(queries, pseudo_batch_topic_ids, k, num_threads)
batch_results = [results[id_] for id_ in pseudo_batch_topic_ids]
return [self._process_hits(hits) for hits in batch_results]
-
+
def _process_hits(self, hits) -> List[SearchHit]:
search_hits = []
for hit in hits:
try:
- raw_content = hit.lucene_document.get('raw')
+ raw_content = hit.lucene_document.get("raw")
if raw_content:
import json
+
raw_df = json.loads(raw_content)
text_list = [raw_df[k] for k in self.answer_key.split("|") if raw_df.get(k)]
content = "\t".join(text_list)
else:
- content = getattr(hit, 'contents', '') or getattr(hit, 'text', '') or str(hit.docid)
-
+ content = (
+ getattr(hit, "contents", "")
+ or getattr(hit, "text", "")
+ or str(hit.docid)
+ )
+
search_hit = SearchHit(
docid=hit.docid,
score=hit.score,
@@ -81,45 +87,41 @@ def _process_hits(self, hits) -> List[SearchHit]:
metadata={
"user_defined": True,
"searcher_class": type(self.searcher).__name__,
- "answer_key": self.answer_key
- }
+ "answer_key": self.answer_key,
+ },
)
search_hits.append(search_hit)
except Exception as e:
search_hit = SearchHit(
- docid=getattr(hit, 'docid', 'unknown'),
- score=getattr(hit, 'score', 0.0),
- content=str(getattr(hit, 'docid', 'unknown')),
- metadata={
- "error": str(e),
- "user_defined": True,
- "fallback": True
- }
+ docid=getattr(hit, "docid", "unknown"),
+ score=getattr(hit, "score", 0.0),
+ content=str(getattr(hit, "docid", "unknown")),
+ metadata={"error": str(e), "user_defined": True, "fallback": True},
)
search_hits.append(search_hit)
return search_hits
-
+
def get_searcher_info(self) -> Dict[str, Any]:
return self._searcher_info.copy()
-
+
return PyseriniWrapper(pyserini_searcher, answer_key)
def wrap_pyterrier_retriever(pyterrier_retriever, index, text_field: str = "text") -> BaseSearcher:
"""
Wrap a user's PyTerrier retriever for use with querygym.
-
+
This function works standalone - it doesn't import PyTerrier at the module level.
The user must provide their own PyTerrier retriever and index instances.
-
+
Args:
pyterrier_retriever: User's PyTerrier retriever instance
index: PyTerrier index reference
text_field: Field name containing document text
-
+
Returns:
BaseSearcher instance that can be used with querygym
-
+
Example:
>>> import pyterrier as pt
>>> pt.init()
@@ -127,13 +129,13 @@ def wrap_pyterrier_retriever(pyterrier_retriever, index, text_field: str = "text
>>> wrapped_searcher = wrap_pyterrier_retriever(BM25_r, index)
>>> retriever = qg.Retriever(searcher=wrapped_searcher)
"""
-
+
class PyTerrierWrapper(BaseSearcher):
def __init__(self, retriever, index, text_field):
# Validate that the retriever has the expected methods
- if not hasattr(retriever, 'search') or not hasattr(retriever, 'transform'):
+ if not hasattr(retriever, "search") or not hasattr(retriever, "transform"):
raise ValueError("Provided retriever must have 'search' and 'transform' methods")
-
+
self.retriever = retriever
self.index = index
self.text_field = text_field
@@ -141,37 +143,37 @@ def __init__(self, retriever, index, text_field):
"name": "UserPyTerrierWrapper",
"type": "user_pyterrier",
"text_field": text_field,
- "retriever_class": type(retriever).__name__
+ "retriever_class": type(retriever).__name__,
}
-
+
def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]:
"""Search for documents using a single query."""
try:
# Use PyTerrier's search method which returns a DataFrame
results_df = self.retriever.search(query)
-
+
# Convert to SearchHit objects
search_hits = []
for _, row in results_df.head(k).iterrows():
# Extract content from the text field
- content = str(row.get(self.text_field, row.get('docno', '')))
-
+ content = str(row.get(self.text_field, row.get("docno", "")))
+
search_hit = SearchHit(
- docid=str(row.get('docno', row.get('docid', ''))),
- score=float(row.get('score', 0.0)),
+ docid=str(row.get("docno", row.get("docid", ""))),
+ score=float(row.get("score", 0.0)),
content=content,
metadata={
- 'user_defined': True,
- 'retriever_class': type(self.retriever).__name__,
- 'text_field': self.text_field,
- 'qid': str(row.get('qid', '')),
- 'rank': int(row.get('rank', 0))
- }
+ "user_defined": True,
+ "retriever_class": type(self.retriever).__name__,
+ "text_field": self.text_field,
+ "qid": str(row.get("qid", "")),
+ "rank": int(row.get("rank", 0)),
+ },
)
search_hits.append(search_hit)
-
+
return search_hits
-
+
except Exception as e:
# Fallback: return mock results if anything fails
search_hits = []
@@ -181,53 +183,56 @@ def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]:
score=1.0 - (i * 0.1),
content=f"Content for query: {query}",
metadata={
- 'user_defined': True,
- 'retriever_class': type(self.retriever).__name__,
- 'text_field': self.text_field,
- 'error': str(e),
- 'fallback': True
- }
+ "user_defined": True,
+ "retriever_class": type(self.retriever).__name__,
+ "text_field": self.text_field,
+ "error": str(e),
+ "fallback": True,
+ },
)
search_hits.append(search_hit)
return search_hits
-
- def batch_search(self, queries: List[str], k: int = 10,
- num_threads: int = 1, **kwargs) -> List[List[SearchHit]]:
+
+ def batch_search(
+ self, queries: List[str], k: int = 10, num_threads: int = 1, **kwargs
+ ) -> List[List[SearchHit]]:
"""Search for documents using multiple queries in batch."""
batch_results = []
-
+
for query in queries:
# Use the single search method for each query
results = self.search(query, k)
batch_results.append(results)
-
+
return batch_results
-
+
def get_searcher_info(self) -> Dict[str, Any]:
return self._searcher_info.copy()
-
+
return PyTerrierWrapper(pyterrier_retriever, index, text_field)
-def wrap_custom_searcher(search_func, batch_search_func=None, searcher_name: str = "CustomSearcher") -> BaseSearcher:
+def wrap_custom_searcher(
+ search_func, batch_search_func=None, searcher_name: str = "CustomSearcher"
+) -> BaseSearcher:
"""
Wrap user's custom search functions for use with queryGym.
-
+
Args:
search_func: Function that takes (query: str, k: int) and returns list of (docid, score, content) tuples
batch_search_func: Optional function that takes (queries: List[str], k: int) and returns List[List[tuple]]
searcher_name: Name for the searcher
-
+
Returns:
BaseSearcher instance that can be used with querygym
-
+
Example:
>>> def my_search(query, k):
... return [("doc1", 0.9, "content1"), ("doc2", 0.8, "content2")]
>>> wrapped_searcher = wrap_custom_searcher(my_search)
>>> retriever = qg.Retriever(searcher=wrapped_searcher)
"""
-
+
class CustomWrapper(BaseSearcher):
def __init__(self, search_func, batch_search_func, searcher_name):
self.search_func = search_func
@@ -236,13 +241,13 @@ def __init__(self, search_func, batch_search_func, searcher_name):
self._searcher_info = {
"name": searcher_name,
"type": "user_custom",
- "has_batch_search": batch_search_func is not None
+ "has_batch_search": batch_search_func is not None,
}
-
+
def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]:
results = self.search_func(query, k)
search_hits = []
-
+
for result in results:
if len(result) >= 3:
docid, score, content = result[0], result[1], result[2]
@@ -250,27 +255,28 @@ def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]:
docid, score, content = result[0], result[1], str(result[0])
else:
docid, score, content = str(result[0]), 0.0, str(result[0])
-
+
search_hit = SearchHit(
docid=str(docid),
score=float(score),
content=str(content),
- metadata={"user_defined": True, "custom_searcher": True}
+ metadata={"user_defined": True, "custom_searcher": True},
)
search_hits.append(search_hit)
-
+
return search_hits
-
- def batch_search(self, queries: List[str], k: int = 10,
- num_threads: int = 1, **kwargs) -> List[List[SearchHit]]:
+
+ def batch_search(
+ self, queries: List[str], k: int = 10, num_threads: int = 1, **kwargs
+ ) -> List[List[SearchHit]]:
if self.batch_search_func:
batch_results = self.batch_search_func(queries, k)
return [self.search(query, k) for query, results in zip(queries, batch_results)]
else:
# Fallback to individual searches
return [self.search(query, k) for query in queries]
-
+
def get_searcher_info(self) -> Dict[str, Any]:
return self._searcher_info.copy()
-
+
return CustomWrapper(search_func, batch_search_func, searcher_name)
diff --git a/querygym/core/utils.py b/querygym/core/utils.py
index 452610d..c339c1f 100644
--- a/querygym/core/utils.py
+++ b/querygym/core/utils.py
@@ -2,16 +2,20 @@
import random
from typing import Optional
+
def seed_everything(seed: Optional[int] = 42):
- if seed is None: return
+ if seed is None:
+ return
random.seed(seed)
try:
import numpy as np
+
np.random.seed(seed)
except Exception:
pass
try:
import torch
+
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
diff --git a/querygym/data/dataloader.py b/querygym/data/dataloader.py
index 3caeae5..8331575 100644
--- a/querygym/data/dataloader.py
+++ b/querygym/data/dataloader.py
@@ -16,7 +16,7 @@
class DataLoader:
"""Core data loader for local files only."""
-
+
@staticmethod
def load_queries(
path: Union[str, Path],
@@ -24,11 +24,11 @@ def load_queries(
qid_col: int = 0,
query_col: int = 1,
qid_key: str = "qid",
- query_key: str = "query"
+ query_key: str = "query",
) -> List[QueryItem]:
"""
Load queries from a local file.
-
+
Args:
path: Path to queries file
format: File format - "tsv" or "jsonl"
@@ -36,162 +36,147 @@ def load_queries(
query_col: Column index for query text (TSV only)
qid_key: JSON key for query ID (JSONL only)
query_key: JSON key for query text (JSONL only)
-
+
Returns:
List of QueryItem objects
-
+
Example:
>>> queries = DataLoader.load_queries("queries.tsv", format="tsv")
>>> queries = DataLoader.load_queries("queries.jsonl", format="jsonl")
"""
path = Path(path)
-
+
if not path.exists():
raise FileNotFoundError(f"Query file not found: {path}")
-
+
if format == "tsv":
return DataLoader._load_queries_tsv(path, qid_col, query_col)
elif format == "jsonl":
return DataLoader._load_queries_jsonl(path, qid_key, query_key)
else:
raise ValueError(f"Unsupported format: {format}. Use 'tsv' or 'jsonl'")
-
+
@staticmethod
- def _load_queries_tsv(
- path: Path,
- qid_col: int,
- query_col: int
- ) -> List[QueryItem]:
+ def _load_queries_tsv(path: Path, qid_col: int, query_col: int) -> List[QueryItem]:
"""Load queries from TSV file."""
queries = []
warned_empty = False
warned_malformed = False
-
+
with open(path, "r", encoding="utf-8") as f:
reader = csv.reader(f, delimiter="\t")
for line_num, row in enumerate(reader, 1):
# Skip malformed rows
if len(row) <= max(qid_col, query_col):
if not warned_malformed:
- warnings.warn(
- f"Skipping malformed rows in {path} (not enough columns)"
- )
+ warnings.warn(f"Skipping malformed rows in {path} (not enough columns)")
warned_malformed = True
continue
-
+
qid = str(row[qid_col]).strip()
query_text = str(row[query_col]).strip()
-
+
# Skip empty queries
if not query_text:
if not warned_empty:
warnings.warn(f"Skipping empty queries in {path}")
warned_empty = True
continue
-
+
queries.append(QueryItem(qid=qid, text=query_text))
-
+
if not queries:
raise ValueError(f"No valid queries found in {path}")
-
+
return queries
-
+
@staticmethod
- def _load_queries_jsonl(
- path: Path,
- qid_key: str,
- query_key: str
- ) -> List[QueryItem]:
+ def _load_queries_jsonl(path: Path, qid_key: str, query_key: str) -> List[QueryItem]:
"""Load queries from JSONL file."""
queries = []
warned_empty = False
warned_missing_keys = False
-
+
with open(path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
if not line.strip():
continue
-
+
try:
obj = json.loads(line)
except json.JSONDecodeError as e:
warnings.warn(f"Invalid JSON at line {line_num} in {path}: {e}")
continue
-
+
# Check for required keys
if qid_key not in obj or query_key not in obj:
if not warned_missing_keys:
- warnings.warn(
- f"Missing keys '{qid_key}' or '{query_key}' in {path}"
- )
+ warnings.warn(f"Missing keys '{qid_key}' or '{query_key}' in {path}")
warned_missing_keys = True
continue
-
+
qid = str(obj[qid_key])
query_text = str(obj[query_key]).strip()
-
+
# Skip empty queries
if not query_text:
if not warned_empty:
warnings.warn(f"Skipping empty queries in {path}")
warned_empty = True
continue
-
+
queries.append(QueryItem(qid=qid, text=query_text))
-
+
if not queries:
raise ValueError(f"No valid queries found in {path}")
-
+
return queries
-
+
@staticmethod
- def load_qrels(
- path: Union[str, Path],
- format: str = "trec"
- ) -> Dict[str, Dict[str, int]]:
+ def load_qrels(path: Union[str, Path], format: str = "trec") -> Dict[str, Dict[str, int]]:
"""
Load qrels (relevance judgments) from a local file.
-
+
Args:
path: Path to qrels file
format: File format - "trec" (standard TREC format)
-
+
Returns:
Dict mapping qid -> {docid -> relevance}
-
+
Example:
>>> qrels = DataLoader.load_qrels("qrels.txt", format="trec")
>>> qrels["q1"]["doc123"] # relevance score for doc123 in query q1
"""
path = Path(path)
-
+
if not path.exists():
raise FileNotFoundError(f"Qrels file not found: {path}")
-
+
if format == "trec":
return DataLoader._load_qrels_trec(path)
else:
raise ValueError(f"Unsupported qrels format: {format}. Use 'trec'")
-
+
@staticmethod
def _load_qrels_trec(path: Path) -> Dict[str, Dict[str, int]]:
"""Load qrels from TREC format file."""
qrels: Dict[str, Dict[str, int]] = {}
warned_malformed = False
-
+
with open(path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
-
+
parts = line.split()
if len(parts) < 4:
if not warned_malformed:
warnings.warn(f"Skipping malformed lines in {path}")
warned_malformed = True
continue
-
+
qid = parts[0]
docid = parts[2]
try:
@@ -201,164 +186,149 @@ def _load_qrels_trec(path: Path) -> Dict[str, Dict[str, int]]:
warnings.warn(f"Invalid relevance score in {path}")
warned_malformed = True
continue
-
+
if qid not in qrels:
qrels[qid] = {}
qrels[qid][docid] = relevance
-
+
if not qrels:
raise ValueError(f"No valid qrels found in {path}")
-
+
return qrels
-
+
@staticmethod
def load_contexts(
- path: Union[str, Path],
- qid_key: str = "qid",
- contexts_key: str = "contexts"
+ path: Union[str, Path], qid_key: str = "qid", contexts_key: str = "contexts"
) -> Dict[str, List[str]]:
"""
Load pre-retrieved contexts from a JSONL file.
-
+
Args:
path: Path to contexts JSONL file
qid_key: JSON key for query ID
contexts_key: JSON key for contexts list
-
+
Returns:
Dict mapping qid -> list of context strings
-
+
Example:
>>> contexts = DataLoader.load_contexts("contexts.jsonl")
>>> contexts["q1"] # List of context strings for query q1
"""
path = Path(path)
-
+
if not path.exists():
raise FileNotFoundError(f"Contexts file not found: {path}")
-
+
contexts: Dict[str, List[str]] = {}
warned_missing_keys = False
-
+
with open(path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
if not line.strip():
continue
-
+
try:
obj = json.loads(line)
except json.JSONDecodeError as e:
warnings.warn(f"Invalid JSON at line {line_num} in {path}: {e}")
continue
-
+
# Check for required keys
if qid_key not in obj or contexts_key not in obj:
if not warned_missing_keys:
- warnings.warn(
- f"Missing keys '{qid_key}' or '{contexts_key}' in {path}"
- )
+ warnings.warn(f"Missing keys '{qid_key}' or '{contexts_key}' in {path}")
warned_missing_keys = True
continue
-
+
qid = str(obj[qid_key])
ctx_list = obj[contexts_key]
-
+
if not isinstance(ctx_list, list):
warnings.warn(f"Contexts for {qid} is not a list, skipping")
continue
-
+
contexts[qid] = [str(ctx) for ctx in ctx_list]
-
+
if not contexts:
raise ValueError(f"No valid contexts found in {path}")
-
+
return contexts
-
+
@staticmethod
def load_examples(
- path: Union[str, Path],
- query_key: str = "query",
- passage_key: str = "passage"
+ path: Union[str, Path], query_key: str = "query", passage_key: str = "passage"
) -> List[Dict[str, str]]:
"""
Load few-shot examples from a JSONL file.
-
+
Used by methods like Query2E that need (query, passage) pairs as demonstrations.
-
+
Args:
path: Path to examples JSONL file
query_key: JSON key for query text
passage_key: JSON key for passage text
-
+
Returns:
List of dicts with 'query' and 'passage' keys
-
+
Example:
>>> examples = DataLoader.load_examples("examples.jsonl")
>>> # Returns: [{"query": "...", "passage": "..."}, ...]
-
+
JSONL format:
{"query": "how long is flea life cycle?", "passage": "The life cycle of a flea..."}
{"query": "cost of flooring?", "passage": "The cost of interior concrete..."}
"""
path = Path(path)
-
+
if not path.exists():
raise FileNotFoundError(f"Examples file not found: {path}")
-
+
examples: List[Dict[str, str]] = []
warned_missing_keys = False
-
+
with open(path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
if not line.strip():
continue
-
+
try:
obj = json.loads(line)
except json.JSONDecodeError as e:
warnings.warn(f"Invalid JSON at line {line_num} in {path}: {e}")
continue
-
+
# Check for required keys
if query_key not in obj or passage_key not in obj:
if not warned_missing_keys:
- warnings.warn(
- f"Missing keys '{query_key}' or '{passage_key}' in {path}"
- )
+ warnings.warn(f"Missing keys '{query_key}' or '{passage_key}' in {path}")
warned_missing_keys = True
continue
-
- examples.append({
- "query": str(obj[query_key]),
- "passage": str(obj[passage_key])
- })
-
+
+ examples.append({"query": str(obj[query_key]), "passage": str(obj[passage_key])})
+
if not examples:
raise ValueError(f"No valid examples found in {path}")
-
+
return examples
-
+
@staticmethod
- def save_queries(
- queries: List[QueryItem],
- path: Union[str, Path],
- format: str = "tsv"
- ) -> None:
+ def save_queries(queries: List[QueryItem], path: Union[str, Path], format: str = "tsv") -> None:
"""
Save queries to a file.
-
+
Args:
queries: List of QueryItem objects
path: Output file path
format: Output format - "tsv" or "jsonl"
-
+
Example:
>>> DataLoader.save_queries(queries, "output.tsv", format="tsv")
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
-
+
if format == "tsv":
with open(path, "w", encoding="utf-8") as f:
writer = csv.writer(f, delimiter="\t")
@@ -376,27 +346,27 @@ def save_queries(
# Backward compatibility: Keep UnifiedQuerySource for now with deprecation warning
class UnifiedQuerySource:
"""Deprecated: Use DataLoader.load_queries() instead."""
-
+
def __init__(self, backend: str = "local", **kwargs):
warnings.warn(
"UnifiedQuerySource is deprecated. Use DataLoader.load_queries() instead.",
DeprecationWarning,
- stacklevel=2
+ stacklevel=2,
)
-
+
if backend != "local":
raise ValueError(
f"Backend '{backend}' is no longer supported. "
"Use querygym.loaders module for BEIR/MS MARCO helpers."
)
-
+
self.path = kwargs.get("path")
self.format = kwargs.get("format", "tsv")
self.qid_col = kwargs.get("tsv_qid_col", 0)
self.query_col = kwargs.get("tsv_query_col", 1)
self.qid_key = kwargs.get("jsonl_qid_key", "qid")
self.query_key = kwargs.get("jsonl_query_key", "query")
-
+
def iter(self):
"""Iterate over queries."""
queries = DataLoader.load_queries(
@@ -405,10 +375,10 @@ def iter(self):
qid_col=self.qid_col,
query_col=self.query_col,
qid_key=self.qid_key,
- query_key=self.query_key
+ query_key=self.query_key,
)
return iter(queries)
-
+
@staticmethod
def export_to_tsv(items, out_path):
"""Export queries to TSV."""
@@ -418,28 +388,24 @@ def export_to_tsv(items, out_path):
# Backward compatibility: Keep UnifiedContextSource for now with deprecation warning
class UnifiedContextSource:
"""Deprecated: Use DataLoader.load_contexts() instead."""
-
+
def __init__(self, mode: str = "file", **kwargs):
warnings.warn(
"UnifiedContextSource is deprecated. Use DataLoader.load_contexts() instead.",
DeprecationWarning,
- stacklevel=2
+ stacklevel=2,
)
-
+
if mode != "file":
raise ValueError(
f"Mode '{mode}' is no longer supported. "
"Load contexts from file or use retrieval separately."
)
-
+
self.path = kwargs.get("path")
self.qid_key = kwargs.get("qid_key", "qid")
self.ctx_key = kwargs.get("ctx_key", "contexts")
-
+
def load(self, queries: List[QueryItem]) -> Dict[str, List[str]]:
"""Load contexts from file."""
- return DataLoader.load_contexts(
- self.path,
- qid_key=self.qid_key,
- contexts_key=self.ctx_key
- )
+ return DataLoader.load_contexts(self.path, qid_key=self.qid_key, contexts_key=self.ctx_key)
diff --git a/querygym/loaders/beir.py b/querygym/loaders/beir.py
index 629a11c..9974ad4 100644
--- a/querygym/loaders/beir.py
+++ b/querygym/loaders/beir.py
@@ -18,117 +18,104 @@
from ..data.dataloader import DataLoader
-def load_queries(
- beir_data_dir: Union[str, Path],
- split: str = "test"
-) -> List[QueryItem]:
+def load_queries(beir_data_dir: Union[str, Path], split: str = "test") -> List[QueryItem]:
"""
Load queries from a BEIR dataset directory.
-
+
BEIR datasets have a queries.jsonl file with format:
{"_id": "query_id", "text": "query text", ...}
-
+
Args:
beir_data_dir: Path to BEIR dataset directory (contains queries.jsonl)
split: Dataset split (not used for queries, kept for API consistency)
-
+
Returns:
List of QueryItem objects
-
+
Example:
>>> from querygym.datasets import beir
>>> queries = beir.load_queries("./data/nfcorpus")
"""
beir_data_dir = Path(beir_data_dir)
queries_file = beir_data_dir / "queries.jsonl"
-
+
if not queries_file.exists():
raise FileNotFoundError(
f"BEIR queries file not found: {queries_file}\n"
f"Expected BEIR directory structure with queries.jsonl"
)
-
+
# BEIR uses "_id" and "text" as keys
- return DataLoader.load_queries(
- queries_file,
- format="jsonl",
- qid_key="_id",
- query_key="text"
- )
-
-
-def load_qrels(
- beir_data_dir: Union[str, Path],
- split: str = "test"
-) -> Dict[str, Dict[str, int]]:
+ return DataLoader.load_queries(queries_file, format="jsonl", qid_key="_id", query_key="text")
+
+
+def load_qrels(beir_data_dir: Union[str, Path], split: str = "test") -> Dict[str, Dict[str, int]]:
"""
Load qrels from a BEIR dataset directory.
-
+
BEIR qrels are in TSV format: query-id corpus-id score
Located at: qrels/{split}.tsv
-
+
Args:
beir_data_dir: Path to BEIR dataset directory
split: Dataset split ("train", "dev", "test")
-
+
Returns:
Dict mapping qid -> {docid -> relevance}
-
+
Example:
>>> from querygym.datasets import beir
>>> qrels = beir.load_qrels("./data/nfcorpus", split="test")
"""
beir_data_dir = Path(beir_data_dir)
qrels_file = beir_data_dir / "qrels" / f"{split}.tsv"
-
+
if not qrels_file.exists():
raise FileNotFoundError(
f"BEIR qrels file not found: {qrels_file}\n"
f"Expected: {beir_data_dir}/qrels/{split}.tsv"
)
-
+
# BEIR qrels format: query-id \t corpus-id \t score
# We need to convert to standard TREC format (add iteration column)
qrels: Dict[str, Dict[str, int]] = {}
-
+
with open(qrels_file, "r", encoding="utf-8") as f:
for line in f:
parts = line.strip().split("\t")
if len(parts) < 3:
continue
-
+
qid = parts[0]
docid = parts[1]
try:
relevance = int(parts[2])
except ValueError:
continue
-
+
if qid not in qrels:
qrels[qid] = {}
qrels[qid][docid] = relevance
-
+
if not qrels:
raise ValueError(f"No valid qrels found in {qrels_file}")
-
+
return qrels
-def load_corpus(
- beir_data_dir: Union[str, Path]
-) -> Dict[str, Dict[str, str]]:
+def load_corpus(beir_data_dir: Union[str, Path]) -> Dict[str, Dict[str, str]]:
"""
Load corpus from a BEIR dataset directory.
-
+
BEIR corpus is in JSONL format with:
{"_id": "doc_id", "title": "...", "text": "...", ...}
-
+
Args:
beir_data_dir: Path to BEIR dataset directory
-
+
Returns:
Dict mapping doc_id -> {"title": ..., "text": ...}
-
+
Example:
>>> from querygym.datasets import beir
>>> corpus = beir.load_corpus("./data/nfcorpus")
@@ -136,35 +123,32 @@ def load_corpus(
"""
beir_data_dir = Path(beir_data_dir)
corpus_file = beir_data_dir / "corpus.jsonl"
-
+
if not corpus_file.exists():
raise FileNotFoundError(
f"BEIR corpus file not found: {corpus_file}\n"
f"Expected BEIR directory structure with corpus.jsonl"
)
-
+
corpus: Dict[str, Dict[str, str]] = {}
-
+
with open(corpus_file, "r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
-
+
try:
doc = json.loads(line)
except json.JSONDecodeError:
continue
-
+
doc_id = doc.get("_id")
if not doc_id:
continue
-
- corpus[str(doc_id)] = {
- "title": doc.get("title", ""),
- "text": doc.get("text", "")
- }
-
+
+ corpus[str(doc_id)] = {"title": doc.get("title", ""), "text": doc.get("text", "")}
+
if not corpus:
raise ValueError(f"No valid documents found in {corpus_file}")
-
+
return corpus
diff --git a/querygym/loaders/msmarco.py b/querygym/loaders/msmarco.py
index 45b4b22..86301fc 100644
--- a/querygym/loaders/msmarco.py
+++ b/querygym/loaders/msmarco.py
@@ -18,130 +18,119 @@
from ..data.dataloader import DataLoader
-def load_queries(
- queries_tsv: Union[str, Path]
-) -> List[QueryItem]:
+def load_queries(queries_tsv: Union[str, Path]) -> List[QueryItem]:
"""
Load queries from MS MARCO queries TSV file.
-
+
MS MARCO queries format: qid \\t query_text
-
+
Args:
queries_tsv: Path to MS MARCO queries.tsv file
-
+
Returns:
List of QueryItem objects
-
+
Example:
>>> from querygym.datasets import msmarco
>>> queries = msmarco.load_queries("./data/msmarco_queries.tsv")
"""
queries_tsv = Path(queries_tsv)
-
+
if not queries_tsv.exists():
raise FileNotFoundError(f"MS MARCO queries file not found: {queries_tsv}")
-
+
# MS MARCO uses standard TSV: qid \t query
- return DataLoader.load_queries(
- queries_tsv,
- format="tsv",
- qid_col=0,
- query_col=1
- )
+ return DataLoader.load_queries(queries_tsv, format="tsv", qid_col=0, query_col=1)
-def load_qrels(
- qrels_tsv: Union[str, Path]
-) -> Dict[str, Dict[str, int]]:
+def load_qrels(qrels_tsv: Union[str, Path]) -> Dict[str, Dict[str, int]]:
"""
Load qrels from MS MARCO qrels file.
-
+
MS MARCO qrels format can be either:
- TSV: qid \\t 0 \\t docid \\t relevance (TREC format)
- Or: qid \\t docid \\t relevance (simplified)
-
+
Args:
qrels_tsv: Path to MS MARCO qrels file
-
+
Returns:
Dict mapping qid -> {docid -> relevance}
-
+
Example:
>>> from querygym.datasets import msmarco
>>> qrels = msmarco.load_qrels("./data/msmarco_qrels.tsv")
"""
qrels_tsv = Path(qrels_tsv)
-
+
if not qrels_tsv.exists():
raise FileNotFoundError(f"MS MARCO qrels file not found: {qrels_tsv}")
-
+
# Try standard TREC format first
try:
return DataLoader.load_qrels(qrels_tsv, format="trec")
except ValueError:
# Fall back to simplified format (qid \t docid \t relevance)
qrels: Dict[str, Dict[str, int]] = {}
-
+
with open(qrels_tsv, "r", encoding="utf-8") as f:
for line in f:
parts = line.strip().split("\t")
if len(parts) < 3:
continue
-
+
qid = parts[0]
docid = parts[1]
try:
relevance = int(parts[2])
except ValueError:
continue
-
+
if qid not in qrels:
qrels[qid] = {}
qrels[qid][docid] = relevance
-
+
if not qrels:
raise ValueError(f"No valid qrels found in {qrels_tsv}")
-
+
return qrels
-def load_collection(
- collection_tsv: Union[str, Path]
-) -> Dict[str, str]:
+def load_collection(collection_tsv: Union[str, Path]) -> Dict[str, str]:
"""
Load MS MARCO passage/document collection.
-
+
MS MARCO collection format: docid \\t text
-
+
Args:
collection_tsv: Path to MS MARCO collection.tsv file
-
+
Returns:
Dict mapping docid -> text
-
+
Example:
>>> from querygym.datasets import msmarco
>>> collection = msmarco.load_collection("./data/collection.tsv")
>>> collection["doc123"]
"""
collection_tsv = Path(collection_tsv)
-
+
if not collection_tsv.exists():
raise FileNotFoundError(f"MS MARCO collection file not found: {collection_tsv}")
-
+
collection: Dict[str, str] = {}
-
+
with open(collection_tsv, "r", encoding="utf-8") as f:
for line in f:
parts = line.strip().split("\t")
if len(parts) < 2:
continue
-
+
docid = parts[0]
text = parts[1]
collection[docid] = text
-
+
if not collection:
raise ValueError(f"No valid documents found in {collection_tsv}")
-
+
return collection
diff --git a/querygym/methods/__init__.py b/querygym/methods/__init__.py
index ba9442e..363faaf 100644
--- a/querygym/methods/__init__.py
+++ b/querygym/methods/__init__.py
@@ -6,3 +6,4 @@
from .lamer import LameR
from .query2e import Query2E
from .csqe import CSQE
+from .thinkqe import ThinkQE
diff --git a/querygym/methods/csqe.py b/querygym/methods/csqe.py
index a4cec58..cce80af 100644
--- a/querygym/methods/csqe.py
+++ b/querygym/methods/csqe.py
@@ -4,20 +4,22 @@
from typing import List, Optional, Dict, Any
import re
+
@register_method("csqe")
class CSQE(BaseReformulator):
"""
CSQE (Context-based Sentence-level Query Expansion) method.
-
+
Combines KEQE (knowledge-based) and CSQE (context-based) expansions:
1. Generates N KEQE passages from LLM knowledge (no contexts) - default N=2
2. Generates N CSQE responses with retrieved contexts (few-shot) - default N=2
3. Extracts key sentences from CSQE responses
4. Concatenates: query × N + keqe_passages + csqe_sentences (newline-separated, lowercased)
-
+
Default: N=2 for both, totaling 4 generations (2 KEQE + 2 CSQE).
Query is repeated N times (equal to number of KEQE expansions).
"""
+
VERSION = "1.0"
REQUIRES_CONTEXT = True # Needs retrieved contexts for CSQE expansion
@@ -30,9 +32,11 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
retrieval_k = int(self.cfg.params.get("retrieval_k", 10))
prompt_ctxs = ctxs[:retrieval_k]
contexts_blob = "\n".join([f"{i+1}. {psg}" for i, psg in enumerate(prompt_ctxs)])
-
+
# Get generation parameters - N=2 for both by default (paper specification)
- n_expansions = int(self.cfg.params.get("gen_num", 2)) # Number of expansions for both KEQE and CSQE (default: 2)
+ n_expansions = int(
+ self.cfg.params.get("gen_num", 2)
+ ) # Number of expansions for both KEQE and CSQE (default: 2)
max_tokens = int(self.cfg.llm.get("max_tokens", 1024))
temperature = float(self.cfg.llm.get("temperature", 1.0))
@@ -44,7 +48,7 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
messages=msgs_keqe,
n=n_expansions,
temperature=temperature,
- max_tokens=max_tokens
+ max_tokens=max_tokens,
)
keqe_passages = [
choice.message.content.strip().strip('"').strip("'") or ""
@@ -59,12 +63,9 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
messages=msgs_csqe,
n=n_expansions,
temperature=temperature,
- max_tokens=max_tokens
+ max_tokens=max_tokens,
)
- csqe_responses = [
- choice.message.content or ""
- for choice in resp_csqe.choices
- ]
+ csqe_responses = [choice.message.content or "" for choice in resp_csqe.choices]
# Step 3: Extract key sentences from CSQE responses
# Pattern matches sentences in double quotes (from new_baselines.py)
csqe_sentences: List[str] = []
@@ -97,28 +98,30 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
def _extract_key_sentences(self, response: str) -> str:
"""
Extract key sentences from CSQE response.
-
+
Primary: Uses regex to find sentences in double quotes (expected format).
Fallback: If no quotes found, extracts content after numbered document markers (1., 2., etc.)
"""
# Primary strategy: Extract sentences in double quotes
pattern = r'"([^"]*)"'
sentences = re.findall(pattern, response)
-
+
if sentences:
# Found quoted sentences - join and return
joint_sentence = " ".join(sentences)
return joint_sentence
-
+
# Fallback: LLM didn't use quotes, extract from numbered document format
# Pattern: Look for content after "1.", "2.", etc. (with optional trailing colon)
# Remove "Relevant Documents:" header first
- cleaned = re.sub(r'^Relevant Documents?:?\s*\n?', '', response, flags=re.IGNORECASE | re.MULTILINE)
-
+ cleaned = re.sub(
+ r"^Relevant Documents?:?\s*\n?", "", response, flags=re.IGNORECASE | re.MULTILINE
+ )
+
# Extract content after numbered markers (1., 2., etc.)
- doc_pattern = r'\d+[\.:]\s*(.+?)(?=\d+[\.:]|$)'
+ doc_pattern = r"\d+[\.:]\s*(.+?)(?=\d+[\.:]|$)"
doc_content = re.findall(doc_pattern, cleaned, re.DOTALL)
-
+
if doc_content:
# Clean up each extracted piece
extracted = []
@@ -129,13 +132,13 @@ def _extract_key_sentences(self, response: str) -> str:
extracted.append(cleaned_content)
if extracted:
return " ".join(extracted)
-
+
# If still nothing found, return empty string
return ""
def _get_retrieval_params(self) -> Optional[Dict[str, Any]]:
"""Get CSQE-specific retrieval parameters for batch retrieval.
-
+
Returns retrieval parameters compatible with BaseSearcher interface.
Supports both searcher_type/searcher_kwargs format and legacy format.
"""
@@ -147,53 +150,53 @@ def _get_retrieval_params(self) -> Optional[Dict[str, Any]]:
"k": int(self.cfg.params.get("retrieval_k", 10)),
"threads": int(self.cfg.params.get("threads", 16)),
}
-
+
# Check if using new searcher interface format with searcher_type
if "searcher_type" in self.cfg.params:
# New format: use searcher_type and searcher_kwargs
searcher_type = self.cfg.params.get("searcher_type", "pyserini")
searcher_kwargs = self.cfg.params.get("searcher_kwargs", {})
-
+
# If index is provided in params but not in searcher_kwargs, add it
if "index" in self.cfg.params and "index" not in searcher_kwargs:
searcher_kwargs["index"] = self.cfg.params["index"]
-
+
return {
"searcher_type": searcher_type,
"searcher_kwargs": searcher_kwargs,
"k": int(self.cfg.params.get("retrieval_k", 10)),
"threads": int(self.cfg.params.get("threads", 16)),
}
-
+
# Legacy format: convert old-style params to new format
index = self.cfg.params.get("index")
if not index:
return None
-
+
# Build searcher_kwargs from legacy params
searcher_kwargs = {
"index": index,
"searcher_type": "impact" if self.cfg.params.get("impact", False) else "bm25",
"answer_key": self.cfg.params.get("answer_key", "contents"),
}
-
+
# Add BM25 parameters if specified
k1 = self.cfg.params.get("k1")
b = self.cfg.params.get("b")
if k1 is not None and b is not None:
searcher_kwargs["k1"] = k1
searcher_kwargs["b"] = b
-
+
# Add RM3 if requested
if self.cfg.params.get("rm3", False):
searcher_kwargs["rm3"] = True
-
+
# Add Rocchio if requested
if self.cfg.params.get("rocchio", False):
searcher_kwargs["rocchio"] = True
if self.cfg.params.get("rocchio_use_negative", False):
searcher_kwargs["rocchio_use_negative"] = True
-
+
return {
"searcher_type": "pyserini", # Default to pyserini for legacy format
"searcher_kwargs": searcher_kwargs,
diff --git a/querygym/methods/genqr.py b/querygym/methods/genqr.py
index 30a1fff..6c947f4 100644
--- a/querygym/methods/genqr.py
+++ b/querygym/methods/genqr.py
@@ -3,15 +3,17 @@
from ..core.registry import register_method
from typing import List
+
@register_method("genqr")
class GENQR(BaseReformulator):
"""
GenQR: Generate reformulations N times and concatenate.
-
+
Pipeline:
1. Call LLM N times (default 5) to generate reformulations
2. Concatenate: query + reformulation1 + reformulation2 + ... + reformulationN
"""
+
VERSION = "1.0"
def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
@@ -19,26 +21,23 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
n_generations = int(self.cfg.params.get("n_generations", 5))
temperature = float(self.cfg.llm.get("temperature", 0.8))
max_tokens = int(self.cfg.llm.get("max_tokens", 256))
-
+
# Generate reformulations N times (5 by default)
reformulations: List[str] = []
-
+
for _ in range(n_generations):
msgs = self.prompts.render("genqr.keywords.v1", query=q.text)
out = self.llm.chat(msgs, temperature=temperature, max_tokens=max_tokens)
# Clean up the reformulation (remove quotes, extra whitespace)
reformulation = out.strip().strip('"').strip("'").replace("\n", " ")
reformulations.append(reformulation)
-
+
# Concatenate: query + reformulation1 + reformulation2 + ... + reformulationN
reformulated = q.text + " " + " ".join(reformulations)
-
+
return ReformulationResult(
- q.qid,
- q.text,
+ q.qid,
+ q.text,
reformulated,
- metadata={
- "n_generations": n_generations,
- "reformulations": reformulations
- }
+ metadata={"n_generations": n_generations, "reformulations": reformulations},
)
diff --git a/querygym/methods/genqr_ensemble.py b/querygym/methods/genqr_ensemble.py
index 1883bd4..78aa2e8 100644
--- a/querygym/methods/genqr_ensemble.py
+++ b/querygym/methods/genqr_ensemble.py
@@ -17,28 +17,30 @@
"genqr_ensemble.inst10.v1",
]
+
@register_method("genqr_ensemble")
class GenQREnsemble(BaseReformulator):
"""
GenQREnsemble: Generative Query Reformulation with Ensemble of Instructions.
-
+
Uses 10 instruction variants to generate diverse keyword expansions.
Each variant generates keywords independently, then all are merged.
-
+
Pipeline:
1. For each of 10 instruction variants, generate keywords (10 LLM calls)
2. Parse and merge all keyword lists
3. Expand query: (Q × 5) + K1 + K2 + ... + K10
-
+
Sampling Parameters (from paper):
- temperature: 0.92 (nucleus sampling)
- max_tokens: 256
- parallel: false (can be enabled for concurrent generation)
-
+
Config Parameters:
- variant_ids: List of prompt IDs (defaults to all 10 variants)
- parallel: Enable parallel generation of all variants (default: false)
"""
+
VERSION = "1.0"
CONCATENATION_STRATEGY = "query_repeat_plus_generated" # (Q × 5) + K1 + K2 + ... + K10
DEFAULT_QUERY_REPEATS = 5 # GenQREnsemble uses 5 repetitions
@@ -46,34 +48,34 @@ class GenQREnsemble(BaseReformulator):
def _parse_keywords(self, keywords_text: str) -> List[str]:
"""
Parse keywords from LLM output (comma-separated or bullet list).
-
+
Args:
keywords_text: Raw LLM output
-
+
Returns:
List of individual keywords
"""
keywords = []
-
+
# Check if it's a bullet list (contains dashes/newlines)
- if '\n' in keywords_text or keywords_text.strip().startswith('-'):
+ if "\n" in keywords_text or keywords_text.strip().startswith("-"):
# Split by newlines and extract after dashes
- for line in keywords_text.split('\n'):
+ for line in keywords_text.split("\n"):
line = line.strip()
# Remove leading dash/bullet
- if line.startswith('-'):
+ if line.startswith("-"):
line = line[1:].strip()
- elif line.startswith('•'):
+ elif line.startswith("•"):
line = line[1:].strip()
- elif line.startswith('*'):
+ elif line.startswith("*"):
line = line[1:].strip()
-
+
if line:
keywords.append(line)
else:
# Split by comma (standard format)
- keywords = [k.strip() for k in keywords_text.split(',') if k.strip()]
-
+ keywords = [k.strip() for k in keywords_text.split(",") if k.strip()]
+
return keywords
def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
@@ -82,24 +84,24 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
temperature = float(self.cfg.llm.get("temperature", 0.92))
max_tokens = int(self.cfg.llm.get("max_tokens", 256))
parallel = bool(self.cfg.params.get("parallel", False))
-
+
# Get variant IDs (configurable, defaults to all 10)
variant_ids = self.cfg.params.get("variant_ids", VARIANT_IDS)
-
+
# Generate keywords for each instruction variant
all_keywords = []
variant_outputs = {}
-
+
if parallel:
# Parallel generation using ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed
-
+
def generate_keywords_for_variant(idx: int, prompt_id: str):
msgs = self.prompts.render(prompt_id, query=q.text)
output = self.llm.chat(msgs, temperature=temperature, max_tokens=max_tokens)
keywords = self._parse_keywords(output)
return idx, prompt_id, output, keywords
-
+
# Generate all variants in parallel
with ThreadPoolExecutor(max_workers=len(variant_ids)) as executor:
futures = [
@@ -113,7 +115,7 @@ def generate_keywords_for_variant(idx: int, prompt_id: str):
"prompt_id": prompt_id,
"raw_output": output,
"keywords": keywords,
- "count": len(keywords)
+ "count": len(keywords),
}
else:
# Sequential generation (default)
@@ -121,28 +123,28 @@ def generate_keywords_for_variant(idx: int, prompt_id: str):
# Render prompt and generate keywords
msgs = self.prompts.render(prompt_id, query=q.text)
output = self.llm.chat(msgs, temperature=temperature, max_tokens=max_tokens)
-
+
# Parse keywords from output
keywords = self._parse_keywords(output)
all_keywords.extend(keywords)
-
+
# Store per-variant output for metadata
variant_outputs[f"variant_{i}"] = {
"prompt_id": prompt_id,
"raw_output": output,
"keywords": keywords,
- "count": len(keywords)
+ "count": len(keywords),
}
-
+
# Merge all keywords into single string
generated_content = " ".join(all_keywords)
-
+
# Use base class concatenation: (Q × 5) + generated_content
reformulated = self.concatenate_result(q.text, generated_content)
-
+
return ReformulationResult(
- q.qid,
- q.text,
+ q.qid,
+ q.text,
reformulated,
metadata={
"num_variants": len(variant_ids),
@@ -151,6 +153,6 @@ def generate_keywords_for_variant(idx: int, prompt_id: str):
"variant_outputs": variant_outputs,
"temperature": temperature,
"max_tokens": max_tokens,
- "parallel": parallel
- }
+ "parallel": parallel,
+ },
)
diff --git a/querygym/methods/lamer.py b/querygym/methods/lamer.py
index e9afc65..3f4d74b 100644
--- a/querygym/methods/lamer.py
+++ b/querygym/methods/lamer.py
@@ -3,11 +3,14 @@
from ..core.registry import register_method
from typing import List, Optional, Dict, Any
+
@register_method("lamer")
class LameR(BaseReformulator):
VERSION = "0.1"
REQUIRES_CONTEXT = True
- CONCATENATION_STRATEGY = "interleaved_query_content" # q + a1 + q + a2 + q + a3 + q + a4 + q + a5
+ CONCATENATION_STRATEGY = (
+ "interleaved_query_content" # q + a1 + q + a2 + q + a3 + q + a4 + q + a5
+ )
def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
# Use provided contexts (from batch retrieval)
@@ -48,7 +51,7 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
def _get_retrieval_params(self) -> Optional[Dict[str, Any]]:
"""Get LameR-specific retrieval parameters for batch retrieval.
-
+
Returns retrieval parameters compatible with BaseSearcher interface.
Supports both searcher_type/searcher_kwargs format and legacy format.
"""
@@ -60,53 +63,53 @@ def _get_retrieval_params(self) -> Optional[Dict[str, Any]]:
"k": int(self.cfg.params.get("retrieval_k", 10)),
"threads": int(self.cfg.params.get("threads", 16)),
}
-
+
# Check if using new searcher interface format with searcher_type
if "searcher_type" in self.cfg.params:
# New format: use searcher_type and searcher_kwargs
searcher_type = self.cfg.params.get("searcher_type", "pyserini")
searcher_kwargs = self.cfg.params.get("searcher_kwargs", {})
-
+
# If index is provided in params but not in searcher_kwargs, add it
if "index" in self.cfg.params and "index" not in searcher_kwargs:
searcher_kwargs["index"] = self.cfg.params["index"]
-
+
return {
"searcher_type": searcher_type,
"searcher_kwargs": searcher_kwargs,
"k": int(self.cfg.params.get("retrieval_k", 10)),
"threads": int(self.cfg.params.get("threads", 16)),
}
-
+
# Legacy format: convert old-style params to new format
index = self.cfg.params.get("index")
if not index:
return None
-
+
# Build searcher_kwargs from legacy params
searcher_kwargs = {
"index": index,
"searcher_type": "impact" if self.cfg.params.get("impact", False) else "bm25",
"answer_key": self.cfg.params.get("answer_key", "contents"),
}
-
+
# Add BM25 parameters if specified
k1 = self.cfg.params.get("k1")
b = self.cfg.params.get("b")
if k1 is not None and b is not None:
searcher_kwargs["k1"] = k1
searcher_kwargs["b"] = b
-
+
# Add RM3 if requested
if self.cfg.params.get("rm3", False):
searcher_kwargs["rm3"] = True
-
+
# Add Rocchio if requested
if self.cfg.params.get("rocchio", False):
searcher_kwargs["rocchio"] = True
if self.cfg.params.get("rocchio_use_negative", False):
searcher_kwargs["rocchio_use_negative"] = True
-
+
return {
"searcher_type": "pyserini", # Default to pyserini for legacy format
"searcher_kwargs": searcher_kwargs,
diff --git a/querygym/methods/mugi.py b/querygym/methods/mugi.py
index e8b8a7e..82b5996 100644
--- a/querygym/methods/mugi.py
+++ b/querygym/methods/mugi.py
@@ -3,14 +3,15 @@
from ..core.base import BaseReformulator, QueryItem, ReformulationResult
from ..core.registry import register_method
+
@register_method("mugi")
class MuGI(BaseReformulator):
"""
MuGI (Multi-Text Generation Integration) Method
-
+
Generates multiple diverse pseudo-documents per query and uses adaptive
concatenation to balance term frequencies between short queries and long pseudo-docs.
-
+
Key Parameters:
- num_docs: Number of pseudo-documents to generate per query (default: 5)
- adaptive_times: Divisor for adaptive repetition ratio (default: 5)
@@ -19,11 +20,12 @@ class MuGI(BaseReformulator):
- mode: "zs" for zero-shot or "fs" for few-shot (default: "zs")
- prompt_id: (Advanced) Direct prompt ID override ("mugi.zeroshot.v1" or "mugi.fewshot.v1")
- parallel: If True, generate all pseudo-docs in parallel; if False, sequential (default: False)
-
+
Formula:
repetition_times = (len(all_pseudo_docs) // len(query)) // adaptive_times
enhanced_query = (query + ' ') * repetition_times + all_pseudo_docs
"""
+
VERSION = "1.0"
REQUIRES_CONTEXT = False
CONCATENATION_STRATEGY = "adaptive_query_repeat_plus_generated"
@@ -33,12 +35,14 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
num_docs = int(self.cfg.params.get("num_docs", 5))
adaptive_times = int(self.cfg.params.get("adaptive_times", 6))
max_tokens = int(self.cfg.params.get("max_tokens", self.cfg.llm.get("max_tokens", 1024)))
- temperature = float(self.cfg.params.get("temperature", self.cfg.llm.get("temperature", 1.0)))
+ temperature = float(
+ self.cfg.params.get("temperature", self.cfg.llm.get("temperature", 1.0))
+ )
parallel = bool(self.cfg.params.get("parallel", False))
-
+
# Mode selection: "zs" (zero-shot) or "fs" (few-shot)
mode = str(self.cfg.params.get("mode", "zs"))
-
+
# Map mode to prompt_id (or allow direct prompt_id override for advanced users)
if "prompt_id" in self.cfg.params:
# Advanced: allow direct prompt ID specification
@@ -50,20 +54,22 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
elif mode in ["zs", "zeroshot"]:
prompt_id = "mugi.zeroshot.v1"
else:
- raise ValueError(f"Invalid mode '{mode}' for MuGI. Must be 'zs' (zero-shot) or 'fs' (few-shot).")
-
+ raise ValueError(
+ f"Invalid mode '{mode}' for MuGI. Must be 'zs' (zero-shot) or 'fs' (few-shot)."
+ )
+
# Generate multiple pseudo-documents for diversity and coverage
pseudo_docs: List[str] = []
-
+
if parallel:
# Parallel generation using ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed
-
+
def generate_single_doc(doc_idx: int) -> str:
msgs = self.prompts.render(prompt_id, query=q.text)
pseudo_doc = self.llm.chat(msgs, temperature=temperature, max_tokens=max_tokens)
return pseudo_doc.strip().strip('"').strip("'")
-
+
# Generate all pseudo-docs in parallel
with ThreadPoolExecutor(max_workers=num_docs) as executor:
futures = [executor.submit(generate_single_doc, i) for i in range(num_docs)]
@@ -74,25 +80,21 @@ def generate_single_doc(doc_idx: int) -> str:
for i in range(num_docs):
# Render prompt (mugi.zeroshot.v1 or mugi.fewshot.v1)
msgs = self.prompts.render(prompt_id, query=q.text)
-
+
# Generate pseudo-document with high temperature for diversity
- pseudo_doc = self.llm.chat(
- msgs,
- temperature=temperature,
- max_tokens=max_tokens
- )
-
+ pseudo_doc = self.llm.chat(msgs, temperature=temperature, max_tokens=max_tokens)
+
# Clean up the generated pseudo-document
pseudo_doc = pseudo_doc.strip().strip('"').strip("'")
pseudo_docs.append(pseudo_doc)
-
+
# Join all pseudo-documents into single text
- all_pseudo_docs = ' '.join(pseudo_docs)
-
+ all_pseudo_docs = " ".join(pseudo_docs)
+
# Use base class concatenation with adaptive_query_repeat_plus_generated strategy
# This automatically handles the MuGI formula and cleaning
reformulated = self.concatenate_result(q.text, all_pseudo_docs)
-
+
# Calculate metrics for metadata (after concatenation)
query_len = len(q.text)
docs_len = len(all_pseudo_docs)
@@ -100,7 +102,7 @@ def generate_single_doc(doc_idx: int) -> str:
repetition_times = max(1, (docs_len // query_len) // adaptive_times)
else:
repetition_times = 1
-
+
# Return result with metadata
return ReformulationResult(
q.qid,
@@ -122,5 +124,5 @@ def generate_single_doc(doc_idx: int) -> str:
"pseudo_doc_3": pseudo_docs[2] if len(pseudo_docs) > 2 else "",
"pseudo_doc_4": pseudo_docs[3] if len(pseudo_docs) > 3 else "",
"pseudo_doc_5": pseudo_docs[4] if len(pseudo_docs) > 4 else "",
- }
+ },
)
diff --git a/querygym/methods/qa_expand.py b/querygym/methods/qa_expand.py
index 5061bec..7a97a00 100644
--- a/querygym/methods/qa_expand.py
+++ b/querygym/methods/qa_expand.py
@@ -4,6 +4,7 @@
from ..core.base import BaseReformulator, QueryItem, ReformulationResult
from ..core.registry import register_method
+
def parse_llm_json(text: str) -> dict:
"""
Robust JSON parser for LLM outputs that handles common formatting issues:
@@ -16,83 +17,84 @@ def parse_llm_json(text: str) -> dict:
"""
if not text or not text.strip():
return {}
-
+
original_text = text
text = text.strip()
-
+
# Remove markdown code blocks
if text.startswith("```"):
- first_newline = text.find('\n')
+ first_newline = text.find("\n")
if first_newline != -1:
- text = text[first_newline+1:]
+ text = text[first_newline + 1 :]
if text.endswith("```"):
text = text[:-3]
text = text.strip()
-
+
# Unescape common escape sequences that appear in raw output
- text = text.replace('\\"', '"').replace('\\n', '\n').replace('\\t', '\t')
-
+ text = text.replace('\\"', '"').replace("\\n", "\n").replace("\\t", "\t")
+
# Try standard JSON parsing first
try:
return json.loads(text)
except json.JSONDecodeError:
pass
-
+
# Remove trailing commas before closing braces/brackets
- cleaned = re.sub(r',(\s*[}\]])', r'\1', text)
+ cleaned = re.sub(r",(\s*[}\]])", r"\1", text)
try:
return json.loads(cleaned)
except json.JSONDecodeError:
pass
-
+
# Replace single quotes with double quotes
try:
return json.loads(text.replace("'", '"'))
except json.JSONDecodeError:
pass
-
+
# Try to fix incomplete JSON by adding missing closing quotes and braces
- if '{' in text:
+ if "{" in text:
# Find the last complete value and try to close it
attempt = text.rstrip()
-
+
# If ends mid-string (no closing quote), add closing quote
if attempt.count('"') % 2 == 1:
attempt += '"'
-
+
# Add missing closing brace if needed
- if not attempt.endswith('}'):
- attempt += '}'
-
+ if not attempt.endswith("}"):
+ attempt += "}"
+
try:
return json.loads(attempt)
except json.JSONDecodeError:
pass
-
+
# Fallback: Extract key-value pairs using regex
# Looks for patterns like "key": "value" or 'key': 'value'
result = {}
-
+
# Pattern to match JSON key-value pairs
patterns = [
r'["\'](\w+)["\']\s*:\s*["\']([^"\']*)["\']', # Standard JSON
- r'(\w+)\s*:\s*["\']([^"\']*)["\']', # Unquoted keys
- r'["\'](\w+)["\']\s*:\s*([^,}\]]+)', # Values without quotes
+ r'(\w+)\s*:\s*["\']([^"\']*)["\']', # Unquoted keys
+ r'["\'](\w+)["\']\s*:\s*([^,}\]]+)', # Values without quotes
]
-
+
for pattern in patterns:
matches = re.findall(pattern, text, re.DOTALL)
for key, value in matches:
if key and value:
result[key] = value.strip()
-
+
# If we found any key-value pairs, return them
if result:
return result
-
+
# Absolute fallback: return empty dict
return {}
+
@register_method("qa_expand")
class QAExpand(BaseReformulator):
"""
@@ -102,6 +104,7 @@ class QAExpand(BaseReformulator):
3. Filter and refine answers based on relevance to query (LLM call #3)
4. Expand query: (Q × 3) + refined_answers
"""
+
VERSION = "1.0"
REQUIRES_CONTEXT = False # No contexts needed - pure generative expansion
CONCATENATION_STRATEGY = "query_repeat_plus_generated" # (Q × 3) + refined_answers
@@ -110,82 +113,86 @@ class QAExpand(BaseReformulator):
def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
# K is fixed at 3 for QA-Expand
k = 3
-
+
# Get configurable parameters
max_tokens = self.cfg.params.get("max_tokens", self.cfg.llm.get("max_tokens", 256))
default_temp = self.cfg.llm.get("temperature", 0.8)
-
+
# Temperature settings (can override per-step, defaults to global temperature)
temp_subq = self.cfg.params.get("temperature_subq", default_temp)
temp_answer = self.cfg.params.get("temperature_answer", default_temp)
temp_refine = self.cfg.params.get("temperature_refine", default_temp)
-
+
# Prompt IDs (customizable for different datasets)
prompt_subq = self.cfg.params.get("prompt_subq", "qa_expand.subq.v1")
prompt_answer = self.cfg.params.get("prompt_answer", "qa_expand.answer.v1")
prompt_refine = self.cfg.params.get("prompt_refine", "qa_expand.refine.v1")
-
+
# Step 1: Generate 3 sub-questions (LLM call #1)
# Input: {query}
# Output: Raw text with 3 questions
msgs = self.prompts.render(prompt_subq, query=q.text)
subqs_raw = self.llm.chat(msgs, temperature=temp_subq, max_tokens=max_tokens)
-
+
# Format questions as JSON for next step
# Try to parse if LLM already returned JSON, otherwise split by newlines
questions_dict = parse_llm_json(subqs_raw)
-
- if questions_dict and any(k.startswith('question') for k in questions_dict.keys()):
+
+ if questions_dict and any(k.startswith("question") for k in questions_dict.keys()):
# LLM returned JSON format already
questions_json = json.dumps(questions_dict)
else:
# Parse from text - split by newlines and clean up
- lines = [line.strip() for line in subqs_raw.split('\n') if line.strip()]
+ lines = [line.strip() for line in subqs_raw.split("\n") if line.strip()]
# Remove common prefixes like "1.", "Q1:", "-", etc.
cleaned_lines = []
for line in lines:
- line = re.sub(r'^[\d\.\-\*\)]+\s*', '', line) # Remove "1.", "1)", "-", "*"
- line = re.sub(r'^[Qq]\d+[\:\.]?\s*', '', line) # Remove "Q1:", "q1."
+ line = re.sub(r"^[\d\.\-\*\)]+\s*", "", line) # Remove "1.", "1)", "-", "*"
+ line = re.sub(r"^[Qq]\d+[\:\.]?\s*", "", line) # Remove "Q1:", "q1."
if line:
cleaned_lines.append(line)
-
+
# Take first 3 questions
- questions_json = json.dumps({
- "question1": cleaned_lines[0] if len(cleaned_lines) > 0 else "",
- "question2": cleaned_lines[1] if len(cleaned_lines) > 1 else "",
- "question3": cleaned_lines[2] if len(cleaned_lines) > 2 else ""
- })
-
+ questions_json = json.dumps(
+ {
+ "question1": cleaned_lines[0] if len(cleaned_lines) > 0 else "",
+ "question2": cleaned_lines[1] if len(cleaned_lines) > 1 else "",
+ "question3": cleaned_lines[2] if len(cleaned_lines) > 2 else "",
+ }
+ )
+
# Step 2: Generate 3 pseudo-answers from LLM's internal knowledge (LLM call #2)
# Input: {questions} (JSON format)
# Output: JSON with {"answer1": "...", "answer2": "...", "answer3": "..."}
# NOTE: No contexts used - LLM generates answers from its own knowledge
msgs = self.prompts.render(prompt_answer, questions=questions_json)
answers_json = self.llm.chat(msgs, temperature=temp_answer, max_tokens=max_tokens)
-
+
# Step 3: Filter and refine answers (LLM call #3)
# Input: {query}, {answers} (JSON format)
# Output: Refined JSON with relevant answers only
msgs = self.prompts.render(prompt_refine, query=q.text, answers=answers_json)
refined_answers_json = self.llm.chat(msgs, temperature=temp_refine, max_tokens=max_tokens)
-
+
# Parse refined JSON and extract answer values for concatenation
refined_dict = parse_llm_json(refined_answers_json)
if refined_dict:
# Extract answer values in order (answer1, answer2, answer3)
- answer_values = [str(v) for k, v in sorted(refined_dict.items()) if v and str(v).strip()]
+ answer_values = [
+ str(v) for k, v in sorted(refined_dict.items()) if v and str(v).strip()
+ ]
refined_text = " ".join(answer_values)
else:
# Fallback if JSON parsing fails - use raw text
refined_text = refined_answers_json
-
+
# Step 4: Expand query: (Q × 3) + refined_answers
reformulated = self.concatenate_result(q.text, refined_text, query_repeats=k)
-
+
return ReformulationResult(
- q.qid,
- q.text,
- reformulated,
+ q.qid,
+ q.text,
+ reformulated,
metadata={
"subquestions_raw": subqs_raw,
"questions_json": questions_json,
@@ -195,7 +202,7 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
"prompts_used": {
"subq": prompt_subq,
"answer": prompt_answer,
- "refine": prompt_refine
- }
- }
+ "refine": prompt_refine,
+ },
+ },
)
diff --git a/querygym/methods/query2doc.py b/querygym/methods/query2doc.py
index ab6820c..c826de5 100644
--- a/querygym/methods/query2doc.py
+++ b/querygym/methods/query2doc.py
@@ -6,88 +6,90 @@
import os
import csv
+
@register_method("query2doc")
class Query2Doc(BaseReformulator):
"""
Query2Doc: Generate pseudo-documents from LLM knowledge.
-
+
Modes:
- "zs" (zero-shot): Generate passage directly [default]
- "cot" (chain-of-thought): Zero-shot with reasoning
- "fs", "fewshot", or "few-shot" (few-shot): Uses training examples from any dataset (dynamic)
-
+
Few-Shot Config (via params or env vars):
- dataset_type: "msmarco", "beir", or "generic" (uses appropriate loader)
- num_examples: Number of few-shot examples (default: 4)
-
+
For MS MARCO datasets (dataset_type="msmarco"):
- collection_path / COLLECTION_PATH: Path to collection.tsv
- train_queries_path / TRAIN_QUERIES_PATH: Path to queries.tsv
- train_qrels_path / TRAIN_QRELS_PATH: Path to qrels file
-
+
For BEIR datasets (dataset_type="beir"):
- beir_data_dir / BEIR_DATA_DIR: Path to BEIR dataset directory
- train_split: "train" or "dev" (default: "train")
-
+
For generic datasets (dataset_type="generic" or omitted):
- collection_path: TSV file (docid \t text)
- train_queries_path: TSV file (qid \t query)
- train_qrels_path: TREC format (qid 0 docid relevance)
-
+
Note: MS MARCO env vars (MSMARCO_COLLECTION, etc.) are supported for backward compatibility.
"""
+
VERSION = "1.0"
CONCATENATION_STRATEGY = "query_repeat_plus_generated"
DEFAULT_QUERY_REPEATS = 5
-
+
def __init__(self, cfg, llm_client, prompt_resolver):
super().__init__(cfg, llm_client, prompt_resolver)
- self._fewshot_data = None
-
+ self._fewshot_data = None
+
def _load_fewshot_data(self):
"""Lazy load training data for few-shot mode (supports MS MARCO, BEIR, or generic datasets)."""
if self._fewshot_data is not None:
return self._fewshot_data
-
+
try:
# Get dataset type (msmarco, beir, or generic)
dataset_type = self.cfg.params.get("dataset_type", "").lower()
-
+
# Get paths from config params or environment variables
collection_path = (
- self.cfg.params.get("collection_path") or
- self.cfg.params.get("msmarco_collection") or
- os.getenv("COLLECTION_PATH") or
- os.getenv("MSMARCO_COLLECTION")
+ self.cfg.params.get("collection_path")
+ or self.cfg.params.get("msmarco_collection")
+ or os.getenv("COLLECTION_PATH")
+ or os.getenv("MSMARCO_COLLECTION")
)
train_queries_path = (
- self.cfg.params.get("train_queries_path") or
- self.cfg.params.get("msmarco_train_queries") or
- os.getenv("TRAIN_QUERIES_PATH") or
- os.getenv("MSMARCO_TRAIN_QUERIES")
+ self.cfg.params.get("train_queries_path")
+ or self.cfg.params.get("msmarco_train_queries")
+ or os.getenv("TRAIN_QUERIES_PATH")
+ or os.getenv("MSMARCO_TRAIN_QUERIES")
)
train_qrels_path = (
- self.cfg.params.get("train_qrels_path") or
- self.cfg.params.get("msmarco_train_qrels") or
- os.getenv("TRAIN_QRELS_PATH") or
- os.getenv("MSMARCO_TRAIN_QRELS")
+ self.cfg.params.get("train_qrels_path")
+ or self.cfg.params.get("msmarco_train_qrels")
+ or os.getenv("TRAIN_QRELS_PATH")
+ or os.getenv("MSMARCO_TRAIN_QRELS")
)
-
+
# For BEIR, collection_path is actually the BEIR data directory
if dataset_type == "beir":
beir_data_dir = (
- self.cfg.params.get("beir_data_dir") or
- collection_path or
- os.getenv("BEIR_DATA_DIR")
+ self.cfg.params.get("beir_data_dir")
+ or collection_path
+ or os.getenv("BEIR_DATA_DIR")
)
train_split = self.cfg.params.get("train_split", "train")
-
+
if not beir_data_dir:
raise RuntimeError(
"Few-shot mode with BEIR requires beir_data_dir (via config or env var):\n"
" - beir_data_dir / BEIR_DATA_DIR: Path to BEIR dataset directory"
)
-
+
elif not all([collection_path, train_queries_path, train_qrels_path]):
raise RuntimeError(
"Few-shot mode requires training data paths (via config params or env vars):\n"
@@ -102,10 +104,11 @@ def _load_fewshot_data(self):
"\nFor backward compatibility, MS MARCO env vars are also supported:\n"
" - MSMARCO_COLLECTION, MSMARCO_TRAIN_QUERIES, MSMARCO_TRAIN_QRELS"
)
-
+
# Load data using appropriate loader based on dataset type
if dataset_type == "beir":
from ..loaders import beir
+
corpus = beir.load_corpus(beir_data_dir)
# Convert BEIR corpus format (dict with title/text) to simple text
collection = {}
@@ -114,47 +117,48 @@ def _load_fewshot_data(self):
title = doc_dict.get("title", "").strip()
text = doc_dict.get("text", "").strip()
collection[docid] = f"{title} {text}".strip() if title else text
-
+
train_queries_list = beir.load_queries(beir_data_dir)
train_queries_dict = {q.qid: q.text for q in train_queries_list}
train_qrels = beir.load_qrels(beir_data_dir, split=train_split)
-
+
elif dataset_type == "msmarco":
from ..loaders import msmarco
+
collection = msmarco.load_collection(collection_path)
train_queries_list = msmarco.load_queries(train_queries_path)
train_queries_dict = {q.qid: q.text for q in train_queries_list}
train_qrels = msmarco.load_qrels(train_qrels_path)
-
+
else:
# Generic: Load using DataLoader for maximum flexibility
from ..data.dataloader import DataLoader
-
+
# Load collection (TSV format: docid \t text)
collection = {}
- with open(collection_path, 'r', encoding='utf-8') as f:
- reader = csv.reader(f, delimiter='\t')
+ with open(collection_path, "r", encoding="utf-8") as f:
+ reader = csv.reader(f, delimiter="\t")
for row in reader:
if len(row) >= 2:
docid, text = row[0], row[1]
collection[docid] = text
-
+
# Load queries and qrels
train_queries_list = DataLoader.load_queries(train_queries_path, format="tsv")
train_queries_dict = {q.qid: q.text for q in train_queries_list}
train_qrels = DataLoader.load_qrels(train_qrels_path, format="trec")
-
+
self._fewshot_data = {
"collection": collection,
"train_queries": train_queries_dict,
- "train_qrels": train_qrels
+ "train_qrels": train_qrels,
}
-
+
return self._fewshot_data
-
+
except Exception as e:
raise RuntimeError(f"Failed to load few-shot training data: {e}")
-
+
def _select_few_shot_examples(self, num_examples: int = 4) -> List[Tuple[str, str]]:
"""Randomly sample relevant query-passage pairs from training data."""
try:
@@ -162,43 +166,43 @@ def _select_few_shot_examples(self, num_examples: int = 4) -> List[Tuple[str, st
collection = data["collection"]
train_queries = data["train_queries"]
train_qrels = data["train_qrels"]
-
+
# Collect all relevant query-doc pairs
relevant_pairs = []
for qid, doc_rels in train_qrels.items():
for docid, relevance in doc_rels.items():
if relevance > 0: # Only relevant
relevant_pairs.append((qid, docid))
-
+
if not relevant_pairs:
raise RuntimeError("No relevant query-document pairs found")
-
+
# Sample pairs and fetch texts
sample_size = min(num_examples * 10, len(relevant_pairs))
sampled_pairs = random.sample(relevant_pairs, sample_size)
-
+
examples = []
for qid, docid in sampled_pairs:
if len(examples) >= num_examples:
break
-
+
query_text = train_queries.get(qid)
doc_text = collection.get(docid)
-
+
if query_text and doc_text:
examples.append((query_text, doc_text))
-
+
if not examples:
raise RuntimeError("Could not find valid examples with matching IDs")
-
+
if len(examples) < num_examples:
print(f"Warning: Only found {len(examples)}/{num_examples} examples")
-
+
return examples
-
+
except Exception as e:
raise RuntimeError(f"Failed to select few-shot examples: {e}")
-
+
def _format_examples(self, examples: List[Tuple[str, str]]) -> str:
"""Format examples for prompt template."""
examples_text = ""
@@ -215,9 +219,9 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
mode = "fs"
temperature = float(self.cfg.llm.get("temperature", 0.7))
max_tokens = int(self.cfg.llm.get("max_tokens", 256))
-
+
metadata = {"mode": mode}
-
+
try:
# Select prompt based on mode
if mode == "fs":
@@ -225,8 +229,10 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
num_examples = int(self.cfg.params.get("num_examples", 4))
examples = self._select_few_shot_examples(num_examples)
examples_text = self._format_examples(examples)
-
- msgs = self.prompts.render("query2doc.fewshot.v1", query=q.text, examples=examples_text)
+
+ msgs = self.prompts.render(
+ "query2doc.fewshot.v1", query=q.text, examples=examples_text
+ )
metadata["prompt_id"] = "query2doc.fewshot.v1"
metadata["num_examples"] = len(examples)
else:
@@ -234,30 +240,27 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
prompt_id = "query2doc.cot.v1" if mode == "cot" else "query2doc.zeroshot.v1"
msgs = self.prompts.render(prompt_id, query=q.text)
metadata["prompt_id"] = prompt_id
-
+
# Generate pseudo-document using LLM
pseudo_doc = self.llm.chat(msgs, temperature=temperature, max_tokens=max_tokens)
-
+
if not pseudo_doc or not pseudo_doc.strip():
raise RuntimeError("LLM returned empty response")
-
+
# Use base class concatenation (query_repeat_plus_generated)
reformulated = self.concatenate_result(q.text, pseudo_doc)
-
- metadata.update({
- "pseudo_doc": pseudo_doc,
- "temperature": temperature,
- "max_tokens": max_tokens
- })
-
+
+ metadata.update(
+ {"pseudo_doc": pseudo_doc, "temperature": temperature, "max_tokens": max_tokens}
+ )
+
return ReformulationResult(q.qid, q.text, reformulated, metadata=metadata)
-
+
except Exception as e:
# Graceful error handling - fallback to original query
error_msg = f"Query2Doc failed: {e}"
print(f"Error for qid={q.qid}: {error_msg}")
-
+
return ReformulationResult(
- q.qid, q.text, q.text,
- metadata={"mode": mode, "error": error_msg, "fallback": True}
+ q.qid, q.text, q.text, metadata={"mode": mode, "error": error_msg, "fallback": True}
)
diff --git a/querygym/methods/query2e.py b/querygym/methods/query2e.py
index 5deee01..c2019ee 100644
--- a/querygym/methods/query2e.py
+++ b/querygym/methods/query2e.py
@@ -6,42 +6,44 @@
import os
import csv
+
@register_method("query2e")
class Query2E(BaseReformulator):
"""
Query2E: Query to keyword expansion.
-
+
Modes:
- zs (zero-shot): Simple keyword generation
- fs (few-shot): Uses training examples for keyword generation
-
+
Formula: (query × 5) + keywords
-
+
Few-Shot Examples (3 ways to provide):
1. Via --ctx-jsonl CLI flag (JSONL file with {"query": "...", "passage": "..."} per line)
2. Via params["examples"] (list of {"query": "...", "passage": "..."} dicts)
3. Auto-generate from training data (if no examples provided)
-
+
Few-Shot Auto-Generation Config (via params or env vars):
- dataset_type: "msmarco", "beir", or "generic" (uses appropriate loader)
- num_examples: Number of few-shot examples (default: 4)
-
+
For MS MARCO datasets (dataset_type="msmarco"):
- collection_path / COLLECTION_PATH: Path to collection.tsv
- train_queries_path / TRAIN_QUERIES_PATH: Path to queries.tsv
- train_qrels_path / TRAIN_QRELS_PATH: Path to qrels file
-
+
For BEIR datasets (dataset_type="beir"):
- beir_data_dir / BEIR_DATA_DIR: Path to BEIR dataset directory
- train_split: "train" or "dev" (default: "train")
-
+
For generic datasets (dataset_type="generic" or omitted):
- collection_path: TSV file (docid \t text)
- train_queries_path: TSV file (qid \t query)
- train_qrels_path: TREC format (qid 0 docid relevance)
-
+
Note: MS MARCO env vars (MSMARCO_COLLECTION, etc.) are supported for backward compatibility.
"""
+
VERSION = "1.0"
CONCATENATION_STRATEGY = "query_repeat_plus_generated"
DEFAULT_QUERY_REPEATS = 5
@@ -51,51 +53,51 @@ def __init__(self, cfg, llm_client, prompt_resolver):
self._fewshot_data = None
# User-provided examples (set via set_examples() or params["examples"])
self._provided_examples = cfg.params.get("examples", None)
-
+
def _load_fewshot_data(self):
"""Lazy load training data for few-shot mode (supports MS MARCO, BEIR, or generic datasets)."""
if self._fewshot_data is not None:
return self._fewshot_data
-
+
try:
# Get dataset type (msmarco, beir, or generic)
dataset_type = self.cfg.params.get("dataset_type", "").lower()
-
+
# Get paths from config params or environment variables
collection_path = (
- self.cfg.params.get("collection_path") or
- self.cfg.params.get("msmarco_collection") or
- os.getenv("COLLECTION_PATH") or
- os.getenv("MSMARCO_COLLECTION")
+ self.cfg.params.get("collection_path")
+ or self.cfg.params.get("msmarco_collection")
+ or os.getenv("COLLECTION_PATH")
+ or os.getenv("MSMARCO_COLLECTION")
)
train_queries_path = (
- self.cfg.params.get("train_queries_path") or
- self.cfg.params.get("msmarco_train_queries") or
- os.getenv("TRAIN_QUERIES_PATH") or
- os.getenv("MSMARCO_TRAIN_QUERIES")
+ self.cfg.params.get("train_queries_path")
+ or self.cfg.params.get("msmarco_train_queries")
+ or os.getenv("TRAIN_QUERIES_PATH")
+ or os.getenv("MSMARCO_TRAIN_QUERIES")
)
train_qrels_path = (
- self.cfg.params.get("train_qrels_path") or
- self.cfg.params.get("msmarco_train_qrels") or
- os.getenv("TRAIN_QRELS_PATH") or
- os.getenv("MSMARCO_TRAIN_QRELS")
+ self.cfg.params.get("train_qrels_path")
+ or self.cfg.params.get("msmarco_train_qrels")
+ or os.getenv("TRAIN_QRELS_PATH")
+ or os.getenv("MSMARCO_TRAIN_QRELS")
)
-
+
# For BEIR, collection_path is actually the BEIR data directory
if dataset_type == "beir":
beir_data_dir = (
- self.cfg.params.get("beir_data_dir") or
- collection_path or
- os.getenv("BEIR_DATA_DIR")
+ self.cfg.params.get("beir_data_dir")
+ or collection_path
+ or os.getenv("BEIR_DATA_DIR")
)
train_split = self.cfg.params.get("train_split", "train")
-
+
if not beir_data_dir:
raise RuntimeError(
"Few-shot mode with BEIR requires beir_data_dir (via config or env var):\n"
" - beir_data_dir / BEIR_DATA_DIR: Path to BEIR dataset directory"
)
-
+
elif not all([collection_path, train_queries_path, train_qrels_path]):
raise RuntimeError(
"Few-shot mode requires training data paths (via config params or env vars):\n"
@@ -110,10 +112,11 @@ def _load_fewshot_data(self):
"\nFor backward compatibility, MS MARCO env vars are also supported:\n"
" - MSMARCO_COLLECTION, MSMARCO_TRAIN_QUERIES, MSMARCO_TRAIN_QRELS"
)
-
+
# Load data using appropriate loader based on dataset type
if dataset_type == "beir":
from ..loaders import beir
+
corpus = beir.load_corpus(beir_data_dir)
# Convert BEIR corpus format (dict with title/text) to simple text
collection = {}
@@ -122,47 +125,48 @@ def _load_fewshot_data(self):
title = doc_dict.get("title", "").strip()
text = doc_dict.get("text", "").strip()
collection[docid] = f"{title} {text}".strip() if title else text
-
+
train_queries_list = beir.load_queries(beir_data_dir)
train_queries_dict = {q.qid: q.text for q in train_queries_list}
train_qrels = beir.load_qrels(beir_data_dir, split=train_split)
-
+
elif dataset_type == "msmarco":
from ..loaders import msmarco
+
collection = msmarco.load_collection(collection_path)
train_queries_list = msmarco.load_queries(train_queries_path)
train_queries_dict = {q.qid: q.text for q in train_queries_list}
train_qrels = msmarco.load_qrels(train_qrels_path)
-
+
else:
# Generic: Load using DataLoader for maximum flexibility
from ..data.dataloader import DataLoader
-
+
# Load collection (TSV format: docid \t text)
collection = {}
- with open(collection_path, 'r', encoding='utf-8') as f:
- reader = csv.reader(f, delimiter='\t')
+ with open(collection_path, "r", encoding="utf-8") as f:
+ reader = csv.reader(f, delimiter="\t")
for row in reader:
if len(row) >= 2:
docid, text = row[0], row[1]
collection[docid] = text
-
+
# Load queries and qrels
train_queries_list = DataLoader.load_queries(train_queries_path, format="tsv")
train_queries_dict = {q.qid: q.text for q in train_queries_list}
train_qrels = DataLoader.load_qrels(train_qrels_path, format="trec")
-
+
self._fewshot_data = {
"collection": collection,
"train_queries": train_queries_dict,
- "train_qrels": train_qrels
+ "train_qrels": train_qrels,
}
-
+
return self._fewshot_data
-
+
except Exception as e:
raise RuntimeError(f"Failed to load few-shot training data: {e}")
-
+
def _select_few_shot_examples(self, num_examples: int = 4) -> List[Tuple[str, str]]:
"""Randomly sample relevant query-document pairs from training data."""
try:
@@ -170,43 +174,43 @@ def _select_few_shot_examples(self, num_examples: int = 4) -> List[Tuple[str, st
collection = data["collection"]
train_queries = data["train_queries"]
train_qrels = data["train_qrels"]
-
+
# Collect all relevant query-doc pairs
relevant_pairs = []
for qid, doc_rels in train_qrels.items():
for docid, relevance in doc_rels.items():
if relevance > 0: # Only relevant
relevant_pairs.append((qid, docid))
-
+
if not relevant_pairs:
raise RuntimeError("No relevant query-document pairs found")
-
+
# Sample pairs and fetch texts
sample_size = min(num_examples * 10, len(relevant_pairs))
sampled_pairs = random.sample(relevant_pairs, sample_size)
-
+
examples = []
for qid, docid in sampled_pairs:
if len(examples) >= num_examples:
break
-
+
query_text = train_queries.get(qid)
doc_text = collection.get(docid)
-
+
if query_text and doc_text:
examples.append((query_text, doc_text))
-
+
if not examples:
raise RuntimeError("Could not find valid examples with matching IDs")
-
+
if len(examples) < num_examples:
print(f"Warning: Only found {len(examples)}/{num_examples} examples")
-
+
return examples
-
+
except Exception as e:
raise RuntimeError(f"Failed to select few-shot examples: {e}")
-
+
def _format_examples(self, examples: List[Tuple[str, str]]) -> str:
"""Format examples for prompt template (query -> keywords extraction)."""
examples_text = ""
@@ -215,11 +219,11 @@ def _format_examples(self, examples: List[Tuple[str, str]]) -> str:
# In practice, you might want more sophisticated keyword extraction
words = passage.lower().split()
# Take first 5-7 meaningful words as keywords (simple heuristic)
- keywords = [w.strip('.,!?;:') for w in words[:7] if len(w) > 3]
+ keywords = [w.strip(".,!?;:") for w in words[:7] if len(w) > 3]
keywords_str = ", ".join(keywords[:5]) if keywords else "keywords, terms, phrases"
-
+
examples_text += f"Query: {query}\nKeywords: {keywords_str}\n"
-
+
return examples_text
def _parse_keywords(self, raw_output: str) -> List[str]:
@@ -229,41 +233,46 @@ def _parse_keywords(self, raw_output: str) -> List[str]:
- Bullet points: "- keyword1\n- keyword2"
- Numbered lists: "1. keyword1\n2. keyword2"
- Mixed formats
-
+
Returns:
List of cleaned keyword strings
"""
import re
-
+
if not raw_output or not raw_output.strip():
return []
-
+
# Normalize the text
text = raw_output.strip()
-
+
# Remove common prefixes like "Keywords:", "Here are the keywords:", etc.
- text = re.sub(r'^(keywords|here are|the keywords|list of keywords)[:\s]*', '', text, flags=re.IGNORECASE)
-
+ text = re.sub(
+ r"^(keywords|here are|the keywords|list of keywords)[:\s]*",
+ "",
+ text,
+ flags=re.IGNORECASE,
+ )
+
keywords = []
-
+
# Check if it's a bullet/numbered list format
- if '\n' in text or text.strip().startswith('-') or re.match(r'^\d+\.', text.strip()):
+ if "\n" in text or text.strip().startswith("-") or re.match(r"^\d+\.", text.strip()):
# Split by newlines
- lines = text.split('\n')
+ lines = text.split("\n")
for line in lines:
line = line.strip()
if not line:
continue
-
+
# Remove bullet points: -, *, •, ▪, etc.
- line = re.sub(r'^[\-\*•▪→►]\s*', '', line)
-
+ line = re.sub(r"^[\-\*•▪→►]\s*", "", line)
+
# Remove numbered prefixes: 1., 1), (1), etc.
- line = re.sub(r'^[\(\[]?\d+[\)\]\.:\-]?\s*', '', line)
-
+ line = re.sub(r"^[\(\[]?\d+[\)\]\.:\-]?\s*", "", line)
+
# If line contains commas, it might be multiple keywords
- if ',' in line:
- for part in line.split(','):
+ if "," in line:
+ for part in line.split(","):
part = part.strip()
if part:
keywords.append(part)
@@ -271,33 +280,33 @@ def _parse_keywords(self, raw_output: str) -> List[str]:
keywords.append(line)
else:
# Assume comma-separated format
- for part in text.split(','):
+ for part in text.split(","):
part = part.strip()
if part:
keywords.append(part)
-
+
# Clean each keyword
cleaned_keywords = []
for kw in keywords:
# Remove quotes
- kw = kw.strip('"\'`')
+ kw = kw.strip("\"'`")
# Remove trailing punctuation
- kw = kw.rstrip('.,;:!?')
+ kw = kw.rstrip(".,;:!?")
# Remove leading/trailing whitespace
kw = kw.strip()
# Skip empty or very short keywords
if kw and len(kw) > 1:
cleaned_keywords.append(kw)
-
+
return cleaned_keywords
def set_examples(self, examples: List[Tuple[str, str]]) -> None:
"""
Set user-provided examples for few-shot mode.
-
+
Args:
examples: List of (query, passage) tuples or list of dicts with 'query' and 'passage' keys
-
+
Example:
>>> reformulator.set_examples([
... ("how long is flea life cycle?", "The life cycle of a flea..."),
@@ -315,15 +324,15 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
mode = str(self.cfg.params.get("mode", "zs")) # Default to zero-shot
temperature = float(self.cfg.llm.get("temperature", 0.3))
max_tokens = int(self.cfg.llm.get("max_tokens", 256))
-
+
metadata = {"mode": mode}
-
+
try:
# Select prompt based on mode
if mode in ["fs", "fewshot"]:
# Few-shot: use provided examples or auto-generate from training data
num_examples = int(self.cfg.params.get("num_examples", 4))
-
+
if self._provided_examples:
# Use user-provided examples
# Convert dict format to tuple format if needed
@@ -338,9 +347,9 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
# Auto-generate from training data (original behavior)
examples = self._select_few_shot_examples(num_examples)
metadata["examples_source"] = "auto_generated"
-
+
examples_text = self._format_examples(examples)
-
+
msgs = self.prompts.render("q2e.fs.v1", query=q.text, examples=examples_text)
metadata["prompt_id"] = "q2e.fs.v1"
metadata["num_examples"] = len(examples)
@@ -349,43 +358,37 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
msgs = self.prompts.render(prompt_id, query=q.text)
metadata["prompt_id"] = prompt_id
else:
- raise ValueError(f"Invalid mode '{mode}' for Query2E. Must be 'zs' (zero-shot) or 'fs' (few-shot).")
-
+ raise ValueError(
+ f"Invalid mode '{mode}' for Query2E. Must be 'zs' (zero-shot) or 'fs' (few-shot)."
+ )
+
# Generate keywords
out = self.llm.chat(msgs, temperature=temperature, max_tokens=max_tokens)
-
+
# Parse keywords using robust parser
terms = self._parse_keywords(out)
-
+
# Limit to max 20 keywords
max_keywords = int(self.cfg.params.get("max_keywords", 20))
if len(terms) > max_keywords:
terms = terms[:max_keywords]
-
+
generated_content = " ".join(terms)
-
+
# Concatenate: query + keywords
reformulated = self.concatenate_result(q.text, generated_content)
-
- metadata.update({
- "keywords": terms,
- "temperature": temperature,
- "max_tokens": max_tokens
- })
-
- return ReformulationResult(
- q.qid,
- q.text,
- reformulated,
- metadata=metadata
+
+ metadata.update(
+ {"keywords": terms, "temperature": temperature, "max_tokens": max_tokens}
)
-
+
+ return ReformulationResult(q.qid, q.text, reformulated, metadata=metadata)
+
except Exception as e:
# Graceful error handling - fallback to original query
error_msg = f"Query2E failed: {e}"
print(f"Error for qid={q.qid}: {error_msg}")
-
+
return ReformulationResult(
- q.qid, q.text, q.text,
- metadata={"mode": mode, "error": error_msg, "fallback": True}
+ q.qid, q.text, q.text, metadata={"mode": mode, "error": error_msg, "fallback": True}
)
diff --git a/querygym/methods/thinkqe.py b/querygym/methods/thinkqe.py
new file mode 100644
index 0000000..4762117
--- /dev/null
+++ b/querygym/methods/thinkqe.py
@@ -0,0 +1,499 @@
+from __future__ import annotations
+
+from typing import Any, Dict, List, Optional, Set, Tuple
+
+from ..core.base import BaseReformulator, QueryItem, ReformulationResult
+from ..core.registry import register_method
+
+
+@register_method("thinkqe")
+class ThinkQE(BaseReformulator):
+ """ThinkQE: query expansion via iterative corpus feedback.
+
+ This implementation follows the original `archive/Think_QE/thinkqe.py`
+ loop while fitting QueryGym's `BaseReformulator` interface:
+
+ 1. Round 0 retrieves passages with the original query.
+ 2. Each later round prompts the LLM with the original query plus the
+ previous round's retrieved passages.
+ 3. The generated answering passage(s) are appended to the query.
+ 4. The updated query is used to retrieve again.
+
+ The method strips any `...` trace and keeps only the
+ answer passage for retrieval.
+ """
+
+ VERSION = "1.0"
+ REQUIRES_CONTEXT = True
+
+ def __init__(self, cfg, llm_client, prompt_resolver):
+ super().__init__(cfg, llm_client, prompt_resolver)
+ self._resolved_searcher = None
+ self._searcher_initialized = False
+
+ @staticmethod
+ def _extract_answer(text: str) -> str:
+ """Return only the answer text after the thinking trace."""
+ if "\n" in text:
+ return text.split("\n")[-1].strip()
+ if "" in text:
+ return text.split("")[-1].strip()
+ return text.strip()
+
+ @staticmethod
+ def _truncate_passage(passage: str, max_words: int) -> str:
+ """Trim a passage by whitespace tokens."""
+ words = passage.split()
+ if len(words) <= max_words:
+ return passage
+ return " ".join(words[:max_words])
+
+ @staticmethod
+ def _build_query(
+ original_query: str,
+ expansions: List[str],
+ repeat_weight: float,
+ lowercase: bool,
+ ) -> Tuple[str, int]:
+ """Build the retrieval query using ThinkQE's repetition heuristic."""
+ query_words = len(original_query.split())
+ expansion_words = len("\n".join(expansions).split())
+
+ if query_words > 0 and repeat_weight > 0:
+ q_repeat = max(1, int(expansion_words / (query_words * repeat_weight)))
+ else:
+ q_repeat = 1
+
+ reformulated = "\n".join([original_query] * q_repeat + expansions)
+ if lowercase:
+ reformulated = reformulated.lower()
+ return reformulated, q_repeat
+
+ @staticmethod
+ def _filter_passages(
+ passages: List[str],
+ keep_k: int,
+ seen: Set[str],
+ prev_prev_top: Set[str],
+ ) -> Tuple[List[str], Set[str]]:
+ """Apply the archive-style novelty filter over retrieved passages."""
+ filtered: List[str] = []
+ for passage in passages:
+ if passage in seen:
+ continue
+ if passage in prev_prev_top:
+ seen.add(passage)
+ continue
+ filtered.append(passage)
+ if len(filtered) >= keep_k:
+ break
+ return filtered, seen
+
+ def _get_prompt_id(self) -> str:
+ return str(self.cfg.params.get("prompt_id", "thinkqe.v1"))
+
+ def _get_keep_passage_num(self) -> int:
+ value = self.cfg.params.get(
+ "keep_passage_num",
+ self.cfg.params.get("retrieval_k", 5),
+ )
+ return int(value)
+
+ def _get_gen_num(self) -> int:
+ value = self.cfg.params.get(
+ "gen_num",
+ self.cfg.params.get("n_generations", 2),
+ )
+ return int(value)
+
+ def _get_repeat_weight(self) -> float:
+ value = self.cfg.params.get(
+ "repeat_weight",
+ self.cfg.params.get("reqeat_weight", 3),
+ )
+ return float(value)
+
+ def _get_search_k(self) -> int:
+ keep_passage_num = self._get_keep_passage_num()
+ value = self.cfg.params.get(
+ "search_k",
+ self.cfg.params.get("hits", keep_passage_num),
+ )
+ return max(int(value), keep_passage_num)
+
+ def _generate_expansions(
+ self,
+ query_text: str,
+ contexts: List[str],
+ gen_num: int,
+ max_demo_len: Optional[int],
+ temperature: float,
+ max_tokens: int,
+ no_thinking: bool,
+ ) -> Tuple[List[str], List[str]]:
+ """Generate `gen_num` expansions using repeated chat calls."""
+ top_passages = contexts
+ if max_demo_len is not None:
+ top_passages = [
+ self._truncate_passage(passage, int(max_demo_len)) for passage in top_passages
+ ]
+
+ contexts_blob = "\n".join(
+ f"{index + 1}. {passage}" for index, passage in enumerate(top_passages)
+ )
+ messages = self.prompts.render(
+ self._get_prompt_id(),
+ query=query_text,
+ contexts=contexts_blob,
+ )
+
+ if no_thinking:
+ messages = list(messages)
+ messages.append(
+ {
+ "role": "assistant",
+ "content": "Okay, I think I have finished thinking.\n\n",
+ }
+ )
+
+ raw_responses: List[str] = []
+ for _ in range(gen_num):
+ response = self.llm.chat(
+ messages,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ raw_responses.append(response)
+
+ expansions = [self._extract_answer(response) for response in raw_responses]
+ return expansions, raw_responses
+
+ def _build_searcher(self):
+ """Resolve a searcher from config if multi-round retrieval is enabled."""
+ if self._searcher_initialized:
+ return self._resolved_searcher
+
+ searcher = self.cfg.params.get("searcher")
+ if searcher is not None:
+ self._resolved_searcher = searcher
+ self._searcher_initialized = True
+ return self._resolved_searcher
+
+ if "searcher_type" in self.cfg.params:
+ from ..core.searcher import create_searcher
+
+ self._resolved_searcher = create_searcher(
+ self.cfg.params["searcher_type"],
+ **dict(self.cfg.params.get("searcher_kwargs", {})),
+ )
+ self._searcher_initialized = True
+ return self._resolved_searcher
+
+ index = self.cfg.params.get("index")
+ if not index:
+ self._searcher_initialized = True
+ return None
+
+ from ..core.searcher import create_searcher
+
+ searcher_kwargs: Dict[str, Any] = {
+ "index": index,
+ "answer_key": self.cfg.params.get("answer_key", "contents"),
+ }
+
+ k1 = self.cfg.params.get("k1")
+ b = self.cfg.params.get("b")
+ if k1 is not None and b is not None:
+ searcher_kwargs["k1"] = k1
+ searcher_kwargs["b"] = b
+
+ self._resolved_searcher = create_searcher("pyserini", **searcher_kwargs)
+ self._searcher_initialized = True
+ return self._resolved_searcher
+
+ def _build_retriever(self):
+ """Create a callable that retrieves passage text for a query string."""
+ searcher = self._build_searcher()
+ if searcher is None:
+ return None
+
+ search_k = self._get_search_k()
+
+ def _retriever(query_text: str) -> List[str]:
+ hits = searcher.search(query_text, k=search_k)
+ return [hit.content for hit in hits]
+
+ return _retriever
+
+ def _multi_round_enabled(self) -> bool:
+ """Return True when ThinkQE can run its full iterative loop."""
+ num_interaction = int(self.cfg.params.get("num_interaction", 3))
+ return num_interaction > 0 and self._build_searcher() is not None
+
+ @staticmethod
+ def _summarize_round(result: ReformulationResult) -> Dict[str, Any]:
+ """Create a compact per-round metadata summary."""
+ meta = dict(result.metadata or {})
+ raw_responses = meta.pop("raw_responses", None)
+ summary = {
+ "round": meta.pop("round", None),
+ "type": meta.pop("type", "expanded"),
+ "reformulated": result.reformulated,
+ "metadata": meta,
+ }
+ if raw_responses is not None:
+ summary["raw_response_count"] = len(raw_responses)
+ return summary
+
+ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult:
+ """Run single-step ThinkQE using provided contexts."""
+ ctxs: List[str] = contexts or []
+
+ keep_passage_num = self._get_keep_passage_num()
+ gen_num = self._get_gen_num()
+ repeat_weight = self._get_repeat_weight()
+ max_demo_len = self.cfg.params.get("max_demo_len")
+ lowercase = bool(self.cfg.params.get("lowercase", True))
+ no_thinking = bool(self.cfg.params.get("no_thinking", False))
+ temperature = float(self.cfg.llm.get("temperature", 0.7))
+ max_tokens = int(self.cfg.llm.get("max_tokens", 32768))
+
+ top_passages = ctxs[:keep_passage_num]
+ expansions, raw_responses = self._generate_expansions(
+ q.text,
+ top_passages,
+ gen_num,
+ max_demo_len,
+ temperature,
+ max_tokens,
+ no_thinking,
+ )
+ reformulated, q_repeat = self._build_query(
+ q.text,
+ expansions,
+ repeat_weight,
+ lowercase,
+ )
+
+ return ReformulationResult(
+ q.qid,
+ q.text,
+ reformulated,
+ metadata={
+ "prompt_id": self._get_prompt_id(),
+ "gen_num": gen_num,
+ "q_repeat": q_repeat,
+ "repeat_weight": repeat_weight,
+ "keep_passage_num": keep_passage_num,
+ "used_ctx": len(top_passages),
+ "expansions": expansions,
+ "raw_responses": raw_responses,
+ },
+ )
+
+ def reformulate_multi_round(
+ self,
+ q: QueryItem,
+ retriever,
+ *,
+ num_interaction: Optional[int] = None,
+ accumulate: Optional[bool] = None,
+ use_passage_filter: Optional[bool] = None,
+ ) -> List[ReformulationResult]:
+ """Run the full ThinkQE loop and return per-round results."""
+ if num_interaction is None:
+ num_interaction = int(self.cfg.params.get("num_interaction", 3))
+ if accumulate is None:
+ accumulate = bool(self.cfg.params.get("accumulate", True))
+ if use_passage_filter is None:
+ use_passage_filter = bool(self.cfg.params.get("use_passage_filter", True))
+
+ keep_passage_num = self._get_keep_passage_num()
+ gen_num = self._get_gen_num()
+ repeat_weight = self._get_repeat_weight()
+ max_demo_len = self.cfg.params.get("max_demo_len")
+ lowercase = bool(self.cfg.params.get("lowercase", True))
+ no_thinking = bool(self.cfg.params.get("no_thinking", False))
+ temperature = float(self.cfg.llm.get("temperature", 0.7))
+ max_tokens = int(self.cfg.llm.get("max_tokens", 32768))
+
+ accumulated_expansions: List[str] = []
+ seen_passages: Set[str] = set()
+ last_top_k: Set[str] = set()
+ prev_prev_top_k: Set[str] = set()
+
+ original_query_text = q.text
+ current_query_text = q.text
+ results: List[ReformulationResult] = []
+
+ last_top_passages = retriever(current_query_text)
+ last_top_k = set(last_top_passages[:keep_passage_num])
+ results.append(
+ ReformulationResult(
+ q.qid,
+ q.text,
+ current_query_text,
+ metadata={
+ "round": 0,
+ "type": "baseline",
+ "retrieved_count": len(last_top_passages),
+ },
+ )
+ )
+
+ for ridx in range(1, num_interaction + 1):
+ all_passages = last_top_passages
+ if use_passage_filter:
+ top_passages, seen_passages = self._filter_passages(
+ all_passages,
+ keep_passage_num,
+ seen_passages,
+ prev_prev_top_k,
+ )
+ else:
+ top_passages = all_passages[:keep_passage_num]
+
+ expansions, raw_responses = self._generate_expansions(
+ original_query_text,
+ top_passages,
+ gen_num,
+ max_demo_len,
+ temperature,
+ max_tokens,
+ no_thinking,
+ )
+
+ if accumulate:
+ accumulated_expansions.extend(expansions)
+ active_expansions = list(accumulated_expansions)
+ else:
+ active_expansions = expansions
+
+ reformulated, q_repeat = self._build_query(
+ original_query_text,
+ active_expansions,
+ repeat_weight,
+ lowercase,
+ )
+ current_query_text = reformulated
+ retrieved_passages = retriever(current_query_text)
+
+ prev_prev_top_k = last_top_k
+ last_top_k = set(retrieved_passages[:keep_passage_num])
+ last_top_passages = retrieved_passages
+
+ results.append(
+ ReformulationResult(
+ q.qid,
+ q.text,
+ reformulated,
+ metadata={
+ "round": ridx,
+ "prompt_id": self._get_prompt_id(),
+ "gen_num": gen_num,
+ "q_repeat": q_repeat,
+ "repeat_weight": repeat_weight,
+ "keep_passage_num": keep_passage_num,
+ "used_ctx": len(top_passages),
+ "search_k": self._get_search_k(),
+ "accumulate": accumulate,
+ "accumulated_count": len(active_expansions),
+ "use_passage_filter": use_passage_filter,
+ "expansions": expansions,
+ "raw_responses": raw_responses,
+ "retrieved_count": len(retrieved_passages),
+ },
+ )
+ )
+
+ return results
+
+ def reformulate_with_round_history(
+ self,
+ q: QueryItem,
+ contexts=None,
+ ) -> Tuple[ReformulationResult, List[ReformulationResult]]:
+ """Return the final result together with all round outputs."""
+ retriever = self._build_retriever()
+ if int(self.cfg.params.get("num_interaction", 3)) > 0 and retriever is not None:
+ round_results = self.reformulate_multi_round(q, retriever)
+ final_result = round_results[-1]
+ final_metadata = dict(final_result.metadata or {})
+ final_metadata["round_history"] = [
+ self._summarize_round(result) for result in round_results
+ ]
+ final_result = ReformulationResult(
+ final_result.qid,
+ final_result.original,
+ final_result.reformulated,
+ metadata=final_metadata,
+ )
+ return final_result, round_results
+
+ single_result = self.reformulate(q, contexts)
+ single_metadata = dict(single_result.metadata or {})
+ single_metadata["round_history"] = [self._summarize_round(single_result)]
+ single_result = ReformulationResult(
+ single_result.qid,
+ single_result.original,
+ single_result.reformulated,
+ metadata=single_metadata,
+ )
+ return single_result, [single_result]
+
+ def reformulate_batch(
+ self,
+ queries: List[QueryItem],
+ ctx_map=None,
+ ) -> List[ReformulationResult]:
+ """Auto-dispatch to the multi-round loop when retrieval is available."""
+ if self._multi_round_enabled():
+ results: List[ReformulationResult] = []
+ for query in queries:
+ final_result, _ = self.reformulate_with_round_history(query)
+ results.append(final_result)
+ return results
+ return super().reformulate_batch(queries, ctx_map)
+
+ def _get_retrieval_params(self) -> Optional[Dict[str, Any]]:
+ """Return retrieval settings for single-step batch reformulation."""
+ keep_passage_num = self._get_keep_passage_num()
+
+ if "searcher" in self.cfg.params:
+ return {
+ "searcher": self.cfg.params["searcher"],
+ "k": keep_passage_num,
+ "threads": int(self.cfg.params.get("threads", 16)),
+ }
+
+ if "searcher_type" in self.cfg.params:
+ searcher_kwargs = dict(self.cfg.params.get("searcher_kwargs", {}))
+ if "index" in self.cfg.params and "index" not in searcher_kwargs:
+ searcher_kwargs["index"] = self.cfg.params["index"]
+ return {
+ "searcher_type": self.cfg.params["searcher_type"],
+ "searcher_kwargs": searcher_kwargs,
+ "k": keep_passage_num,
+ "threads": int(self.cfg.params.get("threads", 16)),
+ }
+
+ index = self.cfg.params.get("index")
+ if not index:
+ return None
+
+ searcher_kwargs: Dict[str, Any] = {
+ "index": index,
+ "answer_key": self.cfg.params.get("answer_key", "contents"),
+ }
+ k1 = self.cfg.params.get("k1")
+ b = self.cfg.params.get("b")
+ if k1 is not None and b is not None:
+ searcher_kwargs["k1"] = k1
+ searcher_kwargs["b"] = b
+
+ return {
+ "searcher_type": "pyserini",
+ "searcher_kwargs": searcher_kwargs,
+ "k": keep_passage_num,
+ "threads": int(self.cfg.params.get("threads", 16)),
+ }
diff --git a/querygym/prompt_bank.yaml b/querygym/prompt_bank.yaml
index ce3dff7..6fe685e 100644
--- a/querygym/prompt_bank.yaml
+++ b/querygym/prompt_bank.yaml
@@ -397,6 +397,24 @@
notes: |
LameR variant for news article retrieval. Tailored for TREC-News dataset with relevant news passages.
+- id: thinkqe.v1
+ method_family: thinkqe
+ version: 1
+ introduced_by: "ThinkQE: Query Expansion via an Evolving Thinking Process"
+ license: "CC-BY-4.0"
+ authors: ["Haison Le", "queryGym Team"]
+ tags: ["context-based", "thinking", "passage-synthesis", "reasoning"]
+ template:
+ user: |
+ Given a question "{query}" and its possible answering passages (most of these passages are wrong) enumerated as:
+ {contexts}
+
+ please write a correct answering passage. Use your own knowledge, not just the example passages!
+ notes: |
+ ThinkQE uses retrieved passages as evidence, strips any ...
+ reasoning trace, and keeps only the final generated answering passage for
+ query expansion.
+
- id: keqe.v1
method_family: keqe
version: 1
@@ -483,4 +501,4 @@
Query: {query}
Keywords:
notes: |
- Few-shot Q2E reformulation prompt
\ No newline at end of file
+ Few-shot Q2E reformulation prompt
diff --git a/tests/test_cli_script_gen.py b/tests/test_cli_script_gen.py
index f28fcd2..a4c8c8c 100644
--- a/tests/test_cli_script_gen.py
+++ b/tests/test_cli_script_gen.py
@@ -1,5 +1,8 @@
from querygym.cli import build_script_lines
+
def test_script_lines():
- lines = build_script_lines(index_path="/idx", topics="/q.tsv", run="/run.txt", qrels="/qrels.txt")
+ lines = build_script_lines(
+ index_path="/idx", topics="/q.tsv", run="/run.txt", qrels="/qrels.txt"
+ )
assert any("--index /idx" in ln for ln in lines)
diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py
index 902319a..e32a88a 100644
--- a/tests/test_dataloader.py
+++ b/tests/test_dataloader.py
@@ -2,6 +2,7 @@
from pathlib import Path
import tempfile
+
def test_local_tsv_loading():
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".tsv") as f:
f.write("1\tapple\n2\tbanana\n")
diff --git a/tests/test_method_thinkqe.py b/tests/test_method_thinkqe.py
new file mode 100644
index 0000000..c6f3a3c
--- /dev/null
+++ b/tests/test_method_thinkqe.py
@@ -0,0 +1,114 @@
+from pathlib import Path
+
+from querygym.core.base import MethodConfig, QueryItem
+from querygym.core.prompts import PromptBank
+from querygym.core.searcher import BaseSearcher, SearchHit
+from querygym.methods.thinkqe import ThinkQE
+
+
+class DummyLLM:
+ def __init__(self, responses):
+ self.responses = list(responses)
+ self.messages = []
+
+ def chat(self, messages, **kwargs):
+ self.messages.append(messages)
+ return self.responses.pop(0)
+
+
+class DummySearcher(BaseSearcher):
+ def __init__(self):
+ self.queries = []
+
+ def search(self, query: str, k: int = 10, **kwargs):
+ self.queries.append(query)
+ if "expansion one" in query:
+ passages = [
+ "Fresh support passage one",
+ "Fresh support passage two",
+ "Initial evidence passage one",
+ ]
+ else:
+ passages = [
+ "Initial evidence passage one",
+ "Initial evidence passage two",
+ "Initial evidence passage three",
+ ]
+ return [
+ SearchHit(docid=f"d{idx}", score=1.0 / (idx + 1), content=passage)
+ for idx, passage in enumerate(passages[:k])
+ ]
+
+ def batch_search(self, queries, k: int = 10, num_threads: int = 1, **kwargs):
+ return [self.search(query, k=k, **kwargs) for query in queries]
+
+ def get_searcher_info(self):
+ return {"name": "dummy"}
+
+
+def _prompt_bank():
+ return PromptBank(Path(__file__).parents[1] / "querygym" / "prompt_bank.yaml")
+
+
+def test_thinkqe_single_round_strips_thinking_trace():
+ cfg = MethodConfig(
+ name="thinkqe",
+ params={"gen_num": 2, "repeat_weight": 3, "keep_passage_num": 2},
+ llm={"model": "dummy", "temperature": 0.7, "max_tokens": 256},
+ )
+ llm = DummyLLM(
+ [
+ "\nreasoning\n\nAnswer Passage One",
+ "\nmore reasoning\n\nAnswer Passage Two",
+ ]
+ )
+ method = ThinkQE(cfg, llm, _prompt_bank())
+
+ result = method.reformulate(
+ QueryItem("q1", "test query"),
+ ["Context one", "Context two", "Context three"],
+ )
+
+ assert result.metadata["expansions"] == [
+ "Answer Passage One",
+ "Answer Passage Two",
+ ]
+ assert result.metadata["used_ctx"] == 2
+ assert "answer passage one" in result.reformulated
+ assert "answer passage two" in result.reformulated
+
+
+def test_thinkqe_multi_round_uses_searcher_and_records_history():
+ cfg = MethodConfig(
+ name="thinkqe",
+ params={
+ "searcher": DummySearcher(),
+ "num_interaction": 2,
+ "accumulate": True,
+ "use_passage_filter": True,
+ "gen_num": 1,
+ "keep_passage_num": 2,
+ "search_k": 3,
+ },
+ llm={"model": "dummy", "temperature": 0.7, "max_tokens": 256},
+ )
+ llm = DummyLLM(
+ [
+ "\ntrace\n\nExpansion One",
+ "\ntrace\n\nExpansion Two",
+ ]
+ )
+ method = ThinkQE(cfg, llm, _prompt_bank())
+
+ results = method.reformulate_batch([QueryItem("q1", "original query")])
+ result = results[0]
+ round_history = result.metadata["round_history"]
+
+ assert len(round_history) == 3
+ assert round_history[0]["metadata"]["retrieved_count"] == 3
+ assert round_history[-1]["metadata"]["accumulated_count"] == 2
+ assert "expansion one" in result.reformulated
+ assert "expansion two" in result.reformulated
+ assert "original query" in llm.messages[0][0]["content"]
+ assert "Initial evidence passage one" in llm.messages[0][0]["content"]
+ assert "Fresh support passage one" in llm.messages[1][0]["content"]
diff --git a/tests/test_methods_genqr.py b/tests/test_methods_genqr.py
index 078f5a7..c34616e 100644
--- a/tests/test_methods_genqr.py
+++ b/tests/test_methods_genqr.py
@@ -3,12 +3,16 @@
from querygym.methods.genqr_ensemble import GenQREnsemble
from pathlib import Path
+
class DummyLLM:
def chat(self, messages, **kwargs):
return "bertopic"
+
def test_genqr_ensemble():
- cfg = MethodConfig(name="genqr_ensemble", params={"repeat_query_weight":2}, llm={"model":"dummy"})
+ cfg = MethodConfig(
+ name="genqr_ensemble", params={"repeat_query_weight": 2}, llm={"model": "dummy"}
+ )
llm = DummyLLM()
pb = PromptBank(Path(__file__).parents[1] / "querygym" / "prompt_bank.yaml")
meth = GenQREnsemble(cfg, llm, pb)
diff --git a/tests/test_prompts.py b/tests/test_prompts.py
index 7cb9681..004f242 100644
--- a/tests/test_prompts.py
+++ b/tests/test_prompts.py
@@ -1,6 +1,7 @@
from querygym.core.prompts import PromptBank
from pathlib import Path
+
def test_prompt_bank_loads():
pb = PromptBank(Path(__file__).parents[1] / "querygym" / "prompt_bank.yaml")
assert len(pb.list()) > 0