diff --git a/.gitignore b/.gitignore index 57252510..b031c1a1 100644 --- a/.gitignore +++ b/.gitignore @@ -183,4 +183,12 @@ examples/db/10.* tests example/ applications -vlm_test \ No newline at end of file +vlm_test +examples/vlm_piezo_test + +# Test results +db +results +elsevier_test.xml +springer_test.xml +wiley_test.pdf diff --git a/CHANGELOG.md b/CHANGELOG.md index e9afc18e..3695ada2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,13 @@ -## [Unreleased] +# Unreleased +- New `value_error_thresholds` parameter added to both `evaluate_semantic()` and `evaluate_agentic()` for range-based absolute error tolerances on numeric property value comparisons: -### Added + - Accepts a dict mapping `(min, max)` tuples to absolute error thresholds. When a ground-truth value falls inside a range, the extracted value is accepted if `|extracted - ground_truth| ≤ threshold`. Values outside all configured ranges fall back to exact comparison. + + - **Semantic evaluation**: handled inside `_is_value_in_range()` via the new `_get_error_threshold()` helper in `MaterialsDataSemanticEvaluator`. + + - **Agentic evaluation**: a new `GetValueErrorThresholdTool` (CrewAI `BaseTool`) is added to the composition evaluator agent when thresholds are configured. The agent calls this tool with the reference value to retrieve the tolerance before deciding on each numeric match. No tool is added and no prompt changes are made when no thresholds are provided. + +- Exposed `value_error_thresholds` in public evaluation methods: `ComProScanner.evaluate_semantic()`, `ComProScanner.evaluate_agentic()`, `comproscanner.evaluate_semantic()`, and `comproscanner.evaluate_agentic()`. - VLM-based graph data extraction added across all publishers and PDF processors: @@ -12,7 +19,25 @@ - New unit tests added for all three agent tools in `tests/test_agent_tools/`. -## [0.1.5] - 08-02-2026 +### Fixed + +- `process_articles()` now routes user-provided `doi_list` by `general_publisher` from metadata and sends each DOI only to its matching source processor. + +--- +## [0.1.6] - 2026-04-02 +### Changed +- Updated [README.md](README.md), [CITATION.cff](CITATION.cff) and docs with the published version (advance article) of the ComProScanner paper in _Digital Discovery_ as fully open access: + - [ComProScanner: a multi-agent based framework for composition-property structured data extraction from scientific literature](https://doi.org/10.1039/D5DD00521C) + +### Added +- Guide for API key creation for various LLM providers and publisher APIs added to the documentation at `docs/getting-started/api-key-guide.md` with detailed instructions for each provider. + +### Fixed +- Model prefix handling in `rag_tool.py` standardized to reflect the docs. +- `HF_TOKEN` documentation clarified as optional — only required for gated or private Hugging Face models. + +--- +## [0.1.5] - 2026-02-08 ### Added - Data related to comparison with other agentic data extraction frameworks added for the ComProScanner paper in the `examples/piezo_test/comparing_existing_frameworks` folder. @@ -83,7 +108,8 @@ - README badges section converted from HTML to markdown format for better compatibility across platforms. -## [0.1.4] - 02-12-2025 +--- +## [0.1.4] - 2025-12-02 ### Added @@ -118,7 +144,8 @@ - [ComProScanner Logo](https://raw.githubusercontent.com/aritraroy24/ComProScanner/main/assets/comproscanner_logo.png) - [ComProScanner Workflow](https://raw.githubusercontent.com/aritraroy24/ComProScanner/main/assets/overall_workflow.png) -## [0.1.3] - 04-11-2025 +--- +## [0.1.3] - 2025-11-04 ### Fixed @@ -126,14 +153,16 @@ - Changed from `from langchain.text_splitter import RecursiveCharacterTextSplitter` - To `from langchain.text_splitter.recursive_character import RecursiveCharacterTextSplitter` -## [0.1.2] - 24-10-2025 +--- +## [0.1.2] - 2025-10-24 ### Added - Link to ComProScanner preprint on arXiv in the documentation index page and README.md: - [arXiv:2510.20362](https://arxiv.org/abs/2510.20362) -## [0.1.1] - 22-10-2025 +--- +## [0.1.1] - 2025-10-22 ### Fixed @@ -141,7 +170,8 @@ - [ComProScanner Logo](https://i.ibb.co/whHSbGvT/comproscanner-logo.png) - [ComProScanner Workflow](https://i.ibb.co/QWd2qd3/overall-workflow.png) -## [0.1.0] - 22-10-2025 +--- +## [0.1.0] - 2025-10-22 ### Added diff --git a/CITATION.cff b/CITATION.cff index 723c2600..56b4a26c 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -16,7 +16,7 @@ contact: - family-names: Roy given-names: Aritra orcid: "https://orcid.org/0000-0002-4928-2935" -message: If you use this software, please cite our article on arXiv. +message: If you use this software, please cite our article in Digital Discovery. preferred-citation: authors: - family-names: Roy @@ -31,21 +31,28 @@ preferred-citation: - family-names: Gattinoni given-names: Chiara orcid: "https://orcid.org/0000-0002-3376-6374" - date-published: 2025-10-23 + doi: "10.1039/D5DD00521C" identifiers: + - type: doi + value: "10.1039/D5DD00521C" + description: "Peer-reviewed article" - type: other value: "arXiv:2510.20362" description: "arXiv preprint" - title: "ComProScanner: A multi-agent based framework for composition-property structured data extraction from scientific literature" + journal: "Digital Discovery" + publisher: + name: "RSC" + status: advance-online + title: "ComProScanner: a multi-agent based framework for composition-property structured data extraction from scientific literature" type: article - url: "https://arxiv.org/abs/2510.20362" + url: "https://doi.org/10.1039/D5DD00521C" repository-code: "https://github.com/slimeslab/ComProScanner" license: MIT title: "ComProScanner: A multi-agent based framework for composition-property structured data extraction from scientific literature" type: software url: "https://slimeslab.github.io/ComProScanner/" -version: "0.1.4" -date-released: 2025-12-03 +version: "0.1.6" +date-released: 2026-04-02 keywords: - materials science - data extraction diff --git a/README.md b/README.md index 2d5a4e05..b1609a5a 100644 --- a/README.md +++ b/README.md @@ -169,14 +169,15 @@ eval_visualizer.plot_multiple_radar_charts( If you use ComProScanner in your research, please cite: ```bibtex -@misc{roy2025comproscannermultiagentbasedframework, - title={ComProScanner: A multi-agent based framework for composition-property structured data extraction from scientific literature}, - author={Aritra Roy and Enrico Grisan and John Buckeridge and Chiara Gattinoni}, - year={2025}, - eprint={2510.20362}, - archivePrefix={arXiv}, - primaryClass={physics.comp-ph}, - url={https://arxiv.org/abs/2510.20362}, +@Article{roy2026comproscannermultiagentbasedframework, +author ="Roy, Aritra and Grisan, Enrico and Buckeridge, John and Gattinoni, Chiara", +title ="ComProScanner: a multi-agent based framework for composition-property structured data extraction from scientific literature", +journal ="Digital Discovery", +year ="2026", +pages ="Accepted", +publisher ="RSC", +doi ="10.1039/D5DD00521C", +url ="https://doi.org/10.1039/D5DD00521C" } ``` diff --git a/docs/about/changelog.md b/docs/about/changelog.md index eb6852fa..438b8421 100644 --- a/docs/about/changelog.md +++ b/docs/about/changelog.md @@ -1,6 +1,42 @@ -## Unreleased +# Unreleased +- New `value_error_thresholds` parameter added to both `evaluate_semantic()` and `evaluate_agentic()` for range-based absolute error tolerances on numeric property value comparisons: + + - Accepts a dict mapping `(min, max)` tuples to absolute error thresholds. When a ground-truth value falls inside a range, the extracted value is accepted if `|extracted - ground_truth| ≤ threshold`. Values outside all configured ranges fall back to exact comparison. + + - **Semantic evaluation**: handled inside `_is_value_in_range()` via the new `_get_error_threshold()` helper in `MaterialsDataSemanticEvaluator`. + + - **Agentic evaluation**: a new `GetValueErrorThresholdTool` (CrewAI `BaseTool`) is added to the composition evaluator agent when thresholds are configured. The agent calls this tool with the reference value to retrieve the tolerance before deciding on each numeric match. No tool is added and no prompt changes are made when no thresholds are provided. + +- Exposed `value_error_thresholds` in public evaluation methods: `ComProScanner.evaluate_semantic()`, `ComProScanner.evaluate_agentic()`, `comproscanner.evaluate_semantic()`, and `comproscanner.evaluate_agentic()`. + +- VLM-based graph data extraction added across all publishers and PDF processors: + + - New `GraphExtractorTool` — a CrewAI agent tool that reads saved figures for a given DOI and uses a vision LLM to extract composition-property value pairs from graphs and charts. Default VLM: `gemini/gemini-3-flash-preview`. + + - New `FigureExtractor` utility — shared helper for caption keyword-based figure filtering and saving, used by all article processors. + + - New `caption_keywords` parameter in `process_articles()` and `extract_composition_property_data()`, and new `vlm_model` and `related_figures_base_path` parameters in `extract_composition_property_data()`. + +- New unit tests added for all three agent tools in `tests/test_agent_tools/`. + +### Fixed + +- `process_articles()` now routes user-provided `doi_list` by `general_publisher` from metadata and sends each DOI only to its matching source processor. + +--- +## [0.1.6] - 2026-04-02 +### Changed +- Updated [README.md](README.md), [CITATION.cff](CITATION.cff) and docs with the published version (advance article) of the ComProScanner paper in _Digital Discovery_ as fully open access: + - [ComProScanner: a multi-agent based framework for composition-property structured data extraction from scientific literature](https://doi.org/10.1039/D5DD00521C) + +### Added +- Guide for API key creation for various LLM providers and publisher APIs added to the documentation at `docs/getting-started/api-key-guide.md` with detailed instructions for each provider. + +--- +## [0.1.5] - 2026-02-08 ### Added +- Data related to comparison with other agentic data extraction frameworks added for the ComProScanner paper in the `examples/piezo_test/comparing_existing_frameworks` folder. - New parameter `apply_advanced_cleaning` added to data cleaning methods in `data_cleaner.py`. When set to `True`, it triggers the advanced cleaning pipeline. @@ -38,6 +74,11 @@ - [CITATION.cff](https://github.com/slimeslab/ComProScanner/blob/main/CITATION.cff) added for standardized citation information based on the latest release and arXiv preprint. ### Fixed +- OAWorks API is replaced with OpenAlex API as OAWorks is no longer available. + +- Empty/corrupted PDF handled in `pdf_processor.py` and `wiley_processor.py` to avoid having GLYPH errors during text extraction. + +- Data extraction failures fixed if composition-property text data is empty. - CSV progress tracking in `elsevier_processor.py`: @@ -63,7 +104,8 @@ - README badges section converted from HTML to markdown format for better compatibility across platforms. -## [0.1.4] - 02-12-2025 +--- +## [0.1.4] - 2025-12-02 ### Added @@ -94,9 +136,12 @@ ### Changed -- README images updated with raw GitHub links for better reliability: [ComProScanner Logo](https://raw.githubusercontent.com/aritraroy24/ComProScanner/main/assets/comproscanner_logo.png), [ComProScanner Workflow](https://raw.githubusercontent.com/aritraroy24/ComProScanner/main/assets/overall_workflow.png) +- README images updated with raw GitHub links for better reliability: + - [ComProScanner Logo](https://raw.githubusercontent.com/aritraroy24/ComProScanner/main/assets/comproscanner_logo.png) + - [ComProScanner Workflow](https://raw.githubusercontent.com/aritraroy24/ComProScanner/main/assets/overall_workflow.png) -## [0.1.3] - 04-11-2025 +--- +## [0.1.3] - 2025-11-04 ### Fixed @@ -104,19 +149,25 @@ - Changed from `from langchain.text_splitter import RecursiveCharacterTextSplitter` - To `from langchain.text_splitter.recursive_character import RecursiveCharacterTextSplitter` -## [0.1.2] - 24-10-2025 +--- +## [0.1.2] - 2025-10-24 ### Added -- Link to ComProScanner preprint on arXiv in the documentation index page and README.md: [arXiv:2510.20362](https://arxiv.org/abs/2510.20362) +- Link to ComProScanner preprint on arXiv in the documentation index page and README.md: + - [arXiv:2510.20362](https://arxiv.org/abs/2510.20362) -## [0.1.1] - 22-10-2025 +--- +## [0.1.1] - 2025-10-22 ### Fixed -- README images updated with external image link to fix PyPI rendering issue. [ComProScanner Logo](https://i.ibb.co/whHSbGvT/comproscanner-logo.png), [ComProScanner Workflow](https://i.ibb.co/QWd2qd3/overall-workflow.png) +- README images updated with external image link to fix PyPI rendering issue. + - [ComProScanner Logo](https://i.ibb.co/whHSbGvT/comproscanner-logo.png) + - [ComProScanner Workflow](https://i.ibb.co/QWd2qd3/overall-workflow.png) -## [0.1.0] - 22-10-2025 +--- +## [0.1.0] - 2025-10-22 ### Added diff --git a/docs/about/citation.md b/docs/about/citation.md index b0d3be68..977b3882 100644 --- a/docs/about/citation.md +++ b/docs/about/citation.md @@ -3,13 +3,14 @@ If you use ComProScanner in your research, please cite our related paper: ```bibtex -@misc{roy2025comproscannermultiagentbasedframework, - title={ComProScanner: A multi-agent based framework for composition-property structured data extraction from scientific literature}, - author={Aritra Roy and Enrico Grisan and John Buckeridge and Chiara Gattinoni}, - year={2025}, - eprint={2510.20362}, - archivePrefix={arXiv}, - primaryClass={physics.comp-ph}, - url={https://arxiv.org/abs/2510.20362}, +@Article{roy2026comproscannermultiagentbasedframework, +author ="Roy, Aritra and Grisan, Enrico and Buckeridge, John and Gattinoni, Chiara", +title ="ComProScanner: a multi-agent based framework for composition-property structured data extraction from scientific literature", +journal ="Digital Discovery", +year ="2026", +pages ="Accepted", +publisher ="RSC", +doi ="10.1039/D5DD00521C", +url ="https://doi.org/10.1039/D5DD00521C" } ``` diff --git a/docs/getting-started/api-key-guide.md b/docs/getting-started/api-key-guide.md new file mode 100644 index 00000000..81645188 --- /dev/null +++ b/docs/getting-started/api-key-guide.md @@ -0,0 +1,323 @@ +# API Key Guide + +This page explains which provider credentials ComProScanner can use, what each one is used for, and how to generate or obtain them. + +## Overview + +ComProScanner can work with three groups of external providers: + +!!! important "Which credentials do you actually need?" + + | Provider group | Requirement level | + | --- | --- | + | **Publisher/content providers for article access** | **Optional** for manual or local workflows, but **required** for automated article retrieval. | + | **LLM providers for extraction, vision models and RAG** | **At least one required** for extraction, vision models and RAG workflows. However, default models are different for extraction/RAG and vision-language models. | + | **Default embedding provider for vector database creation** | **Required** unless you configure a custom embedding provider. | + +Use only the providers relevant to your workflow. You do not need every key. + +## Publisher Providers + +### Elsevier / Scopus + +Environment variable: `SCOPUS_API_KEY` + +Used for: + +- Scopus-based metadata retrieval +- Elsevier article retrieval in XML format + +How to get it: + +1. Create or sign in to your [Elsevier developer account](https://dev.elsevier.com/). +2. Open the [API key management area](https://dev.elsevier.com/apikey/manage). +3. Create a key for Scopus or content APIs. +4. Copy the generated key into your `.env` file as `SCOPUS_API_KEY`. + +```bash +SCOPUS_API_KEY=your_scopus_api_key +``` + +### Springer Nature Open Access API + +Environment variable: `SPRINGER_OPENACCESS_API_KEY` + +Used for: + +- Springer Open Access article retrieval in XML format + +How to get it: + +1. Create or sign in to your [Springer Nature account](https://dev.springernature.com/). +2. Fill up the form to request an Open Access API key at [https://dev.springernature.com/register/](https://dev.springernature.com/register/). +3. Get the Open Access API key from the [Springer Nature API management page](https://datasolutions.springernature.com/account/api-management/). +4. Copy the key into your `.env` file. + +```bash +SPRINGER_OPENACCESS_API_KEY=your_springer_openaccess_api_key +``` + +### Springer Nature TDM API + +Environment variable: `SPRINGER_TDM_API_KEY` + +Used for: + +- Springer subscription article retrieval in XML format + +How to get it: + +1. Subscribe to the Springer Nature TDM service via [https://dev.springernature.com/subscription/](https://dev.springernature.com/subscription/) and select the appropriate access level based on your institution and use case. +2. Copy the issued TDM key or token into your `.env` file. + +```bash +SPRINGER_TDM_API_KEY=your_springer_tdm_api_key +``` + +### Wiley TDM API + +Environment variable: `WILEY_API_KEY` + +Used for: + +- Wiley full-text article download as PDF + +How to get it: + +1. Create your [Wiley account](https://onlinelibrary.wiley.com/action/registration). +2. Login to your Wiley account at [https://onlinelibrary.wiley.com/library-info/resources/text-and-datamining](https://onlinelibrary.wiley.com/library-info/resources/text-and-datamining) under the "**Get a Text and Data Mining Token**" section. +3. Accept the terms and conditions to generate your API token. +4. Copy the API token into your `.env` file. + +```bash +WILEY_API_KEY=your_wiley_api_key +``` + +### IOP Publishing + +Environment variable: `IOP_papers_path` (*not an API key but a required path variable for processing IOP Science XML files*) + +Used for: + +- Local processing of IOP Science XML files downloaded in bulk + +How to get it: + +1. Email [contentsupport@ioppublishing.org](mailto:contentsupport@ioppublishing.org) to request bulk access to the IOP Science XML files, typically through SFTP as IOP Publishing does not provide direct API access for bulk downloads. +2. Once you have access, download the XML files to a local directory. +3. Set `IOP_papers_path` to the absolute local folder path containing all the downloaded files. + +```bash +IOP_papers_path=/absolute/path/to/iop_papers +``` + +## LLM Providers + +These providers can be used for extraction models, RAG chat models, and vision-language models where supported by your configuration. + +### OpenAI + +Environment variable: `OPENAI_API_KEY` + +Typical model prefixes: `openai/...` or OpenAI model names directly + +How to get it: + +1. Create or sign in to your [OpenAI account](https://platform.openai.com/). +2. Open the [API keys section](https://platform.openai.com/api-keys). +3. Create a new secret key. +4. Store it in `.env`. + +```bash +OPENAI_API_KEY=your_openai_api_key +``` + +### Google Gemini + +Environment variable: `GEMINI_API_KEY` + +Typical model prefixes: `gemini/...` + +How to get it: + +1. Create or sign in to your [Google AI Studio account](https://aistudio.google.com/). +2. Generate an API key from the [Gemini API key page](https://aistudio.google.com/app/apikey). +3. Store it in `.env` as `GEMINI_API_KEY`. + +```bash +GEMINI_API_KEY=your_gemini_api_key +``` + +### Anthropic + +Environment variable: `ANTHROPIC_API_KEY` + +Typical model prefixes: `anthropic/...` + +How to get it: + +1. Create or sign in to your [Anthropic Console account](https://console.anthropic.com/). +2. Create a new API key from the [Anthropic keys page](https://console.anthropic.com/settings/keys). +3. Store it in `.env`. + +```bash +ANTHROPIC_API_KEY=your_anthropic_api_key +``` + +### DeepSeek + +Environment variable: `DEEPSEEK_API_KEY` + +Typical model prefixes: `deepseek/...` + +How to get it: + +1. Create or sign in to your [DeepSeek platform account](https://platform.deepseek.com/). +2. Generate an API key from the [DeepSeek API keys page](https://platform.deepseek.com/api_keys). +3. Store it in `.env`. + +```bash +DEEPSEEK_API_KEY=your_deepseek_api_key +``` + +### OpenRouter + +Environment variable: `OPENROUTER_API_KEY` + +Typical model prefixes: `openrouter/...` + +How to get it: + +1. Create or sign in to your [OpenRouter account](https://openrouter.ai/). +2. Generate an API key from the [OpenRouter keys page](https://openrouter.ai/keys). +3. Store it in `.env`. + +```bash +OPENROUTER_API_KEY=your_openrouter_api_key +``` + +### Together AI + +Environment variable: `TOGETHER_API_KEY` + +Typical model prefixes: `together/...` + +How to get it: + +1. Create or sign in to your [Together AI account](https://www.together.ai/). +2. Generate an API key from the [Together AI API keys page](https://api.together.ai/settings/api-keys). +3. Store it in `.env`. + +```bash +TOGETHER_API_KEY=your_together_api_key +``` + +### Cohere + +Environment variable: `COHERE_API_KEY` + +Typical model prefixes: `cohere/...` + +How to get it: + +1. Create or sign in to your [Cohere account](https://dashboard.cohere.com/). +2. Create an API key from the [Cohere API keys page](https://dashboard.cohere.com/api-keys). +3. Store it in `.env`. + +```bash +COHERE_API_KEY=your_cohere_api_key +``` + +### Fireworks AI + +Environment variable: `FIREWORKS_API_KEY` + +Typical model prefixes: `fireworks/...` + +How to get it: + +1. Create or sign in to your [Fireworks AI account](https://fireworks.ai/). +2. Generate an API key from the [Fireworks AI API keys page](https://app.fireworks.ai/settings/users/api-keys). +3. Store it in `.env`. + +```bash +FIREWORKS_API_KEY=your_fireworks_api_key +``` + +### Ollama + +Environment variable: none required + +Used for: + +- Local model inference + +How to set it up: + +1. Install Ollama from the [main Ollama website](https://ollama.com/). +2. Pull the model you want to use by following the [Ollama library and setup docs](https://ollama.com/library). +3. Set `base_url` or `rag_base_url` if needed, such as `http://localhost:11434`. + +## Default Embedding Provider + +### Hugging Face + +Environment variable: `HF_TOKEN` + +> **Optional.** Only required for downloading gated or private Hugging Face models. Public models work without a token. + +Used for: + +- Accessing gated or private Hugging Face models +- Rate-limited API access + +How to get it: + +1. Create or sign in to your [Hugging Face account](https://huggingface.co/). +2. Open the [access tokens page](https://huggingface.co/settings/tokens). +3. Create a new token with the required permissions. +4. Store it in `.env`. + +```bash +HF_TOKEN=your_huggingface_token +``` + +## Recommended `.env` Template + +Use the subset you need: + +```bash +# Publisher providers +SCOPUS_API_KEY=your_scopus_api_key +SPRINGER_OPENACCESS_API_KEY=your_springer_openaccess_api_key +SPRINGER_TDM_API_KEY=your_springer_tdm_api_key +WILEY_API_KEY=your_wiley_api_key +IOP_papers_path=/absolute/path/to/iop_papers + +# LLM providers +OPENAI_API_KEY=your_openai_api_key +GEMINI_API_KEY=your_gemini_api_key +ANTHROPIC_API_KEY=your_anthropic_api_key +DEEPSEEK_API_KEY=your_deepseek_api_key +OPENROUTER_API_KEY=your_openrouter_api_key +TOGETHER_API_KEY=your_together_api_key +COHERE_API_KEY=your_cohere_api_key +FIREWORKS_API_KEY=your_fireworks_api_key + +# Model and embedding access +HF_TOKEN=your_huggingface_token +``` + +## Notes + +- Keep all keys in your local `.env` file and never commit them to version control. +- For most users, the minimum setup is one publisher source plus one LLM provider. +- If you use Gemini models, use `GEMINI_API_KEY`. +- If you use the default embedding setup, make sure `HF_TOKEN` is available. + +## Related Pages + +- [Installation](installation.md) +- [Article Processing](../usage/article-processing.md) +- [Data Extraction](../usage/data-extraction.md) +- [RAG Configuration](../rag-config.md) diff --git a/docs/rag-config.md b/docs/rag-config.md index 9b843481..52376775 100644 --- a/docs/rag-config.md +++ b/docs/rag-config.md @@ -123,7 +123,7 @@ scanner.extract_composition_property_data( scanner.extract_composition_property_data( main_extraction_keyword="d33", rag_db_path="embeddings/piezo", - rag_chat_model="deepseek-chat", + rag_chat_model="deepseek/deepseek-chat", rag_max_tokens=1024, rag_top_k=4, ) @@ -178,7 +178,7 @@ scanner.extract_composition_property_data( scanner.extract_composition_property_data( main_extraction_keyword="d33", rag_db_path="embeddings/piezo", - rag_chat_model="together_ai/meta-llama/Llama-3-70b-chat-hf", + rag_chat_model="together/meta-llama/Llama-3-70b-chat-hf", rag_max_tokens=1024, rag_top_k=4, ) @@ -220,7 +220,7 @@ scanner.extract_composition_property_data( scanner.extract_composition_property_data( main_extraction_keyword="d33", rag_db_path="embeddings/piezo", - rag_chat_model="fireworks_ai/accounts/fireworks/models/llama-v3-8b-instruct", + rag_chat_model="fireworks/models/llama-v3-8b-instruct", rag_max_tokens=1024, rag_top_k=4, ) diff --git a/docs/usage/evaluation/agentic.md b/docs/usage/evaluation/agentic.md index 4bd119b4..ab6808d6 100644 --- a/docs/usage/evaluation/agentic.md +++ b/docs/usage/evaluation/agentic.md @@ -48,6 +48,22 @@ Whether to evaluate synthesis-related information. An instance of the LiteLLM class. Read more about LiteLLM instance from CrewAI [here](https://docs.crewai.com/en/concepts/llms). +#### :material-square-medium:`value_error_thresholds` _(dict)_ + +Optional mapping of ground-truth value ranges to absolute error tolerances for numeric property value comparisons. When provided, a custom tool is added to the evaluation agent that it calls before deciding whether a numeric value matches. The agent is instructed to accept the extracted value if the absolute difference does not exceed the threshold for the range that the ground-truth value falls in. + +Keys must be **tuples** `(min, max)`; `float('inf')` and `float('-inf')` are supported for open-ended bounds. + +```python +value_error_thresholds = { + (-200, 200): 5, # |ref| ≤ 200 → tolerance ±5 + (201, 500): 8, # ref in (200, 500] → tolerance ±8 + (-500, -201): 8, # ref in [-500, -200) → tolerance ±8 + (501, float('inf')): 10, # ref > 500 → tolerance ±10 + (float('-inf'), -501): 10, # ref < -500 → tolerance ±10 +} +``` + !!! info "Default Values" :material-square-small:**`weights`** = { @@ -58,7 +74,7 @@ An instance of the LiteLLM class. Read more about LiteLLM instance from CrewAI [ "precursors": 0.15, "characterization_techniques": 0.15, "steps": 0.1 - }
:material-square-small:**`output_file`** = "agentic_evaluation_result.json"
:material-square-small:**`extraction_agent_model_name`** = "gpt-4o-mini"
:material-square-small:**`is_synthesis_evaluation`** = True
:material-square-small:**`llm`** = LLM(model="o3-mini") + }
:material-square-small:**`output_file`** = "agentic_evaluation_result.json"
:material-square-small:**`extraction_agent_model_name`** = "gpt-4o-mini"
:material-square-small:**`is_synthesis_evaluation`** = True
:material-square-small:**`llm`** = LLM(model="o3-mini")
:material-square-small:**`value_error_thresholds`** = `None` (exact comparison) ## How It Works @@ -92,6 +108,33 @@ results = evaluate_agentic( ) ``` +## Value Error Tolerances + +By default, numeric property values must match exactly. You can allow a tolerance so that values within a specified absolute range of the ground-truth are still accepted as correct. + +When `value_error_thresholds` is set, a `get_value_error_threshold` tool is automatically added to the evaluation agent. The agent calls this tool with the ground-truth value before making each numeric value decision, then applies the returned threshold. + +```python +from comproscanner import evaluate_agentic + +results = evaluate_agentic( + ground_truth_file="ground_truth.json", + test_data_file="test_data.json", + value_error_thresholds={ + (-200, 200): 5, # small values: allow ±5 + (201, 500): 8, # medium values: allow ±8 + (-500, -201): 8, + (501, float('inf')): 10, # large values: allow ±10 + (float('-inf'), -501): 10, + } +) +``` + +!!! note + - Keys must be **tuples** `(min, max)` — Python lists cannot be used as dict keys. + - Key ordering does not matter; `min`/`max` is resolved internally. + - When no thresholds are configured the parameter has zero overhead — no extra tool is added to the agent and the task prompt is unchanged. + ## Output Format ```json diff --git a/docs/usage/evaluation/semantic.md b/docs/usage/evaluation/semantic.md index 64b3bc3b..812ed90f 100644 --- a/docs/usage/evaluation/semantic.md +++ b/docs/usage/evaluation/semantic.md @@ -60,6 +60,22 @@ Name of the fallback model which will be used if the primary model fails for sem Dictionary specifying similarity thresholds for each metric when using semantic evaluation. +#### :material-square-medium:`value_error_thresholds` _(dict)_ + +Optional mapping of ground-truth value ranges to absolute error tolerances for numeric property value comparisons. When provided, a property value is accepted as a match if the absolute difference between the extracted value and the ground-truth value does not exceed the threshold for the range that the ground-truth value falls in. If no range matches, exact comparison is used. + +Keys must be **tuples** `(min, max)` representing the range; `float('inf')` and `float('-inf')` are supported for open-ended bounds. + +```python +value_error_thresholds = { + (-200, 200): 5, # |ref| ≤ 200 → tolerance ±5 + (201, 500): 8, # ref in (200, 500] → tolerance ±8 + (-500, -201): 8, # ref in [-500, -200) → tolerance ±8 + (501, float('inf')): 10, # ref > 500 → tolerance ±10 + (float('-inf'), -501): 10, # ref < -500 → tolerance ±10 +} +``` + !!! info "Default Values" :material-square-small:**`weights`** = { @@ -79,7 +95,7 @@ Dictionary specifying similarity thresholds for each metric when using semantic "precursors": 0.8, "characterization_techniques": 0.8, "steps": 0.8 - } + }
:material-square-small:**`value_error_thresholds`** = `None` (exact comparison) ## How It Works @@ -167,6 +183,32 @@ results = evaluate_semantic( ) ``` +## Value Error Tolerances + +By default, numeric property values must match exactly (within floating-point precision). You can relax this for properties where a small deviation is acceptable — for example, when comparing values read from a figure against a table. + +Tolerances are defined as a mapping of ground-truth value **ranges** to **absolute error** thresholds. The range that contains the ground-truth value determines the tolerance applied for that comparison. + +```python +results = evaluate_semantic( + ground_truth_file="ground_truth.json", + test_data_file="test_data.json", + value_error_thresholds={ + (-200, 200): 5, # small values: allow ±5 + (201, 500): 8, # medium values: allow ±8 + (-500, -201): 8, + (501, float('inf')): 10, # large values: allow ±10 + (float('-inf'), -501): 10, + } +) +``` + +!!! note + - Keys must be **tuples** `(min, max)` — Python lists cannot be used as dict keys. + - Key ordering does not matter; `min`/`max` is resolved internally. + - This tolerance applies only to `compositions_property_values` numeric comparisons. Non-numeric values always use exact matching. + - If no range contains the ground-truth value, exact comparison is used for that value. + ## Output Format ```json diff --git a/examples/test_example.py b/examples/test_example.py index c7058481..c57133d2 100644 --- a/examples/test_example.py +++ b/examples/test_example.py @@ -89,7 +89,7 @@ comproscanner.evaluate_semantic( ground_truth_file="piezo_test/ground_truth.json", - test_data_file="piezo_test/model-outputs/deepseek/deepseek-v3-piezo-ceramic-test-results.json", + test_data_file="piezo_test/model-outputs/deepseek/deepseek-v3-0324-piezo-ceramic-test-results.json", output_file="piezo_test/eval-results/semantic-evaluation/deepseek-v3-0324-semantic-evaluation-results.json", extraction_agent_model_name="DeepSeek-V3-0324", ) @@ -98,7 +98,7 @@ comproscanner.evaluate_agentic( ground_truth_file="piezo_test/ground_truth.json", - test_data_file="piezo_test/model-outputs/deepseek/deepseek-v3-piezo-ceramic-test-results.json", + test_data_file="piezo_test/model-outputs/deepseek/deepseek-v3-0324-piezo-ceramic-test-results.json", output_file="piezo_test/eval-results/agentic-evaluation/deepseek-v3-0324-agentic-evaluation-results.json", extraction_agent_model_name="DeepSeek-V3-0324", llm=llm, diff --git a/pyproject.toml b/pyproject.toml index 3c6085be..3c07ded7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "comproscanner" -version = "0.1.5" +version = "0.1.6" description = "Multi-agent system for extracting and processing structured composition-property data from scientific literature" readme = "README.md" authors = [{ name = "Aritra Roy", email = "contact@aritraroy.live" }] diff --git a/src/comproscanner/__init__.py b/src/comproscanner/__init__.py index 4e8e43c4..8d388080 100644 --- a/src/comproscanner/__init__.py +++ b/src/comproscanner/__init__.py @@ -184,6 +184,7 @@ def evaluate_semantic( primary_model_name="thellert/physbert_cased", fallback_model_name="all-mpnet-base-v2", similarity_thresholds=None, + value_error_thresholds=None, ): """ Evaluate the extracted data using semantic evaluation. @@ -199,6 +200,8 @@ def evaluate_semantic( primary_model_name (str, optional): Name of the primary model for semantic evaluation. Defaults to "thellert/physbert_cased". fallback_model_name (str, optional): Name of the fallback model for semantic evaluation. Defaults to "all-mpnet-base-v2". similarity_thresholds (dict, optional): Similarity thresholds for evaluation. Defaults to 0.8 for each metric. + value_error_thresholds (dict, optional): Mapping of ``(min, max)`` tuples to + absolute error tolerances for numeric property-value comparisons. """ scanner = ComProScanner(main_property_keyword="placeholder") return scanner.evaluate_semantic( @@ -206,22 +209,24 @@ def evaluate_semantic( test_data_file=test_data_file, weights=weights, output_file=output_file, - agent_model_name=extraction_agent_model_name, + extraction_agent_model_name=extraction_agent_model_name, is_synthesis_evaluation=is_synthesis_evaluation, use_semantic_model=use_semantic_model, primary_model_name=primary_model_name, fallback_model_name=fallback_model_name, similarity_thresholds=similarity_thresholds, + value_error_thresholds=value_error_thresholds, ) def evaluate_agentic( ground_truth_file=None, test_data_file=None, - output_file="detailed_evaluation.json", + output_file="agentic_evaluation_result.json", extraction_agent_model_name="gpt-4o-mini", is_synthesis_evaluation=True, weights=None, llm=None, + value_error_thresholds=None, ): """ Evaluate the extracted data using agentic evaluation. @@ -229,21 +234,24 @@ def evaluate_agentic( Args: ground_truth_file (str, optional): Path to the ground truth file. Defaults to None. test_data_file (str, optional): Path to the test data file. Defaults to None. - output_file (str, optional): Path to the output file for saving the evaluation results. Defaults to "detailed_evaluation.json". + output_file (str, optional): Path to the output file for saving the evaluation results. Defaults to "agentic_evaluation_result.json". extraction_agent_model_name (str, optional): Name of the agent model used for extraction. Defaults to "GPT-4o-mini". is_synthesis_evaluation (bool, optional): A flag to indicate if synthesis evaluation is required. Defaults to True. weights (dict, optional): Weights for the evaluation metrics. Defaults to None. llm (LLM, optional): An instance of the LLM class. Defaults to None. + value_error_thresholds (dict, optional): Mapping of ``(min, max)`` tuples to + absolute error tolerances for numeric property-value comparisons. """ scanner = ComProScanner(main_property_keyword="placeholder") return scanner.evaluate_agentic( ground_truth_file=ground_truth_file, test_data_file=test_data_file, output_file=output_file, - agent_model_name=extraction_agent_model_name, + extraction_agent_model_name=extraction_agent_model_name, is_synthesis_evaluation=is_synthesis_evaluation, weights=weights, llm=llm, + value_error_thresholds=value_error_thresholds, ) # Add convenience functions to __all__ diff --git a/src/comproscanner/comproscanner.py b/src/comproscanner/comproscanner.py index 877b59a6..56b18555 100644 --- a/src/comproscanner/comproscanner.py +++ b/src/comproscanner/comproscanner.py @@ -153,6 +153,7 @@ def process_articles( raise ValueErrorHandler( message="Please provide property_keywords dictionary to proceed." ) + source_list = [source.lower() for source in source_list] if caption_keywords is None: caption_keywords = property_keywords rag_config = RAGConfig( @@ -162,8 +163,90 @@ def process_articles( embedding_model=embedding_model, ) + routed_doi_list = { + "elsevier": doi_list, + "springer": doi_list, + "wiley": doi_list, + "iop": doi_list, + } + publisher_sources = {"elsevier", "springer", "wiley", "iop"} + selected_publisher_sources = set(source_list).intersection(publisher_sources) + + # If a manual DOI list is provided, route each DOI to the matching publisher. + if doi_list is not None and len(selected_publisher_sources) > 0: + for src in publisher_sources: + routed_doi_list[src] = [] + try: + import pandas as pd + + metadata_file = ( + f"results/{self.main_property_keyword}_metadata.csv" + ) + if not os.path.exists(metadata_file): + logger.warning( + f"Metadata file '{metadata_file}' not found. " + "Using provided DOI list for all selected sources." + ) + for src in publisher_sources: + routed_doi_list[src] = doi_list + else: + metadata_df = pd.read_csv( + metadata_file, dtype=str, low_memory=False + ).fillna("") + if ( + "doi" not in metadata_df.columns + or "general_publisher" not in metadata_df.columns + ): + logger.warning( + "Metadata file is missing required columns " + "('doi', 'general_publisher'). " + "Using provided DOI list for all selected sources." + ) + for src in publisher_sources: + routed_doi_list[src] = doi_list + else: + doi_to_publisher = { + str(row["doi"]).strip(): str( + row["general_publisher"] + ).strip().lower() + for _, row in metadata_df.iterrows() + if str(row["doi"]).strip() + } + + unresolved_dois = [] + filtered_out_dois = [] + for doi in doi_list: + doi_key = str(doi).strip() + publisher = doi_to_publisher.get(doi_key) + if publisher in selected_publisher_sources: + routed_doi_list[publisher].append(doi) + elif publisher is None or publisher == "": + unresolved_dois.append(doi) + else: + filtered_out_dois.append((doi, publisher)) + + if unresolved_dois: + logger.warning( + f"{len(unresolved_dois)} DOI(s) were not found in metadata " + "and were skipped." + ) + if filtered_out_dois: + logger.info( + f"{len(filtered_out_dois)} DOI(s) were skipped because their " + "publisher is not in source_list." + ) + except Exception as e: + logger.warning( + f"Failed to route DOI list by publisher from metadata: {e}. " + "Using provided DOI list for all selected sources." + ) + for src in publisher_sources: + routed_doi_list[src] = doi_list + # Process Elsevier articles - if "elsevier" in source_list: + if "elsevier" in source_list and ( + doi_list is None or len(routed_doi_list["elsevier"]) > 0 + ): from .article_processors.elsevier_processor import ElsevierArticleProcessor elsevier_processor = ElsevierArticleProcessor( @@ -173,7 +256,7 @@ def process_articles( csv_batch_size=csv_batch_size, start_row=start_row, end_row=end_row, - doi_list=doi_list, + doi_list=routed_doi_list["elsevier"], is_sql_db=is_sql_db, is_save_xml=is_save_xml, rag_config=rag_config, @@ -182,7 +265,9 @@ def process_articles( elsevier_processor.process_elsevier_articles() # Process Springer articles - if "springer" in source_list: + if "springer" in source_list and ( + doi_list is None or len(routed_doi_list["springer"]) > 0 + ): from .article_processors.springer_processor import SpringerArticleProcessor springer_processor = SpringerArticleProcessor( @@ -192,7 +277,7 @@ def process_articles( csv_batch_size=csv_batch_size, start_row=start_row, end_row=end_row, - doi_list=doi_list, + doi_list=routed_doi_list["springer"], is_sql_db=is_sql_db, is_save_xml=is_save_xml, rag_config=rag_config, @@ -201,7 +286,9 @@ def process_articles( springer_processor.process_springer_articles() # Process Wiley articles - if "wiley" in source_list: + if "wiley" in source_list and ( + doi_list is None or len(routed_doi_list["wiley"]) > 0 + ): from .article_processors.wiley_processor import WileyArticleProcessor wiley_processor = WileyArticleProcessor( @@ -211,7 +298,7 @@ def process_articles( csv_batch_size=csv_batch_size, start_row=start_row, end_row=end_row, - doi_list=doi_list, + doi_list=routed_doi_list["wiley"], is_sql_db=is_sql_db, is_save_pdf=is_save_pdf, rag_config=rag_config, @@ -220,7 +307,9 @@ def process_articles( wiley_processor.process_wiley_articles() # Process IOP articles - if "iop" in source_list: + if "iop" in source_list and ( + doi_list is None or len(routed_doi_list["iop"]) > 0 + ): from .article_processors.iop_processor import IOPArticleProcessor iop_processor = IOPArticleProcessor( @@ -230,7 +319,7 @@ def process_articles( csv_batch_size=csv_batch_size, start_row=start_row, end_row=end_row, - doi_list=doi_list, + doi_list=routed_doi_list["iop"], is_sql_db=is_sql_db, rag_config=rag_config, caption_keywords=caption_keywords, @@ -681,6 +770,7 @@ def evaluate_semantic( primary_model_name="thellert/physbert_cased", fallback_model_name="all-mpnet-base-v2", similarity_thresholds=None, + value_error_thresholds=None, ): """Evaluate the extracted data using semantic evaluation. @@ -695,6 +785,8 @@ def evaluate_semantic( primary_model_name (str, optional): Name of the primary model for semantic evaluation. Defaults to "thellert/physbert_cased". fallback_model_name (str, optional): Name of the fallback model for semantic evaluation. Defaults to "all-mpnet-base-v2". similarity_thresholds (dict, optional): Similarity thresholds for evaluation. Defaults to 0.8 for each metric. + value_error_thresholds (dict, optional): Mapping of ``(min, max)`` tuples to + absolute error tolerances for numeric property-value comparisons. Returns: results (dict): Evaluation results containing various metrics. @@ -712,6 +804,7 @@ def evaluate_semantic( primary_model_name=primary_model_name, fallback_model_name=fallback_model_name, similarity_thresholds=similarity_thresholds, + value_error_thresholds=value_error_thresholds, ) results = evaluator.evaluate( ground_truth_file=ground_truth_file, @@ -720,6 +813,7 @@ def evaluate_semantic( output_file=output_file, extraction_agent_model_name=extraction_agent_model_name, is_synthesis_evaluation=is_synthesis_evaluation, + value_error_thresholds=value_error_thresholds, ) return results @@ -732,6 +826,7 @@ def evaluate_agentic( output_file: str = "agentic_evaluation_result.json", is_synthesis_evaluation: bool = True, llm: Optional[LLM] = None, + value_error_thresholds=None, ): """Evaluate the extracted data using agentic evaluation. @@ -743,6 +838,8 @@ def evaluate_agentic( output_file (str, optional): Path to the output file for saving the evaluation results. Defaults to "agentic_evaluation_result.json". is_synthesis_evaluation (bool, optional): A flag to indicate if synthesis evaluation is required. Defaults to True. llm (LLM, optional): An instance of the LLM class. Defaults to instance of LLM with model="o3-mini" + value_error_thresholds (dict, optional): Mapping of ``(min, max)`` tuples to + absolute error tolerances for numeric property-value comparisons. Returns: results (dict): Evaluation results containing various metrics. @@ -764,6 +861,7 @@ def evaluate_agentic( is_synthesis_evaluation=is_synthesis_evaluation, weights=weights, llm=llm, + value_error_thresholds=value_error_thresholds, ) results = evaluator.kickoff() return results diff --git a/src/comproscanner/extract_flow/tools/rag_tool.py b/src/comproscanner/extract_flow/tools/rag_tool.py index deff690d..b8ee8220 100644 --- a/src/comproscanner/extract_flow/tools/rag_tool.py +++ b/src/comproscanner/extract_flow/tools/rag_tool.py @@ -82,28 +82,28 @@ def _get_llm(self) -> BaseChatModel: "callbacks": callbacks, } # OpenAI models - if model.startswith(("gpt-", "text-", "o1", "o3")): + if model.startswith(("openai/", "gpt-", "text-", "o1", "o3")): self._check_package_exists("langchain_openai", model) from langchain_openai import ChatOpenAI return ChatOpenAI(model=model, request_timeout=1000, **common_params) # Deepseek models - if model.startswith("deepseek"): + if model.startswith("deepseek/"): self._check_package_exists("langchain_deepseek", model) from langchain_deepseek import ChatDeepSeek return ChatDeepSeek(model=model, request_timeout=1000, **common_params) # Google Gemini models - elif model.startswith("gemini-"): + elif model.startswith("gemini/"): self._check_package_exists("langchain_google_genai", model) from langchain_google_genai import ChatGoogleGenerativeAI return ChatGoogleGenerativeAI(model=model, **common_params) # Anthropic Claude models - elif model.startswith("claude-"): + elif model.startswith("claude/"): self._check_package_exists("langchain_anthropic", model) from langchain_anthropic import ChatAnthropic @@ -143,7 +143,7 @@ def _get_llm(self) -> BaseChatModel: return ChatCohere(model=model_name, **common_params) # Fireworks models - elif model.startswith(("fireworks/", "accounts/fireworks")): + elif model.startswith(("fireworks/")): self._check_package_exists("langchain_fireworks", model) from langchain_fireworks import ChatFireworks diff --git a/src/comproscanner/post_processing/evaluation/eval_flow/crews/composition_evaluation_crew/composition_evaluation_crew.py b/src/comproscanner/post_processing/evaluation/eval_flow/crews/composition_evaluation_crew/composition_evaluation_crew.py index 5eed3d9c..d3746700 100644 --- a/src/comproscanner/post_processing/evaluation/eval_flow/crews/composition_evaluation_crew/composition_evaluation_crew.py +++ b/src/comproscanner/post_processing/evaluation/eval_flow/crews/composition_evaluation_crew/composition_evaluation_crew.py @@ -8,16 +8,64 @@ """ # Standard library imports -from typing import Dict, Optional, Any, List, Union -import json +from typing import Dict, Optional, Any, List, Union, Tuple, Type # Third party imports from crewai import Agent, Task, Crew, Process from crewai.project import CrewBase, agent, crew, task from crewai import LLM +from crewai.tools import BaseTool from pydantic import BaseModel, Field +class ThresholdToolInput(BaseModel): + """Input schema for GetValueErrorThresholdTool.""" + + reference_value: str = Field( + ..., + description="The ground-truth numeric property value as a string, e.g. '150' or '-300.5'.", + ) + + +class GetValueErrorThresholdTool(BaseTool): + """ + Returns the allowed absolute error tolerance for a numeric ground-truth property value. + + When an evaluator has been configured with a ``value_error_thresholds`` mapping, this + tool lets the agent look up how much the extracted value is allowed to differ from the + ground-truth value before the comparison is counted as a mismatch. + + The returned string is one of: + - ``"threshold:"`` — the extracted value matches if |extracted - reference| <= N. + - ``"exact"`` — no tolerance configured; use exact matching (within 1e-6). + """ + + name: str = "get_value_error_threshold" + description: str = ( + "Look up the allowed absolute error tolerance for a numeric ground-truth property value. " + "Call this tool with the ground-truth (reference) value before deciding whether a " + "numeric property value extracted from the test data is a match. " + "If the tool returns 'threshold:', the test value is accepted if " + "|test_value - reference_value| <= N. If the tool returns 'exact', require exact equality." + ) + args_schema: Type[BaseModel] = ThresholdToolInput + + # Each element is (lo, hi, threshold) with lo <= hi + thresholds_list: List[Tuple[float, float, float]] = Field(default_factory=list) + + def _run(self, reference_value: str) -> str: + try: + ref_num = float(reference_value) + except (ValueError, TypeError): + return "exact: reference value is not numeric" + + for lo, hi, threshold in self.thresholds_list: + if lo <= ref_num <= hi: + return f"threshold:{threshold}" + + return "exact" + + class CompositionMatch(BaseModel): """Basic match structure with reference and test values""" @@ -87,14 +135,48 @@ class CompositionEvaluationCrew: This crew uses binary matching (yes/no) rather than semantic similarity or exact matching. """ - def __init__(self, llm: Optional[LLM] = None): + def __init__( + self, + llm: Optional[LLM] = None, + value_error_thresholds: Optional[Dict] = None, + ): + """ + Args: + llm: LLM instance for the agent. + value_error_thresholds: Mapping of ``(min, max)`` tuples to absolute error + tolerances for numeric property-value comparisons. When provided, the + ``get_value_error_threshold`` tool is added to the evaluator agent so that + it can look up the tolerance before deciding on value matches. Example:: + + { + (-200, 200): 5, + (201, 500): 8, + (-500, -201): 8, + (501, float('inf')): 10, + (float('-inf'), -501): 10, + } + """ self.llm = llm or LLM(model="o3-mini") + # Convert the dict to an internal list of (lo, hi, threshold) triples + self._thresholds_list: List[Tuple[float, float, float]] = [] + if value_error_thresholds: + for range_key, threshold in value_error_thresholds.items(): + lo = min(range_key) + hi = max(range_key) + self._thresholds_list.append((lo, hi, float(threshold))) @agent def composition_evaluator_agent(self) -> Agent: """Agent that evaluates composition data with binary decisions.""" + tools = [] + if self._thresholds_list: + tools = [ + GetValueErrorThresholdTool(thresholds_list=self._thresholds_list) + ] return Agent( - config=self.agents_config["composition_evaluator_agent"], llm=self.llm + config=self.agents_config["composition_evaluator_agent"], + llm=self.llm, + tools=tools, ) @task diff --git a/src/comproscanner/post_processing/evaluation/eval_flow/crews/composition_evaluation_crew/config/tasks.yaml b/src/comproscanner/post_processing/evaluation/eval_flow/crews/composition_evaluation_crew/config/tasks.yaml index d1046dcb..1cccfc48 100644 --- a/src/comproscanner/post_processing/evaluation/eval_flow/crews/composition_evaluation_crew/config/tasks.yaml +++ b/src/comproscanner/post_processing/evaluation/eval_flow/crews/composition_evaluation_crew/config/tasks.yaml @@ -24,7 +24,8 @@ evaluate_composition_data_task: Note: The ground truth values can be given in a range (inside []) or as a single value. - If the test value is inside the range, it is considered a match. - Many times keys can't match exactly but the value will match - so while matching value be less strict about the finding the associated key. - + {value_error_threshold_instructions} + You must return your evaluation results in the following exact format (Use only {} JSON data, don't use markdown ```json format): { "composition_data": { diff --git a/src/comproscanner/post_processing/evaluation/eval_flow/eval_flow.py b/src/comproscanner/post_processing/evaluation/eval_flow/eval_flow.py index 5e393969..a5404f47 100644 --- a/src/comproscanner/post_processing/evaluation/eval_flow/eval_flow.py +++ b/src/comproscanner/post_processing/evaluation/eval_flow/eval_flow.py @@ -58,6 +58,9 @@ class AgentEvaluationState(BaseModel): # Weights for evaluation components weights: Dict[str, float] = {} + # Optional per-range absolute error tolerances for numeric property values + value_error_thresholds: Dict = {} + # Results storage evaluation_details: Dict = {} item_results: Dict = {} @@ -97,6 +100,7 @@ def __init__( is_synthesis_evaluation: bool = True, weights: Dict[str, float] = None, llm: Optional[LLM] = None, + value_error_thresholds: Dict = None, ): # Validate required inputs @@ -137,6 +141,9 @@ def __init__( if weights: self.state.weights.update(weights) + # Store optional per-range error tolerances for numeric property values + self.state.value_error_thresholds = value_error_thresholds or {} + def _calculate_tp_fp_fn(self, details, section): """ Calculate true positives, false positives, and false negatives for a given section. @@ -1018,8 +1025,25 @@ def evaluate_items(self, data_info): # Get normalized weights for metric calculations normalized_weights = self._normalize_weights() + # Build threshold instructions for the agent task prompt (injected as template var) + if self.state.value_error_thresholds: + _threshold_instructions = ( + "IMPORTANT — error tolerance is configured for this evaluation. " + "For each numeric property value comparison, call the " + "`get_value_error_threshold` tool with the ground-truth reference value " + "to retrieve the allowed absolute error. " + "If the tool returns 'threshold:', mark the value as a match (1) when " + "|test_value - reference_value| <= N. " + "If the tool returns 'exact', require exact equality (within 1e-6)." + ) + else: + _threshold_instructions = "" + # Set up the crews - composition_crew = CompositionEvaluationCrew(llm=self.state.llm).crew() + composition_crew = CompositionEvaluationCrew( + llm=self.state.llm, + value_error_thresholds=self.state.value_error_thresholds, + ).crew() if self.state.is_synthesis_evaluation: synthesis_crew = SynthesisEvaluationCrew(llm=self.state.llm).crew() @@ -1111,6 +1135,7 @@ def evaluate_items(self, data_info): "test_item": json.dumps( test_item.get("composition_data", {}) ), + "value_error_threshold_instructions": _threshold_instructions, } ) diff --git a/src/comproscanner/post_processing/evaluation/semantic_evaluator.py b/src/comproscanner/post_processing/evaluation/semantic_evaluator.py index 42f14957..8d8fea75 100644 --- a/src/comproscanner/post_processing/evaluation/semantic_evaluator.py +++ b/src/comproscanner/post_processing/evaluation/semantic_evaluator.py @@ -34,6 +34,7 @@ def __init__( primary_model_name="thellert/physbert_cased", fallback_model_name="all-mpnet-base-v2", similarity_thresholds=None, + value_error_thresholds=None, ): """ Initialize the evaluator with optional semantic models. @@ -43,12 +44,30 @@ def __init__( primary_model_name (str, optional): Name of the primary model to use fallback_model_name (str, optional): Name of the fallback sentence transformer model similarity_thresholds (dict, optional): Custom thresholds for similarity scoring + value_error_thresholds (dict, optional): Mapping of ground-truth value ranges to + absolute error tolerances. Keys must be 2-element tuples (min, max); values + are the allowed absolute difference between the ground-truth and extracted + numeric property value. Use float('inf') / float('-inf') for open-ended + ranges. Example:: + + { + (-200, 200): 5, + (201, 500): 8, + (-500, -201): 8, + (501, float('inf')): 10, + (float('-inf'), -501): 10, + } + + When a ground-truth value falls inside one of the ranges the extracted + value is accepted if ``|extracted - ground_truth| <= threshold``. + If no range matches, exact comparison (epsilon 1e-6) is used. """ self.use_semantic_model = use_semantic_model self.primary_model_name = primary_model_name self.fallback_model_name = fallback_model_name self.physbert_available = False self.model_available = False + self.value_error_thresholds = value_error_thresholds or {} # Load models if requested if self.use_semantic_model: @@ -184,16 +203,41 @@ def _simple_preprocess(self, text): return " ".join(filtered_words) - def _is_value_in_range(self, ref_val, test_val): + def _get_error_threshold(self, ref_val, error_thresholds): + """ + Return the absolute error threshold for *ref_val* from *error_thresholds*. + + Args: + ref_val (float): Ground-truth numeric value. + error_thresholds (dict): Mapping of ``(min, max)`` tuples to absolute + thresholds. Infinity values are supported. + + Returns: + float or None: Threshold if a range contains *ref_val*, else ``None``. + """ + for range_key, threshold in error_thresholds.items(): + lo = min(range_key) + hi = max(range_key) + if lo <= ref_val <= hi: + return threshold + return None + + def _is_value_in_range(self, ref_val, test_val, error_thresholds=None): """ - Check if test_val is within the range specified by ref_val. + Check if test_val is within the range specified by ref_val, optionally + using a per-range absolute error tolerance. Args: ref_val: Reference value, which can be a number or a list [min, max] test_val: Test value to check against the reference + error_thresholds (dict, optional): Mapping of ``(min, max)`` tuples to + absolute error tolerances (see class docstring). When provided and a + matching range is found the comparison becomes + ``|ref_val - test_val| <= threshold`` instead of an exact check. Returns: - bool: True if the test value matches or falls within the reference range, False otherwise + bool: True if the test value matches or falls within the reference range, + False otherwise. """ # Handle case where either value is None if ref_val is None or test_val is None: @@ -212,7 +256,15 @@ def _is_value_in_range(self, ref_val, test_val): # Case 2: ref_val is not a range, do regular comparison try: if isinstance(test_val, (int, float)) and isinstance(ref_val, (int, float)): - return abs(float(ref_val) - float(test_val)) < 1e-6 + ref_num = float(ref_val) + test_num = float(test_val) + # Use error threshold dict if provided and a range matches + if error_thresholds: + threshold = self._get_error_threshold(ref_num, error_thresholds) + if threshold is not None: + return abs(ref_num - test_num) <= threshold + # Default: floating-point epsilon comparison + return abs(ref_num - test_num) < 1e-6 else: return ref_val == test_val except (ValueError, TypeError): @@ -374,7 +426,7 @@ def _evaluate_composition_data(self, reference_comp, test_comp, weights=None): # Check if values match exactly ref_val = ref_values[key] test_val = test_values[key] - value_match = self._is_value_in_range(ref_val, test_val) + value_match = self._is_value_in_range(ref_val, test_val, self.value_error_thresholds) value_matches[key] = value_match if value_match: @@ -458,7 +510,7 @@ def _evaluate_composition_data(self, reference_comp, test_comp, weights=None): ref_val = ref_values[ref_key] test_val = test_values[test_key] - value_match = self._is_value_in_range(ref_val, test_val) + value_match = self._is_value_in_range(ref_val, test_val, self.value_error_thresholds) value_matches[ref_key] = value_match if value_match: @@ -1172,6 +1224,7 @@ def evaluate( weights=None, output_file="detailed_evaluation.json", is_synthesis_evaluation=True, + value_error_thresholds=None, ): """ Evaluate materials science data using normalized weights to ensure fair comparison @@ -1184,10 +1237,15 @@ def evaluate( weights (dict, optional): Custom weights for different components output_file (str, optional): Path to save the detailed evaluation results is_synthesis_evaluation (bool, optional): Whether to evaluate synthesis data + value_error_thresholds (dict, optional): Per-call override for the instance-level + ``value_error_thresholds``. Mapping of ``(min, max)`` tuples to absolute + error tolerances for numeric property value comparisons (see class docstring). Returns: dict: Evaluation results with scores and details including F1 metrics """ + if value_error_thresholds is not None: + self.value_error_thresholds = value_error_thresholds if not ground_truth_file or not test_data_file: raise ValueErrorHandler( "Both ground truth and test data files are required" diff --git a/tests/conftest.py b/tests/conftest.py index 2d422dd5..39a3136b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import pytest import sys import os +import types from pathlib import Path from unittest.mock import MagicMock, patch @@ -28,20 +29,109 @@ def create_mock_module(name, **kwargs): return mock +def create_real_module(name, **kwargs): + """Create a real module object (not MagicMock) with explicit attributes.""" + module = types.ModuleType(name) + module.__package__ = name.rsplit(".", 1)[0] if "." in name else "" + module.__path__ = [name] + + for key, value in kwargs.items(): + setattr(module, key, value) + + return module + + +def _identity_decorator(*args, **kwargs): + if len(args) == 1 and callable(args[0]) and not kwargs: + return args[0] + + def _decorator(func): + return func + + return _decorator + + +def _event_decorator(*args, **kwargs): + def _decorator(func): + return func + + return _decorator + + +def _or_(*signals): + return signals + + +class _DummyFlow: + def __class_getitem__(cls, _item): + return cls + + def __init__(self, *args, **kwargs): + self.state = types.SimpleNamespace( + is_materials_mentioned="", + composition_extracted_data={}, + composition_formatted_data={}, + synthesis_extracted_data={}, + synthesis_formatted_data={}, + doi="", + materials_data_identifier_query="", + main_extraction_keyword="", + composition_property_text_data="", + synthesis_text_data="", + is_extract_synthesis_data=True, + vlm_model="gemini/gemini-3-flash-preview", + related_figures_base_path="results/related_figures", + llm=None, + rag_config=None, + output_log_folder=None, + task_output_folder=None, + is_log_json=False, + verbose=True, + expected_composition_property_example="", + expected_variable_composition_property_example="", + composition_property_extraction_agent_note="", + composition_property_extraction_task_note="", + composition_property_formatting_agent_note="", + composition_property_formatting_task_note="", + synthesis_extraction_agent_note="", + synthesis_extraction_task_note="", + synthesis_formatting_agent_note="", + synthesis_formatting_task_note="", + allowed_synthesis_methods="", + allowed_characterization_techniques="", + ) + + +class _DummyChromaCollection: + def __init__(self, *args, **kwargs): + pass + + def query(self, *args, **kwargs): + return {"ids": [[]], "metadatas": [[]], "documents": [[]], "distances": [[]]} + + def upsert(self, *args, **kwargs): + return None + + +class _DummyCrewAILLM: + pass + + # Mock crewai completely before any imports crewai_modules = { "crewai": { - "LLM": MagicMock(), + "LLM": _DummyCrewAILLM, "Agent": MagicMock(), "Task": MagicMock(), "Crew": MagicMock(), }, "crewai.flow": {}, "crewai.flow.flow": { - "Flow": MagicMock(), - "listen": MagicMock(), - "start": MagicMock(), - "router": MagicMock(), + "Flow": _DummyFlow, + "listen": _event_decorator, + "start": _identity_decorator, + "router": _event_decorator, + "or_": _or_, }, "crewai.project": { "CrewBase": MagicMock(), @@ -53,7 +143,7 @@ def create_mock_module(name, **kwargs): "BaseTool": MagicMock(), }, "crewai.agent": {}, - "crewai.llm": {"LLM": MagicMock()}, + "crewai.llm": {"LLM": _DummyCrewAILLM}, "crewai.agents": {}, "crewai.agents.crew_agent_executor": {"CrewAgentExecutor": MagicMock()}, "crewai.agents.agent_builder": {}, @@ -74,8 +164,239 @@ def create_mock_module(name, **kwargs): sys.modules["litellm"] = create_mock_module("litellm") sys.modules["litellm.types"] = create_mock_module("litellm.types") sys.modules["litellm.types.utils"] = create_mock_module("litellm.types.utils") +sys.modules["litellm.exceptions"] = create_real_module( + "litellm.exceptions", + ContextWindowExceededError=type("ContextWindowExceededError", (Exception,), {}), +) +sys.modules["litellm.utils"] = create_real_module( + "litellm.utils", + supports_response_schema=lambda *args, **kwargs: True, + supports_function_calling=lambda *args, **kwargs: True, +) +sys.modules["litellm.litellm_core_utils"] = create_real_module( + "litellm.litellm_core_utils" +) +sys.modules["litellm.litellm_core_utils.get_supported_openai_params"] = ( + create_real_module( + "litellm.litellm_core_utils.get_supported_openai_params", + get_supported_openai_params=lambda *args, **kwargs: ["stop"], + ) +) +sys.modules["litellm.integrations"] = create_real_module("litellm.integrations") +sys.modules["litellm.integrations.custom_logger"] = create_real_module( + "litellm.integrations.custom_logger", + CustomLogger=type("CustomLogger", (), {}), +) sys.modules["instructor"] = create_mock_module("instructor") sys.modules["crewai_tools"] = create_mock_module("crewai_tools") +sys.modules["langchain_chroma"] = create_real_module( + "langchain_chroma", Chroma=type("Chroma", (), {}) +) + +# Use lightweight real types/functions (not MagicMock) for chromadb stubs. +# CrewAI/Pydantic dataclass parsing reads type annotations and fails on MagicMock. +class _PydanticAnyTypeMixin: + @classmethod + def __get_pydantic_core_schema__(cls, _source_type, _handler): + from pydantic_core import core_schema + + return core_schema.any_schema() + + +class _DummyPersistentClient(_PydanticAnyTypeMixin): + def __init__(self, *args, **kwargs): + self._args = args + self._kwargs = kwargs + + def get_or_create_collection(self, *args, **kwargs): + return _DummyChromaCollection() + + def reset(self): + return None + + def clear_system_cache(self): + return None + + +class _DummySettings(_PydanticAnyTypeMixin): + def __init__(self, *args, **kwargs): + self._kwargs = kwargs + + +class _DummyAsyncClientAPI(_PydanticAnyTypeMixin): + pass + + +class _DummyClientAPI(_PydanticAnyTypeMixin): + pass + + +class _DummyCollectionConfigurationInterface(_PydanticAnyTypeMixin): + pass + + +class _DummyCollectionMetadata(dict, _PydanticAnyTypeMixin): + pass + + +class _DummyLoadable(_PydanticAnyTypeMixin): + pass + + +class _DummyWhere(dict, _PydanticAnyTypeMixin): + pass + + +class _DummyWhereDocument(dict, _PydanticAnyTypeMixin): + pass + + +class _DummyDataLoader(_PydanticAnyTypeMixin): + def __class_getitem__(cls, _item): + return cls + + +class _DummyEmbeddingFunction(_PydanticAnyTypeMixin): + def __class_getitem__(cls, _item): + return cls + + +class _DummyInclude(list, _PydanticAnyTypeMixin): + pass + + +class _DummyDocuments(list, _PydanticAnyTypeMixin): + pass + + +class _DummyEmbeddings(list, _PydanticAnyTypeMixin): + pass + + +class _DummyMetadata(dict, _PydanticAnyTypeMixin): + pass + + +class _DummyCollection(_PydanticAnyTypeMixin): + pass + + +class _DummyOpenAIEmbeddingFunction(_PydanticAnyTypeMixin): + def __init__(self, *args, **kwargs): + pass + + +class _DummyEmbeddingCallable(_PydanticAnyTypeMixin): + def __init__(self, *args, **kwargs): + pass + + +def _dummy_validate_embedding_function(*args, **kwargs): + return None + + +sys.modules["chromadb"] = create_real_module( + "chromadb", + PersistentClient=_DummyPersistentClient, + Collection=_DummyCollection, + Documents=_DummyDocuments, + EmbeddingFunction=_DummyEmbeddingFunction, + Embeddings=_DummyEmbeddings, + Metadata=_DummyMetadata, +) +sys.modules["chromadb.config"] = create_real_module( + "chromadb.config", Settings=_DummySettings +) +sys.modules["chromadb.api"] = create_real_module( + "chromadb.api", + AsyncClientAPI=_DummyAsyncClientAPI, + ClientAPI=_DummyClientAPI, +) +sys.modules["chromadb.api.types"] = create_real_module( + "chromadb.api.types", + CollectionMetadata=_DummyCollectionMetadata, + DataLoader=_DummyDataLoader, + Documents=_DummyDocuments, + EmbeddingFunction=_DummyEmbeddingFunction, + Embeddings=_DummyEmbeddings, + Include=_DummyInclude, + Loadable=_DummyLoadable, + OneOrMany=list, + Where=_DummyWhere, + WhereDocument=_DummyWhereDocument, + validate_embedding_function=_dummy_validate_embedding_function, +) +sys.modules["chromadb.errors"] = create_real_module( + "chromadb.errors", + InvalidDimensionException=type("InvalidDimensionException", (Exception,), {}), +) +sys.modules["chromadb.api.configuration"] = create_real_module( + "chromadb.api.configuration", + CollectionConfigurationInterface=_DummyCollectionConfigurationInterface, +) +sys.modules["chromadb.utils"] = create_real_module("chromadb.utils") +sys.modules["chromadb.utils.embedding_functions"] = create_real_module( + "chromadb.utils.embedding_functions" +) + + +def _register_embedding_function_module(module_name, class_names): + attrs = {class_name: _DummyEmbeddingCallable for class_name in class_names} + sys.modules[module_name] = create_real_module(module_name, **attrs) + + +_register_embedding_function_module( + "chromadb.utils.embedding_functions.openai_embedding_function", + ["OpenAIEmbeddingFunction"], +) +_register_embedding_function_module( + "chromadb.utils.embedding_functions.amazon_bedrock_embedding_function", + ["AmazonBedrockEmbeddingFunction"], +) +_register_embedding_function_module( + "chromadb.utils.embedding_functions.cohere_embedding_function", + ["CohereEmbeddingFunction"], +) +_register_embedding_function_module( + "chromadb.utils.embedding_functions.google_embedding_function", + ["GoogleGenerativeAiEmbeddingFunction", "GoogleVertexEmbeddingFunction"], +) +_register_embedding_function_module( + "chromadb.utils.embedding_functions.huggingface_embedding_function", + ["HuggingFaceEmbeddingFunction", "HuggingFaceEmbeddingServer"], +) +_register_embedding_function_module( + "chromadb.utils.embedding_functions.instructor_embedding_function", + ["InstructorEmbeddingFunction"], +) +_register_embedding_function_module( + "chromadb.utils.embedding_functions.jina_embedding_function", + ["JinaEmbeddingFunction"], +) +_register_embedding_function_module( + "chromadb.utils.embedding_functions.ollama_embedding_function", + ["OllamaEmbeddingFunction"], +) +_register_embedding_function_module( + "chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2", + ["ONNXMiniLM_L6_V2"], +) +_register_embedding_function_module( + "chromadb.utils.embedding_functions.open_clip_embedding_function", + ["OpenCLIPEmbeddingFunction"], +) +_register_embedding_function_module( + "chromadb.utils.embedding_functions.roboflow_embedding_function", + ["RoboflowEmbeddingFunction"], +) +_register_embedding_function_module( + "chromadb.utils.embedding_functions.sentence_transformer_embedding_function", + ["SentenceTransformerEmbeddingFunction"], +) +_register_embedding_function_module( + "chromadb.utils.embedding_functions.text2vec_embedding_function", + ["Text2VecEmbeddingFunction"], +) # Mock aiohttp to avoid the ConnectionTimeoutError aiohttp_mock = create_mock_module("aiohttp") diff --git a/tests/test_extract_flow.py b/tests/test_extract_flow.py index 893f173d..5f176e10 100644 --- a/tests/test_extract_flow.py +++ b/tests/test_extract_flow.py @@ -11,27 +11,35 @@ import sys import os -# Ensure we import the real modules first -if "PYTEST_CURRENT_TEST" in os.environ: +# Optionally import with real crewai modules only when explicitly enabled. +# Default test runs should keep conftest mocks active. +use_real_crewai_import = os.environ.get("USE_REAL_CREWAI_IMPORT", "").lower() in ( + "1", + "true", + "yes", +) + +saved_mocks = {} +if "PYTEST_CURRENT_TEST" in os.environ and use_real_crewai_import: # Temporarily remove the mock to import the real class crewai_mocks = [key for key in sys.modules.keys() if key.startswith("crewai")] - saved_mocks = { - key: sys.modules.pop(key) for key in crewai_mocks if key in sys.modules - } + saved_mocks = {key: sys.modules.pop(key) for key in crewai_mocks if key in sys.modules} import pytest import json from unittest.mock import MagicMock, patch, call -# Now import after ensuring real modules are loaded -from comproscanner.extract_flow.main_extraction_flow import DataExtractionFlow +# Import flow. Prefer real modules, but gracefully fall back to mocked environment. +try: + from comproscanner.extract_flow.main_extraction_flow import DataExtractionFlow +finally: + # Always restore removed crewai mocks so other test modules don't inherit a broken state. + if "PYTEST_CURRENT_TEST" in os.environ and use_real_crewai_import and saved_mocks: + sys.modules.update(saved_mocks) + from comproscanner.utils.error_handler import ValueErrorHandler from comproscanner.utils.configs.rag_config import RAGConfig -# Restore mocks if needed -if "PYTEST_CURRENT_TEST" in os.environ and "saved_mocks" in locals(): - sys.modules.update(saved_mocks) - @pytest.fixture def sample_doi(): diff --git a/tests/test_post_processing/test_evaluation/test_eval_flow.py b/tests/test_post_processing/test_evaluation/test_eval_flow.py index 1055a01e..b2e1c271 100644 --- a/tests/test_post_processing/test_evaluation/test_eval_flow.py +++ b/tests/test_post_processing/test_evaluation/test_eval_flow.py @@ -51,6 +51,7 @@ def test_agent_evaluation_state_default_values(self): assert state.processed_count == 0 assert state.total_count == 0 assert state.remaining_dois == [] + assert state.value_error_thresholds == {} def test_agent_evaluation_state_custom_values(self): """Test AgentEvaluationState with custom values""" @@ -1066,6 +1067,32 @@ def test_file_path_creation(self, temp_files): assert os.path.exists(nested_output) assert os.path.exists(os.path.dirname(nested_output)) + def test_value_error_thresholds_stored_in_state(self, temp_files): + """value_error_thresholds passed at construction are stored in flow.state""" + gt_file, test_file, _ = temp_files + thresholds = {(0, 200): 5, (201, 500): 10} + + flow = MaterialsDataAgenticEvaluatorFlow( + ground_truth_file=gt_file, + test_data_file=test_file, + extraction_agent_model_name="test-model", + value_error_thresholds=thresholds, + ) + + assert flow.state.value_error_thresholds == thresholds + + def test_value_error_thresholds_default_is_empty(self, temp_files): + """When value_error_thresholds is not provided it defaults to empty dict""" + gt_file, test_file, _ = temp_files + + flow = MaterialsDataAgenticEvaluatorFlow( + ground_truth_file=gt_file, + test_data_file=test_file, + extraction_agent_model_name="test-model", + ) + + assert flow.state.value_error_thresholds == {} + def test_missing_dois_handling(self, temp_files): """Test handling of DOIs missing from test data""" gt_file, test_file, output_file = temp_files diff --git a/tests/test_post_processing/test_evaluation/test_semantic_evaluator.py b/tests/test_post_processing/test_evaluation/test_semantic_evaluator.py index 5e46db84..1b58ea85 100644 --- a/tests/test_post_processing/test_evaluation/test_semantic_evaluator.py +++ b/tests/test_post_processing/test_evaluation/test_semantic_evaluator.py @@ -186,6 +186,71 @@ def test_is_value_in_range_none_values(self, evaluator_no_model): assert evaluator_no_model._is_value_in_range(None, 10) is False assert evaluator_no_model._is_value_in_range(10, None) is False + def test_is_value_in_range_with_error_thresholds(self, evaluator_no_model): + """Test numeric comparison using per-range absolute error tolerances""" + thresholds = {(0, 200): 5, (201, 500): 10} + + # Within tolerance + assert evaluator_no_model._is_value_in_range(100.0, 104.0, thresholds) is True + assert evaluator_no_model._is_value_in_range(100.0, 95.5, thresholds) is True + + # Outside tolerance + assert evaluator_no_model._is_value_in_range(100.0, 110.0, thresholds) is False + + # Exact boundary of tolerance + assert evaluator_no_model._is_value_in_range(100.0, 105.0, thresholds) is True + assert evaluator_no_model._is_value_in_range(100.0, 105.1, thresholds) is False + + # Value in second range + assert evaluator_no_model._is_value_in_range(300.0, 308.0, thresholds) is True + assert evaluator_no_model._is_value_in_range(300.0, 315.0, thresholds) is False + + # No matching range — falls back to epsilon comparison + assert evaluator_no_model._is_value_in_range(600.0, 600.0, thresholds) is True + assert evaluator_no_model._is_value_in_range(600.0, 605.0, thresholds) is False + + def test_evaluate_uses_value_error_thresholds_from_constructor( + self, temp_json_files + ): + """value_error_thresholds set at construction time is used during evaluate()""" + gt_file, test_file, output_file = temp_json_files + + # Evaluator with loose tolerance: allow ±5 for values in [0, 10] + evaluator = MaterialsDataSemanticEvaluator( + use_semantic_model=False, + value_error_thresholds={(0, 10): 5}, + ) + assert evaluator.value_error_thresholds == {(0, 10): 5} + + results = evaluator.evaluate( + ground_truth_file=gt_file, + test_data_file=test_file, + extraction_agent_model_name="test_model", + output_file=output_file, + ) + assert "overall_accuracy" in results + + def test_evaluate_value_error_thresholds_per_call_override(self, temp_json_files): + """value_error_thresholds passed to evaluate() overrides the instance-level one""" + gt_file, test_file, output_file = temp_json_files + + evaluator = MaterialsDataSemanticEvaluator( + use_semantic_model=False, + value_error_thresholds={(0, 10): 1}, + ) + + # Pass a different threshold via the evaluate() call + results = evaluator.evaluate( + ground_truth_file=gt_file, + test_data_file=test_file, + extraction_agent_model_name="test_model", + output_file=output_file, + value_error_thresholds={(0, 10): 5}, + ) + # After the call the instance-level threshold should have been updated + assert evaluator.value_error_thresholds == {(0, 10): 5} + assert "overall_accuracy" in results + def test_calculate_text_similarity_sequence_matcher(self, evaluator_no_model): """Test text similarity with sequence matcher fallback""" similarity = evaluator_no_model._calculate_text_similarity( diff --git a/tests/test_public_api.py b/tests/test_public_api.py new file mode 100644 index 00000000..f2ef131e --- /dev/null +++ b/tests/test_public_api.py @@ -0,0 +1,198 @@ +import os +import sys +import types +from unittest.mock import MagicMock, patch + +import pandas as pd + +# Avoid importing heavy vector DB runtime deps during test module import. +if "langchain_chroma" not in sys.modules: + _fake_langchain_chroma = types.ModuleType("langchain_chroma") + _fake_langchain_chroma.Chroma = MagicMock() + sys.modules["langchain_chroma"] = _fake_langchain_chroma +if "chromadb" not in sys.modules: + _fake_chromadb = types.ModuleType("chromadb") + _fake_chromadb.PersistentClient = MagicMock() + sys.modules["chromadb"] = _fake_chromadb + +from comproscanner.comproscanner import ComProScanner + + +def test_collect_metadata_public_api_calls_fetch_and_filter(): + with ( + patch("comproscanner.comproscanner.FetchMetadata") as mock_fetch_cls, + patch("comproscanner.comproscanner.FilterMetadata") as mock_filter_cls, + ): + scanner = ComProScanner(main_property_keyword="piezoelectric") + scanner.collect_metadata( + base_queries=["q1"], extra_queries=["q2"], start_year=2025, end_year=2024 + ) + + mock_fetch_cls.assert_called_once_with( + main_property_keyword="piezoelectric", + start_year=2025, + end_year=2024, + base_queries=["q1"], + extra_queries=["q2"], + ) + mock_fetch_cls.return_value.main_fetch.assert_called_once_with() + mock_filter_cls.assert_called_once_with(main_property_keyword="piezoelectric") + mock_filter_cls.return_value.update_publisher_information.assert_called_once_with() + + +def test_process_articles_routes_dois_to_matching_publishers(): + scanner = ComProScanner(main_property_keyword="piezoelectric") + doi_list = ["10.1/a", "10.2/b", "10.3/c", "10.4/d", "10.5/e"] + metadata_df = pd.DataFrame( + { + "doi": ["10.1/a", "10.2/b", "10.3/c", "10.4/d"], + "general_publisher": ["elsevier", "springer", "wiley", "iop"], + } + ) + + mock_elsevier_cls = MagicMock() + mock_springer_cls = MagicMock() + mock_wiley_cls = MagicMock() + mock_iop_cls = MagicMock() + + fake_modules = { + "comproscanner.article_processors.elsevier_processor": types.SimpleNamespace( + ElsevierArticleProcessor=mock_elsevier_cls + ), + "comproscanner.article_processors.springer_processor": types.SimpleNamespace( + SpringerArticleProcessor=mock_springer_cls + ), + "comproscanner.article_processors.wiley_processor": types.SimpleNamespace( + WileyArticleProcessor=mock_wiley_cls + ), + "comproscanner.article_processors.iop_processor": types.SimpleNamespace( + IOPArticleProcessor=mock_iop_cls + ), + } + + with ( + patch("comproscanner.comproscanner.os.path.exists", return_value=True), + patch("pandas.read_csv", return_value=metadata_df), + patch.dict(sys.modules, fake_modules, clear=False), + ): + scanner.process_articles( + property_keywords={"exact_keywords": ["d33"], "substring_keywords": []}, + source_list=["elsevier", "springer", "wiley", "iop"], + doi_list=doi_list, + ) + + assert mock_elsevier_cls.call_args.kwargs["doi_list"] == ["10.1/a"] + assert mock_springer_cls.call_args.kwargs["doi_list"] == ["10.2/b"] + assert mock_wiley_cls.call_args.kwargs["doi_list"] == ["10.3/c"] + assert mock_iop_cls.call_args.kwargs["doi_list"] == ["10.4/d"] + mock_elsevier_cls.return_value.process_elsevier_articles.assert_called_once_with() + mock_springer_cls.return_value.process_springer_articles.assert_called_once_with() + mock_wiley_cls.return_value.process_wiley_articles.assert_called_once_with() + mock_iop_cls.return_value.process_iop_articles.assert_called_once_with() + + +def test_extract_composition_property_data_public_api_smoke_with_no_papers(tmp_path): + scanner = ComProScanner(main_property_keyword="piezoelectric") + output_file = tmp_path / "results.json" + + mock_preparator = MagicMock() + mock_preparator.get_unprocessed_data.return_value = [] + + with ( + patch("comproscanner.comproscanner.MatPropDataPreparator", return_value=mock_preparator), + patch("comproscanner.comproscanner.LLMConfig") as mock_llm_cfg, + patch("comproscanner.comproscanner.DataCleaner") as mock_cleaner_cls, + ): + mock_llm_cfg.return_value.get_llm.return_value = MagicMock() + mock_cleaner_cls.return_value.get_useful_data.return_value = {} + + scanner.extract_composition_property_data( + main_extraction_keyword="d33", + json_results_file=str(output_file), + checked_doi_list_file=str(tmp_path / "checked.txt"), + ) + + assert os.path.exists(output_file) + mock_cleaner_cls.return_value.get_useful_data.assert_called_once_with() + + +def test_clean_data_public_api_forwards_to_cleaner(tmp_path): + scanner = ComProScanner(main_property_keyword="piezoelectric") + input_file = tmp_path / "input.json" + input_file.write_text("{}", encoding="utf-8") + + cleaner = MagicMock() + cleaner.clean_data_with_relevant_compositions.return_value = {"10.x/test": {}} + + with patch("comproscanner.comproscanner.DataCleaner", return_value=cleaner): + result = scanner.clean_data( + json_results_file=str(input_file), + is_save_separate_results=False, + is_save_composition_property_file=False, + cleaning_strategy="full", + ) + + assert result == {"10.x/test": {}} + cleaner.clean_data_with_relevant_compositions.assert_called_once_with( + strategy="full" + ) + + +def test_evaluate_semantic_public_api_supports_value_error_thresholds(): + scanner = ComProScanner(main_property_keyword="piezoelectric") + thresholds = { + (-200, 200): 5, + (201, 500): 8, + (-500, -201): 8, + (501, float("inf")): 10, + (float("-inf"), -501): 10, + } + + with patch( + "comproscanner.comproscanner.MaterialsDataSemanticEvaluator" + ) as mock_evaluator_cls: + mock_evaluator = MagicMock() + mock_evaluator.evaluate.return_value = {"ok": True} + mock_evaluator_cls.return_value = mock_evaluator + + result = scanner.evaluate_semantic( + ground_truth_file="gt.json", + test_data_file="test.json", + extraction_agent_model_name="model-x", + value_error_thresholds=thresholds, + ) + + assert result == {"ok": True} + assert mock_evaluator_cls.call_args.kwargs["value_error_thresholds"] == thresholds + assert ( + mock_evaluator.evaluate.call_args.kwargs["value_error_thresholds"] == thresholds + ) + + +def test_evaluate_agentic_public_api_supports_value_error_thresholds(): + scanner = ComProScanner(main_property_keyword="piezoelectric") + thresholds = { + (-200, 200): 5, + (201, 500): 8, + (-500, -201): 8, + (501, float("inf")): 10, + (float("-inf"), -501): 10, + } + + with patch( + "comproscanner.comproscanner.MaterialsDataAgenticEvaluatorFlow" + ) as mock_flow_cls: + mock_flow = MagicMock() + mock_flow.kickoff.return_value = {"ok": True} + mock_flow_cls.return_value = mock_flow + + result = scanner.evaluate_agentic( + ground_truth_file="gt.json", + test_data_file="test.json", + extraction_agent_model_name="model-y", + value_error_thresholds=thresholds, + ) + + assert result == {"ok": True} + assert mock_flow_cls.call_args.kwargs["value_error_thresholds"] == thresholds + mock_flow.kickoff.assert_called_once_with()