diff --git a/.gitignore b/.gitignore index 9ef164cf..95430e74 100644 --- a/.gitignore +++ b/.gitignore @@ -224,3 +224,13 @@ benchmarks/questions.json #For testing **/benchmarks + +# Claude local and transient stuff +.claude/settings.local.json +.claude/skills/.cache/ +.claude/skills/*.log +.claude/skills/state/ +.claude/skills/tmp/ +.claude/conversations/ +.claude/sessions/ +.claude/checkpoints/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..5cf45e3b --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,96 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +ChatDKU is an agentic RAG (Retrieval-Augmented Generation) system for Duke Kunshan University. It answers student questions about courses, policies, requirements, and campus resources using a three-stage DSPy pipeline: **Planner** -> **Executor** (assess-act-distill loop) -> **Synthesizer**. + +## Commands + +```bash +# Install dependencies +uv sync + +# Run the agent CLI +python -m chatdku.core.agent + +# Run Django backend +python manage.py runserver + +# Run tests +python -m pytest tests/ +python -m pytest tests/test_retriever.py # single file + +# Lint (CI runs these on changed .py files in PRs) +black --check +flake8 --ignore=E203,W503 --max-line-length 120 + +# Format +black + +# Sync and run on dev server +./devsync.sh # runs the agent remotely +./devsync.sh chatdku/core/tools/your_file.py # runs a specific file +``` + +## Architecture + +### Agent Pipeline (`chatdku/core/`) + +The agent is three DSPy modules in sequence per user message: + +1. **Planner** (`dspy_classes/plan.py`) — Decides whether to answer directly (`send_message`), ask clarifying questions, or produce a plan for what information to gather. +2. **Executor** (`dspy_classes/executor.py`) — Runs an Assess-Act loop guided by the plan, calling tools to gather information. A Distill step extracts only relevant context from the trajectory. +3. **Synthesizer** (`dspy_classes/synthesizer.py`) — Generates the final cited response from distilled context. +4. **ConversationMemory** (`dspy_classes/conversation_memory.py`) — Compresses/maintains chat history across turns. + +Entry point: `chatdku/core/agent.py` + +### Tools (`chatdku/core/tools/`) + +Tools available to the executor: `VectorQuery` (ChromaDB semantic search), `KeywordQuery` (Redis BM25), `MajorRequirementsLookup`, `QueryCurriculum` (PostgreSQL course/syllabus DB), `PrerequisiteLookup`, and others (calculator, campus service, email, search, GraphRAG). + +### Document Ingestion (`chatdku/ingestion/`) + +Parses, chunks, and loads documents into vector/keyword stores: +- `update_data.py` — Parse and chunk documents +- `load_chroma.py` — Load into ChromaDB +- `load_redis.py` — Load into Redis (BM25 index) +- `load_postgres.py` — Load course data into PostgreSQL +- `clean_classdata.py` — Clean class schedule CSV data + +### Configuration (`chatdku/config.py`) + +Singleton `Config` class loaded from environment variables (`.env`). Access via `from chatdku.config import config`. Supports attribute access (`config.llm_temperature`), `.set()`, `.update()`, and a read-only `.view()`. All paths, model names, DB connections, and tuning parameters live here. + +### Backend (`chatdku/backend/`) + +Legacy Flask backend and Django backend (`chatdku/django/`). Django app uses `manage.py` at repo root. + +### Infrastructure + +- **LLM**: OpenAI-compatible endpoint (vLLM/SGLang with Qwen models) +- **Embeddings**: TEI with `BAAI/bge-m3` +- **Vector DB**: ChromaDB +- **Keyword Index**: Redis +- **Course DB**: PostgreSQL +- **Observability**: Arize Phoenix +- **Framework**: DSPy for agent logic, LlamaIndex for document ingestion + +## Code Style + +- **Formatter**: Black (pre-commit hook, CI) +- **Linter**: Flake8 — `max-line-length=120`, ignores `E203,W503` +- **Python**: 3.11 (pinned in `.python-version`); `pyproject.toml` requires `>=3.11, <3.13` +- **Docstrings**: NumPy format preferred (per GUIDE.md) + +## Key Environment Variables + +`LLM_BASE_URL`, `LLM_API_KEY`, `LLM_MODEL`, `TEI_URL`, `EMBEDDING_MODEL`, `CHROMA_HOST`, `CHROMA_DB_PORT`, `REDIS_HOST`, `REDIS_PORT`, `REDIS_PASSWORD`, `DB_USER`, `DB_PASSWORD`, `DB_HOST`, `DB_PORT`, `DB_NAME`. See `.env.example` or `chatdku/config.py` for the full set. + +## Git Workflow + +- `main` branch is protected — never push directly; always use PRs with review. +- Create a GitHub issue before starting work. +- Use `devsync.sh` to iterate on the shared dev server (rsyncs code, runs `uv sync`, starts a live session). diff --git a/chatdku/ingestion/major_ingest.py b/chatdku/ingestion/major_ingest.py index 1c99d1d9..2818a36f 100644 --- a/chatdku/ingestion/major_ingest.py +++ b/chatdku/ingestion/major_ingest.py @@ -137,7 +137,7 @@ def sanitize_filename(name: str) -> str: # Remove or replace unsafe characters safe = re.sub(r"[^\w\s-]", "", name) safe = re.sub(r"[-\s]+", "-", safe) - return safe.strip("-").lower() + return safe.strip().lower() def save_major_content(major_name: str, content: Dict, output_dir: Path): diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..ced3b457 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,147 @@ +"""Shared fixtures for ChatDKU tool tests.""" + +from contextlib import contextmanager +from unittest.mock import MagicMock + +import pandas as pd +import pytest + + +@pytest.fixture() +def mock_span_ctx(monkeypatch): + """Mock span_ctx_start so no real tracer/Phoenix is needed. + + Patches at every import site since each tool module binds the name at import time. + Returns the mock span for assertions on set_attributes / set_status. + """ + mock_span = MagicMock() + + @contextmanager + def fake_span_ctx_start(name, kind, parent_context=None): + yield mock_span + + targets = [ + "chatdku.core.utils.span_ctx_start", + "chatdku.core.tools.course_schedule.span_ctx_start", + "chatdku.core.tools.get_prerequisites.span_ctx_start", + "chatdku.core.tools.major_requirements.span_ctx_start", + "chatdku.core.tools.syllabi_tool.query_curriculum_db.span_ctx_start", + "chatdku.core.tools.retriever.base_retriever.span_ctx_start", + ] + for target in targets: + try: + monkeypatch.setattr(target, fake_span_ctx_start) + except (AttributeError, ImportError): + pass # module not yet imported — safe to skip + + return mock_span + + +@pytest.fixture() +def mock_get_current_span(monkeypatch): + """Mock get_current_span for llama_index_tools which uses it directly.""" + mock_span = MagicMock() + monkeypatch.setattr( + "chatdku.core.tools.llama_index_tools.get_current_span", lambda: mock_span + ) + return mock_span + + +@pytest.fixture() +def sample_classdata_csv(tmp_path): + """Create a temporary class schedule CSV with representative data.""" + csv_path = tmp_path / "classdata.csv" + df = pd.DataFrame( + { + "Subject": ["COMPSCI", "COMPSCI", "MATH", "BIOL", "CHINESE"], + "Catalog": ["101", "201", "201", "305", "101A"], + "Section": ["01", "01", "01", "01", "01"], + "Component": ["LEC", "LEC", "LEC", "LAB", "LEC"], + "Instructor": [ + "Alice Smith", + "Bob Jones", + "Carol Lee", + "Dave Kim", + "Eve Wu", + ], + "Days": ["MWF", "TTh", "MWF", "TTh", "MWF"], + "Start Time": ["09:00", "10:30", "11:00", "14:00", "13:00"], + "End Time": ["09:50", "11:45", "11:50", "15:15", "13:50"], + "Enrollment": [30, 25, 40, 15, 20], + } + ) + df.to_csv(csv_path, index=False) + return str(csv_path) + + +@pytest.fixture() +def sample_prereq_csv(tmp_path): + """Create a temporary prerequisites CSV with UTF-16LE encoding. + + Column layout matches positional access in get_prereq: + col 0: ID + col 1: Effective Date (MM/DD/YYYY) + col 2: Subject + col 3: Catalog + cols 4-12: padding + col 13: Description (prerequisite text) + """ + csv_path = tmp_path / "prereq.csv" + rows = [ + # COMPSCI 201 with prereqs, two rows with different dates + [ + 1, + "01/15/2023", + "COMPSCI", + "201", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "Prereq: COMPSCI 101", + ], + [ + 2, + "09/01/2024", + "COMPSCI", + "201", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "Prereq: COMPSCI 101 or COMPSCI 102", + ], + # MATH 201 with prereqs + [ + 3, + "03/10/2024", + "MATH", + "201", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "Prereq: MATH 101", + ], + # BIOL 305 with empty description + [4, "06/01/2024", "BIOL", "305", "", "", "", "", "", "", "", "", "", ""], + ] + columns = [f"col{i}" for i in range(14)] + df = pd.DataFrame(rows, columns=columns) + df.to_csv(csv_path, index=False, encoding="utf-16le") + return str(csv_path) diff --git a/tests/test_course_schedule.py b/tests/test_course_schedule.py new file mode 100644 index 00000000..3ced34ee --- /dev/null +++ b/tests/test_course_schedule.py @@ -0,0 +1,147 @@ +"""Comprehensive tests for chatdku.core.tools.course_schedule.""" + +import json + +import pandas as pd +import pytest +from opentelemetry.trace import StatusCode + +from chatdku.core.tools.course_schedule import ( + CourseScheduleLookupOuter, + _lookup, + _parse_course, +) + + +# --------------------------------------------------------------------------- +# _parse_course (pure function — no mocks needed) +# --------------------------------------------------------------------------- + + +class TestParseCourse: + def test_with_space(self): + assert _parse_course("COMPSCI 101") == ("COMPSCI", "101") + + def test_no_separator(self): + assert _parse_course("COMPSCI101") == ("COMPSCI", "101") + + def test_with_hyphen(self): + assert _parse_course("COMPSCI-101") == ("COMPSCI", "101") + + def test_alpha_suffix(self): + assert _parse_course("Chinese 101A") == ("CHINESE", "101A") + + def test_strips_whitespace(self): + assert _parse_course(" COMPSCI 101 ") == ("COMPSCI", "101") + + def test_lowercase_normalised_to_upper(self): + assert _parse_course("math 201") == ("MATH", "201") + + def test_empty_string_raises(self): + with pytest.raises(ValueError): + _parse_course("") + + def test_only_symbols_raises(self): + with pytest.raises(ValueError): + _parse_course("!!!") + + def test_only_numbers_raises(self): + with pytest.raises(ValueError): + _parse_course("12345") + + +# --------------------------------------------------------------------------- +# _lookup (needs a DataFrame, no external mocks) +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def schedule_df(): + return pd.DataFrame( + { + "Subject": ["COMPSCI", "COMPSCI", "MATH", "BIOL"], + "Catalog": ["101", "201", "201", "305"], + "Section": ["01", "02", "01", "01"], + "Instructor": ["Alice", "Bob", "Carol", "Dave"], + } + ) + + +class TestLookup: + def test_finds_matching_rows(self, schedule_df): + rows = _lookup("COMPSCI 101", schedule_df) + assert len(rows) == 1 + assert rows[0]["Subject"] == "COMPSCI" + assert rows[0]["Catalog"] == "101" + + def test_returns_empty_for_nonexistent(self, schedule_df): + assert _lookup("ASTROLOGY 1239", schedule_df) == [] + + def test_case_insensitive(self, schedule_df): + rows = _lookup("compsci 201", schedule_df) + assert len(rows) == 1 + + def test_multiple_sections(self, schedule_df): + # Add a second section for COMPSCI 101 + extra = pd.DataFrame( + { + "Subject": ["COMPSCI"], + "Catalog": ["101"], + "Section": ["02"], + "Instructor": ["Eve"], + } + ) + df = pd.concat([schedule_df, extra], ignore_index=True) + rows = _lookup("COMPSCI 101", df) + assert len(rows) == 2 + + +# --------------------------------------------------------------------------- +# CourseScheduleLookupOuter (needs mock_span_ctx + CSV fixture) +# --------------------------------------------------------------------------- + + +class TestCourseScheduleLookupOuter: + def test_returns_callable(self, mock_span_ctx, sample_classdata_csv): + fn = CourseScheduleLookupOuter(sample_classdata_csv) + assert callable(fn) + + def test_single_course_found(self, mock_span_ctx, sample_classdata_csv): + fn = CourseScheduleLookupOuter(sample_classdata_csv) + result = json.loads(fn(["COMPSCI 101"])) + assert "COMPSCI 101" in result + assert isinstance(result["COMPSCI 101"], list) + assert result["COMPSCI 101"][0]["Instructor"] == "Alice Smith" + + def test_multiple_courses(self, mock_span_ctx, sample_classdata_csv): + fn = CourseScheduleLookupOuter(sample_classdata_csv) + result = json.loads(fn(["COMPSCI 101", "MATH 201"])) + assert "COMPSCI 101" in result + assert "MATH 201" in result + + def test_course_not_found_message(self, mock_span_ctx, sample_classdata_csv): + fn = CourseScheduleLookupOuter(sample_classdata_csv) + result = json.loads(fn(["FAKE 999"])) + assert "No schedule found" in result["FAKE 999"] + + def test_mixed_found_and_not_found(self, mock_span_ctx, sample_classdata_csv): + fn = CourseScheduleLookupOuter(sample_classdata_csv) + result = json.loads(fn(["COMPSCI 101", "FAKE 999"])) + assert isinstance(result["COMPSCI 101"], list) + assert "No schedule found" in result["FAKE 999"] + + def test_file_not_found_raises(self, mock_span_ctx): + fn = CourseScheduleLookupOuter("/nonexistent/path.csv") + with pytest.raises(FileNotFoundError): + fn(["COMPSCI 101"]) + + def test_span_status_ok_on_success(self, mock_span_ctx, sample_classdata_csv): + fn = CourseScheduleLookupOuter(sample_classdata_csv) + fn(["COMPSCI 101"]) + calls = mock_span_ctx.set_status.call_args_list + assert any(c.args[0].status_code == StatusCode.OK for c in calls if c.args) + + def test_span_attributes_set(self, mock_span_ctx, sample_classdata_csv): + fn = CourseScheduleLookupOuter(sample_classdata_csv) + fn(["COMPSCI 101"]) + assert mock_span_ctx.set_attributes.called diff --git a/tests/test_course_schedule_ret.py b/tests/test_course_schedule_ret.py index d37b8459..b50d97f1 100644 --- a/tests/test_course_schedule_ret.py +++ b/tests/test_course_schedule_ret.py @@ -4,9 +4,6 @@ _parse_course, ) -CSV_PATH = "/tmp/cleaned_classdata.csv" -df = pd.read_csv(CSV_PATH) - def test_parse_course(): assert _parse_course("COMPSCI 101") == ("COMPSCI", "101") @@ -17,4 +14,10 @@ def test_parse_course(): def test_lookup(): + df = pd.DataFrame( + { + "Subject": ["COMPSCI", "MATH"], + "Catalog": ["101", "201"], + } + ) assert _lookup("ASTROLOGY 1239", df) == [] diff --git a/tests/test_llama_index_tools.py b/tests/test_llama_index_tools.py new file mode 100644 index 00000000..167e03d5 --- /dev/null +++ b/tests/test_llama_index_tools.py @@ -0,0 +1,290 @@ +"""Tests for chatdku.core.tools.llama_index_tools (VectorRetrieverOuter, KeywordRetrieverOuter).""" + +from contextlib import contextmanager +from unittest.mock import MagicMock, patch + +import pytest + +from chatdku.core.tools.retriever.base_retriever import NodeWithScore +from chatdku.core.tools.utils import QueryTimeoutError + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +SAMPLE_NODES = [ + NodeWithScore(node_id="1", text="doc one", metadata={"src": "a"}, score=0.9), + NodeWithScore(node_id="2", text="doc two", metadata={"src": "b"}, score=0.8), +] + + +@contextmanager +def fake_timeout(seconds=5): + """Drop-in replacement for the real timeout context manager.""" + + class FakeCtx: + def run(self, func, *args, **kwargs): + return func(*args, **kwargs) + + yield FakeCtx() + + +@contextmanager +def fake_timeout_that_expires(seconds=5): + """Simulates a timeout by raising QueryTimeoutError on .run().""" + + class FakeCtx: + def run(self, func, *args, **kwargs): + raise QueryTimeoutError(f"Query exceeded {seconds} second timeout") + + yield FakeCtx() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def _patch_vector_retriever(monkeypatch): + """Patch VectorRetriever class so no ChromaDB connection is needed.""" + mock_instance = MagicMock() + mock_instance.query_with_tell.return_value = SAMPLE_NODES + mock_cls = MagicMock(return_value=mock_instance) + monkeypatch.setattr( + "chatdku.core.tools.llama_index_tools.VectorRetriever", mock_cls + ) + return mock_instance + + +@pytest.fixture() +def _patch_keyword_retriever(monkeypatch): + """Patch KeywordRetriever class so no Redis connection is needed.""" + mock_instance = MagicMock() + mock_instance.query_with_tell.return_value = SAMPLE_NODES + mock_cls = MagicMock(return_value=mock_instance) + monkeypatch.setattr( + "chatdku.core.tools.llama_index_tools.KeywordRetriever", mock_cls + ) + return mock_instance + + +@pytest.fixture() +def _patch_rerank(monkeypatch): + """Patch the rerank function.""" + mock_rerank = MagicMock(return_value=SAMPLE_NODES[:1]) + monkeypatch.setattr("chatdku.core.tools.llama_index_tools.rerank", mock_rerank) + return mock_rerank + + +@pytest.fixture() +def _patch_timeout(monkeypatch): + """Replace the real timeout with a synchronous fake.""" + monkeypatch.setattr("chatdku.core.tools.llama_index_tools.timeout", fake_timeout) + + +@pytest.fixture() +def _patch_timeout_expires(monkeypatch): + """Replace the real timeout with one that always times out.""" + monkeypatch.setattr( + "chatdku.core.tools.llama_index_tools.timeout", fake_timeout_that_expires + ) + + +# --------------------------------------------------------------------------- +# VectorRetrieverOuter +# --------------------------------------------------------------------------- + + +class TestVectorRetrieverOuter: + @pytest.fixture(autouse=True) + def _setup( + self, + mock_get_current_span, + _patch_vector_retriever, + _patch_rerank, + _patch_timeout, + ): + self.mock_retriever = _patch_vector_retriever + self.mock_rerank = _patch_rerank + + def _make(self, **kwargs): + from chatdku.core.tools.llama_index_tools import VectorRetrieverOuter + + defaults = dict( + retriever_top_k=10, + use_reranker=False, + reranker_top_n=5, + user_id="Chat_DKU", + search_mode=0, + files=[], + ) + defaults.update(kwargs) + return VectorRetrieverOuter(**defaults) + + def test_returns_callable(self): + assert callable(self._make()) + + def test_query_returns_string(self): + fn = self._make() + result = fn("what is DKU?") + assert isinstance(result, str) + + def test_query_calls_retriever(self): + fn = self._make() + fn("what is DKU?") + self.mock_retriever.query_with_tell.assert_called_once() + + def test_with_reranker_calls_rerank(self): + fn = self._make(use_reranker=True) + fn("what is DKU?") + self.mock_rerank.assert_called_once() + + def test_without_reranker_skips_rerank(self): + fn = self._make(use_reranker=False) + fn("what is DKU?") + self.mock_rerank.assert_not_called() + + def test_invalid_search_mode_defaults_to_zero(self): + # Should not raise; logs a warning and defaults to 0 + fn = self._make(search_mode=5) + result = fn("test") + assert isinstance(result, str) + + def test_search_mode_nonzero_without_files_defaults(self): + # search_mode=1 but files=[] → should default to 0 + fn = self._make(search_mode=1, files=[]) + result = fn("test") + assert isinstance(result, str) + + def test_value_error_propagates(self): + self.mock_retriever.query_with_tell.side_effect = ValueError("bad input") + fn = self._make() + with pytest.raises(ValueError, match="bad input"): + fn("test") + + def test_retrieval_failure_raises_exception(self): + self.mock_retriever.query_with_tell.side_effect = RuntimeError( + "connection lost" + ) + fn = self._make() + with pytest.raises(Exception, match="Vector retrieval failed"): + fn("test") + + +class TestVectorRetrieverOuterTimeout: + def test_timeout_raises_exception( + self, + mock_get_current_span, + _patch_vector_retriever, + _patch_rerank, + _patch_timeout_expires, + ): + from chatdku.core.tools.llama_index_tools import VectorRetrieverOuter + + fn = VectorRetrieverOuter( + retriever_top_k=10, + use_reranker=False, + reranker_top_n=5, + user_id="Chat_DKU", + search_mode=0, + files=[], + ) + with pytest.raises(Exception, match="timed out"): + fn("test") + + +# --------------------------------------------------------------------------- +# KeywordRetrieverOuter +# --------------------------------------------------------------------------- + + +class TestKeywordRetrieverOuter: + @pytest.fixture(autouse=True) + def _setup( + self, + mock_get_current_span, + _patch_keyword_retriever, + _patch_rerank, + _patch_timeout, + ): + self.mock_retriever = _patch_keyword_retriever + self.mock_rerank = _patch_rerank + + def _make(self, **kwargs): + from chatdku.core.tools.llama_index_tools import KeywordRetrieverOuter + + defaults = dict( + retriever_top_k=10, + use_reranker=False, + reranker_top_n=5, + user_id="Chat_DKU", + search_mode=0, + files=[], + ) + defaults.update(kwargs) + return KeywordRetrieverOuter(**defaults) + + def test_returns_callable(self): + assert callable(self._make()) + + def test_query_string_returns_string(self): + fn = self._make() + result = fn("DKU courses") + assert isinstance(result, str) + + def test_query_list_converts_to_strings(self): + fn = self._make() + result = fn(["term1", 42, "term3"]) + assert isinstance(result, str) + # The function stringifies list items in-place + self.mock_retriever.query_with_tell.assert_called_once() + + def test_query_calls_retriever(self): + fn = self._make() + fn("test query") + self.mock_retriever.query_with_tell.assert_called_once() + + def test_with_reranker_calls_rerank(self): + fn = self._make(use_reranker=True) + fn("test") + self.mock_rerank.assert_called_once() + + def test_without_reranker_skips_rerank(self): + fn = self._make(use_reranker=False) + fn("test") + self.mock_rerank.assert_not_called() + + def test_invalid_search_mode_defaults_to_zero(self): + fn = self._make(search_mode=5) + result = fn("test") + assert isinstance(result, str) + + def test_retrieval_failure_raises_exception(self): + self.mock_retriever.query_with_tell.side_effect = RuntimeError("redis down") + fn = self._make() + with pytest.raises(Exception, match="Keyword retrieval failed"): + fn("test") + + +class TestKeywordRetrieverOuterTimeout: + def test_timeout_raises_exception( + self, + mock_get_current_span, + _patch_keyword_retriever, + _patch_rerank, + _patch_timeout_expires, + ): + from chatdku.core.tools.llama_index_tools import KeywordRetrieverOuter + + fn = KeywordRetrieverOuter( + retriever_top_k=10, + use_reranker=False, + reranker_top_n=5, + user_id="Chat_DKU", + search_mode=0, + files=[], + ) + with pytest.raises(Exception, match="Keyword retriever timeout"): + fn("test") diff --git a/tests/test_major_requirements.py b/tests/test_major_requirements.py new file mode 100644 index 00000000..19548536 --- /dev/null +++ b/tests/test_major_requirements.py @@ -0,0 +1,178 @@ +"""Tests for chatdku.core.tools.major_requirements.""" + +import pytest +from opentelemetry.trace import StatusCode + +from chatdku.core.tools.major_requirements import ( + MajorRequirementsLookupOuter, + _best_match, + _jaccard, + _list_stems, + _tokenize, +) + + +# --------------------------------------------------------------------------- +# _tokenize (pure) +# --------------------------------------------------------------------------- + + +class TestTokenize: + def test_lowercases(self): + assert _tokenize("Data Science") == {"data", "science"} + + def test_strips_separators(self): + result = _tokenize("data-science/track") + assert "data" in result + assert "science" in result + assert "track" in result + + def test_removes_punctuation(self): + result = _tokenize("hello! world?") + assert result == {"hello", "world"} + + def test_empty_string(self): + assert _tokenize("") == set() + + +# --------------------------------------------------------------------------- +# _jaccard (pure) +# --------------------------------------------------------------------------- + + +class TestJaccard: + def test_identical_sets(self): + assert _jaccard({"a", "b"}, {"a", "b"}) == 1.0 + + def test_disjoint_sets(self): + assert _jaccard({"a"}, {"b"}) == 0.0 + + def test_partial_overlap(self): + # intersection={b,c}, union={a,b,c,d} → 2/4 = 0.5 + assert _jaccard({"a", "b", "c"}, {"b", "c", "d"}) == 0.5 + + def test_empty_sets(self): + assert _jaccard(set(), set()) == 0.0 + + +# --------------------------------------------------------------------------- +# _best_match (pure) +# --------------------------------------------------------------------------- + + +class TestBestMatch: + STEMS = [ + "data-science", + "computation-and-design-computer-science", + "behavioral-science-psychology", + "requirements-for-all-majors", + ] + + def test_exact_match(self): + assert _best_match("data science", self.STEMS) == "data-science" + + def test_partial_match(self): + result = _best_match("computer science", self.STEMS) + assert result == "computation-and-design-computer-science" + + def test_no_match_returns_none(self): + assert _best_match("astrology", self.STEMS) is None + + def test_empty_query_returns_none(self): + assert _best_match("", self.STEMS) is None + + def test_requirements_for_all(self): + result = _best_match("requirements for all majors", self.STEMS) + assert result == "requirements-for-all-majors" + + +# --------------------------------------------------------------------------- +# _list_stems +# --------------------------------------------------------------------------- + + +class TestListStems: + def test_returns_sorted_stems(self, tmp_path): + (tmp_path / "b-major.md").write_text("B") + (tmp_path / "a-major.md").write_text("A") + stems = _list_stems(tmp_path) + assert stems == ["a-major", "b-major"] + + def test_ignores_non_md_files(self, tmp_path): + (tmp_path / "readme.txt").write_text("text") + (tmp_path / "data.md").write_text("data") + stems = _list_stems(tmp_path) + assert stems == ["data"] + + def test_empty_dir(self, tmp_path): + assert _list_stems(tmp_path) == [] + + +# --------------------------------------------------------------------------- +# MajorRequirementsLookupOuter (needs mock_span_ctx + tmp dir with .md files) +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def requirements_dir(tmp_path): + """Create a temporary requirements directory with sample .md files.""" + (tmp_path / "data-science.md").write_text( + "# Data Science\n\n- COMPSCI 101\n- STATS 202\n" + ) + (tmp_path / "computation-and-design-computer-science.md").write_text( + "# Computation and Design / Computer Science\n\n- COMPSCI 201\n" + ) + (tmp_path / "requirements-for-all-majors.md").write_text( + "# General Requirements\n\n- WRIT 101\n- MATH 101\n" + ) + return str(tmp_path) + + +class TestMajorRequirementsLookupOuter: + def test_returns_callable(self, mock_span_ctx, requirements_dir): + fn = MajorRequirementsLookupOuter(requirements_dir) + assert callable(fn) + + def test_list_returns_all_majors(self, mock_span_ctx, requirements_dir): + fn = MajorRequirementsLookupOuter(requirements_dir) + result = fn("list") + assert "data-science" in result + assert "computation-and-design-computer-science" in result + assert "requirements-for-all-majors" in result + + def test_lookup_returns_file_content(self, mock_span_ctx, requirements_dir): + fn = MajorRequirementsLookupOuter(requirements_dir) + result = fn("data science") + assert "COMPSCI 101" in result + assert "STATS 202" in result + + def test_lookup_prepends_requirements_header(self, mock_span_ctx, requirements_dir): + fn = MajorRequirementsLookupOuter(requirements_dir) + result = fn("data science") + assert result.startswith("# Requirements:") + + def test_no_match_returns_message(self, mock_span_ctx, requirements_dir): + fn = MajorRequirementsLookupOuter(requirements_dir) + result = fn("astrology") + assert "No matching major" in result + + def test_nonexistent_directory_raises(self, mock_span_ctx): + fn = MajorRequirementsLookupOuter("/nonexistent/path") + with pytest.raises(FileNotFoundError): + fn("data science") + + def test_empty_directory_raises(self, mock_span_ctx, tmp_path): + fn = MajorRequirementsLookupOuter(str(tmp_path)) + with pytest.raises(FileNotFoundError): + fn("data science") + + def test_span_status_ok_on_success(self, mock_span_ctx, requirements_dir): + fn = MajorRequirementsLookupOuter(requirements_dir) + fn("data science") + calls = mock_span_ctx.set_status.call_args_list + assert any(c.args[0].status_code == StatusCode.OK for c in calls if c.args) + + def test_span_attributes_set(self, mock_span_ctx, requirements_dir): + fn = MajorRequirementsLookupOuter(requirements_dir) + fn("data science") + assert mock_span_ctx.set_attributes.called diff --git a/tests/test_prerequisites.py b/tests/test_prerequisites.py new file mode 100644 index 00000000..71afdfe4 --- /dev/null +++ b/tests/test_prerequisites.py @@ -0,0 +1,85 @@ +"""Tests for chatdku.core.tools.get_prerequisites.""" + +import pytest +from opentelemetry.trace import StatusCode + +from chatdku.core.tools.get_prerequisites import PrerequisiteLookupOuter, get_prereq + + +# --------------------------------------------------------------------------- +# get_prereq (internal helper) +# --------------------------------------------------------------------------- + + +class TestGetPrereq: + def test_returns_prerequisite_description(self, sample_prereq_csv): + result = get_prereq("COMPSCI 201", sample_prereq_csv) + assert "(Source: DKUHub)" in result + assert "COMPSCI" in result + + def test_uses_latest_effective_date(self, sample_prereq_csv): + """Two rows for COMPSCI 201 — should pick the 09/01/2024 entry.""" + result = get_prereq("COMPSCI 201", sample_prereq_csv) + assert "COMPSCI 102" in result # only in the newer row + + def test_returns_not_found_for_unknown_course(self, sample_prereq_csv): + result = get_prereq("ASTRO 999", sample_prereq_csv) + assert "No prerequisites found" in result + + def test_empty_description_returns_not_found(self, sample_prereq_csv): + """BIOL 305 has an empty description in col 13.""" + result = get_prereq("BIOL 305", sample_prereq_csv) + assert "No prerequisites found" in result + + def test_file_not_found_raises(self): + with pytest.raises(FileNotFoundError): + get_prereq("COMPSCI 201", "/nonexistent/path.csv") + + def test_handles_extra_spaces_in_course_name(self, sample_prereq_csv): + result = get_prereq("COMPSCI 201", sample_prereq_csv) + # Should still parse — splits on underscore after space→underscore replacement + assert "COMPSCI" in result + + def test_known_course_with_prereqs(self, sample_prereq_csv): + result = get_prereq("MATH 201", sample_prereq_csv) + assert "MATH 101" in result + assert "(Source: DKUHub)" in result + + +# --------------------------------------------------------------------------- +# PrerequisiteLookupOuter (needs mock_span_ctx + sample CSV) +# --------------------------------------------------------------------------- + + +class TestPrerequisiteLookupOuter: + def test_returns_callable(self, mock_span_ctx, sample_prereq_csv): + fn = PrerequisiteLookupOuter(sample_prereq_csv) + assert callable(fn) + + def test_single_course_lookup(self, mock_span_ctx, sample_prereq_csv): + fn = PrerequisiteLookupOuter(sample_prereq_csv) + result = fn(["MATH 201"]) + assert "MATH 101" in result + + def test_multiple_courses_joined_by_newline(self, mock_span_ctx, sample_prereq_csv): + fn = PrerequisiteLookupOuter(sample_prereq_csv) + result = fn(["COMPSCI 201", "MATH 201"]) + assert "\n" in result + assert "COMPSCI" in result + assert "MATH" in result + + def test_file_not_found_propagates(self, mock_span_ctx): + fn = PrerequisiteLookupOuter("/nonexistent/path.csv") + with pytest.raises(FileNotFoundError): + fn(["COMPSCI 201"]) + + def test_span_status_ok_on_success(self, mock_span_ctx, sample_prereq_csv): + fn = PrerequisiteLookupOuter(sample_prereq_csv) + fn(["MATH 201"]) + calls = mock_span_ctx.set_status.call_args_list + assert any(c.args[0].status_code == StatusCode.OK for c in calls if c.args) + + def test_span_attributes_set(self, mock_span_ctx, sample_prereq_csv): + fn = PrerequisiteLookupOuter(sample_prereq_csv) + fn(["MATH 201"]) + assert mock_span_ctx.set_attributes.called