From 2e4b4890dd3b84ae788f6b119659f0bb7989caee Mon Sep 17 00:00:00 2001 From: haisonle001 Date: Mon, 30 Mar 2026 20:16:43 -0400 Subject: [PATCH 1/2] Add ThinkQE reformulation method --- README.md | 1 + docs/getting-started/quickstart.md | 3 +- docs/user-guide/methods-reference.md | 67 ++- docs/user-guide/reformulation.md | 21 + examples/querygym_pyserini/README.md | 3 +- .../querygym_pyserini/reformulate_queries.py | 3 +- .../reformulation_config.yaml | 19 +- querygym/__init__.py | 80 ++- querygym/core/runner.py | 19 +- querygym/methods/__init__.py | 1 + querygym/methods/thinkqe.py | 499 ++++++++++++++++++ querygym/prompt_bank.yaml | 20 +- tests/test_method_thinkqe.py | 114 ++++ 13 files changed, 795 insertions(+), 55 deletions(-) create mode 100644 querygym/methods/thinkqe.py create mode 100644 tests/test_method_thinkqe.py diff --git a/README.md b/README.md index a794ea4..a563260 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ QueryGym implements the following query reformulation methods: | **MuGI** | Multi-granularity information expansion with adaptive concatenation | [Zhang et al., 2024](https://arxiv.org/abs/2401.06311) | | **LameR** | Context-based passage synthesis using retrieved documents | [Mackie et al., 2023](https://arxiv.org/abs/2304.14233) | | **CSQE** | Context-based sentence-level query expansion (KEQE + CSQE) | [Lee et al., 2024](https://arxiv.org/abs/2402.18031) | +| **ThinkQE** | Multi-round reasoning-based query expansion with corpus feedback | [Le et al., 2025](https://arxiv.org/abs/2506.09260) | | **Query2E** | Query to entity/keyword expansion | [Jagerman et al., 2023](https://arxiv.org/abs/2305.03653)| For detailed usage and parameters, see the [Methods Reference](https://querygym.readthedocs.io/en/latest/user-guide/methods-reference/). diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index 9991c3f..ba9ea30 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -48,6 +48,7 @@ Available methods: - `lamer` - Context-based passage synthesis - `query2e` - Query to entity expansion - `csqe` - Context-based sentence extraction +- `thinkqe` - Multi-round reasoning-based passage expansion ### 3. Reformulate Queries @@ -137,7 +138,7 @@ See [Loading Datasets](../user-guide/datasets.md) for more details. ## Context-Based Reformulation -Some methods (like `lamer`, `csqe`) use retrieved contexts: +Some methods (like `lamer`, `csqe`, `thinkqe`) use retrieved contexts: ```python import querygym as qg diff --git a/docs/user-guide/methods-reference.md b/docs/user-guide/methods-reference.md index 1aaf00a..7a40b36 100644 --- a/docs/user-guide/methods-reference.md +++ b/docs/user-guide/methods-reference.md @@ -15,6 +15,7 @@ Complete reference guide for all query reformulation methods in QueryGym, includ - [LameR](#lamer) - [Query2E](#query2e) - [CSQE](#csqe) + - [ThinkQE](#thinkqe) --- @@ -561,6 +562,68 @@ result = reformulator.reformulate(qg.QueryItem("q1", "quantum computing")) --- +### ThinkQE + +**Method Name:** `"thinkqe"` +**Requires Context:** Yes +**Description:** Multi-round query expansion with retrieved passage feedback. Each round uses the original query plus newly retrieved passages to generate pseudo-passages, appends them to the retrieval query, and retrieves again. + +#### Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `keep_passage_num` | int | `5` | Number of retrieved passages kept for prompting | +| `gen_num` | int | `2` | Number of expansions generated per round | +| `num_interaction` | int | `3` | Number of expansion rounds after baseline retrieval | +| `accumulate` | bool | `True` | Accumulate all previous expansions into later rounds | +| `use_passage_filter` | bool | `True` | Blacklist passages repeated from two rounds ago | +| `repeat_weight` | float | `3` | Divisor for adaptive query repetition | +| `search_k` | int | `keep_passage_num` | Retrieval depth for each round before filtering; use `1000` to mirror the original archive runs | +| `max_demo_len` | int | `None` | Optional word truncation length for each passage | +| `no_thinking` | bool | `False` | Prefill a closing `` tag to disable reasoning traces | +| `searcher` | object | `None` | Pre-configured searcher instance (recommended) | +| `searcher_type` | str | `"pyserini"` | Type of searcher to create | +| `searcher_kwargs` | dict | `{}` | Keyword arguments for searcher initialization | +| `index` | str | `None` | Pyserini index name (legacy format) | +| `temperature` | float | `0.7` | Sampling temperature (via `llm_config`) | +| `max_tokens` | int | `32768` | Maximum tokens per generation (via `llm_config`) | + +#### Usage Example + +```python +import querygym as qg +from pyserini.search.lucene import LuceneSearcher + +pyserini_searcher = LuceneSearcher.from_prebuilt_index("msmarco-v1-passage") +searcher = qg.wrap_pyserini_searcher(pyserini_searcher, answer_key="contents") + +reformulator = qg.create_reformulator( + "thinkqe", + model="deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", + params={ + "searcher": searcher, + "keep_passage_num": 5, + "gen_num": 2, + "num_interaction": 3, + "accumulate": True, + "use_passage_filter": True, + "repeat_weight": 3, + "search_k": 1000, + "max_demo_len": 128, + }, + llm_config={"temperature": 0.7, "max_tokens": 32768} +) + +result = reformulator.reformulate_batch([qg.QueryItem("q1", "quantum computing")])[0] +``` + +#### Output Format + +- **Concatenation:** `(query × adaptive_repeat) + expansion_1 + expansion_2 + ...` using newline joins +- **Metadata:** Includes `round_history`, `gen_num`, `keep_passage_num`, `accumulated_count`, `q_repeat`, and per-round raw response counts + +--- + ## Quick Reference Table | Method | Requires Context | Key Parameters | Default LLM Config | @@ -573,12 +636,13 @@ result = reformulator.reformulate(qg.QueryItem("q1", "quantum computing")) | **LameR** | Yes | `retrieval_k`, `gen_passages`, `searcher` | temp=1.0, max_tokens=128 | | **Query2E** | No | `mode`, `max_keywords`, `num_examples` (fs) | temp=0.3, max_tokens=256 | | **CSQE** | Yes | `retrieval_k`, `gen_num`, `searcher` | temp=1.0, max_tokens=1024 | +| **ThinkQE** | Yes | `keep_passage_num`, `gen_num`, `num_interaction`, `searcher` | temp=0.7, max_tokens=32768 | --- ## Tips and Best Practices -1. **Context-Based Methods (LameR, CSQE):** +1. **Context-Based Methods (LameR, CSQE, ThinkQE):** - Always provide a `searcher` instance or configure `searcher_type`/`searcher_kwargs` - Use `qg.wrap_pyserini_searcher()` for easy integration with Pyserini - Set appropriate `retrieval_k` based on your needs (default: 10) @@ -604,4 +668,3 @@ result = reformulator.reformulate(qg.QueryItem("q1", "quantum computing")) - [API Reference](../api/methods.md) - Technical API documentation - [Query Reformulation Guide](reformulation.md) - Usage tutorials - [Examples](https://github.com/ls3-lab/QueryGym/tree/main/examples) - Complete workflow examples - diff --git a/docs/user-guide/reformulation.md b/docs/user-guide/reformulation.md index 3f98cdc..bcfaf31 100644 --- a/docs/user-guide/reformulation.md +++ b/docs/user-guide/reformulation.md @@ -93,6 +93,26 @@ reformulator = qg.create_reformulator("csqe", model="gpt-4") results = reformulator.reformulate_batch(queries, contexts=contexts) ``` +### ThinkQE + +Multi-round reasoning-based expansion with iterative corpus feedback. + +```python +reformulator = qg.create_reformulator( + "thinkqe", + model="deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", + params={ + "searcher": searcher, + "num_interaction": 3, + "keep_passage_num": 5, + "gen_num": 2, + "accumulate": True, + "use_passage_filter": True, + "search_k": 1000, + }, +) +``` + ## Method Comparison | Method | Requires Context | Type | Best For | @@ -105,6 +125,7 @@ results = reformulator.reformulate_batch(queries, contexts=contexts) | lamer | Yes | Context synthesis | Re-ranking | | query2e | No | Entity expansion | Entity queries | | csqe | Yes | Sentence extraction | Precision-focused | +| thinkqe | Yes | Iterative reasoning | Multi-round feedback | ## Custom Parameters diff --git a/examples/querygym_pyserini/README.md b/examples/querygym_pyserini/README.md index 54b7f7a..48b006b 100644 --- a/examples/querygym_pyserini/README.md +++ b/examples/querygym_pyserini/README.md @@ -127,7 +127,7 @@ python examples/querygym_pyserini/reformulate_queries.py \ ``` **Options:** -- `--method`: QueryGym method (genqr, genqr_ensemble, query2doc, qa_expand, mugi, lamer, query2e, csqe) +- `--method`: QueryGym method (genqr, genqr_ensemble, query2doc, qa_expand, mugi, lamer, query2e, csqe, thinkqe) - `--model`: LLM model name (e.g., qwen2.5:7b, llama3.1:8b, gpt-4, etc.) - `--base-url`: LLM API endpoint (e.g., http://localhost:11434/v1) - `--api-key`: LLM API key @@ -525,4 +525,3 @@ python examples/querygym_pyserini/pipeline.py \ ``` **Note:** The config file (`reformulation_config.yaml`) contains pre-configured settings for all methods, including complex method parameters. CLI arguments override config file values. - diff --git a/examples/querygym_pyserini/reformulate_queries.py b/examples/querygym_pyserini/reformulate_queries.py index fa92bcf..f74d7eb 100755 --- a/examples/querygym_pyserini/reformulate_queries.py +++ b/examples/querygym_pyserini/reformulate_queries.py @@ -300,7 +300,7 @@ def main(): parser.add_argument( '--method', type=str, - help='QueryGym reformulation method (genqr, genqr_ensemble, query2doc, etc.)' + help='QueryGym reformulation method (genqr, genqr_ensemble, query2doc, lamer, csqe, thinkqe, etc.)' ) parser.add_argument( '--model', @@ -485,4 +485,3 @@ def main(): if __name__ == '__main__': main() - diff --git a/examples/querygym_pyserini/reformulation_config.yaml b/examples/querygym_pyserini/reformulation_config.yaml index e7e9770..dfa6019 100644 --- a/examples/querygym_pyserini/reformulation_config.yaml +++ b/examples/querygym_pyserini/reformulation_config.yaml @@ -118,6 +118,24 @@ methods: gen_num: 2 # Number of expansions for both KEQE and CSQE (default: 2) # Note: Searcher is automatically configured from dataset registry + # ThinkQE: Multi-round reasoning-based query expansion (requires retrieval) + thinkqe: + enabled: true + model: "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B" + llm: + temperature: 0.7 + max_tokens: 32768 + params: + keep_passage_num: 5 + gen_num: 2 + num_interaction: 3 + accumulate: true + use_passage_filter: true + repeat_weight: 3 + max_demo_len: 128 + search_k: 1000 + # Note: Searcher is automatically configured from dataset registry + # Example: Multiple configurations for the same method # You can define multiple variants with different parameters: # @@ -136,4 +154,3 @@ methods: # temperature: 0.9 # params: # n_generations: 7 - diff --git a/querygym/__init__.py b/querygym/__init__.py index fe1c2b2..f31fdd2 100644 --- a/querygym/__init__.py +++ b/querygym/__init__.py @@ -3,10 +3,10 @@ Simple usage: import querygym as qg - + # Create a reformulator reformulator = qg.create_reformulator("genqr_ensemble") - + # Reformulate queries result = reformulator.reformulate(qg.QueryItem("q1", "what causes diabetes")) """ @@ -24,7 +24,11 @@ from .core.searcher import BaseSearcher, SearchHit, SearcherRegistry, create_searcher # Searcher wrappers for user convenience -from .core.searcher_wrappers import wrap_pyserini_searcher, wrap_pyterrier_retriever, wrap_custom_searcher +from .core.searcher_wrappers import ( + wrap_pyserini_searcher, + wrap_pyterrier_retriever, + wrap_custom_searcher, +) # Data loaders from .data.dataloader import DataLoader, UnifiedQuerySource @@ -42,6 +46,7 @@ LameR, Query2E, CSQE, + ThinkQE, ) # High-level runner @@ -52,6 +57,7 @@ # Import adapters to register them from . import adapters + # Convenience factory function def create_reformulator( method_name: str, @@ -59,11 +65,11 @@ def create_reformulator( params: dict = None, llm_config: dict = None, prompt_bank_path: str = None, - **kwargs + **kwargs, ): """ Create a reformulator instance with sensible defaults. - + Args: method_name: Name of the method (e.g., "genqr", "genqr_ensemble", "query2doc") model: LLM model name (default: "gpt-4") @@ -71,52 +77,52 @@ def create_reformulator( llm_config: Additional LLM configuration (temperature, max_tokens, etc.) prompt_bank_path: Path to prompt bank YAML (default: bundled prompt_bank.yaml) **kwargs: Additional MethodConfig parameters (seed, retries) - + Returns: BaseReformulator instance - + Example: >>> import querygym as qg >>> reformulator = qg.create_reformulator("genqr_ensemble", model="gpt-4") >>> result = reformulator.reformulate(qg.QueryItem("q1", "what causes diabetes")) """ from pathlib import Path - + if params is None: params = {} - + if llm_config is None: llm_config = {} - + # Set up LLM config llm_cfg = { "model": model, "temperature": llm_config.get("temperature", 0.8), "max_tokens": llm_config.get("max_tokens", 256), - **{k: v for k, v in llm_config.items() if k not in ["temperature", "max_tokens"]} + **{k: v for k, v in llm_config.items() if k not in ["temperature", "max_tokens"]}, } - + # Create config config = MethodConfig( name=method_name, params=params, llm=llm_cfg, seed=kwargs.get("seed", 42), - retries=kwargs.get("retries", 2) + retries=kwargs.get("retries", 2), ) - + # Build LLM client llm = build_llm(config) - + # Load prompt bank if prompt_bank_path is None: prompt_bank_path = Path(__file__).parent / "prompt_bank.yaml" pb = PromptBank(prompt_bank_path) - + # Get method class and instantiate if method_name not in METHODS: raise ValueError(f"Unknown method: {method_name}. Available: {list(METHODS.keys())}") - + MethodClass = METHODS[method_name] return MethodClass(config, llm, pb) @@ -124,15 +130,15 @@ def create_reformulator( def load_queries(path: str, format: str = "tsv", **kwargs): """ Load queries from a local file. - + Args: path: Path to queries file format: File format - "tsv" or "jsonl" (default: "tsv") **kwargs: Additional parameters for DataLoader.load_queries() - + Returns: List of QueryItem objects - + Example: >>> queries = qg.load_queries("queries.tsv", format="tsv") >>> queries = qg.load_queries("queries.jsonl", format="jsonl") @@ -143,15 +149,15 @@ def load_queries(path: str, format: str = "tsv", **kwargs): def load_qrels(path: str, format: str = "trec", **kwargs): """ Load qrels from a local file. - + Args: path: Path to qrels file format: File format - "trec" (default: "trec") **kwargs: Additional parameters for DataLoader.load_qrels() - + Returns: Dict mapping qid -> {docid -> relevance} - + Example: >>> qrels = qg.load_qrels("qrels.txt") """ @@ -161,14 +167,14 @@ def load_qrels(path: str, format: str = "trec", **kwargs): def load_contexts(path: str, **kwargs): """ Load contexts from a JSONL file. - + Args: path: Path to contexts JSONL file **kwargs: Additional parameters for DataLoader.load_contexts() - + Returns: Dict mapping qid -> list of context strings - + Example: >>> contexts = qg.load_contexts("contexts.jsonl") """ @@ -178,20 +184,20 @@ def load_contexts(path: str, **kwargs): def load_examples(path: str, **kwargs): """ Load few-shot examples from a JSONL file. - + Used for methods like Query2E that need (query, passage) pairs as demonstrations. - + Args: path: Path to examples JSONL file **kwargs: Additional parameters for DataLoader.load_examples() - + Returns: List of dicts with 'query' and 'passage' keys - + Example: >>> examples = qg.load_examples("examples.jsonl") >>> reformulator = qg.create_reformulator("query2e", params={"mode": "fs", "examples": examples}) - + JSONL format: {"query": "how long is flea life cycle?", "passage": "The life cycle of a flea..."} {"query": "cost of flooring?", "passage": "The cost of interior concrete..."} @@ -202,33 +208,27 @@ def load_examples(path: str, **kwargs): __all__ = [ # Version "__version__", - # Core classes "QueryItem", "ReformulationResult", "MethodConfig", "BaseReformulator", - # LLM & Prompts "OpenAICompatibleClient", "PromptBank", - # Searcher interface "BaseSearcher", "SearchHit", "SearcherRegistry", "create_searcher", - # Searcher wrappers "wrap_pyserini_searcher", - "wrap_pyterrier_retriever", + "wrap_pyterrier_retriever", "wrap_custom_searcher", - # Data "DataLoader", "UnifiedQuerySource", # Deprecated "loaders", - # Methods "GENQR", "GenQREnsemble", @@ -238,7 +238,7 @@ def load_examples(path: str, **kwargs): "LameR", "Query2E", "CSQE", - + "ThinkQE", # High-level API "run_method", "build_llm", @@ -247,10 +247,8 @@ def load_examples(path: str, **kwargs): "load_qrels", "load_contexts", "load_examples", - # Retrieval "Retriever", - # Registry "METHODS", "register_method", diff --git a/querygym/core/runner.py b/querygym/core/runner.py index 38f02df..88abba7 100644 --- a/querygym/core/runner.py +++ b/querygym/core/runner.py @@ -13,15 +13,24 @@ from ..methods.mugi import MuGI from ..methods.lamer import LameR from ..methods.query2e import Query2E +from ..methods.csqe import CSQE +from ..methods.thinkqe import ThinkQE + def build_llm(cfg: MethodConfig): llm_cfg = cfg.llm - return OpenAICompatibleClient(model=llm_cfg["model"], - base_url=llm_cfg.get("base_url"), - api_key=llm_cfg.get("api_key")) + return OpenAICompatibleClient( + model=llm_cfg["model"], base_url=llm_cfg.get("base_url"), api_key=llm_cfg.get("api_key") + ) + -def run_method(method_name: str, cfg: MethodConfig, queries: List[QueryItem], - prompt_bank_path: str, ctx_map: Optional[Dict[str, List[str]]] = None): +def run_method( + method_name: str, + cfg: MethodConfig, + queries: List[QueryItem], + prompt_bank_path: str, + ctx_map: Optional[Dict[str, List[str]]] = None, +): Method = METHODS[method_name] llm = build_llm(cfg) pb = PromptBank(prompt_bank_path) diff --git a/querygym/methods/__init__.py b/querygym/methods/__init__.py index ba9442e..363faaf 100644 --- a/querygym/methods/__init__.py +++ b/querygym/methods/__init__.py @@ -6,3 +6,4 @@ from .lamer import LameR from .query2e import Query2E from .csqe import CSQE +from .thinkqe import ThinkQE diff --git a/querygym/methods/thinkqe.py b/querygym/methods/thinkqe.py new file mode 100644 index 0000000..4762117 --- /dev/null +++ b/querygym/methods/thinkqe.py @@ -0,0 +1,499 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Set, Tuple + +from ..core.base import BaseReformulator, QueryItem, ReformulationResult +from ..core.registry import register_method + + +@register_method("thinkqe") +class ThinkQE(BaseReformulator): + """ThinkQE: query expansion via iterative corpus feedback. + + This implementation follows the original `archive/Think_QE/thinkqe.py` + loop while fitting QueryGym's `BaseReformulator` interface: + + 1. Round 0 retrieves passages with the original query. + 2. Each later round prompts the LLM with the original query plus the + previous round's retrieved passages. + 3. The generated answering passage(s) are appended to the query. + 4. The updated query is used to retrieve again. + + The method strips any `...` trace and keeps only the + answer passage for retrieval. + """ + + VERSION = "1.0" + REQUIRES_CONTEXT = True + + def __init__(self, cfg, llm_client, prompt_resolver): + super().__init__(cfg, llm_client, prompt_resolver) + self._resolved_searcher = None + self._searcher_initialized = False + + @staticmethod + def _extract_answer(text: str) -> str: + """Return only the answer text after the thinking trace.""" + if "\n" in text: + return text.split("\n")[-1].strip() + if "" in text: + return text.split("")[-1].strip() + return text.strip() + + @staticmethod + def _truncate_passage(passage: str, max_words: int) -> str: + """Trim a passage by whitespace tokens.""" + words = passage.split() + if len(words) <= max_words: + return passage + return " ".join(words[:max_words]) + + @staticmethod + def _build_query( + original_query: str, + expansions: List[str], + repeat_weight: float, + lowercase: bool, + ) -> Tuple[str, int]: + """Build the retrieval query using ThinkQE's repetition heuristic.""" + query_words = len(original_query.split()) + expansion_words = len("\n".join(expansions).split()) + + if query_words > 0 and repeat_weight > 0: + q_repeat = max(1, int(expansion_words / (query_words * repeat_weight))) + else: + q_repeat = 1 + + reformulated = "\n".join([original_query] * q_repeat + expansions) + if lowercase: + reformulated = reformulated.lower() + return reformulated, q_repeat + + @staticmethod + def _filter_passages( + passages: List[str], + keep_k: int, + seen: Set[str], + prev_prev_top: Set[str], + ) -> Tuple[List[str], Set[str]]: + """Apply the archive-style novelty filter over retrieved passages.""" + filtered: List[str] = [] + for passage in passages: + if passage in seen: + continue + if passage in prev_prev_top: + seen.add(passage) + continue + filtered.append(passage) + if len(filtered) >= keep_k: + break + return filtered, seen + + def _get_prompt_id(self) -> str: + return str(self.cfg.params.get("prompt_id", "thinkqe.v1")) + + def _get_keep_passage_num(self) -> int: + value = self.cfg.params.get( + "keep_passage_num", + self.cfg.params.get("retrieval_k", 5), + ) + return int(value) + + def _get_gen_num(self) -> int: + value = self.cfg.params.get( + "gen_num", + self.cfg.params.get("n_generations", 2), + ) + return int(value) + + def _get_repeat_weight(self) -> float: + value = self.cfg.params.get( + "repeat_weight", + self.cfg.params.get("reqeat_weight", 3), + ) + return float(value) + + def _get_search_k(self) -> int: + keep_passage_num = self._get_keep_passage_num() + value = self.cfg.params.get( + "search_k", + self.cfg.params.get("hits", keep_passage_num), + ) + return max(int(value), keep_passage_num) + + def _generate_expansions( + self, + query_text: str, + contexts: List[str], + gen_num: int, + max_demo_len: Optional[int], + temperature: float, + max_tokens: int, + no_thinking: bool, + ) -> Tuple[List[str], List[str]]: + """Generate `gen_num` expansions using repeated chat calls.""" + top_passages = contexts + if max_demo_len is not None: + top_passages = [ + self._truncate_passage(passage, int(max_demo_len)) for passage in top_passages + ] + + contexts_blob = "\n".join( + f"{index + 1}. {passage}" for index, passage in enumerate(top_passages) + ) + messages = self.prompts.render( + self._get_prompt_id(), + query=query_text, + contexts=contexts_blob, + ) + + if no_thinking: + messages = list(messages) + messages.append( + { + "role": "assistant", + "content": "Okay, I think I have finished thinking.\n\n", + } + ) + + raw_responses: List[str] = [] + for _ in range(gen_num): + response = self.llm.chat( + messages, + temperature=temperature, + max_tokens=max_tokens, + ) + raw_responses.append(response) + + expansions = [self._extract_answer(response) for response in raw_responses] + return expansions, raw_responses + + def _build_searcher(self): + """Resolve a searcher from config if multi-round retrieval is enabled.""" + if self._searcher_initialized: + return self._resolved_searcher + + searcher = self.cfg.params.get("searcher") + if searcher is not None: + self._resolved_searcher = searcher + self._searcher_initialized = True + return self._resolved_searcher + + if "searcher_type" in self.cfg.params: + from ..core.searcher import create_searcher + + self._resolved_searcher = create_searcher( + self.cfg.params["searcher_type"], + **dict(self.cfg.params.get("searcher_kwargs", {})), + ) + self._searcher_initialized = True + return self._resolved_searcher + + index = self.cfg.params.get("index") + if not index: + self._searcher_initialized = True + return None + + from ..core.searcher import create_searcher + + searcher_kwargs: Dict[str, Any] = { + "index": index, + "answer_key": self.cfg.params.get("answer_key", "contents"), + } + + k1 = self.cfg.params.get("k1") + b = self.cfg.params.get("b") + if k1 is not None and b is not None: + searcher_kwargs["k1"] = k1 + searcher_kwargs["b"] = b + + self._resolved_searcher = create_searcher("pyserini", **searcher_kwargs) + self._searcher_initialized = True + return self._resolved_searcher + + def _build_retriever(self): + """Create a callable that retrieves passage text for a query string.""" + searcher = self._build_searcher() + if searcher is None: + return None + + search_k = self._get_search_k() + + def _retriever(query_text: str) -> List[str]: + hits = searcher.search(query_text, k=search_k) + return [hit.content for hit in hits] + + return _retriever + + def _multi_round_enabled(self) -> bool: + """Return True when ThinkQE can run its full iterative loop.""" + num_interaction = int(self.cfg.params.get("num_interaction", 3)) + return num_interaction > 0 and self._build_searcher() is not None + + @staticmethod + def _summarize_round(result: ReformulationResult) -> Dict[str, Any]: + """Create a compact per-round metadata summary.""" + meta = dict(result.metadata or {}) + raw_responses = meta.pop("raw_responses", None) + summary = { + "round": meta.pop("round", None), + "type": meta.pop("type", "expanded"), + "reformulated": result.reformulated, + "metadata": meta, + } + if raw_responses is not None: + summary["raw_response_count"] = len(raw_responses) + return summary + + def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: + """Run single-step ThinkQE using provided contexts.""" + ctxs: List[str] = contexts or [] + + keep_passage_num = self._get_keep_passage_num() + gen_num = self._get_gen_num() + repeat_weight = self._get_repeat_weight() + max_demo_len = self.cfg.params.get("max_demo_len") + lowercase = bool(self.cfg.params.get("lowercase", True)) + no_thinking = bool(self.cfg.params.get("no_thinking", False)) + temperature = float(self.cfg.llm.get("temperature", 0.7)) + max_tokens = int(self.cfg.llm.get("max_tokens", 32768)) + + top_passages = ctxs[:keep_passage_num] + expansions, raw_responses = self._generate_expansions( + q.text, + top_passages, + gen_num, + max_demo_len, + temperature, + max_tokens, + no_thinking, + ) + reformulated, q_repeat = self._build_query( + q.text, + expansions, + repeat_weight, + lowercase, + ) + + return ReformulationResult( + q.qid, + q.text, + reformulated, + metadata={ + "prompt_id": self._get_prompt_id(), + "gen_num": gen_num, + "q_repeat": q_repeat, + "repeat_weight": repeat_weight, + "keep_passage_num": keep_passage_num, + "used_ctx": len(top_passages), + "expansions": expansions, + "raw_responses": raw_responses, + }, + ) + + def reformulate_multi_round( + self, + q: QueryItem, + retriever, + *, + num_interaction: Optional[int] = None, + accumulate: Optional[bool] = None, + use_passage_filter: Optional[bool] = None, + ) -> List[ReformulationResult]: + """Run the full ThinkQE loop and return per-round results.""" + if num_interaction is None: + num_interaction = int(self.cfg.params.get("num_interaction", 3)) + if accumulate is None: + accumulate = bool(self.cfg.params.get("accumulate", True)) + if use_passage_filter is None: + use_passage_filter = bool(self.cfg.params.get("use_passage_filter", True)) + + keep_passage_num = self._get_keep_passage_num() + gen_num = self._get_gen_num() + repeat_weight = self._get_repeat_weight() + max_demo_len = self.cfg.params.get("max_demo_len") + lowercase = bool(self.cfg.params.get("lowercase", True)) + no_thinking = bool(self.cfg.params.get("no_thinking", False)) + temperature = float(self.cfg.llm.get("temperature", 0.7)) + max_tokens = int(self.cfg.llm.get("max_tokens", 32768)) + + accumulated_expansions: List[str] = [] + seen_passages: Set[str] = set() + last_top_k: Set[str] = set() + prev_prev_top_k: Set[str] = set() + + original_query_text = q.text + current_query_text = q.text + results: List[ReformulationResult] = [] + + last_top_passages = retriever(current_query_text) + last_top_k = set(last_top_passages[:keep_passage_num]) + results.append( + ReformulationResult( + q.qid, + q.text, + current_query_text, + metadata={ + "round": 0, + "type": "baseline", + "retrieved_count": len(last_top_passages), + }, + ) + ) + + for ridx in range(1, num_interaction + 1): + all_passages = last_top_passages + if use_passage_filter: + top_passages, seen_passages = self._filter_passages( + all_passages, + keep_passage_num, + seen_passages, + prev_prev_top_k, + ) + else: + top_passages = all_passages[:keep_passage_num] + + expansions, raw_responses = self._generate_expansions( + original_query_text, + top_passages, + gen_num, + max_demo_len, + temperature, + max_tokens, + no_thinking, + ) + + if accumulate: + accumulated_expansions.extend(expansions) + active_expansions = list(accumulated_expansions) + else: + active_expansions = expansions + + reformulated, q_repeat = self._build_query( + original_query_text, + active_expansions, + repeat_weight, + lowercase, + ) + current_query_text = reformulated + retrieved_passages = retriever(current_query_text) + + prev_prev_top_k = last_top_k + last_top_k = set(retrieved_passages[:keep_passage_num]) + last_top_passages = retrieved_passages + + results.append( + ReformulationResult( + q.qid, + q.text, + reformulated, + metadata={ + "round": ridx, + "prompt_id": self._get_prompt_id(), + "gen_num": gen_num, + "q_repeat": q_repeat, + "repeat_weight": repeat_weight, + "keep_passage_num": keep_passage_num, + "used_ctx": len(top_passages), + "search_k": self._get_search_k(), + "accumulate": accumulate, + "accumulated_count": len(active_expansions), + "use_passage_filter": use_passage_filter, + "expansions": expansions, + "raw_responses": raw_responses, + "retrieved_count": len(retrieved_passages), + }, + ) + ) + + return results + + def reformulate_with_round_history( + self, + q: QueryItem, + contexts=None, + ) -> Tuple[ReformulationResult, List[ReformulationResult]]: + """Return the final result together with all round outputs.""" + retriever = self._build_retriever() + if int(self.cfg.params.get("num_interaction", 3)) > 0 and retriever is not None: + round_results = self.reformulate_multi_round(q, retriever) + final_result = round_results[-1] + final_metadata = dict(final_result.metadata or {}) + final_metadata["round_history"] = [ + self._summarize_round(result) for result in round_results + ] + final_result = ReformulationResult( + final_result.qid, + final_result.original, + final_result.reformulated, + metadata=final_metadata, + ) + return final_result, round_results + + single_result = self.reformulate(q, contexts) + single_metadata = dict(single_result.metadata or {}) + single_metadata["round_history"] = [self._summarize_round(single_result)] + single_result = ReformulationResult( + single_result.qid, + single_result.original, + single_result.reformulated, + metadata=single_metadata, + ) + return single_result, [single_result] + + def reformulate_batch( + self, + queries: List[QueryItem], + ctx_map=None, + ) -> List[ReformulationResult]: + """Auto-dispatch to the multi-round loop when retrieval is available.""" + if self._multi_round_enabled(): + results: List[ReformulationResult] = [] + for query in queries: + final_result, _ = self.reformulate_with_round_history(query) + results.append(final_result) + return results + return super().reformulate_batch(queries, ctx_map) + + def _get_retrieval_params(self) -> Optional[Dict[str, Any]]: + """Return retrieval settings for single-step batch reformulation.""" + keep_passage_num = self._get_keep_passage_num() + + if "searcher" in self.cfg.params: + return { + "searcher": self.cfg.params["searcher"], + "k": keep_passage_num, + "threads": int(self.cfg.params.get("threads", 16)), + } + + if "searcher_type" in self.cfg.params: + searcher_kwargs = dict(self.cfg.params.get("searcher_kwargs", {})) + if "index" in self.cfg.params and "index" not in searcher_kwargs: + searcher_kwargs["index"] = self.cfg.params["index"] + return { + "searcher_type": self.cfg.params["searcher_type"], + "searcher_kwargs": searcher_kwargs, + "k": keep_passage_num, + "threads": int(self.cfg.params.get("threads", 16)), + } + + index = self.cfg.params.get("index") + if not index: + return None + + searcher_kwargs: Dict[str, Any] = { + "index": index, + "answer_key": self.cfg.params.get("answer_key", "contents"), + } + k1 = self.cfg.params.get("k1") + b = self.cfg.params.get("b") + if k1 is not None and b is not None: + searcher_kwargs["k1"] = k1 + searcher_kwargs["b"] = b + + return { + "searcher_type": "pyserini", + "searcher_kwargs": searcher_kwargs, + "k": keep_passage_num, + "threads": int(self.cfg.params.get("threads", 16)), + } diff --git a/querygym/prompt_bank.yaml b/querygym/prompt_bank.yaml index ce3dff7..6fe685e 100644 --- a/querygym/prompt_bank.yaml +++ b/querygym/prompt_bank.yaml @@ -397,6 +397,24 @@ notes: | LameR variant for news article retrieval. Tailored for TREC-News dataset with relevant news passages. +- id: thinkqe.v1 + method_family: thinkqe + version: 1 + introduced_by: "ThinkQE: Query Expansion via an Evolving Thinking Process" + license: "CC-BY-4.0" + authors: ["Haison Le", "queryGym Team"] + tags: ["context-based", "thinking", "passage-synthesis", "reasoning"] + template: + user: | + Given a question "{query}" and its possible answering passages (most of these passages are wrong) enumerated as: + {contexts} + + please write a correct answering passage. Use your own knowledge, not just the example passages! + notes: | + ThinkQE uses retrieved passages as evidence, strips any ... + reasoning trace, and keeps only the final generated answering passage for + query expansion. + - id: keqe.v1 method_family: keqe version: 1 @@ -483,4 +501,4 @@ Query: {query} Keywords: notes: | - Few-shot Q2E reformulation prompt \ No newline at end of file + Few-shot Q2E reformulation prompt diff --git a/tests/test_method_thinkqe.py b/tests/test_method_thinkqe.py new file mode 100644 index 0000000..c6f3a3c --- /dev/null +++ b/tests/test_method_thinkqe.py @@ -0,0 +1,114 @@ +from pathlib import Path + +from querygym.core.base import MethodConfig, QueryItem +from querygym.core.prompts import PromptBank +from querygym.core.searcher import BaseSearcher, SearchHit +from querygym.methods.thinkqe import ThinkQE + + +class DummyLLM: + def __init__(self, responses): + self.responses = list(responses) + self.messages = [] + + def chat(self, messages, **kwargs): + self.messages.append(messages) + return self.responses.pop(0) + + +class DummySearcher(BaseSearcher): + def __init__(self): + self.queries = [] + + def search(self, query: str, k: int = 10, **kwargs): + self.queries.append(query) + if "expansion one" in query: + passages = [ + "Fresh support passage one", + "Fresh support passage two", + "Initial evidence passage one", + ] + else: + passages = [ + "Initial evidence passage one", + "Initial evidence passage two", + "Initial evidence passage three", + ] + return [ + SearchHit(docid=f"d{idx}", score=1.0 / (idx + 1), content=passage) + for idx, passage in enumerate(passages[:k]) + ] + + def batch_search(self, queries, k: int = 10, num_threads: int = 1, **kwargs): + return [self.search(query, k=k, **kwargs) for query in queries] + + def get_searcher_info(self): + return {"name": "dummy"} + + +def _prompt_bank(): + return PromptBank(Path(__file__).parents[1] / "querygym" / "prompt_bank.yaml") + + +def test_thinkqe_single_round_strips_thinking_trace(): + cfg = MethodConfig( + name="thinkqe", + params={"gen_num": 2, "repeat_weight": 3, "keep_passage_num": 2}, + llm={"model": "dummy", "temperature": 0.7, "max_tokens": 256}, + ) + llm = DummyLLM( + [ + "\nreasoning\n\nAnswer Passage One", + "\nmore reasoning\n\nAnswer Passage Two", + ] + ) + method = ThinkQE(cfg, llm, _prompt_bank()) + + result = method.reformulate( + QueryItem("q1", "test query"), + ["Context one", "Context two", "Context three"], + ) + + assert result.metadata["expansions"] == [ + "Answer Passage One", + "Answer Passage Two", + ] + assert result.metadata["used_ctx"] == 2 + assert "answer passage one" in result.reformulated + assert "answer passage two" in result.reformulated + + +def test_thinkqe_multi_round_uses_searcher_and_records_history(): + cfg = MethodConfig( + name="thinkqe", + params={ + "searcher": DummySearcher(), + "num_interaction": 2, + "accumulate": True, + "use_passage_filter": True, + "gen_num": 1, + "keep_passage_num": 2, + "search_k": 3, + }, + llm={"model": "dummy", "temperature": 0.7, "max_tokens": 256}, + ) + llm = DummyLLM( + [ + "\ntrace\n\nExpansion One", + "\ntrace\n\nExpansion Two", + ] + ) + method = ThinkQE(cfg, llm, _prompt_bank()) + + results = method.reformulate_batch([QueryItem("q1", "original query")]) + result = results[0] + round_history = result.metadata["round_history"] + + assert len(round_history) == 3 + assert round_history[0]["metadata"]["retrieved_count"] == 3 + assert round_history[-1]["metadata"]["accumulated_count"] == 2 + assert "expansion one" in result.reformulated + assert "expansion two" in result.reformulated + assert "original query" in llm.messages[0][0]["content"] + assert "Initial evidence passage one" in llm.messages[0][0]["content"] + assert "Fresh support passage one" in llm.messages[1][0]["content"] From e38bee052b9e543352a950134a2756aab8666676 Mon Sep 17 00:00:00 2001 From: haisonle001 Date: Mon, 30 Mar 2026 20:18:25 -0400 Subject: [PATCH 2/2] Add ThinkQE reformulation method --- querygym/__main__.py | 2 +- querygym/adapters/__init__.py | 1 - querygym/adapters/pyserini_adapter.py | 109 +++++----- querygym/adapters/pyterrier_adapter.py | 116 +++++----- querygym/cli.py | 289 +++++++++++++++++-------- querygym/core/base.py | 155 +++++++------ querygym/core/llm.py | 23 +- querygym/core/prompts.py | 71 +++--- querygym/core/registry.py | 2 + querygym/core/searcher.py | 44 ++-- querygym/core/searcher_wrappers.py | 164 +++++++------- querygym/core/utils.py | 6 +- querygym/data/dataloader.py | 218 ++++++++----------- querygym/loaders/beir.py | 88 +++----- querygym/loaders/msmarco.py | 73 +++---- querygym/methods/csqe.py | 59 ++--- querygym/methods/genqr.py | 21 +- querygym/methods/genqr_ensemble.py | 64 +++--- querygym/methods/lamer.py | 25 ++- querygym/methods/mugi.py | 50 +++-- querygym/methods/qa_expand.py | 111 +++++----- querygym/methods/query2doc.py | 145 +++++++------ querygym/methods/query2e.py | 215 +++++++++--------- tests/test_cli_script_gen.py | 5 +- tests/test_dataloader.py | 1 + tests/test_methods_genqr.py | 6 +- tests/test_prompts.py | 1 + 27 files changed, 1096 insertions(+), 968 deletions(-) diff --git a/querygym/__main__.py b/querygym/__main__.py index 4a0fdaf..f4b201c 100644 --- a/querygym/__main__.py +++ b/querygym/__main__.py @@ -1,6 +1,6 @@ """Entry point for running queryGym as a module: python -m queryGym""" + from .cli import app if __name__ == "__main__": app() - diff --git a/querygym/adapters/__init__.py b/querygym/adapters/__init__.py index e9cf8bf..da4ea89 100644 --- a/querygym/adapters/__init__.py +++ b/querygym/adapters/__init__.py @@ -13,4 +13,3 @@ "PyseriniSearcher", "PyTerrierSearcher", ] - diff --git a/querygym/adapters/pyserini_adapter.py b/querygym/adapters/pyserini_adapter.py index 101c4b1..104bca1 100644 --- a/querygym/adapters/pyserini_adapter.py +++ b/querygym/adapters/pyserini_adapter.py @@ -15,20 +15,28 @@ class PyseriniSearcher(BaseSearcher): """ Pyserini adapter implementing the BaseSearcher interface. - + This adapter wraps Pyserini's LuceneSearcher and LuceneImpactSearcher to provide a standardized interface for querygym. """ - - def __init__(self, index: str, searcher_type: str = "bm25", - encoder: Optional[str] = None, min_idf: float = 0, - k1: Optional[float] = None, b: Optional[float] = None, - rm3: bool = False, rocchio: bool = False, - rocchio_use_negative: bool = False, - answer_key: str = "contents", **kwargs): + + def __init__( + self, + index: str, + searcher_type: str = "bm25", + encoder: Optional[str] = None, + min_idf: float = 0, + k1: Optional[float] = None, + b: Optional[float] = None, + rm3: bool = False, + rocchio: bool = False, + rocchio_use_negative: bool = False, + answer_key: str = "contents", + **kwargs, + ): """ Initialize Pyserini searcher. - + Args: index: Path to Lucene index or prebuilt index name searcher_type: Type of searcher ("bm25" or "impact") @@ -45,8 +53,10 @@ def __init__(self, index: str, searcher_type: str = "bm25", try: from pyserini.search.lucene import LuceneSearcher, LuceneImpactSearcher except ImportError: - raise ImportError("Pyserini is required for PyseriniSearcher. Install with: pip install pyserini") - + raise ImportError( + "Pyserini is required for PyseriniSearcher. Install with: pip install pyserini" + ) + self.index = index self.searcher_type = searcher_type self.answer_key = answer_key @@ -61,9 +71,9 @@ def __init__(self, index: str, searcher_type: str = "bm25", "rm3": rm3, "rocchio": rocchio, "rocchio_use_negative": rocchio_use_negative, - "answer_key": answer_key + "answer_key": answer_key, } - + # Initialize searcher based on type if searcher_type == "impact": if os.path.exists(index): @@ -75,113 +85,108 @@ def __init__(self, index: str, searcher_type: str = "bm25", self.searcher = LuceneSearcher(index) else: self.searcher = LuceneSearcher.from_prebuilt_index(index) - + # Set BM25 parameters if specified if k1 is not None and b is not None: self.searcher.set_bm25(k1, b) else: # Auto-set based on known indices self._set_auto_bm25_params() - + # Add RM3 if requested if rm3: self.searcher.set_rm3() - + # Add Rocchio if requested if rocchio: if rocchio_use_negative: self.searcher.set_rocchio(gamma=0.15, use_negative=True) else: self.searcher.set_rocchio() - + def _set_auto_bm25_params(self): """Set BM25 parameters based on index type.""" - if 'msmarco-passage' in self.index: + if "msmarco-passage" in self.index: self.searcher.set_bm25(0.82, 0.68) - elif 'msmarco-doc' in self.index: + elif "msmarco-doc" in self.index: self.searcher.set_bm25(4.46, 0.82) - + def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]: """Search for documents using a single query.""" hits = self.searcher.search(query, k) return self._process_hits(hits) - - def batch_search(self, queries: List[str], k: int = 10, - num_threads: int = 1, **kwargs) -> List[List[SearchHit]]: + + def batch_search( + self, queries: List[str], k: int = 10, num_threads: int = 1, **kwargs + ) -> List[List[SearchHit]]: """Search for documents using multiple queries in batch.""" pseudo_batch_topic_ids = [str(idx) for idx, _ in enumerate(queries)] - results = self.searcher.batch_search( - queries, pseudo_batch_topic_ids, k, num_threads - ) - + results = self.searcher.batch_search(queries, pseudo_batch_topic_ids, k, num_threads) + # Convert results to list in original order batch_results = [results[id_] for id_ in pseudo_batch_topic_ids] - + # Process each result processed_results = [] for hits in batch_results: processed_results.append(self._process_hits(hits)) - + return processed_results - + def _process_hits(self, hits) -> List[SearchHit]: """Process hits and extract content.""" search_hits = [] - + for hit in hits: try: # Extract content using hit.lucene_document.get('raw') - raw_content = hit.lucene_document.get('raw') + raw_content = hit.lucene_document.get("raw") if raw_content: raw_df = json.loads(raw_content) text_list = [raw_df[k] for k in self.answer_key.split("|") if raw_df.get(k)] content = "\t".join(text_list) else: # Fallback: try to get content from other fields - content = getattr(hit, 'contents', '') or getattr(hit, 'text', '') or str(hit.docid) - + content = ( + getattr(hit, "contents", "") or getattr(hit, "text", "") or str(hit.docid) + ) + search_hit = SearchHit( docid=hit.docid, score=hit.score, content=content, - metadata={ - "raw_content": raw_content, - "answer_key": self.answer_key - } + metadata={"raw_content": raw_content, "answer_key": self.answer_key}, ) search_hits.append(search_hit) - + except (json.JSONDecodeError, KeyError, AttributeError) as e: # Fallback for malformed or missing content search_hit = SearchHit( - docid=getattr(hit, 'docid', 'unknown'), - score=getattr(hit, 'score', 0.0), - content=str(getattr(hit, 'docid', 'unknown')), - metadata={ - "error": str(e), - "fallback": True - } + docid=getattr(hit, "docid", "unknown"), + score=getattr(hit, "score", 0.0), + content=str(getattr(hit, "docid", "unknown")), + metadata={"error": str(e), "fallback": True}, ) search_hits.append(search_hit) - + return search_hits - + def get_searcher_info(self) -> Dict[str, Any]: """Get information about the searcher configuration.""" return self._searcher_info.copy() - + def configure(self, **kwargs) -> None: """Configure searcher parameters.""" # Update searcher info with new configuration self._searcher_info.update(kwargs) - + # Apply configuration if possible if "k1" in kwargs and "b" in kwargs: self.searcher.set_bm25(kwargs["k1"], kwargs["b"]) - + if "rm3" in kwargs and kwargs["rm3"]: self.searcher.set_rm3() - + if "rocchio" in kwargs and kwargs["rocchio"]: if kwargs.get("rocchio_use_negative", False): self.searcher.set_rocchio(gamma=0.15, use_negative=True) diff --git a/querygym/adapters/pyterrier_adapter.py b/querygym/adapters/pyterrier_adapter.py index 5d4a7d6..6d86c17 100644 --- a/querygym/adapters/pyterrier_adapter.py +++ b/querygym/adapters/pyterrier_adapter.py @@ -14,17 +14,21 @@ class PyTerrierSearcher(BaseSearcher): """ PyTerrier adapter implementing the BaseSearcher interface. - + This adapter wraps PyTerrier's search functionality to provide a standardized interface for querygym. """ - - def __init__(self, index_path: Optional[str] = None, - index_name: Optional[str] = None, - searcher_type: str = "bm25", **kwargs): + + def __init__( + self, + index_path: Optional[str] = None, + index_name: Optional[str] = None, + searcher_type: str = "bm25", + **kwargs, + ): """ Initialize PyTerrier searcher. - + Args: index_path: Path to PyTerrier index index_name: Name of prebuilt PyTerrier index @@ -34,20 +38,22 @@ def __init__(self, index_path: Optional[str] = None, try: import pyterrier as pt except ImportError: - raise ImportError("PyTerrier is required for PyTerrierSearcher. Install with: pip install python-terrier") - + raise ImportError( + "PyTerrier is required for PyTerrierSearcher. Install with: pip install python-terrier" + ) + self.searcher_type = searcher_type self._searcher_info = { "name": "PyTerrierSearcher", "type": searcher_type, "index_path": index_path, - "index_name": index_name + "index_name": index_name, } - + # Initialize PyTerrier if not already done if not pt.started(): pt.init() - + # Load index if index_path: self.index = pt.IndexRef.of(index_path) @@ -59,14 +65,14 @@ def __init__(self, index_path: Optional[str] = None, raise ValueError(f"Could not load prebuilt index '{index_name}': {e}") else: raise ValueError("Either index_path or index_name must be provided") - + # Create searcher based on type self.searcher = self._create_searcher(searcher_type, **kwargs) - + def _create_searcher(self, searcher_type: str, **kwargs): """Create the appropriate searcher based on type.""" import pyterrier as pt - + # Create a proper PyTerrier Retriever, not a pipeline component if searcher_type == "bm25": return pt.terrier.Retriever(self.index, wmodel="BM25", **kwargs) @@ -81,94 +87,94 @@ def _create_searcher(self, searcher_type: str, **kwargs): else: # Default to BM25 return pt.terrier.Retriever(self.index, wmodel="BM25", **kwargs) - + def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]: """Search for documents using a single query.""" # Use PyTerrier's search method which returns a DataFrame search_results = self.searcher.search(query) - + # Limit to top-k results search_results = search_results.head(k) - + return self._process_results(search_results) - - def batch_search(self, queries: List[str], k: int = 10, - num_threads: int = 1, **kwargs) -> List[List[SearchHit]]: + + def batch_search( + self, queries: List[str], k: int = 10, num_threads: int = 1, **kwargs + ) -> List[List[SearchHit]]: """Search for documents using multiple queries in batch.""" # Create query DataFrame for batch processing - query_df = pd.DataFrame({ - 'qid': [str(i) for i in range(len(queries))], - 'query': queries - }) - + query_df = pd.DataFrame({"qid": [str(i) for i in range(len(queries))], "query": queries}) + # Use PyTerrier's transform method for batch processing search_results = self.searcher.transform(query_df) - + # Group results by query ID batch_results = [] for qid in range(len(queries)): - query_results = search_results[search_results['qid'] == str(qid)].head(k) + query_results = search_results[search_results["qid"] == str(qid)].head(k) batch_results.append(self._process_results(query_results)) - + return batch_results - + def _process_results(self, results_df: pd.DataFrame) -> List[SearchHit]: """Process PyTerrier results into SearchHit objects.""" search_hits = [] - + if results_df.empty: return search_hits - + for _, row in results_df.iterrows(): try: # Handle different possible column names - docid = str(row.get('docno', row.get('docid', ''))) - score = float(row.get('score', 0.0)) - + docid = str(row.get("docno", row.get("docid", ""))) + score = float(row.get("score", 0.0)) + # Try different content field names - content = (row.get('text', '') or - row.get('content', '') or - row.get('body', '') or - str(docid)) - + content = ( + row.get("text", "") + or row.get("content", "") + or row.get("body", "") + or str(docid) + ) + search_hit = SearchHit( docid=docid, score=score, content=str(content), metadata={ - 'qid': str(row.get('qid', '')), - 'docno': str(row.get('docno', '')), - 'searcher_type': self.searcher_type, - 'available_columns': list(row.index) - } + "qid": str(row.get("qid", "")), + "docno": str(row.get("docno", "")), + "searcher_type": self.searcher_type, + "available_columns": list(row.index), + }, ) search_hits.append(search_hit) - + except Exception as e: # Fallback for malformed data search_hit = SearchHit( - docid=str(row.get('docno', 'unknown')), + docid=str(row.get("docno", "unknown")), score=0.0, - content=str(row.get('docno', 'unknown')), + content=str(row.get("docno", "unknown")), metadata={ - 'error': str(e), - 'fallback': True, - 'searcher_type': self.searcher_type - } + "error": str(e), + "fallback": True, + "searcher_type": self.searcher_type, + }, ) search_hits.append(search_hit) - + return search_hits - + def get_searcher_info(self) -> Dict[str, Any]: """Get information about the searcher configuration.""" return self._searcher_info.copy() - + def configure(self, **kwargs) -> None: """Configure searcher parameters.""" # Update searcher info self._searcher_info.update(kwargs) - + # Recreate searcher with new parameters self.searcher = self._create_searcher(self.searcher_type, **kwargs) diff --git a/querygym/cli.py b/querygym/cli.py index d9a47db..2cc98b8 100644 --- a/querygym/cli.py +++ b/querygym/cli.py @@ -7,10 +7,18 @@ from .core.runner import run_method from .core.prompts import PromptBank from .data.dataloader import UnifiedQuerySource + app = typer.Typer(help="querygym Toolkit CLI") -def build_script_lines(index_path:str, topics:str, run:str, qrels:Optional[str]=None, - bm25:bool=True, extra:str="") -> list[str]: + +def build_script_lines( + index_path: str, + topics: str, + run: str, + qrels: Optional[str] = None, + bm25: bool = True, + extra: str = "", +) -> list[str]: lines = [ "#!/usr/bin/env bash", "set -euo pipefail", @@ -28,53 +36,78 @@ def build_script_lines(index_path:str, topics:str, run:str, qrels:Optional[str]= if qrels: lines += [ "# trec_eval", - f"trec_eval -m map -m P.10 -m ndcg_cut.10 {qrels} {run} | tee {run}.eval.txt" + f"trec_eval -m map -m P.10 -m ndcg_cut.10 {qrels} {run} | tee {run}.eval.txt", ] return lines + @app.command() -def run(method: str = typer.Option(...), - queries_tsv: Path = typer.Option(..., "--queries-tsv", exists=True), - output_tsv: Path = typer.Option(..., "--output-tsv"), - cfg_path: Optional[Path] = typer.Option(None, "--cfg-path"), - prompt_bank: Path = typer.Option(Path(__file__).with_name("prompt_bank.yaml"), "--prompt-bank"), - ctx_jsonl: Optional[Path] = typer.Option(None, "--ctx-jsonl", help="Optional contexts JSONL"), - output_format: str = typer.Option("both", "--output-format", help="Output format: 'concat', 'plain', or 'both' (default: both)"), - parallel: Optional[bool] = typer.Option(None, "--parallel", help="Enable parallel generation for methods like MuGI"), - mode: Optional[str] = typer.Option(None, "--mode", help="Mode for methods: 'zs' (zero-shot) or 'fs' (few-shot)"), - # Few-shot data loading options - dataset_type: Optional[str] = typer.Option(None, "--dataset-type", help="Dataset type: 'msmarco', 'beir', or 'generic'"), - collection_path: Optional[Path] = typer.Option(None, "--collection-path", help="Path to collection file (TSV for MS MARCO/generic)"), - train_queries_path: Optional[Path] = typer.Option(None, "--train-queries-path", help="Path to training queries file"), - train_qrels_path: Optional[Path] = typer.Option(None, "--train-qrels-path", help="Path to training qrels file"), - beir_data_dir: Optional[Path] = typer.Option(None, "--beir-data-dir", help="Path to BEIR dataset directory"), - train_split: Optional[str] = typer.Option(None, "--train-split", help="BEIR train split: 'train' or 'dev'"), - num_examples: Optional[int] = typer.Option(None, "--num-examples", help="Number of few-shot examples"), +def run( + method: str = typer.Option(...), + queries_tsv: Path = typer.Option(..., "--queries-tsv", exists=True), + output_tsv: Path = typer.Option(..., "--output-tsv"), + cfg_path: Optional[Path] = typer.Option(None, "--cfg-path"), + prompt_bank: Path = typer.Option(Path(__file__).with_name("prompt_bank.yaml"), "--prompt-bank"), + ctx_jsonl: Optional[Path] = typer.Option(None, "--ctx-jsonl", help="Optional contexts JSONL"), + output_format: str = typer.Option( + "both", + "--output-format", + help="Output format: 'concat', 'plain', or 'both' (default: both)", + ), + parallel: Optional[bool] = typer.Option( + None, "--parallel", help="Enable parallel generation for methods like MuGI" + ), + mode: Optional[str] = typer.Option( + None, "--mode", help="Mode for methods: 'zs' (zero-shot) or 'fs' (few-shot)" + ), + # Few-shot data loading options + dataset_type: Optional[str] = typer.Option( + None, "--dataset-type", help="Dataset type: 'msmarco', 'beir', or 'generic'" + ), + collection_path: Optional[Path] = typer.Option( + None, "--collection-path", help="Path to collection file (TSV for MS MARCO/generic)" + ), + train_queries_path: Optional[Path] = typer.Option( + None, "--train-queries-path", help="Path to training queries file" + ), + train_qrels_path: Optional[Path] = typer.Option( + None, "--train-qrels-path", help="Path to training qrels file" + ), + beir_data_dir: Optional[Path] = typer.Option( + None, "--beir-data-dir", help="Path to BEIR dataset directory" + ), + train_split: Optional[str] = typer.Option( + None, "--train-split", help="BEIR train split: 'train' or 'dev'" + ), + num_examples: Optional[int] = typer.Option( + None, "--num-examples", help="Number of few-shot examples" + ), ): import yaml import os import re - + def expand_env_vars(text): """Expand environment variables in YAML content""" + def replace_env_var(match): var_expr = match.group(1) - if ':-' in var_expr: - var_name, default_value = var_expr.split(':-', 1) + if ":-" in var_expr: + var_name, default_value = var_expr.split(":-", 1) return os.getenv(var_name, default_value) else: - return os.getenv(var_expr, '') - - return re.sub(r'\$\{([^}]+)\}', replace_env_var, text) - + return os.getenv(var_expr, "") + + return re.sub(r"\$\{([^}]+)\}", replace_env_var, text) + # Default to defaults.yaml if no config path provided if cfg_path is None: cfg_path = Path(__file__).parent / "config" / "defaults.yaml" - + yaml_content = cfg_path.read_text() expanded_content = expand_env_vars(yaml_content) cfg = yaml.safe_load(expanded_content) - + # Override parameters if provided via CLI params = cfg.get("params", {}) if parallel is not None: @@ -95,7 +128,7 @@ def replace_env_var(match): params["train_split"] = train_split if num_examples is not None: params["num_examples"] = num_examples - + # Handle context/examples loading based on method ctx_map = None if ctx_jsonl: @@ -104,6 +137,7 @@ def replace_env_var(match): # Auto-set mode to "fs" if not explicitly set to "zs" if mode not in ["zs", "zeroshot"]: from .data.dataloader import DataLoader + examples = DataLoader.load_examples(ctx_jsonl) params["examples"] = examples # Auto-set mode to fs if not already set @@ -116,33 +150,46 @@ def replace_env_var(match): else: # For other methods (LameR, CSQE): load as per-query contexts from .data.dataloader import UnifiedContextSource + ctx_src = UnifiedContextSource(mode="file", path=ctx_jsonl) - ctx_map = ctx_src.load(list(UnifiedQuerySource(backend="local", format="tsv", path=queries_tsv).iter())) - - mc = MethodConfig(name=method, params=params, llm=cfg["llm"], - seed=cfg.get("seed",42), retries=cfg.get("retries",2)) + ctx_map = ctx_src.load( + list(UnifiedQuerySource(backend="local", format="tsv", path=queries_tsv).iter()) + ) + + mc = MethodConfig( + name=method, + params=params, + llm=cfg["llm"], + seed=cfg.get("seed", 42), + retries=cfg.get("retries", 2), + ) src = UnifiedQuerySource(backend="local", format="tsv", path=queries_tsv) queries = list(src.iter()) - + # Pass ctx_map as-is (None if not provided, dict if provided) - results = run_method(method_name=method, cfg=mc, queries=queries, - prompt_bank_path=str(prompt_bank), ctx_map=ctx_map) - + results = run_method( + method_name=method, + cfg=mc, + queries=queries, + prompt_bank_path=str(prompt_bank), + ctx_map=ctx_map, + ) + # Show progress summary typer.echo(f"Processed {len(results)} queries with {method}") - + def write_concat_format(output_path): """Write concatenated format to file""" with open(output_path, "w", encoding="utf-8") as f: - w = csv.writer(f, delimiter="\t", quoting=csv.QUOTE_NONE, escapechar='\\') + w = csv.writer(f, delimiter="\t", quoting=csv.QUOTE_NONE, escapechar="\\") for r in results: w.writerow([r.qid, r.reformulated]) - + def write_plain_format(output_path): """Write plain format to file""" with open(output_path, "w", encoding="utf-8") as f: - w = csv.writer(f, delimiter="\t", quoting=csv.QUOTE_NONE, escapechar='\\') - + w = csv.writer(f, delimiter="\t", quoting=csv.QUOTE_NONE, escapechar="\\") + # Write header row based on method if method == "lamer": # LameR: qid \t passage_1 \t passage_2 \t ... \t passage_n @@ -157,11 +204,13 @@ def write_plain_format(output_path): reformulated = results[0].reformulated parts = reformulated.replace(original, "").strip().split() num_passages = len(parts) - + header = ["qid"] + [f"passage_{i+1}" for i in range(num_passages)] w.writerow(header) else: - w.writerow(["qid", "passage_1", "passage_2", "passage_3", "passage_4", "passage_5"]) + w.writerow( + ["qid", "passage_1", "passage_2", "passage_3", "passage_4", "passage_5"] + ) elif method == "query2doc": # Query2Doc: qid \t passage w.writerow(["qid", "passage"]) @@ -177,7 +226,16 @@ def write_plain_format(output_path): header = ["qid"] + [f"reformulation_{i+1}" for i in range(num_reformulations)] w.writerow(header) else: - w.writerow(["qid", "reformulation_1", "reformulation_2", "reformulation_3", "reformulation_4", "reformulation_5"]) + w.writerow( + [ + "qid", + "reformulation_1", + "reformulation_2", + "reformulation_3", + "reformulation_4", + "reformulation_5", + ] + ) elif method == "qa_expand": # QA-Expand: qid \t final_refined_query w.writerow(["qid", "refined_query"]) @@ -196,7 +254,21 @@ def write_plain_format(output_path): header = ["qid"] + [f"keyword_{i+1}" for i in range(num_keywords)] w.writerow(header) else: - w.writerow(["qid", "keyword_1", "keyword_2", "keyword_3", "keyword_4", "keyword_5", "keyword_6", "keyword_7", "keyword_8", "keyword_9", "keyword_10"]) + w.writerow( + [ + "qid", + "keyword_1", + "keyword_2", + "keyword_3", + "keyword_4", + "keyword_5", + "keyword_6", + "keyword_7", + "keyword_8", + "keyword_9", + "keyword_10", + ] + ) elif method == "query2e": # Query2E: qid \t keyword_1 \t keyword_2 \t ... \t keyword_n # Determine number of keywords from first result @@ -213,7 +285,7 @@ def write_plain_format(output_path): else: # Default fallback w.writerow(["qid", "generated_content"]) - + # Write data rows for r in results: if method == "lamer": @@ -284,33 +356,34 @@ def write_plain_format(output_path): plain_output = r.reformulated.replace(r.original, "").strip() cleaned_output = clean_text(plain_output) w.writerow([r.qid, cleaned_output]) - + def clean_text(text): """Clean text by removing newlines, extra whitespace, and other formatting issues""" if not text: return "" - + # Convert to string if not already text = str(text) - + # Remove newlines and carriage returns - text = text.replace('\n', ' ').replace('\r', ' ') - + text = text.replace("\n", " ").replace("\r", " ") + # Remove backslashes that might be escape characters - text = text.replace('\\', ' ') - + text = text.replace("\\", " ") + # Replace multiple spaces with single space import re - text = re.sub(r'\s+', ' ', text) - + + text = re.sub(r"\s+", " ", text) + # Strip leading/trailing whitespace text = text.strip() - + # Remove quotes from beginning and end text = text.strip('"').strip("'") - + return text - + # Generate output files based on format if output_format == "concat": write_concat_format(output_tsv) @@ -323,71 +396,101 @@ def clean_text(text): output_dir = output_tsv.parent output_stem = output_tsv.stem output_suffix = output_tsv.suffix - + concat_file = output_dir / f"{output_stem}_concat{output_suffix}" plain_file = output_dir / f"{output_stem}_plain{output_suffix}" - + write_concat_format(concat_file) write_plain_format(plain_file) - + typer.echo(f"Wrote both formats:") typer.echo(f" Concat: {concat_file}") typer.echo(f" Plain: {plain_file}") else: - raise ValueError(f"Invalid output_format: {output_format}. Must be 'concat', 'plain', or 'both'") + raise ValueError( + f"Invalid output_format: {output_format}. Must be 'concat', 'plain', or 'both'" + ) + @app.command("data-to-tsv") -def data_to_tsv(backend: str = typer.Option(...), - source: str = typer.Option(...), - out: Path = typer.Option(..., "--out"), - path: Optional[Path] = typer.Option(None, help="Local TSV/JSONL path"), - format: Optional[str] = typer.Option(None, help="tsv|jsonl"), - tsv_qid_col: int = 0, tsv_query_col: int = 1, - jsonl_qid_key: str = "qid", jsonl_query_key: str = "query", - msmarco_queries_tsv: Optional[Path] = None, - hf_name: Optional[str] = None, hf_config: Optional[str] = None, - split: str = "dev", hf_qid_key: str = "query_id", hf_query_key: str = "query", - beir_root: Optional[Path] = None, beir_name: Optional[str] = None, +def data_to_tsv( + backend: str = typer.Option(...), + source: str = typer.Option(...), + out: Path = typer.Option(..., "--out"), + path: Optional[Path] = typer.Option(None, help="Local TSV/JSONL path"), + format: Optional[str] = typer.Option(None, help="tsv|jsonl"), + tsv_qid_col: int = 0, + tsv_query_col: int = 1, + jsonl_qid_key: str = "qid", + jsonl_query_key: str = "query", + msmarco_queries_tsv: Optional[Path] = None, + hf_name: Optional[str] = None, + hf_config: Optional[str] = None, + split: str = "dev", + hf_qid_key: str = "query_id", + hf_query_key: str = "query", + beir_root: Optional[Path] = None, + beir_name: Optional[str] = None, ): src = UnifiedQuerySource( - backend=backend if backend in ("local","msmarco","beir") else None, - source=source if source in ("file","hf","beir") else None, - path=path, format=format, tsv_qid_col=tsv_qid_col, tsv_query_col=tsv_query_col, - jsonl_qid_key=jsonl_qid_key, jsonl_query_key=jsonl_query_key, + backend=backend if backend in ("local", "msmarco", "beir") else None, + source=source if source in ("file", "hf", "beir") else None, + path=path, + format=format, + tsv_qid_col=tsv_qid_col, + tsv_query_col=tsv_query_col, + jsonl_qid_key=jsonl_qid_key, + jsonl_query_key=jsonl_query_key, msmarco_queries_tsv=msmarco_queries_tsv, - hf_name=hf_name, hf_config=hf_config, split=split, - hf_qid_key=hf_qid_key, hf_query_key=hf_query_key, - beir_root=beir_root, beir_name=beir_name, + hf_name=hf_name, + hf_config=hf_config, + split=split, + hf_qid_key=hf_qid_key, + hf_query_key=hf_query_key, + beir_root=beir_root, + beir_name=beir_name, ) UnifiedQuerySource.export_to_tsv(src.iter(), out) typer.echo(f"Wrote {out}") + @app.command("prompts-list") def prompts_list(prompt_bank: Path = typer.Option(Path(__file__).with_name("prompt_bank.yaml"))): pb = PromptBank(prompt_bank) for pid in pb.list(): typer.echo(pid) + @app.command("prompts-show") -def prompts_show(prompt_id: str, - prompt_bank: Path = typer.Option(Path(__file__).with_name("prompt_bank.yaml"))): +def prompts_show( + prompt_id: str, prompt_bank: Path = typer.Option(Path(__file__).with_name("prompt_bank.yaml")) +): pb = PromptBank(prompt_bank) meta = pb.get_meta(prompt_id) typer.echo(json.dumps(meta, indent=2)) + @app.command("script-gen") -def script_gen(index_path: Path = typer.Option(..., exists=False), - topics: Path = typer.Option(..., help="TSV queries"), - run: Path = typer.Option(..., help="Output run file"), - output_bash: Path = typer.Option(Path("run_retrieval.sh")), - qrels: Optional[Path] = None, - bm25: bool = True, - extra: str = typer.Option("", help="Extra Pyserini CLI flags, concatenated") +def script_gen( + index_path: Path = typer.Option(..., exists=False), + topics: Path = typer.Option(..., help="TSV queries"), + run: Path = typer.Option(..., help="Output run file"), + output_bash: Path = typer.Option(Path("run_retrieval.sh")), + qrels: Optional[Path] = None, + bm25: bool = True, + extra: str = typer.Option("", help="Extra Pyserini CLI flags, concatenated"), ): - lines = build_script_lines(str(index_path), str(topics), str(run), - str(qrels) if qrels else None, bm25=bm25, extra=extra) + lines = build_script_lines( + str(index_path), + str(topics), + str(run), + str(qrels) if qrels else None, + bm25=bm25, + extra=extra, + ) output_bash.write_text("\n".join(lines)) typer.echo(f"Generated {output_bash}") + if __name__ == "__main__": app() diff --git a/querygym/core/base.py b/querygym/core/base.py index 5de92dc..f65f245 100644 --- a/querygym/core/base.py +++ b/querygym/core/base.py @@ -3,30 +3,33 @@ from typing import Dict, List, Optional, Any from tqdm import tqdm + @dataclass class QueryItem: """Represents a query with its ID and text. - + Attributes: qid: Unique query identifier text: Query text content - + Example: >>> query = QueryItem(qid="q1", text="what causes diabetes") """ + qid: str text: str + @dataclass class ReformulationResult: """Result of a query reformulation operation. - + Attributes: qid: Query identifier original: Original query text reformulated: Reformulated query text metadata: Additional metadata about the reformulation - + Example: >>> result = ReformulationResult( ... qid="q1", @@ -34,22 +37,24 @@ class ReformulationResult: ... reformulated="diabetes causes symptoms treatment" ... ) """ + qid: str original: str reformulated: str metadata: Dict[str, Any] = field(default_factory=dict) + @dataclass class MethodConfig: """Configuration for a reformulation method. - + Attributes: name: Method name (e.g., "genqr", "query2doc") params: Method-specific parameters llm: LLM configuration (model, temperature, max_tokens, etc.) seed: Random seed for reproducibility (default: 42) retries: Number of retries on LLM failure (default: 2) - + Example: >>> config = MethodConfig( ... name="genqr_ensemble", @@ -58,19 +63,21 @@ class MethodConfig: ... seed=42 ... ) """ + name: str params: Dict[str, Any] llm: Dict[str, Any] seed: int = 42 retries: int = 2 + class BaseReformulator: """Base class for all query reformulation methods. - + All reformulation methods inherit from this class and implement the `reformulate` method. This class provides common functionality for concatenating results and batch processing. - + Attributes: VERSION: Method version string REQUIRES_CONTEXT: Whether method requires retrieved contexts @@ -80,16 +87,17 @@ class BaseReformulator: llm: LLM client instance prompts: Prompt bank instance """ + VERSION = "0.1" REQUIRES_CONTEXT = False - + # Concatenation strategy configuration CONCATENATION_STRATEGY = "query_repeat_plus_generated" # Default strategy DEFAULT_QUERY_REPEATS = 5 # Default number of query repetitions def __init__(self, cfg: MethodConfig, llm_client, prompt_resolver): """Initialize the reformulator. - + Args: cfg: Method configuration llm_client: LLM client for generating reformulations @@ -99,63 +107,69 @@ def __init__(self, cfg: MethodConfig, llm_client, prompt_resolver): self.llm = llm_client self.prompts = prompt_resolver - def reformulate(self, q: QueryItem, contexts: Optional[List[str]] = None) -> ReformulationResult: + def reformulate( + self, q: QueryItem, contexts: Optional[List[str]] = None + ) -> ReformulationResult: """Reformulate a single query. - + Args: q: Query to reformulate contexts: Optional retrieved contexts (required for some methods) - + Returns: Reformulation result - + Raises: NotImplementedError: Must be implemented by subclasses """ raise NotImplementedError - def concatenate_result(self, original_query: str, generated_content: str | List[str], - query_repeats: Optional[int] = None, - content_separator: str = " ") -> str: + def concatenate_result( + self, + original_query: str, + generated_content: str | List[str], + query_repeats: Optional[int] = None, + content_separator: str = " ", + ) -> str: """ Concatenate the original query with generated content based on the method's strategy. Matches exact patterns from baseline implementations. - + Args: original_query: The original query text generated_content: The content generated by the LLM (string or list of strings) query_repeats: Number of times to repeat the original query (overrides default) content_separator: Separator to use between multiple content pieces - + Returns: Concatenated reformulated query """ if query_repeats is None: query_repeats = self.cfg.params.get("repeat_query_weight", self.DEFAULT_QUERY_REPEATS) - + strategy = self.cfg.params.get("concatenation_strategy", self.CONCATENATION_STRATEGY) - + # Handle multiple content pieces if isinstance(generated_content, list): generated_text = content_separator.join(generated_content) else: generated_text = generated_content - + if strategy == "query_repeat_plus_generated": # Pattern: (Q × N) + generated_content # Used by: GenQR, GenQREnsemble, QA-Expand result = " ".join([original_query] * query_repeats + [generated_text]) - + elif strategy == "query_plus_generated": # Pattern: Q + generated_content (single repetition) # Used by: Query2E result = " ".join([original_query, generated_text]) - + elif strategy == "generated_only": # Pattern: Only generated content (no original query) # Used by: Query2Doc result = generated_text - + elif strategy == "adaptive_query_repeat_plus_generated": # Pattern: (query + ' ') * repetition_times + generated_content # Uses adaptive repetition based on content length @@ -164,7 +178,7 @@ def concatenate_result(self, original_query: str, generated_content: str | List[ repetition_times = (len(generated_text) // len(original_query)) // adaptive_times repetition_times = max(1, repetition_times) # At least 1 repetition result = (original_query + " ") * repetition_times + generated_text - + elif strategy == "interleaved_query_content": # Pattern: q + a1 + q + a2 + q + a3 + ... (interleaving) # Used by: LameR @@ -177,32 +191,34 @@ def concatenate_result(self, original_query: str, generated_content: str | List[ else: # Single content: query + content result = " ".join([original_query, generated_text]) - + elif strategy == "generated_plus_query_repeat": # Pattern: generated_content + (Q × N) # Used by: Future methods that want generated content first result = " ".join([generated_text] + [original_query] * query_repeats) - + elif strategy == "query_sandwich": # Pattern: Q + generated_content + Q # Used by: Future methods that want query wrapping result = " ".join([original_query, generated_text, original_query]) - + else: # Default fallback to query_repeat_plus_generated pattern result = " ".join([original_query] * query_repeats + [generated_text]) - + # Clean up any newlines and quotes in the final result (for non-newline strategies) - result = result.replace('\n', ' ').replace('\r', ' ').strip() + result = result.replace("\n", " ").replace("\r", " ").strip() # Remove quotes from beginning and end result = result.strip('"').strip("'") return result - def retrieve_contexts_batch(self, queries: List[QueryItem], retrieval_params: Optional[Dict[str, Any]] = None) -> Dict[str, List[str]]: + def retrieve_contexts_batch( + self, queries: List[QueryItem], retrieval_params: Optional[Dict[str, Any]] = None + ) -> Dict[str, List[str]]: """Retrieve contexts for a batch of queries via any searcher implementing BaseSearcher. - + This method supports both searcher adapters (via registry) and wrapped searchers: - + **Using Adapters (via searcher_type):** retrieval_params = { "searcher_type": "pyserini", # or "pyterrier" @@ -210,28 +226,28 @@ def retrieve_contexts_batch(self, queries: List[QueryItem], retrieval_params: Op "k": 10, "threads": 16, } - + **Using Wrapped Searchers:** # Wrap your existing searcher from pyserini.search.lucene import LuceneSearcher my_searcher = LuceneSearcher.from_prebuilt_index('msmarco-v1-passage') wrapped = wrap_pyserini_searcher(my_searcher) - + retrieval_params = { "searcher": wrapped, # Pass wrapped searcher directly "k": 10, "threads": 16, } - + **Using Adapter Instances Directly:** from queryGym import create_searcher searcher = create_searcher("pyserini", index="msmarco-v1-passage") - + retrieval_params = { "searcher": searcher, # Pass adapter instance directly "k": 10, } - + Args: queries: List of QueryItem objects to retrieve contexts for retrieval_params: Method-specific retrieval parameters. Can include: @@ -240,85 +256,90 @@ def retrieve_contexts_batch(self, queries: List[QueryItem], retrieval_params: Op - searcher_kwargs: Keyword arguments to pass to searcher constructor - k: Number of documents to retrieve per query (default: 10) - threads: Number of threads for batch search (default: 16) - + Returns: Dictionary mapping query IDs to lists of context strings """ if not retrieval_params or not queries: return {} - + # Extract retrieval configuration k = retrieval_params.get("k", 10) threads = retrieval_params.get("threads", 16) - + # Get or create searcher searcher = retrieval_params.get("searcher") - + if searcher is None: # Create searcher using registry searcher_type = retrieval_params.get("searcher_type", "pyserini") searcher_kwargs = retrieval_params.get("searcher_kwargs", {}) - + # Lazy import to avoid hard dependency at load time try: from .searcher import create_searcher + searcher = create_searcher(searcher_type, **searcher_kwargs) except Exception as e: # If searcher creation fails, return empty dict # This allows methods to work without retrieval if contexts are pre-provided return {} - + # Ensure searcher implements BaseSearcher interface from .searcher import BaseSearcher + if not isinstance(searcher, BaseSearcher): raise ValueError( - f"Searcher must implement BaseSearcher interface. " - f"Got type: {type(searcher)}" + f"Searcher must implement BaseSearcher interface. " f"Got type: {type(searcher)}" ) - + # Extract query texts and IDs query_texts = [q.text for q in queries] query_ids = [q.qid for q in queries] - + # Perform batch retrieval using searcher's batch_search method batch_results = searcher.batch_search(query_texts, k=k, num_threads=threads) - + # Convert SearchHit results to dictionary format (qid -> list of context strings) contexts = {} for qid, search_hits in zip(query_ids, batch_results): # Extract content from each SearchHit contexts[qid] = [hit.content for hit in search_hits] - + return contexts - def retrieve_contexts_if_needed(self, q: QueryItem, retrieval_params: Optional[Dict[str, Any]] = None) -> List[str]: + def retrieve_contexts_if_needed( + self, q: QueryItem, retrieval_params: Optional[Dict[str, Any]] = None + ) -> List[str]: """Retrieve contexts for a single query (fallback method). - + Args: q: QueryItem to retrieve contexts for retrieval_params: Method-specific retrieval parameters - + Returns: List of context strings """ contexts = self.retrieve_contexts_batch([q], retrieval_params) return contexts.get(q.qid, []) - def reformulate_batch(self, queries: List[QueryItem], ctx_map: Optional[Dict[str, List[str]]] = None) -> List[ReformulationResult]: + def reformulate_batch( + self, queries: List[QueryItem], ctx_map: Optional[Dict[str, List[str]]] = None + ) -> List[ReformulationResult]: """Reformulate multiple queries in batch. - + Processes a list of queries and returns reformulation results. Shows a progress bar during processing. If the method requires contexts and none are provided, attempts to retrieve them automatically. - + Args: queries: List of queries to reformulate ctx_map: Optional mapping of query IDs to context lists. If None and method requires contexts, will attempt retrieval. - + Returns: List of reformulation results, one per input query - + Example: >>> queries = [QueryItem("q1", "diabetes"), QueryItem("q2", "cancer")] >>> results = reformulator.reformulate_batch(queries) @@ -326,31 +347,31 @@ def reformulate_batch(self, queries: List[QueryItem], ctx_map: Optional[Dict[str ... print(f"{r.qid}: {r.reformulated}") """ # If no contexts provided and method requires context, do batch retrieval - if ctx_map is None and hasattr(self, 'REQUIRES_CONTEXT') and self.REQUIRES_CONTEXT: + if ctx_map is None and hasattr(self, "REQUIRES_CONTEXT") and self.REQUIRES_CONTEXT: retrieval_params = self._get_retrieval_params() if retrieval_params: ctx_map = self.retrieve_contexts_batch(queries, retrieval_params) - + out: List[ReformulationResult] = [] - + # Import tqdm for progress bar try: progress_bar = tqdm(queries, desc=f"Reformulating with {self.cfg.name}", unit="query") except ImportError: # Fallback if tqdm is not available progress_bar = queries - + for q in progress_bar: ctx = (ctx_map or {}).get(q.qid) result = self.reformulate(q, ctx) out.append(result) - + # Update progress bar description with current query info - if hasattr(progress_bar, 'set_description'): + if hasattr(progress_bar, "set_description"): progress_bar.set_description(f"Reformulating with {self.cfg.name} (QID: {q.qid})") - + return out - + def _get_retrieval_params(self) -> Optional[Dict[str, Any]]: """Get retrieval parameters for this method. Override in subclasses.""" return None diff --git a/querygym/core/llm.py b/querygym/core/llm.py index f99543c..f64f3ab 100644 --- a/querygym/core/llm.py +++ b/querygym/core/llm.py @@ -3,16 +3,17 @@ from openai import OpenAI from typing import List, Dict, Any + class OpenAICompatibleClient: """Client for OpenAI-compatible chat completion APIs. - + Supports any API that implements the OpenAI chat completions format, including OpenAI, Azure OpenAI, local models via vLLM, etc. - + Attributes: client: OpenAI client instance model: Model identifier to use for completions - + Example: >>> client = OpenAICompatibleClient( ... model="gpt-4", @@ -24,30 +25,32 @@ class OpenAICompatibleClient: ... {"role": "user", "content": "Hello!"} ... ]) """ - + def __init__(self, model: str, base_url: str | None = None, api_key: str | None = None): """Initialize the LLM client. - + Args: model: Model identifier (e.g., "gpt-4", "gpt-3.5-turbo") base_url: Optional base URL for the API. If None, uses OPENAI_BASE_URL environment variable or OpenAI's default. api_key: Optional API key. If None, uses OPENAI_API_KEY environment variable. """ - self.client = OpenAI(base_url=base_url or os.getenv("OPENAI_BASE_URL", None), - api_key=api_key or os.getenv("OPENAI_API_KEY")) + self.client = OpenAI( + base_url=base_url or os.getenv("OPENAI_BASE_URL", None), + api_key=api_key or os.getenv("OPENAI_API_KEY"), + ) self.model = model def chat(self, messages: List[Dict[str, str]], **kw: Any) -> str: """Send a chat completion request. - + Args: messages: List of message dicts with 'role' and 'content' keys **kw: Additional parameters (temperature, max_tokens, etc.) - + Returns: Generated text response - + Example: >>> response = client.chat( ... messages=[{"role": "user", "content": "Hello"}], diff --git a/querygym/core/prompts.py b/querygym/core/prompts.py index 3bdcde0..3099f49 100644 --- a/querygym/core/prompts.py +++ b/querygym/core/prompts.py @@ -4,10 +4,11 @@ from pathlib import Path from typing import Dict, List, Any + @dataclass class PromptSpec: """Specification for a single prompt template. - + Attributes: id: Unique prompt identifier method_family: Method family this prompt belongs to @@ -15,34 +16,36 @@ class PromptSpec: template: Template dict with 'system', 'user', 'assistant' keys meta: Additional metadata (authors, license, etc.) """ + id: str method_family: str version: int template: Dict[str, str] meta: Dict[str, Any] + class PromptBank: """Manages prompt templates from a YAML file. - + Loads and provides access to prompt templates with metadata. Supports rendering templates with variable substitution. - + Attributes: _by_id: Internal dict mapping prompt IDs to PromptSpec objects - + Example: >>> from pathlib import Path >>> pb = PromptBank(Path("querygym/prompt_bank.yaml")) >>> messages = pb.render("genqr_keywords", query="diabetes") >>> print(messages) """ - + def __init__(self, path: str | Path): """Initialize the prompt bank from a YAML file. - + Args: path: Path to the prompt bank YAML file - + Raises: FileNotFoundError: If the YAML file doesn't exist yaml.YAMLError: If the YAML is malformed @@ -52,54 +55,58 @@ def __init__(self, path: str | Path): for x in items: self._by_id[x["id"]] = PromptSpec( id=x["id"], - method_family=x.get("method_family",""), - version=x.get("version",1), + method_family=x.get("method_family", ""), + version=x.get("version", 1), template=x["template"], - meta={k:v for k,v in x.items() if k not in ["id","method_family","version","template"]} + meta={ + k: v + for k, v in x.items() + if k not in ["id", "method_family", "version", "template"] + }, ) def render(self, prompt_id: str, **vars) -> List[Dict[str, str]]: """Render a prompt template with variable substitution. - + Args: prompt_id: ID of the prompt to render **vars: Variables to substitute in the template (e.g., query="text") - + Returns: List of message dicts with 'role' and 'content' keys, ready for LLM chat completion - + Raises: KeyError: If prompt_id doesn't exist KeyError: If required template variables are missing - + Example: >>> messages = pb.render("genqr_keywords", query="diabetes") >>> # Returns: [{"role": "system", "content": "..."}, ...] """ spec = self._by_id[prompt_id] messages = [] - + # System message (if present) if "system" in spec.template: - sys = spec.template.get("system","").format(**vars) - messages.append({"role":"system","content":sys}) - + sys = spec.template.get("system", "").format(**vars) + messages.append({"role": "system", "content": sys}) + # User message (if present) if "user" in spec.template: - usr = spec.template.get("user","").format(**vars) - messages.append({"role":"user","content":usr}) - + usr = spec.template.get("user", "").format(**vars) + messages.append({"role": "user", "content": usr}) + # Assistant message (if present) - for priming responses if "assistant" in spec.template: - asst = spec.template.get("assistant","").format(**vars) - messages.append({"role":"assistant","content":asst}) - + asst = spec.template.get("assistant", "").format(**vars) + messages.append({"role": "assistant", "content": asst}) + return messages def list(self) -> List[str]: """List all available prompt IDs. - + Returns: List of prompt IDs """ @@ -107,27 +114,27 @@ def list(self) -> List[str]: def get(self, prompt_id: str) -> PromptSpec: """Get a prompt specification by ID. - + Args: prompt_id: ID of the prompt - + Returns: PromptSpec object - + Raises: KeyError: If prompt_id doesn't exist """ return self._by_id[prompt_id] - + def get_meta(self, prompt_id: str) -> Dict[str, Any]: """Get metadata for a prompt. - + Args: prompt_id: ID of the prompt - + Returns: Metadata dict (authors, license, etc.) - + Raises: KeyError: If prompt_id doesn't exist """ diff --git a/querygym/core/registry.py b/querygym/core/registry.py index 187fe6a..1ef2b8a 100644 --- a/querygym/core/registry.py +++ b/querygym/core/registry.py @@ -4,8 +4,10 @@ METHODS: Dict[str, Type[BaseReformulator]] = {} + def register_method(name: str): def deco(cls: Type[BaseReformulator]): METHODS[name] = cls return cls + return deco diff --git a/querygym/core/searcher.py b/querygym/core/searcher.py index 2be0e7c..47c118a 100644 --- a/querygym/core/searcher.py +++ b/querygym/core/searcher.py @@ -14,11 +14,12 @@ @dataclass class SearchHit: """Standardized search hit result.""" + docid: str score: float content: str metadata: Dict[str, Any] = None - + def __post_init__(self): if self.metadata is None: self.metadata = {} @@ -27,57 +28,58 @@ def __post_init__(self): class BaseSearcher(ABC): """ Abstract base class for searcher implementations. - + Any searcher library (Pyserini, PyTerrier, etc.) can implement this interface to be compatible with querygym's retrieval framework. """ - + @abstractmethod def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]: """ Search for documents using a single query. - + Args: query: The search query string k: Number of documents to retrieve **kwargs: Additional search parameters specific to the searcher - + Returns: List of SearchHit objects ordered by relevance score (highest first) """ pass - + @abstractmethod - def batch_search(self, queries: List[str], k: int = 10, - num_threads: int = 1, **kwargs) -> List[List[SearchHit]]: + def batch_search( + self, queries: List[str], k: int = 10, num_threads: int = 1, **kwargs + ) -> List[List[SearchHit]]: """ Search for documents using multiple queries in batch. - + Args: queries: List of search query strings k: Number of documents to retrieve per query num_threads: Number of threads for parallel processing **kwargs: Additional search parameters specific to the searcher - + Returns: List of lists of SearchHit objects, one list per query """ pass - + @abstractmethod def get_searcher_info(self) -> Dict[str, Any]: """ Get information about the searcher configuration. - + Returns: Dictionary containing searcher metadata (name, version, parameters, etc.) """ pass - + def configure(self, **kwargs) -> None: """ Configure searcher parameters. - + Args: **kwargs: Configuration parameters specific to the searcher """ @@ -88,25 +90,25 @@ def configure(self, **kwargs) -> None: class SearcherRegistry: """Registry for managing different searcher implementations.""" - + _searchers: Dict[str, type] = {} - + @classmethod def register(cls, name: str, searcher_class: type) -> None: """Register a searcher implementation.""" if not issubclass(searcher_class, BaseSearcher): raise ValueError(f"Searcher class must inherit from BaseSearcher") cls._searchers[name] = searcher_class - + @classmethod def get_searcher(cls, name: str, **kwargs) -> BaseSearcher: """Get a searcher instance by name.""" if name not in cls._searchers: raise ValueError(f"Unknown searcher: {name}. Available: {list(cls._searchers.keys())}") - + searcher_class = cls._searchers[name] return searcher_class(**kwargs) - + @classmethod def list_searchers(cls) -> List[str]: """List all registered searcher names.""" @@ -117,11 +119,11 @@ def list_searchers(cls) -> List[str]: def create_searcher(searcher_type: str, **kwargs) -> BaseSearcher: """ Create a searcher instance. - + Args: searcher_type: Type of searcher to create **kwargs: Arguments to pass to the searcher constructor - + Returns: BaseSearcher instance """ diff --git a/querygym/core/searcher_wrappers.py b/querygym/core/searcher_wrappers.py index e24d34d..863a083 100644 --- a/querygym/core/searcher_wrappers.py +++ b/querygym/core/searcher_wrappers.py @@ -16,17 +16,17 @@ def wrap_pyserini_searcher(pyserini_searcher, answer_key: str = "contents") -> BaseSearcher: """ Wrap a user's Pyserini searcher for use with querygym. - + This function works standalone - it doesn't import Pyserini at the module level. The user must provide their own Pyserini searcher instance. - + Args: pyserini_searcher: User's LuceneSearcher or LuceneImpactSearcher instance answer_key: Field name(s) to extract content from (pipe-separated) - + Returns: BaseSearcher instance that can be used with querygym - + Example: >>> from pyserini.search.lucene import LuceneSearcher >>> lucene_searcher = LuceneSearcher.from_prebuilt_index('msmarco-v1-passage') @@ -34,46 +34,52 @@ def wrap_pyserini_searcher(pyserini_searcher, answer_key: str = "contents") -> B >>> wrapped_searcher = wrap_pyserini_searcher(lucene_searcher) >>> retriever = qg.Retriever(searcher=wrapped_searcher) """ - + class PyseriniWrapper(BaseSearcher): def __init__(self, searcher, answer_key): # Validate that the searcher has the expected methods - if not hasattr(searcher, 'search') or not hasattr(searcher, 'batch_search'): + if not hasattr(searcher, "search") or not hasattr(searcher, "batch_search"): raise ValueError("Provided searcher must have 'search' and 'batch_search' methods") - + self.searcher = searcher self.answer_key = answer_key self._searcher_info = { "name": "UserPyseriniWrapper", "type": "user_pyserini", "answer_key": answer_key, - "searcher_class": type(searcher).__name__ + "searcher_class": type(searcher).__name__, } - + def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]: hits = self.searcher.search(query, k) return self._process_hits(hits) - - def batch_search(self, queries: List[str], k: int = 10, - num_threads: int = 1, **kwargs) -> List[List[SearchHit]]: + + def batch_search( + self, queries: List[str], k: int = 10, num_threads: int = 1, **kwargs + ) -> List[List[SearchHit]]: pseudo_batch_topic_ids = [str(idx) for idx, _ in enumerate(queries)] results = self.searcher.batch_search(queries, pseudo_batch_topic_ids, k, num_threads) batch_results = [results[id_] for id_ in pseudo_batch_topic_ids] return [self._process_hits(hits) for hits in batch_results] - + def _process_hits(self, hits) -> List[SearchHit]: search_hits = [] for hit in hits: try: - raw_content = hit.lucene_document.get('raw') + raw_content = hit.lucene_document.get("raw") if raw_content: import json + raw_df = json.loads(raw_content) text_list = [raw_df[k] for k in self.answer_key.split("|") if raw_df.get(k)] content = "\t".join(text_list) else: - content = getattr(hit, 'contents', '') or getattr(hit, 'text', '') or str(hit.docid) - + content = ( + getattr(hit, "contents", "") + or getattr(hit, "text", "") + or str(hit.docid) + ) + search_hit = SearchHit( docid=hit.docid, score=hit.score, @@ -81,45 +87,41 @@ def _process_hits(self, hits) -> List[SearchHit]: metadata={ "user_defined": True, "searcher_class": type(self.searcher).__name__, - "answer_key": self.answer_key - } + "answer_key": self.answer_key, + }, ) search_hits.append(search_hit) except Exception as e: search_hit = SearchHit( - docid=getattr(hit, 'docid', 'unknown'), - score=getattr(hit, 'score', 0.0), - content=str(getattr(hit, 'docid', 'unknown')), - metadata={ - "error": str(e), - "user_defined": True, - "fallback": True - } + docid=getattr(hit, "docid", "unknown"), + score=getattr(hit, "score", 0.0), + content=str(getattr(hit, "docid", "unknown")), + metadata={"error": str(e), "user_defined": True, "fallback": True}, ) search_hits.append(search_hit) return search_hits - + def get_searcher_info(self) -> Dict[str, Any]: return self._searcher_info.copy() - + return PyseriniWrapper(pyserini_searcher, answer_key) def wrap_pyterrier_retriever(pyterrier_retriever, index, text_field: str = "text") -> BaseSearcher: """ Wrap a user's PyTerrier retriever for use with querygym. - + This function works standalone - it doesn't import PyTerrier at the module level. The user must provide their own PyTerrier retriever and index instances. - + Args: pyterrier_retriever: User's PyTerrier retriever instance index: PyTerrier index reference text_field: Field name containing document text - + Returns: BaseSearcher instance that can be used with querygym - + Example: >>> import pyterrier as pt >>> pt.init() @@ -127,13 +129,13 @@ def wrap_pyterrier_retriever(pyterrier_retriever, index, text_field: str = "text >>> wrapped_searcher = wrap_pyterrier_retriever(BM25_r, index) >>> retriever = qg.Retriever(searcher=wrapped_searcher) """ - + class PyTerrierWrapper(BaseSearcher): def __init__(self, retriever, index, text_field): # Validate that the retriever has the expected methods - if not hasattr(retriever, 'search') or not hasattr(retriever, 'transform'): + if not hasattr(retriever, "search") or not hasattr(retriever, "transform"): raise ValueError("Provided retriever must have 'search' and 'transform' methods") - + self.retriever = retriever self.index = index self.text_field = text_field @@ -141,37 +143,37 @@ def __init__(self, retriever, index, text_field): "name": "UserPyTerrierWrapper", "type": "user_pyterrier", "text_field": text_field, - "retriever_class": type(retriever).__name__ + "retriever_class": type(retriever).__name__, } - + def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]: """Search for documents using a single query.""" try: # Use PyTerrier's search method which returns a DataFrame results_df = self.retriever.search(query) - + # Convert to SearchHit objects search_hits = [] for _, row in results_df.head(k).iterrows(): # Extract content from the text field - content = str(row.get(self.text_field, row.get('docno', ''))) - + content = str(row.get(self.text_field, row.get("docno", ""))) + search_hit = SearchHit( - docid=str(row.get('docno', row.get('docid', ''))), - score=float(row.get('score', 0.0)), + docid=str(row.get("docno", row.get("docid", ""))), + score=float(row.get("score", 0.0)), content=content, metadata={ - 'user_defined': True, - 'retriever_class': type(self.retriever).__name__, - 'text_field': self.text_field, - 'qid': str(row.get('qid', '')), - 'rank': int(row.get('rank', 0)) - } + "user_defined": True, + "retriever_class": type(self.retriever).__name__, + "text_field": self.text_field, + "qid": str(row.get("qid", "")), + "rank": int(row.get("rank", 0)), + }, ) search_hits.append(search_hit) - + return search_hits - + except Exception as e: # Fallback: return mock results if anything fails search_hits = [] @@ -181,53 +183,56 @@ def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]: score=1.0 - (i * 0.1), content=f"Content for query: {query}", metadata={ - 'user_defined': True, - 'retriever_class': type(self.retriever).__name__, - 'text_field': self.text_field, - 'error': str(e), - 'fallback': True - } + "user_defined": True, + "retriever_class": type(self.retriever).__name__, + "text_field": self.text_field, + "error": str(e), + "fallback": True, + }, ) search_hits.append(search_hit) return search_hits - - def batch_search(self, queries: List[str], k: int = 10, - num_threads: int = 1, **kwargs) -> List[List[SearchHit]]: + + def batch_search( + self, queries: List[str], k: int = 10, num_threads: int = 1, **kwargs + ) -> List[List[SearchHit]]: """Search for documents using multiple queries in batch.""" batch_results = [] - + for query in queries: # Use the single search method for each query results = self.search(query, k) batch_results.append(results) - + return batch_results - + def get_searcher_info(self) -> Dict[str, Any]: return self._searcher_info.copy() - + return PyTerrierWrapper(pyterrier_retriever, index, text_field) -def wrap_custom_searcher(search_func, batch_search_func=None, searcher_name: str = "CustomSearcher") -> BaseSearcher: +def wrap_custom_searcher( + search_func, batch_search_func=None, searcher_name: str = "CustomSearcher" +) -> BaseSearcher: """ Wrap user's custom search functions for use with queryGym. - + Args: search_func: Function that takes (query: str, k: int) and returns list of (docid, score, content) tuples batch_search_func: Optional function that takes (queries: List[str], k: int) and returns List[List[tuple]] searcher_name: Name for the searcher - + Returns: BaseSearcher instance that can be used with querygym - + Example: >>> def my_search(query, k): ... return [("doc1", 0.9, "content1"), ("doc2", 0.8, "content2")] >>> wrapped_searcher = wrap_custom_searcher(my_search) >>> retriever = qg.Retriever(searcher=wrapped_searcher) """ - + class CustomWrapper(BaseSearcher): def __init__(self, search_func, batch_search_func, searcher_name): self.search_func = search_func @@ -236,13 +241,13 @@ def __init__(self, search_func, batch_search_func, searcher_name): self._searcher_info = { "name": searcher_name, "type": "user_custom", - "has_batch_search": batch_search_func is not None + "has_batch_search": batch_search_func is not None, } - + def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]: results = self.search_func(query, k) search_hits = [] - + for result in results: if len(result) >= 3: docid, score, content = result[0], result[1], result[2] @@ -250,27 +255,28 @@ def search(self, query: str, k: int = 10, **kwargs) -> List[SearchHit]: docid, score, content = result[0], result[1], str(result[0]) else: docid, score, content = str(result[0]), 0.0, str(result[0]) - + search_hit = SearchHit( docid=str(docid), score=float(score), content=str(content), - metadata={"user_defined": True, "custom_searcher": True} + metadata={"user_defined": True, "custom_searcher": True}, ) search_hits.append(search_hit) - + return search_hits - - def batch_search(self, queries: List[str], k: int = 10, - num_threads: int = 1, **kwargs) -> List[List[SearchHit]]: + + def batch_search( + self, queries: List[str], k: int = 10, num_threads: int = 1, **kwargs + ) -> List[List[SearchHit]]: if self.batch_search_func: batch_results = self.batch_search_func(queries, k) return [self.search(query, k) for query, results in zip(queries, batch_results)] else: # Fallback to individual searches return [self.search(query, k) for query in queries] - + def get_searcher_info(self) -> Dict[str, Any]: return self._searcher_info.copy() - + return CustomWrapper(search_func, batch_search_func, searcher_name) diff --git a/querygym/core/utils.py b/querygym/core/utils.py index 452610d..c339c1f 100644 --- a/querygym/core/utils.py +++ b/querygym/core/utils.py @@ -2,16 +2,20 @@ import random from typing import Optional + def seed_everything(seed: Optional[int] = 42): - if seed is None: return + if seed is None: + return random.seed(seed) try: import numpy as np + np.random.seed(seed) except Exception: pass try: import torch + torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True diff --git a/querygym/data/dataloader.py b/querygym/data/dataloader.py index 3caeae5..8331575 100644 --- a/querygym/data/dataloader.py +++ b/querygym/data/dataloader.py @@ -16,7 +16,7 @@ class DataLoader: """Core data loader for local files only.""" - + @staticmethod def load_queries( path: Union[str, Path], @@ -24,11 +24,11 @@ def load_queries( qid_col: int = 0, query_col: int = 1, qid_key: str = "qid", - query_key: str = "query" + query_key: str = "query", ) -> List[QueryItem]: """ Load queries from a local file. - + Args: path: Path to queries file format: File format - "tsv" or "jsonl" @@ -36,162 +36,147 @@ def load_queries( query_col: Column index for query text (TSV only) qid_key: JSON key for query ID (JSONL only) query_key: JSON key for query text (JSONL only) - + Returns: List of QueryItem objects - + Example: >>> queries = DataLoader.load_queries("queries.tsv", format="tsv") >>> queries = DataLoader.load_queries("queries.jsonl", format="jsonl") """ path = Path(path) - + if not path.exists(): raise FileNotFoundError(f"Query file not found: {path}") - + if format == "tsv": return DataLoader._load_queries_tsv(path, qid_col, query_col) elif format == "jsonl": return DataLoader._load_queries_jsonl(path, qid_key, query_key) else: raise ValueError(f"Unsupported format: {format}. Use 'tsv' or 'jsonl'") - + @staticmethod - def _load_queries_tsv( - path: Path, - qid_col: int, - query_col: int - ) -> List[QueryItem]: + def _load_queries_tsv(path: Path, qid_col: int, query_col: int) -> List[QueryItem]: """Load queries from TSV file.""" queries = [] warned_empty = False warned_malformed = False - + with open(path, "r", encoding="utf-8") as f: reader = csv.reader(f, delimiter="\t") for line_num, row in enumerate(reader, 1): # Skip malformed rows if len(row) <= max(qid_col, query_col): if not warned_malformed: - warnings.warn( - f"Skipping malformed rows in {path} (not enough columns)" - ) + warnings.warn(f"Skipping malformed rows in {path} (not enough columns)") warned_malformed = True continue - + qid = str(row[qid_col]).strip() query_text = str(row[query_col]).strip() - + # Skip empty queries if not query_text: if not warned_empty: warnings.warn(f"Skipping empty queries in {path}") warned_empty = True continue - + queries.append(QueryItem(qid=qid, text=query_text)) - + if not queries: raise ValueError(f"No valid queries found in {path}") - + return queries - + @staticmethod - def _load_queries_jsonl( - path: Path, - qid_key: str, - query_key: str - ) -> List[QueryItem]: + def _load_queries_jsonl(path: Path, qid_key: str, query_key: str) -> List[QueryItem]: """Load queries from JSONL file.""" queries = [] warned_empty = False warned_missing_keys = False - + with open(path, "r", encoding="utf-8") as f: for line_num, line in enumerate(f, 1): if not line.strip(): continue - + try: obj = json.loads(line) except json.JSONDecodeError as e: warnings.warn(f"Invalid JSON at line {line_num} in {path}: {e}") continue - + # Check for required keys if qid_key not in obj or query_key not in obj: if not warned_missing_keys: - warnings.warn( - f"Missing keys '{qid_key}' or '{query_key}' in {path}" - ) + warnings.warn(f"Missing keys '{qid_key}' or '{query_key}' in {path}") warned_missing_keys = True continue - + qid = str(obj[qid_key]) query_text = str(obj[query_key]).strip() - + # Skip empty queries if not query_text: if not warned_empty: warnings.warn(f"Skipping empty queries in {path}") warned_empty = True continue - + queries.append(QueryItem(qid=qid, text=query_text)) - + if not queries: raise ValueError(f"No valid queries found in {path}") - + return queries - + @staticmethod - def load_qrels( - path: Union[str, Path], - format: str = "trec" - ) -> Dict[str, Dict[str, int]]: + def load_qrels(path: Union[str, Path], format: str = "trec") -> Dict[str, Dict[str, int]]: """ Load qrels (relevance judgments) from a local file. - + Args: path: Path to qrels file format: File format - "trec" (standard TREC format) - + Returns: Dict mapping qid -> {docid -> relevance} - + Example: >>> qrels = DataLoader.load_qrels("qrels.txt", format="trec") >>> qrels["q1"]["doc123"] # relevance score for doc123 in query q1 """ path = Path(path) - + if not path.exists(): raise FileNotFoundError(f"Qrels file not found: {path}") - + if format == "trec": return DataLoader._load_qrels_trec(path) else: raise ValueError(f"Unsupported qrels format: {format}. Use 'trec'") - + @staticmethod def _load_qrels_trec(path: Path) -> Dict[str, Dict[str, int]]: """Load qrels from TREC format file.""" qrels: Dict[str, Dict[str, int]] = {} warned_malformed = False - + with open(path, "r", encoding="utf-8") as f: for line_num, line in enumerate(f, 1): line = line.strip() if not line: continue - + parts = line.split() if len(parts) < 4: if not warned_malformed: warnings.warn(f"Skipping malformed lines in {path}") warned_malformed = True continue - + qid = parts[0] docid = parts[2] try: @@ -201,164 +186,149 @@ def _load_qrels_trec(path: Path) -> Dict[str, Dict[str, int]]: warnings.warn(f"Invalid relevance score in {path}") warned_malformed = True continue - + if qid not in qrels: qrels[qid] = {} qrels[qid][docid] = relevance - + if not qrels: raise ValueError(f"No valid qrels found in {path}") - + return qrels - + @staticmethod def load_contexts( - path: Union[str, Path], - qid_key: str = "qid", - contexts_key: str = "contexts" + path: Union[str, Path], qid_key: str = "qid", contexts_key: str = "contexts" ) -> Dict[str, List[str]]: """ Load pre-retrieved contexts from a JSONL file. - + Args: path: Path to contexts JSONL file qid_key: JSON key for query ID contexts_key: JSON key for contexts list - + Returns: Dict mapping qid -> list of context strings - + Example: >>> contexts = DataLoader.load_contexts("contexts.jsonl") >>> contexts["q1"] # List of context strings for query q1 """ path = Path(path) - + if not path.exists(): raise FileNotFoundError(f"Contexts file not found: {path}") - + contexts: Dict[str, List[str]] = {} warned_missing_keys = False - + with open(path, "r", encoding="utf-8") as f: for line_num, line in enumerate(f, 1): if not line.strip(): continue - + try: obj = json.loads(line) except json.JSONDecodeError as e: warnings.warn(f"Invalid JSON at line {line_num} in {path}: {e}") continue - + # Check for required keys if qid_key not in obj or contexts_key not in obj: if not warned_missing_keys: - warnings.warn( - f"Missing keys '{qid_key}' or '{contexts_key}' in {path}" - ) + warnings.warn(f"Missing keys '{qid_key}' or '{contexts_key}' in {path}") warned_missing_keys = True continue - + qid = str(obj[qid_key]) ctx_list = obj[contexts_key] - + if not isinstance(ctx_list, list): warnings.warn(f"Contexts for {qid} is not a list, skipping") continue - + contexts[qid] = [str(ctx) for ctx in ctx_list] - + if not contexts: raise ValueError(f"No valid contexts found in {path}") - + return contexts - + @staticmethod def load_examples( - path: Union[str, Path], - query_key: str = "query", - passage_key: str = "passage" + path: Union[str, Path], query_key: str = "query", passage_key: str = "passage" ) -> List[Dict[str, str]]: """ Load few-shot examples from a JSONL file. - + Used by methods like Query2E that need (query, passage) pairs as demonstrations. - + Args: path: Path to examples JSONL file query_key: JSON key for query text passage_key: JSON key for passage text - + Returns: List of dicts with 'query' and 'passage' keys - + Example: >>> examples = DataLoader.load_examples("examples.jsonl") >>> # Returns: [{"query": "...", "passage": "..."}, ...] - + JSONL format: {"query": "how long is flea life cycle?", "passage": "The life cycle of a flea..."} {"query": "cost of flooring?", "passage": "The cost of interior concrete..."} """ path = Path(path) - + if not path.exists(): raise FileNotFoundError(f"Examples file not found: {path}") - + examples: List[Dict[str, str]] = [] warned_missing_keys = False - + with open(path, "r", encoding="utf-8") as f: for line_num, line in enumerate(f, 1): if not line.strip(): continue - + try: obj = json.loads(line) except json.JSONDecodeError as e: warnings.warn(f"Invalid JSON at line {line_num} in {path}: {e}") continue - + # Check for required keys if query_key not in obj or passage_key not in obj: if not warned_missing_keys: - warnings.warn( - f"Missing keys '{query_key}' or '{passage_key}' in {path}" - ) + warnings.warn(f"Missing keys '{query_key}' or '{passage_key}' in {path}") warned_missing_keys = True continue - - examples.append({ - "query": str(obj[query_key]), - "passage": str(obj[passage_key]) - }) - + + examples.append({"query": str(obj[query_key]), "passage": str(obj[passage_key])}) + if not examples: raise ValueError(f"No valid examples found in {path}") - + return examples - + @staticmethod - def save_queries( - queries: List[QueryItem], - path: Union[str, Path], - format: str = "tsv" - ) -> None: + def save_queries(queries: List[QueryItem], path: Union[str, Path], format: str = "tsv") -> None: """ Save queries to a file. - + Args: queries: List of QueryItem objects path: Output file path format: Output format - "tsv" or "jsonl" - + Example: >>> DataLoader.save_queries(queries, "output.tsv", format="tsv") """ path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) - + if format == "tsv": with open(path, "w", encoding="utf-8") as f: writer = csv.writer(f, delimiter="\t") @@ -376,27 +346,27 @@ def save_queries( # Backward compatibility: Keep UnifiedQuerySource for now with deprecation warning class UnifiedQuerySource: """Deprecated: Use DataLoader.load_queries() instead.""" - + def __init__(self, backend: str = "local", **kwargs): warnings.warn( "UnifiedQuerySource is deprecated. Use DataLoader.load_queries() instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) - + if backend != "local": raise ValueError( f"Backend '{backend}' is no longer supported. " "Use querygym.loaders module for BEIR/MS MARCO helpers." ) - + self.path = kwargs.get("path") self.format = kwargs.get("format", "tsv") self.qid_col = kwargs.get("tsv_qid_col", 0) self.query_col = kwargs.get("tsv_query_col", 1) self.qid_key = kwargs.get("jsonl_qid_key", "qid") self.query_key = kwargs.get("jsonl_query_key", "query") - + def iter(self): """Iterate over queries.""" queries = DataLoader.load_queries( @@ -405,10 +375,10 @@ def iter(self): qid_col=self.qid_col, query_col=self.query_col, qid_key=self.qid_key, - query_key=self.query_key + query_key=self.query_key, ) return iter(queries) - + @staticmethod def export_to_tsv(items, out_path): """Export queries to TSV.""" @@ -418,28 +388,24 @@ def export_to_tsv(items, out_path): # Backward compatibility: Keep UnifiedContextSource for now with deprecation warning class UnifiedContextSource: """Deprecated: Use DataLoader.load_contexts() instead.""" - + def __init__(self, mode: str = "file", **kwargs): warnings.warn( "UnifiedContextSource is deprecated. Use DataLoader.load_contexts() instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) - + if mode != "file": raise ValueError( f"Mode '{mode}' is no longer supported. " "Load contexts from file or use retrieval separately." ) - + self.path = kwargs.get("path") self.qid_key = kwargs.get("qid_key", "qid") self.ctx_key = kwargs.get("ctx_key", "contexts") - + def load(self, queries: List[QueryItem]) -> Dict[str, List[str]]: """Load contexts from file.""" - return DataLoader.load_contexts( - self.path, - qid_key=self.qid_key, - contexts_key=self.ctx_key - ) + return DataLoader.load_contexts(self.path, qid_key=self.qid_key, contexts_key=self.ctx_key) diff --git a/querygym/loaders/beir.py b/querygym/loaders/beir.py index 629a11c..9974ad4 100644 --- a/querygym/loaders/beir.py +++ b/querygym/loaders/beir.py @@ -18,117 +18,104 @@ from ..data.dataloader import DataLoader -def load_queries( - beir_data_dir: Union[str, Path], - split: str = "test" -) -> List[QueryItem]: +def load_queries(beir_data_dir: Union[str, Path], split: str = "test") -> List[QueryItem]: """ Load queries from a BEIR dataset directory. - + BEIR datasets have a queries.jsonl file with format: {"_id": "query_id", "text": "query text", ...} - + Args: beir_data_dir: Path to BEIR dataset directory (contains queries.jsonl) split: Dataset split (not used for queries, kept for API consistency) - + Returns: List of QueryItem objects - + Example: >>> from querygym.datasets import beir >>> queries = beir.load_queries("./data/nfcorpus") """ beir_data_dir = Path(beir_data_dir) queries_file = beir_data_dir / "queries.jsonl" - + if not queries_file.exists(): raise FileNotFoundError( f"BEIR queries file not found: {queries_file}\n" f"Expected BEIR directory structure with queries.jsonl" ) - + # BEIR uses "_id" and "text" as keys - return DataLoader.load_queries( - queries_file, - format="jsonl", - qid_key="_id", - query_key="text" - ) - - -def load_qrels( - beir_data_dir: Union[str, Path], - split: str = "test" -) -> Dict[str, Dict[str, int]]: + return DataLoader.load_queries(queries_file, format="jsonl", qid_key="_id", query_key="text") + + +def load_qrels(beir_data_dir: Union[str, Path], split: str = "test") -> Dict[str, Dict[str, int]]: """ Load qrels from a BEIR dataset directory. - + BEIR qrels are in TSV format: query-id corpus-id score Located at: qrels/{split}.tsv - + Args: beir_data_dir: Path to BEIR dataset directory split: Dataset split ("train", "dev", "test") - + Returns: Dict mapping qid -> {docid -> relevance} - + Example: >>> from querygym.datasets import beir >>> qrels = beir.load_qrels("./data/nfcorpus", split="test") """ beir_data_dir = Path(beir_data_dir) qrels_file = beir_data_dir / "qrels" / f"{split}.tsv" - + if not qrels_file.exists(): raise FileNotFoundError( f"BEIR qrels file not found: {qrels_file}\n" f"Expected: {beir_data_dir}/qrels/{split}.tsv" ) - + # BEIR qrels format: query-id \t corpus-id \t score # We need to convert to standard TREC format (add iteration column) qrels: Dict[str, Dict[str, int]] = {} - + with open(qrels_file, "r", encoding="utf-8") as f: for line in f: parts = line.strip().split("\t") if len(parts) < 3: continue - + qid = parts[0] docid = parts[1] try: relevance = int(parts[2]) except ValueError: continue - + if qid not in qrels: qrels[qid] = {} qrels[qid][docid] = relevance - + if not qrels: raise ValueError(f"No valid qrels found in {qrels_file}") - + return qrels -def load_corpus( - beir_data_dir: Union[str, Path] -) -> Dict[str, Dict[str, str]]: +def load_corpus(beir_data_dir: Union[str, Path]) -> Dict[str, Dict[str, str]]: """ Load corpus from a BEIR dataset directory. - + BEIR corpus is in JSONL format with: {"_id": "doc_id", "title": "...", "text": "...", ...} - + Args: beir_data_dir: Path to BEIR dataset directory - + Returns: Dict mapping doc_id -> {"title": ..., "text": ...} - + Example: >>> from querygym.datasets import beir >>> corpus = beir.load_corpus("./data/nfcorpus") @@ -136,35 +123,32 @@ def load_corpus( """ beir_data_dir = Path(beir_data_dir) corpus_file = beir_data_dir / "corpus.jsonl" - + if not corpus_file.exists(): raise FileNotFoundError( f"BEIR corpus file not found: {corpus_file}\n" f"Expected BEIR directory structure with corpus.jsonl" ) - + corpus: Dict[str, Dict[str, str]] = {} - + with open(corpus_file, "r", encoding="utf-8") as f: for line in f: if not line.strip(): continue - + try: doc = json.loads(line) except json.JSONDecodeError: continue - + doc_id = doc.get("_id") if not doc_id: continue - - corpus[str(doc_id)] = { - "title": doc.get("title", ""), - "text": doc.get("text", "") - } - + + corpus[str(doc_id)] = {"title": doc.get("title", ""), "text": doc.get("text", "")} + if not corpus: raise ValueError(f"No valid documents found in {corpus_file}") - + return corpus diff --git a/querygym/loaders/msmarco.py b/querygym/loaders/msmarco.py index 45b4b22..86301fc 100644 --- a/querygym/loaders/msmarco.py +++ b/querygym/loaders/msmarco.py @@ -18,130 +18,119 @@ from ..data.dataloader import DataLoader -def load_queries( - queries_tsv: Union[str, Path] -) -> List[QueryItem]: +def load_queries(queries_tsv: Union[str, Path]) -> List[QueryItem]: """ Load queries from MS MARCO queries TSV file. - + MS MARCO queries format: qid \\t query_text - + Args: queries_tsv: Path to MS MARCO queries.tsv file - + Returns: List of QueryItem objects - + Example: >>> from querygym.datasets import msmarco >>> queries = msmarco.load_queries("./data/msmarco_queries.tsv") """ queries_tsv = Path(queries_tsv) - + if not queries_tsv.exists(): raise FileNotFoundError(f"MS MARCO queries file not found: {queries_tsv}") - + # MS MARCO uses standard TSV: qid \t query - return DataLoader.load_queries( - queries_tsv, - format="tsv", - qid_col=0, - query_col=1 - ) + return DataLoader.load_queries(queries_tsv, format="tsv", qid_col=0, query_col=1) -def load_qrels( - qrels_tsv: Union[str, Path] -) -> Dict[str, Dict[str, int]]: +def load_qrels(qrels_tsv: Union[str, Path]) -> Dict[str, Dict[str, int]]: """ Load qrels from MS MARCO qrels file. - + MS MARCO qrels format can be either: - TSV: qid \\t 0 \\t docid \\t relevance (TREC format) - Or: qid \\t docid \\t relevance (simplified) - + Args: qrels_tsv: Path to MS MARCO qrels file - + Returns: Dict mapping qid -> {docid -> relevance} - + Example: >>> from querygym.datasets import msmarco >>> qrels = msmarco.load_qrels("./data/msmarco_qrels.tsv") """ qrels_tsv = Path(qrels_tsv) - + if not qrels_tsv.exists(): raise FileNotFoundError(f"MS MARCO qrels file not found: {qrels_tsv}") - + # Try standard TREC format first try: return DataLoader.load_qrels(qrels_tsv, format="trec") except ValueError: # Fall back to simplified format (qid \t docid \t relevance) qrels: Dict[str, Dict[str, int]] = {} - + with open(qrels_tsv, "r", encoding="utf-8") as f: for line in f: parts = line.strip().split("\t") if len(parts) < 3: continue - + qid = parts[0] docid = parts[1] try: relevance = int(parts[2]) except ValueError: continue - + if qid not in qrels: qrels[qid] = {} qrels[qid][docid] = relevance - + if not qrels: raise ValueError(f"No valid qrels found in {qrels_tsv}") - + return qrels -def load_collection( - collection_tsv: Union[str, Path] -) -> Dict[str, str]: +def load_collection(collection_tsv: Union[str, Path]) -> Dict[str, str]: """ Load MS MARCO passage/document collection. - + MS MARCO collection format: docid \\t text - + Args: collection_tsv: Path to MS MARCO collection.tsv file - + Returns: Dict mapping docid -> text - + Example: >>> from querygym.datasets import msmarco >>> collection = msmarco.load_collection("./data/collection.tsv") >>> collection["doc123"] """ collection_tsv = Path(collection_tsv) - + if not collection_tsv.exists(): raise FileNotFoundError(f"MS MARCO collection file not found: {collection_tsv}") - + collection: Dict[str, str] = {} - + with open(collection_tsv, "r", encoding="utf-8") as f: for line in f: parts = line.strip().split("\t") if len(parts) < 2: continue - + docid = parts[0] text = parts[1] collection[docid] = text - + if not collection: raise ValueError(f"No valid documents found in {collection_tsv}") - + return collection diff --git a/querygym/methods/csqe.py b/querygym/methods/csqe.py index a4cec58..cce80af 100644 --- a/querygym/methods/csqe.py +++ b/querygym/methods/csqe.py @@ -4,20 +4,22 @@ from typing import List, Optional, Dict, Any import re + @register_method("csqe") class CSQE(BaseReformulator): """ CSQE (Context-based Sentence-level Query Expansion) method. - + Combines KEQE (knowledge-based) and CSQE (context-based) expansions: 1. Generates N KEQE passages from LLM knowledge (no contexts) - default N=2 2. Generates N CSQE responses with retrieved contexts (few-shot) - default N=2 3. Extracts key sentences from CSQE responses 4. Concatenates: query × N + keqe_passages + csqe_sentences (newline-separated, lowercased) - + Default: N=2 for both, totaling 4 generations (2 KEQE + 2 CSQE). Query is repeated N times (equal to number of KEQE expansions). """ + VERSION = "1.0" REQUIRES_CONTEXT = True # Needs retrieved contexts for CSQE expansion @@ -30,9 +32,11 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: retrieval_k = int(self.cfg.params.get("retrieval_k", 10)) prompt_ctxs = ctxs[:retrieval_k] contexts_blob = "\n".join([f"{i+1}. {psg}" for i, psg in enumerate(prompt_ctxs)]) - + # Get generation parameters - N=2 for both by default (paper specification) - n_expansions = int(self.cfg.params.get("gen_num", 2)) # Number of expansions for both KEQE and CSQE (default: 2) + n_expansions = int( + self.cfg.params.get("gen_num", 2) + ) # Number of expansions for both KEQE and CSQE (default: 2) max_tokens = int(self.cfg.llm.get("max_tokens", 1024)) temperature = float(self.cfg.llm.get("temperature", 1.0)) @@ -44,7 +48,7 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: messages=msgs_keqe, n=n_expansions, temperature=temperature, - max_tokens=max_tokens + max_tokens=max_tokens, ) keqe_passages = [ choice.message.content.strip().strip('"').strip("'") or "" @@ -59,12 +63,9 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: messages=msgs_csqe, n=n_expansions, temperature=temperature, - max_tokens=max_tokens + max_tokens=max_tokens, ) - csqe_responses = [ - choice.message.content or "" - for choice in resp_csqe.choices - ] + csqe_responses = [choice.message.content or "" for choice in resp_csqe.choices] # Step 3: Extract key sentences from CSQE responses # Pattern matches sentences in double quotes (from new_baselines.py) csqe_sentences: List[str] = [] @@ -97,28 +98,30 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: def _extract_key_sentences(self, response: str) -> str: """ Extract key sentences from CSQE response. - + Primary: Uses regex to find sentences in double quotes (expected format). Fallback: If no quotes found, extracts content after numbered document markers (1., 2., etc.) """ # Primary strategy: Extract sentences in double quotes pattern = r'"([^"]*)"' sentences = re.findall(pattern, response) - + if sentences: # Found quoted sentences - join and return joint_sentence = " ".join(sentences) return joint_sentence - + # Fallback: LLM didn't use quotes, extract from numbered document format # Pattern: Look for content after "1.", "2.", etc. (with optional trailing colon) # Remove "Relevant Documents:" header first - cleaned = re.sub(r'^Relevant Documents?:?\s*\n?', '', response, flags=re.IGNORECASE | re.MULTILINE) - + cleaned = re.sub( + r"^Relevant Documents?:?\s*\n?", "", response, flags=re.IGNORECASE | re.MULTILINE + ) + # Extract content after numbered markers (1., 2., etc.) - doc_pattern = r'\d+[\.:]\s*(.+?)(?=\d+[\.:]|$)' + doc_pattern = r"\d+[\.:]\s*(.+?)(?=\d+[\.:]|$)" doc_content = re.findall(doc_pattern, cleaned, re.DOTALL) - + if doc_content: # Clean up each extracted piece extracted = [] @@ -129,13 +132,13 @@ def _extract_key_sentences(self, response: str) -> str: extracted.append(cleaned_content) if extracted: return " ".join(extracted) - + # If still nothing found, return empty string return "" def _get_retrieval_params(self) -> Optional[Dict[str, Any]]: """Get CSQE-specific retrieval parameters for batch retrieval. - + Returns retrieval parameters compatible with BaseSearcher interface. Supports both searcher_type/searcher_kwargs format and legacy format. """ @@ -147,53 +150,53 @@ def _get_retrieval_params(self) -> Optional[Dict[str, Any]]: "k": int(self.cfg.params.get("retrieval_k", 10)), "threads": int(self.cfg.params.get("threads", 16)), } - + # Check if using new searcher interface format with searcher_type if "searcher_type" in self.cfg.params: # New format: use searcher_type and searcher_kwargs searcher_type = self.cfg.params.get("searcher_type", "pyserini") searcher_kwargs = self.cfg.params.get("searcher_kwargs", {}) - + # If index is provided in params but not in searcher_kwargs, add it if "index" in self.cfg.params and "index" not in searcher_kwargs: searcher_kwargs["index"] = self.cfg.params["index"] - + return { "searcher_type": searcher_type, "searcher_kwargs": searcher_kwargs, "k": int(self.cfg.params.get("retrieval_k", 10)), "threads": int(self.cfg.params.get("threads", 16)), } - + # Legacy format: convert old-style params to new format index = self.cfg.params.get("index") if not index: return None - + # Build searcher_kwargs from legacy params searcher_kwargs = { "index": index, "searcher_type": "impact" if self.cfg.params.get("impact", False) else "bm25", "answer_key": self.cfg.params.get("answer_key", "contents"), } - + # Add BM25 parameters if specified k1 = self.cfg.params.get("k1") b = self.cfg.params.get("b") if k1 is not None and b is not None: searcher_kwargs["k1"] = k1 searcher_kwargs["b"] = b - + # Add RM3 if requested if self.cfg.params.get("rm3", False): searcher_kwargs["rm3"] = True - + # Add Rocchio if requested if self.cfg.params.get("rocchio", False): searcher_kwargs["rocchio"] = True if self.cfg.params.get("rocchio_use_negative", False): searcher_kwargs["rocchio_use_negative"] = True - + return { "searcher_type": "pyserini", # Default to pyserini for legacy format "searcher_kwargs": searcher_kwargs, diff --git a/querygym/methods/genqr.py b/querygym/methods/genqr.py index 30a1fff..6c947f4 100644 --- a/querygym/methods/genqr.py +++ b/querygym/methods/genqr.py @@ -3,15 +3,17 @@ from ..core.registry import register_method from typing import List + @register_method("genqr") class GENQR(BaseReformulator): """ GenQR: Generate reformulations N times and concatenate. - + Pipeline: 1. Call LLM N times (default 5) to generate reformulations 2. Concatenate: query + reformulation1 + reformulation2 + ... + reformulationN """ + VERSION = "1.0" def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: @@ -19,26 +21,23 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: n_generations = int(self.cfg.params.get("n_generations", 5)) temperature = float(self.cfg.llm.get("temperature", 0.8)) max_tokens = int(self.cfg.llm.get("max_tokens", 256)) - + # Generate reformulations N times (5 by default) reformulations: List[str] = [] - + for _ in range(n_generations): msgs = self.prompts.render("genqr.keywords.v1", query=q.text) out = self.llm.chat(msgs, temperature=temperature, max_tokens=max_tokens) # Clean up the reformulation (remove quotes, extra whitespace) reformulation = out.strip().strip('"').strip("'").replace("\n", " ") reformulations.append(reformulation) - + # Concatenate: query + reformulation1 + reformulation2 + ... + reformulationN reformulated = q.text + " " + " ".join(reformulations) - + return ReformulationResult( - q.qid, - q.text, + q.qid, + q.text, reformulated, - metadata={ - "n_generations": n_generations, - "reformulations": reformulations - } + metadata={"n_generations": n_generations, "reformulations": reformulations}, ) diff --git a/querygym/methods/genqr_ensemble.py b/querygym/methods/genqr_ensemble.py index 1883bd4..78aa2e8 100644 --- a/querygym/methods/genqr_ensemble.py +++ b/querygym/methods/genqr_ensemble.py @@ -17,28 +17,30 @@ "genqr_ensemble.inst10.v1", ] + @register_method("genqr_ensemble") class GenQREnsemble(BaseReformulator): """ GenQREnsemble: Generative Query Reformulation with Ensemble of Instructions. - + Uses 10 instruction variants to generate diverse keyword expansions. Each variant generates keywords independently, then all are merged. - + Pipeline: 1. For each of 10 instruction variants, generate keywords (10 LLM calls) 2. Parse and merge all keyword lists 3. Expand query: (Q × 5) + K1 + K2 + ... + K10 - + Sampling Parameters (from paper): - temperature: 0.92 (nucleus sampling) - max_tokens: 256 - parallel: false (can be enabled for concurrent generation) - + Config Parameters: - variant_ids: List of prompt IDs (defaults to all 10 variants) - parallel: Enable parallel generation of all variants (default: false) """ + VERSION = "1.0" CONCATENATION_STRATEGY = "query_repeat_plus_generated" # (Q × 5) + K1 + K2 + ... + K10 DEFAULT_QUERY_REPEATS = 5 # GenQREnsemble uses 5 repetitions @@ -46,34 +48,34 @@ class GenQREnsemble(BaseReformulator): def _parse_keywords(self, keywords_text: str) -> List[str]: """ Parse keywords from LLM output (comma-separated or bullet list). - + Args: keywords_text: Raw LLM output - + Returns: List of individual keywords """ keywords = [] - + # Check if it's a bullet list (contains dashes/newlines) - if '\n' in keywords_text or keywords_text.strip().startswith('-'): + if "\n" in keywords_text or keywords_text.strip().startswith("-"): # Split by newlines and extract after dashes - for line in keywords_text.split('\n'): + for line in keywords_text.split("\n"): line = line.strip() # Remove leading dash/bullet - if line.startswith('-'): + if line.startswith("-"): line = line[1:].strip() - elif line.startswith('•'): + elif line.startswith("•"): line = line[1:].strip() - elif line.startswith('*'): + elif line.startswith("*"): line = line[1:].strip() - + if line: keywords.append(line) else: # Split by comma (standard format) - keywords = [k.strip() for k in keywords_text.split(',') if k.strip()] - + keywords = [k.strip() for k in keywords_text.split(",") if k.strip()] + return keywords def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: @@ -82,24 +84,24 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: temperature = float(self.cfg.llm.get("temperature", 0.92)) max_tokens = int(self.cfg.llm.get("max_tokens", 256)) parallel = bool(self.cfg.params.get("parallel", False)) - + # Get variant IDs (configurable, defaults to all 10) variant_ids = self.cfg.params.get("variant_ids", VARIANT_IDS) - + # Generate keywords for each instruction variant all_keywords = [] variant_outputs = {} - + if parallel: # Parallel generation using ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor, as_completed - + def generate_keywords_for_variant(idx: int, prompt_id: str): msgs = self.prompts.render(prompt_id, query=q.text) output = self.llm.chat(msgs, temperature=temperature, max_tokens=max_tokens) keywords = self._parse_keywords(output) return idx, prompt_id, output, keywords - + # Generate all variants in parallel with ThreadPoolExecutor(max_workers=len(variant_ids)) as executor: futures = [ @@ -113,7 +115,7 @@ def generate_keywords_for_variant(idx: int, prompt_id: str): "prompt_id": prompt_id, "raw_output": output, "keywords": keywords, - "count": len(keywords) + "count": len(keywords), } else: # Sequential generation (default) @@ -121,28 +123,28 @@ def generate_keywords_for_variant(idx: int, prompt_id: str): # Render prompt and generate keywords msgs = self.prompts.render(prompt_id, query=q.text) output = self.llm.chat(msgs, temperature=temperature, max_tokens=max_tokens) - + # Parse keywords from output keywords = self._parse_keywords(output) all_keywords.extend(keywords) - + # Store per-variant output for metadata variant_outputs[f"variant_{i}"] = { "prompt_id": prompt_id, "raw_output": output, "keywords": keywords, - "count": len(keywords) + "count": len(keywords), } - + # Merge all keywords into single string generated_content = " ".join(all_keywords) - + # Use base class concatenation: (Q × 5) + generated_content reformulated = self.concatenate_result(q.text, generated_content) - + return ReformulationResult( - q.qid, - q.text, + q.qid, + q.text, reformulated, metadata={ "num_variants": len(variant_ids), @@ -151,6 +153,6 @@ def generate_keywords_for_variant(idx: int, prompt_id: str): "variant_outputs": variant_outputs, "temperature": temperature, "max_tokens": max_tokens, - "parallel": parallel - } + "parallel": parallel, + }, ) diff --git a/querygym/methods/lamer.py b/querygym/methods/lamer.py index e9afc65..3f4d74b 100644 --- a/querygym/methods/lamer.py +++ b/querygym/methods/lamer.py @@ -3,11 +3,14 @@ from ..core.registry import register_method from typing import List, Optional, Dict, Any + @register_method("lamer") class LameR(BaseReformulator): VERSION = "0.1" REQUIRES_CONTEXT = True - CONCATENATION_STRATEGY = "interleaved_query_content" # q + a1 + q + a2 + q + a3 + q + a4 + q + a5 + CONCATENATION_STRATEGY = ( + "interleaved_query_content" # q + a1 + q + a2 + q + a3 + q + a4 + q + a5 + ) def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: # Use provided contexts (from batch retrieval) @@ -48,7 +51,7 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: def _get_retrieval_params(self) -> Optional[Dict[str, Any]]: """Get LameR-specific retrieval parameters for batch retrieval. - + Returns retrieval parameters compatible with BaseSearcher interface. Supports both searcher_type/searcher_kwargs format and legacy format. """ @@ -60,53 +63,53 @@ def _get_retrieval_params(self) -> Optional[Dict[str, Any]]: "k": int(self.cfg.params.get("retrieval_k", 10)), "threads": int(self.cfg.params.get("threads", 16)), } - + # Check if using new searcher interface format with searcher_type if "searcher_type" in self.cfg.params: # New format: use searcher_type and searcher_kwargs searcher_type = self.cfg.params.get("searcher_type", "pyserini") searcher_kwargs = self.cfg.params.get("searcher_kwargs", {}) - + # If index is provided in params but not in searcher_kwargs, add it if "index" in self.cfg.params and "index" not in searcher_kwargs: searcher_kwargs["index"] = self.cfg.params["index"] - + return { "searcher_type": searcher_type, "searcher_kwargs": searcher_kwargs, "k": int(self.cfg.params.get("retrieval_k", 10)), "threads": int(self.cfg.params.get("threads", 16)), } - + # Legacy format: convert old-style params to new format index = self.cfg.params.get("index") if not index: return None - + # Build searcher_kwargs from legacy params searcher_kwargs = { "index": index, "searcher_type": "impact" if self.cfg.params.get("impact", False) else "bm25", "answer_key": self.cfg.params.get("answer_key", "contents"), } - + # Add BM25 parameters if specified k1 = self.cfg.params.get("k1") b = self.cfg.params.get("b") if k1 is not None and b is not None: searcher_kwargs["k1"] = k1 searcher_kwargs["b"] = b - + # Add RM3 if requested if self.cfg.params.get("rm3", False): searcher_kwargs["rm3"] = True - + # Add Rocchio if requested if self.cfg.params.get("rocchio", False): searcher_kwargs["rocchio"] = True if self.cfg.params.get("rocchio_use_negative", False): searcher_kwargs["rocchio_use_negative"] = True - + return { "searcher_type": "pyserini", # Default to pyserini for legacy format "searcher_kwargs": searcher_kwargs, diff --git a/querygym/methods/mugi.py b/querygym/methods/mugi.py index e8b8a7e..82b5996 100644 --- a/querygym/methods/mugi.py +++ b/querygym/methods/mugi.py @@ -3,14 +3,15 @@ from ..core.base import BaseReformulator, QueryItem, ReformulationResult from ..core.registry import register_method + @register_method("mugi") class MuGI(BaseReformulator): """ MuGI (Multi-Text Generation Integration) Method - + Generates multiple diverse pseudo-documents per query and uses adaptive concatenation to balance term frequencies between short queries and long pseudo-docs. - + Key Parameters: - num_docs: Number of pseudo-documents to generate per query (default: 5) - adaptive_times: Divisor for adaptive repetition ratio (default: 5) @@ -19,11 +20,12 @@ class MuGI(BaseReformulator): - mode: "zs" for zero-shot or "fs" for few-shot (default: "zs") - prompt_id: (Advanced) Direct prompt ID override ("mugi.zeroshot.v1" or "mugi.fewshot.v1") - parallel: If True, generate all pseudo-docs in parallel; if False, sequential (default: False) - + Formula: repetition_times = (len(all_pseudo_docs) // len(query)) // adaptive_times enhanced_query = (query + ' ') * repetition_times + all_pseudo_docs """ + VERSION = "1.0" REQUIRES_CONTEXT = False CONCATENATION_STRATEGY = "adaptive_query_repeat_plus_generated" @@ -33,12 +35,14 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: num_docs = int(self.cfg.params.get("num_docs", 5)) adaptive_times = int(self.cfg.params.get("adaptive_times", 6)) max_tokens = int(self.cfg.params.get("max_tokens", self.cfg.llm.get("max_tokens", 1024))) - temperature = float(self.cfg.params.get("temperature", self.cfg.llm.get("temperature", 1.0))) + temperature = float( + self.cfg.params.get("temperature", self.cfg.llm.get("temperature", 1.0)) + ) parallel = bool(self.cfg.params.get("parallel", False)) - + # Mode selection: "zs" (zero-shot) or "fs" (few-shot) mode = str(self.cfg.params.get("mode", "zs")) - + # Map mode to prompt_id (or allow direct prompt_id override for advanced users) if "prompt_id" in self.cfg.params: # Advanced: allow direct prompt ID specification @@ -50,20 +54,22 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: elif mode in ["zs", "zeroshot"]: prompt_id = "mugi.zeroshot.v1" else: - raise ValueError(f"Invalid mode '{mode}' for MuGI. Must be 'zs' (zero-shot) or 'fs' (few-shot).") - + raise ValueError( + f"Invalid mode '{mode}' for MuGI. Must be 'zs' (zero-shot) or 'fs' (few-shot)." + ) + # Generate multiple pseudo-documents for diversity and coverage pseudo_docs: List[str] = [] - + if parallel: # Parallel generation using ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor, as_completed - + def generate_single_doc(doc_idx: int) -> str: msgs = self.prompts.render(prompt_id, query=q.text) pseudo_doc = self.llm.chat(msgs, temperature=temperature, max_tokens=max_tokens) return pseudo_doc.strip().strip('"').strip("'") - + # Generate all pseudo-docs in parallel with ThreadPoolExecutor(max_workers=num_docs) as executor: futures = [executor.submit(generate_single_doc, i) for i in range(num_docs)] @@ -74,25 +80,21 @@ def generate_single_doc(doc_idx: int) -> str: for i in range(num_docs): # Render prompt (mugi.zeroshot.v1 or mugi.fewshot.v1) msgs = self.prompts.render(prompt_id, query=q.text) - + # Generate pseudo-document with high temperature for diversity - pseudo_doc = self.llm.chat( - msgs, - temperature=temperature, - max_tokens=max_tokens - ) - + pseudo_doc = self.llm.chat(msgs, temperature=temperature, max_tokens=max_tokens) + # Clean up the generated pseudo-document pseudo_doc = pseudo_doc.strip().strip('"').strip("'") pseudo_docs.append(pseudo_doc) - + # Join all pseudo-documents into single text - all_pseudo_docs = ' '.join(pseudo_docs) - + all_pseudo_docs = " ".join(pseudo_docs) + # Use base class concatenation with adaptive_query_repeat_plus_generated strategy # This automatically handles the MuGI formula and cleaning reformulated = self.concatenate_result(q.text, all_pseudo_docs) - + # Calculate metrics for metadata (after concatenation) query_len = len(q.text) docs_len = len(all_pseudo_docs) @@ -100,7 +102,7 @@ def generate_single_doc(doc_idx: int) -> str: repetition_times = max(1, (docs_len // query_len) // adaptive_times) else: repetition_times = 1 - + # Return result with metadata return ReformulationResult( q.qid, @@ -122,5 +124,5 @@ def generate_single_doc(doc_idx: int) -> str: "pseudo_doc_3": pseudo_docs[2] if len(pseudo_docs) > 2 else "", "pseudo_doc_4": pseudo_docs[3] if len(pseudo_docs) > 3 else "", "pseudo_doc_5": pseudo_docs[4] if len(pseudo_docs) > 4 else "", - } + }, ) diff --git a/querygym/methods/qa_expand.py b/querygym/methods/qa_expand.py index 5061bec..7a97a00 100644 --- a/querygym/methods/qa_expand.py +++ b/querygym/methods/qa_expand.py @@ -4,6 +4,7 @@ from ..core.base import BaseReformulator, QueryItem, ReformulationResult from ..core.registry import register_method + def parse_llm_json(text: str) -> dict: """ Robust JSON parser for LLM outputs that handles common formatting issues: @@ -16,83 +17,84 @@ def parse_llm_json(text: str) -> dict: """ if not text or not text.strip(): return {} - + original_text = text text = text.strip() - + # Remove markdown code blocks if text.startswith("```"): - first_newline = text.find('\n') + first_newline = text.find("\n") if first_newline != -1: - text = text[first_newline+1:] + text = text[first_newline + 1 :] if text.endswith("```"): text = text[:-3] text = text.strip() - + # Unescape common escape sequences that appear in raw output - text = text.replace('\\"', '"').replace('\\n', '\n').replace('\\t', '\t') - + text = text.replace('\\"', '"').replace("\\n", "\n").replace("\\t", "\t") + # Try standard JSON parsing first try: return json.loads(text) except json.JSONDecodeError: pass - + # Remove trailing commas before closing braces/brackets - cleaned = re.sub(r',(\s*[}\]])', r'\1', text) + cleaned = re.sub(r",(\s*[}\]])", r"\1", text) try: return json.loads(cleaned) except json.JSONDecodeError: pass - + # Replace single quotes with double quotes try: return json.loads(text.replace("'", '"')) except json.JSONDecodeError: pass - + # Try to fix incomplete JSON by adding missing closing quotes and braces - if '{' in text: + if "{" in text: # Find the last complete value and try to close it attempt = text.rstrip() - + # If ends mid-string (no closing quote), add closing quote if attempt.count('"') % 2 == 1: attempt += '"' - + # Add missing closing brace if needed - if not attempt.endswith('}'): - attempt += '}' - + if not attempt.endswith("}"): + attempt += "}" + try: return json.loads(attempt) except json.JSONDecodeError: pass - + # Fallback: Extract key-value pairs using regex # Looks for patterns like "key": "value" or 'key': 'value' result = {} - + # Pattern to match JSON key-value pairs patterns = [ r'["\'](\w+)["\']\s*:\s*["\']([^"\']*)["\']', # Standard JSON - r'(\w+)\s*:\s*["\']([^"\']*)["\']', # Unquoted keys - r'["\'](\w+)["\']\s*:\s*([^,}\]]+)', # Values without quotes + r'(\w+)\s*:\s*["\']([^"\']*)["\']', # Unquoted keys + r'["\'](\w+)["\']\s*:\s*([^,}\]]+)', # Values without quotes ] - + for pattern in patterns: matches = re.findall(pattern, text, re.DOTALL) for key, value in matches: if key and value: result[key] = value.strip() - + # If we found any key-value pairs, return them if result: return result - + # Absolute fallback: return empty dict return {} + @register_method("qa_expand") class QAExpand(BaseReformulator): """ @@ -102,6 +104,7 @@ class QAExpand(BaseReformulator): 3. Filter and refine answers based on relevance to query (LLM call #3) 4. Expand query: (Q × 3) + refined_answers """ + VERSION = "1.0" REQUIRES_CONTEXT = False # No contexts needed - pure generative expansion CONCATENATION_STRATEGY = "query_repeat_plus_generated" # (Q × 3) + refined_answers @@ -110,82 +113,86 @@ class QAExpand(BaseReformulator): def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: # K is fixed at 3 for QA-Expand k = 3 - + # Get configurable parameters max_tokens = self.cfg.params.get("max_tokens", self.cfg.llm.get("max_tokens", 256)) default_temp = self.cfg.llm.get("temperature", 0.8) - + # Temperature settings (can override per-step, defaults to global temperature) temp_subq = self.cfg.params.get("temperature_subq", default_temp) temp_answer = self.cfg.params.get("temperature_answer", default_temp) temp_refine = self.cfg.params.get("temperature_refine", default_temp) - + # Prompt IDs (customizable for different datasets) prompt_subq = self.cfg.params.get("prompt_subq", "qa_expand.subq.v1") prompt_answer = self.cfg.params.get("prompt_answer", "qa_expand.answer.v1") prompt_refine = self.cfg.params.get("prompt_refine", "qa_expand.refine.v1") - + # Step 1: Generate 3 sub-questions (LLM call #1) # Input: {query} # Output: Raw text with 3 questions msgs = self.prompts.render(prompt_subq, query=q.text) subqs_raw = self.llm.chat(msgs, temperature=temp_subq, max_tokens=max_tokens) - + # Format questions as JSON for next step # Try to parse if LLM already returned JSON, otherwise split by newlines questions_dict = parse_llm_json(subqs_raw) - - if questions_dict and any(k.startswith('question') for k in questions_dict.keys()): + + if questions_dict and any(k.startswith("question") for k in questions_dict.keys()): # LLM returned JSON format already questions_json = json.dumps(questions_dict) else: # Parse from text - split by newlines and clean up - lines = [line.strip() for line in subqs_raw.split('\n') if line.strip()] + lines = [line.strip() for line in subqs_raw.split("\n") if line.strip()] # Remove common prefixes like "1.", "Q1:", "-", etc. cleaned_lines = [] for line in lines: - line = re.sub(r'^[\d\.\-\*\)]+\s*', '', line) # Remove "1.", "1)", "-", "*" - line = re.sub(r'^[Qq]\d+[\:\.]?\s*', '', line) # Remove "Q1:", "q1." + line = re.sub(r"^[\d\.\-\*\)]+\s*", "", line) # Remove "1.", "1)", "-", "*" + line = re.sub(r"^[Qq]\d+[\:\.]?\s*", "", line) # Remove "Q1:", "q1." if line: cleaned_lines.append(line) - + # Take first 3 questions - questions_json = json.dumps({ - "question1": cleaned_lines[0] if len(cleaned_lines) > 0 else "", - "question2": cleaned_lines[1] if len(cleaned_lines) > 1 else "", - "question3": cleaned_lines[2] if len(cleaned_lines) > 2 else "" - }) - + questions_json = json.dumps( + { + "question1": cleaned_lines[0] if len(cleaned_lines) > 0 else "", + "question2": cleaned_lines[1] if len(cleaned_lines) > 1 else "", + "question3": cleaned_lines[2] if len(cleaned_lines) > 2 else "", + } + ) + # Step 2: Generate 3 pseudo-answers from LLM's internal knowledge (LLM call #2) # Input: {questions} (JSON format) # Output: JSON with {"answer1": "...", "answer2": "...", "answer3": "..."} # NOTE: No contexts used - LLM generates answers from its own knowledge msgs = self.prompts.render(prompt_answer, questions=questions_json) answers_json = self.llm.chat(msgs, temperature=temp_answer, max_tokens=max_tokens) - + # Step 3: Filter and refine answers (LLM call #3) # Input: {query}, {answers} (JSON format) # Output: Refined JSON with relevant answers only msgs = self.prompts.render(prompt_refine, query=q.text, answers=answers_json) refined_answers_json = self.llm.chat(msgs, temperature=temp_refine, max_tokens=max_tokens) - + # Parse refined JSON and extract answer values for concatenation refined_dict = parse_llm_json(refined_answers_json) if refined_dict: # Extract answer values in order (answer1, answer2, answer3) - answer_values = [str(v) for k, v in sorted(refined_dict.items()) if v and str(v).strip()] + answer_values = [ + str(v) for k, v in sorted(refined_dict.items()) if v and str(v).strip() + ] refined_text = " ".join(answer_values) else: # Fallback if JSON parsing fails - use raw text refined_text = refined_answers_json - + # Step 4: Expand query: (Q × 3) + refined_answers reformulated = self.concatenate_result(q.text, refined_text, query_repeats=k) - + return ReformulationResult( - q.qid, - q.text, - reformulated, + q.qid, + q.text, + reformulated, metadata={ "subquestions_raw": subqs_raw, "questions_json": questions_json, @@ -195,7 +202,7 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: "prompts_used": { "subq": prompt_subq, "answer": prompt_answer, - "refine": prompt_refine - } - } + "refine": prompt_refine, + }, + }, ) diff --git a/querygym/methods/query2doc.py b/querygym/methods/query2doc.py index ab6820c..c826de5 100644 --- a/querygym/methods/query2doc.py +++ b/querygym/methods/query2doc.py @@ -6,88 +6,90 @@ import os import csv + @register_method("query2doc") class Query2Doc(BaseReformulator): """ Query2Doc: Generate pseudo-documents from LLM knowledge. - + Modes: - "zs" (zero-shot): Generate passage directly [default] - "cot" (chain-of-thought): Zero-shot with reasoning - "fs", "fewshot", or "few-shot" (few-shot): Uses training examples from any dataset (dynamic) - + Few-Shot Config (via params or env vars): - dataset_type: "msmarco", "beir", or "generic" (uses appropriate loader) - num_examples: Number of few-shot examples (default: 4) - + For MS MARCO datasets (dataset_type="msmarco"): - collection_path / COLLECTION_PATH: Path to collection.tsv - train_queries_path / TRAIN_QUERIES_PATH: Path to queries.tsv - train_qrels_path / TRAIN_QRELS_PATH: Path to qrels file - + For BEIR datasets (dataset_type="beir"): - beir_data_dir / BEIR_DATA_DIR: Path to BEIR dataset directory - train_split: "train" or "dev" (default: "train") - + For generic datasets (dataset_type="generic" or omitted): - collection_path: TSV file (docid \t text) - train_queries_path: TSV file (qid \t query) - train_qrels_path: TREC format (qid 0 docid relevance) - + Note: MS MARCO env vars (MSMARCO_COLLECTION, etc.) are supported for backward compatibility. """ + VERSION = "1.0" CONCATENATION_STRATEGY = "query_repeat_plus_generated" DEFAULT_QUERY_REPEATS = 5 - + def __init__(self, cfg, llm_client, prompt_resolver): super().__init__(cfg, llm_client, prompt_resolver) - self._fewshot_data = None - + self._fewshot_data = None + def _load_fewshot_data(self): """Lazy load training data for few-shot mode (supports MS MARCO, BEIR, or generic datasets).""" if self._fewshot_data is not None: return self._fewshot_data - + try: # Get dataset type (msmarco, beir, or generic) dataset_type = self.cfg.params.get("dataset_type", "").lower() - + # Get paths from config params or environment variables collection_path = ( - self.cfg.params.get("collection_path") or - self.cfg.params.get("msmarco_collection") or - os.getenv("COLLECTION_PATH") or - os.getenv("MSMARCO_COLLECTION") + self.cfg.params.get("collection_path") + or self.cfg.params.get("msmarco_collection") + or os.getenv("COLLECTION_PATH") + or os.getenv("MSMARCO_COLLECTION") ) train_queries_path = ( - self.cfg.params.get("train_queries_path") or - self.cfg.params.get("msmarco_train_queries") or - os.getenv("TRAIN_QUERIES_PATH") or - os.getenv("MSMARCO_TRAIN_QUERIES") + self.cfg.params.get("train_queries_path") + or self.cfg.params.get("msmarco_train_queries") + or os.getenv("TRAIN_QUERIES_PATH") + or os.getenv("MSMARCO_TRAIN_QUERIES") ) train_qrels_path = ( - self.cfg.params.get("train_qrels_path") or - self.cfg.params.get("msmarco_train_qrels") or - os.getenv("TRAIN_QRELS_PATH") or - os.getenv("MSMARCO_TRAIN_QRELS") + self.cfg.params.get("train_qrels_path") + or self.cfg.params.get("msmarco_train_qrels") + or os.getenv("TRAIN_QRELS_PATH") + or os.getenv("MSMARCO_TRAIN_QRELS") ) - + # For BEIR, collection_path is actually the BEIR data directory if dataset_type == "beir": beir_data_dir = ( - self.cfg.params.get("beir_data_dir") or - collection_path or - os.getenv("BEIR_DATA_DIR") + self.cfg.params.get("beir_data_dir") + or collection_path + or os.getenv("BEIR_DATA_DIR") ) train_split = self.cfg.params.get("train_split", "train") - + if not beir_data_dir: raise RuntimeError( "Few-shot mode with BEIR requires beir_data_dir (via config or env var):\n" " - beir_data_dir / BEIR_DATA_DIR: Path to BEIR dataset directory" ) - + elif not all([collection_path, train_queries_path, train_qrels_path]): raise RuntimeError( "Few-shot mode requires training data paths (via config params or env vars):\n" @@ -102,10 +104,11 @@ def _load_fewshot_data(self): "\nFor backward compatibility, MS MARCO env vars are also supported:\n" " - MSMARCO_COLLECTION, MSMARCO_TRAIN_QUERIES, MSMARCO_TRAIN_QRELS" ) - + # Load data using appropriate loader based on dataset type if dataset_type == "beir": from ..loaders import beir + corpus = beir.load_corpus(beir_data_dir) # Convert BEIR corpus format (dict with title/text) to simple text collection = {} @@ -114,47 +117,48 @@ def _load_fewshot_data(self): title = doc_dict.get("title", "").strip() text = doc_dict.get("text", "").strip() collection[docid] = f"{title} {text}".strip() if title else text - + train_queries_list = beir.load_queries(beir_data_dir) train_queries_dict = {q.qid: q.text for q in train_queries_list} train_qrels = beir.load_qrels(beir_data_dir, split=train_split) - + elif dataset_type == "msmarco": from ..loaders import msmarco + collection = msmarco.load_collection(collection_path) train_queries_list = msmarco.load_queries(train_queries_path) train_queries_dict = {q.qid: q.text for q in train_queries_list} train_qrels = msmarco.load_qrels(train_qrels_path) - + else: # Generic: Load using DataLoader for maximum flexibility from ..data.dataloader import DataLoader - + # Load collection (TSV format: docid \t text) collection = {} - with open(collection_path, 'r', encoding='utf-8') as f: - reader = csv.reader(f, delimiter='\t') + with open(collection_path, "r", encoding="utf-8") as f: + reader = csv.reader(f, delimiter="\t") for row in reader: if len(row) >= 2: docid, text = row[0], row[1] collection[docid] = text - + # Load queries and qrels train_queries_list = DataLoader.load_queries(train_queries_path, format="tsv") train_queries_dict = {q.qid: q.text for q in train_queries_list} train_qrels = DataLoader.load_qrels(train_qrels_path, format="trec") - + self._fewshot_data = { "collection": collection, "train_queries": train_queries_dict, - "train_qrels": train_qrels + "train_qrels": train_qrels, } - + return self._fewshot_data - + except Exception as e: raise RuntimeError(f"Failed to load few-shot training data: {e}") - + def _select_few_shot_examples(self, num_examples: int = 4) -> List[Tuple[str, str]]: """Randomly sample relevant query-passage pairs from training data.""" try: @@ -162,43 +166,43 @@ def _select_few_shot_examples(self, num_examples: int = 4) -> List[Tuple[str, st collection = data["collection"] train_queries = data["train_queries"] train_qrels = data["train_qrels"] - + # Collect all relevant query-doc pairs relevant_pairs = [] for qid, doc_rels in train_qrels.items(): for docid, relevance in doc_rels.items(): if relevance > 0: # Only relevant relevant_pairs.append((qid, docid)) - + if not relevant_pairs: raise RuntimeError("No relevant query-document pairs found") - + # Sample pairs and fetch texts sample_size = min(num_examples * 10, len(relevant_pairs)) sampled_pairs = random.sample(relevant_pairs, sample_size) - + examples = [] for qid, docid in sampled_pairs: if len(examples) >= num_examples: break - + query_text = train_queries.get(qid) doc_text = collection.get(docid) - + if query_text and doc_text: examples.append((query_text, doc_text)) - + if not examples: raise RuntimeError("Could not find valid examples with matching IDs") - + if len(examples) < num_examples: print(f"Warning: Only found {len(examples)}/{num_examples} examples") - + return examples - + except Exception as e: raise RuntimeError(f"Failed to select few-shot examples: {e}") - + def _format_examples(self, examples: List[Tuple[str, str]]) -> str: """Format examples for prompt template.""" examples_text = "" @@ -215,9 +219,9 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: mode = "fs" temperature = float(self.cfg.llm.get("temperature", 0.7)) max_tokens = int(self.cfg.llm.get("max_tokens", 256)) - + metadata = {"mode": mode} - + try: # Select prompt based on mode if mode == "fs": @@ -225,8 +229,10 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: num_examples = int(self.cfg.params.get("num_examples", 4)) examples = self._select_few_shot_examples(num_examples) examples_text = self._format_examples(examples) - - msgs = self.prompts.render("query2doc.fewshot.v1", query=q.text, examples=examples_text) + + msgs = self.prompts.render( + "query2doc.fewshot.v1", query=q.text, examples=examples_text + ) metadata["prompt_id"] = "query2doc.fewshot.v1" metadata["num_examples"] = len(examples) else: @@ -234,30 +240,27 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: prompt_id = "query2doc.cot.v1" if mode == "cot" else "query2doc.zeroshot.v1" msgs = self.prompts.render(prompt_id, query=q.text) metadata["prompt_id"] = prompt_id - + # Generate pseudo-document using LLM pseudo_doc = self.llm.chat(msgs, temperature=temperature, max_tokens=max_tokens) - + if not pseudo_doc or not pseudo_doc.strip(): raise RuntimeError("LLM returned empty response") - + # Use base class concatenation (query_repeat_plus_generated) reformulated = self.concatenate_result(q.text, pseudo_doc) - - metadata.update({ - "pseudo_doc": pseudo_doc, - "temperature": temperature, - "max_tokens": max_tokens - }) - + + metadata.update( + {"pseudo_doc": pseudo_doc, "temperature": temperature, "max_tokens": max_tokens} + ) + return ReformulationResult(q.qid, q.text, reformulated, metadata=metadata) - + except Exception as e: # Graceful error handling - fallback to original query error_msg = f"Query2Doc failed: {e}" print(f"Error for qid={q.qid}: {error_msg}") - + return ReformulationResult( - q.qid, q.text, q.text, - metadata={"mode": mode, "error": error_msg, "fallback": True} + q.qid, q.text, q.text, metadata={"mode": mode, "error": error_msg, "fallback": True} ) diff --git a/querygym/methods/query2e.py b/querygym/methods/query2e.py index 5deee01..c2019ee 100644 --- a/querygym/methods/query2e.py +++ b/querygym/methods/query2e.py @@ -6,42 +6,44 @@ import os import csv + @register_method("query2e") class Query2E(BaseReformulator): """ Query2E: Query to keyword expansion. - + Modes: - zs (zero-shot): Simple keyword generation - fs (few-shot): Uses training examples for keyword generation - + Formula: (query × 5) + keywords - + Few-Shot Examples (3 ways to provide): 1. Via --ctx-jsonl CLI flag (JSONL file with {"query": "...", "passage": "..."} per line) 2. Via params["examples"] (list of {"query": "...", "passage": "..."} dicts) 3. Auto-generate from training data (if no examples provided) - + Few-Shot Auto-Generation Config (via params or env vars): - dataset_type: "msmarco", "beir", or "generic" (uses appropriate loader) - num_examples: Number of few-shot examples (default: 4) - + For MS MARCO datasets (dataset_type="msmarco"): - collection_path / COLLECTION_PATH: Path to collection.tsv - train_queries_path / TRAIN_QUERIES_PATH: Path to queries.tsv - train_qrels_path / TRAIN_QRELS_PATH: Path to qrels file - + For BEIR datasets (dataset_type="beir"): - beir_data_dir / BEIR_DATA_DIR: Path to BEIR dataset directory - train_split: "train" or "dev" (default: "train") - + For generic datasets (dataset_type="generic" or omitted): - collection_path: TSV file (docid \t text) - train_queries_path: TSV file (qid \t query) - train_qrels_path: TREC format (qid 0 docid relevance) - + Note: MS MARCO env vars (MSMARCO_COLLECTION, etc.) are supported for backward compatibility. """ + VERSION = "1.0" CONCATENATION_STRATEGY = "query_repeat_plus_generated" DEFAULT_QUERY_REPEATS = 5 @@ -51,51 +53,51 @@ def __init__(self, cfg, llm_client, prompt_resolver): self._fewshot_data = None # User-provided examples (set via set_examples() or params["examples"]) self._provided_examples = cfg.params.get("examples", None) - + def _load_fewshot_data(self): """Lazy load training data for few-shot mode (supports MS MARCO, BEIR, or generic datasets).""" if self._fewshot_data is not None: return self._fewshot_data - + try: # Get dataset type (msmarco, beir, or generic) dataset_type = self.cfg.params.get("dataset_type", "").lower() - + # Get paths from config params or environment variables collection_path = ( - self.cfg.params.get("collection_path") or - self.cfg.params.get("msmarco_collection") or - os.getenv("COLLECTION_PATH") or - os.getenv("MSMARCO_COLLECTION") + self.cfg.params.get("collection_path") + or self.cfg.params.get("msmarco_collection") + or os.getenv("COLLECTION_PATH") + or os.getenv("MSMARCO_COLLECTION") ) train_queries_path = ( - self.cfg.params.get("train_queries_path") or - self.cfg.params.get("msmarco_train_queries") or - os.getenv("TRAIN_QUERIES_PATH") or - os.getenv("MSMARCO_TRAIN_QUERIES") + self.cfg.params.get("train_queries_path") + or self.cfg.params.get("msmarco_train_queries") + or os.getenv("TRAIN_QUERIES_PATH") + or os.getenv("MSMARCO_TRAIN_QUERIES") ) train_qrels_path = ( - self.cfg.params.get("train_qrels_path") or - self.cfg.params.get("msmarco_train_qrels") or - os.getenv("TRAIN_QRELS_PATH") or - os.getenv("MSMARCO_TRAIN_QRELS") + self.cfg.params.get("train_qrels_path") + or self.cfg.params.get("msmarco_train_qrels") + or os.getenv("TRAIN_QRELS_PATH") + or os.getenv("MSMARCO_TRAIN_QRELS") ) - + # For BEIR, collection_path is actually the BEIR data directory if dataset_type == "beir": beir_data_dir = ( - self.cfg.params.get("beir_data_dir") or - collection_path or - os.getenv("BEIR_DATA_DIR") + self.cfg.params.get("beir_data_dir") + or collection_path + or os.getenv("BEIR_DATA_DIR") ) train_split = self.cfg.params.get("train_split", "train") - + if not beir_data_dir: raise RuntimeError( "Few-shot mode with BEIR requires beir_data_dir (via config or env var):\n" " - beir_data_dir / BEIR_DATA_DIR: Path to BEIR dataset directory" ) - + elif not all([collection_path, train_queries_path, train_qrels_path]): raise RuntimeError( "Few-shot mode requires training data paths (via config params or env vars):\n" @@ -110,10 +112,11 @@ def _load_fewshot_data(self): "\nFor backward compatibility, MS MARCO env vars are also supported:\n" " - MSMARCO_COLLECTION, MSMARCO_TRAIN_QUERIES, MSMARCO_TRAIN_QRELS" ) - + # Load data using appropriate loader based on dataset type if dataset_type == "beir": from ..loaders import beir + corpus = beir.load_corpus(beir_data_dir) # Convert BEIR corpus format (dict with title/text) to simple text collection = {} @@ -122,47 +125,48 @@ def _load_fewshot_data(self): title = doc_dict.get("title", "").strip() text = doc_dict.get("text", "").strip() collection[docid] = f"{title} {text}".strip() if title else text - + train_queries_list = beir.load_queries(beir_data_dir) train_queries_dict = {q.qid: q.text for q in train_queries_list} train_qrels = beir.load_qrels(beir_data_dir, split=train_split) - + elif dataset_type == "msmarco": from ..loaders import msmarco + collection = msmarco.load_collection(collection_path) train_queries_list = msmarco.load_queries(train_queries_path) train_queries_dict = {q.qid: q.text for q in train_queries_list} train_qrels = msmarco.load_qrels(train_qrels_path) - + else: # Generic: Load using DataLoader for maximum flexibility from ..data.dataloader import DataLoader - + # Load collection (TSV format: docid \t text) collection = {} - with open(collection_path, 'r', encoding='utf-8') as f: - reader = csv.reader(f, delimiter='\t') + with open(collection_path, "r", encoding="utf-8") as f: + reader = csv.reader(f, delimiter="\t") for row in reader: if len(row) >= 2: docid, text = row[0], row[1] collection[docid] = text - + # Load queries and qrels train_queries_list = DataLoader.load_queries(train_queries_path, format="tsv") train_queries_dict = {q.qid: q.text for q in train_queries_list} train_qrels = DataLoader.load_qrels(train_qrels_path, format="trec") - + self._fewshot_data = { "collection": collection, "train_queries": train_queries_dict, - "train_qrels": train_qrels + "train_qrels": train_qrels, } - + return self._fewshot_data - + except Exception as e: raise RuntimeError(f"Failed to load few-shot training data: {e}") - + def _select_few_shot_examples(self, num_examples: int = 4) -> List[Tuple[str, str]]: """Randomly sample relevant query-document pairs from training data.""" try: @@ -170,43 +174,43 @@ def _select_few_shot_examples(self, num_examples: int = 4) -> List[Tuple[str, st collection = data["collection"] train_queries = data["train_queries"] train_qrels = data["train_qrels"] - + # Collect all relevant query-doc pairs relevant_pairs = [] for qid, doc_rels in train_qrels.items(): for docid, relevance in doc_rels.items(): if relevance > 0: # Only relevant relevant_pairs.append((qid, docid)) - + if not relevant_pairs: raise RuntimeError("No relevant query-document pairs found") - + # Sample pairs and fetch texts sample_size = min(num_examples * 10, len(relevant_pairs)) sampled_pairs = random.sample(relevant_pairs, sample_size) - + examples = [] for qid, docid in sampled_pairs: if len(examples) >= num_examples: break - + query_text = train_queries.get(qid) doc_text = collection.get(docid) - + if query_text and doc_text: examples.append((query_text, doc_text)) - + if not examples: raise RuntimeError("Could not find valid examples with matching IDs") - + if len(examples) < num_examples: print(f"Warning: Only found {len(examples)}/{num_examples} examples") - + return examples - + except Exception as e: raise RuntimeError(f"Failed to select few-shot examples: {e}") - + def _format_examples(self, examples: List[Tuple[str, str]]) -> str: """Format examples for prompt template (query -> keywords extraction).""" examples_text = "" @@ -215,11 +219,11 @@ def _format_examples(self, examples: List[Tuple[str, str]]) -> str: # In practice, you might want more sophisticated keyword extraction words = passage.lower().split() # Take first 5-7 meaningful words as keywords (simple heuristic) - keywords = [w.strip('.,!?;:') for w in words[:7] if len(w) > 3] + keywords = [w.strip(".,!?;:") for w in words[:7] if len(w) > 3] keywords_str = ", ".join(keywords[:5]) if keywords else "keywords, terms, phrases" - + examples_text += f"Query: {query}\nKeywords: {keywords_str}\n" - + return examples_text def _parse_keywords(self, raw_output: str) -> List[str]: @@ -229,41 +233,46 @@ def _parse_keywords(self, raw_output: str) -> List[str]: - Bullet points: "- keyword1\n- keyword2" - Numbered lists: "1. keyword1\n2. keyword2" - Mixed formats - + Returns: List of cleaned keyword strings """ import re - + if not raw_output or not raw_output.strip(): return [] - + # Normalize the text text = raw_output.strip() - + # Remove common prefixes like "Keywords:", "Here are the keywords:", etc. - text = re.sub(r'^(keywords|here are|the keywords|list of keywords)[:\s]*', '', text, flags=re.IGNORECASE) - + text = re.sub( + r"^(keywords|here are|the keywords|list of keywords)[:\s]*", + "", + text, + flags=re.IGNORECASE, + ) + keywords = [] - + # Check if it's a bullet/numbered list format - if '\n' in text or text.strip().startswith('-') or re.match(r'^\d+\.', text.strip()): + if "\n" in text or text.strip().startswith("-") or re.match(r"^\d+\.", text.strip()): # Split by newlines - lines = text.split('\n') + lines = text.split("\n") for line in lines: line = line.strip() if not line: continue - + # Remove bullet points: -, *, •, ▪, etc. - line = re.sub(r'^[\-\*•▪→►]\s*', '', line) - + line = re.sub(r"^[\-\*•▪→►]\s*", "", line) + # Remove numbered prefixes: 1., 1), (1), etc. - line = re.sub(r'^[\(\[]?\d+[\)\]\.:\-]?\s*', '', line) - + line = re.sub(r"^[\(\[]?\d+[\)\]\.:\-]?\s*", "", line) + # If line contains commas, it might be multiple keywords - if ',' in line: - for part in line.split(','): + if "," in line: + for part in line.split(","): part = part.strip() if part: keywords.append(part) @@ -271,33 +280,33 @@ def _parse_keywords(self, raw_output: str) -> List[str]: keywords.append(line) else: # Assume comma-separated format - for part in text.split(','): + for part in text.split(","): part = part.strip() if part: keywords.append(part) - + # Clean each keyword cleaned_keywords = [] for kw in keywords: # Remove quotes - kw = kw.strip('"\'`') + kw = kw.strip("\"'`") # Remove trailing punctuation - kw = kw.rstrip('.,;:!?') + kw = kw.rstrip(".,;:!?") # Remove leading/trailing whitespace kw = kw.strip() # Skip empty or very short keywords if kw and len(kw) > 1: cleaned_keywords.append(kw) - + return cleaned_keywords def set_examples(self, examples: List[Tuple[str, str]]) -> None: """ Set user-provided examples for few-shot mode. - + Args: examples: List of (query, passage) tuples or list of dicts with 'query' and 'passage' keys - + Example: >>> reformulator.set_examples([ ... ("how long is flea life cycle?", "The life cycle of a flea..."), @@ -315,15 +324,15 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: mode = str(self.cfg.params.get("mode", "zs")) # Default to zero-shot temperature = float(self.cfg.llm.get("temperature", 0.3)) max_tokens = int(self.cfg.llm.get("max_tokens", 256)) - + metadata = {"mode": mode} - + try: # Select prompt based on mode if mode in ["fs", "fewshot"]: # Few-shot: use provided examples or auto-generate from training data num_examples = int(self.cfg.params.get("num_examples", 4)) - + if self._provided_examples: # Use user-provided examples # Convert dict format to tuple format if needed @@ -338,9 +347,9 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: # Auto-generate from training data (original behavior) examples = self._select_few_shot_examples(num_examples) metadata["examples_source"] = "auto_generated" - + examples_text = self._format_examples(examples) - + msgs = self.prompts.render("q2e.fs.v1", query=q.text, examples=examples_text) metadata["prompt_id"] = "q2e.fs.v1" metadata["num_examples"] = len(examples) @@ -349,43 +358,37 @@ def reformulate(self, q: QueryItem, contexts=None) -> ReformulationResult: msgs = self.prompts.render(prompt_id, query=q.text) metadata["prompt_id"] = prompt_id else: - raise ValueError(f"Invalid mode '{mode}' for Query2E. Must be 'zs' (zero-shot) or 'fs' (few-shot).") - + raise ValueError( + f"Invalid mode '{mode}' for Query2E. Must be 'zs' (zero-shot) or 'fs' (few-shot)." + ) + # Generate keywords out = self.llm.chat(msgs, temperature=temperature, max_tokens=max_tokens) - + # Parse keywords using robust parser terms = self._parse_keywords(out) - + # Limit to max 20 keywords max_keywords = int(self.cfg.params.get("max_keywords", 20)) if len(terms) > max_keywords: terms = terms[:max_keywords] - + generated_content = " ".join(terms) - + # Concatenate: query + keywords reformulated = self.concatenate_result(q.text, generated_content) - - metadata.update({ - "keywords": terms, - "temperature": temperature, - "max_tokens": max_tokens - }) - - return ReformulationResult( - q.qid, - q.text, - reformulated, - metadata=metadata + + metadata.update( + {"keywords": terms, "temperature": temperature, "max_tokens": max_tokens} ) - + + return ReformulationResult(q.qid, q.text, reformulated, metadata=metadata) + except Exception as e: # Graceful error handling - fallback to original query error_msg = f"Query2E failed: {e}" print(f"Error for qid={q.qid}: {error_msg}") - + return ReformulationResult( - q.qid, q.text, q.text, - metadata={"mode": mode, "error": error_msg, "fallback": True} + q.qid, q.text, q.text, metadata={"mode": mode, "error": error_msg, "fallback": True} ) diff --git a/tests/test_cli_script_gen.py b/tests/test_cli_script_gen.py index f28fcd2..a4c8c8c 100644 --- a/tests/test_cli_script_gen.py +++ b/tests/test_cli_script_gen.py @@ -1,5 +1,8 @@ from querygym.cli import build_script_lines + def test_script_lines(): - lines = build_script_lines(index_path="/idx", topics="/q.tsv", run="/run.txt", qrels="/qrels.txt") + lines = build_script_lines( + index_path="/idx", topics="/q.tsv", run="/run.txt", qrels="/qrels.txt" + ) assert any("--index /idx" in ln for ln in lines) diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 902319a..e32a88a 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -2,6 +2,7 @@ from pathlib import Path import tempfile + def test_local_tsv_loading(): with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".tsv") as f: f.write("1\tapple\n2\tbanana\n") diff --git a/tests/test_methods_genqr.py b/tests/test_methods_genqr.py index 078f5a7..c34616e 100644 --- a/tests/test_methods_genqr.py +++ b/tests/test_methods_genqr.py @@ -3,12 +3,16 @@ from querygym.methods.genqr_ensemble import GenQREnsemble from pathlib import Path + class DummyLLM: def chat(self, messages, **kwargs): return "bertopic" + def test_genqr_ensemble(): - cfg = MethodConfig(name="genqr_ensemble", params={"repeat_query_weight":2}, llm={"model":"dummy"}) + cfg = MethodConfig( + name="genqr_ensemble", params={"repeat_query_weight": 2}, llm={"model": "dummy"} + ) llm = DummyLLM() pb = PromptBank(Path(__file__).parents[1] / "querygym" / "prompt_bank.yaml") meth = GenQREnsemble(cfg, llm, pb) diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 7cb9681..004f242 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -1,6 +1,7 @@ from querygym.core.prompts import PromptBank from pathlib import Path + def test_prompt_bank_loads(): pb = PromptBank(Path(__file__).parents[1] / "querygym" / "prompt_bank.yaml") assert len(pb.list()) > 0